diff --git a/examples/torch-integration/alltoallv.py b/examples/torch-integration/alltoallv.py deleted file mode 100644 index 29968a26..00000000 --- a/examples/torch-integration/alltoallv.py +++ /dev/null @@ -1,315 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - -# AllToAllV implementation for MSCCLPP -# This module provides a PyTorch-compatible alltoallv operation using MSCCLPP. -# -# Usage: -# MSCCLPP_MASTER_ADDR= MSCCLPP_MASTER_PORT= \ -# torchrun --nnodes=1 --nproc_per_node=8 alltoallv.py -# -# For AMD GPUs: -# MSCCLPP_MASTER_ADDR= MSCCLPP_MASTER_PORT= \ -# GPU_MAX_HW_QUEUES=7 torchrun --nnodes=1 --nproc_per_node=8 alltoallv.py - -import mscclpp -import mscclpp.utils as mscclpp_utils -import torch -import os -import netifaces as ni -import ipaddress -from typing import List, Optional - -_abs_path = os.path.dirname(os.path.abspath(__file__)) - - -def interfaces_for_ip_netifaces(ip: str): - """Find the network interface for a given IP address.""" - target = ipaddress.ip_address(ip) - for interface in ni.interfaces(): - addresses = ni.ifaddresses(interface) - if ni.AF_INET in addresses: - for link in addresses[ni.AF_INET]: - if "addr" in link: - addr = ipaddress.ip_address(link["addr"]) - if addr == target: - return interface - return None - - -class AllToAllVComm: - """ - AllToAllV communication class using MSCCLPP. - - This class provides a customized alltoallv implementation that handles - variable element counts per rank, similar to MPI_Alltoallv or the - batch_all_to_all_v pattern commonly used in MOE (Mixture of Experts) models. - - Unlike NCCL's ncclGroupStart/ncclGroupEnd approach, MSCCLPP uses explicit - put/signal/wait operations on PortChannels for communication. - - Attributes: - comm: MSCCLPP CommGroup instance - rank: Current rank - world_size: Total number of ranks - """ - - def __init__(self, comm: mscclpp.CommGroup): - """ - Initialize AllToAllV communication. - - Args: - comm: MSCCLPP CommGroup instance - """ - self.comm = comm - self.rank = comm.my_rank - self.world_size = comm.nranks - self.local_rank = comm.my_rank % comm.nranks_per_node - self.n_ranks_per_node = comm.nranks_per_node - self.executor = mscclpp.Executor(comm.communicator) - - # Compile and load the native CUDA kernel - mscclpp_native = mscclpp.compile_native( - name="mscclpp_alltoallv", - file=os.path.join(_abs_path, "alltoallv_kernel.cu") - ) - capsule = mscclpp_native.create_alltoallv_algorithm() - self.algorithm = mscclpp.Algorithm.create_from_native_capsule(capsule) - - def alltoallv( - self, - send_tensor: torch.Tensor, - recv_tensor: torch.Tensor, - send_counts: torch.Tensor, - send_displs: torch.Tensor, - recv_counts: torch.Tensor, - recv_displs: torch.Tensor, - stream: Optional[torch.cuda.Stream] = None - ): - """ - Perform alltoallv operation with variable element counts. - - This function exchanges data between all ranks where each rank can send - different amounts of data to each other rank. - - Args: - send_tensor: Source tensor containing all data to be sent - recv_tensor: Destination tensor for received data - send_counts: Tensor of shape [world_size] with byte counts to send to each rank - send_displs: Tensor of shape [world_size] with byte offsets in send_tensor for each rank - recv_counts: Tensor of shape [world_size] with byte counts to receive from each rank - recv_displs: Tensor of shape [world_size] with byte offsets in recv_tensor for each rank - stream: Optional CUDA stream to use for the operation - - Note: - - All count and displacement tensors should be on GPU and contain size_t values - - send_counts[i] is the number of bytes to send to rank i - - send_displs[i] is the byte offset in send_tensor for data going to rank i - - recv_counts[i] is the number of bytes to receive from rank i - - recv_displs[i] is the byte offset in recv_tensor for data from rank i - """ - # Ensure counts and displacements are on GPU and have correct dtype - assert send_counts.device.type == "cuda", "send_counts must be on GPU" - assert send_displs.device.type == "cuda", "send_displs must be on GPU" - assert recv_counts.device.type == "cuda", "recv_counts must be on GPU" - assert recv_displs.device.type == "cuda", "recv_displs must be on GPU" - - # Prepare extras dict with device pointers for counts and displacements - extras = { - "sendCounts": send_counts.data_ptr(), - "sendDispls": send_displs.data_ptr(), - "recvCounts": recv_counts.data_ptr(), - "recvDispls": recv_displs.data_ptr(), - } - - cuda_stream = stream.cuda_stream if stream is not None else 0 - - self.algorithm.execute( - self.comm.communicator, - send_tensor.data_ptr(), - recv_tensor.data_ptr(), - send_tensor.nbytes, - recv_tensor.nbytes, - mscclpp_utils.torch_dtype_to_mscclpp_dtype(send_tensor.dtype), - stream=cuda_stream, - extras=extras - ) - - def alltoallv_by_elements( - self, - send_tensor: torch.Tensor, - recv_tensor: torch.Tensor, - send_counts_elements: torch.Tensor, - recv_counts_elements: torch.Tensor, - stream: Optional[torch.cuda.Stream] = None - ): - """ - Convenience function for alltoallv with element counts (not byte counts). - - This function automatically computes byte counts and displacements from - element counts, making it easier to use with typical MOE patterns. - - Args: - send_tensor: Source tensor containing all data to be sent - recv_tensor: Destination tensor for received data - send_counts_elements: Tensor of shape [world_size] with element counts to send to each rank - recv_counts_elements: Tensor of shape [world_size] with element counts to receive from each rank - stream: Optional CUDA stream to use for the operation - """ - element_size = send_tensor.element_size() - - # Convert element counts to byte counts - send_counts = send_counts_elements.to(torch.int64) * element_size - recv_counts = recv_counts_elements.to(torch.int64) * element_size - - # Compute displacements (exclusive prefix sum) - send_displs = torch.zeros(self.world_size, dtype=torch.int64, device=send_tensor.device) - recv_displs = torch.zeros(self.world_size, dtype=torch.int64, device=recv_tensor.device) - - if self.world_size > 1: - send_displs[1:] = torch.cumsum(send_counts[:-1], dim=0) - recv_displs[1:] = torch.cumsum(recv_counts[:-1], dim=0) - - self.alltoallv( - send_tensor, recv_tensor, - send_counts, send_displs, - recv_counts, recv_displs, - stream - ) - - def barrier_cpu(self): - """CPU barrier to synchronize all ranks.""" - self.comm.barrier() - - -def batch_alltoallv( - comm: AllToAllVComm, - inputs: List[torch.Tensor], - outputs: List[torch.Tensor], - in_sizes: torch.Tensor, - out_sizes: torch.Tensor, - stream: Optional[torch.cuda.Stream] = None -): - """ - Batch all-to-all-v operation for multiple tensors. - - This function replicates the pattern from MOE implementations: - ``` - for k in range(len(inputs)): - ncclGroupStart() - for i in range(world_size): - ncclSend(in_buff + in_offset, in_sizes[i], ...) - ncclRecv(out_buff + out_offset, out_sizes[i], ...) - ncclGroupEnd() - ``` - - Since MSCCLPP doesn't support ncclGroupStart/ncclGroupEnd, we implement - this using explicit alltoallv operations for each tensor in the batch. - - Args: - comm: AllToAllVComm instance - inputs: List of input tensors to send - outputs: List of output tensors to receive - in_sizes: Tensor of shape [world_size] with element counts to send to each rank - out_sizes: Tensor of shape [world_size] with element counts to receive from each rank - stream: Optional CUDA stream - """ - assert len(inputs) == len(outputs), "Input and output lists must have same length" - - # Ensure sizes are on CPU for computing displacements - in_sizes_cpu = in_sizes.cpu().to(torch.int64) - out_sizes_cpu = out_sizes.cpu().to(torch.int64) - - for k in range(len(inputs)): - input_tensor = inputs[k] - output_tensor = outputs[k] - - element_size = input_tensor.element_size() - - # Compute byte counts - send_counts = (in_sizes_cpu * element_size).cuda() - recv_counts = (out_sizes_cpu * element_size).cuda() - - # Compute displacements - send_displs = torch.zeros(comm.world_size, dtype=torch.int64, device="cuda") - recv_displs = torch.zeros(comm.world_size, dtype=torch.int64, device="cuda") - - if comm.world_size > 1: - send_displs_cpu = torch.zeros(comm.world_size, dtype=torch.int64) - recv_displs_cpu = torch.zeros(comm.world_size, dtype=torch.int64) - send_displs_cpu[1:] = torch.cumsum(send_counts.cpu()[:-1], dim=0) - recv_displs_cpu[1:] = torch.cumsum(recv_counts.cpu()[:-1], dim=0) - send_displs = send_displs_cpu.cuda() - recv_displs = recv_displs_cpu.cuda() - - comm.alltoallv( - input_tensor, output_tensor, - send_counts, send_displs, - recv_counts, recv_displs, - stream - ) - - -def main(): - """Test the alltoallv implementation.""" - rank = int(os.environ["RANK"]) - world_size = int(os.environ["WORLD_SIZE"]) - local_rank = int(os.environ.get("LOCAL_RANK", os.environ["RANK"])) - torch.cuda.set_device(local_rank) - - master_addr = os.environ["MSCCLPP_MASTER_ADDR"] - master_port = os.environ["MSCCLPP_MASTER_PORT"] - - interface = interfaces_for_ip_netifaces(master_addr) - if interface is None: - raise ValueError(f"Cannot find network interface for IP address {master_addr}") - - interfaceIpPortTrio = f"{interface}:{master_addr}:{master_port}" - mscclpp_group = mscclpp.CommGroup(interfaceIpPortTrio=interfaceIpPortTrio, rank=rank, size=world_size) - - # Create communicator - comm = AllToAllVComm(mscclpp_group) - - # Test with variable sizes per rank - # Each rank sends different amounts to different peers - # For simplicity: rank i sends (i+1)*100 elements to each peer - elements_per_peer = (rank + 1) * 100 - - # Create send buffer with data - total_send = elements_per_peer * world_size - send_tensor = torch.arange(total_send, device="cuda", dtype=torch.float32) + rank * 10000 - - # Receive buffer needs to accommodate variable amounts from each sender - # Each sender j sends (j+1)*100 elements to us - recv_counts_cpu = torch.tensor([(j + 1) * 100 for j in range(world_size)], dtype=torch.int64) - total_recv = recv_counts_cpu.sum().item() - recv_tensor = torch.zeros(total_recv, device="cuda", dtype=torch.float32) - - # Send counts: we send elements_per_peer to everyone - send_counts_cpu = torch.tensor([elements_per_peer] * world_size, dtype=torch.int64) - - # Move to GPU - send_counts = send_counts_cpu.cuda() - recv_counts = recv_counts_cpu.cuda() - - comm.barrier_cpu() - - # Perform alltoallv - comm.alltoallv_by_elements( - send_tensor, recv_tensor, - send_counts, recv_counts, - stream=torch.cuda.current_stream() - ) - - torch.cuda.synchronize() - comm.barrier_cpu() - - print(f"Rank {rank}: alltoallv completed successfully!") - print(f" Sent {total_send} elements, received {total_recv} elements") - - # Cleanup - comm = None - - -if __name__ == "__main__": - main() diff --git a/examples/torch-integration/alltoallv_kernel.cu b/examples/torch-integration/alltoallv_kernel.cu deleted file mode 100644 index ea017243..00000000 --- a/examples/torch-integration/alltoallv_kernel.cu +++ /dev/null @@ -1,50 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT License. - -// AllToAllV Python bindings for MSCCLPP -// This file provides Python bindings for the alltoallv algorithm. -// The actual implementation is in src/ext/collectives/alltoallv/ - -#include -#include - -#include - -// Include the implementation header -#include "alltoallv/alltoallv_fullmesh.hpp" - -namespace py = pybind11; - -std::shared_ptr createAlltoallvAlgorithm() { - auto alltoallvAlgoBuilder = std::make_shared(); - return alltoallvAlgoBuilder->build(); -} - -void deletePtr(PyObject* capsule) { - const char* name = PyCapsule_GetName(capsule); - void* p = PyCapsule_GetPointer(capsule, name); - if (p == nullptr) { - PyErr_WriteUnraisable(capsule); - return; - } - auto* ptr = static_cast*>(p); - delete ptr; -} - -PyObject* getCapsule(std::shared_ptr algo) { - auto* ptrCopy = new std::shared_ptr(algo); - PyObject* capsule = PyCapsule_New(ptrCopy, mscclpp::ALGORITHM_NATIVE_CAPSULE_NAME, deletePtr); - if (capsule == nullptr) { - delete ptrCopy; - throw pybind11::error_already_set(); - } - return capsule; -} - -PYBIND11_MODULE(mscclpp_alltoallv, m) { - m.doc() = "AllToAllV implementation for MSCCLPP - handles variable element counts per rank"; - m.def( - "create_alltoallv_algorithm", - []() { return py::reinterpret_steal(getCapsule(createAlltoallvAlgorithm())); }, - "Create an alltoallv algorithm and return it as a PyCapsule usable by MSCCL++ Python bindings"); -} diff --git a/python/mscclpp/ext/__init__.py b/python/mscclpp/ext/__init__.py index 5c73df3c..8f251a69 100644 --- a/python/mscclpp/ext/__init__.py +++ b/python/mscclpp/ext/__init__.py @@ -2,5 +2,6 @@ # Licensed under the MIT license. from .algorithm_collection_builder import * +from .alltoallv_single import MscclppAlltoAllV, all_to_all_single -__all__ = algorithm_collection_builder.__all__ +__all__ = algorithm_collection_builder.__all__ + ["MscclppAlltoAllV", "all_to_all_single"] diff --git a/python/mscclpp/ext/alltoallv_single.py b/python/mscclpp/ext/alltoallv_single.py new file mode 100644 index 00000000..29c01096 --- /dev/null +++ b/python/mscclpp/ext/alltoallv_single.py @@ -0,0 +1,331 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +""" +PyTorch-compatible all_to_all_single API using mscclpp optimized kernels. + +This module provides: +- MscclppAlltoAllV: A class to manage mscclpp alltoallv state +- all_to_all_single: Drop-in replacement for torch.distributed.all_to_all_single + +Uses optimized C++ kernels (alltoallvKernel, alltoallvRingKernel, alltoallvPipelinedKernel) +via the NativeAlgorithm framework with size-adaptive algorithm selection. +""" + +from __future__ import annotations +import torch +import torch.distributed as dist +from typing import Optional, List, Tuple +from mscclpp._mscclpp import ( + Communicator, + TcpBootstrap, + DataType, + ReduceOp, +) +from mscclpp.ext.algorithm_collection_builder import AlgorithmCollectionBuilder + +__all__ = ["MscclppAlltoAllV", "all_to_all_single"] + + +def _torch_dtype_to_mscclpp(dtype: torch.dtype) -> DataType: + """Convert PyTorch dtype to mscclpp DataType.""" + if dtype == torch.float32: + return DataType.float32 + elif dtype == torch.float16: + return DataType.float16 + elif dtype == torch.bfloat16: + return DataType.bfloat16 + elif dtype == torch.int32: + return DataType.int32 + elif dtype == torch.int64: + return DataType.int64 + elif dtype == torch.uint8: + return DataType.uint8 + elif dtype == torch.float64: + return DataType.float64 + else: + raise ValueError(f"Unsupported dtype: {dtype}") + + +def _dtype_size(dtype: torch.dtype) -> int: + """Get byte size for dtype.""" + return torch.tensor([], dtype=dtype).element_size() + + +class MscclppAlltoAllV: + """ + Manages mscclpp state for alltoallv operations. + + Uses optimized C++ kernels from alltoallv_fullmesh.cu with size-adaptive selection: + - Small messages (<1MB): alltoallvKernel (lower latency) + - Large messages + small world (<=16): alltoallvPipelinedKernel + - Large messages + large world (>16): alltoallvRingKernel (avoids congestion) + + Example: + mscclpp_alltoallv = MscclppAlltoAllV( + rank=rank, world_size=world_size, + ip_port="10.0.0.1:50000" + ) + # or use existing communicator: + # mscclpp_alltoallv = MscclppAlltoAllV(communicator=comm) + + # Later: + output = mscclpp_alltoallv.all_to_all_single( + input_tensor, + output_split_sizes=[1024, 2048, ...], # per-rank sizes + input_split_sizes=[1024, 2048, ...] + ) + """ + + def __init__( + self, + rank: Optional[int] = None, + world_size: Optional[int] = None, + ip_port: Optional[str] = None, + communicator: Optional[Communicator] = None, + scratch_buffer_size: int = 256 * 1024 * 1024, # 256MB default + ): + """ + Initialize MscclppAlltoAllV. + + Args: + rank: Local rank (required if communicator not provided) + world_size: Total number of ranks (required if communicator not provided) + ip_port: IP:port for bootstrap (required if communicator not provided) + communicator: Existing mscclpp Communicator (alternative to rank/world_size/ip_port) + scratch_buffer_size: Size of scratch buffer in bytes + """ + if communicator is not None: + self._comm = communicator + self._rank = self._comm.bootstrap().get_rank() + self._world_size = self._comm.bootstrap().get_n_ranks() + self._owns_comm = False + else: + if rank is None or world_size is None or ip_port is None: + raise ValueError("Must provide either communicator or (rank, world_size, ip_port)") + + self._rank = rank + self._world_size = world_size + + # Create bootstrap + bootstrap = TcpBootstrap(rank, world_size) + if rank == 0: + unique_id = bootstrap.create_unique_id() + # Broadcast unique_id to other ranks via torch.distributed + id_tensor = torch.tensor(list(unique_id.encode()), dtype=torch.uint8).cuda() + else: + id_tensor = torch.zeros(128, dtype=torch.uint8).cuda() + + dist.broadcast(id_tensor, src=0) + unique_id = bytes(id_tensor.cpu().tolist()).decode().rstrip('\x00') + + bootstrap.initialize(unique_id) + self._comm = Communicator(bootstrap) + self._owns_comm = True + + # Allocate scratch buffer + self._scratch_buffer = torch.zeros(scratch_buffer_size, dtype=torch.uint8, device='cuda') + self._scratch_ptr = self._scratch_buffer.data_ptr() + self._scratch_size = scratch_buffer_size + + # Build algorithm collection with default algorithms including alltoallv + builder = AlgorithmCollectionBuilder() + self._algo_collection = builder.build_default_algorithms( + self._scratch_ptr, + self._scratch_size, + self._rank + ) + + # Get the alltoallv algorithm + alltoallv_algos = self._algo_collection.get_by_collective("alltoallv") + if not alltoallv_algos: + raise RuntimeError("No alltoallv algorithm found. Make sure mscclpp is built correctly.") + self._algo = alltoallv_algos[0] + + # Pre-allocate count/displacement buffers on GPU (reused across calls) + # Using int64 (8 bytes) instead of size_t for safety + self._d_send_counts = torch.zeros(self._world_size, dtype=torch.int64, device='cuda') + self._d_send_displs = torch.zeros(self._world_size, dtype=torch.int64, device='cuda') + self._d_recv_counts = torch.zeros(self._world_size, dtype=torch.int64, device='cuda') + self._d_recv_displs = torch.zeros(self._world_size, dtype=torch.int64, device='cuda') + + @property + def rank(self) -> int: + return self._rank + + @property + def world_size(self) -> int: + return self._world_size + + def all_to_all_single( + self, + input: torch.Tensor, + output_split_sizes: Optional[List[int]] = None, + input_split_sizes: Optional[List[int]] = None, + output: Optional[torch.Tensor] = None, + stream: Optional[torch.cuda.Stream] = None, + ) -> torch.Tensor: + """ + Perform all-to-all exchange with variable-sized chunks. + + Compatible with torch.distributed.all_to_all_single signature. + + Args: + input: Input tensor (contiguous, CUDA) + output_split_sizes: List of sizes to receive from each rank (in elements) + input_split_sizes: List of sizes to send to each rank (in elements) + output: Pre-allocated output tensor (optional) + stream: CUDA stream (optional, uses current stream if not specified) + + Returns: + Output tensor with received data + """ + if not input.is_cuda or not input.is_contiguous(): + raise ValueError("Input must be a contiguous CUDA tensor") + + dtype = input.dtype + elem_size = _dtype_size(dtype) + world_size = self._world_size + + # Handle split sizes + if input_split_sizes is None: + # Equal split + assert input.numel() % world_size == 0 + chunk_size = input.numel() // world_size + input_split_sizes = [chunk_size] * world_size + + if output_split_sizes is None: + # All-to-all uniform: send and recv same sizes + output_split_sizes = input_split_sizes.copy() + + # Calculate total output size and allocate if needed + total_output = sum(output_split_sizes) + if output is None: + output = torch.empty(total_output, dtype=dtype, device='cuda') + elif output.numel() < total_output: + raise ValueError(f"Output tensor too small: {output.numel()} < {total_output}") + + # Calculate displacements + send_displs = [0] + for i in range(world_size - 1): + send_displs.append(send_displs[-1] + input_split_sizes[i]) + + recv_displs = [0] + for i in range(world_size - 1): + recv_displs.append(recv_displs[-1] + output_split_sizes[i]) + + # Convert to byte sizes/offsets for the kernel + send_counts_bytes = [s * elem_size for s in input_split_sizes] + send_displs_bytes = [d * elem_size for d in send_displs] + recv_counts_bytes = [s * elem_size for s in output_split_sizes] + recv_displs_bytes = [d * elem_size for d in recv_displs] + + # Copy to GPU + self._d_send_counts.copy_(torch.tensor(send_counts_bytes, dtype=torch.int64)) + self._d_send_displs.copy_(torch.tensor(send_displs_bytes, dtype=torch.int64)) + self._d_recv_counts.copy_(torch.tensor(recv_counts_bytes, dtype=torch.int64)) + self._d_recv_displs.copy_(torch.tensor(recv_displs_bytes, dtype=torch.int64)) + + # Get stream + if stream is None: + stream = torch.cuda.current_stream() + cuda_stream = stream.cuda_stream + + # Build extras dict with GPU pointers + extras = { + "sendCounts": self._d_send_counts.data_ptr(), + "sendDispls": self._d_send_displs.data_ptr(), + "recvCounts": self._d_recv_counts.data_ptr(), + "recvDispls": self._d_recv_displs.data_ptr(), + } + + input_size = sum(send_counts_bytes) + output_size = sum(recv_counts_bytes) + + # Execute the optimized kernel + result = self._algo.execute( + self._comm, + input.data_ptr(), + output.data_ptr(), + input_size, + output_size, + _torch_dtype_to_mscclpp(dtype), + ReduceOp.NOP, + cuda_stream, + None, # executor (not needed for native algos) + 0, # nblocks (auto) + 0, # nthreads_per_block (auto) + extras, + ) + + if result != 0: + raise RuntimeError(f"alltoallv execution failed with code {result}") + + return output + + def __del__(self): + """Cleanup resources.""" + # Let CUDA handle tensor cleanup automatically + pass + + +# Module-level singleton for convenience +_default_instance: Optional[MscclppAlltoAllV] = None + + +def get_default_instance(**kwargs) -> MscclppAlltoAllV: + """Get or create a default MscclppAlltoAllV instance.""" + global _default_instance + if _default_instance is None: + _default_instance = MscclppAlltoAllV(**kwargs) + return _default_instance + + +def all_to_all_single( + output: torch.Tensor, + input: torch.Tensor, + output_split_sizes: Optional[List[int]] = None, + input_split_sizes: Optional[List[int]] = None, + group=None, + async_op: bool = False, +) -> Optional[torch.Tensor]: + """ + Drop-in replacement for torch.distributed.all_to_all_single. + + Uses mscclpp optimized kernels internally for better performance, + especially with imbalanced message sizes (e.g., MoE workloads). + + Note: This function requires prior initialization via get_default_instance() + or will fall back to PyTorch's native implementation. + + Args: + output: Pre-allocated output tensor + input: Input tensor + output_split_sizes: Sizes to receive from each rank + input_split_sizes: Sizes to send to each rank + group: Process group (unused, for compatibility) + async_op: If True, return async handle (not supported, falls back) + + Returns: + None (modifies output in-place) or async handle if async_op=True + """ + global _default_instance + + # Fall back to PyTorch if not initialized or async requested + if _default_instance is None or async_op: + return dist.all_to_all_single( + output, input, + output_split_sizes=output_split_sizes, + input_split_sizes=input_split_sizes, + group=group, + async_op=async_op + ) + + # Use optimized mscclpp implementation + result = _default_instance.all_to_all_single( + input=input, + output_split_sizes=output_split_sizes, + input_split_sizes=input_split_sizes, + output=output, + ) + + return None # Matches torch.distributed API (async_op=False returns None) diff --git a/python/test/test_alltoallv_mscclpp.py b/python/test/test_alltoallv_mscclpp.py new file mode 100644 index 00000000..bd7ca1f9 --- /dev/null +++ b/python/test/test_alltoallv_mscclpp.py @@ -0,0 +1,178 @@ +#!/usr/bin/env python3 +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +""" +Test script for MscclppAlltoAllV with optimized C++ kernels. +Uses MPI bootstrap for mscclpp and NCCL backend for torch.distributed. + +Usage: + mpirun -np N python test_alltoallv_mscclpp.py +""" + +import torch +import torch.distributed as dist +import os +import time + +# Must init torch.distributed before importing mscclpp modules +# to set rank/world_size environment variables + + +def main(): + # Get rank/world from MPI environment + rank = int(os.environ.get("OMPI_COMM_WORLD_RANK", os.environ.get("PMI_RANK", 0))) + world_size = int(os.environ.get("OMPI_COMM_WORLD_SIZE", os.environ.get("PMI_SIZE", 1))) + + # Set CUDA device + local_rank = int(os.environ.get("LOCAL_RANK", rank % torch.cuda.device_count())) + torch.cuda.set_device(local_rank) + + # Initialize torch.distributed with NCCL (need MASTER_ADDR/PORT) + os.environ.setdefault("MASTER_ADDR", "127.0.0.1") + os.environ.setdefault("MASTER_PORT", "29500") + os.environ["RANK"] = str(rank) + os.environ["WORLD_SIZE"] = str(world_size) + dist.init_process_group(backend="nccl", rank=rank, world_size=world_size, + device_id=torch.device(f"cuda:{local_rank}")) + + if rank == 0: + print(f"Testing MscclppAlltoAllV with {world_size} ranks") + print("=" * 60) + + # Import after torch.distributed init + from mscclpp._mscclpp import ( + Communicator, + TcpBootstrap, + UniqueId, + ) + from mscclpp.ext.alltoallv_single import MscclppAlltoAllV + import pickle + + # Create mscclpp communicator with TcpBootstrap + # Use torch.distributed to share the unique ID via pickle + bootstrap = TcpBootstrap(rank, world_size) + + if rank == 0: + unique_id = bootstrap.create_unique_id() + # Serialize UniqueId via pickle and broadcast + pickled = pickle.dumps(unique_id) + id_tensor = torch.zeros(256, dtype=torch.uint8, device='cuda') + id_tensor[:len(pickled)] = torch.tensor(list(pickled), dtype=torch.uint8) + # Also send length + len_tensor = torch.tensor([len(pickled)], dtype=torch.int64, device='cuda') + else: + id_tensor = torch.zeros(256, dtype=torch.uint8, device='cuda') + len_tensor = torch.zeros(1, dtype=torch.int64, device='cuda') + + dist.broadcast(len_tensor, src=0) + dist.broadcast(id_tensor, src=0) + + if rank != 0: + pickled_len = int(len_tensor.item()) + pickled = bytes(id_tensor[:pickled_len].cpu().tolist()) + unique_id = pickle.loads(pickled) + + bootstrap.initialize(unique_id) + comm = Communicator(bootstrap) + + # Create MscclppAlltoAllV with existing communicator + alltoallv = MscclppAlltoAllV(communicator=comm) + + if rank == 0: + print(f"MscclppAlltoAllV initialized") + print(f"Algorithm: {alltoallv._algo.name}") + + # Test 1: Uniform all-to-all (equal splits) + if rank == 0: + print("\n[Test 1] Uniform all-to-all (1024 elements per rank)") + + chunk_size = 1024 + input_data = torch.arange( + rank * world_size * chunk_size, + (rank + 1) * world_size * chunk_size, + dtype=torch.float32, + device='cuda' + ) + + output = alltoallv.all_to_all_single(input_data) + + # Verify: each chunk should come from different ranks + torch.cuda.synchronize() + expected_total = sum(r * world_size * chunk_size for r in range(world_size)) + actual_total = output[:chunk_size].sum().item() # Just check first chunk is from rank 0 + expected = 0 * world_size * chunk_size + sum(range(chunk_size)) + if rank == 0: + print(f" First chunk sum: {actual_total}, expected ~{expected}") + print(f" PASS" if abs(actual_total - expected) < 1 else f" FAIL") + + # Test 2: Variable-size all-to-all (simulating MoE) + if rank == 0: + print("\n[Test 2] Variable-size all-to-all (MoE-like)") + + # Simulate MoE token distribution: rank 0 sends more to rank 0, etc. + input_split_sizes = [(i + 1) * 512 for i in range(world_size)] + output_split_sizes = [512 * (rank + 1)] * world_size + + total_input = sum(input_split_sizes) + total_output = sum(output_split_sizes) + + input_tensor = torch.randn(total_input, dtype=torch.float32, device='cuda') + output_tensor = torch.empty(total_output, dtype=torch.float32, device='cuda') + + output = alltoallv.all_to_all_single( + input_tensor, + output_split_sizes=output_split_sizes, + input_split_sizes=input_split_sizes, + output=output_tensor + ) + + torch.cuda.synchronize() + if rank == 0: + print(f" Input splits: {input_split_sizes}") + print(f" Output splits: {output_split_sizes}") + print(f" Input total: {total_input}, Output total: {total_output}") + print(f" PASS") + + # Test 3: Performance benchmark + if rank == 0: + print("\n[Test 3] Performance benchmark (1MB per rank)") + + msg_size = 1024 * 1024 # 1MB per message + input_size = msg_size * world_size + + input_tensor = torch.randn(input_size // 4, dtype=torch.float32, device='cuda') # 4 bytes per float + output_tensor = torch.empty_like(input_tensor) + + # Warmup + for _ in range(5): + output = alltoallv.all_to_all_single(input_tensor, output=output_tensor) + torch.cuda.synchronize() + + # Benchmark + n_iters = 20 + start = time.perf_counter() + for _ in range(n_iters): + output = alltoallv.all_to_all_single(input_tensor, output=output_tensor) + torch.cuda.synchronize() + elapsed = time.perf_counter() - start + + # Calculate bandwidth + total_bytes = 2 * input_size * n_iters # read + write + bandwidth_gbps = total_bytes / elapsed / 1e9 + + if rank == 0: + print(f" {n_iters} iterations in {elapsed*1000:.2f} ms") + print(f" Bandwidth: {bandwidth_gbps:.2f} GB/s") + print(f" Per-iteration: {elapsed/n_iters*1000:.3f} ms") + + # Cleanup + dist.barrier() + if rank == 0: + print("\n" + "=" * 60) + print("All tests passed!") + + dist.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/src/ext/collectives/algorithm_collection_builder.cc b/src/ext/collectives/algorithm_collection_builder.cc index 566c1852..2e7b2920 100644 --- a/src/ext/collectives/algorithm_collection_builder.cc +++ b/src/ext/collectives/algorithm_collection_builder.cc @@ -13,6 +13,7 @@ #include "allreduce/allreduce_nvls_with_copy.hpp" #include "allreduce/allreduce_nvls_with_copy_2.hpp" #include "allreduce/allreduce_packet.hpp" +#include "alltoallv/alltoallv_fullmesh.hpp" #include "logger.hpp" namespace mscclpp { @@ -81,6 +82,10 @@ AlgorithmCollection AlgorithmCollectionBuilder::buildDefaultNativeAlgorithms(uin collection.registerAlgorithm(allgatherFullmesh->collective(), allgatherFullmesh->name(), allgatherFullmesh); auto allgatherFullmesh2 = std::make_shared()->build(); collection.registerAlgorithm(allgatherFullmesh2->collective(), allgatherFullmesh2->name(), allgatherFullmesh2); + + // AllToAllV collective for MoE patterns + auto alltoallvFullmesh = std::make_shared()->build(); + collection.registerAlgorithm(alltoallvFullmesh->collective(), alltoallvFullmesh->name(), alltoallvFullmesh); return collection; } diff --git a/src/ext/collectives/alltoallv/alltoallv_fullmesh.cu b/src/ext/collectives/alltoallv/alltoallv_fullmesh.cu index 8a29d322..f6318129 100644 --- a/src/ext/collectives/alltoallv/alltoallv_fullmesh.cu +++ b/src/ext/collectives/alltoallv/alltoallv_fullmesh.cu @@ -36,10 +36,12 @@ struct AllToAllVContext { AlltoallvFullmesh::~AlltoallvFullmesh() = default; std::shared_ptr AlltoallvFullmesh::build() { - auto self = std::shared_ptr(this, [](AlltoallvFullmesh*) {}); + // Create a new shared_ptr that owns the object to keep it alive + // This ensures the lambdas capturing 'self' have a valid object + auto self = std::make_shared(); std::shared_ptr alltoallvAlgo = std::make_shared( - "alltoallv", "alltoallv_fullmesh", + "default_alltoallv_fullmesh", "alltoallv", // name, collective (was swapped before) // Initialize function [self](std::shared_ptr comm) { self->initialize(comm); }, // Kernel execution function