mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-12 09:17:06 +00:00
Integrate with MoE training flow
This commit is contained in:
@@ -64,7 +64,7 @@ void register_core(nb::module_& m) {
|
||||
self->recv(data, size, peer, tag);
|
||||
},
|
||||
nb::arg("data"), nb::arg("size"), nb::arg("peer"), nb::arg("tag"))
|
||||
.def("all_gather", &Bootstrap::allGather, nb::arg("allData"), nb::arg("size"))
|
||||
.def("all_gather", [](Bootstrap* self, uintptr_t ptr, int size) { void* data = reinterpret_cast<void*>(ptr); self->allGather(data, size); }, nb::arg("allData"), nb::arg("size"))
|
||||
.def("barrier", &Bootstrap::barrier)
|
||||
.def("send", static_cast<void (Bootstrap::*)(const std::vector<char>&, int, int)>(&Bootstrap::send),
|
||||
nb::arg("data"), nb::arg("peer"), nb::arg("tag"))
|
||||
|
||||
@@ -12,9 +12,17 @@ via the NativeAlgorithm framework with size-adaptive algorithm selection.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import os
|
||||
import sys
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from typing import Optional, List, Tuple
|
||||
|
||||
_DEBUG_A2AV = bool(int(os.environ.get("DEBUG_ALL2ALL_MSG_SIZE", "0")))
|
||||
|
||||
def _a2av_dbg(msg: str):
|
||||
if _DEBUG_A2AV:
|
||||
print(msg, file=sys.stderr, flush=True)
|
||||
from mscclpp._mscclpp import (
|
||||
Communicator,
|
||||
TcpBootstrap,
|
||||
@@ -164,6 +172,7 @@ class MscclppAlltoAllV:
|
||||
"recvDispls": self._d_recv_displs.data_ptr(),
|
||||
"remoteRecvDispls": self._d_remote_recv_displs.data_ptr(),
|
||||
}
|
||||
self._a2av_call_count = 0
|
||||
|
||||
@property
|
||||
def rank(self) -> int:
|
||||
@@ -259,9 +268,30 @@ class MscclppAlltoAllV:
|
||||
stream = torch.cuda.current_stream()
|
||||
cuda_stream = stream.cuda_stream
|
||||
|
||||
input_size = self._cached_input_size
|
||||
output_size = self._cached_output_size
|
||||
|
||||
# Use full buffer sizes (not actual data sizes) so the C++ context
|
||||
# key (input_ptr, output_ptr, inputSize, outputSize) is always the
|
||||
# same when using persistent buffers. This ensures only ONE context
|
||||
# is ever created, avoiding bootstrap TCP on every unique size combo.
|
||||
# The kernel uses per-peer sendCounts/recvCounts for actual data bounds.
|
||||
input_size = input.numel() * elem_size
|
||||
output_size = output.numel() * elem_size
|
||||
|
||||
self._a2av_call_count += 1
|
||||
_cid = self._a2av_call_count
|
||||
|
||||
# Flush ALL GPU streams (including concurrent NCCL from async reducer)
|
||||
# so the alltoallv kernel launches on a quiet GPU.
|
||||
torch.cuda.synchronize()
|
||||
|
||||
_a2av_dbg(f"[A2AV R{self._rank}] #{_cid} pre-barrier in={input_size} out={output_size}")
|
||||
|
||||
# Barrier: ensure ALL ranks launch the alltoallv kernel simultaneously.
|
||||
# The kernel uses inter-GPU flag-based signaling that requires every
|
||||
# rank kernel to be active at the same time.
|
||||
self._comm.bootstrap().barrier()
|
||||
|
||||
_a2av_dbg(f"[A2AV R{self._rank}] #{_cid} post-barrier, launching kernel")
|
||||
|
||||
# Execute the optimized kernel
|
||||
result = self._algo.execute(
|
||||
self._comm,
|
||||
@@ -277,7 +307,9 @@ class MscclppAlltoAllV:
|
||||
0, # nthreads_per_block (auto)
|
||||
self._extras,
|
||||
)
|
||||
|
||||
|
||||
_a2av_dbg(f"[A2AV R{self._rank}] #{_cid} kernel returned rc={result}")
|
||||
|
||||
if result != 0:
|
||||
raise RuntimeError(f"alltoallv execution failed with code {result}")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user