Differentiable and distributed Particle-Mesh n-body simulations

Wassim Kabalan, François Lanusse, Alexandre Boucaud

Traditional cosmological inference


Bayesian inference in cosmology

  • We need to infer the cosmological parameters \(\theta\) that generated an observartion \(x\)

\[p(\theta | x ) \propto \underbrace{p(x | \theta)}_{\mathrm{likelihood}} \ \underbrace{p(\theta)}_{\mathrm{prior}}\]


➢  Compute summary statistics based on the 2-point correlation function of the shear field


➢  Run an MCMC chain to recover the posterior distribution of the cosmological parameters, using an analytical likelihood


Limitations

  • Simple summary statistics assume Gaussianity
  • The need to compute an analytical likelihood

Beyond 2 point statistics : Full field inference


➕  No longer need to compute the likelihood analytically

➖  We need to infer the joint posterior \(p(\theta, z | x)\) before marginalization to get \(p(\theta | x) = \int p(\theta, z | x) \, dz\)

Possible solutions

  • Hamiltonian Monte Carlo
  • Variational Inference
  • Dimensionality reduction using Fisher Information Matrix

All require a differentiable fast forward model

➢   We need a fast, differentiable and precise cosmological simulations

Fast Particle-mesh simulations



Numerical scheme

➢  Interpolate particles on a grid to estimate mass density

➢  Estimate gravitational force on grid points by FFT

➢  Interpolate forces back on particles

➢  Update particle velocity and positions, and iterate

\(\begin{array}{c}{{\nabla^{2}\phi=-4\pi G\rho}}\\\\ {{f(\vec{k})=i\vec{k}k^{-2}\rho(\vec{k})}}\end{array}\)







  • Fast and simple, at the cost of approximating short range interactions.
  • It is essentially a series of FFTs and interpolations
  • It is differentiable and can run on GPUs

JAX : Automatic differentiation and Hardware acceleration







import numpy as np


def multiply_and_add(a, b, c):
    return np.dot(a, b) + c


a, b, c = np.random.normal(size=(3, 32, 32))
result = multiply_and_add(a, b, c)


import jax
import jax.numpy as jnp


def multiply_and_add(a, b, c):
    return jnp.dot(a, b) + c


key = jax.random.PRNGKey(0)
a, b, c = jax.random.normal(key, (3, 32, 32))

result = multiply_and_add(a, b, c) 

JAX : Automatic differentiation and Hardware acceleration







import jax
import jax.numpy as jnp


def multiply_and_add(a, b, c):
    return jnp.dot(a, b) + c


key = jax.random.PRNGKey(0)
a, b, c = jax.random.normal(key, (3, 32, 32))

result = multiply_and_add(a, b, c) 

JAX : Automatic differentiation and Hardware acceleration







import jax
import jax.numpy as jnp

@jax.jit
def multiply_and_add(a, b, c):
    return jnp.dot(a, b) + c


key = jax.random.PRNGKey(0)
a, b, c = jax.random.normal(key, (3, 32, 32))

result = multiply_and_add(a, b, c) 
gradient = jax.grad(multiply_and_add)(a, b, c)



JAX : Numpy + Autograd + GPU

  • jax.grad uses automatic differentiation to compute the gradient of the function
  • jax.jit compiles the function to run on GPUs

JaxPM : A differentiable Particle-Mesh simulation

FastPM simulation in a few lines of code

GitHub Logo DifferentiableUniverseInitiative/JaxPM

mesh_shape = [64, 64, 64]
box_size = [64., 64., 64.]
snapshots = jnp.linspace(0.1, 1., 2)

@jax.jit
def run_simulation(omega_c, sigma8):
    # Create a small function to generate the matter power spectrum
    k = jnp.logspace(-4, 1, 128)
    pk = jc.power.linear_matter_power(
        jc.Planck15(Omega_c=omega_c, sigma8=sigma8), k)
    pk_fn = lambda x: jc.scipy.interpolate.interp(x.reshape([-1]), k, pk
                                                  ).reshape(x.shape)

    # Create initial conditions
    initial_conditions = linear_field(mesh_shape,
                                      box_size,
                                      pk_fn,
                                      seed=jax.random.PRNGKey(0))

    # Create particles
    particles = jnp.stack(jnp.meshgrid(*[jnp.arange(s) for s in mesh_shape]),
                          axis=-1).reshape([-1, 3])

    cosmo = jc.Planck15(Omega_c=omega_c, sigma8=sigma8)

    # Initial displacement
    dx, p, f = lpt(cosmo, initial_conditions, particles, 0.1)
    field = dx + particles

    # Evolve the simulation forward
    ode_fn = make_ode_fn(mesh_shape)
    term = ODETerm(lambda t, state, args: jnp.stack(ode_fn(state, t, args), axis=0))
    solver = Dopri5()

    stepsize_controller = PIDController(rtol=1e-7, atol=1e-7)
    res = diffeqsolve(term, solver, t0=0.1, t1=1., dt0=0.01, y0=jnp.stack([dx, p],axis=0), 
                       args=(cosmo , kvec),
                       stepsize_controller=stepsize_controller,
                       saveat=SaveAt(ts=snapshots))

    # Return the simulation volume at requested
    return field, res, initial_conditions

field, res, initial_conditions = run_simulation(0.25, 0.8)

Is everything solved ?

Fast Particle-mesh scaling

Current FastPM implementation

➢  (Poqueres et al. 2021) : \(64^3\) mesh size, on a 1000 Mpc/h box

➢  (Li et al. 2022) : \(512^3\) mesh size, using pmwd

Initial Conditions with a 1024 mesh

Initial Conditions with a 512 mesh

Initial Conditions with a 256 mesh

Initial Conditions with a 128 mesh

Initial Conditions with a 64 mesh

Power spectrum comparison

Final field with a 1024 mesh

Final field with a 512 mesh

Final field with a 256 mesh

Final field with a 128 mesh

Final field with a 64 mesh

Final field with a 1024 mesh

Scaling

We need a fast, differentiable and Scalable Particle-Mesh simulation

Scaling on modern hardware

Size of a FastPM simulation

Scaling on multiple GPUs

Single GPU (80GB)

Single Node (8x80GB)

Muti Node ( \(\infty\) )

Distributed Fast Fourier Transform

➢  only operation that requires communication is the FFT


Jaxdecomp

import jax
import jax.numpy as jnp

field = jax.random.normal(jax.random.PRNGKey(0), (1024, 1024, 1024))
k_field = jnp.fft.fftn(field)

Distributed Fast Fourier Transform

➢  only operation that requires communication is the FFT


Jaxdecomp

GitHub Logo DifferentiableUniverseInitiative/jaxDecomp

import jax
import jax.numpy as jnp
import jaxdecomp

devices = mesh_utils.create_device_mesh((2, 2))
mesh = jax.sharding.Mesh(devices, axis_names=('x', 'y'))
sharding = jax.sharding.NamedSharding(mesh, P('x', 'y'))

# Create gaussian field distributed across the mesh
field = jax.make_array_from_single_device_arrays(
    shape=mesh_shape,
    sharding=sharding,
    arrays=[jax.random.normal(jax.random.PRNGKey(rank), (512, 512, 1024))])

k_field = jaxdecomp.fft.pfft3d(field)

JaxDecomp features

➢  jaxDecomp supports 2D and 1D decompositions

➢  Works for multi-node FFTs

➢  is differentiable

➢  The package is also provided as a standalone library

Scaling of Distributed FFT operations

Halo exchange in distributed simulations

Initial Field

First slice

Second slice

Third slice

Fourth slice

First slice

Second slice

Third slice

Fourth slice

LPT Field

LPT Field
from jaxdecomp import halo_exchange

halo_size = 128
field = halo_exchange(field, halo_extent=halo_size)

Distributed JaxPM Particle-Mesh simulations

Distributed FastPM simulations

➢  Multi host version of JaxPM using jaxDecomp

➢  For a \(2048^3\) LPT simulation ran on 16 GPUs runs in under a second on Jean-Zay super computer

➢  We aim to run a \(8192^3\) Particle-mesh simulation on 160 GPUs

Conclusion





Distruibuted Particle-Mesh simulations for cosmological inference

  • A shift from analytical likelihoods to full field inference
    • The need for fast differentiable simulators
    • Particle-Mesh as simulators for full field inference
    • Distributed fourrier transforms that work on multi-node HPC using jaxDecomp
    • Highly scalable LPT simulations using JaxPM
  • Still subject to some challenges
    • Some issues with the ODE solving step
    • Only Euler gives decent results.