ext/ep: apply clang-format and black to fix CI lint failures

Run `tools/lint.sh cpp` (clang-format 14) and `tools/lint.sh py`
(black) over the EP extension files added by this PR. No functional
changes; pure reformatting to satisfy the cpplint and pylint CI jobs.
This commit is contained in:
Qinghua Zhou
2026-05-06 04:12:20 +00:00
parent 01032fa167
commit e87c66a85d
18 changed files with 5521 additions and 5377 deletions

View File

@@ -46,8 +46,9 @@ def init_dist():
world_size = int(os.environ["WORLD_SIZE"])
local_rank = int(os.environ.get("LOCAL_RANK", rank % 8))
torch.cuda.set_device(local_rank)
dist.init_process_group(backend="nccl", world_size=world_size, rank=rank,
device_id=torch.device(f"cuda:{local_rank}"))
dist.init_process_group(
backend="nccl", world_size=world_size, rank=rank, device_id=torch.device(f"cuda:{local_rank}")
)
return rank, world_size, local_rank, dist.new_group(list(range(world_size)))
@@ -71,8 +72,9 @@ def main():
from mscclpp.ext import ep
NUM_MAX_NVL_PEERS = 8
assert num_ranks % NUM_MAX_NVL_PEERS == 0 and num_ranks > NUM_MAX_NVL_PEERS, \
f"expected >1 node with 8 GPUs each, got num_ranks={num_ranks}"
assert (
num_ranks % NUM_MAX_NVL_PEERS == 0 and num_ranks > NUM_MAX_NVL_PEERS
), f"expected >1 node with 8 GPUs each, got num_ranks={num_ranks}"
num_nodes = num_ranks // NUM_MAX_NVL_PEERS
num_local_ranks = NUM_MAX_NVL_PEERS
@@ -80,7 +82,7 @@ def main():
num_tokens = 128
hidden = 1024
num_topk = min(4, num_ranks)
num_experts = (num_ranks * 4) # multiple of num_ranks
num_experts = num_ranks * 4 # multiple of num_ranks
torch.manual_seed(0xA1B2 + rank)
@@ -125,19 +127,25 @@ def main():
num_nvl_bytes = cfg.get_nvl_buffer_size_hint(_buf_hidden * x.element_size(), num_ranks)
num_rdma_bytes = cfg.get_rdma_buffer_size_hint(_buf_hidden * x.element_size(), num_ranks)
if rank == 0:
print(f"[cfg] num_nodes={num_nodes} num_ranks={num_ranks} num_tokens={num_tokens} "
f"hidden={hidden} num_experts={num_experts} num_topk={num_topk} "
f"num_nvl_bytes={num_nvl_bytes} num_rdma_bytes={num_rdma_bytes}",
flush=True)
print(
f"[cfg] num_nodes={num_nodes} num_ranks={num_ranks} num_tokens={num_tokens} "
f"hidden={hidden} num_experts={num_experts} num_topk={num_topk} "
f"num_nvl_bytes={num_nvl_bytes} num_rdma_bytes={num_rdma_bytes}",
flush=True,
)
print(f"[rank {rank}] creating Buffer", flush=True)
buf = ep.Buffer(group, num_nvl_bytes=num_nvl_bytes, num_rdma_bytes=num_rdma_bytes, low_latency_mode=False)
print(f"[rank {rank}] Buffer created is_available={buf.is_available()} "
f"is_internode={buf.is_internode_available()}", flush=True)
print(
f"[rank {rank}] Buffer created is_available={buf.is_available()} "
f"is_internode={buf.is_internode_available()}",
flush=True,
)
assert buf.is_available() and buf.is_internode_available()
ref_rank, ref_rdma_rank, ref_exp, ref_in_rank, _ = \
buf.runtime.get_dispatch_layout(topk_idx, num_experts, None, False, False)
ref_rank, ref_rdma_rank, ref_exp, ref_in_rank, _ = buf.runtime.get_dispatch_layout(
topk_idx, num_experts, None, False, False
)
assert torch.allclose(ref_rank, num_tokens_per_rank)
assert torch.allclose(ref_rdma_rank, num_tokens_per_rdma_rank)
assert torch.allclose(ref_exp, num_tokens_per_expert)
@@ -153,17 +161,42 @@ def main():
# cached_rdma_channel_prefix_matrix=None, cached_recv_rdma_rank_prefix_sum=None,
# cached_gbl_channel_prefix_matrix=None, cached_recv_gbl_rank_prefix_sum=None,
# expert_alignment, config, previous_event, async, allocate_on_comm_stream)
(recv_x, recv_x_scales, recv_topk_idx, recv_topk_weights,
num_recv_tokens_per_expert_list,
rdma_channel_prefix_matrix, gbl_channel_prefix_matrix,
recv_rdma_channel_prefix_matrix, recv_rdma_rank_prefix_sum,
recv_gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum,
recv_src_meta, send_rdma_head, send_nvl_head, _event) = buf.runtime.internode_dispatch(
x, None, topk_idx, topk_weights,
num_tokens_per_rank, num_tokens_per_rdma_rank, is_token_in_rank, num_tokens_per_expert,
0, 0,
None, None, None, None,
1, cfg, None, False, False,
(
recv_x,
recv_x_scales,
recv_topk_idx,
recv_topk_weights,
num_recv_tokens_per_expert_list,
rdma_channel_prefix_matrix,
gbl_channel_prefix_matrix,
recv_rdma_channel_prefix_matrix,
recv_rdma_rank_prefix_sum,
recv_gbl_channel_prefix_matrix,
recv_gbl_rank_prefix_sum,
recv_src_meta,
send_rdma_head,
send_nvl_head,
_event,
) = buf.runtime.internode_dispatch(
x,
None,
topk_idx,
topk_weights,
num_tokens_per_rank,
num_tokens_per_rdma_rank,
is_token_in_rank,
num_tokens_per_expert,
0,
0,
None,
None,
None,
None,
1,
cfg,
None,
False,
False,
)
dist.barrier(group=group)
@@ -176,9 +209,9 @@ def main():
if block.numel():
lo = block.float().amin().item()
hi = block.float().amax().item()
assert abs(lo - src) < 1e-3 and abs(hi - src) < 1e-3, (
f"rank{rank}: block from src={src} has range=[{lo}, {hi}], expected {src}"
)
assert (
abs(lo - src) < 1e-3 and abs(hi - src) < 1e-3
), f"rank{rank}: block from src={src} has range=[{lo}, {hi}], expected {src}"
start = end
if rank == 0:
print(f"[dispatch] OK (recv {recv_x.size(0)} tokens)", flush=True)
@@ -202,11 +235,19 @@ def main():
# (`recv_rdma_channel_prefix_matrix`, `recv_rdma_rank_prefix_sum`,
# `recv_gbl_channel_prefix_matrix`) — not the sender-side ones.
combined_x, combined_topk_weights, _ = buf.runtime.internode_combine(
recv_x, recv_topk_weights,
recv_src_meta, is_token_in_rank,
recv_rdma_channel_prefix_matrix, recv_rdma_rank_prefix_sum, recv_gbl_channel_prefix_matrix,
send_rdma_head, send_nvl_head,
cfg, None, False, False,
recv_x,
recv_topk_weights,
recv_src_meta,
is_token_in_rank,
recv_rdma_channel_prefix_matrix,
recv_rdma_rank_prefix_sum,
recv_gbl_channel_prefix_matrix,
send_rdma_head,
send_nvl_head,
cfg,
None,
False,
False,
)
num_dst = is_token_in_rank.sum(dim=1).to(torch.float32)
@@ -235,19 +276,17 @@ def main():
# NCCL-EP's `ep_bench -a ht` defaults (256 experts, top-8). The functional
# check above still uses the smaller (num_experts=num_ranks*4, topk=4)
# configuration.
bench_num_experts = int(os.environ.get(
"MSCCLPP_EP_BENCH_EXPERTS", str(num_experts)))
bench_num_topk = int(os.environ.get(
"MSCCLPP_EP_BENCH_TOPK", str(num_topk)))
bench_num_experts = int(os.environ.get("MSCCLPP_EP_BENCH_EXPERTS", str(num_experts)))
bench_num_topk = int(os.environ.get("MSCCLPP_EP_BENCH_TOPK", str(num_topk)))
if bench_num_experts % num_ranks != 0:
if rank == 0:
print(f"[bench] skip: num_experts={bench_num_experts} not divisible "
f"by num_ranks={num_ranks}", flush=True)
print(
f"[bench] skip: num_experts={bench_num_experts} not divisible " f"by num_ranks={num_ranks}", flush=True
)
return
if bench_num_topk > bench_num_experts:
if rank == 0:
print(f"[bench] skip: topk={bench_num_topk} > experts={bench_num_experts}",
flush=True)
print(f"[bench] skip: topk={bench_num_topk} > experts={bench_num_experts}", flush=True)
return
# Respect the Buffer's pre-sized num_nvl_bytes / num_rdma_bytes budget.
@@ -294,20 +333,43 @@ def main():
def _dispatch():
return buf.runtime.internode_dispatch(
x_b, None, topk_idx_b, topk_weights_b,
num_tokens_per_rank_b, num_tokens_per_rdma_rank_b, is_token_in_rank_b, num_tokens_per_expert_b,
0, 0, None, None, None, None,
1, cfg, None, False, False,
x_b,
None,
topk_idx_b,
topk_weights_b,
num_tokens_per_rank_b,
num_tokens_per_rdma_rank_b,
is_token_in_rank_b,
num_tokens_per_expert_b,
0,
0,
None,
None,
None,
None,
1,
cfg,
None,
False,
False,
)
def _combine(dout):
(rx, _rxs, _rti, rtw, _lst,
_rpm, _gpm, rrcpm, rrps, rgpm, _rgps,
rsm, sh_rdma, sh_nvl, _ev) = dout
rx, _rxs, _rti, rtw, _lst, _rpm, _gpm, rrcpm, rrps, rgpm, _rgps, rsm, sh_rdma, sh_nvl, _ev = dout
buf.runtime.internode_combine(
rx, rtw, rsm, is_token_in_rank_b,
rrcpm, rrps, rgpm,
sh_rdma, sh_nvl, cfg, None, False, False,
rx,
rtw,
rsm,
is_token_in_rank_b,
rrcpm,
rrps,
rgpm,
sh_rdma,
sh_nvl,
cfg,
None,
False,
False,
)
# Warmup (full round-trip with the sync/barrier guard between phases,
@@ -369,16 +431,16 @@ def main():
num_tokens_per_rank_b.to(torch.int64),
group=group,
)
src_node = (torch.arange(num_ranks, device="cuda") // num_local_ranks)
src_node = torch.arange(num_ranks, device="cuda") // num_local_ranks
remote_mask = (src_node != local_node).to(torch.int64)
total_recv_tokens_local = int(recv_from_src.sum().item())
rdma_recv_tokens_local = int((recv_from_src * remote_mask).sum().item())
# Average per-rank token counts across ranks (matches NCCL-EP `Byte counts (per rank avg)`).
counts_t = torch.tensor(
[total_send_tokens_local, rdma_send_tokens_local,
total_recv_tokens_local, rdma_recv_tokens_local],
dtype=torch.float64, device="cuda",
[total_send_tokens_local, rdma_send_tokens_local, total_recv_tokens_local, rdma_recv_tokens_local],
dtype=torch.float64,
device="cuda",
)
dist.all_reduce(counts_t, op=dist.ReduceOp.SUM, group=group)
counts_avg = (counts_t / num_ranks).tolist()
@@ -469,6 +531,7 @@ if __name__ == "__main__":
main()
except Exception:
import traceback
traceback.print_exc()
sys.exit(1)
finally:

View File

@@ -111,9 +111,11 @@ def main():
_buf_hidden = max(hidden, int(os.environ.get("MSCCLPP_EP_BENCH_HIDDEN", "0"))) if _bench_on else hidden
num_nvl_bytes = cfg.get_nvl_buffer_size_hint(_buf_hidden * x.element_size(), num_ranks)
if rank == 0:
print(f"[cfg] num_ranks={num_ranks} num_tokens={num_tokens} hidden={hidden} "
f"num_experts={num_experts} num_topk={num_topk} num_nvl_bytes={num_nvl_bytes}",
flush=True)
print(
f"[cfg] num_ranks={num_ranks} num_tokens={num_tokens} hidden={hidden} "
f"num_experts={num_experts} num_topk={num_topk} num_nvl_bytes={num_nvl_bytes}",
flush=True,
)
print(f"[rank {rank}] creating Buffer", flush=True)
buf = ep.Buffer(group, num_nvl_bytes=num_nvl_bytes, num_rdma_bytes=0, low_latency_mode=False)
@@ -129,14 +131,34 @@ def main():
print("[layout] OK", flush=True)
# Dispatch
(recv_x, recv_x_scales, recv_topk_idx, recv_topk_weights,
num_recv_tokens_per_expert_list,
rank_prefix_matrix, channel_prefix_matrix, recv_channel_prefix_matrix, recv_src_idx,
send_head, _event) = buf.runtime.intranode_dispatch(
x, None, topk_idx, topk_weights,
num_tokens_per_rank, is_token_in_rank, num_tokens_per_expert,
0, None, None,
1, cfg, None, False, False,
(
recv_x,
recv_x_scales,
recv_topk_idx,
recv_topk_weights,
num_recv_tokens_per_expert_list,
rank_prefix_matrix,
channel_prefix_matrix,
recv_channel_prefix_matrix,
recv_src_idx,
send_head,
_event,
) = buf.runtime.intranode_dispatch(
x,
None,
topk_idx,
topk_weights,
num_tokens_per_rank,
is_token_in_rank,
num_tokens_per_expert,
0,
None,
None,
1,
cfg,
None,
False,
False,
)
dist.barrier(group=group)
@@ -149,9 +171,7 @@ def main():
block = recv_x[start:end]
if block.numel():
actual = block.float().amin().item()
assert abs(actual - src) < 1e-3, (
f"rank{rank}: block from src={src} has min={actual}, expected {src}"
)
assert abs(actual - src) < 1e-3, f"rank{rank}: block from src={src} has min={actual}, expected {src}"
assert abs(block.float().amax().item() - src) < 1e-3
start = end
if rank == 0:
@@ -165,9 +185,16 @@ def main():
handle_channel_prefix_matrix = recv_channel_prefix_matrix
combined_x, combined_topk_weights, _ = buf.runtime.intranode_combine(
recv_x, recv_topk_weights,
handle_recv_src_idx, handle_rank_prefix_matrix, handle_channel_prefix_matrix,
send_head, cfg, None, False, False,
recv_x,
recv_topk_weights,
handle_recv_src_idx,
handle_rank_prefix_matrix,
handle_channel_prefix_matrix,
send_head,
cfg,
None,
False,
False,
)
# Expected: we dispatched with x = rank * ones, so every destination r
@@ -201,19 +228,17 @@ def main():
# NCCL-EP's `ep_bench -a ht` defaults (256 experts, top-8). The functional
# check above still uses the smaller (num_experts=num_ranks*4, topk=4)
# configuration.
bench_num_experts = int(os.environ.get(
"MSCCLPP_EP_BENCH_EXPERTS", str(num_experts)))
bench_num_topk = int(os.environ.get(
"MSCCLPP_EP_BENCH_TOPK", str(num_topk)))
bench_num_experts = int(os.environ.get("MSCCLPP_EP_BENCH_EXPERTS", str(num_experts)))
bench_num_topk = int(os.environ.get("MSCCLPP_EP_BENCH_TOPK", str(num_topk)))
if bench_num_experts % num_ranks != 0:
if rank == 0:
print(f"[bench] skip: num_experts={bench_num_experts} not divisible "
f"by num_ranks={num_ranks}", flush=True)
print(
f"[bench] skip: num_experts={bench_num_experts} not divisible " f"by num_ranks={num_ranks}", flush=True
)
return
if bench_num_topk > bench_num_experts:
if rank == 0:
print(f"[bench] skip: topk={bench_num_topk} > experts={bench_num_experts}",
flush=True)
print(f"[bench] skip: topk={bench_num_topk} > experts={bench_num_experts}", flush=True)
return
# Rebuild inputs at bench size. Keep same layout recipe as above but at
@@ -253,15 +278,36 @@ def main():
def _dispatch():
return buf.runtime.intranode_dispatch(
x_b, None, topk_idx_b, topk_weights_b,
num_tokens_per_rank_b, is_token_in_rank_b, num_tokens_per_expert_b,
0, None, None, 1, cfg, None, False, False,
x_b,
None,
topk_idx_b,
topk_weights_b,
num_tokens_per_rank_b,
is_token_in_rank_b,
num_tokens_per_expert_b,
0,
None,
None,
1,
cfg,
None,
False,
False,
)
def _combine(dout):
(rx, _rxs, _rti, rtw, _lst, rpm, _cpm, rcpm, rsi, sh, _ev) = dout
rx, _rxs, _rti, rtw, _lst, rpm, _cpm, rcpm, rsi, sh, _ev = dout
buf.runtime.intranode_combine(
rx, rtw, rsi, rpm, rcpm, sh, cfg, None, False, False,
rx,
rtw,
rsi,
rpm,
rcpm,
sh,
cfg,
None,
False,
False,
)
# Warmup (full round-trip).
@@ -319,9 +365,9 @@ def main():
# Average per-rank token counts across ranks (matches NCCL-EP `Byte counts (per rank avg)`).
counts_t = torch.tensor(
[total_send_tokens_local, rdma_send_tokens_local,
total_recv_tokens_local, rdma_recv_tokens_local],
dtype=torch.float64, device="cuda",
[total_send_tokens_local, rdma_send_tokens_local, total_recv_tokens_local, rdma_recv_tokens_local],
dtype=torch.float64,
device="cuda",
)
dist.all_reduce(counts_t, op=dist.ReduceOp.SUM, group=group)
counts_avg = (counts_t / num_ranks).tolist()
@@ -410,6 +456,7 @@ if __name__ == "__main__":
main()
except Exception:
import traceback
traceback.print_exc()
sys.exit(1)
finally:

View File

@@ -70,7 +70,9 @@ def main():
assert num_ranks - rank_offset < 257, "too many ranks for bf16 precision anchor"
num_tokens = int(os.environ.get("MSCCLPP_EP_LL_TOKENS", "64"))
hidden = int(os.environ.get("MSCCLPP_EP_LL_HIDDEN", "7168")) # LL kernels are compiled for a fixed set; see SWITCH_HIDDEN
hidden = int(
os.environ.get("MSCCLPP_EP_LL_HIDDEN", "7168")
) # LL kernels are compiled for a fixed set; see SWITCH_HIDDEN
num_topk = int(os.environ.get("MSCCLPP_EP_LL_TOPK", "4"))
num_experts_per_rank = int(os.environ.get("MSCCLPP_EP_LL_EXPERTS_PER_RANK", "4"))
num_experts = num_ranks * num_experts_per_rank
@@ -83,9 +85,7 @@ def main():
x = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device="cuda") * (rank - rank_offset)
# Encode the per-token index into the last 128 elements so the receiver
# can verify which source token it is looking at.
x[:, -128:] = (
torch.arange(num_tokens, device="cuda").to(torch.bfloat16).view(-1, 1)
)
x[:, -128:] = torch.arange(num_tokens, device="cuda").to(torch.bfloat16).view(-1, 1)
scores = torch.randn((num_tokens, num_experts), dtype=torch.float32, device="cuda").abs() + 1
topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=True)[1]
topk_weights = torch.randn((num_tokens, num_topk), dtype=torch.float32, device="cuda").abs()
@@ -94,9 +94,7 @@ def main():
for _ in range(min(10, num_tokens)):
topk_idx[random.randint(0, num_tokens - 1), random.randint(0, num_topk - 1)] = -1
num_rdma_bytes = ep.Buffer.get_low_latency_rdma_size_hint(
num_tokens, hidden, num_ranks, num_experts
)
num_rdma_bytes = ep.Buffer.get_low_latency_rdma_size_hint(num_tokens, hidden, num_ranks, num_experts)
if rank == 0:
print(
f"[cfg] num_ranks={num_ranks} num_tokens={num_tokens} hidden={hidden} "
@@ -129,12 +127,21 @@ def main():
# packed_recv_count, packed_recv_src_info, packed_recv_layout_range,
# event, hook
(
packed_recv_x, _packed_recv_x_scales,
packed_recv_count, packed_recv_src_info, packed_recv_layout_range,
_event, recv_hook,
packed_recv_x,
_packed_recv_x_scales,
packed_recv_count,
packed_recv_src_info,
packed_recv_layout_range,
_event,
recv_hook,
) = buf.low_latency_dispatch(
x, topk_idx, num_tokens, num_experts,
False, False, True, # use_fp8, async, return_recv_hook
x,
topk_idx,
num_tokens,
num_experts,
False,
False,
True, # use_fp8, async, return_recv_hook
)
# Send phase launched on compute_stream; wait for local launch.
torch.cuda.synchronize()
@@ -158,12 +165,12 @@ def main():
expected_count = int((all_topk_idx == expert_id).sum().item())
recv_layout_range = handle[1][i]
layout_sum = int((recv_layout_range & int_mask).sum().item())
assert recv_count == expected_count, (
f"rank{rank} expert{expert_id}: recv_count={recv_count} != expected={expected_count}"
)
assert layout_sum == recv_count, (
f"rank{rank} expert{expert_id}: layout range sum {layout_sum} != recv_count {recv_count}"
)
assert (
recv_count == expected_count
), f"rank{rank} expert{expert_id}: recv_count={recv_count} != expected={expected_count}"
assert (
layout_sum == recv_count
), f"rank{rank} expert{expert_id}: layout range sum {layout_sum} != recv_count {recv_count}"
if recv_count:
recv_x = packed_recv_x[i, :recv_count]
@@ -186,10 +193,16 @@ def main():
# zero_copy, async, return_recv_hook, out)
src_info, layout_range = handle[0], handle[1]
combined_x, _event, _hook = buf.low_latency_combine(
simulated_gemm_x, topk_idx, topk_weights,
src_info, layout_range,
num_tokens, num_experts,
False, False, False, # zero_copy, async, return_recv_hook
simulated_gemm_x,
topk_idx,
topk_weights,
src_info,
layout_range,
num_tokens,
num_experts,
False,
False,
False, # zero_copy, async, return_recv_hook
out,
)
@@ -230,23 +243,34 @@ def main():
num_local_experts = num_experts // num_ranks
bench_packed_recv_x = torch.empty(
(num_local_experts, num_ranks * num_tokens, hidden),
dtype=torch.bfloat16, device="cuda",
dtype=torch.bfloat16,
device="cuda",
)
bench_packed_recv_src_info = torch.empty(
(num_local_experts, num_ranks * num_tokens),
dtype=torch.int32, device="cuda",
dtype=torch.int32,
device="cuda",
)
bench_packed_recv_layout_range = torch.empty(
(num_local_experts, num_ranks), dtype=torch.int64, device="cuda",
(num_local_experts, num_ranks),
dtype=torch.int64,
device="cuda",
)
bench_packed_recv_count = torch.empty(
(num_local_experts,), dtype=torch.int32, device="cuda",
(num_local_experts,),
dtype=torch.int32,
device="cuda",
)
def _dispatch():
return buf.low_latency_dispatch(
x, topk_idx, num_tokens, num_experts,
False, False, False, # use_fp8, async, return_recv_hook
x,
topk_idx,
num_tokens,
num_experts,
False,
False,
False, # use_fp8, async, return_recv_hook
bench_packed_recv_x,
None, # x_scales (FP8 only)
bench_packed_recv_src_info,
@@ -261,12 +285,18 @@ def main():
bench_out = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device="cuda")
def _combine(dout, out_):
(recv_x, _scales, _cnt, src_info_, layout_range_, _ev, _hk) = dout
recv_x, _scales, _cnt, src_info_, layout_range_, _ev, _hk = dout
buf.low_latency_combine(
recv_x, topk_idx, topk_weights,
src_info_, layout_range_,
num_tokens, num_experts,
False, False, False,
recv_x,
topk_idx,
topk_weights,
src_info_,
layout_range_,
num_tokens,
num_experts,
False,
False,
False,
out_,
)