mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-06-29 10:57:27 +00:00
Merge qinghuazhou/expert_parallel_gb200
This commit is contained in:
@@ -41,10 +41,26 @@ import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
def _detect_local_world_size():
|
||||
"""Number of GPUs per node (4 on GB200, 8 on H100/A100, etc.).
|
||||
|
||||
Resolution order:
|
||||
1. `MSCCLPP_EP_LOCAL_WORLD_SIZE` env var (matches the C++ side).
|
||||
2. `LOCAL_WORLD_SIZE` (torchrun) or `OMPI_COMM_WORLD_LOCAL_SIZE` (mpirun).
|
||||
3. `torch.cuda.device_count()` on the current host.
|
||||
"""
|
||||
for var in ("MSCCLPP_EP_LOCAL_WORLD_SIZE", "LOCAL_WORLD_SIZE", "OMPI_COMM_WORLD_LOCAL_SIZE"):
|
||||
v = os.environ.get(var)
|
||||
if v and int(v) > 0:
|
||||
return int(v)
|
||||
return max(1, torch.cuda.device_count())
|
||||
|
||||
|
||||
def init_dist():
|
||||
rank = int(os.environ["RANK"])
|
||||
world_size = int(os.environ["WORLD_SIZE"])
|
||||
local_rank = int(os.environ.get("LOCAL_RANK", rank % 4))
|
||||
local_world_size = _detect_local_world_size()
|
||||
local_rank = int(os.environ.get("LOCAL_RANK", rank % local_world_size))
|
||||
torch.cuda.set_device(local_rank)
|
||||
dist.init_process_group(
|
||||
backend="nccl", world_size=world_size, rank=rank, device_id=torch.device(f"cuda:{local_rank}")
|
||||
@@ -71,10 +87,10 @@ def main():
|
||||
rank, num_ranks, local_rank, group = init_dist()
|
||||
from mscclpp.ext import ep
|
||||
|
||||
NUM_MAX_NVL_PEERS = 4
|
||||
NUM_MAX_NVL_PEERS = _detect_local_world_size()
|
||||
assert (
|
||||
num_ranks % NUM_MAX_NVL_PEERS == 0 and num_ranks > NUM_MAX_NVL_PEERS
|
||||
), f"expected >1 node with 8 GPUs each, got num_ranks={num_ranks}"
|
||||
), f"expected >1 node with {NUM_MAX_NVL_PEERS} GPUs each, got num_ranks={num_ranks}"
|
||||
num_nodes = num_ranks // NUM_MAX_NVL_PEERS
|
||||
num_local_ranks = NUM_MAX_NVL_PEERS
|
||||
|
||||
@@ -124,7 +140,13 @@ def main():
|
||||
|
||||
# Buffer config for internode HT: needs num_rdma_bytes > 0. Size buffers
|
||||
# using max(hidden, bench_hidden) so the optional bench phase fits.
|
||||
cfg = ep.Config(int(os.environ.get("MSCCLPP_EP_NSM","152")), int(os.environ.get("MSCCLPP_EP_NVL_SEND","8")), int(os.environ.get("MSCCLPP_EP_NVL_RECV","256")), int(os.environ.get("MSCCLPP_EP_RDMA_SEND","16")), int(os.environ.get("MSCCLPP_EP_RDMA_RECV","128")))
|
||||
cfg = ep.Config(
|
||||
int(os.environ.get("MSCCLPP_EP_NSM", "152")),
|
||||
int(os.environ.get("MSCCLPP_EP_NVL_SEND", "8")),
|
||||
int(os.environ.get("MSCCLPP_EP_NVL_RECV", "256")),
|
||||
int(os.environ.get("MSCCLPP_EP_RDMA_SEND", "16")),
|
||||
int(os.environ.get("MSCCLPP_EP_RDMA_RECV", "128")),
|
||||
)
|
||||
_bench_on = os.environ.get("MSCCLPP_EP_BENCH", "0") == "1"
|
||||
_buf_hidden = max(hidden, int(os.environ.get("MSCCLPP_EP_BENCH_HIDDEN", "0"))) if _bench_on else hidden
|
||||
num_nvl_bytes = cfg.get_nvl_buffer_size_hint(_buf_hidden * x.element_size(), num_ranks)
|
||||
|
||||
@@ -106,9 +106,11 @@ def main():
|
||||
|
||||
# Allocate Buffer (intranode only: num_rdma_bytes=0). Size the NVL buffer
|
||||
# using max(hidden, bench_hidden) so the optional bench phase fits.
|
||||
cfg = ep.Config(int(os.environ.get("MSCCLPP_EP_NUM_SMS", "20")),
|
||||
int(os.environ.get("MSCCLPP_EP_NVL_SEND", "8")),
|
||||
int(os.environ.get("MSCCLPP_EP_NVL_RECV", "256")))
|
||||
cfg = ep.Config(
|
||||
int(os.environ.get("MSCCLPP_EP_NUM_SMS", "20")),
|
||||
int(os.environ.get("MSCCLPP_EP_NVL_SEND", "8")),
|
||||
int(os.environ.get("MSCCLPP_EP_NVL_RECV", "256")),
|
||||
)
|
||||
_bench_on = os.environ.get("MSCCLPP_EP_BENCH", "0") == "1"
|
||||
_buf_hidden = max(hidden, int(os.environ.get("MSCCLPP_EP_BENCH_HIDDEN", "0"))) if _bench_on else hidden
|
||||
num_nvl_bytes = cfg.get_nvl_buffer_size_hint(_buf_hidden * x.element_size(), num_ranks)
|
||||
@@ -304,8 +306,8 @@ def main():
|
||||
# This matches NCCL-EP's `ep_bench` convention and isolates the on-GPU
|
||||
# dispatch kernel cost from one-time setup overhead.
|
||||
_layout = _dispatch()
|
||||
_cached_rpm = _layout[5] # rank_prefix_matrix
|
||||
_cached_cpm = _layout[6] # channel_prefix_matrix
|
||||
_cached_rpm = _layout[5] # rank_prefix_matrix
|
||||
_cached_cpm = _layout[6] # channel_prefix_matrix
|
||||
_cached_n = int(_layout[0].size(0)) # num_recv_tokens on this rank
|
||||
|
||||
def _dispatch_cached():
|
||||
|
||||
Reference in New Issue
Block a user