Wassim Kabalan
François Lanusse, Alexandre Boucaud, Josquin Errard
Task Parallelism ⚠️
Consider Scaling to multiple GPUs if:
Assesses performance as more GPUs are added to a fixed dataset. Danger Zone⚠️: Indicates the distributed code is not scaling efficiently.
Tests how the code handles increasing data sizes with a fixed number of GPUs. Danger Zone⚠️: Suggests underlying scaling issues with the code itself.
Environmental Benefits of Efficient Parallel Computing
Key Points
def gaussian(x, mean, variance):
coefficient = 1.0 / jnp.sqrt(2 * jnp.pi * variance)
exponent = -((x - mean) ** 2) / (2 * variance)
return coefficient * jnp.exp(exponent)
mean = 0.0
variance = 1.0
x = jnp.linspace(-5, 5, 128)
result = gaussian(x, mean, variance)
visualize_array_sharding(x)
visualize_array_sharding(result)
GPU 0
GPU 0
def gaussian(x, mean, variance):
coefficient = 1.0 / jnp.sqrt(2 * jnp.pi * variance)
exponent = -((x - mean) ** 2) / (2 * variance)
return coefficient * jnp.exp(exponent)
mesh = jax.make_mesh((8,), ('x'))
sharding = NamedSharding(mesh , P('x'))
mean = 0.0
variance = 1.0
x = jnp.linspace(-5, 5, 128)
x = jax.device_put(x, sharding)
result = gaussian(x, mean, variance)
visualize_array_sharding(x)
visualize_array_sharding(result)
GPU 0 GPU 1 GPU 2 GPU 3 GPU 4 GPU 5 GPU 6 GPU 7
GPU 0 GPU 1 GPU 2 GPU 3 GPU 4 GPU 5 GPU 6 GPU 7
@jax.jit
def gradient_descent_step(p, xi, yi, lr=0.1):
gradients = jax.grad(loss_fun)(p, xi, yi)
return p - lr * gradients
def minimzer(loss_fun, x_data, y_data, par_init, method, verbose=True):
...
# Example usage
par_mini_GD = minimzer(
loss_fun,
x_data=xin,
y_data=yin,
par_init=jnp.array([0., 0.5]),
method=partial(gradient_descent_step, lr=0.5),
verbose=True
)
from jax.experimental.shard_map import shard_map
@jax.jit
@partial(shard_map, mesh=mesh , in_specs=P('x'), out_spec=P('x'))
def gradient_descent_step(p, xi, yi, lr=0.1):
per_device_gradients = jax.grad(loss_fun)(p, xi, yi)
avg_gradients = jax.lax.pmean(per_device_gradients, axis_name='x')
return p - lr * avg_gradients
def minimzer(loss_fun, x_data, y_data, par_init, method, verbose=True):
...
# Example usage
xin = jax.device_put(xin, sharding)
yin = jax.device_put(yin, sharding)
par_mini_GD = minimzer(
loss_fun,
x_data=xin,
y_data=yin,
par_init=jnp.array([0., 0.5]),
method=partial(gradient_descent_step, lr=0.5),
verbose=True
)
jax.lax.p*
FunctionsFunction | Description |
---|---|
lax.pmean |
Computes the mean of arrays across devices. Useful for averaging gradients in distributed training. |
lax.ppermute |
Permutes data across devices in a specified order. Very useful in cosmological simulations. |
lax.all_to_all |
Exchanges data between devices in a controlled manner. Useful for custom data exchange patterns in distributed computing. |
lax.pmax / lax.pmin |
Computes the element-wise maximum/minimum across devices. Often used in situations where you want to find the maximum or minimum of a distributed dataset. |
lax.psum |
Sums arrays across devices. Commonly used for aggregating gradients or other values in distributed settings. |
lax.pall |
Checks if all values across devices are True . Often used for collective boolean checks across distributed data. |
Multi-Node scalability with Jean Zay
multi-host-jax.py
import jax
mesh = jax.make_mesh((4,), ('x'))
sharding = NamedSharding(mesh , P('x'))
def gaussian(x, mean, variance):
...
mean = 0.0
variance = 1.0
x = jnp.linspace(-5, 5, 128)
x = jax.device_put(x, sharding)
result = gaussian(x, mean, variance)
visualize_array_sharding(x)
visualize_array_sharding(result)
multi-host-jax.py
import jax
jax.distributed.initialize()
mesh = jax.make_mesh((16,), ('x'))
sharding = NamedSharding(mesh , P('x'))
def gaussian(x, mean, variance):
...
mean = 0.0
variance = 1.0
x = jnp.linspace(-5, 5, 128)
x = jax.device_put(x, sharding) ❌ # DOES NOT WORK
result = gaussian(x, mean, variance)
visualize_array_sharding(x)
visualize_array_sharding(result)
multi-host-jax.py
import jax
jax.distributed.initialize()
mesh = jax.make_mesh((16,), ('x'))
sharding = NamedSharding(mesh , P('x'))
def gaussian(x, mean, variance):
...
mean = 0
variance = 1.0
x = jnp.linspace(-5, 5, 128)
x = jax.device_put(x, sharding) ❌ # DOES NOT WORK
result = gaussian(x, mean, variance)
visualize_array_sharding(x)
visualize_array_sharding(result)
CAUTION ⚠️
jax.device_put
does not work with multi-host setups.import jax
jax.distributed.initialize()
assert jax.device_count() == 16
x = jnp.linspace(-5, 5, 128)
visualize_array_sharding(x)
GPU 0
GPU 2
GPU 1
GPU 3
GPU 14
GPU 8
GPU 7
GPU 5
GPU 6
GPU 4
GPU 15
GPU 12
GPU 13
GPU 11
multi-host-jax.py
import jax
jax.distributed.initialize()
mesh = jax.make_mesh((16,) , ('x',))
sharding = NamedSharding(mesh , P('x'))
def distributed_linspace(start, stop, num):
def local_linspace(indx):
return np.linspace(start, stop, num)[indx]
return jax.make_array_from_callback(shape=(num,), sharding=sharding,data_callback=local_linspace)
x = distributed_linspace(-5, 5, 128)
if jax.process_index() == 0:
visualize_array_sharding(x)
… … … … … … … … … … G… G… G… G… G… G… 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
multi-host-jax.py
import jax
jax.distributed.initialize()
mesh = jax.make_mesh((16,) , ('x',))
sharding = NamedSharding(mesh , P('x'))
def distributed_linspace(start, stop, num):
def local_linspace(indx):
return np.linspace(start, stop, num)[indx]
return jax.make_array_from_callback(shape=(num,), sharding=sharding,data_callback=local_linspace)
x = distributed_linspace(-5, 5, 128)
if jax.process_index() == 0:
visualize_array_sharding(x)
mean = 0.0
variance = 1.0
result = gaussian(x, mean, variance)
if jax.process_index() == 0:
visualize_array_sharding(result)
… … … … … … … … … … G… G… G… G… G… G… 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
… … … … … … … … … … G… G… G… G… G… G… 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
Scaling Challenges
Multi-Node Supports
Supports Different Sharding strategies
Open-source and available on PyPI
PFFT3D
Box size: 1G Mpc/h
Resolution: \(1024^3\)
Number of particles: 1 billion
Number of snapshots: 10
Halo size: 128
Number of GPU used : 32
time taken : 45s
Distributed JAX: A Game-Changer for Cosmology
The future is bright for JAX in cosmology 🎉🎉!!
JAX has transformed the landscape for scientific computing, enabling large-scale, distributed workflows in a Pythonic environment.
Recent advancements (JAX 0.4.3x+) make it straightforward to scale computations across multiple GPUs and nodes.
Key Advantages
The Future Ahead
jaxPM
: By integrating the new distributed jaxPM
into existing cosmological inference models, we can achieve unprecedented levels of detail and complexity.Tutorials and Exercises
https://github.com/ASKabalan/Tutorials/blob/main/Cophy2024/Exercises/01_MultiDevice_With_JAX.ipynb
shard_map
for Advanced Parallelism in JAXshard_map
instead of pmap
?pmap
:
pmap
is effective for simple data parallelism but lacks flexibility in more complex cases.pmap
does not handle nested parallelism well.pmap
does not offer fine-grained control over data layout.shard_map
:
shard_map
allows custom parallelism patterns and fine control over data sharding.shard_map
mesh = jax.make_mesh((2,2), ('x', 'y'))
sharding = NamedSharding(mesh , P('x', 'y'))
data = jnp.arange(16).reshape(4, 4)
sharded_data = lax.with_sharding_constraint(data, sharding)
@partial(jax.pmap, axis_name='x' , devices=mesh.devices[0])
@partial(jax.pmap, axis_name='y', devices=mesh.devices[1])
def sum_and_avg_nested_pmap(x):
sum_across_x = lax.psum(x, axis_name='x')
avg_across_y = lax.pmean(sum_across_x, axis_name='y')
return avg_across_y
def sum_and_avg_pmap(x):
sum_across_x = jax.pmap(lambda a: lax.psum(a, axis_name='x'),
axis_name='x',
devices=mesh.devices[0])(x.reshape(2, 2, 4))
avg_across_y = jax.pmap(lambda a: lax.pmean(a, axis_name='y'),
axis_name='y',
devices=mesh.devices[1])(sum_across_x.reshape(2, 4, 2))
return avg_across_y.reshape(4, 4)
@partial(shard_map , mesh=mesh , in_specs=(P('x', 'y'),), out_specs=P('x'))
def sum_and_avg_shardmap(x):
sum_across_x = lax.psum(x, axis_name='x')
avg_across_y = lax.pmean(sum_across_x, axis_name='y')
return avg_across_y
shard_map
mesh = jax.make_mesh((2,2), ('x', 'y'))
sharding = NamedSharding(mesh , P('x', 'y'))
data = jnp.arange(16).reshape(4, 4)
sharded_data = lax.with_sharding_constraint(data, sharding)
@partial(jax.pmap, axis_name='x' , devices=mesh.devices[0])
@partial(jax.pmap, axis_name='y', devices=mesh.devices[1])
def sum_and_avg_nested_pmap(x):
sum_across_x = lax.psum(x, axis_name='x')
avg_across_y = lax.pmean(sum_across_x, axis_name='y')
return avg_across_y
def sum_and_avg_pmap(x):
sum_across_x = jax.pmap(lambda a: lax.psum(a, axis_name='x'),
axis_name='x',
devices=mesh.devices[0])(x.reshape(2, 2, 4))
avg_across_y = jax.pmap(lambda a: lax.pmean(a, axis_name='y'),
axis_name='y',
devices=mesh.devices[1])(sum_across_x.reshape(2, 4, 2))
return avg_across_y.reshape(4, 4)
@partial(shard_map , mesh=mesh , in_specs=(P('x', 'y'),), out_specs=P('x'))
def sum_and_avg_shardmap(x):
sum_across_x = lax.psum(x, axis_name='x')
avg_across_y = lax.pmean(sum_across_x, axis_name='y')
return avg_across_y
shard_map
mesh = jax.make_mesh((2,2), ('x', 'y'))
sharding = NamedSharding(mesh , P('x', 'y'))
data = jnp.arange(16).reshape(4, 4)
sharded_data = lax.with_sharding_constraint(data, sharding)
@partial(jax.pmap, axis_name='x' , devices=mesh.devices[0])
@partial(jax.pmap, axis_name='y', devices=mesh.devices[1])
def sum_and_avg_nested_pmap(x):
sum_across_x = lax.psum(x, axis_name='x')
avg_across_y = lax.pmean(sum_across_x, axis_name='y')
return avg_across_y
def sum_and_avg_pmap(x):
sum_across_x = jax.pmap(lambda a: lax.psum(a, axis_name='x'),
axis_name='x',
devices=mesh.devices[0])(x.reshape(2, 2, 4))
avg_across_y = jax.pmap(lambda a: lax.pmean(a, axis_name='y'),
axis_name='y',
devices=mesh.devices[1])(sum_across_x.reshape(2, 4, 2))
return avg_across_y.reshape(4, 4)
@partial(shard_map , mesh=mesh , in_specs=(P('x', 'y'),), out_specs=P('x'))
def sum_and_avg_shardmap(x):
sum_across_x = lax.psum(x, axis_name='x')
avg_across_y = lax.pmean(sum_across_x, axis_name='y')
return avg_across_y
shard_map
mesh = jax.make_mesh((2,2), ('x', 'y'))
sharding = NamedSharding(mesh , P('x', 'y'))
data = jnp.arange(16).reshape(4, 4)
sharded_data = lax.with_sharding_constraint(data, sharding)
@partial(jax.pmap, axis_name='x' , devices=mesh.devices[0])
@partial(jax.pmap, axis_name='y', devices=mesh.devices[1])
def sum_and_avg_nested_pmap(x):
sum_across_x = lax.psum(x, axis_name='x')
avg_across_y = lax.pmean(sum_across_x, axis_name='y')
return avg_across_y
def sum_and_avg_pmap(x):
sum_across_x = jax.pmap(lambda a: lax.psum(a, axis_name='x'),
axis_name='x',
devices=mesh.devices[0])(x.reshape(2, 2, 4))
avg_across_y = jax.pmap(lambda a: lax.pmean(a, axis_name='y'),
axis_name='y',
devices=mesh.devices[1])(sum_across_x.reshape(2, 4, 2))
return avg_across_y.reshape(4, 4)
@partial(shard_map , mesh=mesh , in_specs=(P('x', 'y'),), out_specs=P('x'))
def sum_and_avg_shardmap(x):
sum_across_x = lax.psum(x, axis_name='x')
avg_across_y = lax.pmean(sum_across_x, axis_name='y')
return avg_across_y
Upcoming Surveys and Massive Data in Cosmology
Cosmological Models and Pipelines
GDR CoPhy,IAP , 2024