diff --git a/python/mscclpp/language/tests/single_node/allgather_nvls_zero_copy.py b/python/mscclpp/language/tests/single_node/allgather_nvls_zero_copy.py index e6058960..1a8787dd 100644 --- a/python/mscclpp/language/tests/single_node/allgather_nvls_zero_copy.py +++ b/python/mscclpp/language/tests/single_node/allgather_nvls_zero_copy.py @@ -9,14 +9,16 @@ from mscclpp.language.program import * from mscclpp.language.collectives import * -def allgather_example(name, gpu_size, num_threads_per_block, min_message_size, max_message_size): +def allgather_example(name, gpu_size, num_threads_per_block, min_message_size, max_message_size, instances): + # Defaults instances=8, num_threads_per_block=256 are tuned for 16-GPU (4x GB200) MNNVL NVLS: + # they give the best busbw across 1MB-1GB (instances saturate at 8; tpb=256 beats 512/1024). chunksperloop = 1 collective = AllGather(gpu_size, chunksperloop, True) with CollectiveProgram( name, collective, gpu_size, - instances=8, + instances=instances, protocol="Simple", num_threads_per_block=num_threads_per_block, use_double_scratch_buffer=False, @@ -74,10 +76,18 @@ parser = argparse.ArgumentParser() parser.add_argument("--name", type=str, help="name of the program") parser.add_argument("--num_gpus", type=int, help="number of gpus") -parser.add_argument("--num_threads_per_block", type=int, default=1024, help="number of threads per block") +parser.add_argument("--num_threads_per_block", type=int, default=256, help="number of threads per block") parser.add_argument("--min_message_size", type=int, default=0, help="minimum message size") parser.add_argument("--max_message_size", type=int, default=2**64 - 1, help="maximum message size") +parser.add_argument("--instances", type=int, default=8, help="number of instances (parallel threadblocks)") args = parser.parse_args() -allgather_example(args.name, args.num_gpus, args.num_threads_per_block, args.min_message_size, args.max_message_size) +allgather_example( + args.name, + args.num_gpus, + args.num_threads_per_block, + args.min_message_size, + args.max_message_size, + args.instances, +)