Skip to content

Banana: Map from Samples

This example demonstrates how to use TransportMaps.jl to approximate a "banana" distribution using polynomial transport maps when only samples from the target distribution are available.

Unlike the density-based approach, this method learns the transport map directly from sample data using optimization techniques. This is particularly useful when the target density is unknown or difficult to evaluate [1].

We start with the necessary packages:

julia
using TransportMaps
using Distributions
using LinearAlgebra
using Plots

Generating Target Samples

The banana distribution has the density:

p(x)=ϕ(x1)ϕ(x2x12)

where ϕ is the standard normal PDF.

julia
banana_density(x) = pdf(Normal(), x[1]) * pdf(Normal(), x[2] - x[1]^2)
banana_density (generic function with 1 method)

Set up the log-target function for sampling:

julia
num_samples = 1000

Generate samples using rejection sampling (no external dependencies)

julia
function generate_banana_samples(n_samples::Int)
    samples = Matrix{Float64}(undef, n_samples, 2)

    count = 0
    while count < n_samples
        x1 = randn() * 2
        x2 = randn() * 3 + x1^2

        if rand() < banana_density([x1, x2]) / 0.4
            count += 1
            samples[count, :] = [x1, x2]
        end
    end

    return samples
end

println("Generating samples from banana distribution...")
target_samples = generate_banana_samples(num_samples)
println("Generated $(size(target_samples, 1)) samples")
Generating samples from banana distribution...
Generated 1000 samples

Creating the Transport Map

First, we create a linear map to standardize the samples:

julia
L = LinearMap(target_samples)
LinearMap with 2 dimensions
  μ: [-0.02208366172707738, 0.7072273395172769]
  σ: [0.8616822748295156, 1.39902730576217]

We create a 2-dimensional polynomial transport map with degree 2. For sample-based optimization, we typically start with lower degrees and can increase complexity as needed.

julia
M = PolynomialMap(2, 2, :normal, Softplus())
PolynomialMap:
  Dimensions: 2
  Total coefficients: 9
  Reference density: Distributions.Normal{Float64}(μ=0.0, σ=1.0)
  Maximum degree: 2
  Basis: LinearizedHermiteBasis
  Rectifier: Softplus
  Components:
    Component 1: 3 basis functions
    Component 2: 6 basis functions
  Coefficients: min=0.0, max=0.0, mean=0.0

Optimizing from Samples

The key difference from density-based optimization is that we optimize directly from the sample data without requiring the density function. Inside the optimization the map is arranged s.t. the "forward" direction is from the (unknown) target distribution to the standard normal distribution. Also, we give the linear map to standardize the samples before optimization.

julia
res = optimize!(M, target_samples, L)
Optimizing component 1 / 2
Optimizing component 2 / 2

We can check the optimization results of the first component:

julia
res.optimization_results[1]
 * Status: success

 * Candidate solution
    Final objective value:     4.994997e+02

 * Found with
    Algorithm:     L-BFGS

 * Convergence measures
    |x - x'|               = 7.65e-08 ≰ 0.0e+00
    |x - x'|/|x'|          = 9.98e-08 ≰ 0.0e+00
    |f(x) - f(x')|         = 5.23e-12 ≰ 0.0e+00
    |f(x) - f(x')|/|f(x')| = 1.05e-14 ≰ 0.0e+00
    |g(x)|                 = 3.55e-09 ≤ 1.0e-08

 * Work counters
    Seconds run:   0  (vs limit Inf)
    Iterations:    6
    f(x) calls:    16
    ∇f(x) calls:   16
    ∇f(x)ᵀv calls: 0

We can also check the optimization results of the second component:

julia
res.optimization_results[2]
 * Status: success

 * Candidate solution
    Final objective value:     1.190970e+02

 * Found with
    Algorithm:     L-BFGS

 * Convergence measures
    |x - x'|               = 1.96e-11 ≰ 0.0e+00
    |x - x'|/|x'|          = 1.09e-11 ≰ 0.0e+00
    |f(x) - f(x')|         = 1.42e-14 ≰ 0.0e+00
    |f(x) - f(x')|/|f(x')| = 1.19e-16 ≰ 0.0e+00
    |g(x)|                 = 1.26e-12 ≤ 1.0e-08

 * Work counters
    Seconds run:   1  (vs limit Inf)
    Iterations:    12
    f(x) calls:    35
    ∇f(x) calls:   35
    ∇f(x)ᵀv calls: 0

Finally, we construct a composed map that combines the linear and polynomial maps:

julia
C = ComposedMap(L, M)
ComposedMap{LinearMap} with 2 dimensions:
 linearmap: LinearMap(2-dimensional, μ: [-0.02208366172707738, 0.7072273395172769], σ: [0.8616822748295156, 1.39902730576217])
 polynomialmap: PolynomialMap(2-dimensional, degree=2, basis=LinearizedHermiteBasis, reference=Distributions.Normal{Float64}(μ=0.0, σ=1.0), rectifier=Softplus, 9 total coefficients)

Testing the Map

Let's generate new samples from the banana density and standard normal samples and map them through our optimized transport map to verify the learned distribution:

julia
new_samples = generate_banana_samples(1000)
norm_samples = randn(1000, 2)

Map the samples through our transport map. Note that evaluate now transports from reference to target, i.e. mapped_samples should be standard normal samples:

julia
mapped_samples = evaluate(C, new_samples)

while pushing from the standard normal samples to the target distribution generates new samples from the banana distribution:

julia
mapped_banana_samples = inverse(C, norm_samples)

Visualizing Results

Let's create a scatter plot comparing the original samples with the mapped samples to see how well our transport map learned the distribution:

julia
p11 = scatter(new_samples[:, 1], new_samples[:, 2],
    label="Original Samples", alpha=0.5, color=1,
    title="Original Banana Distribution Samples",
    xlabel="x₁", ylabel="x₂")

scatter!(p11, mapped_banana_samples[:, 1], mapped_banana_samples[:, 2],
    label="Mapped Samples", alpha=0.5, color=2,
    title="Transport Map Generated Samples",
    xlabel="x₁", ylabel="x₂")

plot(p11, size=(600, 400))

and the resulting samples in standard normal space:

julia
p12 = scatter(norm_samples[:, 1], norm_samples[:, 2],
    label="Original Samples", alpha=0.5, color=1,
    title="Original Banana Distribution Samples",
    xlabel="x₁", ylabel="x₂")

scatter!(p12, mapped_samples[:, 1], mapped_samples[:, 2],
    label="Mapped Samples", alpha=0.5, color=2,
    title="Transport Map Generated Samples",
    xlabel="x₁", ylabel="x₂")

plot(p12, size=(600, 400), aspect_ratio=1)

Density Comparison

We can also compare the learned density (via pullback) with the true density:

julia
x₁ = range(-3, 3, length=100)
x₂ = range(-2.5, 4.0, length=100)
-2.5:0.06565656565656566:4.0

True banana density values:

julia
true_density = [banana_density([x1, x2]) for x2 in x₂, x1 in x₁]

Learned density via pullback through the transport map. Note that "pullback" computes the density of the mapped samples in the standard normal space:

julia
learned_density = [pullback(C, [x1, x2]) for x2 in x₂, x1 in x₁]

Create contour plots for comparison:

julia
p3 = contour(x₁, x₂, true_density,
    title="True Banana Density",
    xlabel="x₁", ylabel="x₂",
    colormap=:viridis, levels=10)

p4 = contour(x₁, x₂, learned_density,
    title="Learned Density (Pullback)",
    xlabel="x₁", ylabel="x₂",
    colormap=:viridis, levels=10)

plot(p3, p4, layout=(1, 2), size=(800, 400))

Combined Visualization

Finally, let's create a combined plot showing both the original samples and the density contours:

julia
scatter(target_samples[:, 1], target_samples[:, 2],
    label="Original Samples", alpha=0.3, color=1,
    xlabel="x₁", ylabel="x₂",
    title="Banana Distribution: Samples and Learned Density")

contour!(x₁, x₂, learned_density ./ maximum(learned_density),
    levels=5, colormap=:viridis, alpha=0.8,
    label="Learned Density Contours")

xlims!(-3, 3)
ylims!(-2.5, 4.0)

Quality Assessment

We can assess the quality of our sample-based approximation by comparing statistics of the original and mapped samples:

julia
println("Sample Statistics Comparison:")
println("Original samples - Mean: ", Distributions.mean(target_samples, dims=1))
println("Original samples - Std:  ", Distributions.std(target_samples, dims=1))
println("Mapped samples - Mean:   ", Distributions.mean(mapped_banana_samples, dims=1))
println("Mapped samples - Std:    ", Distributions.std(mapped_banana_samples, dims=1))
Sample Statistics Comparison:
Original samples - Mean: [-0.02208366172707738 0.7072273395172769]
Original samples - Std:  [0.8616822748295156 1.39902730576217]
Mapped samples - Mean:   [-0.01943795214124371 0.7475603280473804]
Mapped samples - Std:    [0.8696295012990757 1.388788772341478]

Interpretation

The sample-based approach learns the transport map by fitting to the empirical distribution of the samples. This method is particularly useful when:

  • The target density is unknown or expensive to evaluate

  • Only sample data is available from experiments or simulations

  • The distribution is complex and difficult to express analytically

The quality of the approximation depends on:

  • The number and quality of the original samples

  • The polynomial degree of the transport map

  • The optimization algorithm and convergence criteria

Further Experiments

You can experiment with:

  • Different polynomial degrees for more complex distributions

  • Different rectifier functions (Softplus(), ShiftedELU())

  • More sophisticated MCMC sampling strategies

  • Cross-validation techniques to assess generalization

  • Different sample sizes to study convergence behavior


This page was generated using Literate.jl.