Wassim Kabalan, François Lanusse, Alexandre Boucaud
\[p(\theta | x ) \propto \underbrace{p(x | \theta)}_{\mathrm{likelihood}} \ \underbrace{p(\theta)}_{\mathrm{prior}}\]
➢ Compute summary statistics based on the 2-point correlation function of the shear field
➢ Run an MCMC chain to recover the posterior distribution of the cosmological parameters, using an analytical likelihood
Limitations
➕ No longer need to compute the likelihood analytically
➖ We need to infer the joint posterior \(p(\theta, z | x)\) before marginalization to get \(p(\theta | x) = \int p(\theta, z | x) \, dz\)
Possible solutions
All require a differentiable fast forward model
➢ We need a fast, differentiable and precise cosmological simulations
➢ Interpolate particles on a grid to estimate mass density
➢ Estimate gravitational force on grid points by FFT
➢ Interpolate forces back on particles
➢ Update particle velocity and positions, and iterate
\(\begin{array}{c}{{\nabla^{2}\phi=-4\pi G\rho}}\\\\ {{f(\vec{k})=i\vec{k}k^{-2}\rho(\vec{k})}}\end{array}\)
import numpy as np
def multiply_and_add(a, b, c):
return np.dot(a, b) + c
a, b, c = np.random.normal(size=(3, 32, 32))
result = multiply_and_add(a, b, c)
import jax
import jax.numpy as jnp
@jax.jit
def multiply_and_add(a, b, c):
return jnp.dot(a, b) + c
key = jax.random.PRNGKey(0)
a, b, c = jax.random.normal(key, (3, 32, 32))
result = multiply_and_add(a, b, c)
gradient = jax.grad(multiply_and_add)(a, b, c)
JAX : Numpy + Autograd + GPU
DifferentiableUniverseInitiative/JaxPM
mesh_shape = [64, 64, 64]
box_size = [64., 64., 64.]
snapshots = jnp.linspace(0.1, 1., 2)
@jax.jit
def run_simulation(omega_c, sigma8):
# Create a small function to generate the matter power spectrum
k = jnp.logspace(-4, 1, 128)
pk = jc.power.linear_matter_power(
jc.Planck15(Omega_c=omega_c, sigma8=sigma8), k)
pk_fn = lambda x: jc.scipy.interpolate.interp(x.reshape([-1]), k, pk
).reshape(x.shape)
# Create initial conditions
initial_conditions = linear_field(mesh_shape,
box_size,
pk_fn,
seed=jax.random.PRNGKey(0))
# Create particles
particles = jnp.stack(jnp.meshgrid(*[jnp.arange(s) for s in mesh_shape]),
axis=-1).reshape([-1, 3])
cosmo = jc.Planck15(Omega_c=omega_c, sigma8=sigma8)
# Initial displacement
dx, p, f = lpt(cosmo, initial_conditions, particles, 0.1)
field = dx + particles
# Evolve the simulation forward
ode_fn = make_ode_fn(mesh_shape)
term = ODETerm(lambda t, state, args: jnp.stack(ode_fn(state, t, args), axis=0))
solver = Dopri5()
stepsize_controller = PIDController(rtol=1e-7, atol=1e-7)
res = diffeqsolve(term, solver, t0=0.1, t1=1., dt0=0.01, y0=jnp.stack([dx, p],axis=0),
args=(cosmo , kvec),
stepsize_controller=stepsize_controller,
saveat=SaveAt(ts=snapshots))
# Return the simulation volume at requested
return field, res, initial_conditions
field, res, initial_conditions = run_simulation(0.25, 0.8)
Is everything solved ?
➢ (Poqueres et al. 2021) : \(64^3\) mesh size, on a 1000 Mpc/h box
➢ (Li et al. 2022) : \(512^3\) mesh size, using pmwd
Scaling
We need a fast, differentiable and Scalable Particle-Mesh simulation
➢ only operation that requires communication is the FFT
➢ only operation that requires communication is the FFT
DifferentiableUniverseInitiative/jaxDecomp
import jax
import jax.numpy as jnp
import jaxdecomp
devices = mesh_utils.create_device_mesh((2, 2))
mesh = jax.sharding.Mesh(devices, axis_names=('x', 'y'))
sharding = jax.sharding.NamedSharding(mesh, P('x', 'y'))
# Create gaussian field distributed across the mesh
field = jax.make_array_from_single_device_arrays(
shape=mesh_shape,
sharding=sharding,
arrays=[jax.random.normal(jax.random.PRNGKey(rank), (512, 512, 1024))])
k_field = jaxdecomp.fft.pfft3d(field)
JaxDecomp features
➢ jaxDecomp supports 2D and 1D decompositions
➢ Works for multi-node FFTs
➢ is differentiable
➢ The package is also provided as a standalone library
Distributed FastPM simulations
➢ Multi host version of JaxPM using jaxDecomp
➢ For a \(2048^3\) LPT simulation ran on 16 GPUs runs in under a second on Jean-Zay super computer
➢ We aim to run a \(8192^3\) Particle-mesh simulation on 160 GPUs
Distruibuted Particle-Mesh simulations for cosmological inference
LSST France, 2024