Skip to content

Transport Maps

Transport maps (TMs) construct a deterministic transformation between a simple reference distribution (typically standard normal) and a complex target distribution [14], [15]. Once such a map is constructed, generating samples from the target becomes trivial: simply draw samples from the reference and apply the transformation. Moreover, transport maps enable efficient computation of conditional distributions, marginals, and other probabilistic quantities through the geometric structure they impose

Mathematical Formulation

Mathematically, a transport map T:ZΘ is defined as a deterministic coupling between a reference space Zρ(z) and a target space Θπ(θ). Hence, the inverse map T1:ΘZ maps from the target space back to the reference space.

The target distribution is approximated by the so-called pull-back density:

π(θ)T#ρ=ρ(T1(a,θ)) |detT1(a,θ)|.

Triangular Structure

The construction and inversion of the transport map can be greatly simplified by using a triangular structure following the Knothe-Rosenblatt rearrangement [16], [17]. This triangular structure guarantees invertibility and makes the Jacobian determinant straightforward to compute. A triangular map in n dimensions has the form:

T(z)=(T1(z1)T2(z1,z2)T3(z1,z2,z3)Tn(z1,z2,,zn))

A key requirement for the transport map is that each component Tk must be strictly monotonically increasing in its last argument zk (while possibly depending non-monotonically on the earlier arguments z1:k1). This monotonicity constraint ensures invertibility and is enforced through a specialized parameterization.

Integrated Rectifier Parameterization

This monotonicity requirement is commonly achieved using an integrated rectifier parameterization, where each component takes the form:

Tk(z1,,zk;a)=f(z1,,zk1,0;a)+0zkg(kf(z1,,zk1,ξ;a))dξ.

Here, f(z1,,zk;a) is a multivariate polynomial

f(z1,,zk;a)=αAkaαΨα(z1,,zk)

where:

  • Ak is a multi-index set defining which basis functions are included

  • aα are the optimization coefficients

  • Ψα are multivariate basis functions

  • g:RR+ is a rectifier function that maps the derivative of f to a strictly positive value, ensuring monotonicity

Implementation

In UncertaintyQuantification.jl, transport maps are implemented using the TransportMaps.jl package. This package provides the backend implementation, including the construction of basis functions, rectifier functions, and the optimization procedures to determine the map coefficients.

There are two main approaches to constructing transport maps, depending on the available information about the target distribution:

  1. From target density: When an analytical expression for the log-density is available

  2. From target samples: When only samples from the target distribution are available

Map Construction from Target Density

When an analytical expression for the target log-density is available, we determine the map coefficients a by solving an optimization problem that minimizes the Kullback-Leibler (KL) divergence between the target density and the transport map approximation:

minaDKL(T#ρ||π)

The KL divergence is computed as an expected value with respect to the reference measure and approximated using numerical quadrature:

i=1Nwq,i[logπ(T(a,zq,i))log|detT(a,zq,i)|].

Here, wq,i are quadrature weights and zq,i are quadrature points. We use the quadrature schemes defined in TransportMaps.jl: Quadrature methods. Currently available schemes are:

  • Monte Carlo

  • Latin Hypercube

  • Gauss-Hermite, and

  • Sparse Smolyak

Usage

Transport Maps are implemented with the TransportMap struct which is a custom MultivariateDistribution. Hence, a TransportMap is fully compatible with the JointDistribution. To construct a TM from density call mapfromdensity which requires the following inputs:

A JointDistribution with the optimized TransportMap and names of the variables is returned.

The following example demonstrates how to construct a transport map from a given log-density function:

julia
# Define the log-density of the target (banana-shaped distribution)
logtarget(x) = logpdf.(Normal(), x[1]) + logpdf.(Normal(), x[2] .- x[1].^2)

# Create a 2D polynomial map with degree 2
# Defaults: reference=Normal(), rectifier=Softplus(), basis=LinearizedHermiteBasis()
pm = PolynomialMap(2, 2)

# Define 2D quadrature using a Gauss-Hermite tensor product with 3 points per dimension
quad = GaussHermiteWeights(3, 2)

# Construct the target density and fit the map
target = MapTargetDensity(logtarget) # Requires a log-density function
tm_opt = mapfromdensity(pm, target, quad, [:x1, :x2])
JointDistribution{MultivariateDistribution, Symbol}(TransportMap(map=PolynomialMap(2-dimensional, degree=2, basis=LinearizedHermiteBasis, reference=Normal{Float64}(μ=0.0, σ=1.0), rectifier=Softplus, 9 total coefficients), target=MapTargetDensity(backend=AutoForwardDiff()), names=[:x1, :x2]), [:x1, :x2])

Once the map is constructed, we can generate samples from the target distribution and evaluate its probability density function:

julia
# Generate 1000 samples from the target distribution
samples = sample(tm_opt, 1000)

# Evaluate the pdf on a grid for visualization
x1_range = -4:0.1:4
x2_range = -3:0.1:7
pdf_vals = [pdf(tm_opt, [x1, x2]) for x2 in x2_range, x1 in x1_range]

scatter(samples.x1, samples.x2; alpha=0.8, label="TM Samples")
contour!(x1_range, x2_range, pdf_vals)

Bayesian Updating with Transport Maps

The ability to have an analytical expression for the density and the ability to generate samples make transport maps appealing for Bayesian inference applications. For the usage with bayesianupdating see Variational Inference with Transport Maps.

Map Construction from Target Samples

When only samples from the target distribution are available (without an analytical density), the KL-divergence is formulated in reverse, i.e., as an expected value with respect to the target measure rather than the reference measure. In this case, the construction of the transport map allows for density estimation from samples and can provide an alternative to Gaussian Mixture Models.

Usage

This approach is implemented as a TransportMapFromSamples constructed using the mapfromsamples function, which requires the following inputs:

  • transportmap: The map structure to be optimized (e.g., a TransportMaps.PolynomialMap)

  • samples: A DataFrame containing the samples from the target distribution

Similarly, the function returns a JointDistribution with the optimized TransportMapFromSamples and names of the variables.

We consider the same banana-shaped distribution as before. However, we now start with samples generated using a simple acceptance-rejection method:

julia
function generate_samples(n_samples::Int)
    x1_samples = Vector{Float64}(undef, n_samples)
    x2_samples = Vector{Float64}(undef, n_samples)

    count = 0
    while count < n_samples
        x1 = randn() * 2
        x2 = randn() * 3 + x1^2

        if rand() < exp(logtarget([x1, x2])) / 0.4
            count += 1
            x1_samples[count] = x1
            x2_samples[count] = x2
        end
    end

    return DataFrame(x1 = x1_samples, x2 = x2_samples)
end

target_samples = generate_samples(1000) # Returns a DataFrame with samples

We define the transport map structure as a PolynomialMap and fit it using the available samples:

julia
pm_samples = PolynomialMap(2, 2)
tm = mapfromsamples(pm_samples, target_samples)
JointDistribution{MultivariateDistribution, Symbol}(TransportMapFromSamples(map=ComposedMap{LinearMap}(linearmap=LinearMap(2-dimensional, μ: [0.0018022845082078314, 0.6939200873177283], σ: [0.8725796400255229, 1.4788185281454143]), polynomialmap=PolynomialMap(2-dimensional, degree=2, basis=LinearizedHermiteBasis, reference=Normal{Float64}(μ=0.0, σ=1.0), rectifier=Softplus, 9 total coefficients)), names=[:x1, :x2] number_samples=1000), [:x1, :x2])

As in the density-based approach, the fitted map enables sampling and probability density evaluation through the mapping from the reference space.

julia
# Generate 1000 new samples from the fitted transport map
tm_samples = sample(tm, 1000)

# Evaluate the pdf on a grid
pdf_vals_tm = [pdf(tm, [x1, x2]) for x2 in x2_range, x1 in x1_range]

# Visualize both the original and generated samples with the fitted density
scatter(target_samples.x1, target_samples.x2; alpha=0.8, label="Original Samples")
scatter!(tm_samples.x1, tm_samples.x2; alpha=0.8, label="TM Samples")
contour!(x1_range, x2_range, pdf_vals_tm)