diff --git a/python/mscclpp/language/types.py b/python/mscclpp/language/types.py index ef2cb9ca..3cde2a14 100644 --- a/python/mscclpp/language/types.py +++ b/python/mscclpp/language/types.py @@ -80,7 +80,7 @@ class Instruction(Enum): read_reduce_copy_send = "rrcs" reduce_send = "rs" copy = "copy" - reduce = "reduce" + reduce = "re" copy_packet = "cpkt" transform_to_packet = "tpkt" reduce_send_packet = "rspkt" diff --git a/src/include/execution_kernel.hpp b/src/include/execution_kernel.hpp index 0e7d5f29..93eb4578 100644 --- a/src/include/execution_kernel.hpp +++ b/src/include/execution_kernel.hpp @@ -428,9 +428,9 @@ MSCCLPP_DEVICE_INLINE void handleTransformToPacket(void* dst, void* src, size_t mscclpp::putPackets(dst, dstOffset, src, srcOffset, size, threadIdx.x, blockDim.x, flag); } -template +template MSCCLPP_DEVICE_INLINE void handleReduceSend(T* dst, uint32_t dstOffsetByBytes, T* src, uint32_t srcOffsetByBytes, - T* input, uint32_t* inputOffsets, + T* input, uint32_t* inputOffsets, int nSrcs, DeviceHandle* memoryChannels, uint8_t* outputChannelIndexes, uint32_t* outputOffsets, int nOutChannels, uint32_t size) { const size_t nInt4 = size / sizeof(int4); @@ -441,15 +441,17 @@ MSCCLPP_DEVICE_INLINE void handleReduceSend(T* dst, uint32_t dstOffsetByBytes, T int4* input4 = (int4*)input; for (size_t idx = threadIdx.x; idx < nInt4; idx += blockDim.x) { int4 tmp = src4[srcOffset4 + idx]; - for (int index = 0; index < nOutChannels; ++index) { + for (int index = 0; index < nSrcs; ++index) { size_t offset = inputOffsets[index] / sizeof(int4); int4 val = input4[offset + idx]; tmp = add_vectors(tmp, val); } dst4[dstOffset4 + idx] = tmp; - for (int index = 0; index < nOutChannels; ++index) { - size_t offset = outputOffsets[index] / sizeof(int4); - memoryChannels[outputChannelIndexes[index]].write(offset + idx, tmp); + if (SendToRemote) { + for (int index = 0; index < nOutChannels; ++index) { + size_t offset = outputOffsets[index] / sizeof(int4); + memoryChannels[outputChannelIndexes[index]].write(offset + idx, tmp); + } } } // handle rest of data @@ -458,14 +460,16 @@ MSCCLPP_DEVICE_INLINE void handleReduceSend(T* dst, uint32_t dstOffsetByBytes, T const size_t endIdx = (srcOffsetByBytes + size) / sizeof(T); for (size_t idx = threadIdx.x + startIdx; idx < endIdx; idx += blockDim.x) { T tmp = src[idx]; - for (int index = 0; index < nOutChannels; ++index) { + for (int index = 0; index < nSrcs; ++index) { size_t offset = inputOffsets[index] / sizeof(T); tmp = add_elements(tmp, input[offset + idx]); } dst[idx] = tmp; - for (int index = 0; index < nOutChannels; ++index) { - size_t offset = outputOffsets[index] / sizeof(T); - memoryChannels[outputChannelIndexes[index]].write(offset + idx, tmp); + if (SendToRemote) { + for (int index = 0; index < nOutChannels; ++index) { + size_t offset = outputOffsets[index] / sizeof(T); + memoryChannels[outputChannelIndexes[index]].write(offset + idx, tmp); + } } } } @@ -624,6 +628,12 @@ __global__ void executionKernel([[maybe_unused]] int rank /*for debug*/, T* inpu handleReduceSendPacket(dst, op.dstOffset, src, op.srcOffset, scratch, scratchSize, op.inputOffsets, op.nInputs, memoryChannels, op.outputChannelIndexes, op.outputOffsets, op.nOutputs, op.size, flag); + } else if (op.type == OperationType::REDUCE) { + T* dst = getBuffer(input, output, scratch, op.dstBufferType); + T* src = getBuffer(input, output, scratch, op.srcBufferType); + T* tmp = getBuffer(input, output, scratch, op.inputBufferType); + handleReduceSend(dst, op.dstOffset, src, op.srcOffset, tmp, op.inputOffsets, op.nInputs, memoryChannels, + op.outputChannelIndexes, op.outputOffsets, op.nOutputs, op.size); } else if (op.type == OperationType::REDUCE_PACKET) { T* dst = getBuffer(input, output, scratch, op.dstBufferType); T* src = getBuffer(input, output, scratch, op.srcBufferType); @@ -642,7 +652,7 @@ __global__ void executionKernel([[maybe_unused]] int rank /*for debug*/, T* inpu T* dst = getBuffer(input, output, scratch, op.dstBufferType); T* src = getBuffer(input, output, scratch, op.srcBufferType); T* tmp = getBuffer(input, output, scratch, op.inputBufferType); - handleReduceSend(dst, op.dstOffset, src, op.srcOffset, tmp, op.inputOffsets, memoryChannels, + handleReduceSend(dst, op.dstOffset, src, op.srcOffset, tmp, op.inputOffsets, op.nInputs, memoryChannels, op.outputChannelIndexes, op.outputOffsets, op.nOutputs, op.size); } #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900