mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-04-20 06:49:29 +00:00
Support ReduceScatter in the NCCL interface (#460)
Co-authored-by: root <root@mscclpp-000002.tn3ujtlnlkjehmmeegdavazkfg.jx.internal.cloudapp.net> Co-authored-by: Caio Rocha <aiorocha@microsoft.com> Co-authored-by: Changho Hwang <changhohwang@microsoft.com>
This commit is contained in:
@@ -718,11 +718,60 @@ NCCL_API ncclResult_t ncclAllReduce(const void* sendbuff, void* recvbuff, size_t
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
NCCL_API ncclResult_t ncclReduceScatter(const void*, void*, size_t, ncclDataType_t, ncclRedOp_t, ncclComm_t,
|
||||
cudaStream_t) {
|
||||
// TODO: implement this function
|
||||
WARN("ncclReduceScatter is currently unavailable");
|
||||
return ncclInternalError;
|
||||
NCCL_API ncclResult_t ncclReduceScatter(const void* sendbuff, void* recvbuff, size_t recvcount, ncclDataType_t datatype,
|
||||
ncclRedOp_t op, ncclComm_t comm, cudaStream_t stream) {
|
||||
size_t bytes = recvcount * ncclTypeSize(datatype);
|
||||
if (sendbuff == nullptr || recvbuff == nullptr || bytes == 0 || comm == nullptr) {
|
||||
WARN(
|
||||
"One or more of the following conditions is met: sendbuff or recvbuff pointer is nullptr, bytes is 0, "
|
||||
"or comm is nullptr.");
|
||||
return ncclInvalidArgument;
|
||||
}
|
||||
|
||||
int rank = comm->comm->bootstrap()->getRank();
|
||||
int nRank = comm->comm->bootstrap()->getNranks();
|
||||
|
||||
std::vector<executionPlanInstance>& plans = comm->executionPlans["reducescatter"];
|
||||
std::shared_ptr<mscclpp::ExecutionPlan> plan;
|
||||
void* basePtr = (char*)sendbuff + rank * bytes;
|
||||
bool inPlace = basePtr == recvbuff;
|
||||
const size_t totalBytes = bytes * nRank;
|
||||
for (const auto& p : plans) {
|
||||
if (totalBytes >= p.key.minMessageSize && totalBytes < p.key.maxMessageSize && inPlace == p.key.isInPlace) {
|
||||
plan = p.plan;
|
||||
break;
|
||||
}
|
||||
}
|
||||
// TODO: Fallback code for ReduceScatter
|
||||
if (plan == nullptr) {
|
||||
WARN("No FallBack code for ReduceScatter");
|
||||
return ncclInternalError;
|
||||
}
|
||||
|
||||
switch (datatype) {
|
||||
case ncclFloat16:
|
||||
comm->executor->execute(rank, (half*)sendbuff, (half*)recvbuff, totalBytes, bytes, mscclpp::DataType::FLOAT16,
|
||||
*plan, stream);
|
||||
break;
|
||||
case ncclFloat32:
|
||||
comm->executor->execute(rank, (float*)sendbuff, (float*)recvbuff, totalBytes, bytes, mscclpp::DataType::FLOAT32,
|
||||
*plan, stream);
|
||||
break;
|
||||
case ncclBfloat16:
|
||||
comm->executor->execute(rank, (__bfloat16*)sendbuff, (__bfloat16*)recvbuff, totalBytes, bytes,
|
||||
mscclpp::DataType::BFLOAT16, *plan, stream);
|
||||
break;
|
||||
case ncclInt32:
|
||||
case ncclUint32:
|
||||
comm->executor->execute(rank, (int*)sendbuff, (int*)recvbuff, totalBytes, bytes, mscclpp::DataType::UINT32, *plan,
|
||||
stream);
|
||||
break;
|
||||
default:
|
||||
WARN("datatype is invalid");
|
||||
return ncclInvalidArgument;
|
||||
}
|
||||
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
NCCL_API ncclResult_t ncclAllGather(const void* sendbuff, void* recvbuff, size_t sendcount, ncclDataType_t datatype,
|
||||
|
||||
@@ -202,7 +202,10 @@ class ReduceScatter(Collective):
|
||||
for i in range(self.num_ranks):
|
||||
for c in range(self.chunk_factor):
|
||||
input_buffer.append(Chunk(r, i * self.chunk_factor + c, i, c))
|
||||
buffers = {Buffer.input: input_buffer}
|
||||
buffers = {
|
||||
Buffer.input: input_buffer,
|
||||
Buffer.output: input_buffer[r * self.chunk_factor : (r + 1) * self.chunk_factor],
|
||||
}
|
||||
rank_buffers.append(buffers)
|
||||
else:
|
||||
input_buffer = []
|
||||
|
||||
@@ -145,17 +145,26 @@ def build_bufs(
|
||||
nelems_input = nelems if in_place else nelems // num_ranks
|
||||
else:
|
||||
nelems_input = nelems
|
||||
nelems_output = nelems
|
||||
|
||||
if "reducescatter" in collective:
|
||||
assert (nelems % num_ranks) == 0, "nelems %d not multiple of num_ranks %d" % (nelems, num_ranks)
|
||||
nelems_output = nelems // num_ranks
|
||||
else:
|
||||
nelems_output = nelems
|
||||
|
||||
result_buf = GpuBuffer(nelems_output, dtype=dtype)
|
||||
if in_place:
|
||||
if "allgather" in collective:
|
||||
input_buf = cp.split(result_buf, num_ranks)[rank]
|
||||
elif "reducescatter" in collective:
|
||||
input_buf = GpuBuffer(nelems_input, dtype=dtype)
|
||||
result_buf = cp.split(input_buf, num_ranks)[rank]
|
||||
else:
|
||||
input_buf = result_buf
|
||||
else:
|
||||
input_buf = GpuBuffer(nelems_input, dtype=dtype)
|
||||
test_buf = cp.zeros(nelems_output, dtype=dtype)
|
||||
|
||||
test_buf = cp.zeros(nelems, dtype=dtype)
|
||||
|
||||
return input_buf, result_buf, test_buf
|
||||
|
||||
|
||||
@@ -91,7 +91,7 @@ TEST_DATA_ALL_REDUCE(int32, int)
|
||||
} \
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < num_elems; i += blockDim.x * gridDim.x) { \
|
||||
if (i >= offset && i < offset + nem_elems_per_rank) { \
|
||||
assert(abs(float(result_buf[i]) - float(test_buf[i])) < 1e-3 * num_ranks); \
|
||||
assert(abs(float(result_buf[i - offset]) - float(test_buf[i])) < 1e-3 * num_ranks); \
|
||||
} \
|
||||
} \
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user