mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-04-19 22:39:11 +00:00
Fix in-place all-gather input buffer in executor_test (#372)
This commit is contained in:
@@ -77,6 +77,13 @@ def dtype_to_mscclpp_dtype(dtype):
|
||||
raise ValueError(f"Unknown data type: {dtype}")
|
||||
|
||||
|
||||
def determine_input_buf(sendbuf, recvbuf, in_place, execution_plan_name):
|
||||
if "allgather" in execution_plan_name and in_place:
|
||||
return recvbuf
|
||||
else:
|
||||
return sendbuf
|
||||
|
||||
|
||||
def determine_result_buf(sendbuf, recvbuf, in_place, execution_plan_name):
|
||||
if "allgather" in execution_plan_name:
|
||||
return recvbuf
|
||||
@@ -126,9 +133,9 @@ def main(
|
||||
|
||||
executor_func = lambda stream: executor.execute(
|
||||
MPI.COMM_WORLD.rank,
|
||||
sendbuf.data.ptr,
|
||||
determine_input_buf(sendbuf, recvbuf, in_place, execution_plan_name).data.ptr,
|
||||
determine_result_buf(sendbuf, recvbuf, in_place, execution_plan_name).data.ptr,
|
||||
sendbuf.nbytes,
|
||||
determine_input_buf(sendbuf, recvbuf, in_place, execution_plan_name).nbytes,
|
||||
determine_result_buf(sendbuf, recvbuf, in_place, execution_plan_name).nbytes,
|
||||
dtype_to_mscclpp_dtype(dtype),
|
||||
execution_plan,
|
||||
|
||||
Reference in New Issue
Block a user