diff --git a/src/ext/ep/bindings.cpp b/src/ext/ep/bindings.cpp index e721e3d2..e4f1d9fd 100644 --- a/src/ext/ep/bindings.cpp +++ b/src/ext/ep/bindings.cpp @@ -72,7 +72,15 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { .def("internode_dispatch", &mscclpp::ep::Buffer::internode_dispatch) .def("internode_combine", &mscclpp::ep::Buffer::internode_combine) .def("clean_low_latency_buffer", &mscclpp::ep::Buffer::clean_low_latency_buffer) - .def("low_latency_dispatch", &mscclpp::ep::Buffer::low_latency_dispatch) + .def("low_latency_dispatch", &mscclpp::ep::Buffer::low_latency_dispatch, + py::arg("x"), py::arg("topk_idx"), + py::arg("num_max_dispatch_tokens_per_rank"), py::arg("num_experts"), + py::arg("use_fp8"), py::arg("async"), py::arg("return_recv_hook"), + py::arg("out_packed_recv_x") = py::none(), + py::arg("out_packed_recv_x_scales") = py::none(), + py::arg("out_packed_recv_src_info") = py::none(), + py::arg("out_packed_recv_layout_range") = py::none(), + py::arg("out_packed_recv_count") = py::none()) .def("low_latency_combine", &mscclpp::ep::Buffer::low_latency_combine) .def("get_next_low_latency_combine_buffer", &mscclpp::ep::Buffer::get_next_low_latency_combine_buffer); } diff --git a/src/ext/ep/buffer.cc b/src/ext/ep/buffer.cc index d8151bac..1fd95169 100644 --- a/src/ext/ep/buffer.cc +++ b/src/ext/ep/buffer.cc @@ -1334,7 +1334,12 @@ void Buffer::clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int std::tuple, torch::Tensor, torch::Tensor, torch::Tensor, std::optional, std::optional>> Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx, int num_max_dispatch_tokens_per_rank, int num_experts, - bool use_fp8, bool async, bool return_recv_hook) { + bool use_fp8, bool async, bool return_recv_hook, + const std::optional& out_packed_recv_x, + const std::optional& out_packed_recv_x_scales, + const std::optional& out_packed_recv_src_info, + const std::optional& out_packed_recv_layout_range, + const std::optional& out_packed_recv_count) { EP_HOST_ASSERT(low_latency_mode); EP_HOST_ASSERT(x.dim() == 2 and x.is_contiguous() and x.scalar_type() == torch::kBFloat16); @@ -1359,18 +1364,77 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i if (not return_recv_hook) stream_wait(launch_stream, compute_stream); - auto packed_recv_x = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, hidden}, - x.options().dtype(use_fp8 ? torch::kFloat8_e4m3fn : torch::kBFloat16)); - auto packed_recv_src_info = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank}, torch::dtype(torch::kInt32).device(torch::kCUDA)); - auto packed_recv_layout_range = torch::empty({num_local_experts, num_ranks}, torch::dtype(torch::kInt64).device(torch::kCUDA)); - auto packed_recv_count = torch::empty({num_local_experts}, torch::dtype(torch::kInt32).device(torch::kCUDA)); + // Reusable output tensors. The largest (`packed_recv_x` ~58 MB at 7K hidden) + // is what motivates the reuse path: a fresh torch::empty per call adds + // measurable host overhead (~10us cumulative for the 4 allocations) which + // shows up against NCCL-EP's preallocated bench at small payloads. + const auto recv_x_dtype = use_fp8 ? torch::kFloat8_e4m3fn : torch::kBFloat16; + torch::Tensor packed_recv_x; + if (out_packed_recv_x.has_value()) { + EP_HOST_ASSERT(out_packed_recv_x->dim() == 3 and out_packed_recv_x->is_contiguous()); + EP_HOST_ASSERT(out_packed_recv_x->size(0) == num_local_experts); + EP_HOST_ASSERT(out_packed_recv_x->size(1) == num_ranks * num_max_dispatch_tokens_per_rank); + EP_HOST_ASSERT(out_packed_recv_x->size(2) == hidden); + EP_HOST_ASSERT(out_packed_recv_x->scalar_type() == recv_x_dtype); + packed_recv_x = out_packed_recv_x.value(); + } else { + packed_recv_x = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, hidden}, + x.options().dtype(recv_x_dtype)); + } + torch::Tensor packed_recv_src_info; + if (out_packed_recv_src_info.has_value()) { + EP_HOST_ASSERT(out_packed_recv_src_info->dim() == 2 and out_packed_recv_src_info->is_contiguous()); + EP_HOST_ASSERT(out_packed_recv_src_info->size(0) == num_local_experts); + EP_HOST_ASSERT(out_packed_recv_src_info->size(1) == num_ranks * num_max_dispatch_tokens_per_rank); + EP_HOST_ASSERT(out_packed_recv_src_info->scalar_type() == torch::kInt32); + packed_recv_src_info = out_packed_recv_src_info.value(); + } else { + packed_recv_src_info = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank}, + torch::dtype(torch::kInt32).device(torch::kCUDA)); + } + torch::Tensor packed_recv_layout_range; + if (out_packed_recv_layout_range.has_value()) { + EP_HOST_ASSERT(out_packed_recv_layout_range->dim() == 2 and out_packed_recv_layout_range->is_contiguous()); + EP_HOST_ASSERT(out_packed_recv_layout_range->size(0) == num_local_experts); + EP_HOST_ASSERT(out_packed_recv_layout_range->size(1) == num_ranks); + EP_HOST_ASSERT(out_packed_recv_layout_range->scalar_type() == torch::kInt64); + packed_recv_layout_range = out_packed_recv_layout_range.value(); + } else { + packed_recv_layout_range = torch::empty({num_local_experts, num_ranks}, + torch::dtype(torch::kInt64).device(torch::kCUDA)); + } + torch::Tensor packed_recv_count; + if (out_packed_recv_count.has_value()) { + EP_HOST_ASSERT(out_packed_recv_count->dim() == 1 and out_packed_recv_count->is_contiguous()); + EP_HOST_ASSERT(out_packed_recv_count->size(0) == num_local_experts); + EP_HOST_ASSERT(out_packed_recv_count->scalar_type() == torch::kInt32); + packed_recv_count = out_packed_recv_count.value(); + } else { + packed_recv_count = torch::empty({num_local_experts}, torch::dtype(torch::kInt32).device(torch::kCUDA)); + } auto packed_recv_x_scales = std::optional(); float* packed_recv_x_scales_ptr = nullptr; if (use_fp8) { EP_HOST_ASSERT((num_ranks * num_max_dispatch_tokens_per_rank) % 4 == 0 and "TMA requires the number of tokens to be multiple of 4"); - packed_recv_x_scales = torch::empty({num_local_experts, num_scales, num_ranks * num_max_dispatch_tokens_per_rank}, torch::dtype(torch::kFloat32).device(torch::kCUDA)); - packed_recv_x_scales = torch::transpose(packed_recv_x_scales.value(), 1, 2); + if (out_packed_recv_x_scales.has_value()) { + // Caller-provided scales tensor must already be in the kernel's + // expected (transposed) layout: shape [num_local_experts, + // num_ranks*max_tokens, num_scales], strides such that + // size(1)=num_ranks*max_tokens with the actual storage + // [num_local_experts, num_scales, num_ranks*max_tokens] (i.e. + // produced by `torch.empty(...).transpose(1, 2)`). + EP_HOST_ASSERT(out_packed_recv_x_scales->dim() == 3); + EP_HOST_ASSERT(out_packed_recv_x_scales->size(0) == num_local_experts); + EP_HOST_ASSERT(out_packed_recv_x_scales->size(1) == num_ranks * num_max_dispatch_tokens_per_rank); + EP_HOST_ASSERT(out_packed_recv_x_scales->size(2) == num_scales); + EP_HOST_ASSERT(out_packed_recv_x_scales->scalar_type() == torch::kFloat32); + packed_recv_x_scales = out_packed_recv_x_scales.value(); + } else { + packed_recv_x_scales = torch::empty({num_local_experts, num_scales, num_ranks * num_max_dispatch_tokens_per_rank}, + torch::dtype(torch::kFloat32).device(torch::kCUDA)); + packed_recv_x_scales = torch::transpose(packed_recv_x_scales.value(), 1, 2); + } packed_recv_x_scales_ptr = packed_recv_x_scales->data_ptr(); } diff --git a/src/ext/ep/buffer.hpp b/src/ext/ep/buffer.hpp index 28c8c34d..894fd9a7 100644 --- a/src/ext/ep/buffer.hpp +++ b/src/ext/ep/buffer.hpp @@ -172,7 +172,12 @@ public: std::tuple, torch::Tensor, torch::Tensor, torch::Tensor, std::optional, std::optional>> low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx, int num_max_dispatch_tokens_per_rank, int num_experts, - bool use_fp8, bool async, bool return_recv_hook); + bool use_fp8, bool async, bool return_recv_hook, + const std::optional& out_packed_recv_x = std::nullopt, + const std::optional& out_packed_recv_x_scales = std::nullopt, + const std::optional& out_packed_recv_src_info = std::nullopt, + const std::optional& out_packed_recv_layout_range = std::nullopt, + const std::optional& out_packed_recv_count = std::nullopt); std::tuple, std::optional>> low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_idx, const torch::Tensor& topk_weights, diff --git a/test/python/ext/ep/test_low_latency_multirank.py b/test/python/ext/ep/test_low_latency_multirank.py index f6206136..bb87763f 100644 --- a/test/python/ext/ep/test_low_latency_multirank.py +++ b/test/python/ext/ep/test_low_latency_multirank.py @@ -221,10 +221,35 @@ def main(): warmup = int(os.environ.get("MSCCLPP_EP_BENCH_WARMUP", "5")) iters = int(os.environ.get("MSCCLPP_EP_BENCH_ITERS", "20")) + # Hoist dispatch's output tensors out of the timed loop. The largest + # (`packed_recv_x`, ~58 MB at 7K hidden) costs ~10us cumulative across + # the four torch::empty calls per iter; reusing them brings the bench + # in line with NCCL-EP `ep_bench` which preallocates output buffers. + num_local_experts = num_experts // num_ranks + bench_packed_recv_x = torch.empty( + (num_local_experts, num_ranks * num_tokens, hidden), + dtype=torch.bfloat16, device="cuda", + ) + bench_packed_recv_src_info = torch.empty( + (num_local_experts, num_ranks * num_tokens), + dtype=torch.int32, device="cuda", + ) + bench_packed_recv_layout_range = torch.empty( + (num_local_experts, num_ranks), dtype=torch.int64, device="cuda", + ) + bench_packed_recv_count = torch.empty( + (num_local_experts,), dtype=torch.int32, device="cuda", + ) + def _dispatch(): return buf.low_latency_dispatch( x, topk_idx, num_tokens, num_experts, False, False, False, # use_fp8, async, return_recv_hook + bench_packed_recv_x, + None, # x_scales (FP8 only) + bench_packed_recv_src_info, + bench_packed_recv_layout_range, + bench_packed_recv_count, ) # Hoist combine's output-tensor allocation out of the timed loop so the