// Copyright (c) Microsoft Corporation. // Licensed under the MIT License. #pragma once #include #include #include #include #include #include #include #include #include "config.hpp" #include "event.hpp" #include "kernels/configs.cuh" #include "kernels/exception.cuh" #if defined(USE_IBVERBS) && defined(MSCCLPP_USE_MLX5DV) && !defined(MSCCLPP_USE_ROCM) #define MSCCLPP_EP_HAVE_IBGDA 1 namespace mscclpp { namespace ep { struct IbgdaSetup; } } #endif #ifndef TORCH_EXTENSION_NAME #define TORCH_EXTENSION_NAME mscclpp_ep_cpp #endif namespace mscclpp { namespace ep { struct Buffer { EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS == 8, "The number of maximum NVLink peers must be 8"); private: // Low-latency mode buffer int low_latency_buffer_idx = 0; bool low_latency_mode = false; // NVLink Buffer int64_t num_nvl_bytes; void* buffer_ptrs[NUM_MAX_NVL_PEERS] = {nullptr}; void** buffer_ptrs_gpu = nullptr; // NVSHMEM Buffer int64_t num_rdma_bytes; void* rdma_buffer_ptr = nullptr; // Device info and communication int device_id; int rank, rdma_rank, nvl_rank; int num_ranks, num_rdma_ranks, num_nvl_ranks; cudaIpcMemHandle_t ipc_handles[NUM_MAX_NVL_PEERS]; // Stream for communication at::cuda::CUDAStream comm_stream; // After IPC/NVSHMEM synchronization, this flag will be true bool available = false; // Task fifo int head = 0; int* task_fifo_ptrs[NUM_MAX_NVL_PEERS] = {nullptr}; int** task_fifo_ptrs_gpu = nullptr; // Workspace void* workspace = nullptr; // Host-side MoE info volatile int* moe_recv_counter = nullptr; int* moe_recv_counter_mapped = nullptr; // Host-side expert-level MoE info volatile int* moe_recv_expert_counter = nullptr; int* moe_recv_expert_counter_mapped = nullptr; // Host-side RDMA-level MoE info volatile int* moe_recv_rdma_counter = nullptr; int* moe_recv_rdma_counter_mapped = nullptr; std::shared_ptr bootstrap; // One ProxyService spawns a single proxy thread that drains every PortChannel // FIFO it owns. With LL combine pushing thousands of triggers per iter, the // single thread becomes the wall-clock bottleneck on cross-node runs. We // shard channels across `proxy_services` so each gets its own thread/FIFO, // increasing host-side dispatch parallelism (no kernel changes required). // Count is resolved at construction (env `MSCCLPP_EP_NUM_PROXIES` or // arch-aware default). int num_proxy_services = 1; std::vector> proxy_services; std::shared_ptr communicator; std::vector port_channels; std::vector memory_channels; std::shared_ptr port_channel_handles_device_ptr; std::shared_ptr memory_channel_handles_device_ptr; // Intra-node LL only: peer-mapped RDMA buffer pointers (CUDA IPC). // ``peer_rdma_bases[r]`` aliases rank ``r``'s ``rdma_buffer_ptr`` via // ``cudaIpcOpenMemHandle`` (lazy peer access). Populated in ``sync()`` when // ``low_latency_mode && num_rdma_ranks == 1``; null otherwise. cudaIpcMemHandle_t rdma_ipc_handles[NUM_MAX_NVL_PEERS]; void* peer_rdma_bases[NUM_MAX_NVL_PEERS] = {nullptr}; void** peer_rdma_bases_gpu = nullptr; // MemoryChannels over CUDA IPC used only for the LL barrier ring. std::vector ll_memory_channels; std::shared_ptr ll_memory_channel_handles_device_ptr; bool ll_ipc_ready = false; // ------------------------------------------------------------------ // Phase 11 — Hybrid LL fast path. // // In multi-node LL with IBGDA, also open CUDA IPC peer pointers for // same-node neighbors so the kernel can prefer NVLink for intranode // peers and IBGDA for internode peers (matching nccl-ep's behavior). // // `hybrid_peer_bases` is sparse: indexed by global rank, populated // only for same-node peers (rank' / NUM_MAX_NVL_PEERS == rdma_rank // && rank' != rank). Cross-node and self entries are nullptr; the // kernel checks for nullptr to decide IPC vs IBGDA per peer. // // Built lazily in `sync()` when: // - low_latency_mode && num_rdma_ranks > 1 // - env MSCCLPP_EP_USE_IBGDA=1 && IBGDA setup succeeds // - env MSCCLPP_EP_HYBRID_LL is not set to "0" // ------------------------------------------------------------------ std::vector hybrid_ipc_handles; std::vector hybrid_peer_bases; // size num_ranks; same-node entries non-null void** hybrid_peer_bases_gpu = nullptr; // GPU array of size num_ranks bool hybrid_ll_ready = false; // ------------------------------------------------------------------ // Native IBGDA path (Stage 4b). Built lazily in `sync()` when env // `MSCCLPP_EP_USE_IBGDA=1` is set AND the run is cross-node. // The kernels do NOT consume `ibgda_setup_` until 4b.2 lands; for now // it is constructed-but-unused, so existing tests are unaffected. // ------------------------------------------------------------------ bool use_ibgda_path_ = false; #ifdef MSCCLPP_EP_HAVE_IBGDA std::unique_ptr ibgda_setup_; #endif private: void move_fifo_slots(int num_slots = 1); public: Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_bytes, bool low_latency_mode); ~Buffer() noexcept(false); bool is_available() const; bool is_internode_available() const; int get_num_rdma_ranks() const; int get_rdma_rank() const; int get_root_rdma_rank(bool global) const; int get_local_device_id() const; pybind11::bytearray get_local_ipc_handle() const; pybind11::bytearray get_local_nvshmem_unique_id() const; torch::Tensor get_local_buffer_tensor(const pybind11::object& dtype, int64_t offset, bool use_rdma_buffer) const; mscclpp::UniqueId create_unique_id() const; void connect(mscclpp::UniqueId root_id); void sync(const std::vector& device_ids, const std::vector>& all_gathered_handles, const std::optional& root_unique_id_opt); std::tuple, torch::Tensor, torch::Tensor, std::optional> get_dispatch_layout(const torch::Tensor& topk_idx, int num_experts, std::optional& previous_event, bool async, bool allocate_on_comm_stream); std::tuple, std::optional, std::optional, std::vector, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, std::optional> intranode_dispatch(const torch::Tensor& x, const std::optional& x_scales, const std::optional& topk_idx, const std::optional& topk_weights, const std::optional& num_tokens_per_rank, const torch::Tensor& is_token_in_rank, const std::optional& num_tokens_per_expert, int cached_num_recv_tokens, const std::optional& cached_rank_prefix_matrix, const std::optional& cached_channel_prefix_matrix, int expert_alignment, const Config& config, std::optional& previous_event, bool async, bool allocate_on_comm_stream); std::tuple, std::optional> intranode_combine( const torch::Tensor& x, const std::optional& topk_weights, const torch::Tensor& src_idx, const torch::Tensor& rank_prefix_matrix, const torch::Tensor& channel_prefix_matrix, const torch::Tensor& send_head, const Config& config, std::optional& previous_event, bool async, bool allocate_on_comm_stream); std::tuple, std::optional, std::optional, std::vector, torch::Tensor, torch::Tensor, std::optional, torch::Tensor, std::optional, torch::Tensor, std::optional, std::optional, std::optional, std::optional> internode_dispatch(const torch::Tensor& x, const std::optional& x_scales, const std::optional& topk_idx, const std::optional& topk_weights, const std::optional& num_tokens_per_rank, const std::optional& num_tokens_per_rdma_rank, const torch::Tensor& is_token_in_rank, const std::optional& num_tokens_per_expert, int cached_num_recv_tokens, int cached_num_rdma_recv_tokens, const std::optional& cached_rdma_channel_prefix_matrix, const std::optional& cached_recv_rdma_rank_prefix_sum, const std::optional& cached_gbl_channel_prefix_matrix, const std::optional& cached_recv_gbl_rank_prefix_sum, int expert_alignment, const Config& config, std::optional& previous_event, bool async, bool allocate_on_comm_stream); std::tuple, std::optional> internode_combine( const torch::Tensor& x, const std::optional& topk_weights, const torch::Tensor& src_meta, const torch::Tensor& is_combined_token_in_rank, const torch::Tensor& rdma_channel_prefix_matrix, const torch::Tensor& rdma_rank_prefix_sum, const torch::Tensor& gbl_channel_prefix_matrix, const torch::Tensor& combined_rdma_head, const torch::Tensor& combined_nvl_head, const Config& config, std::optional& previous_event, bool async, bool allocate_on_comm_stream); void clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts); std::tuple, torch::Tensor, torch::Tensor, torch::Tensor, std::optional, std::optional>> low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx, int num_max_dispatch_tokens_per_rank, int num_experts, bool use_fp8, bool async, bool return_recv_hook, const std::optional& out_packed_recv_x = std::nullopt, const std::optional& out_packed_recv_x_scales = std::nullopt, const std::optional& out_packed_recv_src_info = std::nullopt, const std::optional& out_packed_recv_layout_range = std::nullopt, const std::optional& out_packed_recv_count = std::nullopt); std::tuple, std::optional>> low_latency_combine( const torch::Tensor& x, const torch::Tensor& topk_idx, const torch::Tensor& topk_weights, const torch::Tensor& src_info, const torch::Tensor& layout_range, int num_max_dispatch_tokens_per_rank, int num_experts, bool zero_copy, bool async, bool return_recv_hook, const std::optional& out = std::nullopt); torch::Tensor get_next_low_latency_combine_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts); }; } // namespace ep } // namespace mscclpp