Skip to content

Optimization of the Map Coefficients

A crucial step in constructing transport maps is the optimization of the map coefficients, which determine how well the map represents the target distribution. This process can be approached in two distinct ways, depending on the available information about the target distribution [1].

Map-from-density

One way to construct a transport map is to directly optimize its parameters based on the (unnormalized) target density, as shown in Banana: Map from Density. This approach requires access to the target density function and uses quadrature schemes to approximate integrals, as introduced in Quadrature Methods.

Formally, we define the following optimization problem to determine the coefficients a of the parameterized map T:

minai=1Nwq,i[logπ(T(a,zq,i))log|detT(a,zq,i)|]

As noted by [1], this optimization problem is generally non-convex. Specifically, it is only convex when the target density π(x) is log-concave. Especially in Bayesian inference, where the target density represents the posterior density, the function is not log-concave, resulting in a non-convex optimization problem.

In this package, map optimization is performed with the help of Optim.jl, and support a wide range of optimizers and options (such as convergence criteria and printing preferences). Specifically, we can pass our optimize! function the desired optimizer and options. For a full overview of available options, see the Optim.jl configuration documentation.

To perform the optimization of the map coefficients, we call:

julia
optimize!(M::PolynomialMap, target_density::Function, quadrature::AbstractQuadratureWeights;
  optimizer::Optim.AbstractOptimizer = LBFGS(), options::Optim.Options = Optim.Options())

We have to provide the polynomial map M, the target density function, and a quadrature scheme. Optionally, we can specify the optimizer (default is LBFGS()) and options.

Set initial coefficients

As the starting point of the optimization, the map coefficients can be set using setcoefficients!(M, coeffs), where coeffs is a vector of coefficients.

Usage

First we load the packages:

julia
using TransportMaps
using Optim
using Distributions
using Plots

Then, define the target density and quadrature scheme. Here, we use the same banana-shaped density as in Banana: Map from Density:

julia
banana_density(x) = logpdf(Normal(), x[1]) + logpdf(Normal(), x[2] - x[1]^2)
target = MapTargetDensity(banana_density)
quadrature = GaussHermiteWeights(10, 2)

Set optimization options to print the trace every 20 iterations:

julia
opts_trace = Optim.Options(iterations=200, show_trace=true, show_every=20, store_trace=true)
                x_abstol = 0.0
                x_reltol = 0.0
                f_abstol = 0.0
                f_reltol = 0.0
                g_abstol = 1.0e-8
          outer_x_abstol = 0.0
          outer_x_reltol = 0.0
          outer_f_abstol = 0.0
          outer_f_reltol = 0.0
          outer_g_abstol = 1.0e-8
           f_calls_limit = 0
           g_calls_limit = 0
           h_calls_limit = 0
       allow_f_increases = true
 allow_outer_f_increases = true
        successive_f_tol = 1
              iterations = 200
        outer_iterations = 1000
             store_trace = true
           trace_simplex = false
              show_trace = true
          extended_trace = false
           show_warnings = true
              show_every = 20
                callback = nothing
              time_limit = NaN

We will try the following optimizers from Optim.jl, ordered from simplest to most sophisticated:

Gradient Descent

The most basic optimization algorithm, Gradient Descent iteratively moves in the direction of the negative gradient. It is simple and robust, but can be slow to converge, especially for ill-conditioned problems.

julia
M_gd = PolynomialMap(2, 2)
res_gd = optimize!(M_gd, target, quadrature; optimizer=GradientDescent(), options=opts_trace)
println(res_gd)
Iter     Function value   Gradient norm
     0     3.397609e+00     4.804530e-01
 * time: 5.0067901611328125e-5
    20     2.846750e+00     3.244379e-02
 * time: 0.9385058879852295
    40     2.841985e+00     1.426403e-02
 * time: 1.559156894683838
    60     2.840998e+00     1.066274e-02
 * time: 2.2023160457611084
    80     2.840454e+00     8.088299e-03
 * time: 2.8295600414276123
   100     2.840145e+00     6.198896e-03
 * time: 3.4470598697662354
   120     2.839965e+00     4.786295e-03
 * time: 4.100550889968872
   140     2.839858e+00     3.715940e-03
 * time: 4.742078065872192
   160     2.839794e+00     2.896858e-03
 * time: 5.360431909561157
   180     2.839755e+00     2.265401e-03
 * time: 6.013247966766357
   200     2.839731e+00     1.775845e-03
 * time: 6.653687000274658
 * Status: failure (reached maximum number of iterations)

 * Candidate solution
    Final objective value:     2.839731e+00

 * Found with
    Algorithm:     Gradient Descent

 * Convergence measures
    |x - x'|               = 3.70e-04 ≰ 0.0e+00
    |x - x'|/|x'|          = 1.52e-04 ≰ 0.0e+00
    |f(x) - f(x')|         = 9.28e-07 ≰ 0.0e+00
    |f(x) - f(x')|/|f(x')| = 3.27e-07 ≰ 0.0e+00
    |g(x)|                 = 1.78e-03 ≰ 1.0e-08

 * Work counters
    Seconds run:   7  (vs limit Inf)
    Iterations:    200
    f(x) calls:    681
    ∇f(x) calls:   681
    ∇f(x)ᵀv calls: 0

Conjugate Gradient

Conjugate Gradient improves upon basic gradient descent by using conjugate directions, which can accelerate convergence for large-scale or quadratic problems. It requires gradient information but not the Hessian.

julia
M_cg = PolynomialMap(2, 2)
res_cg = optimize!(M_cg, target, quadrature; optimizer=ConjugateGradient(), options=opts_trace)
println(res_cg)
Iter     Function value   Gradient norm
     0     3.397609e+00     4.804530e-01
 * time: 4.1961669921875e-5
    20     2.839693e+00     8.763686e-05
 * time: 0.4612147808074951
    40     2.839693e+00     2.084568e-07
 * time: 0.6847829818725586
 * Status: success

 * Candidate solution
    Final objective value:     2.839693e+00

 * Found with
    Algorithm:     Conjugate Gradient

 * Convergence measures
    |x - x'|               = 9.92e-09 ≰ 0.0e+00
    |x - x'|/|x'|          = 4.01e-09 ≰ 0.0e+00
    |f(x) - f(x')|         = 0.00e+00 ≤ 0.0e+00
    |f(x) - f(x')|/|f(x')| = 0.00e+00 ≤ 0.0e+00
    |g(x)|                 = 1.40e-08 ≰ 1.0e-08

 * Work counters
    Seconds run:   1  (vs limit Inf)
    Iterations:    59
    f(x) calls:    141
    ∇f(x) calls:   84
    ∇f(x)ᵀv calls: 0

Nelder-Mead

Nelder-Mead is a derivative-free optimizer that uses a simplex of points to search for the minimum. It is useful when gradients are unavailable or unreliable, but may be less efficient for high-dimensional or smooth problems.

julia
M_nm = PolynomialMap(2, 2)
res_nm = optimize!(M_nm, target, quadrature; optimizer=NelderMead(), options=opts_trace)
println(res_nm)
Iter     Function value    √(Σ(yᵢ-ȳ)²)/n
------   --------------    --------------
     0     3.385910e+00     6.623032e-03
 * time: 6.389617919921875e-5
    20     3.325296e+00     1.183941e-02
 * time: 0.05990791320800781
    40     3.217172e+00     1.770408e-02
 * time: 0.11934995651245117
    60     3.145470e+00     4.181887e-03
 * time: 0.1869068145751953
    80     3.117131e+00     5.446265e-03
 * time: 0.23119592666625977
   100     3.093431e+00     3.313369e-03
 * time: 0.283214807510376
   120     3.069929e+00     3.715059e-03
 * time: 0.3152339458465576
   140     3.051885e+00     3.568445e-03
 * time: 0.37299084663391113
   160     3.040074e+00     2.744730e-03
 * time: 0.41147899627685547
   180     3.030321e+00     2.994942e-03
 * time: 0.4544689655303955
   200     3.014241e+00     3.253376e-03
 * time: 0.4889087677001953
 * Status: failure (reached maximum number of iterations)

 * Candidate solution
    Final objective value:     3.008308e+00

 * Found with
    Algorithm:     Nelder-Mead

 * Convergence measures
    √(Σ(yᵢ-ȳ)²)/n ≰ 1.0e-08

 * Work counters
    Seconds run:   0  (vs limit Inf)
    Iterations:    200
    f(x) calls:    304

BFGS

BFGS is a quasi-Newton method that builds up an approximation to the Hessian matrix using gradient evaluations. It is generally faster and more robust than gradient descent and conjugate gradient for smooth problems.

julia
M_bfgs = PolynomialMap(2, 2)
res_bfgs = optimize!(M_bfgs, target, quadrature; optimizer=BFGS(), options=opts_trace)
println(res_bfgs)
Iter     Function value   Gradient norm
     0     3.397609e+00     4.804530e-01
 * time: 4.100799560546875e-5
 * Status: success

 * Candidate solution
    Final objective value:     2.839693e+00

 * Found with
    Algorithm:     BFGS

 * Convergence measures
    |x - x'|               = 8.52e-09 ≰ 0.0e+00
    |x - x'|/|x'|          = 3.45e-09 ≰ 0.0e+00
    |f(x) - f(x')|         = 4.44e-16 ≰ 0.0e+00
    |f(x) - f(x')|/|f(x')| = 1.56e-16 ≰ 0.0e+00
    |g(x)|                 = 3.25e-10 ≤ 1.0e-08

 * Work counters
    Seconds run:   0  (vs limit Inf)
    Iterations:    12
    f(x) calls:    45
    ∇f(x) calls:   45
    ∇f(x)ᵀv calls: 0

LBFGS

LBFGS is a limited-memory version of BFGS, making it suitable for large-scale problems where storing the full Hessian approximation is impractical. It is the default optimizer in many scientific computing packages due to its efficiency and reliability.

julia
M_lbfgs = PolynomialMap(2, 2)
res_lbfgs = optimize!(M_lbfgs, target, quadrature; optimizer=LBFGS(), options=opts_trace)
println(res_lbfgs)
Iter     Function value   Gradient norm
     0     3.397609e+00     4.804530e-01
 * time: 9.083747863769531e-5
 * Status: success

 * Candidate solution
    Final objective value:     2.839693e+00

 * Found with
    Algorithm:     L-BFGS

 * Convergence measures
    |x - x'|               = 1.25e-08 ≰ 0.0e+00
    |x - x'|/|x'|          = 5.04e-09 ≰ 0.0e+00
    |f(x) - f(x')|         = 1.33e-15 ≰ 0.0e+00
    |f(x) - f(x')|/|f(x')| = 4.69e-16 ≰ 0.0e+00
    |g(x)|                 = 2.84e-09 ≤ 1.0e-08

 * Work counters
    Seconds run:   0  (vs limit Inf)
    Iterations:    12
    f(x) calls:    45
    ∇f(x) calls:   45
    ∇f(x)ᵀv calls: 0

Finally, we can compare the results by means of variance diagnostic:

julia
samples_z = randn(1000, 2)
v_gd = variance_diagnostic(M_gd, target, samples_z)
v_cg = variance_diagnostic(M_cg, target, samples_z)
v_nm = variance_diagnostic(M_nm, target, samples_z)
v_bfgs = variance_diagnostic(M_bfgs, target, samples_z)
v_lbfgs = variance_diagnostic(M_lbfgs, target, samples_z)

println("Variance diagnostic GradientDescent:   ", v_gd)
println("Variance diagnostic ConjugateGradient: ", v_cg)
println("Variance diagnostic NelderMead:        ", v_nm)
println("Variance diagnostic BFGS:              ", v_bfgs)
println("Variance diagnostic LBFGS:             ", v_lbfgs)
Variance diagnostic GradientDescent:   0.0004373121237130718
Variance diagnostic ConjugateGradient: 0.0003790237201569168
Variance diagnostic NelderMead:        0.10626452172420726
Variance diagnostic BFGS:              0.0003790236351507774
Variance diagnostic LBFGS:             0.0003790236432620713

We can visualize the convergence of all optimizers:

julia
plot([res_gd.trace[i].iteration for i in 1:length(res_gd.trace)], lw=2,
    [res_gd.trace[i].g_norm for i in 1:length(res_gd.trace)], label="GradientDescent")
plot!([res_cg.trace[i].iteration for i in 1:length(res_cg.trace)], lw=2,
    [res_cg.trace[i].g_norm for i in 1:length(res_cg.trace)], label="ConjugateGradient")
plot!([res_nm.trace[i].iteration for i in 1:length(res_nm.trace)], lw=2,
    [res_nm.trace[i].g_norm for i in 1:length(res_nm.trace)], label="NelderMead")
plot!([res_bfgs.trace[i].iteration for i in 1:length(res_bfgs.trace)], lw=2,
    [res_bfgs.trace[i].g_norm for i in 1:length(res_bfgs.trace)], label="BFGS")
plot!([res_lbfgs.trace[i].iteration for i in 1:length(res_lbfgs.trace)], lw=2,
    [res_lbfgs.trace[i].g_norm for i in 1:length(res_lbfgs.trace)], label="LBFGS")
plot!(xaxis=:log, yaxis=:log, xlabel="Iteration", ylabel="Gradient norm",
    title="Convergence of different optimizers", xlims=(1, 200),
    legend=:bottomleft)

It becomes clear, that LBFGS and BFGS are the most efficient optimizers in this case, while Nelder-Mead struggles to keep up.

Map-from-samples

Another strategy of constructing a transport map is to use samples of the target density, as seen in Banana: Map from Samples. The formulation of transport map estimation in this way has the benefit to transform the problem into a convex optimization problem, when reference density is log-concave [1]. Since we can choose the reference density, we can leverage this property to simplify the optimization process.

When the map is constructed from samples, the optimization problem is formulated by minimizing the Kullback-Leibler divergence between the pushforward of the reference density and the empirical distribution of the samples. We denote the transport map by S, which pushes forward the target distribution to the reference distribution. This leads to the following optimization problem:

mina1Mi=1Mlogρ(S(a,xi))log|detS(a,xi)|

where {xi}i=1M are samples from the target distribution, and ρ() is the density of the reference distribution.

To perform the optimization, we can use the same optimize! function as before, but now we pass samples instead of a target density and quadrature scheme. Similarly, we can specify the optimizer and options:

julia
optimize!(M::PolynomialMap, samples::AbstractArray{<:Real};
  optimizer::Optim.AbstractOptimizer = LBFGS(), options::Optim.Options = Optim.Options())

This page was generated using Literate.jl.