mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-04-19 22:39:11 +00:00
Add C++ executor test (#304)
- Add C++ executor test - Fix executor bugs for packet operation - Enhance executor_test.py --------- Co-authored-by: Binyang Li <binyli@microsoft.com>
This commit is contained in:
@@ -1,19 +1,18 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from os import path
|
||||
import argparse
|
||||
from mscclpp import (
|
||||
DataType,
|
||||
Executor,
|
||||
ExecutionPlan,
|
||||
PacketType,
|
||||
)
|
||||
import mscclpp.comm as mscclpp_comm
|
||||
|
||||
import cupy as cp
|
||||
from mpi4py import MPI
|
||||
|
||||
MSCCLPP_ROOT_PATH = "/root/mscclpp"
|
||||
|
||||
|
||||
def bench_time(niters: int, ngraphIters: int, func):
|
||||
# capture cuda graph for niters of the kernel launch
|
||||
@@ -40,36 +39,118 @@ def bench_time(niters: int, ngraphIters: int, func):
|
||||
return cp.cuda.get_elapsed_time(start, end) / niters * 1000.0 / ngraphIters
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
mscclpp_group = mscclpp_comm.CommGroup(MPI.COMM_WORLD)
|
||||
cp.cuda.Device(MPI.COMM_WORLD.rank % mscclpp_group.nranks_per_node).use()
|
||||
executor = Executor(mscclpp_group.communicator)
|
||||
execution_plan = ExecutionPlan(
|
||||
"allreduce_pairs", path.join(MSCCLPP_ROOT_PATH, "test", "execution-files", "allreduce.json")
|
||||
)
|
||||
def parse_size(size_str):
|
||||
"""Convert a human-readable buffer size string to an integer."""
|
||||
size_str = size_str.strip()
|
||||
if not size_str:
|
||||
raise ValueError("Size string can not be empty")
|
||||
units = {"K": 1024, "M": 1024**2, "G": 1024**3}
|
||||
if size_str[-1].upper() in units:
|
||||
return int(size_str[:-1]) * units[size_str[-1].upper()]
|
||||
else:
|
||||
return int(size_str)
|
||||
|
||||
nelems = 1024 * 1024
|
||||
cp.random.seed(42)
|
||||
buffer = cp.random.random(nelems).astype(cp.float16)
|
||||
|
||||
def parse_dtype(dtype_str):
|
||||
"""Convert a human-readable data type string to a numpy data type."""
|
||||
dtype_str = dtype_str.strip().lower()
|
||||
if dtype_str == "float16":
|
||||
return cp.float16
|
||||
elif dtype_str == "float32":
|
||||
return cp.float32
|
||||
elif dtype_str == "int32":
|
||||
return cp.int32
|
||||
else:
|
||||
raise ValueError(f"Unknown data type: {dtype_str}")
|
||||
|
||||
|
||||
def dtype_to_mscclpp_dtype(dtype):
|
||||
if dtype == cp.float16:
|
||||
return DataType.float16
|
||||
elif dtype == cp.float32:
|
||||
return DataType.float32
|
||||
elif dtype == cp.int32:
|
||||
return DataType.int32
|
||||
else:
|
||||
raise ValueError(f"Unknown data type: {dtype}")
|
||||
|
||||
|
||||
def main(
|
||||
execution_paln_name: str,
|
||||
execution_plan_path: str,
|
||||
size: int,
|
||||
nthreads_per_block: int,
|
||||
dtype: cp.dtype = cp.float16,
|
||||
packet_type: PacketType = PacketType.LL16,
|
||||
seed: int = 42,
|
||||
):
|
||||
mscclpp_group = mscclpp_comm.CommGroup(MPI.COMM_WORLD)
|
||||
cp.cuda.Device(mscclpp_group.my_rank % mscclpp_group.nranks_per_node).use()
|
||||
executor = Executor(mscclpp_group.communicator)
|
||||
execution_plan = ExecutionPlan(execution_paln_name, execution_plan_path)
|
||||
|
||||
cp.random.seed(seed)
|
||||
nelems = size // cp.dtype(dtype).itemsize
|
||||
buffer = cp.random.random(nelems * mscclpp_group.nranks).astype(dtype)
|
||||
sub_arrays = cp.split(buffer, MPI.COMM_WORLD.size)
|
||||
sendbuf = sub_arrays[MPI.COMM_WORLD.rank]
|
||||
expected = cp.zeros_like(sendbuf)
|
||||
for i in range(mscclpp_group.nranks):
|
||||
expected += sub_arrays[i]
|
||||
mscclpp_group.barrier()
|
||||
|
||||
execution_time = bench_time(
|
||||
100,
|
||||
10,
|
||||
lambda stream: executor.execute(
|
||||
MPI.COMM_WORLD.rank,
|
||||
sendbuf.data.ptr,
|
||||
sendbuf.data.ptr,
|
||||
sendbuf.nbytes,
|
||||
sendbuf.nbytes,
|
||||
DataType.float16,
|
||||
512,
|
||||
execution_plan,
|
||||
stream.ptr,
|
||||
),
|
||||
executor_func = lambda stream: executor.execute(
|
||||
MPI.COMM_WORLD.rank,
|
||||
sendbuf.data.ptr,
|
||||
sendbuf.data.ptr,
|
||||
sendbuf.nbytes,
|
||||
sendbuf.nbytes,
|
||||
dtype_to_mscclpp_dtype(dtype),
|
||||
nthreads_per_block,
|
||||
execution_plan,
|
||||
stream.ptr,
|
||||
packet_type,
|
||||
)
|
||||
# check correctness
|
||||
stream = cp.cuda.Stream(non_blocking=True)
|
||||
executor_func(stream)
|
||||
stream.synchronize()
|
||||
assert cp.allclose(sendbuf, expected, atol=1e-2 * mscclpp_group.nranks)
|
||||
|
||||
mscclpp_group.barrier()
|
||||
execution_time = bench_time(100, 10, executor_func)
|
||||
print(
|
||||
f"Rank: {MPI.COMM_WORLD.rank} Execution time: {execution_time} us, "
|
||||
f"data size: {sendbuf.nbytes} bytes data type: {dtype().dtype.name} "
|
||||
f"packet type: {packet_type} nthreads_per_block: {nthreads_per_block}"
|
||||
)
|
||||
print(f"Rank: {MPI.COMM_WORLD.rank} Execution time: {execution_time} us, data size: {sendbuf.nbytes} bytes")
|
||||
executor = None
|
||||
mscclpp_group = None
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-n", "--execution_plan_name", type=str, required=True)
|
||||
parser.add_argument("-path", "--execution_plan_path", type=str, required=True)
|
||||
parser.add_argument("--size", type=str, required=True)
|
||||
parser.add_argument("--nthreads_per_block", type=int, required=True)
|
||||
parser.add_argument("--dtype", type=str, default="float16", help="Choose from float16, float32, int32")
|
||||
parser.add_argument("--packet_type", type=str, default="LL16", help="Choose from LL8, LL16")
|
||||
parser.add_argument("--seed", type=int, default=42)
|
||||
args = parser.parse_args()
|
||||
|
||||
packet_type = PacketType.LL16
|
||||
if args.packet_type == "LL8":
|
||||
packet_type = PacketType.LL8
|
||||
|
||||
buffer_size = parse_size(args.size)
|
||||
dtype = parse_dtype(args.dtype)
|
||||
main(
|
||||
args.execution_plan_name,
|
||||
args.execution_plan_path,
|
||||
buffer_size,
|
||||
args.nthreads_per_block,
|
||||
dtype,
|
||||
packet_type,
|
||||
args.seed,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user