Flow Matching: A Visual Introduction

3 days ago 2

Flow Matching (FM) has become a prevalent technique to train a certain class of generative models. In this post we'll try to explore the intuition behind flow matching and how it works.

We'll use this notebook to build a simple flow matching model illustrating linear flow matching based on a minimal toy example. Our goal is to try to keep things simple, intuitive, and visual. We won't be doing any deep dive into the mathematical details of the model, if you're interested in the mathematical details I recommend checking out the references at the end of this post.

In [1]:

# Imports and setup import base64 import functools from pathlib import Path import matplotlib import matplotlib.pyplot as plt import numpy as np import pandas as pd import scipy.stats import seaborn as sns import torch import torch.nn as nn import torch.optim as optim from IPython.display import HTML from matplotlib.animation import FuncAnimation from tqdm import tqdm sns.set_style("darkgrid") # Set the style of the plots pd.options.display.float_format = "{:,.3f}".format # Table display format # Set random seeds for reproducibility np.random.seed(42) torch.manual_seed(626) # PyTorch Device configuration DEVICE: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") #

Flow matching

Flow matching is a technique to learn how to transport samples from one distribution to another. For example we could learn how to transport samples from a simple distribution we can easily sample from (e.g. Gaussian noise) to a complex distribution (e.g. images , videos , robot actions , etc.).

Toy Example: Mapping Gaussian noise to a bimodal distribution

In this post we'll build a simple toy example of a generative model using flow matching. For illustrative purposes we'll start with a simple 1D bimodal target distribution $π_1$ and learn how to transport samples from a 1D Gaussian noise distribution $π_0$ to this target distribution.

In practice the target points $x_1 \sim π_1$ are approximated by sampling from a limited dataset of training points $X_1$ and the noise points $x_0 \sim π_0$ are sampled from a chosen noise distribution $π_0$ that is easy to sample from (e.g. Gaussian noise).

In [2]:

# Define 1D bimodal target distribution mixture_prob = np.array([0.55, 0.45], dtype=float) # Mixture weights mixture_mus = np.array([-0.85, 1.5], dtype=float) # Means of the two Gaussian modes mixture_sigmas = np.array([0.65, 0.25], dtype=float) # Standard deviations of the modes def mixture_pdf(x: np.ndarray) -> np.ndarray: """Compute the PDF of a mixture of Gaussians.""" comps = scipy.stats.norm.pdf(x[None, :], loc=mixture_mus[:, None], scale=mixture_sigmas[:, None]) return np.sum(mixture_prob[:, None] * comps, axis=0) def mixture_sample(size: int) -> np.ndarray: """Sample from a mixture of Gaussians.""" rand_idx = np.random.choice(range(len(mixture_prob)), size=size, p=mixture_prob) means = mixture_mus[rand_idx] stds = mixture_sigmas[rand_idx] return np.random.normal(loc=means, scale=stds) # Plot data distribution. This is the TARGET distribution (π₁) fig, ax = plt.subplots(1, 1, figsize=(8, 3), constrained_layout=True, dpi=100) x_all_steps = np.linspace(-3, 3, 1000) pdf_noise = scipy.stats.norm.pdf(x_all_steps, loc=0, scale=1) ax.plot(x_all_steps, pdf_noise, label="PDF Noise π₀", color="tab:orange") ax.fill_between(x_all_steps, pdf_noise, alpha=0.4, color="tab:orange") pdf_target = mixture_pdf(x=x_all_steps) ax.plot(x_all_steps, pdf_target, label="PDF Target π₁", color="tab:blue") ax.fill_between(x_all_steps, pdf_target, alpha=0.4, color="tab:blue") ax.legend() ax.set_title("Toy Example Data: Noise (π₀) vs Target (π₁) Distributions") ax.set_xlabel("x") ax.set_ylabel("density") plt.show() del fig, ax, x_all_steps, pdf_noise, pdf_target #

No description has been provided for this image

The flow matching model predicts a velocity field

A flow matching model does not predict flow paths directly, but instead predicts a velocity field that can be used to sample the flow paths. The velocity field describes how to move a sample from the noise distribution to the target distribution.

We can describe the flow matching model with learnable parameters $\theta$ as a function: $${FM}_{\theta}(x_t, t) = v(x_t, t)$$ This function takes a sample $x_t$ at flow step $t$ and predicts the velocity vector $v(x_t, t) = dx_t / dt$ that describes how to move the sample $x_t$ closer to the target distribution at step $t$.

The step $t$ is a value between 0 and 1 that describes the progress of the sample $x_t$ along the flow path from the noise distribution to the target distribution. When $t=0$ the sample $x_t = x_0$ is a sample from the noise distribution $π_0$ and when $t=1$ the sample $x_t = x_1$ is a sample from the target distribution $π_1$.

At inference time we can sample a starting point $x_0$ from the noise distribution $π_0$ and then use the predicted velocity field ${FM}_{\theta}(x_t, t)$ to iteratively move the sample towards the target distribution $π_1$ in small steps $dt$

This is illustrated in the following animation ( generated further down in the notebook ) which shows the integration of a sample from the noise distribution $π_0$ on the left towards the target distribution $π_1$ on the right using the predicted velocity field ${FM}_{\theta}(x_t, t)$. The velocity field is visualized as a heatmap where the vertical axis represents the position of the sample $x_t$ and the horizontal axis represents the flow step $t$ going from 0 on the left to 1 on the right. Red means a positive velocity (sample pushed up towards higher $x$) and blue means a negative velocity (sample pulled down towards lower $x$).

In [3]:

# Display the animation in the notebook # This animation is generated further down in the notebook, if it doesn't exist yet we'll skip the display # Embed the GIF directly in the notebook by encoding the bytes as base64, this way it should hopefully also be exported ANIMATION_FILE = Path("flow_matching_path_integration.mp4") if ANIMATION_FILE.exists(): with ANIMATION_FILE.open("rb") as f: gif_data = f.read() display( HTML(f""" <video alt="Flow matching path integration" loop="true" autoplay="autoplay" muted> <source type="video/mp4" src="data:video/mp4;base64,{base64.b64encode(gif_data).decode()}"> </video> """) ) else: print( "Animation file not yet created, it is generated further down in the notebook. Run the full notebook to generate it." ) #

Training the flow matching model is learning the velocity field

Since the flow matching model ${FM}_{\theta}(x_t, t)$ should predict the velocity field $v(x_t, t) = dx_t / dt$ we can train the model on samples of velocity vectors $\mathbf{v}(x_t, t)$.

The flow matching training objective is to minimize the expected reconstruction error of the velocity field: $$ \underset{\theta}{\text{argmin}} \; \mathbb{E}_{t, x_t} \Big\| {FM}_{\theta}(x_t, t) - v(x_t, t) \Big\|^2 $$

with $t \sim \mathcal{U}[0, 1]$ and $x_t$ taken from a sampled reference path evaluated at flow step $t$.

We'll be using straight line reference paths in this post since they are simple and common.

Training: Straight line reference paths

We're going to focus on a common variant of flow matching where we learn a flow matching model based on straight line reference paths. Training flow matching with straight-line conditional paths and independent couplings is also equivalent to the rectified flow training objective.

Linear (straight line) flow matching is trained on a set of reference paths between the noise and target distributions. More specifically, linear flow matching prefers learning from straight line trajectories between the noise and target distributions because they tend to give straighter paths that require fewer steps to reconstruct the target distribution.

To sample a reference path we can independently sample a target point $x_1$ from our target distribution $π_1$ and independently sample a noise point $x_0$ from the noise distribution $π_0$. This gives us a single coupling $(x_0, x_1)$ that allows us to define a straight line reference path between the noise and target samples. During training we'll sample a large set of coupling-inducing paths $(X_0, X_1)$ and use these to train the flow matching model.

The following code illustrates how we define the straight line reference path between a noise and target sample.

In [4]:

def interpolate_linear(x_0, x_1, t): """Evaluates the linear interpolation path between x_0 and x_1 at step t.""" x_t = ((1 - t) * x_0) + (t * x_1) return x_t

The following figure shows a few sampled straight-line reference paths, as well as the reference path distribution approximated by sampling a large number of straight-line reference paths.

In [5]:

# Illustration of the sampled reference paths # Set up the plot fig, ((ax11, ax12, ax13), (ax21, ax22, ax23)) = plt.subplots( 2, 3, figsize=(12, 8), gridspec_kw={"width_ratios": [1, 5, 1]}, sharey=True, dpi=100, ) fig.subplots_adjust(wspace=0) x_min, x_max = -2.5, 2.5 x_all_steps = np.linspace(x_min, x_max, 1000) # Sample set of noise and target points data_size: int = 100_000 np.random.seed(1) # Set random seeds for reproducibility data_x_0 = np.random.randn(data_size) data_x_1 = mixture_sample(size=data_size) # Plot a few sample paths ########################################## # Plot Noise distribution π₀ ax11.set_title("Noise Distribution π₀") pdf_noise = scipy.stats.norm.pdf(x_all_steps, loc=0, scale=1) ax11.plot(pdf_noise, x_all_steps, label="PDF Noise π₀", color="tab:orange") ax11.fill_betweenx(x_all_steps, pdf_noise, alpha=0.4, color="tab:orange") ax11.invert_xaxis() ax11.set_ylabel("x") ax11.set_xlabel("") # ax11.set_xlim(0, 1) # Plot final distribution x1 ax13.set_title("Target Distribution π₁") pdf_target = mixture_pdf(x=x_all_steps) ax13.plot(pdf_target, x_all_steps, label="PDF Target π₁", color="tab:blue") ax13.fill_betweenx(x_all_steps, pdf_target, alpha=0.4, color="tab:blue") ax13.set_xlabel("") ax13.yaxis.set_label_position("right") ax13.set_ylabel("x") # Also show y-axis values on the right side ax13.yaxis.set_visible(True) ax13.yaxis.set_tick_params(labelright=True) # ax13.set_xlim(0, 1) # Plot the Sample paths nb_samples = 7 t = np.arange(0, 1, 0.01) ax12.set_title("Sample of straight line reference paths") colors = plt.cm.tab10.colors for i in range(nb_samples): color = colors[i % len(colors)] ax12.plot(t, interpolate_linear(x_0=data_x_0[i], x_1=data_x_1[i], t=t), alpha=0.5, color=color) ax12.scatter([0, 1], [data_x_0[i], data_x_1[i]], color=color) ax12.set_ylim(x_min, x_max) ax12.set_xlim(0, 1) ax12.set_xlabel("$t$: flow step") # Plot the full data distribution ################################## # Plot Noise samples X0 ax21.set_title("Noise Samples X₀") ax21.hist( data_x_0.flatten(), bins=100, alpha=0.5, label="Noise π₀", color="tab:orange", density=True, orientation="horizontal", ) ax21.invert_xaxis() ax21.set_ylabel("x") ax21.set_xlabel("density") ax21.sharex(ax11) # ax21.set_xlim(0, 1) # Plot target data distribution x1 ax23.set_title("Target Data X₁") ax23.hist( data_x_1.flatten(), bins=100, alpha=0.5, label="Target π₁", color="tab:blue", density=True, orientation="horizontal" ) ax23.set_xlabel("density") ax23.yaxis.set_label_position("right") ax23.set_ylabel("x") # Also show y-axis values on the right side ax23.yaxis.set_visible(True) ax23.yaxis.set_tick_params(labelright=True) ax23.sharex(ax13) # ax23.set_xlim(0, 1) # Plot path density n_samples = int(data_x_1.shape[0]) dt: float = 0.01 # Step size for Euler integration n_flow_steps = int(1 / dt) # Set up the path density histogram parameters img_hist_size = 480 path_density_bins = np.zeros((n_flow_steps + 1, img_hist_size)) flow_field_img_x_bin_edges = np.linspace(x_min, x_max, img_hist_size + 1) # Add the histogram of the initial distribution path_density_bins[0] = np.histogram(data_x_0, bins=flow_field_img_x_bin_edges)[0] path_density_bins[-1] = np.histogram(data_x_1, bins=flow_field_img_x_bin_edges)[0] # Build up the histogram of the reference paths by going over the discretized t-bins for i in range(n_flow_steps): t = np.full((n_samples,), i * dt) x_t = interpolate_linear(x_0=data_x_0, x_1=data_x_1, t=t) path_density_bins[i] = np.histogram(x_t, bins=flow_field_img_x_bin_edges)[0] im = ax22.imshow(path_density_bins.T, aspect="auto", origin="lower", extent=[0, 1, x_min, x_max], cmap="viridis") ax22.set_xlabel("$t$: flow step") ax22.set_title("Reference path density between X₀ and X₁") ax22.grid(False) plt.tight_layout() plt.show() del (fig, ax11, ax12, ax13, ax21, ax22, ax23, x_all_steps, data_x_0, data_x_1, n_samples, dt, n_flow_steps, # fmt: skip img_hist_size, path_density_bins, flow_field_img_x_bin_edges, i, t, x_t, im) # fmt: skip #

No description has been provided for this image

Training: Sampling velocity vectors

Since we are using straight-line reference paths, the sampled velocity vectors $\mathbf{v}(x_t, t)$ have a very simple form. Given a sample from the noise distribution $x_0$ and a sample from the target distribution $x_1$ we can describe the conditional velocity vector along the straight-line connecting $x_0$ and $x_1$ as: $\mathbf{v}(x_t, t) = x_1 - x_0$ as illustrated in the following code and figure.

In [6]:

def get_target_velocity(x_0, x_1): """ Get the velocity for a given pair of noise and target points. This is the per-pair (conditional) velocity along the straight path. """ return x_1 - x_0

In [7]:

# Illustrate the flow matching target velocity vector # Set up the plot fig, (ax1, ax2, ax3) = plt.subplots( 1, 3, figsize=(12, 4), gridspec_kw={"width_ratios": [1, 5, 1]}, sharey=True, dpi=100, ) fig.subplots_adjust(wspace=0) x_min, x_max = -2.5, 2.5 x_all_steps = np.linspace(x_min, x_max, 1000) # Plot a few sample paths ########################################## # Plot Noise distribution π₀ ax1.set_title("Noise Distribution π₀") pdf_noise = scipy.stats.norm.pdf(x_all_steps, loc=0, scale=1) ax1.plot(pdf_noise, x_all_steps, label="PDF Noise π₀", color="tab:orange") ax1.fill_betweenx(x_all_steps, pdf_noise, alpha=0.4, color="tab:orange") ax1.invert_xaxis() ax1.set_ylabel("x") ax1.set_xlabel("") # ax11.set_xlim(0, 1) # Plot final distribution x1 ax3.set_title("Target Distribution π₁") pdf_target = mixture_pdf(x=x_all_steps) ax3.plot(pdf_target, x_all_steps, label="PDF Target π₁", color="tab:blue") ax3.fill_betweenx(x_all_steps, pdf_target, alpha=0.4, color="tab:blue") ax3.set_xlabel("") ax3.yaxis.set_label_position("right") ax3.set_ylabel("x") # Also show y-axis values on the right side ax3.yaxis.set_visible(True) ax3.yaxis.set_tick_params(labelright=True) # ax13.set_xlim(0, 1) # Plot the Sample paths x_0_examplar = 1.32 x_1_examplar = -1.92 t_examplar = 0.67 x_t_examplar = interpolate_linear(x_0=x_0_examplar, x_1=x_1_examplar, t=t_examplar) # Annotate the path between x_0 and x_1 ax2.set_title("Flow matching target for a sample path: velocity $v(x_t, t) = x_1 - x_0$") ax2.plot([0, 1], [x_0_examplar, x_1_examplar], alpha=0.8, color="tab:green", label="Path(x₀, x₁)") ax2.scatter([0, 1], [x_0_examplar, x_1_examplar], color="tab:green") # ax2.scatter([1], [x_0_examplar], color="tab:green", marker="_") ax2.legend() ax2.plot([0, 1], [x_0_examplar, x_0_examplar], alpha=0.5, color="tab:green", linestyle="dotted") ax2.annotate( "$x_0$", xy=(0, x_0_examplar), xycoords="data", xytext=(-20, x_0_examplar - 3), # Shift the text to the left of the point textcoords="offset points", fontsize=16, color="tab:green", annotation_clip=False, ) ax2.annotate( " $x_0$", xy=(1, x_0_examplar), xycoords="data", xytext=(1, x_0_examplar - 3), textcoords="offset points", fontsize=16, color="tab:green", annotation_clip=False, ) ax2.annotate( " $x_1$", xy=(1, x_1_examplar), xycoords="data", xytext=(1, x_1_examplar - 3), textcoords="offset points", fontsize=16, color="tab:green", annotation_clip=False, ) # Annotate x_t ax2.annotate( "$x_t$", xy=(0, x_t_examplar), xycoords="data", xytext=(-20, x_t_examplar), # Shift the text to the left of the point textcoords="offset points", fontsize=16, color="tab:gray", annotation_clip=False, ) ax2.plot([0, t_examplar], [x_t_examplar, x_t_examplar], linestyle=":", color="tab:gray") # Annotate t ax2.annotate( "$t$", xy=(t_examplar, x_min), xycoords="data", xytext=(t_examplar - 4, x_min - 12), # Shift the text below the point textcoords="offset points", fontsize=16, color="tab:gray", annotation_clip=False, ) ax2.plot([t_examplar, t_examplar], [x_min, x_t_examplar], linestyle=":", color="tab:gray") # Annotate the velocity vector ax2.annotate( "", xy=(1, x_1_examplar), xycoords="data", xytext=(1, x_0_examplar), arrowprops=dict(arrowstyle="->", color="tab:red", linewidth=2), annotation_clip=False, ) ax2.annotate( " $x_1 - x_0$", xy=(1, (x_0_examplar + x_1_examplar) / 2), xycoords="data", xytext=(1, (x_0_examplar + x_1_examplar) / 2), textcoords="offset points", fontsize=16, color="tab:red", annotation_clip=False, ) ax2.scatter([t_examplar], [x_t_examplar], color="tab:red", marker="D", zorder=10) ax2.annotate( r" $\mathbf{v}(x_t, t) = x_1 - x_0$", xy=(t_examplar, x_t_examplar), xycoords="data", xytext=(t_examplar, x_t_examplar), textcoords="offset points", fontsize=16, color="tab:red", annotation_clip=False, ) ax2.set_ylim(x_min, x_max) ax2.set_xlim(0, 1) ax2.set_xlabel(r"$t$: flow step") plt.tight_layout() plt.show() del (fig, ax1, ax2, ax3, x_min, x_max, x_all_steps, # fmt: skip x_0_examplar, x_1_examplar, t_examplar, x_t_examplar) # fmt: skip #

No description has been provided for this image

Training: Flow matching objective

We can now write out our objective as a function of the samples from the noise distribution $x_0$ and the target distribution $x_1$: $$ \underset{\theta}{\text{argmin}} \; \mathbb{E}_{t, X_0, X_1} \Big\| {FM}_{\theta}(x_t, t) - (X_1 - X_0) \Big\|^2 \quad\quad $$ with $t \sim \mathcal{U}[0, 1]$, $X_0 \sim \pi_0$, $X_1 \sim \pi_1$, and $x_t = (1 - t) X_0 + t X_1$.

Note that the flow matching model ${FM}_{\theta}(x_t, t)$ is trained conditionally on specific straight-line couplings $(X_0, X_1)$, but since these are averaged out in the training objective, the flow matching model will learn an approximation of the velocity field independent of any specific coupling.

For this simple toy example we could even approximate the flow field directly by sampling a large number of reference paths and computing the average velocity for fixed bins over the flow field. This approximated expectation is illustrated in the following figure, which shows the average flow field. Red means a positive velocity (sample pushed up towards higher $x$) and blue means a negative velocity (sample pulled down towards lower $x$).

In [8]:

# Flow field illustration, approximate the flow field by sampling a large number of reference paths over discretized flow step bins. # Sample set of noise and target points data_size: int = 100_000 np.random.seed(1) # Set random seeds for reproducibility data_x_0 = np.random.randn(data_size) data_x_1 = mixture_sample(size=data_size) # Set up the plot fig, (ax1, ax2, ax3) = plt.subplots( 1, 3, figsize=(12, 5), gridspec_kw={"width_ratios": [1, 5, 1]}, sharey=True, dpi=100, constrained_layout=True, ) x_min, x_max = -2.2, 2.2 # Narrow view for flow field image x_min_wide, x_max_wide = -3.2, 3.2 # Wide view for sample paths # Plot the velocity field ########################################## # Plot Noise samples X0 ax1.set_title("Noise samples X₀") ax1.hist( data_x_0.flatten(), bins=100, alpha=0.5, label="Noise π₀", color="tab:orange", density=True, orientation="horizontal", ) ax1.invert_xaxis() ax1.set_ylabel("x") ax1.set_xlabel("density") # Plot target data distribution x1 ax3.set_title("Target data X₁") ax3.hist( data_x_1.flatten(), bins=100, alpha=0.5, label="Target π₁", color="tab:blue", density=True, orientation="horizontal" ) ax3.set_xlabel("density") ax3.yaxis.set_label_position("right") ax3.set_ylabel("x") # Also show y-axis values on the right side ax3.yaxis.set_visible(True) ax3.yaxis.set_tick_params(labelright=True) # Plot path density n_samples = int(data_x_1.shape[0]) dt: float = 0.05 # Step size for Euler integration n_flow_steps = int(1 / dt) # Set up the path density histogram parameters img_hist_size = 200 # Narrow view for flow field image flow_field_img_bins = np.zeros((n_flow_steps + 1, img_hist_size)) flow_field_img_x_bin_edges = np.linspace(x_min, x_max, img_hist_size + 1) # Wide view to compute the pathlines, need to be wide enough because paths can flow out of the narrow view flow_field_for_path_bins = np.zeros((n_flow_steps + 1, img_hist_size)) flow_field_for_path_x_bin_edges = np.linspace(x_min_wide, x_max_wide, img_hist_size + 1) # Build up the histogram of the reference paths by going over the discretized t-bins # We're building 2 histograms: # - `flow_field_img_bins` narrow view for the flow field image, might have NaNs if there are no samples in a bin # - `flow_field_for_path_bins` wide view for the pathlines, avoid NaNs by setting the velocity to 0 if there are no samples in a bin v = get_target_velocity(x_0=data_x_0, x_1=data_x_1) for i in range(n_flow_steps + 1): t = np.full((n_samples,), i * dt) x_t = interpolate_linear(x_0=data_x_0, x_1=data_x_1, t=t) # Get the average velocity for each bin in the narrow view for the flow field image x_t_bin_indices = np.digitize(x=x_t, bins=flow_field_img_x_bin_edges[1:-1], right=False) counts = np.bincount(x_t_bin_indices, minlength=img_hist_size) sums = np.bincount(x_t_bin_indices, weights=v, minlength=img_hist_size) flow_field_img_bins[i] = np.divide( sums, counts, out=np.full(img_hist_size, np.nan, dtype=float), where=counts > 0, ) # Get the average velocity for each bin in the wide view for the pathlines x_t_bin_indices_wide = np.digitize(x=x_t, bins=flow_field_for_path_x_bin_edges[1:-1], right=False) counts_wide = np.bincount(x_t_bin_indices_wide, minlength=img_hist_size) sums_wide = np.bincount(x_t_bin_indices_wide, weights=v, minlength=img_hist_size) flow_field_for_path_bins[i] = np.divide( sums_wide, counts_wide, out=np.zeros(img_hist_size, dtype=float), # Avoid NaNs for any path sampling where=counts_wide > 0, ) # Plot the flow field max_abs_flow_field = np.nanmax(np.abs(flow_field_img_bins)) color_norm = matplotlib.colors.TwoSlopeNorm(vmin=-max_abs_flow_field, vcenter=0.0, vmax=max_abs_flow_field) im = ax2.imshow( flow_field_img_bins.T, aspect="auto", origin="lower", extent=[0, 1, x_min, x_max], cmap="coolwarm", norm=color_norm, ) ax2.set_ylim(x_min, x_max) cbar = fig.colorbar( im, ax=[ax1, ax2, ax3], orientation="horizontal", fraction=0.08, aspect=40, pad=0.04, ) cbar.set_label("Velocity field value (red pushes up, blue pulls down)") # Sample some paths from the mean flow field to show the pathlines n_paths = 9 paths = np.zeros((n_flow_steps + 1, n_paths)) paths[0] = np.linspace(start=x_min_wide, stop=x_max_wide, num=n_paths) for i in range(n_flow_steps): t = np.full((n_samples,), i * dt) x_tm1 = paths[i] x_tm1_bin_indices = np.digitize(x=x_tm1, bins=flow_field_for_path_x_bin_edges[1:-1], right=False) v = flow_field_for_path_bins[i, x_tm1_bin_indices] paths[i + 1] = x_tm1 + v * dt # Plot the pathlines with arrows using quiver, following best practices t_coords = np.linspace(0, 1, n_flow_steps + 1) arrow_stride = 7 # space out arrows for clarity for i in range(n_paths): y = paths[:, i] idx = np.arange(0, len(t_coords) - 1, arrow_stride) x0 = t_coords[idx] y0 = y[idx] u = np.diff(t_coords)[idx] v = np.diff(y)[idx] ax2.quiver( x0, y0, u, v, angles="xy", scale_units="xy", scale=0.6, units="inches", width=0.015, headwidth=6, headlength=9, headaxislength=7, pivot="tail", color="dimgray", alpha=0.9, zorder=2, ) # overlay original line on top of arrows for clarity ax2.plot( t_coords, y, linestyle="-", color="dimgray", linewidth=1.5, alpha=0.7, zorder=3, label=f"Sample path {i + 1}", ) ax2.set_xlabel("$t$: flow step") ax2.set_title("Average velocity field with pathlines") ax2.grid(False) plt.show() del (fig, ax1, ax2, ax3, x_min, x_max, x_min_wide, x_max_wide, data_size, # fmt: skip data_x_0, data_x_1, n_samples, dt, n_flow_steps, img_hist_size, # fmt: skip flow_field_img_bins, flow_field_img_x_bin_edges, flow_field_for_path_bins, # fmt: skip flow_field_for_path_x_bin_edges, v, t, x_t, x_t_bin_indices, counts, sums, # fmt: skip x_t_bin_indices_wide, counts_wide, sums_wide) # fmt: skip #

No description has been provided for this image

Training the Flow Matching model

Now that we have defined our optimization objective, and how we can sample the data to train the model, we can define the flow matching model and train it. We'll create a simple neural network with a single hidden layer that we can train to predict the velocity field.

In [9]:

class FlowMatchingModel(nn.Module): """ Flow Matching model to predict the velocity field at time t and position x_t. """ def __init__(self, data_dim: int, hidden_dim: int) -> None: super().__init__() # Simple MLP self.net: nn.Sequential = nn.Sequential( nn.Linear(data_dim + 1, hidden_dim), # +1 for time embedding nn.GELU(), nn.Linear(hidden_dim, hidden_dim), nn.GELU(), nn.Linear(hidden_dim, data_dim), ) def forward( self, t: torch.Tensor, # Denoising step [batch_size, 1] x_t: torch.Tensor, # Interpolated samples [batch_size, data_dim] ) -> torch.Tensor: # [batch_size, data_dim] """ Predicts the velocity field at time t and position x_t. """ tx: torch.Tensor = torch.cat([t, x_t], dim=-1) return self.net(tx)

We can now define the loss function as a function of the flow matching model, the noise samples $X_0$, the target samples $X_1$, and the flow steps $T$:

In [10]:

def compute_loss( flow_matching_model: FlowMatchingModel, x_0: torch.Tensor, x_1: torch.Tensor, t: torch.Tensor, ) -> torch.Tensor: """ Compute the loss for a single batch of (X_0, X_1) couplings and flow steps T. """ # Interpolate the data at the sampled time step x_t = interpolate_linear(x_0=x_0, x_1=x_1, t=t) # Get the target velocity v_target = get_target_velocity(x_0=x_0, x_1=x_1) # Predict the velocity v_pred = flow_matching_model(t=t, x_t=x_t) # Compute the loss loss = ((v_pred - v_target) ** 2).mean() return loss

Using this loss function we can now train the flow matching model in a straightforward gradient-based optimization loop. We'll use a standard Adam optimizer to optimize the model parameters.

In [11]:

# Train the flow matching model # Hyperparameters data_dim: int = 1 # 1D data hidden_dim: int = 64 nb_train_iterations: int = 10_000 lr: float = 1e-3 batch_size: int = 256 # Set random seeds for reproducibility np.random.seed(42) torch.manual_seed(626) # Initialize the vector field network and optimizer flow_matching_model = FlowMatchingModel(data_dim=data_dim, hidden_dim=hidden_dim).to(DEVICE).train() optimizer = optim.Adam(flow_matching_model.parameters(), lr=lr) # Training loop losses: list[float] = [] with tqdm(range(nb_train_iterations), desc="Training", unit="iteration") as progress_bar: for i in progress_bar: # Sample a batch of target and noise samples x_1 = torch.from_numpy(mixture_sample(size=batch_size)).to(dtype=torch.float32, device=DEVICE).unsqueeze(-1) x_0 = torch.randn_like(x_1) # Sample a random time step for each sample in the batch t = torch.rand(x_1.shape[0], device=DEVICE).unsqueeze(-1) # Compute the loss loss = compute_loss(flow_matching_model=flow_matching_model, x_0=x_0, x_1=x_1, t=t) # Backpropagate the loss optimizer.zero_grad() loss.backward() optimizer.step() losses.append(loss.item()) progress_bar.set_postfix({"Loss": f"{loss.item():.2f}"})

In [12]:

# Plot loss curve after training fig, ax = plt.subplots(figsize=(12, 3), dpi=100) ax.plot(losses, color="tab:blue", alpha=0.5, label="Loss") ax.set_xlabel("Iteration") ax.set_ylabel("Loss") ax.set_title("Training Loss Curve") # Plot a smoothed loss curve using a simple moving average window_size = 100 smoothed_losses = np.convolve(losses, np.ones(window_size) / window_size, mode="valid") ax.plot(np.arange(window_size - 1, len(losses)), smoothed_losses, color="tab:blue", label="Loss (moving avg)") ax.legend(loc="upper right") ax.set_xlim(0, len(losses)) ax.grid(True) plt.show() del fig, ax, window_size, smoothed_losses #

No description has been provided for this image

Visualizing the trained flow matching model

Now that we have trained this simple flow matching model we can visualize the learned velocity field by getting the predicted velocity field ${FM}_{\theta}(x_t, t)$ at a grid of points $(t, x_t)$ and plotting this grid of velocities as a color image. Red means a positive velocity (sample pushed up towards higher $x$) and blue means a negative velocity (sample pulled down towards lower $x$).

In [13]:

# Flow field visualization of trained model # Set up the plot fig, (ax1, ax2, ax3) = plt.subplots( 1, 3, figsize=(12, 5), gridspec_kw={"width_ratios": [1, 5, 1]}, sharey=True, dpi=100, constrained_layout=True, ) x_min, x_max = -2.2, 2.2 # Narrow view for flow field image x_min_wide, x_max_wide = -3.2, 3.2 # Wide view for sample paths # Sample set of noise and target points data_size: int = 100_000 np.random.seed(1) # Set random seeds for reproducibility data_x_0 = np.random.randn(data_size) # Plot the velocity field ########################################## # Plot Noise samples X0 ax1.set_title("Noise Samples X₀") ax1.hist( data_x_0.flatten(), bins=100, alpha=0.5, label="Noise Samples X₀", color="tab:orange", density=True, orientation="horizontal", ) ax1.invert_xaxis() ax1.set_ylabel("x") ax1.set_xlabel("density") # Plot Flow Field n_flow_steps = 100 # Set up the flow field histogram parameters img_hist_size = 200 # Narrow view for flow field image flow_field_img_bins = np.zeros((n_flow_steps + 1, img_hist_size)) flow_field_img_x_bin_edges = np.linspace(x_min, x_max, img_hist_size + 1) flow_field_img_x_bin_centers = ( torch.from_numpy((flow_field_img_x_bin_edges[:-1] + flow_field_img_x_bin_edges[1:]) / 2) .float() .to(DEVICE) .unsqueeze(-1) ) # Wide view to compute the pathlines, need to be wide enough because paths can flow out of the narrow view flow_field_for_path_bins = np.zeros((n_flow_steps + 1, img_hist_size)) flow_field_for_path_x_bin_edges = np.linspace(x_min_wide, x_max_wide, img_hist_size + 1) flow_field_for_path_x_bin_centers = ( torch.from_numpy((flow_field_for_path_x_bin_edges[:-1] + flow_field_for_path_x_bin_edges[1:]) / 2) .float() .to(DEVICE) .unsqueeze(-1) ) # Build up the histogram of the reference paths by going over the discretized t-bins # We're building 2 histograms: # - `flow_field_img_bins` narrow view for the flow field image, might have NaNs if there are no samples in a bin # - `flow_field_for_path_bins` wide view for the pathlines, avoid NaNs by setting the velocity to 0 if there are no samples in a bin with torch.inference_mode(): flow_matching_model.eval() for i, t in enumerate(torch.linspace(0, 1, n_flow_steps + 1)): t = t.expand_as(flow_field_img_x_bin_centers).to(DEVICE) # Get the model's prediction for the velocity field at the bin centers flow_field_img_bins[i] = flow_matching_model(t=t, x_t=flow_field_img_x_bin_centers).cpu().numpy().squeeze() flow_field_for_path_bins[i] = ( flow_matching_model(t=t, x_t=flow_field_for_path_x_bin_centers).cpu().numpy().squeeze() ) # Plot the flow field max_abs_flow_field = np.nanmax(np.abs(flow_field_img_bins)) color_norm = matplotlib.colors.TwoSlopeNorm(vmin=-max_abs_flow_field, vcenter=0.0, vmax=max_abs_flow_field) im = ax2.imshow( flow_field_img_bins.T, aspect="auto", origin="lower", extent=[0, 1, x_min, x_max], cmap="coolwarm", norm=color_norm, ) ax2.set_ylim(x_min, x_max) cbar = fig.colorbar( im, ax=[ax1, ax2, ax3], orientation="horizontal", fraction=0.08, aspect=40, pad=0.04, ) cbar.set_label("Velocity field value (red pushes up, blue pulls down)") # Create pathlines from the flow field n_paths = 9 # Initialize tensor to store the flow process paths = np.zeros((n_flow_steps + 1, n_paths)) x_t = torch.linspace(start=x_min_wide, end=x_max_wide, steps=n_paths, device=DEVICE).unsqueeze(-1) paths[0] = x_t.cpu().squeeze().numpy() # Generate the flow process with torch.inference_mode(): flow_matching_model.eval() ts = torch.linspace(0, 1, n_flow_steps + 1) for i in range(n_flow_steps): t = ts[i].expand_as(x_t).to(DEVICE) dt = ts[i + 1] - ts[i] x_t = x_t + flow_matching_model(t=t, x_t=x_t.to(DEVICE)) * dt paths[i + 1] = x_t.cpu().numpy().squeeze() # Plot the pathlines with arrows using quiver, following best practices t_coords = np.linspace(0, 1, n_flow_steps + 1) arrow_stride = 35 # space out arrows for clarity arrow_offset = 7 # offset the arrows to the right for clarity for i in range(n_paths): y = paths[:, i] idx = np.arange(arrow_offset, len(t_coords) - 1, arrow_stride) x0 = t_coords[idx] y0 = y[idx] u = np.diff(t_coords)[idx] v = np.diff(y)[idx] ax2.quiver( x0, y0, u, v, angles="xy", scale_units="xy", scale=0.6, units="inches", width=0.015, headwidth=6, headlength=9, headaxislength=7, pivot="tail", color="dimgray", alpha=0.9, zorder=2, ) # overlay original line on top of arrows for clarity ax2.plot( t_coords, y, linestyle="-", color="dimgray", linewidth=1.5, alpha=0.7, zorder=3, label=f"Sample Path {i + 1}", ) ax2.set_xlabel("$t$: flow step") ax2.set_title(r"Predicted velocity field ${FM}_{\theta}(x_t, t)$ with pathlines") ax2.grid(False) # Plot target data distribution x1 # Sample target data distribution x1 x_t = torch.from_numpy(data_x_0).float().to(DEVICE).unsqueeze(-1) dt: float = 0.01 # Step size for Euler integration n_flow_steps = int(1 / dt) # Generate the flow process with torch.inference_mode(): for i in range(n_flow_steps): t = torch.full_like(x_t, i * dt, device=DEVICE) x_t = x_t + flow_matching_model(t=t, x_t=x_t.to(DEVICE)) * dt data_x_1 = x_t.cpu().numpy().squeeze() ax3.set_title(r"Predicted Data $\hat{X}_1$") ax3.hist( data_x_1.flatten(), bins=100, alpha=0.5, label=r"Predicted Data $\hat{X}_1$", color="tab:blue", density=True, orientation="horizontal", ) ax3.set_xlabel("density") ax3.yaxis.set_label_position("right") ax3.set_ylabel("x") # Also show y-axis values on the right side ax3.yaxis.set_visible(True) ax3.yaxis.set_tick_params(labelright=True) plt.show() del ( fig, ax1, ax2, ax3, x_min, x_max, x_min_wide, x_max_wide, data_size, # fmt: skip data_x_0, data_x_1, dt, n_flow_steps, img_hist_size, flow_field_img_bins, # fmt: skip flow_field_img_x_bin_edges, flow_field_for_path_bins, flow_field_for_path_x_bin_edges, # fmt: skip y, idx, x0, y0, u, i, t, x_t, # fmt: skip ) #

No description has been provided for this image

Sampling from the trained model

At inference time we can sample a starting point $x_0$ from the noise distribution $π_0$ and then use the predicted velocity field ${FM}_{\theta}(x_t, t)$ to iteratively move (integrate) the sample towards a sample $\hat{x}_1$ from the target distribution $π_1$.

The code below starts with noise $ x_0 \sim \mathcal{N}(0, 1)$ and integrates the learned ODE using the simple Euler method . The Euler method is a simple integration method that at each step $t$ takes the velocity field prediction ${FM}_{\theta}(x_t, t)$ at the current position $x_t$ and moves the sample a small step $dt$ in the direction of the velocity field.

In [14]:

# Illustration on how to sample x_1 from x_0 using the learned velocity field nb_steps = 15 path_x = np.zeros(nb_steps + 1) # Array to store the full sampled path t_steps = np.linspace(0, 1, nb_steps + 1) # Steps $t$ in the range [0, 1] # x_0 starting point (Pre-selected here for the example, but ideally x_0 ~ N(0, I x_0 = torch.Tensor([[0.85]]).to(DEVICE) with torch.inference_mode(): flow_matching_model.eval() x_t = x_0 # Initialize the sample at the starting point path_x[0] = x_t.squeeze().cpu().numpy() # Integrate the velocity field using Euler integration from t=0 to t=1 for i in range(nb_steps): t = t_steps[i] # Current step $t$ dt = t_steps[i + 1] - t_steps[i] # Step size t_batch = torch.Tensor([[t]]).to(DEVICE) # Expand the step to a batch dimension # Get the velocity field prediction at the current position and time step and move the sample a small step dt in the direction of the velocity field x_t = x_t + flow_matching_model(t=t_batch, x_t=x_t) * dt path_x[i + 1] = x_t.squeeze().cpu().numpy() display(HTML(pd.DataFrame({"t": t_steps, "x": path_x}).transpose().to_html()))
0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 t x
0.000 0.067 0.133 0.200 0.267 0.333 0.400 0.467 0.533 0.600 0.667 0.733 0.800 0.867 0.933 1.000
0.850 0.805 0.767 0.738 0.719 0.716 0.731 0.769 0.830 0.909 0.999 1.095 1.192 1.288 1.384 1.481

We can illustrate this sampled path in the following animation which shows the integration from the noise sample $x_0$ towards the target distribution $\hat{x}_1$ using the predicted velocity field ${FM}_{\theta}(x_t, t)$ above. The velocity field is visualized as a heatmap where the vertical axis represents the position of the sample $x_t$ and the horizontal axis represents the flow step $t$ going from 0 on the left to 1 on the right. Red means a positive velocity (sample pushed up towards higher $x$) and blue means a negative velocity (sample pulled down towards lower $x$).

Notice that while we trained on straight-line paths, the sampled path it not necessarily a straight line. This is because we don't learn the paths directly but learn the unconditioned velocity field by training on a large set of straight-line reference paths.

In [15]:

# Visualize animation of the flow matching integration (denoising) using Euler integration # Set up the plot fig, (ax1, ax2, ax3) = plt.subplots( 1, 3, figsize=(12, 5), gridspec_kw={"width_ratios": [1, 5, 1]}, sharey=True, dpi=100, constrained_layout=True, ) x_min, x_max = -2.2, 2.2 # Narrow view for flow field image x_all_steps = np.linspace(-3, 3, 1000) # Sample set of noise and target points data_size: int = 100_000 np.random.seed(1) # Set random seeds for reproducibility data_x_0 = np.random.randn(data_size) # Plot the velocity field ########################################## # Plot Noise samples X0 # Plot Noise distribution π₀ ax1.set_title("Noise Distribution π₀") pdf_noise = scipy.stats.norm.pdf(x_all_steps, loc=0, scale=1) ax1.plot(pdf_noise, x_all_steps, label="PDF Noise π₀", color="tab:orange") ax1.fill_betweenx(x_all_steps, pdf_noise, alpha=0.4, color="tab:orange") ax1.invert_xaxis() ax1.set_ylabel("x") ax1.set_xlabel("") # ax11.set_xlim(0, 1) # Plot final distribution x1 ax3.set_title("Target Distribution π₁") pdf_target = mixture_pdf(x=x_all_steps) ax3.plot(pdf_target, x_all_steps, label="PDF Target π₁", color="tab:blue") ax3.fill_betweenx(x_all_steps, pdf_target, alpha=0.4, color="tab:blue") ax3.set_xlabel("") ax3.yaxis.set_label_position("right") ax3.set_ylabel("x") # Also show y-axis values on the right side ax3.yaxis.set_visible(True) ax3.yaxis.set_tick_params(labelright=True) # ax13.set_xlim(0, 1) # Plot Flow Field n_flow_steps = 100 # Set up the flow field histogram parameters img_hist_size = 200 # Narrow view for flow field image flow_field_img_bins = np.zeros((n_flow_steps + 1, img_hist_size)) flow_field_img_x_bin_edges = np.linspace(x_min, x_max, img_hist_size + 1) flow_field_img_x_bin_centers = ( torch.from_numpy((flow_field_img_x_bin_edges[:-1] + flow_field_img_x_bin_edges[1:]) / 2) .float() .to(DEVICE) .unsqueeze(-1) ) # Build up the histogram of the reference paths by going over the discretized t-bins # We're building 2 histograms: # - `flow_field_img_bins` narrow view for the flow field image, might have NaNs if there are no samples in a bin # - `flow_field_for_path_bins` wide view for the pathlines, avoid NaNs by setting the velocity to 0 if there are no samples in a bin with torch.inference_mode(): ts = torch.linspace(0, 1, n_flow_steps + 1, device=DEVICE) for i in range(n_flow_steps + 1): t = ts[i].expand_as(flow_field_img_x_bin_centers).to(DEVICE) # Get the model's prediction for the velocity field at the bin centers flow_field_img_bins[i] = flow_matching_model(t=t, x_t=flow_field_img_x_bin_centers).cpu().numpy().squeeze() # Plot the flow field max_abs_flow_field = np.nanmax(np.abs(flow_field_img_bins)) color_norm = matplotlib.colors.TwoSlopeNorm(vmin=-max_abs_flow_field, vcenter=0.0, vmax=max_abs_flow_field) im = ax2.imshow( flow_field_img_bins.T, aspect="auto", origin="lower", extent=[0, 1, x_min, x_max], cmap="coolwarm", norm=color_norm, ) ax2.set_ylim(x_min, x_max) cbar = fig.colorbar( im, ax=[ax1, ax2, ax3], orientation="horizontal", fraction=0.08, aspect=40, pad=0.04, ) cbar.set_label("Velocity field value (red pushes up, blue pulls down)") # overlay original line on top of arrows for clarity step_draw = ax2.scatter( [0], [x_0.squeeze().cpu().numpy()], color="tab:blue", zorder=4, alpha=0.9, ) (line_draw,) = ax2.plot( [0], [x_0.squeeze().cpu().numpy()], linestyle=":", color="tab:blue", linewidth=1, alpha=0.5, zorder=3, ) text_draw = ax2.text( 0.77, -1.58, f" t = 0.00\nstep = 0\n x = {path_x[0]:.2f}", fontsize=16, color="black", fontfamily="monospace", horizontalalignment="left", verticalalignment="center", bbox={"facecolor": "white", "alpha": 0.5, "pad": 8}, ) ax2.set_xlabel("$t$: flow step") ax2.set_title(r"Euler integration of the predicted velocity field ${FM}_{\theta}(x_t, t)$") ax2.grid(True) ax2.legend( [step_draw], ["Euler Integration Step"], loc="lower left", ) def update_animation(frame: int, step_draw, line_draw, text_draw, nb_steps): """ Update the figure to show an animation of the integration of the velocity field. """ xs = path_x[: frame + 1] ts = t_steps[: frame + 1] step_draw.set_offsets(np.stack([ts, xs]).T) line_draw.set_xdata(ts) line_draw.set_ydata(xs) text_draw.set_text(f" t = {ts[-1]:.2f}\nstep = {min(frame, nb_steps):d}\n x = {xs[-1]:.2f}") return (step_draw, line_draw) # Create figure ani = FuncAnimation( fig=fig, func=functools.partial( update_animation, step_draw=step_draw, line_draw=line_draw, text_draw=text_draw, nb_steps=nb_steps, ), frames=nb_steps + 2, interval=1, ) ani.save(str(ANIMATION_FILE), writer="ffmpeg", fps=3) plt.close(fig) # Display the animation in the notebook # Embed the GIF directly in the notebook by encoding the bytes as base64, this way it should hopefully also be exported with ANIMATION_FILE.open("rb") as f: gif_data = f.read() display( HTML(f""" <video alt="Flow matching path integration" loop="true" autoplay="autoplay" muted> <source type="video/mp4" src="data:video/mp4;base64,{base64.b64encode(gif_data).decode()}"> </video> """) ) del (fig, ax1, ax2, ax3, x_min, x_max, data_size, data_x_0, n_flow_steps, img_hist_size, # fmt: skip flow_field_img_bins, flow_field_img_x_bin_edges, flow_field_img_x_bin_centers, max_abs_flow_field, # fmt: skip color_norm, im, cbar, step_draw, line_draw, text_draw, update_animation, ani, gif_data) # fmt: skip #

We can also take a large sample from the model $\hat{X}_1$ and reconstruct the target distribution $\pi_1$. We'll define a sample function that will generate samples by integrating the learned vector field using Euler integration . We'll then plot the target distribution and the reconstructed samples.

In [16]:

@torch.inference_mode() def sample( n_samples: int, # Number of samples to generate model: FlowMatchingModel, # The flow matching model nb_steps: int, # Number of Euler integration steps ) -> torch.Tensor: """Generates samples by integrating the learned vector field using Euler integration.""" ts = torch.linspace(0, 1, nb_steps + 1, device=DEVICE) x_t = torch.randn(n_samples, data_dim).to(DEVICE) # Sample x_0 ~ N(0, I) for i in range(nb_steps): # Euler integration from t=0 to t=1 (last step happens just before t=1) t = ts[i] # Current step $t$ dt = ts[i + 1] - ts[i] # Step size t_batch = t.expand(n_samples).unsqueeze(-1) # Move the sample a small step dt in the direction of the velocity field x_t = x_t + model(t=t_batch, x_t=x_t) * dt return x_t # Final sample x_1

In [17]:

# Plot data distribution. This is the TARGET distribution (π₁) fig, ax = plt.subplots(1, 1, figsize=(8, 3), constrained_layout=True, dpi=100) # Plot the target distribution π₁ x_all_steps = np.linspace(-2.5, 2.5, 1000) pdf_target = mixture_pdf(x=x_all_steps) ax.plot(x_all_steps, pdf_target, label=r"PDF Target $\pi_1$", color="tab:purple") ax.fill_between(x_all_steps, pdf_target, alpha=0.4, color="tab:purple") # Plot the samples x1_samples = sample(n_samples=100_000, model=flow_matching_model, nb_steps=50).cpu().numpy().flatten() sns.histplot( x=x1_samples, bins=100, color="tab:blue", kde=False, alpha=0.8, ax=ax, stat="density", label=r"Reconstructed samples $\hat{X}_1$", ) ax.legend() ax.set_title(r"Target Distribution ($\pi_1$) vs Reconstructed Samples ($\hat{X}_1$)") ax.set_xlabel("x") ax.set_ylabel("density") plt.show() del fig, ax, x_all_steps, pdf_target, x1_samples #

No description has been provided for this image

As a final illustration, let's illustrate the the path density between the starting noise samples $\hat{X}_0$ and the final reconstructed samples $\hat{X}_1$ by sampling a large number of paths from the noise distribution $\pi_0$ to the target distribution $\pi_1$.

In [18]:

# Path density @torch.inference_mode() def sample_paths( n_samples: int, model: FlowMatchingModel, nb_steps: int, # Number of Euler integration steps ) -> torch.Tensor: """Generates samples by integrating the learned vector field, keeping track of the intermediate steps.""" x_all_steps = torch.zeros(n_samples, nb_steps + 1).to(DEVICE) ts = torch.linspace(0, 1, nb_steps + 1, device=DEVICE) x_t = torch.randn(n_samples, 1).to(DEVICE) # Sample x_0 ~ N(0, I) x_all_steps[:, 0] = x_t.squeeze() for i in range(nb_steps): # Euler integration from t=0 to t=1 t = ts[i] # Current step $t$ dt = ts[i + 1] - ts[i] # Step size t_batch = t.expand(n_samples).unsqueeze(-1) # Expand the step to a batch dimension # Move the sample a small step dt in the direction of the velocity field x_t = x_t + model(t=t_batch, x_t=x_t) * dt x_all_steps[:, i + 1] = x_t.squeeze() return x_all_steps nb_paths = 100_000 nb_steps = 200 paths = sample_paths(n_samples=nb_paths, model=flow_matching_model, nb_steps=nb_steps).cpu().numpy() # Illustration of the sampled reference paths # Set up the plot fig, (ax1, ax2, ax3) = plt.subplots( 1, 3, figsize=(12, 4), gridspec_kw={"width_ratios": [1, 5, 1]}, sharey=True, dpi=100, ) fig.subplots_adjust(wspace=0) x_min, x_max = -2.5, 2.5 # Sample set of noise and target points np.random.seed(1) # Set random seeds for reproducibility torch.manual_seed(1) # Plot the full data distribution ################################## # Plot Noise samples X0 ax1.set_title("Noise Samples X₀") ax1.hist( paths[:, 0], bins=100, alpha=0.5, label="Noise π₀", color="tab:orange", density=True, orientation="horizontal", ) ax1.invert_xaxis() ax1.set_ylabel("x") ax1.set_xlabel("density") # Plot target data distribution x1 ax3.set_title(r"Reconstructed samples $\hat{X}_1$") ax3.hist( paths[:, -1], bins=100, alpha=0.5, label=r"Reconstructed samples $\hat{X}_1$", color="tab:blue", density=True, orientation="horizontal", ) ax3.set_xlabel("density") ax3.yaxis.set_label_position("right") ax3.set_ylabel("x") # Also show y-axis values on the right side ax3.yaxis.set_visible(True) ax3.yaxis.set_tick_params(labelright=True) # Plot path density # Set up the path density histogram parameters img_hist_x_size = 480 path_density_bins = np.zeros((nb_steps + 1, img_hist_x_size)) flow_field_img_x_bin_edges = np.linspace(x_min, x_max, img_hist_x_size + 1) # Build up the histogram of the reference paths by going over the discretized t-bins for i in range(nb_steps + 1): path_density_bins[i] = np.histogram(paths[:, i], bins=flow_field_img_x_bin_edges)[0] im = ax2.imshow(path_density_bins.T, aspect="auto", origin="lower", extent=[0, 1, x_min, x_max], cmap="viridis") ax2.set_xlabel("$t$: flow step") ax2.set_title(r"Path density of paths sampled using Euler integration of ${FM}_{\theta}(x_t, t)$") ax2.grid(False) plt.tight_layout() plt.show() del (fig, ax1, ax2, ax3, x_min, x_max, nb_paths, nb_steps, paths, img_hist_x_size, # fmt: skip path_density_bins, flow_field_img_x_bin_edges, i, im) # fmt: skip #

No description has been provided for this image

In [19]:

# Python package versions used %load_ext watermark %watermark --python %watermark --iversions #
Python implementation: CPython Python version : 3.12.10 IPython version : 9.6.0 torch : 2.8.0+cu128 IPython : 9.6.0 numpy : 2.3.3 scipy : 1.16.2 tqdm : 4.67.1 matplotlib: 3.10.6 pandas : 2.3.3 seaborn : 0.13.2

Originally published on November 1, 2025.

Read Entire Article