mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-24 23:06:17 +00:00
wip
This commit is contained in:
@@ -54,10 +54,10 @@ default_algo_configs = [
|
||||
"additional_kwargs": {"thread_block_group_size": 4},
|
||||
},
|
||||
{
|
||||
"filename": "allgather_mrc.json",
|
||||
"filename": "allgather_mrc_2_nodes.json",
|
||||
"function": def_algo.allgather_mrc,
|
||||
"spec": AlgoSpec(
|
||||
name="allgather_mrc",
|
||||
name="allgather_mrc_2_nodes",
|
||||
collective=AllGather(2, 1, True),
|
||||
nranks_per_node=1,
|
||||
world_size=2,
|
||||
@@ -72,6 +72,26 @@ default_algo_configs = [
|
||||
max_message_size=8 << 30,
|
||||
tags={"default": 1},
|
||||
)
|
||||
},
|
||||
{
|
||||
"filename": "allgather_mrc_4_nodes.json",
|
||||
"function": def_algo.allgather_mrc,
|
||||
"spec": AlgoSpec(
|
||||
name="allgather_mrc_4_nodes",
|
||||
collective=AllGather(4, 1, True),
|
||||
nranks_per_node=1,
|
||||
world_size=4,
|
||||
in_place=True,
|
||||
instances=1,
|
||||
protocol="Simple",
|
||||
auto_sync=True,
|
||||
num_threads_per_block=1024,
|
||||
reuse_resources=False,
|
||||
use_double_scratch_buffer=False,
|
||||
min_message_size=1 << 10,
|
||||
max_message_size=8 << 30,
|
||||
tags={"default": 1},
|
||||
)
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
@@ -56,3 +56,35 @@ def allgather_mrc(spec: AlgoSpec) -> CollectiveProgram:
|
||||
ch_from_prev.wait(tb=0)
|
||||
|
||||
return prog
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--name", type=str, help="name of the program")
|
||||
parser.add_argument("--num_gpus", type=int, help="total number of gpus")
|
||||
parser.add_argument("--gpus_per_node", type=int, help="number of gpus per node")
|
||||
parser.add_argument("--num_threads_per_block", type=int, default=1024, help="number of threads per block")
|
||||
parser.add_argument("--min_message_size", type=int, default=0, help="minimum message size")
|
||||
parser.add_argument("--max_message_size", type=int, default=2**64 - 1, help="maximum message size")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
spec = AlgoSpec(
|
||||
name=args.name,
|
||||
collective=AllGather(args.num_gpus, 1, True),
|
||||
nranks_per_node=args.gpus_per_node,
|
||||
world_size=args.num_gpus,
|
||||
in_place=True,
|
||||
instances=1,
|
||||
protocol="Simple",
|
||||
auto_sync=True,
|
||||
num_threads_per_block=args.num_threads_per_block,
|
||||
reuse_resources=False,
|
||||
use_double_scratch_buffer=False,
|
||||
min_message_size=args.min_message_size,
|
||||
max_message_size=args.max_message_size,
|
||||
)
|
||||
|
||||
prog = allgather_mrc(spec)
|
||||
print(prog.to_json())
|
||||
@@ -114,7 +114,8 @@ AlgorithmCollection AlgorithmCollectionBuilder::buildDefaultDslAlgorithms(int ra
|
||||
static const std::vector<DslAlgoConfig> defaultAlgoConfigs = {
|
||||
{"allreduce_2nodes_1K_64K.json", "allreduce", 8, 16, {{"default", 1}}},
|
||||
{"allreduce_2nodes_64K_2M.json", "allreduce", 8, 16, {{"default", 1}}},
|
||||
{"allgather_mrc.json", "allgather", 1, 2, {{"default", 1}}}};
|
||||
{"allgather_mrc_2_nodes.json", "allgather", 1, 2, {{"default", 1}}},
|
||||
{"allgather_mrc_4_nodes.json", "allgather", 1, 4, {{"default", 1}}}};
|
||||
AlgorithmCollection collection;
|
||||
|
||||
static auto generateFileId = [](const std::string& input) {
|
||||
|
||||
@@ -53,7 +53,6 @@ bool matchExecutionPlan(std::shared_ptr<DslAlgorithm> algo, const CollectiveRequ
|
||||
bool maxSizeMatch = effectiveSize <= algo->messageRange().second;
|
||||
bool result =
|
||||
worldSizeMatch && ranksPerNodeMatch && collectiveMatch && bufferModeMatch && minSizeMatch && maxSizeMatch;
|
||||
return true;
|
||||
return result;
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user