Skip to content

Optimization of the Map Coefficients

A crucial step in constructing transport maps is the optimization of the map coefficients, which determine how well the map represents the target distribution. This process can be approached in two distinct ways, depending on the available information about the target distribution [1].

Map-from-density

One way to construct a transport map is to directly optimize its parameters based on the (unnormalized) target density, as shown in Banana: Map from Density. This approach requires access to the target density function and uses quadrature schemes to approximate integrals, as introduced in Quadrature Methods.

Formally, we define the following optimization problem to determine the coefficients a of the parameterized map T:

minai=1Nwq,i[logπ(T(a,zq,i))log|detT(a,zq,i)|]

As noted by [1], this optimization problem is generally non-convex. Specifically, it is only convex when the target density π(x) is log-concave. Especially in Bayesian inference, where the target density represents the posterior density, the function is not log-concave, resulting in a non-convex optimization problem.

In this package, map optimization is performed with the help of Optim.jl, and support a wide range of optimizers and options (such as convergence criteria and printing preferences). Specifically, we can pass our optimize! function the desired optimizer and options. For a full overview of available options, see the Optim.jl configuration documentation.

To perform the optimization of the map coefficients, we call:

julia
optimize!(M::PolynomialMap, target_density::Function, quadrature::AbstractQuadratureWeights;
  optimizer::Optim.AbstractOptimizer = LBFGS(), options::Optim.Options = Optim.Options())

We have to provide the polynomial map M, the target density function, and a quadrature scheme. Optionally, we can specify the optimizer (default is LBFGS()) and options.

Set initial coefficients

As the starting point of the optimization, the map coefficients can be set using setcoefficients!(M, coeffs), where coeffs is a vector of coefficients.

Usage

First we load the packages:

julia
using TransportMaps
using Optim
using Distributions
using Plots

Then, define the target density and quadrature scheme. Here, we use the same banana-shaped density as in Banana: Map from Density:

julia
banana_density(x) = logpdf(Normal(), x[1]) + logpdf(Normal(), x[2] - x[1]^2)
target = MapTargetDensity(banana_density)
quadrature = GaussHermiteWeights(10, 2)

Set optimization options to print the trace every 20 iterations:

julia
opts_trace = Optim.Options(iterations=200, show_trace=true, show_every=20, store_trace=true)
                x_abstol = 0.0
                x_reltol = 0.0
                f_abstol = 0.0
                f_reltol = 0.0
                g_abstol = 1.0e-8
          outer_x_abstol = 0.0
          outer_x_reltol = 0.0
          outer_f_abstol = 0.0
          outer_f_reltol = 0.0
          outer_g_abstol = 1.0e-8
           f_calls_limit = 0
           g_calls_limit = 0
           h_calls_limit = 0
       allow_f_increases = true
 allow_outer_f_increases = true
        successive_f_tol = 1
              iterations = 200
        outer_iterations = 1000
             store_trace = true
           trace_simplex = false
              show_trace = true
          extended_trace = false
           show_warnings = true
              show_every = 20
                callback = nothing
              time_limit = NaN

We will try the following optimizers from Optim.jl, ordered from simplest to most sophisticated:

Gradient Descent

The most basic optimization algorithm, Gradient Descent iteratively moves in the direction of the negative gradient. It is simple and robust, but can be slow to converge, especially for ill-conditioned problems.

julia
M_gd = PolynomialMap(2, 2)
res_gd = optimize!(M_gd, target, quadrature; optimizer=GradientDescent(), options=opts_trace)
println(res_gd)
Iter     Function value   Gradient norm
     0     3.397609e+00     4.804530e-01
 * time: 4.00543212890625e-5
    20     2.846750e+00     3.244379e-02
 * time: 0.9143080711364746
    40     2.841985e+00     1.426403e-02
 * time: 1.5014338493347168
    60     2.840998e+00     1.066274e-02
 * time: 2.09391188621521
    80     2.840454e+00     8.088299e-03
 * time: 2.6986188888549805
   100     2.840145e+00     6.198896e-03
 * time: 3.2951438426971436
   120     2.839965e+00     4.786295e-03
 * time: 3.9086079597473145
   140     2.839858e+00     3.715940e-03
 * time: 4.499942064285278
   160     2.839794e+00     2.896858e-03
 * time: 5.08022403717041
   180     2.839755e+00     2.265401e-03
 * time: 5.690901041030884
   200     2.839731e+00     1.775845e-03
 * time: 6.29883885383606
 * Status: failure (reached maximum number of iterations)

 * Candidate solution
    Final objective value:     2.839731e+00

 * Found with
    Algorithm:     Gradient Descent

 * Convergence measures
    |x - x'|               = 3.70e-04 ≰ 0.0e+00
    |x - x'|/|x'|          = 1.52e-04 ≰ 0.0e+00
    |f(x) - f(x')|         = 9.28e-07 ≰ 0.0e+00
    |f(x) - f(x')|/|f(x')| = 3.27e-07 ≰ 0.0e+00
    |g(x)|                 = 1.78e-03 ≰ 1.0e-08

 * Work counters
    Seconds run:   6  (vs limit Inf)
    Iterations:    200
    f(x) calls:    681
    ∇f(x) calls:   681
    ∇f(x)ᵀv calls: 0

Conjugate Gradient

Conjugate Gradient improves upon basic gradient descent by using conjugate directions, which can accelerate convergence for large-scale or quadratic problems. It requires gradient information but not the Hessian.

julia
M_cg = PolynomialMap(2, 2)
res_cg = optimize!(M_cg, target, quadrature; optimizer=ConjugateGradient(), options=opts_trace)
println(res_cg)
Iter     Function value   Gradient norm
     0     3.397609e+00     4.804530e-01
 * time: 3.981590270996094e-5
    20     2.839693e+00     8.763686e-05
 * time: 0.4423239231109619
    40     2.839693e+00     2.084568e-07
 * time: 0.6498677730560303
 * Status: success

 * Candidate solution
    Final objective value:     2.839693e+00

 * Found with
    Algorithm:     Conjugate Gradient

 * Convergence measures
    |x - x'|               = 9.92e-09 ≰ 0.0e+00
    |x - x'|/|x'|          = 4.01e-09 ≰ 0.0e+00
    |f(x) - f(x')|         = 0.00e+00 ≤ 0.0e+00
    |f(x) - f(x')|/|f(x')| = 0.00e+00 ≤ 0.0e+00
    |g(x)|                 = 1.40e-08 ≰ 1.0e-08

 * Work counters
    Seconds run:   1  (vs limit Inf)
    Iterations:    59
    f(x) calls:    141
    ∇f(x) calls:   84
    ∇f(x)ᵀv calls: 0

Nelder-Mead

Nelder-Mead is a derivative-free optimizer that uses a simplex of points to search for the minimum. It is useful when gradients are unavailable or unreliable, but may be less efficient for high-dimensional or smooth problems.

julia
M_nm = PolynomialMap(2, 2)
res_nm = optimize!(M_nm, target, quadrature; optimizer=NelderMead(), options=opts_trace)
println(res_nm)
Iter     Function value    √(Σ(yᵢ-ȳ)²)/n
------   --------------    --------------
     0     3.385910e+00     6.623032e-03
 * time: 5.412101745605469e-5
    20     3.325296e+00     1.183941e-02
 * time: 0.05845189094543457
    40     3.217172e+00     1.770408e-02
 * time: 0.1181650161743164
    60     3.145470e+00     4.181887e-03
 * time: 0.1720290184020996
    80     3.117131e+00     5.446265e-03
 * time: 0.25172901153564453
   100     3.093431e+00     3.313369e-03
 * time: 0.2879819869995117
   120     3.069929e+00     3.715059e-03
 * time: 0.3351569175720215
   140     3.051885e+00     3.568445e-03
 * time: 0.37532591819763184
   160     3.040074e+00     2.744730e-03
 * time: 0.42562294006347656
   180     3.030321e+00     2.994942e-03
 * time: 0.4643130302429199
   200     3.014241e+00     3.253376e-03
 * time: 0.5006310939788818
 * Status: failure (reached maximum number of iterations)

 * Candidate solution
    Final objective value:     3.008308e+00

 * Found with
    Algorithm:     Nelder-Mead

 * Convergence measures
    √(Σ(yᵢ-ȳ)²)/n ≰ 1.0e-08

 * Work counters
    Seconds run:   1  (vs limit Inf)
    Iterations:    200
    f(x) calls:    304

BFGS

BFGS is a quasi-Newton method that builds up an approximation to the Hessian matrix using gradient evaluations. It is generally faster and more robust than gradient descent and conjugate gradient for smooth problems.

julia
M_bfgs = PolynomialMap(2, 2)
res_bfgs = optimize!(M_bfgs, target, quadrature; optimizer=BFGS(), options=opts_trace)
println(res_bfgs)
Iter     Function value   Gradient norm
     0     3.397609e+00     4.804530e-01
 * time: 3.981590270996094e-5
 * Status: success

 * Candidate solution
    Final objective value:     2.839693e+00

 * Found with
    Algorithm:     BFGS

 * Convergence measures
    |x - x'|               = 8.52e-09 ≰ 0.0e+00
    |x - x'|/|x'|          = 3.45e-09 ≰ 0.0e+00
    |f(x) - f(x')|         = 4.44e-16 ≰ 0.0e+00
    |f(x) - f(x')|/|f(x')| = 1.56e-16 ≰ 0.0e+00
    |g(x)|                 = 3.25e-10 ≤ 1.0e-08

 * Work counters
    Seconds run:   0  (vs limit Inf)
    Iterations:    12
    f(x) calls:    45
    ∇f(x) calls:   45
    ∇f(x)ᵀv calls: 0

LBFGS

LBFGS is a limited-memory version of BFGS, making it suitable for large-scale problems where storing the full Hessian approximation is impractical. It is the default optimizer in many scientific computing packages due to its efficiency and reliability.

julia
M_lbfgs = PolynomialMap(2, 2)
res_lbfgs = optimize!(M_lbfgs, target, quadrature; optimizer=LBFGS(), options=opts_trace)
println(res_lbfgs)
Iter     Function value   Gradient norm
     0     3.397609e+00     4.804530e-01
 * time: 3.504753112792969e-5
 * Status: success

 * Candidate solution
    Final objective value:     2.839693e+00

 * Found with
    Algorithm:     L-BFGS

 * Convergence measures
    |x - x'|               = 1.25e-08 ≰ 0.0e+00
    |x - x'|/|x'|          = 5.04e-09 ≰ 0.0e+00
    |f(x) - f(x')|         = 1.33e-15 ≰ 0.0e+00
    |f(x) - f(x')|/|f(x')| = 4.69e-16 ≰ 0.0e+00
    |g(x)|                 = 2.84e-09 ≤ 1.0e-08

 * Work counters
    Seconds run:   0  (vs limit Inf)
    Iterations:    12
    f(x) calls:    45
    ∇f(x) calls:   45
    ∇f(x)ᵀv calls: 0

Finally, we can compare the results by means of variance diagnostic:

julia
samples_z = randn(1000, 2)
v_gd = variance_diagnostic(M_gd, target, samples_z)
v_cg = variance_diagnostic(M_cg, target, samples_z)
v_nm = variance_diagnostic(M_nm, target, samples_z)
v_bfgs = variance_diagnostic(M_bfgs, target, samples_z)
v_lbfgs = variance_diagnostic(M_lbfgs, target, samples_z)

println("Variance diagnostic GradientDescent:   ", v_gd)
println("Variance diagnostic ConjugateGradient: ", v_cg)
println("Variance diagnostic NelderMead:        ", v_nm)
println("Variance diagnostic BFGS:              ", v_bfgs)
println("Variance diagnostic LBFGS:             ", v_lbfgs)
Variance diagnostic GradientDescent:   0.0004373121237130718
Variance diagnostic ConjugateGradient: 0.0003790237201569168
Variance diagnostic NelderMead:        0.10626452172420726
Variance diagnostic BFGS:              0.0003790236351507774
Variance diagnostic LBFGS:             0.0003790236432620713

We can visualize the convergence of all optimizers:

julia
plot([res_gd.trace[i].iteration for i in 1:length(res_gd.trace)], lw=2,
    [res_gd.trace[i].g_norm for i in 1:length(res_gd.trace)], label="GradientDescent")
plot!([res_cg.trace[i].iteration for i in 1:length(res_cg.trace)], lw=2,
    [res_cg.trace[i].g_norm for i in 1:length(res_cg.trace)], label="ConjugateGradient")
plot!([res_nm.trace[i].iteration for i in 1:length(res_nm.trace)], lw=2,
    [res_nm.trace[i].g_norm for i in 1:length(res_nm.trace)], label="NelderMead")
plot!([res_bfgs.trace[i].iteration for i in 1:length(res_bfgs.trace)], lw=2,
    [res_bfgs.trace[i].g_norm for i in 1:length(res_bfgs.trace)], label="BFGS")
plot!([res_lbfgs.trace[i].iteration for i in 1:length(res_lbfgs.trace)], lw=2,
    [res_lbfgs.trace[i].g_norm for i in 1:length(res_lbfgs.trace)], label="LBFGS")
plot!(xaxis=:log, yaxis=:log, xlabel="Iteration", ylabel="Gradient norm",
    title="Convergence of different optimizers", xlims=(1, 200),
    legend=:bottomleft)

It becomes clear, that LBFGS and BFGS are the most efficient optimizers in this case, while Nelder-Mead struggles to keep up.

Map-from-samples

Another strategy of constructing a transport map is to use samples of the target density, as seen in Banana: Map from Samples. The formulation of transport map estimation in this way has the benefit to transform the problem into a convex optimization problem, when reference density is log-concave [1]. Since we can choose the reference density, we can leverage this property to simplify the optimization process.

When the map is constructed from samples, the optimization problem is formulated by minimizing the Kullback-Leibler divergence between the pushforward of the reference density and the empirical distribution of the samples. We denote the transport map by S, which pushes forward the target distribution to the reference distribution. This leads to the following optimization problem:

mina1Mi=1Mlogρ(S(a,xi))log|detS(a,xi)|

where {xi}i=1M are samples from the target distribution, and ρ() is the density of the reference distribution.

To perform the optimization, we can use the same optimize! function as before, but now we pass samples instead of a target density and quadrature scheme. Similarly, we can specify the optimizer and options:

julia
optimize!(M::PolynomialMap, samples::AbstractArray{<:Real};
  optimizer::Optim.AbstractOptimizer = LBFGS(), options::Optim.Options = Optim.Options())

This page was generated using Literate.jl.