mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-13 09:46:00 +00:00
ext/ep: apply clang-format and black to fix CI lint failures
Run `tools/lint.sh cpp` (clang-format 14) and `tools/lint.sh py` (black) over the EP extension files added by this PR. No functional changes; pure reformatting to satisfy the cpplint and pylint CI jobs.
This commit is contained in:
@@ -46,8 +46,9 @@ def init_dist():
|
||||
world_size = int(os.environ["WORLD_SIZE"])
|
||||
local_rank = int(os.environ.get("LOCAL_RANK", rank % 8))
|
||||
torch.cuda.set_device(local_rank)
|
||||
dist.init_process_group(backend="nccl", world_size=world_size, rank=rank,
|
||||
device_id=torch.device(f"cuda:{local_rank}"))
|
||||
dist.init_process_group(
|
||||
backend="nccl", world_size=world_size, rank=rank, device_id=torch.device(f"cuda:{local_rank}")
|
||||
)
|
||||
return rank, world_size, local_rank, dist.new_group(list(range(world_size)))
|
||||
|
||||
|
||||
@@ -71,8 +72,9 @@ def main():
|
||||
from mscclpp.ext import ep
|
||||
|
||||
NUM_MAX_NVL_PEERS = 8
|
||||
assert num_ranks % NUM_MAX_NVL_PEERS == 0 and num_ranks > NUM_MAX_NVL_PEERS, \
|
||||
f"expected >1 node with 8 GPUs each, got num_ranks={num_ranks}"
|
||||
assert (
|
||||
num_ranks % NUM_MAX_NVL_PEERS == 0 and num_ranks > NUM_MAX_NVL_PEERS
|
||||
), f"expected >1 node with 8 GPUs each, got num_ranks={num_ranks}"
|
||||
num_nodes = num_ranks // NUM_MAX_NVL_PEERS
|
||||
num_local_ranks = NUM_MAX_NVL_PEERS
|
||||
|
||||
@@ -80,7 +82,7 @@ def main():
|
||||
num_tokens = 128
|
||||
hidden = 1024
|
||||
num_topk = min(4, num_ranks)
|
||||
num_experts = (num_ranks * 4) # multiple of num_ranks
|
||||
num_experts = num_ranks * 4 # multiple of num_ranks
|
||||
|
||||
torch.manual_seed(0xA1B2 + rank)
|
||||
|
||||
@@ -125,19 +127,25 @@ def main():
|
||||
num_nvl_bytes = cfg.get_nvl_buffer_size_hint(_buf_hidden * x.element_size(), num_ranks)
|
||||
num_rdma_bytes = cfg.get_rdma_buffer_size_hint(_buf_hidden * x.element_size(), num_ranks)
|
||||
if rank == 0:
|
||||
print(f"[cfg] num_nodes={num_nodes} num_ranks={num_ranks} num_tokens={num_tokens} "
|
||||
f"hidden={hidden} num_experts={num_experts} num_topk={num_topk} "
|
||||
f"num_nvl_bytes={num_nvl_bytes} num_rdma_bytes={num_rdma_bytes}",
|
||||
flush=True)
|
||||
print(
|
||||
f"[cfg] num_nodes={num_nodes} num_ranks={num_ranks} num_tokens={num_tokens} "
|
||||
f"hidden={hidden} num_experts={num_experts} num_topk={num_topk} "
|
||||
f"num_nvl_bytes={num_nvl_bytes} num_rdma_bytes={num_rdma_bytes}",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
print(f"[rank {rank}] creating Buffer", flush=True)
|
||||
buf = ep.Buffer(group, num_nvl_bytes=num_nvl_bytes, num_rdma_bytes=num_rdma_bytes, low_latency_mode=False)
|
||||
print(f"[rank {rank}] Buffer created is_available={buf.is_available()} "
|
||||
f"is_internode={buf.is_internode_available()}", flush=True)
|
||||
print(
|
||||
f"[rank {rank}] Buffer created is_available={buf.is_available()} "
|
||||
f"is_internode={buf.is_internode_available()}",
|
||||
flush=True,
|
||||
)
|
||||
assert buf.is_available() and buf.is_internode_available()
|
||||
|
||||
ref_rank, ref_rdma_rank, ref_exp, ref_in_rank, _ = \
|
||||
buf.runtime.get_dispatch_layout(topk_idx, num_experts, None, False, False)
|
||||
ref_rank, ref_rdma_rank, ref_exp, ref_in_rank, _ = buf.runtime.get_dispatch_layout(
|
||||
topk_idx, num_experts, None, False, False
|
||||
)
|
||||
assert torch.allclose(ref_rank, num_tokens_per_rank)
|
||||
assert torch.allclose(ref_rdma_rank, num_tokens_per_rdma_rank)
|
||||
assert torch.allclose(ref_exp, num_tokens_per_expert)
|
||||
@@ -153,17 +161,42 @@ def main():
|
||||
# cached_rdma_channel_prefix_matrix=None, cached_recv_rdma_rank_prefix_sum=None,
|
||||
# cached_gbl_channel_prefix_matrix=None, cached_recv_gbl_rank_prefix_sum=None,
|
||||
# expert_alignment, config, previous_event, async, allocate_on_comm_stream)
|
||||
(recv_x, recv_x_scales, recv_topk_idx, recv_topk_weights,
|
||||
num_recv_tokens_per_expert_list,
|
||||
rdma_channel_prefix_matrix, gbl_channel_prefix_matrix,
|
||||
recv_rdma_channel_prefix_matrix, recv_rdma_rank_prefix_sum,
|
||||
recv_gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum,
|
||||
recv_src_meta, send_rdma_head, send_nvl_head, _event) = buf.runtime.internode_dispatch(
|
||||
x, None, topk_idx, topk_weights,
|
||||
num_tokens_per_rank, num_tokens_per_rdma_rank, is_token_in_rank, num_tokens_per_expert,
|
||||
0, 0,
|
||||
None, None, None, None,
|
||||
1, cfg, None, False, False,
|
||||
(
|
||||
recv_x,
|
||||
recv_x_scales,
|
||||
recv_topk_idx,
|
||||
recv_topk_weights,
|
||||
num_recv_tokens_per_expert_list,
|
||||
rdma_channel_prefix_matrix,
|
||||
gbl_channel_prefix_matrix,
|
||||
recv_rdma_channel_prefix_matrix,
|
||||
recv_rdma_rank_prefix_sum,
|
||||
recv_gbl_channel_prefix_matrix,
|
||||
recv_gbl_rank_prefix_sum,
|
||||
recv_src_meta,
|
||||
send_rdma_head,
|
||||
send_nvl_head,
|
||||
_event,
|
||||
) = buf.runtime.internode_dispatch(
|
||||
x,
|
||||
None,
|
||||
topk_idx,
|
||||
topk_weights,
|
||||
num_tokens_per_rank,
|
||||
num_tokens_per_rdma_rank,
|
||||
is_token_in_rank,
|
||||
num_tokens_per_expert,
|
||||
0,
|
||||
0,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
1,
|
||||
cfg,
|
||||
None,
|
||||
False,
|
||||
False,
|
||||
)
|
||||
dist.barrier(group=group)
|
||||
|
||||
@@ -176,9 +209,9 @@ def main():
|
||||
if block.numel():
|
||||
lo = block.float().amin().item()
|
||||
hi = block.float().amax().item()
|
||||
assert abs(lo - src) < 1e-3 and abs(hi - src) < 1e-3, (
|
||||
f"rank{rank}: block from src={src} has range=[{lo}, {hi}], expected {src}"
|
||||
)
|
||||
assert (
|
||||
abs(lo - src) < 1e-3 and abs(hi - src) < 1e-3
|
||||
), f"rank{rank}: block from src={src} has range=[{lo}, {hi}], expected {src}"
|
||||
start = end
|
||||
if rank == 0:
|
||||
print(f"[dispatch] OK (recv {recv_x.size(0)} tokens)", flush=True)
|
||||
@@ -202,11 +235,19 @@ def main():
|
||||
# (`recv_rdma_channel_prefix_matrix`, `recv_rdma_rank_prefix_sum`,
|
||||
# `recv_gbl_channel_prefix_matrix`) — not the sender-side ones.
|
||||
combined_x, combined_topk_weights, _ = buf.runtime.internode_combine(
|
||||
recv_x, recv_topk_weights,
|
||||
recv_src_meta, is_token_in_rank,
|
||||
recv_rdma_channel_prefix_matrix, recv_rdma_rank_prefix_sum, recv_gbl_channel_prefix_matrix,
|
||||
send_rdma_head, send_nvl_head,
|
||||
cfg, None, False, False,
|
||||
recv_x,
|
||||
recv_topk_weights,
|
||||
recv_src_meta,
|
||||
is_token_in_rank,
|
||||
recv_rdma_channel_prefix_matrix,
|
||||
recv_rdma_rank_prefix_sum,
|
||||
recv_gbl_channel_prefix_matrix,
|
||||
send_rdma_head,
|
||||
send_nvl_head,
|
||||
cfg,
|
||||
None,
|
||||
False,
|
||||
False,
|
||||
)
|
||||
|
||||
num_dst = is_token_in_rank.sum(dim=1).to(torch.float32)
|
||||
@@ -235,19 +276,17 @@ def main():
|
||||
# NCCL-EP's `ep_bench -a ht` defaults (256 experts, top-8). The functional
|
||||
# check above still uses the smaller (num_experts=num_ranks*4, topk=4)
|
||||
# configuration.
|
||||
bench_num_experts = int(os.environ.get(
|
||||
"MSCCLPP_EP_BENCH_EXPERTS", str(num_experts)))
|
||||
bench_num_topk = int(os.environ.get(
|
||||
"MSCCLPP_EP_BENCH_TOPK", str(num_topk)))
|
||||
bench_num_experts = int(os.environ.get("MSCCLPP_EP_BENCH_EXPERTS", str(num_experts)))
|
||||
bench_num_topk = int(os.environ.get("MSCCLPP_EP_BENCH_TOPK", str(num_topk)))
|
||||
if bench_num_experts % num_ranks != 0:
|
||||
if rank == 0:
|
||||
print(f"[bench] skip: num_experts={bench_num_experts} not divisible "
|
||||
f"by num_ranks={num_ranks}", flush=True)
|
||||
print(
|
||||
f"[bench] skip: num_experts={bench_num_experts} not divisible " f"by num_ranks={num_ranks}", flush=True
|
||||
)
|
||||
return
|
||||
if bench_num_topk > bench_num_experts:
|
||||
if rank == 0:
|
||||
print(f"[bench] skip: topk={bench_num_topk} > experts={bench_num_experts}",
|
||||
flush=True)
|
||||
print(f"[bench] skip: topk={bench_num_topk} > experts={bench_num_experts}", flush=True)
|
||||
return
|
||||
|
||||
# Respect the Buffer's pre-sized num_nvl_bytes / num_rdma_bytes budget.
|
||||
@@ -294,20 +333,43 @@ def main():
|
||||
|
||||
def _dispatch():
|
||||
return buf.runtime.internode_dispatch(
|
||||
x_b, None, topk_idx_b, topk_weights_b,
|
||||
num_tokens_per_rank_b, num_tokens_per_rdma_rank_b, is_token_in_rank_b, num_tokens_per_expert_b,
|
||||
0, 0, None, None, None, None,
|
||||
1, cfg, None, False, False,
|
||||
x_b,
|
||||
None,
|
||||
topk_idx_b,
|
||||
topk_weights_b,
|
||||
num_tokens_per_rank_b,
|
||||
num_tokens_per_rdma_rank_b,
|
||||
is_token_in_rank_b,
|
||||
num_tokens_per_expert_b,
|
||||
0,
|
||||
0,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
1,
|
||||
cfg,
|
||||
None,
|
||||
False,
|
||||
False,
|
||||
)
|
||||
|
||||
def _combine(dout):
|
||||
(rx, _rxs, _rti, rtw, _lst,
|
||||
_rpm, _gpm, rrcpm, rrps, rgpm, _rgps,
|
||||
rsm, sh_rdma, sh_nvl, _ev) = dout
|
||||
rx, _rxs, _rti, rtw, _lst, _rpm, _gpm, rrcpm, rrps, rgpm, _rgps, rsm, sh_rdma, sh_nvl, _ev = dout
|
||||
buf.runtime.internode_combine(
|
||||
rx, rtw, rsm, is_token_in_rank_b,
|
||||
rrcpm, rrps, rgpm,
|
||||
sh_rdma, sh_nvl, cfg, None, False, False,
|
||||
rx,
|
||||
rtw,
|
||||
rsm,
|
||||
is_token_in_rank_b,
|
||||
rrcpm,
|
||||
rrps,
|
||||
rgpm,
|
||||
sh_rdma,
|
||||
sh_nvl,
|
||||
cfg,
|
||||
None,
|
||||
False,
|
||||
False,
|
||||
)
|
||||
|
||||
# Warmup (full round-trip with the sync/barrier guard between phases,
|
||||
@@ -369,16 +431,16 @@ def main():
|
||||
num_tokens_per_rank_b.to(torch.int64),
|
||||
group=group,
|
||||
)
|
||||
src_node = (torch.arange(num_ranks, device="cuda") // num_local_ranks)
|
||||
src_node = torch.arange(num_ranks, device="cuda") // num_local_ranks
|
||||
remote_mask = (src_node != local_node).to(torch.int64)
|
||||
total_recv_tokens_local = int(recv_from_src.sum().item())
|
||||
rdma_recv_tokens_local = int((recv_from_src * remote_mask).sum().item())
|
||||
|
||||
# Average per-rank token counts across ranks (matches NCCL-EP `Byte counts (per rank avg)`).
|
||||
counts_t = torch.tensor(
|
||||
[total_send_tokens_local, rdma_send_tokens_local,
|
||||
total_recv_tokens_local, rdma_recv_tokens_local],
|
||||
dtype=torch.float64, device="cuda",
|
||||
[total_send_tokens_local, rdma_send_tokens_local, total_recv_tokens_local, rdma_recv_tokens_local],
|
||||
dtype=torch.float64,
|
||||
device="cuda",
|
||||
)
|
||||
dist.all_reduce(counts_t, op=dist.ReduceOp.SUM, group=group)
|
||||
counts_avg = (counts_t / num_ranks).tolist()
|
||||
@@ -469,6 +531,7 @@ if __name__ == "__main__":
|
||||
main()
|
||||
except Exception:
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
finally:
|
||||
|
||||
@@ -111,9 +111,11 @@ def main():
|
||||
_buf_hidden = max(hidden, int(os.environ.get("MSCCLPP_EP_BENCH_HIDDEN", "0"))) if _bench_on else hidden
|
||||
num_nvl_bytes = cfg.get_nvl_buffer_size_hint(_buf_hidden * x.element_size(), num_ranks)
|
||||
if rank == 0:
|
||||
print(f"[cfg] num_ranks={num_ranks} num_tokens={num_tokens} hidden={hidden} "
|
||||
f"num_experts={num_experts} num_topk={num_topk} num_nvl_bytes={num_nvl_bytes}",
|
||||
flush=True)
|
||||
print(
|
||||
f"[cfg] num_ranks={num_ranks} num_tokens={num_tokens} hidden={hidden} "
|
||||
f"num_experts={num_experts} num_topk={num_topk} num_nvl_bytes={num_nvl_bytes}",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
print(f"[rank {rank}] creating Buffer", flush=True)
|
||||
buf = ep.Buffer(group, num_nvl_bytes=num_nvl_bytes, num_rdma_bytes=0, low_latency_mode=False)
|
||||
@@ -129,14 +131,34 @@ def main():
|
||||
print("[layout] OK", flush=True)
|
||||
|
||||
# Dispatch
|
||||
(recv_x, recv_x_scales, recv_topk_idx, recv_topk_weights,
|
||||
num_recv_tokens_per_expert_list,
|
||||
rank_prefix_matrix, channel_prefix_matrix, recv_channel_prefix_matrix, recv_src_idx,
|
||||
send_head, _event) = buf.runtime.intranode_dispatch(
|
||||
x, None, topk_idx, topk_weights,
|
||||
num_tokens_per_rank, is_token_in_rank, num_tokens_per_expert,
|
||||
0, None, None,
|
||||
1, cfg, None, False, False,
|
||||
(
|
||||
recv_x,
|
||||
recv_x_scales,
|
||||
recv_topk_idx,
|
||||
recv_topk_weights,
|
||||
num_recv_tokens_per_expert_list,
|
||||
rank_prefix_matrix,
|
||||
channel_prefix_matrix,
|
||||
recv_channel_prefix_matrix,
|
||||
recv_src_idx,
|
||||
send_head,
|
||||
_event,
|
||||
) = buf.runtime.intranode_dispatch(
|
||||
x,
|
||||
None,
|
||||
topk_idx,
|
||||
topk_weights,
|
||||
num_tokens_per_rank,
|
||||
is_token_in_rank,
|
||||
num_tokens_per_expert,
|
||||
0,
|
||||
None,
|
||||
None,
|
||||
1,
|
||||
cfg,
|
||||
None,
|
||||
False,
|
||||
False,
|
||||
)
|
||||
dist.barrier(group=group)
|
||||
|
||||
@@ -149,9 +171,7 @@ def main():
|
||||
block = recv_x[start:end]
|
||||
if block.numel():
|
||||
actual = block.float().amin().item()
|
||||
assert abs(actual - src) < 1e-3, (
|
||||
f"rank{rank}: block from src={src} has min={actual}, expected {src}"
|
||||
)
|
||||
assert abs(actual - src) < 1e-3, f"rank{rank}: block from src={src} has min={actual}, expected {src}"
|
||||
assert abs(block.float().amax().item() - src) < 1e-3
|
||||
start = end
|
||||
if rank == 0:
|
||||
@@ -165,9 +185,16 @@ def main():
|
||||
handle_channel_prefix_matrix = recv_channel_prefix_matrix
|
||||
|
||||
combined_x, combined_topk_weights, _ = buf.runtime.intranode_combine(
|
||||
recv_x, recv_topk_weights,
|
||||
handle_recv_src_idx, handle_rank_prefix_matrix, handle_channel_prefix_matrix,
|
||||
send_head, cfg, None, False, False,
|
||||
recv_x,
|
||||
recv_topk_weights,
|
||||
handle_recv_src_idx,
|
||||
handle_rank_prefix_matrix,
|
||||
handle_channel_prefix_matrix,
|
||||
send_head,
|
||||
cfg,
|
||||
None,
|
||||
False,
|
||||
False,
|
||||
)
|
||||
|
||||
# Expected: we dispatched with x = rank * ones, so every destination r
|
||||
@@ -201,19 +228,17 @@ def main():
|
||||
# NCCL-EP's `ep_bench -a ht` defaults (256 experts, top-8). The functional
|
||||
# check above still uses the smaller (num_experts=num_ranks*4, topk=4)
|
||||
# configuration.
|
||||
bench_num_experts = int(os.environ.get(
|
||||
"MSCCLPP_EP_BENCH_EXPERTS", str(num_experts)))
|
||||
bench_num_topk = int(os.environ.get(
|
||||
"MSCCLPP_EP_BENCH_TOPK", str(num_topk)))
|
||||
bench_num_experts = int(os.environ.get("MSCCLPP_EP_BENCH_EXPERTS", str(num_experts)))
|
||||
bench_num_topk = int(os.environ.get("MSCCLPP_EP_BENCH_TOPK", str(num_topk)))
|
||||
if bench_num_experts % num_ranks != 0:
|
||||
if rank == 0:
|
||||
print(f"[bench] skip: num_experts={bench_num_experts} not divisible "
|
||||
f"by num_ranks={num_ranks}", flush=True)
|
||||
print(
|
||||
f"[bench] skip: num_experts={bench_num_experts} not divisible " f"by num_ranks={num_ranks}", flush=True
|
||||
)
|
||||
return
|
||||
if bench_num_topk > bench_num_experts:
|
||||
if rank == 0:
|
||||
print(f"[bench] skip: topk={bench_num_topk} > experts={bench_num_experts}",
|
||||
flush=True)
|
||||
print(f"[bench] skip: topk={bench_num_topk} > experts={bench_num_experts}", flush=True)
|
||||
return
|
||||
|
||||
# Rebuild inputs at bench size. Keep same layout recipe as above but at
|
||||
@@ -253,15 +278,36 @@ def main():
|
||||
|
||||
def _dispatch():
|
||||
return buf.runtime.intranode_dispatch(
|
||||
x_b, None, topk_idx_b, topk_weights_b,
|
||||
num_tokens_per_rank_b, is_token_in_rank_b, num_tokens_per_expert_b,
|
||||
0, None, None, 1, cfg, None, False, False,
|
||||
x_b,
|
||||
None,
|
||||
topk_idx_b,
|
||||
topk_weights_b,
|
||||
num_tokens_per_rank_b,
|
||||
is_token_in_rank_b,
|
||||
num_tokens_per_expert_b,
|
||||
0,
|
||||
None,
|
||||
None,
|
||||
1,
|
||||
cfg,
|
||||
None,
|
||||
False,
|
||||
False,
|
||||
)
|
||||
|
||||
def _combine(dout):
|
||||
(rx, _rxs, _rti, rtw, _lst, rpm, _cpm, rcpm, rsi, sh, _ev) = dout
|
||||
rx, _rxs, _rti, rtw, _lst, rpm, _cpm, rcpm, rsi, sh, _ev = dout
|
||||
buf.runtime.intranode_combine(
|
||||
rx, rtw, rsi, rpm, rcpm, sh, cfg, None, False, False,
|
||||
rx,
|
||||
rtw,
|
||||
rsi,
|
||||
rpm,
|
||||
rcpm,
|
||||
sh,
|
||||
cfg,
|
||||
None,
|
||||
False,
|
||||
False,
|
||||
)
|
||||
|
||||
# Warmup (full round-trip).
|
||||
@@ -319,9 +365,9 @@ def main():
|
||||
|
||||
# Average per-rank token counts across ranks (matches NCCL-EP `Byte counts (per rank avg)`).
|
||||
counts_t = torch.tensor(
|
||||
[total_send_tokens_local, rdma_send_tokens_local,
|
||||
total_recv_tokens_local, rdma_recv_tokens_local],
|
||||
dtype=torch.float64, device="cuda",
|
||||
[total_send_tokens_local, rdma_send_tokens_local, total_recv_tokens_local, rdma_recv_tokens_local],
|
||||
dtype=torch.float64,
|
||||
device="cuda",
|
||||
)
|
||||
dist.all_reduce(counts_t, op=dist.ReduceOp.SUM, group=group)
|
||||
counts_avg = (counts_t / num_ranks).tolist()
|
||||
@@ -410,6 +456,7 @@ if __name__ == "__main__":
|
||||
main()
|
||||
except Exception:
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
finally:
|
||||
|
||||
@@ -70,7 +70,9 @@ def main():
|
||||
assert num_ranks - rank_offset < 257, "too many ranks for bf16 precision anchor"
|
||||
|
||||
num_tokens = int(os.environ.get("MSCCLPP_EP_LL_TOKENS", "64"))
|
||||
hidden = int(os.environ.get("MSCCLPP_EP_LL_HIDDEN", "7168")) # LL kernels are compiled for a fixed set; see SWITCH_HIDDEN
|
||||
hidden = int(
|
||||
os.environ.get("MSCCLPP_EP_LL_HIDDEN", "7168")
|
||||
) # LL kernels are compiled for a fixed set; see SWITCH_HIDDEN
|
||||
num_topk = int(os.environ.get("MSCCLPP_EP_LL_TOPK", "4"))
|
||||
num_experts_per_rank = int(os.environ.get("MSCCLPP_EP_LL_EXPERTS_PER_RANK", "4"))
|
||||
num_experts = num_ranks * num_experts_per_rank
|
||||
@@ -83,9 +85,7 @@ def main():
|
||||
x = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device="cuda") * (rank - rank_offset)
|
||||
# Encode the per-token index into the last 128 elements so the receiver
|
||||
# can verify which source token it is looking at.
|
||||
x[:, -128:] = (
|
||||
torch.arange(num_tokens, device="cuda").to(torch.bfloat16).view(-1, 1)
|
||||
)
|
||||
x[:, -128:] = torch.arange(num_tokens, device="cuda").to(torch.bfloat16).view(-1, 1)
|
||||
scores = torch.randn((num_tokens, num_experts), dtype=torch.float32, device="cuda").abs() + 1
|
||||
topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=True)[1]
|
||||
topk_weights = torch.randn((num_tokens, num_topk), dtype=torch.float32, device="cuda").abs()
|
||||
@@ -94,9 +94,7 @@ def main():
|
||||
for _ in range(min(10, num_tokens)):
|
||||
topk_idx[random.randint(0, num_tokens - 1), random.randint(0, num_topk - 1)] = -1
|
||||
|
||||
num_rdma_bytes = ep.Buffer.get_low_latency_rdma_size_hint(
|
||||
num_tokens, hidden, num_ranks, num_experts
|
||||
)
|
||||
num_rdma_bytes = ep.Buffer.get_low_latency_rdma_size_hint(num_tokens, hidden, num_ranks, num_experts)
|
||||
if rank == 0:
|
||||
print(
|
||||
f"[cfg] num_ranks={num_ranks} num_tokens={num_tokens} hidden={hidden} "
|
||||
@@ -129,12 +127,21 @@ def main():
|
||||
# packed_recv_count, packed_recv_src_info, packed_recv_layout_range,
|
||||
# event, hook
|
||||
(
|
||||
packed_recv_x, _packed_recv_x_scales,
|
||||
packed_recv_count, packed_recv_src_info, packed_recv_layout_range,
|
||||
_event, recv_hook,
|
||||
packed_recv_x,
|
||||
_packed_recv_x_scales,
|
||||
packed_recv_count,
|
||||
packed_recv_src_info,
|
||||
packed_recv_layout_range,
|
||||
_event,
|
||||
recv_hook,
|
||||
) = buf.low_latency_dispatch(
|
||||
x, topk_idx, num_tokens, num_experts,
|
||||
False, False, True, # use_fp8, async, return_recv_hook
|
||||
x,
|
||||
topk_idx,
|
||||
num_tokens,
|
||||
num_experts,
|
||||
False,
|
||||
False,
|
||||
True, # use_fp8, async, return_recv_hook
|
||||
)
|
||||
# Send phase launched on compute_stream; wait for local launch.
|
||||
torch.cuda.synchronize()
|
||||
@@ -158,12 +165,12 @@ def main():
|
||||
expected_count = int((all_topk_idx == expert_id).sum().item())
|
||||
recv_layout_range = handle[1][i]
|
||||
layout_sum = int((recv_layout_range & int_mask).sum().item())
|
||||
assert recv_count == expected_count, (
|
||||
f"rank{rank} expert{expert_id}: recv_count={recv_count} != expected={expected_count}"
|
||||
)
|
||||
assert layout_sum == recv_count, (
|
||||
f"rank{rank} expert{expert_id}: layout range sum {layout_sum} != recv_count {recv_count}"
|
||||
)
|
||||
assert (
|
||||
recv_count == expected_count
|
||||
), f"rank{rank} expert{expert_id}: recv_count={recv_count} != expected={expected_count}"
|
||||
assert (
|
||||
layout_sum == recv_count
|
||||
), f"rank{rank} expert{expert_id}: layout range sum {layout_sum} != recv_count {recv_count}"
|
||||
|
||||
if recv_count:
|
||||
recv_x = packed_recv_x[i, :recv_count]
|
||||
@@ -186,10 +193,16 @@ def main():
|
||||
# zero_copy, async, return_recv_hook, out)
|
||||
src_info, layout_range = handle[0], handle[1]
|
||||
combined_x, _event, _hook = buf.low_latency_combine(
|
||||
simulated_gemm_x, topk_idx, topk_weights,
|
||||
src_info, layout_range,
|
||||
num_tokens, num_experts,
|
||||
False, False, False, # zero_copy, async, return_recv_hook
|
||||
simulated_gemm_x,
|
||||
topk_idx,
|
||||
topk_weights,
|
||||
src_info,
|
||||
layout_range,
|
||||
num_tokens,
|
||||
num_experts,
|
||||
False,
|
||||
False,
|
||||
False, # zero_copy, async, return_recv_hook
|
||||
out,
|
||||
)
|
||||
|
||||
@@ -230,23 +243,34 @@ def main():
|
||||
num_local_experts = num_experts // num_ranks
|
||||
bench_packed_recv_x = torch.empty(
|
||||
(num_local_experts, num_ranks * num_tokens, hidden),
|
||||
dtype=torch.bfloat16, device="cuda",
|
||||
dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
)
|
||||
bench_packed_recv_src_info = torch.empty(
|
||||
(num_local_experts, num_ranks * num_tokens),
|
||||
dtype=torch.int32, device="cuda",
|
||||
dtype=torch.int32,
|
||||
device="cuda",
|
||||
)
|
||||
bench_packed_recv_layout_range = torch.empty(
|
||||
(num_local_experts, num_ranks), dtype=torch.int64, device="cuda",
|
||||
(num_local_experts, num_ranks),
|
||||
dtype=torch.int64,
|
||||
device="cuda",
|
||||
)
|
||||
bench_packed_recv_count = torch.empty(
|
||||
(num_local_experts,), dtype=torch.int32, device="cuda",
|
||||
(num_local_experts,),
|
||||
dtype=torch.int32,
|
||||
device="cuda",
|
||||
)
|
||||
|
||||
def _dispatch():
|
||||
return buf.low_latency_dispatch(
|
||||
x, topk_idx, num_tokens, num_experts,
|
||||
False, False, False, # use_fp8, async, return_recv_hook
|
||||
x,
|
||||
topk_idx,
|
||||
num_tokens,
|
||||
num_experts,
|
||||
False,
|
||||
False,
|
||||
False, # use_fp8, async, return_recv_hook
|
||||
bench_packed_recv_x,
|
||||
None, # x_scales (FP8 only)
|
||||
bench_packed_recv_src_info,
|
||||
@@ -261,12 +285,18 @@ def main():
|
||||
bench_out = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device="cuda")
|
||||
|
||||
def _combine(dout, out_):
|
||||
(recv_x, _scales, _cnt, src_info_, layout_range_, _ev, _hk) = dout
|
||||
recv_x, _scales, _cnt, src_info_, layout_range_, _ev, _hk = dout
|
||||
buf.low_latency_combine(
|
||||
recv_x, topk_idx, topk_weights,
|
||||
src_info_, layout_range_,
|
||||
num_tokens, num_experts,
|
||||
False, False, False,
|
||||
recv_x,
|
||||
topk_idx,
|
||||
topk_weights,
|
||||
src_info_,
|
||||
layout_range_,
|
||||
num_tokens,
|
||||
num_experts,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
out_,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user