mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-04-19 22:39:11 +00:00
Executor AllGather In-Place Support (#365)
This commit is contained in:
@@ -77,6 +77,15 @@ def dtype_to_mscclpp_dtype(dtype):
|
||||
raise ValueError(f"Unknown data type: {dtype}")
|
||||
|
||||
|
||||
def determine_result_buf(sendbuf, recvbuf, in_place, execution_plan_name):
|
||||
if "allgather" in execution_plan_name:
|
||||
return recvbuf
|
||||
elif in_place:
|
||||
return sendbuf
|
||||
else:
|
||||
return recvbuf
|
||||
|
||||
|
||||
def main(
|
||||
execution_plan_name: str,
|
||||
execution_plan_path: str,
|
||||
@@ -104,9 +113,11 @@ def main(
|
||||
|
||||
if "allgather" in execution_plan_name:
|
||||
recvbuf = cp.zeros(nelems * mscclpp_group.nranks, dtype=dtype)
|
||||
if in_place:
|
||||
for i in range(nelems):
|
||||
recvbuf[mscclpp_group.my_rank * nelems + i] = sendbuf[i]
|
||||
expected = buffer
|
||||
else:
|
||||
cp.random.seed(seed)
|
||||
recvbuf = cp.zeros(nelems, dtype=dtype)
|
||||
expected = cp.zeros_like(sendbuf, dtype=dtype)
|
||||
for i in range(mscclpp_group.nranks):
|
||||
@@ -116,9 +127,9 @@ def main(
|
||||
executor_func = lambda stream: executor.execute(
|
||||
MPI.COMM_WORLD.rank,
|
||||
sendbuf.data.ptr,
|
||||
sendbuf.data.ptr if in_place else recvbuf.data.ptr,
|
||||
determine_result_buf(sendbuf, recvbuf, in_place, execution_plan_name).data.ptr,
|
||||
sendbuf.nbytes,
|
||||
sendbuf.nbytes if in_place else recvbuf.nbytes,
|
||||
determine_result_buf(sendbuf, recvbuf, in_place, execution_plan_name).nbytes,
|
||||
dtype_to_mscclpp_dtype(dtype),
|
||||
execution_plan,
|
||||
stream.ptr,
|
||||
@@ -129,10 +140,14 @@ def main(
|
||||
executor_func(stream)
|
||||
stream.synchronize()
|
||||
|
||||
assert cp.allclose(sendbuf if in_place else recvbuf, expected, atol=1e-2 * mscclpp_group.nranks)
|
||||
assert cp.allclose(
|
||||
determine_result_buf(sendbuf, recvbuf, in_place, execution_plan_name),
|
||||
expected,
|
||||
atol=1e-2 * mscclpp_group.nranks,
|
||||
)
|
||||
|
||||
mscclpp_group.barrier()
|
||||
execution_time = bench_time(100, 10, executor_func)
|
||||
execution_time = bench_time(10, 10, executor_func)
|
||||
if npkit_dump_dir is not None:
|
||||
npkit.dump(npkit_dump_dir)
|
||||
npkit.shutdown()
|
||||
|
||||
Reference in New Issue
Block a user