Skip to content

Adaptive Transport Maps

A key challenge in constructing transport maps is choosing the appropriate parameterization, specifically the multi-index set that defines which polynomial terms to include in the expansion. While fixed parameterizations like total order, no-mixed terms, or diagonal maps (see Choosing a Map Parameterization) can work well in many cases, they may not be optimal for all target distributions.

Adaptive transport maps address this limitation by automatically selecting the most relevant polynomial terms through a greedy enrichment strategy. This approach is particularly useful when:

  • The structure of the target distribution is unknown a priori

  • Computational resources are limited and a sparse representation is desired

  • High-dimensional problems require careful selection of interaction terms

Theory

The adaptive transport map (ATM) algorithm was introduced by [3] and provides a principled approach to construct sparse, triangular transport maps by adaptively enriching the multi-index set based on gradient information.

Greedy Multi-Index Selection

Given a triangular transport map with components Tk:RkR for k=1,,d, each component is parameterized by a polynomial expansion:

Tk(z1,,zk;a)=f(z1,,zk1,0;a)+0zkg(kf(z1,,zk1,ξ;a))dξ

where f is a multivariate polynomial:

f(z1,,zk;a)=αAkaαΨα(z1,,zk)

Here, Ak is the multi-index set that determines which terms are included.

The ATM algorithm starts with a minimal multi-index set (typically containing only the constant term) and iteratively adds terms that maximize the improvement in the objective function. At each iteration t, given the current multi-index set At, the algorithm:

  1. Identifies candidate terms from the reduced margin of At:
ARM(At)={αAt:αeiAt for all i with αi>0}

where ei is the i-th standard basis vector.

  1. For each candidate αARM(At), evaluates the gradient of the objective with respect to the coefficient aα (initialized to zero).

  2. Selects the candidate with the largest absolute gradient value:

α+=argmaxαARM(At)|Jaα|
  1. Updates the multi-index set: At+1=At{α+} and optimizes all coefficients.

This greedy selection strategy ensures that at each iteration, the term most likely to improve the objective function is added, leading to sparse and efficient representations.

Cross-Validation for Model Selection

A critical question when using adaptive transport maps is: how many terms should be included? Including too few terms may result in underfitting, while including too many can lead to overfitting.

To address this, the ATM implementation supports k-fold cross-validation. The algorithm:

  1. Splits the data into k folds

  2. For each fold, trains the map on k1 folds and validates on the remaining fold

  3. Tracks both training and validation objectives at each iteration

  4. Selects the number of terms that minimizes the average validation objective across folds

This approach provides a data-driven way to balance model complexity and generalization performance.

Usage in TransportMaps.jl

The optimize_adaptive_transportmap function provides interfaces for constructing adaptive transport maps from either samples or a known density function.

Adaptive Maps from Samples

When working with sample data, the simplest approach uses a fixed train-test split to monitor overfitting:

julia
M, histories = optimize_adaptive_transportmap(
    samples,             # Matrix of samples (n_samples × d)
    maxterms,            # Vector of maximum terms per component
    lm,                  # Linear map for standardization (default: LinearMap(samples))
    rectifier,           # Rectifier function (default: Softplus())
    basis;               # Polynomial basis (default: LinearizedHermiteBasis())
    optimizer = LBFGS(),
    options = Optim.Options(),
    test_fraction = 0.2  # Fraction of data for validation
)

For automatic model selection, use k-fold cross-validation. This implementation is based on the original algorithm proposed in [3]:

julia
M, fold_histories, selected_terms, selected_folds = optimize_adaptive_transportmap(
    samples,             # Matrix of samples (n_samples × d)
    maxterms,            # Vector of maximum terms per component
    k_folds,             # Number of folds for cross-validation
    lm,                  # Linear map for standardization (default: LinearMap(samples))
    rectifier,           # Rectifier function (default: Softplus())
    basis;               # Polynomial basis (default: LinearizedHermiteBasis())
    optimizer = LBFGS(),
    options = Optim.Options()
)

The k-fold version returns:

  • M: The final composed transport map trained on all data with the selected number of terms

  • fold_histories: Optimization histories for each component and fold

  • selected_terms: Number of terms selected for each component based on cross-validation

  • selected_folds: Which fold had the best performance for each component

Example from Samples

The usage is demonstrated in the example Banana: Adaptive Transport Map from Samples.

Adaptive Maps from Density

When the target density function is known analytically, adaptive maps can be constructed directly without requiring samples. This approach uses quadrature methods for integration and adaptively enriches the multi-index set across all components simultaneously:

julia
M, history = optimize_adaptive_transportmap(
    target,              # AbstractMapDensity: Target density to approximate
    quadrature,          # AbstractQuadratureWeights: Quadrature points and weights
    maxterms;            # Maximum total number of terms to add across all components
    rectifier = Softplus(),
    basis = LinearizedHermiteBasis(),
    reference_density = Normal(),
    optimizer = LBFGS(),
    options = Optim.Options(),
    validation = nothing  # Optional: validation quadrature for model selection
)

Key differences from the sample-based approach:

  • Uses a single global budget of terms (maxterms) shared across all components

  • Selects which component to enrich at each iteration based on gradient information

  • Supports optional validation using a separate quadrature rule

  • Returns the map with the best validation KL divergence (if validation is provided)

The returned history contains:

  • maps: Array of maps at each iteration

  • train_objectives: Training KL divergence values

  • test_objectives: Validation KL divergence values (if validation provided)

  • gradients: Gradient metrics for all candidates at each iteration

Example from Density

The usage is demonstrated in the example Cubic: Adaptive Transport Map from Density.

References

The implementation of adaptive transport maps is based on the work by Baptista et al.:


This page was generated using Literate.jl.