This commit is contained in:
Caio Rocha
2026-04-14 22:52:27 +00:00
parent e6602b4a8b
commit 17774b5f83
4 changed files with 30 additions and 32 deletions

View File

@@ -53,6 +53,26 @@ default_algo_configs = [
),
"additional_kwargs": {"thread_block_group_size": 4},
},
{
"filename": "allgather_mrc.json",
"function": def_algo.allgather_mrc,
"spec": AlgoSpec(
name="allgather_mrc",
collective=AllGather(2, 1, True),
nranks_per_node=1,
world_size=2,
in_place=True,
instances=1,
protocol="Simple",
auto_sync=True,
num_threads_per_block=1024,
reuse_resources=False,
use_double_scratch_buffer=False,
min_message_size=1 << 10,
max_message_size=8 << 30,
tags={"default": 1},
)
}
]

View File

@@ -2,5 +2,6 @@
# Licensed under the MIT License.
from mscclpp.default_algos.allreduce_2nodes import allreduce_2nodes
from mscclpp.default_algos.allgather_mrc import allgather_mrc
__all__ = ["allreduce_2nodes"]
__all__ = ["allreduce_2nodes", "allgather_mrc"]

View File

@@ -8,23 +8,13 @@ from mscclpp.language.general import *
from mscclpp.language.program import *
from mscclpp.language.collectives import *
from mscclpp.language.loop import LoopIterationContext
from mscclpp.language.utils import AlgoSpec
def allgather_hierarchical(name, gpus, num_threads_per_block, min_message_size, max_message_size):
size = gpus
chunksperloop = 1
collective = AllGather(size, chunksperloop, True)
with CollectiveProgram(
name,
collective,
size,
protocol="Simple",
num_threads_per_block=num_threads_per_block,
instances=1,
use_double_scratch_buffer=False,
min_message_size=min_message_size,
max_message_size=max_message_size,
):
def allgather_mrc(spec: AlgoSpec) -> CollectiveProgram:
size = spec.world_size
with CollectiveProgram.from_spec(spec) as prog:
# Port channels for inter-node communication
port_channels = {}
for n in range(size):
@@ -65,18 +55,4 @@ def allgather_hierarchical(name, gpus, num_threads_per_block, min_message_size,
recv_src_chunk = Rank(src_rank).get_output_buffer()[recv_offset:recv_offset + 1]
ch_from_prev.wait(tb=0)
print(JSON())
parser = argparse.ArgumentParser()
parser.add_argument("--name", type=str, help="name of the program")
parser.add_argument("--num_gpus", type=int, help="number of gpus")
parser.add_argument("--num_threads_per_block", type=int, default=1024, help="number of threads per block")
parser.add_argument("--min_message_size", type=int, default=0, help="minimum message size")
parser.add_argument("--max_message_size", type=int, default=2**64 - 1, help="maximum message size")
args = parser.parse_args()
allgather_hierarchical(
args.name, args.num_gpus, args.num_threads_per_block, args.min_message_size, args.max_message_size
)
return prog