mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-11 08:50:21 +00:00
ext/ep: optional preallocated outputs for low_latency_dispatch
Add optional out_packed_recv_x / out_src_info / out_layout_range / out_count parameters to Buffer::low_latency_dispatch so callers can hoist the four recv-side allocations out of a hot loop, mirroring the existing out= path on low_latency_combine. The bench in test_low_latency_multirank.py preallocates these tensors once and passes them on every iter so the timed loop reflects kernel cost, not torch.empty + caching-allocator overhead.
This commit is contained in:
@@ -72,7 +72,15 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
.def("internode_dispatch", &mscclpp::ep::Buffer::internode_dispatch)
|
||||
.def("internode_combine", &mscclpp::ep::Buffer::internode_combine)
|
||||
.def("clean_low_latency_buffer", &mscclpp::ep::Buffer::clean_low_latency_buffer)
|
||||
.def("low_latency_dispatch", &mscclpp::ep::Buffer::low_latency_dispatch)
|
||||
.def("low_latency_dispatch", &mscclpp::ep::Buffer::low_latency_dispatch,
|
||||
py::arg("x"), py::arg("topk_idx"),
|
||||
py::arg("num_max_dispatch_tokens_per_rank"), py::arg("num_experts"),
|
||||
py::arg("use_fp8"), py::arg("async"), py::arg("return_recv_hook"),
|
||||
py::arg("out_packed_recv_x") = py::none(),
|
||||
py::arg("out_packed_recv_x_scales") = py::none(),
|
||||
py::arg("out_packed_recv_src_info") = py::none(),
|
||||
py::arg("out_packed_recv_layout_range") = py::none(),
|
||||
py::arg("out_packed_recv_count") = py::none())
|
||||
.def("low_latency_combine", &mscclpp::ep::Buffer::low_latency_combine)
|
||||
.def("get_next_low_latency_combine_buffer", &mscclpp::ep::Buffer::get_next_low_latency_combine_buffer);
|
||||
}
|
||||
|
||||
@@ -1334,7 +1334,12 @@ void Buffer::clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int
|
||||
std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor, torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
|
||||
Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx,
|
||||
int num_max_dispatch_tokens_per_rank, int num_experts,
|
||||
bool use_fp8, bool async, bool return_recv_hook) {
|
||||
bool use_fp8, bool async, bool return_recv_hook,
|
||||
const std::optional<torch::Tensor>& out_packed_recv_x,
|
||||
const std::optional<torch::Tensor>& out_packed_recv_x_scales,
|
||||
const std::optional<torch::Tensor>& out_packed_recv_src_info,
|
||||
const std::optional<torch::Tensor>& out_packed_recv_layout_range,
|
||||
const std::optional<torch::Tensor>& out_packed_recv_count) {
|
||||
EP_HOST_ASSERT(low_latency_mode);
|
||||
|
||||
EP_HOST_ASSERT(x.dim() == 2 and x.is_contiguous() and x.scalar_type() == torch::kBFloat16);
|
||||
@@ -1359,18 +1364,77 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
|
||||
if (not return_recv_hook)
|
||||
stream_wait(launch_stream, compute_stream);
|
||||
|
||||
auto packed_recv_x = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, hidden},
|
||||
x.options().dtype(use_fp8 ? torch::kFloat8_e4m3fn : torch::kBFloat16));
|
||||
auto packed_recv_src_info = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank}, torch::dtype(torch::kInt32).device(torch::kCUDA));
|
||||
auto packed_recv_layout_range = torch::empty({num_local_experts, num_ranks}, torch::dtype(torch::kInt64).device(torch::kCUDA));
|
||||
auto packed_recv_count = torch::empty({num_local_experts}, torch::dtype(torch::kInt32).device(torch::kCUDA));
|
||||
// Reusable output tensors. The largest (`packed_recv_x` ~58 MB at 7K hidden)
|
||||
// is what motivates the reuse path: a fresh torch::empty per call adds
|
||||
// measurable host overhead (~10us cumulative for the 4 allocations) which
|
||||
// shows up against NCCL-EP's preallocated bench at small payloads.
|
||||
const auto recv_x_dtype = use_fp8 ? torch::kFloat8_e4m3fn : torch::kBFloat16;
|
||||
torch::Tensor packed_recv_x;
|
||||
if (out_packed_recv_x.has_value()) {
|
||||
EP_HOST_ASSERT(out_packed_recv_x->dim() == 3 and out_packed_recv_x->is_contiguous());
|
||||
EP_HOST_ASSERT(out_packed_recv_x->size(0) == num_local_experts);
|
||||
EP_HOST_ASSERT(out_packed_recv_x->size(1) == num_ranks * num_max_dispatch_tokens_per_rank);
|
||||
EP_HOST_ASSERT(out_packed_recv_x->size(2) == hidden);
|
||||
EP_HOST_ASSERT(out_packed_recv_x->scalar_type() == recv_x_dtype);
|
||||
packed_recv_x = out_packed_recv_x.value();
|
||||
} else {
|
||||
packed_recv_x = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, hidden},
|
||||
x.options().dtype(recv_x_dtype));
|
||||
}
|
||||
torch::Tensor packed_recv_src_info;
|
||||
if (out_packed_recv_src_info.has_value()) {
|
||||
EP_HOST_ASSERT(out_packed_recv_src_info->dim() == 2 and out_packed_recv_src_info->is_contiguous());
|
||||
EP_HOST_ASSERT(out_packed_recv_src_info->size(0) == num_local_experts);
|
||||
EP_HOST_ASSERT(out_packed_recv_src_info->size(1) == num_ranks * num_max_dispatch_tokens_per_rank);
|
||||
EP_HOST_ASSERT(out_packed_recv_src_info->scalar_type() == torch::kInt32);
|
||||
packed_recv_src_info = out_packed_recv_src_info.value();
|
||||
} else {
|
||||
packed_recv_src_info = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank},
|
||||
torch::dtype(torch::kInt32).device(torch::kCUDA));
|
||||
}
|
||||
torch::Tensor packed_recv_layout_range;
|
||||
if (out_packed_recv_layout_range.has_value()) {
|
||||
EP_HOST_ASSERT(out_packed_recv_layout_range->dim() == 2 and out_packed_recv_layout_range->is_contiguous());
|
||||
EP_HOST_ASSERT(out_packed_recv_layout_range->size(0) == num_local_experts);
|
||||
EP_HOST_ASSERT(out_packed_recv_layout_range->size(1) == num_ranks);
|
||||
EP_HOST_ASSERT(out_packed_recv_layout_range->scalar_type() == torch::kInt64);
|
||||
packed_recv_layout_range = out_packed_recv_layout_range.value();
|
||||
} else {
|
||||
packed_recv_layout_range = torch::empty({num_local_experts, num_ranks},
|
||||
torch::dtype(torch::kInt64).device(torch::kCUDA));
|
||||
}
|
||||
torch::Tensor packed_recv_count;
|
||||
if (out_packed_recv_count.has_value()) {
|
||||
EP_HOST_ASSERT(out_packed_recv_count->dim() == 1 and out_packed_recv_count->is_contiguous());
|
||||
EP_HOST_ASSERT(out_packed_recv_count->size(0) == num_local_experts);
|
||||
EP_HOST_ASSERT(out_packed_recv_count->scalar_type() == torch::kInt32);
|
||||
packed_recv_count = out_packed_recv_count.value();
|
||||
} else {
|
||||
packed_recv_count = torch::empty({num_local_experts}, torch::dtype(torch::kInt32).device(torch::kCUDA));
|
||||
}
|
||||
|
||||
auto packed_recv_x_scales = std::optional<torch::Tensor>();
|
||||
float* packed_recv_x_scales_ptr = nullptr;
|
||||
if (use_fp8) {
|
||||
EP_HOST_ASSERT((num_ranks * num_max_dispatch_tokens_per_rank) % 4 == 0 and "TMA requires the number of tokens to be multiple of 4");
|
||||
packed_recv_x_scales = torch::empty({num_local_experts, num_scales, num_ranks * num_max_dispatch_tokens_per_rank}, torch::dtype(torch::kFloat32).device(torch::kCUDA));
|
||||
packed_recv_x_scales = torch::transpose(packed_recv_x_scales.value(), 1, 2);
|
||||
if (out_packed_recv_x_scales.has_value()) {
|
||||
// Caller-provided scales tensor must already be in the kernel's
|
||||
// expected (transposed) layout: shape [num_local_experts,
|
||||
// num_ranks*max_tokens, num_scales], strides such that
|
||||
// size(1)=num_ranks*max_tokens with the actual storage
|
||||
// [num_local_experts, num_scales, num_ranks*max_tokens] (i.e.
|
||||
// produced by `torch.empty(...).transpose(1, 2)`).
|
||||
EP_HOST_ASSERT(out_packed_recv_x_scales->dim() == 3);
|
||||
EP_HOST_ASSERT(out_packed_recv_x_scales->size(0) == num_local_experts);
|
||||
EP_HOST_ASSERT(out_packed_recv_x_scales->size(1) == num_ranks * num_max_dispatch_tokens_per_rank);
|
||||
EP_HOST_ASSERT(out_packed_recv_x_scales->size(2) == num_scales);
|
||||
EP_HOST_ASSERT(out_packed_recv_x_scales->scalar_type() == torch::kFloat32);
|
||||
packed_recv_x_scales = out_packed_recv_x_scales.value();
|
||||
} else {
|
||||
packed_recv_x_scales = torch::empty({num_local_experts, num_scales, num_ranks * num_max_dispatch_tokens_per_rank},
|
||||
torch::dtype(torch::kFloat32).device(torch::kCUDA));
|
||||
packed_recv_x_scales = torch::transpose(packed_recv_x_scales.value(), 1, 2);
|
||||
}
|
||||
packed_recv_x_scales_ptr = packed_recv_x_scales->data_ptr<float>();
|
||||
}
|
||||
|
||||
|
||||
@@ -172,7 +172,12 @@ public:
|
||||
std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor, torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
|
||||
low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx,
|
||||
int num_max_dispatch_tokens_per_rank, int num_experts,
|
||||
bool use_fp8, bool async, bool return_recv_hook);
|
||||
bool use_fp8, bool async, bool return_recv_hook,
|
||||
const std::optional<torch::Tensor>& out_packed_recv_x = std::nullopt,
|
||||
const std::optional<torch::Tensor>& out_packed_recv_x_scales = std::nullopt,
|
||||
const std::optional<torch::Tensor>& out_packed_recv_src_info = std::nullopt,
|
||||
const std::optional<torch::Tensor>& out_packed_recv_layout_range = std::nullopt,
|
||||
const std::optional<torch::Tensor>& out_packed_recv_count = std::nullopt);
|
||||
|
||||
std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
|
||||
low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_idx, const torch::Tensor& topk_weights,
|
||||
|
||||
@@ -221,10 +221,35 @@ def main():
|
||||
warmup = int(os.environ.get("MSCCLPP_EP_BENCH_WARMUP", "5"))
|
||||
iters = int(os.environ.get("MSCCLPP_EP_BENCH_ITERS", "20"))
|
||||
|
||||
# Hoist dispatch's output tensors out of the timed loop. The largest
|
||||
# (`packed_recv_x`, ~58 MB at 7K hidden) costs ~10us cumulative across
|
||||
# the four torch::empty calls per iter; reusing them brings the bench
|
||||
# in line with NCCL-EP `ep_bench` which preallocates output buffers.
|
||||
num_local_experts = num_experts // num_ranks
|
||||
bench_packed_recv_x = torch.empty(
|
||||
(num_local_experts, num_ranks * num_tokens, hidden),
|
||||
dtype=torch.bfloat16, device="cuda",
|
||||
)
|
||||
bench_packed_recv_src_info = torch.empty(
|
||||
(num_local_experts, num_ranks * num_tokens),
|
||||
dtype=torch.int32, device="cuda",
|
||||
)
|
||||
bench_packed_recv_layout_range = torch.empty(
|
||||
(num_local_experts, num_ranks), dtype=torch.int64, device="cuda",
|
||||
)
|
||||
bench_packed_recv_count = torch.empty(
|
||||
(num_local_experts,), dtype=torch.int32, device="cuda",
|
||||
)
|
||||
|
||||
def _dispatch():
|
||||
return buf.low_latency_dispatch(
|
||||
x, topk_idx, num_tokens, num_experts,
|
||||
False, False, False, # use_fp8, async, return_recv_hook
|
||||
bench_packed_recv_x,
|
||||
None, # x_scales (FP8 only)
|
||||
bench_packed_recv_src_info,
|
||||
bench_packed_recv_layout_range,
|
||||
bench_packed_recv_count,
|
||||
)
|
||||
|
||||
# Hoist combine's output-tensor allocation out of the timed loop so the
|
||||
|
||||
Reference in New Issue
Block a user