Async Tensor Parallelism with Less Wright + Efficient Strategies for Distributed Inference with Marc Sun
Async TP
Yesterday we learned a bit more about TP, and today thanks to Less Wright we learned about Async TP. The idea is to decompose our operations (comm and computations) into more finegrained operations (say instead of a big matmul that would require receiving the full tensor, we do smaller matmuls and receive a sharded input slice by slice). We also have 2 streams, a computation stream performing matmuls (and other kernels) and a communication stream, that way we can do both in parallel and not waste cycles.
However there is a “quantization wave” (see this article) at the end of matmuls, and because we split the work to interleave compute and comms, we have alot of matmuls. A solution is to swap the roles of the streams at the end of the first computations, having the compute stream become the comms stream and vice-versa.
Efficient Strategies for Distributed Inference
Then Marc Sun gave us a very complete talk on optimizing inference. This inside vLLM article and this Accelerating PyTorch inference article should cover most of the topics. We discussed topics such as prefill vs decode phase, KV Caching, PagedAttention, torch.compile
, Quantization, speculative decoding, continuous batching, prefix caching, TP/PP/DP.