Executor AllGather In-Place Support (#365)

This commit is contained in:
Caio Rocha
2024-10-21 05:45:56 -07:00
committed by GitHub
parent 4136153a76
commit c6e06cfad7
4 changed files with 95 additions and 41 deletions

View File

@@ -77,6 +77,15 @@ def dtype_to_mscclpp_dtype(dtype):
raise ValueError(f"Unknown data type: {dtype}")
def determine_result_buf(sendbuf, recvbuf, in_place, execution_plan_name):
if "allgather" in execution_plan_name:
return recvbuf
elif in_place:
return sendbuf
else:
return recvbuf
def main(
execution_plan_name: str,
execution_plan_path: str,
@@ -104,9 +113,11 @@ def main(
if "allgather" in execution_plan_name:
recvbuf = cp.zeros(nelems * mscclpp_group.nranks, dtype=dtype)
if in_place:
for i in range(nelems):
recvbuf[mscclpp_group.my_rank * nelems + i] = sendbuf[i]
expected = buffer
else:
cp.random.seed(seed)
recvbuf = cp.zeros(nelems, dtype=dtype)
expected = cp.zeros_like(sendbuf, dtype=dtype)
for i in range(mscclpp_group.nranks):
@@ -116,9 +127,9 @@ def main(
executor_func = lambda stream: executor.execute(
MPI.COMM_WORLD.rank,
sendbuf.data.ptr,
sendbuf.data.ptr if in_place else recvbuf.data.ptr,
determine_result_buf(sendbuf, recvbuf, in_place, execution_plan_name).data.ptr,
sendbuf.nbytes,
sendbuf.nbytes if in_place else recvbuf.nbytes,
determine_result_buf(sendbuf, recvbuf, in_place, execution_plan_name).nbytes,
dtype_to_mscclpp_dtype(dtype),
execution_plan,
stream.ptr,
@@ -129,10 +140,14 @@ def main(
executor_func(stream)
stream.synchronize()
assert cp.allclose(sendbuf if in_place else recvbuf, expected, atol=1e-2 * mscclpp_group.nranks)
assert cp.allclose(
determine_result_buf(sendbuf, recvbuf, in_place, execution_plan_name),
expected,
atol=1e-2 * mscclpp_group.nranks,
)
mscclpp_group.barrier()
execution_time = bench_time(100, 10, executor_func)
execution_time = bench_time(10, 10, executor_func)
if npkit_dump_dir is not None:
npkit.dump(npkit_dump_dir)
npkit.shutdown()