Wassim Kabalan
2025-10-01



Key Idea: Generative AI


\[\min_G \max_D \mathbb{E}_{x\sim p_\text{data}}[\log D(x)] + \mathbb{E}_{z\sim p(z)}[\log(1 - D(G(z)))]\]
Advantages and Limitations



Some possible solutions


VAE takeaways

Evidence Lower Bound (ELBO):
\[\mathcal{L} = \mathbb{E}_{q_\phi(z|x)}[\log p_\theta(x|z)] - \text{KL}(q_\phi(z|x)\|p(z))\]
Reparameterization trick: \[z = \mu + \sigma \odot \epsilon, \quad \epsilon \sim \mathcal{N}(0,I)\]
Without KL regularization:

Latent space is unstructured
With KL regularization:

Organized, continuous latent space
β-VAE objective:
\[\mathcal{L}_\beta = \mathbb{E}[\log p_\theta(x|z)] - \beta \cdot \text{KL}(q_\phi(z|x)\|p(z))\]





Mappings: \(f(x) \to y\), \(g(y) = z\); \(z = g_{\theta}(x) = g_K \circ \cdots \circ g_1(x)\)
Log-likelihood
\[\log p_{\theta}(x) = \log p_Z\big(g_{\theta}(x)\big) + \sum_{k=1}^{K} \log \left|\det J_{g_k}\right|\]
Rezende & Mohamed (2015)
In short
Coupling layer computation graph
Forward (x → y)
\[ y_a = x_a, \qquad y_b = x_b \odot e^{s(x_a)} + t(x_a). \]
Inverse (y → x)
Given \((y_a, y_b)\):
\[ x_a = y_a, \qquad x_b = \big(y_b - t(y_a)\big) \odot e^{-s(y_a)}. \]
MAF vs IAF comparison
MAF (Masked Autoregressive Flow)
Autoregressive flow where density evaluation is parallel sampling is sequential.
\[ z_i=\frac{x_i-\mu_i(x_{<i})}{\sigma_i(x_{<i})} \]
IAF (Inverse Autoregressive Flow) — fast sampling
Same masked structure but reversed so sampling is parallel (given \(z\)), density is sequential.
\[ x_i=\mu_i(z_{<i})+\sigma_i(z_{<i})\cdot z_i \]



Characteristics of Diffusion Models
Pros:
Cons:
| Model | Pros | Cons |
|---|---|---|
| VAE | - Explicit likelihood - Fast sampling - Good global structure |
- Blurry outputs - Prior bias (Gaussian) - Limited expressiveness |
| GAN | - Sharp, high-quality images - Fast sampling |
- Training instability - Mode collapse - No likelihood |
| Flow | - Exact likelihood - Invertible - Flexible distributions |
- Computationally expensive - Architecture constraints - Difficult to scale |
| Diffusion | - SOTA quality - Excellent mode coverage - Stable training |
- Slow sampling (many steps) - Computationally intensive - Careful tuning required |
<h2 style="margin:0;">JAX and OpenXLA</h2>
import numpy as np
import jax
import jax.numpy as jnp
mesh = jax.sharding.Mesh(np.array(jax.devices()), ('data',))
def double(x):
return x * 2
double_sharded = jax.shard_map(double, mesh=mesh,
in_specs=jax.sharding.PartitionSpec('data', None),
out_specs=jax.sharding.PartitionSpec('data', None))
# shard over axis 0
x = jnp.arange(8).reshape(len(jax.devices()), -1)
y = double_sharded(x)
In general
.at[].set() or similar to create modified copies.import torch.nn as nn
import torch.nn.functional as F
class CNN(nn.Module):
def __init__(self):
super().__init__()
self.c1 = nn.Conv2d(3, 32, 3); self.c2 = nn.Conv2d(32, 64, 3)
self.fc = nn.Linear(64*6*6, 10)
def forward(self, x):
x = F.relu(self.c1(x)); x = F.relu(self.c2(x))
x = x.view(x.size(0), -1); return self.fc(x)
model = CNN()import jax, jax.numpy as jnp
import flax.linen as nn
class CNN(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.relu(nn.Conv(32, (3,3))(x))
x = nn.relu(nn.Conv(64, (3,3))(x))
x = x.reshape((x.shape[0], -1))
return nn.Dense(10)(x)
model = CNN()
params = model.init(jax.random.PRNGKey(0), jnp.ones((1, 3, 8, 8)))import jax, jax.numpy as jnp
import optax
from flax.training import train_state
def loss_fn(params, batch):
logits = model.apply(params, batch["x"]) # forward
return optax.softmax_cross_entropy_with_integer_labels(
logits, batch["y"]
).mean()
@jax.jit
def train_step(state, batch):
grads = jax.grad(loss_fn)(state.params, batch)
return state.apply_gradients(grads=grads)
state = train_state.TrainState.create(
apply_fn=model.apply,
params=params,
tx=optax.adam(1e-3),
)
state = train_step(state, batch) # one update stepNeural network library for JAX with modules, layers, optimizers, training loops.
Gradient processing and optimization library for JAX.
MCMC sampling library for JAX with HMC, NUTS, SGLD algorithms.
Probabilistic programming library for JAX with Bayesian modeling and inference.
Differential equation solver library for JAX with ODE, SDE, DDE solvers.
Strengths
jitjit(vmap(grad(...))), shard_map for multi-deviceWeaknesses

Galaxy Zoo 10 from MultimodalUniverse:
Dataset statistics:
gz10_label, redshift, object_idFor this workshop: Downsize to 32×32 or 64×64 for
Explore the dataset
Notebook link : Generative_AI_JAX_GZ10_VAE.ipynb
Variational Autoencoder
Train a VAE on GZ10 galaxies and analyze latent space
Notebook link : Generative_AI_JAX_GZ10_VAE.ipynb
VAE latent space to redshift
Classify galaxy morphology from VAE latent space
Notebook link : Generative_AI_JAX_GZ10_VAE_Classifier.ipynb
AISSAI School 2025