diff --git a/python/mscclpp/ext/ep/buffer.py b/python/mscclpp/ext/ep/buffer.py index c747f750..1c2f9604 100644 --- a/python/mscclpp/ext/ep/buffer.py +++ b/python/mscclpp/ext/ep/buffer.py @@ -177,7 +177,9 @@ class Buffer: def get_next_low_latency_combine_buffer(self, num_max_dispatch_tokens_per_rank: int, hidden: int, num_experts: int): return self.runtime.get_next_low_latency_combine_buffer(num_max_dispatch_tokens_per_rank, hidden, num_experts) - def get_local_buffer_tensor(self, dtype: torch.dtype, offset: int = 0, use_rdma_buffer: bool = False) -> torch.Tensor: + def get_local_buffer_tensor( + self, dtype: torch.dtype, offset: int = 0, use_rdma_buffer: bool = False + ) -> torch.Tensor: return self.runtime.get_local_buffer_tensor(dtype, offset, use_rdma_buffer) # ------------------------------------------------------------------ diff --git a/src/ext/ep/bindings.cpp b/src/ext/ep/bindings.cpp index e4f1d9fd..bf3143a5 100644 --- a/src/ext/ep/bindings.cpp +++ b/src/ext/ep/bindings.cpp @@ -22,9 +22,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.doc() = "MSCCL++ Expert-Parallel (MoE dispatch/combine) extension"; py::class_(m, "Config") - .def(py::init(), py::arg("num_sms") = 20, - py::arg("num_max_nvl_chunked_send_tokens") = 6, py::arg("num_max_nvl_chunked_recv_tokens") = 256, - py::arg("num_max_rdma_chunked_send_tokens") = 6, py::arg("num_max_rdma_chunked_recv_tokens") = 256) + .def(py::init(), py::arg("num_sms") = 20, py::arg("num_max_nvl_chunked_send_tokens") = 6, + py::arg("num_max_nvl_chunked_recv_tokens") = 256, py::arg("num_max_rdma_chunked_send_tokens") = 6, + py::arg("num_max_rdma_chunked_recv_tokens") = 256) .def("get_nvl_buffer_size_hint", &mscclpp::ep::Config::get_nvl_buffer_size_hint) .def("get_rdma_buffer_size_hint", &mscclpp::ep::Config::get_rdma_buffer_size_hint); @@ -72,15 +72,11 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { .def("internode_dispatch", &mscclpp::ep::Buffer::internode_dispatch) .def("internode_combine", &mscclpp::ep::Buffer::internode_combine) .def("clean_low_latency_buffer", &mscclpp::ep::Buffer::clean_low_latency_buffer) - .def("low_latency_dispatch", &mscclpp::ep::Buffer::low_latency_dispatch, - py::arg("x"), py::arg("topk_idx"), - py::arg("num_max_dispatch_tokens_per_rank"), py::arg("num_experts"), - py::arg("use_fp8"), py::arg("async"), py::arg("return_recv_hook"), - py::arg("out_packed_recv_x") = py::none(), - py::arg("out_packed_recv_x_scales") = py::none(), - py::arg("out_packed_recv_src_info") = py::none(), - py::arg("out_packed_recv_layout_range") = py::none(), - py::arg("out_packed_recv_count") = py::none()) + .def("low_latency_dispatch", &mscclpp::ep::Buffer::low_latency_dispatch, py::arg("x"), py::arg("topk_idx"), + py::arg("num_max_dispatch_tokens_per_rank"), py::arg("num_experts"), py::arg("use_fp8"), py::arg("async"), + py::arg("return_recv_hook"), py::arg("out_packed_recv_x") = py::none(), + py::arg("out_packed_recv_x_scales") = py::none(), py::arg("out_packed_recv_src_info") = py::none(), + py::arg("out_packed_recv_layout_range") = py::none(), py::arg("out_packed_recv_count") = py::none()) .def("low_latency_combine", &mscclpp::ep::Buffer::low_latency_combine) .def("get_next_low_latency_combine_buffer", &mscclpp::ep::Buffer::get_next_low_latency_combine_buffer); } diff --git a/src/ext/ep/buffer.cc b/src/ext/ep/buffer.cc index 1fd95169..35702f8c 100644 --- a/src/ext/ep/buffer.cc +++ b/src/ext/ep/buffer.cc @@ -1,19 +1,22 @@ +#include "buffer.hpp" + #include #include -#include -#include #include -#include -#include #include #include + +#include +#include +#include +#include #include -#include "buffer.hpp" #include "kernels/api.cuh" #include "kernels/configs.cuh" -namespace mscclpp { namespace ep { +namespace mscclpp { +namespace ep { // Upstream MSCCL++ now exposes `Connection::atomicAdd` and // `PortChannelDeviceHandle::atomicAdd` natively (see commit "atomic add" @@ -42,880 +45,881 @@ using EPProxyService = mscclpp::ProxyService; // N=8: D 445 us / C 469 us <-- knee // N=12: collapses (CPU oversubscription with 8 GPUs/node). static int resolve_num_proxy_services() { - if (const char* env = std::getenv("MSCCLPP_EP_NUM_PROXIES")) { - int v = std::atoi(env); - return v > 0 ? v : 1; - } - int dev = 0; - if (cudaGetDevice(&dev) != cudaSuccess) return 8; - cudaDeviceProp prop{}; - if (cudaGetDeviceProperties(&prop, dev) != cudaSuccess) return 8; - // sm_100+ = Blackwell (GB200 etc.) -- NVSwitch fabric, host proxy not the - // bottleneck; sm_90 (Hopper) and earlier benefit from sharding for IB. - if (prop.major >= 10) return 1; - return 8; + if (const char* env = std::getenv("MSCCLPP_EP_NUM_PROXIES")) { + int v = std::atoi(env); + return v > 0 ? v : 1; + } + int dev = 0; + if (cudaGetDevice(&dev) != cudaSuccess) return 8; + cudaDeviceProp prop{}; + if (cudaGetDeviceProperties(&prop, dev) != cudaSuccess) return 8; + // sm_100+ = Blackwell (GB200 etc.) -- NVSwitch fabric, host proxy not the + // bottleneck; sm_90 (Hopper) and earlier benefit from sharding for IB. + if (prop.major >= 10) return 1; + return 8; } -Buffer::Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_bytes, bool low_latency_mode): - rank(rank), num_ranks(num_ranks), - num_nvl_bytes(num_nvl_bytes), num_rdma_bytes(num_rdma_bytes), - low_latency_mode(low_latency_mode), - comm_stream(at::cuda::getStreamFromPool(true)), - bootstrap(std::make_shared(rank, num_ranks)) { - num_proxy_services = resolve_num_proxy_services(); - proxy_services.reserve(num_proxy_services); - for (int i = 0; i < num_proxy_services; ++i) { - proxy_services.emplace_back(std::make_shared()); - } - if (rank == 0) { - printf("[mscclpp_ep] num_proxy_services=%d (set MSCCLPP_EP_NUM_PROXIES to override)\n", - num_proxy_services); - fflush(stdout); - } - // Task fifo memory - int64_t fifo_bytes = sizeof(int) * NUM_MAX_FIFO_SLOTS; - int64_t buffer_ptr_bytes = sizeof(void*) * NUM_MAX_NVL_PEERS; - int64_t task_ptr_bytes = sizeof(int*) * NUM_MAX_NVL_PEERS; +Buffer::Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_bytes, bool low_latency_mode) + : rank(rank), + num_ranks(num_ranks), + num_nvl_bytes(num_nvl_bytes), + num_rdma_bytes(num_rdma_bytes), + low_latency_mode(low_latency_mode), + comm_stream(at::cuda::getStreamFromPool(true)), + bootstrap(std::make_shared(rank, num_ranks)) { + num_proxy_services = resolve_num_proxy_services(); + proxy_services.reserve(num_proxy_services); + for (int i = 0; i < num_proxy_services; ++i) { + proxy_services.emplace_back(std::make_shared()); + } + if (rank == 0) { + printf("[mscclpp_ep] num_proxy_services=%d (set MSCCLPP_EP_NUM_PROXIES to override)\n", num_proxy_services); + fflush(stdout); + } + // Task fifo memory + int64_t fifo_bytes = sizeof(int) * NUM_MAX_FIFO_SLOTS; + int64_t buffer_ptr_bytes = sizeof(void*) * NUM_MAX_NVL_PEERS; + int64_t task_ptr_bytes = sizeof(int*) * NUM_MAX_NVL_PEERS; - // Common checks - EP_HOST_ASSERT(num_nvl_bytes % NUM_BUFFER_ALIGNMENT_BYTES == 0 and (num_nvl_bytes <= std::numeric_limits::max() or num_rdma_bytes == 0)); - EP_HOST_ASSERT(num_rdma_bytes % NUM_BUFFER_ALIGNMENT_BYTES == 0 and (low_latency_mode or num_rdma_bytes <= std::numeric_limits::max())); - EP_HOST_ASSERT(0 <= rank and rank < num_ranks and (num_ranks <= NUM_MAX_NVL_PEERS * NUM_MAX_RDMA_PEERS or low_latency_mode)); - EP_HOST_ASSERT(num_ranks < NUM_MAX_NVL_PEERS or num_ranks % NUM_MAX_NVL_PEERS == 0); - if (num_rdma_bytes > 0) - EP_HOST_ASSERT(num_ranks > NUM_MAX_NVL_PEERS or low_latency_mode); + // Common checks + EP_HOST_ASSERT(num_nvl_bytes % NUM_BUFFER_ALIGNMENT_BYTES == 0 and + (num_nvl_bytes <= std::numeric_limits::max() or num_rdma_bytes == 0)); + EP_HOST_ASSERT(num_rdma_bytes % NUM_BUFFER_ALIGNMENT_BYTES == 0 and + (low_latency_mode or num_rdma_bytes <= std::numeric_limits::max())); + EP_HOST_ASSERT(0 <= rank and rank < num_ranks and + (num_ranks <= NUM_MAX_NVL_PEERS * NUM_MAX_RDMA_PEERS or low_latency_mode)); + EP_HOST_ASSERT(num_ranks < NUM_MAX_NVL_PEERS or num_ranks % NUM_MAX_NVL_PEERS == 0); + if (num_rdma_bytes > 0) EP_HOST_ASSERT(num_ranks > NUM_MAX_NVL_PEERS or low_latency_mode); - // Get ranks - CUDA_CHECK(cudaGetDevice(&device_id)); - rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS; - num_rdma_ranks = std::max(1, num_ranks / NUM_MAX_NVL_PEERS), num_nvl_ranks = std::min(num_ranks, NUM_MAX_NVL_PEERS); + // Get ranks + CUDA_CHECK(cudaGetDevice(&device_id)); + rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS; + num_rdma_ranks = std::max(1, num_ranks / NUM_MAX_NVL_PEERS), num_nvl_ranks = std::min(num_ranks, NUM_MAX_NVL_PEERS); - // Get device info - cudaDeviceProp device_prop = {}; - CUDA_CHECK(cudaGetDeviceProperties(&device_prop, device_id)); + // Get device info + cudaDeviceProp device_prop = {}; + CUDA_CHECK(cudaGetDeviceProperties(&device_prop, device_id)); - if (num_nvl_bytes > 0) { - // Local IPC: alloc local memory and set local IPC handle - CUDA_CHECK(cudaMalloc(&buffer_ptrs[nvl_rank], num_nvl_bytes + fifo_bytes + buffer_ptr_bytes + task_ptr_bytes)); - CUDA_CHECK(cudaIpcGetMemHandle(&ipc_handles[nvl_rank], buffer_ptrs[nvl_rank])); - buffer_ptrs_gpu = reinterpret_cast(reinterpret_cast(buffer_ptrs[nvl_rank]) + num_nvl_bytes + fifo_bytes); + if (num_nvl_bytes > 0) { + // Local IPC: alloc local memory and set local IPC handle + CUDA_CHECK(cudaMalloc(&buffer_ptrs[nvl_rank], num_nvl_bytes + fifo_bytes + buffer_ptr_bytes + task_ptr_bytes)); + CUDA_CHECK(cudaIpcGetMemHandle(&ipc_handles[nvl_rank], buffer_ptrs[nvl_rank])); + buffer_ptrs_gpu = + reinterpret_cast(reinterpret_cast(buffer_ptrs[nvl_rank]) + num_nvl_bytes + fifo_bytes); - // Set task fifo - EP_HOST_ASSERT(NUM_MAX_FIFO_SLOTS % num_nvl_ranks == 0); - task_fifo_ptrs[nvl_rank] = reinterpret_cast(reinterpret_cast(buffer_ptrs[nvl_rank]) + num_nvl_bytes); - task_fifo_ptrs_gpu = reinterpret_cast(reinterpret_cast(buffer_ptrs[nvl_rank]) + num_nvl_bytes + fifo_bytes + buffer_ptr_bytes); + // Set task fifo + EP_HOST_ASSERT(NUM_MAX_FIFO_SLOTS % num_nvl_ranks == 0); + task_fifo_ptrs[nvl_rank] = + reinterpret_cast(reinterpret_cast(buffer_ptrs[nvl_rank]) + num_nvl_bytes); + task_fifo_ptrs_gpu = reinterpret_cast(reinterpret_cast(buffer_ptrs[nvl_rank]) + num_nvl_bytes + + fifo_bytes + buffer_ptr_bytes); - // No need to synchronize, will do a full device sync during `sync` - CUDA_CHECK(cudaMemsetAsync(task_fifo_ptrs[nvl_rank], 0, fifo_bytes, comm_stream)); - } + // No need to synchronize, will do a full device sync during `sync` + CUDA_CHECK(cudaMemsetAsync(task_fifo_ptrs[nvl_rank], 0, fifo_bytes, comm_stream)); + } - // Create 32 MiB workspace - CUDA_CHECK(cudaMalloc(&workspace, NUM_WORKSPACE_BYTES)); - CUDA_CHECK(cudaMemsetAsync(workspace, 0, NUM_WORKSPACE_BYTES, comm_stream)); + // Create 32 MiB workspace + CUDA_CHECK(cudaMalloc(&workspace, NUM_WORKSPACE_BYTES)); + CUDA_CHECK(cudaMemsetAsync(workspace, 0, NUM_WORKSPACE_BYTES, comm_stream)); - // MoE counter - CUDA_CHECK(cudaMallocHost(&moe_recv_counter, sizeof(int64_t), cudaHostAllocMapped)); - CUDA_CHECK(cudaHostGetDevicePointer(&moe_recv_counter_mapped, const_cast(moe_recv_counter), 0)); - *moe_recv_counter = -1; + // MoE counter + CUDA_CHECK(cudaMallocHost(&moe_recv_counter, sizeof(int64_t), cudaHostAllocMapped)); + CUDA_CHECK(cudaHostGetDevicePointer(&moe_recv_counter_mapped, const_cast(moe_recv_counter), 0)); + *moe_recv_counter = -1; - // MoE expert-level counter - CUDA_CHECK(cudaMallocHost(&moe_recv_expert_counter, sizeof(int) * NUM_MAX_LOCAL_EXPERTS, cudaHostAllocMapped)); - CUDA_CHECK(cudaHostGetDevicePointer(&moe_recv_expert_counter_mapped, const_cast(moe_recv_expert_counter), 0)); - for (int i = 0; i < NUM_MAX_LOCAL_EXPERTS; ++ i) - moe_recv_expert_counter[i] = -1; + // MoE expert-level counter + CUDA_CHECK(cudaMallocHost(&moe_recv_expert_counter, sizeof(int) * NUM_MAX_LOCAL_EXPERTS, cudaHostAllocMapped)); + CUDA_CHECK(cudaHostGetDevicePointer(&moe_recv_expert_counter_mapped, const_cast(moe_recv_expert_counter), 0)); + for (int i = 0; i < NUM_MAX_LOCAL_EXPERTS; ++i) moe_recv_expert_counter[i] = -1; - // MoE RDMA-level counter - if (num_rdma_ranks > 0) { - CUDA_CHECK(cudaMallocHost(&moe_recv_rdma_counter, sizeof(int), cudaHostAllocMapped)); - CUDA_CHECK(cudaHostGetDevicePointer(&moe_recv_rdma_counter_mapped, const_cast(moe_recv_rdma_counter), 0)); - *moe_recv_rdma_counter = -1; - } + // MoE RDMA-level counter + if (num_rdma_ranks > 0) { + CUDA_CHECK(cudaMallocHost(&moe_recv_rdma_counter, sizeof(int), cudaHostAllocMapped)); + CUDA_CHECK(cudaHostGetDevicePointer(&moe_recv_rdma_counter_mapped, const_cast(moe_recv_rdma_counter), 0)); + *moe_recv_rdma_counter = -1; + } - for (auto& ps : proxy_services) ps->startProxy(); + for (auto& ps : proxy_services) ps->startProxy(); } Buffer::~Buffer() noexcept(false) { - // Synchronize + // Synchronize + CUDA_CHECK(cudaDeviceSynchronize()); + + if (num_nvl_bytes > 0) { + // Barrier + intranode::barrier(task_fifo_ptrs_gpu, head, nvl_rank, num_nvl_ranks, comm_stream); + move_fifo_slots(); CUDA_CHECK(cudaDeviceSynchronize()); - if (num_nvl_bytes > 0) { - // Barrier - intranode::barrier(task_fifo_ptrs_gpu, head, nvl_rank, num_nvl_ranks, comm_stream); - move_fifo_slots(); - CUDA_CHECK(cudaDeviceSynchronize()); - - // Close remote IPC - if (is_available()) { - for (int i = 0; i < num_nvl_ranks; ++ i) if (i != nvl_rank) - CUDA_CHECK(cudaIpcCloseMemHandle(buffer_ptrs[i])); - } - - // Free local buffer and error flag - CUDA_CHECK(cudaFree(buffer_ptrs[nvl_rank])); + // Close remote IPC + if (is_available()) { + for (int i = 0; i < num_nvl_ranks; ++i) + if (i != nvl_rank) CUDA_CHECK(cudaIpcCloseMemHandle(buffer_ptrs[i])); } - // Free NVSHMEM - if (num_rdma_bytes > 0) { - // NVSHMEM support is not yet ported; if we got here with - // num_rdma_bytes > 0 the construction or sync call would already have - // failed, so there is nothing to tear down. + // Free local buffer and error flag + CUDA_CHECK(cudaFree(buffer_ptrs[nvl_rank])); + } + + // Free NVSHMEM + if (num_rdma_bytes > 0) { + // NVSHMEM support is not yet ported; if we got here with + // num_rdma_bytes > 0 the construction or sync call would already have + // failed, so there is nothing to tear down. + } + + // Intra-node LL IPC fast-path teardown. + if (ll_ipc_ready) { + for (int i = 0; i < num_ranks; ++i) { + if (i == rank or peer_rdma_bases[i] == nullptr) continue; + CUDA_CHECK(cudaIpcCloseMemHandle(peer_rdma_bases[i])); } - - // Intra-node LL IPC fast-path teardown. - if (ll_ipc_ready) { - for (int i = 0; i < num_ranks; ++i) { - if (i == rank or peer_rdma_bases[i] == nullptr) continue; - CUDA_CHECK(cudaIpcCloseMemHandle(peer_rdma_bases[i])); - } - if (peer_rdma_bases_gpu != nullptr) { - CUDA_CHECK(cudaFree(peer_rdma_bases_gpu)); - peer_rdma_bases_gpu = nullptr; - } + if (peer_rdma_bases_gpu != nullptr) { + CUDA_CHECK(cudaFree(peer_rdma_bases_gpu)); + peer_rdma_bases_gpu = nullptr; } + } - for (auto& ps : proxy_services) ps->stopProxy(); + for (auto& ps : proxy_services) ps->stopProxy(); - // Free cuBLAS handle, workspace and MoE counter - CUDA_CHECK(cudaFree(workspace)); - CUDA_CHECK(cudaFreeHost(const_cast(moe_recv_counter))); + // Free cuBLAS handle, workspace and MoE counter + CUDA_CHECK(cudaFree(workspace)); + CUDA_CHECK(cudaFreeHost(const_cast(moe_recv_counter))); - // Free chunked mode staffs - CUDA_CHECK(cudaFreeHost(const_cast(moe_recv_expert_counter))); + // Free chunked mode staffs + CUDA_CHECK(cudaFreeHost(const_cast(moe_recv_expert_counter))); } -void Buffer::move_fifo_slots(int num_slots) { - head = (head + num_ranks * num_slots) % NUM_MAX_FIFO_SLOTS; -} +void Buffer::move_fifo_slots(int num_slots) { head = (head + num_ranks * num_slots) % NUM_MAX_FIFO_SLOTS; } -bool Buffer::is_available() const { - return available; -} +bool Buffer::is_available() const { return available; } -bool Buffer::is_internode_available() const { - return is_available() and num_ranks > NUM_MAX_NVL_PEERS; -} +bool Buffer::is_internode_available() const { return is_available() and num_ranks > NUM_MAX_NVL_PEERS; } -int Buffer::get_num_rdma_ranks() const { - return num_rdma_ranks; -} +int Buffer::get_num_rdma_ranks() const { return num_rdma_ranks; } -int Buffer::get_rdma_rank() const { - return rdma_rank; -} +int Buffer::get_rdma_rank() const { return rdma_rank; } -int Buffer::get_root_rdma_rank(bool global) const { - return global ? nvl_rank : 0; -} +int Buffer::get_root_rdma_rank(bool global) const { return global ? nvl_rank : 0; } -int Buffer::get_local_device_id() const { - return device_id; -} +int Buffer::get_local_device_id() const { return device_id; } pybind11::bytearray Buffer::get_local_ipc_handle() const { - return {ipc_handles[nvl_rank].reserved, CUDA_IPC_HANDLE_SIZE}; + return {ipc_handles[nvl_rank].reserved, CUDA_IPC_HANDLE_SIZE}; } pybind11::bytearray Buffer::get_local_nvshmem_unique_id() const { - // The MSCCL++ EP port replaces NVSHMEM with PortChannel/MemoryChannel, - // so there is no NVSHMEM unique id to expose. Kept for ABI parity with - // DeepEP's Python frontend; callers should use the MSCCL++ bootstrap. - throw std::runtime_error("mscclpp::ep::Buffer::get_local_nvshmem_unique_id: not applicable (NVSHMEM is not used in mscclpp_ep)"); + // The MSCCL++ EP port replaces NVSHMEM with PortChannel/MemoryChannel, + // so there is no NVSHMEM unique id to expose. Kept for ABI parity with + // DeepEP's Python frontend; callers should use the MSCCL++ bootstrap. + throw std::runtime_error( + "mscclpp::ep::Buffer::get_local_nvshmem_unique_id: not applicable (NVSHMEM is not used in mscclpp_ep)"); } -torch::Tensor Buffer::get_local_buffer_tensor(const pybind11::object& dtype, int64_t offset, bool use_rdma_buffer) const { - torch::ScalarType casted_dtype = torch::python::detail::py_object_to_dtype(dtype); - auto element_bytes = static_cast(elementSize(casted_dtype)); - auto base_ptr = reinterpret_cast(use_rdma_buffer ? rdma_buffer_ptr : buffer_ptrs[nvl_rank]) + offset; - auto num_bytes = use_rdma_buffer ? num_rdma_bytes : num_nvl_bytes; - return torch::from_blob(base_ptr, num_bytes / element_bytes, torch::TensorOptions().dtype(casted_dtype).device(at::kCUDA)); +torch::Tensor Buffer::get_local_buffer_tensor(const pybind11::object& dtype, int64_t offset, + bool use_rdma_buffer) const { + torch::ScalarType casted_dtype = torch::python::detail::py_object_to_dtype(dtype); + auto element_bytes = static_cast(elementSize(casted_dtype)); + auto base_ptr = reinterpret_cast(use_rdma_buffer ? rdma_buffer_ptr : buffer_ptrs[nvl_rank]) + offset; + auto num_bytes = use_rdma_buffer ? num_rdma_bytes : num_nvl_bytes; + return torch::from_blob(base_ptr, num_bytes / element_bytes, + torch::TensorOptions().dtype(casted_dtype).device(at::kCUDA)); } -mscclpp::UniqueId Buffer::create_unique_id() const { - return bootstrap->createUniqueId(); -} +mscclpp::UniqueId Buffer::create_unique_id() const { return bootstrap->createUniqueId(); } void Buffer::connect(mscclpp::UniqueId root_id) { - bootstrap->initialize(root_id); - communicator = std::make_shared(bootstrap); + bootstrap->initialize(root_id); + communicator = std::make_shared(bootstrap); } -void Buffer::sync(const std::vector &device_ids, - const std::vector> &all_gathered_handles, +void Buffer::sync(const std::vector& device_ids, + const std::vector>& all_gathered_handles, const std::optional& root_unique_id_opt) { - EP_HOST_ASSERT(not is_available()); + EP_HOST_ASSERT(not is_available()); - const std::vector ib_transports = {mscclpp::Transport::IB0, mscclpp::Transport::IB1, - mscclpp::Transport::IB2, mscclpp::Transport::IB3, mscclpp::Transport::IB4, - mscclpp::Transport::IB5, mscclpp::Transport::IB6, mscclpp::Transport::IB7}; - const auto ipc_transport = mscclpp::Transport::CudaIpc; - const auto ib_transport = ib_transports[device_id]; - const mscclpp::TransportFlags all_transport = ipc_transport | ib_transport; + const std::vector ib_transports = { + mscclpp::Transport::IB0, mscclpp::Transport::IB1, mscclpp::Transport::IB2, mscclpp::Transport::IB3, + mscclpp::Transport::IB4, mscclpp::Transport::IB5, mscclpp::Transport::IB6, mscclpp::Transport::IB7}; + const auto ipc_transport = mscclpp::Transport::CudaIpc; + const auto ib_transport = ib_transports[device_id]; + const mscclpp::TransportFlags all_transport = ipc_transport | ib_transport; - // Sync IPC handles - if (num_nvl_bytes > 0) { - EP_HOST_ASSERT(num_ranks == device_ids.size()); - EP_HOST_ASSERT(device_ids.size() == all_gathered_handles.size()); - for (int i = 0, offset = rdma_rank * num_nvl_ranks; i < num_nvl_ranks; ++ i) { - EP_HOST_ASSERT(all_gathered_handles[offset + i].has_value()); - auto handle_str = std::string(all_gathered_handles[offset + i].value()); - EP_HOST_ASSERT(handle_str.size() == CUDA_IPC_HANDLE_SIZE); - if (offset + i != rank) { - std::memcpy(ipc_handles[i].reserved, handle_str.c_str(), CUDA_IPC_HANDLE_SIZE); - CUDA_CHECK(cudaIpcOpenMemHandle(&buffer_ptrs[i], ipc_handles[i], cudaIpcMemLazyEnablePeerAccess)); - task_fifo_ptrs[i] = reinterpret_cast(reinterpret_cast(buffer_ptrs[i]) + num_nvl_bytes); - } else { - EP_HOST_ASSERT(std::memcmp(ipc_handles[i].reserved, handle_str.c_str(), CUDA_IPC_HANDLE_SIZE) == 0); - } - } - - // Copy all buffer and task pointers to GPU - CUDA_CHECK(cudaMemcpy(buffer_ptrs_gpu, buffer_ptrs, sizeof(void*) * NUM_MAX_NVL_PEERS, cudaMemcpyHostToDevice)); - CUDA_CHECK(cudaMemcpy(task_fifo_ptrs_gpu, task_fifo_ptrs, sizeof(int*) * NUM_MAX_NVL_PEERS, cudaMemcpyHostToDevice)); - CUDA_CHECK(cudaDeviceSynchronize()); - - // create connections - std::vector connections; - { - std::vector> connection_futures; - mscclpp::EndpointConfig local_config(ipc_transport); - for (int i = 0; i < num_nvl_ranks; ++i) { - auto r = i + rdma_rank * num_nvl_ranks; - connection_futures.emplace_back(communicator->connect(local_config, r, 0)); - } - for (auto& future : connection_futures) { - connections.emplace_back(future.get()); - } - } - - auto buffer_mem = communicator->registerMemory(buffer_ptrs[nvl_rank], num_nvl_bytes, ipc_transport); - - std::vector> remote_mem_futures(num_nvl_ranks); - for (int i = 0; i < num_nvl_ranks; ++i) { - if (i == nvl_rank) continue; - auto r = i + rdma_rank * num_nvl_ranks; - communicator->sendMemory(buffer_mem, r, 0); - remote_mem_futures[i] = communicator->recvMemory(r, 0); - } - for (int i = 0; i < num_nvl_ranks; ++i) { - if (i == nvl_rank) continue; - auto sema = std::make_shared(*communicator, connections[i]); - memory_channels.emplace_back(sema, remote_mem_futures[i].get(), buffer_mem); - } - std::vector memory_channel_handles(num_nvl_ranks); - for (int i = 0; i < num_nvl_ranks; ++i) { - if (i == nvl_rank) continue; - memory_channel_handles[i] = memory_channels.rbegin()->deviceHandle(); - } - - memory_channel_handles_device_ptr = mscclpp::detail::gpuCallocShared(num_nvl_ranks); - mscclpp::gpuMemcpy( - memory_channel_handles_device_ptr.get(), memory_channel_handles.data(), num_nvl_ranks, - cudaMemcpyHostToDevice); + // Sync IPC handles + if (num_nvl_bytes > 0) { + EP_HOST_ASSERT(num_ranks == device_ids.size()); + EP_HOST_ASSERT(device_ids.size() == all_gathered_handles.size()); + for (int i = 0, offset = rdma_rank * num_nvl_ranks; i < num_nvl_ranks; ++i) { + EP_HOST_ASSERT(all_gathered_handles[offset + i].has_value()); + auto handle_str = std::string(all_gathered_handles[offset + i].value()); + EP_HOST_ASSERT(handle_str.size() == CUDA_IPC_HANDLE_SIZE); + if (offset + i != rank) { + std::memcpy(ipc_handles[i].reserved, handle_str.c_str(), CUDA_IPC_HANDLE_SIZE); + CUDA_CHECK(cudaIpcOpenMemHandle(&buffer_ptrs[i], ipc_handles[i], cudaIpcMemLazyEnablePeerAccess)); + task_fifo_ptrs[i] = reinterpret_cast(reinterpret_cast(buffer_ptrs[i]) + num_nvl_bytes); + } else { + EP_HOST_ASSERT(std::memcmp(ipc_handles[i].reserved, handle_str.c_str(), CUDA_IPC_HANDLE_SIZE) == 0); + } } - // RDMA buffer setup (replaces DeepEP's NVSHMEM symmetric-heap allocation). + // Copy all buffer and task pointers to GPU + CUDA_CHECK(cudaMemcpy(buffer_ptrs_gpu, buffer_ptrs, sizeof(void*) * NUM_MAX_NVL_PEERS, cudaMemcpyHostToDevice)); + CUDA_CHECK( + cudaMemcpy(task_fifo_ptrs_gpu, task_fifo_ptrs, sizeof(int*) * NUM_MAX_NVL_PEERS, cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaDeviceSynchronize()); + + // create connections + std::vector connections; + { + std::vector> connection_futures; + mscclpp::EndpointConfig local_config(ipc_transport); + for (int i = 0; i < num_nvl_ranks; ++i) { + auto r = i + rdma_rank * num_nvl_ranks; + connection_futures.emplace_back(communicator->connect(local_config, r, 0)); + } + for (auto& future : connection_futures) { + connections.emplace_back(future.get()); + } + } + + auto buffer_mem = communicator->registerMemory(buffer_ptrs[nvl_rank], num_nvl_bytes, ipc_transport); + + std::vector> remote_mem_futures(num_nvl_ranks); + for (int i = 0; i < num_nvl_ranks; ++i) { + if (i == nvl_rank) continue; + auto r = i + rdma_rank * num_nvl_ranks; + communicator->sendMemory(buffer_mem, r, 0); + remote_mem_futures[i] = communicator->recvMemory(r, 0); + } + for (int i = 0; i < num_nvl_ranks; ++i) { + if (i == nvl_rank) continue; + auto sema = std::make_shared(*communicator, connections[i]); + memory_channels.emplace_back(sema, remote_mem_futures[i].get(), buffer_mem); + } + std::vector memory_channel_handles(num_nvl_ranks); + for (int i = 0; i < num_nvl_ranks; ++i) { + if (i == nvl_rank) continue; + memory_channel_handles[i] = memory_channels.rbegin()->deviceHandle(); + } + + memory_channel_handles_device_ptr = + mscclpp::detail::gpuCallocShared(num_nvl_ranks); + mscclpp::gpuMemcpy( + memory_channel_handles_device_ptr.get(), memory_channel_handles.data(), num_nvl_ranks, cudaMemcpyHostToDevice); + } + + // RDMA buffer setup (replaces DeepEP's NVSHMEM symmetric-heap allocation). + // + // Unlike DeepEP which used `nvshmem_align` to place the RDMA buffer on the + // symmetric heap, all our internode communication goes through MSCCL++ + // `PortChannel` (proxy-based RDMA), so a plain `cudaMalloc` + IB memory + // registration is sufficient. The bootstrap barrier replaces + // `nvshmem_barrier_all`. + if (num_rdma_bytes > 0) { + EP_HOST_ASSERT(communicator != nullptr); + EP_HOST_ASSERT(bootstrap != nullptr); + + // Allocate the RDMA buffer + CUDA_CHECK(cudaMalloc(&rdma_buffer_ptr, num_rdma_bytes)); + CUDA_CHECK(cudaMemset(rdma_buffer_ptr, 0, num_rdma_bytes)); + bootstrap->barrier(); + CUDA_CHECK(cudaDeviceSynchronize()); + + // Rank -> RDMA buffer IDs. MemoryIds are local to each ProxyService; + // we register every memory in every proxy in the same global order so + // a single int identifies the memory across all of them. + std::map memory_ids; + + auto add_memory_to_all = [&](mscclpp::RegisteredMemory mem) -> mscclpp::MemoryId { + mscclpp::MemoryId id = static_cast(-1); + for (auto& ps : proxy_services) { + auto cur = ps->addMemory(mem); + if (id == static_cast(-1)) id = cur; + EP_HOST_ASSERT(cur == id && "MemoryIds drifted across proxy services"); + } + return id; + }; + + // Register local memory + auto local_rdma_buffer_mem = communicator->registerMemory(rdma_buffer_ptr, num_rdma_bytes, all_transport); + memory_ids[rank] = add_memory_to_all(local_rdma_buffer_mem); + + // Send local memory to other ranks. // - // Unlike DeepEP which used `nvshmem_align` to place the RDMA buffer on the - // symmetric heap, all our internode communication goes through MSCCL++ - // `PortChannel` (proxy-based RDMA), so a plain `cudaMalloc` + IB memory - // registration is sufficient. The bootstrap barrier replaces - // `nvshmem_barrier_all`. - if (num_rdma_bytes > 0) { - EP_HOST_ASSERT(communicator != nullptr); - EP_HOST_ASSERT(bootstrap != nullptr); - - // Allocate the RDMA buffer - CUDA_CHECK(cudaMalloc(&rdma_buffer_ptr, num_rdma_bytes)); - CUDA_CHECK(cudaMemset(rdma_buffer_ptr, 0, num_rdma_bytes)); - bootstrap->barrier(); - CUDA_CHECK(cudaDeviceSynchronize()); - - // Rank -> RDMA buffer IDs. MemoryIds are local to each ProxyService; - // we register every memory in every proxy in the same global order so - // a single int identifies the memory across all of them. - std::map memory_ids; - - auto add_memory_to_all = [&](mscclpp::RegisteredMemory mem) -> mscclpp::MemoryId { - mscclpp::MemoryId id = static_cast(-1); - for (auto& ps : proxy_services) { - auto cur = ps->addMemory(mem); - if (id == static_cast(-1)) id = cur; - EP_HOST_ASSERT(cur == id && "MemoryIds drifted across proxy services"); - } - return id; - }; - - // Register local memory - auto local_rdma_buffer_mem = communicator->registerMemory(rdma_buffer_ptr, num_rdma_bytes, all_transport); - memory_ids[rank] = add_memory_to_all(local_rdma_buffer_mem); - - // Send local memory to other ranks. - // - // NOTE: DeepEP filters this to same-GPU-ID peers in low_latency_mode - // because LL there uses NVSHMEM, not port channels. This port drives - // LL kernels through PortChannel, so every peer must have a real - // memory/connection/semaphore/port channel entry. Treat LL and HT - // sync identically: always connect all peers. - // - // Caveat: for a pure intra-node LL launch (``num_nvl_bytes == 0`` with - // every peer on the same host) the resulting port channels go through - // the CPU proxy over IB loopback between different HCAs, which on - // this platform does not deliver atomics reliably and currently - // deadlocks LL dispatch. See `src/ext/ep/README.md` for the full - // discussion. Cross-node LL (DeepEP's recommended 1-GPU-per-node - // topology) is unaffected. - // Use tag=1 to disambiguate from the NVL phase's tag=0 traffic with same-node peers. - constexpr int kRdmaTag = 1; - for (int r = 0; r < num_ranks; ++r) { - if (r == rank) continue; - communicator->sendMemory(local_rdma_buffer_mem, r, kRdmaTag); - } - - // Receive remote memory from other ranks. - for (int r = 0; r < num_ranks; ++r) { - if (r == rank) continue; - auto f = communicator->recvMemory(r, kRdmaTag); - auto mem = f.get(); - memory_ids[r] = add_memory_to_all(std::move(mem)); - } - - // Rank -> vector of connections - std::unordered_map> connections; - const mscclpp::EndpointConfig ipc_cfg(ipc_transport); - const mscclpp::EndpointConfig ib_cfg(ib_transport); - - // Self connection for local memory (CUDA IPC). - connections[rank].emplace_back(communicator->connect(ipc_cfg, rank, kRdmaTag).get()); - - // Remote IB connections (multi-QP per peer). - const int num_ib_connections_per_rank = 12; // #QPs per rank (mirrors DeepEP). - for (int r = 0; r < num_ranks; ++r) { - if (r == rank) continue; - std::vector> futures; - futures.reserve(num_ib_connections_per_rank); - for (int i = 0; i < num_ib_connections_per_rank; ++i) { - futures.emplace_back(communicator->connect(ib_cfg, r, kRdmaTag)); - } - for (auto& f : futures) connections[r].emplace_back(f.get()); - } - - // Rank -> vector of (proxy_idx, semaphore_id_within_proxy). Iterate - // peers in sorted rank order so semaphore pairings between nodes line - // up deterministically. Channels — and therefore their backing - // semaphores — are sharded across `proxy_services`: channel at flat - // index `i*num_ranks + r` lives on proxy `(i*num_ranks + r) % - // num_proxy_services`. SemaphoreIds are local to each proxy, so we - // record (proxy_idx, sid) pairs. - std::unordered_map>> sema_ids; - const int num_semaphores_per_rank = 16; - for (int i = 0; i < num_semaphores_per_rank; ++i) { - for (int r = 0; r < num_ranks; ++r) { - auto conn_it = connections.find(r); - EP_HOST_ASSERT(conn_it != connections.end()); - auto& conns = conn_it->second; - auto& conn = conns[i % conns.size()]; - int proxy_idx = (i * num_ranks + r) % num_proxy_services; - auto sema_id = proxy_services[proxy_idx]->buildAndAddSemaphore(*communicator, conn); - sema_ids[r].emplace_back(proxy_idx, sema_id); - } - } - - // Create port channels + device handles. - // - // The kernels index `port_channel_handles[channel_id * num_ranks + peer_rank]` - // where peer_rank is a GLOBAL rank in [0..num_ranks). So the outer stride must - // be num_ranks with peers in ascending rank order. Iterating `memory_ids` (an - // `unordered_map`) yields hash order and would misroute signals, deadlocking. - // Each channel inherits the proxy of the semaphore it was built on, so the - // resulting `PortChannelDeviceHandle` routes its FIFO pushes to the correct - // proxy thread. - const int num_port_channels_per_rank = num_semaphores_per_rank; - std::vector port_channel_handles; - for (int i = 0; i < num_port_channels_per_rank; ++i) { - for (int r = 0; r < num_ranks; ++r) { - auto mem_it = memory_ids.find(r); - EP_HOST_ASSERT(mem_it != memory_ids.end()); - auto memory_id = mem_it->second; - auto [proxy_idx, sema_id] = sema_ids[r][i % sema_ids[r].size()]; - auto port_channel = proxy_services[proxy_idx]->portChannel(sema_id, memory_id, memory_ids[rank]); - port_channels.emplace_back(std::move(port_channel)); - port_channel_handles.emplace_back(port_channels.rbegin()->deviceHandle()); - } - } - - port_channel_handles_device_ptr = mscclpp::detail::gpuCallocShared( - port_channel_handles.size()); - mscclpp::gpuMemcpy( - port_channel_handles_device_ptr.get(), port_channel_handles.data(), port_channel_handles.size(), - cudaMemcpyHostToDevice); - - // ------------------------------------------------------------------ - // Intra-node LL fast path setup. - // - // When all ranks sit on the same host (num_rdma_ranks == 1), LL dispatch - // and combine still go through `PortChannel` above — which internally - // uses the proxy service over IB loopback between different HCAs on - // this platform. That path is correct but slow (caps at ~170 GB/s vs. - // NVLink's multi-TB/s). We additionally set up CUDA-IPC peer pointers - // to each peer's `rdma_buffer_ptr` plus a set of per-peer MemoryChannels - // for a barrier ring. The LL kernels select this path at launch time. - // Cross-node LL is unaffected: this block is a no-op there. - // ------------------------------------------------------------------ - if (low_latency_mode and num_rdma_ranks == 1) { - EP_HOST_ASSERT(num_ranks == num_nvl_ranks); - EP_HOST_ASSERT(num_ranks <= NUM_MAX_NVL_PEERS); - - // 1. Exchange CUDA IPC handles for rdma_buffer_ptr via bootstrap. - CUDA_CHECK(cudaIpcGetMemHandle(&rdma_ipc_handles[rank], rdma_buffer_ptr)); - std::vector all_rdma_handles(num_ranks); - all_rdma_handles[rank] = rdma_ipc_handles[rank]; - bootstrap->allGather(all_rdma_handles.data(), sizeof(cudaIpcMemHandle_t)); - - peer_rdma_bases[rank] = rdma_buffer_ptr; - for (int r = 0; r < num_ranks; ++r) { - if (r == rank) continue; - rdma_ipc_handles[r] = all_rdma_handles[r]; - CUDA_CHECK(cudaIpcOpenMemHandle(&peer_rdma_bases[r], rdma_ipc_handles[r], - cudaIpcMemLazyEnablePeerAccess)); - } - CUDA_CHECK(cudaMalloc(&peer_rdma_bases_gpu, sizeof(void*) * NUM_MAX_NVL_PEERS)); - CUDA_CHECK(cudaMemcpy(peer_rdma_bases_gpu, peer_rdma_bases, - sizeof(void*) * NUM_MAX_NVL_PEERS, cudaMemcpyHostToDevice)); - - // 2. Build MemoryChannels for the per-peer barrier ring. These use - // CUDA IPC connections (distinct tag from the existing port-channel - // machinery) so setup does not interfere with cross-node fallback. - constexpr int kLlIpcTag = 2; - auto rdma_mem_ipc = communicator->registerMemory(rdma_buffer_ptr, num_rdma_bytes, ipc_transport); - std::vector> remote_futures(num_ranks); - for (int r = 0; r < num_ranks; ++r) { - if (r == rank) continue; - communicator->sendMemory(rdma_mem_ipc, r, kLlIpcTag); - remote_futures[r] = communicator->recvMemory(r, kLlIpcTag); - } - std::vector ll_ipc_conns(num_ranks); - { - std::vector> conn_futures(num_ranks); - mscclpp::EndpointConfig cfg(ipc_transport); - for (int r = 0; r < num_ranks; ++r) { - if (r == rank) continue; - conn_futures[r] = communicator->connect(cfg, r, kLlIpcTag); - } - for (int r = 0; r < num_ranks; ++r) { - if (r == rank) continue; - ll_ipc_conns[r] = conn_futures[r].get(); - } - } - - std::vector ll_handles(num_ranks); - for (int r = 0; r < num_ranks; ++r) { - if (r == rank) continue; - auto sema = std::make_shared(*communicator, ll_ipc_conns[r]); - ll_memory_channels.emplace_back(sema, remote_futures[r].get(), rdma_mem_ipc); - ll_handles[r] = ll_memory_channels.rbegin()->deviceHandle(); - } - ll_memory_channel_handles_device_ptr = - mscclpp::detail::gpuCallocShared(num_ranks); - mscclpp::gpuMemcpy( - ll_memory_channel_handles_device_ptr.get(), ll_handles.data(), num_ranks, - cudaMemcpyHostToDevice); - - ll_ipc_ready = true; - } + // NOTE: DeepEP filters this to same-GPU-ID peers in low_latency_mode + // because LL there uses NVSHMEM, not port channels. This port drives + // LL kernels through PortChannel, so every peer must have a real + // memory/connection/semaphore/port channel entry. Treat LL and HT + // sync identically: always connect all peers. + // + // Caveat: for a pure intra-node LL launch (``num_nvl_bytes == 0`` with + // every peer on the same host) the resulting port channels go through + // the CPU proxy over IB loopback between different HCAs, which on + // this platform does not deliver atomics reliably and currently + // deadlocks LL dispatch. See `src/ext/ep/README.md` for the full + // discussion. Cross-node LL (DeepEP's recommended 1-GPU-per-node + // topology) is unaffected. + // Use tag=1 to disambiguate from the NVL phase's tag=0 traffic with same-node peers. + constexpr int kRdmaTag = 1; + for (int r = 0; r < num_ranks; ++r) { + if (r == rank) continue; + communicator->sendMemory(local_rdma_buffer_mem, r, kRdmaTag); } - // Ready to use - available = true; + // Receive remote memory from other ranks. + for (int r = 0; r < num_ranks; ++r) { + if (r == rank) continue; + auto f = communicator->recvMemory(r, kRdmaTag); + auto mem = f.get(); + memory_ids[r] = add_memory_to_all(std::move(mem)); + } + + // Rank -> vector of connections + std::unordered_map> connections; + const mscclpp::EndpointConfig ipc_cfg(ipc_transport); + const mscclpp::EndpointConfig ib_cfg(ib_transport); + + // Self connection for local memory (CUDA IPC). + connections[rank].emplace_back(communicator->connect(ipc_cfg, rank, kRdmaTag).get()); + + // Remote IB connections (multi-QP per peer). + const int num_ib_connections_per_rank = 12; // #QPs per rank (mirrors DeepEP). + for (int r = 0; r < num_ranks; ++r) { + if (r == rank) continue; + std::vector> futures; + futures.reserve(num_ib_connections_per_rank); + for (int i = 0; i < num_ib_connections_per_rank; ++i) { + futures.emplace_back(communicator->connect(ib_cfg, r, kRdmaTag)); + } + for (auto& f : futures) connections[r].emplace_back(f.get()); + } + + // Rank -> vector of (proxy_idx, semaphore_id_within_proxy). Iterate + // peers in sorted rank order so semaphore pairings between nodes line + // up deterministically. Channels — and therefore their backing + // semaphores — are sharded across `proxy_services`: channel at flat + // index `i*num_ranks + r` lives on proxy `(i*num_ranks + r) % + // num_proxy_services`. SemaphoreIds are local to each proxy, so we + // record (proxy_idx, sid) pairs. + std::unordered_map>> sema_ids; + const int num_semaphores_per_rank = 16; + for (int i = 0; i < num_semaphores_per_rank; ++i) { + for (int r = 0; r < num_ranks; ++r) { + auto conn_it = connections.find(r); + EP_HOST_ASSERT(conn_it != connections.end()); + auto& conns = conn_it->second; + auto& conn = conns[i % conns.size()]; + int proxy_idx = (i * num_ranks + r) % num_proxy_services; + auto sema_id = proxy_services[proxy_idx]->buildAndAddSemaphore(*communicator, conn); + sema_ids[r].emplace_back(proxy_idx, sema_id); + } + } + + // Create port channels + device handles. + // + // The kernels index `port_channel_handles[channel_id * num_ranks + peer_rank]` + // where peer_rank is a GLOBAL rank in [0..num_ranks). So the outer stride must + // be num_ranks with peers in ascending rank order. Iterating `memory_ids` (an + // `unordered_map`) yields hash order and would misroute signals, deadlocking. + // Each channel inherits the proxy of the semaphore it was built on, so the + // resulting `PortChannelDeviceHandle` routes its FIFO pushes to the correct + // proxy thread. + const int num_port_channels_per_rank = num_semaphores_per_rank; + std::vector port_channel_handles; + for (int i = 0; i < num_port_channels_per_rank; ++i) { + for (int r = 0; r < num_ranks; ++r) { + auto mem_it = memory_ids.find(r); + EP_HOST_ASSERT(mem_it != memory_ids.end()); + auto memory_id = mem_it->second; + auto [proxy_idx, sema_id] = sema_ids[r][i % sema_ids[r].size()]; + auto port_channel = proxy_services[proxy_idx]->portChannel(sema_id, memory_id, memory_ids[rank]); + port_channels.emplace_back(std::move(port_channel)); + port_channel_handles.emplace_back(port_channels.rbegin()->deviceHandle()); + } + } + + port_channel_handles_device_ptr = + mscclpp::detail::gpuCallocShared(port_channel_handles.size()); + mscclpp::gpuMemcpy(port_channel_handles_device_ptr.get(), + port_channel_handles.data(), port_channel_handles.size(), + cudaMemcpyHostToDevice); + + // ------------------------------------------------------------------ + // Intra-node LL fast path setup. + // + // When all ranks sit on the same host (num_rdma_ranks == 1), LL dispatch + // and combine still go through `PortChannel` above — which internally + // uses the proxy service over IB loopback between different HCAs on + // this platform. That path is correct but slow (caps at ~170 GB/s vs. + // NVLink's multi-TB/s). We additionally set up CUDA-IPC peer pointers + // to each peer's `rdma_buffer_ptr` plus a set of per-peer MemoryChannels + // for a barrier ring. The LL kernels select this path at launch time. + // Cross-node LL is unaffected: this block is a no-op there. + // ------------------------------------------------------------------ + if (low_latency_mode and num_rdma_ranks == 1) { + EP_HOST_ASSERT(num_ranks == num_nvl_ranks); + EP_HOST_ASSERT(num_ranks <= NUM_MAX_NVL_PEERS); + + // 1. Exchange CUDA IPC handles for rdma_buffer_ptr via bootstrap. + CUDA_CHECK(cudaIpcGetMemHandle(&rdma_ipc_handles[rank], rdma_buffer_ptr)); + std::vector all_rdma_handles(num_ranks); + all_rdma_handles[rank] = rdma_ipc_handles[rank]; + bootstrap->allGather(all_rdma_handles.data(), sizeof(cudaIpcMemHandle_t)); + + peer_rdma_bases[rank] = rdma_buffer_ptr; + for (int r = 0; r < num_ranks; ++r) { + if (r == rank) continue; + rdma_ipc_handles[r] = all_rdma_handles[r]; + CUDA_CHECK(cudaIpcOpenMemHandle(&peer_rdma_bases[r], rdma_ipc_handles[r], cudaIpcMemLazyEnablePeerAccess)); + } + CUDA_CHECK(cudaMalloc(&peer_rdma_bases_gpu, sizeof(void*) * NUM_MAX_NVL_PEERS)); + CUDA_CHECK( + cudaMemcpy(peer_rdma_bases_gpu, peer_rdma_bases, sizeof(void*) * NUM_MAX_NVL_PEERS, cudaMemcpyHostToDevice)); + + // 2. Build MemoryChannels for the per-peer barrier ring. These use + // CUDA IPC connections (distinct tag from the existing port-channel + // machinery) so setup does not interfere with cross-node fallback. + constexpr int kLlIpcTag = 2; + auto rdma_mem_ipc = communicator->registerMemory(rdma_buffer_ptr, num_rdma_bytes, ipc_transport); + std::vector> remote_futures(num_ranks); + for (int r = 0; r < num_ranks; ++r) { + if (r == rank) continue; + communicator->sendMemory(rdma_mem_ipc, r, kLlIpcTag); + remote_futures[r] = communicator->recvMemory(r, kLlIpcTag); + } + std::vector ll_ipc_conns(num_ranks); + { + std::vector> conn_futures(num_ranks); + mscclpp::EndpointConfig cfg(ipc_transport); + for (int r = 0; r < num_ranks; ++r) { + if (r == rank) continue; + conn_futures[r] = communicator->connect(cfg, r, kLlIpcTag); + } + for (int r = 0; r < num_ranks; ++r) { + if (r == rank) continue; + ll_ipc_conns[r] = conn_futures[r].get(); + } + } + + std::vector ll_handles(num_ranks); + for (int r = 0; r < num_ranks; ++r) { + if (r == rank) continue; + auto sema = std::make_shared(*communicator, ll_ipc_conns[r]); + ll_memory_channels.emplace_back(sema, remote_futures[r].get(), rdma_mem_ipc); + ll_handles[r] = ll_memory_channels.rbegin()->deviceHandle(); + } + ll_memory_channel_handles_device_ptr = + mscclpp::detail::gpuCallocShared(num_ranks); + mscclpp::gpuMemcpy(ll_memory_channel_handles_device_ptr.get(), + ll_handles.data(), num_ranks, cudaMemcpyHostToDevice); + + ll_ipc_ready = true; + } + } + + // Ready to use + available = true; } std::tuple, torch::Tensor, torch::Tensor, std::optional> -Buffer::get_dispatch_layout(const torch::Tensor& topk_idx, int num_experts, - std::optional& previous_event, bool async, bool allocate_on_comm_stream) { - EP_HOST_ASSERT(topk_idx.dim() == 2); - EP_HOST_ASSERT(topk_idx.is_contiguous()); +Buffer::get_dispatch_layout(const torch::Tensor& topk_idx, int num_experts, std::optional& previous_event, + bool async, bool allocate_on_comm_stream) { + EP_HOST_ASSERT(topk_idx.dim() == 2); + EP_HOST_ASSERT(topk_idx.is_contiguous()); + EP_HOST_ASSERT(num_experts > 0); + + // Allocate all tensors on comm stream if set + // NOTES: do not allocate tensors upfront! + auto compute_stream = at::cuda::getCurrentCUDAStream(); + if (allocate_on_comm_stream) { + EP_HOST_ASSERT(previous_event.has_value() and async); + at::cuda::setCurrentCUDAStream(comm_stream); + } + + // Wait previous tasks to be finished + if (previous_event.has_value()) { + stream_wait(comm_stream, previous_event.value()); + } else { + stream_wait(comm_stream, compute_stream); + } + + auto num_tokens = static_cast(topk_idx.size(0)), num_topk = static_cast(topk_idx.size(1)); + auto num_tokens_per_rank = torch::empty({num_ranks}, dtype(torch::kInt32).device(torch::kCUDA)); + auto num_tokens_per_rdma_rank = std::optional(); + auto num_tokens_per_expert = torch::empty({num_experts}, dtype(torch::kInt32).device(torch::kCUDA)); + auto is_token_in_rank = torch::empty({num_tokens, num_ranks}, dtype(torch::kBool).device(torch::kCUDA)); + if (is_internode_available()) + num_tokens_per_rdma_rank = torch::empty({num_rdma_ranks}, dtype(torch::kInt32).device(torch::kCUDA)); + + internode::get_dispatch_layout( + topk_idx.data_ptr(), num_tokens_per_rank.data_ptr(), + num_tokens_per_rdma_rank.has_value() ? num_tokens_per_rdma_rank.value().data_ptr() : nullptr, + num_tokens_per_expert.data_ptr(), is_token_in_rank.data_ptr(), num_tokens, num_topk, num_ranks, + num_experts, comm_stream); + + // Wait streams + std::optional event; + if (async) { + event = EventHandle(comm_stream); + for (auto& t : {topk_idx, num_tokens_per_rank, num_tokens_per_expert, is_token_in_rank}) { + t.record_stream(comm_stream); + if (allocate_on_comm_stream) t.record_stream(compute_stream); + } + for (auto& to : {num_tokens_per_rdma_rank}) { + to.has_value() ? to->record_stream(comm_stream) : void(); + if (allocate_on_comm_stream) to.has_value() ? to->record_stream(compute_stream) : void(); + } + } else { + stream_wait(compute_stream, comm_stream); + } + + // Switch back compute stream + if (allocate_on_comm_stream) at::cuda::setCurrentCUDAStream(compute_stream); + + return {num_tokens_per_rank, num_tokens_per_rdma_rank, num_tokens_per_expert, is_token_in_rank, event}; +} + +std::tuple, std::optional, std::optional, + std::vector, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, + std::optional> +Buffer::intranode_dispatch( + const torch::Tensor& x, const std::optional& x_scales, const std::optional& topk_idx, + const std::optional& topk_weights, const std::optional& num_tokens_per_rank, + const torch::Tensor& is_token_in_rank, const std::optional& num_tokens_per_expert, + int cached_num_recv_tokens, const std::optional& cached_rank_prefix_matrix, + const std::optional& cached_channel_prefix_matrix, int expert_alignment, const Config& config, + std::optional& previous_event, bool async, bool allocate_on_comm_stream) { + bool cached_mode = cached_rank_prefix_matrix.has_value(); + + // One channel use two blocks, even-numbered blocks for sending, odd-numbered blocks for receiving. + EP_HOST_ASSERT(config.num_sms % 2 == 0); + int num_channels = config.num_sms / 2; + if (cached_mode) { + EP_HOST_ASSERT(cached_rank_prefix_matrix.has_value()); + EP_HOST_ASSERT(cached_channel_prefix_matrix.has_value()); + } else { + EP_HOST_ASSERT(num_tokens_per_rank.has_value()); + EP_HOST_ASSERT(num_tokens_per_expert.has_value()); + } + + // Type checks + EP_HOST_ASSERT(is_token_in_rank.scalar_type() == torch::kBool); + if (cached_mode) { + EP_HOST_ASSERT(cached_rank_prefix_matrix->scalar_type() == torch::kInt32); + EP_HOST_ASSERT(cached_channel_prefix_matrix->scalar_type() == torch::kInt32); + } else { + EP_HOST_ASSERT(num_tokens_per_expert->scalar_type() == torch::kInt32); + EP_HOST_ASSERT(num_tokens_per_rank->scalar_type() == torch::kInt32); + } + + // Shape and contiguous checks + EP_HOST_ASSERT(x.dim() == 2 and x.is_contiguous()); + EP_HOST_ASSERT((x.size(1) * x.element_size()) % sizeof(int4) == 0); + EP_HOST_ASSERT(is_token_in_rank.dim() == 2 and is_token_in_rank.is_contiguous()); + EP_HOST_ASSERT(is_token_in_rank.size(0) == x.size(0) and is_token_in_rank.size(1) == num_ranks); + if (cached_mode) { + EP_HOST_ASSERT(cached_rank_prefix_matrix->dim() == 2 and cached_rank_prefix_matrix->is_contiguous()); + EP_HOST_ASSERT(cached_rank_prefix_matrix->size(0) == num_ranks and cached_rank_prefix_matrix->size(1) == num_ranks); + EP_HOST_ASSERT(cached_channel_prefix_matrix->dim() == 2 and cached_channel_prefix_matrix->is_contiguous()); + EP_HOST_ASSERT(cached_channel_prefix_matrix->size(0) == num_ranks and + cached_channel_prefix_matrix->size(1) == num_channels); + } else { + EP_HOST_ASSERT(num_tokens_per_expert->dim() == 1 and num_tokens_per_expert->is_contiguous()); + EP_HOST_ASSERT(num_tokens_per_expert->size(0) % num_ranks == 0); + EP_HOST_ASSERT(num_tokens_per_expert->size(0) / num_ranks <= NUM_MAX_LOCAL_EXPERTS); + EP_HOST_ASSERT(num_tokens_per_rank->dim() == 1 and num_tokens_per_rank->is_contiguous()); + EP_HOST_ASSERT(num_tokens_per_rank->size(0) == num_ranks); + } + + auto num_tokens = static_cast(x.size(0)), hidden = static_cast(x.size(1)); + auto num_experts = cached_mode ? 0 : static_cast(num_tokens_per_expert->size(0)), + num_local_experts = num_experts / num_ranks; + + // Top-k checks + int num_topk = 0; + int64_t* topk_idx_ptr = nullptr; + float* topk_weights_ptr = nullptr; + EP_HOST_ASSERT(topk_idx.has_value() == topk_weights.has_value()); + if (topk_idx.has_value()) { + num_topk = static_cast(topk_idx->size(1)); EP_HOST_ASSERT(num_experts > 0); + EP_HOST_ASSERT(topk_idx->dim() == 2 and topk_idx->is_contiguous()); + EP_HOST_ASSERT(topk_weights->dim() == 2 and topk_weights->is_contiguous()); + EP_HOST_ASSERT(num_tokens == topk_idx->size(0) and num_tokens == topk_weights->size(0)); + EP_HOST_ASSERT(num_topk == topk_weights->size(1)); + EP_HOST_ASSERT(topk_weights->scalar_type() == torch::kFloat32); + topk_idx_ptr = topk_idx->data_ptr(); + topk_weights_ptr = topk_weights->data_ptr(); + } - // Allocate all tensors on comm stream if set - // NOTES: do not allocate tensors upfront! - auto compute_stream = at::cuda::getCurrentCUDAStream(); - if (allocate_on_comm_stream) { - EP_HOST_ASSERT(previous_event.has_value() and async); - at::cuda::setCurrentCUDAStream(comm_stream); + // FP8 scales checks + float* x_scales_ptr = nullptr; + int num_scales = 0; + if (x_scales.has_value()) { + EP_HOST_ASSERT(x.element_size() == 1); + EP_HOST_ASSERT(x_scales->scalar_type() == torch::kFloat32); + EP_HOST_ASSERT(x_scales->dim() > 0 and x_scales->dim() < 3 and x_scales->is_contiguous()); + EP_HOST_ASSERT(x_scales->size(0) == num_tokens); + num_scales = x_scales->dim() == 1 ? 1 : static_cast(x_scales->size(1)); + x_scales_ptr = x_scales->data_ptr(); + } + + // Allocate all tensors on comm stream if set + // NOTES: do not allocate tensors upfront! + auto compute_stream = at::cuda::getCurrentCUDAStream(); + if (allocate_on_comm_stream) { + EP_HOST_ASSERT(previous_event.has_value() and async); + at::cuda::setCurrentCUDAStream(comm_stream); + } + + // Wait previous tasks to be finished + if (previous_event.has_value()) { + stream_wait(comm_stream, previous_event.value()); + } else { + stream_wait(comm_stream, compute_stream); + } + + // Create handles (only return for non-cached mode) + int num_recv_tokens = -1; + auto rank_prefix_matrix = torch::Tensor(); + auto channel_prefix_matrix = torch::Tensor(); + std::vector num_recv_tokens_per_expert_list; + + // Barrier or send sizes + // To clean: channel start/end offset, head and tail + int num_memset_int = num_channels * num_ranks * 4; + if (cached_mode) { + num_recv_tokens = cached_num_recv_tokens; + rank_prefix_matrix = cached_rank_prefix_matrix.value(); + channel_prefix_matrix = cached_channel_prefix_matrix.value(); + + // Copy rank prefix matrix and clean flags + intranode::cached_notify_dispatch(rank_prefix_matrix.data_ptr(), num_memset_int, buffer_ptrs_gpu, + task_fifo_ptrs_gpu, head, rank, num_ranks, comm_stream); + move_fifo_slots(2); + } else { + rank_prefix_matrix = torch::empty({num_ranks, num_ranks}, dtype(torch::kInt32).device(torch::kCUDA)); + channel_prefix_matrix = torch::empty({num_ranks, num_channels}, dtype(torch::kInt32).device(torch::kCUDA)); + + // Send sizes + // Meta information: + // - Size prefix by ranks, shaped as `[num_ranks, num_ranks]` + // - Size prefix by experts (not used later), shaped as `[num_ranks, num_local_experts]` + // NOTES: no more token dropping in this version + *moe_recv_counter = -1; + for (int i = 0; i < num_local_experts; ++i) moe_recv_expert_counter[i] = -1; + EP_HOST_ASSERT(num_ranks * (num_ranks + num_local_experts) * sizeof(int) <= num_nvl_bytes); + intranode::notify_dispatch(num_tokens_per_rank->data_ptr(), moe_recv_counter_mapped, num_ranks, + num_tokens_per_expert->data_ptr(), moe_recv_expert_counter_mapped, num_experts, + num_tokens, is_token_in_rank.data_ptr(), channel_prefix_matrix.data_ptr(), + rank_prefix_matrix.data_ptr(), num_memset_int, expert_alignment, buffer_ptrs_gpu, + task_fifo_ptrs_gpu, head, rank, comm_stream, num_channels); + move_fifo_slots(3); + + // Synchronize total received tokens and tokens per expert + auto start_time = std::chrono::high_resolution_clock::now(); + while (true) { + // Read total count + num_recv_tokens = static_cast(*moe_recv_counter); + + // Read per-expert count + bool ready = (num_recv_tokens >= 0); + for (int i = 0; i < num_local_experts and ready; ++i) ready &= moe_recv_expert_counter[i] >= 0; + + if (ready) break; + + // Timeout check + if (std::chrono::duration_cast(std::chrono::high_resolution_clock::now() - start_time) + .count() > NUM_CPU_TIMEOUT_SECS) + throw std::runtime_error("DeepEP error: CPU recv timeout"); } + num_recv_tokens_per_expert_list = + std::vector(moe_recv_expert_counter, moe_recv_expert_counter + num_local_experts); + } - // Wait previous tasks to be finished - if (previous_event.has_value()) { - stream_wait(comm_stream, previous_event.value()); - } else { - stream_wait(comm_stream, compute_stream); + // Allocate new tensors + auto recv_x = torch::empty({num_recv_tokens, hidden}, x.options()); + auto recv_src_idx = torch::empty({num_recv_tokens}, dtype(torch::kInt32).device(torch::kCUDA)); + auto recv_topk_idx = std::optional(), recv_topk_weights = std::optional(), + recv_x_scales = std::optional(); + auto recv_channel_prefix_matrix = torch::empty({num_ranks, num_channels}, dtype(torch::kInt32).device(torch::kCUDA)); + auto send_head = torch::empty({num_tokens, num_ranks}, dtype(torch::kInt32).device(torch::kCUDA)); + + // Assign pointers + int64_t* recv_topk_idx_ptr = nullptr; + float* recv_topk_weights_ptr = nullptr; + float* recv_x_scales_ptr = nullptr; + if (topk_idx.has_value()) { + recv_topk_idx = torch::empty({num_recv_tokens, num_topk}, topk_idx->options()); + recv_topk_weights = torch::empty({num_recv_tokens, num_topk}, topk_weights->options()); + recv_topk_idx_ptr = recv_topk_idx->data_ptr(); + recv_topk_weights_ptr = recv_topk_weights->data_ptr(); + } + if (x_scales.has_value()) { + recv_x_scales = x_scales->dim() == 1 ? torch::empty({num_recv_tokens}, x_scales->options()) + : torch::empty({num_recv_tokens, num_scales}, x_scales->options()); + recv_x_scales_ptr = recv_x_scales->data_ptr(); + } + + // Dispatch + EP_HOST_ASSERT(num_ranks * num_ranks * sizeof(int) + // Size prefix matrix + num_channels * num_ranks * sizeof(int) + // Channel start offset + num_channels * num_ranks * sizeof(int) + // Channel end offset + num_channels * num_ranks * sizeof(int) * 2 + // Queue head and tail + num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * hidden * + recv_x.element_size() + // Data buffer + num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * + sizeof(int) + // Source index buffer + num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * num_topk * + sizeof(int64_t) + // Top-k index buffer + num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * num_topk * + sizeof(float) + // Top-k weight buffer + num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * sizeof(float) * + num_scales // FP8 scale buffer + <= num_nvl_bytes); + intranode::dispatch(recv_x.data_ptr(), recv_x_scales_ptr, recv_src_idx.data_ptr(), recv_topk_idx_ptr, + recv_topk_weights_ptr, recv_channel_prefix_matrix.data_ptr(), send_head.data_ptr(), + x.data_ptr(), x_scales_ptr, topk_idx_ptr, topk_weights_ptr, is_token_in_rank.data_ptr(), + channel_prefix_matrix.data_ptr(), num_tokens, + static_cast(hidden * recv_x.element_size() / sizeof(int4)), num_topk, num_experts, + num_scales, buffer_ptrs_gpu, rank, num_ranks, comm_stream, config.num_sms, + config.num_max_nvl_chunked_send_tokens, config.num_max_nvl_chunked_recv_tokens); + + // Wait streams + std::optional event; + if (async) { + event = EventHandle(comm_stream); + for (auto& t : {x, is_token_in_rank, rank_prefix_matrix, channel_prefix_matrix, recv_x, recv_src_idx, + recv_channel_prefix_matrix, send_head}) { + t.record_stream(comm_stream); + if (allocate_on_comm_stream) t.record_stream(compute_stream); } + for (auto& to : + {x_scales, topk_idx, topk_weights, num_tokens_per_rank, num_tokens_per_expert, cached_channel_prefix_matrix, + cached_rank_prefix_matrix, recv_topk_idx, recv_topk_weights, recv_x_scales}) { + to.has_value() ? to->record_stream(comm_stream) : void(); + if (allocate_on_comm_stream) to.has_value() ? to->record_stream(compute_stream) : void(); + } + } else { + stream_wait(compute_stream, comm_stream); + } - auto num_tokens = static_cast(topk_idx.size(0)), num_topk = static_cast(topk_idx.size(1)); - auto num_tokens_per_rank = torch::empty({num_ranks}, dtype(torch::kInt32).device(torch::kCUDA)); - auto num_tokens_per_rdma_rank = std::optional(); - auto num_tokens_per_expert = torch::empty({num_experts}, dtype(torch::kInt32).device(torch::kCUDA)); - auto is_token_in_rank = torch::empty({num_tokens, num_ranks}, dtype(torch::kBool).device(torch::kCUDA)); - if (is_internode_available()) - num_tokens_per_rdma_rank = torch::empty({num_rdma_ranks}, dtype(torch::kInt32).device(torch::kCUDA)); + // Switch back compute stream + if (allocate_on_comm_stream) at::cuda::setCurrentCUDAStream(compute_stream); - internode::get_dispatch_layout(topk_idx.data_ptr(), - num_tokens_per_rank.data_ptr(), - num_tokens_per_rdma_rank.has_value() ? num_tokens_per_rdma_rank.value().data_ptr() : nullptr, - num_tokens_per_expert.data_ptr(), - is_token_in_rank.data_ptr(), - num_tokens, num_topk, num_ranks, num_experts, + // Return values + return {recv_x, + recv_x_scales, + recv_topk_idx, + recv_topk_weights, + num_recv_tokens_per_expert_list, + rank_prefix_matrix, + channel_prefix_matrix, + recv_channel_prefix_matrix, + recv_src_idx, + send_head, + event}; +} + +std::tuple, std::optional> Buffer::intranode_combine( + const torch::Tensor& x, const std::optional& topk_weights, const torch::Tensor& src_idx, + const torch::Tensor& rank_prefix_matrix, const torch::Tensor& channel_prefix_matrix, const torch::Tensor& send_head, + const Config& config, std::optional& previous_event, bool async, bool allocate_on_comm_stream) { + EP_HOST_ASSERT(x.dim() == 2 and x.is_contiguous()); + EP_HOST_ASSERT(src_idx.dim() == 1 and src_idx.is_contiguous() and src_idx.scalar_type() == torch::kInt32); + EP_HOST_ASSERT(send_head.dim() == 2 and send_head.is_contiguous() and send_head.scalar_type() == torch::kInt32); + EP_HOST_ASSERT(rank_prefix_matrix.dim() == 2 and rank_prefix_matrix.is_contiguous() and + rank_prefix_matrix.scalar_type() == torch::kInt32); + EP_HOST_ASSERT(channel_prefix_matrix.dim() == 2 and channel_prefix_matrix.is_contiguous() and + channel_prefix_matrix.scalar_type() == torch::kInt32); + + // One channel use two blocks, even-numbered blocks for sending, odd-numbered blocks for receiving. + EP_HOST_ASSERT(config.num_sms % 2 == 0); + int num_channels = config.num_sms / 2; + + auto num_tokens = static_cast(x.size(0)), hidden = static_cast(x.size(1)); + auto num_recv_tokens = static_cast(send_head.size(0)); + EP_HOST_ASSERT(src_idx.size(0) == num_tokens); + EP_HOST_ASSERT(send_head.size(1) == num_ranks); + EP_HOST_ASSERT(rank_prefix_matrix.size(0) == num_ranks and rank_prefix_matrix.size(1) == num_ranks); + EP_HOST_ASSERT(channel_prefix_matrix.size(0) == num_ranks and channel_prefix_matrix.size(1) == num_channels); + EP_HOST_ASSERT((hidden * x.element_size()) % sizeof(int4) == 0); + + // Allocate all tensors on comm stream if set + // NOTES: do not allocate tensors upfront! + auto compute_stream = at::cuda::getCurrentCUDAStream(); + if (allocate_on_comm_stream) { + EP_HOST_ASSERT(previous_event.has_value() and async); + at::cuda::setCurrentCUDAStream(comm_stream); + } + + // Wait previous tasks to be finished + if (previous_event.has_value()) { + stream_wait(comm_stream, previous_event.value()); + } else { + stream_wait(comm_stream, compute_stream); + } + + int num_topk = 0; + auto recv_topk_weights = std::optional(); + float* topk_weights_ptr = nullptr; + float* recv_topk_weights_ptr = nullptr; + if (topk_weights.has_value()) { + EP_HOST_ASSERT(topk_weights->dim() == 2 and topk_weights->is_contiguous()); + EP_HOST_ASSERT(topk_weights->size(0) == num_tokens); + EP_HOST_ASSERT(topk_weights->scalar_type() == torch::kFloat32); + num_topk = static_cast(topk_weights->size(1)); + topk_weights_ptr = topk_weights->data_ptr(); + recv_topk_weights = torch::empty({num_recv_tokens, num_topk}, topk_weights->options()); + recv_topk_weights_ptr = recv_topk_weights->data_ptr(); + } + + // Launch barrier and reset queue head and tail + EP_HOST_ASSERT(num_channels * num_ranks * sizeof(int) * 2 <= num_nvl_bytes); + intranode::cached_notify_combine(buffer_ptrs_gpu, send_head.data_ptr(), num_channels, num_recv_tokens, + num_channels * num_ranks * 2, task_fifo_ptrs_gpu, head, rank, num_ranks, comm_stream); - // Wait streams - std::optional event; - if (async) { - event = EventHandle(comm_stream); - for (auto& t: {topk_idx, num_tokens_per_rank, num_tokens_per_expert, is_token_in_rank}) { - t.record_stream(comm_stream); - if (allocate_on_comm_stream) - t.record_stream(compute_stream); - } - for (auto& to: {num_tokens_per_rdma_rank}) { - to.has_value() ? to->record_stream(comm_stream) : void(); - if (allocate_on_comm_stream) - to.has_value() ? to->record_stream(compute_stream) : void(); - } - } else { - stream_wait(compute_stream, comm_stream); + // NOTES: this function uses two FIFO slots (barrier before and after) + move_fifo_slots(2); + + // Combine data + auto recv_x = torch::empty({num_recv_tokens, hidden}, x.options()); + EP_HOST_ASSERT(num_channels * num_ranks * sizeof(int) * 2 + // Queue head and tail + num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * hidden * + x.element_size() + // Data buffer + num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * + sizeof(int) + // Source index buffer + num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * num_topk * + sizeof(float) // Top-k weight buffer + <= num_nvl_bytes); + intranode::combine(at::cuda::ScalarTypeToCudaDataType(x.scalar_type()), recv_x.data_ptr(), recv_topk_weights_ptr, + x.data_ptr(), topk_weights_ptr, src_idx.data_ptr(), rank_prefix_matrix.data_ptr(), + channel_prefix_matrix.data_ptr(), send_head.data_ptr(), num_tokens, num_recv_tokens, + hidden, num_topk, buffer_ptrs_gpu, rank, num_ranks, comm_stream, config.num_sms, + config.num_max_nvl_chunked_send_tokens, config.num_max_nvl_chunked_recv_tokens); + + // Wait streams + std::optional event; + if (async) { + event = EventHandle(comm_stream); + for (auto& t : {x, src_idx, send_head, rank_prefix_matrix, channel_prefix_matrix, recv_x}) { + t.record_stream(comm_stream); + if (allocate_on_comm_stream) t.record_stream(compute_stream); } - - // Switch back compute stream - if (allocate_on_comm_stream) - at::cuda::setCurrentCUDAStream(compute_stream); - - return {num_tokens_per_rank, num_tokens_per_rdma_rank, num_tokens_per_expert, is_token_in_rank, event}; -} - -std::tuple, std::optional, std::optional, std::vector, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, std::optional> -Buffer::intranode_dispatch(const torch::Tensor& x, const std::optional& x_scales, - const std::optional& topk_idx, const std::optional& topk_weights, - const std::optional& num_tokens_per_rank, const torch::Tensor& is_token_in_rank, const std::optional& num_tokens_per_expert, - int cached_num_recv_tokens, const std::optional& cached_rank_prefix_matrix, const std::optional& cached_channel_prefix_matrix, - int expert_alignment, const Config& config, std::optional& previous_event, bool async, bool allocate_on_comm_stream) { - bool cached_mode = cached_rank_prefix_matrix.has_value(); - - // One channel use two blocks, even-numbered blocks for sending, odd-numbered blocks for receiving. - EP_HOST_ASSERT(config.num_sms % 2 == 0); - int num_channels = config.num_sms / 2; - if (cached_mode) { - EP_HOST_ASSERT(cached_rank_prefix_matrix.has_value()); - EP_HOST_ASSERT(cached_channel_prefix_matrix.has_value()); - } else { - EP_HOST_ASSERT(num_tokens_per_rank.has_value()); - EP_HOST_ASSERT(num_tokens_per_expert.has_value()); + for (auto& to : {topk_weights, recv_topk_weights}) { + to.has_value() ? to->record_stream(comm_stream) : void(); + if (allocate_on_comm_stream) to.has_value() ? to->record_stream(compute_stream) : void(); } + } else { + stream_wait(compute_stream, comm_stream); + } - // Type checks - EP_HOST_ASSERT(is_token_in_rank.scalar_type() == torch::kBool); - if (cached_mode) { - EP_HOST_ASSERT(cached_rank_prefix_matrix->scalar_type() == torch::kInt32); - EP_HOST_ASSERT(cached_channel_prefix_matrix->scalar_type() == torch::kInt32); - } else { - EP_HOST_ASSERT(num_tokens_per_expert->scalar_type() == torch::kInt32); - EP_HOST_ASSERT(num_tokens_per_rank->scalar_type() == torch::kInt32); - } + // Switch back compute stream + if (allocate_on_comm_stream) at::cuda::setCurrentCUDAStream(compute_stream); - // Shape and contiguous checks - EP_HOST_ASSERT(x.dim() == 2 and x.is_contiguous()); - EP_HOST_ASSERT((x.size(1) * x.element_size()) % sizeof(int4) == 0); - EP_HOST_ASSERT(is_token_in_rank.dim() == 2 and is_token_in_rank.is_contiguous()); - EP_HOST_ASSERT(is_token_in_rank.size(0) == x.size(0) and is_token_in_rank.size(1) == num_ranks); - if (cached_mode) { - EP_HOST_ASSERT(cached_rank_prefix_matrix->dim() == 2 and cached_rank_prefix_matrix->is_contiguous()); - EP_HOST_ASSERT(cached_rank_prefix_matrix->size(0) == num_ranks and cached_rank_prefix_matrix->size(1) == num_ranks); - EP_HOST_ASSERT(cached_channel_prefix_matrix->dim() == 2 and cached_channel_prefix_matrix->is_contiguous()); - EP_HOST_ASSERT(cached_channel_prefix_matrix->size(0) == num_ranks and cached_channel_prefix_matrix->size(1) == num_channels); - } else { - EP_HOST_ASSERT(num_tokens_per_expert->dim() == 1 and num_tokens_per_expert->is_contiguous()); - EP_HOST_ASSERT(num_tokens_per_expert->size(0) % num_ranks == 0); - EP_HOST_ASSERT(num_tokens_per_expert->size(0) / num_ranks <= NUM_MAX_LOCAL_EXPERTS); - EP_HOST_ASSERT(num_tokens_per_rank->dim() == 1 and num_tokens_per_rank->is_contiguous()); - EP_HOST_ASSERT(num_tokens_per_rank->size(0) == num_ranks); - } - - auto num_tokens = static_cast(x.size(0)), hidden = static_cast(x.size(1)); - auto num_experts = cached_mode ? 0 : static_cast(num_tokens_per_expert->size(0)), num_local_experts = num_experts / num_ranks; - - // Top-k checks - int num_topk = 0; - int64_t* topk_idx_ptr = nullptr; - float* topk_weights_ptr = nullptr; - EP_HOST_ASSERT(topk_idx.has_value() == topk_weights.has_value()); - if (topk_idx.has_value()) { - num_topk = static_cast(topk_idx->size(1)); - EP_HOST_ASSERT(num_experts > 0); - EP_HOST_ASSERT(topk_idx->dim() == 2 and topk_idx->is_contiguous()); - EP_HOST_ASSERT(topk_weights->dim() == 2 and topk_weights->is_contiguous()); - EP_HOST_ASSERT(num_tokens == topk_idx->size(0) and num_tokens == topk_weights->size(0)); - EP_HOST_ASSERT(num_topk == topk_weights->size(1)); - EP_HOST_ASSERT(topk_weights->scalar_type() == torch::kFloat32); - topk_idx_ptr = topk_idx->data_ptr(); - topk_weights_ptr = topk_weights->data_ptr(); - } - - // FP8 scales checks - float* x_scales_ptr = nullptr; - int num_scales = 0; - if (x_scales.has_value()) { - EP_HOST_ASSERT(x.element_size() == 1); - EP_HOST_ASSERT(x_scales->scalar_type() == torch::kFloat32); - EP_HOST_ASSERT(x_scales->dim() > 0 and x_scales->dim() < 3 and x_scales->is_contiguous()); - EP_HOST_ASSERT(x_scales->size(0) == num_tokens); - num_scales = x_scales->dim() == 1 ? 1 : static_cast(x_scales->size(1)); - x_scales_ptr = x_scales->data_ptr(); - } - - // Allocate all tensors on comm stream if set - // NOTES: do not allocate tensors upfront! - auto compute_stream = at::cuda::getCurrentCUDAStream(); - if (allocate_on_comm_stream) { - EP_HOST_ASSERT(previous_event.has_value() and async); - at::cuda::setCurrentCUDAStream(comm_stream); - } - - // Wait previous tasks to be finished - if (previous_event.has_value()) { - stream_wait(comm_stream, previous_event.value()); - } else { - stream_wait(comm_stream, compute_stream); - } - - // Create handles (only return for non-cached mode) - int num_recv_tokens = -1; - auto rank_prefix_matrix = torch::Tensor(); - auto channel_prefix_matrix = torch::Tensor(); - std::vector num_recv_tokens_per_expert_list; - - // Barrier or send sizes - // To clean: channel start/end offset, head and tail - int num_memset_int = num_channels * num_ranks * 4; - if (cached_mode) { - num_recv_tokens = cached_num_recv_tokens; - rank_prefix_matrix = cached_rank_prefix_matrix.value(); - channel_prefix_matrix = cached_channel_prefix_matrix.value(); - - // Copy rank prefix matrix and clean flags - intranode::cached_notify_dispatch(rank_prefix_matrix.data_ptr(), num_memset_int, - buffer_ptrs_gpu, task_fifo_ptrs_gpu, head, rank, num_ranks, - comm_stream); - move_fifo_slots(2); - } else { - rank_prefix_matrix = torch::empty({num_ranks, num_ranks}, dtype(torch::kInt32).device(torch::kCUDA)); - channel_prefix_matrix = torch::empty({num_ranks, num_channels}, dtype(torch::kInt32).device(torch::kCUDA)); - - // Send sizes - // Meta information: - // - Size prefix by ranks, shaped as `[num_ranks, num_ranks]` - // - Size prefix by experts (not used later), shaped as `[num_ranks, num_local_experts]` - // NOTES: no more token dropping in this version - *moe_recv_counter = -1; - for (int i = 0; i < num_local_experts; ++ i) - moe_recv_expert_counter[i] = -1; - EP_HOST_ASSERT(num_ranks * (num_ranks + num_local_experts) * sizeof(int) <= num_nvl_bytes); - intranode::notify_dispatch(num_tokens_per_rank->data_ptr(), moe_recv_counter_mapped, num_ranks, - num_tokens_per_expert->data_ptr(), moe_recv_expert_counter_mapped, num_experts, - num_tokens, is_token_in_rank.data_ptr(), channel_prefix_matrix.data_ptr(), - rank_prefix_matrix.data_ptr(), - num_memset_int, expert_alignment, - buffer_ptrs_gpu, task_fifo_ptrs_gpu, head, rank, - comm_stream, num_channels); - move_fifo_slots(3); - - // Synchronize total received tokens and tokens per expert - auto start_time = std::chrono::high_resolution_clock::now(); - while (true) { - // Read total count - num_recv_tokens = static_cast(*moe_recv_counter); - - // Read per-expert count - bool ready = (num_recv_tokens >= 0); - for (int i = 0; i < num_local_experts and ready; ++i) - ready &= moe_recv_expert_counter[i] >= 0; - - if (ready) - break; - - // Timeout check - if (std::chrono::duration_cast(std::chrono::high_resolution_clock::now() - start_time).count() > NUM_CPU_TIMEOUT_SECS) - throw std::runtime_error("DeepEP error: CPU recv timeout"); - } - num_recv_tokens_per_expert_list = std::vector(moe_recv_expert_counter, moe_recv_expert_counter + num_local_experts); - } - - // Allocate new tensors - auto recv_x = torch::empty({num_recv_tokens, hidden}, x.options()); - auto recv_src_idx = torch::empty({num_recv_tokens}, dtype(torch::kInt32).device(torch::kCUDA)); - auto recv_topk_idx = std::optional(), recv_topk_weights = std::optional(), recv_x_scales = std::optional(); - auto recv_channel_prefix_matrix = torch::empty({num_ranks, num_channels}, dtype(torch::kInt32).device(torch::kCUDA)); - auto send_head = torch::empty({num_tokens, num_ranks}, dtype(torch::kInt32).device(torch::kCUDA)); - - // Assign pointers - int64_t* recv_topk_idx_ptr = nullptr; - float* recv_topk_weights_ptr = nullptr; - float* recv_x_scales_ptr = nullptr; - if (topk_idx.has_value()) { - recv_topk_idx = torch::empty({num_recv_tokens, num_topk}, topk_idx->options()); - recv_topk_weights = torch::empty({num_recv_tokens, num_topk}, topk_weights->options()); - recv_topk_idx_ptr = recv_topk_idx->data_ptr(); - recv_topk_weights_ptr = recv_topk_weights->data_ptr(); - } - if (x_scales.has_value()) { - recv_x_scales = x_scales->dim() == 1 ? - torch::empty({num_recv_tokens}, x_scales->options()) : - torch::empty({num_recv_tokens, num_scales}, x_scales->options()); - recv_x_scales_ptr = recv_x_scales->data_ptr(); - } - - // Dispatch - EP_HOST_ASSERT(num_ranks * num_ranks * sizeof(int) + // Size prefix matrix - num_channels * num_ranks * sizeof(int) + // Channel start offset - num_channels * num_ranks * sizeof(int) + // Channel end offset - num_channels * num_ranks * sizeof(int) * 2 + // Queue head and tail - num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * hidden * recv_x.element_size() + // Data buffer - num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * sizeof(int) + // Source index buffer - num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * num_topk * sizeof(int64_t) + // Top-k index buffer - num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * num_topk * sizeof(float) + // Top-k weight buffer - num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * sizeof(float) * num_scales // FP8 scale buffer - <= num_nvl_bytes); - intranode::dispatch(recv_x.data_ptr(), recv_x_scales_ptr, recv_src_idx.data_ptr(), recv_topk_idx_ptr, recv_topk_weights_ptr, recv_channel_prefix_matrix.data_ptr(), - send_head.data_ptr(), - x.data_ptr(), x_scales_ptr, topk_idx_ptr, topk_weights_ptr, - is_token_in_rank.data_ptr(), channel_prefix_matrix.data_ptr(), - num_tokens, static_cast(hidden * recv_x.element_size() / sizeof(int4)), num_topk, num_experts, num_scales, - buffer_ptrs_gpu, rank, num_ranks, comm_stream, config.num_sms, - config.num_max_nvl_chunked_send_tokens, config.num_max_nvl_chunked_recv_tokens); - - // Wait streams - std::optional event; - if (async) { - event = EventHandle(comm_stream); - for (auto& t: {x, is_token_in_rank, rank_prefix_matrix, channel_prefix_matrix, recv_x, recv_src_idx, recv_channel_prefix_matrix, send_head}) { - t.record_stream(comm_stream); - if (allocate_on_comm_stream) - t.record_stream(compute_stream); - } - for (auto& to: {x_scales, topk_idx, topk_weights, num_tokens_per_rank, num_tokens_per_expert, cached_channel_prefix_matrix, cached_rank_prefix_matrix, recv_topk_idx, recv_topk_weights, recv_x_scales}) { - to.has_value() ? to->record_stream(comm_stream) : void(); - if (allocate_on_comm_stream) - to.has_value() ? to->record_stream(compute_stream) : void(); - } - } else { - stream_wait(compute_stream, comm_stream); - } - - // Switch back compute stream - if (allocate_on_comm_stream) - at::cuda::setCurrentCUDAStream(compute_stream); - - // Return values - return {recv_x, recv_x_scales, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, rank_prefix_matrix, channel_prefix_matrix, recv_channel_prefix_matrix, recv_src_idx, send_head, event}; -} - -std::tuple, std::optional> -Buffer::intranode_combine(const torch::Tensor& x, const std::optional& topk_weights, - const torch::Tensor& src_idx, const torch::Tensor& rank_prefix_matrix, const torch::Tensor& channel_prefix_matrix, - const torch::Tensor& send_head, const Config& config, std::optional& previous_event, bool async, bool allocate_on_comm_stream) { - EP_HOST_ASSERT(x.dim() == 2 and x.is_contiguous()); - EP_HOST_ASSERT(src_idx.dim() == 1 and src_idx.is_contiguous() and src_idx.scalar_type() == torch::kInt32); - EP_HOST_ASSERT(send_head.dim() == 2 and send_head.is_contiguous() and send_head.scalar_type() == torch::kInt32); - EP_HOST_ASSERT(rank_prefix_matrix.dim() == 2 and rank_prefix_matrix.is_contiguous() and rank_prefix_matrix.scalar_type() == torch::kInt32); - EP_HOST_ASSERT(channel_prefix_matrix.dim() == 2 and channel_prefix_matrix.is_contiguous() and channel_prefix_matrix.scalar_type() == torch::kInt32); - - // One channel use two blocks, even-numbered blocks for sending, odd-numbered blocks for receiving. - EP_HOST_ASSERT(config.num_sms % 2 == 0); - int num_channels = config.num_sms / 2; - - auto num_tokens = static_cast(x.size(0)), hidden = static_cast(x.size(1)); - auto num_recv_tokens = static_cast(send_head.size(0)); - EP_HOST_ASSERT(src_idx.size(0) == num_tokens); - EP_HOST_ASSERT(send_head.size(1) == num_ranks); - EP_HOST_ASSERT(rank_prefix_matrix.size(0) == num_ranks and rank_prefix_matrix.size(1) == num_ranks); - EP_HOST_ASSERT(channel_prefix_matrix.size(0) == num_ranks and channel_prefix_matrix.size(1) == num_channels); - EP_HOST_ASSERT((hidden * x.element_size()) % sizeof(int4) == 0); - - // Allocate all tensors on comm stream if set - // NOTES: do not allocate tensors upfront! - auto compute_stream = at::cuda::getCurrentCUDAStream(); - if (allocate_on_comm_stream) { - EP_HOST_ASSERT(previous_event.has_value() and async); - at::cuda::setCurrentCUDAStream(comm_stream); - } - - // Wait previous tasks to be finished - if (previous_event.has_value()) { - stream_wait(comm_stream, previous_event.value()); - } else { - stream_wait(comm_stream, compute_stream); - } - - int num_topk = 0; - auto recv_topk_weights = std::optional(); - float* topk_weights_ptr = nullptr; - float* recv_topk_weights_ptr = nullptr; - if (topk_weights.has_value()) { - EP_HOST_ASSERT(topk_weights->dim() == 2 and topk_weights->is_contiguous()); - EP_HOST_ASSERT(topk_weights->size(0) == num_tokens); - EP_HOST_ASSERT(topk_weights->scalar_type() == torch::kFloat32); - num_topk = static_cast(topk_weights->size(1)); - topk_weights_ptr = topk_weights->data_ptr(); - recv_topk_weights = torch::empty({num_recv_tokens, num_topk}, topk_weights->options()); - recv_topk_weights_ptr = recv_topk_weights->data_ptr(); - } - - // Launch barrier and reset queue head and tail - EP_HOST_ASSERT(num_channels * num_ranks * sizeof(int) * 2 <= num_nvl_bytes); - intranode::cached_notify_combine(buffer_ptrs_gpu, send_head.data_ptr(), - num_channels, num_recv_tokens, num_channels * num_ranks * 2, - task_fifo_ptrs_gpu, head, rank, num_ranks, - comm_stream); - - // NOTES: this function uses two FIFO slots (barrier before and after) - move_fifo_slots(2); - - // Combine data - auto recv_x = torch::empty({num_recv_tokens, hidden}, x.options()); - EP_HOST_ASSERT(num_channels * num_ranks * sizeof(int) * 2 + // Queue head and tail - num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * hidden * x.element_size() + // Data buffer - num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * sizeof(int) + // Source index buffer - num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * num_topk * sizeof(float) // Top-k weight buffer - <= num_nvl_bytes); - intranode::combine(at::cuda::ScalarTypeToCudaDataType(x.scalar_type()), - recv_x.data_ptr(), recv_topk_weights_ptr, - x.data_ptr(), topk_weights_ptr, - src_idx.data_ptr(), rank_prefix_matrix.data_ptr(), channel_prefix_matrix.data_ptr(), - send_head.data_ptr(), num_tokens, num_recv_tokens, hidden, num_topk, - buffer_ptrs_gpu, rank, num_ranks, - comm_stream, config.num_sms, - config.num_max_nvl_chunked_send_tokens, config.num_max_nvl_chunked_recv_tokens); - - // Wait streams - std::optional event; - if (async) { - event = EventHandle(comm_stream); - for (auto& t: {x, src_idx, send_head, rank_prefix_matrix, channel_prefix_matrix, recv_x}) { - t.record_stream(comm_stream); - if (allocate_on_comm_stream) - t.record_stream(compute_stream); - } - for (auto& to: {topk_weights, recv_topk_weights}) { - to.has_value() ? to->record_stream(comm_stream) : void(); - if (allocate_on_comm_stream) - to.has_value() ? to->record_stream(compute_stream) : void(); - } - } else { - stream_wait(compute_stream, comm_stream); - } - - // Switch back compute stream - if (allocate_on_comm_stream) - at::cuda::setCurrentCUDAStream(compute_stream); - - return {recv_x, recv_topk_weights, event}; + return {recv_x, recv_topk_weights, event}; } // ----------------------------------------------------------------------------- @@ -927,651 +931,639 @@ Buffer::intranode_combine(const torch::Tensor& x, const std::optional, std::optional, std::optional, std::vector, torch::Tensor, torch::Tensor, std::optional, torch::Tensor, std::optional, torch::Tensor, std::optional, std::optional, std::optional, std::optional> -Buffer::internode_dispatch(const torch::Tensor& x, const std::optional& x_scales, - const std::optional& topk_idx, const std::optional& topk_weights, - const std::optional& num_tokens_per_rank, const std::optional& num_tokens_per_rdma_rank, - const torch::Tensor& is_token_in_rank, const std::optional& num_tokens_per_expert, - int cached_num_recv_tokens, int cached_num_rdma_recv_tokens, - const std::optional& cached_rdma_channel_prefix_matrix, const std::optional& cached_recv_rdma_rank_prefix_sum, - const std::optional& cached_gbl_channel_prefix_matrix, const std::optional& cached_recv_gbl_rank_prefix_sum, - int expert_alignment, const Config& config, std::optional& previous_event, bool async, bool allocate_on_comm_stream) { - // In dispatch, CPU will busy-wait until GPU receive tensor size metadata from other ranks, which can be quite long. - pybind11::gil_scoped_release release; +std::tuple, std::optional, std::optional, + std::vector, torch::Tensor, torch::Tensor, std::optional, torch::Tensor, + std::optional, torch::Tensor, std::optional, std::optional, + std::optional, std::optional> +Buffer::internode_dispatch( + const torch::Tensor& x, const std::optional& x_scales, const std::optional& topk_idx, + const std::optional& topk_weights, const std::optional& num_tokens_per_rank, + const std::optional& num_tokens_per_rdma_rank, const torch::Tensor& is_token_in_rank, + const std::optional& num_tokens_per_expert, int cached_num_recv_tokens, + int cached_num_rdma_recv_tokens, const std::optional& cached_rdma_channel_prefix_matrix, + const std::optional& cached_recv_rdma_rank_prefix_sum, + const std::optional& cached_gbl_channel_prefix_matrix, + const std::optional& cached_recv_gbl_rank_prefix_sum, int expert_alignment, const Config& config, + std::optional& previous_event, bool async, bool allocate_on_comm_stream) { + // In dispatch, CPU will busy-wait until GPU receive tensor size metadata from other ranks, which can be quite long. + pybind11::gil_scoped_release release; - const int num_channels = config.num_sms / 2; - EP_HOST_ASSERT(config.num_sms % 2 == 0); - EP_HOST_ASSERT(0 < get_num_rdma_ranks() and get_num_rdma_ranks() <= NUM_MAX_RDMA_PEERS); + const int num_channels = config.num_sms / 2; + EP_HOST_ASSERT(config.num_sms % 2 == 0); + EP_HOST_ASSERT(0 < get_num_rdma_ranks() and get_num_rdma_ranks() <= NUM_MAX_RDMA_PEERS); - bool cached_mode = cached_rdma_channel_prefix_matrix.has_value(); - if (cached_mode) { - EP_HOST_ASSERT(cached_rdma_channel_prefix_matrix.has_value()); - EP_HOST_ASSERT(cached_recv_rdma_rank_prefix_sum.has_value()); - EP_HOST_ASSERT(cached_gbl_channel_prefix_matrix.has_value()); - EP_HOST_ASSERT(cached_recv_gbl_rank_prefix_sum.has_value()); - } else { - EP_HOST_ASSERT(num_tokens_per_rank.has_value()); - EP_HOST_ASSERT(num_tokens_per_rdma_rank.has_value()); - EP_HOST_ASSERT(num_tokens_per_expert.has_value()); - } + bool cached_mode = cached_rdma_channel_prefix_matrix.has_value(); + if (cached_mode) { + EP_HOST_ASSERT(cached_rdma_channel_prefix_matrix.has_value()); + EP_HOST_ASSERT(cached_recv_rdma_rank_prefix_sum.has_value()); + EP_HOST_ASSERT(cached_gbl_channel_prefix_matrix.has_value()); + EP_HOST_ASSERT(cached_recv_gbl_rank_prefix_sum.has_value()); + } else { + EP_HOST_ASSERT(num_tokens_per_rank.has_value()); + EP_HOST_ASSERT(num_tokens_per_rdma_rank.has_value()); + EP_HOST_ASSERT(num_tokens_per_expert.has_value()); + } - // Type checks - if (cached_mode) { - EP_HOST_ASSERT(cached_rdma_channel_prefix_matrix->scalar_type() == torch::kInt32); - EP_HOST_ASSERT(cached_recv_rdma_rank_prefix_sum->scalar_type() == torch::kInt32); - EP_HOST_ASSERT(cached_gbl_channel_prefix_matrix->scalar_type() == torch::kInt32); - EP_HOST_ASSERT(cached_recv_gbl_rank_prefix_sum->scalar_type() == torch::kInt32); - } else { - EP_HOST_ASSERT(num_tokens_per_rank->scalar_type() == torch::kInt32); - EP_HOST_ASSERT(num_tokens_per_rdma_rank->scalar_type() == torch::kInt32); - EP_HOST_ASSERT(num_tokens_per_expert->scalar_type() == torch::kInt32); - } + // Type checks + if (cached_mode) { + EP_HOST_ASSERT(cached_rdma_channel_prefix_matrix->scalar_type() == torch::kInt32); + EP_HOST_ASSERT(cached_recv_rdma_rank_prefix_sum->scalar_type() == torch::kInt32); + EP_HOST_ASSERT(cached_gbl_channel_prefix_matrix->scalar_type() == torch::kInt32); + EP_HOST_ASSERT(cached_recv_gbl_rank_prefix_sum->scalar_type() == torch::kInt32); + } else { + EP_HOST_ASSERT(num_tokens_per_rank->scalar_type() == torch::kInt32); + EP_HOST_ASSERT(num_tokens_per_rdma_rank->scalar_type() == torch::kInt32); + EP_HOST_ASSERT(num_tokens_per_expert->scalar_type() == torch::kInt32); + } - // Shape and contiguous checks - EP_HOST_ASSERT(x.dim() == 2 and x.is_contiguous()); - EP_HOST_ASSERT((x.size(1) * x.element_size()) % sizeof(int4) == 0); - if (cached_mode) { - EP_HOST_ASSERT(cached_rdma_channel_prefix_matrix->dim() == 2 and cached_rdma_channel_prefix_matrix->is_contiguous()); - EP_HOST_ASSERT(cached_rdma_channel_prefix_matrix->size(0) == num_rdma_ranks and cached_rdma_channel_prefix_matrix->size(1) == num_channels); - EP_HOST_ASSERT(cached_recv_rdma_rank_prefix_sum->dim() == 1 and cached_recv_rdma_rank_prefix_sum->is_contiguous()); - EP_HOST_ASSERT(cached_recv_rdma_rank_prefix_sum->size(0) == num_rdma_ranks); - EP_HOST_ASSERT(cached_gbl_channel_prefix_matrix->dim() == 2 and cached_gbl_channel_prefix_matrix->is_contiguous()); - EP_HOST_ASSERT(cached_gbl_channel_prefix_matrix->size(0) == num_ranks and cached_gbl_channel_prefix_matrix->size(1) == num_channels); - EP_HOST_ASSERT(cached_recv_gbl_rank_prefix_sum->dim() == 1 and cached_recv_gbl_rank_prefix_sum->is_contiguous()); - EP_HOST_ASSERT(cached_recv_gbl_rank_prefix_sum->size(0) == num_ranks); - } else { - EP_HOST_ASSERT(num_tokens_per_rank->dim() == 1 and num_tokens_per_rank->is_contiguous()); - EP_HOST_ASSERT(num_tokens_per_rdma_rank->dim() == 1 and num_tokens_per_rdma_rank->is_contiguous()); - EP_HOST_ASSERT(num_tokens_per_expert->dim() == 1 and num_tokens_per_expert->is_contiguous()); - EP_HOST_ASSERT(num_tokens_per_rank->size(0) == num_ranks); - EP_HOST_ASSERT(num_tokens_per_rdma_rank->size(0) == num_rdma_ranks); - EP_HOST_ASSERT(num_tokens_per_expert->size(0) % num_ranks == 0); - EP_HOST_ASSERT(num_tokens_per_expert->size(0) / num_ranks <= NUM_MAX_LOCAL_EXPERTS); - } + // Shape and contiguous checks + EP_HOST_ASSERT(x.dim() == 2 and x.is_contiguous()); + EP_HOST_ASSERT((x.size(1) * x.element_size()) % sizeof(int4) == 0); + if (cached_mode) { + EP_HOST_ASSERT(cached_rdma_channel_prefix_matrix->dim() == 2 and + cached_rdma_channel_prefix_matrix->is_contiguous()); + EP_HOST_ASSERT(cached_rdma_channel_prefix_matrix->size(0) == num_rdma_ranks and + cached_rdma_channel_prefix_matrix->size(1) == num_channels); + EP_HOST_ASSERT(cached_recv_rdma_rank_prefix_sum->dim() == 1 and cached_recv_rdma_rank_prefix_sum->is_contiguous()); + EP_HOST_ASSERT(cached_recv_rdma_rank_prefix_sum->size(0) == num_rdma_ranks); + EP_HOST_ASSERT(cached_gbl_channel_prefix_matrix->dim() == 2 and cached_gbl_channel_prefix_matrix->is_contiguous()); + EP_HOST_ASSERT(cached_gbl_channel_prefix_matrix->size(0) == num_ranks and + cached_gbl_channel_prefix_matrix->size(1) == num_channels); + EP_HOST_ASSERT(cached_recv_gbl_rank_prefix_sum->dim() == 1 and cached_recv_gbl_rank_prefix_sum->is_contiguous()); + EP_HOST_ASSERT(cached_recv_gbl_rank_prefix_sum->size(0) == num_ranks); + } else { + EP_HOST_ASSERT(num_tokens_per_rank->dim() == 1 and num_tokens_per_rank->is_contiguous()); + EP_HOST_ASSERT(num_tokens_per_rdma_rank->dim() == 1 and num_tokens_per_rdma_rank->is_contiguous()); + EP_HOST_ASSERT(num_tokens_per_expert->dim() == 1 and num_tokens_per_expert->is_contiguous()); + EP_HOST_ASSERT(num_tokens_per_rank->size(0) == num_ranks); + EP_HOST_ASSERT(num_tokens_per_rdma_rank->size(0) == num_rdma_ranks); + EP_HOST_ASSERT(num_tokens_per_expert->size(0) % num_ranks == 0); + EP_HOST_ASSERT(num_tokens_per_expert->size(0) / num_ranks <= NUM_MAX_LOCAL_EXPERTS); + } - auto num_tokens = static_cast(x.size(0)), hidden = static_cast(x.size(1)), hidden_int4 = static_cast(x.size(1) * x.element_size() / sizeof(int4)); - auto num_experts = cached_mode ? 0 : static_cast(num_tokens_per_expert->size(0)), num_local_experts = num_experts / num_ranks; + auto num_tokens = static_cast(x.size(0)), hidden = static_cast(x.size(1)), + hidden_int4 = static_cast(x.size(1) * x.element_size() / sizeof(int4)); + auto num_experts = cached_mode ? 0 : static_cast(num_tokens_per_expert->size(0)), + num_local_experts = num_experts / num_ranks; - // Top-k checks - int num_topk = 0; - int64_t* topk_idx_ptr = nullptr; - float* topk_weights_ptr = nullptr; - EP_HOST_ASSERT(topk_idx.has_value() == topk_weights.has_value()); - if (topk_idx.has_value()) { - num_topk = static_cast(topk_idx->size(1)); - EP_HOST_ASSERT(num_experts > 0); - EP_HOST_ASSERT(topk_idx->dim() == 2 and topk_idx->is_contiguous()); - EP_HOST_ASSERT(topk_weights->dim() == 2 and topk_weights->is_contiguous()); - EP_HOST_ASSERT(num_tokens == topk_idx->size(0) and num_tokens == topk_weights->size(0)); - EP_HOST_ASSERT(num_topk == topk_weights->size(1)); - EP_HOST_ASSERT(topk_weights->scalar_type() == torch::kFloat32); - topk_idx_ptr = topk_idx->data_ptr(); - topk_weights_ptr = topk_weights->data_ptr(); - } + // Top-k checks + int num_topk = 0; + int64_t* topk_idx_ptr = nullptr; + float* topk_weights_ptr = nullptr; + EP_HOST_ASSERT(topk_idx.has_value() == topk_weights.has_value()); + if (topk_idx.has_value()) { + num_topk = static_cast(topk_idx->size(1)); + EP_HOST_ASSERT(num_experts > 0); + EP_HOST_ASSERT(topk_idx->dim() == 2 and topk_idx->is_contiguous()); + EP_HOST_ASSERT(topk_weights->dim() == 2 and topk_weights->is_contiguous()); + EP_HOST_ASSERT(num_tokens == topk_idx->size(0) and num_tokens == topk_weights->size(0)); + EP_HOST_ASSERT(num_topk == topk_weights->size(1)); + EP_HOST_ASSERT(topk_weights->scalar_type() == torch::kFloat32); + topk_idx_ptr = topk_idx->data_ptr(); + topk_weights_ptr = topk_weights->data_ptr(); + } - // FP8 scales checks - float* x_scales_ptr = nullptr; - int num_scales = 0; - if (x_scales.has_value()) { - EP_HOST_ASSERT(x.element_size() == 1); - EP_HOST_ASSERT(x_scales->scalar_type() == torch::kFloat32); - EP_HOST_ASSERT(x_scales->dim() > 0 and x_scales->dim() < 3 and x_scales->is_contiguous()); - EP_HOST_ASSERT(x_scales->size(0) == num_tokens); - num_scales = x_scales->dim() == 1 ? 1 : static_cast(x_scales->size(1)); - x_scales_ptr = x_scales->data_ptr(); - } + // FP8 scales checks + float* x_scales_ptr = nullptr; + int num_scales = 0; + if (x_scales.has_value()) { + EP_HOST_ASSERT(x.element_size() == 1); + EP_HOST_ASSERT(x_scales->scalar_type() == torch::kFloat32); + EP_HOST_ASSERT(x_scales->dim() > 0 and x_scales->dim() < 3 and x_scales->is_contiguous()); + EP_HOST_ASSERT(x_scales->size(0) == num_tokens); + num_scales = x_scales->dim() == 1 ? 1 : static_cast(x_scales->size(1)); + x_scales_ptr = x_scales->data_ptr(); + } - // Allocate all tensors on comm stream if set - auto compute_stream = at::cuda::getCurrentCUDAStream(); - if (allocate_on_comm_stream) { - EP_HOST_ASSERT(previous_event.has_value() and async); - at::cuda::setCurrentCUDAStream(comm_stream); - } + // Allocate all tensors on comm stream if set + auto compute_stream = at::cuda::getCurrentCUDAStream(); + if (allocate_on_comm_stream) { + EP_HOST_ASSERT(previous_event.has_value() and async); + at::cuda::setCurrentCUDAStream(comm_stream); + } - // Wait previous tasks to be finished - if (previous_event.has_value()) { - stream_wait(comm_stream, previous_event.value()); - } else { - stream_wait(comm_stream, compute_stream); - } + // Wait previous tasks to be finished + if (previous_event.has_value()) { + stream_wait(comm_stream, previous_event.value()); + } else { + stream_wait(comm_stream, compute_stream); + } - // Create handles (only return for non-cached mode) - int num_recv_tokens = -1, num_rdma_recv_tokens = -1; - auto rdma_channel_prefix_matrix = torch::Tensor(); - auto recv_rdma_rank_prefix_sum = torch::Tensor(); - auto gbl_channel_prefix_matrix = torch::Tensor(); - auto recv_gbl_rank_prefix_sum = torch::Tensor(); - std::vector num_recv_tokens_per_expert_list; + // Create handles (only return for non-cached mode) + int num_recv_tokens = -1, num_rdma_recv_tokens = -1; + auto rdma_channel_prefix_matrix = torch::Tensor(); + auto recv_rdma_rank_prefix_sum = torch::Tensor(); + auto gbl_channel_prefix_matrix = torch::Tensor(); + auto recv_gbl_rank_prefix_sum = torch::Tensor(); + std::vector num_recv_tokens_per_expert_list; - // Barrier or send sizes - if (cached_mode) { - num_recv_tokens = cached_num_recv_tokens; - num_rdma_recv_tokens = cached_num_rdma_recv_tokens; - rdma_channel_prefix_matrix = cached_rdma_channel_prefix_matrix.value(); - recv_rdma_rank_prefix_sum = cached_recv_rdma_rank_prefix_sum.value(); - gbl_channel_prefix_matrix = cached_gbl_channel_prefix_matrix.value(); - recv_gbl_rank_prefix_sum = cached_recv_gbl_rank_prefix_sum.value(); + // Barrier or send sizes + if (cached_mode) { + num_recv_tokens = cached_num_recv_tokens; + num_rdma_recv_tokens = cached_num_rdma_recv_tokens; + rdma_channel_prefix_matrix = cached_rdma_channel_prefix_matrix.value(); + recv_rdma_rank_prefix_sum = cached_recv_rdma_rank_prefix_sum.value(); + gbl_channel_prefix_matrix = cached_gbl_channel_prefix_matrix.value(); + recv_gbl_rank_prefix_sum = cached_recv_gbl_rank_prefix_sum.value(); - // Just a barrier and clean flags - internode::cached_notify(hidden_int4, num_scales, num_topk, num_topk, - num_ranks, num_channels, 0, nullptr, - nullptr, nullptr, nullptr, - rdma_buffer_ptr, config.num_max_rdma_chunked_recv_tokens, - buffer_ptrs_gpu, config.num_max_nvl_chunked_recv_tokens, - task_fifo_ptrs_gpu, head, rank, comm_stream, - config.get_rdma_buffer_size_hint(hidden_int4 * sizeof(int4), num_ranks), - num_nvl_bytes, true, low_latency_mode, - port_channel_handles_device_ptr.get(), - memory_channel_handles_device_ptr.get()); - move_fifo_slots(2); - } else { - rdma_channel_prefix_matrix = torch::empty({num_rdma_ranks, num_channels}, dtype(torch::kInt32).device(torch::kCUDA)); - recv_rdma_rank_prefix_sum = torch::empty({num_rdma_ranks}, dtype(torch::kInt32).device(torch::kCUDA)); - gbl_channel_prefix_matrix = torch::empty({num_ranks, num_channels}, dtype(torch::kInt32).device(torch::kCUDA)); - recv_gbl_rank_prefix_sum = torch::empty({num_ranks}, dtype(torch::kInt32).device(torch::kCUDA)); - - // Send sizes - *moe_recv_counter = -1, *moe_recv_rdma_counter = -1; - for (int i = 0; i < num_local_experts; ++i) - moe_recv_expert_counter[i] = -1; - internode::notify_dispatch(num_tokens_per_rank->data_ptr(), moe_recv_counter_mapped, num_ranks, - num_tokens_per_rdma_rank->data_ptr(), moe_recv_rdma_counter_mapped, - num_tokens_per_expert->data_ptr(), moe_recv_expert_counter_mapped, num_experts, - is_token_in_rank.data_ptr(), num_tokens, num_channels, - hidden_int4, num_scales, num_topk, expert_alignment, - rdma_channel_prefix_matrix.data_ptr(), recv_rdma_rank_prefix_sum.data_ptr(), - gbl_channel_prefix_matrix.data_ptr(), recv_gbl_rank_prefix_sum.data_ptr(), - rdma_buffer_ptr, config.num_max_rdma_chunked_recv_tokens, - buffer_ptrs_gpu, config.num_max_nvl_chunked_recv_tokens, - task_fifo_ptrs_gpu, head, rank, comm_stream, - config.get_rdma_buffer_size_hint(hidden_int4 * sizeof(int4), num_ranks), - num_nvl_bytes, low_latency_mode, port_channel_handles_device_ptr.get(), - memory_channel_handles_device_ptr.get()); - move_fifo_slots(3); - - // Synchronize total received tokens and tokens per expert - auto start_time = std::chrono::high_resolution_clock::now(); - while (true) { - num_recv_tokens = static_cast(*moe_recv_counter); - num_rdma_recv_tokens = static_cast(*moe_recv_rdma_counter); - - bool ready = (num_recv_tokens >= 0) and (num_rdma_recv_tokens >= 0); - for (int i = 0; i < num_local_experts and ready; ++i) - ready &= moe_recv_expert_counter[i] >= 0; - - if (ready) break; - - if (std::chrono::duration_cast(std::chrono::high_resolution_clock::now() - start_time).count() > NUM_CPU_TIMEOUT_SECS) { - printf("Global rank: %d, num_recv_tokens: %d, num_rdma_recv_tokens: %d\n", rank, num_recv_tokens, num_rdma_recv_tokens); - for (int i = 0; i < num_local_experts; ++i) - printf("moe_recv_expert_counter[%d]: %d\n", i, moe_recv_expert_counter[i]); - throw std::runtime_error("mscclpp::ep error: timeout (internode_dispatch CPU)"); - } - } - num_recv_tokens_per_expert_list = std::vector(moe_recv_expert_counter, moe_recv_expert_counter + num_local_experts); - } - - // Allocate new tensors - auto recv_x = torch::empty({num_recv_tokens, hidden}, x.options()); - auto recv_topk_idx = std::optional(), recv_topk_weights = std::optional(), recv_x_scales = std::optional(); - auto recv_src_meta = std::optional(); - auto recv_rdma_channel_prefix_matrix = std::optional(); - auto recv_gbl_channel_prefix_matrix = std::optional(); - auto send_rdma_head = std::optional(); - auto send_nvl_head = std::optional(); - if (not cached_mode) { - recv_src_meta = torch::empty({num_recv_tokens, internode::get_source_meta_bytes()}, dtype(torch::kByte).device(torch::kCUDA)); - recv_rdma_channel_prefix_matrix = torch::empty({num_rdma_ranks, num_channels}, dtype(torch::kInt32).device(torch::kCUDA)); - recv_gbl_channel_prefix_matrix = torch::empty({num_ranks, num_channels}, dtype(torch::kInt32).device(torch::kCUDA)); - send_rdma_head = torch::empty({num_tokens, num_rdma_ranks}, dtype(torch::kInt32).device(torch::kCUDA)); - send_nvl_head = torch::empty({num_rdma_recv_tokens, NUM_MAX_NVL_PEERS}, dtype(torch::kInt32).device(torch::kCUDA)); - } - - int64_t* recv_topk_idx_ptr = nullptr; - float* recv_topk_weights_ptr = nullptr; - float* recv_x_scales_ptr = nullptr; - if (topk_idx.has_value()) { - recv_topk_idx = torch::empty({num_recv_tokens, num_topk}, topk_idx->options()); - recv_topk_weights = torch::empty({num_recv_tokens, num_topk}, topk_weights->options()); - recv_topk_idx_ptr = recv_topk_idx->data_ptr(); - recv_topk_weights_ptr = recv_topk_weights->data_ptr(); - } - if (x_scales.has_value()) { - recv_x_scales = x_scales->dim() == 1 ? - torch::empty({num_recv_tokens}, x_scales->options()) : - torch::empty({num_recv_tokens, num_scales}, x_scales->options()); - recv_x_scales_ptr = recv_x_scales->data_ptr(); - } - - // Launch data dispatch - internode::dispatch(recv_x.data_ptr(), recv_x_scales_ptr, recv_topk_idx_ptr, recv_topk_weights_ptr, - cached_mode ? nullptr : recv_src_meta->data_ptr(), - x.data_ptr(), x_scales_ptr, topk_idx_ptr, topk_weights_ptr, - cached_mode ? nullptr : send_rdma_head->data_ptr(), cached_mode ? nullptr : send_nvl_head->data_ptr(), - cached_mode ? nullptr : recv_rdma_channel_prefix_matrix->data_ptr(), - cached_mode ? nullptr : recv_gbl_channel_prefix_matrix->data_ptr(), - rdma_channel_prefix_matrix.data_ptr(), recv_rdma_rank_prefix_sum.data_ptr(), - gbl_channel_prefix_matrix.data_ptr(), recv_gbl_rank_prefix_sum.data_ptr(), - num_tokens, hidden_int4, num_scales, num_topk, num_experts, - is_token_in_rank.data_ptr(), - rdma_buffer_ptr, config.num_max_rdma_chunked_send_tokens, config.num_max_rdma_chunked_recv_tokens, - buffer_ptrs_gpu, config.num_max_nvl_chunked_send_tokens, config.num_max_nvl_chunked_recv_tokens, - rank, num_ranks, cached_mode, - comm_stream, num_channels, low_latency_mode, - port_channel_handles_device_ptr.get(), - memory_channel_handles_device_ptr.get()); - - // Wait streams - std::optional event; - if (async) { - event = EventHandle(comm_stream); - for (auto& t: {x, is_token_in_rank, recv_x, - rdma_channel_prefix_matrix, recv_rdma_rank_prefix_sum, gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum}) { - t.record_stream(comm_stream); - if (allocate_on_comm_stream) - t.record_stream(compute_stream); - } - for (auto& to: {x_scales, topk_idx, topk_weights, - num_tokens_per_rank, num_tokens_per_rdma_rank, num_tokens_per_expert, - cached_rdma_channel_prefix_matrix, cached_recv_rdma_rank_prefix_sum, - cached_gbl_channel_prefix_matrix, cached_recv_gbl_rank_prefix_sum, - recv_topk_idx, recv_topk_weights, recv_x_scales, - recv_rdma_channel_prefix_matrix, recv_gbl_channel_prefix_matrix, send_rdma_head, send_nvl_head, - recv_src_meta}) { - to.has_value() ? to->record_stream(comm_stream) : void(); - if (allocate_on_comm_stream) - to.has_value() ? to->record_stream(compute_stream) : void(); - } - } else { - stream_wait(compute_stream, comm_stream); - } - - if (allocate_on_comm_stream) - at::cuda::setCurrentCUDAStream(compute_stream); - - return {recv_x, recv_x_scales, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, - rdma_channel_prefix_matrix, gbl_channel_prefix_matrix, - recv_rdma_channel_prefix_matrix, recv_rdma_rank_prefix_sum, - recv_gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum, - recv_src_meta, send_rdma_head, send_nvl_head, event}; -} - -std::tuple, std::optional> -Buffer::internode_combine(const torch::Tensor& x, const std::optional& topk_weights, - const torch::Tensor& src_meta, const torch::Tensor& is_combined_token_in_rank, - const torch::Tensor& rdma_channel_prefix_matrix, const torch::Tensor& rdma_rank_prefix_sum, const torch::Tensor& gbl_channel_prefix_matrix, - const torch::Tensor& combined_rdma_head, const torch::Tensor& combined_nvl_head, - const Config& config, std::optional& previous_event, bool async, bool allocate_on_comm_stream) { - const int num_channels = config.num_sms / 2; - EP_HOST_ASSERT(config.num_sms % 2 == 0); - - // Shape and contiguous checks - EP_HOST_ASSERT(x.dim() == 2 and x.is_contiguous()); - EP_HOST_ASSERT(src_meta.dim() == 2 and src_meta.is_contiguous() and src_meta.scalar_type() == torch::kByte); - EP_HOST_ASSERT(is_combined_token_in_rank.dim() == 2 and is_combined_token_in_rank.is_contiguous() and is_combined_token_in_rank.scalar_type() == torch::kBool); - EP_HOST_ASSERT(rdma_channel_prefix_matrix.dim() == 2 and rdma_channel_prefix_matrix.is_contiguous() and rdma_channel_prefix_matrix.scalar_type() == torch::kInt32); - EP_HOST_ASSERT(rdma_rank_prefix_sum.dim() == 1 and rdma_rank_prefix_sum.is_contiguous() and rdma_rank_prefix_sum.scalar_type() == torch::kInt32); - EP_HOST_ASSERT(gbl_channel_prefix_matrix.dim() == 2 and gbl_channel_prefix_matrix.is_contiguous() and gbl_channel_prefix_matrix.scalar_type() == torch::kInt32); - EP_HOST_ASSERT(combined_rdma_head.dim() == 2 and combined_rdma_head.is_contiguous() and combined_rdma_head.scalar_type() == torch::kInt32); - EP_HOST_ASSERT(combined_nvl_head.dim() == 2 and combined_nvl_head.is_contiguous() and combined_nvl_head.scalar_type() == torch::kInt32); - - auto num_tokens = static_cast(x.size(0)), hidden = static_cast(x.size(1)), hidden_int4 = static_cast(x.size(1) * x.element_size() / sizeof(int4)); - auto num_combined_tokens = static_cast(is_combined_token_in_rank.size(0)); - EP_HOST_ASSERT((hidden * x.element_size()) % sizeof(int4) == 0); - EP_HOST_ASSERT(src_meta.size(1) == internode::get_source_meta_bytes()); - EP_HOST_ASSERT(is_combined_token_in_rank.size(1) == num_ranks); - EP_HOST_ASSERT(rdma_channel_prefix_matrix.size(0) == num_rdma_ranks and rdma_channel_prefix_matrix.size(1) == num_channels); - EP_HOST_ASSERT(rdma_rank_prefix_sum.size(0) == num_rdma_ranks); - EP_HOST_ASSERT(gbl_channel_prefix_matrix.size(0) == num_ranks and gbl_channel_prefix_matrix.size(1) == num_channels); - EP_HOST_ASSERT(combined_rdma_head.dim() == 2 and combined_rdma_head.size(0) == num_combined_tokens and combined_rdma_head.size(1) == num_rdma_ranks); - EP_HOST_ASSERT(combined_nvl_head.dim() == 2 and combined_nvl_head.size(1) == NUM_MAX_NVL_PEERS); - - auto compute_stream = at::cuda::getCurrentCUDAStream(); - if (allocate_on_comm_stream) { - EP_HOST_ASSERT(previous_event.has_value() and async); - at::cuda::setCurrentCUDAStream(comm_stream); - } - - if (previous_event.has_value()) { - stream_wait(comm_stream, previous_event.value()); - } else { - stream_wait(comm_stream, compute_stream); - } - - int num_topk = 0; - auto combined_topk_weights = std::optional(); - float* topk_weights_ptr = nullptr; - float* combined_topk_weights_ptr = nullptr; - if (topk_weights.has_value()) { - EP_HOST_ASSERT(topk_weights->dim() == 2 and topk_weights->is_contiguous()); - EP_HOST_ASSERT(topk_weights->size(0) == num_tokens); - EP_HOST_ASSERT(topk_weights->scalar_type() == torch::kFloat32); - num_topk = static_cast(topk_weights->size(1)); - topk_weights_ptr = topk_weights->data_ptr(); - combined_topk_weights = torch::empty({num_combined_tokens, num_topk}, topk_weights->options()); - combined_topk_weights_ptr = combined_topk_weights->data_ptr(); - } - - EP_HOST_ASSERT(config.num_max_nvl_chunked_recv_tokens % num_rdma_ranks == 0); - EP_HOST_ASSERT(config.num_max_nvl_chunked_send_tokens <= config.num_max_nvl_chunked_recv_tokens / num_rdma_ranks); - - internode::cached_notify(hidden_int4, 0, 0, num_topk, - num_ranks, num_channels, - num_combined_tokens, combined_rdma_head.data_ptr(), - rdma_channel_prefix_matrix.data_ptr(), rdma_rank_prefix_sum.data_ptr(), combined_nvl_head.data_ptr(), - rdma_buffer_ptr, config.num_max_rdma_chunked_recv_tokens, - buffer_ptrs_gpu, config.num_max_nvl_chunked_recv_tokens, - task_fifo_ptrs_gpu, head, rank, comm_stream, - config.get_rdma_buffer_size_hint(hidden_int4 * sizeof(int4), num_ranks), - num_nvl_bytes, false, low_latency_mode, - port_channel_handles_device_ptr.get(), + // Just a barrier and clean flags + internode::cached_notify(hidden_int4, num_scales, num_topk, num_topk, num_ranks, num_channels, 0, nullptr, nullptr, + nullptr, nullptr, rdma_buffer_ptr, config.num_max_rdma_chunked_recv_tokens, + buffer_ptrs_gpu, config.num_max_nvl_chunked_recv_tokens, task_fifo_ptrs_gpu, head, rank, + comm_stream, config.get_rdma_buffer_size_hint(hidden_int4 * sizeof(int4), num_ranks), + num_nvl_bytes, true, low_latency_mode, port_channel_handles_device_ptr.get(), memory_channel_handles_device_ptr.get()); move_fifo_slots(2); + } else { + rdma_channel_prefix_matrix = + torch::empty({num_rdma_ranks, num_channels}, dtype(torch::kInt32).device(torch::kCUDA)); + recv_rdma_rank_prefix_sum = torch::empty({num_rdma_ranks}, dtype(torch::kInt32).device(torch::kCUDA)); + gbl_channel_prefix_matrix = torch::empty({num_ranks, num_channels}, dtype(torch::kInt32).device(torch::kCUDA)); + recv_gbl_rank_prefix_sum = torch::empty({num_ranks}, dtype(torch::kInt32).device(torch::kCUDA)); - auto combined_x = torch::empty({num_combined_tokens, hidden}, x.options()); - internode::combine(at::cuda::ScalarTypeToCudaDataType(x.scalar_type()), - combined_x.data_ptr(), combined_topk_weights_ptr, - is_combined_token_in_rank.data_ptr(), - x.data_ptr(), topk_weights_ptr, - combined_rdma_head.data_ptr(), combined_nvl_head.data_ptr(), - src_meta.data_ptr(), rdma_channel_prefix_matrix.data_ptr(), rdma_rank_prefix_sum.data_ptr(), gbl_channel_prefix_matrix.data_ptr(), - num_tokens, num_combined_tokens, hidden, num_topk, - rdma_buffer_ptr, config.num_max_rdma_chunked_send_tokens, config.num_max_rdma_chunked_recv_tokens, - buffer_ptrs_gpu, config.num_max_nvl_chunked_send_tokens, config.num_max_nvl_chunked_recv_tokens, - rank, num_ranks, comm_stream, num_channels, low_latency_mode, - port_channel_handles_device_ptr.get(), - memory_channel_handles_device_ptr.get()); + // Send sizes + *moe_recv_counter = -1, *moe_recv_rdma_counter = -1; + for (int i = 0; i < num_local_experts; ++i) moe_recv_expert_counter[i] = -1; + internode::notify_dispatch( + num_tokens_per_rank->data_ptr(), moe_recv_counter_mapped, num_ranks, + num_tokens_per_rdma_rank->data_ptr(), moe_recv_rdma_counter_mapped, num_tokens_per_expert->data_ptr(), + moe_recv_expert_counter_mapped, num_experts, is_token_in_rank.data_ptr(), num_tokens, num_channels, + hidden_int4, num_scales, num_topk, expert_alignment, rdma_channel_prefix_matrix.data_ptr(), + recv_rdma_rank_prefix_sum.data_ptr(), gbl_channel_prefix_matrix.data_ptr(), + recv_gbl_rank_prefix_sum.data_ptr(), rdma_buffer_ptr, config.num_max_rdma_chunked_recv_tokens, + buffer_ptrs_gpu, config.num_max_nvl_chunked_recv_tokens, task_fifo_ptrs_gpu, head, rank, comm_stream, + config.get_rdma_buffer_size_hint(hidden_int4 * sizeof(int4), num_ranks), num_nvl_bytes, low_latency_mode, + port_channel_handles_device_ptr.get(), memory_channel_handles_device_ptr.get()); + move_fifo_slots(3); - std::optional event; - if (async) { - event = EventHandle(comm_stream); - for (auto& t: {x, src_meta, - is_combined_token_in_rank, rdma_channel_prefix_matrix, rdma_rank_prefix_sum, gbl_channel_prefix_matrix, - combined_x, combined_rdma_head, combined_nvl_head}) { - t.record_stream(comm_stream); - if (allocate_on_comm_stream) - t.record_stream(compute_stream); - } - for (auto& to: {topk_weights, combined_topk_weights}) { - to.has_value() ? to->record_stream(comm_stream) : void(); - if (allocate_on_comm_stream) - to.has_value() ? to->record_stream(compute_stream) : void(); - } - } else { - stream_wait(compute_stream, comm_stream); + // Synchronize total received tokens and tokens per expert + auto start_time = std::chrono::high_resolution_clock::now(); + while (true) { + num_recv_tokens = static_cast(*moe_recv_counter); + num_rdma_recv_tokens = static_cast(*moe_recv_rdma_counter); + + bool ready = (num_recv_tokens >= 0) and (num_rdma_recv_tokens >= 0); + for (int i = 0; i < num_local_experts and ready; ++i) ready &= moe_recv_expert_counter[i] >= 0; + + if (ready) break; + + if (std::chrono::duration_cast(std::chrono::high_resolution_clock::now() - start_time) + .count() > NUM_CPU_TIMEOUT_SECS) { + printf("Global rank: %d, num_recv_tokens: %d, num_rdma_recv_tokens: %d\n", rank, num_recv_tokens, + num_rdma_recv_tokens); + for (int i = 0; i < num_local_experts; ++i) + printf("moe_recv_expert_counter[%d]: %d\n", i, moe_recv_expert_counter[i]); + throw std::runtime_error("mscclpp::ep error: timeout (internode_dispatch CPU)"); + } } + num_recv_tokens_per_expert_list = + std::vector(moe_recv_expert_counter, moe_recv_expert_counter + num_local_experts); + } - if (allocate_on_comm_stream) - at::cuda::setCurrentCUDAStream(compute_stream); + // Allocate new tensors + auto recv_x = torch::empty({num_recv_tokens, hidden}, x.options()); + auto recv_topk_idx = std::optional(), recv_topk_weights = std::optional(), + recv_x_scales = std::optional(); + auto recv_src_meta = std::optional(); + auto recv_rdma_channel_prefix_matrix = std::optional(); + auto recv_gbl_channel_prefix_matrix = std::optional(); + auto send_rdma_head = std::optional(); + auto send_nvl_head = std::optional(); + if (not cached_mode) { + recv_src_meta = + torch::empty({num_recv_tokens, internode::get_source_meta_bytes()}, dtype(torch::kByte).device(torch::kCUDA)); + recv_rdma_channel_prefix_matrix = + torch::empty({num_rdma_ranks, num_channels}, dtype(torch::kInt32).device(torch::kCUDA)); + recv_gbl_channel_prefix_matrix = torch::empty({num_ranks, num_channels}, dtype(torch::kInt32).device(torch::kCUDA)); + send_rdma_head = torch::empty({num_tokens, num_rdma_ranks}, dtype(torch::kInt32).device(torch::kCUDA)); + send_nvl_head = torch::empty({num_rdma_recv_tokens, NUM_MAX_NVL_PEERS}, dtype(torch::kInt32).device(torch::kCUDA)); + } - return {combined_x, combined_topk_weights, event}; + int64_t* recv_topk_idx_ptr = nullptr; + float* recv_topk_weights_ptr = nullptr; + float* recv_x_scales_ptr = nullptr; + if (topk_idx.has_value()) { + recv_topk_idx = torch::empty({num_recv_tokens, num_topk}, topk_idx->options()); + recv_topk_weights = torch::empty({num_recv_tokens, num_topk}, topk_weights->options()); + recv_topk_idx_ptr = recv_topk_idx->data_ptr(); + recv_topk_weights_ptr = recv_topk_weights->data_ptr(); + } + if (x_scales.has_value()) { + recv_x_scales = x_scales->dim() == 1 ? torch::empty({num_recv_tokens}, x_scales->options()) + : torch::empty({num_recv_tokens, num_scales}, x_scales->options()); + recv_x_scales_ptr = recv_x_scales->data_ptr(); + } + + // Launch data dispatch + internode::dispatch(recv_x.data_ptr(), recv_x_scales_ptr, recv_topk_idx_ptr, recv_topk_weights_ptr, + cached_mode ? nullptr : recv_src_meta->data_ptr(), x.data_ptr(), x_scales_ptr, topk_idx_ptr, + topk_weights_ptr, cached_mode ? nullptr : send_rdma_head->data_ptr(), + cached_mode ? nullptr : send_nvl_head->data_ptr(), + cached_mode ? nullptr : recv_rdma_channel_prefix_matrix->data_ptr(), + cached_mode ? nullptr : recv_gbl_channel_prefix_matrix->data_ptr(), + rdma_channel_prefix_matrix.data_ptr(), recv_rdma_rank_prefix_sum.data_ptr(), + gbl_channel_prefix_matrix.data_ptr(), recv_gbl_rank_prefix_sum.data_ptr(), num_tokens, + hidden_int4, num_scales, num_topk, num_experts, is_token_in_rank.data_ptr(), + rdma_buffer_ptr, config.num_max_rdma_chunked_send_tokens, config.num_max_rdma_chunked_recv_tokens, + buffer_ptrs_gpu, config.num_max_nvl_chunked_send_tokens, config.num_max_nvl_chunked_recv_tokens, + rank, num_ranks, cached_mode, comm_stream, num_channels, low_latency_mode, + port_channel_handles_device_ptr.get(), memory_channel_handles_device_ptr.get()); + + // Wait streams + std::optional event; + if (async) { + event = EventHandle(comm_stream); + for (auto& t : {x, is_token_in_rank, recv_x, rdma_channel_prefix_matrix, recv_rdma_rank_prefix_sum, + gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum}) { + t.record_stream(comm_stream); + if (allocate_on_comm_stream) t.record_stream(compute_stream); + } + for (auto& to : {x_scales, topk_idx, topk_weights, num_tokens_per_rank, num_tokens_per_rdma_rank, + num_tokens_per_expert, cached_rdma_channel_prefix_matrix, cached_recv_rdma_rank_prefix_sum, + cached_gbl_channel_prefix_matrix, cached_recv_gbl_rank_prefix_sum, recv_topk_idx, + recv_topk_weights, recv_x_scales, recv_rdma_channel_prefix_matrix, recv_gbl_channel_prefix_matrix, + send_rdma_head, send_nvl_head, recv_src_meta}) { + to.has_value() ? to->record_stream(comm_stream) : void(); + if (allocate_on_comm_stream) to.has_value() ? to->record_stream(compute_stream) : void(); + } + } else { + stream_wait(compute_stream, comm_stream); + } + + if (allocate_on_comm_stream) at::cuda::setCurrentCUDAStream(compute_stream); + + return {recv_x, + recv_x_scales, + recv_topk_idx, + recv_topk_weights, + num_recv_tokens_per_expert_list, + rdma_channel_prefix_matrix, + gbl_channel_prefix_matrix, + recv_rdma_channel_prefix_matrix, + recv_rdma_rank_prefix_sum, + recv_gbl_channel_prefix_matrix, + recv_gbl_rank_prefix_sum, + recv_src_meta, + send_rdma_head, + send_nvl_head, + event}; +} + +std::tuple, std::optional> Buffer::internode_combine( + const torch::Tensor& x, const std::optional& topk_weights, const torch::Tensor& src_meta, + const torch::Tensor& is_combined_token_in_rank, const torch::Tensor& rdma_channel_prefix_matrix, + const torch::Tensor& rdma_rank_prefix_sum, const torch::Tensor& gbl_channel_prefix_matrix, + const torch::Tensor& combined_rdma_head, const torch::Tensor& combined_nvl_head, const Config& config, + std::optional& previous_event, bool async, bool allocate_on_comm_stream) { + const int num_channels = config.num_sms / 2; + EP_HOST_ASSERT(config.num_sms % 2 == 0); + + // Shape and contiguous checks + EP_HOST_ASSERT(x.dim() == 2 and x.is_contiguous()); + EP_HOST_ASSERT(src_meta.dim() == 2 and src_meta.is_contiguous() and src_meta.scalar_type() == torch::kByte); + EP_HOST_ASSERT(is_combined_token_in_rank.dim() == 2 and is_combined_token_in_rank.is_contiguous() and + is_combined_token_in_rank.scalar_type() == torch::kBool); + EP_HOST_ASSERT(rdma_channel_prefix_matrix.dim() == 2 and rdma_channel_prefix_matrix.is_contiguous() and + rdma_channel_prefix_matrix.scalar_type() == torch::kInt32); + EP_HOST_ASSERT(rdma_rank_prefix_sum.dim() == 1 and rdma_rank_prefix_sum.is_contiguous() and + rdma_rank_prefix_sum.scalar_type() == torch::kInt32); + EP_HOST_ASSERT(gbl_channel_prefix_matrix.dim() == 2 and gbl_channel_prefix_matrix.is_contiguous() and + gbl_channel_prefix_matrix.scalar_type() == torch::kInt32); + EP_HOST_ASSERT(combined_rdma_head.dim() == 2 and combined_rdma_head.is_contiguous() and + combined_rdma_head.scalar_type() == torch::kInt32); + EP_HOST_ASSERT(combined_nvl_head.dim() == 2 and combined_nvl_head.is_contiguous() and + combined_nvl_head.scalar_type() == torch::kInt32); + + auto num_tokens = static_cast(x.size(0)), hidden = static_cast(x.size(1)), + hidden_int4 = static_cast(x.size(1) * x.element_size() / sizeof(int4)); + auto num_combined_tokens = static_cast(is_combined_token_in_rank.size(0)); + EP_HOST_ASSERT((hidden * x.element_size()) % sizeof(int4) == 0); + EP_HOST_ASSERT(src_meta.size(1) == internode::get_source_meta_bytes()); + EP_HOST_ASSERT(is_combined_token_in_rank.size(1) == num_ranks); + EP_HOST_ASSERT(rdma_channel_prefix_matrix.size(0) == num_rdma_ranks and + rdma_channel_prefix_matrix.size(1) == num_channels); + EP_HOST_ASSERT(rdma_rank_prefix_sum.size(0) == num_rdma_ranks); + EP_HOST_ASSERT(gbl_channel_prefix_matrix.size(0) == num_ranks and gbl_channel_prefix_matrix.size(1) == num_channels); + EP_HOST_ASSERT(combined_rdma_head.dim() == 2 and combined_rdma_head.size(0) == num_combined_tokens and + combined_rdma_head.size(1) == num_rdma_ranks); + EP_HOST_ASSERT(combined_nvl_head.dim() == 2 and combined_nvl_head.size(1) == NUM_MAX_NVL_PEERS); + + auto compute_stream = at::cuda::getCurrentCUDAStream(); + if (allocate_on_comm_stream) { + EP_HOST_ASSERT(previous_event.has_value() and async); + at::cuda::setCurrentCUDAStream(comm_stream); + } + + if (previous_event.has_value()) { + stream_wait(comm_stream, previous_event.value()); + } else { + stream_wait(comm_stream, compute_stream); + } + + int num_topk = 0; + auto combined_topk_weights = std::optional(); + float* topk_weights_ptr = nullptr; + float* combined_topk_weights_ptr = nullptr; + if (topk_weights.has_value()) { + EP_HOST_ASSERT(topk_weights->dim() == 2 and topk_weights->is_contiguous()); + EP_HOST_ASSERT(topk_weights->size(0) == num_tokens); + EP_HOST_ASSERT(topk_weights->scalar_type() == torch::kFloat32); + num_topk = static_cast(topk_weights->size(1)); + topk_weights_ptr = topk_weights->data_ptr(); + combined_topk_weights = torch::empty({num_combined_tokens, num_topk}, topk_weights->options()); + combined_topk_weights_ptr = combined_topk_weights->data_ptr(); + } + + EP_HOST_ASSERT(config.num_max_nvl_chunked_recv_tokens % num_rdma_ranks == 0); + EP_HOST_ASSERT(config.num_max_nvl_chunked_send_tokens <= config.num_max_nvl_chunked_recv_tokens / num_rdma_ranks); + + internode::cached_notify( + hidden_int4, 0, 0, num_topk, num_ranks, num_channels, num_combined_tokens, combined_rdma_head.data_ptr(), + rdma_channel_prefix_matrix.data_ptr(), rdma_rank_prefix_sum.data_ptr(), + combined_nvl_head.data_ptr(), rdma_buffer_ptr, config.num_max_rdma_chunked_recv_tokens, buffer_ptrs_gpu, + config.num_max_nvl_chunked_recv_tokens, task_fifo_ptrs_gpu, head, rank, comm_stream, + config.get_rdma_buffer_size_hint(hidden_int4 * sizeof(int4), num_ranks), num_nvl_bytes, false, low_latency_mode, + port_channel_handles_device_ptr.get(), memory_channel_handles_device_ptr.get()); + move_fifo_slots(2); + + auto combined_x = torch::empty({num_combined_tokens, hidden}, x.options()); + internode::combine(at::cuda::ScalarTypeToCudaDataType(x.scalar_type()), combined_x.data_ptr(), + combined_topk_weights_ptr, is_combined_token_in_rank.data_ptr(), x.data_ptr(), + topk_weights_ptr, combined_rdma_head.data_ptr(), combined_nvl_head.data_ptr(), + src_meta.data_ptr(), rdma_channel_prefix_matrix.data_ptr(), + rdma_rank_prefix_sum.data_ptr(), gbl_channel_prefix_matrix.data_ptr(), num_tokens, + num_combined_tokens, hidden, num_topk, rdma_buffer_ptr, config.num_max_rdma_chunked_send_tokens, + config.num_max_rdma_chunked_recv_tokens, buffer_ptrs_gpu, config.num_max_nvl_chunked_send_tokens, + config.num_max_nvl_chunked_recv_tokens, rank, num_ranks, comm_stream, num_channels, + low_latency_mode, port_channel_handles_device_ptr.get(), memory_channel_handles_device_ptr.get()); + + std::optional event; + if (async) { + event = EventHandle(comm_stream); + for (auto& t : {x, src_meta, is_combined_token_in_rank, rdma_channel_prefix_matrix, rdma_rank_prefix_sum, + gbl_channel_prefix_matrix, combined_x, combined_rdma_head, combined_nvl_head}) { + t.record_stream(comm_stream); + if (allocate_on_comm_stream) t.record_stream(compute_stream); + } + for (auto& to : {topk_weights, combined_topk_weights}) { + to.has_value() ? to->record_stream(comm_stream) : void(); + if (allocate_on_comm_stream) to.has_value() ? to->record_stream(compute_stream) : void(); + } + } else { + stream_wait(compute_stream, comm_stream); + } + + if (allocate_on_comm_stream) at::cuda::setCurrentCUDAStream(compute_stream); + + return {combined_x, combined_topk_weights, event}; } void Buffer::clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts) { - EP_HOST_ASSERT(low_latency_mode); + EP_HOST_ASSERT(low_latency_mode); - auto layout = LowLatencyLayout(rdma_buffer_ptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts); - auto clean_meta_0 = layout.buffers[0].clean_meta(); - auto clean_meta_1 = layout.buffers[1].clean_meta(); + auto layout = LowLatencyLayout(rdma_buffer_ptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts); + auto clean_meta_0 = layout.buffers[0].clean_meta(); + auto clean_meta_1 = layout.buffers[1].clean_meta(); - auto check_boundary = [=](void* ptr, size_t num_bytes) { - auto offset = reinterpret_cast(ptr) - reinterpret_cast(rdma_buffer_ptr); - EP_HOST_ASSERT(0 <= offset and offset + static_cast(num_bytes) <= num_rdma_bytes); - }; - check_boundary(clean_meta_0.first, clean_meta_0.second * sizeof(int)); - check_boundary(clean_meta_1.first, clean_meta_1.second * sizeof(int)); + auto check_boundary = [=](void* ptr, size_t num_bytes) { + auto offset = reinterpret_cast(ptr) - reinterpret_cast(rdma_buffer_ptr); + EP_HOST_ASSERT(0 <= offset and offset + static_cast(num_bytes) <= num_rdma_bytes); + }; + check_boundary(clean_meta_0.first, clean_meta_0.second * sizeof(int)); + check_boundary(clean_meta_1.first, clean_meta_1.second * sizeof(int)); - internode_ll::clean_low_latency_buffer(clean_meta_0.first, clean_meta_0.second, - clean_meta_1.first, clean_meta_1.second, - rank, num_ranks, - port_channel_handles_device_ptr.get(), - ll_memory_channel_handles_device_ptr ? - ll_memory_channel_handles_device_ptr.get() : nullptr, - ll_ipc_ready, - at::cuda::getCurrentCUDAStream()); + internode_ll::clean_low_latency_buffer( + clean_meta_0.first, clean_meta_0.second, clean_meta_1.first, clean_meta_1.second, rank, num_ranks, + port_channel_handles_device_ptr.get(), + ll_memory_channel_handles_device_ptr ? ll_memory_channel_handles_device_ptr.get() : nullptr, ll_ipc_ready, + at::cuda::getCurrentCUDAStream()); } -std::tuple, torch::Tensor, torch::Tensor, torch::Tensor, std::optional, std::optional>> +std::tuple, torch::Tensor, torch::Tensor, torch::Tensor, + std::optional, std::optional>> Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx, - int num_max_dispatch_tokens_per_rank, int num_experts, - bool use_fp8, bool async, bool return_recv_hook, - const std::optional& out_packed_recv_x, + int num_max_dispatch_tokens_per_rank, int num_experts, bool use_fp8, bool async, + bool return_recv_hook, const std::optional& out_packed_recv_x, const std::optional& out_packed_recv_x_scales, const std::optional& out_packed_recv_src_info, const std::optional& out_packed_recv_layout_range, const std::optional& out_packed_recv_count) { - EP_HOST_ASSERT(low_latency_mode); + EP_HOST_ASSERT(low_latency_mode); - EP_HOST_ASSERT(x.dim() == 2 and x.is_contiguous() and x.scalar_type() == torch::kBFloat16); - EP_HOST_ASSERT(x.size(1) % sizeof(int4) == 0 and x.size(1) % 128 == 0); - EP_HOST_ASSERT(topk_idx.dim() == 2 and topk_idx.is_contiguous()); - EP_HOST_ASSERT(x.size(0) == topk_idx.size(0) and x.size(0) <= num_max_dispatch_tokens_per_rank); - EP_HOST_ASSERT(topk_idx.scalar_type() == torch::kInt64); - EP_HOST_ASSERT(num_experts % num_ranks == 0); + EP_HOST_ASSERT(x.dim() == 2 and x.is_contiguous() and x.scalar_type() == torch::kBFloat16); + EP_HOST_ASSERT(x.size(1) % sizeof(int4) == 0 and x.size(1) % 128 == 0); + EP_HOST_ASSERT(topk_idx.dim() == 2 and topk_idx.is_contiguous()); + EP_HOST_ASSERT(x.size(0) == topk_idx.size(0) and x.size(0) <= num_max_dispatch_tokens_per_rank); + EP_HOST_ASSERT(topk_idx.scalar_type() == torch::kInt64); + EP_HOST_ASSERT(num_experts % num_ranks == 0); - auto num_tokens = static_cast(x.size(0)), hidden = static_cast(x.size(1)); - auto num_scales = hidden / 128, num_topk = static_cast(topk_idx.size(1)); - int num_local_experts = num_experts / num_ranks; + auto num_tokens = static_cast(x.size(0)), hidden = static_cast(x.size(1)); + auto num_scales = hidden / 128, num_topk = static_cast(topk_idx.size(1)); + int num_local_experts = num_experts / num_ranks; - LowLatencyLayout layout(rdma_buffer_ptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts); - EP_HOST_ASSERT(layout.total_bytes <= num_rdma_bytes); - auto buffer = layout.buffers[low_latency_buffer_idx]; - auto next_buffer = layout.buffers[low_latency_buffer_idx ^= 1]; + LowLatencyLayout layout(rdma_buffer_ptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts); + EP_HOST_ASSERT(layout.total_bytes <= num_rdma_bytes); + auto buffer = layout.buffers[low_latency_buffer_idx]; + auto next_buffer = layout.buffers[low_latency_buffer_idx ^= 1]; - auto compute_stream = at::cuda::getCurrentCUDAStream(); - auto launch_stream = return_recv_hook ? compute_stream : comm_stream; - EP_HOST_ASSERT(not (async and return_recv_hook)); - if (not return_recv_hook) - stream_wait(launch_stream, compute_stream); + auto compute_stream = at::cuda::getCurrentCUDAStream(); + auto launch_stream = return_recv_hook ? compute_stream : comm_stream; + EP_HOST_ASSERT(not(async and return_recv_hook)); + if (not return_recv_hook) stream_wait(launch_stream, compute_stream); - // Reusable output tensors. The largest (`packed_recv_x` ~58 MB at 7K hidden) - // is what motivates the reuse path: a fresh torch::empty per call adds - // measurable host overhead (~10us cumulative for the 4 allocations) which - // shows up against NCCL-EP's preallocated bench at small payloads. - const auto recv_x_dtype = use_fp8 ? torch::kFloat8_e4m3fn : torch::kBFloat16; - torch::Tensor packed_recv_x; - if (out_packed_recv_x.has_value()) { - EP_HOST_ASSERT(out_packed_recv_x->dim() == 3 and out_packed_recv_x->is_contiguous()); - EP_HOST_ASSERT(out_packed_recv_x->size(0) == num_local_experts); - EP_HOST_ASSERT(out_packed_recv_x->size(1) == num_ranks * num_max_dispatch_tokens_per_rank); - EP_HOST_ASSERT(out_packed_recv_x->size(2) == hidden); - EP_HOST_ASSERT(out_packed_recv_x->scalar_type() == recv_x_dtype); - packed_recv_x = out_packed_recv_x.value(); + // Reusable output tensors. The largest (`packed_recv_x` ~58 MB at 7K hidden) + // is what motivates the reuse path: a fresh torch::empty per call adds + // measurable host overhead (~10us cumulative for the 4 allocations) which + // shows up against NCCL-EP's preallocated bench at small payloads. + const auto recv_x_dtype = use_fp8 ? torch::kFloat8_e4m3fn : torch::kBFloat16; + torch::Tensor packed_recv_x; + if (out_packed_recv_x.has_value()) { + EP_HOST_ASSERT(out_packed_recv_x->dim() == 3 and out_packed_recv_x->is_contiguous()); + EP_HOST_ASSERT(out_packed_recv_x->size(0) == num_local_experts); + EP_HOST_ASSERT(out_packed_recv_x->size(1) == num_ranks * num_max_dispatch_tokens_per_rank); + EP_HOST_ASSERT(out_packed_recv_x->size(2) == hidden); + EP_HOST_ASSERT(out_packed_recv_x->scalar_type() == recv_x_dtype); + packed_recv_x = out_packed_recv_x.value(); + } else { + packed_recv_x = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, hidden}, + x.options().dtype(recv_x_dtype)); + } + torch::Tensor packed_recv_src_info; + if (out_packed_recv_src_info.has_value()) { + EP_HOST_ASSERT(out_packed_recv_src_info->dim() == 2 and out_packed_recv_src_info->is_contiguous()); + EP_HOST_ASSERT(out_packed_recv_src_info->size(0) == num_local_experts); + EP_HOST_ASSERT(out_packed_recv_src_info->size(1) == num_ranks * num_max_dispatch_tokens_per_rank); + EP_HOST_ASSERT(out_packed_recv_src_info->scalar_type() == torch::kInt32); + packed_recv_src_info = out_packed_recv_src_info.value(); + } else { + packed_recv_src_info = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank}, + torch::dtype(torch::kInt32).device(torch::kCUDA)); + } + torch::Tensor packed_recv_layout_range; + if (out_packed_recv_layout_range.has_value()) { + EP_HOST_ASSERT(out_packed_recv_layout_range->dim() == 2 and out_packed_recv_layout_range->is_contiguous()); + EP_HOST_ASSERT(out_packed_recv_layout_range->size(0) == num_local_experts); + EP_HOST_ASSERT(out_packed_recv_layout_range->size(1) == num_ranks); + EP_HOST_ASSERT(out_packed_recv_layout_range->scalar_type() == torch::kInt64); + packed_recv_layout_range = out_packed_recv_layout_range.value(); + } else { + packed_recv_layout_range = + torch::empty({num_local_experts, num_ranks}, torch::dtype(torch::kInt64).device(torch::kCUDA)); + } + torch::Tensor packed_recv_count; + if (out_packed_recv_count.has_value()) { + EP_HOST_ASSERT(out_packed_recv_count->dim() == 1 and out_packed_recv_count->is_contiguous()); + EP_HOST_ASSERT(out_packed_recv_count->size(0) == num_local_experts); + EP_HOST_ASSERT(out_packed_recv_count->scalar_type() == torch::kInt32); + packed_recv_count = out_packed_recv_count.value(); + } else { + packed_recv_count = torch::empty({num_local_experts}, torch::dtype(torch::kInt32).device(torch::kCUDA)); + } + + auto packed_recv_x_scales = std::optional(); + float* packed_recv_x_scales_ptr = nullptr; + if (use_fp8) { + EP_HOST_ASSERT((num_ranks * num_max_dispatch_tokens_per_rank) % 4 == 0 and + "TMA requires the number of tokens to be multiple of 4"); + if (out_packed_recv_x_scales.has_value()) { + // Caller-provided scales tensor must already be in the kernel's + // expected (transposed) layout: shape [num_local_experts, + // num_ranks*max_tokens, num_scales], strides such that + // size(1)=num_ranks*max_tokens with the actual storage + // [num_local_experts, num_scales, num_ranks*max_tokens] (i.e. + // produced by `torch.empty(...).transpose(1, 2)`). + EP_HOST_ASSERT(out_packed_recv_x_scales->dim() == 3); + EP_HOST_ASSERT(out_packed_recv_x_scales->size(0) == num_local_experts); + EP_HOST_ASSERT(out_packed_recv_x_scales->size(1) == num_ranks * num_max_dispatch_tokens_per_rank); + EP_HOST_ASSERT(out_packed_recv_x_scales->size(2) == num_scales); + EP_HOST_ASSERT(out_packed_recv_x_scales->scalar_type() == torch::kFloat32); + packed_recv_x_scales = out_packed_recv_x_scales.value(); } else { - packed_recv_x = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, hidden}, - x.options().dtype(recv_x_dtype)); - } - torch::Tensor packed_recv_src_info; - if (out_packed_recv_src_info.has_value()) { - EP_HOST_ASSERT(out_packed_recv_src_info->dim() == 2 and out_packed_recv_src_info->is_contiguous()); - EP_HOST_ASSERT(out_packed_recv_src_info->size(0) == num_local_experts); - EP_HOST_ASSERT(out_packed_recv_src_info->size(1) == num_ranks * num_max_dispatch_tokens_per_rank); - EP_HOST_ASSERT(out_packed_recv_src_info->scalar_type() == torch::kInt32); - packed_recv_src_info = out_packed_recv_src_info.value(); - } else { - packed_recv_src_info = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank}, - torch::dtype(torch::kInt32).device(torch::kCUDA)); - } - torch::Tensor packed_recv_layout_range; - if (out_packed_recv_layout_range.has_value()) { - EP_HOST_ASSERT(out_packed_recv_layout_range->dim() == 2 and out_packed_recv_layout_range->is_contiguous()); - EP_HOST_ASSERT(out_packed_recv_layout_range->size(0) == num_local_experts); - EP_HOST_ASSERT(out_packed_recv_layout_range->size(1) == num_ranks); - EP_HOST_ASSERT(out_packed_recv_layout_range->scalar_type() == torch::kInt64); - packed_recv_layout_range = out_packed_recv_layout_range.value(); - } else { - packed_recv_layout_range = torch::empty({num_local_experts, num_ranks}, - torch::dtype(torch::kInt64).device(torch::kCUDA)); - } - torch::Tensor packed_recv_count; - if (out_packed_recv_count.has_value()) { - EP_HOST_ASSERT(out_packed_recv_count->dim() == 1 and out_packed_recv_count->is_contiguous()); - EP_HOST_ASSERT(out_packed_recv_count->size(0) == num_local_experts); - EP_HOST_ASSERT(out_packed_recv_count->scalar_type() == torch::kInt32); - packed_recv_count = out_packed_recv_count.value(); - } else { - packed_recv_count = torch::empty({num_local_experts}, torch::dtype(torch::kInt32).device(torch::kCUDA)); + packed_recv_x_scales = torch::empty({num_local_experts, num_scales, num_ranks * num_max_dispatch_tokens_per_rank}, + torch::dtype(torch::kFloat32).device(torch::kCUDA)); + packed_recv_x_scales = torch::transpose(packed_recv_x_scales.value(), 1, 2); } + packed_recv_x_scales_ptr = packed_recv_x_scales->data_ptr(); + } - auto packed_recv_x_scales = std::optional(); - float* packed_recv_x_scales_ptr = nullptr; - if (use_fp8) { - EP_HOST_ASSERT((num_ranks * num_max_dispatch_tokens_per_rank) % 4 == 0 and "TMA requires the number of tokens to be multiple of 4"); - if (out_packed_recv_x_scales.has_value()) { - // Caller-provided scales tensor must already be in the kernel's - // expected (transposed) layout: shape [num_local_experts, - // num_ranks*max_tokens, num_scales], strides such that - // size(1)=num_ranks*max_tokens with the actual storage - // [num_local_experts, num_scales, num_ranks*max_tokens] (i.e. - // produced by `torch.empty(...).transpose(1, 2)`). - EP_HOST_ASSERT(out_packed_recv_x_scales->dim() == 3); - EP_HOST_ASSERT(out_packed_recv_x_scales->size(0) == num_local_experts); - EP_HOST_ASSERT(out_packed_recv_x_scales->size(1) == num_ranks * num_max_dispatch_tokens_per_rank); - EP_HOST_ASSERT(out_packed_recv_x_scales->size(2) == num_scales); - EP_HOST_ASSERT(out_packed_recv_x_scales->scalar_type() == torch::kFloat32); - packed_recv_x_scales = out_packed_recv_x_scales.value(); - } else { - packed_recv_x_scales = torch::empty({num_local_experts, num_scales, num_ranks * num_max_dispatch_tokens_per_rank}, - torch::dtype(torch::kFloat32).device(torch::kCUDA)); - packed_recv_x_scales = torch::transpose(packed_recv_x_scales.value(), 1, 2); - } - packed_recv_x_scales_ptr = packed_recv_x_scales->data_ptr(); - } + auto next_clean_meta = next_buffer.clean_meta(); + auto port_handles = port_channel_handles_device_ptr.get(); + auto mem_handles = ll_memory_channel_handles_device_ptr ? ll_memory_channel_handles_device_ptr.get() : nullptr; + auto peer_bases = peer_rdma_bases_gpu; + const bool use_ipc = ll_ipc_ready; + auto rdma_base = rdma_buffer_ptr; + auto launcher = [=](int phases) { + internode_ll::dispatch(packed_recv_x.data_ptr(), packed_recv_x_scales_ptr, packed_recv_src_info.data_ptr(), + packed_recv_layout_range.data_ptr(), packed_recv_count.data_ptr(), + buffer.dispatch_rdma_recv_data_buffer, buffer.dispatch_rdma_recv_count_buffer, + buffer.dispatch_rdma_send_buffer, x.data_ptr(), topk_idx.data_ptr(), + next_clean_meta.first, next_clean_meta.second, num_tokens, hidden, + num_max_dispatch_tokens_per_rank, num_topk, num_experts, rank, num_ranks, use_fp8, workspace, + launch_stream, phases, rdma_base, port_handles, peer_bases, mem_handles, use_ipc); + }; + launcher(return_recv_hook ? LOW_LATENCY_SEND_PHASE : (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE)); - auto next_clean_meta = next_buffer.clean_meta(); - auto port_handles = port_channel_handles_device_ptr.get(); - auto mem_handles = ll_memory_channel_handles_device_ptr ? - ll_memory_channel_handles_device_ptr.get() : nullptr; - auto peer_bases = peer_rdma_bases_gpu; - const bool use_ipc = ll_ipc_ready; - auto rdma_base = rdma_buffer_ptr; - auto launcher = [=](int phases) { - internode_ll::dispatch(packed_recv_x.data_ptr(), packed_recv_x_scales_ptr, - packed_recv_src_info.data_ptr(), packed_recv_layout_range.data_ptr(), - packed_recv_count.data_ptr(), - buffer.dispatch_rdma_recv_data_buffer, buffer.dispatch_rdma_recv_count_buffer, - buffer.dispatch_rdma_send_buffer, - x.data_ptr(), topk_idx.data_ptr(), - next_clean_meta.first, next_clean_meta.second, - num_tokens, hidden, num_max_dispatch_tokens_per_rank, - num_topk, num_experts, rank, num_ranks, use_fp8, - workspace, launch_stream, phases, - rdma_base, port_handles, - peer_bases, mem_handles, use_ipc); - }; - launcher(return_recv_hook ? LOW_LATENCY_SEND_PHASE : (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE)); + std::optional event; + if (async) { + event = EventHandle(launch_stream); + } else if (not return_recv_hook) { + stream_wait(compute_stream, launch_stream); + } - std::optional event; - if (async) { - event = EventHandle(launch_stream); - } else if (not return_recv_hook) { - stream_wait(compute_stream, launch_stream); - } + std::optional> recv_hook = std::nullopt; + if (return_recv_hook) recv_hook = [=]() { launcher(LOW_LATENCY_RECV_PHASE); }; - std::optional> recv_hook = std::nullopt; - if (return_recv_hook) - recv_hook = [=]() { launcher(LOW_LATENCY_RECV_PHASE); }; - - return {packed_recv_x, packed_recv_x_scales, packed_recv_count, packed_recv_src_info, packed_recv_layout_range, event, recv_hook}; + return {packed_recv_x, packed_recv_x_scales, packed_recv_count, packed_recv_src_info, packed_recv_layout_range, event, + recv_hook}; } -std::tuple, std::optional>> -Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_idx, const torch::Tensor& topk_weights, - const torch::Tensor& src_info, const torch::Tensor& layout_range, - int num_max_dispatch_tokens_per_rank, int num_experts, - bool zero_copy, bool async, bool return_recv_hook, - const std::optional& out) { - EP_HOST_ASSERT(low_latency_mode); +std::tuple, std::optional>> Buffer::low_latency_combine( + const torch::Tensor& x, const torch::Tensor& topk_idx, const torch::Tensor& topk_weights, + const torch::Tensor& src_info, const torch::Tensor& layout_range, int num_max_dispatch_tokens_per_rank, + int num_experts, bool zero_copy, bool async, bool return_recv_hook, const std::optional& out) { + EP_HOST_ASSERT(low_latency_mode); - EP_HOST_ASSERT(x.dim() == 3 and x.is_contiguous() and x.scalar_type() == torch::kBFloat16); - EP_HOST_ASSERT(x.size(0) == num_experts / num_ranks); - EP_HOST_ASSERT(x.size(1) == num_ranks * num_max_dispatch_tokens_per_rank); - EP_HOST_ASSERT(x.size(2) % sizeof(int4) == 0 and x.size(2) % 128 == 0); - EP_HOST_ASSERT(topk_idx.dim() == 2 and topk_idx.is_contiguous()); - EP_HOST_ASSERT(topk_idx.size(0) == topk_weights.size(0) and topk_idx.size(1) == topk_weights.size(1)); - EP_HOST_ASSERT(topk_idx.scalar_type() == torch::kInt64); - EP_HOST_ASSERT(topk_weights.dim() == 2 and topk_weights.is_contiguous()); - EP_HOST_ASSERT(topk_weights.size(0) <= num_max_dispatch_tokens_per_rank); - EP_HOST_ASSERT(topk_weights.scalar_type() == torch::kFloat32); - EP_HOST_ASSERT(src_info.dim() == 2 and src_info.is_contiguous()); - EP_HOST_ASSERT(src_info.scalar_type() == torch::kInt32 and x.size(0) == src_info.size(0)); - EP_HOST_ASSERT(layout_range.dim() == 2 and layout_range.is_contiguous()); - EP_HOST_ASSERT(layout_range.scalar_type() == torch::kInt64); - EP_HOST_ASSERT(layout_range.size(0) == num_experts / num_ranks and layout_range.size(1) == num_ranks); - auto hidden = static_cast(x.size(2)); - auto num_topk = static_cast(topk_weights.size(1)); - auto num_combined_tokens = static_cast(topk_weights.size(0)); + EP_HOST_ASSERT(x.dim() == 3 and x.is_contiguous() and x.scalar_type() == torch::kBFloat16); + EP_HOST_ASSERT(x.size(0) == num_experts / num_ranks); + EP_HOST_ASSERT(x.size(1) == num_ranks * num_max_dispatch_tokens_per_rank); + EP_HOST_ASSERT(x.size(2) % sizeof(int4) == 0 and x.size(2) % 128 == 0); + EP_HOST_ASSERT(topk_idx.dim() == 2 and topk_idx.is_contiguous()); + EP_HOST_ASSERT(topk_idx.size(0) == topk_weights.size(0) and topk_idx.size(1) == topk_weights.size(1)); + EP_HOST_ASSERT(topk_idx.scalar_type() == torch::kInt64); + EP_HOST_ASSERT(topk_weights.dim() == 2 and topk_weights.is_contiguous()); + EP_HOST_ASSERT(topk_weights.size(0) <= num_max_dispatch_tokens_per_rank); + EP_HOST_ASSERT(topk_weights.scalar_type() == torch::kFloat32); + EP_HOST_ASSERT(src_info.dim() == 2 and src_info.is_contiguous()); + EP_HOST_ASSERT(src_info.scalar_type() == torch::kInt32 and x.size(0) == src_info.size(0)); + EP_HOST_ASSERT(layout_range.dim() == 2 and layout_range.is_contiguous()); + EP_HOST_ASSERT(layout_range.scalar_type() == torch::kInt64); + EP_HOST_ASSERT(layout_range.size(0) == num_experts / num_ranks and layout_range.size(1) == num_ranks); + auto hidden = static_cast(x.size(2)); + auto num_topk = static_cast(topk_weights.size(1)); + auto num_combined_tokens = static_cast(topk_weights.size(0)); - LowLatencyLayout layout(rdma_buffer_ptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts); - EP_HOST_ASSERT(layout.total_bytes <= num_rdma_bytes); - auto buffer = layout.buffers[low_latency_buffer_idx]; - auto next_buffer = layout.buffers[low_latency_buffer_idx ^= 1]; + LowLatencyLayout layout(rdma_buffer_ptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts); + EP_HOST_ASSERT(layout.total_bytes <= num_rdma_bytes); + auto buffer = layout.buffers[low_latency_buffer_idx]; + auto next_buffer = layout.buffers[low_latency_buffer_idx ^= 1]; - auto compute_stream = at::cuda::getCurrentCUDAStream(); - auto launch_stream = return_recv_hook ? compute_stream : comm_stream; - EP_HOST_ASSERT(not (async and return_recv_hook)); - if (not return_recv_hook) - stream_wait(launch_stream, compute_stream); + auto compute_stream = at::cuda::getCurrentCUDAStream(); + auto launch_stream = return_recv_hook ? compute_stream : comm_stream; + EP_HOST_ASSERT(not(async and return_recv_hook)); + if (not return_recv_hook) stream_wait(launch_stream, compute_stream); - torch::Tensor combined_x; - if (out.has_value()) { - EP_HOST_ASSERT(out->dim() == 2 and out->is_contiguous()); - EP_HOST_ASSERT(out->size(0) == num_combined_tokens and out->size(1) == hidden); - EP_HOST_ASSERT(out->scalar_type() == x.scalar_type()); - combined_x = out.value(); - } else { - combined_x = torch::empty({num_combined_tokens, hidden}, x.options()); - } + torch::Tensor combined_x; + if (out.has_value()) { + EP_HOST_ASSERT(out->dim() == 2 and out->is_contiguous()); + EP_HOST_ASSERT(out->size(0) == num_combined_tokens and out->size(1) == hidden); + EP_HOST_ASSERT(out->scalar_type() == x.scalar_type()); + combined_x = out.value(); + } else { + combined_x = torch::empty({num_combined_tokens, hidden}, x.options()); + } - auto next_clean_meta = next_buffer.clean_meta(); - auto port_handles = port_channel_handles_device_ptr.get(); - auto mem_handles = ll_memory_channel_handles_device_ptr ? - ll_memory_channel_handles_device_ptr.get() : nullptr; - auto peer_bases = peer_rdma_bases_gpu; - const bool use_ipc = ll_ipc_ready; - auto rdma_base = rdma_buffer_ptr; - auto launcher = [=](int phases) { - internode_ll::combine(combined_x.data_ptr(), - buffer.combine_rdma_recv_data_buffer, buffer.combine_rdma_recv_flag_buffer, - buffer.combine_rdma_send_buffer, - x.data_ptr(), topk_idx.data_ptr(), topk_weights.data_ptr(), - src_info.data_ptr(), layout_range.data_ptr(), - next_clean_meta.first, next_clean_meta.second, - num_combined_tokens, hidden, num_max_dispatch_tokens_per_rank, - num_topk, num_experts, rank, num_ranks, - workspace, launch_stream, - phases, zero_copy, - rdma_base, port_handles, - peer_bases, mem_handles, use_ipc); - }; - launcher(return_recv_hook ? LOW_LATENCY_SEND_PHASE : (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE)); + auto next_clean_meta = next_buffer.clean_meta(); + auto port_handles = port_channel_handles_device_ptr.get(); + auto mem_handles = ll_memory_channel_handles_device_ptr ? ll_memory_channel_handles_device_ptr.get() : nullptr; + auto peer_bases = peer_rdma_bases_gpu; + const bool use_ipc = ll_ipc_ready; + auto rdma_base = rdma_buffer_ptr; + auto launcher = [=](int phases) { + internode_ll::combine( + combined_x.data_ptr(), buffer.combine_rdma_recv_data_buffer, buffer.combine_rdma_recv_flag_buffer, + buffer.combine_rdma_send_buffer, x.data_ptr(), topk_idx.data_ptr(), topk_weights.data_ptr(), + src_info.data_ptr(), layout_range.data_ptr(), next_clean_meta.first, next_clean_meta.second, + num_combined_tokens, hidden, num_max_dispatch_tokens_per_rank, num_topk, num_experts, rank, num_ranks, + workspace, launch_stream, phases, zero_copy, rdma_base, port_handles, peer_bases, mem_handles, use_ipc); + }; + launcher(return_recv_hook ? LOW_LATENCY_SEND_PHASE : (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE)); - std::optional event; - if (async) { - event = EventHandle(launch_stream); - } else if (not return_recv_hook) { - stream_wait(compute_stream, launch_stream); - } + std::optional event; + if (async) { + event = EventHandle(launch_stream); + } else if (not return_recv_hook) { + stream_wait(compute_stream, launch_stream); + } - std::optional> recv_hook = std::nullopt; - if (return_recv_hook) - recv_hook = [=]() { launcher(LOW_LATENCY_RECV_PHASE); }; + std::optional> recv_hook = std::nullopt; + if (return_recv_hook) recv_hook = [=]() { launcher(LOW_LATENCY_RECV_PHASE); }; - return {combined_x, event, recv_hook}; + return {combined_x, event, recv_hook}; } -torch::Tensor Buffer::get_next_low_latency_combine_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts) { - LowLatencyLayout layout(rdma_buffer_ptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts); - auto buffer = layout.buffers[low_latency_buffer_idx]; - auto dtype = torch::kBFloat16; - auto num_msg_elems = static_cast(buffer.num_bytes_per_combine_msg / elementSize(torch::kBFloat16)); +torch::Tensor Buffer::get_next_low_latency_combine_buffer(int num_max_dispatch_tokens_per_rank, int hidden, + int num_experts) { + LowLatencyLayout layout(rdma_buffer_ptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts); + auto buffer = layout.buffers[low_latency_buffer_idx]; + auto dtype = torch::kBFloat16; + auto num_msg_elems = static_cast(buffer.num_bytes_per_combine_msg / elementSize(torch::kBFloat16)); - EP_HOST_ASSERT(buffer.num_bytes_per_combine_msg % elementSize(torch::kBFloat16) == 0); - return torch::from_blob(buffer.combine_rdma_send_buffer_data_start, - {num_experts / num_ranks, num_ranks * num_max_dispatch_tokens_per_rank, hidden}, - {num_ranks * num_max_dispatch_tokens_per_rank * num_msg_elems, num_msg_elems, 1}, - torch::TensorOptions().dtype(dtype).device(torch::kCUDA)); + EP_HOST_ASSERT(buffer.num_bytes_per_combine_msg % elementSize(torch::kBFloat16) == 0); + return torch::from_blob(buffer.combine_rdma_send_buffer_data_start, + {num_experts / num_ranks, num_ranks * num_max_dispatch_tokens_per_rank, hidden}, + {num_ranks * num_max_dispatch_tokens_per_rank * num_msg_elems, num_msg_elems, 1}, + torch::TensorOptions().dtype(dtype).device(torch::kCUDA)); } -} // namespace ep -} // namespace mscclpp +} // namespace ep +} // namespace mscclpp diff --git a/src/ext/ep/buffer.hpp b/src/ext/ep/buffer.hpp index 4937ee8a..7c2a0540 100644 --- a/src/ext/ep/buffer.hpp +++ b/src/ext/ep/buffer.hpp @@ -5,11 +5,12 @@ #include #include #include + +#include +#include +#include #include #include -#include -#include -#include #include "config.hpp" #include "event.hpp" @@ -20,172 +21,186 @@ #define TORCH_EXTENSION_NAME mscclpp_ep_cpp #endif -namespace mscclpp { namespace ep { +namespace mscclpp { +namespace ep { struct Buffer { - EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS == 8, "The number of maximum NVLink peers must be 8"); + EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS == 8, "The number of maximum NVLink peers must be 8"); -private: - // Low-latency mode buffer - int low_latency_buffer_idx = 0; - bool low_latency_mode = false; + private: + // Low-latency mode buffer + int low_latency_buffer_idx = 0; + bool low_latency_mode = false; - // NVLink Buffer - int64_t num_nvl_bytes; - void* buffer_ptrs[NUM_MAX_NVL_PEERS] = {nullptr}; - void** buffer_ptrs_gpu = nullptr; + // NVLink Buffer + int64_t num_nvl_bytes; + void* buffer_ptrs[NUM_MAX_NVL_PEERS] = {nullptr}; + void** buffer_ptrs_gpu = nullptr; - // NVSHMEM Buffer - int64_t num_rdma_bytes; - void* rdma_buffer_ptr = nullptr; + // NVSHMEM Buffer + int64_t num_rdma_bytes; + void* rdma_buffer_ptr = nullptr; - // Device info and communication - int device_id; - int rank, rdma_rank, nvl_rank; - int num_ranks, num_rdma_ranks, num_nvl_ranks; - cudaIpcMemHandle_t ipc_handles[NUM_MAX_NVL_PEERS]; + // Device info and communication + int device_id; + int rank, rdma_rank, nvl_rank; + int num_ranks, num_rdma_ranks, num_nvl_ranks; + cudaIpcMemHandle_t ipc_handles[NUM_MAX_NVL_PEERS]; - // Stream for communication - at::cuda::CUDAStream comm_stream; + // Stream for communication + at::cuda::CUDAStream comm_stream; - // After IPC/NVSHMEM synchronization, this flag will be true - bool available = false; + // After IPC/NVSHMEM synchronization, this flag will be true + bool available = false; - // Task fifo - int head = 0; - int* task_fifo_ptrs[NUM_MAX_NVL_PEERS] = {nullptr}; - int** task_fifo_ptrs_gpu = nullptr; + // Task fifo + int head = 0; + int* task_fifo_ptrs[NUM_MAX_NVL_PEERS] = {nullptr}; + int** task_fifo_ptrs_gpu = nullptr; - // Workspace - void* workspace = nullptr; + // Workspace + void* workspace = nullptr; - // Host-side MoE info - volatile int* moe_recv_counter = nullptr; - int* moe_recv_counter_mapped = nullptr; + // Host-side MoE info + volatile int* moe_recv_counter = nullptr; + int* moe_recv_counter_mapped = nullptr; - // Host-side expert-level MoE info - volatile int* moe_recv_expert_counter = nullptr; - int* moe_recv_expert_counter_mapped = nullptr; + // Host-side expert-level MoE info + volatile int* moe_recv_expert_counter = nullptr; + int* moe_recv_expert_counter_mapped = nullptr; - // Host-side RDMA-level MoE info - volatile int* moe_recv_rdma_counter = nullptr; - int* moe_recv_rdma_counter_mapped = nullptr; + // Host-side RDMA-level MoE info + volatile int* moe_recv_rdma_counter = nullptr; + int* moe_recv_rdma_counter_mapped = nullptr; - std::shared_ptr bootstrap; - // One ProxyService spawns a single proxy thread that drains every PortChannel - // FIFO it owns. With LL combine pushing thousands of triggers per iter, the - // single thread becomes the wall-clock bottleneck on cross-node runs. We - // shard channels across `proxy_services` so each gets its own thread/FIFO, - // increasing host-side dispatch parallelism (no kernel changes required). - // Count is resolved at construction (env `MSCCLPP_EP_NUM_PROXIES` or - // arch-aware default). - int num_proxy_services = 1; - std::vector> proxy_services; - std::shared_ptr communicator; - std::vector port_channels; - std::vector memory_channels; - std::shared_ptr port_channel_handles_device_ptr; - std::shared_ptr memory_channel_handles_device_ptr; + std::shared_ptr bootstrap; + // One ProxyService spawns a single proxy thread that drains every PortChannel + // FIFO it owns. With LL combine pushing thousands of triggers per iter, the + // single thread becomes the wall-clock bottleneck on cross-node runs. We + // shard channels across `proxy_services` so each gets its own thread/FIFO, + // increasing host-side dispatch parallelism (no kernel changes required). + // Count is resolved at construction (env `MSCCLPP_EP_NUM_PROXIES` or + // arch-aware default). + int num_proxy_services = 1; + std::vector> proxy_services; + std::shared_ptr communicator; + std::vector port_channels; + std::vector memory_channels; + std::shared_ptr port_channel_handles_device_ptr; + std::shared_ptr memory_channel_handles_device_ptr; - // Intra-node LL only: peer-mapped RDMA buffer pointers (CUDA IPC). - // ``peer_rdma_bases[r]`` aliases rank ``r``'s ``rdma_buffer_ptr`` via - // ``cudaIpcOpenMemHandle`` (lazy peer access). Populated in ``sync()`` when - // ``low_latency_mode && num_rdma_ranks == 1``; null otherwise. - cudaIpcMemHandle_t rdma_ipc_handles[NUM_MAX_NVL_PEERS]; - void* peer_rdma_bases[NUM_MAX_NVL_PEERS] = {nullptr}; - void** peer_rdma_bases_gpu = nullptr; - // MemoryChannels over CUDA IPC used only for the LL barrier ring. - std::vector ll_memory_channels; - std::shared_ptr ll_memory_channel_handles_device_ptr; - bool ll_ipc_ready = false; + // Intra-node LL only: peer-mapped RDMA buffer pointers (CUDA IPC). + // ``peer_rdma_bases[r]`` aliases rank ``r``'s ``rdma_buffer_ptr`` via + // ``cudaIpcOpenMemHandle`` (lazy peer access). Populated in ``sync()`` when + // ``low_latency_mode && num_rdma_ranks == 1``; null otherwise. + cudaIpcMemHandle_t rdma_ipc_handles[NUM_MAX_NVL_PEERS]; + void* peer_rdma_bases[NUM_MAX_NVL_PEERS] = {nullptr}; + void** peer_rdma_bases_gpu = nullptr; + // MemoryChannels over CUDA IPC used only for the LL barrier ring. + std::vector ll_memory_channels; + std::shared_ptr ll_memory_channel_handles_device_ptr; + bool ll_ipc_ready = false; -private: - void move_fifo_slots(int num_slots = 1); + private: + void move_fifo_slots(int num_slots = 1); -public: - Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_bytes, bool low_latency_mode); + public: + Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_bytes, bool low_latency_mode); - ~Buffer() noexcept(false); + ~Buffer() noexcept(false); - bool is_available() const; + bool is_available() const; - bool is_internode_available() const; + bool is_internode_available() const; - int get_num_rdma_ranks() const; + int get_num_rdma_ranks() const; - int get_rdma_rank() const; + int get_rdma_rank() const; - int get_root_rdma_rank(bool global) const; + int get_root_rdma_rank(bool global) const; - int get_local_device_id() const; + int get_local_device_id() const; - pybind11::bytearray get_local_ipc_handle() const; + pybind11::bytearray get_local_ipc_handle() const; - pybind11::bytearray get_local_nvshmem_unique_id() const; + pybind11::bytearray get_local_nvshmem_unique_id() const; - torch::Tensor get_local_buffer_tensor(const pybind11::object& dtype, int64_t offset, bool use_rdma_buffer) const; + torch::Tensor get_local_buffer_tensor(const pybind11::object& dtype, int64_t offset, bool use_rdma_buffer) const; - mscclpp::UniqueId create_unique_id() const; + mscclpp::UniqueId create_unique_id() const; - void connect(mscclpp::UniqueId root_id); + void connect(mscclpp::UniqueId root_id); - void sync(const std::vector& device_ids, const std::vector>& all_gathered_handles, const std::optional& root_unique_id_opt); + void sync(const std::vector& device_ids, + const std::vector>& all_gathered_handles, + const std::optional& root_unique_id_opt); - std::tuple, torch::Tensor, torch::Tensor, std::optional> - get_dispatch_layout(const torch::Tensor& topk_idx, int num_experts, std::optional& previous_event, - bool async, bool allocate_on_comm_stream); + std::tuple, torch::Tensor, torch::Tensor, std::optional> + get_dispatch_layout(const torch::Tensor& topk_idx, int num_experts, std::optional& previous_event, + bool async, bool allocate_on_comm_stream); - std::tuple, std::optional, std::optional, std::vector, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, std::optional> - intranode_dispatch(const torch::Tensor& x, const std::optional& x_scales, - const std::optional& topk_idx, const std::optional& topk_weights, - const std::optional& num_tokens_per_rank, const torch::Tensor& is_token_in_rank, const std::optional& num_tokens_per_expert, - int cached_num_recv_tokens, const std::optional& cached_rank_prefix_matrix, const std::optional& cached_channel_prefix_matrix, - int expert_alignment, const Config& config, std::optional& previous_event, bool async, bool allocate_on_comm_stream); + std::tuple, std::optional, std::optional, + std::vector, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, + std::optional> + intranode_dispatch(const torch::Tensor& x, const std::optional& x_scales, + const std::optional& topk_idx, const std::optional& topk_weights, + const std::optional& num_tokens_per_rank, const torch::Tensor& is_token_in_rank, + const std::optional& num_tokens_per_expert, int cached_num_recv_tokens, + const std::optional& cached_rank_prefix_matrix, + const std::optional& cached_channel_prefix_matrix, int expert_alignment, + const Config& config, std::optional& previous_event, bool async, + bool allocate_on_comm_stream); - std::tuple, std::optional> - intranode_combine(const torch::Tensor& x, const std::optional& topk_weights, - const torch::Tensor& src_idx, const torch::Tensor& rank_prefix_matrix, const torch::Tensor& channel_prefix_matrix, - const torch::Tensor& send_head, const Config& config, std::optional& previous_event, bool async, bool allocate_on_comm_stream); + std::tuple, std::optional> intranode_combine( + const torch::Tensor& x, const std::optional& topk_weights, const torch::Tensor& src_idx, + const torch::Tensor& rank_prefix_matrix, const torch::Tensor& channel_prefix_matrix, + const torch::Tensor& send_head, const Config& config, std::optional& previous_event, bool async, + bool allocate_on_comm_stream); - std::tuple, std::optional, std::optional, std::vector, torch::Tensor, torch::Tensor, std::optional, torch::Tensor, std::optional, torch::Tensor, std::optional, std::optional, std::optional, std::optional> - internode_dispatch(const torch::Tensor& x, const std::optional& x_scales, - const std::optional& topk_idx, const std::optional& topk_weights, - const std::optional& num_tokens_per_rank, const std::optional& num_tokens_per_rdma_rank, - const torch::Tensor& is_token_in_rank, const std::optional& num_tokens_per_expert, - int cached_num_recv_tokens, int cached_num_rdma_recv_tokens, - const std::optional& cached_rdma_channel_prefix_matrix, const std::optional& cached_recv_rdma_rank_prefix_sum, - const std::optional& cached_gbl_channel_prefix_matrix, const std::optional& cached_recv_gbl_rank_prefix_sum, - int expert_alignment, const Config& config, std::optional& previous_event, bool async, bool allocate_on_comm_stream); + std::tuple, std::optional, std::optional, + std::vector, torch::Tensor, torch::Tensor, std::optional, torch::Tensor, + std::optional, torch::Tensor, std::optional, std::optional, + std::optional, std::optional> + internode_dispatch(const torch::Tensor& x, const std::optional& x_scales, + const std::optional& topk_idx, const std::optional& topk_weights, + const std::optional& num_tokens_per_rank, + const std::optional& num_tokens_per_rdma_rank, + const torch::Tensor& is_token_in_rank, const std::optional& num_tokens_per_expert, + int cached_num_recv_tokens, int cached_num_rdma_recv_tokens, + const std::optional& cached_rdma_channel_prefix_matrix, + const std::optional& cached_recv_rdma_rank_prefix_sum, + const std::optional& cached_gbl_channel_prefix_matrix, + const std::optional& cached_recv_gbl_rank_prefix_sum, int expert_alignment, + const Config& config, std::optional& previous_event, bool async, + bool allocate_on_comm_stream); - std::tuple, std::optional> - internode_combine(const torch::Tensor& x, const std::optional& topk_weights, - const torch::Tensor& src_meta, const torch::Tensor& is_combined_token_in_rank, - const torch::Tensor& rdma_channel_prefix_matrix, const torch::Tensor& rdma_rank_prefix_sum, const torch::Tensor& gbl_channel_prefix_matrix, - const torch::Tensor& combined_rdma_head, const torch::Tensor& combined_nvl_head, - const Config& config, std::optional& previous_event, bool async, bool allocate_on_comm_stream); + std::tuple, std::optional> internode_combine( + const torch::Tensor& x, const std::optional& topk_weights, const torch::Tensor& src_meta, + const torch::Tensor& is_combined_token_in_rank, const torch::Tensor& rdma_channel_prefix_matrix, + const torch::Tensor& rdma_rank_prefix_sum, const torch::Tensor& gbl_channel_prefix_matrix, + const torch::Tensor& combined_rdma_head, const torch::Tensor& combined_nvl_head, const Config& config, + std::optional& previous_event, bool async, bool allocate_on_comm_stream); - void clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts); + void clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts); - std::tuple, torch::Tensor, torch::Tensor, torch::Tensor, std::optional, std::optional>> - low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx, - int num_max_dispatch_tokens_per_rank, int num_experts, - bool use_fp8, bool async, bool return_recv_hook, - const std::optional& out_packed_recv_x = std::nullopt, - const std::optional& out_packed_recv_x_scales = std::nullopt, - const std::optional& out_packed_recv_src_info = std::nullopt, - const std::optional& out_packed_recv_layout_range = std::nullopt, - const std::optional& out_packed_recv_count = std::nullopt); + std::tuple, torch::Tensor, torch::Tensor, torch::Tensor, + std::optional, std::optional>> + low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx, int num_max_dispatch_tokens_per_rank, + int num_experts, bool use_fp8, bool async, bool return_recv_hook, + const std::optional& out_packed_recv_x = std::nullopt, + const std::optional& out_packed_recv_x_scales = std::nullopt, + const std::optional& out_packed_recv_src_info = std::nullopt, + const std::optional& out_packed_recv_layout_range = std::nullopt, + const std::optional& out_packed_recv_count = std::nullopt); - std::tuple, std::optional>> - low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_idx, const torch::Tensor& topk_weights, - const torch::Tensor& src_info, const torch::Tensor& layout_range, - int num_max_dispatch_tokens_per_rank, int num_experts, - bool zero_copy, bool async, bool return_recv_hook, - const std::optional& out = std::nullopt); + std::tuple, std::optional>> low_latency_combine( + const torch::Tensor& x, const torch::Tensor& topk_idx, const torch::Tensor& topk_weights, + const torch::Tensor& src_info, const torch::Tensor& layout_range, int num_max_dispatch_tokens_per_rank, + int num_experts, bool zero_copy, bool async, bool return_recv_hook, + const std::optional& out = std::nullopt); - torch::Tensor - get_next_low_latency_combine_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts); + torch::Tensor get_next_low_latency_combine_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts); }; -} // namespace ep -} // namespace mscclpp +} // namespace ep +} // namespace mscclpp diff --git a/src/ext/ep/config.hpp b/src/ext/ep/config.hpp index 35f7aacf..d1f0b5be 100644 --- a/src/ext/ep/config.hpp +++ b/src/ext/ep/config.hpp @@ -5,185 +5,190 @@ #include "kernels/api.cuh" #include "kernels/exception.cuh" -namespace mscclpp { namespace ep { +namespace mscclpp { +namespace ep { template dtype_t cell_div(dtype_t a, dtype_t b) { - return (a + b - 1) / b; + return (a + b - 1) / b; } template dtype_t align(dtype_t a, dtype_t b) { - return cell_div(a, b) * b; + return cell_div(a, b) * b; } struct Config { - int num_sms; - int num_max_nvl_chunked_send_tokens; - int num_max_nvl_chunked_recv_tokens; - int num_max_rdma_chunked_send_tokens; - int num_max_rdma_chunked_recv_tokens; + int num_sms; + int num_max_nvl_chunked_send_tokens; + int num_max_nvl_chunked_recv_tokens; + int num_max_rdma_chunked_send_tokens; + int num_max_rdma_chunked_recv_tokens; - Config(int num_sms, - int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens, - int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens) : - num_sms(num_sms), - num_max_nvl_chunked_send_tokens(num_max_nvl_chunked_send_tokens), - num_max_nvl_chunked_recv_tokens(num_max_nvl_chunked_recv_tokens), - num_max_rdma_chunked_send_tokens(num_max_rdma_chunked_send_tokens), - num_max_rdma_chunked_recv_tokens(num_max_rdma_chunked_recv_tokens) { - EP_HOST_ASSERT(num_sms >= 0); - EP_HOST_ASSERT(num_max_nvl_chunked_send_tokens > 0 and num_max_nvl_chunked_recv_tokens > 0); - EP_HOST_ASSERT(num_max_nvl_chunked_send_tokens < num_max_nvl_chunked_recv_tokens); - EP_HOST_ASSERT(num_max_rdma_chunked_send_tokens > 0 and num_max_rdma_chunked_recv_tokens > 0); + Config(int num_sms, int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens, + int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens) + : num_sms(num_sms), + num_max_nvl_chunked_send_tokens(num_max_nvl_chunked_send_tokens), + num_max_nvl_chunked_recv_tokens(num_max_nvl_chunked_recv_tokens), + num_max_rdma_chunked_send_tokens(num_max_rdma_chunked_send_tokens), + num_max_rdma_chunked_recv_tokens(num_max_rdma_chunked_recv_tokens) { + EP_HOST_ASSERT(num_sms >= 0); + EP_HOST_ASSERT(num_max_nvl_chunked_send_tokens > 0 and num_max_nvl_chunked_recv_tokens > 0); + EP_HOST_ASSERT(num_max_nvl_chunked_send_tokens < num_max_nvl_chunked_recv_tokens); + EP_HOST_ASSERT(num_max_rdma_chunked_send_tokens > 0 and num_max_rdma_chunked_recv_tokens > 0); - // Ceil up RDMA buffer size - this->num_max_rdma_chunked_recv_tokens = align(num_max_rdma_chunked_recv_tokens, num_max_rdma_chunked_send_tokens); - EP_HOST_ASSERT(num_max_rdma_chunked_send_tokens < num_max_rdma_chunked_recv_tokens); - // NOTES: this assertion is related to RDMA lazy head update, we must ensure senders always have space to push - EP_HOST_ASSERT(num_max_rdma_chunked_send_tokens <= num_max_rdma_chunked_recv_tokens / 2); - } + // Ceil up RDMA buffer size + this->num_max_rdma_chunked_recv_tokens = + align(num_max_rdma_chunked_recv_tokens, num_max_rdma_chunked_send_tokens); + EP_HOST_ASSERT(num_max_rdma_chunked_send_tokens < num_max_rdma_chunked_recv_tokens); + // NOTES: this assertion is related to RDMA lazy head update, we must ensure senders always have space to push + EP_HOST_ASSERT(num_max_rdma_chunked_send_tokens <= num_max_rdma_chunked_recv_tokens / 2); + } - size_t get_nvl_buffer_size_hint(size_t hidden_bytes, int num_ranks) const { - // Below are some assumptions - // TODO: add assertions - constexpr int kNumMaxTopK = 128; - constexpr int kNumMaxScales = 128; - EP_HOST_ASSERT(num_ranks < NUM_MAX_NVL_PEERS or num_ranks % NUM_MAX_NVL_PEERS == 0); - EP_HOST_ASSERT(num_ranks <= NUM_MAX_NVL_PEERS or num_sms % 2 == 0); - const auto num_rdma_ranks = std::max(num_ranks / NUM_MAX_NVL_PEERS, 1); - const auto num_nvl_ranks = std::min(num_ranks, NUM_MAX_NVL_PEERS); - const int num_channels = num_sms / 2; + size_t get_nvl_buffer_size_hint(size_t hidden_bytes, int num_ranks) const { + // Below are some assumptions + // TODO: add assertions + constexpr int kNumMaxTopK = 128; + constexpr int kNumMaxScales = 128; + EP_HOST_ASSERT(num_ranks < NUM_MAX_NVL_PEERS or num_ranks % NUM_MAX_NVL_PEERS == 0); + EP_HOST_ASSERT(num_ranks <= NUM_MAX_NVL_PEERS or num_sms % 2 == 0); + const auto num_rdma_ranks = std::max(num_ranks / NUM_MAX_NVL_PEERS, 1); + const auto num_nvl_ranks = std::min(num_ranks, NUM_MAX_NVL_PEERS); + const int num_channels = num_sms / 2; - size_t num_bytes = 0; - num_bytes += num_channels * num_nvl_ranks * (2 * num_rdma_ranks + 3) * sizeof(int); - num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * hidden_bytes; - num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * internode::get_source_meta_bytes(); - num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * kNumMaxTopK * sizeof(int64_t); - num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * kNumMaxTopK * sizeof(float); - num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * kNumMaxScales * sizeof(float); - num_bytes = ((num_bytes + 127) / 128) * 128; - return num_bytes; - } + size_t num_bytes = 0; + num_bytes += num_channels * num_nvl_ranks * (2 * num_rdma_ranks + 3) * sizeof(int); + num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * hidden_bytes; + num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * internode::get_source_meta_bytes(); + num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * kNumMaxTopK * sizeof(int64_t); + num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * kNumMaxTopK * sizeof(float); + num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * kNumMaxScales * sizeof(float); + num_bytes = ((num_bytes + 127) / 128) * 128; + return num_bytes; + } - size_t get_rdma_buffer_size_hint(int64_t hidden_bytes, int num_ranks) const { - // Legacy mode - if (num_ranks <= NUM_MAX_NVL_PEERS) - return 0; + size_t get_rdma_buffer_size_hint(int64_t hidden_bytes, int num_ranks) const { + // Legacy mode + if (num_ranks <= NUM_MAX_NVL_PEERS) return 0; - // Below are some assumptions - // TODO: add assertions - constexpr int kNumMaxTopK = 128; - constexpr int kNumMaxScales = 128; - EP_HOST_ASSERT(num_ranks % NUM_MAX_NVL_PEERS == 0); - EP_HOST_ASSERT(num_sms % 2 == 0); - const int num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS; - const int num_channels = num_sms / 2; + // Below are some assumptions + // TODO: add assertions + constexpr int kNumMaxTopK = 128; + constexpr int kNumMaxScales = 128; + EP_HOST_ASSERT(num_ranks % NUM_MAX_NVL_PEERS == 0); + EP_HOST_ASSERT(num_sms % 2 == 0); + const int num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS; + const int num_channels = num_sms / 2; - size_t num_bytes = 0; - num_bytes += num_channels * num_rdma_ranks * (NUM_MAX_NVL_PEERS * 2 + 2) * 2 * sizeof(int); - num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * hidden_bytes * 2; - num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * internode::get_source_meta_bytes() * 2; - num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * kNumMaxTopK * sizeof(int64_t) * 2; - num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * kNumMaxTopK * sizeof(float) * 2; - num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * kNumMaxScales * sizeof(float) * 2; - num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * sizeof(int4) * 2; - num_bytes = ((num_bytes + 127) / 128) * 128; - return num_bytes; - } + size_t num_bytes = 0; + num_bytes += num_channels * num_rdma_ranks * (NUM_MAX_NVL_PEERS * 2 + 2) * 2 * sizeof(int); + num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * hidden_bytes * 2; + num_bytes += + num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * internode::get_source_meta_bytes() * 2; + num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * kNumMaxTopK * sizeof(int64_t) * 2; + num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * kNumMaxTopK * sizeof(float) * 2; + num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * kNumMaxScales * sizeof(float) * 2; + num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * sizeof(int4) * 2; + num_bytes = ((num_bytes + 127) / 128) * 128; + return num_bytes; + } }; struct LowLatencyBuffer { - int num_clean_int = 0; + int num_clean_int = 0; - void* dispatch_rdma_send_buffer = nullptr; - void* dispatch_rdma_recv_data_buffer = nullptr; - // NOTE: signaling buffers are int64_t (not int) so that IB atomic ops - // (IBV_WR_ATOMIC_FETCH_AND_ADD is a 64-bit, 8-byte-aligned op) always - // target an 8-byte-aligned address. Using int32 slots produced unaligned - // atomics at odd indices that the NIC silently drops. - int64_t* dispatch_rdma_recv_count_buffer = nullptr; + void* dispatch_rdma_send_buffer = nullptr; + void* dispatch_rdma_recv_data_buffer = nullptr; + // NOTE: signaling buffers are int64_t (not int) so that IB atomic ops + // (IBV_WR_ATOMIC_FETCH_AND_ADD is a 64-bit, 8-byte-aligned op) always + // target an 8-byte-aligned address. Using int32 slots produced unaligned + // atomics at odd indices that the NIC silently drops. + int64_t* dispatch_rdma_recv_count_buffer = nullptr; - void* combine_rdma_send_buffer = nullptr; - void* combine_rdma_recv_data_buffer = nullptr; - int64_t* combine_rdma_recv_flag_buffer = nullptr; + void* combine_rdma_send_buffer = nullptr; + void* combine_rdma_recv_data_buffer = nullptr; + int64_t* combine_rdma_recv_flag_buffer = nullptr; - void* combine_rdma_send_buffer_data_start = nullptr; - size_t num_bytes_per_combine_msg = 0; + void* combine_rdma_send_buffer_data_start = nullptr; + size_t num_bytes_per_combine_msg = 0; - std::pair clean_meta() { - EP_HOST_ASSERT(dispatch_rdma_recv_count_buffer == combine_rdma_recv_flag_buffer); - return {dispatch_rdma_recv_count_buffer, num_clean_int}; - } + std::pair clean_meta() { + EP_HOST_ASSERT(dispatch_rdma_recv_count_buffer == combine_rdma_recv_flag_buffer); + return {dispatch_rdma_recv_count_buffer, num_clean_int}; + } }; struct LowLatencyLayout { - size_t total_bytes = 0; - LowLatencyBuffer buffers[2]; + size_t total_bytes = 0; + LowLatencyBuffer buffers[2]; - template - out_ptr_t advance(const in_ptr_t& ptr, size_t count) { - return reinterpret_cast(reinterpret_cast(ptr) + count); - } - - LowLatencyLayout(void* rdma_buffer, int num_max_dispatch_tokens_per_rank, int hidden, int num_ranks, int num_experts) { - const int num_scales = hidden / 128; - - // Dispatch and combine layout: - // - 2 symmetric odd/even send buffer - // - 2 symmetric odd/even receive buffers - // - 2 symmetric odd/even signaling buffers - - // Message sizes - // NOTES: you should add a control `int4` for combine messages if you want to do data transformation - EP_HOST_ASSERT(num_scales * sizeof(float) <= hidden); - size_t num_bytes_per_dispatch_msg = sizeof(int4) + std::max(hidden * sizeof(nv_bfloat16), hidden + num_scales * sizeof(float)); - size_t num_bytes_per_combine_msg = hidden * sizeof(nv_bfloat16); - - // Send buffer - size_t dispatch_send_buffer_bytes = num_max_dispatch_tokens_per_rank * num_bytes_per_dispatch_msg; - size_t combine_send_buffer_bytes = num_experts * num_max_dispatch_tokens_per_rank * num_bytes_per_combine_msg; - size_t send_buffer_bytes = std::max(dispatch_send_buffer_bytes, combine_send_buffer_bytes); - EP_HOST_ASSERT(send_buffer_bytes % sizeof(int4) == 0); - total_bytes += send_buffer_bytes * 2; - - // Symmetric receive buffers - // TODO: optimize memory usages - size_t dispatch_recv_data_buffer_bytes = num_experts * num_max_dispatch_tokens_per_rank * num_bytes_per_dispatch_msg; - size_t combine_recv_buffer_bytes = num_experts * num_max_dispatch_tokens_per_rank * num_bytes_per_combine_msg; - size_t recv_buffer_bytes = std::max(dispatch_recv_data_buffer_bytes, combine_recv_buffer_bytes); - EP_HOST_ASSERT(recv_buffer_bytes % sizeof(int4) == 0); - total_bytes += recv_buffer_bytes * 2; - - // Symmetric signaling buffers (int64_t slots for 8-byte-aligned IB atomics). - size_t dispatch_recv_count_buffer_bytes = num_experts * sizeof(int64_t); - size_t combine_recv_flag_buffer_bytes = dispatch_recv_count_buffer_bytes; - size_t signaling_buffer_bytes = std::max(dispatch_recv_count_buffer_bytes, combine_recv_flag_buffer_bytes); - total_bytes += signaling_buffer_bytes * 2; - - // Assign pointers - // NOTES: we still leave some space for distinguishing dispatch/combine buffer, - // so you may see some parameters are duplicated - for (int i = 0; i < 2; ++ i) { - buffers[i] = { - static_cast(signaling_buffer_bytes / sizeof(int64_t)), - advance(rdma_buffer, send_buffer_bytes * i), - advance(rdma_buffer, send_buffer_bytes * 2 + recv_buffer_bytes * i), - advance(rdma_buffer, send_buffer_bytes * 2 + recv_buffer_bytes * 2 + signaling_buffer_bytes * i), - advance(rdma_buffer, send_buffer_bytes * i), - advance(rdma_buffer, send_buffer_bytes * 2 + recv_buffer_bytes * i), - advance(rdma_buffer, send_buffer_bytes * 2 + recv_buffer_bytes * 2 + signaling_buffer_bytes * i), - advance(rdma_buffer, send_buffer_bytes * i), - num_bytes_per_combine_msg - }; - } + template + out_ptr_t advance(const in_ptr_t& ptr, size_t count) { + return reinterpret_cast(reinterpret_cast(ptr) + count); + } + + LowLatencyLayout(void* rdma_buffer, int num_max_dispatch_tokens_per_rank, int hidden, int num_ranks, + int num_experts) { + const int num_scales = hidden / 128; + + // Dispatch and combine layout: + // - 2 symmetric odd/even send buffer + // - 2 symmetric odd/even receive buffers + // - 2 symmetric odd/even signaling buffers + + // Message sizes + // NOTES: you should add a control `int4` for combine messages if you want to do data transformation + EP_HOST_ASSERT(num_scales * sizeof(float) <= hidden); + size_t num_bytes_per_dispatch_msg = + sizeof(int4) + std::max(hidden * sizeof(nv_bfloat16), hidden + num_scales * sizeof(float)); + size_t num_bytes_per_combine_msg = hidden * sizeof(nv_bfloat16); + + // Send buffer + size_t dispatch_send_buffer_bytes = num_max_dispatch_tokens_per_rank * num_bytes_per_dispatch_msg; + size_t combine_send_buffer_bytes = num_experts * num_max_dispatch_tokens_per_rank * num_bytes_per_combine_msg; + size_t send_buffer_bytes = std::max(dispatch_send_buffer_bytes, combine_send_buffer_bytes); + EP_HOST_ASSERT(send_buffer_bytes % sizeof(int4) == 0); + total_bytes += send_buffer_bytes * 2; + + // Symmetric receive buffers + // TODO: optimize memory usages + size_t dispatch_recv_data_buffer_bytes = + num_experts * num_max_dispatch_tokens_per_rank * num_bytes_per_dispatch_msg; + size_t combine_recv_buffer_bytes = num_experts * num_max_dispatch_tokens_per_rank * num_bytes_per_combine_msg; + size_t recv_buffer_bytes = std::max(dispatch_recv_data_buffer_bytes, combine_recv_buffer_bytes); + EP_HOST_ASSERT(recv_buffer_bytes % sizeof(int4) == 0); + total_bytes += recv_buffer_bytes * 2; + + // Symmetric signaling buffers (int64_t slots for 8-byte-aligned IB atomics). + size_t dispatch_recv_count_buffer_bytes = num_experts * sizeof(int64_t); + size_t combine_recv_flag_buffer_bytes = dispatch_recv_count_buffer_bytes; + size_t signaling_buffer_bytes = std::max(dispatch_recv_count_buffer_bytes, combine_recv_flag_buffer_bytes); + total_bytes += signaling_buffer_bytes * 2; + + // Assign pointers + // NOTES: we still leave some space for distinguishing dispatch/combine buffer, + // so you may see some parameters are duplicated + for (int i = 0; i < 2; ++i) { + buffers[i] = { + static_cast(signaling_buffer_bytes / sizeof(int64_t)), + advance(rdma_buffer, send_buffer_bytes * i), + advance(rdma_buffer, send_buffer_bytes * 2 + recv_buffer_bytes * i), + advance(rdma_buffer, send_buffer_bytes * 2 + recv_buffer_bytes * 2 + signaling_buffer_bytes * i), + advance(rdma_buffer, send_buffer_bytes * i), + advance(rdma_buffer, send_buffer_bytes * 2 + recv_buffer_bytes * i), + advance(rdma_buffer, send_buffer_bytes * 2 + recv_buffer_bytes * 2 + signaling_buffer_bytes * i), + advance(rdma_buffer, send_buffer_bytes * i), + num_bytes_per_combine_msg}; } + } }; -inline size_t get_low_latency_rdma_size_hint(int num_max_dispatch_tokens_per_rank, int hidden, int num_ranks, int num_experts) { - auto num_bytes = LowLatencyLayout(nullptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts).total_bytes; - return ((num_bytes + NUM_BUFFER_ALIGNMENT_BYTES) / NUM_BUFFER_ALIGNMENT_BYTES) * NUM_BUFFER_ALIGNMENT_BYTES; +inline size_t get_low_latency_rdma_size_hint(int num_max_dispatch_tokens_per_rank, int hidden, int num_ranks, + int num_experts) { + auto num_bytes = + LowLatencyLayout(nullptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts).total_bytes; + return ((num_bytes + NUM_BUFFER_ALIGNMENT_BYTES) / NUM_BUFFER_ALIGNMENT_BYTES) * NUM_BUFFER_ALIGNMENT_BYTES; } -} // namespace ep -} // namespace mscclpp +} // namespace ep +} // namespace mscclpp diff --git a/src/ext/ep/event.hpp b/src/ext/ep/event.hpp index d5a77526..b5aae97f 100644 --- a/src/ext/ep/event.hpp +++ b/src/ext/ep/event.hpp @@ -3,46 +3,44 @@ #pragma once #include + #include #include "kernels/exception.cuh" -namespace mscclpp { namespace ep { +namespace mscclpp { +namespace ep { struct EventHandle { - std::shared_ptr event; + std::shared_ptr event; - EventHandle() { - event = std::make_shared(torch::kCUDA); - event->record(at::cuda::getCurrentCUDAStream()); - } + EventHandle() { + event = std::make_shared(torch::kCUDA); + event->record(at::cuda::getCurrentCUDAStream()); + } - explicit EventHandle(const at::cuda::CUDAStream& stream) { - event = std::make_shared(torch::kCUDA); - event->record(stream); - } + explicit EventHandle(const at::cuda::CUDAStream& stream) { + event = std::make_shared(torch::kCUDA); + event->record(stream); + } - EventHandle(const EventHandle& other) = default; + EventHandle(const EventHandle& other) = default; - void current_stream_wait() const { - at::cuda::getCurrentCUDAStream().unwrap().wait(*event); - } + void current_stream_wait() const { at::cuda::getCurrentCUDAStream().unwrap().wait(*event); } }; -inline torch::Event create_event(const at::cuda::CUDAStream &s) { - auto event = torch::Event(torch::kCUDA); - event.record(s); - return event; +inline torch::Event create_event(const at::cuda::CUDAStream& s) { + auto event = torch::Event(torch::kCUDA); + event.record(s); + return event; } inline void stream_wait(const at::cuda::CUDAStream& s_0, const at::cuda::CUDAStream& s_1) { - EP_HOST_ASSERT(s_0.id() != s_1.id()); - s_0.unwrap().wait(create_event(s_1)); + EP_HOST_ASSERT(s_0.id() != s_1.id()); + s_0.unwrap().wait(create_event(s_1)); } -inline void stream_wait(const at::cuda::CUDAStream& s, const EventHandle& event) { - s.unwrap().wait(*event.event); -} +inline void stream_wait(const at::cuda::CUDAStream& s, const EventHandle& event) { s.unwrap().wait(*event.event); } -} // namespace ep -} // namespace mscclpp +} // namespace ep +} // namespace mscclpp diff --git a/src/ext/ep/kernels/api.cuh b/src/ext/ep/kernels/api.cuh index 7647cb97..e83a480d 100644 --- a/src/ext/ep/kernels/api.cuh +++ b/src/ext/ep/kernels/api.cuh @@ -12,10 +12,10 @@ #include #include -#include #include #include +#include namespace mscclpp { namespace ep { @@ -30,20 +30,17 @@ void barrier(int** task_fifo_ptrs, int head, int rank, int num_ranks, cudaStream void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, int num_ranks, const int* num_tokens_per_expert, int* moe_recv_expert_counter_mapped, int num_experts, int num_tokens, const bool* is_token_in_rank, int* channel_prefix_matrix, - int* rank_prefix_matrix_copy, int num_memset_int, int expert_alignment, - void** buffer_ptrs, int** task_fifo_ptrs, int head, int rank, - cudaStream_t stream, int num_sms); + int* rank_prefix_matrix_copy, int num_memset_int, int expert_alignment, void** buffer_ptrs, + int** task_fifo_ptrs, int head, int rank, cudaStream_t stream, int num_sms); -void cached_notify_dispatch(const int* rank_prefix_matrix, int num_memset_int, - void** buffer_ptrs, int** task_fifo_ptrs, int head, int rank, int num_ranks, - cudaStream_t stream); +void cached_notify_dispatch(const int* rank_prefix_matrix, int num_memset_int, void** buffer_ptrs, int** task_fifo_ptrs, + int head, int rank, int num_ranks, cudaStream_t stream); void dispatch(void* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_topk_idx, float* recv_topk_weights, int* recv_channel_offset, int* send_head, const void* x, const float* x_scales, const int64_t* topk_idx, - const float* topk_weights, const bool* is_token_in_rank, const int* channel_prefix_matrix, - int num_tokens, int hidden_int4, int num_topk, int num_experts, int num_scales, - void** buffer_ptrs, int rank, int num_ranks, cudaStream_t stream, int num_sms, - int num_max_send_tokens, int num_recv_buffer_tokens); + const float* topk_weights, const bool* is_token_in_rank, const int* channel_prefix_matrix, int num_tokens, + int hidden_int4, int num_topk, int num_experts, int num_scales, void** buffer_ptrs, int rank, + int num_ranks, cudaStream_t stream, int num_sms, int num_max_send_tokens, int num_recv_buffer_tokens); void cached_notify_combine(void** buffer_ptrs, int* send_head, int num_channels, int num_recv_tokens, int num_memset_int, int** task_fifo_ptrs, int head, int rank, int num_ranks, @@ -51,9 +48,8 @@ void cached_notify_combine(void** buffer_ptrs, int* send_head, int num_channels, void combine(cudaDataType_t type, void* recv_x, float* recv_topk_weights, const void* x, const float* topk_weights, const int* src_idx, const int* rank_prefix_matrix, const int* channel_prefix_matrix, int* send_head, - int num_tokens, int num_recv_tokens, int hidden, int num_topk, - void** buffer_ptrs, int rank, int num_ranks, cudaStream_t stream, int num_sms, - int num_max_send_tokens, int num_recv_buffer_tokens); + int num_tokens, int num_recv_tokens, int hidden, int num_topk, void** buffer_ptrs, int rank, int num_ranks, + cudaStream_t stream, int num_sms, int num_max_send_tokens, int num_recv_buffer_tokens); } // namespace intranode @@ -74,55 +70,44 @@ void get_dispatch_layout(const int64_t* topk_idx, int* num_tokens_per_rank, int* void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, int num_ranks, const int* num_tokens_per_rdma_rank, int* moe_recv_rdma_counter_mapped, const int* num_tokens_per_expert, int* moe_recv_expert_counter_mapped, int num_experts, - const bool* is_token_in_rank, int num_tokens, int num_channels, - int hidden_int4, int num_scales, int num_topk, int expert_alignment, - int* rdma_channel_prefix_matrix, int* recv_rdma_rank_prefix_sum, - int* gbl_channel_prefix_matrix, int* recv_gbl_rank_prefix_sum, - void* rdma_buffer_ptr, int num_max_rdma_chunked_recv_tokens, - void** buffer_ptrs, int num_max_nvl_chunked_recv_tokens, - int** task_fifo_ptrs, int head, int rank, - cudaStream_t stream, int64_t num_rdma_bytes, int64_t num_nvl_bytes, - bool low_latency_mode, + const bool* is_token_in_rank, int num_tokens, int num_channels, int hidden_int4, int num_scales, + int num_topk, int expert_alignment, int* rdma_channel_prefix_matrix, + int* recv_rdma_rank_prefix_sum, int* gbl_channel_prefix_matrix, int* recv_gbl_rank_prefix_sum, + void* rdma_buffer_ptr, int num_max_rdma_chunked_recv_tokens, void** buffer_ptrs, + int num_max_nvl_chunked_recv_tokens, int** task_fifo_ptrs, int head, int rank, cudaStream_t stream, + int64_t num_rdma_bytes, int64_t num_nvl_bytes, bool low_latency_mode, mscclpp::PortChannelDeviceHandle* port_channel_handles, mscclpp::MemoryChannelDeviceHandle* memory_channel_handles); void dispatch(void* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv_topk_weights, void* recv_src_meta, const void* x, const float* x_scales, const int64_t* topk_idx, const float* topk_weights, - int* send_rdma_head, int* send_nvl_head, - int* recv_rdma_channel_prefix_matrix, int* recv_gbl_channel_prefix_matrix, - const int* rdma_channel_prefix_matrix, const int* recv_rdma_rank_prefix_sum, - const int* gbl_channel_prefix_matrix, const int* recv_gbl_rank_prefix_sum, - int num_tokens, int hidden_int4, int num_scales, int num_topk, int num_experts, - const bool* is_token_in_rank, - void* rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens, - void** buffer_ptrs, int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens, - int rank, int num_ranks, bool is_cached_dispatch, - cudaStream_t stream, int num_channels, bool low_latency_mode, + int* send_rdma_head, int* send_nvl_head, int* recv_rdma_channel_prefix_matrix, + int* recv_gbl_channel_prefix_matrix, const int* rdma_channel_prefix_matrix, + const int* recv_rdma_rank_prefix_sum, const int* gbl_channel_prefix_matrix, + const int* recv_gbl_rank_prefix_sum, int num_tokens, int hidden_int4, int num_scales, int num_topk, + int num_experts, const bool* is_token_in_rank, void* rdma_buffer_ptr, + int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens, void** buffer_ptrs, + int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens, int rank, int num_ranks, + bool is_cached_dispatch, cudaStream_t stream, int num_channels, bool low_latency_mode, mscclpp::PortChannelDeviceHandle* port_channel_handles, mscclpp::MemoryChannelDeviceHandle* memory_channel_handles); -void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_topk_weights, - int num_ranks, int num_channels, int num_combined_tokens, int* combined_rdma_head, +void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_topk_weights, int num_ranks, + int num_channels, int num_combined_tokens, int* combined_rdma_head, const int* rdma_channel_prefix_matrix, const int* rdma_rank_prefix_sum, int* combined_nvl_head, - void* rdma_buffer_ptr, int num_max_rdma_chunked_recv_tokens, - void** buffer_ptrs, int num_max_nvl_chunked_recv_tokens, - int** task_fifo_ptrs, int head, int rank, cudaStream_t stream, - int64_t num_rdma_bytes, int64_t num_nvl_bytes, - bool is_cached_dispatch, bool low_latency_mode, + void* rdma_buffer_ptr, int num_max_rdma_chunked_recv_tokens, void** buffer_ptrs, + int num_max_nvl_chunked_recv_tokens, int** task_fifo_ptrs, int head, int rank, cudaStream_t stream, + int64_t num_rdma_bytes, int64_t num_nvl_bytes, bool is_cached_dispatch, bool low_latency_mode, mscclpp::PortChannelDeviceHandle* port_channel_handles, mscclpp::MemoryChannelDeviceHandle* memory_channel_handles); -void combine(cudaDataType_t type, - void* combined_x, float* combined_topk_weights, - const bool* is_combined_token_in_rank, - const void* x, const float* topk_weights, - const int* combined_rdma_head, const int* combined_nvl_head, +void combine(cudaDataType_t type, void* combined_x, float* combined_topk_weights, const bool* is_combined_token_in_rank, + const void* x, const float* topk_weights, const int* combined_rdma_head, const int* combined_nvl_head, const void* src_meta, const int* rdma_channel_prefix_matrix, const int* rdma_rank_prefix_sum, - const int* gbl_channel_prefix_matrix, - int num_tokens, int num_combined_tokens, int hidden, int num_topk, + const int* gbl_channel_prefix_matrix, int num_tokens, int num_combined_tokens, int hidden, int num_topk, void* rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens, - void** buffer_ptrs, int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens, - int rank, int num_ranks, cudaStream_t stream, int num_channels, bool low_latency_mode, + void** buffer_ptrs, int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens, int rank, + int num_ranks, cudaStream_t stream, int num_channels, bool low_latency_mode, mscclpp::PortChannelDeviceHandle* port_channel_handles, mscclpp::MemoryChannelDeviceHandle* memory_channel_handles); @@ -135,43 +120,26 @@ void combine(cudaDataType_t type, // =========================================================================== namespace internode_ll { -void clean_low_latency_buffer(int64_t* clean_0, int num_clean_int_0, - int64_t* clean_1, int num_clean_int_1, - int rank, int num_ranks, - mscclpp::PortChannelDeviceHandle* port_channel_handles, - mscclpp::MemoryChannelDeviceHandle* memory_channel_handles, - bool use_ipc_path, +void clean_low_latency_buffer(int64_t* clean_0, int num_clean_int_0, int64_t* clean_1, int num_clean_int_1, int rank, + int num_ranks, mscclpp::PortChannelDeviceHandle* port_channel_handles, + mscclpp::MemoryChannelDeviceHandle* memory_channel_handles, bool use_ipc_path, cudaStream_t stream); -void dispatch(void* packed_recv_x, float* packed_recv_x_scales, - int* packed_recv_src_info, int64_t* packed_recv_layout_range, - int* packed_recv_count, - void* rdma_recv_x, int64_t* rdma_recv_count, void* rdma_x, - const void* x, const int64_t* topk_idx, - int64_t* next_clean, int num_next_clean_int, - int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank, - int num_topk, int num_experts, int rank, int num_ranks, bool use_fp8, - void* workspace, cudaStream_t stream, int phases, - void* rdma_buffer_ptr, - mscclpp::PortChannelDeviceHandle* port_channel_handles, - void* const* peer_rdma_bases, - mscclpp::MemoryChannelDeviceHandle* memory_channel_handles, - bool use_ipc_path); +void dispatch(void* packed_recv_x, float* packed_recv_x_scales, int* packed_recv_src_info, + int64_t* packed_recv_layout_range, int* packed_recv_count, void* rdma_recv_x, int64_t* rdma_recv_count, + void* rdma_x, const void* x, const int64_t* topk_idx, int64_t* next_clean, int num_next_clean_int, + int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank, int num_topk, int num_experts, int rank, + int num_ranks, bool use_fp8, void* workspace, cudaStream_t stream, int phases, void* rdma_buffer_ptr, + mscclpp::PortChannelDeviceHandle* port_channel_handles, void* const* peer_rdma_bases, + mscclpp::MemoryChannelDeviceHandle* memory_channel_handles, bool use_ipc_path); -void combine(void* combined_x, - void* rdma_recv_x, int64_t* rdma_recv_flag, void* rdma_send_x, - const void* x, const int64_t* topk_idx, const float* topk_weights, - const int* src_info, const int64_t* layout_range, - int64_t* next_clean, int num_next_clean_int, - int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank, - int num_topk, int num_experts, int rank, int num_ranks, - void* workspace, cudaStream_t stream, - int phases, bool zero_copy, - void* rdma_buffer_ptr, - mscclpp::PortChannelDeviceHandle* port_channel_handles, - void* const* peer_rdma_bases, - mscclpp::MemoryChannelDeviceHandle* memory_channel_handles, - bool use_ipc_path); +void combine(void* combined_x, void* rdma_recv_x, int64_t* rdma_recv_flag, void* rdma_send_x, const void* x, + const int64_t* topk_idx, const float* topk_weights, const int* src_info, const int64_t* layout_range, + int64_t* next_clean, int num_next_clean_int, int num_combined_tokens, int hidden, + int num_max_dispatch_tokens_per_rank, int num_topk, int num_experts, int rank, int num_ranks, + void* workspace, cudaStream_t stream, int phases, bool zero_copy, void* rdma_buffer_ptr, + mscclpp::PortChannelDeviceHandle* port_channel_handles, void* const* peer_rdma_bases, + mscclpp::MemoryChannelDeviceHandle* memory_channel_handles, bool use_ipc_path); } // namespace internode_ll diff --git a/src/ext/ep/kernels/buffer.cuh b/src/ext/ep/kernels/buffer.cuh index 84e5d230..b48bd588 100644 --- a/src/ext/ep/kernels/buffer.cuh +++ b/src/ext/ep/kernels/buffer.cuh @@ -5,137 +5,131 @@ #include "configs.cuh" #include "exception.cuh" -namespace mscclpp { namespace ep { +namespace mscclpp { +namespace ep { template struct Buffer { -private: - uint8_t* ptr; + private: + uint8_t* ptr; -public: - int total_bytes; + public: + int total_bytes; - __device__ __forceinline__ Buffer() : ptr(nullptr), total_bytes(0) {} + __device__ __forceinline__ Buffer() : ptr(nullptr), total_bytes(0) {} - __device__ __forceinline__ Buffer(void* &gbl_ptr, int num_elems, int offset = 0) { - total_bytes = num_elems * sizeof(dtype_t); - ptr = reinterpret_cast(gbl_ptr) + offset * sizeof(dtype_t); - gbl_ptr = reinterpret_cast(gbl_ptr) + total_bytes; - } + __device__ __forceinline__ Buffer(void*& gbl_ptr, int num_elems, int offset = 0) { + total_bytes = num_elems * sizeof(dtype_t); + ptr = reinterpret_cast(gbl_ptr) + offset * sizeof(dtype_t); + gbl_ptr = reinterpret_cast(gbl_ptr) + total_bytes; + } - __device__ __forceinline__ Buffer advance_also(void* &gbl_ptr) { - gbl_ptr = reinterpret_cast(gbl_ptr) + total_bytes; - return *this; - } + __device__ __forceinline__ Buffer advance_also(void*& gbl_ptr) { + gbl_ptr = reinterpret_cast(gbl_ptr) + total_bytes; + return *this; + } - __device__ __forceinline__ dtype_t* buffer() { - return reinterpret_cast(ptr); - } + __device__ __forceinline__ dtype_t* buffer() { return reinterpret_cast(ptr); } - __device__ __forceinline__ dtype_t& operator[](int idx) { - return buffer()[idx]; - } + __device__ __forceinline__ dtype_t& operator[](int idx) { return buffer()[idx]; } }; template struct AsymBuffer { -private: - uint8_t* ptrs[kNumRanks]; - int num_bytes; + private: + uint8_t* ptrs[kNumRanks]; + int num_bytes; -public: - int total_bytes; + public: + int total_bytes; - __device__ __forceinline__ AsymBuffer(void* &gbl_ptr, int num_elems, int num_ranks, - int sm_id = 0, int num_sms = 1, int offset = 0) { - EP_STATIC_ASSERT(kNumRanks == 1, ""); - num_bytes = num_elems * sizeof(dtype_t); + __device__ __forceinline__ AsymBuffer(void*& gbl_ptr, int num_elems, int num_ranks, int sm_id = 0, int num_sms = 1, + int offset = 0) { + EP_STATIC_ASSERT(kNumRanks == 1, ""); + num_bytes = num_elems * sizeof(dtype_t); - int per_channel_bytes = num_bytes * num_ranks; - total_bytes = per_channel_bytes * num_sms; - ptrs[0] = reinterpret_cast(gbl_ptr) + per_channel_bytes * sm_id + num_bytes * offset; - gbl_ptr = reinterpret_cast(gbl_ptr) + total_bytes; + int per_channel_bytes = num_bytes * num_ranks; + total_bytes = per_channel_bytes * num_sms; + ptrs[0] = reinterpret_cast(gbl_ptr) + per_channel_bytes * sm_id + num_bytes * offset; + gbl_ptr = reinterpret_cast(gbl_ptr) + total_bytes; + } + + __device__ __forceinline__ AsymBuffer(void** gbl_ptrs, int num_elems, int num_ranks, int sm_id = 0, int num_sms = 1, + int offset = 0) { + EP_STATIC_ASSERT(kNumRanks > 1, ""); + num_bytes = num_elems * sizeof(dtype_t); + + int per_channel_bytes = num_bytes * num_ranks; + total_bytes = per_channel_bytes * num_sms; + for (int i = 0; i < kNumRanks; ++i) { + ptrs[i] = reinterpret_cast(gbl_ptrs[i]) + per_channel_bytes * sm_id + num_bytes * offset; + gbl_ptrs[i] = reinterpret_cast(gbl_ptrs[i]) + total_bytes; } + } - __device__ __forceinline__ AsymBuffer(void** gbl_ptrs, int num_elems, int num_ranks, - int sm_id = 0, int num_sms = 1, int offset = 0) { - EP_STATIC_ASSERT(kNumRanks > 1, ""); - num_bytes = num_elems * sizeof(dtype_t); + __device__ __forceinline__ void advance(int shift) { +#pragma unroll + for (int i = 0; i < kNumRanks; ++i) ptrs[i] = ptrs[i] + shift * sizeof(dtype_t); + } - int per_channel_bytes = num_bytes * num_ranks; - total_bytes = per_channel_bytes * num_sms; - for (int i = 0; i < kNumRanks; ++ i) { - ptrs[i] = reinterpret_cast(gbl_ptrs[i]) + per_channel_bytes * sm_id + num_bytes * offset; - gbl_ptrs[i] = reinterpret_cast(gbl_ptrs[i]) + total_bytes; - } - } + __device__ __forceinline__ AsymBuffer advance_also(void*& gbl_ptr) { + gbl_ptr = reinterpret_cast(gbl_ptr) + total_bytes; + return *this; + } - __device__ __forceinline__ void advance(int shift) { - #pragma unroll - for (int i = 0; i < kNumRanks; ++ i) - ptrs[i] = ptrs[i] + shift * sizeof(dtype_t); - } + template + __device__ __forceinline__ AsymBuffer advance_also(void** gbl_ptrs) { + for (int i = 0; i < kNumAlsoRanks; ++i) gbl_ptrs[i] = reinterpret_cast(gbl_ptrs[i]) + total_bytes; + return *this; + } - __device__ __forceinline__ AsymBuffer advance_also(void* &gbl_ptr) { - gbl_ptr = reinterpret_cast(gbl_ptr) + total_bytes; - return *this; - } + __device__ __forceinline__ dtype_t* buffer(int idx = 0) { + EP_STATIC_ASSERT(kNumRanks == 1, "`buffer` is only available for single rank case"); + return reinterpret_cast(ptrs[0] + num_bytes * idx); + } - template - __device__ __forceinline__ AsymBuffer advance_also(void** gbl_ptrs) { - for (int i = 0; i < kNumAlsoRanks; ++ i) - gbl_ptrs[i] = reinterpret_cast(gbl_ptrs[i]) + total_bytes; - return *this; - } - - __device__ __forceinline__ dtype_t* buffer(int idx = 0) { - EP_STATIC_ASSERT(kNumRanks == 1, "`buffer` is only available for single rank case"); - return reinterpret_cast(ptrs[0] + num_bytes * idx); - } - - __device__ __forceinline__ dtype_t* buffer_by(int rank_idx, int idx = 0) { - EP_STATIC_ASSERT(kNumRanks > 1, "`buffer` is only available for single rank case"); - return reinterpret_cast(ptrs[rank_idx] + num_bytes * idx); - } + __device__ __forceinline__ dtype_t* buffer_by(int rank_idx, int idx = 0) { + EP_STATIC_ASSERT(kNumRanks > 1, "`buffer` is only available for single rank case"); + return reinterpret_cast(ptrs[rank_idx] + num_bytes * idx); + } }; template struct SymBuffer { -private: - // NOTES: for non-decoupled case, `recv_ptr` is not used - uint8_t* send_ptr; - uint8_t* recv_ptr; - int num_bytes; + private: + // NOTES: for non-decoupled case, `recv_ptr` is not used + uint8_t* send_ptr; + uint8_t* recv_ptr; + int num_bytes; -public: - int total_bytes; + public: + int total_bytes; - __device__ __forceinline__ SymBuffer(void* &gbl_ptr, int num_elems, int num_ranks, - int sm_id = 0, int num_sms = 1) { - num_bytes = num_elems * sizeof(dtype_t); + __device__ __forceinline__ SymBuffer(void*& gbl_ptr, int num_elems, int num_ranks, int sm_id = 0, int num_sms = 1) { + num_bytes = num_elems * sizeof(dtype_t); - int per_channel_bytes = num_bytes * num_ranks; - total_bytes = per_channel_bytes * num_sms * (static_cast(kDecoupled) + 1); - send_ptr = reinterpret_cast(gbl_ptr) + per_channel_bytes * sm_id; - recv_ptr = reinterpret_cast(gbl_ptr) + per_channel_bytes * (sm_id + num_sms); - gbl_ptr = reinterpret_cast(gbl_ptr) + total_bytes; - } + int per_channel_bytes = num_bytes * num_ranks; + total_bytes = per_channel_bytes * num_sms * (static_cast(kDecoupled) + 1); + send_ptr = reinterpret_cast(gbl_ptr) + per_channel_bytes * sm_id; + recv_ptr = reinterpret_cast(gbl_ptr) + per_channel_bytes * (sm_id + num_sms); + gbl_ptr = reinterpret_cast(gbl_ptr) + total_bytes; + } - __device__ __forceinline__ dtype_t* send_buffer(int idx = 0) { - EP_STATIC_ASSERT(kDecoupled, "`send_buffer` is only available for non-decoupled case"); - return reinterpret_cast(send_ptr + num_bytes * idx); - } + __device__ __forceinline__ dtype_t* send_buffer(int idx = 0) { + EP_STATIC_ASSERT(kDecoupled, "`send_buffer` is only available for non-decoupled case"); + return reinterpret_cast(send_ptr + num_bytes * idx); + } - __device__ __forceinline__ dtype_t* recv_buffer(int idx = 0) { - EP_STATIC_ASSERT(kDecoupled, "`recv_buffer` is only available for non-decoupled case"); - return reinterpret_cast(recv_ptr + num_bytes * idx); - } + __device__ __forceinline__ dtype_t* recv_buffer(int idx = 0) { + EP_STATIC_ASSERT(kDecoupled, "`recv_buffer` is only available for non-decoupled case"); + return reinterpret_cast(recv_ptr + num_bytes * idx); + } - __device__ __forceinline__ dtype_t* buffer(int idx = 0) { - EP_STATIC_ASSERT(not kDecoupled, "`buffer` is only available for decoupled case"); - return reinterpret_cast(send_ptr + num_bytes * idx); - } + __device__ __forceinline__ dtype_t* buffer(int idx = 0) { + EP_STATIC_ASSERT(not kDecoupled, "`buffer` is only available for decoupled case"); + return reinterpret_cast(send_ptr + num_bytes * idx); + } }; -} // namespace ep -} // namespace mscclpp +} // namespace ep +} // namespace mscclpp diff --git a/src/ext/ep/kernels/exception.cuh b/src/ext/ep/kernels/exception.cuh index 192ef863..60c09556 100644 --- a/src/ext/ep/kernels/exception.cuh +++ b/src/ext/ep/kernels/exception.cuh @@ -2,8 +2,8 @@ // Licensed under the MIT License. #pragma once -#include #include +#include #include "configs.cuh" @@ -11,38 +11,39 @@ #define EP_STATIC_ASSERT(cond, reason) static_assert(cond, reason) #endif -class EPException: public std::exception { -private: - std::string message = {}; +class EPException : public std::exception { + private: + std::string message = {}; -public: - explicit EPException(const char *name, const char* file, const int line, const std::string& error) { - message = std::string("Failed: ") + name + " error " + file + ":" + std::to_string(line) + " '" + error + "'"; - } + public: + explicit EPException(const char* name, const char* file, const int line, const std::string& error) { + message = std::string("Failed: ") + name + " error " + file + ":" + std::to_string(line) + " '" + error + "'"; + } - const char *what() const noexcept override { return message.c_str(); } + const char* what() const noexcept override { return message.c_str(); } }; #ifndef CUDA_CHECK -#define CUDA_CHECK(cmd) \ -do { \ - cudaError_t e = (cmd); \ - if (e != cudaSuccess) { \ - throw EPException("CUDA", __FILE__, __LINE__, cudaGetErrorString(e)); \ - } \ -} while (0) +#define CUDA_CHECK(cmd) \ + do { \ + cudaError_t e = (cmd); \ + if (e != cudaSuccess) { \ + throw EPException("CUDA", __FILE__, __LINE__, cudaGetErrorString(e)); \ + } \ + } while (0) #endif #ifndef EP_HOST_ASSERT -#define EP_HOST_ASSERT(cond) \ -do { \ - if (not (cond)) { \ - throw EPException("Assertion", __FILE__, __LINE__, #cond); \ - } \ -} while (0) +#define EP_HOST_ASSERT(cond) \ + do { \ + if (not(cond)) { \ + throw EPException("Assertion", __FILE__, __LINE__, #cond); \ + } \ + } while (0) #endif #ifndef EP_DEVICE_ASSERT -// #define EP_DEVICE_ASSERT(cond) do { if (not (cond)) { printf("Assertion failed: %s:%d, condition: %s\n", __FILE__, __LINE__, #cond); asm("trap;"); } } while (0) +// #define EP_DEVICE_ASSERT(cond) do { if (not (cond)) { printf("Assertion failed: %s:%d, condition: %s\n", __FILE__, +// __LINE__, #cond); asm("trap;"); } } while (0) #define EP_DEVICE_ASSERT(cond) #endif diff --git a/src/ext/ep/kernels/internode.cu b/src/ext/ep/kernels/internode.cu index 98d29809..44a3b06d 100644 --- a/src/ext/ep/kernels/internode.cu +++ b/src/ext/ep/kernels/internode.cu @@ -1,1818 +1,1893 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. #include +#include +#include -#include "configs.cuh" #include "buffer.cuh" +#include "configs.cuh" #include "exception.cuh" #include "launch.cuh" #include "utils.cuh" -#include -#include -namespace mscclpp { namespace ep { +namespace mscclpp { +namespace ep { namespace internode { -template +template __global__ void __launch_bounds__(kNumThreads, 1) -get_dispatch_layout(const int64_t* topk_idx, - int* num_tokens_per_rank, int* num_tokens_per_rdma_rank, - int* num_tokens_per_expert, bool* is_token_in_rank, - int num_tokens, int num_topk, int num_ranks, int num_experts) { - auto sm_id = static_cast(blockIdx.x); - auto thread_id = static_cast(threadIdx.x); + get_dispatch_layout(const int64_t* topk_idx, int* num_tokens_per_rank, int* num_tokens_per_rdma_rank, + int* num_tokens_per_expert, bool* is_token_in_rank, int num_tokens, int num_topk, int num_ranks, + int num_experts) { + auto sm_id = static_cast(blockIdx.x); + auto thread_id = static_cast(threadIdx.x); - // Count expert statistics - __shared__ int num_tokens_per_expert_per_thread[kNumThreads][kNumExpertsPerSM]; - int expert_begin_idx = sm_id * kNumExpertsPerSM, expert_end_idx = min(expert_begin_idx + kNumExpertsPerSM, num_experts); - if (expert_begin_idx < expert_end_idx) { - // Per-thread count - #pragma unroll - for (int i = 0; i < kNumExpertsPerSM; ++ i) - num_tokens_per_expert_per_thread[thread_id][i] = 0; - #pragma unroll - for (int i = thread_id; i < num_tokens; i += kNumThreads) { - auto shifted_topk_idx = topk_idx + i * num_topk; - #pragma unroll - for (int j = 0, expert_idx; j < num_topk; ++ j) { - expert_idx = static_cast(shifted_topk_idx[j]); - if (expert_begin_idx <= expert_idx and expert_idx < expert_end_idx) - ++ num_tokens_per_expert_per_thread[thread_id][expert_idx - expert_begin_idx]; - } - } - __syncthreads(); + // Count expert statistics + __shared__ int num_tokens_per_expert_per_thread[kNumThreads][kNumExpertsPerSM]; + int expert_begin_idx = sm_id * kNumExpertsPerSM, + expert_end_idx = min(expert_begin_idx + kNumExpertsPerSM, num_experts); + if (expert_begin_idx < expert_end_idx) { +// Per-thread count +#pragma unroll + for (int i = 0; i < kNumExpertsPerSM; ++i) num_tokens_per_expert_per_thread[thread_id][i] = 0; +#pragma unroll + for (int i = thread_id; i < num_tokens; i += kNumThreads) { + auto shifted_topk_idx = topk_idx + i * num_topk; +#pragma unroll + for (int j = 0, expert_idx; j < num_topk; ++j) { + expert_idx = static_cast(shifted_topk_idx[j]); + if (expert_begin_idx <= expert_idx and expert_idx < expert_end_idx) + ++num_tokens_per_expert_per_thread[thread_id][expert_idx - expert_begin_idx]; + } + } + __syncthreads(); - // Sum up - EP_STATIC_ASSERT(kNumExpertsPerSM <= kNumThreads, "Too many experts per SM"); - if (expert_begin_idx + thread_id < expert_end_idx) { - int sum = 0; - #pragma unroll - for (int i = 0; i < kNumThreads; ++ i) - sum += num_tokens_per_expert_per_thread[i][thread_id]; - num_tokens_per_expert[expert_begin_idx + thread_id] = sum; + // Sum up + EP_STATIC_ASSERT(kNumExpertsPerSM <= kNumThreads, "Too many experts per SM"); + if (expert_begin_idx + thread_id < expert_end_idx) { + int sum = 0; +#pragma unroll + for (int i = 0; i < kNumThreads; ++i) sum += num_tokens_per_expert_per_thread[i][thread_id]; + num_tokens_per_expert[expert_begin_idx + thread_id] = sum; + } + return; + } + + if (num_tokens_per_rdma_rank != nullptr) + EP_DEVICE_ASSERT(num_ranks % NUM_MAX_NVL_PEERS == 0 and num_ranks > NUM_MAX_NVL_PEERS); + + // Count rank statistics + constexpr int kNumRDMARanksPerSM = kNumRanksPerSM / NUM_MAX_NVL_PEERS; + __shared__ int num_tokens_per_rank_per_thread[kNumThreads][kNumRanksPerSM]; + __shared__ int num_tokens_per_rdma_rank_per_thread[kNumThreads][kNumRDMARanksPerSM]; + auto sm_begin = (num_experts + kNumExpertsPerSM - 1) / kNumExpertsPerSM; + int rank_begin_idx = (sm_id - sm_begin) * kNumRanksPerSM, + rank_end_idx = min(rank_begin_idx + kNumRanksPerSM, num_ranks); + int rdma_rank_begin_idx = rank_begin_idx / NUM_MAX_NVL_PEERS, rdma_rank_end_idx = rank_end_idx / NUM_MAX_NVL_PEERS; + if (rank_begin_idx < rank_end_idx) { + const auto num_expert_per_rank = num_experts / num_ranks; + auto expert_begin = rank_begin_idx * num_expert_per_rank; + auto expert_end = rank_end_idx * num_expert_per_rank; + +// Per-thread count +#pragma unroll + for (int i = 0; i < kNumRanksPerSM; ++i) num_tokens_per_rank_per_thread[thread_id][i] = 0; +#pragma unroll + for (int i = 0; i < kNumRDMARanksPerSM; ++i) num_tokens_per_rdma_rank_per_thread[thread_id][i] = 0; +#pragma unroll + for (int i = thread_id; i < num_tokens; i += kNumThreads) { + auto shifted_topk_idx = topk_idx + i * num_topk; + int is_in_rank[kNumRanksPerSM] = {0}, is_in_rdma_rank[kNumRDMARanksPerSM] = {0}; +#pragma unroll + for (int j = 0, expert_idx, rank_idx; j < num_topk; ++j) { + expert_idx = static_cast(shifted_topk_idx[j]); + if (expert_begin <= expert_idx and expert_idx < expert_end) { + // Count single rank + rank_idx = expert_idx / num_expert_per_rank - rank_begin_idx; + is_in_rank[rank_idx]++, is_in_rdma_rank[rank_idx / NUM_MAX_NVL_PEERS]++; } - return; + } + + auto shifted_is_token_in_rank = is_token_in_rank + i * num_ranks; +#pragma unroll + for (int j = 0; j + rank_begin_idx < rank_end_idx; ++j) { + shifted_is_token_in_rank[j + rank_begin_idx] = (is_in_rank[j] > 0); + num_tokens_per_rank_per_thread[thread_id][j] += (is_in_rank[j] > 0); + } + +#pragma unroll + for (int j = 0; j + rdma_rank_begin_idx < rdma_rank_end_idx; ++j) + num_tokens_per_rdma_rank_per_thread[thread_id][j] += (is_in_rdma_rank[j] > 0); + } + __syncthreads(); + + // Sum up + EP_STATIC_ASSERT(kNumRanksPerSM <= kNumThreads, "Too many ranks per SM"); + if (rank_begin_idx + thread_id < rank_end_idx) { + int sum = 0; +#pragma unroll + for (int i = 0; i < kNumThreads; ++i) sum += num_tokens_per_rank_per_thread[i][thread_id]; + num_tokens_per_rank[rank_begin_idx + thread_id] = sum; } - if (num_tokens_per_rdma_rank != nullptr) - EP_DEVICE_ASSERT(num_ranks % NUM_MAX_NVL_PEERS == 0 and num_ranks > NUM_MAX_NVL_PEERS); - - // Count rank statistics - constexpr int kNumRDMARanksPerSM = kNumRanksPerSM / NUM_MAX_NVL_PEERS; - __shared__ int num_tokens_per_rank_per_thread[kNumThreads][kNumRanksPerSM]; - __shared__ int num_tokens_per_rdma_rank_per_thread[kNumThreads][kNumRDMARanksPerSM]; - auto sm_begin = (num_experts + kNumExpertsPerSM - 1) / kNumExpertsPerSM; - int rank_begin_idx = (sm_id - sm_begin) * kNumRanksPerSM, rank_end_idx = min(rank_begin_idx + kNumRanksPerSM, num_ranks); - int rdma_rank_begin_idx = rank_begin_idx / NUM_MAX_NVL_PEERS, rdma_rank_end_idx = rank_end_idx / NUM_MAX_NVL_PEERS; - if (rank_begin_idx < rank_end_idx) { - const auto num_expert_per_rank = num_experts / num_ranks; - auto expert_begin = rank_begin_idx * num_expert_per_rank; - auto expert_end = rank_end_idx * num_expert_per_rank; - - // Per-thread count - #pragma unroll - for (int i = 0; i < kNumRanksPerSM; ++ i) - num_tokens_per_rank_per_thread[thread_id][i] = 0; - #pragma unroll - for (int i = 0; i < kNumRDMARanksPerSM; ++ i) - num_tokens_per_rdma_rank_per_thread[thread_id][i] = 0; - #pragma unroll - for (int i = thread_id; i < num_tokens; i += kNumThreads) { - auto shifted_topk_idx = topk_idx + i * num_topk; - int is_in_rank[kNumRanksPerSM] = {0}, is_in_rdma_rank[kNumRDMARanksPerSM] = {0}; - #pragma unroll - for (int j = 0, expert_idx, rank_idx; j < num_topk; ++j) { - expert_idx = static_cast(shifted_topk_idx[j]); - if (expert_begin <= expert_idx and expert_idx < expert_end) { - // Count single rank - rank_idx = expert_idx / num_expert_per_rank - rank_begin_idx; - is_in_rank[rank_idx] ++, is_in_rdma_rank[rank_idx / NUM_MAX_NVL_PEERS] ++; - } - } - - auto shifted_is_token_in_rank = is_token_in_rank + i * num_ranks; - #pragma unroll - for (int j = 0; j + rank_begin_idx < rank_end_idx; ++ j) { - shifted_is_token_in_rank[j + rank_begin_idx] = (is_in_rank[j] > 0); - num_tokens_per_rank_per_thread[thread_id][j] += (is_in_rank[j] > 0); - } - - #pragma unroll - for (int j = 0; j + rdma_rank_begin_idx < rdma_rank_end_idx; ++ j) - num_tokens_per_rdma_rank_per_thread[thread_id][j] += (is_in_rdma_rank[j] > 0); - } - __syncthreads(); - - // Sum up - EP_STATIC_ASSERT(kNumRanksPerSM <= kNumThreads, "Too many ranks per SM"); - if (rank_begin_idx + thread_id < rank_end_idx) { - int sum = 0; - #pragma unroll - for (int i = 0; i < kNumThreads; ++ i) - sum += num_tokens_per_rank_per_thread[i][thread_id]; - num_tokens_per_rank[rank_begin_idx + thread_id] = sum; - } - - if (num_tokens_per_rdma_rank != nullptr and rdma_rank_begin_idx + thread_id < rdma_rank_end_idx) { - int sum = 0; - #pragma unroll - for (int i = 0; i < kNumThreads; ++ i) - sum += num_tokens_per_rdma_rank_per_thread[i][thread_id]; - num_tokens_per_rdma_rank[rdma_rank_begin_idx + thread_id] = sum; - } + if (num_tokens_per_rdma_rank != nullptr and rdma_rank_begin_idx + thread_id < rdma_rank_end_idx) { + int sum = 0; +#pragma unroll + for (int i = 0; i < kNumThreads; ++i) sum += num_tokens_per_rdma_rank_per_thread[i][thread_id]; + num_tokens_per_rdma_rank[rdma_rank_begin_idx + thread_id] = sum; } + } } -void get_dispatch_layout(const int64_t* topk_idx, - int* num_tokens_per_rank, int* num_tokens_per_rdma_rank, - int* num_tokens_per_expert, bool* is_token_in_rank, - int num_tokens, int num_topk, int num_ranks, int num_experts, - cudaStream_t stream) { - constexpr int kNumThreads = 256, kNumExpertsPerSM = 32, kNumRanksPerSM = 8; - int num_sms = ((num_experts + kNumExpertsPerSM - 1) / kNumExpertsPerSM) + (num_ranks + kNumRanksPerSM - 1) / kNumRanksPerSM; - EP_STATIC_ASSERT(kNumExpertsPerSM % NUM_MAX_NVL_PEERS == 0, "Invalid number of experts per SM"); +void get_dispatch_layout(const int64_t* topk_idx, int* num_tokens_per_rank, int* num_tokens_per_rdma_rank, + int* num_tokens_per_expert, bool* is_token_in_rank, int num_tokens, int num_topk, + int num_ranks, int num_experts, cudaStream_t stream) { + constexpr int kNumThreads = 256, kNumExpertsPerSM = 32, kNumRanksPerSM = 8; + int num_sms = + ((num_experts + kNumExpertsPerSM - 1) / kNumExpertsPerSM) + (num_ranks + kNumRanksPerSM - 1) / kNumRanksPerSM; + EP_STATIC_ASSERT(kNumExpertsPerSM % NUM_MAX_NVL_PEERS == 0, "Invalid number of experts per SM"); - SETUP_LAUNCH_CONFIG(num_sms, kNumThreads, stream); - LAUNCH_KERNEL(&cfg, (get_dispatch_layout), - topk_idx, num_tokens_per_rank, num_tokens_per_rdma_rank, num_tokens_per_expert, is_token_in_rank, - num_tokens, num_topk, num_ranks, num_experts); + SETUP_LAUNCH_CONFIG(num_sms, kNumThreads, stream); + LAUNCH_KERNEL(&cfg, (get_dispatch_layout), topk_idx, + num_tokens_per_rank, num_tokens_per_rdma_rank, num_tokens_per_expert, is_token_in_rank, num_tokens, + num_topk, num_ranks, num_experts); } struct SourceMeta { - int src_rdma_rank, is_token_in_nvl_rank_bits; + int src_rdma_rank, is_token_in_nvl_rank_bits; - EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS == 8, "Invalid number of maximum NVL peers"); + EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS == 8, "Invalid number of maximum NVL peers"); - __forceinline__ SourceMeta() = default; + __forceinline__ SourceMeta() = default; - // TODO: faster encoding - __device__ __forceinline__ SourceMeta(int rdma_rank, const bool* is_token_in_nvl_ranks) { - src_rdma_rank = rdma_rank; - is_token_in_nvl_rank_bits = is_token_in_nvl_ranks[0]; - #pragma unroll - for (int i = 1; i < NUM_MAX_NVL_PEERS; ++ i) - is_token_in_nvl_rank_bits |= is_token_in_nvl_ranks[i] << i; - } + // TODO: faster encoding + __device__ __forceinline__ SourceMeta(int rdma_rank, const bool* is_token_in_nvl_ranks) { + src_rdma_rank = rdma_rank; + is_token_in_nvl_rank_bits = is_token_in_nvl_ranks[0]; +#pragma unroll + for (int i = 1; i < NUM_MAX_NVL_PEERS; ++i) is_token_in_nvl_rank_bits |= is_token_in_nvl_ranks[i] << i; + } - __device__ __forceinline__ bool is_token_in_nvl_rank(int nvl_rank) const { - return (is_token_in_nvl_rank_bits >> nvl_rank) & 1; - } + __device__ __forceinline__ bool is_token_in_nvl_rank(int nvl_rank) const { + return (is_token_in_nvl_rank_bits >> nvl_rank) & 1; + } }; EP_STATIC_ASSERT(sizeof(SourceMeta) % sizeof(int) == 0, "Invalid size of `SourceMeta`"); -int get_source_meta_bytes() { - return sizeof(SourceMeta); +int get_source_meta_bytes() { return sizeof(SourceMeta); } + +__host__ __device__ __forceinline__ int get_num_bytes_per_rdma_token(int hidden_int4, int num_scales, int num_topk_idx, + int num_topk_weights) { + return static_cast(align(hidden_int4 * sizeof(int4) + sizeof(SourceMeta) + num_scales * sizeof(float) + + num_topk_idx * sizeof(int) + num_topk_weights * sizeof(float), + sizeof(int4))); } -__host__ __device__ __forceinline__ -int get_num_bytes_per_rdma_token(int hidden_int4, int num_scales, int num_topk_idx, int num_topk_weights) { - return static_cast(align(hidden_int4 * sizeof(int4) + sizeof(SourceMeta) + num_scales * sizeof(float) + num_topk_idx * sizeof(int) + num_topk_weights * sizeof(float), sizeof(int4))); +__host__ __device__ __forceinline__ std::pair get_rdma_clean_meta(int hidden_int4, int num_scales, + int num_topk_idx, int num_topk_weights, + int num_rdma_ranks, + int num_rdma_recv_buffer_tokens, + int num_sms) { + // Return `int32_t` offset and count to clean + return {(get_num_bytes_per_rdma_token(hidden_int4, num_scales, num_topk_idx, num_topk_weights) * + num_rdma_recv_buffer_tokens * num_rdma_ranks * 2 * num_sms) / + sizeof(int), + (NUM_MAX_NVL_PEERS * 2 + 4) * num_rdma_ranks * 2 * num_sms}; } -__host__ __device__ __forceinline__ -std::pair get_rdma_clean_meta(int hidden_int4, int num_scales, int num_topk_idx, int num_topk_weights, int num_rdma_ranks, int num_rdma_recv_buffer_tokens, int num_sms) { - // Return `int32_t` offset and count to clean - return { - (get_num_bytes_per_rdma_token(hidden_int4, num_scales, num_topk_idx, num_topk_weights) * num_rdma_recv_buffer_tokens * num_rdma_ranks * 2 * num_sms) / sizeof(int), - (NUM_MAX_NVL_PEERS * 2 + 4) * num_rdma_ranks * 2 * num_sms - }; -} - -__host__ __device__ __forceinline__ -std::pair get_nvl_clean_meta(int hidden_int4, int num_scales, int num_topk_idx, int num_topk_weights, int num_rdma_ranks, int num_nvl_ranks, int num_nvl_recv_buffer_tokens, int num_sms) { - // Return `int32_t` offset and to clean - EP_STATIC_ASSERT(sizeof(SourceMeta) % sizeof(int) == 0, "Invalid size of `SourceMeta`"); - return { - (num_nvl_recv_buffer_tokens * (hidden_int4 * sizeof(int4) + num_scales * sizeof(float) + num_topk_idx * sizeof(int) + num_topk_weights * sizeof(float) + sizeof(SourceMeta)) * num_nvl_ranks * num_sms) / sizeof(int), - num_nvl_ranks * (2 * num_rdma_ranks + 2) * num_sms, - }; +__host__ __device__ __forceinline__ std::pair get_nvl_clean_meta(int hidden_int4, int num_scales, + int num_topk_idx, int num_topk_weights, + int num_rdma_ranks, int num_nvl_ranks, + int num_nvl_recv_buffer_tokens, + int num_sms) { + // Return `int32_t` offset and to clean + EP_STATIC_ASSERT(sizeof(SourceMeta) % sizeof(int) == 0, "Invalid size of `SourceMeta`"); + return { + (num_nvl_recv_buffer_tokens * + (hidden_int4 * sizeof(int4) + num_scales * sizeof(float) + num_topk_idx * sizeof(int) + + num_topk_weights * sizeof(float) + sizeof(SourceMeta)) * + num_nvl_ranks * num_sms) / + sizeof(int), + num_nvl_ranks * (2 * num_rdma_ranks + 2) * num_sms, + }; } template -__global__ void -notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, int num_ranks, - const int* num_tokens_per_rdma_rank, int* moe_recv_rdma_counter_mapped, - const int* num_tokens_per_expert, int* moe_recv_expert_counter_mapped, int num_experts, - const bool* is_token_in_rank, int num_tokens, int num_channels, int expert_alignment, - const int rdma_clean_offset, const int rdma_num_int_clean, - const int nvl_clean_offset, const int nvl_num_int_clean, - int* rdma_channel_prefix_matrix, int* recv_rdma_rank_prefix_sum, - int* gbl_channel_prefix_matrix, int* recv_gbl_rank_prefix_sum, - void* rdma_buffer_ptr, - void** buffer_ptrs, int** task_fifo_ptrs, int head, int rank, - mscclpp::PortChannelDeviceHandle* port_channel_handles, - mscclpp::MemoryChannelDeviceHandle *memory_channel_handles) { - auto sm_id = static_cast(blockIdx.x); - auto thread_id = static_cast(threadIdx.x), warp_id = thread_id / 32, lane_id = get_lane_id(); - auto num_threads = static_cast(blockDim.x), num_warps = num_threads / 32; +__global__ void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, int num_ranks, + const int* num_tokens_per_rdma_rank, int* moe_recv_rdma_counter_mapped, + const int* num_tokens_per_expert, int* moe_recv_expert_counter_mapped, int num_experts, + const bool* is_token_in_rank, int num_tokens, int num_channels, int expert_alignment, + const int rdma_clean_offset, const int rdma_num_int_clean, const int nvl_clean_offset, + const int nvl_num_int_clean, int* rdma_channel_prefix_matrix, + int* recv_rdma_rank_prefix_sum, int* gbl_channel_prefix_matrix, + int* recv_gbl_rank_prefix_sum, void* rdma_buffer_ptr, void** buffer_ptrs, + int** task_fifo_ptrs, int head, int rank, + mscclpp::PortChannelDeviceHandle* port_channel_handles, + mscclpp::MemoryChannelDeviceHandle* memory_channel_handles) { + auto sm_id = static_cast(blockIdx.x); + auto thread_id = static_cast(threadIdx.x), warp_id = thread_id / 32, lane_id = get_lane_id(); + auto num_threads = static_cast(blockDim.x), num_warps = num_threads / 32; - auto rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS; - auto num_rdma_experts = num_experts / kNumRDMARanks, num_nvl_experts = num_rdma_experts / NUM_MAX_NVL_PEERS; + auto rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS; + auto num_rdma_experts = num_experts / kNumRDMARanks, num_nvl_experts = num_rdma_experts / NUM_MAX_NVL_PEERS; - if (sm_id == 0) { - // Communication with others - // Global barrier: the first warp do intra-node sync, the second warp do internode sync - EP_DEVICE_ASSERT(num_warps > 1); - EP_DEVICE_ASSERT(kNumRDMARanks + 32 <= num_threads); - const auto barrier_thread_id = thread_id - 32; - const bool run_barrier = (barrier_thread_id >= 0) && (barrier_thread_id < kNumRDMARanks) && (barrier_thread_id != rdma_rank); - const auto barrier_channel_idx = kLowLatencyMode ? barrier_thread_id : (barrier_thread_id * NUM_MAX_NVL_PEERS + nvl_rank); - if (run_barrier) { - port_channel_handles[barrier_channel_idx].signal(); - port_channel_handles[barrier_channel_idx].wait(); - } - if constexpr (!kLowLatencyMode) { - // kLowLatencyMode==false requires sync of all ranks, which can be done by running intra-node sync - // after the inter-node sync is done. - __syncthreads(); - } -#if 1 - barrier_device(task_fifo_ptrs, head, nvl_rank); - move_fifo_slots(head); -#else - // TODO(chhwang): make memory channels work - if (thread_id < NUM_MAX_NVL_PEERS && thread_id != nvl_rank) { - memory_channel_handles[thread_id].relaxedSignal(); - memory_channel_handles[thread_id].relaxedWait(); - } -#endif - __syncthreads(); - - // Send numbers of tokens per rank/expert to RDMA ranks - auto rdma_buffer_ptr_int = reinterpret_cast(rdma_buffer_ptr); - auto num_elems = NUM_MAX_NVL_PEERS + num_rdma_experts + 1; - auto num_bytes = num_elems * sizeof(int); - auto per_channel_bytes = num_bytes * kNumRDMARanks; - auto rdma_recv_num_tokens_mixed = SymBuffer(rdma_buffer_ptr, num_elems, kNumRDMARanks); - - // Clean up for later data dispatch - EP_DEVICE_ASSERT(rdma_recv_num_tokens_mixed.total_bytes <= rdma_clean_offset * sizeof(int)); - #pragma unroll - for (int i = thread_id; i < rdma_num_int_clean; i += num_threads) - rdma_buffer_ptr_int[rdma_clean_offset + i] = 0; - - // Copy to send buffer - #pragma unroll - for (int i = thread_id; i < num_ranks; i += num_threads) - rdma_recv_num_tokens_mixed.send_buffer(i / NUM_MAX_NVL_PEERS)[i % NUM_MAX_NVL_PEERS] = num_tokens_per_rank[i]; - #pragma unroll - for (int i = thread_id; i < num_experts; i += num_threads) - rdma_recv_num_tokens_mixed.send_buffer(i / num_rdma_experts)[NUM_MAX_NVL_PEERS + i % num_rdma_experts] = num_tokens_per_expert[i]; - if (thread_id < kNumRDMARanks) - rdma_recv_num_tokens_mixed.send_buffer(thread_id)[NUM_MAX_NVL_PEERS + num_rdma_experts] = num_tokens_per_rdma_rank[thread_id]; - __syncthreads(); - - // Issue send - // TODO: more light fence or barrier or signaling - // TODO: overlap EP barrier and NVL cleaning - if (thread_id < kNumRDMARanks) { - auto dst_offset = rdma_rank * num_bytes + per_channel_bytes; - auto src_offset = thread_id * num_bytes; - auto peer_rank = kLowLatencyMode ? thread_id : (thread_id * NUM_MAX_NVL_PEERS + nvl_rank); - port_channel_handles[peer_rank].putWithSignal(dst_offset, src_offset, num_bytes); - port_channel_handles[peer_rank].wait(); - } - __syncthreads(); - - // NVL buffers - auto nvl_send_buffer = thread_id < NUM_MAX_NVL_PEERS ? buffer_ptrs[thread_id] : nullptr; - auto nvl_recv_buffer = buffer_ptrs[nvl_rank]; - auto nvl_reduced_num_tokens_per_expert = Buffer(nvl_recv_buffer, num_rdma_experts).advance_also(nvl_send_buffer); - auto nvl_send_num_tokens_per_rank = AsymBuffer(nvl_send_buffer, kNumRDMARanks, NUM_MAX_NVL_PEERS); - auto nvl_send_num_tokens_per_expert = AsymBuffer(nvl_send_buffer, num_nvl_experts, NUM_MAX_NVL_PEERS); - auto nvl_recv_num_tokens_per_rank = AsymBuffer(nvl_recv_buffer, kNumRDMARanks, NUM_MAX_NVL_PEERS); - auto nvl_recv_num_tokens_per_expert = AsymBuffer(nvl_recv_buffer, num_nvl_experts, NUM_MAX_NVL_PEERS); - - // Clean up for later data dispatch - auto nvl_buffer_ptr_int = reinterpret_cast(buffer_ptrs[nvl_rank]); - EP_DEVICE_ASSERT(nvl_reduced_num_tokens_per_expert.total_bytes + nvl_send_num_tokens_per_rank.total_bytes + - nvl_send_num_tokens_per_expert.total_bytes <= nvl_clean_offset * sizeof(int)); - #pragma unroll - for (int i = thread_id; i < nvl_num_int_clean; i += num_threads) - nvl_buffer_ptr_int[nvl_clean_offset + i] = 0; - - // Reduce number of tokens per expert into the NVL send buffer - // TODO: may use NVSHMEM reduction - EP_DEVICE_ASSERT(num_rdma_experts <= num_threads); - if (thread_id < num_rdma_experts) { - int sum = 0; - #pragma unroll - for (int i = 0; i < kNumRDMARanks; ++ i) - sum += rdma_recv_num_tokens_mixed.recv_buffer(i)[NUM_MAX_NVL_PEERS + thread_id]; - nvl_reduced_num_tokens_per_expert[thread_id] = sum; - } - __syncthreads(); - - // Reduce RDMA received tokens - if (thread_id == 0) { - int sum = 0; - #pragma unroll - for (int i = 0; i < kNumRDMARanks; ++ i) { - sum += rdma_recv_num_tokens_mixed.recv_buffer(i)[NUM_MAX_NVL_PEERS + num_rdma_experts]; - recv_rdma_rank_prefix_sum[i] = sum; - } - while (ld_volatile_global(moe_recv_rdma_counter_mapped) != -1); - *moe_recv_rdma_counter_mapped = sum; - } - - // Send numbers of tokens per rank/expert to NVL ranks - EP_DEVICE_ASSERT(NUM_MAX_NVL_PEERS <= num_threads); - if (thread_id < NUM_MAX_NVL_PEERS) { - #pragma unroll - for (int i = 0; i < kNumRDMARanks; ++ i) - nvl_send_num_tokens_per_rank.buffer(nvl_rank)[i] = rdma_recv_num_tokens_mixed.recv_buffer(i)[thread_id]; - #pragma unroll - for (int i = 0; i < num_nvl_experts; ++ i) - nvl_send_num_tokens_per_expert.buffer(nvl_rank)[i] = nvl_reduced_num_tokens_per_expert[thread_id * num_nvl_experts + i]; - } - memory_fence(); - __syncthreads(); - barrier_device(task_fifo_ptrs, head, nvl_rank); - move_fifo_slots(head); - __syncthreads(); - - // Reduce number of tokens per rank/expert - EP_DEVICE_ASSERT(num_nvl_experts <= num_threads); - if (thread_id == 0) { - int sum = 0; - #pragma unroll - for (int i = 0; i < num_ranks; ++ i) { - int src_rdma_rank = i / NUM_MAX_NVL_PEERS, src_nvl_rank = i % NUM_MAX_NVL_PEERS; - sum += nvl_recv_num_tokens_per_rank.buffer(src_nvl_rank)[src_rdma_rank]; - recv_gbl_rank_prefix_sum[i] = sum; - } - while (ld_volatile_global(moe_recv_counter_mapped) != -1); - *moe_recv_counter_mapped = sum; - } - if (thread_id < num_nvl_experts) { - int sum = 0; - #pragma unroll - for (int i = 0; i < NUM_MAX_NVL_PEERS; ++ i) - sum += nvl_recv_num_tokens_per_expert.buffer(i)[thread_id]; - sum = (sum + expert_alignment - 1) / expert_alignment * expert_alignment; - while (ld_volatile_global(moe_recv_expert_counter_mapped + thread_id) != -1); - moe_recv_expert_counter_mapped[thread_id] = sum; - } - - // Finally barrier - __syncthreads(); - - if (run_barrier) { - port_channel_handles[barrier_channel_idx].signal(); - port_channel_handles[barrier_channel_idx].wait(); - } - if constexpr (!kLowLatencyMode) { - // kLowLatencyMode==false requires sync of all ranks, which can be done by running intra-node sync - // after the inter-node sync is done. - __syncthreads(); - } - barrier_device(task_fifo_ptrs, head, nvl_rank); - move_fifo_slots(head); - } else { - // Calculate meta data - int dst_rdma_rank = sm_id - 1; - for (int channel_id = warp_id; channel_id < num_channels; channel_id += num_warps) { - int token_start_idx, token_end_idx; - get_channel_task_range(num_tokens, num_channels, channel_id, token_start_idx, token_end_idx); - - // Iterate over tokens - int total_count = 0, per_nvl_rank_count[NUM_MAX_NVL_PEERS] = {0}; - for (int64_t i = token_start_idx + lane_id; i < token_end_idx; i += 32) { - EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS * sizeof(bool) == sizeof(uint64_t), "Invalid number of NVL peers"); - auto is_token_in_rank_uint64 = *reinterpret_cast(is_token_in_rank + i * num_ranks + dst_rdma_rank * NUM_MAX_NVL_PEERS); - auto is_token_in_rank_values = reinterpret_cast(&is_token_in_rank_uint64); - #pragma unroll - for (int j = 0; j < NUM_MAX_NVL_PEERS; ++ j) - per_nvl_rank_count[j] += is_token_in_rank_values[j]; - total_count += (is_token_in_rank_uint64 != 0); - } - - // Warp reduce - total_count = warp_reduce_sum(total_count); - #pragma unroll - for (int i = 0; i < NUM_MAX_NVL_PEERS; ++ i) - per_nvl_rank_count[i] = warp_reduce_sum(per_nvl_rank_count[i]); - - // Write into channel matrix - if (lane_id == 0) { - #pragma unroll - for (int i = 0; i < NUM_MAX_NVL_PEERS; ++ i) - gbl_channel_prefix_matrix[(dst_rdma_rank * NUM_MAX_NVL_PEERS + i) * num_channels + channel_id] = per_nvl_rank_count[i]; - rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id] = total_count; - } - } - - // Calculate prefix sum - __syncthreads(); - if (thread_id == 0) { - auto prefix_row = rdma_channel_prefix_matrix + dst_rdma_rank * num_channels; - #pragma unroll - for (int i = 1; i < num_channels; ++ i) - prefix_row[i] += prefix_row[i - 1]; - } - - EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS <= 32, "Invalid number of NVL peers"); - if (thread_id < NUM_MAX_NVL_PEERS) { - auto prefix_row = gbl_channel_prefix_matrix + (dst_rdma_rank * NUM_MAX_NVL_PEERS + thread_id) * num_channels; - #pragma unroll - for (int i = 1; i < num_channels; ++ i) - prefix_row[i] += prefix_row[i - 1]; - } + if (sm_id == 0) { + // Communication with others + // Global barrier: the first warp do intra-node sync, the second warp do internode sync + EP_DEVICE_ASSERT(num_warps > 1); + EP_DEVICE_ASSERT(kNumRDMARanks + 32 <= num_threads); + const auto barrier_thread_id = thread_id - 32; + const bool run_barrier = + (barrier_thread_id >= 0) && (barrier_thread_id < kNumRDMARanks) && (barrier_thread_id != rdma_rank); + const auto barrier_channel_idx = + kLowLatencyMode ? barrier_thread_id : (barrier_thread_id * NUM_MAX_NVL_PEERS + nvl_rank); + if (run_barrier) { + port_channel_handles[barrier_channel_idx].signal(); + port_channel_handles[barrier_channel_idx].wait(); } + if constexpr (!kLowLatencyMode) { + // kLowLatencyMode==false requires sync of all ranks, which can be done by running intra-node sync + // after the inter-node sync is done. + __syncthreads(); + } +#if 1 + barrier_device(task_fifo_ptrs, head, nvl_rank); + move_fifo_slots(head); +#else + // TODO(chhwang): make memory channels work + if (thread_id < NUM_MAX_NVL_PEERS && thread_id != nvl_rank) { + memory_channel_handles[thread_id].relaxedSignal(); + memory_channel_handles[thread_id].relaxedWait(); + } +#endif + __syncthreads(); + + // Send numbers of tokens per rank/expert to RDMA ranks + auto rdma_buffer_ptr_int = reinterpret_cast(rdma_buffer_ptr); + auto num_elems = NUM_MAX_NVL_PEERS + num_rdma_experts + 1; + auto num_bytes = num_elems * sizeof(int); + auto per_channel_bytes = num_bytes * kNumRDMARanks; + auto rdma_recv_num_tokens_mixed = SymBuffer(rdma_buffer_ptr, num_elems, kNumRDMARanks); + + // Clean up for later data dispatch + EP_DEVICE_ASSERT(rdma_recv_num_tokens_mixed.total_bytes <= rdma_clean_offset * sizeof(int)); +#pragma unroll + for (int i = thread_id; i < rdma_num_int_clean; i += num_threads) rdma_buffer_ptr_int[rdma_clean_offset + i] = 0; + +// Copy to send buffer +#pragma unroll + for (int i = thread_id; i < num_ranks; i += num_threads) + rdma_recv_num_tokens_mixed.send_buffer(i / NUM_MAX_NVL_PEERS)[i % NUM_MAX_NVL_PEERS] = num_tokens_per_rank[i]; +#pragma unroll + for (int i = thread_id; i < num_experts; i += num_threads) + rdma_recv_num_tokens_mixed.send_buffer(i / num_rdma_experts)[NUM_MAX_NVL_PEERS + i % num_rdma_experts] = + num_tokens_per_expert[i]; + if (thread_id < kNumRDMARanks) + rdma_recv_num_tokens_mixed.send_buffer(thread_id)[NUM_MAX_NVL_PEERS + num_rdma_experts] = + num_tokens_per_rdma_rank[thread_id]; + __syncthreads(); + + // Issue send + // TODO: more light fence or barrier or signaling + // TODO: overlap EP barrier and NVL cleaning + if (thread_id < kNumRDMARanks) { + auto dst_offset = rdma_rank * num_bytes + per_channel_bytes; + auto src_offset = thread_id * num_bytes; + auto peer_rank = kLowLatencyMode ? thread_id : (thread_id * NUM_MAX_NVL_PEERS + nvl_rank); + port_channel_handles[peer_rank].putWithSignal(dst_offset, src_offset, num_bytes); + port_channel_handles[peer_rank].wait(); + } + __syncthreads(); + + // NVL buffers + auto nvl_send_buffer = thread_id < NUM_MAX_NVL_PEERS ? buffer_ptrs[thread_id] : nullptr; + auto nvl_recv_buffer = buffer_ptrs[nvl_rank]; + auto nvl_reduced_num_tokens_per_expert = + Buffer(nvl_recv_buffer, num_rdma_experts).advance_also(nvl_send_buffer); + auto nvl_send_num_tokens_per_rank = AsymBuffer(nvl_send_buffer, kNumRDMARanks, NUM_MAX_NVL_PEERS); + auto nvl_send_num_tokens_per_expert = AsymBuffer(nvl_send_buffer, num_nvl_experts, NUM_MAX_NVL_PEERS); + auto nvl_recv_num_tokens_per_rank = AsymBuffer(nvl_recv_buffer, kNumRDMARanks, NUM_MAX_NVL_PEERS); + auto nvl_recv_num_tokens_per_expert = AsymBuffer(nvl_recv_buffer, num_nvl_experts, NUM_MAX_NVL_PEERS); + + // Clean up for later data dispatch + auto nvl_buffer_ptr_int = reinterpret_cast(buffer_ptrs[nvl_rank]); + EP_DEVICE_ASSERT(nvl_reduced_num_tokens_per_expert.total_bytes + nvl_send_num_tokens_per_rank.total_bytes + + nvl_send_num_tokens_per_expert.total_bytes <= + nvl_clean_offset * sizeof(int)); +#pragma unroll + for (int i = thread_id; i < nvl_num_int_clean; i += num_threads) nvl_buffer_ptr_int[nvl_clean_offset + i] = 0; + + // Reduce number of tokens per expert into the NVL send buffer + // TODO: may use NVSHMEM reduction + EP_DEVICE_ASSERT(num_rdma_experts <= num_threads); + if (thread_id < num_rdma_experts) { + int sum = 0; +#pragma unroll + for (int i = 0; i < kNumRDMARanks; ++i) + sum += rdma_recv_num_tokens_mixed.recv_buffer(i)[NUM_MAX_NVL_PEERS + thread_id]; + nvl_reduced_num_tokens_per_expert[thread_id] = sum; + } + __syncthreads(); + + // Reduce RDMA received tokens + if (thread_id == 0) { + int sum = 0; +#pragma unroll + for (int i = 0; i < kNumRDMARanks; ++i) { + sum += rdma_recv_num_tokens_mixed.recv_buffer(i)[NUM_MAX_NVL_PEERS + num_rdma_experts]; + recv_rdma_rank_prefix_sum[i] = sum; + } + while (ld_volatile_global(moe_recv_rdma_counter_mapped) != -1) + ; + *moe_recv_rdma_counter_mapped = sum; + } + + // Send numbers of tokens per rank/expert to NVL ranks + EP_DEVICE_ASSERT(NUM_MAX_NVL_PEERS <= num_threads); + if (thread_id < NUM_MAX_NVL_PEERS) { +#pragma unroll + for (int i = 0; i < kNumRDMARanks; ++i) + nvl_send_num_tokens_per_rank.buffer(nvl_rank)[i] = rdma_recv_num_tokens_mixed.recv_buffer(i)[thread_id]; +#pragma unroll + for (int i = 0; i < num_nvl_experts; ++i) + nvl_send_num_tokens_per_expert.buffer(nvl_rank)[i] = + nvl_reduced_num_tokens_per_expert[thread_id * num_nvl_experts + i]; + } + memory_fence(); + __syncthreads(); + barrier_device(task_fifo_ptrs, head, nvl_rank); + move_fifo_slots(head); + __syncthreads(); + + // Reduce number of tokens per rank/expert + EP_DEVICE_ASSERT(num_nvl_experts <= num_threads); + if (thread_id == 0) { + int sum = 0; +#pragma unroll + for (int i = 0; i < num_ranks; ++i) { + int src_rdma_rank = i / NUM_MAX_NVL_PEERS, src_nvl_rank = i % NUM_MAX_NVL_PEERS; + sum += nvl_recv_num_tokens_per_rank.buffer(src_nvl_rank)[src_rdma_rank]; + recv_gbl_rank_prefix_sum[i] = sum; + } + while (ld_volatile_global(moe_recv_counter_mapped) != -1) + ; + *moe_recv_counter_mapped = sum; + } + if (thread_id < num_nvl_experts) { + int sum = 0; +#pragma unroll + for (int i = 0; i < NUM_MAX_NVL_PEERS; ++i) sum += nvl_recv_num_tokens_per_expert.buffer(i)[thread_id]; + sum = (sum + expert_alignment - 1) / expert_alignment * expert_alignment; + while (ld_volatile_global(moe_recv_expert_counter_mapped + thread_id) != -1) + ; + moe_recv_expert_counter_mapped[thread_id] = sum; + } + + // Finally barrier + __syncthreads(); + + if (run_barrier) { + port_channel_handles[barrier_channel_idx].signal(); + port_channel_handles[barrier_channel_idx].wait(); + } + if constexpr (!kLowLatencyMode) { + // kLowLatencyMode==false requires sync of all ranks, which can be done by running intra-node sync + // after the inter-node sync is done. + __syncthreads(); + } + barrier_device(task_fifo_ptrs, head, nvl_rank); + move_fifo_slots(head); + } else { + // Calculate meta data + int dst_rdma_rank = sm_id - 1; + for (int channel_id = warp_id; channel_id < num_channels; channel_id += num_warps) { + int token_start_idx, token_end_idx; + get_channel_task_range(num_tokens, num_channels, channel_id, token_start_idx, token_end_idx); + + // Iterate over tokens + int total_count = 0, per_nvl_rank_count[NUM_MAX_NVL_PEERS] = {0}; + for (int64_t i = token_start_idx + lane_id; i < token_end_idx; i += 32) { + EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS * sizeof(bool) == sizeof(uint64_t), "Invalid number of NVL peers"); + auto is_token_in_rank_uint64 = + *reinterpret_cast(is_token_in_rank + i * num_ranks + dst_rdma_rank * NUM_MAX_NVL_PEERS); + auto is_token_in_rank_values = reinterpret_cast(&is_token_in_rank_uint64); +#pragma unroll + for (int j = 0; j < NUM_MAX_NVL_PEERS; ++j) per_nvl_rank_count[j] += is_token_in_rank_values[j]; + total_count += (is_token_in_rank_uint64 != 0); + } + + // Warp reduce + total_count = warp_reduce_sum(total_count); +#pragma unroll + for (int i = 0; i < NUM_MAX_NVL_PEERS; ++i) per_nvl_rank_count[i] = warp_reduce_sum(per_nvl_rank_count[i]); + + // Write into channel matrix + if (lane_id == 0) { +#pragma unroll + for (int i = 0; i < NUM_MAX_NVL_PEERS; ++i) + gbl_channel_prefix_matrix[(dst_rdma_rank * NUM_MAX_NVL_PEERS + i) * num_channels + channel_id] = + per_nvl_rank_count[i]; + rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id] = total_count; + } + } + + // Calculate prefix sum + __syncthreads(); + if (thread_id == 0) { + auto prefix_row = rdma_channel_prefix_matrix + dst_rdma_rank * num_channels; +#pragma unroll + for (int i = 1; i < num_channels; ++i) prefix_row[i] += prefix_row[i - 1]; + } + + EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS <= 32, "Invalid number of NVL peers"); + if (thread_id < NUM_MAX_NVL_PEERS) { + auto prefix_row = gbl_channel_prefix_matrix + (dst_rdma_rank * NUM_MAX_NVL_PEERS + thread_id) * num_channels; +#pragma unroll + for (int i = 1; i < num_channels; ++i) prefix_row[i] += prefix_row[i - 1]; + } + } } void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, int num_ranks, const int* num_tokens_per_rdma_rank, int* moe_recv_rdma_counter_mapped, const int* num_tokens_per_expert, int* moe_recv_expert_counter_mapped, int num_experts, - const bool* is_token_in_rank, int num_tokens, int num_channels, - int hidden_int4, int num_scales, int num_topk, int expert_alignment, - int* rdma_channel_prefix_matrix, int* recv_rdma_rank_prefix_sum, - int* gbl_channel_prefix_matrix, int* recv_gbl_rank_prefix_sum, - void* rdma_buffer_ptr, int num_max_rdma_chunked_recv_tokens, - void** buffer_ptrs, int num_max_nvl_chunked_recv_tokens, - int** task_fifo_ptrs, int head, int rank, - cudaStream_t stream, int64_t num_rdma_bytes, int64_t num_nvl_bytes, - bool low_latency_mode, - mscclpp::PortChannelDeviceHandle *port_channel_handles, - mscclpp::MemoryChannelDeviceHandle *memory_channel_handles) { -#define NOTIFY_DISPATCH_LAUNCH_CASE(num_rdma_ranks) { \ - auto notify_dispatch_func = low_latency_mode ? \ - notify_dispatch : notify_dispatch; \ - LAUNCH_KERNEL(&cfg, notify_dispatch_func, \ - num_tokens_per_rank, moe_recv_counter_mapped, num_ranks, \ - num_tokens_per_rdma_rank, moe_recv_rdma_counter_mapped, \ - num_tokens_per_expert, moe_recv_expert_counter_mapped, num_experts, \ - is_token_in_rank, num_tokens, num_channels, expert_alignment, \ - rdma_clean_meta.first, rdma_clean_meta.second, \ - nvl_clean_meta.first, nvl_clean_meta.second, \ - rdma_channel_prefix_matrix, recv_rdma_rank_prefix_sum, \ - gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum, \ - rdma_buffer_ptr, \ - buffer_ptrs, task_fifo_ptrs, head, rank, \ - port_channel_handles, memory_channel_handles); } break + const bool* is_token_in_rank, int num_tokens, int num_channels, int hidden_int4, int num_scales, + int num_topk, int expert_alignment, int* rdma_channel_prefix_matrix, + int* recv_rdma_rank_prefix_sum, int* gbl_channel_prefix_matrix, int* recv_gbl_rank_prefix_sum, + void* rdma_buffer_ptr, int num_max_rdma_chunked_recv_tokens, void** buffer_ptrs, + int num_max_nvl_chunked_recv_tokens, int** task_fifo_ptrs, int head, int rank, cudaStream_t stream, + int64_t num_rdma_bytes, int64_t num_nvl_bytes, bool low_latency_mode, + mscclpp::PortChannelDeviceHandle* port_channel_handles, + mscclpp::MemoryChannelDeviceHandle* memory_channel_handles) { +#define NOTIFY_DISPATCH_LAUNCH_CASE(num_rdma_ranks) \ + { \ + auto notify_dispatch_func = \ + low_latency_mode ? notify_dispatch : notify_dispatch; \ + LAUNCH_KERNEL(&cfg, notify_dispatch_func, num_tokens_per_rank, moe_recv_counter_mapped, num_ranks, \ + num_tokens_per_rdma_rank, moe_recv_rdma_counter_mapped, num_tokens_per_expert, \ + moe_recv_expert_counter_mapped, num_experts, is_token_in_rank, num_tokens, num_channels, \ + expert_alignment, rdma_clean_meta.first, rdma_clean_meta.second, nvl_clean_meta.first, \ + nvl_clean_meta.second, rdma_channel_prefix_matrix, recv_rdma_rank_prefix_sum, \ + gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum, rdma_buffer_ptr, buffer_ptrs, task_fifo_ptrs, \ + head, rank, port_channel_handles, memory_channel_handles); \ + } \ + break - constexpr int kNumThreads = 512; - const auto num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS; + constexpr int kNumThreads = 512; + const auto num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS; - // Get clean meta - auto rdma_clean_meta = get_rdma_clean_meta(hidden_int4, num_scales, num_topk, num_topk, num_rdma_ranks, num_max_rdma_chunked_recv_tokens, num_channels); - auto nvl_clean_meta = get_nvl_clean_meta(hidden_int4, num_scales, num_topk, num_topk, num_rdma_ranks, NUM_MAX_NVL_PEERS, num_max_nvl_chunked_recv_tokens, num_channels); - EP_HOST_ASSERT((rdma_clean_meta.first + rdma_clean_meta.second) * sizeof(int) <= num_rdma_bytes); - EP_HOST_ASSERT((nvl_clean_meta.first + nvl_clean_meta.second) * sizeof(int) <= num_nvl_bytes); - EP_HOST_ASSERT(num_rdma_bytes < std::numeric_limits::max()); - EP_HOST_ASSERT(num_nvl_bytes < std::numeric_limits::max()); + // Get clean meta + auto rdma_clean_meta = get_rdma_clean_meta(hidden_int4, num_scales, num_topk, num_topk, num_rdma_ranks, + num_max_rdma_chunked_recv_tokens, num_channels); + auto nvl_clean_meta = get_nvl_clean_meta(hidden_int4, num_scales, num_topk, num_topk, num_rdma_ranks, + NUM_MAX_NVL_PEERS, num_max_nvl_chunked_recv_tokens, num_channels); + EP_HOST_ASSERT((rdma_clean_meta.first + rdma_clean_meta.second) * sizeof(int) <= num_rdma_bytes); + EP_HOST_ASSERT((nvl_clean_meta.first + nvl_clean_meta.second) * sizeof(int) <= num_nvl_bytes); + EP_HOST_ASSERT(num_rdma_bytes < std::numeric_limits::max()); + EP_HOST_ASSERT(num_nvl_bytes < std::numeric_limits::max()); - // Launch kernel - SETUP_LAUNCH_CONFIG(1 + num_rdma_ranks, kNumThreads, stream); - SWITCH_RDMA_RANKS(NOTIFY_DISPATCH_LAUNCH_CASE); + // Launch kernel + SETUP_LAUNCH_CONFIG(1 + num_rdma_ranks, kNumThreads, stream); + SWITCH_RDMA_RANKS(NOTIFY_DISPATCH_LAUNCH_CASE); #undef NOTIFY_DISPATCH_LAUNCH_CASE } // At most 8 RDMA ranks to be sent -constexpr int get_num_topk_rdma_ranks(int num_rdma_ranks) { - return num_rdma_ranks < 8 ? num_rdma_ranks : 8; -} +constexpr int get_num_topk_rdma_ranks(int num_rdma_ranks) { return num_rdma_ranks < 8 ? num_rdma_ranks : 8; } -template +template __global__ void __launch_bounds__(((kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NVL_PEERS) * 32), 1) -dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv_topk_weights, SourceMeta* recv_src_meta, - const int4* x, const float* x_scales, const int64_t* topk_idx, const float* topk_weights, - int* send_rdma_head, int* send_nvl_head, - int* recv_rdma_channel_prefix_matrix, int* recv_gbl_channel_prefix_matrix, - const int* rdma_channel_prefix_matrix, const int* recv_rdma_rank_prefix_sum, - const int* gbl_channel_prefix_matrix, const int* recv_gbl_rank_prefix_sum, - int num_tokens, int hidden_int4, int num_scales, int num_topk, int num_experts, - const bool* is_token_in_rank, - void* rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens, - void** buffer_ptrs, int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens, - int rank, int num_ranks, - mscclpp::PortChannelDeviceHandle *port_channel_handles, - mscclpp::MemoryChannelDeviceHandle *memory_channel_handles) { - enum class WarpRole { - kRDMASender, - kRDMASenderCoordinator, - kRDMAAndNVLForwarder, - kForwarderCoordinator, - kNVLReceivers - }; + dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv_topk_weights, + SourceMeta* recv_src_meta, const int4* x, const float* x_scales, const int64_t* topk_idx, + const float* topk_weights, int* send_rdma_head, int* send_nvl_head, int* recv_rdma_channel_prefix_matrix, + int* recv_gbl_channel_prefix_matrix, const int* rdma_channel_prefix_matrix, + const int* recv_rdma_rank_prefix_sum, const int* gbl_channel_prefix_matrix, + const int* recv_gbl_rank_prefix_sum, int num_tokens, int hidden_int4, int num_scales, int num_topk, + int num_experts, const bool* is_token_in_rank, void* rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens, + int num_max_rdma_chunked_recv_tokens, void** buffer_ptrs, int num_max_nvl_chunked_send_tokens, + int num_max_nvl_chunked_recv_tokens, int rank, int num_ranks, + mscclpp::PortChannelDeviceHandle* port_channel_handles, + mscclpp::MemoryChannelDeviceHandle* memory_channel_handles) { + enum class WarpRole { + kRDMASender, + kRDMASenderCoordinator, + kRDMAAndNVLForwarder, + kForwarderCoordinator, + kNVLReceivers + }; - const auto sm_id = static_cast(blockIdx.x); - const auto thread_id = static_cast(threadIdx.x), warp_id = thread_id / 32, lane_id = get_lane_id(); - const auto num_channels = static_cast(gridDim.x) / 2, channel_id = sm_id / 2; - const bool is_forwarder = sm_id % 2 == 0; - const auto rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS; + const auto sm_id = static_cast(blockIdx.x); + const auto thread_id = static_cast(threadIdx.x), warp_id = thread_id / 32, lane_id = get_lane_id(); + const auto num_channels = static_cast(gridDim.x) / 2, channel_id = sm_id / 2; + const bool is_forwarder = sm_id % 2 == 0; + const auto rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS; - const auto role_meta = [=]() -> std::pair { - if (is_forwarder) { - if (warp_id < NUM_MAX_NVL_PEERS) { - return {WarpRole::kRDMAAndNVLForwarder, (warp_id + channel_id) % NUM_MAX_NVL_PEERS}; - } else { - return {WarpRole::kForwarderCoordinator, warp_id - NUM_MAX_NVL_PEERS}; - } - } else if (warp_id < kNumDispatchRDMASenderWarps) { - return {WarpRole::kRDMASender, -1}; - } else if (warp_id == kNumDispatchRDMASenderWarps) { - return {WarpRole::kRDMASenderCoordinator, -1}; - } else { - return {WarpRole::kNVLReceivers, (warp_id + channel_id - kNumDispatchRDMASenderWarps) % NUM_MAX_NVL_PEERS}; - } - }(); - auto warp_role = role_meta.first; - auto target_rank = role_meta.second; // Not applicable for RDMA senders - EP_DEVICE_ASSERT(num_warps == kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NVL_PEERS); - - // Data checks - EP_DEVICE_ASSERT(num_topk <= 32); - - // RDMA symmetric layout - EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS * sizeof(bool) == sizeof(uint64_t), "Invalid number of NVL peers"); - auto hidden_bytes = hidden_int4 * sizeof(int4); - auto num_bytes_per_rdma_token = get_num_bytes_per_rdma_token(hidden_int4, num_scales, num_topk, num_topk); - auto rdma_channel_data = SymBuffer(rdma_buffer_ptr, num_max_rdma_chunked_recv_tokens * num_bytes_per_rdma_token, kNumRDMARanks, channel_id, num_channels); - auto rdma_channel_meta = SymBuffer(rdma_buffer_ptr, NUM_MAX_NVL_PEERS * 2 + 2, kNumRDMARanks, channel_id, num_channels); - auto rdma_channel_head = SymBuffer(rdma_buffer_ptr, 1, kNumRDMARanks, channel_id, num_channels); - auto rdma_channel_tail = SymBuffer(rdma_buffer_ptr, 1, kNumRDMARanks, channel_id, num_channels); - - auto data_send_offset = sizeof(int8_t) * (num_max_rdma_chunked_recv_tokens * num_bytes_per_rdma_token) * kNumRDMARanks * channel_id; - auto data_recv_offset = sizeof(int8_t) * (num_max_rdma_chunked_recv_tokens * num_bytes_per_rdma_token) * kNumRDMARanks * (channel_id + num_channels); - auto meta_offset = sizeof(int8_t) * (num_max_rdma_chunked_recv_tokens * num_bytes_per_rdma_token) * kNumRDMARanks * num_channels * 2; - auto meta_send_offset = meta_offset + sizeof(int) * (NUM_MAX_NVL_PEERS * 2 + 2) * kNumRDMARanks * channel_id; - auto meta_recv_offset = meta_offset + sizeof(int) * (NUM_MAX_NVL_PEERS * 2 + 2) * kNumRDMARanks * (channel_id + num_channels); - auto head_offset = meta_offset + sizeof(int) * (NUM_MAX_NVL_PEERS * 2 + 2) * kNumRDMARanks * num_channels * 2; - auto head_send_offset = head_offset + sizeof(uint64_t) * kNumRDMARanks * channel_id; - auto tail_offset = head_offset + sizeof(uint64_t) * kNumRDMARanks * num_channels; - auto tail_send_offset = tail_offset + sizeof(uint64_t) * kNumRDMARanks * channel_id; - - // NVL buffer layouts - // NOTES: `rs_wr_buffer_ptr` means "Read for Senders, Write for Receivers", `ws_rr_buffer_ptr` means "Write for Senders, Read for Receivers" - void *rs_wr_buffer_ptr = nullptr, *ws_rr_buffer_ptr = nullptr; - int rs_wr_rank = 0, ws_rr_rank = 0; - if (warp_role == WarpRole::kRDMAAndNVLForwarder) - rs_wr_buffer_ptr = buffer_ptrs[nvl_rank], ws_rr_buffer_ptr = buffer_ptrs[target_rank], rs_wr_rank = nvl_rank, ws_rr_rank = target_rank; - if (warp_role == WarpRole::kNVLReceivers) - rs_wr_buffer_ptr = buffer_ptrs[target_rank], ws_rr_buffer_ptr = buffer_ptrs[nvl_rank], rs_wr_rank = target_rank, ws_rr_rank = nvl_rank; - - // Allocate buffers - auto nvl_channel_x = AsymBuffer(ws_rr_buffer_ptr, num_max_nvl_chunked_recv_tokens * hidden_int4, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr); - auto nvl_channel_src_meta = AsymBuffer(ws_rr_buffer_ptr, num_max_nvl_chunked_recv_tokens, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr); - auto nvl_channel_x_scales = AsymBuffer(ws_rr_buffer_ptr, num_max_nvl_chunked_recv_tokens * num_scales, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr); - auto nvl_channel_topk_idx = AsymBuffer(ws_rr_buffer_ptr, num_max_nvl_chunked_recv_tokens * num_topk, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr); - auto nvl_channel_topk_weights = AsymBuffer(ws_rr_buffer_ptr, num_max_nvl_chunked_recv_tokens * num_topk, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr); - auto nvl_channel_prefix_start = AsymBuffer(ws_rr_buffer_ptr, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr); - auto nvl_channel_prefix_end = AsymBuffer(ws_rr_buffer_ptr, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr); - auto nvl_channel_head = AsymBuffer(rs_wr_buffer_ptr, 1, NUM_MAX_NVL_PEERS, channel_id, num_channels, ws_rr_rank).advance_also(ws_rr_buffer_ptr); - auto nvl_channel_tail = AsymBuffer(ws_rr_buffer_ptr, 1, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr); - - // RDMA sender warp synchronization - __shared__ volatile int rdma_send_next_token_idx; - __shared__ volatile int rdma_send_channel_tail[kNumRDMARanks]; - __shared__ volatile int rdma_send_channel_next_tail[kNumRDMARanks]; - auto sync_rdma_sender_smem = []() { asm volatile("bar.sync 0, %0;" :: "r"((kNumDispatchRDMASenderWarps + 1) * 32)); }; - - // Forward warp synchronization - __shared__ volatile int forward_channel_head[NUM_MAX_NVL_PEERS][kNumRDMARanks]; - __shared__ volatile bool forward_channel_retired[NUM_MAX_NVL_PEERS]; - auto sync_forwarder_smem = []() { asm volatile("bar.sync 1, %0;" :: "r"((NUM_MAX_NVL_PEERS + 1) * 32)); }; - - if (warp_role == WarpRole::kRDMASender) { - // Get tasks - int token_start_idx, token_end_idx; - get_channel_task_range(num_tokens, num_channels, channel_id, token_start_idx, token_end_idx); - - // Clean shared memory - EP_STATIC_ASSERT(kNumRDMARanks <= 32, "Invalid number of RDMA ranks"); - (warp_id == 0 and lane_id == 0) ? (rdma_send_next_token_idx = token_start_idx) : 0; - (warp_id == 0 and lane_id < kNumRDMARanks) ? (rdma_send_channel_tail[lane_id] = 0) : 0; - (warp_id == 0 and lane_id < kNumRDMARanks) ? (rdma_send_channel_next_tail[lane_id] = 0) : 0; - - // Send number of tokens in this channel by `-value - 1` - EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS * 2 + 2 <= 32, "Invalid number of NVL peers"); - for (int dst_rdma_rank = warp_id; dst_rdma_rank < kNumRDMARanks; dst_rdma_rank += kNumDispatchRDMASenderWarps) { - auto dst_ptr = dst_rdma_rank == rdma_rank ? rdma_channel_meta.recv_buffer(dst_rdma_rank) : rdma_channel_meta.send_buffer(dst_rdma_rank); - if (lane_id < NUM_MAX_NVL_PEERS) { - dst_ptr[lane_id] = -(channel_id == 0 ? 0 : gbl_channel_prefix_matrix[(dst_rdma_rank * NUM_MAX_NVL_PEERS + lane_id) * num_channels + channel_id - 1]) - 1; - } else if (lane_id < NUM_MAX_NVL_PEERS * 2) { - dst_ptr[lane_id] = -gbl_channel_prefix_matrix[(dst_rdma_rank * NUM_MAX_NVL_PEERS + lane_id - NUM_MAX_NVL_PEERS) * num_channels + channel_id] - 1; - } else if (lane_id == NUM_MAX_NVL_PEERS * 2) { - dst_ptr[lane_id] = -(channel_id == 0 ? 0 : rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id - 1]) - 1; - } else if (lane_id == NUM_MAX_NVL_PEERS * 2 + 1) { - dst_ptr[lane_id] = -rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id] - 1; - } - __syncwarp(); - - if (dst_rdma_rank == rdma_rank) continue; - - // Issue RDMA for non-local ranks - if (lane_id == 0) { - auto num_bytes = sizeof(int) * (NUM_MAX_NVL_PEERS * 2 + 2); - auto dst_offset = rdma_rank * num_bytes + meta_recv_offset; - auto src_offset = dst_rdma_rank * num_bytes + meta_send_offset; - auto port_channel_idx = kLowLatencyMode ? (channel_id * kNumRDMARanks + dst_rdma_rank) : (channel_id * num_ranks + dst_rdma_rank * NUM_MAX_NVL_PEERS + nvl_rank); - port_channel_handles[port_channel_idx].put(dst_offset, src_offset, num_bytes); - // port_channel_handles[port_channel_idx].flush(); - } - __syncwarp(); - } - sync_rdma_sender_smem(); - - // Iterate over tokens and copy into buffer - int64_t token_idx; - int cached_rdma_channel_head = 0, last_rdma_tail_idx = -1; - auto send_buffer = lane_id == rdma_rank ? rdma_channel_data.recv_buffer(lane_id) : rdma_channel_data.send_buffer(lane_id); - for (token_idx = token_start_idx + warp_id; token_idx < token_end_idx; token_idx += kNumDispatchRDMASenderWarps) { - // Read RDMA rank existence - uint64_t is_token_in_rank_uint64 = 0; - if (lane_id < kNumRDMARanks) - is_token_in_rank_uint64 = *reinterpret_cast(is_token_in_rank + token_idx * num_ranks + lane_id * NUM_MAX_NVL_PEERS); - - // Acquire sequential lock - while (lane_id == 0 and rdma_send_next_token_idx != token_idx); - __syncwarp(); - - // Acquire next tail - int rdma_tail_idx = -1; - if (is_token_in_rank_uint64 != 0) { - rdma_tail_idx = rdma_send_channel_next_tail[lane_id] ++; - while (rdma_tail_idx - cached_rdma_channel_head >= num_max_rdma_chunked_recv_tokens) - cached_rdma_channel_head = static_cast(ld_volatile_global(rdma_channel_head.buffer(lane_id))); - } - __syncwarp(); - - // Store RDMA head for combine - if (lane_id < kNumRDMARanks and not kCachedMode) - send_rdma_head[token_idx * kNumRDMARanks + lane_id] = rdma_tail_idx; - - // Update last token tail - if (last_rdma_tail_idx >= 0) - st_release_cta(const_cast(rdma_send_channel_tail + lane_id), last_rdma_tail_idx + 1); - last_rdma_tail_idx = rdma_tail_idx; - - // Release sequential lock - lane_id == 0 ? (rdma_send_next_token_idx += 1) : 0; - - // Broadcast tails - SourceMeta src_meta; - int num_topk_ranks = 0, topk_ranks[kNumTopkRDMARanks]; - void* dst_send_buffers[kNumTopkRDMARanks]; - #pragma unroll - for (int i = 0, slot_idx; i < kNumRDMARanks; ++ i) if ((slot_idx = __shfl_sync(0xffffffff, rdma_tail_idx, i)) >= 0) { - slot_idx = slot_idx % num_max_rdma_chunked_recv_tokens; - topk_ranks[num_topk_ranks] = i; - auto recv_is_token_in_rank_uint64 = broadcast(is_token_in_rank_uint64, i); - auto recv_is_token_in_rank_values = reinterpret_cast(&recv_is_token_in_rank_uint64); - if (lane_id == num_topk_ranks) - src_meta = SourceMeta(rdma_rank, recv_is_token_in_rank_values); - dst_send_buffers[num_topk_ranks ++] = reinterpret_cast(broadcast(send_buffer, i)) + slot_idx * num_bytes_per_rdma_token; - } - EP_DEVICE_ASSERT(num_topk_ranks <= kNumTopkRDMARanks); - - // Copy `x` into symmetric send buffer - auto st_broadcast = [=](const int key, const int4& value) { - #pragma unroll - for (int j = 0; j < num_topk_ranks; ++ j) - st_na_global(reinterpret_cast(dst_send_buffers[j]) + key, value); - }; - UNROLLED_WARP_COPY(5, lane_id, hidden_int4, 0, x + token_idx * hidden_int4, ld_nc_global, st_broadcast); - #pragma unroll - for (int i = 0; i < num_topk_ranks; ++ i) - dst_send_buffers[i] = reinterpret_cast(dst_send_buffers[i]) + hidden_int4; - - // Copy source metadata into symmetric send buffer - if (lane_id < num_topk_ranks) - st_na_global(reinterpret_cast(dst_send_buffers[lane_id]), src_meta); - #pragma unroll - for (int i = 0; i < num_topk_ranks; ++ i) - dst_send_buffers[i] = reinterpret_cast(dst_send_buffers[i]) + 1; - - // Copy `x_scales` into symmetric send buffer - #pragma unroll - for (int i = lane_id; i < num_scales; i += 32) { - auto value = ld_nc_global(x_scales + token_idx * num_scales + i); - #pragma unroll - for (int j = 0; j < num_topk_ranks; ++ j) - st_na_global(reinterpret_cast(dst_send_buffers[j]) + i, value); - } - #pragma unroll - for (int i = 0; i < num_topk_ranks; ++ i) - dst_send_buffers[i] = reinterpret_cast(dst_send_buffers[i]) + num_scales; - - // Copy `topk_idx` and `topk_weights` into symmetric send buffer - #pragma unroll - for (int i = lane_id; i < num_topk * num_topk_ranks; i += 32) { - auto rank_idx = i / num_topk, copy_idx = i % num_topk; - auto idx_value = static_cast(ld_nc_global(topk_idx + token_idx * num_topk + copy_idx)); - auto weight_value = ld_nc_global(topk_weights + token_idx * num_topk + copy_idx); - st_na_global(reinterpret_cast(dst_send_buffers[rank_idx]) + copy_idx, idx_value); - st_na_global(reinterpret_cast(dst_send_buffers[rank_idx]) + num_topk + copy_idx, weight_value); - } - } - - // Epilogue - // Acquire sequential lock - while (lane_id == 0 and rdma_send_next_token_idx != token_idx); - __syncwarp(); - - // Update last token tail - if (last_rdma_tail_idx >= 0) - st_release_cta(const_cast(rdma_send_channel_tail + lane_id), last_rdma_tail_idx + 1); - - // Release sequential lock - lane_id == 0 ? (rdma_send_next_token_idx += 1) : 0; - } else if (warp_role == WarpRole::kRDMASenderCoordinator) { - // NOTES: in case of splitting the issued put at the end of the buffer - EP_DEVICE_ASSERT(num_max_rdma_chunked_recv_tokens % num_max_rdma_chunked_send_tokens == 0); - - // Synchronize shared memory - sync_rdma_sender_smem(); - - // Get number of tokens to send for each RDMA rank - int num_tokens_to_send = 0; - if (lane_id < kNumRDMARanks) { - num_tokens_to_send = rdma_channel_prefix_matrix[lane_id * num_channels + channel_id]; - if (channel_id > 0) - num_tokens_to_send -= rdma_channel_prefix_matrix[lane_id * num_channels + channel_id - 1]; - } - - // Iterate all RDMA ranks - int last_issued_tail = 0; - while (__any_sync(0xffffffff, num_tokens_to_send > 0)) { - #pragma unroll - for (int i = 0; i < kNumRDMARanks; ++i, __syncwarp()) { - // To mitigate incast congestion, shuffle the starting index of target rank for different ranks and channels - const int dst_rdma_rank = (i + channel_id + rdma_rank) % kNumRDMARanks; - if (lane_id != dst_rdma_rank) continue; - if (num_tokens_to_send == 0) continue; - - // Read progress - auto processed_tail = ld_acquire_cta(const_cast(rdma_send_channel_tail + dst_rdma_rank)); - auto num_tokens_processed = processed_tail - last_issued_tail; - if (num_tokens_processed != num_tokens_to_send && num_tokens_processed < num_max_rdma_chunked_send_tokens) - continue; - - // Issue RDMA send - int num_tokens_to_issue = min(num_tokens_processed, num_max_rdma_chunked_send_tokens); - EP_DEVICE_ASSERT(num_tokens_to_issue >= 0 && num_tokens_to_issue <= num_tokens_to_send); - if (num_tokens_to_issue == 0) continue; - - if (dst_rdma_rank == rdma_rank) { - // Update tails - mscclpp::atomicFetchAdd(reinterpret_cast(rdma_channel_tail.buffer(rdma_rank)), (uint64_t)num_tokens_to_issue, mscclpp::memoryOrderRelease); - } else { - const auto dst_slot_idx = last_issued_tail % num_max_rdma_chunked_recv_tokens; - const size_t num_bytes_per_msg = num_bytes_per_rdma_token * num_tokens_to_issue; - const auto dst_offset = rdma_rank * (num_max_rdma_chunked_recv_tokens * num_bytes_per_rdma_token) + dst_slot_idx * num_bytes_per_rdma_token + data_recv_offset; - const auto src_offset = dst_rdma_rank * (num_max_rdma_chunked_recv_tokens * num_bytes_per_rdma_token) + dst_slot_idx * num_bytes_per_rdma_token + data_send_offset; - const auto port_channel_idx = kLowLatencyMode ? (channel_id * kNumRDMARanks + dst_rdma_rank) : (channel_id * num_ranks + dst_rdma_rank * NUM_MAX_NVL_PEERS + nvl_rank); - auto& handle = port_channel_handles[port_channel_idx]; - handle.put(dst_offset, src_offset, num_bytes_per_msg); - - // Remote atomic add on the peer's tail counter: +num_tokens_to_issue. - handle.atomicAdd(rdma_rank * sizeof(uint64_t) + tail_send_offset, (int64_t)num_tokens_to_issue); - // handle.flush(); - } - last_issued_tail += num_tokens_to_issue; - num_tokens_to_send -= num_tokens_to_issue; - } - } - } else if (warp_role == WarpRole::kRDMAAndNVLForwarder) { - // RDMA consumers and NVL producers - const auto dst_nvl_rank = target_rank; - const auto dst_rank = rdma_rank * NUM_MAX_NVL_PEERS + dst_nvl_rank; - const auto dst_rank_expert_begin = dst_rank * (num_experts / num_ranks); - const auto dst_rank_expert_end = dst_rank_expert_begin + (num_experts / num_ranks); - - // Wait counters to arrive - int num_tokens_to_recv_from_rdma = 0, src_rdma_channel_prefix = 0; - EP_DEVICE_ASSERT(kNumRDMARanks <= 32); - auto start_time = clock64(); - if (lane_id < kNumRDMARanks) { - while (true) { - auto meta_0 = ld_volatile_global(rdma_channel_meta.recv_buffer(lane_id) + dst_nvl_rank); - auto meta_1 = ld_volatile_global(rdma_channel_meta.recv_buffer(lane_id) + NUM_MAX_NVL_PEERS + dst_nvl_rank); - auto meta_2 = ld_volatile_global(rdma_channel_meta.recv_buffer(lane_id) + NUM_MAX_NVL_PEERS * 2); - auto meta_3 = ld_volatile_global(rdma_channel_meta.recv_buffer(lane_id) + NUM_MAX_NVL_PEERS * 2 + 1); - if (meta_0 < 0 and meta_1 < 0 and meta_2 < 0 and meta_3 < 0) { - // Notify NVL ranks - int start_sum = -meta_0 - 1, end_sum = -meta_1 - 1; - EP_DEVICE_ASSERT(start_sum >= 0 and end_sum >= 0 and end_sum >= start_sum); - st_relaxed_sys_global(nvl_channel_prefix_start.buffer() + lane_id, -start_sum - 1); - st_relaxed_sys_global(nvl_channel_prefix_end.buffer() + lane_id, -end_sum - 1); - - // Save RDMA channel received token count - src_rdma_channel_prefix = -meta_2 - 1; - auto src_rdma_channel_prefix_1 = -meta_3 - 1; - num_tokens_to_recv_from_rdma = src_rdma_channel_prefix_1 - src_rdma_channel_prefix; - if (not kCachedMode) - recv_rdma_channel_prefix_matrix[lane_id * num_channels + channel_id] = src_rdma_channel_prefix_1; - src_rdma_channel_prefix += lane_id == 0 ? 0 : recv_rdma_rank_prefix_sum[lane_id - 1]; - EP_DEVICE_ASSERT(num_tokens_to_recv_from_rdma >= 0); - break; - } - - // Timeout check - if (clock64() - start_time > NUM_TIMEOUT_CYCLES) { - printf("DeepEP dispatch forwarder timeout (RDMA meta), channel: %d, RDMA: %d, nvl: %d, src RDMA lane: %d, dst NVL: %d, meta: %d, %d, %d, %d\n", - channel_id, rdma_rank, nvl_rank, lane_id, dst_nvl_rank, meta_0, meta_1, meta_2, meta_3); - trap(); - } - } - } - __syncwarp(); - - // Shift cached head - send_nvl_head += src_rdma_channel_prefix * NUM_MAX_NVL_PEERS + dst_nvl_rank; - - // Wait shared memory to be cleaned - sync_forwarder_smem(); - - // Forward tokens from RDMA buffer - // NOTES: always start from the local rank - int src_rdma_rank = sm_id % kNumRDMARanks; - int cached_rdma_channel_head = 0, cached_rdma_channel_tail = 0; - int cached_nvl_channel_head = 0, cached_nvl_channel_tail = 0, rdma_nvl_token_idx = 0; - while (__any_sync(0xffffffff, num_tokens_to_recv_from_rdma > 0)) { - // Check destination queue emptiness, or wait a buffer to be released - start_time = clock64(); - while (lane_id == 0) { - int num_used_slots = cached_nvl_channel_tail - cached_nvl_channel_head; - if (num_max_nvl_chunked_recv_tokens - num_used_slots >= num_max_nvl_chunked_send_tokens) - break; - cached_nvl_channel_head = ld_volatile_global(nvl_channel_head.buffer()); - - // Timeout check - if (clock64() - start_time > NUM_TIMEOUT_CYCLES) { - printf("DeepEP dispatch forwarder timeout (NVL check), channel: %d, RDMA: %d, nvl: %d, dst NVL: %d, head: %d, tail: %d\n", - channel_id, rdma_rank, nvl_rank, dst_nvl_rank, ld_volatile_global(nvl_channel_head.buffer()), cached_nvl_channel_tail); - trap(); - } - } - __syncwarp(); - - // Find next source RDMA rank (round-robin) - start_time = clock64(); - while (true) { - src_rdma_rank = (src_rdma_rank + 1) % kNumRDMARanks; - if (__shfl_sync(0xffffffff, num_tokens_to_recv_from_rdma, src_rdma_rank) > 0) { - if (lane_id == src_rdma_rank and cached_rdma_channel_head == cached_rdma_channel_tail) - cached_rdma_channel_tail = static_cast(ld_acquire_sys_global(rdma_channel_tail.buffer(src_rdma_rank))); - if (__shfl_sync(0xffffffff, cached_rdma_channel_tail > cached_rdma_channel_head, src_rdma_rank)) - break; - } - - // Timeout check - if (clock64() - start_time > NUM_TIMEOUT_CYCLES and lane_id < kNumRDMARanks) { - printf("DeepEP dispatch forwarder timeout (RDMA check), channel: %d, RDMA: %d, nvl: %d, dst NVL: %d, src RDMA lane: %d, head: %d, tail: %d, expected: %d\n", - channel_id, rdma_rank, nvl_rank, dst_nvl_rank, lane_id, cached_rdma_channel_head, cached_rdma_channel_tail, num_tokens_to_recv_from_rdma); - trap(); - } - } - auto src_rdma_head = __shfl_sync(0xffffffff, cached_rdma_channel_head, src_rdma_rank); - auto src_rdma_tail = __shfl_sync(0xffffffff, cached_rdma_channel_tail, src_rdma_rank); - - // Iterate over every token from the RDMA buffer - for (int i = src_rdma_head, num_tokens_sent = 0; i < src_rdma_tail; ++ i) { - auto rdma_slot_idx = i % num_max_rdma_chunked_recv_tokens; - void* shifted = rdma_channel_data.recv_buffer(src_rdma_rank) + rdma_slot_idx * num_bytes_per_rdma_token; - auto src_meta = ld_nc_global(reinterpret_cast(reinterpret_cast(shifted) + hidden_bytes)); - lane_id == src_rdma_rank ? (num_tokens_to_recv_from_rdma -= 1) : 0; - bool is_in_dst_nvl_rank = src_meta.is_token_in_nvl_rank(dst_nvl_rank); - if (lane_id == src_rdma_rank) { - auto cached_head = is_in_dst_nvl_rank ? rdma_nvl_token_idx : -1; - rdma_nvl_token_idx += is_in_dst_nvl_rank; - if (not kCachedMode) - send_nvl_head[i * NUM_MAX_NVL_PEERS] = cached_head; - } - if (not is_in_dst_nvl_rank) - continue; - - // Get an empty slot - int dst_slot_idx = (cached_nvl_channel_tail ++) % num_max_nvl_chunked_recv_tokens; - - // Copy data - UNROLLED_WARP_COPY(5, lane_id, hidden_int4, - nvl_channel_x.buffer() + dst_slot_idx * hidden_int4, - reinterpret_cast(shifted), - ld_nc_global, st_na_global); - shifted = reinterpret_cast(shifted) + hidden_int4; - - // Copy source meta - if (lane_id == 0) - st_na_global(nvl_channel_src_meta.buffer() + dst_slot_idx, src_meta); - shifted = reinterpret_cast(shifted) + 1; - - // Copy `x_scales` - UNROLLED_WARP_COPY(1, lane_id, num_scales, - nvl_channel_x_scales.buffer() + dst_slot_idx * num_scales, - reinterpret_cast(shifted), - ld_nc_global, st_na_global); - shifted = reinterpret_cast(shifted) + num_scales; - - // Copy `topk_idx` and `topk_weights` - // NOTES: do not use `shifted` after this `if`, because only several lanes are shifted - if (lane_id < num_topk) { - // Read - auto idx_value = ld_nc_global(reinterpret_cast(shifted) + lane_id); - shifted = reinterpret_cast(shifted) + num_topk; - auto weight_value = ld_nc_global(reinterpret_cast(shifted) + lane_id); - - // Transform and write - idx_value = (idx_value >= dst_rank_expert_begin and idx_value < dst_rank_expert_end) ? idx_value - dst_rank_expert_begin : -1; - st_na_global(nvl_channel_topk_idx.buffer() + dst_slot_idx * num_topk + lane_id, idx_value); - weight_value = idx_value >= 0 ? weight_value : 0.0f; - st_na_global(nvl_channel_topk_weights.buffer() + dst_slot_idx * num_topk + lane_id, weight_value); - } - - // In case of insufficient NVL buffers, early stopping - if ((++ num_tokens_sent) == num_max_nvl_chunked_send_tokens) - src_rdma_tail = i + 1; - } - - // Sync head index - if (lane_id == src_rdma_rank) - forward_channel_head[dst_nvl_rank][src_rdma_rank] = (cached_rdma_channel_head = src_rdma_tail); - - // Move tail index - __syncwarp(); - if (lane_id == 0) - st_release_sys_global(nvl_channel_tail.buffer(), cached_nvl_channel_tail); - } - - // Retired - __syncwarp(); - if (lane_id == 0) - forward_channel_retired[dst_nvl_rank] = true; - } else if (warp_role == WarpRole::kForwarderCoordinator) { - // Extra warps for forwarder coordinator should exit directly - if (target_rank > 0) - return; - - // Forward warp coordinator - EP_STATIC_ASSERT(kNumRDMARanks <= 32, "Invalid number of RDMA peers"); - - // Clean shared memory - EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS <= 32, "Invalid number of NVL peers"); - #pragma unroll - for (int i = lane_id; i < kNumRDMARanks * NUM_MAX_NVL_PEERS; i += 32) - forward_channel_head[i % NUM_MAX_NVL_PEERS][i / NUM_MAX_NVL_PEERS] = 0; - if (lane_id < NUM_MAX_NVL_PEERS) - forward_channel_retired[lane_id] = false; - sync_forwarder_smem(); - - int last_head = 0, target_rdma = lane_id < kNumRDMARanks ? lane_id : 0; - while (true) { - // Find minimum head - int min_head = std::numeric_limits::max(); - #pragma unroll - for (int i = 0; i < NUM_MAX_NVL_PEERS; ++ i) if (not forward_channel_retired[i]) - min_head = min(min_head, forward_channel_head[i][target_rdma]); - if (__all_sync(0xffffffff, min_head == std::numeric_limits::max())) - break; - - // Update remote head - if (min_head != std::numeric_limits::max() and min_head >= last_head + num_max_rdma_chunked_send_tokens and lane_id < kNumRDMARanks) { - if (lane_id == rdma_rank) { - mscclpp::atomicFetchAdd(static_cast(rdma_channel_head.buffer(rdma_rank)), (uint64_t)(min_head - last_head), mscclpp::memoryOrderRelease); - } else { - auto dst_offset = rdma_rank * sizeof(uint64_t) + head_send_offset; - auto port_channel_idx = kLowLatencyMode ? (channel_id * kNumRDMARanks + lane_id) : (channel_id * num_ranks + lane_id * NUM_MAX_NVL_PEERS + nvl_rank); - auto& handle = port_channel_handles[port_channel_idx]; - // Remote atomic add on the peer's head counter. - handle.atomicAdd(dst_offset, (int64_t)(min_head - last_head)); - } - last_head = min_head; - } - - // Nanosleep and let other warps work - __nanosleep(NUM_WAIT_NANOSECONDS); - } + const auto role_meta = [=]() -> std::pair { + if (is_forwarder) { + if (warp_id < NUM_MAX_NVL_PEERS) { + return {WarpRole::kRDMAAndNVLForwarder, (warp_id + channel_id) % NUM_MAX_NVL_PEERS}; + } else { + return {WarpRole::kForwarderCoordinator, warp_id - NUM_MAX_NVL_PEERS}; + } + } else if (warp_id < kNumDispatchRDMASenderWarps) { + return {WarpRole::kRDMASender, -1}; + } else if (warp_id == kNumDispatchRDMASenderWarps) { + return {WarpRole::kRDMASenderCoordinator, -1}; } else { - // NVL consumers - // Retrieve rank offset from barrier results (each lane's register stores an RDMA rank) - int src_nvl_rank = target_rank, total_offset = 0; - EP_STATIC_ASSERT(kNumRDMARanks <= 32, "Invalid number of RDMA peers"); - if (lane_id < kNumRDMARanks and lane_id * NUM_MAX_NVL_PEERS + src_nvl_rank > 0) - total_offset = recv_gbl_rank_prefix_sum[lane_id * NUM_MAX_NVL_PEERS + src_nvl_rank - 1]; - - // Receive channel offsets - int start_offset = 0, end_offset = 0, num_tokens_to_recv; - auto start_time = clock64(); - while (lane_id < kNumRDMARanks) { - start_offset = ld_volatile_global(nvl_channel_prefix_start.buffer() + lane_id); - end_offset = ld_volatile_global(nvl_channel_prefix_end.buffer() + lane_id); - if (start_offset < 0 and end_offset < 0) { - start_offset = -start_offset - 1, end_offset = -end_offset - 1; - total_offset += start_offset; - break; - } - - // Timeout check - if (clock64() - start_time > NUM_TIMEOUT_CYCLES) { - printf("DeepEP dispatch NVL receiver timeout, channel: %d, RDMA: %d, nvl: %d, src RDMA: %d, src nvl: %d, start: %d, end: %d\n", - channel_id, rdma_rank, nvl_rank, lane_id, src_nvl_rank, start_offset, end_offset); - trap(); - } - } - num_tokens_to_recv = warp_reduce_sum(end_offset - start_offset); - - // Save for combine usage - if (lane_id < kNumRDMARanks and not kCachedMode) - recv_gbl_channel_prefix_matrix[(lane_id * NUM_MAX_NVL_PEERS + src_nvl_rank) * num_channels + channel_id] = total_offset; - __syncwarp(); - - int cached_channel_head_idx = 0, cached_channel_tail_idx = 0; - while (num_tokens_to_recv > 0) { - // Check channel status by lane 0 - start_time = clock64(); - while (lane_id == 0) { - // Ready to copy - if (cached_channel_head_idx != cached_channel_tail_idx) - break; - cached_channel_tail_idx = ld_acquire_sys_global(nvl_channel_tail.buffer()); - - // Timeout check - if (clock64() - start_time > NUM_TIMEOUT_CYCLES) { - printf("DeepEP dispatch NVL receiver timeout, channel: %d, RDMA: %d, nvl: %d, src NVL: %d, head: %d, tail: %d\n", - channel_id, rdma_rank, nvl_rank, src_nvl_rank, cached_channel_head_idx, cached_channel_tail_idx); - trap(); - } - } - - // Sync queue tail - cached_channel_tail_idx = __shfl_sync(0xffffffff, cached_channel_tail_idx, 0); - - // Copy data - int num_recv_tokens = cached_channel_tail_idx - cached_channel_head_idx; - for (int chunk_idx = 0; chunk_idx < num_recv_tokens; ++ chunk_idx, -- num_tokens_to_recv) { - int token_idx_in_buffer = (cached_channel_head_idx ++) % num_max_nvl_chunked_recv_tokens; - auto meta = ld_nc_global(nvl_channel_src_meta.buffer() + token_idx_in_buffer); - int64_t recv_token_idx = __shfl_sync(0xffffffff, total_offset, meta.src_rdma_rank); - (lane_id == meta.src_rdma_rank) ? (total_offset += 1) : 0; - - // Copy data - UNROLLED_WARP_COPY(5, lane_id, hidden_int4, - recv_x + recv_token_idx * hidden_int4, - nvl_channel_x.buffer() + token_idx_in_buffer * hidden_int4, - ld_nc_global, st_na_global); - - // Copy source meta - if (lane_id == 0 and not kCachedMode) - st_na_global(recv_src_meta + recv_token_idx, meta); - - // Copy scales - UNROLLED_WARP_COPY(1, lane_id, num_scales, - recv_x_scales + recv_token_idx * num_scales, - nvl_channel_x_scales.buffer() + token_idx_in_buffer * num_scales, - ld_nc_global, st_na_global); - - // Copy `topk_idx` and `topk_weights` - if (lane_id < num_topk) { - auto recv_idx = recv_token_idx * num_topk + lane_id; - auto buffer_idx = token_idx_in_buffer * num_topk + lane_id; - st_na_global(recv_topk_idx + recv_idx, static_cast(ld_nc_global(nvl_channel_topk_idx.buffer() + buffer_idx))); - st_na_global(recv_topk_weights + recv_idx, ld_nc_global(nvl_channel_topk_weights.buffer() + buffer_idx)); - } - } - - // Move queue - __syncwarp(); - if (lane_id == 0) - st_relaxed_sys_global(nvl_channel_head.buffer(), cached_channel_head_idx); - } + return {WarpRole::kNVLReceivers, (warp_id + channel_id - kNumDispatchRDMASenderWarps) % NUM_MAX_NVL_PEERS}; } + }(); + auto warp_role = role_meta.first; + auto target_rank = role_meta.second; // Not applicable for RDMA senders + EP_DEVICE_ASSERT(num_warps == kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NVL_PEERS); + + // Data checks + EP_DEVICE_ASSERT(num_topk <= 32); + + // RDMA symmetric layout + EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS * sizeof(bool) == sizeof(uint64_t), "Invalid number of NVL peers"); + auto hidden_bytes = hidden_int4 * sizeof(int4); + auto num_bytes_per_rdma_token = get_num_bytes_per_rdma_token(hidden_int4, num_scales, num_topk, num_topk); + auto rdma_channel_data = + SymBuffer(rdma_buffer_ptr, num_max_rdma_chunked_recv_tokens * num_bytes_per_rdma_token, kNumRDMARanks, + channel_id, num_channels); + auto rdma_channel_meta = + SymBuffer(rdma_buffer_ptr, NUM_MAX_NVL_PEERS * 2 + 2, kNumRDMARanks, channel_id, num_channels); + auto rdma_channel_head = SymBuffer(rdma_buffer_ptr, 1, kNumRDMARanks, channel_id, num_channels); + auto rdma_channel_tail = SymBuffer(rdma_buffer_ptr, 1, kNumRDMARanks, channel_id, num_channels); + + auto data_send_offset = + sizeof(int8_t) * (num_max_rdma_chunked_recv_tokens * num_bytes_per_rdma_token) * kNumRDMARanks * channel_id; + auto data_recv_offset = sizeof(int8_t) * (num_max_rdma_chunked_recv_tokens * num_bytes_per_rdma_token) * + kNumRDMARanks * (channel_id + num_channels); + auto meta_offset = + sizeof(int8_t) * (num_max_rdma_chunked_recv_tokens * num_bytes_per_rdma_token) * kNumRDMARanks * num_channels * 2; + auto meta_send_offset = meta_offset + sizeof(int) * (NUM_MAX_NVL_PEERS * 2 + 2) * kNumRDMARanks * channel_id; + auto meta_recv_offset = + meta_offset + sizeof(int) * (NUM_MAX_NVL_PEERS * 2 + 2) * kNumRDMARanks * (channel_id + num_channels); + auto head_offset = meta_offset + sizeof(int) * (NUM_MAX_NVL_PEERS * 2 + 2) * kNumRDMARanks * num_channels * 2; + auto head_send_offset = head_offset + sizeof(uint64_t) * kNumRDMARanks * channel_id; + auto tail_offset = head_offset + sizeof(uint64_t) * kNumRDMARanks * num_channels; + auto tail_send_offset = tail_offset + sizeof(uint64_t) * kNumRDMARanks * channel_id; + + // NVL buffer layouts + // NOTES: `rs_wr_buffer_ptr` means "Read for Senders, Write for Receivers", `ws_rr_buffer_ptr` means "Write for + // Senders, Read for Receivers" + void *rs_wr_buffer_ptr = nullptr, *ws_rr_buffer_ptr = nullptr; + int rs_wr_rank = 0, ws_rr_rank = 0; + if (warp_role == WarpRole::kRDMAAndNVLForwarder) + rs_wr_buffer_ptr = buffer_ptrs[nvl_rank], ws_rr_buffer_ptr = buffer_ptrs[target_rank], rs_wr_rank = nvl_rank, + ws_rr_rank = target_rank; + if (warp_role == WarpRole::kNVLReceivers) + rs_wr_buffer_ptr = buffer_ptrs[target_rank], ws_rr_buffer_ptr = buffer_ptrs[nvl_rank], rs_wr_rank = target_rank, + ws_rr_rank = nvl_rank; + + // Allocate buffers + auto nvl_channel_x = AsymBuffer(ws_rr_buffer_ptr, num_max_nvl_chunked_recv_tokens * hidden_int4, + NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank) + .advance_also(rs_wr_buffer_ptr); + auto nvl_channel_src_meta = AsymBuffer(ws_rr_buffer_ptr, num_max_nvl_chunked_recv_tokens, + NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank) + .advance_also(rs_wr_buffer_ptr); + auto nvl_channel_x_scales = AsymBuffer(ws_rr_buffer_ptr, num_max_nvl_chunked_recv_tokens * num_scales, + NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank) + .advance_also(rs_wr_buffer_ptr); + auto nvl_channel_topk_idx = AsymBuffer(ws_rr_buffer_ptr, num_max_nvl_chunked_recv_tokens * num_topk, + NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank) + .advance_also(rs_wr_buffer_ptr); + auto nvl_channel_topk_weights = AsymBuffer(ws_rr_buffer_ptr, num_max_nvl_chunked_recv_tokens * num_topk, + NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank) + .advance_also(rs_wr_buffer_ptr); + auto nvl_channel_prefix_start = + AsymBuffer(ws_rr_buffer_ptr, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank) + .advance_also(rs_wr_buffer_ptr); + auto nvl_channel_prefix_end = + AsymBuffer(ws_rr_buffer_ptr, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank) + .advance_also(rs_wr_buffer_ptr); + auto nvl_channel_head = AsymBuffer(rs_wr_buffer_ptr, 1, NUM_MAX_NVL_PEERS, channel_id, num_channels, ws_rr_rank) + .advance_also(ws_rr_buffer_ptr); + auto nvl_channel_tail = AsymBuffer(ws_rr_buffer_ptr, 1, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank) + .advance_also(rs_wr_buffer_ptr); + + // RDMA sender warp synchronization + __shared__ volatile int rdma_send_next_token_idx; + __shared__ volatile int rdma_send_channel_tail[kNumRDMARanks]; + __shared__ volatile int rdma_send_channel_next_tail[kNumRDMARanks]; + auto sync_rdma_sender_smem = []() { asm volatile("bar.sync 0, %0;" ::"r"((kNumDispatchRDMASenderWarps + 1) * 32)); }; + + // Forward warp synchronization + __shared__ volatile int forward_channel_head[NUM_MAX_NVL_PEERS][kNumRDMARanks]; + __shared__ volatile bool forward_channel_retired[NUM_MAX_NVL_PEERS]; + auto sync_forwarder_smem = []() { asm volatile("bar.sync 1, %0;" ::"r"((NUM_MAX_NVL_PEERS + 1) * 32)); }; + + if (warp_role == WarpRole::kRDMASender) { + // Get tasks + int token_start_idx, token_end_idx; + get_channel_task_range(num_tokens, num_channels, channel_id, token_start_idx, token_end_idx); + + // Clean shared memory + EP_STATIC_ASSERT(kNumRDMARanks <= 32, "Invalid number of RDMA ranks"); + (warp_id == 0 and lane_id == 0) ? (rdma_send_next_token_idx = token_start_idx) : 0; + (warp_id == 0 and lane_id < kNumRDMARanks) ? (rdma_send_channel_tail[lane_id] = 0) : 0; + (warp_id == 0 and lane_id < kNumRDMARanks) ? (rdma_send_channel_next_tail[lane_id] = 0) : 0; + + // Send number of tokens in this channel by `-value - 1` + EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS * 2 + 2 <= 32, "Invalid number of NVL peers"); + for (int dst_rdma_rank = warp_id; dst_rdma_rank < kNumRDMARanks; dst_rdma_rank += kNumDispatchRDMASenderWarps) { + auto dst_ptr = dst_rdma_rank == rdma_rank ? rdma_channel_meta.recv_buffer(dst_rdma_rank) + : rdma_channel_meta.send_buffer(dst_rdma_rank); + if (lane_id < NUM_MAX_NVL_PEERS) { + dst_ptr[lane_id] = + -(channel_id == 0 ? 0 + : gbl_channel_prefix_matrix[(dst_rdma_rank * NUM_MAX_NVL_PEERS + lane_id) * num_channels + + channel_id - 1]) - + 1; + } else if (lane_id < NUM_MAX_NVL_PEERS * 2) { + dst_ptr[lane_id] = + -gbl_channel_prefix_matrix[(dst_rdma_rank * NUM_MAX_NVL_PEERS + lane_id - NUM_MAX_NVL_PEERS) * + num_channels + + channel_id] - + 1; + } else if (lane_id == NUM_MAX_NVL_PEERS * 2) { + dst_ptr[lane_id] = + -(channel_id == 0 ? 0 : rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id - 1]) - 1; + } else if (lane_id == NUM_MAX_NVL_PEERS * 2 + 1) { + dst_ptr[lane_id] = -rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id] - 1; + } + __syncwarp(); + + if (dst_rdma_rank == rdma_rank) continue; + + // Issue RDMA for non-local ranks + if (lane_id == 0) { + auto num_bytes = sizeof(int) * (NUM_MAX_NVL_PEERS * 2 + 2); + auto dst_offset = rdma_rank * num_bytes + meta_recv_offset; + auto src_offset = dst_rdma_rank * num_bytes + meta_send_offset; + auto port_channel_idx = kLowLatencyMode + ? (channel_id * kNumRDMARanks + dst_rdma_rank) + : (channel_id * num_ranks + dst_rdma_rank * NUM_MAX_NVL_PEERS + nvl_rank); + port_channel_handles[port_channel_idx].put(dst_offset, src_offset, num_bytes); + // port_channel_handles[port_channel_idx].flush(); + } + __syncwarp(); + } + sync_rdma_sender_smem(); + + // Iterate over tokens and copy into buffer + int64_t token_idx; + int cached_rdma_channel_head = 0, last_rdma_tail_idx = -1; + auto send_buffer = + lane_id == rdma_rank ? rdma_channel_data.recv_buffer(lane_id) : rdma_channel_data.send_buffer(lane_id); + for (token_idx = token_start_idx + warp_id; token_idx < token_end_idx; token_idx += kNumDispatchRDMASenderWarps) { + // Read RDMA rank existence + uint64_t is_token_in_rank_uint64 = 0; + if (lane_id < kNumRDMARanks) + is_token_in_rank_uint64 = + *reinterpret_cast(is_token_in_rank + token_idx * num_ranks + lane_id * NUM_MAX_NVL_PEERS); + + // Acquire sequential lock + while (lane_id == 0 and rdma_send_next_token_idx != token_idx) + ; + __syncwarp(); + + // Acquire next tail + int rdma_tail_idx = -1; + if (is_token_in_rank_uint64 != 0) { + rdma_tail_idx = rdma_send_channel_next_tail[lane_id]++; + while (rdma_tail_idx - cached_rdma_channel_head >= num_max_rdma_chunked_recv_tokens) + cached_rdma_channel_head = static_cast(ld_volatile_global(rdma_channel_head.buffer(lane_id))); + } + __syncwarp(); + + // Store RDMA head for combine + if (lane_id < kNumRDMARanks and not kCachedMode) + send_rdma_head[token_idx * kNumRDMARanks + lane_id] = rdma_tail_idx; + + // Update last token tail + if (last_rdma_tail_idx >= 0) + st_release_cta(const_cast(rdma_send_channel_tail + lane_id), last_rdma_tail_idx + 1); + last_rdma_tail_idx = rdma_tail_idx; + + // Release sequential lock + lane_id == 0 ? (rdma_send_next_token_idx += 1) : 0; + + // Broadcast tails + SourceMeta src_meta; + int num_topk_ranks = 0, topk_ranks[kNumTopkRDMARanks]; + void* dst_send_buffers[kNumTopkRDMARanks]; +#pragma unroll + for (int i = 0, slot_idx; i < kNumRDMARanks; ++i) + if ((slot_idx = __shfl_sync(0xffffffff, rdma_tail_idx, i)) >= 0) { + slot_idx = slot_idx % num_max_rdma_chunked_recv_tokens; + topk_ranks[num_topk_ranks] = i; + auto recv_is_token_in_rank_uint64 = broadcast(is_token_in_rank_uint64, i); + auto recv_is_token_in_rank_values = reinterpret_cast(&recv_is_token_in_rank_uint64); + if (lane_id == num_topk_ranks) src_meta = SourceMeta(rdma_rank, recv_is_token_in_rank_values); + dst_send_buffers[num_topk_ranks++] = + reinterpret_cast(broadcast(send_buffer, i)) + slot_idx * num_bytes_per_rdma_token; + } + EP_DEVICE_ASSERT(num_topk_ranks <= kNumTopkRDMARanks); + + // Copy `x` into symmetric send buffer + auto st_broadcast = [=](const int key, const int4& value) { +#pragma unroll + for (int j = 0; j < num_topk_ranks; ++j) + st_na_global(reinterpret_cast(dst_send_buffers[j]) + key, value); + }; + UNROLLED_WARP_COPY(5, lane_id, hidden_int4, 0, x + token_idx * hidden_int4, ld_nc_global, st_broadcast); +#pragma unroll + for (int i = 0; i < num_topk_ranks; ++i) + dst_send_buffers[i] = reinterpret_cast(dst_send_buffers[i]) + hidden_int4; + + // Copy source metadata into symmetric send buffer + if (lane_id < num_topk_ranks) st_na_global(reinterpret_cast(dst_send_buffers[lane_id]), src_meta); +#pragma unroll + for (int i = 0; i < num_topk_ranks; ++i) + dst_send_buffers[i] = reinterpret_cast(dst_send_buffers[i]) + 1; + +// Copy `x_scales` into symmetric send buffer +#pragma unroll + for (int i = lane_id; i < num_scales; i += 32) { + auto value = ld_nc_global(x_scales + token_idx * num_scales + i); +#pragma unroll + for (int j = 0; j < num_topk_ranks; ++j) st_na_global(reinterpret_cast(dst_send_buffers[j]) + i, value); + } +#pragma unroll + for (int i = 0; i < num_topk_ranks; ++i) + dst_send_buffers[i] = reinterpret_cast(dst_send_buffers[i]) + num_scales; + +// Copy `topk_idx` and `topk_weights` into symmetric send buffer +#pragma unroll + for (int i = lane_id; i < num_topk * num_topk_ranks; i += 32) { + auto rank_idx = i / num_topk, copy_idx = i % num_topk; + auto idx_value = static_cast(ld_nc_global(topk_idx + token_idx * num_topk + copy_idx)); + auto weight_value = ld_nc_global(topk_weights + token_idx * num_topk + copy_idx); + st_na_global(reinterpret_cast(dst_send_buffers[rank_idx]) + copy_idx, idx_value); + st_na_global(reinterpret_cast(dst_send_buffers[rank_idx]) + num_topk + copy_idx, weight_value); + } + } + + // Epilogue + // Acquire sequential lock + while (lane_id == 0 and rdma_send_next_token_idx != token_idx) + ; + __syncwarp(); + + // Update last token tail + if (last_rdma_tail_idx >= 0) + st_release_cta(const_cast(rdma_send_channel_tail + lane_id), last_rdma_tail_idx + 1); + + // Release sequential lock + lane_id == 0 ? (rdma_send_next_token_idx += 1) : 0; + } else if (warp_role == WarpRole::kRDMASenderCoordinator) { + // NOTES: in case of splitting the issued put at the end of the buffer + EP_DEVICE_ASSERT(num_max_rdma_chunked_recv_tokens % num_max_rdma_chunked_send_tokens == 0); + + // Synchronize shared memory + sync_rdma_sender_smem(); + + // Get number of tokens to send for each RDMA rank + int num_tokens_to_send = 0; + if (lane_id < kNumRDMARanks) { + num_tokens_to_send = rdma_channel_prefix_matrix[lane_id * num_channels + channel_id]; + if (channel_id > 0) num_tokens_to_send -= rdma_channel_prefix_matrix[lane_id * num_channels + channel_id - 1]; + } + + // Iterate all RDMA ranks + int last_issued_tail = 0; + while (__any_sync(0xffffffff, num_tokens_to_send > 0)) { +#pragma unroll + for (int i = 0; i < kNumRDMARanks; ++i, __syncwarp()) { + // To mitigate incast congestion, shuffle the starting index of target rank for different ranks and channels + const int dst_rdma_rank = (i + channel_id + rdma_rank) % kNumRDMARanks; + if (lane_id != dst_rdma_rank) continue; + if (num_tokens_to_send == 0) continue; + + // Read progress + auto processed_tail = ld_acquire_cta(const_cast(rdma_send_channel_tail + dst_rdma_rank)); + auto num_tokens_processed = processed_tail - last_issued_tail; + if (num_tokens_processed != num_tokens_to_send && num_tokens_processed < num_max_rdma_chunked_send_tokens) + continue; + + // Issue RDMA send + int num_tokens_to_issue = min(num_tokens_processed, num_max_rdma_chunked_send_tokens); + EP_DEVICE_ASSERT(num_tokens_to_issue >= 0 && num_tokens_to_issue <= num_tokens_to_send); + if (num_tokens_to_issue == 0) continue; + + if (dst_rdma_rank == rdma_rank) { + // Update tails + mscclpp::atomicFetchAdd(reinterpret_cast(rdma_channel_tail.buffer(rdma_rank)), + (uint64_t)num_tokens_to_issue, mscclpp::memoryOrderRelease); + } else { + const auto dst_slot_idx = last_issued_tail % num_max_rdma_chunked_recv_tokens; + const size_t num_bytes_per_msg = num_bytes_per_rdma_token * num_tokens_to_issue; + const auto dst_offset = rdma_rank * (num_max_rdma_chunked_recv_tokens * num_bytes_per_rdma_token) + + dst_slot_idx * num_bytes_per_rdma_token + data_recv_offset; + const auto src_offset = dst_rdma_rank * (num_max_rdma_chunked_recv_tokens * num_bytes_per_rdma_token) + + dst_slot_idx * num_bytes_per_rdma_token + data_send_offset; + const auto port_channel_idx = kLowLatencyMode + ? (channel_id * kNumRDMARanks + dst_rdma_rank) + : (channel_id * num_ranks + dst_rdma_rank * NUM_MAX_NVL_PEERS + nvl_rank); + auto& handle = port_channel_handles[port_channel_idx]; + handle.put(dst_offset, src_offset, num_bytes_per_msg); + + // Remote atomic add on the peer's tail counter: +num_tokens_to_issue. + handle.atomicAdd(rdma_rank * sizeof(uint64_t) + tail_send_offset, (int64_t)num_tokens_to_issue); + // handle.flush(); + } + last_issued_tail += num_tokens_to_issue; + num_tokens_to_send -= num_tokens_to_issue; + } + } + } else if (warp_role == WarpRole::kRDMAAndNVLForwarder) { + // RDMA consumers and NVL producers + const auto dst_nvl_rank = target_rank; + const auto dst_rank = rdma_rank * NUM_MAX_NVL_PEERS + dst_nvl_rank; + const auto dst_rank_expert_begin = dst_rank * (num_experts / num_ranks); + const auto dst_rank_expert_end = dst_rank_expert_begin + (num_experts / num_ranks); + + // Wait counters to arrive + int num_tokens_to_recv_from_rdma = 0, src_rdma_channel_prefix = 0; + EP_DEVICE_ASSERT(kNumRDMARanks <= 32); + auto start_time = clock64(); + if (lane_id < kNumRDMARanks) { + while (true) { + auto meta_0 = ld_volatile_global(rdma_channel_meta.recv_buffer(lane_id) + dst_nvl_rank); + auto meta_1 = ld_volatile_global(rdma_channel_meta.recv_buffer(lane_id) + NUM_MAX_NVL_PEERS + dst_nvl_rank); + auto meta_2 = ld_volatile_global(rdma_channel_meta.recv_buffer(lane_id) + NUM_MAX_NVL_PEERS * 2); + auto meta_3 = ld_volatile_global(rdma_channel_meta.recv_buffer(lane_id) + NUM_MAX_NVL_PEERS * 2 + 1); + if (meta_0 < 0 and meta_1 < 0 and meta_2 < 0 and meta_3 < 0) { + // Notify NVL ranks + int start_sum = -meta_0 - 1, end_sum = -meta_1 - 1; + EP_DEVICE_ASSERT(start_sum >= 0 and end_sum >= 0 and end_sum >= start_sum); + st_relaxed_sys_global(nvl_channel_prefix_start.buffer() + lane_id, -start_sum - 1); + st_relaxed_sys_global(nvl_channel_prefix_end.buffer() + lane_id, -end_sum - 1); + + // Save RDMA channel received token count + src_rdma_channel_prefix = -meta_2 - 1; + auto src_rdma_channel_prefix_1 = -meta_3 - 1; + num_tokens_to_recv_from_rdma = src_rdma_channel_prefix_1 - src_rdma_channel_prefix; + if (not kCachedMode) + recv_rdma_channel_prefix_matrix[lane_id * num_channels + channel_id] = src_rdma_channel_prefix_1; + src_rdma_channel_prefix += lane_id == 0 ? 0 : recv_rdma_rank_prefix_sum[lane_id - 1]; + EP_DEVICE_ASSERT(num_tokens_to_recv_from_rdma >= 0); + break; + } + + // Timeout check + if (clock64() - start_time > NUM_TIMEOUT_CYCLES) { + printf( + "DeepEP dispatch forwarder timeout (RDMA meta), channel: %d, RDMA: %d, nvl: %d, src RDMA lane: %d, dst " + "NVL: %d, meta: %d, %d, %d, %d\n", + channel_id, rdma_rank, nvl_rank, lane_id, dst_nvl_rank, meta_0, meta_1, meta_2, meta_3); + trap(); + } + } + } + __syncwarp(); + + // Shift cached head + send_nvl_head += src_rdma_channel_prefix * NUM_MAX_NVL_PEERS + dst_nvl_rank; + + // Wait shared memory to be cleaned + sync_forwarder_smem(); + + // Forward tokens from RDMA buffer + // NOTES: always start from the local rank + int src_rdma_rank = sm_id % kNumRDMARanks; + int cached_rdma_channel_head = 0, cached_rdma_channel_tail = 0; + int cached_nvl_channel_head = 0, cached_nvl_channel_tail = 0, rdma_nvl_token_idx = 0; + while (__any_sync(0xffffffff, num_tokens_to_recv_from_rdma > 0)) { + // Check destination queue emptiness, or wait a buffer to be released + start_time = clock64(); + while (lane_id == 0) { + int num_used_slots = cached_nvl_channel_tail - cached_nvl_channel_head; + if (num_max_nvl_chunked_recv_tokens - num_used_slots >= num_max_nvl_chunked_send_tokens) break; + cached_nvl_channel_head = ld_volatile_global(nvl_channel_head.buffer()); + + // Timeout check + if (clock64() - start_time > NUM_TIMEOUT_CYCLES) { + printf( + "DeepEP dispatch forwarder timeout (NVL check), channel: %d, RDMA: %d, nvl: %d, dst NVL: %d, head: %d, " + "tail: %d\n", + channel_id, rdma_rank, nvl_rank, dst_nvl_rank, ld_volatile_global(nvl_channel_head.buffer()), + cached_nvl_channel_tail); + trap(); + } + } + __syncwarp(); + + // Find next source RDMA rank (round-robin) + start_time = clock64(); + while (true) { + src_rdma_rank = (src_rdma_rank + 1) % kNumRDMARanks; + if (__shfl_sync(0xffffffff, num_tokens_to_recv_from_rdma, src_rdma_rank) > 0) { + if (lane_id == src_rdma_rank and cached_rdma_channel_head == cached_rdma_channel_tail) + cached_rdma_channel_tail = static_cast(ld_acquire_sys_global(rdma_channel_tail.buffer(src_rdma_rank))); + if (__shfl_sync(0xffffffff, cached_rdma_channel_tail > cached_rdma_channel_head, src_rdma_rank)) break; + } + + // Timeout check + if (clock64() - start_time > NUM_TIMEOUT_CYCLES and lane_id < kNumRDMARanks) { + printf( + "DeepEP dispatch forwarder timeout (RDMA check), channel: %d, RDMA: %d, nvl: %d, dst NVL: %d, src RDMA " + "lane: %d, head: %d, tail: %d, expected: %d\n", + channel_id, rdma_rank, nvl_rank, dst_nvl_rank, lane_id, cached_rdma_channel_head, + cached_rdma_channel_tail, num_tokens_to_recv_from_rdma); + trap(); + } + } + auto src_rdma_head = __shfl_sync(0xffffffff, cached_rdma_channel_head, src_rdma_rank); + auto src_rdma_tail = __shfl_sync(0xffffffff, cached_rdma_channel_tail, src_rdma_rank); + + // Iterate over every token from the RDMA buffer + for (int i = src_rdma_head, num_tokens_sent = 0; i < src_rdma_tail; ++i) { + auto rdma_slot_idx = i % num_max_rdma_chunked_recv_tokens; + void* shifted = rdma_channel_data.recv_buffer(src_rdma_rank) + rdma_slot_idx * num_bytes_per_rdma_token; + auto src_meta = ld_nc_global(reinterpret_cast(reinterpret_cast(shifted) + hidden_bytes)); + lane_id == src_rdma_rank ? (num_tokens_to_recv_from_rdma -= 1) : 0; + bool is_in_dst_nvl_rank = src_meta.is_token_in_nvl_rank(dst_nvl_rank); + if (lane_id == src_rdma_rank) { + auto cached_head = is_in_dst_nvl_rank ? rdma_nvl_token_idx : -1; + rdma_nvl_token_idx += is_in_dst_nvl_rank; + if (not kCachedMode) send_nvl_head[i * NUM_MAX_NVL_PEERS] = cached_head; + } + if (not is_in_dst_nvl_rank) continue; + + // Get an empty slot + int dst_slot_idx = (cached_nvl_channel_tail++) % num_max_nvl_chunked_recv_tokens; + + // Copy data + UNROLLED_WARP_COPY(5, lane_id, hidden_int4, nvl_channel_x.buffer() + dst_slot_idx * hidden_int4, + reinterpret_cast(shifted), ld_nc_global, st_na_global); + shifted = reinterpret_cast(shifted) + hidden_int4; + + // Copy source meta + if (lane_id == 0) st_na_global(nvl_channel_src_meta.buffer() + dst_slot_idx, src_meta); + shifted = reinterpret_cast(shifted) + 1; + + // Copy `x_scales` + UNROLLED_WARP_COPY(1, lane_id, num_scales, nvl_channel_x_scales.buffer() + dst_slot_idx * num_scales, + reinterpret_cast(shifted), ld_nc_global, st_na_global); + shifted = reinterpret_cast(shifted) + num_scales; + + // Copy `topk_idx` and `topk_weights` + // NOTES: do not use `shifted` after this `if`, because only several lanes are shifted + if (lane_id < num_topk) { + // Read + auto idx_value = ld_nc_global(reinterpret_cast(shifted) + lane_id); + shifted = reinterpret_cast(shifted) + num_topk; + auto weight_value = ld_nc_global(reinterpret_cast(shifted) + lane_id); + + // Transform and write + idx_value = (idx_value >= dst_rank_expert_begin and idx_value < dst_rank_expert_end) + ? idx_value - dst_rank_expert_begin + : -1; + st_na_global(nvl_channel_topk_idx.buffer() + dst_slot_idx * num_topk + lane_id, idx_value); + weight_value = idx_value >= 0 ? weight_value : 0.0f; + st_na_global(nvl_channel_topk_weights.buffer() + dst_slot_idx * num_topk + lane_id, weight_value); + } + + // In case of insufficient NVL buffers, early stopping + if ((++num_tokens_sent) == num_max_nvl_chunked_send_tokens) src_rdma_tail = i + 1; + } + + // Sync head index + if (lane_id == src_rdma_rank) + forward_channel_head[dst_nvl_rank][src_rdma_rank] = (cached_rdma_channel_head = src_rdma_tail); + + // Move tail index + __syncwarp(); + if (lane_id == 0) st_release_sys_global(nvl_channel_tail.buffer(), cached_nvl_channel_tail); + } + + // Retired + __syncwarp(); + if (lane_id == 0) forward_channel_retired[dst_nvl_rank] = true; + } else if (warp_role == WarpRole::kForwarderCoordinator) { + // Extra warps for forwarder coordinator should exit directly + if (target_rank > 0) return; + + // Forward warp coordinator + EP_STATIC_ASSERT(kNumRDMARanks <= 32, "Invalid number of RDMA peers"); + + // Clean shared memory + EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS <= 32, "Invalid number of NVL peers"); +#pragma unroll + for (int i = lane_id; i < kNumRDMARanks * NUM_MAX_NVL_PEERS; i += 32) + forward_channel_head[i % NUM_MAX_NVL_PEERS][i / NUM_MAX_NVL_PEERS] = 0; + if (lane_id < NUM_MAX_NVL_PEERS) forward_channel_retired[lane_id] = false; + sync_forwarder_smem(); + + int last_head = 0, target_rdma = lane_id < kNumRDMARanks ? lane_id : 0; + while (true) { + // Find minimum head + int min_head = std::numeric_limits::max(); +#pragma unroll + for (int i = 0; i < NUM_MAX_NVL_PEERS; ++i) + if (not forward_channel_retired[i]) min_head = min(min_head, forward_channel_head[i][target_rdma]); + if (__all_sync(0xffffffff, min_head == std::numeric_limits::max())) break; + + // Update remote head + if (min_head != std::numeric_limits::max() and min_head >= last_head + num_max_rdma_chunked_send_tokens and + lane_id < kNumRDMARanks) { + if (lane_id == rdma_rank) { + mscclpp::atomicFetchAdd(static_cast(rdma_channel_head.buffer(rdma_rank)), + (uint64_t)(min_head - last_head), mscclpp::memoryOrderRelease); + } else { + auto dst_offset = rdma_rank * sizeof(uint64_t) + head_send_offset; + auto port_channel_idx = kLowLatencyMode ? (channel_id * kNumRDMARanks + lane_id) + : (channel_id * num_ranks + lane_id * NUM_MAX_NVL_PEERS + nvl_rank); + auto& handle = port_channel_handles[port_channel_idx]; + // Remote atomic add on the peer's head counter. + handle.atomicAdd(dst_offset, (int64_t)(min_head - last_head)); + } + last_head = min_head; + } + + // Nanosleep and let other warps work + __nanosleep(NUM_WAIT_NANOSECONDS); + } + } else { + // NVL consumers + // Retrieve rank offset from barrier results (each lane's register stores an RDMA rank) + int src_nvl_rank = target_rank, total_offset = 0; + EP_STATIC_ASSERT(kNumRDMARanks <= 32, "Invalid number of RDMA peers"); + if (lane_id < kNumRDMARanks and lane_id * NUM_MAX_NVL_PEERS + src_nvl_rank > 0) + total_offset = recv_gbl_rank_prefix_sum[lane_id * NUM_MAX_NVL_PEERS + src_nvl_rank - 1]; + + // Receive channel offsets + int start_offset = 0, end_offset = 0, num_tokens_to_recv; + auto start_time = clock64(); + while (lane_id < kNumRDMARanks) { + start_offset = ld_volatile_global(nvl_channel_prefix_start.buffer() + lane_id); + end_offset = ld_volatile_global(nvl_channel_prefix_end.buffer() + lane_id); + if (start_offset < 0 and end_offset < 0) { + start_offset = -start_offset - 1, end_offset = -end_offset - 1; + total_offset += start_offset; + break; + } + + // Timeout check + if (clock64() - start_time > NUM_TIMEOUT_CYCLES) { + printf( + "DeepEP dispatch NVL receiver timeout, channel: %d, RDMA: %d, nvl: %d, src RDMA: %d, src nvl: %d, start: " + "%d, end: %d\n", + channel_id, rdma_rank, nvl_rank, lane_id, src_nvl_rank, start_offset, end_offset); + trap(); + } + } + num_tokens_to_recv = warp_reduce_sum(end_offset - start_offset); + + // Save for combine usage + if (lane_id < kNumRDMARanks and not kCachedMode) + recv_gbl_channel_prefix_matrix[(lane_id * NUM_MAX_NVL_PEERS + src_nvl_rank) * num_channels + channel_id] = + total_offset; + __syncwarp(); + + int cached_channel_head_idx = 0, cached_channel_tail_idx = 0; + while (num_tokens_to_recv > 0) { + // Check channel status by lane 0 + start_time = clock64(); + while (lane_id == 0) { + // Ready to copy + if (cached_channel_head_idx != cached_channel_tail_idx) break; + cached_channel_tail_idx = ld_acquire_sys_global(nvl_channel_tail.buffer()); + + // Timeout check + if (clock64() - start_time > NUM_TIMEOUT_CYCLES) { + printf( + "DeepEP dispatch NVL receiver timeout, channel: %d, RDMA: %d, nvl: %d, src NVL: %d, head: %d, tail: %d\n", + channel_id, rdma_rank, nvl_rank, src_nvl_rank, cached_channel_head_idx, cached_channel_tail_idx); + trap(); + } + } + + // Sync queue tail + cached_channel_tail_idx = __shfl_sync(0xffffffff, cached_channel_tail_idx, 0); + + // Copy data + int num_recv_tokens = cached_channel_tail_idx - cached_channel_head_idx; + for (int chunk_idx = 0; chunk_idx < num_recv_tokens; ++chunk_idx, --num_tokens_to_recv) { + int token_idx_in_buffer = (cached_channel_head_idx++) % num_max_nvl_chunked_recv_tokens; + auto meta = ld_nc_global(nvl_channel_src_meta.buffer() + token_idx_in_buffer); + int64_t recv_token_idx = __shfl_sync(0xffffffff, total_offset, meta.src_rdma_rank); + (lane_id == meta.src_rdma_rank) ? (total_offset += 1) : 0; + + // Copy data + UNROLLED_WARP_COPY(5, lane_id, hidden_int4, recv_x + recv_token_idx * hidden_int4, + nvl_channel_x.buffer() + token_idx_in_buffer * hidden_int4, ld_nc_global, st_na_global); + + // Copy source meta + if (lane_id == 0 and not kCachedMode) st_na_global(recv_src_meta + recv_token_idx, meta); + + // Copy scales + UNROLLED_WARP_COPY(1, lane_id, num_scales, recv_x_scales + recv_token_idx * num_scales, + nvl_channel_x_scales.buffer() + token_idx_in_buffer * num_scales, ld_nc_global, + st_na_global); + + // Copy `topk_idx` and `topk_weights` + if (lane_id < num_topk) { + auto recv_idx = recv_token_idx * num_topk + lane_id; + auto buffer_idx = token_idx_in_buffer * num_topk + lane_id; + st_na_global(recv_topk_idx + recv_idx, + static_cast(ld_nc_global(nvl_channel_topk_idx.buffer() + buffer_idx))); + st_na_global(recv_topk_weights + recv_idx, ld_nc_global(nvl_channel_topk_weights.buffer() + buffer_idx)); + } + } + + // Move queue + __syncwarp(); + if (lane_id == 0) st_relaxed_sys_global(nvl_channel_head.buffer(), cached_channel_head_idx); + } + } } void dispatch(void* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv_topk_weights, void* recv_src_meta, const void* x, const float* x_scales, const int64_t* topk_idx, const float* topk_weights, - int* send_rdma_head, int* send_nvl_head, - int* recv_rdma_channel_prefix_matrix, int* recv_gbl_channel_prefix_matrix, - const int* rdma_channel_prefix_matrix, const int* recv_rdma_rank_prefix_sum, - const int* gbl_channel_prefix_matrix, const int* recv_gbl_rank_prefix_sum, - int num_tokens, int hidden_int4, int num_scales, int num_topk, int num_experts, - const bool* is_token_in_rank, - void* rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens, - void** buffer_ptrs, int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens, - int rank, int num_ranks, bool is_cached_dispatch, - cudaStream_t stream, int num_channels, bool low_latency_mode, - mscclpp::PortChannelDeviceHandle *port_channel_handles, - mscclpp::MemoryChannelDeviceHandle *memory_channel_handles) { - constexpr int kNumDispatchRDMASenderWarps = 7; + int* send_rdma_head, int* send_nvl_head, int* recv_rdma_channel_prefix_matrix, + int* recv_gbl_channel_prefix_matrix, const int* rdma_channel_prefix_matrix, + const int* recv_rdma_rank_prefix_sum, const int* gbl_channel_prefix_matrix, + const int* recv_gbl_rank_prefix_sum, int num_tokens, int hidden_int4, int num_scales, int num_topk, + int num_experts, const bool* is_token_in_rank, void* rdma_buffer_ptr, + int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens, void** buffer_ptrs, + int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens, int rank, int num_ranks, + bool is_cached_dispatch, cudaStream_t stream, int num_channels, bool low_latency_mode, + mscclpp::PortChannelDeviceHandle* port_channel_handles, + mscclpp::MemoryChannelDeviceHandle* memory_channel_handles) { + constexpr int kNumDispatchRDMASenderWarps = 7; -#define DISPATCH_LAUNCH_CASE(num_rdma_ranks) { \ - auto dispatch_func = low_latency_mode ? \ - (is_cached_dispatch ? dispatch : dispatch) : \ - (is_cached_dispatch ? dispatch : dispatch); \ - LAUNCH_KERNEL(&cfg, dispatch_func, \ - reinterpret_cast(recv_x), recv_x_scales, recv_topk_idx, recv_topk_weights, reinterpret_cast(recv_src_meta), \ - reinterpret_cast(x), x_scales, topk_idx, topk_weights, \ - send_rdma_head, send_nvl_head, \ - recv_rdma_channel_prefix_matrix, recv_gbl_channel_prefix_matrix, \ - rdma_channel_prefix_matrix, recv_rdma_rank_prefix_sum, \ - gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum, \ - num_tokens, hidden_int4, num_scales, num_topk, num_experts, \ - is_token_in_rank, \ - rdma_buffer_ptr, num_max_rdma_chunked_send_tokens, num_max_rdma_chunked_recv_tokens, \ - buffer_ptrs, num_max_nvl_chunked_send_tokens, num_max_nvl_chunked_recv_tokens, \ - rank, num_ranks, \ - port_channel_handles, memory_channel_handles); } break +#define DISPATCH_LAUNCH_CASE(num_rdma_ranks) \ + { \ + auto dispatch_func = \ + low_latency_mode ? (is_cached_dispatch ? dispatch \ + : dispatch) \ + : (is_cached_dispatch ? dispatch \ + : dispatch); \ + LAUNCH_KERNEL(&cfg, dispatch_func, reinterpret_cast(recv_x), recv_x_scales, recv_topk_idx, \ + recv_topk_weights, reinterpret_cast(recv_src_meta), reinterpret_cast(x), \ + x_scales, topk_idx, topk_weights, send_rdma_head, send_nvl_head, recv_rdma_channel_prefix_matrix, \ + recv_gbl_channel_prefix_matrix, rdma_channel_prefix_matrix, recv_rdma_rank_prefix_sum, \ + gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum, num_tokens, hidden_int4, num_scales, num_topk, \ + num_experts, is_token_in_rank, rdma_buffer_ptr, num_max_rdma_chunked_send_tokens, \ + num_max_rdma_chunked_recv_tokens, buffer_ptrs, num_max_nvl_chunked_send_tokens, \ + num_max_nvl_chunked_recv_tokens, rank, num_ranks, port_channel_handles, memory_channel_handles); \ + } \ + break - EP_HOST_ASSERT((topk_idx == nullptr) == (topk_weights == nullptr)); - EP_HOST_ASSERT((recv_topk_idx == nullptr) == (recv_topk_weights == nullptr)); + EP_HOST_ASSERT((topk_idx == nullptr) == (topk_weights == nullptr)); + EP_HOST_ASSERT((recv_topk_idx == nullptr) == (recv_topk_weights == nullptr)); - SETUP_LAUNCH_CONFIG(num_channels * 2, (kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NVL_PEERS) * 32, stream); - SWITCH_RDMA_RANKS(DISPATCH_LAUNCH_CASE); + SETUP_LAUNCH_CONFIG(num_channels * 2, (kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NVL_PEERS) * 32, stream); + SWITCH_RDMA_RANKS(DISPATCH_LAUNCH_CASE); #undef DISPATCH_LAUNCH_CASE } template -__global__ void cached_notify(const int rdma_clean_offset, const int rdma_num_int_clean, - const int nvl_clean_offset, const int nvl_num_int_clean, - int* combined_rdma_head, int num_combined_tokens, int num_channels, - const int* rdma_channel_prefix_matrix, const int* rdma_rank_prefix_sum, int* combined_nvl_head, - void* rdma_buffer_ptr, - void** buffer_ptrs, int** task_fifo_ptrs, int head, int rank, int num_ranks, - bool is_cached_dispatch, - mscclpp::PortChannelDeviceHandle *port_channel_handles, - mscclpp::MemoryChannelDeviceHandle *memory_channel_handles) { - auto sm_id = static_cast(blockIdx.x); - auto thread_id = static_cast(threadIdx.x); - auto num_threads = static_cast(blockDim.x); - auto warp_id = thread_id / 32; - auto lane_id = get_lane_id(); +__global__ void cached_notify(const int rdma_clean_offset, const int rdma_num_int_clean, const int nvl_clean_offset, + const int nvl_num_int_clean, int* combined_rdma_head, int num_combined_tokens, + int num_channels, const int* rdma_channel_prefix_matrix, const int* rdma_rank_prefix_sum, + int* combined_nvl_head, void* rdma_buffer_ptr, void** buffer_ptrs, int** task_fifo_ptrs, + int head, int rank, int num_ranks, bool is_cached_dispatch, + mscclpp::PortChannelDeviceHandle* port_channel_handles, + mscclpp::MemoryChannelDeviceHandle* memory_channel_handles) { + auto sm_id = static_cast(blockIdx.x); + auto thread_id = static_cast(threadIdx.x); + auto num_threads = static_cast(blockDim.x); + auto warp_id = thread_id / 32; + auto lane_id = get_lane_id(); - auto rdma_rank = rank / NUM_MAX_NVL_PEERS; - auto nvl_rank = rank % NUM_MAX_NVL_PEERS; - auto num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS; + auto rdma_rank = rank / NUM_MAX_NVL_PEERS; + auto nvl_rank = rank % NUM_MAX_NVL_PEERS; + auto num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS; - // Using two SMs, which clean the RDMA/NVL buffer respectively - if (sm_id == 0) { - // Barrier for RDMA + // Using two SMs, which clean the RDMA/NVL buffer respectively + if (sm_id == 0) { + // Barrier for RDMA - // TODO(chhwang): it should be a global barrier when kLowLatencyMode is false - const bool run_barrier = (threadIdx.x < num_rdma_ranks) && (threadIdx.x != rdma_rank); - const auto barrier_channel_idx = kLowLatencyMode ? threadIdx.x : (threadIdx.x * NUM_MAX_NVL_PEERS + nvl_rank); - if (run_barrier) { - port_channel_handles[barrier_channel_idx].signal(); - port_channel_handles[barrier_channel_idx].wait(); - } - __syncthreads(); - - // Clean - auto rdma_buffer_ptr_int = reinterpret_cast(rdma_buffer_ptr); - #pragma unroll - for (int i = thread_id; i < rdma_num_int_clean; i += num_threads) - rdma_buffer_ptr_int[rdma_clean_offset + i] = 0; - // Make the cleanup visible to the proxy + remote peers before the barrier. - // DeepEP used `nvshmem_fence()` here; we fall back to a system-scope - // threadfence because the actual remote visibility is provided by the - // subsequent port-channel barrier (signal + flush + wait). - __threadfence_system(); - __syncthreads(); - - // Barrier again - if (run_barrier) { - port_channel_handles[barrier_channel_idx].signal(); - port_channel_handles[barrier_channel_idx].flush(); - port_channel_handles[barrier_channel_idx].wait(); - } - } else if (sm_id == 1) { - // Barrier for NVL - barrier_device(task_fifo_ptrs, head, nvl_rank); - move_fifo_slots(head); - __syncthreads(); - - // Clean - auto nvl_buffer_ptr_int = reinterpret_cast(buffer_ptrs[nvl_rank]); - #pragma unroll - for (int i = thread_id; i < nvl_num_int_clean; i += num_threads) - nvl_buffer_ptr_int[nvl_clean_offset + i] = 0; - memory_fence(); - __syncthreads(); - - // Barrier again - barrier_device(task_fifo_ptrs, head, nvl_rank); - move_fifo_slots(head); - } else if (sm_id == 2) { - if (is_cached_dispatch) - return; - - EP_DEVICE_ASSERT(num_warps >= num_channels); - EP_DEVICE_ASSERT(num_rdma_ranks <= 32); - - // Iterate in reverse order - if (lane_id < num_rdma_ranks and warp_id < num_channels) { - int token_start_idx, token_end_idx; - get_channel_task_range(num_combined_tokens, num_channels, warp_id, token_start_idx, token_end_idx); - - // NOTES: `1 << 25` is a heuristic large number - int last_head = 1 << 25; - for (int token_idx = token_end_idx - 1; token_idx >= token_start_idx; -- token_idx) { - auto current_head = __ldg(combined_rdma_head + token_idx * num_rdma_ranks + lane_id); - if (current_head < 0) { - combined_rdma_head[token_idx * num_rdma_ranks + lane_id] = -last_head - 1; - } else { - last_head = current_head; - } - } - } - } else { - if (is_cached_dispatch) - return; - - EP_DEVICE_ASSERT(num_warps >= num_channels); - EP_DEVICE_ASSERT(rdma_channel_prefix_matrix != nullptr and rdma_rank_prefix_sum != nullptr); - EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS <= 32, "Too many NVL peers"); - - if (lane_id < NUM_MAX_NVL_PEERS and warp_id < num_channels) { - for (int dst_rdma_rank = sm_id - 3; dst_rdma_rank < num_rdma_ranks; dst_rdma_rank += num_channels * 2 - 3) { - // Iterate in reverse order - int token_start_idx = warp_id == 0 ? 0 : rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + warp_id - 1]; - int token_end_idx = rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + warp_id]; - int shift = dst_rdma_rank == 0 ? 0 : rdma_rank_prefix_sum[dst_rdma_rank - 1]; - token_start_idx += shift, token_end_idx += shift; - - // NOTES: `1 << 25` is a heuristic large number - int last_head = 1 << 25; - #pragma unroll - for (int token_idx = token_end_idx - 1; token_idx >= token_start_idx; -- token_idx) { - auto current_head = __ldg(combined_nvl_head + token_idx * NUM_MAX_NVL_PEERS + lane_id); - if (current_head < 0) { - combined_nvl_head[token_idx * NUM_MAX_NVL_PEERS + lane_id] = -last_head - 1; - } else { - last_head = current_head; - } - } - } - } + // TODO(chhwang): it should be a global barrier when kLowLatencyMode is false + const bool run_barrier = (threadIdx.x < num_rdma_ranks) && (threadIdx.x != rdma_rank); + const auto barrier_channel_idx = kLowLatencyMode ? threadIdx.x : (threadIdx.x * NUM_MAX_NVL_PEERS + nvl_rank); + if (run_barrier) { + port_channel_handles[barrier_channel_idx].signal(); + port_channel_handles[barrier_channel_idx].wait(); } + __syncthreads(); + + // Clean + auto rdma_buffer_ptr_int = reinterpret_cast(rdma_buffer_ptr); +#pragma unroll + for (int i = thread_id; i < rdma_num_int_clean; i += num_threads) rdma_buffer_ptr_int[rdma_clean_offset + i] = 0; + // Make the cleanup visible to the proxy + remote peers before the barrier. + // DeepEP used `nvshmem_fence()` here; we fall back to a system-scope + // threadfence because the actual remote visibility is provided by the + // subsequent port-channel barrier (signal + flush + wait). + __threadfence_system(); + __syncthreads(); + + // Barrier again + if (run_barrier) { + port_channel_handles[barrier_channel_idx].signal(); + port_channel_handles[barrier_channel_idx].flush(); + port_channel_handles[barrier_channel_idx].wait(); + } + } else if (sm_id == 1) { + // Barrier for NVL + barrier_device(task_fifo_ptrs, head, nvl_rank); + move_fifo_slots(head); + __syncthreads(); + + // Clean + auto nvl_buffer_ptr_int = reinterpret_cast(buffer_ptrs[nvl_rank]); +#pragma unroll + for (int i = thread_id; i < nvl_num_int_clean; i += num_threads) nvl_buffer_ptr_int[nvl_clean_offset + i] = 0; + memory_fence(); + __syncthreads(); + + // Barrier again + barrier_device(task_fifo_ptrs, head, nvl_rank); + move_fifo_slots(head); + } else if (sm_id == 2) { + if (is_cached_dispatch) return; + + EP_DEVICE_ASSERT(num_warps >= num_channels); + EP_DEVICE_ASSERT(num_rdma_ranks <= 32); + + // Iterate in reverse order + if (lane_id < num_rdma_ranks and warp_id < num_channels) { + int token_start_idx, token_end_idx; + get_channel_task_range(num_combined_tokens, num_channels, warp_id, token_start_idx, token_end_idx); + + // NOTES: `1 << 25` is a heuristic large number + int last_head = 1 << 25; + for (int token_idx = token_end_idx - 1; token_idx >= token_start_idx; --token_idx) { + auto current_head = __ldg(combined_rdma_head + token_idx * num_rdma_ranks + lane_id); + if (current_head < 0) { + combined_rdma_head[token_idx * num_rdma_ranks + lane_id] = -last_head - 1; + } else { + last_head = current_head; + } + } + } + } else { + if (is_cached_dispatch) return; + + EP_DEVICE_ASSERT(num_warps >= num_channels); + EP_DEVICE_ASSERT(rdma_channel_prefix_matrix != nullptr and rdma_rank_prefix_sum != nullptr); + EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS <= 32, "Too many NVL peers"); + + if (lane_id < NUM_MAX_NVL_PEERS and warp_id < num_channels) { + for (int dst_rdma_rank = sm_id - 3; dst_rdma_rank < num_rdma_ranks; dst_rdma_rank += num_channels * 2 - 3) { + // Iterate in reverse order + int token_start_idx = warp_id == 0 ? 0 : rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + warp_id - 1]; + int token_end_idx = rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + warp_id]; + int shift = dst_rdma_rank == 0 ? 0 : rdma_rank_prefix_sum[dst_rdma_rank - 1]; + token_start_idx += shift, token_end_idx += shift; + + // NOTES: `1 << 25` is a heuristic large number + int last_head = 1 << 25; +#pragma unroll + for (int token_idx = token_end_idx - 1; token_idx >= token_start_idx; --token_idx) { + auto current_head = __ldg(combined_nvl_head + token_idx * NUM_MAX_NVL_PEERS + lane_id); + if (current_head < 0) { + combined_nvl_head[token_idx * NUM_MAX_NVL_PEERS + lane_id] = -last_head - 1; + } else { + last_head = current_head; + } + } + } + } + } } -void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_topk_weights, - int num_ranks, int num_channels, int num_combined_tokens, int* combined_rdma_head, +void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_topk_weights, int num_ranks, + int num_channels, int num_combined_tokens, int* combined_rdma_head, const int* rdma_channel_prefix_matrix, const int* rdma_rank_prefix_sum, int* combined_nvl_head, - void* rdma_buffer_ptr, int num_max_rdma_chunked_recv_tokens, - void** buffer_ptrs, int num_max_nvl_chunked_recv_tokens, - int** task_fifo_ptrs, int head, int rank, cudaStream_t stream, - int64_t num_rdma_bytes, int64_t num_nvl_bytes, - bool is_cached_dispatch, bool low_latency_mode, - mscclpp::PortChannelDeviceHandle *port_channel_handles, - mscclpp::MemoryChannelDeviceHandle *memory_channel_handles) { - const int num_threads = std::max(128, 32 * num_channels); - const auto num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS; + void* rdma_buffer_ptr, int num_max_rdma_chunked_recv_tokens, void** buffer_ptrs, + int num_max_nvl_chunked_recv_tokens, int** task_fifo_ptrs, int head, int rank, cudaStream_t stream, + int64_t num_rdma_bytes, int64_t num_nvl_bytes, bool is_cached_dispatch, bool low_latency_mode, + mscclpp::PortChannelDeviceHandle* port_channel_handles, + mscclpp::MemoryChannelDeviceHandle* memory_channel_handles) { + const int num_threads = std::max(128, 32 * num_channels); + const auto num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS; - // Get clean meta - auto rdma_clean_meta = get_rdma_clean_meta(hidden_int4, num_scales, num_topk_idx, num_topk_weights, num_rdma_ranks, num_max_rdma_chunked_recv_tokens, num_channels); - auto nvl_clean_meta = get_nvl_clean_meta(hidden_int4, num_scales, num_topk_idx, num_topk_weights, num_rdma_ranks, NUM_MAX_NVL_PEERS, num_max_nvl_chunked_recv_tokens, num_channels); - EP_HOST_ASSERT((rdma_clean_meta.first + rdma_clean_meta.second) * sizeof(int) <= num_rdma_bytes); - EP_HOST_ASSERT((nvl_clean_meta.first + nvl_clean_meta.second) * sizeof(int) <= num_nvl_bytes); - EP_HOST_ASSERT(num_rdma_bytes < std::numeric_limits::max()); - EP_HOST_ASSERT(num_nvl_bytes < std::numeric_limits::max()); - EP_HOST_ASSERT(num_channels * 2 > 3); + // Get clean meta + auto rdma_clean_meta = get_rdma_clean_meta(hidden_int4, num_scales, num_topk_idx, num_topk_weights, num_rdma_ranks, + num_max_rdma_chunked_recv_tokens, num_channels); + auto nvl_clean_meta = get_nvl_clean_meta(hidden_int4, num_scales, num_topk_idx, num_topk_weights, num_rdma_ranks, + NUM_MAX_NVL_PEERS, num_max_nvl_chunked_recv_tokens, num_channels); + EP_HOST_ASSERT((rdma_clean_meta.first + rdma_clean_meta.second) * sizeof(int) <= num_rdma_bytes); + EP_HOST_ASSERT((nvl_clean_meta.first + nvl_clean_meta.second) * sizeof(int) <= num_nvl_bytes); + EP_HOST_ASSERT(num_rdma_bytes < std::numeric_limits::max()); + EP_HOST_ASSERT(num_nvl_bytes < std::numeric_limits::max()); + EP_HOST_ASSERT(num_channels * 2 > 3); - // Launch kernel - auto cached_notify_func = low_latency_mode ? cached_notify : cached_notify; - SETUP_LAUNCH_CONFIG(num_channels * 2, num_threads, stream); - LAUNCH_KERNEL(&cfg, cached_notify_func, - rdma_clean_meta.first, rdma_clean_meta.second, - nvl_clean_meta.first, nvl_clean_meta.second, - combined_rdma_head, num_combined_tokens, num_channels, - rdma_channel_prefix_matrix, rdma_rank_prefix_sum, combined_nvl_head, - rdma_buffer_ptr, - buffer_ptrs, task_fifo_ptrs, head, rank, num_ranks, - is_cached_dispatch, - port_channel_handles, memory_channel_handles); + // Launch kernel + auto cached_notify_func = low_latency_mode ? cached_notify : cached_notify; + SETUP_LAUNCH_CONFIG(num_channels * 2, num_threads, stream); + LAUNCH_KERNEL(&cfg, cached_notify_func, rdma_clean_meta.first, rdma_clean_meta.second, nvl_clean_meta.first, + nvl_clean_meta.second, combined_rdma_head, num_combined_tokens, num_channels, + rdma_channel_prefix_matrix, rdma_rank_prefix_sum, combined_nvl_head, rdma_buffer_ptr, buffer_ptrs, + task_fifo_ptrs, head, rank, num_ranks, is_cached_dispatch, port_channel_handles, + memory_channel_handles); } template -__device__ int combine_token(bool is_token_in_rank, int head_idx, - int lane_id, int hidden_int4, int num_topk, - int4* combined_row, float* combined_topk_weights, - int num_max_recv_tokens, const ReceiveFn& recv_fn, const ReceiveTWFn& recv_tw_fn) { - constexpr auto kDtypePerInt4 = sizeof(int4) / sizeof(dtype_t); +__device__ int combine_token(bool is_token_in_rank, int head_idx, int lane_id, int hidden_int4, int num_topk, + int4* combined_row, float* combined_topk_weights, int num_max_recv_tokens, + const ReceiveFn& recv_fn, const ReceiveTWFn& recv_tw_fn) { + constexpr auto kDtypePerInt4 = sizeof(int4) / sizeof(dtype_t); - // Broadcast current heads - // Lane `i` holds the head of rank `i` and `is_token_in_rank` - EP_STATIC_ASSERT(kMaxNumRanks <= 32, "Too many ranks"); - int num_topk_ranks = 0, topk_ranks[kMaxNumRanks], slot_indices[kMaxNumRanks]; - #pragma unroll - for (int i = 0; i < kNumRanks; ++ i) if (__shfl_sync(0xffffffff, is_token_in_rank, i)) { - slot_indices[num_topk_ranks] = __shfl_sync(0xffffffff, head_idx, i) % num_max_recv_tokens; - topk_ranks[num_topk_ranks ++] = i; + // Broadcast current heads + // Lane `i` holds the head of rank `i` and `is_token_in_rank` + EP_STATIC_ASSERT(kMaxNumRanks <= 32, "Too many ranks"); + int num_topk_ranks = 0, topk_ranks[kMaxNumRanks], slot_indices[kMaxNumRanks]; +#pragma unroll + for (int i = 0; i < kNumRanks; ++i) + if (__shfl_sync(0xffffffff, is_token_in_rank, i)) { + slot_indices[num_topk_ranks] = __shfl_sync(0xffffffff, head_idx, i) % num_max_recv_tokens; + topk_ranks[num_topk_ranks++] = i; } - EP_DEVICE_ASSERT(num_topk_ranks <= kMaxNumRanks); + EP_DEVICE_ASSERT(num_topk_ranks <= kMaxNumRanks); - // Reduce data - #pragma unroll - for (int i = lane_id; i < hidden_int4; i += 32) { - // Read buffers - // TODO: maybe too many registers here - int4 recv_value_int4[kMaxNumRanks]; - #pragma unroll - for (int j = 0; j < num_topk_ranks; ++ j) - recv_value_int4[j] = recv_fn(topk_ranks[j], slot_indices[j], i); +// Reduce data +#pragma unroll + for (int i = lane_id; i < hidden_int4; i += 32) { + // Read buffers + // TODO: maybe too many registers here + int4 recv_value_int4[kMaxNumRanks]; +#pragma unroll + for (int j = 0; j < num_topk_ranks; ++j) recv_value_int4[j] = recv_fn(topk_ranks[j], slot_indices[j], i); - // Reduce all-to-all results - float values[kDtypePerInt4] = {0}; - #pragma unroll - for (int j = 0; j < num_topk_ranks; ++ j) { - auto recv_value_dtypes = reinterpret_cast(&recv_value_int4[j]); - #pragma unroll - for (int k = 0; k < kDtypePerInt4; ++ k) - values[k] += static_cast(recv_value_dtypes[k]); - } - - // Cast back to `dtype_t` and write - int4 out_int4; - auto out_dtypes = reinterpret_cast(&out_int4); - #pragma unroll - for (int j = 0; j < kDtypePerInt4; ++ j) - out_dtypes[j] = static_cast(values[j]); - st_na_global(combined_row + i, out_int4); + // Reduce all-to-all results + float values[kDtypePerInt4] = {0}; +#pragma unroll + for (int j = 0; j < num_topk_ranks; ++j) { + auto recv_value_dtypes = reinterpret_cast(&recv_value_int4[j]); +#pragma unroll + for (int k = 0; k < kDtypePerInt4; ++k) values[k] += static_cast(recv_value_dtypes[k]); } - // Reduce `topk_weights` - if (lane_id < num_topk) { - float value = 0; - #pragma unroll - for (int i = 0; i < num_topk_ranks; ++ i) - value += recv_tw_fn(topk_ranks[i], slot_indices[i], lane_id); - st_na_global(combined_topk_weights + lane_id, value); - } + // Cast back to `dtype_t` and write + int4 out_int4; + auto out_dtypes = reinterpret_cast(&out_int4); +#pragma unroll + for (int j = 0; j < kDtypePerInt4; ++j) out_dtypes[j] = static_cast(values[j]); + st_na_global(combined_row + i, out_int4); + } - // Return the minimum top-k rank - return topk_ranks[0]; + // Reduce `topk_weights` + if (lane_id < num_topk) { + float value = 0; +#pragma unroll + for (int i = 0; i < num_topk_ranks; ++i) value += recv_tw_fn(topk_ranks[i], slot_indices[i], lane_id); + st_na_global(combined_topk_weights + lane_id, value); + } + + // Return the minimum top-k rank + return topk_ranks[0]; } -template 0) ? kNumCombineForwarderWarps / kNumRDMARanks : 1, - int kNumForwarders = kNumRDMARanks * kNumWarpsPerForwarder, - int kNumRDMAReceivers = kNumForwarders + NUM_MAX_NVL_PEERS> +template 0) ? kNumCombineForwarderWarps / kNumRDMARanks : 1, + int kNumForwarders = kNumRDMARanks* kNumWarpsPerForwarder, + int kNumRDMAReceivers = kNumForwarders + NUM_MAX_NVL_PEERS> __global__ void __launch_bounds__((NUM_MAX_NVL_PEERS + 1 + kNumForwarders) * 32, 1) -combine(int4* combined_x, float* combined_topk_weights, - const bool* is_combined_token_in_rank, - const int4* x, const float* topk_weights, - const int* combined_rdma_head, const int* combined_nvl_head, - const SourceMeta* src_meta, const int* rdma_channel_prefix_matrix, const int* rdma_rank_prefix_sum, const int* gbl_channel_prefix_matrix, - int num_tokens, int num_combined_tokens, int hidden, int num_topk, - void* rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens, - void** buffer_ptrs, int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens, - int rank, int num_ranks, - mscclpp::PortChannelDeviceHandle *port_channel_handles, - mscclpp::MemoryChannelDeviceHandle *memory_channel_handles) { - enum class WarpRole { - kNVLSender, - kNVLAndRDMAForwarder, - kRDMAReceiver, - kCoordinator - }; + combine(int4* combined_x, float* combined_topk_weights, const bool* is_combined_token_in_rank, const int4* x, + const float* topk_weights, const int* combined_rdma_head, const int* combined_nvl_head, + const SourceMeta* src_meta, const int* rdma_channel_prefix_matrix, const int* rdma_rank_prefix_sum, + const int* gbl_channel_prefix_matrix, int num_tokens, int num_combined_tokens, int hidden, int num_topk, + void* rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens, + void** buffer_ptrs, int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens, int rank, + int num_ranks, mscclpp::PortChannelDeviceHandle* port_channel_handles, + mscclpp::MemoryChannelDeviceHandle* memory_channel_handles) { + enum class WarpRole { kNVLSender, kNVLAndRDMAForwarder, kRDMAReceiver, kCoordinator }; - const auto sm_id = static_cast(blockIdx.x); - const auto thread_id = static_cast(threadIdx.x), lane_id = get_lane_id(); - const auto num_channels = static_cast(gridDim.x) / 2, channel_id = sm_id / 2; - const bool is_rdma_receiver_sm = sm_id % 2 == 1; + const auto sm_id = static_cast(blockIdx.x); + const auto thread_id = static_cast(threadIdx.x), lane_id = get_lane_id(); + const auto num_channels = static_cast(gridDim.x) / 2, channel_id = sm_id / 2; + const bool is_rdma_receiver_sm = sm_id % 2 == 1; - EP_DEVICE_ASSERT(num_topk <= 32); - EP_DEVICE_ASSERT(hidden % (sizeof(int4) / sizeof(dtype_t)) == 0); - const auto hidden_int4 = hidden / (sizeof(int4) / sizeof(dtype_t)); + EP_DEVICE_ASSERT(num_topk <= 32); + EP_DEVICE_ASSERT(hidden % (sizeof(int4) / sizeof(dtype_t)) == 0); + const auto hidden_int4 = hidden / (sizeof(int4) / sizeof(dtype_t)); - // NOTES: we decouple a channel into 2 SMs - const auto rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS; - auto role_meta = [=]() -> std::pair { - auto warp_id = thread_id / 32; - if (not is_rdma_receiver_sm) { - if (warp_id < NUM_MAX_NVL_PEERS) { - auto shuffled_warp_id = warp_id; - shuffled_warp_id = (shuffled_warp_id + channel_id) % NUM_MAX_NVL_PEERS; - return {WarpRole::kNVLSender, shuffled_warp_id}; - } else if (warp_id < NUM_MAX_NVL_PEERS + kNumForwarders) { - auto shuffled_warp_id = warp_id - NUM_MAX_NVL_PEERS; - shuffled_warp_id = (shuffled_warp_id + channel_id) % kNumForwarders; - return {WarpRole::kNVLAndRDMAForwarder, shuffled_warp_id}; - } else { - return {WarpRole::kCoordinator, 0}; - } - } else { - if (warp_id < NUM_MAX_NVL_PEERS + kNumForwarders) { - return {WarpRole::kRDMAReceiver, warp_id}; - } else { - return {WarpRole::kCoordinator, 0}; - } + // NOTES: we decouple a channel into 2 SMs + const auto rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS; + auto role_meta = [=]() -> std::pair { + auto warp_id = thread_id / 32; + if (not is_rdma_receiver_sm) { + if (warp_id < NUM_MAX_NVL_PEERS) { + auto shuffled_warp_id = warp_id; + shuffled_warp_id = (shuffled_warp_id + channel_id) % NUM_MAX_NVL_PEERS; + return {WarpRole::kNVLSender, shuffled_warp_id}; + } else if (warp_id < NUM_MAX_NVL_PEERS + kNumForwarders) { + auto shuffled_warp_id = warp_id - NUM_MAX_NVL_PEERS; + shuffled_warp_id = (shuffled_warp_id + channel_id) % kNumForwarders; + return {WarpRole::kNVLAndRDMAForwarder, shuffled_warp_id}; + } else { + return {WarpRole::kCoordinator, 0}; + } + } else { + if (warp_id < NUM_MAX_NVL_PEERS + kNumForwarders) { + return {WarpRole::kRDMAReceiver, warp_id}; + } else { + return {WarpRole::kCoordinator, 0}; + } + } + }(); + auto warp_role = role_meta.first; + auto warp_id = role_meta.second; + + EP_DEVICE_ASSERT(num_warps == NUM_MAX_NVL_PEERS + kNumForwarders + 1); + auto num_max_nvl_chunked_recv_tokens_per_rdma = num_max_nvl_chunked_recv_tokens / kNumRDMARanks; + + if (warp_role == WarpRole::kNVLSender) { + // NVL producers + const auto dst_nvl_rank = warp_id; + + // NVL layouts + // NOTES: to avoid deadlocks, we use separate NVL buffers for different RDMA sources + auto dst_buffer_ptr = buffer_ptrs[dst_nvl_rank], local_buffer_ptr = buffer_ptrs[nvl_rank]; + auto nvl_channel_x = AsymBuffer(dst_buffer_ptr, num_max_nvl_chunked_recv_tokens * hidden_int4, + NUM_MAX_NVL_PEERS, channel_id, num_channels, nvl_rank) + .advance_also(local_buffer_ptr); + auto nvl_channel_src_meta = AsymBuffer(dst_buffer_ptr, num_max_nvl_chunked_recv_tokens, + NUM_MAX_NVL_PEERS, channel_id, num_channels, nvl_rank) + .advance_also(local_buffer_ptr); + auto nvl_channel_topk_weights = AsymBuffer(dst_buffer_ptr, num_max_nvl_chunked_recv_tokens * num_topk, + NUM_MAX_NVL_PEERS, channel_id, num_channels, nvl_rank) + .advance_also(local_buffer_ptr); + auto nvl_channel_head = + AsymBuffer(local_buffer_ptr, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels, dst_nvl_rank) + .advance_also(dst_buffer_ptr); + auto nvl_channel_tail = + AsymBuffer(dst_buffer_ptr, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels, nvl_rank) + .advance_also(local_buffer_ptr); + + // Get tasks for each RDMA lane + int token_start_idx = 0, token_end_idx = 0; + if (lane_id < kNumRDMARanks) { + int prefix_idx = (lane_id * NUM_MAX_NVL_PEERS + dst_nvl_rank) * num_channels + channel_id; + token_start_idx = gbl_channel_prefix_matrix[prefix_idx]; + token_end_idx = + (prefix_idx == num_channels * num_ranks - 1) ? num_tokens : gbl_channel_prefix_matrix[prefix_idx + 1]; + } + __syncwarp(); + + // NOTES: here the cached value of each lane is only responsible for a single RDMA buffer + int cached_channel_head_idx = 0, cached_channel_tail_idx = 0; + EP_STATIC_ASSERT(kNumRDMARanks <= 32, "Invalid number of RDMA peers"); + + // Iterate over all tokens and send by chunks + while (true) { + // Exit if possible + if (__all_sync(0xffffffff, token_start_idx >= token_end_idx)) break; + + // Decide next RDMA buffer to send + bool is_lane_ready = false; + auto start_time = clock64(); + while (true) { + int num_used_slots = cached_channel_tail_idx - cached_channel_head_idx; + is_lane_ready = lane_id < kNumRDMARanks and token_start_idx < token_end_idx and + num_max_nvl_chunked_recv_tokens_per_rdma - num_used_slots >= num_max_nvl_chunked_send_tokens; + if (__any_sync(0xffffffff, is_lane_ready)) break; + + // Retry + if (lane_id < kNumRDMARanks and token_start_idx < token_end_idx) + cached_channel_head_idx = ld_volatile_global(nvl_channel_head.buffer() + lane_id); + + // Timeout check + if (clock64() - start_time > NUM_TIMEOUT_CYCLES and lane_id < kNumRDMARanks) { + printf( + "DeepEP combine NVL sender timeout, channel: %d, RDMA: %d, nvl: %d, dst NVL: %d, RDMA lane: %d, head: " + "%d, tail: %d, start: %d, end: %d\n", + channel_id, rdma_rank, nvl_rank, dst_nvl_rank, lane_id, + ld_volatile_global(nvl_channel_head.buffer() + lane_id), cached_channel_tail_idx, token_start_idx, + token_end_idx); + trap(); } - }(); - auto warp_role = role_meta.first; - auto warp_id = role_meta.second; + } - EP_DEVICE_ASSERT(num_warps == NUM_MAX_NVL_PEERS + kNumForwarders + 1); - auto num_max_nvl_chunked_recv_tokens_per_rdma = num_max_nvl_chunked_recv_tokens / kNumRDMARanks; + // Sync token start index and count + for (int current_rdma_idx = 0; current_rdma_idx < kNumRDMARanks; ++current_rdma_idx) { + if (__shfl_sync(0xffffffff, (token_start_idx >= token_end_idx) or (not is_lane_ready), current_rdma_idx)) + continue; - if (warp_role == WarpRole::kNVLSender) { - // NVL producers - const auto dst_nvl_rank = warp_id; + // Sync token start index + auto token_idx = static_cast(__shfl_sync(0xffffffff, token_start_idx, current_rdma_idx)); + int num_tokens_in_chunk = __shfl_sync( + 0xffffffff, min(num_max_nvl_chunked_send_tokens, token_end_idx - token_start_idx), current_rdma_idx); - // NVL layouts - // NOTES: to avoid deadlocks, we use separate NVL buffers for different RDMA sources - auto dst_buffer_ptr = buffer_ptrs[dst_nvl_rank], local_buffer_ptr = buffer_ptrs[nvl_rank]; - auto nvl_channel_x = AsymBuffer(dst_buffer_ptr, num_max_nvl_chunked_recv_tokens * hidden_int4, NUM_MAX_NVL_PEERS, channel_id, num_channels, nvl_rank).advance_also(local_buffer_ptr); - auto nvl_channel_src_meta = AsymBuffer(dst_buffer_ptr, num_max_nvl_chunked_recv_tokens, NUM_MAX_NVL_PEERS, channel_id, num_channels, nvl_rank).advance_also(local_buffer_ptr); - auto nvl_channel_topk_weights = AsymBuffer(dst_buffer_ptr, num_max_nvl_chunked_recv_tokens * num_topk, NUM_MAX_NVL_PEERS, channel_id, num_channels, nvl_rank).advance_also(local_buffer_ptr); - auto nvl_channel_head = AsymBuffer(local_buffer_ptr, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels, dst_nvl_rank).advance_also(dst_buffer_ptr); - auto nvl_channel_tail = AsymBuffer(dst_buffer_ptr, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels, nvl_rank).advance_also(local_buffer_ptr); + // Send by chunk + for (int chunk_idx = 0; chunk_idx < num_tokens_in_chunk; ++chunk_idx, ++token_idx) { + // Get an empty slot + int dst_slot_idx = 0; + if (lane_id == current_rdma_idx) { + dst_slot_idx = (cached_channel_tail_idx++) % num_max_nvl_chunked_recv_tokens_per_rdma; + dst_slot_idx = current_rdma_idx * num_max_nvl_chunked_recv_tokens_per_rdma + dst_slot_idx; + } + dst_slot_idx = __shfl_sync(0xffffffff, dst_slot_idx, current_rdma_idx); - // Get tasks for each RDMA lane - int token_start_idx = 0, token_end_idx = 0; + // Copy data + auto shifted_x_buffers = nvl_channel_x.buffer() + dst_slot_idx * hidden_int4; + auto shifted_x = x + token_idx * hidden_int4; + UNROLLED_WARP_COPY(5, lane_id, hidden_int4, shifted_x_buffers, shifted_x, ld_nc_global, st_na_global); + + // Copy source meta + if (lane_id == 0) + st_na_global(nvl_channel_src_meta.buffer() + dst_slot_idx, ld_nc_global(src_meta + token_idx)); + + // Copy `topk_weights` + if (lane_id < num_topk) + st_na_global(nvl_channel_topk_weights.buffer() + dst_slot_idx * num_topk + lane_id, + ld_nc_global(topk_weights + token_idx * num_topk + lane_id)); + } + lane_id == current_rdma_idx ? (token_start_idx = static_cast(token_idx)) : 0; + } + + // Move queue tail + __syncwarp(); + if (lane_id < kNumRDMARanks and is_lane_ready) + st_release_sys_global(nvl_channel_tail.buffer() + lane_id, cached_channel_tail_idx); + } + } else { + // Combiners and coordinators + // RDMA symmetric layout + auto hidden_bytes = hidden_int4 * sizeof(int4); + auto num_bytes_per_rdma_token = get_num_bytes_per_rdma_token(hidden_int4, 0, 0, num_topk); + auto rdma_channel_data = + SymBuffer(rdma_buffer_ptr, num_max_rdma_chunked_recv_tokens * num_bytes_per_rdma_token, kNumRDMARanks, + channel_id, num_channels); + auto rdma_channel_head = SymBuffer(rdma_buffer_ptr, 1, kNumRDMARanks, channel_id, num_channels); + auto rdma_channel_tail = SymBuffer(rdma_buffer_ptr, 1, kNumRDMARanks, channel_id, num_channels); + + auto data_send_offset = + sizeof(int8_t) * (num_max_rdma_chunked_recv_tokens * num_bytes_per_rdma_token) * kNumRDMARanks * channel_id; + auto data_recv_offset = sizeof(int8_t) * (num_max_rdma_chunked_recv_tokens * num_bytes_per_rdma_token) * + kNumRDMARanks * (channel_id + num_channels); + auto head_offset = sizeof(int8_t) * (num_max_rdma_chunked_recv_tokens * num_bytes_per_rdma_token) * kNumRDMARanks * + num_channels * 2; + auto head_send_offset = head_offset + sizeof(uint64_t) * kNumRDMARanks * channel_id; + auto tail_offset = head_offset + sizeof(uint64_t) * kNumRDMARanks * num_channels; + auto tail_send_offset = tail_offset + sizeof(uint64_t) * kNumRDMARanks * channel_id; + + // NVL layouts + void* local_nvl_buffer = buffer_ptrs[nvl_rank]; + void* nvl_buffers[NUM_MAX_NVL_PEERS]; +#pragma unroll + for (int i = 0; i < NUM_MAX_NVL_PEERS; ++i) nvl_buffers[i] = buffer_ptrs[i]; + auto nvl_channel_x = AsymBuffer(local_nvl_buffer, num_max_nvl_chunked_recv_tokens * hidden_int4, + NUM_MAX_NVL_PEERS, channel_id, num_channels) + .advance_also(nvl_buffers); + auto nvl_channel_src_meta = AsymBuffer(local_nvl_buffer, num_max_nvl_chunked_recv_tokens, + NUM_MAX_NVL_PEERS, channel_id, num_channels) + .advance_also(nvl_buffers); + auto nvl_channel_topk_weights = AsymBuffer(local_nvl_buffer, num_max_nvl_chunked_recv_tokens * num_topk, + NUM_MAX_NVL_PEERS, channel_id, num_channels) + .advance_also(nvl_buffers); + auto nvl_channel_head = AsymBuffer(nvl_buffers, kNumRDMARanks, NUM_MAX_NVL_PEERS, + channel_id, num_channels, nvl_rank) + .advance_also(local_nvl_buffer); + auto nvl_channel_tail = + AsymBuffer(local_nvl_buffer, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels) + .advance_also(nvl_buffers); + + // Combiner warp synchronization + __shared__ volatile int forwarder_nvl_head[kNumForwarders][NUM_MAX_NVL_PEERS]; + __shared__ volatile bool forwarder_retired[kNumForwarders]; + __shared__ volatile int rdma_receiver_rdma_head[kNumRDMAReceivers][kNumRDMARanks]; + __shared__ volatile bool rdma_receiver_retired[kNumRDMAReceivers]; + auto sync_forwarder_smem = [=]() { asm volatile("bar.sync 0, %0;" ::"r"((kNumForwarders + 1) * 32)); }; + auto sync_rdma_receiver_smem = [=]() { asm volatile("bar.sync 1, %0;" ::"r"((kNumRDMAReceivers + 1) * 32)); }; + + if (warp_role == WarpRole::kNVLAndRDMAForwarder) { + // Receive from NVL ranks and forward to RDMA ranks + // NOTES: this part is using "large warps" for each RDMA ranks + const auto dst_rdma_rank = warp_id / kNumWarpsPerForwarder; + const auto sub_warp_id = warp_id % kNumWarpsPerForwarder; + auto send_buffer = dst_rdma_rank == rdma_rank ? rdma_channel_data.recv_buffer(dst_rdma_rank) + : rdma_channel_data.send_buffer(dst_rdma_rank); + auto sync_large_warp = [=]() { + if (kNumWarpsPerForwarder == 1) { + __syncwarp(); + } else { + asm volatile("bar.sync %0, %1;" ::"r"(dst_rdma_rank + 2), "r"(kNumWarpsPerForwarder * 32)); + } + }; + EP_STATIC_ASSERT(kNumWarpsPerForwarder == 1 or kNumRDMARanks + 2 <= 16, "Barriers are not enough"); + + // Advance to the corresponding NVL buffer + nvl_channel_x.advance(dst_rdma_rank * num_max_nvl_chunked_recv_tokens_per_rdma * hidden_int4); + nvl_channel_src_meta.advance(dst_rdma_rank * num_max_nvl_chunked_recv_tokens_per_rdma); + nvl_channel_topk_weights.advance(dst_rdma_rank * num_max_nvl_chunked_recv_tokens_per_rdma * num_topk); + nvl_channel_head.advance(dst_rdma_rank); + nvl_channel_tail.advance(dst_rdma_rank); + + // Clean shared memory and sync + EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS <= 32, "Invalid number of NVL peers"); + lane_id < NUM_MAX_NVL_PEERS ? (forwarder_nvl_head[warp_id][lane_id] = 0) : 0; + lane_id == 0 ? (forwarder_retired[warp_id] = false) : false; + sync_forwarder_smem(); + + // Get count and cached head + int cached_nvl_channel_tail_idx = 0; + int num_tokens_to_combine = rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id]; + int num_tokens_prefix = + channel_id == 0 ? 0 : rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id - 1]; + num_tokens_to_combine -= num_tokens_prefix; + num_tokens_prefix += dst_rdma_rank == 0 ? 0 : rdma_rank_prefix_sum[dst_rdma_rank - 1]; + combined_nvl_head += num_tokens_prefix * NUM_MAX_NVL_PEERS; + + // Iterate over all tokens and combine by chunks + for (int token_start_idx = 0; token_start_idx < num_tokens_to_combine; + token_start_idx += num_max_rdma_chunked_send_tokens) { + // Check destination queue emptiness, or wait a buffer to be released + auto token_end_idx = min(token_start_idx + num_max_rdma_chunked_send_tokens, num_tokens_to_combine); + auto num_chunked_tokens = token_end_idx - token_start_idx; + auto start_time = clock64(); + while (sub_warp_id == 0 and lane_id == 0) { + // Inequality: `num_max_rdma_chunked_recv_tokens - (tail - head) >= num_chunked_tokens` + // Here, `token_start_idx` is the actual tail + int num_used_slots = token_start_idx - ld_volatile_global(rdma_channel_head.buffer(dst_rdma_rank)); + if (num_max_rdma_chunked_recv_tokens - num_used_slots >= num_chunked_tokens) break; + + // Timeout check + if (clock64() - start_time > NUM_TIMEOUT_CYCLES) { + printf( + "DeepEP combine forwarder (RDMA check) timeout, channel: %d, RDMA: %d, nvl: %d, dst RDMA: %d, head: " + "%ld, tail: %d, chunked: %d\n", + channel_id, rdma_rank, nvl_rank, dst_rdma_rank, + ld_volatile_global(rdma_channel_head.buffer(dst_rdma_rank)), token_start_idx, num_chunked_tokens); + trap(); + } + } + sync_large_warp(); + + // Combine and write to the RDMA buffer + for (int token_idx = token_start_idx + sub_warp_id; token_idx < token_end_idx; + token_idx += kNumWarpsPerForwarder) { + // Read expected head + EP_STATIC_ASSERT(kNumRDMARanks <= 32, "Invalid number of RDMA peers"); + int expected_head = -1; + if (lane_id < NUM_MAX_NVL_PEERS) + expected_head = ld_nc_global(combined_nvl_head + token_idx * NUM_MAX_NVL_PEERS + lane_id); + + // Wait lanes to be ready + start_time = clock64(); + while (cached_nvl_channel_tail_idx <= expected_head) { + cached_nvl_channel_tail_idx = ld_acquire_sys_global(nvl_channel_tail.buffer(lane_id)); + + // Timeout check + if (clock64() - start_time > NUM_TIMEOUT_CYCLES and lane_id < NUM_MAX_NVL_PEERS) { + printf( + "DeepEP combine forwarder (NVL check) timeout, channel: %d, RDMA: %d, nvl: %d, src NVL: %d, dst " + "RDMA: %d, tail: %d, waiting: %d, total: %d, sub: %d, large: %d, expected: %d\n", + channel_id, rdma_rank, nvl_rank, lane_id, dst_rdma_rank, cached_nvl_channel_tail_idx, token_idx, + num_tokens_to_combine, sub_warp_id, kNumWarpsPerForwarder, expected_head); + trap(); + } + } + + // Combine current token + auto rdma_slot_idx = token_idx % num_max_rdma_chunked_recv_tokens; + void* shifted = send_buffer + rdma_slot_idx * num_bytes_per_rdma_token; + auto recv_fn = [&](int src_nvl_rank, int slot_idx, int hidden_int4_idx) -> int4 { + return ld_nc_global(nvl_channel_x.buffer(src_nvl_rank) + slot_idx * hidden_int4 + hidden_int4_idx); + }; + auto recv_tw_fn = [&](int src_nvl_rank, int slot_idx, int topk_idx) -> float { + return ld_nc_global(nvl_channel_topk_weights.buffer(src_nvl_rank) + slot_idx * num_topk + topk_idx); + }; + combine_token( + expected_head >= 0, expected_head, lane_id, hidden_int4, num_topk, reinterpret_cast(shifted), + reinterpret_cast(reinterpret_cast(shifted) + hidden_bytes + sizeof(SourceMeta)), + num_max_nvl_chunked_recv_tokens_per_rdma, recv_fn, recv_tw_fn); + + // Update head + if (lane_id < NUM_MAX_NVL_PEERS) + expected_head < 0 ? (forwarder_nvl_head[warp_id][lane_id] = -expected_head - 1) + : (forwarder_nvl_head[warp_id][lane_id] = expected_head + 1); + } + sync_large_warp(); + + // Issue RDMA send + if (sub_warp_id == kNumWarpsPerForwarder - 1) { + if (lane_id == 0) { + if (dst_rdma_rank == rdma_rank) { + mscclpp::atomicFetchAdd(reinterpret_cast(rdma_channel_tail.buffer(rdma_rank)), + (uint64_t)num_chunked_tokens, mscclpp::memoryOrderRelease); + } else { + auto rdma_slot_idx = token_start_idx % num_max_rdma_chunked_recv_tokens; + const size_t num_bytes_per_msg = num_chunked_tokens * num_bytes_per_rdma_token; + auto dst_offset = rdma_rank * (num_max_rdma_chunked_recv_tokens * num_bytes_per_rdma_token) + + rdma_slot_idx * num_bytes_per_rdma_token + data_recv_offset; + auto src_offset = dst_rdma_rank * (num_max_rdma_chunked_recv_tokens * num_bytes_per_rdma_token) + + rdma_slot_idx * num_bytes_per_rdma_token + data_send_offset; + auto port_channel_idx = kLowLatencyMode + ? (channel_id * kNumRDMARanks + dst_rdma_rank) + : (channel_id * num_ranks + dst_rdma_rank * NUM_MAX_NVL_PEERS + nvl_rank); + auto& handle = port_channel_handles[port_channel_idx]; + handle.put(dst_offset, src_offset, num_bytes_per_msg); + + // Remote atomic add on the peer's tail counter: +num_chunked_tokens. + handle.atomicAdd(rdma_rank * sizeof(uint64_t) + tail_send_offset, (int64_t)num_chunked_tokens); + } + } + __syncwarp(); + } + } + + // Retired + __syncwarp(); + if (lane_id == 0) forwarder_retired[warp_id] = true; + } else if (warp_role == WarpRole::kRDMAReceiver) { + // Receive from RDMA ranks and write to the output tensor + // Clean shared memory and sync + EP_DEVICE_ASSERT(kNumRDMARanks <= 32); + lane_id < kNumRDMARanks ? (rdma_receiver_rdma_head[warp_id][lane_id] = 0) : 0; + lane_id == 0 ? (rdma_receiver_retired[warp_id] = false) : 0; + sync_rdma_receiver_smem(); + + // The same tokens as the dispatch process + int token_start_idx, token_end_idx; + get_channel_task_range(num_combined_tokens, num_channels, channel_id, token_start_idx, token_end_idx); + + // Iterate over all tokens and combine + int cached_channel_tail_idx = 0; + for (int64_t token_idx = token_start_idx + warp_id; token_idx < token_end_idx; token_idx += kNumRDMAReceivers) { + // Read expected head + EP_STATIC_ASSERT(kNumRDMARanks <= 32, "Invalid number of RDMA peers"); + int expected_head = -1; if (lane_id < kNumRDMARanks) { - int prefix_idx = (lane_id * NUM_MAX_NVL_PEERS + dst_nvl_rank) * num_channels + channel_id; - token_start_idx = gbl_channel_prefix_matrix[prefix_idx]; - token_end_idx = (prefix_idx == num_channels * num_ranks - 1) ? num_tokens : gbl_channel_prefix_matrix[prefix_idx + 1]; + expected_head = ld_nc_global(combined_rdma_head + token_idx * kNumRDMARanks + lane_id); + (expected_head < 0) ? (rdma_receiver_rdma_head[warp_id][lane_id] = -expected_head - 1) + : (rdma_receiver_rdma_head[warp_id][lane_id] = expected_head); + } + + // Wait lanes to be ready + auto start_time = clock64(); + while (cached_channel_tail_idx <= expected_head) { + cached_channel_tail_idx = static_cast(ld_acquire_sys_global(rdma_channel_tail.buffer(lane_id))); + + // Timeout check + if (clock64() - start_time > NUM_TIMEOUT_CYCLES) { + printf( + "DeepEP combine RDMA receiver timeout, channel: %d, RDMA: %d, nvl: %d, src RDMA: %d, tail: %d, " + "waiting: %ld, expect: %d\n", + channel_id, rdma_rank, nvl_rank, lane_id, cached_channel_tail_idx, token_idx, expected_head); + trap(); + } } __syncwarp(); - // NOTES: here the cached value of each lane is only responsible for a single RDMA buffer - int cached_channel_head_idx = 0, cached_channel_tail_idx = 0; - EP_STATIC_ASSERT(kNumRDMARanks <= 32, "Invalid number of RDMA peers"); + // Combine current token + auto recv_fn = [&](int src_rdma_rank, int slot_idx, int hidden_int4_idx) -> int4 { + return ld_nc_global(reinterpret_cast(rdma_channel_data.recv_buffer(src_rdma_rank) + + slot_idx * num_bytes_per_rdma_token) + + hidden_int4_idx); + }; + auto recv_tw_fn = [&](int src_rdma_rank, int slot_idx, int topk_idx) -> float { + return ld_nc_global(reinterpret_cast(rdma_channel_data.recv_buffer(src_rdma_rank) + + slot_idx * num_bytes_per_rdma_token + hidden_bytes + + sizeof(SourceMeta)) + + topk_idx); + }; + combine_token( + expected_head >= 0, expected_head, lane_id, hidden_int4, num_topk, combined_x + token_idx * hidden_int4, + combined_topk_weights + token_idx * num_topk, num_max_rdma_chunked_recv_tokens, recv_fn, recv_tw_fn); + } - // Iterate over all tokens and send by chunks - while (true) { - // Exit if possible - if (__all_sync(0xffffffff, token_start_idx >= token_end_idx)) - break; - - // Decide next RDMA buffer to send - bool is_lane_ready = false; - auto start_time = clock64(); - while (true) { - int num_used_slots = cached_channel_tail_idx - cached_channel_head_idx; - is_lane_ready = lane_id < kNumRDMARanks and token_start_idx < token_end_idx and num_max_nvl_chunked_recv_tokens_per_rdma - num_used_slots >= num_max_nvl_chunked_send_tokens; - if (__any_sync(0xffffffff, is_lane_ready)) - break; - - // Retry - if (lane_id < kNumRDMARanks and token_start_idx < token_end_idx) - cached_channel_head_idx = ld_volatile_global(nvl_channel_head.buffer() + lane_id); - - // Timeout check - if (clock64() - start_time > NUM_TIMEOUT_CYCLES and lane_id < kNumRDMARanks) { - printf("DeepEP combine NVL sender timeout, channel: %d, RDMA: %d, nvl: %d, dst NVL: %d, RDMA lane: %d, head: %d, tail: %d, start: %d, end: %d\n", - channel_id, rdma_rank, nvl_rank, dst_nvl_rank, lane_id, ld_volatile_global(nvl_channel_head.buffer() + lane_id), cached_channel_tail_idx, - token_start_idx, token_end_idx); - trap(); - } - } - - // Sync token start index and count - for (int current_rdma_idx = 0; current_rdma_idx < kNumRDMARanks; ++ current_rdma_idx) { - if (__shfl_sync(0xffffffff, (token_start_idx >= token_end_idx) or (not is_lane_ready), current_rdma_idx)) - continue; - - // Sync token start index - auto token_idx = static_cast(__shfl_sync(0xffffffff, token_start_idx, current_rdma_idx)); - int num_tokens_in_chunk = __shfl_sync(0xffffffff, min(num_max_nvl_chunked_send_tokens, token_end_idx - token_start_idx), current_rdma_idx); - - // Send by chunk - for (int chunk_idx = 0; chunk_idx < num_tokens_in_chunk; ++ chunk_idx, ++ token_idx) { - // Get an empty slot - int dst_slot_idx = 0; - if (lane_id == current_rdma_idx) { - dst_slot_idx = (cached_channel_tail_idx ++) % num_max_nvl_chunked_recv_tokens_per_rdma; - dst_slot_idx = current_rdma_idx * num_max_nvl_chunked_recv_tokens_per_rdma + dst_slot_idx; - } - dst_slot_idx = __shfl_sync(0xffffffff, dst_slot_idx, current_rdma_idx); - - // Copy data - auto shifted_x_buffers = nvl_channel_x.buffer() + dst_slot_idx * hidden_int4; - auto shifted_x = x + token_idx * hidden_int4; - UNROLLED_WARP_COPY(5, lane_id, hidden_int4, shifted_x_buffers, shifted_x, ld_nc_global, st_na_global); - - // Copy source meta - if (lane_id == 0) - st_na_global(nvl_channel_src_meta.buffer() + dst_slot_idx, ld_nc_global(src_meta + token_idx)); - - // Copy `topk_weights` - if (lane_id < num_topk) - st_na_global(nvl_channel_topk_weights.buffer() + dst_slot_idx * num_topk + lane_id, ld_nc_global(topk_weights + token_idx * num_topk + lane_id)); - } - lane_id == current_rdma_idx ? (token_start_idx = static_cast(token_idx)) : 0; - } - - // Move queue tail - __syncwarp(); - if (lane_id < kNumRDMARanks and is_lane_ready) - st_release_sys_global(nvl_channel_tail.buffer() + lane_id, cached_channel_tail_idx); - } + // Retired + __syncwarp(); + if (lane_id == 0) rdma_receiver_retired[warp_id] = true; } else { - // Combiners and coordinators - // RDMA symmetric layout - auto hidden_bytes = hidden_int4 * sizeof(int4); - auto num_bytes_per_rdma_token = get_num_bytes_per_rdma_token(hidden_int4, 0, 0, num_topk); - auto rdma_channel_data = SymBuffer(rdma_buffer_ptr, num_max_rdma_chunked_recv_tokens * num_bytes_per_rdma_token, kNumRDMARanks, channel_id, num_channels); - auto rdma_channel_head = SymBuffer(rdma_buffer_ptr, 1, kNumRDMARanks, channel_id, num_channels); - auto rdma_channel_tail = SymBuffer(rdma_buffer_ptr, 1, kNumRDMARanks, channel_id, num_channels); + // Coordinator + // Sync shared memory status + is_rdma_receiver_sm ? sync_rdma_receiver_smem() : sync_forwarder_smem(); + const auto num_warps_per_rdma_rank = kNumForwarders / kNumRDMARanks; - auto data_send_offset = sizeof(int8_t) * (num_max_rdma_chunked_recv_tokens * num_bytes_per_rdma_token) * kNumRDMARanks * channel_id; - auto data_recv_offset = sizeof(int8_t) * (num_max_rdma_chunked_recv_tokens * num_bytes_per_rdma_token) * kNumRDMARanks * (channel_id + num_channels); - auto head_offset = sizeof(int8_t) * (num_max_rdma_chunked_recv_tokens * num_bytes_per_rdma_token) * kNumRDMARanks * num_channels * 2; - auto head_send_offset = head_offset + sizeof(uint64_t) * kNumRDMARanks * channel_id; - auto tail_offset = head_offset + sizeof(uint64_t) * kNumRDMARanks * num_channels; - auto tail_send_offset = tail_offset + sizeof(uint64_t) * kNumRDMARanks * channel_id; + int last_rdma_head = 0; + int last_nvl_head[kNumRDMARanks] = {0}; + int dst_rdma_rank = lane_id < kNumRDMARanks ? lane_id : 0; + int dst_nvl_rank = lane_id < NUM_MAX_NVL_PEERS ? lane_id : 0; + EP_STATIC_ASSERT(kNumCombineForwarderWarps <= 32, "Invalid number of forwarder warps"); + while (true) { + // Retired + if (is_rdma_receiver_sm and + __all_sync(0xffffffff, lane_id >= kNumRDMAReceivers or rdma_receiver_retired[lane_id])) + break; + if (not is_rdma_receiver_sm and __all_sync(0xffffffff, lane_id >= kNumForwarders or forwarder_retired[lane_id])) + break; - // NVL layouts - void* local_nvl_buffer = buffer_ptrs[nvl_rank]; - void* nvl_buffers[NUM_MAX_NVL_PEERS]; - #pragma unroll - for (int i = 0; i < NUM_MAX_NVL_PEERS; ++ i) - nvl_buffers[i] = buffer_ptrs[i]; - auto nvl_channel_x = AsymBuffer(local_nvl_buffer, num_max_nvl_chunked_recv_tokens * hidden_int4, NUM_MAX_NVL_PEERS, channel_id, num_channels).advance_also(nvl_buffers); - auto nvl_channel_src_meta = AsymBuffer(local_nvl_buffer, num_max_nvl_chunked_recv_tokens, NUM_MAX_NVL_PEERS, channel_id, num_channels).advance_also(nvl_buffers); - auto nvl_channel_topk_weights = AsymBuffer(local_nvl_buffer, num_max_nvl_chunked_recv_tokens * num_topk, NUM_MAX_NVL_PEERS, channel_id, num_channels).advance_also(nvl_buffers); - auto nvl_channel_head = AsymBuffer(nvl_buffers, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels, nvl_rank).advance_also(local_nvl_buffer); - auto nvl_channel_tail = AsymBuffer(local_nvl_buffer, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels).advance_also(nvl_buffers); - - // Combiner warp synchronization - __shared__ volatile int forwarder_nvl_head[kNumForwarders][NUM_MAX_NVL_PEERS]; - __shared__ volatile bool forwarder_retired[kNumForwarders]; - __shared__ volatile int rdma_receiver_rdma_head[kNumRDMAReceivers][kNumRDMARanks]; - __shared__ volatile bool rdma_receiver_retired[kNumRDMAReceivers]; - auto sync_forwarder_smem = [=]() { asm volatile("bar.sync 0, %0;" :: "r"((kNumForwarders + 1) * 32)); }; - auto sync_rdma_receiver_smem = [=]() { asm volatile("bar.sync 1, %0;" :: "r"((kNumRDMAReceivers + 1) * 32)); }; - - if (warp_role == WarpRole::kNVLAndRDMAForwarder) { - // Receive from NVL ranks and forward to RDMA ranks - // NOTES: this part is using "large warps" for each RDMA ranks - const auto dst_rdma_rank = warp_id / kNumWarpsPerForwarder; - const auto sub_warp_id = warp_id % kNumWarpsPerForwarder; - auto send_buffer = dst_rdma_rank == rdma_rank ? rdma_channel_data.recv_buffer(dst_rdma_rank) : rdma_channel_data.send_buffer(dst_rdma_rank); - auto sync_large_warp = [=]() { - if (kNumWarpsPerForwarder == 1) { - __syncwarp(); - } else { - asm volatile("bar.sync %0, %1;" :: "r"(dst_rdma_rank + 2), "r"(kNumWarpsPerForwarder * 32)); - } - }; - EP_STATIC_ASSERT(kNumWarpsPerForwarder == 1 or kNumRDMARanks + 2 <= 16, "Barriers are not enough"); - - // Advance to the corresponding NVL buffer - nvl_channel_x.advance(dst_rdma_rank * num_max_nvl_chunked_recv_tokens_per_rdma * hidden_int4); - nvl_channel_src_meta.advance(dst_rdma_rank * num_max_nvl_chunked_recv_tokens_per_rdma); - nvl_channel_topk_weights.advance(dst_rdma_rank * num_max_nvl_chunked_recv_tokens_per_rdma * num_topk); - nvl_channel_head.advance(dst_rdma_rank); - nvl_channel_tail.advance(dst_rdma_rank); - - // Clean shared memory and sync - EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS <= 32, "Invalid number of NVL peers"); - lane_id < NUM_MAX_NVL_PEERS ? (forwarder_nvl_head[warp_id][lane_id] = 0) : 0; - lane_id == 0 ? (forwarder_retired[warp_id] = false) : false; - sync_forwarder_smem(); - - // Get count and cached head - int cached_nvl_channel_tail_idx = 0; - int num_tokens_to_combine = rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id]; - int num_tokens_prefix = channel_id == 0 ? 0 : rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id - 1]; - num_tokens_to_combine -= num_tokens_prefix; - num_tokens_prefix += dst_rdma_rank == 0 ? 0 : rdma_rank_prefix_sum[dst_rdma_rank - 1]; - combined_nvl_head += num_tokens_prefix * NUM_MAX_NVL_PEERS; - - // Iterate over all tokens and combine by chunks - for (int token_start_idx = 0; token_start_idx < num_tokens_to_combine; token_start_idx += num_max_rdma_chunked_send_tokens) { - // Check destination queue emptiness, or wait a buffer to be released - auto token_end_idx = min(token_start_idx + num_max_rdma_chunked_send_tokens, num_tokens_to_combine); - auto num_chunked_tokens = token_end_idx - token_start_idx; - auto start_time = clock64(); - while (sub_warp_id == 0 and lane_id == 0) { - // Inequality: `num_max_rdma_chunked_recv_tokens - (tail - head) >= num_chunked_tokens` - // Here, `token_start_idx` is the actual tail - int num_used_slots = token_start_idx - ld_volatile_global(rdma_channel_head.buffer(dst_rdma_rank)); - if (num_max_rdma_chunked_recv_tokens - num_used_slots >= num_chunked_tokens) - break; - - // Timeout check - if (clock64() - start_time > NUM_TIMEOUT_CYCLES) { - printf("DeepEP combine forwarder (RDMA check) timeout, channel: %d, RDMA: %d, nvl: %d, dst RDMA: %d, head: %ld, tail: %d, chunked: %d\n", - channel_id, rdma_rank, nvl_rank, dst_rdma_rank, ld_volatile_global(rdma_channel_head.buffer(dst_rdma_rank)), token_start_idx, num_chunked_tokens); - trap(); - } - } - sync_large_warp(); - - // Combine and write to the RDMA buffer - for (int token_idx = token_start_idx + sub_warp_id; token_idx < token_end_idx; token_idx += kNumWarpsPerForwarder) { - // Read expected head - EP_STATIC_ASSERT(kNumRDMARanks <= 32, "Invalid number of RDMA peers"); - int expected_head = -1; - if (lane_id < NUM_MAX_NVL_PEERS) - expected_head = ld_nc_global(combined_nvl_head + token_idx * NUM_MAX_NVL_PEERS + lane_id); - - // Wait lanes to be ready - start_time = clock64(); - while (cached_nvl_channel_tail_idx <= expected_head) { - cached_nvl_channel_tail_idx = ld_acquire_sys_global(nvl_channel_tail.buffer(lane_id)); - - // Timeout check - if (clock64() - start_time > NUM_TIMEOUT_CYCLES and lane_id < NUM_MAX_NVL_PEERS) { - printf("DeepEP combine forwarder (NVL check) timeout, channel: %d, RDMA: %d, nvl: %d, src NVL: %d, dst RDMA: %d, tail: %d, waiting: %d, total: %d, sub: %d, large: %d, expected: %d\n", - channel_id, rdma_rank, nvl_rank, lane_id, dst_rdma_rank, cached_nvl_channel_tail_idx, token_idx, num_tokens_to_combine, sub_warp_id, kNumWarpsPerForwarder, expected_head); - trap(); - } - } - - // Combine current token - auto rdma_slot_idx = token_idx % num_max_rdma_chunked_recv_tokens; - void* shifted = send_buffer + rdma_slot_idx * num_bytes_per_rdma_token; - auto recv_fn = [&](int src_nvl_rank, int slot_idx, int hidden_int4_idx) -> int4 { return ld_nc_global(nvl_channel_x.buffer(src_nvl_rank) + slot_idx * hidden_int4 + hidden_int4_idx); }; - auto recv_tw_fn = [&](int src_nvl_rank, int slot_idx, int topk_idx) -> float { return ld_nc_global(nvl_channel_topk_weights.buffer(src_nvl_rank) + slot_idx * num_topk + topk_idx); }; - combine_token(expected_head >= 0, - expected_head, lane_id, - hidden_int4, num_topk, - reinterpret_cast(shifted), - reinterpret_cast(reinterpret_cast(shifted) + hidden_bytes + sizeof(SourceMeta)), - num_max_nvl_chunked_recv_tokens_per_rdma, recv_fn, recv_tw_fn); - - // Update head - if (lane_id < NUM_MAX_NVL_PEERS) - expected_head < 0 ? (forwarder_nvl_head[warp_id][lane_id] = -expected_head - 1) : (forwarder_nvl_head[warp_id][lane_id] = expected_head + 1); - } - sync_large_warp(); - - // Issue RDMA send - if (sub_warp_id == kNumWarpsPerForwarder - 1) { - if (lane_id == 0) { - if (dst_rdma_rank == rdma_rank) { - mscclpp::atomicFetchAdd(reinterpret_cast(rdma_channel_tail.buffer(rdma_rank)), (uint64_t)num_chunked_tokens, mscclpp::memoryOrderRelease); - } else { - auto rdma_slot_idx = token_start_idx % num_max_rdma_chunked_recv_tokens; - const size_t num_bytes_per_msg = num_chunked_tokens * num_bytes_per_rdma_token; - auto dst_offset = rdma_rank * (num_max_rdma_chunked_recv_tokens * num_bytes_per_rdma_token) + rdma_slot_idx * num_bytes_per_rdma_token + data_recv_offset; - auto src_offset = dst_rdma_rank * (num_max_rdma_chunked_recv_tokens * num_bytes_per_rdma_token) + rdma_slot_idx * num_bytes_per_rdma_token + data_send_offset; - auto port_channel_idx = kLowLatencyMode ? (channel_id * kNumRDMARanks + dst_rdma_rank) : (channel_id * num_ranks + dst_rdma_rank * NUM_MAX_NVL_PEERS + nvl_rank); - auto& handle = port_channel_handles[port_channel_idx]; - handle.put(dst_offset, src_offset, num_bytes_per_msg); - - // Remote atomic add on the peer's tail counter: +num_chunked_tokens. - handle.atomicAdd(rdma_rank * sizeof(uint64_t) + tail_send_offset, (int64_t)num_chunked_tokens); - } - } - __syncwarp(); - } + // Find minimum head for RDMA ranks + if (is_rdma_receiver_sm) { + int min_head = std::numeric_limits::max(); +#pragma unroll + for (int i = 0; i < kNumRDMAReceivers; ++i) + if (not rdma_receiver_retired[i]) min_head = min(min_head, rdma_receiver_rdma_head[i][dst_rdma_rank]); + if (min_head != std::numeric_limits::max() and + min_head >= last_rdma_head + num_max_rdma_chunked_send_tokens and lane_id < kNumRDMARanks) { + if (dst_rdma_rank == rdma_rank) { + mscclpp::atomicFetchAdd(static_cast(rdma_channel_head.buffer(rdma_rank)), + (uint64_t)(min_head - last_rdma_head), mscclpp::memoryOrderRelease); + } else { + auto dst_offset = rdma_rank * sizeof(uint64_t) + head_send_offset; + auto port_channel_idx = kLowLatencyMode + ? (channel_id * kNumRDMARanks + dst_rdma_rank) + : (channel_id * num_ranks + dst_rdma_rank * NUM_MAX_NVL_PEERS + nvl_rank); + auto& handle = port_channel_handles[port_channel_idx]; + // Remote atomic add on the peer's head counter. + handle.atomicAdd(dst_offset, (int64_t)(min_head - last_rdma_head)); } - - // Retired - __syncwarp(); - if (lane_id == 0) - forwarder_retired[warp_id] = true; - } else if (warp_role == WarpRole::kRDMAReceiver) { - // Receive from RDMA ranks and write to the output tensor - // Clean shared memory and sync - EP_DEVICE_ASSERT(kNumRDMARanks <= 32); - lane_id < kNumRDMARanks ? (rdma_receiver_rdma_head[warp_id][lane_id] = 0) : 0; - lane_id == 0 ? (rdma_receiver_retired[warp_id] = false) : 0; - sync_rdma_receiver_smem(); - - // The same tokens as the dispatch process - int token_start_idx, token_end_idx; - get_channel_task_range(num_combined_tokens, num_channels, channel_id, token_start_idx, token_end_idx); - - // Iterate over all tokens and combine - int cached_channel_tail_idx = 0; - for (int64_t token_idx = token_start_idx + warp_id; token_idx < token_end_idx; token_idx += kNumRDMAReceivers) { - // Read expected head - EP_STATIC_ASSERT(kNumRDMARanks <= 32, "Invalid number of RDMA peers"); - int expected_head = -1; - if (lane_id < kNumRDMARanks) { - expected_head = ld_nc_global(combined_rdma_head + token_idx * kNumRDMARanks + lane_id); - (expected_head < 0) ? (rdma_receiver_rdma_head[warp_id][lane_id] = -expected_head - 1) : (rdma_receiver_rdma_head[warp_id][lane_id] = expected_head); - } - - // Wait lanes to be ready - auto start_time = clock64(); - while (cached_channel_tail_idx <= expected_head) { - cached_channel_tail_idx = static_cast(ld_acquire_sys_global(rdma_channel_tail.buffer(lane_id))); - - // Timeout check - if (clock64() - start_time > NUM_TIMEOUT_CYCLES) { - printf("DeepEP combine RDMA receiver timeout, channel: %d, RDMA: %d, nvl: %d, src RDMA: %d, tail: %d, waiting: %ld, expect: %d\n", - channel_id, rdma_rank, nvl_rank, lane_id, cached_channel_tail_idx, token_idx, expected_head); - trap(); - } - } - __syncwarp(); - - // Combine current token - auto recv_fn = [&](int src_rdma_rank, int slot_idx, int hidden_int4_idx) -> int4 { return ld_nc_global(reinterpret_cast(rdma_channel_data.recv_buffer(src_rdma_rank) + slot_idx * num_bytes_per_rdma_token) + hidden_int4_idx);}; - auto recv_tw_fn = [&](int src_rdma_rank, int slot_idx, int topk_idx) -> float { return ld_nc_global(reinterpret_cast(rdma_channel_data.recv_buffer(src_rdma_rank) + slot_idx * num_bytes_per_rdma_token + hidden_bytes + sizeof(SourceMeta)) + topk_idx);}; - combine_token(expected_head >= 0, - expected_head, lane_id, - hidden_int4, num_topk, - combined_x + token_idx * hidden_int4, - combined_topk_weights + token_idx * num_topk, - num_max_rdma_chunked_recv_tokens, recv_fn, recv_tw_fn); - } - - // Retired - __syncwarp(); - if (lane_id == 0) - rdma_receiver_retired[warp_id] = true; + last_rdma_head = min_head; + } } else { - // Coordinator - // Sync shared memory status - is_rdma_receiver_sm ? sync_rdma_receiver_smem() : sync_forwarder_smem(); - const auto num_warps_per_rdma_rank = kNumForwarders / kNumRDMARanks; - - int last_rdma_head = 0; - int last_nvl_head[kNumRDMARanks] = {0}; - int dst_rdma_rank = lane_id < kNumRDMARanks ? lane_id : 0; - int dst_nvl_rank = lane_id < NUM_MAX_NVL_PEERS ? lane_id : 0; - EP_STATIC_ASSERT(kNumCombineForwarderWarps <= 32, "Invalid number of forwarder warps"); - while (true) { - // Retired - if (is_rdma_receiver_sm and __all_sync(0xffffffff, lane_id >= kNumRDMAReceivers or rdma_receiver_retired[lane_id])) - break; - if (not is_rdma_receiver_sm and __all_sync(0xffffffff, lane_id >= kNumForwarders or forwarder_retired[lane_id])) - break; - - // Find minimum head for RDMA ranks - if (is_rdma_receiver_sm) { - int min_head = std::numeric_limits::max(); - #pragma unroll - for (int i = 0; i < kNumRDMAReceivers; ++ i) if (not rdma_receiver_retired[i]) - min_head = min(min_head, rdma_receiver_rdma_head[i][dst_rdma_rank]); - if (min_head != std::numeric_limits::max() and min_head >= last_rdma_head + num_max_rdma_chunked_send_tokens and lane_id < kNumRDMARanks) { - if (dst_rdma_rank == rdma_rank) { - mscclpp::atomicFetchAdd(static_cast(rdma_channel_head.buffer(rdma_rank)), (uint64_t)(min_head - last_rdma_head), mscclpp::memoryOrderRelease); - } else { - auto dst_offset = rdma_rank * sizeof(uint64_t) + head_send_offset; - auto port_channel_idx = kLowLatencyMode ? (channel_id * kNumRDMARanks + dst_rdma_rank) : (channel_id * num_ranks + dst_rdma_rank * NUM_MAX_NVL_PEERS + nvl_rank); - auto& handle = port_channel_handles[port_channel_idx]; - // Remote atomic add on the peer's head counter. - handle.atomicAdd(dst_offset, (int64_t)(min_head - last_rdma_head)); - } - last_rdma_head = min_head; - } - } else { - // Find minimum head for NVL ranks - #pragma unroll - for (int i = 0; i < kNumRDMARanks; ++ i) { - int min_head = std::numeric_limits::max(); - #pragma unroll - for (int j = 0; j < num_warps_per_rdma_rank; ++ j) if (not forwarder_retired[i * num_warps_per_rdma_rank + j]) - min_head = min(min_head, forwarder_nvl_head[i * num_warps_per_rdma_rank + j][dst_nvl_rank]); - if (min_head != std::numeric_limits::max() and min_head > last_nvl_head[i] and lane_id < NUM_MAX_NVL_PEERS) - st_relaxed_sys_global(nvl_channel_head.buffer_by(dst_nvl_rank) + i, last_nvl_head[i] = min_head); - } - } - - // Nanosleep and let other warps work - __nanosleep(NUM_WAIT_NANOSECONDS); - } +// Find minimum head for NVL ranks +#pragma unroll + for (int i = 0; i < kNumRDMARanks; ++i) { + int min_head = std::numeric_limits::max(); +#pragma unroll + for (int j = 0; j < num_warps_per_rdma_rank; ++j) + if (not forwarder_retired[i * num_warps_per_rdma_rank + j]) + min_head = min(min_head, forwarder_nvl_head[i * num_warps_per_rdma_rank + j][dst_nvl_rank]); + if (min_head != std::numeric_limits::max() and min_head > last_nvl_head[i] and + lane_id < NUM_MAX_NVL_PEERS) + st_relaxed_sys_global(nvl_channel_head.buffer_by(dst_nvl_rank) + i, last_nvl_head[i] = min_head); + } } + + // Nanosleep and let other warps work + __nanosleep(NUM_WAIT_NANOSECONDS); + } } + } } -void combine(cudaDataType_t type, - void* combined_x, float* combined_topk_weights, - const bool* is_combined_token_in_rank, - const void* x, const float* topk_weights, - const int* combined_rdma_head, const int* combined_nvl_head, - const void* src_meta, const int* rdma_channel_prefix_matrix, const int* rdma_rank_prefix_sum, const int* gbl_channel_prefix_matrix, - int num_tokens, int num_combined_tokens, int hidden, int num_topk, +void combine(cudaDataType_t type, void* combined_x, float* combined_topk_weights, const bool* is_combined_token_in_rank, + const void* x, const float* topk_weights, const int* combined_rdma_head, const int* combined_nvl_head, + const void* src_meta, const int* rdma_channel_prefix_matrix, const int* rdma_rank_prefix_sum, + const int* gbl_channel_prefix_matrix, int num_tokens, int num_combined_tokens, int hidden, int num_topk, void* rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens, - void** buffer_ptrs, int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens, - int rank, int num_ranks, cudaStream_t stream, int num_channels, bool low_latency_mode, - mscclpp::PortChannelDeviceHandle *port_channel_handles, - mscclpp::MemoryChannelDeviceHandle *memory_channel_handles) { - constexpr int kNumCombineForwarderWarps = 16; + void** buffer_ptrs, int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens, int rank, + int num_ranks, cudaStream_t stream, int num_channels, bool low_latency_mode, + mscclpp::PortChannelDeviceHandle* port_channel_handles, + mscclpp::MemoryChannelDeviceHandle* memory_channel_handles) { + constexpr int kNumCombineForwarderWarps = 16; -#define COMBINE_LAUNCH_CASE(num_rdma_ranks) { \ - auto combine_func = low_latency_mode ? \ - combine : combine; \ - LAUNCH_KERNEL(&cfg, combine_func, \ - reinterpret_cast(combined_x), combined_topk_weights, is_combined_token_in_rank, \ - reinterpret_cast(x), topk_weights, \ - combined_rdma_head, combined_nvl_head, \ - reinterpret_cast(src_meta), rdma_channel_prefix_matrix, rdma_rank_prefix_sum, gbl_channel_prefix_matrix, \ - num_tokens, num_combined_tokens, hidden, num_topk, \ - rdma_buffer_ptr, num_max_rdma_chunked_send_tokens, num_max_rdma_chunked_recv_tokens, \ - buffer_ptrs, num_max_nvl_chunked_send_tokens, num_max_nvl_chunked_recv_tokens, \ - rank, num_ranks, \ - port_channel_handles, memory_channel_handles); } break +#define COMBINE_LAUNCH_CASE(num_rdma_ranks) \ + { \ + auto combine_func = low_latency_mode ? combine \ + : combine; \ + LAUNCH_KERNEL(&cfg, combine_func, reinterpret_cast(combined_x), combined_topk_weights, \ + is_combined_token_in_rank, reinterpret_cast(x), topk_weights, combined_rdma_head, \ + combined_nvl_head, reinterpret_cast(src_meta), rdma_channel_prefix_matrix, \ + rdma_rank_prefix_sum, gbl_channel_prefix_matrix, num_tokens, num_combined_tokens, hidden, num_topk, \ + rdma_buffer_ptr, num_max_rdma_chunked_send_tokens, num_max_rdma_chunked_recv_tokens, buffer_ptrs, \ + num_max_nvl_chunked_send_tokens, num_max_nvl_chunked_recv_tokens, rank, num_ranks, \ + port_channel_handles, memory_channel_handles); \ + } \ + break - int num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS; - auto num_warps_per_forwarder = std::max(kNumCombineForwarderWarps / num_rdma_ranks, 1); - int num_forwarder_warps = num_rdma_ranks * num_warps_per_forwarder; - EP_HOST_ASSERT(num_forwarder_warps > 0 and num_forwarder_warps % num_rdma_ranks == 0); - EP_HOST_ASSERT(num_max_nvl_chunked_recv_tokens % num_rdma_ranks == 0); - EP_HOST_ASSERT(num_max_nvl_chunked_recv_tokens / num_rdma_ranks > std::max(num_max_rdma_chunked_send_tokens, num_max_nvl_chunked_send_tokens)); - EP_HOST_ASSERT(type == CUDA_R_16BF); + int num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS; + auto num_warps_per_forwarder = std::max(kNumCombineForwarderWarps / num_rdma_ranks, 1); + int num_forwarder_warps = num_rdma_ranks * num_warps_per_forwarder; + EP_HOST_ASSERT(num_forwarder_warps > 0 and num_forwarder_warps % num_rdma_ranks == 0); + EP_HOST_ASSERT(num_max_nvl_chunked_recv_tokens % num_rdma_ranks == 0); + EP_HOST_ASSERT(num_max_nvl_chunked_recv_tokens / num_rdma_ranks > + std::max(num_max_rdma_chunked_send_tokens, num_max_nvl_chunked_send_tokens)); + EP_HOST_ASSERT(type == CUDA_R_16BF); - SETUP_LAUNCH_CONFIG(num_channels * 2, (NUM_MAX_NVL_PEERS + num_forwarder_warps + 1) * 32, stream); - SWITCH_RDMA_RANKS(COMBINE_LAUNCH_CASE); + SETUP_LAUNCH_CONFIG(num_channels * 2, (NUM_MAX_NVL_PEERS + num_forwarder_warps + 1) * 32, stream); + SWITCH_RDMA_RANKS(COMBINE_LAUNCH_CASE); #undef COMBINE_LAUNCH_CASE } -} // namespace internode +} // namespace internode -} // namespace ep -} // namespace mscclpp +} // namespace ep +} // namespace mscclpp diff --git a/src/ext/ep/kernels/internode_ll.cu b/src/ext/ep/kernels/internode_ll.cu index 613f9e63..5c451577 100644 --- a/src/ext/ep/kernels/internode_ll.cu +++ b/src/ext/ep/kernels/internode_ll.cu @@ -26,15 +26,16 @@ // `test/python/ext/ep/test_low_latency_multirank.py`. Performance does NOT // match IBGDA (host-proxy adds latency); see README for measurements. +#include + +#include +#include + #include "configs.cuh" #include "exception.cuh" #include "launch.cuh" #include "utils.cuh" -#include -#include -#include - namespace cg = cooperative_groups; namespace mscclpp { @@ -50,102 +51,91 @@ namespace internode_ll { // pointers passed to the DeepEP LL kernels are virtual addresses aliased into // the caller's symmetric RDMA buffer; MSCCL++ expects offsets. __device__ __forceinline__ uint64_t rdma_offset_of(uint64_t ptr, void* rdma_buffer_ptr) { - return ptr - reinterpret_cast(rdma_buffer_ptr); + return ptr - reinterpret_cast(rdma_buffer_ptr); } // Translate a local pointer aliased into our own rdma_buffer_ptr into the // peer-mapped pointer for rank `peer_rank`. Only valid when the IPC fast path // is active (`peer_rdma_bases != nullptr`). -__device__ __forceinline__ uint64_t peer_ptr_of(uint64_t local_ptr, void* const* peer_rdma_bases, - void* rdma_buffer_ptr, int peer_rank) { - const auto off = local_ptr - reinterpret_cast(rdma_buffer_ptr); - return reinterpret_cast(peer_rdma_bases[peer_rank]) + off; +__device__ __forceinline__ uint64_t peer_ptr_of(uint64_t local_ptr, void* const* peer_rdma_bases, void* rdma_buffer_ptr, + int peer_rank) { + const auto off = local_ptr - reinterpret_cast(rdma_buffer_ptr); + return reinterpret_cast(peer_rdma_bases[peer_rank]) + off; } // Cross-rank barrier via port-channel signal/wait ring. // Uses port channel `qp=0` across all connected peers. -__device__ __forceinline__ void port_channel_barrier_block( - mscclpp::PortChannelDeviceHandle* port_channel_handles, - int rank, int num_ranks) { - const int tid = threadIdx.x; - if (tid < num_ranks && tid != rank) { - // Index: qp 0, peer = tid's rank (assumes peer_idx == rank in LL topology). - port_channel_handles[tid].signal(); - port_channel_handles[tid].wait(); - } - __syncthreads(); +__device__ __forceinline__ void port_channel_barrier_block(mscclpp::PortChannelDeviceHandle* port_channel_handles, + int rank, int num_ranks) { + const int tid = threadIdx.x; + if (tid < num_ranks && tid != rank) { + // Index: qp 0, peer = tid's rank (assumes peer_idx == rank in LL topology). + port_channel_handles[tid].signal(); + port_channel_handles[tid].wait(); + } + __syncthreads(); } // Same barrier but using MemoryChannels (CUDA IPC, NVLink). Used on the // intra-node LL fast path. -__device__ __forceinline__ void memory_channel_barrier_block( - mscclpp::MemoryChannelDeviceHandle* memory_channel_handles, - int rank, int num_ranks) { - const int tid = threadIdx.x; - if (tid < num_ranks && tid != rank) { - memory_channel_handles[tid].signal(); - memory_channel_handles[tid].wait(); - } - __syncthreads(); +__device__ __forceinline__ void memory_channel_barrier_block(mscclpp::MemoryChannelDeviceHandle* memory_channel_handles, + int rank, int num_ranks) { + const int tid = threadIdx.x; + if (tid < num_ranks && tid != rank) { + memory_channel_handles[tid].signal(); + memory_channel_handles[tid].wait(); + } + __syncthreads(); } template -__device__ __forceinline__ void ll_barrier_block( - mscclpp::PortChannelDeviceHandle* port_channel_handles, - mscclpp::MemoryChannelDeviceHandle* memory_channel_handles, - int rank, int num_ranks) { - if constexpr (kIpcPath) { - memory_channel_barrier_block(memory_channel_handles, rank, num_ranks); - } else { - port_channel_barrier_block(port_channel_handles, rank, num_ranks); - } +__device__ __forceinline__ void ll_barrier_block(mscclpp::PortChannelDeviceHandle* port_channel_handles, + mscclpp::MemoryChannelDeviceHandle* memory_channel_handles, int rank, + int num_ranks) { + if constexpr (kIpcPath) { + memory_channel_barrier_block(memory_channel_handles, rank, num_ranks); + } else { + port_channel_barrier_block(port_channel_handles, rank, num_ranks); + } } // --------------------------------------------------------------------------- // clean_low_latency_buffer // --------------------------------------------------------------------------- -template __launch_bounds__(kNumThreads, 1) -__global__ void clean_low_latency_buffer(int64_t* clean_0, int num_clean_int_0, - int64_t* clean_1, int num_clean_int_1, - mscclpp::PortChannelDeviceHandle* port_channel_handles, - mscclpp::MemoryChannelDeviceHandle* memory_channel_handles, - int rank, int num_ranks) { - // Barrier before cleaning (in case of unfinished chunked EP) - ll_barrier_block(port_channel_handles, memory_channel_handles, rank, num_ranks); +template +__launch_bounds__(kNumThreads, 1) __global__ + void clean_low_latency_buffer(int64_t* clean_0, int num_clean_int_0, int64_t* clean_1, int num_clean_int_1, + mscclpp::PortChannelDeviceHandle* port_channel_handles, + mscclpp::MemoryChannelDeviceHandle* memory_channel_handles, int rank, int num_ranks) { + // Barrier before cleaning (in case of unfinished chunked EP) + ll_barrier_block(port_channel_handles, memory_channel_handles, rank, num_ranks); - // Clean - auto thread_id = static_cast(threadIdx.x); - #pragma unroll - for (int i = thread_id; i < num_clean_int_0; i += kNumThreads) - clean_0[i] = 0; - #pragma unroll - for (int i = thread_id; i < num_clean_int_1; i += kNumThreads) - clean_1[i] = 0; + // Clean + auto thread_id = static_cast(threadIdx.x); +#pragma unroll + for (int i = thread_id; i < num_clean_int_0; i += kNumThreads) clean_0[i] = 0; +#pragma unroll + for (int i = thread_id; i < num_clean_int_1; i += kNumThreads) clean_1[i] = 0; - // Barrier after cleaning (make sure low-latency mode work fine) - ll_barrier_block(port_channel_handles, memory_channel_handles, rank, num_ranks); + // Barrier after cleaning (make sure low-latency mode work fine) + ll_barrier_block(port_channel_handles, memory_channel_handles, rank, num_ranks); } -void clean_low_latency_buffer(int64_t* clean_0, int num_clean_int_0, - int64_t* clean_1, int num_clean_int_1, - int rank, int num_ranks, - mscclpp::PortChannelDeviceHandle* port_channel_handles, - mscclpp::MemoryChannelDeviceHandle* memory_channel_handles, - bool use_ipc_path, +void clean_low_latency_buffer(int64_t* clean_0, int num_clean_int_0, int64_t* clean_1, int num_clean_int_1, int rank, + int num_ranks, mscclpp::PortChannelDeviceHandle* port_channel_handles, + mscclpp::MemoryChannelDeviceHandle* memory_channel_handles, bool use_ipc_path, cudaStream_t stream) { - constexpr int kNumThreads = 256; + constexpr int kNumThreads = 256; - SETUP_LAUNCH_CONFIG(1, kNumThreads, stream); - if (use_ipc_path) { - LAUNCH_KERNEL(&cfg, (clean_low_latency_buffer), - clean_0, num_clean_int_0, clean_1, num_clean_int_1, - port_channel_handles, memory_channel_handles, rank, num_ranks); - } else { - LAUNCH_KERNEL(&cfg, (clean_low_latency_buffer), - clean_0, num_clean_int_0, clean_1, num_clean_int_1, - port_channel_handles, memory_channel_handles, rank, num_ranks); - } + SETUP_LAUNCH_CONFIG(1, kNumThreads, stream); + if (use_ipc_path) { + LAUNCH_KERNEL(&cfg, (clean_low_latency_buffer), clean_0, num_clean_int_0, clean_1, + num_clean_int_1, port_channel_handles, memory_channel_handles, rank, num_ranks); + } else { + LAUNCH_KERNEL(&cfg, (clean_low_latency_buffer), clean_0, num_clean_int_0, clean_1, + num_clean_int_1, port_channel_handles, memory_channel_handles, rank, num_ranks); + } } // --------------------------------------------------------------------------- @@ -153,362 +143,333 @@ void clean_low_latency_buffer(int64_t* clean_0, int num_clean_int_0, // --------------------------------------------------------------------------- template -__global__ __launch_bounds__(kNumWarpGroups * kNumWarpsPerGroup * 32, 1) void -dispatch(void* packed_recv_x, float* packed_recv_x_scales, - int* packed_recv_src_info, int64_t* packed_recv_layout_range, - int* packed_recv_count, - void* rdma_recv_x, int64_t* rdma_recv_count, void* rdma_x, - const void* x, const int64_t* topk_idx, - int* atomic_counter_per_expert, int* atomic_finish_counter_per_expert, - int64_t* next_clean, int num_next_clean_int, - int num_tokens, int num_max_dispatch_tokens_per_rank, - int num_topk, int num_experts, int rank, int num_ranks, - int phases, - void* rdma_buffer_ptr, - mscclpp::PortChannelDeviceHandle* port_channel_handles, - void* const* peer_rdma_bases, - mscclpp::MemoryChannelDeviceHandle* memory_channel_handles) { - const auto sm_id = static_cast(blockIdx.x); - const auto thread_id = static_cast(threadIdx.x); - const auto warp_id = thread_id / 32, lane_id = get_lane_id(); - const auto num_sms = static_cast(gridDim.x); - const auto num_warps = kNumWarpGroups * kNumWarpsPerGroup; - const auto num_local_experts = num_experts / num_ranks; - const auto warp_group_id = warp_id / kNumWarpsPerGroup; - const auto sub_warp_id = warp_id % kNumWarpsPerGroup; - const auto responsible_expert_idx = sm_id * kNumWarpGroups + warp_group_id; +__global__ __launch_bounds__(kNumWarpGroups* kNumWarpsPerGroup * 32, 1) void dispatch( + void* packed_recv_x, float* packed_recv_x_scales, int* packed_recv_src_info, int64_t* packed_recv_layout_range, + int* packed_recv_count, void* rdma_recv_x, int64_t* rdma_recv_count, void* rdma_x, const void* x, + const int64_t* topk_idx, int* atomic_counter_per_expert, int* atomic_finish_counter_per_expert, int64_t* next_clean, + int num_next_clean_int, int num_tokens, int num_max_dispatch_tokens_per_rank, int num_topk, int num_experts, + int rank, int num_ranks, int phases, void* rdma_buffer_ptr, mscclpp::PortChannelDeviceHandle* port_channel_handles, + void* const* peer_rdma_bases, mscclpp::MemoryChannelDeviceHandle* memory_channel_handles) { + const auto sm_id = static_cast(blockIdx.x); + const auto thread_id = static_cast(threadIdx.x); + const auto warp_id = thread_id / 32, lane_id = get_lane_id(); + const auto num_sms = static_cast(gridDim.x); + const auto num_warps = kNumWarpGroups * kNumWarpsPerGroup; + const auto num_local_experts = num_experts / num_ranks; + const auto warp_group_id = warp_id / kNumWarpsPerGroup; + const auto sub_warp_id = warp_id % kNumWarpsPerGroup; + const auto responsible_expert_idx = sm_id * kNumWarpGroups + warp_group_id; - // FP8 staffs - constexpr int kNumPerChannels = 128; - constexpr float kFP8Margin = 1e-4, kFP8Amax = 448, kFP8AmaxInv = 1.0f / 448.0f; - const int num_scales = kHidden / kNumPerChannels; - const size_t hidden_bytes = kHidden * (kUseFP8 ? sizeof(__nv_fp8_storage_t) : sizeof(nv_bfloat16)); - const size_t hidden_int4 = hidden_bytes / sizeof(int4); + // FP8 staffs + constexpr int kNumPerChannels = 128; + constexpr float kFP8Margin = 1e-4, kFP8Amax = 448, kFP8AmaxInv = 1.0f / 448.0f; + const int num_scales = kHidden / kNumPerChannels; + const size_t hidden_bytes = kHidden * (kUseFP8 ? sizeof(__nv_fp8_storage_t) : sizeof(nv_bfloat16)); + const size_t hidden_int4 = hidden_bytes / sizeof(int4); - // Message package: hidden data, FP8 scales, index at source - using vec_t = typename std::conditional::type; - const size_t num_bytes_per_msg = sizeof(int4) + (kUseFP8 ? (kHidden + num_scales * sizeof(float)) : (kHidden * sizeof(nv_bfloat16))); - const size_t num_int4_per_msg = num_bytes_per_msg / sizeof(int4); - EP_DEVICE_ASSERT(num_bytes_per_msg % sizeof(int4) == 0); + // Message package: hidden data, FP8 scales, index at source + using vec_t = typename std::conditional::type; + const size_t num_bytes_per_msg = + sizeof(int4) + (kUseFP8 ? (kHidden + num_scales * sizeof(float)) : (kHidden * sizeof(nv_bfloat16))); + const size_t num_int4_per_msg = num_bytes_per_msg / sizeof(int4); + EP_DEVICE_ASSERT(num_bytes_per_msg % sizeof(int4) == 0); - // Sending phase - if ((phases & LOW_LATENCY_SEND_PHASE) == 0) - goto LOW_LATENCY_DISPATCH_RECV; + // Sending phase + if ((phases & LOW_LATENCY_SEND_PHASE) == 0) goto LOW_LATENCY_DISPATCH_RECV; - // Expert counts - __shared__ int shared_num_tokens_sent_per_expert[kNumWarpGroups]; + // Expert counts + __shared__ int shared_num_tokens_sent_per_expert[kNumWarpGroups]; - if (warp_id < num_warps - 1) { - constexpr int kNumElemsPerRead = sizeof(int4) / sizeof(nv_bfloat16); - EP_DEVICE_ASSERT(kHidden % kNumElemsPerRead == 0); - EP_STATIC_ASSERT(kNumElemsPerRead * 32 % kNumPerChannels == 0, "Invalid vectorization"); - const auto num_threads = (num_warps - 1) * 32; - const size_t hidden_bf16_int4 = kHidden / kNumElemsPerRead; + if (warp_id < num_warps - 1) { + constexpr int kNumElemsPerRead = sizeof(int4) / sizeof(nv_bfloat16); + EP_DEVICE_ASSERT(kHidden % kNumElemsPerRead == 0); + EP_STATIC_ASSERT(kNumElemsPerRead * 32 % kNumPerChannels == 0, "Invalid vectorization"); + const auto num_threads = (num_warps - 1) * 32; + const size_t hidden_bf16_int4 = kHidden / kNumElemsPerRead; - for (int token_idx = sm_id; token_idx < num_tokens; token_idx += num_sms) { - const auto x_int4 = reinterpret_cast(x) + token_idx * hidden_bf16_int4; - const auto rdma_x_src_idx = reinterpret_cast(reinterpret_cast(rdma_x) + token_idx * num_bytes_per_msg); - const auto rdma_x_vec = reinterpret_cast(reinterpret_cast(rdma_x_src_idx) + sizeof(int4)); - const auto rdma_x_scales = reinterpret_cast(reinterpret_cast(rdma_x_vec) + hidden_bytes); + for (int token_idx = sm_id; token_idx < num_tokens; token_idx += num_sms) { + const auto x_int4 = reinterpret_cast(x) + token_idx * hidden_bf16_int4; + const auto rdma_x_src_idx = + reinterpret_cast(reinterpret_cast(rdma_x) + token_idx * num_bytes_per_msg); + const auto rdma_x_vec = reinterpret_cast(reinterpret_cast(rdma_x_src_idx) + sizeof(int4)); + const auto rdma_x_scales = reinterpret_cast(reinterpret_cast(rdma_x_vec) + hidden_bytes); - auto dst_expert_idx = warp_id < num_topk ? static_cast(__ldg(topk_idx + token_idx * num_topk + warp_id)) : -1; - thread_id == 0 ? (*rdma_x_src_idx = token_idx) : 0; + auto dst_expert_idx = + warp_id < num_topk ? static_cast(__ldg(topk_idx + token_idx * num_topk + warp_id)) : -1; + thread_id == 0 ? (*rdma_x_src_idx = token_idx) : 0; - // FP8 cast - #pragma unroll - for (int i = thread_id; i < hidden_bf16_int4; i += num_threads) { - auto int4_value = __ldg(x_int4 + i); +// FP8 cast +#pragma unroll + for (int i = thread_id; i < hidden_bf16_int4; i += num_threads) { + auto int4_value = __ldg(x_int4 + i); - if (kUseFP8) { - auto bf16_values = reinterpret_cast(&int4_value); - float fp32_values[kNumElemsPerRead]; - float amax = kFP8Margin, scale, scale_inv; - #pragma unroll - for (int j = 0; j < kNumElemsPerRead; ++ j) { - fp32_values[j] = static_cast(bf16_values[j]); - amax = fmaxf(amax, fabsf(fp32_values[j])); - } + if (kUseFP8) { + auto bf16_values = reinterpret_cast(&int4_value); + float fp32_values[kNumElemsPerRead]; + float amax = kFP8Margin, scale, scale_inv; +#pragma unroll + for (int j = 0; j < kNumElemsPerRead; ++j) { + fp32_values[j] = static_cast(bf16_values[j]); + amax = fmaxf(amax, fabsf(fp32_values[j])); + } - EP_STATIC_ASSERT(kNumElemsPerRead * 32 / kNumPerChannels == 2, "Invalid vectorization"); - amax = half_warp_reduce_max(amax), scale = kFP8Amax / amax, scale_inv = amax * kFP8AmaxInv; - if (lane_id == 0 or lane_id == 16) - rdma_x_scales[i * kNumElemsPerRead / 128] = scale_inv; + EP_STATIC_ASSERT(kNumElemsPerRead * 32 / kNumPerChannels == 2, "Invalid vectorization"); + amax = half_warp_reduce_max(amax), scale = kFP8Amax / amax, scale_inv = amax * kFP8AmaxInv; + if (lane_id == 0 or lane_id == 16) rdma_x_scales[i * kNumElemsPerRead / 128] = scale_inv; - vec_t int2_value; - auto fp8x2_values = reinterpret_cast<__nv_fp8x2_storage_t*>(&int2_value); - #pragma unroll - for (int j = 0; j < kNumElemsPerRead; j += 2) { - float2 fp32x2 = {fp32_values[j] * scale, fp32_values[j + 1] * scale}; - fp8x2_values[j / 2] = __nv_cvt_float2_to_fp8x2(fp32x2, __NV_SATFINITE, __NV_E4M3); - } - rdma_x_vec[i] = int2_value; - } else { - rdma_x_vec[i] = *reinterpret_cast(&int4_value); - } - } - asm volatile("bar.sync 1, %0;" :: "r"(num_threads)); - - // Issue sends - if (dst_expert_idx >= 0) { - int slot_idx = lane_id == 0 ? atomicAdd(atomic_counter_per_expert + dst_expert_idx, 1) : 0; - slot_idx = __shfl_sync(0xffffffff, slot_idx, 0); - const auto dst_rank = dst_expert_idx / num_local_experts; - const auto dst_expert_local_idx = dst_expert_idx % num_local_experts; - const auto src_ptr = reinterpret_cast(rdma_x_src_idx); - const auto dst_ptr = reinterpret_cast(rdma_recv_x) + - dst_expert_local_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_msg + - rank * num_max_dispatch_tokens_per_rank * num_bytes_per_msg + - slot_idx * num_bytes_per_msg; - if (dst_rank != rank) { - if constexpr (kIpcPath) { - // Peer-mapped warp copy over NVLink (CUDA IPC). - const auto peer_dst = peer_ptr_of(dst_ptr, peer_rdma_bases, rdma_buffer_ptr, dst_rank); - const auto* src_int4_ptr = reinterpret_cast(src_ptr); - const auto* dst_int4_ptr = reinterpret_cast(peer_dst); - UNROLLED_WARP_COPY(8, lane_id, num_int4_per_msg, dst_int4_ptr, src_int4_ptr, ld_nc_global, st_na_global); - } else { - // MSCCL++ port-channel PUT (lane 0 issues one request). - if (lane_id == 0) { - const auto dst_off = rdma_offset_of(dst_ptr, rdma_buffer_ptr); - const auto src_off = rdma_offset_of(src_ptr, rdma_buffer_ptr); - port_channel_handles[dst_expert_local_idx * num_ranks + dst_rank] - .put(dst_off, src_off, num_bytes_per_msg); - } - __syncwarp(); - } - } else { - const auto* src_int4_ptr = reinterpret_cast(src_ptr); - const auto* dst_int4_ptr = reinterpret_cast(dst_ptr); - UNROLLED_WARP_COPY(8, lane_id, num_int4_per_msg, dst_int4_ptr, src_int4_ptr, ld_nc_global, st_na_global); - } - - __syncwarp(); - lane_id == 0 ? atomic_add_release_global(atomic_finish_counter_per_expert + dst_expert_idx, 1) : 0; - } - } - } else if (warp_id == num_warps - 1) { - EP_DEVICE_ASSERT(num_sms > 1); - if (sm_id == 0) { - // NOTE: DeepEP asserts `ibgda_get_state()->num_rc_per_pe >= num_local_experts` - // here. The MSCCL++ port relies on Buffer::sync() provisioning enough QPs - // (see `num_ib_connections_per_rank` / `num_port_channels_per_rank`). - - #pragma unroll - for (int i = lane_id; i < num_next_clean_int; i += 32) - next_clean[i] = 0; - - __syncwarp(); - #pragma unroll - for (int i = lane_id; i < num_experts; i += 32) - atomic_add_release_global(atomic_finish_counter_per_expert + i, FINISHED_SUM_TAG); - } - - int expert_count[kNumWarpGroups] = {0}; - const auto expert_begin_idx = sm_id * kNumWarpGroups; - const auto expert_end_idx = min(expert_begin_idx + kNumWarpGroups, num_experts); - - #pragma unroll 8 - for (int i = lane_id; i < num_tokens * num_topk; i += 32) { - auto idx = static_cast(__ldg(topk_idx + i)); - if (idx >= expert_begin_idx and idx < expert_end_idx) - expert_count[idx - expert_begin_idx] ++; - } - - #pragma unroll - for (int i = expert_begin_idx; i < expert_end_idx; ++ i) { - auto sum = warp_reduce_sum(expert_count[i - expert_begin_idx]); - if (lane_id == 0) { - shared_num_tokens_sent_per_expert[i - expert_begin_idx] = sum; - atomic_add_release_global(atomic_finish_counter_per_expert + i, FINISHED_SUM_TAG - sum); - } - } - } - __syncthreads(); - - // Issue count sends - if (responsible_expert_idx < num_experts and sub_warp_id == 0 and lane_id == 0) { - const auto dst_rank = responsible_expert_idx / num_local_experts; - const auto dst_expert_local_idx = responsible_expert_idx % num_local_experts; - const auto num_tokens_sent = shared_num_tokens_sent_per_expert[responsible_expert_idx - sm_id * kNumWarpGroups]; - - while (ld_acquire_global(atomic_finish_counter_per_expert + responsible_expert_idx) != FINISHED_SUM_TAG * 2); - if (dst_rank != rank) { - if constexpr (kIpcPath) { - // Single writer per (dst_expert_local_idx, rank) slot, so a - // release-ordered store on the peer-mapped counter is enough. - auto peer_counter = reinterpret_cast(peer_ptr_of( - reinterpret_cast(rdma_recv_count + dst_expert_local_idx * num_ranks + rank), - peer_rdma_bases, rdma_buffer_ptr, dst_rank)); - st_na_release(peer_counter, static_cast(-num_tokens_sent - 1)); - } else { - auto* counter_ptr = rdma_recv_count + dst_expert_local_idx * num_ranks + rank; - const auto off = rdma_offset_of(reinterpret_cast(counter_ptr), rdma_buffer_ptr); - port_channel_handles[dst_expert_local_idx * num_ranks + dst_rank] - .atomicAdd(off, static_cast(-num_tokens_sent - 1)); - } + vec_t int2_value; + auto fp8x2_values = reinterpret_cast<__nv_fp8x2_storage_t*>(&int2_value); +#pragma unroll + for (int j = 0; j < kNumElemsPerRead; j += 2) { + float2 fp32x2 = {fp32_values[j] * scale, fp32_values[j + 1] * scale}; + fp8x2_values[j / 2] = __nv_cvt_float2_to_fp8x2(fp32x2, __NV_SATFINITE, __NV_E4M3); + } + rdma_x_vec[i] = int2_value; } else { - st_na_release(rdma_recv_count + dst_expert_local_idx * num_ranks + rank, - static_cast(-num_tokens_sent - 1)); + rdma_x_vec[i] = *reinterpret_cast(&int4_value); } + } + asm volatile("bar.sync 1, %0;" ::"r"(num_threads)); - atomic_counter_per_expert[responsible_expert_idx] = 0; - atomic_finish_counter_per_expert[responsible_expert_idx] = 0; - - if (dst_rank == 0) - packed_recv_count[dst_expert_local_idx] = 0; - } - __syncwarp(); - - // Receiving phase - LOW_LATENCY_DISPATCH_RECV: - if ((phases & LOW_LATENCY_RECV_PHASE) == 0) - return; - - if (phases & LOW_LATENCY_SEND_PHASE) - cg::this_grid().sync(); - - if (responsible_expert_idx < num_experts) { - const auto src_rank = responsible_expert_idx / num_local_experts; - const auto local_expert_idx = responsible_expert_idx % num_local_experts; - const auto rdma_recv_x_uint8 = reinterpret_cast(rdma_recv_x) + - local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_msg + - src_rank * num_max_dispatch_tokens_per_rank * num_bytes_per_msg; - const auto recv_x_int4 = reinterpret_cast(packed_recv_x) + - local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * hidden_int4; - const auto recv_x_scales = packed_recv_x_scales + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_scales; - const auto recv_src_info = packed_recv_src_info + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank; - const auto recv_range = packed_recv_layout_range + local_expert_idx * num_ranks; - - __shared__ int shared_num_recv_tokens[kNumWarpGroups], shared_recv_token_begin_idx[kNumWarpGroups]; - - int num_recv_tokens, recv_token_begin_idx; - EP_STATIC_ASSERT(kNumWarpsPerGroup > 1, "Requires more than one warp per group"); - if (sub_warp_id == 1 and lane_id == 0) { - int64_t raw; - while ((raw = ld_acquire_sys_global(rdma_recv_count + local_expert_idx * num_ranks + src_rank)) == 0); - num_recv_tokens = static_cast(-raw - 1); - recv_token_begin_idx = atomicAdd(packed_recv_count + local_expert_idx, num_recv_tokens); - shared_num_recv_tokens[warp_group_id] = num_recv_tokens; - shared_recv_token_begin_idx[warp_group_id] = recv_token_begin_idx; - recv_range[src_rank] = pack2(num_recv_tokens, recv_token_begin_idx); - } - asm volatile("bar.sync %0, %1;" :: "r"(warp_group_id + 2), "r"(kNumWarpsPerGroup * 32)); - num_recv_tokens = shared_num_recv_tokens[warp_group_id]; - recv_token_begin_idx = shared_recv_token_begin_idx[warp_group_id]; - - EP_DEVICE_ASSERT(num_scales <= 64); - for (int i = sub_warp_id; i < num_recv_tokens; i += kNumWarpsPerGroup) { - const auto src_src_idx = reinterpret_cast(rdma_recv_x_uint8 + i * num_bytes_per_msg); - if (lane_id == 0) - recv_src_info[recv_token_begin_idx + i] = ld_nc_global(src_src_idx); - __syncwarp(); - - const auto src_data = reinterpret_cast(reinterpret_cast(src_src_idx) + sizeof(int4)); - const auto dst_data = recv_x_int4 + (recv_token_begin_idx + i) * hidden_int4; - UNROLLED_WARP_COPY(7, lane_id, hidden_int4, dst_data, src_data, ld_nc_global, st_na_global); - - if (kUseFP8) { - const auto src_scales = reinterpret_cast(reinterpret_cast(src_data) + hidden_bytes); - const auto dst_scales = reinterpret_cast(recv_x_scales + recv_token_begin_idx + i); - const auto scale_stride = num_ranks * num_max_dispatch_tokens_per_rank; - auto scale_0 = lane_id < num_scales ? ld_nc_global(src_scales + lane_id) : 0; - auto scale_1 = (lane_id + 32) < num_scales ? ld_nc_global(src_scales + lane_id + 32) : 0; - lane_id < num_scales ? dst_scales[lane_id * scale_stride] = scale_0 : 0.0f; - (lane_id + 32) < num_scales ? dst_scales[(lane_id + 32) * scale_stride] = scale_1 : 0.0f; + // Issue sends + if (dst_expert_idx >= 0) { + int slot_idx = lane_id == 0 ? atomicAdd(atomic_counter_per_expert + dst_expert_idx, 1) : 0; + slot_idx = __shfl_sync(0xffffffff, slot_idx, 0); + const auto dst_rank = dst_expert_idx / num_local_experts; + const auto dst_expert_local_idx = dst_expert_idx % num_local_experts; + const auto src_ptr = reinterpret_cast(rdma_x_src_idx); + const auto dst_ptr = reinterpret_cast(rdma_recv_x) + + dst_expert_local_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_msg + + rank * num_max_dispatch_tokens_per_rank * num_bytes_per_msg + slot_idx * num_bytes_per_msg; + if (dst_rank != rank) { + if constexpr (kIpcPath) { + // Peer-mapped warp copy over NVLink (CUDA IPC). + const auto peer_dst = peer_ptr_of(dst_ptr, peer_rdma_bases, rdma_buffer_ptr, dst_rank); + const auto* src_int4_ptr = reinterpret_cast(src_ptr); + const auto* dst_int4_ptr = reinterpret_cast(peer_dst); + UNROLLED_WARP_COPY(8, lane_id, num_int4_per_msg, dst_int4_ptr, src_int4_ptr, ld_nc_global, st_na_global); + } else { + // MSCCL++ port-channel PUT (lane 0 issues one request). + if (lane_id == 0) { + const auto dst_off = rdma_offset_of(dst_ptr, rdma_buffer_ptr); + const auto src_off = rdma_offset_of(src_ptr, rdma_buffer_ptr); + port_channel_handles[dst_expert_local_idx * num_ranks + dst_rank].put(dst_off, src_off, + num_bytes_per_msg); } + __syncwarp(); + } + } else { + const auto* src_int4_ptr = reinterpret_cast(src_ptr); + const auto* dst_int4_ptr = reinterpret_cast(dst_ptr); + UNROLLED_WARP_COPY(8, lane_id, num_int4_per_msg, dst_int4_ptr, src_int4_ptr, ld_nc_global, st_na_global); } + + __syncwarp(); + lane_id == 0 ? atomic_add_release_global(atomic_finish_counter_per_expert + dst_expert_idx, 1) : 0; + } } + } else if (warp_id == num_warps - 1) { + EP_DEVICE_ASSERT(num_sms > 1); + if (sm_id == 0) { + // NOTE: DeepEP asserts `ibgda_get_state()->num_rc_per_pe >= num_local_experts` + // here. The MSCCL++ port relies on Buffer::sync() provisioning enough QPs + // (see `num_ib_connections_per_rank` / `num_port_channels_per_rank`). + +#pragma unroll + for (int i = lane_id; i < num_next_clean_int; i += 32) next_clean[i] = 0; + + __syncwarp(); +#pragma unroll + for (int i = lane_id; i < num_experts; i += 32) + atomic_add_release_global(atomic_finish_counter_per_expert + i, FINISHED_SUM_TAG); + } + + int expert_count[kNumWarpGroups] = {0}; + const auto expert_begin_idx = sm_id * kNumWarpGroups; + const auto expert_end_idx = min(expert_begin_idx + kNumWarpGroups, num_experts); + +#pragma unroll 8 + for (int i = lane_id; i < num_tokens * num_topk; i += 32) { + auto idx = static_cast(__ldg(topk_idx + i)); + if (idx >= expert_begin_idx and idx < expert_end_idx) expert_count[idx - expert_begin_idx]++; + } + +#pragma unroll + for (int i = expert_begin_idx; i < expert_end_idx; ++i) { + auto sum = warp_reduce_sum(expert_count[i - expert_begin_idx]); + if (lane_id == 0) { + shared_num_tokens_sent_per_expert[i - expert_begin_idx] = sum; + atomic_add_release_global(atomic_finish_counter_per_expert + i, FINISHED_SUM_TAG - sum); + } + } + } + __syncthreads(); + + // Issue count sends + if (responsible_expert_idx < num_experts and sub_warp_id == 0 and lane_id == 0) { + const auto dst_rank = responsible_expert_idx / num_local_experts; + const auto dst_expert_local_idx = responsible_expert_idx % num_local_experts; + const auto num_tokens_sent = shared_num_tokens_sent_per_expert[responsible_expert_idx - sm_id * kNumWarpGroups]; + + while (ld_acquire_global(atomic_finish_counter_per_expert + responsible_expert_idx) != FINISHED_SUM_TAG * 2) + ; + if (dst_rank != rank) { + if constexpr (kIpcPath) { + // Single writer per (dst_expert_local_idx, rank) slot, so a + // release-ordered store on the peer-mapped counter is enough. + auto peer_counter = reinterpret_cast( + peer_ptr_of(reinterpret_cast(rdma_recv_count + dst_expert_local_idx * num_ranks + rank), + peer_rdma_bases, rdma_buffer_ptr, dst_rank)); + st_na_release(peer_counter, static_cast(-num_tokens_sent - 1)); + } else { + auto* counter_ptr = rdma_recv_count + dst_expert_local_idx * num_ranks + rank; + const auto off = rdma_offset_of(reinterpret_cast(counter_ptr), rdma_buffer_ptr); + port_channel_handles[dst_expert_local_idx * num_ranks + dst_rank].atomicAdd( + off, static_cast(-num_tokens_sent - 1)); + } + } else { + st_na_release(rdma_recv_count + dst_expert_local_idx * num_ranks + rank, + static_cast(-num_tokens_sent - 1)); + } + + atomic_counter_per_expert[responsible_expert_idx] = 0; + atomic_finish_counter_per_expert[responsible_expert_idx] = 0; + + if (dst_rank == 0) packed_recv_count[dst_expert_local_idx] = 0; + } + __syncwarp(); + +// Receiving phase +LOW_LATENCY_DISPATCH_RECV: + if ((phases & LOW_LATENCY_RECV_PHASE) == 0) return; + + if (phases & LOW_LATENCY_SEND_PHASE) cg::this_grid().sync(); + + if (responsible_expert_idx < num_experts) { + const auto src_rank = responsible_expert_idx / num_local_experts; + const auto local_expert_idx = responsible_expert_idx % num_local_experts; + const auto rdma_recv_x_uint8 = reinterpret_cast(rdma_recv_x) + + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_msg + + src_rank * num_max_dispatch_tokens_per_rank * num_bytes_per_msg; + const auto recv_x_int4 = reinterpret_cast(packed_recv_x) + + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * hidden_int4; + const auto recv_x_scales = + packed_recv_x_scales + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_scales; + const auto recv_src_info = packed_recv_src_info + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank; + const auto recv_range = packed_recv_layout_range + local_expert_idx * num_ranks; + + __shared__ int shared_num_recv_tokens[kNumWarpGroups], shared_recv_token_begin_idx[kNumWarpGroups]; + + int num_recv_tokens, recv_token_begin_idx; + EP_STATIC_ASSERT(kNumWarpsPerGroup > 1, "Requires more than one warp per group"); + if (sub_warp_id == 1 and lane_id == 0) { + int64_t raw; + while ((raw = ld_acquire_sys_global(rdma_recv_count + local_expert_idx * num_ranks + src_rank)) == 0) + ; + num_recv_tokens = static_cast(-raw - 1); + recv_token_begin_idx = atomicAdd(packed_recv_count + local_expert_idx, num_recv_tokens); + shared_num_recv_tokens[warp_group_id] = num_recv_tokens; + shared_recv_token_begin_idx[warp_group_id] = recv_token_begin_idx; + recv_range[src_rank] = pack2(num_recv_tokens, recv_token_begin_idx); + } + asm volatile("bar.sync %0, %1;" ::"r"(warp_group_id + 2), "r"(kNumWarpsPerGroup * 32)); + num_recv_tokens = shared_num_recv_tokens[warp_group_id]; + recv_token_begin_idx = shared_recv_token_begin_idx[warp_group_id]; + + EP_DEVICE_ASSERT(num_scales <= 64); + for (int i = sub_warp_id; i < num_recv_tokens; i += kNumWarpsPerGroup) { + const auto src_src_idx = reinterpret_cast(rdma_recv_x_uint8 + i * num_bytes_per_msg); + if (lane_id == 0) recv_src_info[recv_token_begin_idx + i] = ld_nc_global(src_src_idx); + __syncwarp(); + + const auto src_data = reinterpret_cast(reinterpret_cast(src_src_idx) + sizeof(int4)); + const auto dst_data = recv_x_int4 + (recv_token_begin_idx + i) * hidden_int4; + UNROLLED_WARP_COPY(7, lane_id, hidden_int4, dst_data, src_data, ld_nc_global, st_na_global); + + if (kUseFP8) { + const auto src_scales = reinterpret_cast(reinterpret_cast(src_data) + hidden_bytes); + const auto dst_scales = reinterpret_cast(recv_x_scales + recv_token_begin_idx + i); + const auto scale_stride = num_ranks * num_max_dispatch_tokens_per_rank; + auto scale_0 = lane_id < num_scales ? ld_nc_global(src_scales + lane_id) : 0; + auto scale_1 = (lane_id + 32) < num_scales ? ld_nc_global(src_scales + lane_id + 32) : 0; + lane_id < num_scales ? dst_scales[lane_id * scale_stride] = scale_0 : 0.0f; + (lane_id + 32) < num_scales ? dst_scales[(lane_id + 32) * scale_stride] = scale_1 : 0.0f; + } + } + } } -void dispatch(void* packed_recv_x, float* packed_recv_x_scales, - int* packed_recv_src_info, int64_t* packed_recv_layout_range, - int* packed_recv_count, - void* rdma_recv_x, int64_t* rdma_recv_count, void* rdma_x, - const void* x, const int64_t* topk_idx, - int64_t* next_clean, int num_next_clean_int, - int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank, - int num_topk, int num_experts, int rank, int num_ranks, bool use_fp8, - void* workspace, cudaStream_t stream, int phases, - void* rdma_buffer_ptr, - mscclpp::PortChannelDeviceHandle* port_channel_handles, - void* const* peer_rdma_bases, - mscclpp::MemoryChannelDeviceHandle* memory_channel_handles, - bool use_ipc_path) { - constexpr int kNumMaxTopK = 9; - // (kNumWarpGroups, kNumWarpsPerGroup) is path-dependent. Intra-node IPC - // benefits from 1 expert per SM with 32 warps cooperating on the recv-side - // body (matches NCCL-EP's structure for num_experts <= num_sms). The - // PortChannel path is IB-bound and a wider grid only adds host-proxy FIFO - // contention and a costlier cg::this_grid().sync(), so we keep (3, 10). - constexpr int kNumWarpsPerGroupIpc = 32; - constexpr int kNumWarpGroupsIpc = 1; - constexpr int kNumWarpsPerGroupRdma = 10; - constexpr int kNumWarpGroupsRdma = 3; - EP_STATIC_ASSERT(kNumMaxTopK + 1 <= kNumWarpGroupsIpc * kNumWarpsPerGroupIpc, "Too many top-k selections"); - EP_STATIC_ASSERT(kNumMaxTopK + 1 <= kNumWarpGroupsRdma * kNumWarpsPerGroupRdma, "Too many top-k selections"); +void dispatch(void* packed_recv_x, float* packed_recv_x_scales, int* packed_recv_src_info, + int64_t* packed_recv_layout_range, int* packed_recv_count, void* rdma_recv_x, int64_t* rdma_recv_count, + void* rdma_x, const void* x, const int64_t* topk_idx, int64_t* next_clean, int num_next_clean_int, + int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank, int num_topk, int num_experts, int rank, + int num_ranks, bool use_fp8, void* workspace, cudaStream_t stream, int phases, void* rdma_buffer_ptr, + mscclpp::PortChannelDeviceHandle* port_channel_handles, void* const* peer_rdma_bases, + mscclpp::MemoryChannelDeviceHandle* memory_channel_handles, bool use_ipc_path) { + constexpr int kNumMaxTopK = 9; + // (kNumWarpGroups, kNumWarpsPerGroup) is path-dependent. Intra-node IPC + // benefits from 1 expert per SM with 32 warps cooperating on the recv-side + // body (matches NCCL-EP's structure for num_experts <= num_sms). The + // PortChannel path is IB-bound and a wider grid only adds host-proxy FIFO + // contention and a costlier cg::this_grid().sync(), so we keep (3, 10). + constexpr int kNumWarpsPerGroupIpc = 32; + constexpr int kNumWarpGroupsIpc = 1; + constexpr int kNumWarpsPerGroupRdma = 10; + constexpr int kNumWarpGroupsRdma = 3; + EP_STATIC_ASSERT(kNumMaxTopK + 1 <= kNumWarpGroupsIpc * kNumWarpsPerGroupIpc, "Too many top-k selections"); + EP_STATIC_ASSERT(kNumMaxTopK + 1 <= kNumWarpGroupsRdma * kNumWarpsPerGroupRdma, "Too many top-k selections"); - const int kNumWarpGroups = use_ipc_path ? kNumWarpGroupsIpc : kNumWarpGroupsRdma; - const int kNumWarpsPerGroup = use_ipc_path ? kNumWarpsPerGroupIpc : kNumWarpsPerGroupRdma; - const auto num_warps = kNumWarpGroups * kNumWarpsPerGroup; - const auto num_sms_base = cell_div(num_experts, kNumWarpGroups); - // LL dispatch/combine are latency-bound at typical problem sizes: for - // num_experts=32 the base grid is cell_div(32,3)=11 blocks, i.e. 8% of a - // 132-SM H100. The recv-side bodies stride tokens by `sm_id`, so extra - // blocks parallelize token work linearly when the transport is cheap. - // - // Only enabled on the IPC path: on the PortChannel path each extra block - // issues more concurrent PUTs into the host proxy FIFO, and the - // cg::this_grid().sync() barrier between phases costs more with a larger - // grid, which empirically regresses cross-node dispatch. - int device_num_sms = num_sms_base; - if (use_ipc_path) { - int cur_dev = 0; - cudaGetDevice(&cur_dev); - cudaDeviceGetAttribute(&device_num_sms, cudaDevAttrMultiProcessorCount, cur_dev); - } - const auto num_sms = std::max(num_sms_base, - std::min(device_num_sms, std::max(num_tokens, num_sms_base))); - EP_HOST_ASSERT(num_topk <= kNumMaxTopK); + const int kNumWarpGroups = use_ipc_path ? kNumWarpGroupsIpc : kNumWarpGroupsRdma; + const int kNumWarpsPerGroup = use_ipc_path ? kNumWarpsPerGroupIpc : kNumWarpsPerGroupRdma; + const auto num_warps = kNumWarpGroups * kNumWarpsPerGroup; + const auto num_sms_base = cell_div(num_experts, kNumWarpGroups); + // LL dispatch/combine are latency-bound at typical problem sizes: for + // num_experts=32 the base grid is cell_div(32,3)=11 blocks, i.e. 8% of a + // 132-SM H100. The recv-side bodies stride tokens by `sm_id`, so extra + // blocks parallelize token work linearly when the transport is cheap. + // + // Only enabled on the IPC path: on the PortChannel path each extra block + // issues more concurrent PUTs into the host proxy FIFO, and the + // cg::this_grid().sync() barrier between phases costs more with a larger + // grid, which empirically regresses cross-node dispatch. + int device_num_sms = num_sms_base; + if (use_ipc_path) { + int cur_dev = 0; + cudaGetDevice(&cur_dev); + cudaDeviceGetAttribute(&device_num_sms, cudaDevAttrMultiProcessorCount, cur_dev); + } + const auto num_sms = + std::max(num_sms_base, std::min(device_num_sms, std::max(num_tokens, num_sms_base))); + EP_HOST_ASSERT(num_topk <= kNumMaxTopK); - auto atomic_counter_per_expert = reinterpret_cast(workspace); - auto atomic_finish_counter_per_expert = atomic_counter_per_expert + num_experts; - EP_HOST_ASSERT(num_experts * sizeof(int) * 2 <= NUM_WORKSPACE_BYTES); + auto atomic_counter_per_expert = reinterpret_cast(workspace); + auto atomic_finish_counter_per_expert = atomic_counter_per_expert + num_experts; + EP_HOST_ASSERT(num_experts * sizeof(int) * 2 <= NUM_WORKSPACE_BYTES); -#define DISPATCH_LAUNCH_CASE(hidden_case) { \ -if (use_ipc_path) { \ - auto dispatch_func = use_fp8 ? dispatch \ - : dispatch; \ - LAUNCH_KERNEL(&cfg, dispatch_func, \ - packed_recv_x, packed_recv_x_scales, \ - packed_recv_src_info, packed_recv_layout_range, \ - packed_recv_count, \ - rdma_recv_x, rdma_recv_count, rdma_x, \ - x, topk_idx, \ - atomic_counter_per_expert, atomic_finish_counter_per_expert, \ - next_clean, num_next_clean_int, \ - num_tokens, num_max_dispatch_tokens_per_rank, \ - num_topk, num_experts, rank, num_ranks, phases, \ - rdma_buffer_ptr, port_channel_handles, \ - peer_rdma_bases, memory_channel_handles); \ -} else { \ - auto dispatch_func = use_fp8 ? dispatch \ - : dispatch; \ - LAUNCH_KERNEL(&cfg, dispatch_func, \ - packed_recv_x, packed_recv_x_scales, \ - packed_recv_src_info, packed_recv_layout_range, \ - packed_recv_count, \ - rdma_recv_x, rdma_recv_count, rdma_x, \ - x, topk_idx, \ - atomic_counter_per_expert, atomic_finish_counter_per_expert, \ - next_clean, num_next_clean_int, \ - num_tokens, num_max_dispatch_tokens_per_rank, \ - num_topk, num_experts, rank, num_ranks, phases, \ - rdma_buffer_ptr, port_channel_handles, \ - peer_rdma_bases, memory_channel_handles); \ -} } break +#define DISPATCH_LAUNCH_CASE(hidden_case) \ + { \ + if (use_ipc_path) { \ + auto dispatch_func = use_fp8 ? dispatch \ + : dispatch; \ + LAUNCH_KERNEL(&cfg, dispatch_func, packed_recv_x, packed_recv_x_scales, packed_recv_src_info, \ + packed_recv_layout_range, packed_recv_count, rdma_recv_x, rdma_recv_count, rdma_x, x, topk_idx, \ + atomic_counter_per_expert, atomic_finish_counter_per_expert, next_clean, num_next_clean_int, \ + num_tokens, num_max_dispatch_tokens_per_rank, num_topk, num_experts, rank, num_ranks, phases, \ + rdma_buffer_ptr, port_channel_handles, peer_rdma_bases, memory_channel_handles); \ + } else { \ + auto dispatch_func = use_fp8 ? dispatch \ + : dispatch; \ + LAUNCH_KERNEL(&cfg, dispatch_func, packed_recv_x, packed_recv_x_scales, packed_recv_src_info, \ + packed_recv_layout_range, packed_recv_count, rdma_recv_x, rdma_recv_count, rdma_x, x, topk_idx, \ + atomic_counter_per_expert, atomic_finish_counter_per_expert, next_clean, num_next_clean_int, \ + num_tokens, num_max_dispatch_tokens_per_rank, num_topk, num_experts, rank, num_ranks, phases, \ + rdma_buffer_ptr, port_channel_handles, peer_rdma_bases, memory_channel_handles); \ + } \ + } \ + break - SETUP_LAUNCH_CONFIG(num_sms, num_warps * 32, stream); - SWITCH_HIDDEN(DISPATCH_LAUNCH_CASE); + SETUP_LAUNCH_CONFIG(num_sms, num_warps * 32, stream); + SWITCH_HIDDEN(DISPATCH_LAUNCH_CASE); #undef DISPATCH_LAUNCH_CASE } @@ -517,247 +478,221 @@ if (use_ipc_path) { \ // --------------------------------------------------------------------------- template -__global__ __launch_bounds__(kNumWarpGroups * kNumWarpsPerGroup * 32, 1) void -combine(void* combined_x, - void* rdma_recv_x, int64_t* rdma_recv_flag, void* rdma_send_x, - const void* x, const int64_t* topk_idx, const float* topk_weights, - const int* src_info, const int64_t* layout_range, - int64_t* next_clean, int num_next_clean_int, - int* atomic_clean_flag, - int num_combined_tokens, int hidden, int num_topk, - int num_max_dispatch_tokens_per_rank, - int num_experts, int rank, int num_ranks, - int phases, bool zero_copy, - void* rdma_buffer_ptr, - mscclpp::PortChannelDeviceHandle* port_channel_handles, - void* const* peer_rdma_bases, - mscclpp::MemoryChannelDeviceHandle* memory_channel_handles) { - const auto sm_id = static_cast(blockIdx.x); - const auto num_sms = static_cast(gridDim.x); - const auto thread_id = static_cast(threadIdx.x); - const auto num_threads = static_cast(blockDim.x); - const auto warp_id = thread_id / 32, lane_id = get_lane_id(); - const auto num_local_experts = num_experts / num_ranks; - const auto warp_group_id = warp_id / kNumWarpsPerGroup; - const auto sub_warp_id = warp_id % kNumWarpsPerGroup; - const auto responsible_expert_idx = sm_id * kNumWarpGroups + warp_group_id; +__global__ __launch_bounds__(kNumWarpGroups* kNumWarpsPerGroup * 32, 1) void combine( + void* combined_x, void* rdma_recv_x, int64_t* rdma_recv_flag, void* rdma_send_x, const void* x, + const int64_t* topk_idx, const float* topk_weights, const int* src_info, const int64_t* layout_range, + int64_t* next_clean, int num_next_clean_int, int* atomic_clean_flag, int num_combined_tokens, int hidden, + int num_topk, int num_max_dispatch_tokens_per_rank, int num_experts, int rank, int num_ranks, int phases, + bool zero_copy, void* rdma_buffer_ptr, mscclpp::PortChannelDeviceHandle* port_channel_handles, + void* const* peer_rdma_bases, mscclpp::MemoryChannelDeviceHandle* memory_channel_handles) { + const auto sm_id = static_cast(blockIdx.x); + const auto num_sms = static_cast(gridDim.x); + const auto thread_id = static_cast(threadIdx.x); + const auto num_threads = static_cast(blockDim.x); + const auto warp_id = thread_id / 32, lane_id = get_lane_id(); + const auto num_local_experts = num_experts / num_ranks; + const auto warp_group_id = warp_id / kNumWarpsPerGroup; + const auto sub_warp_id = warp_id % kNumWarpsPerGroup; + const auto responsible_expert_idx = sm_id * kNumWarpGroups + warp_group_id; - constexpr int kNumElemsPerInt4 = sizeof(int4) / sizeof(nv_bfloat16); - const size_t hidden_bf16_int4 = kHidden / kNumElemsPerInt4; + constexpr int kNumElemsPerInt4 = sizeof(int4) / sizeof(nv_bfloat16); + const size_t hidden_bf16_int4 = kHidden / kNumElemsPerInt4; - constexpr size_t num_bytes_per_slot = kHidden * sizeof(nv_bfloat16); - EP_STATIC_ASSERT(num_bytes_per_slot % sizeof(int4) == 0, "Invalid vectorization"); + constexpr size_t num_bytes_per_slot = kHidden * sizeof(nv_bfloat16); + EP_STATIC_ASSERT(num_bytes_per_slot % sizeof(int4) == 0, "Invalid vectorization"); - if ((phases & LOW_LATENCY_SEND_PHASE) == 0) - goto LOW_LATENCY_COMBINE_RECV; + if ((phases & LOW_LATENCY_SEND_PHASE) == 0) goto LOW_LATENCY_COMBINE_RECV; - if (sm_id == 0 and warp_group_id == 0 and sub_warp_id == 0) { - #pragma unroll - for (int i = lane_id; i < num_next_clean_int; i += 32) - next_clean[i] = 0; + if (sm_id == 0 and warp_group_id == 0 and sub_warp_id == 0) { +#pragma unroll + for (int i = lane_id; i < num_next_clean_int; i += 32) next_clean[i] = 0; - __syncwarp(); - if (lane_id == 0) - atomic_add_release_global(atomic_clean_flag, num_experts); + __syncwarp(); + if (lane_id == 0) atomic_add_release_global(atomic_clean_flag, num_experts); + } + + if (responsible_expert_idx < num_experts) { + const auto dst_rank = responsible_expert_idx / num_local_experts; + const auto local_expert_idx = responsible_expert_idx % num_local_experts; + const auto global_expert_idx = rank * num_local_experts + local_expert_idx; + const auto layout = __ldg(layout_range + local_expert_idx * num_ranks + dst_rank); + const auto local_x = reinterpret_cast(x) + + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * hidden_bf16_int4; + const auto local_src_info = src_info + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank; + const auto rdma_send_x_vec = reinterpret_cast(rdma_send_x) + + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_slot; + + int offset, num_tokens_to_send; + unpack2(layout, num_tokens_to_send, offset); + + for (int token_idx = offset + sub_warp_id; token_idx < offset + num_tokens_to_send; + token_idx += kNumWarpsPerGroup) { + const auto x_int4 = local_x + token_idx * hidden_bf16_int4; + const auto rdma_send_type_row = reinterpret_cast(rdma_send_x_vec + token_idx * num_bytes_per_slot); + const auto rdma_send_x_vec_row = reinterpret_cast(rdma_send_type_row); + + auto src_idx = __ldg(local_src_info + token_idx); + const auto buf_ptr = reinterpret_cast(rdma_send_x_vec_row); + const auto dst_ptr = reinterpret_cast(rdma_recv_x) + + (global_expert_idx * num_max_dispatch_tokens_per_rank + src_idx) * num_bytes_per_slot; + if (dst_rank == rank) { + const auto dst_int4_ptr = reinterpret_cast(dst_ptr); + UNROLLED_WARP_COPY(7, lane_id, hidden_bf16_int4, dst_int4_ptr, x_int4, ld_nc_global, st_na_global); + } else { + if constexpr (kIpcPath) { + // Peer-mapped warp copy over NVLink. `zero_copy` is irrelevant + // on this path because we skip the rdma_send staging buffer. + const auto peer_dst = peer_ptr_of(dst_ptr, peer_rdma_bases, rdma_buffer_ptr, dst_rank); + const auto peer_dst_int4 = reinterpret_cast(peer_dst); + UNROLLED_WARP_COPY(7, lane_id, hidden_bf16_int4, peer_dst_int4, x_int4, ld_nc_global, st_na_global); + } else { + const auto buf_int4_ptr = reinterpret_cast(buf_ptr); + if (not zero_copy) + UNROLLED_WARP_COPY(7, lane_id, hidden_bf16_int4, buf_int4_ptr, x_int4, ld_nc_global, st_na_global); + // MSCCL++ port-channel PUT. + if (lane_id == 0) { + const auto dst_off = rdma_offset_of(dst_ptr, rdma_buffer_ptr); + const auto src_off = rdma_offset_of(static_cast(buf_ptr), rdma_buffer_ptr); + port_channel_handles[local_expert_idx * num_ranks + dst_rank].put(dst_off, src_off, + hidden * sizeof(nv_bfloat16)); + } + __syncwarp(); + } + } } - if (responsible_expert_idx < num_experts) { - const auto dst_rank = responsible_expert_idx / num_local_experts; - const auto local_expert_idx = responsible_expert_idx % num_local_experts; - const auto global_expert_idx = rank * num_local_experts + local_expert_idx; - const auto layout = __ldg(layout_range + local_expert_idx * num_ranks + dst_rank); - const auto local_x = reinterpret_cast(x) + - local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * hidden_bf16_int4; - const auto local_src_info = src_info + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank; - const auto rdma_send_x_vec = reinterpret_cast(rdma_send_x) + - local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_slot; + EP_STATIC_ASSERT(kNumWarpsPerGroup > 1, "Requires more than one warp per group"); + asm volatile("bar.sync %0, %1;" ::"r"(warp_group_id + 1), "r"(kNumWarpsPerGroup * 32)); + if (sub_warp_id == 1 and lane_id == 0) { + while (ld_acquire_global(atomic_clean_flag) == 0) + ; + if (dst_rank != rank) { + if constexpr (kIpcPath) { + auto peer_flag = + reinterpret_cast(peer_ptr_of(reinterpret_cast(rdma_recv_flag + global_expert_idx), + peer_rdma_bases, rdma_buffer_ptr, dst_rank)); + st_na_release(peer_flag, static_cast(1)); + } else { + auto* flag_ptr = rdma_recv_flag + global_expert_idx; + const auto off = rdma_offset_of(reinterpret_cast(flag_ptr), rdma_buffer_ptr); + port_channel_handles[local_expert_idx * num_ranks + dst_rank].atomicAdd(off, static_cast(1)); + } + } else { + st_na_release(rdma_recv_flag + global_expert_idx, static_cast(1)); + } + atomic_add_release_global(atomic_clean_flag, -1); + } + __syncwarp(); + } - int offset, num_tokens_to_send; - unpack2(layout, num_tokens_to_send, offset); +LOW_LATENCY_COMBINE_RECV: + if ((phases & LOW_LATENCY_RECV_PHASE) == 0) return; - for (int token_idx = offset + sub_warp_id; token_idx < offset + num_tokens_to_send; token_idx += kNumWarpsPerGroup) { - const auto x_int4 = local_x + token_idx * hidden_bf16_int4; - const auto rdma_send_type_row = reinterpret_cast(rdma_send_x_vec + token_idx * num_bytes_per_slot); - const auto rdma_send_x_vec_row = reinterpret_cast(rdma_send_type_row); + if (responsible_expert_idx < num_experts) { + EP_STATIC_ASSERT(kNumWarpsPerGroup > 1, "Invalid number of warps per group"); + if (sub_warp_id == 0 and lane_id == 0) + while (ld_acquire_sys_global(rdma_recv_flag + responsible_expert_idx) == 0) + ; + } + cg::this_grid().sync(); - auto src_idx = __ldg(local_src_info + token_idx); - const auto buf_ptr = reinterpret_cast(rdma_send_x_vec_row); - const auto dst_ptr = reinterpret_cast(rdma_recv_x) + (global_expert_idx * num_max_dispatch_tokens_per_rank + src_idx) * num_bytes_per_slot; - if (dst_rank == rank) { - const auto dst_int4_ptr = reinterpret_cast(dst_ptr); - UNROLLED_WARP_COPY(7, lane_id, hidden_bf16_int4, dst_int4_ptr, x_int4, ld_nc_global, st_na_global); - } else { - if constexpr (kIpcPath) { - // Peer-mapped warp copy over NVLink. `zero_copy` is irrelevant - // on this path because we skip the rdma_send staging buffer. - const auto peer_dst = peer_ptr_of(dst_ptr, peer_rdma_bases, rdma_buffer_ptr, dst_rank); - const auto peer_dst_int4 = reinterpret_cast(peer_dst); - UNROLLED_WARP_COPY(7, lane_id, hidden_bf16_int4, peer_dst_int4, x_int4, ld_nc_global, st_na_global); - } else { - const auto buf_int4_ptr = reinterpret_cast(buf_ptr); - if (not zero_copy) - UNROLLED_WARP_COPY(7, lane_id, hidden_bf16_int4, buf_int4_ptr, x_int4, ld_nc_global, st_na_global); - // MSCCL++ port-channel PUT. - if (lane_id == 0) { - const auto dst_off = rdma_offset_of(dst_ptr, rdma_buffer_ptr); - const auto src_off = rdma_offset_of(static_cast(buf_ptr), rdma_buffer_ptr); - port_channel_handles[local_expert_idx * num_ranks + dst_rank] - .put(dst_off, src_off, hidden * sizeof(nv_bfloat16)); - } - __syncwarp(); - } - } + EP_DEVICE_ASSERT(num_topk <= 32 and hidden_bf16_int4 <= num_threads); + EP_STATIC_ASSERT(kHidden % (32 * kNumElemsPerInt4) == 0, "Invalid vectorization"); + if (thread_id < hidden_bf16_int4) { + for (int token_idx = sm_id; token_idx < num_combined_tokens; token_idx += num_sms) { + int reg_topk_idx[kNumMaxTopk]; + float reg_topk_weights[kNumMaxTopk]; +#pragma unroll + for (int i = 0; i < num_topk; ++i) { + reg_topk_idx[i] = static_cast(__ldg(topk_idx + token_idx * num_topk + i)); + reg_topk_weights[i] = __ldg(topk_weights + token_idx * num_topk + i); + } + + float combined_values[kNumElemsPerInt4] = {0.0f}; +#pragma unroll + for (int i = 0; i < num_topk; ++i) + if (reg_topk_idx[i] >= 0) { + auto rdma_buffer_type = reinterpret_cast( + reinterpret_cast(rdma_recv_x) + + (reg_topk_idx[i] * num_max_dispatch_tokens_per_rank + token_idx) * num_bytes_per_slot); + auto rdma_buffer_row = reinterpret_cast(rdma_buffer_type); + + auto x_vec = ld_nc_global(reinterpret_cast(rdma_buffer_row) + thread_id); + const auto x_bf16 = reinterpret_cast(&x_vec); +#pragma unroll + for (int j = 0; j < kNumElemsPerInt4; ++j) + combined_values[j] += static_cast(x_bf16[j]) * reg_topk_weights[i]; } - EP_STATIC_ASSERT(kNumWarpsPerGroup > 1, "Requires more than one warp per group"); - asm volatile("bar.sync %0, %1;" :: "r"(warp_group_id + 1), "r"(kNumWarpsPerGroup * 32)); - if (sub_warp_id == 1 and lane_id == 0) { - while (ld_acquire_global(atomic_clean_flag) == 0); - if (dst_rank != rank) { - if constexpr (kIpcPath) { - auto peer_flag = reinterpret_cast(peer_ptr_of( - reinterpret_cast(rdma_recv_flag + global_expert_idx), - peer_rdma_bases, rdma_buffer_ptr, dst_rank)); - st_na_release(peer_flag, static_cast(1)); - } else { - auto* flag_ptr = rdma_recv_flag + global_expert_idx; - const auto off = rdma_offset_of(reinterpret_cast(flag_ptr), rdma_buffer_ptr); - port_channel_handles[local_expert_idx * num_ranks + dst_rank] - .atomicAdd(off, static_cast(1)); - } - } else { - st_na_release(rdma_recv_flag + global_expert_idx, static_cast(1)); - } - atomic_add_release_global(atomic_clean_flag, -1); - } - __syncwarp(); - } - - LOW_LATENCY_COMBINE_RECV: - if ((phases & LOW_LATENCY_RECV_PHASE) == 0) - return; - - if (responsible_expert_idx < num_experts) { - EP_STATIC_ASSERT(kNumWarpsPerGroup > 1, "Invalid number of warps per group"); - if (sub_warp_id == 0 and lane_id == 0) - while (ld_acquire_sys_global(rdma_recv_flag + responsible_expert_idx) == 0); - } - cg::this_grid().sync(); - - EP_DEVICE_ASSERT(num_topk <= 32 and hidden_bf16_int4 <= num_threads); - EP_STATIC_ASSERT(kHidden % (32 * kNumElemsPerInt4) == 0, "Invalid vectorization"); - if (thread_id < hidden_bf16_int4) { - for (int token_idx = sm_id; token_idx < num_combined_tokens; token_idx += num_sms) { - int reg_topk_idx[kNumMaxTopk]; - float reg_topk_weights[kNumMaxTopk]; - #pragma unroll - for (int i = 0; i < num_topk; ++ i) { - reg_topk_idx[i] = static_cast(__ldg(topk_idx + token_idx * num_topk + i)); - reg_topk_weights[i] = __ldg(topk_weights + token_idx * num_topk + i); - } - - float combined_values[kNumElemsPerInt4] = {0.0f}; - #pragma unroll - for (int i = 0; i < num_topk; ++ i) if (reg_topk_idx[i] >= 0) { - auto rdma_buffer_type = reinterpret_cast(reinterpret_cast(rdma_recv_x) + (reg_topk_idx[i] * num_max_dispatch_tokens_per_rank + token_idx) * num_bytes_per_slot); - auto rdma_buffer_row = reinterpret_cast(rdma_buffer_type); - - auto x_vec = ld_nc_global(reinterpret_cast(rdma_buffer_row) + thread_id); - const auto x_bf16 = reinterpret_cast(&x_vec); - #pragma unroll - for (int j = 0; j < kNumElemsPerInt4; ++ j) - combined_values[j] += static_cast(x_bf16[j]) * reg_topk_weights[i]; - } - - int4& combined_int4 = *reinterpret_cast(combined_values); - auto combined_bf16 = reinterpret_cast(&combined_values); - #pragma unroll - for (int j = 0; j < kNumElemsPerInt4; ++ j) - combined_bf16[j] = static_cast(combined_values[j]); - (reinterpret_cast(combined_x) + token_idx * hidden_bf16_int4)[thread_id] = combined_int4; - } + int4& combined_int4 = *reinterpret_cast(combined_values); + auto combined_bf16 = reinterpret_cast(&combined_values); +#pragma unroll + for (int j = 0; j < kNumElemsPerInt4; ++j) combined_bf16[j] = static_cast(combined_values[j]); + (reinterpret_cast(combined_x) + token_idx * hidden_bf16_int4)[thread_id] = combined_int4; } + } } -void combine(void* combined_x, - void* rdma_recv_x, int64_t* rdma_recv_flag, void* rdma_send_x, - const void* x, const int64_t* topk_idx, const float* topk_weights, - const int* src_info, const int64_t* layout_range, - int64_t* next_clean, int num_next_clean_int, - int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank, - int num_topk, int num_experts, int rank, int num_ranks, - void* workspace, cudaStream_t stream, - int phases, bool zero_copy, - void* rdma_buffer_ptr, - mscclpp::PortChannelDeviceHandle* port_channel_handles, - void* const* peer_rdma_bases, - mscclpp::MemoryChannelDeviceHandle* memory_channel_handles, - bool use_ipc_path) { - // See the comment in `dispatch()`: (kNumWarpGroups, kNumWarpsPerGroup) - // is path-dependent. IPC uses (1, 32) to mirror NCCL-EP; PortChannel keeps - // (3, 10) to avoid host-proxy FIFO contention on the IB path. - constexpr int kNumWarpsPerGroupIpc = 32; - constexpr int kNumWarpGroupsIpc = 1; - constexpr int kNumWarpsPerGroupRdma = 10; - constexpr int kNumWarpGroupsRdma = 3; - constexpr int kNumMaxTopk = 9; +void combine(void* combined_x, void* rdma_recv_x, int64_t* rdma_recv_flag, void* rdma_send_x, const void* x, + const int64_t* topk_idx, const float* topk_weights, const int* src_info, const int64_t* layout_range, + int64_t* next_clean, int num_next_clean_int, int num_combined_tokens, int hidden, + int num_max_dispatch_tokens_per_rank, int num_topk, int num_experts, int rank, int num_ranks, + void* workspace, cudaStream_t stream, int phases, bool zero_copy, void* rdma_buffer_ptr, + mscclpp::PortChannelDeviceHandle* port_channel_handles, void* const* peer_rdma_bases, + mscclpp::MemoryChannelDeviceHandle* memory_channel_handles, bool use_ipc_path) { + // See the comment in `dispatch()`: (kNumWarpGroups, kNumWarpsPerGroup) + // is path-dependent. IPC uses (1, 32) to mirror NCCL-EP; PortChannel keeps + // (3, 10) to avoid host-proxy FIFO contention on the IB path. + constexpr int kNumWarpsPerGroupIpc = 32; + constexpr int kNumWarpGroupsIpc = 1; + constexpr int kNumWarpsPerGroupRdma = 10; + constexpr int kNumWarpGroupsRdma = 3; + constexpr int kNumMaxTopk = 9; - const int kNumWarpGroups = use_ipc_path ? kNumWarpGroupsIpc : kNumWarpGroupsRdma; - const int kNumWarpsPerGroup = use_ipc_path ? kNumWarpsPerGroupIpc : kNumWarpsPerGroupRdma; - const auto num_warps = kNumWarpGroups * kNumWarpsPerGroup; - const auto num_sms_base = cell_div(num_experts, kNumWarpGroups); - // See the comment in `dispatch()` above: combine-recv's per-token loop - // strides by `sm_id`, so extra blocks parallelize the weighted reduction - // linearly on the IPC path. Keep the baseline grid on the PortChannel - // path to avoid the cooperative-sync / proxy-FIFO overhead that regressed - // cross-node dispatch. - int device_num_sms = num_sms_base; - if (use_ipc_path) { - int cur_dev = 0; - cudaGetDevice(&cur_dev); - cudaDeviceGetAttribute(&device_num_sms, cudaDevAttrMultiProcessorCount, cur_dev); - } - const auto num_sms = std::max(num_sms_base, - std::min(device_num_sms, - std::max(num_combined_tokens, num_sms_base))); + const int kNumWarpGroups = use_ipc_path ? kNumWarpGroupsIpc : kNumWarpGroupsRdma; + const int kNumWarpsPerGroup = use_ipc_path ? kNumWarpsPerGroupIpc : kNumWarpsPerGroupRdma; + const auto num_warps = kNumWarpGroups * kNumWarpsPerGroup; + const auto num_sms_base = cell_div(num_experts, kNumWarpGroups); + // See the comment in `dispatch()` above: combine-recv's per-token loop + // strides by `sm_id`, so extra blocks parallelize the weighted reduction + // linearly on the IPC path. Keep the baseline grid on the PortChannel + // path to avoid the cooperative-sync / proxy-FIFO overhead that regressed + // cross-node dispatch. + int device_num_sms = num_sms_base; + if (use_ipc_path) { + int cur_dev = 0; + cudaGetDevice(&cur_dev); + cudaDeviceGetAttribute(&device_num_sms, cudaDevAttrMultiProcessorCount, cur_dev); + } + const auto num_sms = + std::max(num_sms_base, std::min(device_num_sms, std::max(num_combined_tokens, num_sms_base))); - auto atomic_clean_flag = reinterpret_cast(workspace); - EP_HOST_ASSERT(sizeof(int) <= NUM_WORKSPACE_BYTES); - EP_HOST_ASSERT(num_topk <= kNumMaxTopk); + auto atomic_clean_flag = reinterpret_cast(workspace); + EP_HOST_ASSERT(sizeof(int) <= NUM_WORKSPACE_BYTES); + EP_HOST_ASSERT(num_topk <= kNumMaxTopk); -#define COMBINE_LAUNCH_CASE(hidden_case) { \ -if (use_ipc_path) { \ - auto combine_func = combine; \ - LAUNCH_KERNEL(&cfg, combine_func, \ - combined_x, \ - rdma_recv_x, rdma_recv_flag, rdma_send_x, \ - x, topk_idx, topk_weights, src_info, layout_range, \ - next_clean, num_next_clean_int, \ - atomic_clean_flag, \ - num_combined_tokens, hidden, num_topk, \ - num_max_dispatch_tokens_per_rank, \ - num_experts, rank, num_ranks, \ - phases, zero_copy, \ - rdma_buffer_ptr, port_channel_handles, \ - peer_rdma_bases, memory_channel_handles); \ -} else { \ - auto combine_func = combine; \ - LAUNCH_KERNEL(&cfg, combine_func, \ - combined_x, \ - rdma_recv_x, rdma_recv_flag, rdma_send_x, \ - x, topk_idx, topk_weights, src_info, layout_range, \ - next_clean, num_next_clean_int, \ - atomic_clean_flag, \ - num_combined_tokens, hidden, num_topk, \ - num_max_dispatch_tokens_per_rank, \ - num_experts, rank, num_ranks, \ - phases, zero_copy, \ - rdma_buffer_ptr, port_channel_handles, \ - peer_rdma_bases, memory_channel_handles); \ -} } break +#define COMBINE_LAUNCH_CASE(hidden_case) \ + { \ + if (use_ipc_path) { \ + auto combine_func = combine; \ + LAUNCH_KERNEL(&cfg, combine_func, combined_x, rdma_recv_x, rdma_recv_flag, rdma_send_x, x, topk_idx, \ + topk_weights, src_info, layout_range, next_clean, num_next_clean_int, atomic_clean_flag, \ + num_combined_tokens, hidden, num_topk, num_max_dispatch_tokens_per_rank, num_experts, rank, \ + num_ranks, phases, zero_copy, rdma_buffer_ptr, port_channel_handles, peer_rdma_bases, \ + memory_channel_handles); \ + } else { \ + auto combine_func = combine; \ + LAUNCH_KERNEL(&cfg, combine_func, combined_x, rdma_recv_x, rdma_recv_flag, rdma_send_x, x, topk_idx, \ + topk_weights, src_info, layout_range, next_clean, num_next_clean_int, atomic_clean_flag, \ + num_combined_tokens, hidden, num_topk, num_max_dispatch_tokens_per_rank, num_experts, rank, \ + num_ranks, phases, zero_copy, rdma_buffer_ptr, port_channel_handles, peer_rdma_bases, \ + memory_channel_handles); \ + } \ + } \ + break - SETUP_LAUNCH_CONFIG(num_sms, num_warps * 32, stream); - SWITCH_HIDDEN(COMBINE_LAUNCH_CASE); + SETUP_LAUNCH_CONFIG(num_sms, num_warps * 32, stream); + SWITCH_HIDDEN(COMBINE_LAUNCH_CASE); #undef COMBINE_LAUNCH_CASE } diff --git a/src/ext/ep/kernels/intranode_kernel.cu b/src/ext/ep/kernels/intranode_kernel.cu index f6af5c66..b97fdb13 100644 --- a/src/ext/ep/kernels/intranode_kernel.cu +++ b/src/ext/ep/kernels/intranode_kernel.cu @@ -2,829 +2,822 @@ // Licensed under the MIT License. #include -#include "configs.cuh" #include "buffer.cuh" +#include "configs.cuh" #include "exception.cuh" #include "launch.cuh" #include "utils.cuh" -namespace mscclpp { namespace ep { +namespace mscclpp { +namespace ep { namespace intranode { -template -__global__ void -notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, - const int* num_tokens_per_expert, int* moe_recv_expert_counter_mapped, int num_experts, - int num_tokens, int num_channels, const bool* is_token_in_rank, int* channel_prefix_matrix, - int* rank_prefix_matrix_copy, int num_memset_int, int expert_alignment, - void** buffer_ptrs, int **task_fifo_ptrs, int head, int rank) { - auto sm_id = static_cast(blockIdx.x); - auto thread_id = static_cast(threadIdx.x), num_threads = static_cast(blockDim.x); - auto lane_id = thread_id % 32, warp_id = thread_id / 32, num_warps = num_threads / 32; +template +__global__ void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, + const int* num_tokens_per_expert, int* moe_recv_expert_counter_mapped, int num_experts, + int num_tokens, int num_channels, const bool* is_token_in_rank, + int* channel_prefix_matrix, int* rank_prefix_matrix_copy, int num_memset_int, + int expert_alignment, void** buffer_ptrs, int** task_fifo_ptrs, int head, int rank) { + auto sm_id = static_cast(blockIdx.x); + auto thread_id = static_cast(threadIdx.x), num_threads = static_cast(blockDim.x); + auto lane_id = thread_id % 32, warp_id = thread_id / 32, num_warps = num_threads / 32; - if (sm_id == 0) { - // Barrier first - barrier_device(task_fifo_ptrs, head, rank); - move_fifo_slots(head); - __syncthreads(); + if (sm_id == 0) { + // Barrier first + barrier_device(task_fifo_ptrs, head, rank); + move_fifo_slots(head); + __syncthreads(); - int *per_rank_buffer, *per_expert_buffer; - if (thread_id < kNumRanks) { - per_rank_buffer = reinterpret_cast(buffer_ptrs[thread_id]); - per_expert_buffer = per_rank_buffer + kNumRanks * kNumRanks; - } - - // After this loop: - // - `per_rank_buffer[rank][i, j]` means the number of tokens from rank i to rank j - // - `per_expert_buffer[rank][i, j]` means the number of tokens from rank i to local expert j - int num_experts_per_rank = num_experts / kNumRanks; - if (thread_id < kNumRanks) { - #pragma unroll - for (int i = 0; i < kNumRanks; ++ i) - per_rank_buffer[rank * kNumRanks + i] = num_tokens_per_rank[i]; - #pragma unroll - for (int i = 0; i < num_experts_per_rank; ++ i) - per_expert_buffer[rank * num_experts_per_rank + i] = num_tokens_per_expert[thread_id * num_experts_per_rank + i]; - } - __syncthreads(); - - // Wait for all ranks to be finished - barrier_device(task_fifo_ptrs, head, rank); - move_fifo_slots(head); - __syncthreads(); - - // Sum per-rank counts and return to CPU - // Also pre-compute the prefix sum for data sending - auto local_per_rank_buffer = reinterpret_cast(buffer_ptrs[rank]); - if (thread_id < kNumRanks) { - #pragma unroll - for (int i = 1; i < kNumRanks; ++ i) - local_per_rank_buffer[i * kNumRanks + thread_id] += local_per_rank_buffer[(i - 1) * kNumRanks + thread_id]; - if (thread_id == rank) - *moe_recv_counter_mapped = local_per_rank_buffer[(kNumRanks - 1) * kNumRanks + rank]; - } - - // Sum per-experts counts and return to CPU - auto local_per_expert_buffer = local_per_rank_buffer + kNumRanks * kNumRanks; - if (thread_id < num_experts_per_rank) { - int sum = 0; - #pragma unroll - for (int i = 0; i < kNumRanks; ++ i) - sum += local_per_expert_buffer[i * num_experts_per_rank + thread_id]; - sum = (sum + expert_alignment - 1) / expert_alignment * expert_alignment; - moe_recv_expert_counter_mapped[thread_id] = sum; - } - __syncthreads(); - - // Copy rank size prefix matrix to another tensor - #pragma unroll - for (int i = thread_id; i < kNumRanks * kNumRanks; i += num_threads) - rank_prefix_matrix_copy[i] = local_per_rank_buffer[i]; - - // Extra memset for later communication queue - #pragma unroll - for (int i = thread_id; i < num_memset_int; i += num_threads) - local_per_expert_buffer[i] = 0; - - // Barrier - memory_fence(); - __syncthreads(); - barrier_device(task_fifo_ptrs, head, rank); - } else { - int dst_rank = sm_id - 1; - for (int channel_id = warp_id; channel_id < num_channels; channel_id += num_warps) { - int token_start_idx, token_end_idx; - get_channel_task_range(num_tokens, num_channels, channel_id, token_start_idx, token_end_idx); - - // Iterate over tokens - int count = 0; - for (int64_t i = token_start_idx + lane_id; i < token_end_idx; i += 32) - count += is_token_in_rank[i * kNumRanks + dst_rank]; - count = warp_reduce_sum(count); - if (lane_id == 0) - channel_prefix_matrix[dst_rank * num_channels + channel_id] = count; - } - __syncthreads(); - - // Pre-compute prefix sum for all channels - if (thread_id == 0) { - #pragma unroll - for (int i = 1; i < num_channels; ++ i) - channel_prefix_matrix[dst_rank * num_channels + i] += channel_prefix_matrix[dst_rank * num_channels + i - 1]; - } + int *per_rank_buffer, *per_expert_buffer; + if (thread_id < kNumRanks) { + per_rank_buffer = reinterpret_cast(buffer_ptrs[thread_id]); + per_expert_buffer = per_rank_buffer + kNumRanks * kNumRanks; } + + // After this loop: + // - `per_rank_buffer[rank][i, j]` means the number of tokens from rank i to rank j + // - `per_expert_buffer[rank][i, j]` means the number of tokens from rank i to local expert j + int num_experts_per_rank = num_experts / kNumRanks; + if (thread_id < kNumRanks) { +#pragma unroll + for (int i = 0; i < kNumRanks; ++i) per_rank_buffer[rank * kNumRanks + i] = num_tokens_per_rank[i]; +#pragma unroll + for (int i = 0; i < num_experts_per_rank; ++i) + per_expert_buffer[rank * num_experts_per_rank + i] = + num_tokens_per_expert[thread_id * num_experts_per_rank + i]; + } + __syncthreads(); + + // Wait for all ranks to be finished + barrier_device(task_fifo_ptrs, head, rank); + move_fifo_slots(head); + __syncthreads(); + + // Sum per-rank counts and return to CPU + // Also pre-compute the prefix sum for data sending + auto local_per_rank_buffer = reinterpret_cast(buffer_ptrs[rank]); + if (thread_id < kNumRanks) { +#pragma unroll + for (int i = 1; i < kNumRanks; ++i) + local_per_rank_buffer[i * kNumRanks + thread_id] += local_per_rank_buffer[(i - 1) * kNumRanks + thread_id]; + if (thread_id == rank) *moe_recv_counter_mapped = local_per_rank_buffer[(kNumRanks - 1) * kNumRanks + rank]; + } + + // Sum per-experts counts and return to CPU + auto local_per_expert_buffer = local_per_rank_buffer + kNumRanks * kNumRanks; + if (thread_id < num_experts_per_rank) { + int sum = 0; +#pragma unroll + for (int i = 0; i < kNumRanks; ++i) sum += local_per_expert_buffer[i * num_experts_per_rank + thread_id]; + sum = (sum + expert_alignment - 1) / expert_alignment * expert_alignment; + moe_recv_expert_counter_mapped[thread_id] = sum; + } + __syncthreads(); + +// Copy rank size prefix matrix to another tensor +#pragma unroll + for (int i = thread_id; i < kNumRanks * kNumRanks; i += num_threads) + rank_prefix_matrix_copy[i] = local_per_rank_buffer[i]; + +// Extra memset for later communication queue +#pragma unroll + for (int i = thread_id; i < num_memset_int; i += num_threads) local_per_expert_buffer[i] = 0; + + // Barrier + memory_fence(); + __syncthreads(); + barrier_device(task_fifo_ptrs, head, rank); + } else { + int dst_rank = sm_id - 1; + for (int channel_id = warp_id; channel_id < num_channels; channel_id += num_warps) { + int token_start_idx, token_end_idx; + get_channel_task_range(num_tokens, num_channels, channel_id, token_start_idx, token_end_idx); + + // Iterate over tokens + int count = 0; + for (int64_t i = token_start_idx + lane_id; i < token_end_idx; i += 32) + count += is_token_in_rank[i * kNumRanks + dst_rank]; + count = warp_reduce_sum(count); + if (lane_id == 0) channel_prefix_matrix[dst_rank * num_channels + channel_id] = count; + } + __syncthreads(); + + // Pre-compute prefix sum for all channels + if (thread_id == 0) { +#pragma unroll + for (int i = 1; i < num_channels; ++i) + channel_prefix_matrix[dst_rank * num_channels + i] += channel_prefix_matrix[dst_rank * num_channels + i - 1]; + } + } } void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, int num_ranks, const int* num_tokens_per_expert, int* moe_recv_expert_counter_mapped, int num_experts, int num_tokens, const bool* is_token_in_rank, int* channel_prefix_matrix, - int* rank_prefix_matrix_copy, int num_memset_int, int expert_alignment, - void** buffer_ptrs, int **task_fifo_ptrs, int head, int rank, - cudaStream_t stream, int num_channels) { -#define NOTIFY_DISPATCH_LAUNCH_CASE(ranks) \ - LAUNCH_KERNEL(&cfg, notify_dispatch, \ - num_tokens_per_rank, moe_recv_counter_mapped, \ - num_tokens_per_expert, moe_recv_expert_counter_mapped, num_experts, \ - num_tokens, num_channels, is_token_in_rank, channel_prefix_matrix, \ - rank_prefix_matrix_copy, num_memset_int, expert_alignment, \ - buffer_ptrs, task_fifo_ptrs, head, rank); \ - break + int* rank_prefix_matrix_copy, int num_memset_int, int expert_alignment, void** buffer_ptrs, + int** task_fifo_ptrs, int head, int rank, cudaStream_t stream, int num_channels) { +#define NOTIFY_DISPATCH_LAUNCH_CASE(ranks) \ + LAUNCH_KERNEL(&cfg, notify_dispatch, num_tokens_per_rank, moe_recv_counter_mapped, num_tokens_per_expert, \ + moe_recv_expert_counter_mapped, num_experts, num_tokens, num_channels, is_token_in_rank, \ + channel_prefix_matrix, rank_prefix_matrix_copy, num_memset_int, expert_alignment, buffer_ptrs, \ + task_fifo_ptrs, head, rank); \ + break - constexpr int kNumThreads = 128; - EP_HOST_ASSERT(num_experts % num_ranks == 0); - EP_HOST_ASSERT(num_experts / num_ranks <= kNumThreads and num_ranks <= kNumThreads); + constexpr int kNumThreads = 128; + EP_HOST_ASSERT(num_experts % num_ranks == 0); + EP_HOST_ASSERT(num_experts / num_ranks <= kNumThreads and num_ranks <= kNumThreads); - SETUP_LAUNCH_CONFIG(1 + num_ranks, kNumThreads, stream); - SWITCH_RANKS(NOTIFY_DISPATCH_LAUNCH_CASE); + SETUP_LAUNCH_CONFIG(1 + num_ranks, kNumThreads, stream); + SWITCH_RANKS(NOTIFY_DISPATCH_LAUNCH_CASE); #undef NOTIFY_DISPATCH_LAUNCH_CASE } -template -__global__ void -cached_notify_dispatch(const int* rank_prefix_matrix, int num_memset_int, - void** buffer_ptrs, int** task_fifo_ptrs, int head, int rank) { - // A simplified version for cached handles - barrier_device(task_fifo_ptrs, head, rank); - move_fifo_slots(head); - __syncthreads(); +template +__global__ void cached_notify_dispatch(const int* rank_prefix_matrix, int num_memset_int, void** buffer_ptrs, + int** task_fifo_ptrs, int head, int rank) { + // A simplified version for cached handles + barrier_device(task_fifo_ptrs, head, rank); + move_fifo_slots(head); + __syncthreads(); - // Copy and clean - auto thread_id = static_cast(threadIdx.x), num_threads = static_cast(blockDim.x); - auto ptr = reinterpret_cast(buffer_ptrs[rank]); - #pragma unroll - for (int i = thread_id; i < kNumRanks * kNumRanks; i += num_threads) - ptr[i] = rank_prefix_matrix[i]; - #pragma unroll - for (int i = thread_id; i < num_memset_int; i += num_threads) - ptr[kNumRanks * kNumRanks + i] = 0; - memory_fence(); - __syncthreads(); + // Copy and clean + auto thread_id = static_cast(threadIdx.x), num_threads = static_cast(blockDim.x); + auto ptr = reinterpret_cast(buffer_ptrs[rank]); +#pragma unroll + for (int i = thread_id; i < kNumRanks * kNumRanks; i += num_threads) ptr[i] = rank_prefix_matrix[i]; +#pragma unroll + for (int i = thread_id; i < num_memset_int; i += num_threads) ptr[kNumRanks * kNumRanks + i] = 0; + memory_fence(); + __syncthreads(); - // Barrier after cleaning - barrier_device(task_fifo_ptrs, head, rank); + // Barrier after cleaning + barrier_device(task_fifo_ptrs, head, rank); } -void cached_notify_dispatch(const int* rank_prefix_matrix, int num_memset_int, - void** buffer_ptrs, int** task_fifo_ptrs, +void cached_notify_dispatch(const int* rank_prefix_matrix, int num_memset_int, void** buffer_ptrs, int** task_fifo_ptrs, int head, int rank, int num_ranks, cudaStream_t stream) { -#define CACHED_NOTIFY_DISPATCH_LAUNCH_CASE(ranks) \ - LAUNCH_KERNEL(&cfg, cached_notify_dispatch, \ - rank_prefix_matrix, num_memset_int, buffer_ptrs, task_fifo_ptrs, head, rank); \ - break +#define CACHED_NOTIFY_DISPATCH_LAUNCH_CASE(ranks) \ + LAUNCH_KERNEL(&cfg, cached_notify_dispatch, rank_prefix_matrix, num_memset_int, buffer_ptrs, task_fifo_ptrs, \ + head, rank); \ + break - SETUP_LAUNCH_CONFIG(1, 128, stream); - SWITCH_RANKS(CACHED_NOTIFY_DISPATCH_LAUNCH_CASE); + SETUP_LAUNCH_CONFIG(1, 128, stream); + SWITCH_RANKS(CACHED_NOTIFY_DISPATCH_LAUNCH_CASE); #undef CACHED_NOTIFY_DISPATCH_LAUNCH_CASE } template __global__ void __launch_bounds__(kNumThreads, 1) -dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_topk_idx, float* recv_topk_weights, int* recv_channel_offset, - int* send_head, const int4* x, const float* x_scales, const int64_t* topk_idx, const float* topk_weights, - const bool* is_token_in_rank, const int* channel_prefix_matrix, - int num_tokens, int hidden_int4, int num_topk, int num_experts, int num_scales, - void **buffer_ptrs, int rank, - int num_max_send_tokens, int num_recv_buffer_tokens) { - const auto num_sms = static_cast(gridDim.x), sm_id = static_cast(blockIdx.x); - const auto thread_id = static_cast(threadIdx.x); - const bool is_sender = sm_id % 2 == 0; - EP_DEVICE_ASSERT(num_sms % 2 == 0); + dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_topk_idx, float* recv_topk_weights, + int* recv_channel_offset, int* send_head, const int4* x, const float* x_scales, const int64_t* topk_idx, + const float* topk_weights, const bool* is_token_in_rank, const int* channel_prefix_matrix, int num_tokens, + int hidden_int4, int num_topk, int num_experts, int num_scales, void** buffer_ptrs, int rank, + int num_max_send_tokens, int num_recv_buffer_tokens) { + const auto num_sms = static_cast(gridDim.x), sm_id = static_cast(blockIdx.x); + const auto thread_id = static_cast(threadIdx.x); + const bool is_sender = sm_id % 2 == 0; + EP_DEVICE_ASSERT(num_sms % 2 == 0); - // Several warps are response for a single rank - const auto num_threads_per_rank = kNumThreads / kNumRanks; - const auto num_channels = num_sms / 2; - const auto responsible_rank = (static_cast(thread_id)) / num_threads_per_rank; - // Even-numbered blocks for sending, odd-numbered blocks for receiving. - const auto responsible_channel = sm_id / 2; + // Several warps are response for a single rank + const auto num_threads_per_rank = kNumThreads / kNumRanks; + const auto num_channels = num_sms / 2; + const auto responsible_rank = (static_cast(thread_id)) / num_threads_per_rank; + // Even-numbered blocks for sending, odd-numbered blocks for receiving. + const auto responsible_channel = sm_id / 2; - int num_experts_per_rank = num_experts / kNumRanks; - EP_DEVICE_ASSERT(num_experts_per_rank > 0 or num_topk == 0); - EP_DEVICE_ASSERT(num_topk <= 32); - EP_DEVICE_ASSERT((topk_idx == nullptr) == (topk_weights == nullptr)); - EP_DEVICE_ASSERT((recv_topk_idx == nullptr) == (recv_topk_weights == nullptr)); + int num_experts_per_rank = num_experts / kNumRanks; + EP_DEVICE_ASSERT(num_experts_per_rank > 0 or num_topk == 0); + EP_DEVICE_ASSERT(num_topk <= 32); + EP_DEVICE_ASSERT((topk_idx == nullptr) == (topk_weights == nullptr)); + EP_DEVICE_ASSERT((recv_topk_idx == nullptr) == (recv_topk_weights == nullptr)); - // Calculate pointers by the specific layout - // `rank_prefix_matrix`: kNumRanks * kNumRanks * sizeof(int) - auto ptr = reinterpret_cast(reinterpret_cast(buffer_ptrs[is_sender ? responsible_rank : rank]) + kNumRanks * kNumRanks * sizeof(int)); - int target_rank = is_sender ? rank : responsible_rank; - auto num_channels_total = num_channels * kNumRanks; - auto channel_rank_offset = responsible_channel * kNumRanks + target_rank; + // Calculate pointers by the specific layout + // `rank_prefix_matrix`: kNumRanks * kNumRanks * sizeof(int) + auto ptr = reinterpret_cast(reinterpret_cast(buffer_ptrs[is_sender ? responsible_rank : rank]) + + kNumRanks * kNumRanks * sizeof(int)); + int target_rank = is_sender ? rank : responsible_rank; + auto num_channels_total = num_channels * kNumRanks; + auto channel_rank_offset = responsible_channel * kNumRanks + target_rank; - // Channel buffer metadata - // Senders are responsible for tails, and receivers are responsible for heads - // Stored on the receiver side - // The retired signals are actually boolean flags, but to align with 16 bytes, we make it `int64_t` - // `start_offset`: kNumChannels * kNumRanks * sizeof(int) - // `end_offset`: kNumChannels * kNumRanks * sizeof(int) - // `head_idx`: kNumChannels * kNumRanks * sizeof(int) - // `tail_idx`: kNumChannels * kNumRanks * sizeof(int) - auto channel_start_offset = Buffer(ptr, num_channels_total, channel_rank_offset); - auto channel_end_offset = Buffer(ptr, num_channels_total, channel_rank_offset); - auto channel_head_idx = Buffer(ptr, num_channels_total, channel_rank_offset); - auto channel_tail_idx = Buffer(ptr, num_channels_total, channel_rank_offset); + // Channel buffer metadata + // Senders are responsible for tails, and receivers are responsible for heads + // Stored on the receiver side + // The retired signals are actually boolean flags, but to align with 16 bytes, we make it `int64_t` + // `start_offset`: kNumChannels * kNumRanks * sizeof(int) + // `end_offset`: kNumChannels * kNumRanks * sizeof(int) + // `head_idx`: kNumChannels * kNumRanks * sizeof(int) + // `tail_idx`: kNumChannels * kNumRanks * sizeof(int) + auto channel_start_offset = Buffer(ptr, num_channels_total, channel_rank_offset); + auto channel_end_offset = Buffer(ptr, num_channels_total, channel_rank_offset); + auto channel_head_idx = Buffer(ptr, num_channels_total, channel_rank_offset); + auto channel_tail_idx = Buffer(ptr, num_channels_total, channel_rank_offset); - // Channel data buffers, stored on the receiver side - // `x_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * hidden_int4 * sizeof(int4) - // `src_idx_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * sizeof(int) - // `topk_idx_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * num_topk * sizeof(int64_t) - // `topk_weights_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * num_topk * sizeof(float) - // `x_scales_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * num_scales * sizeof(float) - auto channel_x_buffers = Buffer(ptr, num_channels_total * num_recv_buffer_tokens * hidden_int4, channel_rank_offset * num_recv_buffer_tokens * hidden_int4); - auto channel_src_idx_buffers = Buffer(ptr, num_channels_total * num_recv_buffer_tokens, channel_rank_offset * num_recv_buffer_tokens); - auto channel_topk_idx_buffers = Buffer(ptr, num_channels_total * num_recv_buffer_tokens * num_topk, channel_rank_offset * num_recv_buffer_tokens * num_topk); - auto channel_topk_weights_buffers = Buffer(ptr, num_channels_total * num_recv_buffer_tokens * num_topk, channel_rank_offset * num_recv_buffer_tokens * num_topk); - auto channel_x_scales_buffers = Buffer(ptr, num_channels_total * num_recv_buffer_tokens * num_scales, channel_rank_offset * num_recv_buffer_tokens * num_scales); + // Channel data buffers, stored on the receiver side + // `x_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * hidden_int4 * sizeof(int4) + // `src_idx_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * sizeof(int) + // `topk_idx_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * num_topk * sizeof(int64_t) + // `topk_weights_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * num_topk * sizeof(float) + // `x_scales_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * num_scales * sizeof(float) + auto channel_x_buffers = Buffer(ptr, num_channels_total * num_recv_buffer_tokens * hidden_int4, + channel_rank_offset * num_recv_buffer_tokens * hidden_int4); + auto channel_src_idx_buffers = + Buffer(ptr, num_channels_total * num_recv_buffer_tokens, channel_rank_offset * num_recv_buffer_tokens); + auto channel_topk_idx_buffers = Buffer(ptr, num_channels_total * num_recv_buffer_tokens * num_topk, + channel_rank_offset * num_recv_buffer_tokens * num_topk); + auto channel_topk_weights_buffers = Buffer(ptr, num_channels_total * num_recv_buffer_tokens * num_topk, + channel_rank_offset * num_recv_buffer_tokens * num_topk); + auto channel_x_scales_buffers = Buffer(ptr, num_channels_total * num_recv_buffer_tokens * num_scales, + channel_rank_offset * num_recv_buffer_tokens * num_scales); - if (is_sender) { - // Workers for sending - constexpr int num_send_warps = kNumThreads / 32; - constexpr int num_send_warps_per_rank = num_send_warps / kNumRanks; - const auto send_thread_id = thread_id; - const auto send_lane_id = send_thread_id % 32; - const auto send_warp_id_in_rank = send_thread_id % num_threads_per_rank / 32; - EP_DEVICE_ASSERT(kNumRanks <= 32); - EP_DEVICE_ASSERT(num_send_warps % kNumRanks == 0); + if (is_sender) { + // Workers for sending + constexpr int num_send_warps = kNumThreads / 32; + constexpr int num_send_warps_per_rank = num_send_warps / kNumRanks; + const auto send_thread_id = thread_id; + const auto send_lane_id = send_thread_id % 32; + const auto send_warp_id_in_rank = send_thread_id % num_threads_per_rank / 32; + EP_DEVICE_ASSERT(kNumRanks <= 32); + EP_DEVICE_ASSERT(num_send_warps % kNumRanks == 0); - // Send offset by `-value - 1`, e.g. 0 -> -1, 1 -> -2 - // NOTES: this is for distinguishing zero tokens - if (send_lane_id == 0 and send_warp_id_in_rank == 0) { - int value = responsible_channel > 0 ? channel_prefix_matrix[responsible_rank * num_channels + responsible_channel - 1] : 0; - st_relaxed_sys_global(channel_start_offset.buffer(), -value - 1); - value = channel_prefix_matrix[responsible_rank * num_channels + responsible_channel]; - st_relaxed_sys_global(channel_end_offset.buffer(), -value - 1); - } - __syncwarp(); - - // Get tasks - int token_start_idx, token_end_idx; - get_channel_task_range(num_tokens, num_channels, responsible_channel, token_start_idx, token_end_idx); - - // Iterate over all tokens and send by chunks - int cached_channel_tail_idx = 0; - for (int64_t token_idx = token_start_idx; token_idx < token_end_idx; ) { - // Check destination queue emptiness, or wait a buffer to be released (rare cases) - // NOTES: the head index received by different warps may not be the same - auto start_time = clock64(); - while (send_lane_id == 0) { - // NOTES: we only consider the worst case, because counting the real numbers are time-consuming - int num_used_slots = cached_channel_tail_idx - ld_volatile_global(channel_head_idx.buffer()); - if (num_recv_buffer_tokens - num_used_slots >= num_max_send_tokens) - break; - - // Rare cases to loop again - if (clock64() - start_time > NUM_TIMEOUT_CYCLES) { - printf("DeepEP timeout for dispatch senders, rank %d, responsible_channel = %d\n", rank, responsible_channel); - trap(); - } - } - __syncwarp(); - - int chunk_token_idx = 0; - while (chunk_token_idx < num_max_send_tokens and token_idx < token_end_idx) { - // NOTES: for the same token, the warp assigned to save `send_head` may be different from the warp assigned to send subsequent data - if (send_lane_id == 0 and token_idx % num_send_warps_per_rank == send_warp_id_in_rank) - send_head[token_idx * kNumRanks + responsible_rank] = is_token_in_rank[token_idx * kNumRanks + responsible_rank] ? cached_channel_tail_idx : -1; - - // Skip if not selected - if (not is_token_in_rank[token_idx * kNumRanks + responsible_rank]) { - token_idx ++; - continue; - } - - // Get an empty slot - int dst_slot_idx = (cached_channel_tail_idx ++) % num_recv_buffer_tokens; - if (cached_channel_tail_idx % num_send_warps_per_rank == send_warp_id_in_rank) { - // Copy data - auto shifted_channel_x_buffers = channel_x_buffers.buffer() + dst_slot_idx * hidden_int4; - auto shifted_x = x + token_idx * hidden_int4; - UNROLLED_WARP_COPY(5, send_lane_id, hidden_int4, shifted_channel_x_buffers, shifted_x, - __ldg, st_na_global); - - // Copy source index - if (send_lane_id == 0) - channel_src_idx_buffers[dst_slot_idx] = static_cast(token_idx); - - // Copy `topk_idx` and `topk_weights` with transformed index - if (send_lane_id < num_topk) { - // Top-k index - int recv_expert_begin = responsible_rank * num_experts_per_rank, recv_expert_end = (responsible_rank + 1) * num_experts_per_rank; - auto idx_value = __ldg(topk_idx + token_idx * num_topk + send_lane_id); - idx_value = (idx_value >= recv_expert_begin and idx_value < recv_expert_end) ? idx_value - recv_expert_begin : -1; - channel_topk_idx_buffers[dst_slot_idx * num_topk + send_lane_id] = idx_value; - - // Top-k weights - auto weight_value = __ldg(topk_weights + token_idx * num_topk + send_lane_id); - weight_value = (idx_value >= 0) ? weight_value : 0.0f; - channel_topk_weights_buffers[dst_slot_idx * num_topk + send_lane_id] = weight_value; - } - - // Copy `x_scales` - #pragma unroll - for (int i = send_lane_id; i < num_scales; i += 32) - channel_x_scales_buffers[dst_slot_idx * num_scales + i] = __ldg(x_scales + token_idx * num_scales + i); - } - - // Move token index - chunk_token_idx ++, token_idx ++; - } - - // Move tail index - // NOTES: here all warps should share the same new tail - asm volatile("bar.sync %0, %1;" :: "r"(responsible_rank), "r"(num_threads_per_rank)); - if (send_warp_id_in_rank == 0 and send_lane_id == 0) - st_release_sys_global(channel_tail_idx.buffer(), cached_channel_tail_idx); - } - } else { - // Workers for receiving and copying into buffer - constexpr int num_recv_warps = kNumThreads / 32; - constexpr int num_recv_warps_per_rank = num_recv_warps / kNumRanks; - const auto recv_thread_id = thread_id; - const auto recv_lane_id = recv_thread_id % 32; - const auto recv_thread_id_in_rank = recv_thread_id % num_threads_per_rank; - const auto recv_warp_id_in_rank = recv_thread_id_in_rank / 32; - EP_DEVICE_ASSERT(kNumRanks <= 32); - EP_DEVICE_ASSERT(recv_thread_id >= 0 and num_recv_warps % kNumRanks == 0); - - // Calculate offset first - auto rank_prefix_matrix = reinterpret_cast(buffer_ptrs[rank]); - int rank_offset = responsible_rank > 0 ? rank_prefix_matrix[(responsible_rank - 1) * kNumRanks + rank] : 0; - - // Receive channel offset - int total_offset, num_tokens_to_recv; - while (recv_lane_id == 0 and (total_offset = ld_volatile_global(channel_start_offset.buffer())) == 0); - while (recv_lane_id == 0 and (num_tokens_to_recv = ld_volatile_global(channel_end_offset.buffer())) == 0); - if (recv_lane_id == 0) { - total_offset = -total_offset - 1, num_tokens_to_recv = -num_tokens_to_recv - 1; - if (recv_warp_id_in_rank == 0) - recv_channel_offset[responsible_rank * num_channels + responsible_channel] = total_offset; - num_tokens_to_recv -= total_offset; - } - total_offset = __shfl_sync(0xffffffff, total_offset, 0); - total_offset += rank_offset; - num_tokens_to_recv = __shfl_sync(0xffffffff, num_tokens_to_recv, 0); - - // Shared tail indices for different warps - __shared__ volatile int shared_channel_tail_idx[kNumRanks]; - - auto start_time = clock64(); - int cached_channel_head_idx = 0, cached_channel_tail_idx = 0; - while (num_tokens_to_recv > 0) { - // NOTES: unlike the sender, the receiver must ensure that the tail indices hold by different warps are same - while (recv_thread_id_in_rank == 0) { - cached_channel_tail_idx = ld_acquire_sys_global(channel_tail_idx.buffer());; - - // Ready to copy - if (cached_channel_head_idx != cached_channel_tail_idx) { - shared_channel_tail_idx[responsible_rank] = cached_channel_tail_idx; - break; - } - - // Timeout check - if (clock64() - start_time > NUM_TIMEOUT_CYCLES) { - printf("DeepEP timeout for dispatch receivers, rank %d, responsible_channel = %d, tokens remained: %d\n", rank, responsible_channel, num_tokens_to_recv); - trap(); - } - } - - // Synchronize queue tail - asm volatile("bar.sync %0, %1;" :: "r"(responsible_rank), "r"(num_threads_per_rank)); - cached_channel_tail_idx = shared_channel_tail_idx[responsible_rank]; - - // Copy data - int num_recv_tokens = cached_channel_tail_idx - cached_channel_head_idx; - for (int chunk_idx = recv_warp_id_in_rank; chunk_idx < num_recv_tokens; chunk_idx += num_recv_warps_per_rank) { - int token_idx_in_buffer = (cached_channel_head_idx + chunk_idx) % num_recv_buffer_tokens; - auto shifted_buffer_x_int4 = channel_x_buffers.buffer() + token_idx_in_buffer * hidden_int4; - auto shifted_recv_x_int4 = recv_x + static_cast(total_offset + chunk_idx) * hidden_int4; - UNROLLED_WARP_COPY(5, recv_lane_id, hidden_int4, shifted_recv_x_int4, shifted_buffer_x_int4, - ld_nc_global, st_na_global); - } - - // Copy `src_idx` - #pragma unroll 4 - for (int chunk_idx = cached_channel_head_idx + recv_thread_id_in_rank; chunk_idx < cached_channel_tail_idx; chunk_idx += 32 * num_recv_warps_per_rank) - recv_src_idx[total_offset + chunk_idx - cached_channel_head_idx] = ld_nc_global(channel_src_idx_buffers.buffer() + chunk_idx % num_recv_buffer_tokens); - - // Copy `topk_idx` and `topk_weights` - #pragma unroll 4 - for (int idx = recv_thread_id_in_rank; idx < num_recv_tokens * num_topk; idx += 32 * num_recv_warps_per_rank) { - int chunk_idx = idx / num_topk, token_topk_idx = idx % num_topk; - int token_idx_in_buffer = (cached_channel_head_idx + chunk_idx) % num_recv_buffer_tokens; - auto recv_idx = static_cast(total_offset + chunk_idx) * num_topk + token_topk_idx; - auto buffer_idx = token_idx_in_buffer * num_topk + token_topk_idx; - recv_topk_idx[recv_idx] = ld_nc_global(channel_topk_idx_buffers.buffer() + buffer_idx); - recv_topk_weights[recv_idx] = ld_nc_global(channel_topk_weights_buffers.buffer() + buffer_idx); - } - - // Copy `x_scales` - #pragma unroll 4 - for (int i = recv_thread_id_in_rank; i < num_recv_tokens * num_scales; i += 32 * num_recv_warps_per_rank) { - int chunk_idx = i / num_scales, scales_idx = i % num_scales; - int token_idx_in_buffer = (cached_channel_head_idx + chunk_idx) % num_recv_buffer_tokens; - recv_x_scales[static_cast(total_offset + chunk_idx) * num_scales + scales_idx] = - ld_nc_global(channel_x_scales_buffers.buffer() + token_idx_in_buffer * num_scales + scales_idx); - } - - // Move queue - cached_channel_head_idx += num_recv_tokens; - total_offset += num_recv_tokens; - asm volatile("bar.sync %0, %1;" :: "r"(responsible_rank), "r"(num_threads_per_rank)); - if (recv_warp_id_in_rank == num_recv_warps_per_rank - 1 and recv_lane_id == 0) - st_relaxed_sys_global(channel_head_idx.buffer(), cached_channel_head_idx); - - // Exit - num_tokens_to_recv -= num_recv_tokens; - } + // Send offset by `-value - 1`, e.g. 0 -> -1, 1 -> -2 + // NOTES: this is for distinguishing zero tokens + if (send_lane_id == 0 and send_warp_id_in_rank == 0) { + int value = responsible_channel > 0 + ? channel_prefix_matrix[responsible_rank * num_channels + responsible_channel - 1] + : 0; + st_relaxed_sys_global(channel_start_offset.buffer(), -value - 1); + value = channel_prefix_matrix[responsible_rank * num_channels + responsible_channel]; + st_relaxed_sys_global(channel_end_offset.buffer(), -value - 1); } + __syncwarp(); + + // Get tasks + int token_start_idx, token_end_idx; + get_channel_task_range(num_tokens, num_channels, responsible_channel, token_start_idx, token_end_idx); + + // Iterate over all tokens and send by chunks + int cached_channel_tail_idx = 0; + for (int64_t token_idx = token_start_idx; token_idx < token_end_idx;) { + // Check destination queue emptiness, or wait a buffer to be released (rare cases) + // NOTES: the head index received by different warps may not be the same + auto start_time = clock64(); + while (send_lane_id == 0) { + // NOTES: we only consider the worst case, because counting the real numbers are time-consuming + int num_used_slots = cached_channel_tail_idx - ld_volatile_global(channel_head_idx.buffer()); + if (num_recv_buffer_tokens - num_used_slots >= num_max_send_tokens) break; + + // Rare cases to loop again + if (clock64() - start_time > NUM_TIMEOUT_CYCLES) { + printf("DeepEP timeout for dispatch senders, rank %d, responsible_channel = %d\n", rank, responsible_channel); + trap(); + } + } + __syncwarp(); + + int chunk_token_idx = 0; + while (chunk_token_idx < num_max_send_tokens and token_idx < token_end_idx) { + // NOTES: for the same token, the warp assigned to save `send_head` may be different from the warp assigned to + // send subsequent data + if (send_lane_id == 0 and token_idx % num_send_warps_per_rank == send_warp_id_in_rank) + send_head[token_idx * kNumRanks + responsible_rank] = + is_token_in_rank[token_idx * kNumRanks + responsible_rank] ? cached_channel_tail_idx : -1; + + // Skip if not selected + if (not is_token_in_rank[token_idx * kNumRanks + responsible_rank]) { + token_idx++; + continue; + } + + // Get an empty slot + int dst_slot_idx = (cached_channel_tail_idx++) % num_recv_buffer_tokens; + if (cached_channel_tail_idx % num_send_warps_per_rank == send_warp_id_in_rank) { + // Copy data + auto shifted_channel_x_buffers = channel_x_buffers.buffer() + dst_slot_idx * hidden_int4; + auto shifted_x = x + token_idx * hidden_int4; + UNROLLED_WARP_COPY(5, send_lane_id, hidden_int4, shifted_channel_x_buffers, shifted_x, __ldg, st_na_global); + + // Copy source index + if (send_lane_id == 0) channel_src_idx_buffers[dst_slot_idx] = static_cast(token_idx); + + // Copy `topk_idx` and `topk_weights` with transformed index + if (send_lane_id < num_topk) { + // Top-k index + int recv_expert_begin = responsible_rank * num_experts_per_rank, + recv_expert_end = (responsible_rank + 1) * num_experts_per_rank; + auto idx_value = __ldg(topk_idx + token_idx * num_topk + send_lane_id); + idx_value = + (idx_value >= recv_expert_begin and idx_value < recv_expert_end) ? idx_value - recv_expert_begin : -1; + channel_topk_idx_buffers[dst_slot_idx * num_topk + send_lane_id] = idx_value; + + // Top-k weights + auto weight_value = __ldg(topk_weights + token_idx * num_topk + send_lane_id); + weight_value = (idx_value >= 0) ? weight_value : 0.0f; + channel_topk_weights_buffers[dst_slot_idx * num_topk + send_lane_id] = weight_value; + } + +// Copy `x_scales` +#pragma unroll + for (int i = send_lane_id; i < num_scales; i += 32) + channel_x_scales_buffers[dst_slot_idx * num_scales + i] = __ldg(x_scales + token_idx * num_scales + i); + } + + // Move token index + chunk_token_idx++, token_idx++; + } + + // Move tail index + // NOTES: here all warps should share the same new tail + asm volatile("bar.sync %0, %1;" ::"r"(responsible_rank), "r"(num_threads_per_rank)); + if (send_warp_id_in_rank == 0 and send_lane_id == 0) + st_release_sys_global(channel_tail_idx.buffer(), cached_channel_tail_idx); + } + } else { + // Workers for receiving and copying into buffer + constexpr int num_recv_warps = kNumThreads / 32; + constexpr int num_recv_warps_per_rank = num_recv_warps / kNumRanks; + const auto recv_thread_id = thread_id; + const auto recv_lane_id = recv_thread_id % 32; + const auto recv_thread_id_in_rank = recv_thread_id % num_threads_per_rank; + const auto recv_warp_id_in_rank = recv_thread_id_in_rank / 32; + EP_DEVICE_ASSERT(kNumRanks <= 32); + EP_DEVICE_ASSERT(recv_thread_id >= 0 and num_recv_warps % kNumRanks == 0); + + // Calculate offset first + auto rank_prefix_matrix = reinterpret_cast(buffer_ptrs[rank]); + int rank_offset = responsible_rank > 0 ? rank_prefix_matrix[(responsible_rank - 1) * kNumRanks + rank] : 0; + + // Receive channel offset + int total_offset, num_tokens_to_recv; + while (recv_lane_id == 0 and (total_offset = ld_volatile_global(channel_start_offset.buffer())) == 0) + ; + while (recv_lane_id == 0 and (num_tokens_to_recv = ld_volatile_global(channel_end_offset.buffer())) == 0) + ; + if (recv_lane_id == 0) { + total_offset = -total_offset - 1, num_tokens_to_recv = -num_tokens_to_recv - 1; + if (recv_warp_id_in_rank == 0) + recv_channel_offset[responsible_rank * num_channels + responsible_channel] = total_offset; + num_tokens_to_recv -= total_offset; + } + total_offset = __shfl_sync(0xffffffff, total_offset, 0); + total_offset += rank_offset; + num_tokens_to_recv = __shfl_sync(0xffffffff, num_tokens_to_recv, 0); + + // Shared tail indices for different warps + __shared__ volatile int shared_channel_tail_idx[kNumRanks]; + + auto start_time = clock64(); + int cached_channel_head_idx = 0, cached_channel_tail_idx = 0; + while (num_tokens_to_recv > 0) { + // NOTES: unlike the sender, the receiver must ensure that the tail indices hold by different warps are same + while (recv_thread_id_in_rank == 0) { + cached_channel_tail_idx = ld_acquire_sys_global(channel_tail_idx.buffer()); + ; + + // Ready to copy + if (cached_channel_head_idx != cached_channel_tail_idx) { + shared_channel_tail_idx[responsible_rank] = cached_channel_tail_idx; + break; + } + + // Timeout check + if (clock64() - start_time > NUM_TIMEOUT_CYCLES) { + printf("DeepEP timeout for dispatch receivers, rank %d, responsible_channel = %d, tokens remained: %d\n", + rank, responsible_channel, num_tokens_to_recv); + trap(); + } + } + + // Synchronize queue tail + asm volatile("bar.sync %0, %1;" ::"r"(responsible_rank), "r"(num_threads_per_rank)); + cached_channel_tail_idx = shared_channel_tail_idx[responsible_rank]; + + // Copy data + int num_recv_tokens = cached_channel_tail_idx - cached_channel_head_idx; + for (int chunk_idx = recv_warp_id_in_rank; chunk_idx < num_recv_tokens; chunk_idx += num_recv_warps_per_rank) { + int token_idx_in_buffer = (cached_channel_head_idx + chunk_idx) % num_recv_buffer_tokens; + auto shifted_buffer_x_int4 = channel_x_buffers.buffer() + token_idx_in_buffer * hidden_int4; + auto shifted_recv_x_int4 = recv_x + static_cast(total_offset + chunk_idx) * hidden_int4; + UNROLLED_WARP_COPY(5, recv_lane_id, hidden_int4, shifted_recv_x_int4, shifted_buffer_x_int4, ld_nc_global, + st_na_global); + } + +// Copy `src_idx` +#pragma unroll 4 + for (int chunk_idx = cached_channel_head_idx + recv_thread_id_in_rank; chunk_idx < cached_channel_tail_idx; + chunk_idx += 32 * num_recv_warps_per_rank) + recv_src_idx[total_offset + chunk_idx - cached_channel_head_idx] = + ld_nc_global(channel_src_idx_buffers.buffer() + chunk_idx % num_recv_buffer_tokens); + +// Copy `topk_idx` and `topk_weights` +#pragma unroll 4 + for (int idx = recv_thread_id_in_rank; idx < num_recv_tokens * num_topk; idx += 32 * num_recv_warps_per_rank) { + int chunk_idx = idx / num_topk, token_topk_idx = idx % num_topk; + int token_idx_in_buffer = (cached_channel_head_idx + chunk_idx) % num_recv_buffer_tokens; + auto recv_idx = static_cast(total_offset + chunk_idx) * num_topk + token_topk_idx; + auto buffer_idx = token_idx_in_buffer * num_topk + token_topk_idx; + recv_topk_idx[recv_idx] = ld_nc_global(channel_topk_idx_buffers.buffer() + buffer_idx); + recv_topk_weights[recv_idx] = ld_nc_global(channel_topk_weights_buffers.buffer() + buffer_idx); + } + +// Copy `x_scales` +#pragma unroll 4 + for (int i = recv_thread_id_in_rank; i < num_recv_tokens * num_scales; i += 32 * num_recv_warps_per_rank) { + int chunk_idx = i / num_scales, scales_idx = i % num_scales; + int token_idx_in_buffer = (cached_channel_head_idx + chunk_idx) % num_recv_buffer_tokens; + recv_x_scales[static_cast(total_offset + chunk_idx) * num_scales + scales_idx] = + ld_nc_global(channel_x_scales_buffers.buffer() + token_idx_in_buffer * num_scales + scales_idx); + } + + // Move queue + cached_channel_head_idx += num_recv_tokens; + total_offset += num_recv_tokens; + asm volatile("bar.sync %0, %1;" ::"r"(responsible_rank), "r"(num_threads_per_rank)); + if (recv_warp_id_in_rank == num_recv_warps_per_rank - 1 and recv_lane_id == 0) + st_relaxed_sys_global(channel_head_idx.buffer(), cached_channel_head_idx); + + // Exit + num_tokens_to_recv -= num_recv_tokens; + } + } } -void dispatch(void* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_topk_idx, float* recv_topk_weights, int* recv_channel_offset, - int* send_head, const void* x, const float* x_scales, const int64_t* topk_idx, const float* topk_weights, - const bool* is_token_in_rank, const int* channel_prefix_matrix, - int num_tokens, int hidden_int4, int num_topk, int num_experts, int num_scales, - void** buffer_ptrs, int rank, int num_ranks, - cudaStream_t stream, int num_sms, int num_max_send_tokens, int num_recv_buffer_tokens) { - constexpr int kNumThreads = 512; +void dispatch(void* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_topk_idx, float* recv_topk_weights, + int* recv_channel_offset, int* send_head, const void* x, const float* x_scales, const int64_t* topk_idx, + const float* topk_weights, const bool* is_token_in_rank, const int* channel_prefix_matrix, int num_tokens, + int hidden_int4, int num_topk, int num_experts, int num_scales, void** buffer_ptrs, int rank, + int num_ranks, cudaStream_t stream, int num_sms, int num_max_send_tokens, int num_recv_buffer_tokens) { + constexpr int kNumThreads = 512; -#define DISPATCH_LAUNCH_CASE(ranks) \ -LAUNCH_KERNEL(&cfg, dispatch, \ - reinterpret_cast(recv_x), recv_x_scales, recv_src_idx, recv_topk_idx, recv_topk_weights, recv_channel_offset, \ - send_head, reinterpret_cast(x), x_scales, topk_idx, topk_weights, \ - is_token_in_rank, channel_prefix_matrix, \ - num_tokens, hidden_int4, num_topk, num_experts, num_scales, \ - buffer_ptrs, rank, \ - num_max_send_tokens, num_recv_buffer_tokens); \ -break +#define DISPATCH_LAUNCH_CASE(ranks) \ + LAUNCH_KERNEL(&cfg, dispatch, reinterpret_cast(recv_x), recv_x_scales, recv_src_idx, \ + recv_topk_idx, recv_topk_weights, recv_channel_offset, send_head, reinterpret_cast(x), \ + x_scales, topk_idx, topk_weights, is_token_in_rank, channel_prefix_matrix, num_tokens, hidden_int4, \ + num_topk, num_experts, num_scales, buffer_ptrs, rank, num_max_send_tokens, num_recv_buffer_tokens); \ + break - // Even-numbered blocks for sending, odd-numbered blocks for receiving. - EP_HOST_ASSERT(num_sms % 2 == 0); - SETUP_LAUNCH_CONFIG(num_sms, kNumThreads, stream); - SWITCH_RANKS(DISPATCH_LAUNCH_CASE); + // Even-numbered blocks for sending, odd-numbered blocks for receiving. + EP_HOST_ASSERT(num_sms % 2 == 0); + SETUP_LAUNCH_CONFIG(num_sms, kNumThreads, stream); + SWITCH_RANKS(DISPATCH_LAUNCH_CASE); #undef DISPATCH_LAUNCH_CASE } -template -__global__ void -cached_notify_combine(void** buffer_ptrs, int* send_head, int num_channels, int num_recv_tokens, int num_memset_int, - int** task_fifo_ptrs, int head, int rank) { - const auto sm_id = static_cast(blockIdx.x); - if (sm_id == 0) { - // Barrier before cleaning - barrier_device(task_fifo_ptrs, head, rank); - move_fifo_slots(head); - __syncthreads(); +template +__global__ void cached_notify_combine(void** buffer_ptrs, int* send_head, int num_channels, int num_recv_tokens, + int num_memset_int, int** task_fifo_ptrs, int head, int rank) { + const auto sm_id = static_cast(blockIdx.x); + if (sm_id == 0) { + // Barrier before cleaning + barrier_device(task_fifo_ptrs, head, rank); + move_fifo_slots(head); + __syncthreads(); - // Clean - auto thread_id = static_cast(threadIdx.x), num_threads = static_cast(blockDim.x); - auto ptr = reinterpret_cast(buffer_ptrs[rank]); - #pragma unroll - for (int i = thread_id; i < num_memset_int; i += num_threads) - ptr[i] = 0; - memory_fence(); - __syncthreads(); + // Clean + auto thread_id = static_cast(threadIdx.x), num_threads = static_cast(blockDim.x); + auto ptr = reinterpret_cast(buffer_ptrs[rank]); +#pragma unroll + for (int i = thread_id; i < num_memset_int; i += num_threads) ptr[i] = 0; + memory_fence(); + __syncthreads(); - // Barrier after cleaning - barrier_device(task_fifo_ptrs, head, rank); - } else { - const auto channel_id = sm_id - 1; - const auto thread_id = static_cast(threadIdx.x); - const auto rank_id = thread_id / 32; - const auto lane_id = thread_id % 32; - if (rank_id >= kNumRanks) - return; + // Barrier after cleaning + barrier_device(task_fifo_ptrs, head, rank); + } else { + const auto channel_id = sm_id - 1; + const auto thread_id = static_cast(threadIdx.x); + const auto rank_id = thread_id / 32; + const auto lane_id = thread_id % 32; + if (rank_id >= kNumRanks) return; - int token_start_idx, token_end_idx; - get_channel_task_range(num_recv_tokens, num_channels, channel_id, token_start_idx, token_end_idx); + int token_start_idx, token_end_idx; + get_channel_task_range(num_recv_tokens, num_channels, channel_id, token_start_idx, token_end_idx); - // NOTES: `1 << 25` is a heuristic large number - int last_head = 1 << 25; - #pragma unroll - for (int token_idx_tail = token_end_idx - 1; token_idx_tail >= token_start_idx; token_idx_tail -= 32) { - int token_idx = token_idx_tail - lane_id, expected_head = 0; - auto current_head = (token_idx >= token_start_idx) ? __ldg(send_head + token_idx * kNumRanks + rank_id) : -1; - for (int i = 0; i < min(32, token_idx_tail - token_start_idx + 1); ++ i) { - head = __shfl_sync(0xffffffff, current_head, i); - if (head < 0) { - if (lane_id == i) - expected_head = -last_head - 1; - } else { - last_head = head; - } - } - if (current_head < 0 and token_idx >= token_start_idx) - send_head[token_idx * kNumRanks + rank_id] = expected_head; + // NOTES: `1 << 25` is a heuristic large number + int last_head = 1 << 25; +#pragma unroll + for (int token_idx_tail = token_end_idx - 1; token_idx_tail >= token_start_idx; token_idx_tail -= 32) { + int token_idx = token_idx_tail - lane_id, expected_head = 0; + auto current_head = (token_idx >= token_start_idx) ? __ldg(send_head + token_idx * kNumRanks + rank_id) : -1; + for (int i = 0; i < min(32, token_idx_tail - token_start_idx + 1); ++i) { + head = __shfl_sync(0xffffffff, current_head, i); + if (head < 0) { + if (lane_id == i) expected_head = -last_head - 1; + } else { + last_head = head; } + } + if (current_head < 0 and token_idx >= token_start_idx) send_head[token_idx * kNumRanks + rank_id] = expected_head; } + } } -void cached_notify_combine(void** buffer_ptrs, int* send_head, int num_channels, - int num_recv_tokens, int num_memset_int, - int** task_fifo_ptrs, int head, int rank, int num_ranks, +void cached_notify_combine(void** buffer_ptrs, int* send_head, int num_channels, int num_recv_tokens, + int num_memset_int, int** task_fifo_ptrs, int head, int rank, int num_ranks, cudaStream_t stream) { -#define CACHED_NOTIFY_COMBINE(ranks) \ - LAUNCH_KERNEL(&cfg, cached_notify_combine, \ - buffer_ptrs, send_head, num_channels, num_recv_tokens, num_memset_int, task_fifo_ptrs, head, rank); \ - break +#define CACHED_NOTIFY_COMBINE(ranks) \ + LAUNCH_KERNEL(&cfg, cached_notify_combine, buffer_ptrs, send_head, num_channels, num_recv_tokens, \ + num_memset_int, task_fifo_ptrs, head, rank); \ + break - const int num_threads = std::max(128, 32 * num_ranks); - EP_HOST_ASSERT(num_ranks <= num_threads); - EP_HOST_ASSERT(num_threads <= 1024); - EP_HOST_ASSERT(1 + num_channels <= num_channels * 2); - SETUP_LAUNCH_CONFIG(1 + num_channels, num_threads, stream); - SWITCH_RANKS(CACHED_NOTIFY_COMBINE); + const int num_threads = std::max(128, 32 * num_ranks); + EP_HOST_ASSERT(num_ranks <= num_threads); + EP_HOST_ASSERT(num_threads <= 1024); + EP_HOST_ASSERT(1 + num_channels <= num_channels * 2); + SETUP_LAUNCH_CONFIG(1 + num_channels, num_threads, stream); + SWITCH_RANKS(CACHED_NOTIFY_COMBINE); #undef CACHED_NOTIFY_COMBINE } -template +template __global__ void __launch_bounds__(kNumThreads, 1) -combine(dtype_t* recv_x, float* recv_topk_weights, - const dtype_t* x, const float* topk_weights, - const int* src_idx, const int* rank_prefix_matrix, const int* channel_prefix_matrix, - int* send_head, int num_tokens, int num_recv_tokens, int hidden, int num_topk, - void **buffer_ptrs, int rank, - int num_max_send_tokens, int num_recv_buffer_tokens) { - const auto num_sms = static_cast(gridDim.x); - const auto thread_id = static_cast(threadIdx.x); - const auto sm_id = static_cast(blockIdx.x); - const auto num_channels = num_sms / 2; - const bool is_sender = sm_id % 2 == 0; - const int responsible_channel = sm_id / 2; - EP_DEVICE_ASSERT(num_topk <= 32); + combine(dtype_t* recv_x, float* recv_topk_weights, const dtype_t* x, const float* topk_weights, const int* src_idx, + const int* rank_prefix_matrix, const int* channel_prefix_matrix, int* send_head, int num_tokens, + int num_recv_tokens, int hidden, int num_topk, void** buffer_ptrs, int rank, int num_max_send_tokens, + int num_recv_buffer_tokens) { + const auto num_sms = static_cast(gridDim.x); + const auto thread_id = static_cast(threadIdx.x); + const auto sm_id = static_cast(blockIdx.x); + const auto num_channels = num_sms / 2; + const bool is_sender = sm_id % 2 == 0; + const int responsible_channel = sm_id / 2; + EP_DEVICE_ASSERT(num_topk <= 32); - constexpr int kDtypePerInt4 = sizeof(int4) / sizeof(dtype_t); - int hidden_int4 = hidden * sizeof(dtype_t) / sizeof(int4); - auto x_int4 = reinterpret_cast(x); - auto recv_int4 = reinterpret_cast(recv_x); + constexpr int kDtypePerInt4 = sizeof(int4) / sizeof(dtype_t); + int hidden_int4 = hidden * sizeof(dtype_t) / sizeof(int4); + auto x_int4 = reinterpret_cast(x); + auto recv_int4 = reinterpret_cast(recv_x); - if (is_sender) { - // Workers for sending - // Several warps are responsible for a single rank - constexpr int num_send_warps = kNumThreads / 32; - constexpr int num_send_warps_per_rank = num_send_warps / kNumRanks; - const auto num_threads_per_rank = num_send_warps_per_rank * 32; - const auto send_thread_id = thread_id; - const auto send_lane_id = send_thread_id % 32; - const auto send_rank_id = thread_id / num_threads_per_rank; - const auto send_warp_id_in_rank = send_thread_id % num_threads_per_rank / 32; + if (is_sender) { + // Workers for sending + // Several warps are responsible for a single rank + constexpr int num_send_warps = kNumThreads / 32; + constexpr int num_send_warps_per_rank = num_send_warps / kNumRanks; + const auto num_threads_per_rank = num_send_warps_per_rank * 32; + const auto send_thread_id = thread_id; + const auto send_lane_id = send_thread_id % 32; + const auto send_rank_id = thread_id / num_threads_per_rank; + const auto send_warp_id_in_rank = send_thread_id % num_threads_per_rank / 32; - // Calculate pointers by the specific layout - auto ptr = reinterpret_cast(reinterpret_cast(buffer_ptrs[send_rank_id])); - auto num_channels_total = num_channels * kNumRanks; - auto channel_rank_offset = responsible_channel * kNumRanks + rank; + // Calculate pointers by the specific layout + auto ptr = reinterpret_cast(reinterpret_cast(buffer_ptrs[send_rank_id])); + auto num_channels_total = num_channels * kNumRanks; + auto channel_rank_offset = responsible_channel * kNumRanks + rank; - // Channel meta data - // `head_idx`: kNumChannels * kNumRanks * sizeof(int) - // `tail_idx`: kNumChannels * kNumRanks * sizeof(int) - // `x_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * hidden_int4 * sizeof(int4) - // `src_idx_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * sizeof(int) - // `topk_weights_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * num_topk * sizeof(float) - auto channel_head_idx = Buffer(ptr, num_channels_total, channel_rank_offset); - auto channel_tail_idx = Buffer(ptr, num_channels_total, channel_rank_offset); - auto channel_x_buffers = Buffer(ptr, num_channels_total * num_recv_buffer_tokens * hidden_int4, channel_rank_offset * num_recv_buffer_tokens * hidden_int4); - auto channel_src_idx_buffers = Buffer(ptr, num_channels_total * num_recv_buffer_tokens, channel_rank_offset * num_recv_buffer_tokens); - auto channel_topk_weights_buffers = Buffer(ptr, num_channels_total * num_recv_buffer_tokens * num_topk, channel_rank_offset * num_recv_buffer_tokens * num_topk); + // Channel meta data + // `head_idx`: kNumChannels * kNumRanks * sizeof(int) + // `tail_idx`: kNumChannels * kNumRanks * sizeof(int) + // `x_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * hidden_int4 * sizeof(int4) + // `src_idx_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * sizeof(int) + // `topk_weights_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * num_topk * sizeof(float) + auto channel_head_idx = Buffer(ptr, num_channels_total, channel_rank_offset); + auto channel_tail_idx = Buffer(ptr, num_channels_total, channel_rank_offset); + auto channel_x_buffers = Buffer(ptr, num_channels_total * num_recv_buffer_tokens * hidden_int4, + channel_rank_offset * num_recv_buffer_tokens * hidden_int4); + auto channel_src_idx_buffers = + Buffer(ptr, num_channels_total * num_recv_buffer_tokens, channel_rank_offset * num_recv_buffer_tokens); + auto channel_topk_weights_buffers = Buffer(ptr, num_channels_total * num_recv_buffer_tokens * num_topk, + channel_rank_offset * num_recv_buffer_tokens * num_topk); - // Get tasks - // NOTES: `channel_offset` is already shifted - int rank_offset = send_rank_id > 0 ? rank_prefix_matrix[(send_rank_id - 1) * kNumRanks + rank] : 0; - int num_rank_tokens = rank_prefix_matrix[send_rank_id * kNumRanks + rank] - rank_offset; - int channel_offset = channel_prefix_matrix[send_rank_id * num_channels + responsible_channel]; - int num_channel_tokens = (responsible_channel == num_channels - 1 ? num_rank_tokens : channel_prefix_matrix[send_rank_id * num_channels + responsible_channel + 1]) - channel_offset; - int token_start_idx = rank_offset + channel_offset, token_end_idx = rank_offset + channel_offset + num_channel_tokens; + // Get tasks + // NOTES: `channel_offset` is already shifted + int rank_offset = send_rank_id > 0 ? rank_prefix_matrix[(send_rank_id - 1) * kNumRanks + rank] : 0; + int num_rank_tokens = rank_prefix_matrix[send_rank_id * kNumRanks + rank] - rank_offset; + int channel_offset = channel_prefix_matrix[send_rank_id * num_channels + responsible_channel]; + int num_channel_tokens = (responsible_channel == num_channels - 1 + ? num_rank_tokens + : channel_prefix_matrix[send_rank_id * num_channels + responsible_channel + 1]) - + channel_offset; + int token_start_idx = rank_offset + channel_offset, + token_end_idx = rank_offset + channel_offset + num_channel_tokens; - // Iterate over all tokens and send by chunks - int current_channel_tail_idx = 0; - for (int64_t token_idx = token_start_idx; token_idx < token_end_idx; ) { - // Check destination queue emptiness, or wait a buffer to be released (rare cases) - auto start_time = clock64(); - int num_round_tokens = min(num_max_send_tokens, token_end_idx - static_cast(token_idx)); - while (send_lane_id == 0) { - // NOTES: we only consider the worst case, because counting the real numbers are time-consuming - int num_used_slots = current_channel_tail_idx - ld_volatile_global(channel_head_idx.buffer()); - if (num_recv_buffer_tokens - num_used_slots >= num_round_tokens) - break; + // Iterate over all tokens and send by chunks + int current_channel_tail_idx = 0; + for (int64_t token_idx = token_start_idx; token_idx < token_end_idx;) { + // Check destination queue emptiness, or wait a buffer to be released (rare cases) + auto start_time = clock64(); + int num_round_tokens = min(num_max_send_tokens, token_end_idx - static_cast(token_idx)); + while (send_lane_id == 0) { + // NOTES: we only consider the worst case, because counting the real numbers are time-consuming + int num_used_slots = current_channel_tail_idx - ld_volatile_global(channel_head_idx.buffer()); + if (num_recv_buffer_tokens - num_used_slots >= num_round_tokens) break; - // Rare cases to loop again - if (clock64() - start_time > NUM_TIMEOUT_CYCLES) { - printf("DeepEP timeout for combine senders, rank %d, responsible_channel = %d\n", rank, responsible_channel); - trap(); - } - } - __syncwarp(); - - // Send by chunk - #pragma unroll - for (int i = send_warp_id_in_rank; i < num_round_tokens; i += num_send_warps_per_rank) { - // Get an empty slot - int dst_slot_idx = (current_channel_tail_idx + i) % num_recv_buffer_tokens; - - // Copy data - auto shifted_x_buffers = channel_x_buffers.buffer() + dst_slot_idx * hidden_int4; - auto shifted_x = x_int4 + (token_idx + i) * hidden_int4; - UNROLLED_WARP_COPY(4, send_lane_id, hidden_int4, shifted_x_buffers, shifted_x, ld_nc_global, st_na_global); - - // Send source index - if (send_lane_id == 0) - channel_src_idx_buffers[dst_slot_idx] = __ldg(src_idx + token_idx + i); - - // Send `topk_weights` - if (num_topk > 0 and send_lane_id < num_topk) - channel_topk_weights_buffers[dst_slot_idx * num_topk + send_lane_id] = __ldg(topk_weights + (token_idx + i) * num_topk + send_lane_id); - } - token_idx += num_round_tokens; - current_channel_tail_idx += num_round_tokens; - - // Move tail index - asm volatile("bar.sync %0, %1;" :: "r"(send_rank_id), "r"(num_threads_per_rank)); - if (send_lane_id == 0 and send_warp_id_in_rank == 0) - st_release_sys_global(channel_tail_idx.buffer(), current_channel_tail_idx); + // Rare cases to loop again + if (clock64() - start_time > NUM_TIMEOUT_CYCLES) { + printf("DeepEP timeout for combine senders, rank %d, responsible_channel = %d\n", rank, responsible_channel); + trap(); } - } else { - // Workers for receiving - // One warp for moving the queue head, others for reduction - constexpr int num_recv_warps = kNumThreads / 32; - const auto recv_warp_id = thread_id / 32; - const auto recv_lane_id = thread_id % 32; - EP_DEVICE_ASSERT(kNumRanks <= 32 and kNumThreads > 32); - EP_DEVICE_ASSERT(thread_id >= 0 and kNumThreads % 32 == 0); + } + __syncwarp(); - // Shared head, tail and retired flags for receiver warps - __shared__ volatile int warp_channel_head_idx[num_recv_warps][kNumRanks]; - __shared__ volatile int channel_tail_idx[kNumRanks]; - __shared__ volatile bool warp_retired[num_recv_warps]; - if (thread_id < num_recv_warps) - warp_retired[thread_id] = false; - if (recv_lane_id < kNumRanks) - warp_channel_head_idx[recv_warp_id][recv_lane_id] = 0; - if (thread_id < kNumRanks) - channel_tail_idx[thread_id] = 0; - asm volatile("bar.sync 0, %0;" :: "r"(kNumThreads)); +// Send by chunk +#pragma unroll + for (int i = send_warp_id_in_rank; i < num_round_tokens; i += num_send_warps_per_rank) { + // Get an empty slot + int dst_slot_idx = (current_channel_tail_idx + i) % num_recv_buffer_tokens; - if (thread_id < 32) { - int* channel_head_idx_ptr = reinterpret_cast(buffer_ptrs[rank]) + responsible_channel * kNumRanks + recv_lane_id; - int* channel_tail_idx_ptr = channel_head_idx_ptr + num_channels * kNumRanks; + // Copy data + auto shifted_x_buffers = channel_x_buffers.buffer() + dst_slot_idx * hidden_int4; + auto shifted_x = x_int4 + (token_idx + i) * hidden_int4; + UNROLLED_WARP_COPY(4, send_lane_id, hidden_int4, shifted_x_buffers, shifted_x, ld_nc_global, st_na_global); - // Queue head updater - int last_head = 0; - while (recv_lane_id < kNumRanks) { - // Check retired - bool retired = true; - #pragma unroll - for (int i = 1; i < num_recv_warps; ++ i) - retired = retired and warp_retired[i]; - if (retired) - break; + // Send source index + if (send_lane_id == 0) channel_src_idx_buffers[dst_slot_idx] = __ldg(src_idx + token_idx + i); - // Update queue tail - channel_tail_idx[recv_lane_id] = ld_acquire_sys_global(channel_tail_idx_ptr); + // Send `topk_weights` + if (num_topk > 0 and send_lane_id < num_topk) + channel_topk_weights_buffers[dst_slot_idx * num_topk + send_lane_id] = + __ldg(topk_weights + (token_idx + i) * num_topk + send_lane_id); + } + token_idx += num_round_tokens; + current_channel_tail_idx += num_round_tokens; - // Update minimum head - int min_head = std::numeric_limits::max(); - #pragma unroll - for (int i = 1; i < num_recv_warps; ++ i) if (not warp_retired[i]) - min_head = min(min_head, warp_channel_head_idx[i][recv_lane_id]); - if (min_head != std::numeric_limits::max() and min_head > last_head) - st_relaxed_sys_global(channel_head_idx_ptr, last_head = min_head); - } - } else { - // Receivers - // Channel metadata - // All lanes will use data buffer, but only rank lane will use `head/tail/src_idx` - Buffer channel_x_buffers[kNumRanks]; - Buffer channel_topk_weights_buffers[kNumRanks]; - - // Calculate pointers by the specific layout - #pragma unroll - for (int i = 0; i < kNumRanks; ++ i) { - auto channel_rank_offset = responsible_channel * kNumRanks + i; - auto num_channels_total = num_channels * kNumRanks; - // `head_idx` & `tail_idx`: kNumChannels * kNumRanks * sizeof(int) - auto ptr = reinterpret_cast(reinterpret_cast(buffer_ptrs[rank]) + 2 * num_channels * kNumRanks * sizeof(int)); - - // `x_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * hidden_int4 * sizeof(int4) - channel_x_buffers[i] = Buffer(ptr, num_channels_total * num_recv_buffer_tokens * hidden_int4, channel_rank_offset * num_recv_buffer_tokens * hidden_int4); - - // `src_idx_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * sizeof(int) - ptr = reinterpret_cast(reinterpret_cast(ptr) + num_channels_total * num_recv_buffer_tokens * sizeof(int)); - - // `topk_weights_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * num_topk * sizeof(float) - channel_topk_weights_buffers[i] = Buffer(ptr, num_channels_total * num_recv_buffer_tokens * num_topk, channel_rank_offset * num_recv_buffer_tokens * num_topk); - } - - // The same tokens as the dispatch process - int token_start_idx, token_end_idx; - get_channel_task_range(num_recv_tokens, num_channels, responsible_channel, token_start_idx, token_end_idx); - - // Iterate over all tokens and combine - for (int64_t token_idx = token_start_idx + recv_warp_id - 1; token_idx < token_end_idx; token_idx += num_recv_warps - 1) { - // Read expected head - int expected_head = -1; - if (recv_lane_id < kNumRanks) - expected_head = ld_nc_global(send_head + token_idx * kNumRanks + recv_lane_id); - - auto start_time = clock64(); - while (expected_head >= 0 and channel_tail_idx[recv_lane_id] <= expected_head) { - // Timeout check - if (clock64() - start_time > NUM_TIMEOUT_CYCLES) { - printf("DeepEP timeout for combine receivers, rank %d, responsible_channel = %d, expect = %d\n", rank, responsible_channel, expected_head); - trap(); - } - } - __syncwarp(); - - // Broadcast current heads - int num_topk_ranks = 0, topk_ranks[kNumRanks], slot_indices[kNumRanks]; - #pragma unroll - for (int i = 0; i < kNumRanks; ++ i) { - auto expected_head_i = __shfl_sync(0xffffffff, expected_head, i); - if (expected_head_i >= 0) { - slot_indices[num_topk_ranks] = expected_head_i % num_recv_buffer_tokens; - topk_ranks[num_topk_ranks ++] = i; - } - } - - // Reduce data - #pragma unroll - for (int i = recv_lane_id; i < hidden_int4; i += 32) { - // Read buffers - int4 recv_value_int4[kNumRanks]; - #pragma unroll - for (int j = 0; j < num_topk_ranks; ++ j) - recv_value_int4[j] = ld_nc_global(channel_x_buffers[topk_ranks[j]].buffer() + slot_indices[j] * hidden_int4 + i); - - // Reduce all-to-all results - float values[kDtypePerInt4] = {0}; - #pragma unroll - for (int j = 0; j < num_topk_ranks; ++ j) { - auto recv_value_dtypes = reinterpret_cast(&recv_value_int4[j]); - #pragma unroll - for (int k = 0; k < kDtypePerInt4; ++ k) - values[k] += static_cast(recv_value_dtypes[k]); - } - - // Cast back to `dtype_t` and write - int4 out_int4; - auto out_dtypes = reinterpret_cast(&out_int4); - #pragma unroll - for (int j = 0; j < kDtypePerInt4; ++ j) - out_dtypes[j] = static_cast(values[j]); - recv_int4[token_idx * hidden_int4 + i] = out_int4; - } - - // Reduce `topk_weights` - if (recv_lane_id < num_topk) { - float value = 0; - #pragma unroll - for (int i = 0; i < num_topk_ranks; ++ i) - value += ld_nc_global(channel_topk_weights_buffers[topk_ranks[i]].buffer() + slot_indices[i] * num_topk + recv_lane_id); - recv_topk_weights[token_idx * num_topk + recv_lane_id] = value; - } - - // Update head - if (recv_lane_id < kNumRanks) - warp_channel_head_idx[recv_warp_id][recv_lane_id] = (expected_head < 0) ? -expected_head - 1 : expected_head + 1; - } - - // Retired - __syncwarp(); - if (recv_lane_id == 0) - warp_retired[recv_warp_id] = true; - } + // Move tail index + asm volatile("bar.sync %0, %1;" ::"r"(send_rank_id), "r"(num_threads_per_rank)); + if (send_lane_id == 0 and send_warp_id_in_rank == 0) + st_release_sys_global(channel_tail_idx.buffer(), current_channel_tail_idx); } + } else { + // Workers for receiving + // One warp for moving the queue head, others for reduction + constexpr int num_recv_warps = kNumThreads / 32; + const auto recv_warp_id = thread_id / 32; + const auto recv_lane_id = thread_id % 32; + EP_DEVICE_ASSERT(kNumRanks <= 32 and kNumThreads > 32); + EP_DEVICE_ASSERT(thread_id >= 0 and kNumThreads % 32 == 0); + + // Shared head, tail and retired flags for receiver warps + __shared__ volatile int warp_channel_head_idx[num_recv_warps][kNumRanks]; + __shared__ volatile int channel_tail_idx[kNumRanks]; + __shared__ volatile bool warp_retired[num_recv_warps]; + if (thread_id < num_recv_warps) warp_retired[thread_id] = false; + if (recv_lane_id < kNumRanks) warp_channel_head_idx[recv_warp_id][recv_lane_id] = 0; + if (thread_id < kNumRanks) channel_tail_idx[thread_id] = 0; + asm volatile("bar.sync 0, %0;" ::"r"(kNumThreads)); + + if (thread_id < 32) { + int* channel_head_idx_ptr = + reinterpret_cast(buffer_ptrs[rank]) + responsible_channel * kNumRanks + recv_lane_id; + int* channel_tail_idx_ptr = channel_head_idx_ptr + num_channels * kNumRanks; + + // Queue head updater + int last_head = 0; + while (recv_lane_id < kNumRanks) { + // Check retired + bool retired = true; +#pragma unroll + for (int i = 1; i < num_recv_warps; ++i) retired = retired and warp_retired[i]; + if (retired) break; + + // Update queue tail + channel_tail_idx[recv_lane_id] = ld_acquire_sys_global(channel_tail_idx_ptr); + + // Update minimum head + int min_head = std::numeric_limits::max(); +#pragma unroll + for (int i = 1; i < num_recv_warps; ++i) + if (not warp_retired[i]) min_head = min(min_head, warp_channel_head_idx[i][recv_lane_id]); + if (min_head != std::numeric_limits::max() and min_head > last_head) + st_relaxed_sys_global(channel_head_idx_ptr, last_head = min_head); + } + } else { + // Receivers + // Channel metadata + // All lanes will use data buffer, but only rank lane will use `head/tail/src_idx` + Buffer channel_x_buffers[kNumRanks]; + Buffer channel_topk_weights_buffers[kNumRanks]; + +// Calculate pointers by the specific layout +#pragma unroll + for (int i = 0; i < kNumRanks; ++i) { + auto channel_rank_offset = responsible_channel * kNumRanks + i; + auto num_channels_total = num_channels * kNumRanks; + // `head_idx` & `tail_idx`: kNumChannels * kNumRanks * sizeof(int) + auto ptr = reinterpret_cast(reinterpret_cast(buffer_ptrs[rank]) + + 2 * num_channels * kNumRanks * sizeof(int)); + + // `x_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * hidden_int4 * sizeof(int4) + channel_x_buffers[i] = Buffer(ptr, num_channels_total * num_recv_buffer_tokens * hidden_int4, + channel_rank_offset * num_recv_buffer_tokens * hidden_int4); + + // `src_idx_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * sizeof(int) + ptr = reinterpret_cast(reinterpret_cast(ptr) + + num_channels_total * num_recv_buffer_tokens * sizeof(int)); + + // `topk_weights_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * num_topk * sizeof(float) + channel_topk_weights_buffers[i] = Buffer(ptr, num_channels_total * num_recv_buffer_tokens * num_topk, + channel_rank_offset * num_recv_buffer_tokens * num_topk); + } + + // The same tokens as the dispatch process + int token_start_idx, token_end_idx; + get_channel_task_range(num_recv_tokens, num_channels, responsible_channel, token_start_idx, token_end_idx); + + // Iterate over all tokens and combine + for (int64_t token_idx = token_start_idx + recv_warp_id - 1; token_idx < token_end_idx; + token_idx += num_recv_warps - 1) { + // Read expected head + int expected_head = -1; + if (recv_lane_id < kNumRanks) expected_head = ld_nc_global(send_head + token_idx * kNumRanks + recv_lane_id); + + auto start_time = clock64(); + while (expected_head >= 0 and channel_tail_idx[recv_lane_id] <= expected_head) { + // Timeout check + if (clock64() - start_time > NUM_TIMEOUT_CYCLES) { + printf("DeepEP timeout for combine receivers, rank %d, responsible_channel = %d, expect = %d\n", rank, + responsible_channel, expected_head); + trap(); + } + } + __syncwarp(); + + // Broadcast current heads + int num_topk_ranks = 0, topk_ranks[kNumRanks], slot_indices[kNumRanks]; +#pragma unroll + for (int i = 0; i < kNumRanks; ++i) { + auto expected_head_i = __shfl_sync(0xffffffff, expected_head, i); + if (expected_head_i >= 0) { + slot_indices[num_topk_ranks] = expected_head_i % num_recv_buffer_tokens; + topk_ranks[num_topk_ranks++] = i; + } + } + +// Reduce data +#pragma unroll + for (int i = recv_lane_id; i < hidden_int4; i += 32) { + // Read buffers + int4 recv_value_int4[kNumRanks]; +#pragma unroll + for (int j = 0; j < num_topk_ranks; ++j) + recv_value_int4[j] = + ld_nc_global(channel_x_buffers[topk_ranks[j]].buffer() + slot_indices[j] * hidden_int4 + i); + + // Reduce all-to-all results + float values[kDtypePerInt4] = {0}; +#pragma unroll + for (int j = 0; j < num_topk_ranks; ++j) { + auto recv_value_dtypes = reinterpret_cast(&recv_value_int4[j]); +#pragma unroll + for (int k = 0; k < kDtypePerInt4; ++k) values[k] += static_cast(recv_value_dtypes[k]); + } + + // Cast back to `dtype_t` and write + int4 out_int4; + auto out_dtypes = reinterpret_cast(&out_int4); +#pragma unroll + for (int j = 0; j < kDtypePerInt4; ++j) out_dtypes[j] = static_cast(values[j]); + recv_int4[token_idx * hidden_int4 + i] = out_int4; + } + + // Reduce `topk_weights` + if (recv_lane_id < num_topk) { + float value = 0; +#pragma unroll + for (int i = 0; i < num_topk_ranks; ++i) + value += ld_nc_global(channel_topk_weights_buffers[topk_ranks[i]].buffer() + slot_indices[i] * num_topk + + recv_lane_id); + recv_topk_weights[token_idx * num_topk + recv_lane_id] = value; + } + + // Update head + if (recv_lane_id < kNumRanks) + warp_channel_head_idx[recv_warp_id][recv_lane_id] = + (expected_head < 0) ? -expected_head - 1 : expected_head + 1; + } + + // Retired + __syncwarp(); + if (recv_lane_id == 0) warp_retired[recv_warp_id] = true; + } + } } -void combine(cudaDataType_t type, - void* recv_x, float* recv_topk_weights, - const void* x, const float* topk_weights, - const int* src_idx, const int* rank_prefix_matrix, const int* channel_prefix_matrix, - int* send_head, int num_tokens, int num_recv_tokens, int hidden, int num_topk, - void** buffer_ptrs, int rank, int num_ranks, - cudaStream_t stream, int num_sms, - int num_max_send_tokens, int num_recv_buffer_tokens) { - constexpr int kNumThreads = 768; +void combine(cudaDataType_t type, void* recv_x, float* recv_topk_weights, const void* x, const float* topk_weights, + const int* src_idx, const int* rank_prefix_matrix, const int* channel_prefix_matrix, int* send_head, + int num_tokens, int num_recv_tokens, int hidden, int num_topk, void** buffer_ptrs, int rank, int num_ranks, + cudaStream_t stream, int num_sms, int num_max_send_tokens, int num_recv_buffer_tokens) { + constexpr int kNumThreads = 768; -#define COMBINE_LAUNCH_CASE(dtype, ranks) \ - LAUNCH_KERNEL(&cfg, (combine), \ - reinterpret_cast(recv_x), recv_topk_weights, \ - reinterpret_cast(x), topk_weights, \ - src_idx, rank_prefix_matrix, channel_prefix_matrix, \ - send_head, num_tokens, num_recv_tokens, hidden, num_topk, \ - buffer_ptrs, rank, \ - num_max_send_tokens, num_recv_buffer_tokens); \ - break -#define COMBINE_DTYPE_LAUNCH_CASE(dtype) SWITCH_RANKS_WITH_DTYPE(dtype, COMBINE_LAUNCH_CASE); break +#define COMBINE_LAUNCH_CASE(dtype, ranks) \ + LAUNCH_KERNEL(&cfg, (combine), reinterpret_cast(recv_x), recv_topk_weights, \ + reinterpret_cast(x), topk_weights, src_idx, rank_prefix_matrix, channel_prefix_matrix, \ + send_head, num_tokens, num_recv_tokens, hidden, num_topk, buffer_ptrs, rank, num_max_send_tokens, \ + num_recv_buffer_tokens); \ + break +#define COMBINE_DTYPE_LAUNCH_CASE(dtype) \ + SWITCH_RANKS_WITH_DTYPE(dtype, COMBINE_LAUNCH_CASE); \ + break - // Even-numbered blocks for sending, odd-numbered blocks for receiving - EP_HOST_ASSERT(num_sms % 2 == 0); - EP_HOST_ASSERT(kNumThreads >= num_ranks * 32); - SETUP_LAUNCH_CONFIG(num_sms, kNumThreads, stream); - SWITCH_TYPES(COMBINE_DTYPE_LAUNCH_CASE); + // Even-numbered blocks for sending, odd-numbered blocks for receiving + EP_HOST_ASSERT(num_sms % 2 == 0); + EP_HOST_ASSERT(kNumThreads >= num_ranks * 32); + SETUP_LAUNCH_CONFIG(num_sms, kNumThreads, stream); + SWITCH_TYPES(COMBINE_DTYPE_LAUNCH_CASE); #undef COMBINE_DTYPE_LAUNCH_CASE #undef COMBINE_LAUNCH_CASE } -} // namespace intranode +} // namespace intranode -} // namespace ep -} // namespace mscclpp +} // namespace ep +} // namespace mscclpp diff --git a/src/ext/ep/kernels/launch.cuh b/src/ext/ep/kernels/launch.cuh index 94f9eb72..763f4ea1 100644 --- a/src/ext/ep/kernels/launch.cuh +++ b/src/ext/ep/kernels/launch.cuh @@ -5,69 +5,93 @@ #include "configs.cuh" #ifndef SETUP_LAUNCH_CONFIG -#define SETUP_LAUNCH_CONFIG(num_sms, num_threads, stream) \ - cudaLaunchConfig_t cfg = {(num_sms), (num_threads), 0, stream, nullptr, 0}; \ - cudaLaunchAttribute attr[1]; \ - attr[0].id = cudaLaunchAttributeCooperative; \ - attr[0].val.cooperative = 1; \ - cfg.attrs = attr; \ - cfg.numAttrs = 1 +#define SETUP_LAUNCH_CONFIG(num_sms, num_threads, stream) \ + cudaLaunchConfig_t cfg = {(num_sms), (num_threads), 0, stream, nullptr, 0}; \ + cudaLaunchAttribute attr[1]; \ + attr[0].id = cudaLaunchAttributeCooperative; \ + attr[0].val.cooperative = 1; \ + cfg.attrs = attr; \ + cfg.numAttrs = 1 #endif #ifndef LAUNCH_KERNEL #define LAUNCH_KERNEL(config, kernel, ...) CUDA_CHECK(cudaLaunchKernelEx(config, kernel, ##__VA_ARGS__)) #endif -#define SWITCH_RANKS(case_macro) \ - do { \ - switch (num_ranks) { \ - case 2: case_macro(2); \ - case 4: case_macro(4); \ - case 8: case_macro(8); \ - default: EP_HOST_ASSERT(false and "Unsupported ranks"); \ - } \ - } while (false) +#define SWITCH_RANKS(case_macro) \ + do { \ + switch (num_ranks) { \ + case 2: \ + case_macro(2); \ + case 4: \ + case_macro(4); \ + case 8: \ + case_macro(8); \ + default: \ + EP_HOST_ASSERT(false and "Unsupported ranks"); \ + } \ + } while (false) -#define SWITCH_RDMA_RANKS(case_macro) \ - do { \ - switch (num_ranks / NUM_MAX_NVL_PEERS) { \ - case 2: case_macro(2); \ - case 3: case_macro(3); \ - case 4: case_macro(4); \ - case 8: case_macro(8); \ - case 16: case_macro(16); \ - case 18: case_macro(18); \ - case 20: case_macro(20); \ - default: EP_HOST_ASSERT(false and "Unsupported RDMA ranks"); \ - } \ - } while (false) +#define SWITCH_RDMA_RANKS(case_macro) \ + do { \ + switch (num_ranks / NUM_MAX_NVL_PEERS) { \ + case 2: \ + case_macro(2); \ + case 3: \ + case_macro(3); \ + case 4: \ + case_macro(4); \ + case 8: \ + case_macro(8); \ + case 16: \ + case_macro(16); \ + case 18: \ + case_macro(18); \ + case 20: \ + case_macro(20); \ + default: \ + EP_HOST_ASSERT(false and "Unsupported RDMA ranks"); \ + } \ + } while (false) -#define SWITCH_RANKS_WITH_DTYPE(dtype, case_macro) \ - do { \ - switch (num_ranks) { \ - case 2: case_macro(dtype, 2); \ - case 4: case_macro(dtype, 4); \ - case 8: case_macro(dtype, 8); \ - default: EP_HOST_ASSERT(false && "Unsupported ranks"); \ - } \ - } while (false) +#define SWITCH_RANKS_WITH_DTYPE(dtype, case_macro) \ + do { \ + switch (num_ranks) { \ + case 2: \ + case_macro(dtype, 2); \ + case 4: \ + case_macro(dtype, 4); \ + case 8: \ + case_macro(dtype, 8); \ + default: \ + EP_HOST_ASSERT(false && "Unsupported ranks"); \ + } \ + } while (false) -#define SWITCH_TYPES(case_macro) \ - do { \ - switch (type) { \ - case CUDA_R_16BF: case_macro(nv_bfloat16); \ - case CUDA_R_32F: case_macro(float); \ - default: EP_HOST_ASSERT(false && "Unsupported type"); \ - } \ - } while (false) +#define SWITCH_TYPES(case_macro) \ + do { \ + switch (type) { \ + case CUDA_R_16BF: \ + case_macro(nv_bfloat16); \ + case CUDA_R_32F: \ + case_macro(float); \ + default: \ + EP_HOST_ASSERT(false && "Unsupported type"); \ + } \ + } while (false) -#define SWITCH_HIDDEN(case_macro) \ - do { \ - switch (hidden) { \ - case 2560: case_macro(2560); \ - case 4096: case_macro(4096); \ - case 5120: case_macro(5120); \ - case 7168: case_macro(7168); \ - default: EP_HOST_ASSERT(false && "Unsupported hidden"); \ - } \ - } while (false) +#define SWITCH_HIDDEN(case_macro) \ + do { \ + switch (hidden) { \ + case 2560: \ + case_macro(2560); \ + case 4096: \ + case_macro(4096); \ + case 5120: \ + case_macro(5120); \ + case 7168: \ + case_macro(7168); \ + default: \ + EP_HOST_ASSERT(false && "Unsupported hidden"); \ + } \ + } while (false) diff --git a/src/ext/ep/kernels/runtime.cu b/src/ext/ep/kernels/runtime.cu index 4526fac1..d32653fa 100644 --- a/src/ext/ep/kernels/runtime.cu +++ b/src/ext/ep/kernels/runtime.cu @@ -24,7 +24,7 @@ __global__ void barrier(int** task_fifo_ptrs, int head, int rank) { } void barrier(int** task_fifo_ptrs, int head, int rank, int num_ranks, cudaStream_t stream) { -#define BARRIER_LAUNCH_CASE(ranks) \ +#define BARRIER_LAUNCH_CASE(ranks) \ LAUNCH_KERNEL(&cfg, barrier, task_fifo_ptrs, head, rank); \ break diff --git a/src/ext/ep/kernels/utils.cuh b/src/ext/ep/kernels/utils.cuh index 70ca21a4..3fb01fe4 100644 --- a/src/ext/ep/kernels/utils.cuh +++ b/src/ext/ep/kernels/utils.cuh @@ -6,150 +6,156 @@ #include "exception.cuh" -#define UNROLLED_WARP_COPY(UNROLL_FACTOR, LANE_ID, N, DST, SRC, LD_FUNC, ST_FUNC) \ -{ \ - constexpr int kLoopStride = 32 * (UNROLL_FACTOR); \ +#define UNROLLED_WARP_COPY(UNROLL_FACTOR, LANE_ID, N, DST, SRC, LD_FUNC, ST_FUNC) \ + { \ + constexpr int kLoopStride = 32 * (UNROLL_FACTOR); \ typename std::remove_reference::type unrolled_values[(UNROLL_FACTOR)]; \ - auto __src = (SRC); \ - auto __dst = (DST); \ - for (int __i = (LANE_ID); __i < ((N) / kLoopStride) * kLoopStride; __i += kLoopStride) { \ - _Pragma("unroll") \ - for (int __j = 0; __j < (UNROLL_FACTOR); ++ __j) \ - unrolled_values[__j] = LD_FUNC(__src + __i + __j * 32); \ - _Pragma("unroll") \ - for (int __j = 0; __j < (UNROLL_FACTOR); ++ __j) \ - ST_FUNC(__dst + __i + __j * 32, unrolled_values[__j]); \ - } \ - for (int __i = ((N) / kLoopStride) * kLoopStride + (LANE_ID); __i < (N); __i += 32) \ - ST_FUNC(__dst + __i, LD_FUNC(__src + __i)); \ -} + auto __src = (SRC); \ + auto __dst = (DST); \ + for (int __i = (LANE_ID); __i < ((N) / kLoopStride) * kLoopStride; __i += kLoopStride) { \ + _Pragma("unroll") for (int __j = 0; __j < (UNROLL_FACTOR); ++__j) unrolled_values[__j] = \ + LD_FUNC(__src + __i + __j * 32); \ + _Pragma("unroll") for (int __j = 0; __j < (UNROLL_FACTOR); ++__j) \ + ST_FUNC(__dst + __i + __j * 32, unrolled_values[__j]); \ + } \ + for (int __i = ((N) / kLoopStride) * kLoopStride + (LANE_ID); __i < (N); __i += 32) \ + ST_FUNC(__dst + __i, LD_FUNC(__src + __i)); \ + } -namespace mscclpp { namespace ep { +namespace mscclpp { +namespace ep { template struct VecInt {}; -template<> struct VecInt<1> { using vec_t = int8_t; }; -template<> struct VecInt<2> { using vec_t = int16_t; }; -template<> struct VecInt<4> { using vec_t = int; }; -template<> struct VecInt<8> { using vec_t = int64_t; }; -template<> struct VecInt<16> { using vec_t = int4; }; +template <> +struct VecInt<1> { + using vec_t = int8_t; +}; +template <> +struct VecInt<2> { + using vec_t = int16_t; +}; +template <> +struct VecInt<4> { + using vec_t = int; +}; +template <> +struct VecInt<8> { + using vec_t = int64_t; +}; +template <> +struct VecInt<16> { + using vec_t = int4; +}; -__device__ __forceinline__ void trap() { - asm("trap;"); +__device__ __forceinline__ void trap() { asm("trap;"); } + +__device__ __forceinline__ void memory_fence() { asm volatile("fence.acq_rel.sys;" ::: "memory"); } + +__device__ __forceinline__ void memory_fence_gpu() { asm volatile("fence.acq_rel.gpu;" ::: "memory"); } + +__device__ __forceinline__ void memory_fence_cta() { asm volatile("fence.acq_rel.cta;" ::: "memory"); } + +__device__ __forceinline__ void st_relaxed_sys_global(const int *ptr, int val) { + asm volatile("st.relaxed.sys.global.s32 [%0], %1;" ::"l"(ptr), "r"(val) : "memory"); } -__device__ __forceinline__ void memory_fence() { - asm volatile("fence.acq_rel.sys;":: : "memory"); +__device__ __forceinline__ void st_release_sys_global(const int *ptr, int val) { + asm volatile("st.release.sys.global.s32 [%0], %1;" ::"l"(ptr), "r"(val) : "memory"); } -__device__ __forceinline__ void memory_fence_gpu() { - asm volatile("fence.acq_rel.gpu;":: : "memory"); -} - -__device__ __forceinline__ void memory_fence_cta() { - asm volatile("fence.acq_rel.cta;":: : "memory"); -} - -__device__ __forceinline__ void st_relaxed_sys_global(const int *ptr, int val) { - asm volatile("st.relaxed.sys.global.s32 [%0], %1;"::"l"(ptr), "r"(val) : "memory"); -} - -__device__ __forceinline__ void st_release_sys_global(const int *ptr, int val) { - asm volatile("st.release.sys.global.s32 [%0], %1;"::"l"(ptr), "r"(val) : "memory"); -} - -__device__ __forceinline__ void st_release_cta(const int *ptr, int val) { - asm volatile("st.release.cta.s32 [%0], %1;"::"l"(ptr), "r"(val) : "memory"); +__device__ __forceinline__ void st_release_cta(const int *ptr, int val) { + asm volatile("st.release.cta.s32 [%0], %1;" ::"l"(ptr), "r"(val) : "memory"); } __device__ __forceinline__ int ld_acquire_sys_global(const int *ptr) { - int ret; - asm volatile("ld.acquire.sys.global.s32 %0, [%1];" : "=r"(ret) : "l"(ptr)); - return ret; + int ret; + asm volatile("ld.acquire.sys.global.s32 %0, [%1];" : "=r"(ret) : "l"(ptr)); + return ret; } __device__ __forceinline__ uint64_t ld_acquire_sys_global(const uint64_t *ptr) { - uint64_t ret; - asm volatile("ld.acquire.sys.global.u64 %0, [%1];" : "=l"(ret) : "l"(ptr)); - return ret; + uint64_t ret; + asm volatile("ld.acquire.sys.global.u64 %0, [%1];" : "=l"(ret) : "l"(ptr)); + return ret; } __device__ __forceinline__ int64_t ld_acquire_sys_global(const int64_t *ptr) { - int64_t ret; - asm volatile("ld.acquire.sys.global.s64 %0, [%1];" : "=l"(ret) : "l"(ptr)); - return ret; + int64_t ret; + asm volatile("ld.acquire.sys.global.s64 %0, [%1];" : "=l"(ret) : "l"(ptr)); + return ret; } __device__ __forceinline__ int ld_acquire_global(const int *ptr) { - int ret; - asm volatile("ld.acquire.gpu.global.s32 %0, [%1];" : "=r"(ret) : "l"(ptr)); - return ret; + int ret; + asm volatile("ld.acquire.gpu.global.s32 %0, [%1];" : "=r"(ret) : "l"(ptr)); + return ret; } -__device__ __forceinline__ int atomic_add_release_sys_global(const int* ptr, int value) { - int ret; - asm volatile("atom.add.release.sys.global.s32 %0, [%1], %2;" : "=r"(ret) : "l"(ptr), "r"(value)); - return ret; +__device__ __forceinline__ int atomic_add_release_sys_global(const int *ptr, int value) { + int ret; + asm volatile("atom.add.release.sys.global.s32 %0, [%1], %2;" : "=r"(ret) : "l"(ptr), "r"(value)); + return ret; } -__device__ __forceinline__ int atomic_add_release_global(const int* ptr, int value) { - int ret; - asm volatile("atom.add.release.gpu.global.s32 %0, [%1], %2;" : "=r"(ret) : "l"(ptr), "r"(value)); - return ret; +__device__ __forceinline__ int atomic_add_release_global(const int *ptr, int value) { + int ret; + asm volatile("atom.add.release.gpu.global.s32 %0, [%1], %2;" : "=r"(ret) : "l"(ptr), "r"(value)); + return ret; } __device__ __forceinline__ int ld_acquire_cta(const int *ptr) { - int ret; - asm volatile("ld.acquire.cta.s32 %0, [%1];" : "=r"(ret) : "l"(ptr)); - return ret; + int ret; + asm volatile("ld.acquire.cta.s32 %0, [%1];" : "=r"(ret) : "l"(ptr)); + return ret; } __device__ __forceinline__ uint8_t ld_na_relaxed(const uint8_t *ptr) { - uint16_t ret; - asm volatile("ld.relaxed.gpu.global.L1::no_allocate.b8 %0, [%1];" : "=h"(ret) : "l"(ptr)); - return static_cast(ret); + uint16_t ret; + asm volatile("ld.relaxed.gpu.global.L1::no_allocate.b8 %0, [%1];" : "=h"(ret) : "l"(ptr)); + return static_cast(ret); } __device__ __forceinline__ uint16_t ld_na_relaxed(const uint16_t *ptr) { - uint16_t ret; - asm volatile("ld.relaxed.gpu.global.L1::no_allocate.b16 %0, [%1];" : "=h"(ret) : "l"(ptr)); - return ret; + uint16_t ret; + asm volatile("ld.relaxed.gpu.global.L1::no_allocate.b16 %0, [%1];" : "=h"(ret) : "l"(ptr)); + return ret; } __device__ __forceinline__ uint32_t ld_na_relaxed(const uint32_t *ptr) { - uint32_t ret; - asm volatile("ld.relaxed.gpu.global.L1::no_allocate.b32 %0, [%1];" : "=r"(ret) : "l"(ptr)); - return ret; + uint32_t ret; + asm volatile("ld.relaxed.gpu.global.L1::no_allocate.b32 %0, [%1];" : "=r"(ret) : "l"(ptr)); + return ret; } __device__ __forceinline__ uint64_t ld_na_relaxed(const uint64_t *ptr) { - uint64_t ret; - asm volatile("ld.relaxed.gpu.global.L1::no_allocate.b64 %0, [%1];" : "=l"(ret) : "l"(ptr)); - return ret; + uint64_t ret; + asm volatile("ld.relaxed.gpu.global.L1::no_allocate.b64 %0, [%1];" : "=l"(ret) : "l"(ptr)); + return ret; } -__device__ __forceinline__ int ld_volatile_global(const int *ptr) { - int ret; - asm volatile("ld.volatile.global.s32 %0, [%1];" : "=r"(ret) : "l"(ptr)); - return ret; +__device__ __forceinline__ int ld_volatile_global(const int *ptr) { + int ret; + asm volatile("ld.volatile.global.s32 %0, [%1];" : "=r"(ret) : "l"(ptr)); + return ret; } -__device__ __forceinline__ float ld_volatile_global(const float *ptr) { - float ret; - asm volatile("ld.volatile.global.f32 %0, [%1];" : "=f"(ret) : "l"(ptr)); - return ret; +__device__ __forceinline__ float ld_volatile_global(const float *ptr) { + float ret; + asm volatile("ld.volatile.global.f32 %0, [%1];" : "=f"(ret) : "l"(ptr)); + return ret; } -__device__ __forceinline__ int64_t ld_volatile_global(const int64_t *ptr) { - int64_t ret; - asm volatile("ld.volatile.global.s64 %0, [%1];" : "=l"(ret) : "l"(ptr)); - return ret; +__device__ __forceinline__ int64_t ld_volatile_global(const int64_t *ptr) { + int64_t ret; + asm volatile("ld.volatile.global.s64 %0, [%1];" : "=l"(ret) : "l"(ptr)); + return ret; } -__device__ __forceinline__ int64_t ld_volatile_global(const uint64_t *ptr) { - int64_t ret; - asm volatile("ld.volatile.global.u64 %0, [%1];" : "=l"(ret) : "l"(ptr)); - return ret; +__device__ __forceinline__ int64_t ld_volatile_global(const uint64_t *ptr) { + int64_t ret; + asm volatile("ld.volatile.global.u64 %0, [%1];" : "=l"(ret) : "l"(ptr)); + return ret; } #ifndef DISABLE_AGGRESSIVE_PTX_INSTRS @@ -160,90 +166,93 @@ __device__ __forceinline__ int64_t ld_volatile_global(const uint64_t *ptr) { // `ld.global.nc.L1::no_allocate` will be translated into `LDG.E.NA.[width].CONSTANT` in SASS template -__device__ __forceinline__ dtype_t ld_nc_global(const dtype_t *ptr) { - auto ret = ld_nc_global(reinterpret_cast::vec_t*>(ptr)); - return *reinterpret_cast(&ret); +__device__ __forceinline__ dtype_t ld_nc_global(const dtype_t *ptr) { + auto ret = ld_nc_global(reinterpret_cast::vec_t *>(ptr)); + return *reinterpret_cast(&ret); } template <> -__device__ __forceinline__ uint8_t ld_nc_global(const uint8_t *ptr) { - uint16_t ret; - // NOTES: we must use `uint16_t` as inline ASM does not support 8-bit constraint letter (`h` below means unsigned 16-bit) - asm volatile(LD_NC_FUNC ".u8 %0, [%1];" : "=h"(ret) : "l"(ptr)); - return static_cast(ret); +__device__ __forceinline__ uint8_t ld_nc_global(const uint8_t *ptr) { + uint16_t ret; + // NOTES: we must use `uint16_t` as inline ASM does not support 8-bit constraint letter (`h` below means unsigned + // 16-bit) + asm volatile(LD_NC_FUNC ".u8 %0, [%1];" : "=h"(ret) : "l"(ptr)); + return static_cast(ret); } template <> -__device__ __forceinline__ int ld_nc_global(const int *ptr) { - int ret; - asm volatile(LD_NC_FUNC ".s32 %0, [%1];" : "=r"(ret) : "l"(ptr)); - return ret; +__device__ __forceinline__ int ld_nc_global(const int *ptr) { + int ret; + asm volatile(LD_NC_FUNC ".s32 %0, [%1];" : "=r"(ret) : "l"(ptr)); + return ret; } template <> -__device__ __forceinline__ int64_t ld_nc_global(const int64_t *ptr) { - int64_t ret; - asm volatile(LD_NC_FUNC ".s64 %0, [%1];" : "=l"(ret) : "l"(ptr)); - return ret; +__device__ __forceinline__ int64_t ld_nc_global(const int64_t *ptr) { + int64_t ret; + asm volatile(LD_NC_FUNC ".s64 %0, [%1];" : "=l"(ret) : "l"(ptr)); + return ret; } template <> -__device__ __forceinline__ float ld_nc_global(const float *ptr) { - float ret; - asm volatile(LD_NC_FUNC ".f32 %0, [%1];" : "=f"(ret) : "l"(ptr)); - return ret; +__device__ __forceinline__ float ld_nc_global(const float *ptr) { + float ret; + asm volatile(LD_NC_FUNC ".f32 %0, [%1];" : "=f"(ret) : "l"(ptr)); + return ret; } template <> -__device__ __forceinline__ int2 ld_nc_global(const int2 *ptr) { - int2 ret; - asm volatile(LD_NC_FUNC ".v2.s32 {%0, %1}, [%2];" : "=r"(ret.x), "=r"(ret.y) : "l"(ptr)); - return ret; +__device__ __forceinline__ int2 ld_nc_global(const int2 *ptr) { + int2 ret; + asm volatile(LD_NC_FUNC ".v2.s32 {%0, %1}, [%2];" : "=r"(ret.x), "=r"(ret.y) : "l"(ptr)); + return ret; } template <> -__device__ __forceinline__ int4 ld_nc_global(const int4 *ptr) { - int4 ret; - asm volatile(LD_NC_FUNC ".v4.s32 {%0, %1, %2, %3}, [%4];" - : "=r"(ret.x), "=r"(ret.y), "=r"(ret.z), "=r"(ret.w) : "l"(ptr)); - return ret; +__device__ __forceinline__ int4 ld_nc_global(const int4 *ptr) { + int4 ret; + asm volatile(LD_NC_FUNC ".v4.s32 {%0, %1, %2, %3}, [%4];" + : "=r"(ret.x), "=r"(ret.y), "=r"(ret.z), "=r"(ret.w) + : "l"(ptr)); + return ret; } __device__ __forceinline__ void st_na_relaxed(const uint8_t *ptr, uint8_t val) { - asm volatile("st.relaxed.gpu.global.L1::no_allocate.b8 [%0], %1;" : : "l"(ptr), "h"(static_cast(val))); + asm volatile("st.relaxed.gpu.global.L1::no_allocate.b8 [%0], %1;" : : "l"(ptr), "h"(static_cast(val))); } __device__ __forceinline__ void st_na_relaxed(const uint16_t *ptr, uint16_t val) { - asm volatile("st.relaxed.gpu.global.L1::no_allocate.b16 [%0], %1;" : : "l"(ptr), "h"(val)); + asm volatile("st.relaxed.gpu.global.L1::no_allocate.b16 [%0], %1;" : : "l"(ptr), "h"(val)); } __device__ __forceinline__ void st_na_relaxed(const uint32_t *ptr, uint32_t val) { - asm volatile("st.relaxed.gpu.global.L1::no_allocate.b32 [%0], %1;" : : "l"(ptr), "r"(val)); + asm volatile("st.relaxed.gpu.global.L1::no_allocate.b32 [%0], %1;" : : "l"(ptr), "r"(val)); } __device__ __forceinline__ void st_na_relaxed(const int *ptr, int val) { - asm volatile("st.relaxed.gpu.global.L1::no_allocate.b32 [%0], %1;" : : "l"(ptr), "r"(val)); + asm volatile("st.relaxed.gpu.global.L1::no_allocate.b32 [%0], %1;" : : "l"(ptr), "r"(val)); } __device__ __forceinline__ void st_na_relaxed(const int4 *ptr, int4 val) { - asm volatile("st.relaxed.gpu.global.L1::no_allocate.v4.s32 [%0], {%1, %2, %3, %4};" - : : "l"(ptr), "r"(val.x), "r"(val.y), "r"(val.z), "r"(val.w)); + asm volatile("st.relaxed.gpu.global.L1::no_allocate.v4.s32 [%0], {%1, %2, %3, %4};" + : + : "l"(ptr), "r"(val.x), "r"(val.y), "r"(val.z), "r"(val.w)); } __device__ __forceinline__ void st_na_release(const int *ptr, int val) { - asm volatile("st.release.gpu.global.L1::no_allocate.b32 [%0], %1;" : : "l"(ptr), "r"(val)); + asm volatile("st.release.gpu.global.L1::no_allocate.b32 [%0], %1;" : : "l"(ptr), "r"(val)); } __device__ __forceinline__ void st_na_release(const uint32_t *ptr, uint32_t val) { - asm volatile("st.release.gpu.global.L1::no_allocate.b32 [%0], %1;" : : "l"(ptr), "r"(val)); + asm volatile("st.release.gpu.global.L1::no_allocate.b32 [%0], %1;" : : "l"(ptr), "r"(val)); } __device__ __forceinline__ void st_na_release(const uint64_t *ptr, uint64_t val) { - asm volatile("st.release.gpu.global.L1::no_allocate.b64 [%0], %1;" : : "l"(ptr), "l"(val)); + asm volatile("st.release.gpu.global.L1::no_allocate.b64 [%0], %1;" : : "l"(ptr), "l"(val)); } __device__ __forceinline__ void st_na_release(const int64_t *ptr, int64_t val) { - asm volatile("st.release.gpu.global.L1::no_allocate.b64 [%0], %1;" : : "l"(ptr), "l"(val)); + asm volatile("st.release.gpu.global.L1::no_allocate.b64 [%0], %1;" : : "l"(ptr), "l"(val)); } // `st.global.L1::no_allocate` will be translated into `ST.E.NA.[width]` in SASS @@ -254,140 +263,137 @@ __device__ __forceinline__ void st_na_release(const int64_t *ptr, int64_t val) { #endif template -__device__ __forceinline__ void st_na_global(const dtype_t *ptr, const dtype_t& value) { - st_na_global(reinterpret_cast::vec_t*>(ptr), - *reinterpret_cast::vec_t*>(&value)); +__device__ __forceinline__ void st_na_global(const dtype_t *ptr, const dtype_t &value) { + st_na_global(reinterpret_cast::vec_t *>(ptr), + *reinterpret_cast::vec_t *>(&value)); } template <> -__device__ __forceinline__ void st_na_global(const int *ptr, const int& value) { - asm volatile(ST_NA_FUNC ".s32 [%0], %1;" ::"l"(ptr), "r"(value)); +__device__ __forceinline__ void st_na_global(const int *ptr, const int &value) { + asm volatile(ST_NA_FUNC ".s32 [%0], %1;" ::"l"(ptr), "r"(value)); } template <> -__device__ __forceinline__ void st_na_global(const int64_t *ptr, const int64_t& value) { - asm volatile(ST_NA_FUNC ".s64 [%0], %1;" ::"l"(ptr), "l"(value)); +__device__ __forceinline__ void st_na_global(const int64_t *ptr, const int64_t &value) { + asm volatile(ST_NA_FUNC ".s64 [%0], %1;" ::"l"(ptr), "l"(value)); } template <> -__device__ __forceinline__ void st_na_global(const float *ptr, const float& value) { - asm volatile(ST_NA_FUNC ".f32 [%0], %1;" ::"l"(ptr), "f"(value)); +__device__ __forceinline__ void st_na_global(const float *ptr, const float &value) { + asm volatile(ST_NA_FUNC ".f32 [%0], %1;" ::"l"(ptr), "f"(value)); } template <> -__device__ __forceinline__ void st_na_global(const int4 *ptr, const int4& value) { - asm volatile(ST_NA_FUNC ".v4.s32 [%0], {%1, %2, %3, %4};" - ::"l"(ptr), "r"(value.x), "r"(value.y), "r"(value.z), "r"(value.w)); +__device__ __forceinline__ void st_na_global(const int4 *ptr, const int4 &value) { + asm volatile(ST_NA_FUNC ".v4.s32 [%0], {%1, %2, %3, %4};" ::"l"(ptr), "r"(value.x), "r"(value.y), "r"(value.z), + "r"(value.w)); } template __host__ __device__ dtype_t cell_div(dtype_t a, dtype_t b) { - return (a + b - 1) / b; + return (a + b - 1) / b; } template __host__ __device__ dtype_t align(dtype_t a, dtype_t b) { - return cell_div(a, b) * b; + return cell_div(a, b) * b; } -__forceinline__ __device__ void get_channel_task_range(int num_tokens, int num_sms, int sm_id, - int& token_start_idx, int& token_end_idx) { - int num_tokens_per_sm = cell_div(num_tokens, num_sms); - token_start_idx = min(num_tokens_per_sm * sm_id, num_tokens); - token_end_idx = min(token_start_idx + num_tokens_per_sm, num_tokens); +__forceinline__ __device__ void get_channel_task_range(int num_tokens, int num_sms, int sm_id, int &token_start_idx, + int &token_end_idx) { + int num_tokens_per_sm = cell_div(num_tokens, num_sms); + token_start_idx = min(num_tokens_per_sm * sm_id, num_tokens); + token_end_idx = min(token_start_idx + num_tokens_per_sm, num_tokens); } template -__device__ __forceinline__ dtype_b_t pack2(const dtype_a_t& x, const dtype_a_t& y) { - EP_STATIC_ASSERT(sizeof(dtype_a_t) * 2 == sizeof(dtype_b_t), "Invalid dtypes"); - dtype_b_t packed; - auto unpacked_ptr = reinterpret_cast(&packed); - unpacked_ptr[0] = x, unpacked_ptr[1] = y; - return packed; +__device__ __forceinline__ dtype_b_t pack2(const dtype_a_t &x, const dtype_a_t &y) { + EP_STATIC_ASSERT(sizeof(dtype_a_t) * 2 == sizeof(dtype_b_t), "Invalid dtypes"); + dtype_b_t packed; + auto unpacked_ptr = reinterpret_cast(&packed); + unpacked_ptr[0] = x, unpacked_ptr[1] = y; + return packed; } template -__device__ __forceinline__ void unpack2(const dtype_b_t& packed, dtype_a_t& x, dtype_a_t& y) { - EP_STATIC_ASSERT(sizeof(dtype_a_t) * 2 == sizeof(dtype_b_t), "Invalid dtypes"); - auto unpacked_ptr = reinterpret_cast(&packed); - x = unpacked_ptr[0], y = unpacked_ptr[1]; +__device__ __forceinline__ void unpack2(const dtype_b_t &packed, dtype_a_t &x, dtype_a_t &y) { + EP_STATIC_ASSERT(sizeof(dtype_a_t) * 2 == sizeof(dtype_b_t), "Invalid dtypes"); + auto unpacked_ptr = reinterpret_cast(&packed); + x = unpacked_ptr[0], y = unpacked_ptr[1]; } template -__device__ __forceinline__ dtype_t broadcast(dtype_t& ptr, int src_lane_idx) { - EP_STATIC_ASSERT(sizeof(dtype_t) % sizeof(int) == 0, ""); - auto send_int_values = reinterpret_cast(&ptr); - int recv_int_values[sizeof(dtype_t) / sizeof(int)]; - #pragma unroll - for (int i = 0; i < sizeof(dtype_t) / sizeof(int); ++ i) - recv_int_values[i] = __shfl_sync(0xffffffff, send_int_values[i], src_lane_idx); - return *reinterpret_cast(recv_int_values); +__device__ __forceinline__ dtype_t broadcast(dtype_t &ptr, int src_lane_idx) { + EP_STATIC_ASSERT(sizeof(dtype_t) % sizeof(int) == 0, ""); + auto send_int_values = reinterpret_cast(&ptr); + int recv_int_values[sizeof(dtype_t) / sizeof(int)]; +#pragma unroll + for (int i = 0; i < sizeof(dtype_t) / sizeof(int); ++i) + recv_int_values[i] = __shfl_sync(0xffffffff, send_int_values[i], src_lane_idx); + return *reinterpret_cast(recv_int_values); } __forceinline__ __device__ int warp_reduce_sum(int value) { - value += __shfl_xor_sync(0xffffffff, value, 16); - value += __shfl_xor_sync(0xffffffff, value, 8); - value += __shfl_xor_sync(0xffffffff, value, 4); - value += __shfl_xor_sync(0xffffffff, value, 2); - value += __shfl_xor_sync(0xffffffff, value, 1); - return value; + value += __shfl_xor_sync(0xffffffff, value, 16); + value += __shfl_xor_sync(0xffffffff, value, 8); + value += __shfl_xor_sync(0xffffffff, value, 4); + value += __shfl_xor_sync(0xffffffff, value, 2); + value += __shfl_xor_sync(0xffffffff, value, 1); + return value; } __forceinline__ __device__ float half_warp_reduce_max(float value) { - auto mask = __activemask(); - // The mask be in `{0xffffffff, 0xffff}` - value = max(value, __shfl_xor_sync(mask, value, 8)); - value = max(value, __shfl_xor_sync(mask, value, 4)); - value = max(value, __shfl_xor_sync(mask, value, 2)); - value = max(value, __shfl_xor_sync(mask, value, 1)); - return value; + auto mask = __activemask(); + // The mask be in `{0xffffffff, 0xffff}` + value = max(value, __shfl_xor_sync(mask, value, 8)); + value = max(value, __shfl_xor_sync(mask, value, 4)); + value = max(value, __shfl_xor_sync(mask, value, 2)); + value = max(value, __shfl_xor_sync(mask, value, 1)); + return value; } __forceinline__ __device__ int get_lane_id() { - int lane_id; - asm("mov.s32 %0, %laneid;" : "=r"(lane_id)); - return lane_id; + int lane_id; + asm("mov.s32 %0, %laneid;" : "=r"(lane_id)); + return lane_id; } template __forceinline__ __device__ void move_fifo_slots(int &head) { - head = (head + kNumRanks) % NUM_MAX_FIFO_SLOTS; + head = (head + kNumRanks) % NUM_MAX_FIFO_SLOTS; } template __device__ __forceinline__ bool not_finished(int *task, int expected) { - auto result = false; - auto lane_id = threadIdx.x % 32; - if (lane_id < kNumRanks) - result = ld_volatile_global(task + lane_id) != expected; - return __any_sync(0xffffffff, result); + auto result = false; + auto lane_id = threadIdx.x % 32; + if (lane_id < kNumRanks) result = ld_volatile_global(task + lane_id) != expected; + return __any_sync(0xffffffff, result); } template -__forceinline__ __device__ void -timeout_check(int **task_fifo_ptrs, int head, int rank, int expected, int tag = 0) { - auto start_time = clock64(); - while (not_finished(task_fifo_ptrs[rank] + head, expected)) { - if (clock64() - start_time > NUM_TIMEOUT_CYCLES and threadIdx.x == 0) { - printf("DeepEP timeout check failed: %d (rank = %d)\n", tag, rank); - trap(); - } +__forceinline__ __device__ void timeout_check(int **task_fifo_ptrs, int head, int rank, int expected, int tag = 0) { + auto start_time = clock64(); + while (not_finished(task_fifo_ptrs[rank] + head, expected)) { + if (clock64() - start_time > NUM_TIMEOUT_CYCLES and threadIdx.x == 0) { + printf("DeepEP timeout check failed: %d (rank = %d)\n", tag, rank); + trap(); } + } } template -__forceinline__ __device__ void -barrier_device(int **task_fifo_ptrs, int head, int rank, int tag = 0) { - auto thread_id = static_cast(threadIdx.x); - EP_DEVICE_ASSERT(kNumRanks <= 32); +__forceinline__ __device__ void barrier_device(int **task_fifo_ptrs, int head, int rank, int tag = 0) { + auto thread_id = static_cast(threadIdx.x); + EP_DEVICE_ASSERT(kNumRanks <= 32); - if (thread_id < kNumRanks) { - atomicAdd_system(task_fifo_ptrs[rank] + head + thread_id, FINISHED_SUM_TAG); - memory_fence(); - atomicSub_system(task_fifo_ptrs[thread_id] + head + rank, FINISHED_SUM_TAG); - } - timeout_check(task_fifo_ptrs, head, rank, 0, tag); + if (thread_id < kNumRanks) { + atomicAdd_system(task_fifo_ptrs[rank] + head + thread_id, FINISHED_SUM_TAG); + memory_fence(); + atomicSub_system(task_fifo_ptrs[thread_id] + head + rank, FINISHED_SUM_TAG); + } + timeout_check(task_fifo_ptrs, head, rank, 0, tag); } -} // namespace ep -} // namespace mscclpp +} // namespace ep +} // namespace mscclpp diff --git a/test/python/ext/ep/test_internode_multirank.py b/test/python/ext/ep/test_internode_multirank.py index 09679305..b8a1ba0b 100644 --- a/test/python/ext/ep/test_internode_multirank.py +++ b/test/python/ext/ep/test_internode_multirank.py @@ -46,8 +46,9 @@ def init_dist(): world_size = int(os.environ["WORLD_SIZE"]) local_rank = int(os.environ.get("LOCAL_RANK", rank % 8)) torch.cuda.set_device(local_rank) - dist.init_process_group(backend="nccl", world_size=world_size, rank=rank, - device_id=torch.device(f"cuda:{local_rank}")) + dist.init_process_group( + backend="nccl", world_size=world_size, rank=rank, device_id=torch.device(f"cuda:{local_rank}") + ) return rank, world_size, local_rank, dist.new_group(list(range(world_size))) @@ -71,8 +72,9 @@ def main(): from mscclpp.ext import ep NUM_MAX_NVL_PEERS = 8 - assert num_ranks % NUM_MAX_NVL_PEERS == 0 and num_ranks > NUM_MAX_NVL_PEERS, \ - f"expected >1 node with 8 GPUs each, got num_ranks={num_ranks}" + assert ( + num_ranks % NUM_MAX_NVL_PEERS == 0 and num_ranks > NUM_MAX_NVL_PEERS + ), f"expected >1 node with 8 GPUs each, got num_ranks={num_ranks}" num_nodes = num_ranks // NUM_MAX_NVL_PEERS num_local_ranks = NUM_MAX_NVL_PEERS @@ -80,7 +82,7 @@ def main(): num_tokens = 128 hidden = 1024 num_topk = min(4, num_ranks) - num_experts = (num_ranks * 4) # multiple of num_ranks + num_experts = num_ranks * 4 # multiple of num_ranks torch.manual_seed(0xA1B2 + rank) @@ -125,19 +127,25 @@ def main(): num_nvl_bytes = cfg.get_nvl_buffer_size_hint(_buf_hidden * x.element_size(), num_ranks) num_rdma_bytes = cfg.get_rdma_buffer_size_hint(_buf_hidden * x.element_size(), num_ranks) if rank == 0: - print(f"[cfg] num_nodes={num_nodes} num_ranks={num_ranks} num_tokens={num_tokens} " - f"hidden={hidden} num_experts={num_experts} num_topk={num_topk} " - f"num_nvl_bytes={num_nvl_bytes} num_rdma_bytes={num_rdma_bytes}", - flush=True) + print( + f"[cfg] num_nodes={num_nodes} num_ranks={num_ranks} num_tokens={num_tokens} " + f"hidden={hidden} num_experts={num_experts} num_topk={num_topk} " + f"num_nvl_bytes={num_nvl_bytes} num_rdma_bytes={num_rdma_bytes}", + flush=True, + ) print(f"[rank {rank}] creating Buffer", flush=True) buf = ep.Buffer(group, num_nvl_bytes=num_nvl_bytes, num_rdma_bytes=num_rdma_bytes, low_latency_mode=False) - print(f"[rank {rank}] Buffer created is_available={buf.is_available()} " - f"is_internode={buf.is_internode_available()}", flush=True) + print( + f"[rank {rank}] Buffer created is_available={buf.is_available()} " + f"is_internode={buf.is_internode_available()}", + flush=True, + ) assert buf.is_available() and buf.is_internode_available() - ref_rank, ref_rdma_rank, ref_exp, ref_in_rank, _ = \ - buf.runtime.get_dispatch_layout(topk_idx, num_experts, None, False, False) + ref_rank, ref_rdma_rank, ref_exp, ref_in_rank, _ = buf.runtime.get_dispatch_layout( + topk_idx, num_experts, None, False, False + ) assert torch.allclose(ref_rank, num_tokens_per_rank) assert torch.allclose(ref_rdma_rank, num_tokens_per_rdma_rank) assert torch.allclose(ref_exp, num_tokens_per_expert) @@ -153,17 +161,42 @@ def main(): # cached_rdma_channel_prefix_matrix=None, cached_recv_rdma_rank_prefix_sum=None, # cached_gbl_channel_prefix_matrix=None, cached_recv_gbl_rank_prefix_sum=None, # expert_alignment, config, previous_event, async, allocate_on_comm_stream) - (recv_x, recv_x_scales, recv_topk_idx, recv_topk_weights, - num_recv_tokens_per_expert_list, - rdma_channel_prefix_matrix, gbl_channel_prefix_matrix, - recv_rdma_channel_prefix_matrix, recv_rdma_rank_prefix_sum, - recv_gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum, - recv_src_meta, send_rdma_head, send_nvl_head, _event) = buf.runtime.internode_dispatch( - x, None, topk_idx, topk_weights, - num_tokens_per_rank, num_tokens_per_rdma_rank, is_token_in_rank, num_tokens_per_expert, - 0, 0, - None, None, None, None, - 1, cfg, None, False, False, + ( + recv_x, + recv_x_scales, + recv_topk_idx, + recv_topk_weights, + num_recv_tokens_per_expert_list, + rdma_channel_prefix_matrix, + gbl_channel_prefix_matrix, + recv_rdma_channel_prefix_matrix, + recv_rdma_rank_prefix_sum, + recv_gbl_channel_prefix_matrix, + recv_gbl_rank_prefix_sum, + recv_src_meta, + send_rdma_head, + send_nvl_head, + _event, + ) = buf.runtime.internode_dispatch( + x, + None, + topk_idx, + topk_weights, + num_tokens_per_rank, + num_tokens_per_rdma_rank, + is_token_in_rank, + num_tokens_per_expert, + 0, + 0, + None, + None, + None, + None, + 1, + cfg, + None, + False, + False, ) dist.barrier(group=group) @@ -176,9 +209,9 @@ def main(): if block.numel(): lo = block.float().amin().item() hi = block.float().amax().item() - assert abs(lo - src) < 1e-3 and abs(hi - src) < 1e-3, ( - f"rank{rank}: block from src={src} has range=[{lo}, {hi}], expected {src}" - ) + assert ( + abs(lo - src) < 1e-3 and abs(hi - src) < 1e-3 + ), f"rank{rank}: block from src={src} has range=[{lo}, {hi}], expected {src}" start = end if rank == 0: print(f"[dispatch] OK (recv {recv_x.size(0)} tokens)", flush=True) @@ -202,11 +235,19 @@ def main(): # (`recv_rdma_channel_prefix_matrix`, `recv_rdma_rank_prefix_sum`, # `recv_gbl_channel_prefix_matrix`) — not the sender-side ones. combined_x, combined_topk_weights, _ = buf.runtime.internode_combine( - recv_x, recv_topk_weights, - recv_src_meta, is_token_in_rank, - recv_rdma_channel_prefix_matrix, recv_rdma_rank_prefix_sum, recv_gbl_channel_prefix_matrix, - send_rdma_head, send_nvl_head, - cfg, None, False, False, + recv_x, + recv_topk_weights, + recv_src_meta, + is_token_in_rank, + recv_rdma_channel_prefix_matrix, + recv_rdma_rank_prefix_sum, + recv_gbl_channel_prefix_matrix, + send_rdma_head, + send_nvl_head, + cfg, + None, + False, + False, ) num_dst = is_token_in_rank.sum(dim=1).to(torch.float32) @@ -235,19 +276,17 @@ def main(): # NCCL-EP's `ep_bench -a ht` defaults (256 experts, top-8). The functional # check above still uses the smaller (num_experts=num_ranks*4, topk=4) # configuration. - bench_num_experts = int(os.environ.get( - "MSCCLPP_EP_BENCH_EXPERTS", str(num_experts))) - bench_num_topk = int(os.environ.get( - "MSCCLPP_EP_BENCH_TOPK", str(num_topk))) + bench_num_experts = int(os.environ.get("MSCCLPP_EP_BENCH_EXPERTS", str(num_experts))) + bench_num_topk = int(os.environ.get("MSCCLPP_EP_BENCH_TOPK", str(num_topk))) if bench_num_experts % num_ranks != 0: if rank == 0: - print(f"[bench] skip: num_experts={bench_num_experts} not divisible " - f"by num_ranks={num_ranks}", flush=True) + print( + f"[bench] skip: num_experts={bench_num_experts} not divisible " f"by num_ranks={num_ranks}", flush=True + ) return if bench_num_topk > bench_num_experts: if rank == 0: - print(f"[bench] skip: topk={bench_num_topk} > experts={bench_num_experts}", - flush=True) + print(f"[bench] skip: topk={bench_num_topk} > experts={bench_num_experts}", flush=True) return # Respect the Buffer's pre-sized num_nvl_bytes / num_rdma_bytes budget. @@ -294,20 +333,43 @@ def main(): def _dispatch(): return buf.runtime.internode_dispatch( - x_b, None, topk_idx_b, topk_weights_b, - num_tokens_per_rank_b, num_tokens_per_rdma_rank_b, is_token_in_rank_b, num_tokens_per_expert_b, - 0, 0, None, None, None, None, - 1, cfg, None, False, False, + x_b, + None, + topk_idx_b, + topk_weights_b, + num_tokens_per_rank_b, + num_tokens_per_rdma_rank_b, + is_token_in_rank_b, + num_tokens_per_expert_b, + 0, + 0, + None, + None, + None, + None, + 1, + cfg, + None, + False, + False, ) def _combine(dout): - (rx, _rxs, _rti, rtw, _lst, - _rpm, _gpm, rrcpm, rrps, rgpm, _rgps, - rsm, sh_rdma, sh_nvl, _ev) = dout + rx, _rxs, _rti, rtw, _lst, _rpm, _gpm, rrcpm, rrps, rgpm, _rgps, rsm, sh_rdma, sh_nvl, _ev = dout buf.runtime.internode_combine( - rx, rtw, rsm, is_token_in_rank_b, - rrcpm, rrps, rgpm, - sh_rdma, sh_nvl, cfg, None, False, False, + rx, + rtw, + rsm, + is_token_in_rank_b, + rrcpm, + rrps, + rgpm, + sh_rdma, + sh_nvl, + cfg, + None, + False, + False, ) # Warmup (full round-trip with the sync/barrier guard between phases, @@ -369,16 +431,16 @@ def main(): num_tokens_per_rank_b.to(torch.int64), group=group, ) - src_node = (torch.arange(num_ranks, device="cuda") // num_local_ranks) + src_node = torch.arange(num_ranks, device="cuda") // num_local_ranks remote_mask = (src_node != local_node).to(torch.int64) total_recv_tokens_local = int(recv_from_src.sum().item()) rdma_recv_tokens_local = int((recv_from_src * remote_mask).sum().item()) # Average per-rank token counts across ranks (matches NCCL-EP `Byte counts (per rank avg)`). counts_t = torch.tensor( - [total_send_tokens_local, rdma_send_tokens_local, - total_recv_tokens_local, rdma_recv_tokens_local], - dtype=torch.float64, device="cuda", + [total_send_tokens_local, rdma_send_tokens_local, total_recv_tokens_local, rdma_recv_tokens_local], + dtype=torch.float64, + device="cuda", ) dist.all_reduce(counts_t, op=dist.ReduceOp.SUM, group=group) counts_avg = (counts_t / num_ranks).tolist() @@ -469,6 +531,7 @@ if __name__ == "__main__": main() except Exception: import traceback + traceback.print_exc() sys.exit(1) finally: diff --git a/test/python/ext/ep/test_intranode_multirank.py b/test/python/ext/ep/test_intranode_multirank.py index fa2c0dc6..08bf4355 100644 --- a/test/python/ext/ep/test_intranode_multirank.py +++ b/test/python/ext/ep/test_intranode_multirank.py @@ -111,9 +111,11 @@ def main(): _buf_hidden = max(hidden, int(os.environ.get("MSCCLPP_EP_BENCH_HIDDEN", "0"))) if _bench_on else hidden num_nvl_bytes = cfg.get_nvl_buffer_size_hint(_buf_hidden * x.element_size(), num_ranks) if rank == 0: - print(f"[cfg] num_ranks={num_ranks} num_tokens={num_tokens} hidden={hidden} " - f"num_experts={num_experts} num_topk={num_topk} num_nvl_bytes={num_nvl_bytes}", - flush=True) + print( + f"[cfg] num_ranks={num_ranks} num_tokens={num_tokens} hidden={hidden} " + f"num_experts={num_experts} num_topk={num_topk} num_nvl_bytes={num_nvl_bytes}", + flush=True, + ) print(f"[rank {rank}] creating Buffer", flush=True) buf = ep.Buffer(group, num_nvl_bytes=num_nvl_bytes, num_rdma_bytes=0, low_latency_mode=False) @@ -129,14 +131,34 @@ def main(): print("[layout] OK", flush=True) # Dispatch - (recv_x, recv_x_scales, recv_topk_idx, recv_topk_weights, - num_recv_tokens_per_expert_list, - rank_prefix_matrix, channel_prefix_matrix, recv_channel_prefix_matrix, recv_src_idx, - send_head, _event) = buf.runtime.intranode_dispatch( - x, None, topk_idx, topk_weights, - num_tokens_per_rank, is_token_in_rank, num_tokens_per_expert, - 0, None, None, - 1, cfg, None, False, False, + ( + recv_x, + recv_x_scales, + recv_topk_idx, + recv_topk_weights, + num_recv_tokens_per_expert_list, + rank_prefix_matrix, + channel_prefix_matrix, + recv_channel_prefix_matrix, + recv_src_idx, + send_head, + _event, + ) = buf.runtime.intranode_dispatch( + x, + None, + topk_idx, + topk_weights, + num_tokens_per_rank, + is_token_in_rank, + num_tokens_per_expert, + 0, + None, + None, + 1, + cfg, + None, + False, + False, ) dist.barrier(group=group) @@ -149,9 +171,7 @@ def main(): block = recv_x[start:end] if block.numel(): actual = block.float().amin().item() - assert abs(actual - src) < 1e-3, ( - f"rank{rank}: block from src={src} has min={actual}, expected {src}" - ) + assert abs(actual - src) < 1e-3, f"rank{rank}: block from src={src} has min={actual}, expected {src}" assert abs(block.float().amax().item() - src) < 1e-3 start = end if rank == 0: @@ -165,9 +185,16 @@ def main(): handle_channel_prefix_matrix = recv_channel_prefix_matrix combined_x, combined_topk_weights, _ = buf.runtime.intranode_combine( - recv_x, recv_topk_weights, - handle_recv_src_idx, handle_rank_prefix_matrix, handle_channel_prefix_matrix, - send_head, cfg, None, False, False, + recv_x, + recv_topk_weights, + handle_recv_src_idx, + handle_rank_prefix_matrix, + handle_channel_prefix_matrix, + send_head, + cfg, + None, + False, + False, ) # Expected: we dispatched with x = rank * ones, so every destination r @@ -201,19 +228,17 @@ def main(): # NCCL-EP's `ep_bench -a ht` defaults (256 experts, top-8). The functional # check above still uses the smaller (num_experts=num_ranks*4, topk=4) # configuration. - bench_num_experts = int(os.environ.get( - "MSCCLPP_EP_BENCH_EXPERTS", str(num_experts))) - bench_num_topk = int(os.environ.get( - "MSCCLPP_EP_BENCH_TOPK", str(num_topk))) + bench_num_experts = int(os.environ.get("MSCCLPP_EP_BENCH_EXPERTS", str(num_experts))) + bench_num_topk = int(os.environ.get("MSCCLPP_EP_BENCH_TOPK", str(num_topk))) if bench_num_experts % num_ranks != 0: if rank == 0: - print(f"[bench] skip: num_experts={bench_num_experts} not divisible " - f"by num_ranks={num_ranks}", flush=True) + print( + f"[bench] skip: num_experts={bench_num_experts} not divisible " f"by num_ranks={num_ranks}", flush=True + ) return if bench_num_topk > bench_num_experts: if rank == 0: - print(f"[bench] skip: topk={bench_num_topk} > experts={bench_num_experts}", - flush=True) + print(f"[bench] skip: topk={bench_num_topk} > experts={bench_num_experts}", flush=True) return # Rebuild inputs at bench size. Keep same layout recipe as above but at @@ -253,15 +278,36 @@ def main(): def _dispatch(): return buf.runtime.intranode_dispatch( - x_b, None, topk_idx_b, topk_weights_b, - num_tokens_per_rank_b, is_token_in_rank_b, num_tokens_per_expert_b, - 0, None, None, 1, cfg, None, False, False, + x_b, + None, + topk_idx_b, + topk_weights_b, + num_tokens_per_rank_b, + is_token_in_rank_b, + num_tokens_per_expert_b, + 0, + None, + None, + 1, + cfg, + None, + False, + False, ) def _combine(dout): - (rx, _rxs, _rti, rtw, _lst, rpm, _cpm, rcpm, rsi, sh, _ev) = dout + rx, _rxs, _rti, rtw, _lst, rpm, _cpm, rcpm, rsi, sh, _ev = dout buf.runtime.intranode_combine( - rx, rtw, rsi, rpm, rcpm, sh, cfg, None, False, False, + rx, + rtw, + rsi, + rpm, + rcpm, + sh, + cfg, + None, + False, + False, ) # Warmup (full round-trip). @@ -319,9 +365,9 @@ def main(): # Average per-rank token counts across ranks (matches NCCL-EP `Byte counts (per rank avg)`). counts_t = torch.tensor( - [total_send_tokens_local, rdma_send_tokens_local, - total_recv_tokens_local, rdma_recv_tokens_local], - dtype=torch.float64, device="cuda", + [total_send_tokens_local, rdma_send_tokens_local, total_recv_tokens_local, rdma_recv_tokens_local], + dtype=torch.float64, + device="cuda", ) dist.all_reduce(counts_t, op=dist.ReduceOp.SUM, group=group) counts_avg = (counts_t / num_ranks).tolist() @@ -410,6 +456,7 @@ if __name__ == "__main__": main() except Exception: import traceback + traceback.print_exc() sys.exit(1) finally: diff --git a/test/python/ext/ep/test_low_latency_multirank.py b/test/python/ext/ep/test_low_latency_multirank.py index a5b3e26d..18e64a4e 100644 --- a/test/python/ext/ep/test_low_latency_multirank.py +++ b/test/python/ext/ep/test_low_latency_multirank.py @@ -70,7 +70,9 @@ def main(): assert num_ranks - rank_offset < 257, "too many ranks for bf16 precision anchor" num_tokens = int(os.environ.get("MSCCLPP_EP_LL_TOKENS", "64")) - hidden = int(os.environ.get("MSCCLPP_EP_LL_HIDDEN", "7168")) # LL kernels are compiled for a fixed set; see SWITCH_HIDDEN + hidden = int( + os.environ.get("MSCCLPP_EP_LL_HIDDEN", "7168") + ) # LL kernels are compiled for a fixed set; see SWITCH_HIDDEN num_topk = int(os.environ.get("MSCCLPP_EP_LL_TOPK", "4")) num_experts_per_rank = int(os.environ.get("MSCCLPP_EP_LL_EXPERTS_PER_RANK", "4")) num_experts = num_ranks * num_experts_per_rank @@ -83,9 +85,7 @@ def main(): x = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device="cuda") * (rank - rank_offset) # Encode the per-token index into the last 128 elements so the receiver # can verify which source token it is looking at. - x[:, -128:] = ( - torch.arange(num_tokens, device="cuda").to(torch.bfloat16).view(-1, 1) - ) + x[:, -128:] = torch.arange(num_tokens, device="cuda").to(torch.bfloat16).view(-1, 1) scores = torch.randn((num_tokens, num_experts), dtype=torch.float32, device="cuda").abs() + 1 topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=True)[1] topk_weights = torch.randn((num_tokens, num_topk), dtype=torch.float32, device="cuda").abs() @@ -94,9 +94,7 @@ def main(): for _ in range(min(10, num_tokens)): topk_idx[random.randint(0, num_tokens - 1), random.randint(0, num_topk - 1)] = -1 - num_rdma_bytes = ep.Buffer.get_low_latency_rdma_size_hint( - num_tokens, hidden, num_ranks, num_experts - ) + num_rdma_bytes = ep.Buffer.get_low_latency_rdma_size_hint(num_tokens, hidden, num_ranks, num_experts) if rank == 0: print( f"[cfg] num_ranks={num_ranks} num_tokens={num_tokens} hidden={hidden} " @@ -129,12 +127,21 @@ def main(): # packed_recv_count, packed_recv_src_info, packed_recv_layout_range, # event, hook ( - packed_recv_x, _packed_recv_x_scales, - packed_recv_count, packed_recv_src_info, packed_recv_layout_range, - _event, recv_hook, + packed_recv_x, + _packed_recv_x_scales, + packed_recv_count, + packed_recv_src_info, + packed_recv_layout_range, + _event, + recv_hook, ) = buf.low_latency_dispatch( - x, topk_idx, num_tokens, num_experts, - False, False, True, # use_fp8, async, return_recv_hook + x, + topk_idx, + num_tokens, + num_experts, + False, + False, + True, # use_fp8, async, return_recv_hook ) # Send phase launched on compute_stream; wait for local launch. torch.cuda.synchronize() @@ -158,12 +165,12 @@ def main(): expected_count = int((all_topk_idx == expert_id).sum().item()) recv_layout_range = handle[1][i] layout_sum = int((recv_layout_range & int_mask).sum().item()) - assert recv_count == expected_count, ( - f"rank{rank} expert{expert_id}: recv_count={recv_count} != expected={expected_count}" - ) - assert layout_sum == recv_count, ( - f"rank{rank} expert{expert_id}: layout range sum {layout_sum} != recv_count {recv_count}" - ) + assert ( + recv_count == expected_count + ), f"rank{rank} expert{expert_id}: recv_count={recv_count} != expected={expected_count}" + assert ( + layout_sum == recv_count + ), f"rank{rank} expert{expert_id}: layout range sum {layout_sum} != recv_count {recv_count}" if recv_count: recv_x = packed_recv_x[i, :recv_count] @@ -186,10 +193,16 @@ def main(): # zero_copy, async, return_recv_hook, out) src_info, layout_range = handle[0], handle[1] combined_x, _event, _hook = buf.low_latency_combine( - simulated_gemm_x, topk_idx, topk_weights, - src_info, layout_range, - num_tokens, num_experts, - False, False, False, # zero_copy, async, return_recv_hook + simulated_gemm_x, + topk_idx, + topk_weights, + src_info, + layout_range, + num_tokens, + num_experts, + False, + False, + False, # zero_copy, async, return_recv_hook out, ) @@ -230,23 +243,34 @@ def main(): num_local_experts = num_experts // num_ranks bench_packed_recv_x = torch.empty( (num_local_experts, num_ranks * num_tokens, hidden), - dtype=torch.bfloat16, device="cuda", + dtype=torch.bfloat16, + device="cuda", ) bench_packed_recv_src_info = torch.empty( (num_local_experts, num_ranks * num_tokens), - dtype=torch.int32, device="cuda", + dtype=torch.int32, + device="cuda", ) bench_packed_recv_layout_range = torch.empty( - (num_local_experts, num_ranks), dtype=torch.int64, device="cuda", + (num_local_experts, num_ranks), + dtype=torch.int64, + device="cuda", ) bench_packed_recv_count = torch.empty( - (num_local_experts,), dtype=torch.int32, device="cuda", + (num_local_experts,), + dtype=torch.int32, + device="cuda", ) def _dispatch(): return buf.low_latency_dispatch( - x, topk_idx, num_tokens, num_experts, - False, False, False, # use_fp8, async, return_recv_hook + x, + topk_idx, + num_tokens, + num_experts, + False, + False, + False, # use_fp8, async, return_recv_hook bench_packed_recv_x, None, # x_scales (FP8 only) bench_packed_recv_src_info, @@ -261,12 +285,18 @@ def main(): bench_out = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device="cuda") def _combine(dout, out_): - (recv_x, _scales, _cnt, src_info_, layout_range_, _ev, _hk) = dout + recv_x, _scales, _cnt, src_info_, layout_range_, _ev, _hk = dout buf.low_latency_combine( - recv_x, topk_idx, topk_weights, - src_info_, layout_range_, - num_tokens, num_experts, - False, False, False, + recv_x, + topk_idx, + topk_weights, + src_info_, + layout_range_, + num_tokens, + num_experts, + False, + False, + False, out_, )