update algo

This commit is contained in:
Empyreus
2026-06-08 18:24:02 +00:00
parent 00668b4a41
commit 0d8efdb43d

View File

@@ -9,14 +9,16 @@ from mscclpp.language.program import *
from mscclpp.language.collectives import *
def allgather_example(name, gpu_size, num_threads_per_block, min_message_size, max_message_size):
def allgather_example(name, gpu_size, num_threads_per_block, min_message_size, max_message_size, instances):
# Defaults instances=8, num_threads_per_block=256 are tuned for 16-GPU (4x GB200) MNNVL NVLS:
# they give the best busbw across 1MB-1GB (instances saturate at 8; tpb=256 beats 512/1024).
chunksperloop = 1
collective = AllGather(gpu_size, chunksperloop, True)
with CollectiveProgram(
name,
collective,
gpu_size,
instances=8,
instances=instances,
protocol="Simple",
num_threads_per_block=num_threads_per_block,
use_double_scratch_buffer=False,
@@ -74,10 +76,18 @@ parser = argparse.ArgumentParser()
parser.add_argument("--name", type=str, help="name of the program")
parser.add_argument("--num_gpus", type=int, help="number of gpus")
parser.add_argument("--num_threads_per_block", type=int, default=1024, help="number of threads per block")
parser.add_argument("--num_threads_per_block", type=int, default=256, help="number of threads per block")
parser.add_argument("--min_message_size", type=int, default=0, help="minimum message size")
parser.add_argument("--max_message_size", type=int, default=2**64 - 1, help="maximum message size")
parser.add_argument("--instances", type=int, default=8, help="number of instances (parallel threadblocks)")
args = parser.parse_args()
allgather_example(args.name, args.num_gpus, args.num_threads_per_block, args.min_message_size, args.max_message_size)
allgather_example(
args.name,
args.num_gpus,
args.num_threads_per_block,
args.min_message_size,
args.max_message_size,
args.instances,
)