This commit is contained in:
Caio Rocha
2026-05-13 21:21:01 +00:00
parent 719e9124af
commit 718e5b2897
4 changed files with 56 additions and 4 deletions

View File

@@ -54,10 +54,10 @@ default_algo_configs = [
"additional_kwargs": {"thread_block_group_size": 4},
},
{
"filename": "allgather_mrc.json",
"filename": "allgather_mrc_2_nodes.json",
"function": def_algo.allgather_mrc,
"spec": AlgoSpec(
name="allgather_mrc",
name="allgather_mrc_2_nodes",
collective=AllGather(2, 1, True),
nranks_per_node=1,
world_size=2,
@@ -72,6 +72,26 @@ default_algo_configs = [
max_message_size=8 << 30,
tags={"default": 1},
)
},
{
"filename": "allgather_mrc_4_nodes.json",
"function": def_algo.allgather_mrc,
"spec": AlgoSpec(
name="allgather_mrc_4_nodes",
collective=AllGather(4, 1, True),
nranks_per_node=1,
world_size=4,
in_place=True,
instances=1,
protocol="Simple",
auto_sync=True,
num_threads_per_block=1024,
reuse_resources=False,
use_double_scratch_buffer=False,
min_message_size=1 << 10,
max_message_size=8 << 30,
tags={"default": 1},
)
}
]

View File

@@ -56,3 +56,35 @@ def allgather_mrc(spec: AlgoSpec) -> CollectiveProgram:
ch_from_prev.wait(tb=0)
return prog
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--name", type=str, help="name of the program")
parser.add_argument("--num_gpus", type=int, help="total number of gpus")
parser.add_argument("--gpus_per_node", type=int, help="number of gpus per node")
parser.add_argument("--num_threads_per_block", type=int, default=1024, help="number of threads per block")
parser.add_argument("--min_message_size", type=int, default=0, help="minimum message size")
parser.add_argument("--max_message_size", type=int, default=2**64 - 1, help="maximum message size")
args = parser.parse_args()
spec = AlgoSpec(
name=args.name,
collective=AllGather(args.num_gpus, 1, True),
nranks_per_node=args.gpus_per_node,
world_size=args.num_gpus,
in_place=True,
instances=1,
protocol="Simple",
auto_sync=True,
num_threads_per_block=args.num_threads_per_block,
reuse_resources=False,
use_double_scratch_buffer=False,
min_message_size=args.min_message_size,
max_message_size=args.max_message_size,
)
prog = allgather_mrc(spec)
print(prog.to_json())