From 0c9b9abfd512c51e5df4d9b27f39627c3693013b Mon Sep 17 00:00:00 2001 From: Caio Rocha <164253795+caiomcbr@users.noreply.github.com> Date: Tue, 12 May 2026 13:45:55 -0700 Subject: [PATCH] Adding Support 4 Nodes AllReduce Small Message Size (#794) Results on 4 Nodes H200: | Size | NCCL | MSCCL++ 57TB | MSCCL++ 29TB | |------|-------|--------------|--------------| | 8K | 45.75 | 17.74 | 18.18 | | 16K | 47.08 | 18.9 | 18.42 | | 32K | 47.29 | 19.48 | 19.12 | | 64K | 50.34 | 20.51 | 19.29 | | 128K | 59.65 | 21.37 | 20.25 | | 256K | 87.46 | 23.87 | 23.51 | | 512K | 106.55| 29.15 | 29.51 | | 1M | 115 | 40.64 | 41.83 | | 2M | 135.89| 63.73 | 70.45 | | 4M | 177.59| 121.76 | 128.79 | | 8M | 251.17| 228.5 | 251.36 | --------- Co-authored-by: Binyang Li Co-authored-by: Caio Rocha --- python/mscclpp/__main__.py | 46 ++++++- python/mscclpp/default_algos/__init__.py | 4 +- ...uce_2nodes.py => allreduce_multi_nodes.py} | 124 +++++++++++++----- .../algorithm_collection_builder.cc | 4 +- 4 files changed, 143 insertions(+), 35 deletions(-) rename python/mscclpp/default_algos/{allreduce_2nodes.py => allreduce_multi_nodes.py} (61%) diff --git a/python/mscclpp/__main__.py b/python/mscclpp/__main__.py index 6a6f5f28..450ec748 100644 --- a/python/mscclpp/__main__.py +++ b/python/mscclpp/__main__.py @@ -13,7 +13,7 @@ from mscclpp.language.utils import AlgoSpec default_algo_configs = [ { "filename": "allreduce_2nodes_1K_64K.json", - "function": def_algo.allreduce_2nodes, + "function": def_algo.allreduce_multi_nodes, "spec": AlgoSpec( name="allreduce_2nodes_1K_64K", collective=AllReduce(16, 1, True), @@ -34,7 +34,7 @@ default_algo_configs = [ }, { "filename": "allreduce_2nodes_128K_2M.json", - "function": def_algo.allreduce_2nodes, + "function": def_algo.allreduce_multi_nodes, "spec": AlgoSpec( name="allreduce_2nodes_128K_2M", collective=AllReduce(16, 1, True), @@ -53,6 +53,48 @@ default_algo_configs = [ ), "additional_kwargs": {"thread_block_group_size": 4}, }, + { + "filename": "allreduce_4nodes_1K_64K.json", + "function": def_algo.allreduce_multi_nodes, + "spec": AlgoSpec( + name="allreduce_4nodes_1K_64K", + collective=AllReduce(32, 1, True), + nranks_per_node=8, + world_size=32, + in_place=True, + instances=1, + protocol="LL", + auto_sync=False, + num_threads_per_block=1024, + reuse_resources=True, + use_double_scratch_buffer=True, + min_message_size=1 << 10, + max_message_size=64 << 10, + tags={"default": 1}, + ), + "additional_kwargs": {"thread_block_group_size": 1}, + }, + { + "filename": "allreduce_4nodes_128K_2M.json", + "function": def_algo.allreduce_multi_nodes, + "spec": AlgoSpec( + name="allreduce_4nodes_128K_2M", + collective=AllReduce(32, 1, True), + nranks_per_node=8, + world_size=32, + in_place=True, + instances=1, + protocol="LL", + auto_sync=False, + num_threads_per_block=1024, + reuse_resources=True, + use_double_scratch_buffer=True, + min_message_size=128 << 10, + max_message_size=2 << 20, + tags={"default": 1}, + ), + "additional_kwargs": {"thread_block_group_size": 4}, + }, ] diff --git a/python/mscclpp/default_algos/__init__.py b/python/mscclpp/default_algos/__init__.py index a5cfa882..1767aab6 100644 --- a/python/mscclpp/default_algos/__init__.py +++ b/python/mscclpp/default_algos/__init__.py @@ -1,6 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from mscclpp.default_algos.allreduce_2nodes import allreduce_2nodes +from mscclpp.default_algos.allreduce_multi_nodes import allreduce_multi_nodes -__all__ = ["allreduce_2nodes"] +__all__ = ["allreduce_multi_nodes"] diff --git a/python/mscclpp/default_algos/allreduce_2nodes.py b/python/mscclpp/default_algos/allreduce_multi_nodes.py similarity index 61% rename from python/mscclpp/default_algos/allreduce_2nodes.py rename to python/mscclpp/default_algos/allreduce_multi_nodes.py index 5a355887..5697a0e3 100644 --- a/python/mscclpp/default_algos/allreduce_2nodes.py +++ b/python/mscclpp/default_algos/allreduce_multi_nodes.py @@ -2,9 +2,11 @@ # Licensed under the MIT License. """ -Multi-node AllReduce implementation using packet-based communication. -This implements a hierarchical AllReduce: intra-node allreduce followed by -inter-node exchange and final intra-node allreduce. +Generalized multi-node AllReduce implementation using packet-based communication. +This implements a hierarchical AllReduce for N nodes: +1. Intra-node reduce-scatter (each GPU reduces its assigned chunk across the node) +2. Inter-node allreduce (exchange fully intra-reduced chunks across all nodes) +3. Intra-node broadcast (distribute the fully reduced chunks back to all GPUs in the node) """ from mscclpp.language.utils import AlgoSpec @@ -15,7 +17,7 @@ from mscclpp.language.program import * from mscclpp.language.collectives import * -def allreduce_2nodes(spec: AlgoSpec, thread_block_group_size: int) -> CollectiveProgram: +def allreduce_multi_nodes(spec: AlgoSpec, thread_block_group_size: int) -> CollectiveProgram: """ Implements a multi-node AllReduce using a hierarchical approach: 1. Intra-node allreduce @@ -23,10 +25,10 @@ def allreduce_2nodes(spec: AlgoSpec, thread_block_group_size: int) -> Collective 3. Intra-node allreduce """ # Configuration constants - num_nodes = 2 + num_nodes = spec.world_size // spec.nranks_per_node gpus_per_node = spec.nranks_per_node total_gpus = num_nodes * gpus_per_node - packets_per_gpu = 2 + packets_per_gpu = num_nodes with CollectiveProgram.from_spec(spec) as prog: # Initialize communication channels and buffers @@ -54,11 +56,21 @@ def allreduce_2nodes(spec: AlgoSpec, thread_block_group_size: int) -> Collective ) ) - scratch_buffer_size = packets_per_gpu * (total_gpus + 1) + # Scratch buffer layout (3 contiguous regions): + # Region 1 [0, total_gpus): + # Intra-node reduce-scatter. Each GPU receives chunks from gpus_per_node peers, + # packets_per_gpu each → gpus_per_node * packets_per_gpu = total_gpus slots. + # Region 2 [total_gpus, total_gpus + num_nodes * packets_per_gpu): + # Inter-node exchange. Each GPU receives reduced chunks from num_nodes nodes, + # packets_per_gpu each → num_nodes * packets_per_gpu slots. + # Region 3 [total_gpus + num_nodes * packets_per_gpu, end): + # Intra-node broadcast. Each GPU receives final reduced data from gpus_per_node peers, + # packets_per_gpu each → gpus_per_node * packets_per_gpu = total_gpus slots. + # Total = 2 * total_gpus + num_nodes * packets_per_gpu + scratch_buffer_size = 2 * total_gpus + packets_per_gpu * num_nodes for node_id in range(num_nodes): for local_gpu_id in range(gpus_per_node): current_rank_id = local_gpu_id + gpus_per_node * node_id - next_node_rank_id = (local_gpu_id + gpus_per_node * (node_id + 1)) % total_gpus scratch_buffers.append(Buffer(current_rank_id, scratch_buffer_size)) for peer_gpu_id in range(gpus_per_node): if peer_gpu_id != local_gpu_id: @@ -66,7 +78,12 @@ def allreduce_2nodes(spec: AlgoSpec, thread_block_group_size: int) -> Collective intra_node_memory_channels[(peer_rank_id, current_rank_id)] = MemoryChannel( peer_rank_id, current_rank_id ) - inter_node_port_channels[current_rank_id] = PortChannel(next_node_rank_id, current_rank_id) + for peer_node_id in range(num_nodes): + if peer_node_id != node_id: + peer_node_rank_id = (local_gpu_id + gpus_per_node * peer_node_id) % total_gpus + inter_node_port_channels[(current_rank_id, peer_node_rank_id)] = PortChannel( + peer_node_rank_id, current_rank_id + ) # AllReduce for node_id in range(num_nodes): @@ -74,7 +91,6 @@ def allreduce_2nodes(spec: AlgoSpec, thread_block_group_size: int) -> Collective current_rank_id = local_gpu_id + gpus_per_node * node_id current_rank = Rank(current_rank_id) input_buffer = current_rank.get_input_buffer() - next_node_rank_id = (local_gpu_id + gpus_per_node * (node_id + 1)) % total_gpus # Intra Node Exchange Data for peer_gpu_id in range(gpus_per_node): @@ -118,27 +134,32 @@ def allreduce_2nodes(spec: AlgoSpec, thread_block_group_size: int) -> Collective ) inter_node_offset = total_gpus - inter_node_port_channels[current_rank_id].put_packets( - scratch_buffers[next_node_rank_id][ - inter_node_offset - + local_gpu_id * packets_per_gpu : inter_node_offset - + local_gpu_id * packets_per_gpu - + packets_per_gpu - ], - scratch_buffers[current_rank_id][ - local_gpu_id * packets_per_gpu : local_gpu_id * packets_per_gpu + packets_per_gpu - ], - tb=0, - ) + for peer_node_id in range(num_nodes): + if peer_node_id != node_id: + peer_node_rank_id = (local_gpu_id + gpus_per_node * peer_node_id) % total_gpus + inter_node_port_channels[(current_rank_id, peer_node_rank_id)].put_packets( + scratch_buffers[peer_node_rank_id][ + inter_node_offset + + node_id * packets_per_gpu : inter_node_offset + + node_id * packets_per_gpu + + packets_per_gpu + ], + scratch_buffers[current_rank_id][ + local_gpu_id * packets_per_gpu : local_gpu_id * packets_per_gpu + packets_per_gpu + ], + tb=0, + ) # Reduce Received Data from Remote Node inter_node_data = [ scratch_buffers[current_rank_id][ inter_node_offset - + local_gpu_id * packets_per_gpu : inter_node_offset - + local_gpu_id * packets_per_gpu + + peer_node_id * packets_per_gpu : inter_node_offset + + peer_node_id * packets_per_gpu + packets_per_gpu ] + for peer_node_id in range(num_nodes) + if peer_node_id != node_id ] current_rank.reduce( input_buffer[local_gpu_id * packets_per_gpu : local_gpu_id * packets_per_gpu + packets_per_gpu], @@ -148,12 +169,18 @@ def allreduce_2nodes(spec: AlgoSpec, thread_block_group_size: int) -> Collective ) current_rank.copy_packets( - scratch_buffers[current_rank_id][scratch_buffer_size - packets_per_gpu : scratch_buffer_size], + scratch_buffers[current_rank_id][ + inter_node_offset + + node_id * packets_per_gpu : inter_node_offset + + node_id * packets_per_gpu + + packets_per_gpu + ], input_buffer[local_gpu_id * packets_per_gpu : local_gpu_id * packets_per_gpu + packets_per_gpu], tb_group=global_intra_node_tbg, ) # Broadcast Reduced Data + broadcast_offset = total_gpus + packets_per_gpu * num_nodes for peer_gpu_id in range(gpus_per_node): peer_rank_id = peer_gpu_id + gpus_per_node * node_id @@ -161,13 +188,16 @@ def allreduce_2nodes(spec: AlgoSpec, thread_block_group_size: int) -> Collective tbg_id = peer_gpu_id if peer_gpu_id < local_gpu_id else peer_gpu_id - 1 intra_node_memory_channels[(peer_rank_id, current_rank_id)].read_put_packets( scratch_buffers[peer_rank_id][ - inter_node_offset - + local_gpu_id * packets_per_gpu : inter_node_offset + broadcast_offset + + local_gpu_id * packets_per_gpu : broadcast_offset + local_gpu_id * packets_per_gpu + packets_per_gpu ], scratch_buffers[current_rank_id][ - scratch_buffer_size - packets_per_gpu : scratch_buffer_size + inter_node_offset + + node_id * packets_per_gpu : inter_node_offset + + node_id * packets_per_gpu + + packets_per_gpu ], tb_group=thread_block_groups[tbg_id], ) @@ -181,8 +211,8 @@ def allreduce_2nodes(spec: AlgoSpec, thread_block_group_size: int) -> Collective peer_gpu_id * packets_per_gpu : peer_gpu_id * packets_per_gpu + packets_per_gpu ], scratch_buffers[current_rank_id][ - inter_node_offset - + peer_gpu_id * packets_per_gpu : inter_node_offset + broadcast_offset + + peer_gpu_id * packets_per_gpu : broadcast_offset + peer_gpu_id * packets_per_gpu + packets_per_gpu ], @@ -190,3 +220,37 @@ def allreduce_2nodes(spec: AlgoSpec, thread_block_group_size: int) -> Collective ) return prog + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--name", type=str, help="name of the program") + parser.add_argument("--num_gpus", type=int, help="total number of gpus") + parser.add_argument("--gpus_per_node", type=int, help="number of gpus per node") + parser.add_argument("--tbg", type=int, default=1, help="thread block group size") + parser.add_argument("--num_threads_per_block", type=int, default=1024, help="number of threads per block") + parser.add_argument("--min_message_size", type=int, default=0, help="minimum message size") + parser.add_argument("--max_message_size", type=int, default=2**64 - 1, help="maximum message size") + + args = parser.parse_args() + + spec = AlgoSpec( + name=args.name, + collective=AllReduce(args.num_gpus, 1, True), + nranks_per_node=args.gpus_per_node, + world_size=args.num_gpus, + in_place=True, + instances=1, + protocol="LL", + auto_sync=False, + num_threads_per_block=args.num_threads_per_block, + reuse_resources=True, + use_double_scratch_buffer=True, + min_message_size=args.min_message_size, + max_message_size=args.max_message_size, + ) + + prog = allreduce_multi_nodes(spec, args.tbg) + print(prog.to_json()) diff --git a/src/ext/collectives/algorithm_collection_builder.cc b/src/ext/collectives/algorithm_collection_builder.cc index 2a7e6e91..5d196d12 100644 --- a/src/ext/collectives/algorithm_collection_builder.cc +++ b/src/ext/collectives/algorithm_collection_builder.cc @@ -113,7 +113,9 @@ AlgorithmCollection AlgorithmCollectionBuilder::buildDefaultDslAlgorithms(int ra }; static const std::vector defaultAlgoConfigs = { {"allreduce_2nodes_1K_64K.json", "allreduce", 8, 16, {{"default", 1}}}, - {"allreduce_2nodes_64K_2M.json", "allreduce", 8, 16, {{"default", 1}}}}; + {"allreduce_2nodes_128K_2M.json", "allreduce", 8, 16, {{"default", 1}}}, + {"allreduce_4nodes_1K_64K.json", "allreduce", 8, 32, {{"default", 1}}}, + {"allreduce_4nodes_128K_2M.json", "allreduce", 4, 64, {{"default", 1}}}}; AlgorithmCollection collection; static auto generateFileId = [](const std::string& input) {