Reduce Operation Support to the Executor (#484)

Co-authored-by: Binyang Li <binyli@microsoft.com>
This commit is contained in:
Caio Rocha
2025-03-25 13:58:12 -07:00
committed by GitHub
parent b4062462fd
commit ac5cc647e0
2 changed files with 22 additions and 12 deletions

View File

@@ -428,9 +428,9 @@ MSCCLPP_DEVICE_INLINE void handleTransformToPacket(void* dst, void* src, size_t
mscclpp::putPackets<PacketType>(dst, dstOffset, src, srcOffset, size, threadIdx.x, blockDim.x, flag);
}
template <typename T>
template <typename T, bool SendToRemote = true>
MSCCLPP_DEVICE_INLINE void handleReduceSend(T* dst, uint32_t dstOffsetByBytes, T* src, uint32_t srcOffsetByBytes,
T* input, uint32_t* inputOffsets,
T* input, uint32_t* inputOffsets, int nSrcs,
DeviceHandle<MemoryChannel>* memoryChannels, uint8_t* outputChannelIndexes,
uint32_t* outputOffsets, int nOutChannels, uint32_t size) {
const size_t nInt4 = size / sizeof(int4);
@@ -441,15 +441,17 @@ MSCCLPP_DEVICE_INLINE void handleReduceSend(T* dst, uint32_t dstOffsetByBytes, T
int4* input4 = (int4*)input;
for (size_t idx = threadIdx.x; idx < nInt4; idx += blockDim.x) {
int4 tmp = src4[srcOffset4 + idx];
for (int index = 0; index < nOutChannels; ++index) {
for (int index = 0; index < nSrcs; ++index) {
size_t offset = inputOffsets[index] / sizeof(int4);
int4 val = input4[offset + idx];
tmp = add_vectors<T>(tmp, val);
}
dst4[dstOffset4 + idx] = tmp;
for (int index = 0; index < nOutChannels; ++index) {
size_t offset = outputOffsets[index] / sizeof(int4);
memoryChannels[outputChannelIndexes[index]].write<int4>(offset + idx, tmp);
if (SendToRemote) {
for (int index = 0; index < nOutChannels; ++index) {
size_t offset = outputOffsets[index] / sizeof(int4);
memoryChannels[outputChannelIndexes[index]].write<int4>(offset + idx, tmp);
}
}
}
// handle rest of data
@@ -458,14 +460,16 @@ MSCCLPP_DEVICE_INLINE void handleReduceSend(T* dst, uint32_t dstOffsetByBytes, T
const size_t endIdx = (srcOffsetByBytes + size) / sizeof(T);
for (size_t idx = threadIdx.x + startIdx; idx < endIdx; idx += blockDim.x) {
T tmp = src[idx];
for (int index = 0; index < nOutChannels; ++index) {
for (int index = 0; index < nSrcs; ++index) {
size_t offset = inputOffsets[index] / sizeof(T);
tmp = add_elements(tmp, input[offset + idx]);
}
dst[idx] = tmp;
for (int index = 0; index < nOutChannels; ++index) {
size_t offset = outputOffsets[index] / sizeof(T);
memoryChannels[outputChannelIndexes[index]].write<T>(offset + idx, tmp);
if (SendToRemote) {
for (int index = 0; index < nOutChannels; ++index) {
size_t offset = outputOffsets[index] / sizeof(T);
memoryChannels[outputChannelIndexes[index]].write<T>(offset + idx, tmp);
}
}
}
}
@@ -624,6 +628,12 @@ __global__ void executionKernel([[maybe_unused]] int rank /*for debug*/, T* inpu
handleReduceSendPacket<T, PacketType>(dst, op.dstOffset, src, op.srcOffset, scratch, scratchSize, op.inputOffsets,
op.nInputs, memoryChannels, op.outputChannelIndexes, op.outputOffsets,
op.nOutputs, op.size, flag);
} else if (op.type == OperationType::REDUCE) {
T* dst = getBuffer(input, output, scratch, op.dstBufferType);
T* src = getBuffer(input, output, scratch, op.srcBufferType);
T* tmp = getBuffer(input, output, scratch, op.inputBufferType);
handleReduceSend<T, false>(dst, op.dstOffset, src, op.srcOffset, tmp, op.inputOffsets, op.nInputs, memoryChannels,
op.outputChannelIndexes, op.outputOffsets, op.nOutputs, op.size);
} else if (op.type == OperationType::REDUCE_PACKET) {
T* dst = getBuffer(input, output, scratch, op.dstBufferType);
T* src = getBuffer(input, output, scratch, op.srcBufferType);
@@ -642,7 +652,7 @@ __global__ void executionKernel([[maybe_unused]] int rank /*for debug*/, T* inpu
T* dst = getBuffer(input, output, scratch, op.dstBufferType);
T* src = getBuffer(input, output, scratch, op.srcBufferType);
T* tmp = getBuffer(input, output, scratch, op.inputBufferType);
handleReduceSend(dst, op.dstOffset, src, op.srcOffset, tmp, op.inputOffsets, memoryChannels,
handleReduceSend(dst, op.dstOffset, src, op.srcOffset, tmp, op.inputOffsets, op.nInputs, memoryChannels,
op.outputChannelIndexes, op.outputOffsets, op.nOutputs, op.size);
}
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900