Wassim Kabalan
Alexandre Boucaud, François Lanusse
2025-05-01
Summary Statistics Based Inference
\[ p(\theta \mid x_0) \propto p(x_0 \mid \theta) \cdot p(\theta) \]
Treats the simulator as a black box — we only require the ability to simulate \((\theta, x)\) pairs.
No need for an explicit likelihood — instead, use simulation-based inference (SBI) techniques
Often relies on compression to summary statistics \(t = f_\phi(x)\), then approximates \(p(\theta \mid t)\).
Requires a differentiable forward model or simulator.
Treat the simulator as a probabilistic model and perform inference over the joint posterior \(p(x \mid \theta, z)\)
Computationally demanding — but provides exact control over the statistical model.
The goal is to reconstruct the entire latent structure of the Universe — not just compress it into summary statistics. To do this, we jointly infer:
\[ p(\theta, z \mid x) \propto p(x \mid \theta, z) \, p(z \mid \theta) \, p(\theta) \]
Where:
\(\theta\): cosmological parameters (e.g. matter density, dark energy, etc.)
\(z\): latent fields (e.g. initial conditions of the density field)
\(x\): observed data (e.g. convergence maps or galaxy fields)
The challenge of explicit inference
The latent variables \(z\) typically live in very high-dimensional spaces — with millions of degrees of freedom.
Sampling in this space is intractable using traditional inference techniques.
Born Approximation for Convergence
\[ \kappa(\boldsymbol{\theta}) = \int_0^{r_s} dr \, W(r, r_s) \, \delta(\boldsymbol{\theta}, r) \]
Where the lensing weight is:
\[ W(r, r_s) = \frac{3}{2} \, \Omega_m \, \left( \frac{H_0}{c} \right)^2 \, \frac{r}{a(r)} \left(1 - \frac{r}{r_s} \right) \]
Takeaway
Poisson’s equation in Fourier space:
\[
\nabla^2 \phi = -4\pi G \rho
\]
Gravitational force in Fourier space:
\[
\mathbf{f}(\mathbf{k}) = i\mathbf{k}k^{-2}\rho(\mathbf{k})
\]
Each Fourier mode \(\mathbf{k}\) can be computed independently using JAX
Perfect for large-scale, parallel GPU execution
Fourier Transform requires global communication
jax.numpy.fft.fftn
pip install jaxdecomp
The Cloud-In-Cell (CIC) scheme spreads particle mass to nearby grid points using a linear kernel:
In distributed simulations, each subdomain handles a portion of the global domain
Boundary conditions are crucial to ensure physical continuity across subdomain edges
CIC interpolation assigns and reads mass from nearby grid cells — potentially crossing subdomain boundaries
To avoid discontinuities or mass loss, we apply halo exchange:
Without halo exchange, subdomain boundaries introduce visible artifacts in the final field.
This breaks the smoothness of the result — even when each local computation is correct.
To compute gradients through a simulation, we need to track:
Even though each update is simple, autodiff requires storing the full history.
A typical update step looks like:
\[ \begin{aligned} d_{i+1} &= d_i + v_i \, \Delta t \\\\ v_{i+1} &= v_i + F(d_{i+1}) \, \Delta t \end{aligned} \]
Over many steps, the memory demand scales linearly — which becomes a bottleneck for large simulations.
Why This Is a Problem in Practice
Storing intermediate states for autodiff causes memory to scale linearly with the number of steps.
Example costs at full resolution:
This severely limits how many time steps or how large a volume we can afford — even with many GPUs.
Instead of storing the full trajectory…
We use the reverse adjoint method:
Forward Pass (Kick-Drift)
\[ \begin{aligned} d_{i+1} &= d_i + v_i \, \Delta t \\ v_{i+1} &= v_i + F(d_{i+1}) \, \Delta t \end{aligned} \]
Reverse Pass (Adjoint Method)
\[ \begin{aligned} v_i &= v_{i+1} - F(d_{i+1}) \, \Delta t \\ d_i &= d_{i+1} - v_i \, \Delta t \end{aligned} \]
Reverse Adjoint Method
What JAXPM v0.1.5 Supports
Multi-GPU and Multi-Node simulation with distributed domain decomposition (Successfully ran 2048³ on 256 GPUs)
End-to-end differentiability, including force computation and interpolation
Compatible with a custom JAX compatible Reverse Adjoint solver for memory-efficient gradients
Supports full PM Lightcone Weak Lensing
Available on PyPI: pip install jaxpm
Built on top of jaxdecomp
for distributed 3D FFT
What We’ve Achieved So Far
Built a scalable, differentiable N-body simulation pipeline (JAXPM)
Enables forward modeling and sampling in large cosmological volumes, paving the way toward full LSST-scale inference
Preliminary performance:
For discussion
Using Scattering Transform to compress the data with SBI
Using Excplicit Inference for CMB r estimation
Bayesian Deep Learning Workshop , 2025