Distributed Training

mle
python
Distributed training study notes and algorithms.
Author

Theo POMIES

Published

September 2, 2025

Modified

September 17, 2025

Functions / Methods

Basics (Point-Point)

send and recv to send or receive a tensor synchronously — from/to a single rank.

And their async counterparts, isend and irecv.

rank = dist.get_rank()
tensor = torch.arange(3) if rank == 0 else torch.zeros(3)

print(f"Before: {tensor}")

if rank == 0:
    request = dist.isend(tensor, 1)
    ...
    # can do something else, like more sends for example!
    ...
    request.wait() # now block until it's been fulfilled
elif rank == 1:
    dist.recv(tensor, 0) # recv is synchronous, so it will block until tensor is fully received

print(f"After: {tensor}")
========== rank 0 ==========
Before: tensor([0, 1, 2])
After: tensor([0, 1, 2])

========== rank 1 ==========
Before: tensor([0., 0., 0.])
After: tensor([0, 1, 2])

Collective Operations

Collective operations allow communication (data-transfer) from All->Point, Point->All and All->All.

Point->All

Broadcast

Broadcast (torch.distributed.broadcast(tensor, src, ...)) allows a rank to broadcast a tensor to the whole group.

tensor = torch.arange(3) if rank == 0 else torch.zeros(3)

print(f"Before: {tensor}")

dist.broadcast(tensor, src=0)

print(f"After: {tensor})
========== rank 0 ==========
Before: tensor([0, 1, 2])
After: tensor([0, 1, 2])

========== rank 1 ==========
Before: tensor([0., 0., 0.])
After: tensor([0, 1, 2])

========== rank 2 ==========
Before: tensor([0., 0., 0.])
After: tensor([0, 1, 2])
Scatter

Scatter (torch.distributed.scatter(tensor, scatter_list, src, ...)) allows us to scatter — split and broadcast different chunks of — a tensor from a rank to the whole group.

tensor = torch.zeros(3)
scatter_list = [torch.arange(3 * i, 3 * i + 3) if rank == 0 else torch.zeros(3) for i in range(world_size)]

print(f"Scatter list: {scatter_list}")
print(f"Before: {tensor}")

dist.scatter(tensor, scatter_list, src=0)

print(f"After: {tensor}")
========== rank 0 ==========
Scatter list: [tensor([0, 1, 2]), tensor([3, 4, 5])]
Before: tensor([0., 0., 0.])
After: tensor([0, 1, 2])

========== rank 1 ==========
Scatter list: [tensor([0., 0., 0.]), tensor([0., 0., 0.])]
Before: tensor([0., 0., 0.])
After: tensor([3, 4, 5])

All->Point

Reduce

Reduce (torch.distributed.reduce(tensor, dst, op, ...)) performs a reduction operation (N->1, eg. sum, max, min, prod, …) and the dst rank receives the result.

tensor = torch.arange(3) + rank * 3

print(f"Before: {tensor}")

dist.reduce(tensor, dst=0, op=dist.ReduceOp.SUM)

print(f"After: {tensor}")
========== rank 0 ==========
Before: tensor([0, 1, 2])
After: tensor([3, 5, 7])

========== rank 1 ==========
Before: tensor([3, 4, 5])
After: tensor([3, 4, 5])
Gather

Gather (torch.distributed.gather(tensor, gather_list, dst, ...)) gathers — pulls — a tensor, of the same size, from every rank and stores them in a list in a single rank.

tensor = torch.arange(3) + rank * 3
gather_list = [torch.zeros(3) for _ in range(world_size)]

print(f"Before: {tensor}")
print(f"Before: {gather_list}")

dist.gather(tensor, gather_list, dst=0)

print(f"After: {gather_list}")
========== rank 0 ==========
Before: tensor([0, 1, 2])
Before: [tensor([0., 0., 0.]), tensor([0., 0., 0.])]
Before: tensor([3, 4, 5])

========== rank 1 ==========
Before: [tensor([0., 0., 0.]), tensor([0., 0., 0.])]
After: [tensor([0, 1, 2]), tensor([3, 4, 5])]
After: [tensor([0., 0., 0.]), tensor([0., 0., 0.])]

All->All

All-Reduce

All-Reduce (torch.distributed.all_reduce(tensor, op, ...)) performs a reduction operation, like reduce, but every rank receives the result — rather than a single one with reduce. Think of it as reduce + broadcast — though it is optimized by techniques like ring-reduce.

tensor = torch.arange(3) + rank * 3

print(f"Before: {tensor}")

dist.all_reduce(tensor, op=dist.ReduceOp.SUM)

print(f"After: {tensor}")
========== rank 0 ==========
Before: tensor([0, 1, 2])
After: tensor([3, 5, 7])

========== rank 1 ==========
Before: tensor([3, 4, 5])
After: tensor([3, 5, 7])
All-Gather

All-Gather (torch.distributed.all_gather(tensor, gather_list, ...)) gathers — pulls — a tensor, of the same size, from every rank and stores them in a list in every rank. Think of it as running gather on all ranks.

tensor = torch.arange(3) + rank * 3
gather_list = [torch.zeros(3) for _ in range(world_size)]

print(f"Before: {tensor}")
print(f"Before: {gather_list}")

dist.all_gather(tensor, gather_list)

print(f"After: {gather_list}")
========== rank 2 ==========
Before: tensor([0, 1, 2])
Before: [tensor([0., 0., 0.]), tensor([0., 0., 0.])]
After: [tensor([0, 1, 2]), tensor([3, 4, 5])]

========== rank 1 ==========
Before: tensor([3, 4, 5])
Before: [tensor([0., 0., 0.]), tensor([0., 0., 0.])]
After: [tensor([0, 1, 2]), tensor([3, 4, 5])]
Reduce-Scatter

Reduce-Scatter (torch.distributed.reduce_scatter(output_tensor, input_list, op, ...)) performs a reduction operation — like other reduce functions — and scatters the resulting tensor. Think of it like reduce + scatter.

Note

It needs len(input_list) == world_size and every tensor in input_list to have the same shape of output_tensor.

tensor = torch.zeros(2)
scatter_list = [torch.tensor([(rank + 1) * i for i in range(1, 3)]) ** (j + 1) for j in range(3)]

print(f"Before: {tensor}")
print(f"Before: {scatter_list}")

dist.reduce_scatter(tensor, scatter_list, op=dist.ReduceOp.SUM)

print(f"After: {tensor}")
========== rank 0 ==========
Before: tensor([0., 0.])
Before: [tensor([1, 2]), tensor([1, 4]), tensor([1, 8])]
After: tensor([ 6, 12])

========== rank 1 ==========
Before: tensor([0., 0.])
Before: [tensor([2, 4]), tensor([ 4, 16]), tensor([ 8, 64])]
After: tensor([14, 56])

========== rank 2 ==========
Before: tensor([0., 0.])
Before: [tensor([3, 6]), tensor([ 9, 36]), tensor([ 27, 216])]
After: tensor([ 36, 288])

Other

Barrier

Barrier (torch.distributed.barrier(...)) synchronizes all ranks (processes), waiting for all of them to reach the barrier. (think of .join() for threads or processes)

Algorithms / Techniques

Data Sharding

Data Sharding is the process of sharding — splitting — the dataset / dataloader so that each rank only pulls their own unique mini-batches of the training data. This avoids duplicates and is more commucation / memory efficient that duplicating the same full dataset on every rank. To do this with torch, setup the DataLoader with sampler=[instance of DistributedSampler].

Types of parallelism

The goal of parallelism is to maximize throughput and cluster utilization.

  • Data Parallelism (DP): Each rank has a replica of the model — they’re replicants — and receives a different mini-batch. After optional [Gradient Accumulation] , gradients are averaged across ranks (all_reduce).
  • Pipeline Parallelism (PP): The model is split along the layers. Each rank has 1+ consecutive layers of the model, and we orchestrate sequential forward/backward passes along the ranks. This is inter-layer parallelism.
  • Tensor Parallelism (TP): The model’s layers themselves are split across ranks. We need more complex orchestration since a single tensor’s values are scattered across different ranks. This is intra-layer parallelism.
  • Expert Parallelism (EP): A specific type of TP where we only split the experts of an MoE across ranks.
Important

ZeRO/FSDP is not a parallelism strategy in the strict sense, but a memory-optimization strategy. It’s a highly memory-efficient form DP.

  • Parallelism = distributing computation to increase throughput.
  • Memory optimization (eg. ZeRO/FSDP) = sharding model states (parameters, gradients, optimizer states) across ranks so the model fits in memory, while each rank still computes the full forward and backward pass.

Thus:

  • With ZeRO/FSDP, every rank executes the full network computation but stores only a shard of the model states.
  • With TP/EP/PP, computation itself is partitioned across ranks, and the combined work reconstructs the whole forward/backward pass.

These approaches are complementary and usually combined in large-scale training.

DDP — Distributed Data Parallelism

Distributed Data Parallelism is a type of parallelism where each rank loads a copy — replica — of the model, after each optimizer step they always all have the same parameters, they are replicants. Each rank then trains on a different mini-batch (hence the importance of data sharding). We then average the gradients (all_reduce sum + division by world_size or avg operation if available), perform a step of gradient descent, rinse and repeat. If we can use this, we should, it has the least amount of overhead, but it requires that the model + optimizer states all fit in the device’s VRAM.

Note

The difference between DDP and DP is that DDP uses processes to avoid the GIL and DP uses threads. Do not use DP, only DDP.

class SimpleDataParaellism():
    def __init__(self, model):
        self.model = model

        for param in model.parameters():
            rank_0_params = param.data.clone()
            dist.broadcast(rank_0_params, src=0)
            assert torch.equal(param.data, rank_0_params), "Parameters mismatch at initialization"

    def sync_grad(self):
        for param in model.parameters():
            dist.all_reduce(param.grad, op=dist.ReduceOp.AVG) # only available on NCCL backend
            # eq.
            # dist.all_reduce(param.grad, op=dist.ReduceOp.SUM)
            # param.grad /= dist.get_world_size()
Note

The above is a Toy implementation, in reality you do not waste time and resources by doing a single all_reduce at the end. This leaves GPUs idle. You interleave computations and communications

ZeRO / FSDP

Zero Redudency Optimizer (ZeRO) by DeepSpeed is a modeling strategy involving sharding states and parameters during training as a mean of optimizing peak memory. The core idea is that the optimizer states, gradients and/or model parameters are sharded, retrieved only when necessary for some computation, then anything we do not use anymore is discarded.

Fully Sharded Data Parallelism (FSDP) is PyTorch’s implementation of ZeRo.

Paper Article FSDP Paper FSDP Doc

ZeRO-1

ZeRO stage 1 (aka. \(P_{os}\)) is the sharding/partitioning of optimizer states only. 4x memory reduction, communication volume of the same order as DP (gradient all-reduce dominates).

Forward pass

  • Same as DP: each rank stores the full model parameters and runs the full forward pass.

Backward pass

  • Same as DP: each rank computes all gradients locally.
  • Same as DP: gradients are averaged across ranks via all_reduce.
Note

Can be a reduce_scatter too

Optimizer step

  • Each rank holds the full parameters and full averaged gradients.
  • Each rank updates only the parameter shard corresponding to its shard of the optimizer state.
  • Updated parameter shards are then exchanged (all_gather) so all ranks end up with the full updated model.

ZeRO-2

ZeRO stage 2 (aka. \(P_{os} + P_g\)) is the sharding/partitioning of optimizer states and gradient. 8x memory reduction, communication volume of the same order as DP and ZeRO-1.

Forward pass

  • Same as DP: each rank stores the full model parameters and runs the full forward pass.

Backward pass

  • Each rank computes gradients locally, so gradients are temporarily materialized on every rank. This means ZeRO-2 has the same peak memory as ZeRO-1, but 8x lower persistent memory.
  • Gradients are averaged and sharded across ranks (reduce_scatter) — think averaging + sending to each rank the shard of the gradients that corresponds exactly to its optimizer state

Optimizer step

  • Each rank holds the full parameters.
  • Each rank holds only the averaged gradients corresponding to its shard of the optimizer state.
  • Each rank updates only the parameter shard corresponding to its shard of the optimizer state.
  • Updated parameter shards are then exchanged (all_gather) so all ranks end up with the full updated model.

ZeRO-3

ZeRO stage 3 (aka. \(P_{os} + P_g + P_p\)) is the sharding/partitioning of optimizer states and gradient and model parameters. Memory reduction scales linearly with our parallelism degree, larger communication overhead (≈50% more than DP/ZeRO-1/2) — (need to all_gather and reduce_scatter parameters before and after every computation requiring them).

(Assuming FP16 params and FP32 optimizer states)

Forward pass

  • Each rank stores its shard of the model parameters.
  • Whenever a parameter is needed for computation, it is materialized (all_gather from its shard)
  • The computation is done
  • The local param is released/flushed (del/=None) on every rank but the one owning it

Backward pass

  • Each rank runs the backward pass for its full model replica, but parameters must be all-gathered on demand.
  • Gradients are produced during backprop, then immediately reduce-scattered so only the owning rank keeps the shard.

Optimizer step

  • Each rank holds only the parameters’ shard corresponding to its shard of the gradients and optimizer state.
  • Each rank holds only the averaged gradients’ shard corresponding to its shard of the parameters and optimizer state.
  • Each rank updates only the parameter shard corresponding to its shard of the optimizer state.
  • Updated parameter shards are then exchanged (all_gather) so all ranks end up with the full updated model.

Terminology

  • device: Hardware unit — GPU "cuda:0", CPU "cpu" etc. that’s where tensors and computations live
  • node: Phyisical machine/server (or VPS whatever) that has 1+ devices
  • process: Python process/worker, executing a copy of the code/script — often on a single device (GPU)
  • rank: ID of a process — often that maps to a single device. rank without qualifiers is global rank
  • world: Set of all processes part of our current distributed job
  • global rank, world rank: rank across all processes/nodes. note: collective operations take the global rank (or just rank) as input for src/dst
  • local rank: rank within a single node (node not group). note: device takes the local rank "cuda:{local_rank}"
  • group: subset of processes (1+ nodes) that we’ve grouped for sub-communications. note: we still use global rank for intra-group communication.

Resources / References / Bookmarks