diff --git a/.azure-pipelines/templates/ut-no-ib-env.yaml b/.azure-pipelines/templates/ut-no-ib-env.yaml index e6576f6d..0d97f9fc 100644 --- a/.azure-pipelines/templates/ut-no-ib-env.yaml +++ b/.azure-pipelines/templates/ut-no-ib-env.yaml @@ -16,7 +16,7 @@ steps: targetType: 'inline' script: | mkdir build && cd build - cmake -DCMAKE_BUILD_TYPE=Release -DMSCCLPP_BYPASS_GPU_CHECK=ON -DMSCCLPP_USE_CUDA=ON -DMSCCLPP_BUILD_TESTS=ON -DMSCCLPP_GPU_ARCHS=${{ parameters.gpuArch }} .. + cmake -DCMAKE_BUILD_TYPE=Release -DMSCCLPP_BYPASS_GPU_CHECK=ON -DMSCCLPP_USE_CUDA=ON -DMSCCLPP_BUILD_TESTS=ON -DMSCCLPP_USE_IB=OFF -DMSCCLPP_GPU_ARCHS=${{ parameters.gpuArch }} .. make -j workingDirectory: '$(System.DefaultWorkingDirectory)' @@ -55,6 +55,51 @@ steps: arguments: single-node-test false workingDirectory: $(System.DefaultWorkingDirectory) +- task: Bash@3 + name: UnitTests + displayName: Run mscclpp unit tests + inputs: + targetType: inline + script: | + set -e + HOSTFILE=$(System.DefaultWorkingDirectory)/test/deploy/hostfile_ci + SSH_OPTION="StrictHostKeyChecking=no" + KeyFilePath=${SSHKEYFILE_SECUREFILEPATH} + : > azureuser@10.0.0.4 + tail -f azureuser@10.0.0.4 & + CHILD_PID=$! + parallel-ssh -t 0 -h ${HOSTFILE} -x "-i ${KeyFilePath}" -o . \ + -O $SSH_OPTION 'sudo docker exec -t mscclpp-test bash -c " \ + cd /root/mscclpp; \ + export LD_LIBRARY_PATH=/root/mscclpp/build/lib:\$LD_LIBRARY_PATH; \ + ./build/bin/unit_tests"' + kill $CHILD_PID + workingDirectory: '$(System.DefaultWorkingDirectory)' + +- task: Bash@3 + name: MpUnitTests + displayName: Run mscclpp multi-process unit tests + inputs: + targetType: 'inline' + script: | + set -e + HOSTFILE=$(System.DefaultWorkingDirectory)/test/deploy/hostfile_ci + SSH_OPTION="StrictHostKeyChecking=no" + KeyFilePath=${SSHKEYFILE_SECUREFILEPATH} + : > azureuser@10.0.0.4 + tail -f azureuser@10.0.0.4 & + CHILD_PID=$! + parallel-ssh -t 0 -h ${HOSTFILE} -x "-i ${KeyFilePath}" -o . \ + -O $SSH_OPTION 'sudo docker exec -t mscclpp-test bash -c " \ + export PATH=/usr/local/mpi/bin:\$PATH; \ + cd /root/mscclpp; \ + export LD_LIBRARY_PATH=/root/mscclpp/build/lib:\$LD_LIBRARY_PATH; \ + mpirun --allow-run-as-root -tag-output -np 2 ./build/bin/mp_unit_tests; \ + mpirun --allow-run-as-root -tag-output -np 4 ./build/bin/mp_unit_tests; \ + mpirun --allow-run-as-root -tag-output -np 8 ./build/bin/mp_unit_tests"' + kill $CHILD_PID + workingDirectory: '$(System.DefaultWorkingDirectory)' + - task: Bash@3 name: PyTests displayName: Run pytests @@ -73,7 +118,64 @@ steps: export PATH=/usr/local/mpi/bin:\$PATH \ export LD_LIBRARY_PATH=/root/mscclpp/build/lib:\$LD_LIBRARY_PATH; \ cd /root/mscclpp; \ - mpirun --allow-run-as-root -tag-output -x MSCCLPP_HOME=/root/mscclpp -np 8 python3 -m pytest ./python/test/test_mscclpp.py::test_executor -x"' + mpirun --allow-run-as-root -tag-output -x MSCCLPP_HOME=/root/mscclpp -x MSCCLPP_DISABLE_IB_TESTS=1 -np 8 python3 -m pytest ./python/test/test_mscclpp.py -x"' + kill $CHILD_PID + workingDirectory: '$(System.DefaultWorkingDirectory)' + +- task: Bash@3 + name: StopContainer + displayName: Stop existing container + inputs: + targetType: 'inline' + script: | + set -e + HOSTFILE=$(System.DefaultWorkingDirectory)/test/deploy/hostfile_ci + SSH_OPTION="StrictHostKeyChecking=no" + KeyFilePath=${SSHKEYFILE_SECUREFILEPATH} + parallel-ssh -i -t 0 -h ${HOSTFILE} -x "-i ${KeyFilePath}" -O $SSH_OPTION \ + "sudo docker stop mscclpp-test || true; sudo docker rm mscclpp-test || true" + rm -f $(System.DefaultWorkingDirectory)/sshkey $(System.DefaultWorkingDirectory)/sshkey.pub + workingDirectory: '$(System.DefaultWorkingDirectory)' + +- task: Bash@3 + name: BuildWithIb + displayName: Rebuild with IB + inputs: + targetType: 'inline' + script: | + rm -rf build && mkdir build && cd build + cmake -DCMAKE_BUILD_TYPE=Release -DMSCCLPP_BYPASS_GPU_CHECK=ON -DMSCCLPP_USE_CUDA=ON -DMSCCLPP_BUILD_TESTS=ON -DMSCCLPP_GPU_ARCHS=${{ parameters.gpuArch }} .. + make -j + workingDirectory: '$(System.DefaultWorkingDirectory)' + +- task: Bash@3 + name: DeployTestEnvWithIb + displayName: Deploy Test Env (with IB build) + inputs: + targetType: filePath + filePath: test/deploy/deploy.sh + arguments: single-node-test false + workingDirectory: $(System.DefaultWorkingDirectory) + +- task: Bash@3 + name: PyTestsWithIbBuildDisableIb + displayName: Run pytests (IB build, IB tests disabled) + inputs: + targetType: inline + script: | + set -e + HOSTFILE=$(System.DefaultWorkingDirectory)/test/deploy/hostfile_ci + SSH_OPTION="StrictHostKeyChecking=no" + KeyFilePath=${SSHKEYFILE_SECUREFILEPATH} + : > azureuser@10.0.0.4 + tail -f azureuser@10.0.0.4 & + CHILD_PID=$! + parallel-ssh -t 0 -h ${HOSTFILE} -x "-i ${KeyFilePath}" -o . \ + -O $SSH_OPTION 'sudo docker exec -t mscclpp-test bash -c " \ + export PATH=/usr/local/mpi/bin:\$PATH \ + export LD_LIBRARY_PATH=/root/mscclpp/build/lib:\$LD_LIBRARY_PATH; \ + cd /root/mscclpp; \ + mpirun --allow-run-as-root -tag-output -x MSCCLPP_HOME=/root/mscclpp -x MSCCLPP_DISABLE_IB_TESTS=1 -np 8 python3 -m pytest ./python/test/test_mscclpp.py -x"' kill $CHILD_PID workingDirectory: '$(System.DefaultWorkingDirectory)' diff --git a/.azure-pipelines/ut.yml b/.azure-pipelines/ut.yml index 960f3eae..4aac07e6 100644 --- a/.azure-pipelines/ut.yml +++ b/.azure-pipelines/ut.yml @@ -113,7 +113,7 @@ jobs: gpuArch: '90' - job: UnitTestNoIBEnv - timeoutInMinutes: 40 + timeoutInMinutes: 60 displayName: Test No IB Environment pool: name: msccl-ci-h100 diff --git a/.github/workflows/codeql-analysis.yml b/.github/workflows/codeql-analysis.yml index 6982c69c..575c472b 100644 --- a/.github/workflows/codeql-analysis.yml +++ b/.github/workflows/codeql-analysis.yml @@ -51,7 +51,7 @@ jobs: df -h - name: Initialize CodeQL - uses: github/codeql-action/init@v3 + uses: github/codeql-action/init@v4 with: languages: ${{ matrix.language }} @@ -63,10 +63,10 @@ jobs: run: | rm -rf build && mkdir build && cd build cmake -DMSCCLPP_BYPASS_GPU_CHECK=ON -DMSCCLPP_USE_CUDA=ON -DMSCCLPP_BUILD_TESTS=OFF .. - make -j + make -j4 - name: Perform CodeQL Analysis - uses: github/codeql-action/analyze@v3 + uses: github/codeql-action/analyze@v4 with: category: "/language:${{matrix.language}}/version:${{matrix.version}}" @@ -96,7 +96,7 @@ jobs: df -h - name: Initialize CodeQL - uses: github/codeql-action/init@v3 + uses: github/codeql-action/init@v4 with: languages: ${{ matrix.language }} @@ -108,9 +108,9 @@ jobs: run: | rm -rf build && mkdir build && cd build CXX=/opt/rocm/bin/hipcc cmake -DMSCCLPP_BYPASS_GPU_CHECK=ON -DMSCCLPP_USE_ROCM=ON -DMSCCLPP_BUILD_TESTS=OFF .. - make -j + make -j4 - name: Perform CodeQL Analysis - uses: github/codeql-action/analyze@v3 + uses: github/codeql-action/analyze@v4 with: category: "/language:${{matrix.language}}/version:${{matrix.version}}" diff --git a/examples/torch-integration/customized_comm_with_tuning.py b/examples/torch-integration/customized_comm_with_tuning.py new file mode 100644 index 00000000..b618df5c --- /dev/null +++ b/examples/torch-integration/customized_comm_with_tuning.py @@ -0,0 +1,282 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +# MSCCLPP_MASTER_ADDR= MSCCLPP_MASTER_PORT= torchrun --nnodes=1 --nproc_per_node=8 customized_comm_with_tuning.py + +import os +import torch +import mscclpp.utils as mscclpp_utils +import mscclpp +import mscclpp.ext +import netifaces as ni +import ipaddress + + +def load_algorithms(scratch_buffer: torch.tensor, rank: int) -> mscclpp.AlgorithmCollection: + collection_builder = mscclpp.ext.AlgorithmCollectionBuilder() + return collection_builder.build_default_algorithms( + scratch_buffer=scratch_buffer.data_ptr(), scratch_buffer_size=scratch_buffer.nbytes, rank=rank + ) + + +def interfaces_for_ip_netifaces(ip: str): + target = ipaddress.ip_address(ip) + for interface in ni.interfaces(): + addresses = ni.ifaddresses(interface) + if ni.AF_INET in addresses: + for link in addresses[ni.AF_INET]: + if "addr" in link: + addr = ipaddress.ip_address(link["addr"]) + if addr == target: + return interface + return None + + +def to_mscclpp_reduce_op(op: torch.distributed.ReduceOp) -> mscclpp.ReduceOp: + if op == torch.distributed.ReduceOp.SUM: + return mscclpp.ReduceOp.SUM + elif op == torch.distributed.ReduceOp.MIN: + return mscclpp.ReduceOp.MIN + else: + raise ValueError(f"unsupported op: {op}") + + +class CustomizedComm: + def __init__(self, comm: mscclpp.CommGroup): + self.comm = comm + self.rank = comm.my_rank + self.world_size = comm.nranks + self.local_rank = comm.my_rank % comm.nranks_per_node + self.n_ranks_per_node = comm.nranks_per_node + dlpack = mscclpp.RawGpuBuffer(1 << 27).to_dlpack(data_type=str(torch.float16)) + self.scratch_buffer = torch.utils.dlpack.from_dlpack(dlpack) + algorithms = load_algorithms(scratch_buffer=self.scratch_buffer, rank=self.rank) + self._algorithm_nvls_packet = [ + algo + for algo in algorithms + if algo.collective == "allreduce" and algo.name == "default_allreduce_nvls_packet" + ][0] + self._algorithm_rsag_zero_copy = [ + algo + for algo in algorithms + if algo.collective == "allreduce" and algo.name == "default_allreduce_rsag_zero_copy" + ][0] + self._algorithm_packet = [ + algo for algo in algorithms if algo.collective == "allreduce" and algo.name == "default_allreduce_packet" + ][0] + self._tune(n_warmup=5, n_graph_launches=10, n_ops_per_graph=100) + + def _tune(self, n_warmup, n_graph_launches, n_ops_per_graph): + sizes = [1 << i for i in range(10, 28)] + # Pre-fill with defaults for barrier + self.best_configs = {1024: (self._algorithm_nvls_packet, 0, 0)} + + tune_tensor = torch.rand(1 << 27, dtype=torch.float16, device="cuda") + candidates_nblocks = [4, 8, 16, 24, 32, 48, 64, 128] + candidates_nthreads = [512, 768, 1024] + + for size in sizes: + algos = [] + if size <= 4 * 1024 * 1024: + algos.append(self._algorithm_nvls_packet) + algos.append(self._algorithm_packet) + if size >= 512 * 1024: + algos.append(self._algorithm_rsag_zero_copy) + + best_time = float("inf") + best_config = None + + for algo in algos: + for nb in candidates_nblocks: + if algo.name == "default_allreduce_nvls_packet" and nb > 16: + continue + if algo.name == "default_allreduce_packet" and nb > 56: + continue + for nt in candidates_nthreads: + if self._run_algo(algo, tune_tensor, size, nb, nt) != 0: + continue + + for _ in range(n_warmup): + self._run_algo(algo, tune_tensor, size, nb, nt) + self.barrier() + + capture_stream = torch.cuda.Stream() + capture_stream.wait_stream(torch.cuda.current_stream()) + + g = torch.cuda.CUDAGraph() + # Warmup on capture stream + with torch.cuda.stream(capture_stream): + self._run_algo(algo, tune_tensor, size, nb, nt) + capture_stream.synchronize() + + with torch.cuda.graph(g, stream=capture_stream): + for _ in range(n_ops_per_graph): + self._run_algo(algo, tune_tensor, size, nb, nt) + + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record(capture_stream) + with torch.cuda.stream(capture_stream): + for _ in range(n_graph_launches): + g.replay() + end_event.record(capture_stream) + end_event.synchronize() + + elapsed = start_event.elapsed_time(end_event) + + # Synchronize timing results across all ranks to ensure consistent algorithm selection + # replicate n times such due to algo limitations + time_tensor = torch.full((self.world_size,), elapsed, dtype=torch.float64, device="cuda").to( + dtype=torch.float32 + ) + torch.cuda.current_stream().wait_stream(capture_stream) + # TODO: use all_reduce may cause problem if the time elapsed between different algos are too close. + # May change to broadcast in the future if that becomes an issue. + self.all_reduce(time_tensor, op=torch.distributed.ReduceOp.SUM) + avg_time = time_tensor[self.rank].item() / self.world_size + + if avg_time < best_time: + best_time = avg_time + best_config = (algo, nb, nt) + + if best_config: + self.best_configs[size] = best_config + if self.rank == 0: + print( + f"Size {size}: Best Algo {best_config[0].name} nblocks {best_config[1]} nthreads {best_config[2]} Time {(best_time/(n_graph_launches * n_ops_per_graph))*1000:.2f} us" + ) + # reset the algorithms after tuning + torch.cuda.synchronize() + for algo in algos: + algo.reset() + + def _run_algo(self, algo, tensor, size, nblocks, nthreads): + return algo.execute( + comm=self.comm.communicator, + input_buffer=tensor.data_ptr(), + output_buffer=tensor.data_ptr(), + input_size=size, + output_size=size, + dtype=mscclpp_utils.torch_dtype_to_mscclpp_dtype(tensor.dtype), + op=mscclpp.ReduceOp.SUM, + stream=torch.cuda.current_stream().cuda_stream, + nblocks=nblocks, + nthreads_per_block=nthreads, + ) + + def get_tuned_config(self, size): + if size < 1024: + target_size = 1024 + elif size > 256 * 1024 * 1024: + target_size = 256 * 1024 * 1024 + else: + target_size = 1 << (size - 1).bit_length() + return self.best_configs.get(target_size) + + def all_reduce(self, tensor: torch.Tensor, op=torch.distributed.ReduceOp.SUM, stream: torch.cuda.Stream = None): + assert op == torch.distributed.ReduceOp.SUM + config = self.get_tuned_config(tensor.nbytes) + algo, nblocks, nthreads = config if config else (self._algorithm_nvls_packet, 0, 0) + ret = algo.execute( + comm=self.comm.communicator, + input_buffer=tensor.data_ptr(), + output_buffer=tensor.data_ptr(), + input_size=tensor.nbytes, + output_size=tensor.nbytes, + dtype=mscclpp_utils.torch_dtype_to_mscclpp_dtype(tensor.dtype), + op=to_mscclpp_reduce_op(op), + stream=stream.cuda_stream if stream is not None else torch.cuda.current_stream().cuda_stream, + nblocks=nblocks, + nthreads_per_block=nthreads, + ) + if ret != 0: + print(f"Rank {self.rank}: Algo {algo.name} failed with error {ret}") + + def barrier(self): + tensor = torch.empty(self.world_size, dtype=torch.float, device=torch.device("cuda")) + self.all_reduce(tensor, op=torch.distributed.ReduceOp.SUM, stream=torch.cuda.current_stream()) + + def benchmark(self, n_warmup=10, n_graph_launches=10, n_iter_per_graph=100): + low = 5 * 1024 + high = 80 * 1024 * 1024 + sizes = [] + curr = low + while curr <= high: + sizes.append(curr) + curr *= 2 + + if self.rank == 0: + print(f"{'Size (Bytes)':<20} {'Time (us)':<20} {'AlgoBW (GB/s)':<20}") + + dtype = torch.float16 + capture_stream = torch.cuda.Stream() + + for size in sizes: + tensor = torch.rand(size // 2, dtype=dtype, device="cuda") + capture_stream.wait_stream(torch.cuda.current_stream()) + # Capture Graph + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g, stream=capture_stream): + for _ in range(n_iter_per_graph): + self.all_reduce(tensor, op=torch.distributed.ReduceOp.SUM) + + # warmup: Execute the graph once to prime the driver + with torch.cuda.stream(capture_stream): + for _ in range(n_warmup): + g.replay() + self.barrier() + capture_stream.synchronize() + + # Benchmark + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + start_event.record(capture_stream) + with torch.cuda.stream(capture_stream): + for _ in range(n_graph_launches): + g.replay() + end_event.record(capture_stream) + end_event.synchronize() + + # Get elapsed time in milliseconds + elapsed_ms = start_event.elapsed_time(end_event) + avg_time_ms = elapsed_ms / (n_graph_launches * n_iter_per_graph) + time_us = avg_time_ms * 1000 + + alg_bw = size / (avg_time_ms * 1e-3) if avg_time_ms > 0 else 0 + if self.rank == 0: + print(f"{size:<20} {time_us:<20.2f} {alg_bw / 1e9:<20.2f}") + + def destroy(self): + self._algorithm_nvls_nonzero_copy = None + self._algorithm_nvls_packet = None + self.scratch_buffer = None + self.comm = None + + +def init_dist() -> CustomizedComm: + rank = int(os.environ["RANK"]) + world = int(os.environ["WORLD_SIZE"]) + master_addr = os.environ["MSCCLPP_MASTER_ADDR"] + master_port = os.environ["MSCCLPP_MASTER_PORT"] + interface = interfaces_for_ip_netifaces(master_addr) + if interface is None: + raise ValueError(f"Cannot find network interface for IP address {master_addr}") + interfaceIpPortTrio = f"{interface}:{master_addr}:{master_port}" + mscclpp_group = mscclpp.CommGroup(interfaceIpPortTrio=interfaceIpPortTrio, rank=rank, size=world) + return CustomizedComm(mscclpp_group) + + +def main(): + local = int(os.environ["LOCAL_RANK"]) + torch.cuda.set_device(local) + comm = init_dist() + comm.benchmark(n_warmup=5, n_graph_launches=10, n_iter_per_graph=100) + comm.barrier() + torch.cuda.synchronize() + comm.destroy() + print(f"rank {local} All-reduce operation completed successfully.") + + +if __name__ == "__main__": + main() diff --git a/python/test/test_mscclpp.py b/python/test/test_mscclpp.py index a6899642..6b3119cb 100644 --- a/python/test/test_mscclpp.py +++ b/python/test/test_mscclpp.py @@ -162,13 +162,10 @@ def create_connection(group: CommGroup, connection_type: str): def create_group_and_connection(mpi_group: MpiGroup, connection_type: str): if (connection_type == "NVLink" or connection_type == "NVLS") and all_ranks_on_the_same_node(mpi_group) is False: pytest.skip("cannot use nvlink/nvls for cross node") + if connection_type == "IB" and os.environ.get("MSCCLPP_DISABLE_IB_TESTS", "0") != "0": + pytest.skip("IB tests are disabled via MSCCLPP_DISABLE_IB_TESTS=1") group = CommGroup(mpi_group.comm) - try: - connection = create_connection(group, connection_type) - except Error as e: - if connection_type == "IB" and e.args[0] == ErrorCode.InvalidUsage: - pytest.skip("IB not supported on this node") - raise + connection = create_connection(group, connection_type) return group, connection @@ -281,6 +278,8 @@ def test_connection_write_and_signal(mpi_group: MpiGroup, connection_type: str, @parametrize_mpi_groups(2, 4, 8, 16) def test_h2h_semaphores(mpi_group: MpiGroup): + if os.environ.get("MSCCLPP_DISABLE_IB_TESTS", "0") != "0": + pytest.skip("IB tests are disabled via MSCCLPP_DISABLE_IB_TESTS=1") group = CommGroup(mpi_group.comm) tran = group.my_ib_device(group.my_rank % 8) endpoint = EndpointConfig(tran, Device(DeviceType.CPU)) @@ -301,6 +300,8 @@ def test_h2h_semaphores(mpi_group: MpiGroup): @parametrize_mpi_groups(2, 4, 8, 16) def test_h2h_semaphores_gil_release(mpi_group: MpiGroup): + if os.environ.get("MSCCLPP_DISABLE_IB_TESTS", "0") != "0": + pytest.skip("IB tests are disabled via MSCCLPP_DISABLE_IB_TESTS=1") group = CommGroup(mpi_group.comm) tran = group.my_ib_device(group.my_rank % 8) endpoint = EndpointConfig(tran, Device(DeviceType.CPU)) diff --git a/src/core/ib.cc b/src/core/ib.cc index 2e7b867d..b8854a6e 100644 --- a/src/core/ib.cc +++ b/src/core/ib.cc @@ -636,6 +636,34 @@ MSCCLPP_API_CPP std::string getIBDeviceName(Transport) { return ""; } MSCCLPP_API_CPP Transport getIBTransportByDeviceName(const std::string&) { return Transport::Unknown; } +IbMr::~IbMr() {} +IbMrInfo IbMr::getInfo() const { return IbMrInfo(); } +const void* IbMr::getBuff() const { return nullptr; } +uint32_t IbMr::getLkey() const { return 0; } + +IbQp::~IbQp() {} +void IbQp::rtr(const IbQpInfo& /*info*/) {} +void IbQp::rts() {} +void IbQp::stageSendWrite(const IbMr* /*mr*/, const IbMrInfo& /*info*/, uint32_t /*size*/, uint64_t /*wrId*/, + uint64_t /*srcOffset*/, uint64_t /*dstOffset*/, bool /*signaled*/) {} +void IbQp::stageSendAtomicAdd(const IbMr* /*mr*/, const IbMrInfo& /*info*/, uint64_t /*wrId*/, uint64_t /*dstOffset*/, + uint64_t /*addVal*/, bool /*signaled*/) {} +void IbQp::stageSendWriteWithImm(const IbMr* /*mr*/, const IbMrInfo& /*info*/, uint32_t /*size*/, uint64_t /*wrId*/, + uint64_t /*srcOffset*/, uint64_t /*dstOffset*/, bool /*signaled*/, + unsigned int /*immData*/) {} +void IbQp::postSend() {} +void IbQp::stageRecv(uint64_t /*wrId*/) {} +void IbQp::stageRecv(const IbMr* /*mr*/, uint64_t /*wrId*/, uint32_t /*size*/, uint64_t /*offset*/) {} +void IbQp::postRecv() {} +int IbQp::pollSendCq() { return 0; } +int IbQp::pollRecvCq() { return 0; } +int IbQp::getSendWcStatus(int /*idx*/) const { return 0; } +std::string IbQp::getSendWcStatusString(int /*idx*/) const { return ""; } +int IbQp::getNumSendCqItems() const { return 0; } +int IbQp::getRecvWcStatus(int /*idx*/) const { return 0; } +std::string IbQp::getRecvWcStatusString(int /*idx*/) const { return ""; } +unsigned int IbQp::getRecvWcImmData(int /*idx*/) const { return 0; } + #endif // !defined(USE_IBVERBS) } // namespace mscclpp diff --git a/src/ext/collectives/algorithm_collection_builder.cc b/src/ext/collectives/algorithm_collection_builder.cc index 1ede7519..2b3bec8d 100644 --- a/src/ext/collectives/algorithm_collection_builder.cc +++ b/src/ext/collectives/algorithm_collection_builder.cc @@ -13,6 +13,9 @@ #include "allreduce/allreduce_nvls_with_copy.hpp" #include "allreduce/allreduce_nvls_with_copy_2.hpp" #include "allreduce/allreduce_packet.hpp" +#include "allreduce/allreduce_rsag.hpp" +#include "allreduce/allreduce_rsag_pipeline.hpp" +#include "allreduce/allreduce_rsag_zero_copy.hpp" #include "logger.hpp" namespace mscclpp { @@ -82,6 +85,14 @@ AlgorithmCollection AlgorithmCollectionBuilder::buildDefaultNativeAlgorithms(uin collection.registerAlgorithm(allreduceNvls->collective(), allreduceNvls->name(), allreduceNvls); auto allreduceFullmesh = std::make_shared(scratchBuffer, scratchBufferSize)->build(); collection.registerAlgorithm(allreduceFullmesh->collective(), allreduceFullmesh->name(), allreduceFullmesh); + auto allreduceRsag = std::make_shared(scratchBuffer, scratchBufferSize)->build(); + collection.registerAlgorithm(allreduceRsag->collective(), allreduceRsag->name(), allreduceRsag); + auto allreduceRsagPipeline = std::make_shared(scratchBuffer, scratchBufferSize)->build(); + collection.registerAlgorithm(allreduceRsagPipeline->collective(), allreduceRsagPipeline->name(), + allreduceRsagPipeline); + auto allreduceRsagZeroCopy = std::make_shared()->build(); + collection.registerAlgorithm(allreduceRsagZeroCopy->collective(), allreduceRsagZeroCopy->name(), + allreduceRsagZeroCopy); auto allgatherFullmesh = std::make_shared(scratchBuffer, scratchBufferSize)->build(); collection.registerAlgorithm(allgatherFullmesh->collective(), allgatherFullmesh->name(), allgatherFullmesh); diff --git a/src/ext/collectives/allreduce/allreduce_rsag.cu b/src/ext/collectives/allreduce/allreduce_rsag.cu new file mode 100644 index 00000000..d5be2257 --- /dev/null +++ b/src/ext/collectives/allreduce/allreduce_rsag.cu @@ -0,0 +1,229 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#include "allreduce/allreduce_rsag.hpp" +#include "allreduce/common.hpp" +#include "collective_utils.hpp" +#include "logger.hpp" + +namespace mscclpp { +namespace collective { + +// Allreduce using the Reduce-Scatter + All-Gather (RSAG) pattern. +// +// This algorithm performs allreduce in three phases over intra-node peers +// connected via CudaIpc memory channels: +// +// 1. Scatter: Each rank copies its input data into a scratch buffer, then +// signals peers and waits for all peers to do the same. +// +// 2. Reduce-Scatter: Each rank reduces its assigned chunk by reading the +// corresponding chunks from all peers' scratch buffers (via remote memory +// handles) and applying the reduction op. The reduced result is written +// back to both the local result buffer and peers' scratch buffers. +// +// 3. All-Gather: After a second signal/wait barrier, each rank copies the +// reduced chunks produced by other ranks from the scratch buffer into its +// result buffer, completing the allreduce. +// +// Data is processed in int4-sized (16-byte) units for coalesced memory access, +// with special handling for any remainder elements at the tail. +template +__global__ void __launch_bounds__(1024, 1) + allreduceRsAg(T* buff, T* scratch, T* resultBuff, DeviceHandle* memoryChannels, + DeviceHandle* switchChannels, void* remoteMemories, int rank, int nRanksPerNode, + int worldSize, size_t nelems) { + int blockId = blockIdx.x; + uint32_t nPeers = nRanksPerNode - 1; + + assert((uintptr_t)buff % sizeof(int4) == 0); + assert((uintptr_t)resultBuff % sizeof(int4) == 0); + + constexpr uint32_t nelemsPerInt4 = sizeof(int4) / sizeof(T); + uint32_t alignedNelems = ((nelems + nRanksPerNode - 1) / nRanksPerNode + nelemsPerInt4 - 1) / nelemsPerInt4 * + nelemsPerInt4 * nRanksPerNode; + uint32_t nelemsPerRank = alignedNelems / nRanksPerNode; + uint32_t nInt4PerRank = nelemsPerRank / nelemsPerInt4; + uint32_t lastInt4Index = nelems / nelemsPerInt4; + uint32_t remainder = nelems % nelemsPerInt4; + + int4* scratch4 = reinterpret_cast((char*)scratch); + int4* resultBuff4 = reinterpret_cast((char*)resultBuff); + int4* buff4 = reinterpret_cast((char*)buff); + DeviceHandle* memoryChannelsLocal = memoryChannels + blockId * nPeers; + + uint32_t nInt4PerBlock = nInt4PerRank / gridDim.x; + uint32_t remainderForBlock = nInt4PerRank % gridDim.x; + uint32_t offset4 = blockId * nInt4PerBlock; + if (blockId == (int)(gridDim.x - 1)) { + nInt4PerBlock += remainderForBlock; + } + if (nInt4PerBlock == 0) return; + uint32_t nInt4ForCopy = nInt4PerBlock * nRanksPerNode; + + for (uint32_t idx = threadIdx.x; idx < nInt4ForCopy; idx += blockDim.x) { + int rankIdx = idx / nInt4PerBlock; + uint32_t offsetIdx = rankIdx * nInt4PerRank + offset4 + (idx % nInt4PerBlock); + if (offsetIdx > lastInt4Index) continue; + if (offsetIdx == lastInt4Index && remainder != 0) { + for (uint32_t i = 0; i < remainder; i++) { + ((T*)&scratch4[offsetIdx])[i] = ((T*)&buff4[offsetIdx])[i]; + } + continue; + } + scratch4[offsetIdx] = buff4[offsetIdx]; + } + __syncthreads(); + if (threadIdx.x < nPeers) { + memoryChannelsLocal[threadIdx.x].signal(); + memoryChannelsLocal[threadIdx.x].wait(); + } + __syncthreads(); + for (uint32_t idx = threadIdx.x; idx < nInt4PerBlock; idx += blockDim.x) { + uint32_t offset = idx + offset4 + rank * nInt4PerRank; + if (offset > lastInt4Index) continue; + int4 tmp = scratch4[offset]; + for (uint32_t i = 0; i < nPeers; i++) { + int rankIdx = (rank + i + 1) % nRanksPerNode; + int peerIdx = rankIdx < rank ? rankIdx : rankIdx - 1; + int4 data = mscclpp::read(((void**)remoteMemories)[peerIdx], offset); + tmp = cal_vector(data, tmp); + } + for (uint32_t i = 0; i < nPeers; i++) { + int rankIdx = (rank + i + 1) % nRanksPerNode; + int peerIdx = rankIdx < rank ? rankIdx : rankIdx - 1; + mscclpp::write(((void**)remoteMemories)[peerIdx], offset, tmp); + } + if (offset == lastInt4Index && remainder != 0) { + for (uint32_t i = 0; i < remainder; i++) { + ((T*)&resultBuff4[offset])[i] = ((T*)&tmp)[i]; + } + continue; + } + resultBuff4[offset] = tmp; + } + __syncthreads(); + if (threadIdx.x < nPeers) { + memoryChannelsLocal[threadIdx.x].signal(); + memoryChannelsLocal[threadIdx.x].wait(); + } + __syncthreads(); + for (uint32_t idx = threadIdx.x; idx < nInt4ForCopy; idx += blockDim.x) { + int rankIdx = idx / nInt4PerBlock; + if (rankIdx == rank) continue; + uint32_t offsetIdx = rankIdx * nInt4PerRank + offset4 + (idx % nInt4PerBlock); + if (offsetIdx > lastInt4Index) continue; + if (offsetIdx == lastInt4Index && remainder != 0) { + for (uint32_t i = 0; i < remainder; i++) { + ((T*)&resultBuff4[offsetIdx])[i] = ((T*)&scratch4[offsetIdx])[i]; + } + continue; + } + resultBuff4[offsetIdx] = scratch4[offsetIdx]; + } +} + +template +struct AllreduceRsAgAdapter { + static cudaError_t call(const void* input, void* scratch, void* output, void* memoryChannels, void* remoteMemories, + DeviceHandle* switchChannel, DeviceHandle*, size_t, size_t, + size_t, int rank, int nRanksPerNode, int worldSize, size_t inputSize, cudaStream_t stream, + void*, uint32_t, uint32_t, int nBlocks, int nThreadsPerBlock) { + using ChannelType = DeviceHandle; + size_t nelems = inputSize / sizeof(T); + if (nBlocks == 0 || nThreadsPerBlock == 0) { + nThreadsPerBlock = 1024; + nBlocks = 64; + } + allreduceRsAg<<>>( + (T*)input, (T*)scratch, (T*)output, (ChannelType*)memoryChannels, switchChannel, remoteMemories, rank, + nRanksPerNode, worldSize, nelems); + return cudaGetLastError(); + } +}; + +void AllreduceRsAg::initialize(std::shared_ptr comm) { + this->conns_ = setupConnections(comm); + nChannelsPerConnection_ = 64; + comm_ = comm; + // setup semaphores + this->scratchSemaphores_ = setupMemorySemaphores(comm, this->conns_, nChannelsPerConnection_); + RegisteredMemory localMemory = comm->registerMemory(scratchBuffer_, scratchBufferSize_, Transport::CudaIpc); + this->remoteScratchMemories_ = setupRemoteMemories(comm, comm->bootstrap()->getRank(), localMemory); + localScratchMemory_ = std::move(localMemory); + + this->baseChannels_ = setupBaseMemoryChannels(this->conns_, this->scratchSemaphores_, nChannelsPerConnection_); + this->baseMemoryChannelHandles_ = setupBaseMemoryChannelDeviceHandles(baseChannels_); + std::vector remoteMemoryHandles; + for (const auto& remoteMemory : this->remoteScratchMemories_) { + remoteMemoryHandles.push_back(remoteMemory.data()); + } + this->remoteMemoryHandles_ = detail::gpuCallocShared(remoteMemoryHandles.size()); + gpuMemcpy(this->remoteMemoryHandles_.get(), remoteMemoryHandles.data(), remoteMemoryHandles.size(), + cudaMemcpyHostToDevice); +} + +CommResult AllreduceRsAg::allreduceKernelFunc(const std::shared_ptr ctx, const void* input, void* output, + size_t inputSize, DataType dtype, ReduceOp op, cudaStream_t stream, + int nBlocks, int nThreadsPerBlock, + const std::unordered_map&) { + auto algoCtx = std::static_pointer_cast(ctx); + AllreduceFunc allreduce = dispatch(op, dtype); + if (!allreduce) { + WARN(ALGO, "Unsupported operation or data type for allreduce: op=", static_cast(op), + ", dtype=", static_cast(dtype)); + return CommResult::CommInvalidArgument; + } + if (inputSize > this->scratchBufferSize_) { + WARN(ALGO, "Input size ", inputSize, " exceeds scratch buffer size ", this->scratchBufferSize_); + return CommResult::CommInvalidArgument; + } + std::pair numBlocksAndThreads = {nBlocks, nThreadsPerBlock}; + cudaError_t error = allreduce(input, this->scratchBuffer_, output, this->baseMemoryChannelHandles_.get(), + this->remoteMemoryHandles_.get(), nullptr, nullptr, 0, 0, 0, algoCtx->rank, + algoCtx->nRanksPerNode, algoCtx->workSize, inputSize, stream, nullptr, 0, 0, + numBlocksAndThreads.first, numBlocksAndThreads.second); + if (error != cudaSuccess) { + WARN(ALGO, "Allreduce kernel launch failed with error: ", cudaGetErrorString(error)); + return CommResult::CommUnhandledCudaError; + } + return CommResult::CommSuccess; +} + +AlgorithmCtxKey AllreduceRsAg::generateAllreduceContextKey(const void*, void*, size_t, DataType, bool) { + return AlgorithmCtxKey{nullptr, nullptr, 0, 0, 0}; +} + +std::shared_ptr AllreduceRsAg::initAllreduceContext(std::shared_ptr comm, const void*, void*, + size_t, DataType) { + auto ctx = std::make_shared(); + ctx->rank = comm->bootstrap()->getRank(); + ctx->workSize = comm->bootstrap()->getNranks(); + ctx->nRanksPerNode = comm->bootstrap()->getNranksPerNode(); + + ctx->memorySemaphores = this->scratchSemaphores_; + ctx->registeredMemories = this->remoteScratchMemories_; + return ctx; +} + +std::shared_ptr AllreduceRsAg::build() { + auto self = std::make_shared((uintptr_t)scratchBuffer_, scratchBufferSize_); + return std::make_shared( + "default_allreduce_rsag", "allreduce", + [self](std::shared_ptr comm) { self->initialize(comm); }, + [self](const std::shared_ptr ctx, const void* input, void* output, size_t inputSize, + [[maybe_unused]] size_t outputSize, DataType dtype, ReduceOp op, cudaStream_t stream, int nBlocks, + int nThreadsPerBlock, const std::unordered_map& extras) -> CommResult { + return self->allreduceKernelFunc(ctx, input, output, inputSize, dtype, op, stream, nBlocks, nThreadsPerBlock, + extras); + }, + [self](std::shared_ptr comm, const void* input, void* output, size_t inputSize, + [[maybe_unused]] size_t outputSize, + DataType dtype) { return self->initAllreduceContext(comm, input, output, inputSize, dtype); }, + [self](const void* input, void* output, size_t inputSize, [[maybe_unused]] size_t outputSize, DataType dtype, + bool symmetricMemory) { + return self->generateAllreduceContextKey(input, output, inputSize, dtype, symmetricMemory); + }); +} +} // namespace collective +} // namespace mscclpp diff --git a/src/ext/collectives/allreduce/allreduce_rsag_pipeline.cu b/src/ext/collectives/allreduce/allreduce_rsag_pipeline.cu new file mode 100644 index 00000000..a230d8cd --- /dev/null +++ b/src/ext/collectives/allreduce/allreduce_rsag_pipeline.cu @@ -0,0 +1,336 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#include "allreduce/allreduce_rsag_pipeline.hpp" +#include "allreduce/common.hpp" +#include "collective_utils.hpp" +#include "logger.hpp" + +namespace mscclpp { +namespace collective { +constexpr int MAX_NBLOCKS_FOR_PUT = 32; +constexpr int MAX_NBLOCKS_FOR_RECV = 32; +constexpr int MAX_NBLOCKS_FOR_REDUCE = 64; +constexpr int REDUCE_COPY_RATIO = 2; +__device__ DeviceSemaphore semaphoreForSend[MAX_NBLOCKS_FOR_REDUCE]; +__device__ DeviceSemaphore semaphoreForRecv[MAX_NBLOCKS_FOR_REDUCE]; +__device__ DeviceSemaphore semaphoreForReduce[MAX_NBLOCKS_FOR_REDUCE]; + +// TODO: move it to a common header file +template +__device__ __forceinline__ int4 loadVec(const T* buff, size_t i, size_t nelems) { + constexpr size_t ElemsPerInt4 = sizeof(int4) / sizeof(T); + size_t offset = i * ElemsPerInt4; + if (offset + ElemsPerInt4 <= nelems) { + return reinterpret_cast(buff)[i]; + } else { + union { + int4 i; + T t[ElemsPerInt4]; + } vec; + vec.i = make_int4(0, 0, 0, 0); + for (size_t j = 0; j < ElemsPerInt4 && offset + j < nelems; ++j) { + vec.t[j] = buff[offset + j]; + } + return vec.i; + } +} + +template +__device__ __forceinline__ void storeVec(T* buff, size_t i, int4 val, size_t nelems) { + constexpr size_t ElemsPerInt4 = sizeof(int4) / sizeof(T); + size_t offset = i * ElemsPerInt4; + if (offset + ElemsPerInt4 <= nelems) { + reinterpret_cast(buff)[i] = val; + } else { + union { + int4 i; + T t[ElemsPerInt4]; + } vec; + vec.i = val; + for (size_t j = 0; j < ElemsPerInt4 && offset + j < nelems; ++j) { + buff[offset + j] = vec.t[j]; + } + } +} + +// Pipelined Reduce-Scatter + All-Gather (RSAG) allreduce. +// +// This is a pipelined variant of the basic RSAG allreduce that overlaps +// communication and computation by splitting the data into chunks processed +// across multiple iterations. Three groups of thread blocks run concurrently +// with different roles, synchronized via device semaphores: +// +// PUT blocks — Read local input chunks and write them into peers' scratch +// buffers via remote memory handles (CudaIpc). +// +// REDUCE blocks — After a signal/wait barrier confirming PUT completion, +// reduce the local chunk with data received from all peers +// in the scratch buffer. Write the reduced result to both +// the local output and peers' scratch (for the AG phase). +// +// RECV blocks — After a signal/wait barrier confirming REDUCE completion, +// copy other ranks' reduced chunks from scratch into the +// local result buffer, completing the all-gather. +// +// Pipelining is achieved by using a circular scratch buffer (pipelineDepth +// stages). PUT blocks wait on a semaphore before reusing a scratch slot, +// allowing the next iteration's PUT to overlap with the current iteration's +// REDUCE and RECV. Each REDUCE block handles a subset of the PUT block's +// data (controlled by REDUCE_COPY_RATIO), enabling finer-grained overlap. +// +// Data is processed in int4-sized (16-byte) units with vectorized load/store +// helpers that handle tail elements. + +template +__global__ void __launch_bounds__(1024, 1) + allreduceRsAgPipeline(T* buff, T* scratch, T* resultBuff, DeviceHandle* memoryChannels, + DeviceHandle* switchChannels, void* remoteMemories, int rank, + int nRanksPerNode, int worldSize, size_t nelems, size_t scratchSize, uint32_t nblocksForPut, + uint32_t nblocksForReduce, uint32_t nblocksForRecv) { + uint32_t bid = blockIdx.x; + constexpr uint32_t nStepsPerIter = 4; + uint32_t nInt4 = (nelems * sizeof(T) + sizeof(int4) - 1) / sizeof(int4); + uint32_t nInt4PerIter = nblocksForReduce * blockDim.x * nStepsPerIter; + const uint32_t chunkSize = nInt4PerIter * worldSize; + uint32_t nIters = (nInt4 + chunkSize - 1) / chunkSize; + uint32_t nPeers = nRanksPerNode - 1; + int4* scratch4 = reinterpret_cast((char*)scratch); + const uint32_t scratchIterStride = 2 * chunkSize; // one for AS, one for AG + const uint32_t pipelineDepth = scratchSize / sizeof(int4) / scratchIterStride; + assert(pipelineDepth >= 1); + + if (bid < nblocksForPut) { + if (threadIdx.x == 0) { + semaphoreForSend[bid].set(pipelineDepth); + } + for (uint32_t iter = 0; iter < nIters; iter++) { + if (threadIdx.x == 0) { + semaphoreForSend[bid].acquire(); + } + __syncthreads(); + uint32_t threadIdInPut = bid * blockDim.x + threadIdx.x; + for (uint32_t peer = 0; peer < nPeers; peer++) { + int remoteRankId = (rank + peer + 1) % nRanksPerNode; + int peerId = remoteRankId < rank ? remoteRankId : remoteRankId - 1; + // Read chunk[remoteRankId] from local buff, write to peer's scratch[rank] (sender's slot) + uint32_t srcOffset = iter * chunkSize + remoteRankId * nInt4PerIter; + uint32_t dstOffset = (iter % pipelineDepth) * scratchIterStride + rank * nInt4PerIter; + int4 tmp[nStepsPerIter * REDUCE_COPY_RATIO]; +#pragma unroll + for (uint32_t step = 0; step < nStepsPerIter * REDUCE_COPY_RATIO; step++) { + uint32_t offset = srcOffset + threadIdInPut + step * blockDim.x * nblocksForPut; + tmp[step] = loadVec(buff, offset, nelems); + } +#pragma unroll + for (uint32_t step = 0; step < nStepsPerIter * REDUCE_COPY_RATIO; step++) { + uint32_t offset = dstOffset + threadIdInPut + step * blockDim.x * nblocksForPut; + mscclpp::write(((void**)remoteMemories)[peerId], offset, tmp[step]); + } + } + __syncthreads(); + if (threadIdx.x < REDUCE_COPY_RATIO) { + semaphoreForReduce[bid * REDUCE_COPY_RATIO + threadIdx.x].release(); + } + } + } else if (bid < nblocksForPut + nblocksForReduce) { + uint32_t bidInReduce = bid - nblocksForPut; + DeviceHandle* localMemoryChannels = memoryChannels + bidInReduce * nPeers; + // Map REDUCE blocks to PUT blocks: REDUCE blocks 0,1 handle PUT block 0's data + uint32_t putBlockId = bidInReduce / REDUCE_COPY_RATIO; + uint32_t subBlockId = bidInReduce % REDUCE_COPY_RATIO; + for (uint32_t iter = 0; iter < nIters; iter++) { + if (threadIdx.x == 0) { + semaphoreForReduce[bidInReduce].acquire(); + } + uint32_t baseOffset = (iter % pipelineDepth) * scratchIterStride; + uint32_t baseSrcOffset = iter * chunkSize; + + // Use same thread mapping as PUT: putBlockId * blockDim.x + threadIdx.x + uint32_t threadIdInPut = putBlockId * blockDim.x + threadIdx.x; + __syncthreads(); + if (threadIdx.x < nPeers) { + localMemoryChannels[threadIdx.x].signal(); + localMemoryChannels[threadIdx.x].wait(); + } + __syncthreads(); +#pragma unroll nStepsPerIter + for (uint32_t step = 0; step < nStepsPerIter; step++) { + // Map to PUT's step pattern: each REDUCE block handles nStepsPerIter steps + // subBlockId determines which subset of the REDUCE_COPY_RATIO * nStepsPerIter steps + uint32_t putStep = subBlockId * nStepsPerIter + step; + uint32_t myChunkOffset = + baseSrcOffset + rank * nInt4PerIter + threadIdInPut + putStep * blockDim.x * nblocksForPut; + int4 tmp = loadVec(buff, myChunkOffset, nelems); + // Add data from each peer's slot in scratch (peer sent their chunk[rank] to our scratch[peer]) + for (uint32_t peer = 0; peer < nPeers; peer++) { + int remoteRankId = (rank + peer + 1) % nRanksPerNode; + uint32_t peerSlotOffset = + baseOffset + remoteRankId * nInt4PerIter + threadIdInPut + putStep * blockDim.x * nblocksForPut; + int4 data = scratch4[peerSlotOffset]; + tmp = cal_vector(data, tmp); + } + storeVec(resultBuff, myChunkOffset, tmp, nelems); + // Broadcast reduced result to all peers' scratch at SCATTER_AG_OFFSET + rank * nInt4PerIter + uint32_t dstOffset = + baseOffset + chunkSize + rank * nInt4PerIter + threadIdInPut + putStep * blockDim.x * nblocksForPut; + for (uint32_t i = 0; i < nPeers; i++) { + int peerIdx = (rank + i + 1) % nRanksPerNode; + int index = peerIdx < rank ? peerIdx : peerIdx - 1; + mscclpp::write(((void**)remoteMemories)[index], dstOffset, tmp); + } + } + __syncthreads(); + if (threadIdx.x == 0) { + semaphoreForRecv[bidInReduce].release(); + } + } + } else if (bid < nblocksForPut + nblocksForReduce + nblocksForRecv) { + uint32_t bidInRecv = bid - nblocksForPut - nblocksForReduce; + DeviceHandle* localMemoryChannels = memoryChannels + (nblocksForReduce + bidInRecv) * nPeers; + for (uint32_t iter = 0; iter < nIters; iter++) { + if (threadIdx.x < REDUCE_COPY_RATIO) { + semaphoreForRecv[bidInRecv * REDUCE_COPY_RATIO + threadIdx.x].acquire(); + } + uint32_t baseOffset = scratchIterStride * (iter % pipelineDepth); + uint32_t baseDstOffset = chunkSize * iter; + int threadIdInRecv = bidInRecv * blockDim.x + threadIdx.x; + __syncthreads(); + if (threadIdx.x < nPeers) { + localMemoryChannels[threadIdx.x].signal(); + localMemoryChannels[threadIdx.x].wait(); + } + __syncthreads(); + // Copy other ranks' reduced chunks from scratch to result + for (uint32_t peer = 0; peer < nPeers; peer++) { + int remoteRankId = (rank + peer + 1) % nRanksPerNode; + for (uint32_t step = 0; step < nStepsPerIter * REDUCE_COPY_RATIO; step++) { + uint32_t offset = baseOffset + chunkSize + remoteRankId * nInt4PerIter + threadIdInRecv + + step * blockDim.x * nblocksForRecv; + uint32_t dstOffset = + baseDstOffset + remoteRankId * nInt4PerIter + threadIdInRecv + step * blockDim.x * nblocksForRecv; + storeVec(resultBuff, dstOffset, scratch4[offset], nelems); + } + } + __syncthreads(); + if (threadIdx.x == 0) { + semaphoreForSend[bidInRecv].release(); + } + } + } +} + +template +struct AllreduceRsAgPipelineAdapter { + static cudaError_t call(const void* input, void* scratch, void* output, void* memoryChannels, void* remoteMemories, + DeviceHandle* switchChannel, DeviceHandle*, size_t, size_t, + size_t scratchSize, int rank, int nRanksPerNode, int worldSize, size_t inputSize, + cudaStream_t stream, void*, uint32_t, uint32_t, int nBlocks, int nThreadsPerBlock) { + using ChannelType = DeviceHandle; + size_t nelems = inputSize / sizeof(T); + uint32_t nblocksForPut = MAX_NBLOCKS_FOR_PUT; + uint32_t nblocksForReduce = MAX_NBLOCKS_FOR_REDUCE; + uint32_t nblocksForRecv = MAX_NBLOCKS_FOR_RECV; + int maxNblocks = nblocksForPut + nblocksForReduce + nblocksForRecv; + if (nBlocks == 0 || nThreadsPerBlock == 0) { + nThreadsPerBlock = 1024; + nBlocks = maxNblocks; + } else { + nBlocks = nBlocks / (REDUCE_COPY_RATIO + 2) * (REDUCE_COPY_RATIO + 2); + if (nBlocks > maxNblocks) { + WARN(ALGO, "The number of blocks is too large for the allreduce pipeline algorithm, reducing it to ", + maxNblocks); + nBlocks = maxNblocks; + } + nblocksForPut = nBlocks / (REDUCE_COPY_RATIO + 2); + nblocksForReduce = nblocksForPut * REDUCE_COPY_RATIO; + nblocksForRecv = nblocksForPut; + } + allreduceRsAgPipeline<<>>( + (T*)input, (T*)scratch, (T*)output, (ChannelType*)memoryChannels, switchChannel, remoteMemories, rank, + nRanksPerNode, worldSize, nelems, scratchSize, nblocksForPut, nblocksForReduce, nblocksForRecv); + return cudaGetLastError(); + } +}; + +void AllreduceRsAgPipeline::initialize(std::shared_ptr comm) { + this->conns_ = setupConnections(comm); + nChannelsPerConnection_ = MAX_NBLOCKS_FOR_REDUCE + MAX_NBLOCKS_FOR_RECV; + comm_ = comm; + // setup semaphores + this->scratchSemaphores_ = setupMemorySemaphores(comm, this->conns_, nChannelsPerConnection_); + RegisteredMemory localMemory = comm->registerMemory(scratchBuffer_, scratchBufferSize_, Transport::CudaIpc); + this->remoteScratchMemories_ = setupRemoteMemories(comm, comm->bootstrap()->getRank(), localMemory); + localScratchMemory_ = std::move(localMemory); + + this->baseChannels_ = setupBaseMemoryChannels(this->conns_, this->scratchSemaphores_, nChannelsPerConnection_); + this->baseMemoryChannelHandles_ = setupBaseMemoryChannelDeviceHandles(baseChannels_); + std::vector remoteMemoryHandles; + for (const auto& remoteMemory : this->remoteScratchMemories_) { + remoteMemoryHandles.push_back(remoteMemory.data()); + } + this->remoteMemoryHandles_ = detail::gpuCallocShared(remoteMemoryHandles.size()); + gpuMemcpy(this->remoteMemoryHandles_.get(), remoteMemoryHandles.data(), remoteMemoryHandles.size(), + cudaMemcpyHostToDevice); +} + +CommResult AllreduceRsAgPipeline::allreduceKernelFunc(const std::shared_ptr ctx, const void* input, void* output, + size_t inputSize, DataType dtype, ReduceOp op, + cudaStream_t stream, int nBlocks, int nThreadsPerBlock, + const std::unordered_map&) { + auto algoCtx = std::static_pointer_cast(ctx); + AllreduceFunc allreduce = dispatch(op, dtype); + if (!allreduce) { + WARN(ALGO, "Unsupported operation or data type for allreduce: op=", static_cast(op), + ", dtype=", static_cast(dtype)); + return CommResult::CommInvalidArgument; + } + std::pair numBlocksAndThreads = {nBlocks, nThreadsPerBlock}; + cudaError_t error = allreduce(input, this->scratchBuffer_, output, this->baseMemoryChannelHandles_.get(), + this->remoteMemoryHandles_.get(), nullptr, nullptr, 0, 0, this->scratchBufferSize_, + algoCtx->rank, algoCtx->nRanksPerNode, algoCtx->workSize, inputSize, stream, nullptr, 0, + 0, numBlocksAndThreads.first, numBlocksAndThreads.second); + if (error != cudaSuccess) { + WARN(ALGO, "Allreduce kernel launch failed with error: ", cudaGetErrorString(error)); + return CommResult::CommUnhandledCudaError; + } + return CommResult::CommSuccess; +} + +AlgorithmCtxKey AllreduceRsAgPipeline::generateAllreduceContextKey(const void*, void*, size_t, DataType, bool) { + return AlgorithmCtxKey{nullptr, nullptr, 0, 0, 0}; +} + +std::shared_ptr AllreduceRsAgPipeline::initAllreduceContext(std::shared_ptr comm, const void*, + void*, size_t, DataType) { + auto ctx = std::make_shared(); + ctx->rank = comm->bootstrap()->getRank(); + ctx->workSize = comm->bootstrap()->getNranks(); + ctx->nRanksPerNode = comm->bootstrap()->getNranksPerNode(); + + ctx->memorySemaphores = this->scratchSemaphores_; + ctx->registeredMemories = this->remoteScratchMemories_; + return ctx; +} + +std::shared_ptr AllreduceRsAgPipeline::build() { + auto self = std::make_shared((uintptr_t)scratchBuffer_, scratchBufferSize_); + return std::make_shared( + "default_allreduce_rsag_pipeline", "allreduce", + [self](std::shared_ptr comm) { self->initialize(comm); }, + [self](const std::shared_ptr ctx, const void* input, void* output, size_t inputSize, + [[maybe_unused]] size_t outputSize, DataType dtype, ReduceOp op, cudaStream_t stream, int nBlocks, + int nThreadsPerBlock, const std::unordered_map& extras) -> CommResult { + return self->allreduceKernelFunc(ctx, input, output, inputSize, dtype, op, stream, nBlocks, nThreadsPerBlock, + extras); + }, + [self](std::shared_ptr comm, const void* input, void* output, size_t inputSize, + [[maybe_unused]] size_t outputSize, + DataType dtype) { return self->initAllreduceContext(comm, input, output, inputSize, dtype); }, + [self](const void* input, void* output, size_t inputSize, [[maybe_unused]] size_t outputSize, DataType dtype, + bool symmetricMemory) { + return self->generateAllreduceContextKey(input, output, inputSize, dtype, symmetricMemory); + }); +} +} // namespace collective +} // namespace mscclpp diff --git a/src/ext/collectives/allreduce/allreduce_rsag_zero_copy.cu b/src/ext/collectives/allreduce/allreduce_rsag_zero_copy.cu new file mode 100644 index 00000000..caac07ae --- /dev/null +++ b/src/ext/collectives/allreduce/allreduce_rsag_zero_copy.cu @@ -0,0 +1,236 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#include "allreduce/allreduce_rsag_zero_copy.hpp" +#include "allreduce/common.hpp" +#include "collective_utils.hpp" +#include "logger.hpp" + +namespace mscclpp { +namespace collective { + +__device__ mscclpp::DeviceSyncer globalSyncer; + +// Zero-copy Reduce-Scatter + All-Gather (RSAG) allreduce. +// +// Unlike the standard RSAG which copies input into a scratch buffer first, +// this variant reads directly from peers' input buffers and writes reduced +// results directly to peers' output buffers — eliminating the need for a +// separate scratch buffer and reducing memory traffic. +// +// The algorithm runs in a single kernel with the following steps: +// +// 1. Barrier: Signal and wait on all peers to ensure input buffers are ready. +// +// 2. Reduce-Scatter: Each rank reads its assigned chunk from every peer's +// input buffer (via CudaIpc remote memory handles), reduces all values +// locally, then writes the reduced result to its own output buffer AND +// directly to every peer's output buffer at the same offset. +// +// 3. Global sync + Barrier: A device-wide sync ensures all writes complete, +// followed by a final signal/wait to guarantee all peers have finished +// writing, making the full output buffer valid on every rank. +// +// This approach requires registering both input and output buffers as remote +// memories (2 * nPeers handles), but avoids scratch buffer allocation and +// the extra copy steps of the standard RSAG. The NRanksPerNode template +// parameter enables compile-time unrolling of peer loops (supports 4 or 8). + +template +__global__ void __launch_bounds__(1024, 1) + allreduceRsAgZeroCopy(T* buff, T* scratch, T* resultBuff, DeviceHandle* memoryChannels, + DeviceHandle* switchChannels, void* remoteMemories, int rank, int worldSize, + size_t nelems) { + int blockId = blockIdx.x; + + assert((uintptr_t)buff % sizeof(int4) == 0); + assert((uintptr_t)resultBuff % sizeof(int4) == 0); + + constexpr int NPeers = NRanksPerNode - 1; + constexpr uint32_t nelemsPerInt4 = sizeof(int4) / sizeof(T); + const uint32_t outputRemoteBufferOffset = NRanksPerNode - 1; + uint32_t alignedNelems = ((nelems + NRanksPerNode - 1) / NRanksPerNode + nelemsPerInt4 - 1) / nelemsPerInt4 * + nelemsPerInt4 * NRanksPerNode; + uint32_t nelemsPerRank = alignedNelems / NRanksPerNode; + uint32_t nInt4PerRank = nelemsPerRank / nelemsPerInt4; + uint32_t nInt4Total = (nelems + nelemsPerInt4 - 1) / nelemsPerInt4; + + int4* resultBuff4 = reinterpret_cast((char*)resultBuff); + int4* buff4 = reinterpret_cast((char*)buff); + DeviceHandle* memoryChannelsLocal = memoryChannels + blockId * NPeers; + + uint32_t nInt4PerBlock = nInt4PerRank / gridDim.x; + uint32_t remainderForBlock = nInt4PerRank % gridDim.x; + uint32_t offset4 = blockId * nInt4PerBlock; + if (blockId == (int)(gridDim.x - 1)) { + nInt4PerBlock += remainderForBlock; + } + if (nInt4PerBlock == 0) return; + + if (threadIdx.x < NPeers) { + memoryChannelsLocal[threadIdx.x].relaxedSignal(); + memoryChannelsLocal[threadIdx.x].relaxedWait(); + } + __syncthreads(); + int4 data[NPeers]; + for (uint32_t idx = threadIdx.x; idx < nInt4PerBlock; idx += blockDim.x) { + uint32_t offset = idx + offset4 + rank * nInt4PerRank; + if (offset >= nInt4Total) continue; + int4 tmp = buff4[offset]; +#pragma unroll + for (int i = 0; i < NPeers; i++) { + int rankIdx = (rank + i + 1) % NRanksPerNode; + int peerIdx = rankIdx < rank ? rankIdx : rankIdx - 1; + data[i] = mscclpp::read(((void**)remoteMemories)[peerIdx], offset); + } + for (int i = 0; i < NPeers; i++) { + tmp = cal_vector(data[i], tmp); + } +#pragma unroll + for (int i = 0; i < NPeers; i++) { + int rankIdx = (rank + i + 1) % NRanksPerNode; + int peerIdx = rankIdx < rank ? rankIdx : rankIdx - 1; + mscclpp::write(((void**)remoteMemories)[outputRemoteBufferOffset + peerIdx], offset, tmp); + } + resultBuff4[offset] = tmp; + } + // Use device barrier gives better performance here. + globalSyncer.sync(gridDim.x); + if (blockIdx.x == 0 && threadIdx.x < NPeers) { + memoryChannelsLocal[threadIdx.x].signal(); + memoryChannelsLocal[threadIdx.x].wait(); + } +} + +template +struct AllreduceRsAgZeroCopyAdapter { + static cudaError_t call(const void* input, void* scratch, void* output, void* memoryChannels, void* remoteMemories, + DeviceHandle* switchChannel, DeviceHandle*, size_t, size_t, + size_t, int rank, int nRanksPerNode, int worldSize, size_t inputSize, cudaStream_t stream, + void*, uint32_t, uint32_t, int nBlocks, int nThreadsPerBlock) { + using ChannelType = DeviceHandle; + size_t nelems = inputSize / sizeof(T); + if (nBlocks == 0 || nThreadsPerBlock == 0) { + nThreadsPerBlock = 1024; + nBlocks = 64; + if (inputSize >= (1 << 26)) { + nBlocks = 128; + } + } + if (nRanksPerNode == 4) { + allreduceRsAgZeroCopy<4, OpType, T> + <<>>((T*)input, (T*)scratch, (T*)output, (ChannelType*)memoryChannels, + switchChannel, remoteMemories, rank, worldSize, nelems); + } else if (nRanksPerNode == 8) { + allreduceRsAgZeroCopy<8, OpType, T> + <<>>((T*)input, (T*)scratch, (T*)output, (ChannelType*)memoryChannels, + switchChannel, remoteMemories, rank, worldSize, nelems); + } else { + THROW(ALGO, Error, ErrorCode::InvalidUsage, "Unsupported number of ranks per node: ", nRanksPerNode); + } + return cudaGetLastError(); + } +}; + +void AllreduceRsAgZeroCopy::initialize(std::shared_ptr comm) { + this->conns_ = setupConnections(comm); + nChannelsPerConnection_ = 128; + comm_ = comm; + // setup semaphores + this->semaphores_ = setupMemorySemaphores(comm, this->conns_, nChannelsPerConnection_); + this->baseChannels_ = setupBaseMemoryChannels(this->conns_, this->semaphores_, nChannelsPerConnection_); + this->baseMemoryChannelHandles_ = setupBaseMemoryChannelDeviceHandles(baseChannels_); +} + +CommResult AllreduceRsAgZeroCopy::allreduceKernelFunc(const std::shared_ptr ctx, const void* input, void* output, + size_t inputSize, DataType dtype, ReduceOp op, + cudaStream_t stream, int nBlocks, int nThreadsPerBlock, + const std::unordered_map&) { + auto algoCtx = std::static_pointer_cast(ctx); + AllreduceFunc allreduce = dispatch(op, dtype); + if (!allreduce) { + WARN(ALGO, "Unsupported operation or data type for allreduce: op=", static_cast(op), + ", dtype=", static_cast(dtype)); + return CommResult::CommInvalidArgument; + } + std::pair numBlocksAndThreads = {nBlocks, nThreadsPerBlock}; + cudaError_t error = + allreduce(input, nullptr, output, this->baseMemoryChannelHandles_.get(), algoCtx->remoteMemoryHandles.get(), + nullptr, nullptr, 0, 0, 0, algoCtx->rank, algoCtx->nRanksPerNode, algoCtx->workSize, inputSize, stream, + nullptr, 0, 0, numBlocksAndThreads.first, numBlocksAndThreads.second); + if (error != cudaSuccess) { + WARN(ALGO, "Allreduce kernel launch failed with error: ", cudaGetErrorString(error)); + return CommResult::CommUnhandledCudaError; + } + return CommResult::CommSuccess; +} + +AlgorithmCtxKey AllreduceRsAgZeroCopy::generateAllreduceContextKey(const void* inputBuffer, void* outputBuffer, + size_t size, DataType, bool symmetricMemory) { + // For non-symmetric algorithms, we use both input and output buffer pointers in the key. + static int tag = 0; + if (symmetricMemory) { + size_t inputBytes, outputBytes; + CUdeviceptr inputBasePtr, outputBasePtr; + MSCCLPP_CUTHROW(cuMemGetAddressRange(&inputBasePtr, &inputBytes, (CUdeviceptr)inputBuffer)); + MSCCLPP_CUTHROW(cuMemGetAddressRange(&outputBasePtr, &outputBytes, (CUdeviceptr)outputBuffer)); + return AlgorithmCtxKey{(void*)inputBasePtr, (void*)outputBasePtr, inputBytes, outputBytes, 0}; + } + return AlgorithmCtxKey{(void*)inputBuffer, outputBuffer, size, size, ++tag}; +} + +std::shared_ptr AllreduceRsAgZeroCopy::initAllreduceContext(std::shared_ptr comm, const void* input, + void* output, size_t size, DataType) { + auto ctx = std::make_shared(); + ctx->rank = comm->bootstrap()->getRank(); + ctx->workSize = comm->bootstrap()->getNranks(); + ctx->nRanksPerNode = comm->bootstrap()->getNranksPerNode(); + + ctx->memorySemaphores = this->semaphores_; + + // register input and output memories + RegisteredMemory inputMemory = comm->registerMemory((void*)input, size, Transport::CudaIpc); + RegisteredMemory outputMemory = comm->registerMemory(output, size, Transport::CudaIpc); + this->inputMemories_.push_back(inputMemory); + this->outputMemories_.push_back(outputMemory); + + auto remoteInputMemories = setupRemoteMemories(comm, ctx->rank, inputMemory); + auto remoteOutputMemories = setupRemoteMemories(comm, ctx->rank, outputMemory); + ctx->registeredMemories.insert(ctx->registeredMemories.end(), remoteInputMemories.begin(), remoteInputMemories.end()); + ctx->registeredMemories.insert(ctx->registeredMemories.end(), remoteOutputMemories.begin(), + remoteOutputMemories.end()); + std::vector remoteMemoryHandles; + for (const auto& remoteMemory : ctx->registeredMemories) { + remoteMemoryHandles.push_back(remoteMemory.data()); + } + ctx->remoteMemoryHandles = detail::gpuCallocShared(remoteMemoryHandles.size()); + gpuMemcpy(ctx->remoteMemoryHandles.get(), remoteMemoryHandles.data(), remoteMemoryHandles.size(), + cudaMemcpyHostToDevice); + + // store local registered memories to ctx for lifetime management + ctx->registeredMemories.push_back(inputMemory); + ctx->registeredMemories.push_back(outputMemory); + return ctx; +} + +std::shared_ptr AllreduceRsAgZeroCopy::build() { + auto self = std::make_shared(); + return std::make_shared( + "default_allreduce_rsag_zero_copy", "allreduce", + [self](std::shared_ptr comm) { self->initialize(comm); }, + [self](const std::shared_ptr ctx, const void* input, void* output, size_t inputSize, + [[maybe_unused]] size_t outputSize, DataType dtype, ReduceOp op, cudaStream_t stream, int nBlocks, + int nThreadsPerBlock, const std::unordered_map& extras) -> CommResult { + return self->allreduceKernelFunc(ctx, input, output, inputSize, dtype, op, stream, nBlocks, nThreadsPerBlock, + extras); + }, + [self](std::shared_ptr comm, const void* input, void* output, size_t inputSize, + [[maybe_unused]] size_t outputSize, + DataType dtype) { return self->initAllreduceContext(comm, input, output, inputSize, dtype); }, + [self](const void* input, void* output, size_t inputSize, [[maybe_unused]] size_t outputSize, DataType dtype, + bool symmetricMemory) { + return self->generateAllreduceContextKey(input, output, inputSize, dtype, symmetricMemory); + }); +} +} // namespace collective +} // namespace mscclpp diff --git a/src/ext/collectives/include/allreduce/allreduce_rsag.hpp b/src/ext/collectives/include/allreduce/allreduce_rsag.hpp new file mode 100644 index 00000000..6e033f67 --- /dev/null +++ b/src/ext/collectives/include/allreduce/allreduce_rsag.hpp @@ -0,0 +1,43 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#ifndef MSCCLPP_EXT_ALLREDUCE_RSAG_HPP_ +#define MSCCLPP_EXT_ALLREDUCE_RSAG_HPP_ + +#include + +namespace mscclpp { +namespace collective { + +class AllreduceRsAg : public mscclpp::AlgorithmBuilder { + public: + AllreduceRsAg(uintptr_t scratchBuffer, size_t scratchBufferSize) + : scratchBuffer_((void*)scratchBuffer), scratchBufferSize_(scratchBufferSize){}; + std::shared_ptr build() override; + + private: + void initialize(std::shared_ptr comm); + CommResult allreduceKernelFunc(const std::shared_ptr ctx, const void* input, void* output, size_t inputSize, + DataType dtype, ReduceOp op, cudaStream_t stream, int nBlocks, int nThreadsPerBlock, + const std::unordered_map& extras); + + std::shared_ptr initAllreduceContext(std::shared_ptr comm, const void*, void* output, size_t, + DataType); + AlgorithmCtxKey generateAllreduceContextKey(const void*, void*, size_t, DataType, bool); + void* scratchBuffer_; + size_t scratchBufferSize_; + std::shared_ptr comm_; + int nChannelsPerConnection_; + std::vector conns_; + std::vector> scratchSemaphores_; + std::vector remoteScratchMemories_; + RegisteredMemory localScratchMemory_; + + std::vector baseChannels_; + std::shared_ptr> baseMemoryChannelHandles_; + std::shared_ptr remoteMemoryHandles_; +}; +} // namespace collective +} // namespace mscclpp + +#endif // MSCCLPP_EXT_ALLREDUCE_RSAG_HPP_ \ No newline at end of file diff --git a/src/ext/collectives/include/allreduce/allreduce_rsag_pipeline.hpp b/src/ext/collectives/include/allreduce/allreduce_rsag_pipeline.hpp new file mode 100644 index 00000000..2a740ac0 --- /dev/null +++ b/src/ext/collectives/include/allreduce/allreduce_rsag_pipeline.hpp @@ -0,0 +1,43 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#ifndef MSCCLPP_EXT_ALLREDUCE_RSAG_PIPELINE_HPP_ +#define MSCCLPP_EXT_ALLREDUCE_RSAG_PIPELINE_HPP_ + +#include + +namespace mscclpp { +namespace collective { + +class AllreduceRsAgPipeline : public mscclpp::AlgorithmBuilder { + public: + AllreduceRsAgPipeline(uintptr_t scratchBuffer, size_t scratchBufferSize) + : scratchBuffer_((void*)scratchBuffer), scratchBufferSize_(scratchBufferSize){}; + std::shared_ptr build() override; + + private: + void initialize(std::shared_ptr comm); + CommResult allreduceKernelFunc(const std::shared_ptr ctx, const void* input, void* output, size_t inputSize, + DataType dtype, ReduceOp op, cudaStream_t stream, int nBlocks, int nThreadsPerBlock, + const std::unordered_map& extras); + + std::shared_ptr initAllreduceContext(std::shared_ptr comm, const void*, void* output, size_t, + DataType); + AlgorithmCtxKey generateAllreduceContextKey(const void*, void*, size_t, DataType, bool); + void* scratchBuffer_; + size_t scratchBufferSize_; + std::shared_ptr comm_; + int nChannelsPerConnection_; + std::vector conns_; + std::vector> scratchSemaphores_; + std::vector remoteScratchMemories_; + RegisteredMemory localScratchMemory_; + + std::vector baseChannels_; + std::shared_ptr> baseMemoryChannelHandles_; + std::shared_ptr remoteMemoryHandles_; +}; +} // namespace collective +} // namespace mscclpp + +#endif // MSCCLPP_EXT_ALLREDUCE_RSAG_PIPELINE_HPP_ \ No newline at end of file diff --git a/src/ext/collectives/include/allreduce/allreduce_rsag_zero_copy.hpp b/src/ext/collectives/include/allreduce/allreduce_rsag_zero_copy.hpp new file mode 100644 index 00000000..6153a0e4 --- /dev/null +++ b/src/ext/collectives/include/allreduce/allreduce_rsag_zero_copy.hpp @@ -0,0 +1,39 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#ifndef MSCCLPP_EXT_ALLREDUCE_RSAG_ZERO_COPY_HPP_ +#define MSCCLPP_EXT_ALLREDUCE_RSAG_ZERO_COPY_HPP_ + +#include + +namespace mscclpp { +namespace collective { + +class AllreduceRsAgZeroCopy : public mscclpp::AlgorithmBuilder { + public: + AllreduceRsAgZeroCopy() = default; + std::shared_ptr build() override; + + private: + void initialize(std::shared_ptr comm); + CommResult allreduceKernelFunc(const std::shared_ptr ctx, const void* input, void* output, size_t inputSize, + DataType dtype, ReduceOp op, cudaStream_t stream, int nBlocks, int nThreadsPerBlock, + const std::unordered_map& extras); + + std::shared_ptr initAllreduceContext(std::shared_ptr comm, const void*, void* output, size_t, + DataType); + AlgorithmCtxKey generateAllreduceContextKey(const void*, void*, size_t, DataType, bool); + std::shared_ptr comm_; + int nChannelsPerConnection_; + std::vector conns_; + std::vector> semaphores_; + std::vector inputMemories_; + std::vector outputMemories_; + + std::vector baseChannels_; + std::shared_ptr> baseMemoryChannelHandles_; +}; +} // namespace collective +} // namespace mscclpp + +#endif // MSCCLPP_EXT_ALLREDUCE_RSAG_ZERO_COPY_HPP_ \ No newline at end of file diff --git a/src/ext/collectives/include/collective_utils.hpp b/src/ext/collectives/include/collective_utils.hpp index 74cf83fd..bff2d5c9 100644 --- a/src/ext/collectives/include/collective_utils.hpp +++ b/src/ext/collectives/include/collective_utils.hpp @@ -84,6 +84,7 @@ class AlgorithmCtx { std::shared_ptr> portChannelDeviceHandles; std::vector> memorySemaphores; std::vector> hostSemaphores; + std::shared_ptr remoteMemoryHandles; std::unordered_map> extras; }; diff --git a/src/ext/nccl/algorithm_selector.cc b/src/ext/nccl/algorithm_selector.cc index be3c58c7..d523320e 100644 --- a/src/ext/nccl/algorithm_selector.cc +++ b/src/ext/nccl/algorithm_selector.cc @@ -71,7 +71,15 @@ static std::shared_ptr selectSingleNodeAllreduceBlackwell( if (messageSize <= (1 << 21)) { // <= 2MB return algoMap.at("default_allreduce_packet"); } - return nullptr; + if (config.inCaptureMode) { + // CUDA graph mode: setup new connections each time (zero-copy for graph) + return algoMap.at("default_allreduce_rsag_zero_copy"); + } + // Non-graph mode: use non-zero-copy algorithms + if (messageSize <= (1 << 23)) { // <= 8MB + return algoMap.at("default_allreduce_rsag"); + } + return algoMap.at("default_allreduce_rsag_pipeline"); } // Symmetric memory path: can use cached memory handles @@ -83,8 +91,7 @@ static std::shared_ptr selectSingleNodeAllreduceBlackwell( return algoMap.at("default_allreduce_nvls"); } - INFO(MSCCLPP_NCCL, "No suitable kernel for Blackwell architecture, fallback to nccl/rccl"); - return nullptr; + return algoMap.at("default_allreduce_rsag_zero_copy"); } std::shared_ptr selectSingleNodeAllreduce(