Massively Parallel Computing in Cosmology with JAX

Wassim Kabalan

François Lanusse, Alexandre Boucaud, Josquin Errard

Goals for This Presentation

  • Understand the Basics of Parallelism: Learn how parallelism works and its importance for high-performance computing.

  • Know When (and When Not) to Parallelize: Discover when it is beneficial to parallelize your code and when it’s better to avoid it.

  • When to Use (and Avoid) Parallelism: Discover the benefits and limitations.

  • Scale Code Using JAX: Explore techniques to scale your computations using JAX for large-scale tasks.

  • Hands-On Tutorials: Apply the concepts discussed with interactive code examples and tutorials.

Background on Parallel Computing with GPUs

How GPUs Work


Massive Thread Count

  • GPUs are designed with thousands of threads.
  • Each core can handle many data elements simultaneously.


The main bottleneck is memory throughput

  • Computation is often only a fraction of total processing time.

Optimizing Throughput with Multiple GPUs:

  • Using multiple GPUs increases overall data throughput, enhancing performance and reducing idle time.

GPU threads

Single GPU throughput

Saturated GPU

Multiple GPUs throughput

Types of Data parallelism

Data Parallelism

  • Simple Parallelism: Each device processes a different subset of data independently.
  • Data Parallelism with Collective Communication:
    • Devices process data in parallel but periodically share results (e.g., for gradient averaging in training).

Task Parallelism

  • Each device handles a different part of the computation.
  • The computation itself is divided between devices.
  • Is generally more complex than data parallelism.

Simple Data Parallelism

Data Parallelism with Communication

Task Parallelism

When Should You Use Parallelism?


Simple cases

  • Data Parallelism (Simple)
    • If your pipeline resembles simple data parallelism, then parallelism is a good idea.
  • Data Parallelism with Simple Collectives
    • Simple collectives (e.g., gradient averaging) can be easily expressed in JAX, allowing devices to share intermediate results.

Complex cases

  • Non-splittable Input (e.g., N-body Simulation Fields) ⚠️
    • When the input is not easily batchable, like a field in an N-body simulation.
  • Task Parallelism ⚠️

    • Useful for long sequential cosmological pipelines where each device handles a unique task in the sequence.
    • More common in training complex models (e.g., LLMs like Gemini or ChatGPT).

When NOT to Use Parallelism

To Keep in Mind

  • Data Fits on a Single GPU
  • Need for Complex Collectives
    • Additional GPUs can add complexity and may not yield enough performance improvement.
  • Task Parallel Model
    • Changing the pipeline or adapting to new devices often requires significant rewrites.

Not fully used GPU

Consider Scaling to multiple GPUs if:

  • You have a single-GPU prototype that’s working but needs significant runtime reduction.
  • Has a significant impact on your results.
    • Using multiple GPUs can significantly decrease execution time.
    • OR You have non-splittable input (e.g., fields in a cosmological simulation) that is crucial for your results.

How to Measure Scaling for Parallel Codes


Strong Scaling

  • Increasing the number of GPUs to reduce runtime for a fixed data size.

Assesses performance as more GPUs are added to a fixed dataset. Danger Zone⚠️: Indicates the distributed code is not scaling efficiently.

Weak Scaling

  • Increasing data size with a fixed number of GPUs.

Tests how the code handles increasing data sizes with a fixed number of GPUs. Danger Zone⚠️: Suggests underlying scaling issues with the code itself.

Environmental Impact of High-Performance Computing

Perlmutter Supercomputer (NERSC)

  • Location: NERSC, Berkeley Lab, California, USA
  • Compute Power: ~170 PFlops
  • GPUs: 7,208 NVIDIA A100 GPUs
  • Power Draw: ~ 3-4 MW


Jean Zay Supercomputer (IDRIS)

  • Location: IDRIS, France
  • Compute Power: ~126 PFlops (FP64), 2.88 EFlops (BF/FP16)
  • GPUs: 3,704 GPUs, including V100, A100, and H100
  • Power Draw: ~1.4 MW on average (as of September, without full H100 usage), leveraging France’s renewable energy grid.

Perlmutter Supercomputer

Jean Zay Supercomputer

Environmental Benefits of Efficient Parallel Computing


  • Higher throughput moves computations closer to peak FLOPS.
  • Operating near peak FLOPS ensures more effective use of computational resources.
  • More computations are achieved per unit of energy, improving energy efficiency.


How to Scale in JAX

Why JAX for Distributed Computing?

  • Distributed Computing Isn’t New:
    • Tools like MPI and OpenMP are used extensively.
    • ML frameworks like TensorFlow and PyTorch offer distributed training.
    • DiffEqFlux.jl Horovod and Ray
  • Familiar and Accessible API:
    • JAX offers a NumPy-like API that is both accessible and intuitive.
    • Python users can leverage parallelism without needing in-depth knowledge of low-level parallel frameworks like MPI.

Key Points

  • Pythonic Scalability: JAX allows you to write scalable, pythonic code that is compiled by XLA for performance.
  • Automatic Differentiation: JAX offers a trivial way to write diffrentiable distributed code.
  • Same code runs on anything from a laptop to multi node supercomputer.

Expressing Parallelism in JAX (Simple parallelism)

Example of computing a gaussian from data Points

import jax
import jax.numpy as jnp
from jax.debug import visualize_array_sharding


def gaussian(x, mean, variance):
  coefficient = 1.0 / jnp.sqrt(2 * jnp.pi * variance)
  exponent = -((x - mean) ** 2) / (2 * variance)
  return coefficient * jnp.exp(exponent)

mean = 0.0
variance = 1.0
x = jnp.linspace(-5, 5, 128)
result = gaussian(x, mean, variance)
visualize_array_sharding(x)
visualize_array_sharding(result)


  GPU 0  
         


  GPU 0  
         

Expressing Parallelism in JAX (Simple parallelism)

Example of computing a gaussian from data Points

assert jax.device_count() == 8

from jax.sharding import PartitionSpec as P, NamedSharding


def gaussian(x, mean, variance):
  coefficient = 1.0 / jnp.sqrt(2 * jnp.pi * variance)
  exponent = -((x - mean) ** 2) / (2 * variance)
  return coefficient * jnp.exp(exponent)

mesh = jax.make_mesh((8,), ('x'))
sharding = NamedSharding(mesh , P('x'))

mean = 0.0
variance = 1.0
x = jnp.linspace(-5, 5, 128)
x = jax.device_put(x, sharding)
result = gaussian(x, mean, variance)
visualize_array_sharding(x)
visualize_array_sharding(result)


  GPU 0    GPU 1    GPU 2    GPU 3    GPU 4    GPU 5    GPU 6    GPU 7  
                                                                        


  GPU 0    GPU 1    GPU 2    GPU 3    GPU 4    GPU 5    GPU 6    GPU 7  
                                                                        

Expressing Parallelism in JAX (Using collectives)

Example of SGD with Gradient averaging (from Jean-Eric’s tutorial)



@jax.jit  
def gradient_descent_step(p, xi, yi, lr=0.1):
  gradients = jax.grad(loss_fun)(p, xi, yi)
  return p - lr * gradients

def minimzer(loss_fun, x_data, y_data, par_init, method, verbose=True):
  ...
# Example usage
par_mini_GD = minimzer(
  loss_fun, 
  x_data=xin, 
  y_data=yin, 
  par_init=jnp.array([0., 0.5]), 
  method=partial(gradient_descent_step, lr=0.5), 
  verbose=True
)

Expressing Parallelism in JAX (Using collectives)

Example of SGD with Gradient averaging (from Jean-Eric’s tutorial)

from jax.experimental.shard_map import shard_map

@jax.jit 
@partial(shard_map, mesh=mesh , in_specs=P('x'), out_spec=P('x'))
def gradient_descent_step(p, xi, yi, lr=0.1):
      per_device_gradients = jax.grad(loss_fun)(p, xi, yi)
      avg_gradients = jax.lax.pmean(per_device_gradients, axis_name='x')
      return p - lr * avg_gradients

def minimzer(loss_fun, x_data, y_data, par_init, method, verbose=True):
     ...
  # Example usage
xin = jax.device_put(xin, sharding)
yin = jax.device_put(yin, sharding)
par_mini_GD = minimzer(
        loss_fun, 
        x_data=xin, 
        y_data=yin, 
        par_init=jnp.array([0., 0.5]), 
        method=partial(gradient_descent_step, lr=0.5), 
        verbose=True
    )

JAX Collective Operations for Parallel Computing

Overview of JAX Collectives in jax.lax.p* Functions


Function Description
lax.pmean Computes the mean of arrays across devices. Useful for averaging gradients in distributed training.
lax.ppermute Permutes data across devices in a specified order. Very useful in cosmological simulations.
lax.all_to_all Exchanges data between devices in a controlled manner. Useful for custom data exchange patterns in distributed computing.
lax.pmax / lax.pmin Computes the element-wise maximum/minimum across devices. Often used in situations where you want to find the maximum or minimum of a distributed dataset.
lax.psum Sums arrays across devices. Commonly used for aggregating gradients or other values in distributed settings.
lax.pall Checks if all values across devices are True. Often used for collective boolean checks across distributed data.

Towards Infinite Scalability with JAX

A Node vs a Supercomputer

Differences in Scale

  • Single GPU:
    • Maximum memory: 80 GB
  • Single Node (Octocore):
    • Maximum memory: 640 GB
    • Contains multiple GPUs (e.g., 8 A100 GPUs) connected via high-speed interconnects.
  • Multi-Node Cluster:
    • Infinite Memory 🎉
    • Connects multiple nodes, allowing scaling across potentially thousands of GPUs.

Multi-Node scalability with Jean Zay

  • Up to 30TB of memory using all 48 nodes of Jean Zay
  • Is enough to run a 15 billion particle simulation.

@credit: NVIDIA

@credit: servethehome.com

Scaling JAX on a Single GPU vs. Multi-Host Setup

Single GPU Code

x = jnp.linspace(-5, 5, 128)
mean = 0.0
variance = 1.0
result = gaussian(x, mean, variance)


Multi-GPU Code

mesh = jax.make_mesh((8,), ('x'))
sharding = NamedSharding(mesh , P('x'))
x = jnp.linspace(-5, 5, 128)
x = jax.device_put(x, sharding)
mean = 0.0
variance = 1.0
result = gaussian(x, mean, variance)


Multi-Host Code

Scaling JAX on a Single GPU vs. Multi-Host Setup

A JAX process per GPU


Requesting a slurm job

$ salloc --gres=gpu:8 --ntasks-per-node=1 --nodes=1


multi-host-jax.py
import jax

mesh = jax.make_mesh((4,), ('x'))
sharding = NamedSharding(mesh , P('x'))

def gaussian(x, mean, variance):
    ...
mean = 0.0
variance = 1.0
x = jnp.linspace(-5, 5, 128)
x = jax.device_put(x, sharding)
result = gaussian(x, mean, variance)
visualize_array_sharding(x)
visualize_array_sharding(result)


Running with srun

$ srun python multi-host-jax.py

Scaling JAX on a Single GPU vs. Multi-Host Setup

A JAX process per GPU


Requesting a slurm job

$ salloc --gres=gpu:8 --ntasks-per-node=8 --nodes=2


multi-host-jax.py
import jax
jax.distributed.initialize()
mesh = jax.make_mesh((16,), ('x'))
sharding = NamedSharding(mesh , P('x'))

def gaussian(x, mean, variance):
    ...
mean = 0.0
variance = 1.0
x = jnp.linspace(-5, 5, 128)
x = jax.device_put(x, sharding) ❌ # DOES NOT WORK
result = gaussian(x, mean, variance)
visualize_array_sharding(x)
visualize_array_sharding(result)


Running with srun

$ srun -n 8 python multi-host-jax.py

Scaling JAX on a Single GPU vs. Multi-Host Setup

A JAX process per GPU


Requesting a slurm job

$ salloc --gres=gpu:8 --ntasks-per-node=8 --nodes=2


multi-host-jax.py
import jax
jax.distributed.initialize()
mesh = jax.make_mesh((16,), ('x'))
sharding = NamedSharding(mesh , P('x'))

def gaussian(x, mean, variance):
    ...
mean = 0
variance = 1.0
x = jnp.linspace(-5, 5, 128)
x = jax.device_put(x, sharding) ❌ # DOES NOT WORK
result = gaussian(x, mean, variance)
visualize_array_sharding(x)
visualize_array_sharding(result)


Running with srun

$ srun -n 8 python multi-host-jax.py

CAUTION ⚠️

  • jax.device_put does not work with multi-host setups.
  • Allocating a jax numpy array does not have the same behavior as single node setups.

Loading Data in JAX in a Multi-Host Setup

A JAX process per GPU


import jax
jax.distributed.initialize()

assert jax.device_count() == 16

x = jnp.linspace(-5, 5, 128)
visualize_array_sharding(x)


  GPU 0  
         
  GPU 2  
         
  GPU 1  
         
  GPU 3  
         
  GPU 14  
         
  GPU 8  
         
  GPU 7  
         
  GPU 5  
         
  GPU 6  
         
  GPU 4  
         
  GPU 15  
         
  GPU 12  
         
  GPU 13  
         
  GPU 11  
         

Loading Data in JAX in a Multi-Host Setup

A JAX process per GPU


multi-host-jax.py
import jax
jax.distributed.initialize()

mesh = jax.make_mesh((16,) , ('x',))
sharding = NamedSharding(mesh , P('x'))

def distributed_linspace(start, stop, num):
    def local_linspace(indx):
        return np.linspace(start, stop, num)[indx]
    return jax.make_array_from_callback(shape=(num,), sharding=sharding,data_callback=local_linspace)

x = distributed_linspace(-5, 5, 128)
if jax.process_index() == 0:
  visualize_array_sharding(x)


 G…   G…   G…   G…   G…   G…  
  0    1    2    3    4    5    6    7    8    9   10   11   12   13   14   15  
                                                                                

Loading Data in JAX in a Multi-Host Setup

A JAX process per GPU


multi-host-jax.py
import jax
jax.distributed.initialize()

mesh = jax.make_mesh((16,) , ('x',))
sharding = NamedSharding(mesh , P('x'))

def distributed_linspace(start, stop, num):
    def local_linspace(indx):
        return np.linspace(start, stop, num)[indx]
    return jax.make_array_from_callback(shape=(num,), sharding=sharding,data_callback=local_linspace)

x = distributed_linspace(-5, 5, 128)
if jax.process_index() == 0:
  visualize_array_sharding(x)
mean = 0.0
variance = 1.0
result = gaussian(x, mean, variance)
if jax.process_index() == 0:
  visualize_array_sharding(result)


 G…   G…   G…   G…   G…   G…  
  0    1    2    3    4    5    6    7    8    9   10   11   12   13   14   15  
                                                                                


 G…   G…   G…   G…   G…   G…  
  0    1    2    3    4    5    6    7    8    9   10   11   12   13   14   15  
                                                                                

Multi-node packages in JAX for Cosmology

Forward Modeling in Cosmology

Weak Lensing Model

  • Prediction:
    • A simulator generates observations from initial conditions and cosmological parameters.
  • Inference:
    • The simulated results are compared with actual observations.
    • Optimal initial conditions and parameters are inferred to closely match the observed data.

Scaling Challenges

  • Software: Existing tools like JaxPM or PMWD already exist.
  • Resolution Today: these differentiable simulators currently support up to 130 million particles \(512^3\).
  • Ideal Resolution: Billion-particle simulations are necessary for high accuracy \(1024^3\) and more.
  • (See Hugo’s and Justine’s talks for more details)
  • We need to scale up to multiple GPUs and nodes to reach the required resolution.

Forward Modeling (Prediction)

Forward Modeling (Inference)

jaxDecomp : Components for Distributed Particle Mesh Simulations

Key Features

  • Distributed 3D FFT
    • Essential for force calculations in large-scale simulations.
  • Halo Exchange for Boundary Conditions
    • Manages boundary conditions or particles leaving the simulation domain.
  • Fully Differentiable
    • Can be used with differentiable simulations.
  • Multi-Node Supports

    • Works seamlessly across multiple nodes.
  • Supports Different Sharding strategies

  • Open-source and available on PyPI

Performance benchmarks of PFFT3D



Strong Scaling

Weak scaling

Halo exchange in distributed simulations

Initial Field

First slice

Second slice

Third slice

Fourth slice

First slice

Second slice

Third slice

Fourth slice

LPT Field

LPT Field

JaxPM 2.0 : Distributed Particle Mesh Simulation


Box size: 1G Mpc/h
Resolution: \(1024^3\)
Number of particles: 1 billion
Number of snapshots: 10
Halo size: 128
Number of GPU used : 32
time taken : 45s







Key Features of JaxPM

  • Multi-Node Performance: Optimized for efficient scaling across nodes.
  • High Resolution: Capable of handling billions of particles for accurate simulations.
  • Differentiable: Compatible with JAX’s automatic differentiation (HMC, NUTS compatible).
  • Open Source: GitHub Badge

Conclusion

Conclusion: Enabling Scalable Cosmology with Distributed JAX

Distributed JAX: A Game-Changer for Cosmology

  • The future is bright for JAX in cosmology 🎉🎉!!

  • JAX has transformed the landscape for scientific computing, enabling large-scale, distributed workflows in a Pythonic environment.

  • Recent advancements (JAX 0.4.3x+) make it straightforward to scale computations across multiple GPUs and nodes.

  • Key Advantages

    • Simplicity: JAX makes it easier than ever to write high-performance code, allowing researchers to focus on science rather than infrastructure.
    • Differentiability: JAX allows seamless differentiation of code running across hundreds of GPUs, enabling advanced inference techniques.
  • The Future Ahead

    • Scaling Inference Models with Distributed jaxPM: By integrating the new distributed jaxPM into existing cosmological inference models, we can achieve unprecedented levels of detail and complexity.
    • Paving the way to fully leverage large-scale survey data for deeper insights into the universe.

Tutorials and Exercises

https://github.com/ASKabalan/Tutorials/blob/main/Cophy2024/Exercises/01_MultiDevice_With_JAX.ipynb

Extra slides

Using shard_map for Advanced Parallelism in JAX

Why shard_map instead of pmap?

  • Limitations of pmap :
    • pmap is effective for simple data parallelism but lacks flexibility in more complex cases.
    • Nested Parallelism: pmap does not handle nested parallelism well.
    • Data Layout Control: pmap does not offer fine-grained control over data layout.
  • Advantages of shard_map:
    • Greater Flexibility: shard_map allows custom parallelism patterns and fine control over data sharding.
    • Nested Parallelism Support: Suitable for complex workloads that require hierarchical parallelism.
    • Direct Device Control: Allows fine-grained control over data distribution and parallel operations.

JAX explaining the weakness of pmap

Example: Nested Parallelism with shard_map


mesh = jax.make_mesh((2,2), ('x', 'y'))
sharding = NamedSharding(mesh , P('x', 'y'))
data = jnp.arange(16).reshape(4, 4) 
sharded_data = lax.with_sharding_constraint(data, sharding)

@partial(jax.pmap, axis_name='x' , devices=mesh.devices[0])
@partial(jax.pmap, axis_name='y', devices=mesh.devices[1])
def sum_and_avg_nested_pmap(x):
      sum_across_x = lax.psum(x, axis_name='x')
      avg_across_y = lax.pmean(sum_across_x, axis_name='y')  
      return avg_across_y

def sum_and_avg_pmap(x):
    sum_across_x = jax.pmap(lambda a: lax.psum(a, axis_name='x'),
                            axis_name='x',
                            devices=mesh.devices[0])(x.reshape(2, 2, 4))
    avg_across_y = jax.pmap(lambda a: lax.pmean(a, axis_name='y'),
                            axis_name='y',
                            devices=mesh.devices[1])(sum_across_x.reshape(2, 4, 2))
    return avg_across_y.reshape(4, 4)

@partial(shard_map , mesh=mesh , in_specs=(P('x', 'y'),), out_specs=P('x'))
def sum_and_avg_shardmap(x):
      sum_across_x = lax.psum(x, axis_name='x')
      avg_across_y = lax.pmean(sum_across_x, axis_name='y')  
      return avg_across_y

Example: Nested Parallelism with shard_map


mesh = jax.make_mesh((2,2), ('x', 'y'))
sharding = NamedSharding(mesh , P('x', 'y'))
data = jnp.arange(16).reshape(4, 4) 
sharded_data = lax.with_sharding_constraint(data, sharding)

@partial(jax.pmap, axis_name='x' , devices=mesh.devices[0])
@partial(jax.pmap, axis_name='y', devices=mesh.devices[1])
def sum_and_avg_nested_pmap(x):
      sum_across_x = lax.psum(x, axis_name='x')
      avg_across_y = lax.pmean(sum_across_x, axis_name='y')  
      return avg_across_y

def sum_and_avg_pmap(x):
    sum_across_x = jax.pmap(lambda a: lax.psum(a, axis_name='x'),
                            axis_name='x',
                            devices=mesh.devices[0])(x.reshape(2, 2, 4))
    avg_across_y = jax.pmap(lambda a: lax.pmean(a, axis_name='y'),
                            axis_name='y',
                            devices=mesh.devices[1])(sum_across_x.reshape(2, 4, 2))
    return avg_across_y.reshape(4, 4)

@partial(shard_map , mesh=mesh , in_specs=(P('x', 'y'),), out_specs=P('x'))
def sum_and_avg_shardmap(x):
      sum_across_x = lax.psum(x, axis_name='x')
      avg_across_y = lax.pmean(sum_across_x, axis_name='y')  
      return avg_across_y

Example: Nested Parallelism with shard_map


mesh = jax.make_mesh((2,2), ('x', 'y'))
sharding = NamedSharding(mesh , P('x', 'y'))
data = jnp.arange(16).reshape(4, 4) 
sharded_data = lax.with_sharding_constraint(data, sharding)

@partial(jax.pmap, axis_name='x' , devices=mesh.devices[0])
@partial(jax.pmap, axis_name='y', devices=mesh.devices[1])
def sum_and_avg_nested_pmap(x):
      sum_across_x = lax.psum(x, axis_name='x')
      avg_across_y = lax.pmean(sum_across_x, axis_name='y')  
      return avg_across_y

def sum_and_avg_pmap(x):
    sum_across_x = jax.pmap(lambda a: lax.psum(a, axis_name='x'),
                            axis_name='x',
                            devices=mesh.devices[0])(x.reshape(2, 2, 4))
    avg_across_y = jax.pmap(lambda a: lax.pmean(a, axis_name='y'),
                            axis_name='y',
                            devices=mesh.devices[1])(sum_across_x.reshape(2, 4, 2))
    return avg_across_y.reshape(4, 4)

@partial(shard_map , mesh=mesh , in_specs=(P('x', 'y'),), out_specs=P('x'))
def sum_and_avg_shardmap(x):
      sum_across_x = lax.psum(x, axis_name='x')
      avg_across_y = lax.pmean(sum_across_x, axis_name='y')  
      return avg_across_y

Example: Nested Parallelism with shard_map


mesh = jax.make_mesh((2,2), ('x', 'y'))
sharding = NamedSharding(mesh , P('x', 'y'))
data = jnp.arange(16).reshape(4, 4) 
sharded_data = lax.with_sharding_constraint(data, sharding)

@partial(jax.pmap, axis_name='x' , devices=mesh.devices[0])
@partial(jax.pmap, axis_name='y', devices=mesh.devices[1])
def sum_and_avg_nested_pmap(x):
      sum_across_x = lax.psum(x, axis_name='x')
      avg_across_y = lax.pmean(sum_across_x, axis_name='y')  
      return avg_across_y

def sum_and_avg_pmap(x):
    sum_across_x = jax.pmap(lambda a: lax.psum(a, axis_name='x'),
                            axis_name='x',
                            devices=mesh.devices[0])(x.reshape(2, 2, 4))
    avg_across_y = jax.pmap(lambda a: lax.pmean(a, axis_name='y'),
                            axis_name='y',
                            devices=mesh.devices[1])(sum_across_x.reshape(2, 4, 2))
    return avg_across_y.reshape(4, 4)

@partial(shard_map , mesh=mesh , in_specs=(P('x', 'y'),), out_specs=P('x'))
def sum_and_avg_shardmap(x):
      sum_across_x = lax.psum(x, axis_name='x')
      avg_across_y = lax.pmean(sum_across_x, axis_name='y')  
      return avg_across_y

Motivation: Cosmology in the Exascale Era


Upcoming Surveys and Massive Data in Cosmology

  • Massive Data Volume: LSST will generate 20 TB of raw data per night over 10 years, totaling 60 PB.
  • Catalog Size: The processed LSST catalog database will reach 15 PB.

Cosmological Models and Pipelines

  • Cosmological simulations and forward modeling can easily reach multiple terabytes in size.
  • We need to scale up cosmological pipelines to handle these data volumes effectively.