Integrate with MoE training flow

This commit is contained in:
Qinghua Zhou
2026-03-03 15:17:20 +00:00
parent ee843d445f
commit d5743e2d6c
2 changed files with 37 additions and 5 deletions

View File

@@ -64,7 +64,7 @@ void register_core(nb::module_& m) {
self->recv(data, size, peer, tag);
},
nb::arg("data"), nb::arg("size"), nb::arg("peer"), nb::arg("tag"))
.def("all_gather", &Bootstrap::allGather, nb::arg("allData"), nb::arg("size"))
.def("all_gather", [](Bootstrap* self, uintptr_t ptr, int size) { void* data = reinterpret_cast<void*>(ptr); self->allGather(data, size); }, nb::arg("allData"), nb::arg("size"))
.def("barrier", &Bootstrap::barrier)
.def("send", static_cast<void (Bootstrap::*)(const std::vector<char>&, int, int)>(&Bootstrap::send),
nb::arg("data"), nb::arg("peer"), nb::arg("tag"))

View File

@@ -12,9 +12,17 @@ via the NativeAlgorithm framework with size-adaptive algorithm selection.
"""
from __future__ import annotations
import os
import sys
import torch
import torch.distributed as dist
from typing import Optional, List, Tuple
_DEBUG_A2AV = bool(int(os.environ.get("DEBUG_ALL2ALL_MSG_SIZE", "0")))
def _a2av_dbg(msg: str):
if _DEBUG_A2AV:
print(msg, file=sys.stderr, flush=True)
from mscclpp._mscclpp import (
Communicator,
TcpBootstrap,
@@ -164,6 +172,7 @@ class MscclppAlltoAllV:
"recvDispls": self._d_recv_displs.data_ptr(),
"remoteRecvDispls": self._d_remote_recv_displs.data_ptr(),
}
self._a2av_call_count = 0
@property
def rank(self) -> int:
@@ -259,9 +268,30 @@ class MscclppAlltoAllV:
stream = torch.cuda.current_stream()
cuda_stream = stream.cuda_stream
input_size = self._cached_input_size
output_size = self._cached_output_size
# Use full buffer sizes (not actual data sizes) so the C++ context
# key (input_ptr, output_ptr, inputSize, outputSize) is always the
# same when using persistent buffers. This ensures only ONE context
# is ever created, avoiding bootstrap TCP on every unique size combo.
# The kernel uses per-peer sendCounts/recvCounts for actual data bounds.
input_size = input.numel() * elem_size
output_size = output.numel() * elem_size
self._a2av_call_count += 1
_cid = self._a2av_call_count
# Flush ALL GPU streams (including concurrent NCCL from async reducer)
# so the alltoallv kernel launches on a quiet GPU.
torch.cuda.synchronize()
_a2av_dbg(f"[A2AV R{self._rank}] #{_cid} pre-barrier in={input_size} out={output_size}")
# Barrier: ensure ALL ranks launch the alltoallv kernel simultaneously.
# The kernel uses inter-GPU flag-based signaling that requires every
# rank kernel to be active at the same time.
self._comm.bootstrap().barrier()
_a2av_dbg(f"[A2AV R{self._rank}] #{_cid} post-barrier, launching kernel")
# Execute the optimized kernel
result = self._algo.execute(
self._comm,
@@ -277,7 +307,9 @@ class MscclppAlltoAllV:
0, # nthreads_per_block (auto)
self._extras,
)
_a2av_dbg(f"[A2AV R{self._rank}] #{_cid} kernel returned rc={result}")
if result != 0:
raise RuntimeError(f"alltoallv execution failed with code {result}")