Tune threads per block for mscclpp executor (#345)

This commit is contained in:
Binyang Li
2024-09-18 17:21:47 -07:00
committed by GitHub
parent 0c7311e83f
commit b30bb260e3
12 changed files with 43 additions and 46 deletions

View File

@@ -81,10 +81,9 @@ def main(
execution_paln_name: str,
execution_plan_path: str,
size: int,
nthreads_per_block: int,
dtype: cp.dtype = cp.float16,
packet_type: PacketType = PacketType.LL16,
seed: int = 42,
seed: int = 42 + MPI.COMM_WORLD.rank,
):
mscclpp_group = mscclpp_comm.CommGroup(MPI.COMM_WORLD)
cp.cuda.Device(mscclpp_group.my_rank % mscclpp_group.nranks_per_node).use()
@@ -96,12 +95,9 @@ def main(
cp.random.seed(seed)
nelems = size // cp.dtype(dtype).itemsize
buffer = cp.random.random(nelems * mscclpp_group.nranks).astype(dtype)
sub_arrays = cp.split(buffer, MPI.COMM_WORLD.size)
sendbuf = sub_arrays[MPI.COMM_WORLD.rank]
expected = cp.zeros_like(sendbuf)
for i in range(mscclpp_group.nranks):
expected += sub_arrays[i]
sendbuf = cp.random.random(nelems).astype(dtype)
expected = cp.asnumpy(sendbuf)
expected = MPI.COMM_WORLD.allreduce(expected, op=MPI.SUM)
mscclpp_group.barrier()
executor_func = lambda stream: executor.execute(
@@ -111,7 +107,6 @@ def main(
sendbuf.nbytes,
sendbuf.nbytes,
dtype_to_mscclpp_dtype(dtype),
nthreads_per_block,
execution_plan,
stream.ptr,
packet_type,
@@ -130,7 +125,7 @@ def main(
print(
f"Rank: {MPI.COMM_WORLD.rank} Execution time: {execution_time} us, "
f"data size: {sendbuf.nbytes} bytes data type: {dtype().dtype.name} "
f"packet type: {packet_type} nthreads_per_block: {nthreads_per_block}"
f"packet type: {packet_type}"
)
executor = None
mscclpp_group = None
@@ -141,7 +136,6 @@ if __name__ == "__main__":
parser.add_argument("-n", "--execution_plan_name", type=str, required=True)
parser.add_argument("-path", "--execution_plan_path", type=str, required=True)
parser.add_argument("--size", type=str, required=True)
parser.add_argument("--nthreads_per_block", type=int, required=True)
parser.add_argument("--dtype", type=str, default="float16", help="Choose from float16, float32, int32")
parser.add_argument("--packet_type", type=str, default="LL16", help="Choose from LL8, LL16")
parser.add_argument("--seed", type=int, default=42)
@@ -157,7 +151,6 @@ if __name__ == "__main__":
args.execution_plan_name,
args.execution_plan_path,
buffer_size,
args.nthreads_per_block,
dtype,
packet_type,
args.seed,