= dist.get_rank()
rank = torch.arange(3) if rank == 0 else torch.zeros(3)
tensor
print(f"Before: {tensor}")
if rank == 0:
= dist.isend(tensor, 1)
request
...# can do something else, like more sends for example!
...# now block until it's been fulfilled
request.wait() elif rank == 1:
0) # recv is synchronous, so it will block until tensor is fully received
dist.recv(tensor,
print(f"After: {tensor}")
Distributed Training
Functions / Methods
Basics (Point-Point)
send
and recv
to send or receive a tensor synchronously — from/to a single rank.
And their async counterparts, isend
and irecv
.
========== rank 0 ==========
Before: tensor([0, 1, 2])
After: tensor([0, 1, 2])
========== rank 1 ==========
Before: tensor([0., 0., 0.])
After: tensor([0, 1, 2])
Collective Operations
Collective operations allow communication (data-transfer) from All->Point, Point->All and All->All.
Point->All
Broadcast
Broadcast (torch.distributed.broadcast(tensor, src, ...)
) allows a rank to broadcast a tensor to the whole group.
= torch.arange(3) if rank == 0 else torch.zeros(3)
tensor
print(f"Before: {tensor}")
=0)
dist.broadcast(tensor, src
print(f"After: {tensor})
========== rank 0 ==========
Before: tensor([0, 1, 2])
After: tensor([0, 1, 2])
========== rank 1 ==========
Before: tensor([0., 0., 0.])
After: tensor([0, 1, 2])
========== rank 2 ==========
Before: tensor([0., 0., 0.])
After: tensor([0, 1, 2])
Scatter
Scatter (torch.distributed.scatter(tensor, scatter_list, src, ...)
) allows us to scatter — split and broadcast different chunks of — a tensor from a rank to the whole group.
= torch.zeros(3)
tensor = [torch.arange(3 * i, 3 * i + 3) if rank == 0 else torch.zeros(3) for i in range(world_size)]
scatter_list
print(f"Scatter list: {scatter_list}")
print(f"Before: {tensor}")
=0)
dist.scatter(tensor, scatter_list, src
print(f"After: {tensor}")
========== rank 0 ==========
Scatter list: [tensor([0, 1, 2]), tensor([3, 4, 5])]
Before: tensor([0., 0., 0.])
After: tensor([0, 1, 2])
========== rank 1 ==========
Scatter list: [tensor([0., 0., 0.]), tensor([0., 0., 0.])]
Before: tensor([0., 0., 0.])
After: tensor([3, 4, 5])
All->Point
Reduce
Reduce (torch.distributed.reduce(tensor, dst, op, ...)
) performs a reduction operation (N->1, eg. sum, max, min, prod, …) and the dst
rank receives the result.
= torch.arange(3) + rank * 3
tensor
print(f"Before: {tensor}")
reduce(tensor, dst=0, op=dist.ReduceOp.SUM)
dist.
print(f"After: {tensor}")
========== rank 0 ==========
Before: tensor([0, 1, 2])
After: tensor([3, 5, 7])
========== rank 1 ==========
Before: tensor([3, 4, 5])
After: tensor([3, 4, 5])
Gather
Gather (torch.distributed.gather(tensor, gather_list, dst, ...)
) gathers — pulls — a tensor, of the same size, from every rank and stores them in a list in a single rank.
= torch.arange(3) + rank * 3
tensor = [torch.zeros(3) for _ in range(world_size)]
gather_list
print(f"Before: {tensor}")
print(f"Before: {gather_list}")
=0)
dist.gather(tensor, gather_list, dst
print(f"After: {gather_list}")
========== rank 0 ==========
Before: tensor([0, 1, 2])
Before: [tensor([0., 0., 0.]), tensor([0., 0., 0.])]
Before: tensor([3, 4, 5])
========== rank 1 ==========
Before: [tensor([0., 0., 0.]), tensor([0., 0., 0.])]
After: [tensor([0, 1, 2]), tensor([3, 4, 5])]
After: [tensor([0., 0., 0.]), tensor([0., 0., 0.])]
All->All
All-Reduce
All-Reduce (torch.distributed.all_reduce(tensor, op, ...)
) performs a reduction operation, like reduce
, but every rank receives the result — rather than a single one with reduce
. Think of it as reduce
+ broadcast
— though it is optimized by techniques like ring-reduce.
= torch.arange(3) + rank * 3
tensor
print(f"Before: {tensor}")
=dist.ReduceOp.SUM)
dist.all_reduce(tensor, op
print(f"After: {tensor}")
========== rank 0 ==========
Before: tensor([0, 1, 2])
After: tensor([3, 5, 7])
========== rank 1 ==========
Before: tensor([3, 4, 5])
After: tensor([3, 5, 7])
All-Gather
All-Gather (torch.distributed.all_gather(tensor, gather_list, ...)
) gathers — pulls — a tensor, of the same size, from every rank and stores them in a list in every rank. Think of it as running gather
on all ranks.
= torch.arange(3) + rank * 3
tensor = [torch.zeros(3) for _ in range(world_size)]
gather_list
print(f"Before: {tensor}")
print(f"Before: {gather_list}")
dist.all_gather(tensor, gather_list)
print(f"After: {gather_list}")
========== rank 2 ==========
Before: tensor([0, 1, 2])
Before: [tensor([0., 0., 0.]), tensor([0., 0., 0.])]
After: [tensor([0, 1, 2]), tensor([3, 4, 5])]
========== rank 1 ==========
Before: tensor([3, 4, 5])
Before: [tensor([0., 0., 0.]), tensor([0., 0., 0.])]
After: [tensor([0, 1, 2]), tensor([3, 4, 5])]
Reduce-Scatter
Reduce-Scatter (torch.distributed.reduce_scatter(output_tensor, input_list, op, ...)
) performs a reduction operation — like other reduce
functions — and scatters the resulting tensor. Think of it like reduce
+ scatter
. Note: it needs len(input_list) == world_size
and every tensor in input_list
to have the same shape of output_tensor
.
= torch.zeros(2)
tensor = [torch.tensor([(rank + 1) * i for i in range(1, 3)]) ** (j + 1) for j in range(3)]
scatter_list
print(f"Before: {tensor}")
print(f"Before: {scatter_list}")
=dist.ReduceOp.SUM)
dist.reduce_scatter(tensor, scatter_list, op
print(f"After: {tensor}")
========== rank 0 ==========
Before: tensor([0., 0.])
Before: [tensor([1, 2]), tensor([1, 4]), tensor([1, 8])]
After: tensor([ 6, 12])
========== rank 1 ==========
Before: tensor([0., 0.])
Before: [tensor([2, 4]), tensor([ 4, 16]), tensor([ 8, 64])]
After: tensor([14, 56])
========== rank 2 ==========
Before: tensor([0., 0.])
Before: [tensor([3, 6]), tensor([ 9, 36]), tensor([ 27, 216])]
After: tensor([ 36, 288])
Other
Barrier
Barrier (torch.distributed.barrier(...)
) synchronizes all ranks (processes), waiting for all of them to reach the barrier. (think of .join()
for threads or processes)
Algorithms / Techniques
DDP — Distributed Data Parallelism
Distributed Data Parallelism is a type of parallelism where each rank loads a copy — replica — of the model, after each optimizer step they always all have the same parameters, they are replicants. Each rank then trains on a different mini-batch (hence the importance of data sharding). We then average the gradients (all_reduce
sum + division by world_size), perform a step of gradient descent, rinse and repeat. If we can use this, we should, it has the least amount of overhead, but it requires that the model + optimizer states all fit in the device’s VRAM.
Note: the difference between DDP and DP is that DDP uses processes to avoid the GIL and DP uses threads. Do not use DP, only DDP.
class SimpleDataParaellism():
def __init__(self, model):
self.model = model
for param in model.parameters():
= param.data.clone()
rank_0_params =0)
dist.broadcast(rank_0_params, srcassert torch.equal(param.data, rank_0_params), "Parameters mismatch at initialization"
def sync_grad(self):
for param in model.parameters():
=dist.ReduceOp.SUM)
dist.all_reduce(param.grad, op/= dist.get_world_size() param.grad
Data Sharding
Data Sharding is the process of sharding — splitting — the dataset / dataloader so that each rank only pulls their own unique mini-batches of the training data. This avoids duplicates and is more commucation / memory efficient that duplicating the same full dataset on every rank. To do this with torch, setup the DataLoader
with sampler=[instance of DistributedSampler]
.
Notes
Terminology
- device: Hardware unit — GPU
"cuda:0"
, CPU"cpu"
etc. that’s where tensors and computations live - node: Phyisical machine/server (or VPS whatever) that has 1+ devices
- process: Python process/worker, executing a copy of the code/script — often on a single device (GPU)
- rank: ID of a process — often that maps to a single device. rank without qualifiers is global rank
- world: Set of all processes part of our current distributed job
- global rank, world rank: rank across all processes/nodes. note: collective operations take the global rank (or just rank) as input for
src
/dst
- local rank: rank within a single node (node not group). note:
device
takes the local rank"cuda:{local_rank}"
- group: subset of processes (1+ nodes) that we’ve grouped for sub-communications. note: we still use global rank for intra-group communication.