ext/ep: optional preallocated outputs for low_latency_dispatch

Add optional out_packed_recv_x / out_src_info / out_layout_range /
out_count parameters to Buffer::low_latency_dispatch so callers can
hoist the four recv-side allocations out of a hot loop, mirroring the
existing out= path on low_latency_combine.

The bench in test_low_latency_multirank.py preallocates these tensors
once and passes them on every iter so the timed loop reflects kernel
cost, not torch.empty + caching-allocator overhead.
This commit is contained in:
Qinghua Zhou
2026-04-30 18:45:44 +00:00
parent 2529774868
commit fdf7d579dc
4 changed files with 112 additions and 10 deletions

View File

@@ -72,7 +72,15 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.def("internode_dispatch", &mscclpp::ep::Buffer::internode_dispatch)
.def("internode_combine", &mscclpp::ep::Buffer::internode_combine)
.def("clean_low_latency_buffer", &mscclpp::ep::Buffer::clean_low_latency_buffer)
.def("low_latency_dispatch", &mscclpp::ep::Buffer::low_latency_dispatch)
.def("low_latency_dispatch", &mscclpp::ep::Buffer::low_latency_dispatch,
py::arg("x"), py::arg("topk_idx"),
py::arg("num_max_dispatch_tokens_per_rank"), py::arg("num_experts"),
py::arg("use_fp8"), py::arg("async"), py::arg("return_recv_hook"),
py::arg("out_packed_recv_x") = py::none(),
py::arg("out_packed_recv_x_scales") = py::none(),
py::arg("out_packed_recv_src_info") = py::none(),
py::arg("out_packed_recv_layout_range") = py::none(),
py::arg("out_packed_recv_count") = py::none())
.def("low_latency_combine", &mscclpp::ep::Buffer::low_latency_combine)
.def("get_next_low_latency_combine_buffer", &mscclpp::ep::Buffer::get_next_low_latency_combine_buffer);
}

View File

@@ -1334,7 +1334,12 @@ void Buffer::clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int
std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor, torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx,
int num_max_dispatch_tokens_per_rank, int num_experts,
bool use_fp8, bool async, bool return_recv_hook) {
bool use_fp8, bool async, bool return_recv_hook,
const std::optional<torch::Tensor>& out_packed_recv_x,
const std::optional<torch::Tensor>& out_packed_recv_x_scales,
const std::optional<torch::Tensor>& out_packed_recv_src_info,
const std::optional<torch::Tensor>& out_packed_recv_layout_range,
const std::optional<torch::Tensor>& out_packed_recv_count) {
EP_HOST_ASSERT(low_latency_mode);
EP_HOST_ASSERT(x.dim() == 2 and x.is_contiguous() and x.scalar_type() == torch::kBFloat16);
@@ -1359,18 +1364,77 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
if (not return_recv_hook)
stream_wait(launch_stream, compute_stream);
auto packed_recv_x = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, hidden},
x.options().dtype(use_fp8 ? torch::kFloat8_e4m3fn : torch::kBFloat16));
auto packed_recv_src_info = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank}, torch::dtype(torch::kInt32).device(torch::kCUDA));
auto packed_recv_layout_range = torch::empty({num_local_experts, num_ranks}, torch::dtype(torch::kInt64).device(torch::kCUDA));
auto packed_recv_count = torch::empty({num_local_experts}, torch::dtype(torch::kInt32).device(torch::kCUDA));
// Reusable output tensors. The largest (`packed_recv_x` ~58 MB at 7K hidden)
// is what motivates the reuse path: a fresh torch::empty per call adds
// measurable host overhead (~10us cumulative for the 4 allocations) which
// shows up against NCCL-EP's preallocated bench at small payloads.
const auto recv_x_dtype = use_fp8 ? torch::kFloat8_e4m3fn : torch::kBFloat16;
torch::Tensor packed_recv_x;
if (out_packed_recv_x.has_value()) {
EP_HOST_ASSERT(out_packed_recv_x->dim() == 3 and out_packed_recv_x->is_contiguous());
EP_HOST_ASSERT(out_packed_recv_x->size(0) == num_local_experts);
EP_HOST_ASSERT(out_packed_recv_x->size(1) == num_ranks * num_max_dispatch_tokens_per_rank);
EP_HOST_ASSERT(out_packed_recv_x->size(2) == hidden);
EP_HOST_ASSERT(out_packed_recv_x->scalar_type() == recv_x_dtype);
packed_recv_x = out_packed_recv_x.value();
} else {
packed_recv_x = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, hidden},
x.options().dtype(recv_x_dtype));
}
torch::Tensor packed_recv_src_info;
if (out_packed_recv_src_info.has_value()) {
EP_HOST_ASSERT(out_packed_recv_src_info->dim() == 2 and out_packed_recv_src_info->is_contiguous());
EP_HOST_ASSERT(out_packed_recv_src_info->size(0) == num_local_experts);
EP_HOST_ASSERT(out_packed_recv_src_info->size(1) == num_ranks * num_max_dispatch_tokens_per_rank);
EP_HOST_ASSERT(out_packed_recv_src_info->scalar_type() == torch::kInt32);
packed_recv_src_info = out_packed_recv_src_info.value();
} else {
packed_recv_src_info = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank},
torch::dtype(torch::kInt32).device(torch::kCUDA));
}
torch::Tensor packed_recv_layout_range;
if (out_packed_recv_layout_range.has_value()) {
EP_HOST_ASSERT(out_packed_recv_layout_range->dim() == 2 and out_packed_recv_layout_range->is_contiguous());
EP_HOST_ASSERT(out_packed_recv_layout_range->size(0) == num_local_experts);
EP_HOST_ASSERT(out_packed_recv_layout_range->size(1) == num_ranks);
EP_HOST_ASSERT(out_packed_recv_layout_range->scalar_type() == torch::kInt64);
packed_recv_layout_range = out_packed_recv_layout_range.value();
} else {
packed_recv_layout_range = torch::empty({num_local_experts, num_ranks},
torch::dtype(torch::kInt64).device(torch::kCUDA));
}
torch::Tensor packed_recv_count;
if (out_packed_recv_count.has_value()) {
EP_HOST_ASSERT(out_packed_recv_count->dim() == 1 and out_packed_recv_count->is_contiguous());
EP_HOST_ASSERT(out_packed_recv_count->size(0) == num_local_experts);
EP_HOST_ASSERT(out_packed_recv_count->scalar_type() == torch::kInt32);
packed_recv_count = out_packed_recv_count.value();
} else {
packed_recv_count = torch::empty({num_local_experts}, torch::dtype(torch::kInt32).device(torch::kCUDA));
}
auto packed_recv_x_scales = std::optional<torch::Tensor>();
float* packed_recv_x_scales_ptr = nullptr;
if (use_fp8) {
EP_HOST_ASSERT((num_ranks * num_max_dispatch_tokens_per_rank) % 4 == 0 and "TMA requires the number of tokens to be multiple of 4");
packed_recv_x_scales = torch::empty({num_local_experts, num_scales, num_ranks * num_max_dispatch_tokens_per_rank}, torch::dtype(torch::kFloat32).device(torch::kCUDA));
packed_recv_x_scales = torch::transpose(packed_recv_x_scales.value(), 1, 2);
if (out_packed_recv_x_scales.has_value()) {
// Caller-provided scales tensor must already be in the kernel's
// expected (transposed) layout: shape [num_local_experts,
// num_ranks*max_tokens, num_scales], strides such that
// size(1)=num_ranks*max_tokens with the actual storage
// [num_local_experts, num_scales, num_ranks*max_tokens] (i.e.
// produced by `torch.empty(...).transpose(1, 2)`).
EP_HOST_ASSERT(out_packed_recv_x_scales->dim() == 3);
EP_HOST_ASSERT(out_packed_recv_x_scales->size(0) == num_local_experts);
EP_HOST_ASSERT(out_packed_recv_x_scales->size(1) == num_ranks * num_max_dispatch_tokens_per_rank);
EP_HOST_ASSERT(out_packed_recv_x_scales->size(2) == num_scales);
EP_HOST_ASSERT(out_packed_recv_x_scales->scalar_type() == torch::kFloat32);
packed_recv_x_scales = out_packed_recv_x_scales.value();
} else {
packed_recv_x_scales = torch::empty({num_local_experts, num_scales, num_ranks * num_max_dispatch_tokens_per_rank},
torch::dtype(torch::kFloat32).device(torch::kCUDA));
packed_recv_x_scales = torch::transpose(packed_recv_x_scales.value(), 1, 2);
}
packed_recv_x_scales_ptr = packed_recv_x_scales->data_ptr<float>();
}

View File

@@ -172,7 +172,12 @@ public:
std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor, torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx,
int num_max_dispatch_tokens_per_rank, int num_experts,
bool use_fp8, bool async, bool return_recv_hook);
bool use_fp8, bool async, bool return_recv_hook,
const std::optional<torch::Tensor>& out_packed_recv_x = std::nullopt,
const std::optional<torch::Tensor>& out_packed_recv_x_scales = std::nullopt,
const std::optional<torch::Tensor>& out_packed_recv_src_info = std::nullopt,
const std::optional<torch::Tensor>& out_packed_recv_layout_range = std::nullopt,
const std::optional<torch::Tensor>& out_packed_recv_count = std::nullopt);
std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_idx, const torch::Tensor& topk_weights,

View File

@@ -221,10 +221,35 @@ def main():
warmup = int(os.environ.get("MSCCLPP_EP_BENCH_WARMUP", "5"))
iters = int(os.environ.get("MSCCLPP_EP_BENCH_ITERS", "20"))
# Hoist dispatch's output tensors out of the timed loop. The largest
# (`packed_recv_x`, ~58 MB at 7K hidden) costs ~10us cumulative across
# the four torch::empty calls per iter; reusing them brings the bench
# in line with NCCL-EP `ep_bench` which preallocates output buffers.
num_local_experts = num_experts // num_ranks
bench_packed_recv_x = torch.empty(
(num_local_experts, num_ranks * num_tokens, hidden),
dtype=torch.bfloat16, device="cuda",
)
bench_packed_recv_src_info = torch.empty(
(num_local_experts, num_ranks * num_tokens),
dtype=torch.int32, device="cuda",
)
bench_packed_recv_layout_range = torch.empty(
(num_local_experts, num_ranks), dtype=torch.int64, device="cuda",
)
bench_packed_recv_count = torch.empty(
(num_local_experts,), dtype=torch.int32, device="cuda",
)
def _dispatch():
return buf.low_latency_dispatch(
x, topk_idx, num_tokens, num_experts,
False, False, False, # use_fp8, async, return_recv_hook
bench_packed_recv_x,
None, # x_scales (FP8 only)
bench_packed_recv_src_info,
bench_packed_recv_layout_range,
bench_packed_recv_count,
)
# Hoist combine's output-tensor allocation out of the timed loop so the