# Base classes for MoE CPU inference operations # SPDX-License-Identifier: Apache-2.0 """ Base infrastructure for CPU-based MoE inference. This module contains base classes and utilities shared across all backend implementations. """ from __future__ import annotations import torch from typing import Dict, List, Optional, Tuple from abc import ABC, abstractmethod import os import ctypes import kt_kernel_ext class KExpertsCPUBuffer: """ CPU buffer management for expert computation. Manages pinned memory buffers for efficient GPU-CPU data transfer. """ capture_bs: List = list() capture_buffers: Dict = dict() temp_bs: int = 0 temp_buffer: tuple = tuple() buffer_depth: int = 2 @classmethod def get_buffer(cls, hidden_states: torch.Tensor, num_experts_per_tok): hidden_size = hidden_states.shape[-1] batch_size = hidden_states.shape[0] if batch_size in cls.capture_buffers: return cls.capture_buffers[batch_size] if batch_size == cls.temp_bs: return cls.temp_buffer input_tensor_cpu = [ torch.zeros((batch_size, hidden_size), device="cpu", pin_memory=True, dtype=torch.bfloat16) for _ in range(cls.buffer_depth) ] immediate_experts_ids_cpu = [ torch.zeros((batch_size, num_experts_per_tok), device="cpu", dtype=torch.long, pin_memory=True) for _ in range(cls.buffer_depth) ] deferred_experts_ids_cpu = [ torch.full((batch_size, num_experts_per_tok), -1, device="cpu", dtype=torch.long, pin_memory=True) for _ in range(cls.buffer_depth) ] weights_cpu = [ torch.zeros((batch_size, num_experts_per_tok), device="cpu", dtype=torch.float32, pin_memory=True) for _ in range(cls.buffer_depth) ] output_cpu = [ torch.zeros((batch_size, hidden_size), device="cpu", pin_memory=True, dtype=torch.bfloat16) for _ in range(cls.buffer_depth) ] bsz_tensor_cpu = [ torch.full((1,), batch_size, device="cpu", dtype=torch.int32, pin_memory=True) for _ in range(cls.buffer_depth) ] output_gpu = [ torch.zeros((batch_size, hidden_size), device=hidden_states.device, dtype=hidden_states.dtype) for _ in range(cls.buffer_depth) ] cur_buffer = ( input_tensor_cpu, immediate_experts_ids_cpu, deferred_experts_ids_cpu, weights_cpu, output_cpu, bsz_tensor_cpu, output_gpu, ) if batch_size in cls.capture_bs: cls.capture_buffers[batch_size] = cur_buffer cls.temp_bs = batch_size cls.temp_buffer = cur_buffer return cur_buffer class BaseMoEWrapper(ABC): """ Base class for MoE CPU inference operations. Provides common functionality for all backend implementations. """ _cpu_infer_instance = None _layer_has_pending_deferred: Dict[int, bool] = {} def __init__( self, layer_idx: int, num_experts: int, num_experts_per_tok: int, hidden_size: int, moe_intermediate_size: int, num_gpu_experts: int, cpuinfer_threads: int, threadpool_count: int, weight_path: str, chunked_prefill_size: int, cpu_save: bool = False, max_deferred_experts_per_token: Optional[int] = None, method: str = "AMXINT4", ): """ Initialize base MoE Wrapper. Args: layer_idx: Layer index num_experts: Total number of experts num_experts_per_tok: Number of experts per token (top-k) hidden_size: Hidden dimension size moe_intermediate_size: MoE intermediate size num_gpu_experts: Number of experts to run on GPU cpuinfer_threads: Number of CPU inference threads threadpool_count: Number of NUMA subpools weight_path: Path to weights chunked_prefill_size: Maximum prefill chunk size cpu_save: Whether to save weights to CPU memory max_deferred_experts_per_token: Number of experts per token to defer on this layer. Defaults to 0 (no defer). method: Backend method string """ self.layer_idx = layer_idx self.num_experts = num_experts self.num_experts_per_tok = num_experts_per_tok self.hidden_size = hidden_size self.moe_intermediate_size = moe_intermediate_size self.num_gpu_experts = num_gpu_experts self.weight_path = weight_path self.chunked_prefill_size = chunked_prefill_size self.cpu_save = cpu_save self.max_deferred_experts_per_token = ( int(max_deferred_experts_per_token) if max_deferred_experts_per_token is not None else 0 ) BaseMoEWrapper._layer_has_pending_deferred[self.layer_idx] = False self.method = method # Initialize CPU inference engine (singleton) if BaseMoEWrapper._cpu_infer_instance is None: worker_config = kt_kernel_ext.WorkerPoolConfig() subpool_numa_map = list(range(threadpool_count)) subpool_thread_count = [ cpuinfer_threads // threadpool_count + (1 if i < cpuinfer_threads % threadpool_count else 0) for i in range(threadpool_count) ] worker_config.subpool_count = threadpool_count worker_config.subpool_numa_map = subpool_numa_map worker_config.subpool_thread_count = subpool_thread_count BaseMoEWrapper._cpu_infer_instance = kt_kernel_ext.CPUInfer(worker_config) self.cpu_infer = BaseMoEWrapper._cpu_infer_instance # Backend-specific initialization happens in subclasses self.moe = None @abstractmethod def load_weights_from_tensors( self, gate_proj: torch.Tensor, up_proj: torch.Tensor, down_proj: torch.Tensor, physical_to_logical_map_cpu: torch.Tensor, ): """ Load and quantize weights from BF16/FP16 tensors (online quantization). Args: gate_proj: Gate projection weights [num_experts, intermediate_size, hidden_size] up_proj: Up projection weights [num_experts, intermediate_size, hidden_size] down_proj: Down projection weights [num_experts, hidden_size, intermediate_size] physical_to_logical_map_cpu: Mapping from physical to logical expert IDs """ pass @abstractmethod def load_weights(self, physical_to_logical_map_cpu: torch.Tensor): """ Load weights for this layer and initialize the MoE module. Args: physical_to_logical_map_cpu: Mapping from physical to logical expert IDs """ pass def select_deferred_experts( self, expert_ids: torch.Tensor, expert_scores: torch.Tensor, protected_k: int, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: batch, topk = expert_ids.shape device = expert_ids.device protected_k = max(0, min(int(protected_k), topk)) if protected_k == 0: deferred_ids = expert_ids.clone() immediate_ids = torch.full_like(expert_ids, -1) return immediate_ids, deferred_ids topk_result = torch.topk(expert_scores, k=protected_k, dim=-1, largest=True, sorted=False) protected_indices = topk_result.indices protected_ids = torch.gather(expert_ids, -1, protected_indices) protected_flag = torch.zeros((self.num_experts,), dtype=torch.int32, device=device) protected_flag.scatter_(0, protected_ids.reshape(-1), 1) protected_mask_flat = torch.gather(protected_flag, 0, expert_ids.reshape(-1)).ne(0) protected_mask = protected_mask_flat.view(batch, topk) immediate_ids = expert_ids.clone().masked_fill(~protected_mask, -1) deferred_ids = expert_ids.clone().masked_fill(protected_mask, -1) return immediate_ids, deferred_ids def submit_forward( self, hidden_states: torch.Tensor, topk_ids: torch.Tensor, topk_weights: torch.Tensor, cuda_stream, ): """ Submit forward inference task to CPU (non-blocking). Args: hidden_states: Input hidden states [batch_size, hidden_size] topk_ids: Top-k expert IDs [batch_size, num_experts_per_tok] topk_weights: Top-k expert weights [batch_size, num_experts_per_tok] cuda_stream: CUDA stream for synchronization """ flat_hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) batch_size = flat_hidden_states.shape[0] ( input_tensor_cpu, immediate_experts_ids_cpu, deferred_experts_ids_cpu, weights_cpu, output_cpu, bsz_tensor_cpu, _output_gpu, ) = KExpertsCPUBuffer.get_buffer(flat_hidden_states, self.num_experts_per_tok) current_slot = self.layer_idx % KExpertsCPUBuffer.buffer_depth next_slot = (current_slot + 1) % KExpertsCPUBuffer.buffer_depth bsz_slot_tensor = bsz_tensor_cpu[current_slot] topk_ids_long = topk_ids.to(torch.long) immediate_ids: torch.Tensor deferred_ids: Optional[torch.Tensor] if self.max_deferred_experts_per_token > 0: protected_k = self.num_experts_per_tok - self.max_deferred_experts_per_token immediate_ids, deferred_ids = self.select_deferred_experts(topk_ids_long, topk_weights, protected_k) else: immediate_ids = topk_ids_long deferred_ids = None input_tensor_cpu[current_slot].copy_(flat_hidden_states, non_blocking=True) weights_cpu[current_slot].copy_(topk_weights, non_blocking=True) immediate_experts_ids_cpu[current_slot].copy_(immediate_ids, non_blocking=True) incremental = BaseMoEWrapper._layer_has_pending_deferred.get(self.layer_idx - 1, False) self.cpu_infer.submit_with_cuda_stream( cuda_stream, self.moe.forward_task( bsz_slot_tensor.data_ptr(), immediate_experts_ids_cpu[current_slot].size(-1), immediate_experts_ids_cpu[current_slot].data_ptr(), weights_cpu[current_slot].data_ptr(), input_tensor_cpu[current_slot].data_ptr(), output_cpu[current_slot].data_ptr(), incremental, ), ) BaseMoEWrapper._layer_has_pending_deferred[self.layer_idx] = False if deferred_ids is not None: deferred_experts_ids_cpu[current_slot].copy_(deferred_ids, non_blocking=True) self.cpu_infer.submit_with_cuda_stream( cuda_stream, self.moe.forward_task( bsz_slot_tensor.data_ptr(), deferred_experts_ids_cpu[current_slot].size(-1), deferred_experts_ids_cpu[current_slot].data_ptr(), weights_cpu[current_slot].data_ptr(), input_tensor_cpu[current_slot].data_ptr(), output_cpu[next_slot].data_ptr(), False, ), ) BaseMoEWrapper._layer_has_pending_deferred[self.layer_idx] = True def sync_forward(self, hidden_states: torch.Tensor, cuda_stream) -> torch.Tensor: """ Synchronize and retrieve forward inference results. Args: hidden_states: Original input hidden states (for getting buffer) cuda_stream: CUDA stream for synchronization Returns: output_gpu: Output tensor on GPU """ flat_hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) ( _input_tensor_cpu, _immediate_experts_ids_cpu, _deferred_experts_ids_cpu, _weights_cpu, output_cpu, _bsz_tensor_cpu, output_gpu, ) = KExpertsCPUBuffer.get_buffer(flat_hidden_states, self.num_experts_per_tok) current_slot = self.layer_idx % KExpertsCPUBuffer.buffer_depth allow_pending = 1 if BaseMoEWrapper._layer_has_pending_deferred.get(self.layer_idx, False) else 0 self.cpu_infer.sync_with_cuda_stream(cuda_stream, allow_pending) output_gpu[current_slot].copy_(output_cpu[current_slot], non_blocking=True) return output_gpu[current_slot] def forward( self, hidden_states: torch.Tensor, topk_ids: torch.Tensor, topk_weights: torch.Tensor, cuda_stream, ) -> torch.Tensor: """ Execute forward inference synchronously (submit + sync). Args: hidden_states: Input hidden states [batch_size, hidden_size] topk_ids: Top-k expert IDs [batch_size, num_experts_per_tok] topk_weights: Top-k expert weights [batch_size, num_experts_per_tok] cuda_stream: CUDA stream for synchronization Returns: Output tensor on GPU """ self.submit_forward(hidden_states, topk_ids, topk_weights, cuda_stream) return self.sync_forward(hidden_states, cuda_stream) @staticmethod def set_capture_batch_sizes(capture_bs: List[int]): """ Set batch sizes to capture and cache buffers for. This allows pre-allocation of CPU buffers for specific batch sizes, improving performance by avoiding buffer re-allocation during inference. Args: capture_bs: List of batch sizes to capture (e.g., [1, 2, 4, 8, 16]) Example: >>> BaseMoEWrapper.set_capture_batch_sizes([1, 2, 4, 8, 16]) """ KExpertsCPUBuffer.capture_bs = capture_bs @staticmethod def get_capture_batch_sizes() -> List[int]: """ Get currently configured capture batch sizes. Returns: List of batch sizes that are being captured """ return KExpertsCPUBuffer.capture_bs @staticmethod def clear_buffer_cache(): """ Clear all cached buffers. This frees up memory by clearing the buffer cache. Useful when you want to reset the buffer state or free memory. """ KExpertsCPUBuffer.capture_buffers.clear() KExpertsCPUBuffer.temp_bs = 0 KExpertsCPUBuffer.temp_buffer = tuple()