mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-06-07 08:14:51 +00:00
Tune threads per block for mscclpp executor (#345)
This commit is contained in:
@@ -29,11 +29,10 @@ void register_executor(nb::module_& m) {
|
||||
.def(
|
||||
"execute",
|
||||
[](Executor* self, int rank, uintptr_t sendbuff, uintptr_t recvBuff, size_t sendBuffSize, size_t recvBuffSize,
|
||||
DataType dataType, int nthreads, const ExecutionPlan& plan, uintptr_t stream, PacketType packetType) {
|
||||
DataType dataType, const ExecutionPlan& plan, uintptr_t stream, PacketType packetType) {
|
||||
self->execute(rank, reinterpret_cast<void*>(sendbuff), reinterpret_cast<void*>(recvBuff), sendBuffSize,
|
||||
recvBuffSize, dataType, nthreads, plan, (cudaStream_t)stream, packetType);
|
||||
recvBuffSize, dataType, plan, (cudaStream_t)stream, packetType);
|
||||
},
|
||||
nb::arg("rank"), nb::arg("sendbuff"), nb::arg("recvBuff"), nb::arg("sendBuffSize"), nb::arg("recvBuffSize"),
|
||||
nb::arg("dataType"), nb::arg("nthreads"), nb::arg("plan"), nb::arg("stream"),
|
||||
nb::arg("packetType") = PacketType::LL16);
|
||||
nb::arg("dataType"), nb::arg("plan"), nb::arg("stream"), nb::arg("packetType") = PacketType::LL16);
|
||||
}
|
||||
|
||||
@@ -81,10 +81,9 @@ def main(
|
||||
execution_paln_name: str,
|
||||
execution_plan_path: str,
|
||||
size: int,
|
||||
nthreads_per_block: int,
|
||||
dtype: cp.dtype = cp.float16,
|
||||
packet_type: PacketType = PacketType.LL16,
|
||||
seed: int = 42,
|
||||
seed: int = 42 + MPI.COMM_WORLD.rank,
|
||||
):
|
||||
mscclpp_group = mscclpp_comm.CommGroup(MPI.COMM_WORLD)
|
||||
cp.cuda.Device(mscclpp_group.my_rank % mscclpp_group.nranks_per_node).use()
|
||||
@@ -96,12 +95,9 @@ def main(
|
||||
|
||||
cp.random.seed(seed)
|
||||
nelems = size // cp.dtype(dtype).itemsize
|
||||
buffer = cp.random.random(nelems * mscclpp_group.nranks).astype(dtype)
|
||||
sub_arrays = cp.split(buffer, MPI.COMM_WORLD.size)
|
||||
sendbuf = sub_arrays[MPI.COMM_WORLD.rank]
|
||||
expected = cp.zeros_like(sendbuf)
|
||||
for i in range(mscclpp_group.nranks):
|
||||
expected += sub_arrays[i]
|
||||
sendbuf = cp.random.random(nelems).astype(dtype)
|
||||
expected = cp.asnumpy(sendbuf)
|
||||
expected = MPI.COMM_WORLD.allreduce(expected, op=MPI.SUM)
|
||||
mscclpp_group.barrier()
|
||||
|
||||
executor_func = lambda stream: executor.execute(
|
||||
@@ -111,7 +107,6 @@ def main(
|
||||
sendbuf.nbytes,
|
||||
sendbuf.nbytes,
|
||||
dtype_to_mscclpp_dtype(dtype),
|
||||
nthreads_per_block,
|
||||
execution_plan,
|
||||
stream.ptr,
|
||||
packet_type,
|
||||
@@ -130,7 +125,7 @@ def main(
|
||||
print(
|
||||
f"Rank: {MPI.COMM_WORLD.rank} Execution time: {execution_time} us, "
|
||||
f"data size: {sendbuf.nbytes} bytes data type: {dtype().dtype.name} "
|
||||
f"packet type: {packet_type} nthreads_per_block: {nthreads_per_block}"
|
||||
f"packet type: {packet_type}"
|
||||
)
|
||||
executor = None
|
||||
mscclpp_group = None
|
||||
@@ -141,7 +136,6 @@ if __name__ == "__main__":
|
||||
parser.add_argument("-n", "--execution_plan_name", type=str, required=True)
|
||||
parser.add_argument("-path", "--execution_plan_path", type=str, required=True)
|
||||
parser.add_argument("--size", type=str, required=True)
|
||||
parser.add_argument("--nthreads_per_block", type=int, required=True)
|
||||
parser.add_argument("--dtype", type=str, default="float16", help="Choose from float16, float32, int32")
|
||||
parser.add_argument("--packet_type", type=str, default="LL16", help="Choose from LL8, LL16")
|
||||
parser.add_argument("--seed", type=int, default=42)
|
||||
@@ -157,7 +151,6 @@ if __name__ == "__main__":
|
||||
args.execution_plan_name,
|
||||
args.execution_plan_path,
|
||||
buffer_size,
|
||||
args.nthreads_per_block,
|
||||
dtype,
|
||||
packet_type,
|
||||
args.seed,
|
||||
|
||||
@@ -630,7 +630,6 @@ def test_executor(mpi_group: MpiGroup, filename: str):
|
||||
sendbuf.nbytes,
|
||||
sendbuf.nbytes,
|
||||
DataType.float16,
|
||||
512,
|
||||
execution_plan,
|
||||
stream.ptr,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user