# Wrapper for AMX MoE CPU inference operations # This module encapsulates CPU inference engine, weight loading, and buffer management # SPDX-License-Identifier: Apache-2.0 """ Expert wrappers for CPU-based MoE inference. This module provides high-level Python wrappers around the low-level C++ kernel implementations, handling weight loading, buffer management, and forward inference. """ from __future__ import annotations import torch from typing import Dict, List, Optional, Tuple from safetensors import safe_open import os import ctypes # Import the C++ extension module (compiled as kt_kernel_ext) import kt_kernel_ext from kt_kernel_ext.moe import MOEConfig, AMXInt4_MOE, AMXInt8_MOE class SafeTensorLoader: tensor_file_map: dict tensor_type_map: dict file_handle_map: dict tensor_device_map: dict def __init__(self, file_path: str): self.__load_tensor_file_map(file_path) def __load_tensor_file_map(self, file_path: str): if not os.path.exists(file_path): raise FileNotFoundError(f"Path not found: {file_path}") if os.path.isfile(file_path): folder_path = os.path.dirname(file_path) else: folder_path = file_path self.file_handle_map = {} self.tensor_file_map = {} self.tensor_type_map = {} self.tensor_device_map = {} found_safetensor = False for root, _, files in os.walk(folder_path): files = sorted(files) for file in files: if file.endswith(".safetensors"): found_safetensor = True file_path = os.path.join(root, file) if file not in self.file_handle_map: try: handle = safe_open(file_path, framework="pt") self.file_handle_map[file] = handle except Exception as e: print(f"Error opening Safetensor file {file_path}: {e}") continue f = self.file_handle_map.get(file) if f is None: continue try: for key in f.keys(): self.tensor_file_map[key] = file except Exception as e: print(f"Error reading Safetensor file {file_path}: {e}") if not found_safetensor: raise FileNotFoundError(f"No Safetensor files found in {folder_path}") def load_tensor(self, key: str, device: str = "cpu"): if key not in self.tensor_file_map: raise KeyError(f"Key {key} not found in Safetensor files") file = self.tensor_file_map[key] f = self.file_handle_map.get(file) if f is None: raise FileNotFoundError(f"File {file} not found in Safetensor files") tensor = f.get_tensor(key) return tensor.to(device) def close_all_handles(self): for handle in self.file_handle_map.values(): handle.close() self.file_handle_map.clear() def load_experts(self, base_key: str, device: str = "cpu"): # base_key: blk.{layer_index} # blk.{layer_index}.ffn_[up, down, gate]_exps.{expert_id}.numa.{numa_id}.weight up_base_key = f"{base_key}.ffn_up_exps" gate_base_key = f"{base_key}.ffn_gate_exps" down_base_key = f"{base_key}.ffn_down_exps" max_numa_id = -1 max_experts_count = -1 while self.has_tensor(f"{up_base_key}.{max_experts_count+1}.numa.{0}.weight"): max_experts_count += 1 if max_experts_count == 0: raise ValueError(f"No experts found for key {base_key}") while self.has_tensor(f"{up_base_key}.{0}.numa.{max_numa_id+1}.weight"): max_numa_id += 1 # Initialize empty lists to store tensors for each projection type up_weights = [[] for _ in range(max_numa_id + 1)] gate_weights = [[] for _ in range(max_numa_id + 1)] down_weights = [[] for _ in range(max_numa_id + 1)] up_scales = [[] for _ in range(max_numa_id + 1)] gate_scales = [[] for _ in range(max_numa_id + 1)] down_scales = [[] for _ in range(max_numa_id + 1)] for numa_id in range(max_numa_id + 1): for expert_id in range(max_experts_count + 1): up_key = f"{up_base_key}.{expert_id}.numa.{numa_id}.weight" gate_key = f"{gate_base_key}.{expert_id}.numa.{numa_id}.weight" down_key = f"{down_base_key}.{expert_id}.numa.{numa_id}.weight" up_scale_key = f"{up_base_key}.{expert_id}.numa.{numa_id}.scale" gate_scale_key = f"{gate_base_key}.{expert_id}.numa.{numa_id}.scale" down_scale_key = f"{down_base_key}.{expert_id}.numa.{numa_id}.scale" # make sure contiguous up_tensor = self.load_tensor(up_key, device).numpy() gate_tensor = self.load_tensor(gate_key, device).numpy() down_tensor = self.load_tensor(down_key, device).numpy() up_scale_tensor = self.load_tensor(up_scale_key, device).numpy() gate_scale_tensor = self.load_tensor(gate_scale_key, device).numpy() down_scale_tensor = self.load_tensor(down_scale_key, device).numpy() up_weights[numa_id].append(up_tensor) gate_weights[numa_id].append(gate_tensor) down_weights[numa_id].append(down_tensor) up_scales[numa_id].append(up_scale_tensor) gate_scales[numa_id].append(gate_scale_tensor) down_scales[numa_id].append(down_scale_tensor) return { "up": up_weights, "gate": gate_weights, "down": down_weights, "up_scale": up_scales, "gate_scale": gate_scales, "down_scale": down_scales, } def has_tensor(self, name: str): return name in self.tensor_file_map class KExpertsCPUBuffer: capture_bs: List = list() capture_buffers: Dict = dict() temp_bs: int = 0 temp_buffer: tuple = tuple() buffer_depth: int = 2 @classmethod def get_buffer(cls, hidden_states: torch.Tensor, num_experts_per_tok): hidden_size = hidden_states.shape[-1] batch_size = hidden_states.shape[0] if batch_size in cls.capture_buffers: return cls.capture_buffers[batch_size] if batch_size == cls.temp_bs: return cls.temp_buffer input_tensor_cpu = [ torch.zeros((batch_size, hidden_size), device="cpu", pin_memory=True, dtype=torch.bfloat16) for _ in range(cls.buffer_depth) ] immediate_experts_ids_cpu = [ torch.zeros((batch_size, num_experts_per_tok), device="cpu", dtype=torch.long, pin_memory=True) for _ in range(cls.buffer_depth) ] deferred_experts_ids_cpu = [ torch.full((batch_size, num_experts_per_tok), -1, device="cpu", dtype=torch.long, pin_memory=True) for _ in range(cls.buffer_depth) ] weights_cpu = [ torch.zeros((batch_size, num_experts_per_tok), device="cpu", dtype=torch.float32, pin_memory=True) for _ in range(cls.buffer_depth) ] output_cpu = [ torch.zeros((batch_size, hidden_size), device="cpu", pin_memory=True, dtype=torch.bfloat16) for _ in range(cls.buffer_depth) ] bsz_tensor_cpu = [ torch.zeros((1,), device="cpu", dtype=torch.int32, pin_memory=True) for _ in range(cls.buffer_depth) ] output_gpu = [ torch.zeros((batch_size, hidden_size), device=hidden_states.device, dtype=hidden_states.dtype) for _ in range(cls.buffer_depth) ] cur_buffer = ( input_tensor_cpu, immediate_experts_ids_cpu, deferred_experts_ids_cpu, weights_cpu, output_cpu, bsz_tensor_cpu, output_gpu, ) if batch_size in cls.capture_bs: cls.capture_buffers[batch_size] = cur_buffer cls.temp_bs = batch_size cls.temp_buffer = cur_buffer return cur_buffer class AMXMoEWrapper: """ Wrapper for AMX MoE CPU inference operations. Manages CPU inference engine, weight loading, and buffer management. """ _cpu_infer_instance = None _safetensor_loader_instance = None _layer_has_pending_deferred: Dict[int, bool] = {} def __init__( self, layer_idx: int, num_experts: int, num_experts_per_tok: int, hidden_size: int, moe_intermediate_size: int, num_gpu_experts: int, cpuinfer_threads: int, threadpool_count: int, amx_weight_path: str, chunked_prefill_size: int, cpu_save: bool = False, max_deferred_experts_per_token: Optional[int] = None, amx_method: str = "AMXINT4", ): """ Initialize AMX MoE Wrapper. Args: layer_idx: Layer index num_experts: Total number of experts num_experts_per_tok: Number of experts per token (top-k) hidden_size: Hidden dimension size moe_intermediate_size: MoE intermediate size num_gpu_experts: Number of experts to run on GPU cpuinfer_threads: Number of CPU inference threads threadpool_count: Number of NUMA subpools amx_weight_path: Path to AMX weights chunked_prefill_size: Maximum prefill chunk size cpu_save: Whether to save weights to CPU memory max_deferred_experts_per_token: Number of experts per token to defer on this layer. Defaults to 0 (no defer). amx_method: AMX quantization method ("AMXINT4" or "AMXINT8") """ self.layer_idx = layer_idx self.num_experts = num_experts self.num_experts_per_tok = num_experts_per_tok self.hidden_size = hidden_size self.moe_intermediate_size = moe_intermediate_size self.num_gpu_experts = num_gpu_experts self.amx_weight_path = amx_weight_path self.chunked_prefill_size = chunked_prefill_size self.cpu_save = cpu_save self.max_deferred_experts_per_token = int(max_deferred_experts_per_token) if max_deferred_experts_per_token is not None else 0 AMXMoEWrapper._layer_has_pending_deferred[self.layer_idx] = False self.amx_method = amx_method # Initialize CPU inference engine (singleton) if AMXMoEWrapper._cpu_infer_instance is None: worker_config = kt_kernel_ext.WorkerPoolConfig() subpool_numa_map = list(range(threadpool_count)) subpool_thread_count = [ cpuinfer_threads // threadpool_count + (1 if i < cpuinfer_threads % threadpool_count else 0) for i in range(threadpool_count) ] worker_config.subpool_count = threadpool_count worker_config.subpool_numa_map = subpool_numa_map worker_config.subpool_thread_count = subpool_thread_count AMXMoEWrapper._cpu_infer_instance = kt_kernel_ext.CPUInfer(worker_config) self.cpu_infer = AMXMoEWrapper._cpu_infer_instance # Check if we should load merged safetensor weights self.load_merged_weight = False import glob if glob.glob(os.path.join(amx_weight_path, "*.safetensors")): self.load_merged_weight = True # Initialize SafeTensor loader (singleton) if self.load_merged_weight: if AMXMoEWrapper._safetensor_loader_instance is None: AMXMoEWrapper._safetensor_loader_instance = SafeTensorLoader(amx_weight_path) self.safetensor_loader = AMXMoEWrapper._safetensor_loader_instance self.moe = None self.gate_weights = None self.up_weights = None self.down_weights = None self.gate_scales = None self.up_scales = None self.down_scales = None def load_weights_from_tensors( self, gate_proj: torch.Tensor, up_proj: torch.Tensor, down_proj: torch.Tensor, physical_to_logical_map_cpu: torch.Tensor, ): """ Load and quantize weights from BF16/FP16 tensors (online quantization). Args: gate_proj: Gate projection weights [num_experts, intermediate_size, hidden_size] up_proj: Up projection weights [num_experts, intermediate_size, hidden_size] down_proj: Down projection weights [num_experts, hidden_size, intermediate_size] physical_to_logical_map_cpu: Mapping from physical to logical expert IDs """ # Store tensors as instance variables to keep them alive self.gate_proj = gate_proj.contiguous() self.up_proj = up_proj.contiguous() self.down_proj = down_proj.contiguous() # Configure MoE with online quantization (cpu_save mode) moe_config = MOEConfig( self.num_experts, self.num_experts_per_tok, self.hidden_size, self.moe_intermediate_size, self.num_gpu_experts, ) moe_config.layer_idx = self.layer_idx moe_config.pool = self.cpu_infer.backend_ moe_config.max_len = self.chunked_prefill_size # Enable save mode for online quantization moe_config.save = True moe_config.load = False # Set weight pointers moe_config.gate_proj = self.gate_proj.data_ptr() moe_config.up_proj = self.up_proj.data_ptr() moe_config.down_proj = self.down_proj.data_ptr() # Set output path for quantized weights moe_config.path = self.amx_weight_path # Create MoE module based on AMX method if self.amx_method == "AMXINT4": self.moe = AMXInt4_MOE(moe_config) elif self.amx_method == "AMXINT8": self.moe = AMXInt8_MOE(moe_config) else: raise NotImplementedError(f"Unsupported AMX method: {self.amx_method}") # Submit quantization and save task self.cpu_infer.submit(self.moe.load_weights_task(physical_to_logical_map_cpu.data_ptr())) self.cpu_infer.sync() def load_weights(self, physical_to_logical_map_cpu: torch.Tensor): """ Load weights for this layer and initialize the MoE module. Args: physical_to_logical_map_cpu: Mapping from physical to logical expert IDs """ gate_ptr = 0 up_ptr = 0 down_ptr = 0 gate_ptrs = [] up_ptrs = [] down_ptrs = [] gate_scale_ptrs = [] up_scale_ptrs = [] down_scale_ptrs = [] if self.load_merged_weight: base_key = f"blk.{self.layer_idx}" w = self.safetensor_loader.load_experts(base_key) self.gate_weights = w["gate"] self.up_weights = w["up"] self.down_weights = w["down"] self.gate_scales = w["gate_scale"] self.up_scales = w["up_scale"] self.down_scales = w["down_scale"] # Get pointers to weight arrays gate_ptrs = [ [ ctypes.addressof(ctypes.cast(et.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents) for et in numa_array ] for numa_array in self.gate_weights ] up_ptrs = [ [ ctypes.addressof(ctypes.cast(et.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents) for et in numa_array ] for numa_array in self.up_weights ] down_ptrs = [ [ ctypes.addressof(ctypes.cast(et.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents) for et in numa_array ] for numa_array in self.down_weights ] gate_scale_ptrs = [ [ ctypes.addressof(ctypes.cast(et.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents) for et in numa_array ] for numa_array in self.gate_scales ] up_scale_ptrs = [ [ ctypes.addressof(ctypes.cast(et.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents) for et in numa_array ] for numa_array in self.up_scales ] down_scale_ptrs = [ [ ctypes.addressof(ctypes.cast(et.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents) for et in numa_array ] for numa_array in self.down_scales ] # Configure MoE moe_config = MOEConfig( self.num_experts, self.num_experts_per_tok, self.hidden_size, self.moe_intermediate_size, self.num_gpu_experts, ) moe_config.layer_idx = self.layer_idx moe_config.pool = self.cpu_infer.backend_ moe_config.max_len = self.chunked_prefill_size moe_config.gate_proj = gate_ptr moe_config.up_proj = up_ptr moe_config.down_proj = down_ptr moe_config.gate_projs = gate_ptrs moe_config.up_projs = up_ptrs moe_config.down_projs = down_ptrs moe_config.gate_scales = gate_scale_ptrs moe_config.up_scales = up_scale_ptrs moe_config.down_scales = down_scale_ptrs if self.cpu_save: moe_config.save = True moe_config.load = False base_key = f"model.layers.{self.layer_idx}" w = self.safetensor_loader.load_experts(base_key) self.gate_proj = torch.cat(w["gate_weight"], dim=0).contiguous() self.up_proj = torch.cat(w["up_weight"], dim=0).contiguous() self.down_proj = torch.cat(w["down_weight"], dim=0).contiguous() moe_config.gate_proj = self.gate_proj.data_ptr() moe_config.up_proj = self.up_proj.data_ptr() moe_config.down_proj = self.down_proj.data_ptr() else: moe_config.load = True if not self.load_merged_weight: moe_config.path = self.amx_weight_path # Create MoE module based on AMX method if self.amx_method == "AMXINT4": self.moe = AMXInt4_MOE(moe_config) elif self.amx_method == "AMXINT8": self.moe = AMXInt8_MOE(moe_config) else: raise NotImplementedError(f"Unsupported AMX method: {self.amx_method}") # Load weights self.cpu_infer.submit(self.moe.load_weights_task(physical_to_logical_map_cpu.data_ptr())) self.cpu_infer.sync() # Clean up temporary weight storage if using merged weights if self.load_merged_weight: del self.gate_weights del self.up_weights del self.down_weights del self.gate_scales del self.up_scales del self.down_scales def select_deferred_experts( self, expert_ids: torch.Tensor, expert_scores: torch.Tensor, protected_k: int, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: batch, topk = expert_ids.shape device = expert_ids.device protected_k = max(0, min(int(protected_k), topk)) if protected_k == 0: deferred_ids = expert_ids.clone() immediate_ids = torch.full_like(expert_ids, -1) return immediate_ids, deferred_ids topk_result = torch.topk(expert_scores, k=protected_k, dim=-1, largest=True, sorted=False) protected_indices = topk_result.indices protected_ids = torch.gather(expert_ids, -1, protected_indices) protected_flag = torch.zeros((self.num_experts,), dtype=torch.int32, device=device) protected_flag.scatter_(0, protected_ids.reshape(-1), 1) protected_mask_flat = torch.gather(protected_flag, 0, expert_ids.reshape(-1)).ne(0) protected_mask = protected_mask_flat.view(batch, topk) immediate_ids = expert_ids.clone().masked_fill(~protected_mask, -1) deferred_ids = expert_ids.clone().masked_fill(protected_mask, -1) return immediate_ids, deferred_ids def submit_forward( self, hidden_states: torch.Tensor, topk_ids: torch.Tensor, topk_weights: torch.Tensor, cuda_stream, ): """ Submit forward inference task to CPU (non-blocking). Args: hidden_states: Input hidden states [batch_size, hidden_size] topk_ids: Top-k expert IDs [batch_size, num_experts_per_tok] topk_weights: Top-k expert weights [batch_size, num_experts_per_tok] cuda_stream: CUDA stream for synchronization """ flat_hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) batch_size = flat_hidden_states.shape[0] ( input_tensor_cpu, immediate_experts_ids_cpu, deferred_experts_ids_cpu, weights_cpu, output_cpu, bsz_tensor_cpu, _output_gpu, ) = KExpertsCPUBuffer.get_buffer(flat_hidden_states, self.num_experts_per_tok) current_slot = self.layer_idx % KExpertsCPUBuffer.buffer_depth next_slot = (current_slot + 1) % KExpertsCPUBuffer.buffer_depth bsz_slot_tensor = bsz_tensor_cpu[current_slot] bsz_slot_tensor.fill_(batch_size) deferred_experts_ids_cpu[current_slot].fill_(-1) topk_ids_long = topk_ids.to(torch.long) immediate_ids: torch.Tensor deferred_ids: Optional[torch.Tensor] if self.max_deferred_experts_per_token > 0: protected_k = self.num_experts_per_tok - self.max_deferred_experts_per_token immediate_ids, deferred_ids = self.select_deferred_experts(topk_ids_long, topk_weights, protected_k) else: immediate_ids = topk_ids_long deferred_ids = None input_tensor_cpu[current_slot].copy_(flat_hidden_states, non_blocking=True) weights_cpu[current_slot].copy_(topk_weights, non_blocking=True) immediate_experts_ids_cpu[current_slot].copy_(immediate_ids, non_blocking=True) incremental = AMXMoEWrapper._layer_has_pending_deferred.get(self.layer_idx - 1, False) self.cpu_infer.submit_with_cuda_stream( cuda_stream, self.moe.forward_task( bsz_slot_tensor.data_ptr(), immediate_experts_ids_cpu[current_slot].size(-1), immediate_experts_ids_cpu[current_slot].data_ptr(), weights_cpu[current_slot].data_ptr(), input_tensor_cpu[current_slot].data_ptr(), output_cpu[current_slot].data_ptr(), incremental, ), ) AMXMoEWrapper._layer_has_pending_deferred[self.layer_idx] = False if deferred_ids is not None: deferred_experts_ids_cpu[current_slot].copy_(deferred_ids, non_blocking=True) self.cpu_infer.submit_with_cuda_stream( cuda_stream, self.moe.forward_task( bsz_slot_tensor.data_ptr(), deferred_experts_ids_cpu[current_slot].size(-1), deferred_experts_ids_cpu[current_slot].data_ptr(), weights_cpu[current_slot].data_ptr(), input_tensor_cpu[current_slot].data_ptr(), output_cpu[next_slot].data_ptr(), False, ), ) AMXMoEWrapper._layer_has_pending_deferred[self.layer_idx] = True def sync_forward(self, hidden_states: torch.Tensor, cuda_stream) -> torch.Tensor: """ Synchronize and retrieve forward inference results. Args: hidden_states: Original input hidden states (for getting buffer) cuda_stream: CUDA stream for synchronization Returns: output_gpu: Output tensor on GPU """ flat_hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) ( input_tensor_cpu, immediate_experts_ids_cpu, _deferred_experts_ids_cpu, weights_cpu, output_cpu, _bsz_tensor_cpu, output_gpu, ) = KExpertsCPUBuffer.get_buffer(flat_hidden_states, self.num_experts_per_tok) current_slot = self.layer_idx % KExpertsCPUBuffer.buffer_depth allow_pending = 1 if AMXMoEWrapper._layer_has_pending_deferred.get(self.layer_idx, False) else 0 self.cpu_infer.sync_with_cuda_stream(cuda_stream, allow_pending) output_gpu[current_slot].copy_(output_cpu[current_slot], non_blocking=True) return output_gpu[current_slot] def forward( self, hidden_states: torch.Tensor, topk_ids: torch.Tensor, topk_weights: torch.Tensor, cuda_stream, ) -> torch.Tensor: """ Execute forward inference synchronously (submit + sync). Args: hidden_states: Input hidden states [batch_size, hidden_size] topk_ids: Top-k expert IDs [batch_size, num_experts_per_tok] topk_weights: Top-k expert weights [batch_size, num_experts_per_tok] cuda_stream: CUDA stream for synchronization Returns: Output tensor on GPU """ self.submit_forward(hidden_states, topk_ids, topk_weights, cuda_stream) return self.sync_forward(hidden_states, cuda_stream) @staticmethod def set_capture_batch_sizes(capture_bs: List[int]): """ Set batch sizes to capture and cache buffers for. This allows pre-allocation of CPU buffers for specific batch sizes, improving performance by avoiding buffer re-allocation during inference. Args: capture_bs: List of batch sizes to capture (e.g., [1, 2, 4, 8, 16]) Example: >>> AMXMoEWrapper.set_capture_batch_sizes([1, 2, 4, 8, 16]) """ KExpertsCPUBuffer.capture_bs = capture_bs @staticmethod def get_capture_batch_sizes() -> List[int]: """ Get currently configured capture batch sizes. Returns: List of batch sizes that are being captured """ return KExpertsCPUBuffer.capture_bs @staticmethod def clear_buffer_cache(): """ Clear all cached buffers. This frees up memory by clearing the buffer cache. Useful when you want to reset the buffer state or free memory. """ KExpertsCPUBuffer.capture_buffers.clear() KExpertsCPUBuffer.temp_bs = 0 KExpertsCPUBuffer.temp_buffer = tuple()