From 9cbe78c1885b3df32775f9843add4d3e68c779fd Mon Sep 17 00:00:00 2001 From: Caio Rocha Date: Mon, 6 Apr 2026 23:18:16 +0000 Subject: [PATCH] wip --- .../tests/executor_tests/reduce_pack.py | 67 +++++++ .../tests/executor_tests/reduce_pack_tbg.py | 70 +++++++ .../tests/executor_tests/transfer_pack.py | 60 ++++++ .../tests/executor_tests/transfer_pack_tbg.py | 63 +++++++ src/core/executor/execution_plan.cc | 31 +-- src/core/include/execution_common.hpp | 3 + src/core/include/execution_kernel.hpp | 178 +++++++++++++----- 7 files changed, 409 insertions(+), 63 deletions(-) create mode 100644 python/mscclpp/language/tests/executor_tests/reduce_pack.py create mode 100644 python/mscclpp/language/tests/executor_tests/reduce_pack_tbg.py create mode 100644 python/mscclpp/language/tests/executor_tests/transfer_pack.py create mode 100644 python/mscclpp/language/tests/executor_tests/transfer_pack_tbg.py diff --git a/python/mscclpp/language/tests/executor_tests/reduce_pack.py b/python/mscclpp/language/tests/executor_tests/reduce_pack.py new file mode 100644 index 00000000..bf584a0f --- /dev/null +++ b/python/mscclpp/language/tests/executor_tests/reduce_pack.py @@ -0,0 +1,67 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import argparse +from mscclpp.language.channel import * +from mscclpp.language.rank import * +from mscclpp.language.general import * +from mscclpp.language.program import * +from mscclpp.language.collectives import * +from enum import Enum + +def allreduce_example(name, num_threads_per_block, min_message_size, max_message_size): + chunksperloop = 1 + gpu_size = 2 + collective = AllReduce(gpu_size, chunksperloop, True) + with CollectiveProgram( + name, + collective, + gpu_size, + protocol="LL", + num_threads_per_block=num_threads_per_block, + use_double_scratch_buffer=True, + min_message_size=min_message_size, + max_message_size=max_message_size, + ): + # Declaring Ranks, Channels, Buffers for 2 GPU allgather example + first_rank = Rank(0) + second_rank = Rank(1) + first_ch = MemoryChannel(1, 0) + second_ch = MemoryChannel(0, 1) + first_input_buffer = first_rank.get_input_buffer() + second_input_buffer = second_rank.get_input_buffer() + first_scratch_buffer = Buffer(0, 3) + second_scratch_buffer = Buffer(1, 3) + + # First rank puts packets in the remote scratch buffer of the second rank + first_ch.put_packets(second_scratch_buffer[1: 2], first_input_buffer[1 : 2], tb=0) + second_ch.put_packets(first_scratch_buffer[0 : 1], second_input_buffer[0 : 1], tb=0) + + # Second rank copy packets to scratch buffer and then read put packets to first rank output buffer + first_rank.reduce(first_input_buffer[0 : 1], [first_scratch_buffer[0 : 1]], tb=1, packet=True) + first_ch.put_packets(second_scratch_buffer[0 : 1], first_input_buffer[0 : 1], tb=1) + + # First rank copy packets to scratch buffer and then read put packets to second rank output buffer + second_rank.reduce(second_input_buffer[1 : 2], [second_scratch_buffer[1 : 2]], tb=1, packet=True) + second_rank.copy_packets(second_scratch_buffer[2: 3], second_input_buffer[1 : 2], tb=1) + second_ch.read_put_packets(first_scratch_buffer[1 : 2], second_scratch_buffer[2: 3], tb=1) + + # First rank copy packets to scratch buffer and then read put packets to second rank output buffer + first_rank.unpack_packets(first_input_buffer[1 : 2], first_scratch_buffer[1 : 2], tb=2) + second_rank.unpack_packets(second_input_buffer[0 : 1], second_scratch_buffer[0 : 1], tb=2) + + + print(JSON()) + + +parser = argparse.ArgumentParser() + +parser.add_argument("--name", type=str, help="name of the program") +parser.add_argument("--num_gpus", type=int, help="number of gpus") +parser.add_argument("--num_threads_per_block", type=int, default=1024, help="number of threads per block") +parser.add_argument("--min_message_size", type=int, default=0, help="minimum message size") +parser.add_argument("--max_message_size", type=int, default=2**64 - 1, help="maximum message size") + +args = parser.parse_args() + +allreduce_example(args.name, args.num_threads_per_block, args.min_message_size, args.max_message_size) diff --git a/python/mscclpp/language/tests/executor_tests/reduce_pack_tbg.py b/python/mscclpp/language/tests/executor_tests/reduce_pack_tbg.py new file mode 100644 index 00000000..9f7f9df8 --- /dev/null +++ b/python/mscclpp/language/tests/executor_tests/reduce_pack_tbg.py @@ -0,0 +1,70 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import argparse +from mscclpp.language.channel import * +from mscclpp.language.rank import * +from mscclpp.language.general import * +from mscclpp.language.program import * +from mscclpp.language.collectives import * +from enum import Enum + +def allreduce_example(name, num_threads_per_block, min_message_size, max_message_size): + chunksperloop = 1 + gpu_size = 2 + collective = AllReduce(gpu_size, chunksperloop, True) + with CollectiveProgram( + name, + collective, + gpu_size, + protocol="LL", + num_threads_per_block=num_threads_per_block, + use_double_scratch_buffer=True, + min_message_size=min_message_size, + max_message_size=max_message_size, + ): + # Declaring Ranks, Channels, Buffers for 2 GPU allgather example + first_rank = Rank(0) + second_rank = Rank(1) + first_ch = MemoryChannel(1, 0) + second_ch = MemoryChannel(0, 1) + first_input_buffer = first_rank.get_input_buffer() + second_input_buffer = second_rank.get_input_buffer() + first_scratch_buffer = Buffer(0, 3) + second_scratch_buffer = Buffer(1, 3) + tbg = [] + for i in range(3): + tbg.append(ThreadBlockGroup(tb_list = [2 * i, 2 * i + 1])) + + # First rank puts packets in the remote scratch buffer of the second rank + first_ch.put_packets(second_scratch_buffer[1: 2], first_input_buffer[1 : 2], tb_group=tbg[0]) + second_ch.put_packets(first_scratch_buffer[0 : 1], second_input_buffer[0 : 1], tb_group=tbg[0]) + + # Second rank copy packets to scratch buffer and then read put packets to first rank output buffer + first_rank.reduce(first_input_buffer[0 : 1], [first_scratch_buffer[0 : 1]], tb_group=tbg[1], packet=True) + first_ch.put_packets(second_scratch_buffer[0 : 1], first_input_buffer[0 : 1], tb_group=tbg[1]) + + # First rank copy packets to scratch buffer and then read put packets to second rank output buffer + second_rank.reduce(second_input_buffer[1 : 2], [second_scratch_buffer[1 : 2]], tb_group=tbg[1], packet=True) + second_rank.copy_packets(second_scratch_buffer[2: 3], second_input_buffer[1 : 2], tb_group=tbg[1]) + second_ch.read_put_packets(first_scratch_buffer[1 : 2], second_scratch_buffer[2: 3], tb_group=tbg[1]) + + # First rank copy packets to scratch buffer and then read put packets to second rank output buffer + first_rank.unpack_packets(first_input_buffer[1 : 2], first_scratch_buffer[1 : 2], tb_group=tbg[2]) + second_rank.unpack_packets(second_input_buffer[0 : 1], second_scratch_buffer[0 : 1], tb_group=tbg[2]) + + + print(JSON()) + + +parser = argparse.ArgumentParser() + +parser.add_argument("--name", type=str, help="name of the program") +parser.add_argument("--num_gpus", type=int, help="number of gpus") +parser.add_argument("--num_threads_per_block", type=int, default=1024, help="number of threads per block") +parser.add_argument("--min_message_size", type=int, default=0, help="minimum message size") +parser.add_argument("--max_message_size", type=int, default=2**64 - 1, help="maximum message size") + +args = parser.parse_args() + +allreduce_example(args.name, args.num_threads_per_block, args.min_message_size, args.max_message_size) diff --git a/python/mscclpp/language/tests/executor_tests/transfer_pack.py b/python/mscclpp/language/tests/executor_tests/transfer_pack.py new file mode 100644 index 00000000..bb3b5592 --- /dev/null +++ b/python/mscclpp/language/tests/executor_tests/transfer_pack.py @@ -0,0 +1,60 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import argparse +from mscclpp.language.channel import * +from mscclpp.language.rank import * +from mscclpp.language.general import * +from mscclpp.language.program import * +from mscclpp.language.collectives import * +from enum import Enum + +def allgather_example(name, num_threads_per_block, min_message_size, max_message_size): + chunksperloop = 1 + gpu_size = 2 + collective = AllGather(gpu_size, chunksperloop, True) + with CollectiveProgram( + name, + collective, + gpu_size, + protocol="LL", + num_threads_per_block=num_threads_per_block, + use_double_scratch_buffer=True, + min_message_size=min_message_size, + max_message_size=max_message_size, + ): + # Declaring Ranks, Channels, Buffers for 2 GPU allgather example + first_rank = Rank(0) + second_rank = Rank(1) + first_ch = MemoryChannel(1, 0) + second_ch = MemoryChannel(0, 1) + first_output_buffer = first_rank.get_output_buffer() + second_output_buffer = second_rank.get_output_buffer() + first_scratch_buffer = Buffer(0, 2) + second_scratch_buffer = Buffer(1, 2) + + # First rank puts packets in the remote scratch buffer of the second rank + first_ch.put_packets(second_scratch_buffer[0: 1], first_output_buffer[0 : 1], tb=0) + + # Second rank copy packets to scratch buffer and then read put packets to first rank output buffer + second_rank.copy_packets(second_scratch_buffer[1 : 2], second_output_buffer[1 : 2], tb=0) + second_ch.read_put_packets(first_scratch_buffer[1 : 2], second_scratch_buffer[1 : 2], tb=1) + + # Copying packets from local scratch buffer to local output buffer + first_rank.unpack_packets(first_output_buffer[1 : 2], first_scratch_buffer[1 : 2], tb=1) + second_rank.unpack_packets(second_output_buffer[0 : 1], second_scratch_buffer[0 : 1], tb=2) + + print(JSON()) + + +parser = argparse.ArgumentParser() + +parser.add_argument("--name", type=str, help="name of the program") +parser.add_argument("--num_gpus", type=int, help="number of gpus") +parser.add_argument("--num_threads_per_block", type=int, default=1024, help="number of threads per block") +parser.add_argument("--min_message_size", type=int, default=0, help="minimum message size") +parser.add_argument("--max_message_size", type=int, default=2**64 - 1, help="maximum message size") + +args = parser.parse_args() + +allgather_example(args.name, args.num_threads_per_block, args.min_message_size, args.max_message_size) diff --git a/python/mscclpp/language/tests/executor_tests/transfer_pack_tbg.py b/python/mscclpp/language/tests/executor_tests/transfer_pack_tbg.py new file mode 100644 index 00000000..41c81a6e --- /dev/null +++ b/python/mscclpp/language/tests/executor_tests/transfer_pack_tbg.py @@ -0,0 +1,63 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import argparse +from mscclpp.language.channel import * +from mscclpp.language.rank import * +from mscclpp.language.general import * +from mscclpp.language.program import * +from mscclpp.language.collectives import * +from enum import Enum + +def allgather_example(name, num_threads_per_block, min_message_size, max_message_size): + chunksperloop = 1 + gpu_size = 2 + collective = AllGather(gpu_size, chunksperloop, True) + with CollectiveProgram( + name, + collective, + gpu_size, + protocol="LL", + num_threads_per_block=num_threads_per_block, + use_double_scratch_buffer=True, + min_message_size=min_message_size, + max_message_size=max_message_size, + ): + # Declaring Ranks, Channels, Buffers for 2 GPU allgather example + first_rank = Rank(0) + second_rank = Rank(1) + first_ch = MemoryChannel(1, 0) + second_ch = MemoryChannel(0, 1) + first_output_buffer = first_rank.get_output_buffer() + second_output_buffer = second_rank.get_output_buffer() + first_scratch_buffer = Buffer(0, 2) + second_scratch_buffer = Buffer(1, 2) + tbg = [] + for i in range(3): + tbg.append(ThreadBlockGroup(tb_list = [2 * i, 2 * i + 1])) + + # First rank puts packets in the remote scratch buffer of the second rank + first_ch.put_packets(second_scratch_buffer[0: 1], first_output_buffer[0 : 1], tb_group=tbg[0]) + + # Second rank copy packets to scratch buffer and then read put packets to first rank output buffer + second_rank.copy_packets(second_scratch_buffer[1 : 2], second_output_buffer[1 : 2], tb_group=tbg[0]) + second_ch.read_put_packets(first_scratch_buffer[1 : 2], second_scratch_buffer[1 : 2], tb_group=tbg[1]) + + # Copying packets from local scratch buffer to local output buffer + first_rank.unpack_packets(first_output_buffer[1 : 2], first_scratch_buffer[1 : 2], tb_group=tbg[1]) + second_rank.unpack_packets(second_output_buffer[0 : 1], second_scratch_buffer[0 : 1], tb_group=tbg[2]) + + print(JSON()) + + +parser = argparse.ArgumentParser() + +parser.add_argument("--name", type=str, help="name of the program") +parser.add_argument("--num_gpus", type=int, help="number of gpus") +parser.add_argument("--num_threads_per_block", type=int, default=1024, help="number of threads per block") +parser.add_argument("--min_message_size", type=int, default=0, help="minimum message size") +parser.add_argument("--max_message_size", type=int, default=2**64 - 1, help="maximum message size") + +args = parser.parse_args() + +allgather_example(args.name, args.num_threads_per_block, args.min_message_size, args.max_message_size) diff --git a/src/core/executor/execution_plan.cc b/src/core/executor/execution_plan.cc index 98ec3ab6..0437f9d1 100644 --- a/src/core/executor/execution_plan.cc +++ b/src/core/executor/execution_plan.cc @@ -496,9 +496,6 @@ void ExecutionPlan::Impl::setupOperation(const nlohmann::json& op, Operation& op throw Error("Invalid channel type", ErrorCode::ExecutorError); }; - uint32_t tbId = 0; - uint32_t tbgSize = 1; - operation.type = static_cast(getOpType(op["name"])); if (op.contains("channel_type")) { operation.channelType = convertToChannelType(op["channel_type"]); @@ -509,12 +506,16 @@ void ExecutionPlan::Impl::setupOperation(const nlohmann::json& op, Operation& op operation.channelIndexes[i] = op["channel_ids"][i]; } } - if (op.contains("tbg_info")) { - tbId = op["tbg_info"]["tb_id"]; - tbgSize = op["tbg_info"]["tbg_size"]; - } if (op.contains("src_buff")) { operation.nInputs = op["src_buff"].size(); + if (op.contains("tbg_info")) { + operation.tbId = op["tbg_info"]["tb_id"]; + operation.tbgSize = op["tbg_info"]["tbg_size"]; + } + else{ + operation.tbId = 0; + operation.tbgSize = 1; + } for (int i = 0; i < operation.nInputs; i++) { auto& buff = op["src_buff"][i]; size_t constOffset = 0; @@ -537,14 +538,22 @@ void ExecutionPlan::Impl::setupOperation(const nlohmann::json& op, Operation& op } size_t inputOffset = this->getOffset(this->inputSize, this->outputSize, buff["index"], bufferType) + constOffset; size_t inputBufferSize = this->getBufferSize(this->inputSize, this->outputSize, buff["index"], buff["size"]); - inputOffset += calcOffset(inputBufferSize, tbId, tbgSize); - inputBufferSize = calcSize(inputBufferSize, tbId, tbgSize); + inputOffset += calcOffset(inputBufferSize, 0, 1); + inputBufferSize = calcSize(inputBufferSize, 0, 1); operation.inputOffsets[i] = inputOffset; operation.inputBufferSizes[i] = inputBufferSize; } } if (op.contains("dst_buff")) { operation.nOutputs = op["dst_buff"].size(); + if (op.contains("tbg_info")) { + operation.tbId = op["tbg_info"]["tb_id"]; + operation.tbgSize = op["tbg_info"]["tbg_size"]; + } + else{ + operation.tbId = 0; + operation.tbgSize = 1; + } for (int i = 0; i < operation.nOutputs; i++) { auto& buff = op["dst_buff"][i]; size_t constOffset = 0; @@ -567,8 +576,8 @@ void ExecutionPlan::Impl::setupOperation(const nlohmann::json& op, Operation& op } size_t outputOffset = this->getOffset(this->inputSize, this->outputSize, buff["index"], bufferType) + constOffset; size_t outputBufferSize = this->getBufferSize(this->inputSize, this->outputSize, buff["index"], buff["size"]); - outputOffset += calcOffset(outputBufferSize, tbId, tbgSize); - outputBufferSize = calcSize(outputBufferSize, tbId, tbgSize); + outputOffset += calcOffset(outputBufferSize, 0, 1); + outputBufferSize = calcSize(outputBufferSize, 0, 1); operation.outputOffsets[i] = outputOffset; operation.outputBufferSizes[i] = outputBufferSize; } diff --git a/src/core/include/execution_common.hpp b/src/core/include/execution_common.hpp index d071ce7d..3272c4ef 100644 --- a/src/core/include/execution_common.hpp +++ b/src/core/include/execution_common.hpp @@ -118,6 +118,9 @@ struct Operation { uint8_t nChannels; uint8_t nInputs; uint8_t nOutputs; + + uint8_t tbId; + uint8_t tbgSize; }; struct { uint32_t unitSize; diff --git a/src/core/include/execution_kernel.hpp b/src/core/include/execution_kernel.hpp index 20147c30..01f25435 100644 --- a/src/core/include/execution_kernel.hpp +++ b/src/core/include/execution_kernel.hpp @@ -66,6 +66,24 @@ MSCCLPP_DEVICE_INLINE uint32_t getOffset(BufferType bufferType, uint32_t offset) } } +// Mirrors ExecutionPlan::Impl::calcOffset from execution_plan.cc. +// Computes the byte offset for the index-th slice when splitting size bytes into slices parts, +// aligned to 16 bytes (matching the default bufferAlignment). +MSCCLPP_DEVICE_INLINE uint32_t calcOffset(uint32_t size, uint32_t index, uint32_t slices) { + constexpr uint32_t alignment = 16; + uint32_t nelems = size / alignment; + uint32_t minNelems = nelems / slices; + uint32_t remainder = nelems % slices; + uint32_t off = index * minNelems + (index < remainder ? index : remainder); + return off * alignment; +} + +// Mirrors ExecutionPlan::Impl::calcSize from execution_plan.cc. +// Computes the byte size of the index-th slice when splitting size bytes into slices parts. +MSCCLPP_DEVICE_INLINE uint32_t calcSize(uint32_t size, uint32_t index, uint32_t slices) { + return calcOffset(size, index + 1, slices) - calcOffset(size, index, slices); +} + template MSCCLPP_DEVICE_INLINE void executeDeviceFunction(const Operation& op, T* input, T* output, T* scratch, uint8_t* nSteps = nullptr, uint32_t offset = 0, @@ -129,14 +147,18 @@ template MSCCLPP_DEVICE_INLINE void handleGet(const Operation& op, void* input, void* output, void* scratch, uint32_t offset, uint32_t unitSize) { const uint32_t count = op.nInputs; + const uint8_t tbId = op.tbId; + const uint8_t tbgSize = op.tbgSize; const uint32_t* sizes = op.inputBufferSizes; const uint32_t* srcOffsets = op.inputOffsets; const uint32_t* dstOffsets = op.outputOffsets; for (uint32_t i = 0; i < count; i++) { - uint32_t dstOffset = dstOffsets[i] + getOffset(op.outputBufferRefs[i].type, offset); - uint32_t srcOffset = - srcOffsets[i] + getOffset(memoryChannelBufferTypes_[op.inputBufferRefs[i].id], offset); uint32_t size = min(sizes[i] - offset, unitSize); + uint32_t dstOffset = + dstOffsets[i] + getOffset(op.outputBufferRefs[i].type, offset + calcOffset(size, tbId, tbgSize)); + uint32_t srcOffset = srcOffsets[i] + getOffset(memoryChannelBufferTypes_[op.inputBufferRefs[i].id], + offset + calcOffset(size, tbId, tbgSize)); + size = calcSize(size, tbId, tbgSize); char* remoteMemory = static_cast(memoryChannelBufferPtrs_[op.inputBufferRefs[i].id]); mscclpp::copy(static_cast(getBuffer(input, output, scratch, op.outputBufferRefs[i].type)) + srcOffset, remoteMemory + dstOffset, size, threadIdx.x, blockDim.x); @@ -148,6 +170,8 @@ MSCCLPP_DEVICE_INLINE void handlePut(const Operation& op, void* input, void* out uint32_t unitSize) { ChannelType chType = op.channelType; uint32_t count = op.nOutputs; + const uint8_t tbId = op.tbId; + const uint8_t tbgSize = op.tbgSize; const uint8_t* channelIndexes = op.channelIndexes; const uint32_t* dstOffsets = op.outputOffsets; const uint32_t* srcOffsets = op.inputOffsets; @@ -155,10 +179,12 @@ MSCCLPP_DEVICE_INLINE void handlePut(const Operation& op, void* input, void* out char* src = static_cast(getBuffer(input, output, scratch, op.inputBufferRefs[0].type)); if (chType == ChannelType::MEMORY) { for (uint32_t i = 0; i < count; i++) { - uint32_t dstOffset = - dstOffsets[i] + getOffset(memoryChannelBufferTypes_[op.outputBufferRefs[i].id], offset); - uint32_t srcOffset = srcOffsets[i] + getOffset(op.inputBufferRefs[i].type, offset); uint32_t size = min(outputSizes[i] - offset, unitSize); + uint32_t dstOffset = dstOffsets[i] + getOffset(memoryChannelBufferTypes_[op.outputBufferRefs[i].id], + offset + calcOffset(size, tbId, tbgSize)); + uint32_t srcOffset = + srcOffsets[i] + getOffset(op.inputBufferRefs[i].type, offset + calcOffset(size, tbId, tbgSize)); + size = calcSize(size, tbId, tbgSize); char* remoteMemory = static_cast(memoryChannelBufferPtrs_[op.outputBufferRefs[i].id]); mscclpp::copy(remoteMemory + dstOffset, src + srcOffset, size, threadIdx.x, blockDim.x); } @@ -188,12 +214,16 @@ MSCCLPP_DEVICE_INLINE void handlePut(const Operation& op, void* input, void* out template MSCCLPP_DEVICE_INLINE void handleReadReduceSend(const Operation& op, void* input, void* output, void* scratch, uint32_t offset, uint32_t unitSize) { - const uint32_t size = min(op.inputBufferSizes[0] - offset, unitSize); + uint32_t size = min(op.inputBufferSizes[0] - offset, unitSize); + const uint8_t tbId = op.tbId; + const uint8_t tbgSize = op.tbgSize; + const uint32_t tbgOffset = calcOffset(size, tbId, tbgSize); + size = calcSize(size, tbId, tbgSize); const uint32_t nInt4 = size / sizeof(int4); const uint32_t inputOffset4 = - (op.inputOffsets[0] + getOffset(op.inputBufferRefs[0].type, offset)) / sizeof(int4); + (op.inputOffsets[0] + getOffset(op.inputBufferRefs[0].type, offset + tbgOffset)) / sizeof(int4); const uint32_t outputOffset4 = - (op.outputOffsets[0] + getOffset(op.outputBufferRefs[0].type, offset)) / sizeof(int4); + (op.outputOffsets[0] + getOffset(op.outputBufferRefs[0].type, offset + tbgOffset)) / sizeof(int4); const uint8_t nRemoteInputs = op.nInputs - 1; const uint8_t nRemoteOutputs = op.nOutputs - 1; const uint32_t* srcOffsets = op.inputOffsets + 1; @@ -206,7 +236,7 @@ MSCCLPP_DEVICE_INLINE void handleReadReduceSend(const Operation& op, void* input int4 val; uint32_t srcOffset = (srcOffsets[index] + - getOffset(memoryChannelBufferTypes_[op.inputBufferRefs[index + 1].id], offset)) / + getOffset(memoryChannelBufferTypes_[op.inputBufferRefs[index + 1].id], offset + tbgOffset)) / sizeof(int4); void* remoteMemory = static_cast(memoryChannelBufferPtrs_[op.inputBufferRefs[index + 1].id]); val = mscclpp::read(remoteMemory, srcOffset + idx); @@ -216,8 +246,8 @@ MSCCLPP_DEVICE_INLINE void handleReadReduceSend(const Operation& op, void* input if constexpr (SendToRemote) { for (int index = 0; index < nRemoteOutputs; ++index) { uint32_t dstOffset = - (dstOffsets[index] + - getOffset(memoryChannelBufferTypes_[op.outputBufferRefs[index + 1].id], offset)) / + (dstOffsets[index] + getOffset(memoryChannelBufferTypes_[op.outputBufferRefs[index + 1].id], + offset + tbgOffset)) / sizeof(int4); void* remoteMemory = static_cast(memoryChannelBufferPtrs_[op.outputBufferRefs[index + 1].id]); mscclpp::write(remoteMemory, dstOffset + idx, tmp); @@ -227,15 +257,16 @@ MSCCLPP_DEVICE_INLINE void handleReadReduceSend(const Operation& op, void* input // handle rest of data uint32_t processed = nInt4 * sizeof(int4); const uint32_t startIdx = - (op.inputOffsets[0] + getOffset(op.inputBufferRefs[0].type, offset) + processed) / sizeof(T); + (op.inputOffsets[0] + getOffset(op.inputBufferRefs[0].type, offset + tbgOffset) + processed) / + sizeof(T); const uint32_t endIdx = - (op.inputOffsets[0] + getOffset(op.inputBufferRefs[0].type, offset) + size) / sizeof(T); + (op.inputOffsets[0] + getOffset(op.inputBufferRefs[0].type, offset + tbgOffset) + size) / sizeof(T); for (uint32_t idx = threadIdx.x + startIdx; idx < endIdx; idx += blockDim.x) { T tmp = static_cast(input)[idx]; for (int index = 0; index < nRemoteInputs; ++index) { uint32_t srcOffset = (srcOffsets[index] + - getOffset(memoryChannelBufferTypes_[op.inputBufferRefs[index + 1].id], offset)) / + getOffset(memoryChannelBufferTypes_[op.inputBufferRefs[index + 1].id], offset + tbgOffset)) / sizeof(T); void* remoteMemory = static_cast(memoryChannelBufferPtrs_[op.inputBufferRefs[index + 1].id]); tmp = tmp + mscclpp::read(remoteMemory, srcOffset + idx); @@ -244,8 +275,8 @@ MSCCLPP_DEVICE_INLINE void handleReadReduceSend(const Operation& op, void* input if constexpr (SendToRemote) { for (int index = 0; index < nRemoteOutputs; ++index) { uint32_t dstOffset = - (dstOffsets[index] + - getOffset(memoryChannelBufferTypes_[op.outputBufferRefs[index + 1].id], offset)) / + (dstOffsets[index] + getOffset(memoryChannelBufferTypes_[op.outputBufferRefs[index + 1].id], + offset + tbgOffset)) / sizeof(T); void* remoteMemory = static_cast(memoryChannelBufferPtrs_[op.outputBufferRefs[index + 1].id]); mscclpp::write(remoteMemory, dstOffset + idx, tmp); @@ -258,6 +289,8 @@ template MSCCLPP_DEVICE_INLINE void handlePutPackets(const Operation& op, void* input, void* output, void* scratch) { ChannelType chType = op.channelType; uint16_t nDstChannels = op.nOutputs; + const uint8_t tbId = op.tbId; + const uint8_t tbgSize = op.tbgSize; const uint32_t* dstOffsets = op.outputOffsets; const uint32_t* srcOffsets = op.inputOffsets; const uint32_t* sizes = op.inputBufferSizes; @@ -266,9 +299,11 @@ MSCCLPP_DEVICE_INLINE void handlePutPackets(const Operation& op, void* input, vo if (chType == ChannelType::MEMORY) { for (int index = 0; index < nDstChannels; ++index) { uint32_t size = sizes[index]; + uint32_t tbgOff = calcOffset(size, tbId, tbgSize); + size = calcSize(size, tbId, tbgSize); mscclpp::copyToPackets( - (char*)memoryChannelBufferPtrs_[op.outputBufferRefs[index].id] + (dstOffsets[index] << 1) + scratchOffset_, - (char*)inputBuff + srcOffsets[index], size, threadIdx.x, blockDim.x, flag_); + (char*)memoryChannelBufferPtrs_[op.outputBufferRefs[index].id] + ((dstOffsets[index] + tbgOff) << 1) + scratchOffset_, + (char*)inputBuff + srcOffsets[index] + tbgOff, size, threadIdx.x, blockDim.x, flag_); } } if (chType == ChannelType::PORT) { @@ -291,19 +326,23 @@ MSCCLPP_DEVICE_INLINE void handlePutPackets(const Operation& op, void* input, vo template MSCCLPP_DEVICE_INLINE void handleReadPutPackets(const Operation& op, void* scratch) { uint32_t nOutput = op.nOutputs; + const uint8_t tbId = op.tbId; + const uint8_t tbgSize = op.tbgSize; const uint32_t* dstOffsets = op.outputOffsets; const uint32_t* srcOffsets = op.inputOffsets; const uint8_t* channelIndexes = op.channelIndexes; uint32_t size = op.inputBufferSizes[0]; + uint32_t tbgOff = calcOffset(size, tbId, tbgSize); + size = calcSize(size, tbId, tbgSize); ChannelType chType = op.channelType; if (chType == ChannelType::MEMORY) { size_t nPackets = size / sizeof(PacketPayload); - PacketType* pkts = (PacketType*)((char*)scratch + scratchOffset_ + (srcOffsets[0] << 1)); + PacketType* pkts = (PacketType*)((char*)scratch + scratchOffset_ + ((srcOffsets[0] + tbgOff) << 1)); for (size_t pktIdx = threadIdx.x; pktIdx < nPackets; pktIdx += blockDim.x) { PacketPayload data = pkts[pktIdx].read(flag_); PacketType pkt(data, flag_); for (uint32_t idx = 0; idx < nOutput; ++idx) { - size_t offset = (scratchOffset_ + (dstOffsets[idx] << 1)) / sizeof(PacketType); + size_t offset = (scratchOffset_ + ((dstOffsets[idx] + tbgOff) << 1)) / sizeof(PacketType); void* remoteMemory = static_cast(memoryChannelBufferPtrs_[op.outputBufferRefs[idx].id]); mscclpp::write(remoteMemory, offset + pktIdx, pkt); } @@ -333,10 +372,14 @@ MSCCLPP_DEVICE_INLINE void handleReadPutPackets(const Operation& op, void* scrat template MSCCLPP_DEVICE_INLINE void handleReduceSendPackets(const Operation& op, void* input, void* output, void* scratch) { uint32_t size = op.inputBufferSizes[0]; + const uint8_t tbId = op.tbId; + const uint8_t tbgSize = op.tbgSize; + uint32_t tbgOff = calcOffset(size, tbId, tbgSize); + size = calcSize(size, tbId, tbgSize); const uint32_t nSrcs = op.nInputs - 1; const uint32_t nDstChannels = op.nOutputs - 1; - const uint32_t srcOffsetByBytes = op.inputOffsets[0]; - const uint32_t dstOffsetByBytes = op.outputOffsets[0]; + const uint32_t srcOffsetByBytes = op.inputOffsets[0] + tbgOff; + const uint32_t dstOffsetByBytes = op.outputOffsets[0] + tbgOff; const uint32_t* inputOffsets = op.inputOffsets + 1; const uint32_t* outputOffsets = op.outputOffsets + 1; const BufferRef* outputBufferRefs = op.outputBufferRefs + 1; @@ -351,7 +394,7 @@ MSCCLPP_DEVICE_INLINE void handleReduceSendPackets(const Operation& op, void* in for (uint32_t idx = threadIdx.x; idx < nPackets; idx += blockDim.x) { PacketPayload data = {}; for (uint32_t index = 0; index < nSrcs; ++index) { - PacketType* pkt = (PacketType*)((char*)scratch + scratchOffset_ + 2 * inputOffsets[index]); + PacketType* pkt = (PacketType*)((char*)scratch + scratchOffset_ + 2 * (inputOffsets[index] + tbgOff)); PacketPayload val = pkt[idx].read(flag_); data = cal_vector(data, val); } @@ -361,7 +404,7 @@ MSCCLPP_DEVICE_INLINE void handleReduceSendPackets(const Operation& op, void* in if constexpr (SendToRemote) { PacketType pkt(data, flag_); for (uint32_t index = 0; index < nDstChannels; ++index) { - uint32_t offset = (scratchOffset_ + outputOffsets[index] * 2) / sizeof(PacketType); + uint32_t offset = (scratchOffset_ + (outputOffsets[index] + tbgOff) * 2) / sizeof(PacketType); void* remoteMemory = static_cast(memoryChannelBufferPtrs_[outputBufferRefs[index].id]); mscclpp::write(remoteMemory, offset + idx, pkt); } @@ -372,16 +415,20 @@ MSCCLPP_DEVICE_INLINE void handleReduceSendPackets(const Operation& op, void* in template MSCCLPP_DEVICE_INLINE void handleReduceCopySendPackets(const Operation& op, void* input, void* output, void* scratch) { uint32_t size = op.inputBufferSizes[0]; + const uint8_t tbId = op.tbId; + const uint8_t tbgSize = op.tbgSize; + const uint32_t tbgOff = calcOffset(size, tbId, tbgSize); const uint32_t nSrcs = op.nInputs - 1; const uint32_t nDstChannels = op.nOutputs - 2; - const uint32_t srcOffsetByBytes = op.inputOffsets[0]; - const uint32_t dstOffsetByBytes = op.outputOffsets[0]; + const uint32_t srcOffsetByBytes = op.inputOffsets[0] + tbgOff; + const uint32_t dstOffsetByBytes = op.outputOffsets[0] + tbgOff; const uint32_t* inputOffsets = op.inputOffsets + 1; const uint32_t* outputOffsets = op.outputOffsets + 2; const BufferRef* outputBufferRefs = op.outputBufferRefs + 2; - PacketType* dstPkt = - (PacketType*)((char*)getBuffer(input, output, scratch, op.outputBufferRefs[1].type) + 2 * op.outputOffsets[1]); + size = calcSize(size, tbId, tbgSize); + PacketType* dstPkt = (PacketType*)((char*)getBuffer(input, output, scratch, op.outputBufferRefs[1].type) + + 2 * (op.outputOffsets[1] + tbgOff)); uint32_t nPackets = size / sizeof(PacketPayload); const uint32_t srcOffset = srcOffsetByBytes / sizeof(PacketPayload); const uint32_t dstOffset = dstOffsetByBytes / sizeof(PacketPayload); @@ -392,7 +439,7 @@ MSCCLPP_DEVICE_INLINE void handleReduceCopySendPackets(const Operation& op, void for (uint32_t idx = threadIdx.x; idx < nPackets; idx += blockDim.x) { PacketPayload data = {}; for (uint32_t index = 0; index < nSrcs; ++index) { - PacketType* pkt = (PacketType*)((char*)scratch + scratchOffset_ + 2 * inputOffsets[index]); + PacketType* pkt = (PacketType*)((char*)scratch + scratchOffset_ + 2 * (inputOffsets[index] + tbgOff)); PacketPayload val = pkt[idx].read(flag_); data = cal_vector(data, val); } @@ -404,7 +451,7 @@ MSCCLPP_DEVICE_INLINE void handleReduceCopySendPackets(const Operation& op, void if constexpr (SendToRemote) { PacketType pkt(data, flag_); for (uint32_t index = 0; index < nDstChannels; ++index) { - uint32_t offset = (scratchOffset_ + outputOffsets[index] * 2) / sizeof(PacketType); + uint32_t offset = (scratchOffset_ + (outputOffsets[index] + tbgOff) * 2) / sizeof(PacketType); void* remoteMemory = static_cast(memoryChannelBufferPtrs_[outputBufferRefs[index].id]); mscclpp::write(remoteMemory, offset + idx, pkt); } @@ -414,9 +461,14 @@ MSCCLPP_DEVICE_INLINE void handleReduceCopySendPackets(const Operation& op, void template MSCCLPP_DEVICE_INLINE void handleUnpackPackets(const Operation& op, void* input, void* output, void* scratch) { - const uint32_t size = op.inputBufferSizes[0]; - const uint32_t dstOffset = op.outputOffsets[0]; - const uint32_t srcOffset = op.inputOffsets[0]; + uint32_t size = op.inputBufferSizes[0]; + const uint8_t tbId = op.tbId; + const uint8_t tbgSize = op.tbgSize; + const uint32_t tbgOff = calcOffset(size, tbId, tbgSize); + const uint32_t dstOffset = op.outputOffsets[0] + tbgOff; + const uint32_t srcOffset = op.inputOffsets[0] + tbgOff; + + size = calcSize(size, tbId, tbgSize); PacketType* srcPackets = (PacketType*)(static_cast(scratch) + scratchOffset_ + (srcOffset << 1)); PacketPayload* result = (PacketPayload*)(static_cast(getBuffer(input, output, scratch, op.outputBufferRefs[0].type)) + @@ -431,8 +483,12 @@ MSCCLPP_DEVICE_INLINE void handleUnpackPackets(const Operation& op, void* input, template MSCCLPP_DEVICE_INLINE void handleCopyPackets(const Operation& op, void* input, void* output, void* scratch) { uint32_t size = op.inputBufferSizes[0]; - uint32_t dstOffset = op.outputOffsets[0]; - uint32_t srcOffset = op.inputOffsets[0]; + const uint8_t tbId = op.tbId; + const uint8_t tbgSize = op.tbgSize; + const uint32_t tbgOff = calcOffset(size, tbId, tbgSize); + uint32_t dstOffset = op.outputOffsets[0] + tbgOff; + uint32_t srcOffset = op.inputOffsets[0] + tbgOff; + size = calcSize(size, tbId, tbgSize); dstOffset = dstOffset << 1; char* dst = static_cast(getBuffer(input, output, scratch, op.outputBufferRefs[0].type)) + dstOffset; char* src = static_cast(getBuffer(input, output, scratch, op.inputBufferRefs[0].type)) + srcOffset; @@ -442,7 +498,11 @@ MSCCLPP_DEVICE_INLINE void handleCopyPackets(const Operation& op, void* input, v template MSCCLPP_DEVICE_INLINE void handleReduceSend(const Operation& op, void* input, void* output, void* scratch, uint32_t offset, uint32_t unitSize) { - const uint32_t size = min(op.inputBufferSizes[0] - offset, unitSize); + uint32_t size = min(op.inputBufferSizes[0] - offset, unitSize); + const uint8_t tbId = op.tbId; + const uint8_t tbgSize = op.tbgSize; + const uint32_t tbgOffset = calcOffset(size, tbId, tbgSize); + size = calcSize(size, tbId, tbgSize); const uint32_t nInt4 = size / sizeof(int4); int nInput = op.nInputs - 1; int nOutput = op.nOutputs - 1; @@ -450,8 +510,10 @@ MSCCLPP_DEVICE_INLINE void handleReduceSend(const Operation& op, void* input, vo const uint32_t* outputOffsets = op.outputOffsets + 1; const BufferRef* inputBufferRefs = op.inputBufferRefs + 1; const BufferRef* outputBufferRefs = op.outputBufferRefs + 1; - uint32_t srcOffsetByBytes = op.inputOffsets[0] + getOffset(op.inputBufferRefs[0].type, offset); - uint32_t dstOffsetByBytes = op.outputOffsets[0] + getOffset(op.outputBufferRefs[0].type, offset); + uint32_t srcOffsetByBytes = + op.inputOffsets[0] + getOffset(op.inputBufferRefs[0].type, offset + tbgOffset); + uint32_t dstOffsetByBytes = + op.outputOffsets[0] + getOffset(op.outputBufferRefs[0].type, offset + tbgOffset); const uint32_t srcOffset4 = srcOffsetByBytes / sizeof(int4); const uint32_t dstOffset4 = dstOffsetByBytes / sizeof(int4); @@ -462,16 +524,18 @@ MSCCLPP_DEVICE_INLINE void handleReduceSend(const Operation& op, void* input, vo for (int index = 0; index < nInput; ++index) { int4* buff4 = static_cast(getBuffer(input, output, scratch, inputBufferRefs[index].type)); size_t buffOffset = - (inputOffsets[index] + getOffset(outputBufferRefs[index].type, offset)) / sizeof(int4); + (inputOffsets[index] + getOffset(outputBufferRefs[index].type, offset + tbgOffset)) / + sizeof(int4); int4 val = buff4[buffOffset + idx]; tmp = cal_vector(tmp, val); } dst4[dstOffset4 + idx] = tmp; if constexpr (SendToRemote) { for (int index = 0; index < nOutput; ++index) { - size_t outOffset = (outputOffsets[index] + - getOffset(memoryChannelBufferTypes_[outputBufferRefs[index].id], offset)) / - sizeof(int4); + size_t outOffset = + (outputOffsets[index] + + getOffset(memoryChannelBufferTypes_[outputBufferRefs[index].id], offset + tbgOffset)) / + sizeof(int4); void* remoteMemory = memoryChannelBufferPtrs_[outputBufferRefs[index].id]; mscclpp::write(remoteMemory, outOffset + idx, tmp); } @@ -488,15 +552,16 @@ MSCCLPP_DEVICE_INLINE void handleReduceSend(const Operation& op, void* input, vo for (int index = 0; index < nInput; ++index) { T* buff = static_cast(getBuffer(input, output, scratch, inputBufferRefs[index].type)); uint32_t buffOffset = - (inputOffsets[index] + getOffset(inputBufferRefs[index].type, offset)) / sizeof(T); + (inputOffsets[index] + getOffset(inputBufferRefs[index].type, offset + tbgOffset)) / sizeof(T); tmp = tmp + buff[buffOffset + idx]; } dst[idx] = tmp; if constexpr (SendToRemote) { for (int index = 0; index < nOutput; ++index) { - uint32_t outOffset = (outputOffsets[index] + - getOffset(memoryChannelBufferTypes_[outputBufferRefs[index].id], offset)) / - sizeof(T); + uint32_t outOffset = + (outputOffsets[index] + + getOffset(memoryChannelBufferTypes_[outputBufferRefs[index].id], offset + tbgOffset)) / + sizeof(T); void* remoteMemory = memoryChannelBufferPtrs_[outputBufferRefs[index].id]; mscclpp::write(remoteMemory, outOffset + idx, tmp); } @@ -508,11 +573,15 @@ template MSCCLPP_DEVICE_INLINE void handleCopy(const Operation& op, void* input, void* output, void* scratch, uint32_t offset, uint32_t unitSize) { uint32_t size = min(op.inputBufferSizes[0] - offset, unitSize); + const uint8_t tbId = op.tbId; + const uint8_t tbgSize = op.tbgSize; + const uint32_t tbgOffset = calcOffset(size, tbId, tbgSize); + size = calcSize(size, tbId, tbgSize); if (size <= 0) { return; } - uint32_t dstOffset = op.outputOffsets[0] + getOffset(op.outputBufferRefs[0].type, offset); - uint32_t srcOffset = op.inputOffsets[0] + getOffset(op.inputBufferRefs[0].type, offset); + uint32_t dstOffset = op.outputOffsets[0] + getOffset(op.outputBufferRefs[0].type, offset + tbgOffset); + uint32_t srcOffset = op.inputOffsets[0] + getOffset(op.inputBufferRefs[0].type, offset + tbgOffset); char* srcData = static_cast(getBuffer(input, output, scratch, op.inputBufferRefs[0].type)) + srcOffset; char* dstData = static_cast(getBuffer(input, output, scratch, op.outputBufferRefs[0].type)) + dstOffset; mscclpp::copy(dstData, srcData, size, threadIdx.x, blockDim.x); @@ -526,12 +595,17 @@ MSCCLPP_DEVICE_INLINE void handleMultiLoadReduceStore(const Operation& op, uint3 return; } else { static_assert(sizeof(T) <= 8, "Only support type with size <= 8 bytes"); - const uint32_t size = min(op.inputBufferSizes[0] - offset, unitSize); + uint32_t size = min(op.inputBufferSizes[0] - offset, unitSize); + const uint8_t tbId = op.tbId; + const uint8_t tbgSize = op.tbgSize; + const uint32_t tbgOffset = calcOffset(size, tbId, tbgSize); + size = calcSize(size, tbId, tbgSize); if (size <= 0) { return; } - const uint32_t srcOffset = op.inputOffsets[0] + getOffset(op.nvlsInputBufferType, offset); - const uint32_t dstOffset = op.outputOffsets[0] + getOffset(op.nvlsOutputBufferType, offset); + const uint32_t srcOffset = op.inputOffsets[0] + getOffset(op.nvlsInputBufferType, offset + tbgOffset); + const uint32_t dstOffset = + op.outputOffsets[0] + getOffset(op.nvlsOutputBufferType, offset + tbgOffset); assert(size % sizeof(T) == 0); assert(srcOffset % sizeof(T) == 0); assert(dstOffset % sizeof(T) == 0);