diff --git a/python/mscclpp/default_algos/allreduce_multi_nodes.py b/python/mscclpp/default_algos/allreduce_multi_nodes.py index 20227c5c..8e00fc1d 100644 --- a/python/mscclpp/default_algos/allreduce_multi_nodes.py +++ b/python/mscclpp/default_algos/allreduce_multi_nodes.py @@ -56,6 +56,17 @@ def allreduce_multi_nodes(spec: AlgoSpec, thread_block_group_size: int) -> Colle ) ) + # Scratch buffer layout (3 contiguous regions): + # Region 1 [0, total_gpus): + # Intra-node reduce-scatter. Each GPU receives chunks from gpus_per_node peers, + # packets_per_gpu each → gpus_per_node * packets_per_gpu = total_gpus slots. + # Region 2 [total_gpus, total_gpus + num_nodes * packets_per_gpu): + # Inter-node exchange. Each GPU receives reduced chunks from num_nodes nodes, + # packets_per_gpu each → num_nodes * packets_per_gpu slots. + # Region 3 [total_gpus + num_nodes * packets_per_gpu, end): + # Intra-node broadcast. Each GPU receives final reduced data from gpus_per_node peers, + # packets_per_gpu each → gpus_per_node * packets_per_gpu = total_gpus slots. + # Total = 2 * total_gpus + num_nodes * packets_per_gpu scratch_buffer_size = 2 * total_gpus + packets_per_gpu * num_nodes for node_id in range(num_nodes): for local_gpu_id in range(gpus_per_node):