Merge qinghuazhou/expert_parallel_gb200

This commit is contained in:
Qinghua Zhou
2026-05-20 01:56:34 +00:00
56 changed files with 1465 additions and 530 deletions

View File

@@ -41,10 +41,26 @@ import torch
import torch.distributed as dist
def _detect_local_world_size():
"""Number of GPUs per node (4 on GB200, 8 on H100/A100, etc.).
Resolution order:
1. `MSCCLPP_EP_LOCAL_WORLD_SIZE` env var (matches the C++ side).
2. `LOCAL_WORLD_SIZE` (torchrun) or `OMPI_COMM_WORLD_LOCAL_SIZE` (mpirun).
3. `torch.cuda.device_count()` on the current host.
"""
for var in ("MSCCLPP_EP_LOCAL_WORLD_SIZE", "LOCAL_WORLD_SIZE", "OMPI_COMM_WORLD_LOCAL_SIZE"):
v = os.environ.get(var)
if v and int(v) > 0:
return int(v)
return max(1, torch.cuda.device_count())
def init_dist():
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
local_rank = int(os.environ.get("LOCAL_RANK", rank % 4))
local_world_size = _detect_local_world_size()
local_rank = int(os.environ.get("LOCAL_RANK", rank % local_world_size))
torch.cuda.set_device(local_rank)
dist.init_process_group(
backend="nccl", world_size=world_size, rank=rank, device_id=torch.device(f"cuda:{local_rank}")
@@ -71,10 +87,10 @@ def main():
rank, num_ranks, local_rank, group = init_dist()
from mscclpp.ext import ep
NUM_MAX_NVL_PEERS = 4
NUM_MAX_NVL_PEERS = _detect_local_world_size()
assert (
num_ranks % NUM_MAX_NVL_PEERS == 0 and num_ranks > NUM_MAX_NVL_PEERS
), f"expected >1 node with 8 GPUs each, got num_ranks={num_ranks}"
), f"expected >1 node with {NUM_MAX_NVL_PEERS} GPUs each, got num_ranks={num_ranks}"
num_nodes = num_ranks // NUM_MAX_NVL_PEERS
num_local_ranks = NUM_MAX_NVL_PEERS
@@ -124,7 +140,13 @@ def main():
# Buffer config for internode HT: needs num_rdma_bytes > 0. Size buffers
# using max(hidden, bench_hidden) so the optional bench phase fits.
cfg = ep.Config(int(os.environ.get("MSCCLPP_EP_NSM","152")), int(os.environ.get("MSCCLPP_EP_NVL_SEND","8")), int(os.environ.get("MSCCLPP_EP_NVL_RECV","256")), int(os.environ.get("MSCCLPP_EP_RDMA_SEND","16")), int(os.environ.get("MSCCLPP_EP_RDMA_RECV","128")))
cfg = ep.Config(
int(os.environ.get("MSCCLPP_EP_NSM", "152")),
int(os.environ.get("MSCCLPP_EP_NVL_SEND", "8")),
int(os.environ.get("MSCCLPP_EP_NVL_RECV", "256")),
int(os.environ.get("MSCCLPP_EP_RDMA_SEND", "16")),
int(os.environ.get("MSCCLPP_EP_RDMA_RECV", "128")),
)
_bench_on = os.environ.get("MSCCLPP_EP_BENCH", "0") == "1"
_buf_hidden = max(hidden, int(os.environ.get("MSCCLPP_EP_BENCH_HIDDEN", "0"))) if _bench_on else hidden
num_nvl_bytes = cfg.get_nvl_buffer_size_hint(_buf_hidden * x.element_size(), num_ranks)

View File

@@ -106,9 +106,11 @@ def main():
# Allocate Buffer (intranode only: num_rdma_bytes=0). Size the NVL buffer
# using max(hidden, bench_hidden) so the optional bench phase fits.
cfg = ep.Config(int(os.environ.get("MSCCLPP_EP_NUM_SMS", "20")),
int(os.environ.get("MSCCLPP_EP_NVL_SEND", "8")),
int(os.environ.get("MSCCLPP_EP_NVL_RECV", "256")))
cfg = ep.Config(
int(os.environ.get("MSCCLPP_EP_NUM_SMS", "20")),
int(os.environ.get("MSCCLPP_EP_NVL_SEND", "8")),
int(os.environ.get("MSCCLPP_EP_NVL_RECV", "256")),
)
_bench_on = os.environ.get("MSCCLPP_EP_BENCH", "0") == "1"
_buf_hidden = max(hidden, int(os.environ.get("MSCCLPP_EP_BENCH_HIDDEN", "0"))) if _bench_on else hidden
num_nvl_bytes = cfg.get_nvl_buffer_size_hint(_buf_hidden * x.element_size(), num_ranks)
@@ -304,8 +306,8 @@ def main():
# This matches NCCL-EP's `ep_bench` convention and isolates the on-GPU
# dispatch kernel cost from one-time setup overhead.
_layout = _dispatch()
_cached_rpm = _layout[5] # rank_prefix_matrix
_cached_cpm = _layout[6] # channel_prefix_matrix
_cached_rpm = _layout[5] # rank_prefix_matrix
_cached_cpm = _layout[6] # channel_prefix_matrix
_cached_n = int(_layout[0].size(0)) # num_recv_tokens on this rank
def _dispatch_cached():