Skip to content

Conditional Densities

When constructing a transport map, the triangular structure of the Knothe-Rosenblatt rearrangement provides a systematic way to factorize the joint density into a product of conditional densities:

π(x)=π(x1)π(x2|x1)π(x3|x1,x2)π(xk|x1,,xk1).

In Refs. [1] and [2], the authors show that these conditional densities are obtained sequentially by inverting the map components after one another. We define a transport map as given in Getting Started with TransportMaps.jl.

The conditional density for xk given x1,,xk1 is defined as:

π(xk|x1,,xk1)=ρ(zk)|Tk(zk)zk|1,

where zk is obtained by inverting the first k components of the map, i.e., zk=Tk(x1,,xk)1. Here, ρ(zk) represents the reference density, and Tk is the k-th component of the triangular map.

Similarly, this allows us to sample from a conditional density π(xk|x1,,xk1) by first inverting the map to get zk, and then evaluating xk=Tk(z1,zk).

Setting up a Transport Map

For this example, we'll use the banana distribution as our target density, which is also used in Getting Started with TransportMaps.jl and Banana: Map from Density. The banana distribution is defined as a:

π(x1,x2)=ϕ(x1)ϕ(x2x12),

where ϕ is the standard normal PDF.

We load the packaged and define the banana density function and create a target density object:

julia
using TransportMaps
using Plots
using Distributions

banana_density(x) = logpdf(Normal(), x[1]) + logpdf(Normal(), x[2] - x[1]^2)
target = MapTargetDensity(banana_density)

Define the map and quadrature; and optimize the map:

julia
M = PolynomialMap(2, 2, :normal, Softplus(), HermiteBasis())
quadrature = GaussHermiteWeights(10, 2)

# Optimize the map:
optimize!(M, target, quadrature)

Conditional Density Evaluation

Now we can compute conditional densities. For simplicity, we look at a two-dimensional example, so we are interested in the conditional density π(x2|x1=0.5).

Define the conditioning variable and the variable to evaluate the conditional density:

julia
x₁ = 0.5
x₂ = 0.8
density = conditional_density(M, x₂, x₁)
println("Conditional density π(x₂=$x₂ | x₁=$x₁) = $density")
Conditional density π(x₂=0.8 | x₁=0.5) = 0.3429438549395321

Evaluate conditional density for multiple values:

julia
x₂_values = range(-3, 3, length=100)
densities = conditional_density(M, x₂_values, x₁)

Plot the conditional density:

julia
plot(x₂_values, densities,
    xlabel="x₂", ylabel="π(x₂ | x₁=$x₁)",
    title="Conditional Density π(x₂ | x₁=$x₁)",
    linewidth=2, label="Conditional Density")

Conditional Sampling

We can also sample from the conditional density π(x2|x1). First, we sample z2 in the standard normal space and then evaluate the conditional map:

Single value sampling:

julia
z₂ = randn()
cond_sample = conditional_sample(M, x₁, z₂)
println("Conditional sample for z₂=$z₂: x₂=$cond_sample")
Conditional sample for z₂=0.8182912564716242: x₂=1.0682912560843039

Multiple samples:

julia
z₂_values = randn(10_000)
cond_samples = conditional_sample(M, x₁, z₂_values)

Create a histogram of the conditional samples and overlay the analytical density:

julia
histogram(cond_samples, bins=50, normalize=:pdf, alpha=0.7,
    label="Conditional Samples", xlabel="x₂", ylabel="Density")
plot!(x₂_values, densities, linewidth=2,
    label="Conditional Density")

Comparison with True Conditional Density

Let's compare our transport map's conditional density with the true conditional density of the target distribution:

True conditional density for the banana distribution:

julia
function true_banana_conditional_density(x₂, x₁)
    # For the banana distribution π(x₁, x₂) = N(x₁; 0, 1) * N(x₂ - x₁²; 0, 1)
    # The conditional density π(x₂|x₁) = N(x₂; x₁², 1)
    # This is a normal distribution centered at x₁² with variance 1
    μ_cond = x₁^2
    σ_cond = 1.0
    return pdf(Normal(μ_cond, σ_cond), x₂)
end

Compute true conditional densities:

julia
true_densities = [true_banana_conditional_density(x₂, x₁) for x₂ in x₂_values]

Plot comparison:

julia
plot(x₂_values, densities, linewidth=2, label="TM Conditional",
    xlabel="x₂", ylabel="π(x₂ | x₁=$x₁)")
plot!(x₂_values, true_densities, linewidth=2, linestyle=:dash,
    label="True Conditional")
title!("Transport Map vs True Conditional Density")

Multiple Conditioning Scenarios

Let's explore how the conditional density π(x₂ | x₁) changes as we vary x₁. This shows the nonlinear structure of the banana distribution:

julia
x₁_values = [-0.6, 0.0, 1.0, 2.0]

p = plot(xlabel="x₂", ylabel="π(x₂ | x₁)",
    title="Conditional Densities for Different x₁ Values")

for (i, x₁_val) in enumerate(x₁_values)
    densities_cond = conditional_density(M, x₂_values, x₁_val)
    true_densities_cond = [true_banana_conditional_density(x₂, x₁_val) for x₂ in x₂_values]

    plot!(p, x₂_values, densities_cond, linewidth=2,
        label="TM: x₁=$x₁_val", color=i)
    plot!(p, x₂_values, true_densities_cond, linewidth=2, linestyle=:dash,
        label="True: x₁=$x₁_val", color=i)
end

plot!(p)

Note

For a more comprehensive example with real-world applications, see the Bayesian Inference: Biochemical Oxygen Demand (BOD) Example.


This page was generated using Literate.jl.