rank = dist.get_rank()
tensor = torch.arange(3) if rank == 0 else torch.zeros(3)
print(f"Before: {tensor}")
if rank == 0:
request = dist.isend(tensor, 1)
...
# can do something else, like more sends for example!
...
request.wait() # now block until it's been fulfilled
elif rank == 1:
dist.recv(tensor, 0) # recv is synchronous, so it will block until tensor is fully received
print(f"After: {tensor}")Distributed Training
Functions / Methods
Basics (Point-Point)
send and recv to send or receive a tensor synchronously — from/to a single rank.
And their async counterparts, isend and irecv.
========== rank 0 ==========
Before: tensor([0, 1, 2])
After: tensor([0, 1, 2])
========== rank 1 ==========
Before: tensor([0., 0., 0.])
After: tensor([0, 1, 2])
Collective Operations
Collective operations allow communication (data-transfer) from All->Point, Point->All and All->All.
Point->All
Broadcast
Broadcast (torch.distributed.broadcast(tensor, src, ...)) allows a rank to broadcast a tensor to the whole group.
tensor = torch.arange(3) if rank == 0 else torch.zeros(3)
print(f"Before: {tensor}")
dist.broadcast(tensor, src=0)
print(f"After: {tensor})========== rank 0 ==========
Before: tensor([0, 1, 2])
After: tensor([0, 1, 2])
========== rank 1 ==========
Before: tensor([0., 0., 0.])
After: tensor([0, 1, 2])
========== rank 2 ==========
Before: tensor([0., 0., 0.])
After: tensor([0, 1, 2])
Scatter
Scatter (torch.distributed.scatter(tensor, scatter_list, src, ...)) allows us to scatter — split and broadcast different chunks of — a tensor from a rank to the whole group.
tensor = torch.zeros(3)
scatter_list = [torch.arange(3 * i, 3 * i + 3) if rank == 0 else torch.zeros(3) for i in range(world_size)]
print(f"Scatter list: {scatter_list}")
print(f"Before: {tensor}")
dist.scatter(tensor, scatter_list, src=0)
print(f"After: {tensor}")========== rank 0 ==========
Scatter list: [tensor([0, 1, 2]), tensor([3, 4, 5])]
Before: tensor([0., 0., 0.])
After: tensor([0, 1, 2])
========== rank 1 ==========
Scatter list: [tensor([0., 0., 0.]), tensor([0., 0., 0.])]
Before: tensor([0., 0., 0.])
After: tensor([3, 4, 5])
All->Point
Reduce
Reduce (torch.distributed.reduce(tensor, dst, op, ...)) performs a reduction operation (N->1, eg. sum, max, min, prod, …) and the dst rank receives the result.
tensor = torch.arange(3) + rank * 3
print(f"Before: {tensor}")
dist.reduce(tensor, dst=0, op=dist.ReduceOp.SUM)
print(f"After: {tensor}")========== rank 0 ==========
Before: tensor([0, 1, 2])
After: tensor([3, 5, 7])
========== rank 1 ==========
Before: tensor([3, 4, 5])
After: tensor([3, 4, 5])
Gather
Gather (torch.distributed.gather(tensor, gather_list, dst, ...)) gathers — pulls — a tensor, of the same size, from every rank and stores them in a list in a single rank.
tensor = torch.arange(3) + rank * 3
gather_list = [torch.zeros(3) for _ in range(world_size)]
print(f"Before: {tensor}")
print(f"Before: {gather_list}")
dist.gather(tensor, gather_list, dst=0)
print(f"After: {gather_list}")========== rank 0 ==========
Before: tensor([0, 1, 2])
Before: [tensor([0., 0., 0.]), tensor([0., 0., 0.])]
Before: tensor([3, 4, 5])
========== rank 1 ==========
Before: [tensor([0., 0., 0.]), tensor([0., 0., 0.])]
After: [tensor([0, 1, 2]), tensor([3, 4, 5])]
After: [tensor([0., 0., 0.]), tensor([0., 0., 0.])]
All->All
All-Reduce
All-Reduce (torch.distributed.all_reduce(tensor, op, ...)) performs a reduction operation, like reduce, but every rank receives the result — rather than a single one with reduce. Think of it as reduce + broadcast — though it is optimized by techniques like ring-reduce.
tensor = torch.arange(3) + rank * 3
print(f"Before: {tensor}")
dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
print(f"After: {tensor}")========== rank 0 ==========
Before: tensor([0, 1, 2])
After: tensor([3, 5, 7])
========== rank 1 ==========
Before: tensor([3, 4, 5])
After: tensor([3, 5, 7])
All-Gather
All-Gather (torch.distributed.all_gather(tensor, gather_list, ...)) gathers — pulls — a tensor, of the same size, from every rank and stores them in a list in every rank. Think of it as running gather on all ranks.
tensor = torch.arange(3) + rank * 3
gather_list = [torch.zeros(3) for _ in range(world_size)]
print(f"Before: {tensor}")
print(f"Before: {gather_list}")
dist.all_gather(tensor, gather_list)
print(f"After: {gather_list}")========== rank 2 ==========
Before: tensor([0, 1, 2])
Before: [tensor([0., 0., 0.]), tensor([0., 0., 0.])]
After: [tensor([0, 1, 2]), tensor([3, 4, 5])]
========== rank 1 ==========
Before: tensor([3, 4, 5])
Before: [tensor([0., 0., 0.]), tensor([0., 0., 0.])]
After: [tensor([0, 1, 2]), tensor([3, 4, 5])]
Reduce-Scatter
Reduce-Scatter (torch.distributed.reduce_scatter(output_tensor, input_list, op, ...)) performs a reduction operation — like other reduce functions — and scatters the resulting tensor. Think of it like reduce + scatter.
It needs len(input_list) == world_size and every tensor in input_list to have the same shape of output_tensor.
tensor = torch.zeros(2)
scatter_list = [torch.tensor([(rank + 1) * i for i in range(1, 3)]) ** (j + 1) for j in range(3)]
print(f"Before: {tensor}")
print(f"Before: {scatter_list}")
dist.reduce_scatter(tensor, scatter_list, op=dist.ReduceOp.SUM)
print(f"After: {tensor}")========== rank 0 ==========
Before: tensor([0., 0.])
Before: [tensor([1, 2]), tensor([1, 4]), tensor([1, 8])]
After: tensor([ 6, 12])
========== rank 1 ==========
Before: tensor([0., 0.])
Before: [tensor([2, 4]), tensor([ 4, 16]), tensor([ 8, 64])]
After: tensor([14, 56])
========== rank 2 ==========
Before: tensor([0., 0.])
Before: [tensor([3, 6]), tensor([ 9, 36]), tensor([ 27, 216])]
After: tensor([ 36, 288])
Other
Barrier
Barrier (torch.distributed.barrier(...)) synchronizes all ranks (processes), waiting for all of them to reach the barrier. (think of .join() for threads or processes)
Algorithms / Techniques
Data Sharding
Data Sharding is the process of sharding — splitting — the dataset / dataloader so that each rank only pulls their own unique mini-batches of the training data. This avoids duplicates and is more commucation / memory efficient that duplicating the same full dataset on every rank. To do this with torch, setup the DataLoader with sampler=[instance of DistributedSampler].
Types of parallelism
The goal of parallelism is to maximize throughput and cluster utilization.
- Data Parallelism (DP): Each rank has a replica of the model — they’re replicants — and receives a different mini-batch. After optional [Gradient Accumulation] , gradients are averaged across ranks (
all_reduce). - Pipeline Parallelism (PP): The model is split along the layers. Each rank has 1+ consecutive layers of the model, and we orchestrate sequential forward/backward passes along the ranks. This is inter-layer parallelism.
- Tensor Parallelism (TP): The model’s layers themselves are split across ranks. We need more complex orchestration since a single tensor’s values are scattered across different ranks. This is intra-layer parallelism.
- Expert Parallelism (EP): A specific type of TP where we only split the experts of an MoE across ranks.
ZeRO/FSDP is not a parallelism strategy in the strict sense, but a memory-optimization strategy. It’s a highly memory-efficient form DP.
- Parallelism = distributing computation to increase throughput.
- Memory optimization (eg. ZeRO/FSDP) = sharding model states (parameters, gradients, optimizer states) across ranks so the model fits in memory, while each rank still computes the full forward and backward pass.
Thus:
- With ZeRO/FSDP, every rank executes the full network computation but stores only a shard of the model states.
- With TP/EP/PP, computation itself is partitioned across ranks, and the combined work reconstructs the whole forward/backward pass.
These approaches are complementary and usually combined in large-scale training.
DDP — Distributed Data Parallelism
Distributed Data Parallelism is a type of parallelism where each rank loads a copy — replica — of the model, after each optimizer step they always all have the same parameters, they are replicants. Each rank then trains on a different mini-batch (hence the importance of data sharding). We then average the gradients (all_reduce sum + division by world_size or avg operation if available), perform a step of gradient descent, rinse and repeat. If we can use this, we should, it has the least amount of overhead, but it requires that the model + optimizer states all fit in the device’s VRAM.
The difference between DDP and DP is that DDP uses processes to avoid the GIL and DP uses threads. Do not use DP, only DDP.
class SimpleDataParaellism():
def __init__(self, model):
self.model = model
for param in model.parameters():
rank_0_params = param.data.clone()
dist.broadcast(rank_0_params, src=0)
assert torch.equal(param.data, rank_0_params), "Parameters mismatch at initialization"
def sync_grad(self):
for param in model.parameters():
dist.all_reduce(param.grad, op=dist.ReduceOp.AVG) # only available on NCCL backend
# eq.
# dist.all_reduce(param.grad, op=dist.ReduceOp.SUM)
# param.grad /= dist.get_world_size()The above is a Toy implementation, in reality you do not waste time and resources by doing a single all_reduce at the end. This leaves GPUs idle. You interleave computations and communications
ZeRO / FSDP
Zero Redudency Optimizer (ZeRO) by DeepSpeed is a modeling strategy involving sharding states and parameters during training as a mean of optimizing peak memory. The core idea is that the optimizer states, gradients and/or model parameters are sharded, retrieved only when necessary for some computation, then anything we do not use anymore is discarded.
Fully Sharded Data Parallelism (FSDP) is PyTorch’s implementation of ZeRo.
Paper Article FSDP Paper FSDP Doc
ZeRO-1
ZeRO stage 1 (aka. \(P_{os}\)) is the sharding/partitioning of optimizer states only. 4x memory reduction, communication volume of the same order as DP (gradient all-reduce dominates).
Forward pass
- Same as DP: each rank stores the full model parameters and runs the full forward pass.
Backward pass
- Same as DP: each rank computes all gradients locally.
- Same as DP: gradients are averaged across ranks via
all_reduce.
Can be a reduce_scatter too
Optimizer step
- Each rank holds the full parameters and full averaged gradients.
- Each rank updates only the parameter shard corresponding to its shard of the optimizer state.
- Updated parameter shards are then exchanged (
all_gather) so all ranks end up with the full updated model.
ZeRO-2
ZeRO stage 2 (aka. \(P_{os} + P_g\)) is the sharding/partitioning of optimizer states and gradient. 8x memory reduction, communication volume of the same order as DP and ZeRO-1.
Forward pass
- Same as DP: each rank stores the full model parameters and runs the full forward pass.
Backward pass
- Each rank computes gradients locally, so gradients are temporarily materialized on every rank. This means ZeRO-2 has the same peak memory as ZeRO-1, but 8x lower persistent memory.
- Gradients are averaged and sharded across ranks (
reduce_scatter) — think averaging + sending to each rank the shard of the gradients that corresponds exactly to its optimizer state
Optimizer step
- Each rank holds the full parameters.
- Each rank holds only the averaged gradients corresponding to its shard of the optimizer state.
- Each rank updates only the parameter shard corresponding to its shard of the optimizer state.
- Updated parameter shards are then exchanged (
all_gather) so all ranks end up with the full updated model.
ZeRO-3
ZeRO stage 3 (aka. \(P_{os} + P_g + P_p\)) is the sharding/partitioning of optimizer states and gradient and model parameters. Memory reduction scales linearly with our parallelism degree, larger communication overhead (≈50% more than DP/ZeRO-1/2) — (need to all_gather and reduce_scatter parameters before and after every computation requiring them).
(Assuming FP16 params and FP32 optimizer states)
Forward pass
- Each rank stores its shard of the model parameters.
- Whenever a parameter is needed for computation, it is materialized (
all_gatherfrom its shard) - The computation is done
- The local param is released/flushed (
del/=None) on every rank but the one owning it
Backward pass
- Each rank runs the backward pass for its full model replica, but parameters must be all-gathered on demand.
- Gradients are produced during backprop, then immediately reduce-scattered so only the owning rank keeps the shard.
Optimizer step
- Each rank holds only the parameters’ shard corresponding to its shard of the gradients and optimizer state.
- Each rank holds only the averaged gradients’ shard corresponding to its shard of the parameters and optimizer state.
- Each rank updates only the parameter shard corresponding to its shard of the optimizer state.
- Updated parameter shards are then exchanged (
all_gather) so all ranks end up with the full updated model.
Terminology
- device: Hardware unit — GPU
"cuda:0", CPU"cpu"etc. that’s where tensors and computations live - node: Phyisical machine/server (or VPS whatever) that has 1+ devices
- process: Python process/worker, executing a copy of the code/script — often on a single device (GPU)
- rank: ID of a process — often that maps to a single device. rank without qualifiers is global rank
- world: Set of all processes part of our current distributed job
- global rank, world rank: rank across all processes/nodes. note: collective operations take the global rank (or just rank) as input for
src/dst - local rank: rank within a single node (node not group). note:
devicetakes the local rank"cuda:{local_rank}" - group: subset of processes (1+ nodes) that we’ve grouped for sub-communications. note: we still use global rank for intra-group communication.