Files
ktransformers/kt-kernel/python/utils/amx.py
Jiaqi Liao 9bc00e587b Refactor KTMoEWrapper backend (#1587)
* universal backend for cpu inference
* expert defer
2025-11-10 20:26:15 +08:00

302 lines
10 KiB
Python

import os
import torch
import ctypes
# Use relative imports for package structure
from ..experts_base import BaseMoEWrapper
from .loader import SafeTensorLoader
from kt_kernel_ext.moe import MOEConfig
try:
from kt_kernel_ext.moe import AMXInt4_MOE, AMXInt8_MOE
_HAS_AMX_SUPPORT = True
except (ImportError, AttributeError):
_HAS_AMX_SUPPORT = False
AMXInt4_MOE, AMXInt8_MOE = None, None
from typing import Optional
class AMXMoEWrapper(BaseMoEWrapper):
"""
AMX-based MoE wrapper implementation.
Supports AMXINT4 and AMXINT8 quantization methods.
"""
_safetensor_loader_instance = None # Singleton SafeTensorLoader
def __init__(
self,
layer_idx: int,
num_experts: int,
num_experts_per_tok: int,
hidden_size: int,
moe_intermediate_size: int,
num_gpu_experts: int,
cpuinfer_threads: int,
threadpool_count: int,
weight_path: str,
chunked_prefill_size: int,
cpu_save: bool = False,
max_deferred_experts_per_token: Optional[int] = None,
method: str = "AMXINT4",
):
"""
Initialize AMX MoE Wrapper.
Args:
layer_idx: Layer index
num_experts: Total number of experts
num_experts_per_tok: Number of experts per token (top-k)
hidden_size: Hidden dimension size
moe_intermediate_size: MoE intermediate size
num_gpu_experts: Number of experts to run on GPU
cpuinfer_threads: Number of CPU inference threads
threadpool_count: Number of NUMA subpools
weight_path: Path to AMX weights (SafeTensor format)
chunked_prefill_size: Maximum prefill chunk size
cpu_save: Whether to save weights to CPU memory
max_deferred_experts_per_token: Number of experts per token to defer. Defaults to 0.
method: AMX quantization method ("AMXINT4" or "AMXINT8")
"""
if not _HAS_AMX_SUPPORT:
raise RuntimeError(
"AMX backend not available. kt_kernel_ext was not compiled with AMX support.\n"
"Please recompile with AMX enabled."
)
# Initialize base class
super().__init__(
layer_idx=layer_idx,
num_experts=num_experts,
num_experts_per_tok=num_experts_per_tok,
hidden_size=hidden_size,
moe_intermediate_size=moe_intermediate_size,
num_gpu_experts=num_gpu_experts,
cpuinfer_threads=cpuinfer_threads,
threadpool_count=threadpool_count,
weight_path=weight_path,
chunked_prefill_size=chunked_prefill_size,
cpu_save=cpu_save,
max_deferred_experts_per_token=max_deferred_experts_per_token,
method=method,
)
# AMX-specific: Check if we should load merged safetensor weights
self.load_merged_weight = False
import glob
if glob.glob(os.path.join(weight_path, "*.safetensors")):
self.load_merged_weight = True
# Initialize SafeTensor loader (singleton)
if self.load_merged_weight:
if AMXMoEWrapper._safetensor_loader_instance is None:
AMXMoEWrapper._safetensor_loader_instance = SafeTensorLoader(weight_path)
self.safetensor_loader = AMXMoEWrapper._safetensor_loader_instance
# AMX-specific weight storage
self.gate_weights = None
self.up_weights = None
self.down_weights = None
self.gate_scales = None
self.up_scales = None
self.down_scales = None
def load_weights_from_tensors(
self,
gate_proj: torch.Tensor,
up_proj: torch.Tensor,
down_proj: torch.Tensor,
physical_to_logical_map_cpu: torch.Tensor,
):
"""
Load and quantize weights from BF16/FP16 tensors (online quantization).
Args:
gate_proj: Gate projection weights [num_experts, intermediate_size, hidden_size]
up_proj: Up projection weights [num_experts, intermediate_size, hidden_size]
down_proj: Down projection weights [num_experts, hidden_size, intermediate_size]
physical_to_logical_map_cpu: Mapping from physical to logical expert IDs
"""
# Store tensors as instance variables to keep them alive
self.gate_proj = gate_proj.contiguous()
self.up_proj = up_proj.contiguous()
self.down_proj = down_proj.contiguous()
# Configure MoE with online quantization (cpu_save mode)
moe_config = MOEConfig(
self.num_experts,
self.num_experts_per_tok,
self.hidden_size,
self.moe_intermediate_size,
self.num_gpu_experts,
)
moe_config.layer_idx = self.layer_idx
moe_config.pool = self.cpu_infer.backend_
moe_config.max_len = self.chunked_prefill_size
# Enable save mode for online quantization
moe_config.save = True
moe_config.load = False
# Set weight pointers
moe_config.gate_proj = self.gate_proj.data_ptr()
moe_config.up_proj = self.up_proj.data_ptr()
moe_config.down_proj = self.down_proj.data_ptr()
# Set output path for quantized weights
moe_config.path = self.weight_path
# Create MoE module based on AMX method
if self.method == "AMXINT4":
self.moe = AMXInt4_MOE(moe_config)
elif self.method == "AMXINT8":
self.moe = AMXInt8_MOE(moe_config)
else:
raise NotImplementedError(f"Unsupported AMX method: {self.method}")
# Submit quantization and save task
self.cpu_infer.submit(self.moe.load_weights_task(physical_to_logical_map_cpu.data_ptr()))
self.cpu_infer.sync()
def load_weights(self, physical_to_logical_map_cpu: torch.Tensor):
"""
Load weights for this layer and initialize the MoE module.
Args:
physical_to_logical_map_cpu: Mapping from physical to logical expert IDs
"""
gate_ptr = 0
up_ptr = 0
down_ptr = 0
gate_ptrs = []
up_ptrs = []
down_ptrs = []
gate_scale_ptrs = []
up_scale_ptrs = []
down_scale_ptrs = []
if self.load_merged_weight:
base_key = f"blk.{self.layer_idx}"
w = self.safetensor_loader.load_experts(base_key)
self.gate_weights = w["gate"]
self.up_weights = w["up"]
self.down_weights = w["down"]
self.gate_scales = w["gate_scale"]
self.up_scales = w["up_scale"]
self.down_scales = w["down_scale"]
# Get pointers to weight arrays
gate_ptrs = [
[
ctypes.addressof(ctypes.cast(et.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents)
for et in numa_array
]
for numa_array in self.gate_weights
]
up_ptrs = [
[
ctypes.addressof(ctypes.cast(et.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents)
for et in numa_array
]
for numa_array in self.up_weights
]
down_ptrs = [
[
ctypes.addressof(ctypes.cast(et.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents)
for et in numa_array
]
for numa_array in self.down_weights
]
gate_scale_ptrs = [
[
ctypes.addressof(ctypes.cast(et.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents)
for et in numa_array
]
for numa_array in self.gate_scales
]
up_scale_ptrs = [
[
ctypes.addressof(ctypes.cast(et.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents)
for et in numa_array
]
for numa_array in self.up_scales
]
down_scale_ptrs = [
[
ctypes.addressof(ctypes.cast(et.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents)
for et in numa_array
]
for numa_array in self.down_scales
]
# Configure MoE
moe_config = MOEConfig(
self.num_experts,
self.num_experts_per_tok,
self.hidden_size,
self.moe_intermediate_size,
self.num_gpu_experts,
)
moe_config.layer_idx = self.layer_idx
moe_config.pool = self.cpu_infer.backend_
moe_config.max_len = self.chunked_prefill_size
moe_config.gate_proj = gate_ptr
moe_config.up_proj = up_ptr
moe_config.down_proj = down_ptr
moe_config.gate_projs = gate_ptrs
moe_config.up_projs = up_ptrs
moe_config.down_projs = down_ptrs
moe_config.gate_scales = gate_scale_ptrs
moe_config.up_scales = up_scale_ptrs
moe_config.down_scales = down_scale_ptrs
if self.cpu_save:
moe_config.save = True
moe_config.load = False
base_key = f"model.layers.{self.layer_idx}"
w = self.safetensor_loader.load_experts(base_key)
self.gate_proj = torch.cat(w["gate_weight"], dim=0).contiguous()
self.up_proj = torch.cat(w["up_weight"], dim=0).contiguous()
self.down_proj = torch.cat(w["down_weight"], dim=0).contiguous()
moe_config.gate_proj = self.gate_proj.data_ptr()
moe_config.up_proj = self.up_proj.data_ptr()
moe_config.down_proj = self.down_proj.data_ptr()
else:
moe_config.load = True
if not self.load_merged_weight:
moe_config.path = self.weight_path
# Create MoE module based on AMX method
if self.method == "AMXINT4":
self.moe = AMXInt4_MOE(moe_config)
elif self.method == "AMXINT8":
self.moe = AMXInt8_MOE(moe_config)
else:
raise NotImplementedError(f"Unsupported AMX method: {self.method}")
# Load weights
self.cpu_infer.submit(self.moe.load_weights_task(physical_to_logical_map_cpu.data_ptr()))
self.cpu_infer.sync()
# Clean up temporary weight storage if using merged weights
if self.load_merged_weight:
del self.gate_weights
del self.up_weights
del self.down_weights
del self.gate_scales
del self.up_scales
del self.down_scales