mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-11 17:00:22 +00:00
1. Fix pinned buffer race condition (alltoallv_single.py):
- The shared pinned CPU buffer was reused for 4 sequential non_blocking
H2D copies. GPU DMA read stale data after CPU overwrote the buffer
with the next field, corrupting sendCounts/recvCounts and causing the
kernel to write to wrong addresses. Fixed by using 5 dedicated pinned
buffers — one per field (send_counts, send_displs, recv_counts,
recv_displs, remote_recv_displs).
2. Remove C++ periodic reset (alltoallv_fullmesh.cu):
- A hardcoded static counter reset destroyed MemoryChannels and
semaphores every 1000 kernel calls while inter-GPU signaling was
still in progress, causing semaphore epoch mismatch and illegal
memory access.
3. Fix semaphore wait (alltoallv_kernel.hpp):
- Make wait() unconditional after signal(). Skipping wait() when
recvCounts==0 desynced the semaphore epoch counter — subsequent
calls wait() returned immediately before the peer finished writing.
4. Add memory fence (alltoallv_kernel.hpp):
- Add __threadfence_system() after wait() outside the primary-block
guard so ALL thread blocks execute it before kernel exit. Ensures
NVLink remote writes from put() are globally visible to subsequent
kernels on the receiving GPU.
511 lines
21 KiB
Python
511 lines
21 KiB
Python
# Copyright (c) Microsoft Corporation.
|
|
# Licensed under the MIT License.
|
|
"""
|
|
PyTorch-compatible all_to_all_single API using mscclpp optimized kernels.
|
|
|
|
This module provides:
|
|
- MscclppAlltoAllV: A class to manage mscclpp alltoallv state
|
|
- all_to_all_single: Drop-in replacement for torch.distributed.all_to_all_single
|
|
|
|
Uses optimized C++ kernels (alltoallvKernel, alltoallvRingKernel, alltoallvPipelinedKernel)
|
|
via the NativeAlgorithm framework with size-adaptive algorithm selection.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
import os
|
|
import sys
|
|
import torch
|
|
import torch.distributed as dist
|
|
from typing import Optional, List, Tuple
|
|
|
|
_DEBUG_A2AV = bool(int(os.environ.get("DEBUG_ALL2ALL_MSG_SIZE", "0")))
|
|
|
|
def _a2av_dbg(msg: str):
|
|
if _DEBUG_A2AV:
|
|
print(msg, file=sys.stderr, flush=True)
|
|
from mscclpp._mscclpp import (
|
|
Communicator,
|
|
TcpBootstrap,
|
|
DataType,
|
|
ReduceOp,
|
|
CommResult,
|
|
)
|
|
from mscclpp.ext.algorithm_collection_builder import AlgorithmCollectionBuilder
|
|
|
|
import ctypes as _ctypes
|
|
try:
|
|
_cudart = _ctypes.CDLL("libcudart.so")
|
|
except Exception:
|
|
_cudart = None
|
|
|
|
_DEBUG = os.environ.get("MSCCLPP_DEBUG_ALLTOALLV", "0") == "1"
|
|
|
|
__all__ = ["MscclppAlltoAllV", "all_to_all_single"]
|
|
|
|
|
|
def _torch_dtype_to_mscclpp(dtype: torch.dtype) -> DataType:
|
|
"""Convert PyTorch dtype to mscclpp DataType."""
|
|
if dtype == torch.float32:
|
|
return DataType.float32
|
|
elif dtype == torch.float16:
|
|
return DataType.float16
|
|
elif dtype == torch.bfloat16:
|
|
return DataType.bfloat16
|
|
elif dtype == torch.int32:
|
|
return DataType.int32
|
|
elif dtype == torch.int64:
|
|
return DataType.int64
|
|
elif dtype == torch.uint8:
|
|
return DataType.uint8
|
|
elif dtype == torch.float64:
|
|
return DataType.float64
|
|
else:
|
|
raise ValueError(f"Unsupported dtype: {dtype}")
|
|
|
|
|
|
def _dtype_size(dtype: torch.dtype) -> int:
|
|
"""Get byte size for dtype."""
|
|
return torch.tensor([], dtype=dtype).element_size()
|
|
|
|
|
|
class MscclppAlltoAllV:
|
|
"""
|
|
Manages mscclpp state for alltoallv operations.
|
|
|
|
Uses optimized C++ kernels from alltoallv_fullmesh.cu with size-adaptive selection:
|
|
- Small messages (<1MB): alltoallvKernel (lower latency)
|
|
- Large messages + small world (<=16): alltoallvPipelinedKernel
|
|
- Large messages + large world (>16): alltoallvRingKernel (avoids congestion)
|
|
|
|
Example:
|
|
mscclpp_alltoallv = MscclppAlltoAllV(
|
|
rank=rank, world_size=world_size,
|
|
ip_port="10.0.0.1:50000"
|
|
)
|
|
# or use existing communicator:
|
|
# mscclpp_alltoallv = MscclppAlltoAllV(communicator=comm)
|
|
|
|
# Later:
|
|
output = mscclpp_alltoallv.all_to_all_single(
|
|
input_tensor,
|
|
output_split_sizes=[1024, 2048, ...], # per-rank sizes
|
|
input_split_sizes=[1024, 2048, ...]
|
|
)
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
rank: Optional[int] = None,
|
|
world_size: Optional[int] = None,
|
|
ip_port: Optional[str] = None,
|
|
communicator: Optional[Communicator] = None,
|
|
scratch_buffer_size: int = 256 * 1024 * 1024, # 256MB default
|
|
):
|
|
"""
|
|
Initialize MscclppAlltoAllV.
|
|
|
|
Args:
|
|
rank: Local rank (required if communicator not provided)
|
|
world_size: Total number of ranks (required if communicator not provided)
|
|
ip_port: IP:port for bootstrap (required if communicator not provided)
|
|
communicator: Existing mscclpp Communicator (alternative to rank/world_size/ip_port)
|
|
scratch_buffer_size: Size of scratch buffer in bytes
|
|
"""
|
|
if communicator is not None:
|
|
self._comm = communicator
|
|
self._rank = self._comm.bootstrap().get_rank()
|
|
self._world_size = self._comm.bootstrap().get_n_ranks()
|
|
self._owns_comm = False
|
|
else:
|
|
if rank is None or world_size is None or ip_port is None:
|
|
raise ValueError("Must provide either communicator or (rank, world_size, ip_port)")
|
|
|
|
self._rank = rank
|
|
self._world_size = world_size
|
|
|
|
# Create bootstrap
|
|
bootstrap = TcpBootstrap(rank, world_size)
|
|
if rank == 0:
|
|
unique_id = bootstrap.create_unique_id()
|
|
# Broadcast unique_id to other ranks via torch.distributed
|
|
id_tensor = torch.tensor(list(unique_id.encode()), dtype=torch.uint8).cuda()
|
|
else:
|
|
id_tensor = torch.zeros(128, dtype=torch.uint8).cuda()
|
|
|
|
dist.broadcast(id_tensor, src=0)
|
|
unique_id = bytes(id_tensor.cpu().tolist()).decode().rstrip('\x00')
|
|
|
|
bootstrap.initialize(unique_id)
|
|
self._comm = Communicator(bootstrap)
|
|
self._owns_comm = True
|
|
|
|
# Allocate scratch buffer
|
|
self._scratch_buffer = torch.zeros(scratch_buffer_size, dtype=torch.uint8, device='cuda')
|
|
self._scratch_ptr = self._scratch_buffer.data_ptr()
|
|
self._scratch_size = scratch_buffer_size
|
|
|
|
# Build algorithm collection with default algorithms including alltoallv
|
|
builder = AlgorithmCollectionBuilder()
|
|
self._algo_collection = builder.build_default_algorithms(
|
|
self._scratch_ptr,
|
|
self._scratch_size,
|
|
self._rank
|
|
)
|
|
|
|
# Get the alltoallv algorithm
|
|
alltoallv_algos = self._algo_collection.get_by_collective("alltoallv")
|
|
if not alltoallv_algos:
|
|
raise RuntimeError("No alltoallv algorithm found. Make sure mscclpp is built correctly.")
|
|
self._algo = alltoallv_algos[0]
|
|
|
|
# Pre-allocate count/displacement buffers on GPU (reused across calls)
|
|
# Using int64 (8 bytes) instead of size_t for safety
|
|
self._d_send_counts = torch.zeros(self._world_size, dtype=torch.int64, device='cuda')
|
|
self._d_send_displs = torch.zeros(self._world_size, dtype=torch.int64, device='cuda')
|
|
self._d_recv_counts = torch.zeros(self._world_size, dtype=torch.int64, device='cuda')
|
|
self._d_recv_displs = torch.zeros(self._world_size, dtype=torch.int64, device='cuda')
|
|
self._d_remote_recv_displs = torch.zeros(self._world_size, dtype=torch.int64, device='cuda')
|
|
|
|
# Cache for split sizes to avoid redundant bootstrap exchanges and GPU copies.
|
|
# Key: (tuple(send_counts_bytes), tuple(recv_counts_bytes))
|
|
self._cached_splits_key = None
|
|
self._cached_input_size = 0
|
|
self._cached_output_size = 0
|
|
self._cached_total_output_elems = 0
|
|
self._cached_dtype = None
|
|
# One-time check for untyped_storage (available since PyTorch 1.13)
|
|
self._has_untyped_storage = hasattr(torch.Tensor, 'untyped_storage')
|
|
# Pre-built extras dict (GPU pointers don't change)
|
|
# Unlike torch.cuda.synchronize() which stalls the host (+20GB OOM),
|
|
self._exec_event = torch.cuda.Event()
|
|
self._extras = {
|
|
"sendCounts": self._d_send_counts.data_ptr(),
|
|
"sendDispls": self._d_send_displs.data_ptr(),
|
|
"recvCounts": self._d_recv_counts.data_ptr(),
|
|
"recvDispls": self._d_recv_displs.data_ptr(),
|
|
"remoteRecvDispls": self._d_remote_recv_displs.data_ptr(),
|
|
}
|
|
self._a2av_call_count = 0
|
|
|
|
@property
|
|
def rank(self) -> int:
|
|
return self._rank
|
|
|
|
@property
|
|
def world_size(self) -> int:
|
|
return self._world_size
|
|
|
|
def all_to_all_single(
|
|
self,
|
|
input: torch.Tensor,
|
|
output_split_sizes: Optional[List[int]] = None,
|
|
input_split_sizes: Optional[List[int]] = None,
|
|
output: Optional[torch.Tensor] = None,
|
|
stream: Optional[torch.cuda.Stream] = None,
|
|
) -> torch.Tensor:
|
|
"""
|
|
Perform all-to-all exchange with variable-sized chunks.
|
|
|
|
Compatible with torch.distributed.all_to_all_single signature.
|
|
|
|
Args:
|
|
input: Input tensor (contiguous, CUDA)
|
|
output_split_sizes: List of sizes to receive from each rank (in elements)
|
|
input_split_sizes: List of sizes to send to each rank (in elements)
|
|
output: Pre-allocated output tensor (optional)
|
|
stream: CUDA stream (optional, uses current stream if not specified)
|
|
|
|
Returns:
|
|
Output tensor with received data
|
|
"""
|
|
if not input.is_cuda or not input.is_contiguous():
|
|
raise ValueError("Input must be a contiguous CUDA tensor")
|
|
|
|
dtype = input.dtype
|
|
elem_size = _dtype_size(dtype)
|
|
world_size = self._world_size
|
|
|
|
# Handle split sizes
|
|
if input_split_sizes is None:
|
|
# Equal split
|
|
assert input.numel() % world_size == 0
|
|
chunk_size = input.numel() // world_size
|
|
input_split_sizes = [chunk_size] * world_size
|
|
|
|
if output_split_sizes is None:
|
|
# All-to-all uniform: send and recv same sizes
|
|
output_split_sizes = input_split_sizes.copy()
|
|
|
|
# Calculate total output size and allocate if needed
|
|
total_output = sum(output_split_sizes)
|
|
if output is None:
|
|
output = torch.empty(total_output, dtype=dtype, device='cuda')
|
|
elif output.numel() < total_output:
|
|
raise ValueError(f"Output tensor too small: {output.numel()} < {total_output}")
|
|
|
|
# Calculate displacements
|
|
send_displs = [0]
|
|
for i in range(world_size - 1):
|
|
send_displs.append(send_displs[-1] + input_split_sizes[i])
|
|
|
|
recv_displs = [0]
|
|
for i in range(world_size - 1):
|
|
recv_displs.append(recv_displs[-1] + output_split_sizes[i])
|
|
|
|
# Convert to byte sizes/offsets for the kernel
|
|
send_counts_bytes = [s * elem_size for s in input_split_sizes]
|
|
send_displs_bytes = [d * elem_size for d in send_displs]
|
|
recv_counts_bytes = [s * elem_size for s in output_split_sizes]
|
|
recv_displs_bytes = [d * elem_size for d in recv_displs]
|
|
|
|
# Fast path: skip GPU copies + bootstrap exchange if split sizes unchanged
|
|
splits_key = (tuple(send_counts_bytes), tuple(recv_counts_bytes))
|
|
if splits_key != self._cached_splits_key:
|
|
if _DEBUG:
|
|
print(f" [rank {self._rank}] alltoallv: splits changed, doing bootstrap exchange", flush=True)
|
|
# NOTE: Do NOT call self._algo.reset() here.
|
|
# With persistent fixed-size buffers, the C++ context key is stable
|
|
# (same ptr + same untyped_storage size). The illegal memory access
|
|
# bug was caused by the shared pinned buffer race (now fixed with
|
|
# 5 separate pinned buffers), NOT by stale contexts.
|
|
# Calling reset() on every split change causes ~20 GiB memory growth
|
|
# on GPU0 over 60k+ calls due to CudaIpc driver resource leaks.
|
|
|
|
|
|
# Copy counts/displacements to GPU using separate pinned CPU buffers.
|
|
# Each field has its own buffer so non_blocking DMA won't race with
|
|
# CPU overwrites (the old 2-buffer approach aliased send/recv).
|
|
if not hasattr(self, '_h_send_counts'):
|
|
ws = self._world_size
|
|
self._h_send_counts = torch.zeros(ws, dtype=torch.int64, pin_memory=True)
|
|
self._h_send_displs = torch.zeros(ws, dtype=torch.int64, pin_memory=True)
|
|
self._h_recv_counts = torch.zeros(ws, dtype=torch.int64, pin_memory=True)
|
|
self._h_recv_displs = torch.zeros(ws, dtype=torch.int64, pin_memory=True)
|
|
self._h_remote_displs = torch.zeros(ws, dtype=torch.int64, pin_memory=True)
|
|
# Write directly to pinned buffers — no torch.tensor() temporaries.
|
|
# Each torch.tensor() call creates a temporary CPU tensor that
|
|
# accumulates PyTorch allocator overhead over 60k+ split changes.
|
|
for _i in range(self._world_size):
|
|
self._h_send_counts[_i] = send_counts_bytes[_i]
|
|
self._h_send_displs[_i] = send_displs_bytes[_i]
|
|
self._h_recv_counts[_i] = recv_counts_bytes[_i]
|
|
self._h_recv_displs[_i] = recv_displs_bytes[_i]
|
|
self._d_send_counts.copy_(self._h_send_counts, non_blocking=True)
|
|
self._d_send_displs.copy_(self._h_send_displs, non_blocking=True)
|
|
self._d_recv_counts.copy_(self._h_recv_counts, non_blocking=True)
|
|
self._d_recv_displs.copy_(self._h_recv_displs, non_blocking=True)
|
|
|
|
# Exchange recv displacements with peers via bootstrap
|
|
if _DEBUG:
|
|
print(f" [rank {self._rank}] alltoallv: starting _exchange_recv_displs", flush=True)
|
|
remote_recv_displs = self._exchange_recv_displs(recv_displs_bytes)
|
|
if _DEBUG:
|
|
print(f" [rank {self._rank}] alltoallv: _exchange_recv_displs done", flush=True)
|
|
for _i in range(self._world_size):
|
|
self._h_remote_displs[_i] = remote_recv_displs[_i]
|
|
self._d_remote_recv_displs.copy_(self._h_remote_displs, non_blocking=True)
|
|
|
|
# Cache for subsequent calls
|
|
self._cached_splits_key = splits_key
|
|
self._cached_input_size = sum(send_counts_bytes)
|
|
self._cached_output_size = sum(recv_counts_bytes)
|
|
|
|
# Barrier: all ranks must finish the displacement exchange before any
|
|
# rank enters algo.execute(), which on the first call does its own
|
|
# bootstrap operations (comm->connect, setupRemoteMemories).
|
|
# Without this barrier, fast ranks' bootstrap messages from
|
|
# initialize() can collide with slow ranks still in _exchange_recv_displs.
|
|
if _DEBUG:
|
|
print(f" [rank {self._rank}] alltoallv: waiting on bootstrap barrier", flush=True)
|
|
self._comm.bootstrap().barrier()
|
|
if _DEBUG:
|
|
print(f" [rank {self._rank}] alltoallv: bootstrap barrier done", flush=True)
|
|
|
|
# Get stream
|
|
if stream is None:
|
|
stream = torch.cuda.current_stream()
|
|
cuda_stream = stream.cuda_stream
|
|
|
|
# Use the full underlying storage size for context key stability.
|
|
# When the test reuses the same large tensor with different split sizes,
|
|
# storage size stays constant → same context key → reuses channels.
|
|
if self._has_untyped_storage:
|
|
input_alloc_size = input.untyped_storage().size()
|
|
output_alloc_size = output.untyped_storage().size()
|
|
else:
|
|
input_alloc_size = input.nelement() * input.element_size()
|
|
output_alloc_size = output.nelement() * output.element_size()
|
|
|
|
self._a2av_call_count += 1
|
|
_cid = self._a2av_call_count
|
|
|
|
# NOTE: Pre-execute sync removed to reduce peak GPU memory pressure.
|
|
# The post-execute sync is sufficient for correctness.
|
|
# 2 syncs per call prevents PyTorch caching allocator from overlapping
|
|
# memory reclamation with the collective, causing ~20GB extra peak.
|
|
|
|
_a2av_dbg(f"[A2AV R{self._rank}] #{_cid} pre-barrier in={input_alloc_size} out={output_alloc_size}")
|
|
|
|
# No per-call barrier: the kernel's semaphore wait() blocks on-GPU
|
|
# until the peer signals. A host-side TCP barrier stalls the pipeline
|
|
# and causes ~20GB peak memory overhead vs NCCL's async model.
|
|
|
|
_a2av_dbg(f"[A2AV R{self._rank}] #{_cid} post-barrier, launching kernel")
|
|
|
|
# Execute the optimized kernel
|
|
|
|
if _DEBUG:
|
|
# Clear stale CUDA errors (the C++ code checks cudaGetLastError
|
|
# after the kernel and returns INTERNAL_ERROR if any was pending).
|
|
if _cudart is not None:
|
|
_last_err = _cudart.cudaGetLastError()
|
|
if _last_err != 0:
|
|
print(f" [rank {self._rank}] WARNING: cleared stale CUDA error code {_last_err} before execute", flush=True)
|
|
print(f" [rank {self._rank}] alltoallv: calling algo.execute(input_alloc={input_alloc_size}, output_alloc={output_alloc_size})", flush=True)
|
|
|
|
result = self._algo.execute(
|
|
self._comm,
|
|
input.data_ptr(),
|
|
output.data_ptr(),
|
|
input_alloc_size,
|
|
output_alloc_size,
|
|
_torch_dtype_to_mscclpp(dtype),
|
|
ReduceOp.NOP,
|
|
cuda_stream,
|
|
None, # executor (not needed for native algos)
|
|
0, # nblocks (auto)
|
|
0, # nthreads_per_block (auto)
|
|
self._extras,
|
|
)
|
|
|
|
if _DEBUG:
|
|
print(f" [rank {self._rank}] alltoallv: algo.execute returned {result}", flush=True)
|
|
|
|
self._exec_event.record()
|
|
torch.cuda.current_stream().wait_event(self._exec_event)
|
|
|
|
|
|
if result != CommResult.COMM_SUCCESS:
|
|
# Get detailed CUDA error before raising
|
|
try:
|
|
torch.cuda.synchronize()
|
|
except Exception as cuda_err:
|
|
raise RuntimeError(f"alltoallv execution failed with code {result}; CUDA error: {cuda_err}")
|
|
raise RuntimeError(f"alltoallv execution failed with code {result}")
|
|
|
|
# Lightweight async error probe: check for CUDA errors accumulated
|
|
# from the kernel we just launched (no full synchronize — just peek).
|
|
# This catches illegal-address faults at the source call instead of
|
|
# letting them propagate to the next unrelated CUDA API call.
|
|
err = torch.cuda.last_status() if hasattr(torch.cuda, 'last_status') else None
|
|
if err is not None and err != 0:
|
|
torch.cuda.synchronize() # force the full error to surface
|
|
raise RuntimeError(
|
|
f"[alltoallv #{_cid}] CUDA error detected after execute: "
|
|
f"err={err}, send_counts={send_counts_bytes}, "
|
|
f"recv_counts={recv_counts_bytes}"
|
|
)
|
|
|
|
return output
|
|
|
|
def _exchange_recv_displs(self, recv_displs_bytes: list) -> list:
|
|
"""
|
|
Exchange recv displacement arrays between all ranks via bootstrap allGather.
|
|
|
|
Each rank needs to know where to write in each peer's output buffer.
|
|
remoteRecvDispls[peer] = peer's recvDispls[rank] — the byte offset in
|
|
peer's output buffer where data from this rank should be placed.
|
|
|
|
Uses bootstrap.all_gather() (ring sockets, pre-established during
|
|
initialize()) instead of pairwise TCP send/recv to avoid deadlocks.
|
|
|
|
Args:
|
|
recv_displs_bytes: This rank's recv displacement array (in bytes)
|
|
|
|
Returns:
|
|
List of remote recv displacements (one per rank, in bytes).
|
|
remoteRecvDispls[rank] == recv_displs_bytes[rank] (self, unused by kernel)
|
|
"""
|
|
import numpy as np
|
|
rank = self._rank
|
|
world_size = self._world_size
|
|
bootstrap = self._comm.bootstrap()
|
|
|
|
# All-gather: each rank contributes world_size int64 values
|
|
all_data = np.zeros((world_size, world_size), dtype=np.int64)
|
|
all_data[rank, :] = recv_displs_bytes
|
|
per_rank_bytes = world_size * 8 # world_size x sizeof(int64)
|
|
bootstrap.all_gather(all_data.ctypes.data, per_rank_bytes)
|
|
|
|
# remoteRecvDispls[peer] = peer's recv_displs[rank]
|
|
remote_recv_displs = [int(all_data[peer, rank]) for peer in range(world_size)]
|
|
return remote_recv_displs
|
|
|
|
def __del__(self):
|
|
"""Cleanup resources."""
|
|
# Let CUDA handle tensor cleanup automatically
|
|
pass
|
|
|
|
|
|
# Module-level singleton for convenience
|
|
_default_instance: Optional[MscclppAlltoAllV] = None
|
|
|
|
|
|
def get_default_instance(**kwargs) -> MscclppAlltoAllV:
|
|
"""Get or create a default MscclppAlltoAllV instance."""
|
|
global _default_instance
|
|
if _default_instance is None:
|
|
_default_instance = MscclppAlltoAllV(**kwargs)
|
|
return _default_instance
|
|
|
|
|
|
def all_to_all_single(
|
|
output: torch.Tensor,
|
|
input: torch.Tensor,
|
|
output_split_sizes: Optional[List[int]] = None,
|
|
input_split_sizes: Optional[List[int]] = None,
|
|
group=None,
|
|
async_op: bool = False,
|
|
) -> Optional[torch.Tensor]:
|
|
"""
|
|
Drop-in replacement for torch.distributed.all_to_all_single.
|
|
|
|
Uses mscclpp optimized kernels internally for better performance,
|
|
especially with imbalanced message sizes (e.g., MoE workloads).
|
|
|
|
Note: This function requires prior initialization via get_default_instance()
|
|
or will fall back to PyTorch's native implementation.
|
|
|
|
Args:
|
|
output: Pre-allocated output tensor
|
|
input: Input tensor
|
|
output_split_sizes: Sizes to receive from each rank
|
|
input_split_sizes: Sizes to send to each rank
|
|
group: Process group (unused, for compatibility)
|
|
async_op: If True, return async handle (not supported, falls back)
|
|
|
|
Returns:
|
|
None (modifies output in-place) or async handle if async_op=True
|
|
"""
|
|
global _default_instance
|
|
|
|
# Fall back to PyTorch if not initialized or async requested
|
|
if _default_instance is None or async_op:
|
|
return dist.all_to_all_single(
|
|
output, input,
|
|
output_split_sizes=output_split_sizes,
|
|
input_split_sizes=input_split_sizes,
|
|
group=group,
|
|
async_op=async_op
|
|
)
|
|
|
|
# Use optimized mscclpp implementation
|
|
result = _default_instance.all_to_all_single(
|
|
input=input,
|
|
output_split_sizes=output_split_sizes,
|
|
input_split_sizes=input_split_sizes,
|
|
output=output,
|
|
)
|
|
|
|
return None # Matches torch.distributed API (async_op=False returns None)
|