Conditional Densities
When constructing a transport map, the triangular structure of the Knothe-Rosenblatt rearrangement provides a systematic way to factorize the joint density into a product of conditional densities:
In Refs. [1] and [2], the authors show that these conditional densities are obtained sequentially by inverting the map components after one another. We define a transport map as given in Getting Started with TransportMaps.jl.
The conditional density for
where
Similarly, this allows us to sample from a conditional density
Setting up a Transport Map
For this example, we'll use the banana distribution as our target density, which is also used in Getting Started with TransportMaps.jl and Banana: Map from Density. The banana distribution is defined as a:
where
We load the packaged and define the banana density function and create a target density object:
using TransportMaps
using Plots
using Distributions
banana_density(x) = logpdf(Normal(), x[1]) + logpdf(Normal(), x[2] - x[1]^2)
target = MapTargetDensity(banana_density)Define the map and quadrature; and optimize the map:
M = PolynomialMap(2, 2, :normal, Softplus(), HermiteBasis())
quadrature = GaussHermiteWeights(10, 2)
# Optimize the map:
optimize!(M, target, quadrature)Conditional Density Evaluation
Now we can compute conditional densities. For simplicity, we look at a two-dimensional example, so we are interested in the conditional density
Define the conditioning variable and the variable to evaluate the conditional density:
x₁ = 0.5
x₂ = 0.8
density = conditional_density(M, x₂, x₁)
println("Conditional density π(x₂=$x₂ | x₁=$x₁) = $density")Conditional density π(x₂=0.8 | x₁=0.5) = 0.3429438549395321Evaluate conditional density for multiple values:
x₂_values = range(-3, 3, length=100)
densities = conditional_density(M, x₂_values, x₁)Plot the conditional density:
plot(x₂_values, densities,
xlabel="x₂", ylabel="π(x₂ | x₁=$x₁)",
title="Conditional Density π(x₂ | x₁=$x₁)",
linewidth=2, label="Conditional Density")Conditional Sampling
We can also sample from the conditional density
Single value sampling:
z₂ = randn()
cond_sample = conditional_sample(M, x₁, z₂)
println("Conditional sample for z₂=$z₂: x₂=$cond_sample")Conditional sample for z₂=0.8182912564716242: x₂=1.0682912560843039Multiple samples:
z₂_values = randn(10_000)
cond_samples = conditional_sample(M, x₁, z₂_values)Create a histogram of the conditional samples and overlay the analytical density:
histogram(cond_samples, bins=50, normalize=:pdf, alpha=0.7,
label="Conditional Samples", xlabel="x₂", ylabel="Density")
plot!(x₂_values, densities, linewidth=2,
label="Conditional Density")Comparison with True Conditional Density
Let's compare our transport map's conditional density with the true conditional density of the target distribution:
True conditional density for the banana distribution:
function true_banana_conditional_density(x₂, x₁)
# For the banana distribution π(x₁, x₂) = N(x₁; 0, 1) * N(x₂ - x₁²; 0, 1)
# The conditional density π(x₂|x₁) = N(x₂; x₁², 1)
# This is a normal distribution centered at x₁² with variance 1
μ_cond = x₁^2
σ_cond = 1.0
return pdf(Normal(μ_cond, σ_cond), x₂)
endCompute true conditional densities:
true_densities = [true_banana_conditional_density(x₂, x₁) for x₂ in x₂_values]Plot comparison:
plot(x₂_values, densities, linewidth=2, label="TM Conditional",
xlabel="x₂", ylabel="π(x₂ | x₁=$x₁)")
plot!(x₂_values, true_densities, linewidth=2, linestyle=:dash,
label="True Conditional")
title!("Transport Map vs True Conditional Density")Multiple Conditioning Scenarios
Let's explore how the conditional density π(x₂ | x₁) changes as we vary x₁. This shows the nonlinear structure of the banana distribution:
x₁_values = [-0.6, 0.0, 1.0, 2.0]
p = plot(xlabel="x₂", ylabel="π(x₂ | x₁)",
title="Conditional Densities for Different x₁ Values")
for (i, x₁_val) in enumerate(x₁_values)
densities_cond = conditional_density(M, x₂_values, x₁_val)
true_densities_cond = [true_banana_conditional_density(x₂, x₁_val) for x₂ in x₂_values]
plot!(p, x₂_values, densities_cond, linewidth=2,
label="TM: x₁=$x₁_val", color=i)
plot!(p, x₂_values, true_densities_cond, linewidth=2, linestyle=:dash,
label="True: x₁=$x₁_val", color=i)
end
plot!(p)Note
For a more comprehensive example with real-world applications, see the Bayesian Inference: Biochemical Oxygen Demand (BOD) Example.
This page was generated using Literate.jl.