Distributed Training

mle
python
Distributed training study notes and algorithms.
Author

Theo POMIES

Published

September 2, 2025

Modified

September 5, 2025

Functions / Methods

Basics (Point-Point)

send and recv to send or receive a tensor synchronously — from/to a single rank.

And their async counterparts, isend and irecv.

rank = dist.get_rank()
tensor = torch.arange(3) if rank == 0 else torch.zeros(3)

print(f"Before: {tensor}")

if rank == 0:
    request = dist.isend(tensor, 1)
    ...
    # can do something else, like more sends for example!
    ...
    request.wait() # now block until it's been fulfilled
elif rank == 1:
    dist.recv(tensor, 0) # recv is synchronous, so it will block until tensor is fully received

print(f"After: {tensor}")
========== rank 0 ==========
Before: tensor([0, 1, 2])
After: tensor([0, 1, 2])

========== rank 1 ==========
Before: tensor([0., 0., 0.])
After: tensor([0, 1, 2])

Collective Operations

Collective operations allow communication (data-transfer) from All->Point, Point->All and All->All.

Point->All

Broadcast

Broadcast (torch.distributed.broadcast(tensor, src, ...)) allows a rank to broadcast a tensor to the whole group.

tensor = torch.arange(3) if rank == 0 else torch.zeros(3)

print(f"Before: {tensor}")

dist.broadcast(tensor, src=0)

print(f"After: {tensor})
========== rank 0 ==========
Before: tensor([0, 1, 2])
After: tensor([0, 1, 2])

========== rank 1 ==========
Before: tensor([0., 0., 0.])
After: tensor([0, 1, 2])

========== rank 2 ==========
Before: tensor([0., 0., 0.])
After: tensor([0, 1, 2])
Scatter

Scatter (torch.distributed.scatter(tensor, scatter_list, src, ...)) allows us to scatter — split and broadcast different chunks of — a tensor from a rank to the whole group.

tensor = torch.zeros(3)
scatter_list = [torch.arange(3 * i, 3 * i + 3) if rank == 0 else torch.zeros(3) for i in range(world_size)]

print(f"Scatter list: {scatter_list}")
print(f"Before: {tensor}")

dist.scatter(tensor, scatter_list, src=0)

print(f"After: {tensor}")
========== rank 0 ==========
Scatter list: [tensor([0, 1, 2]), tensor([3, 4, 5])]
Before: tensor([0., 0., 0.])
After: tensor([0, 1, 2])

========== rank 1 ==========
Scatter list: [tensor([0., 0., 0.]), tensor([0., 0., 0.])]
Before: tensor([0., 0., 0.])
After: tensor([3, 4, 5])

All->Point

Reduce

Reduce (torch.distributed.reduce(tensor, dst, op, ...)) performs a reduction operation (N->1, eg. sum, max, min, prod, …) and the dst rank receives the result.

tensor = torch.arange(3) + rank * 3

print(f"Before: {tensor}")

dist.reduce(tensor, dst=0, op=dist.ReduceOp.SUM)

print(f"After: {tensor}")
========== rank 0 ==========
Before: tensor([0, 1, 2])
After: tensor([3, 5, 7])

========== rank 1 ==========
Before: tensor([3, 4, 5])
After: tensor([3, 4, 5])
Gather

Gather (torch.distributed.gather(tensor, gather_list, dst, ...)) gathers — pulls — a tensor, of the same size, from every rank and stores them in a list in a single rank.

tensor = torch.arange(3) + rank * 3
gather_list = [torch.zeros(3) for _ in range(world_size)]

print(f"Before: {tensor}")
print(f"Before: {gather_list}")

dist.gather(tensor, gather_list, dst=0)

print(f"After: {gather_list}")
========== rank 0 ==========
Before: tensor([0, 1, 2])
Before: [tensor([0., 0., 0.]), tensor([0., 0., 0.])]
Before: tensor([3, 4, 5])

========== rank 1 ==========
Before: [tensor([0., 0., 0.]), tensor([0., 0., 0.])]
After: [tensor([0, 1, 2]), tensor([3, 4, 5])]
After: [tensor([0., 0., 0.]), tensor([0., 0., 0.])]

All->All

All-Reduce

All-Reduce (torch.distributed.all_reduce(tensor, op, ...)) performs a reduction operation, like reduce, but every rank receives the result — rather than a single one with reduce. Think of it as reduce + broadcast — though it is optimized by techniques like ring-reduce.

tensor = torch.arange(3) + rank * 3

print(f"Before: {tensor}")

dist.all_reduce(tensor, op=dist.ReduceOp.SUM)

print(f"After: {tensor}")
========== rank 0 ==========
Before: tensor([0, 1, 2])
After: tensor([3, 5, 7])

========== rank 1 ==========
Before: tensor([3, 4, 5])
After: tensor([3, 5, 7])
All-Gather

All-Gather (torch.distributed.all_gather(tensor, gather_list, ...)) gathers — pulls — a tensor, of the same size, from every rank and stores them in a list in every rank. Think of it as running gather on all ranks.

tensor = torch.arange(3) + rank * 3
gather_list = [torch.zeros(3) for _ in range(world_size)]

print(f"Before: {tensor}")
print(f"Before: {gather_list}")

dist.all_gather(tensor, gather_list)

print(f"After: {gather_list}")
========== rank 2 ==========
Before: tensor([0, 1, 2])
Before: [tensor([0., 0., 0.]), tensor([0., 0., 0.])]
After: [tensor([0, 1, 2]), tensor([3, 4, 5])]

========== rank 1 ==========
Before: tensor([3, 4, 5])
Before: [tensor([0., 0., 0.]), tensor([0., 0., 0.])]
After: [tensor([0, 1, 2]), tensor([3, 4, 5])]
Reduce-Scatter

Reduce-Scatter (torch.distributed.reduce_scatter(output_tensor, input_list, op, ...)) performs a reduction operation — like other reduce functions — and scatters the resulting tensor. Think of it like reduce + scatter. Note: it needs len(input_list) == world_size and every tensor in input_list to have the same shape of output_tensor.

tensor = torch.zeros(2)
scatter_list = [torch.tensor([(rank + 1) * i for i in range(1, 3)]) ** (j + 1) for j in range(3)]

print(f"Before: {tensor}")
print(f"Before: {scatter_list}")

dist.reduce_scatter(tensor, scatter_list, op=dist.ReduceOp.SUM)

print(f"After: {tensor}")
========== rank 0 ==========
Before: tensor([0., 0.])
Before: [tensor([1, 2]), tensor([1, 4]), tensor([1, 8])]
After: tensor([ 6, 12])

========== rank 1 ==========
Before: tensor([0., 0.])
Before: [tensor([2, 4]), tensor([ 4, 16]), tensor([ 8, 64])]
After: tensor([14, 56])

========== rank 2 ==========
Before: tensor([0., 0.])
Before: [tensor([3, 6]), tensor([ 9, 36]), tensor([ 27, 216])]
After: tensor([ 36, 288])

Other

Barrier

Barrier (torch.distributed.barrier(...)) synchronizes all ranks (processes), waiting for all of them to reach the barrier. (think of .join() for threads or processes)

Algorithms / Techniques

DDP — Distributed Data Parallelism

Distributed Data Parallelism is a type of parallelism where each rank loads a copy — replica — of the model, after each optimizer step they always all have the same parameters, they are replicants. Each rank then trains on a different mini-batch (hence the importance of data sharding). We then average the gradients (all_reducesum + division by world_size), perform a step of gradient descent, rinse and repeat. If we can use this, we should, it has the least amount of overhead, but it requires that the model + optimizer states all fit in the device’s VRAM.

Note: the difference between DDP and DP is that DDP uses processes to avoid the GIL and DP uses threads. Do not use DP, only DDP.

class SimpleDataParaellism():
    def __init__(self, model):
        self.model = model

        for param in model.parameters():
            rank_0_params = param.data.clone()
            dist.broadcast(rank_0_params, src=0)
            assert torch.equal(param.data, rank_0_params), "Parameters mismatch at initialization"

    def sync_grad(self):
        for param in model.parameters():
            dist.all_reduce(param.grad, op=dist.ReduceOp.SUM)
            param.grad /= dist.get_world_size()

Data Sharding

Data Sharding is the process of sharding — splitting — the dataset / dataloader so that each rank only pulls their own unique mini-batches of the training data. This avoids duplicates and is more commucation / memory efficient that duplicating the same full dataset on every rank. To do this with torch, setup the DataLoader with sampler=[instance of DistributedSampler].

Notes

Terminology

  • device: Hardware unit — GPU "cuda:0", CPU "cpu" etc. that’s where tensors and computations live
  • node: Phyisical machine/server (or VPS whatever) that has 1+ devices
  • process: Python process/worker, executing a copy of the code/script — often on a single device (GPU)
  • rank: ID of a process — often that maps to a single device. rank without qualifiers is global rank
  • world: Set of all processes part of our current distributed job
  • global rank, world rank: rank across all processes/nodes. note: collective operations take the global rank (or just rank) as input for src/dst
  • local rank: rank within a single node (node not group). note: device takes the local rank "cuda:{local_rank}"
  • group: subset of processes (1+ nodes) that we’ve grouped for sub-communications. note: we still use global rank for intra-group communication.