Adaptive Transport Maps
A key challenge in constructing transport maps is choosing the appropriate parameterization, specifically the multi-index set that defines which polynomial terms to include in the expansion. While fixed parameterizations like total order, no-mixed terms, or diagonal maps (see Choosing a Map Parameterization) can work well in many cases, they may not be optimal for all target distributions.
Adaptive transport maps address this limitation by automatically selecting the most relevant polynomial terms through a greedy enrichment strategy. This approach is particularly useful when:
The structure of the target distribution is unknown a priori
Computational resources are limited and a sparse representation is desired
High-dimensional problems require careful selection of interaction terms
Theory
The adaptive transport map (ATM) algorithm was introduced by [3] and provides a principled approach to construct sparse, triangular transport maps by adaptively enriching the multi-index set based on gradient information.
Greedy Multi-Index Selection
Given a triangular transport map with components
where
Here,
The ATM algorithm starts with a minimal multi-index set (typically containing only the constant term) and iteratively adds terms that maximize the improvement in the objective function. At each iteration
- Identifies candidate terms from the reduced margin of
:
where
For each candidate
, evaluates the gradient of the objective with respect to the coefficient (initialized to zero). Selects the candidate with the largest absolute gradient value:
- Updates the multi-index set:
and optimizes all coefficients.
This greedy selection strategy ensures that at each iteration, the term most likely to improve the objective function is added, leading to sparse and efficient representations.
Cross-Validation for Model Selection
A critical question when using adaptive transport maps is: how many terms should be included? Including too few terms may result in underfitting, while including too many can lead to overfitting.
To address this, the ATM implementation supports k-fold cross-validation. The algorithm:
Splits the data into
folds For each fold, trains the map on
folds and validates on the remaining fold Tracks both training and validation objectives at each iteration
Selects the number of terms that minimizes the average validation objective across folds
This approach provides a data-driven way to balance model complexity and generalization performance.
Usage in TransportMaps.jl
The optimize_adaptive_transportmap function provides interfaces for constructing adaptive transport maps from either samples or a known density function.
Adaptive Maps from Samples
When working with sample data, the simplest approach uses a fixed train-test split to monitor overfitting:
M, histories = optimize_adaptive_transportmap(
samples, # Matrix of samples (n_samples × d)
maxterms, # Vector of maximum terms per component
lm, # Linear map for standardization (default: LinearMap(samples))
rectifier, # Rectifier function (default: Softplus())
basis; # Polynomial basis (default: LinearizedHermiteBasis())
optimizer = LBFGS(),
options = Optim.Options(),
test_fraction = 0.2 # Fraction of data for validation
)For automatic model selection, use k-fold cross-validation. This implementation is based on the original algorithm proposed in [3]:
M, fold_histories, selected_terms, selected_folds = optimize_adaptive_transportmap(
samples, # Matrix of samples (n_samples × d)
maxterms, # Vector of maximum terms per component
k_folds, # Number of folds for cross-validation
lm, # Linear map for standardization (default: LinearMap(samples))
rectifier, # Rectifier function (default: Softplus())
basis; # Polynomial basis (default: LinearizedHermiteBasis())
optimizer = LBFGS(),
options = Optim.Options()
)The k-fold version returns:
M: The final composed transport map trained on all data with the selected number of termsfold_histories: Optimization histories for each component and foldselected_terms: Number of terms selected for each component based on cross-validationselected_folds: Which fold had the best performance for each component
Example from Samples
The usage is demonstrated in the example Banana: Adaptive Transport Map from Samples.
Adaptive Maps from Density
When the target density function is known analytically, adaptive maps can be constructed directly without requiring samples. This approach uses quadrature methods for integration and adaptively enriches the multi-index set across all components simultaneously:
M, history = optimize_adaptive_transportmap(
target, # AbstractMapDensity: Target density to approximate
quadrature, # AbstractQuadratureWeights: Quadrature points and weights
maxterms; # Maximum total number of terms to add across all components
rectifier = Softplus(),
basis = LinearizedHermiteBasis(),
reference_density = Normal(),
optimizer = LBFGS(),
options = Optim.Options(),
validation = nothing # Optional: validation quadrature for model selection
)Key differences from the sample-based approach:
Uses a single global budget of terms (
maxterms) shared across all componentsSelects which component to enrich at each iteration based on gradient information
Supports optional validation using a separate quadrature rule
Returns the map with the best validation KL divergence (if validation is provided)
The returned history contains:
maps: Array of maps at each iterationtrain_objectives: Training KL divergence valuestest_objectives: Validation KL divergence values (if validation provided)gradients: Gradient metrics for all candidates at each iteration
Example from Density
The usage is demonstrated in the example Cubic: Adaptive Transport Map from Density.
References
The implementation of adaptive transport maps is based on the work by Baptista et al.:
Baptista, R., Marzouk, Y., & Zahm, O. (2023). On the Representation and Learning of Monotone Triangular Transport Maps. Foundations of Computational Mathematics. https://doi.org/10.1007/s10208-023-09630-x
Matlab implementation of the original ATM algorithm: https://github.com/baptistar/ATM
This page was generated using Literate.jl.