mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-04-19 22:39:11 +00:00
Support ReduceScatter in the NCCL interface (#460)
Co-authored-by: root <root@mscclpp-000002.tn3ujtlnlkjehmmeegdavazkfg.jx.internal.cloudapp.net> Co-authored-by: Caio Rocha <aiorocha@microsoft.com> Co-authored-by: Changho Hwang <changhohwang@microsoft.com>
This commit is contained in:
@@ -145,17 +145,26 @@ def build_bufs(
|
||||
nelems_input = nelems if in_place else nelems // num_ranks
|
||||
else:
|
||||
nelems_input = nelems
|
||||
nelems_output = nelems
|
||||
|
||||
if "reducescatter" in collective:
|
||||
assert (nelems % num_ranks) == 0, "nelems %d not multiple of num_ranks %d" % (nelems, num_ranks)
|
||||
nelems_output = nelems // num_ranks
|
||||
else:
|
||||
nelems_output = nelems
|
||||
|
||||
result_buf = GpuBuffer(nelems_output, dtype=dtype)
|
||||
if in_place:
|
||||
if "allgather" in collective:
|
||||
input_buf = cp.split(result_buf, num_ranks)[rank]
|
||||
elif "reducescatter" in collective:
|
||||
input_buf = GpuBuffer(nelems_input, dtype=dtype)
|
||||
result_buf = cp.split(input_buf, num_ranks)[rank]
|
||||
else:
|
||||
input_buf = result_buf
|
||||
else:
|
||||
input_buf = GpuBuffer(nelems_input, dtype=dtype)
|
||||
test_buf = cp.zeros(nelems_output, dtype=dtype)
|
||||
|
||||
test_buf = cp.zeros(nelems, dtype=dtype)
|
||||
|
||||
return input_buf, result_buf, test_buf
|
||||
|
||||
|
||||
Reference in New Issue
Block a user