Files
ktransformers/kt-kernel/python/utils/moe_kernel.py
2026-01-16 17:01:15 +08:00

320 lines
11 KiB
Python

import os
import torch
import ctypes
from typing import Optional
# Use relative imports for package structure
from ..experts_base import BaseMoEWrapper
from .loader import SafeTensorLoader
from kt_kernel_ext.moe import MOEConfig
try:
from kt_kernel_ext.moe import Int8_KERNEL_MOE
_HAS_INT8_SUPPORT = True
except (ImportError, AttributeError):
Int8_KERNEL_MOE = None
_HAS_INT8_SUPPORT = False
try:
from kt_kernel_ext.moe import Int4_KERNEL_MOE
_HAS_INT4_SUPPORT = True
except (ImportError, AttributeError):
Int4_KERNEL_MOE = None
_HAS_INT4_SUPPORT = False
from typing import Optional
class GeneralMoEWrapper(BaseMoEWrapper):
"""
moe-based MoE wrapper implementation.
Supports MOE_INT4 and MOE_INT8 quantization methods.
"""
_safetensor_loader_instance = None # Singleton SafeTensorLoader
def __init__(
self,
layer_idx: int,
num_experts: int,
num_experts_per_tok: int,
hidden_size: int,
moe_intermediate_size: int,
gpu_experts_mask: Optional[torch.Tensor],
cpuinfer_threads: int,
threadpool_count: int,
weight_path: str,
chunked_prefill_size: int,
cpu_save: bool = False,
max_deferred_experts_per_token: Optional[int] = None,
method: str = "MOE_INT8",
):
"""
Initialize general MoE Wrapper.
Args:
layer_idx: Layer index
num_experts: Total number of experts
num_experts_per_tok: Number of experts per token (top-k)
hidden_size: Hidden dimension size
moe_intermediate_size: MoE intermediate size
gpu_experts_mask: Boolean mask indicating which experts are on GPU.
Shape: [num_experts], dtype: torch.bool.
mask[i] = True means expert i is on GPU.
If None, all experts are on CPU.
cpuinfer_threads: Number of CPU inference threads
threadpool_count: Number of NUMA subpools
weight_path: Path to weights (SafeTensor format)
chunked_prefill_size: Maximum prefill chunk size
cpu_save: Whether to save weights to CPU memory
max_deferred_experts_per_token: Number of experts per token to defer. Defaults to 0.
method: general quantization method ("MOE_INT4" or "MOE_INT8")
"""
if not _HAS_INT4_SUPPORT and method == "MOE_INT4":
raise RuntimeError(
"MoE_INT4 backend not available. kt_kernel_ext was not compiled with int4 support.\n"
"Please recompile with int4 enabled."
)
if not _HAS_INT8_SUPPORT and method == "MOE_INT8":
raise RuntimeError(
"MoE_INT8 backend not available. kt_kernel_ext was not compiled with int8 support.\n"
"Please recompile with int8 enabled."
)
# Initialize base class
super().__init__(
layer_idx=layer_idx,
num_experts=num_experts,
num_experts_per_tok=num_experts_per_tok,
hidden_size=hidden_size,
moe_intermediate_size=moe_intermediate_size,
gpu_experts_mask=gpu_experts_mask,
cpuinfer_threads=cpuinfer_threads,
threadpool_count=threadpool_count,
weight_path=weight_path,
chunked_prefill_size=chunked_prefill_size,
cpu_save=cpu_save,
max_deferred_experts_per_token=max_deferred_experts_per_token,
method=method,
)
# moe-specific: Check if we should load merged safetensor weights
self.load_merged_weight = False
import glob
if glob.glob(os.path.join(weight_path, "*.safetensors")):
self.load_merged_weight = True
# Initialize SafeTensor loader (singleton)
if self.load_merged_weight:
if GeneralMoEWrapper._safetensor_loader_instance is None:
GeneralMoEWrapper._safetensor_loader_instance = SafeTensorLoader(weight_path)
self.safetensor_loader = GeneralMoEWrapper._safetensor_loader_instance
# moe-specific weight storage
self.gate_weights = None
self.up_weights = None
self.down_weights = None
self.gate_scales = None
self.up_scales = None
self.down_scales = None
def load_weights_from_tensors(
self,
gate_proj: torch.Tensor,
up_proj: torch.Tensor,
down_proj: torch.Tensor,
physical_to_logical_map_cpu: torch.Tensor,
):
"""
Load and quantize weights from BF16/FP16 tensors (online quantization).
Args:
gate_proj: Gate projection weights [num_experts, intermediate_size, hidden_size]
up_proj: Up projection weights [num_experts, intermediate_size, hidden_size]
down_proj: Down projection weights [num_experts, hidden_size, intermediate_size]
physical_to_logical_map_cpu: Mapping from physical to logical expert IDs
"""
# Store tensors as instance variables to keep them alive
self.gate_proj = gate_proj.contiguous()
self.up_proj = up_proj.contiguous()
self.down_proj = down_proj.contiguous()
# Configure MoE with online quantization (cpu_save mode)
moe_config = MOEConfig(
self.num_experts,
self.num_experts_per_tok,
self.hidden_size,
self.moe_intermediate_size,
self.gpu_experts_mask.data_ptr(),
)
moe_config.layer_idx = self.layer_idx
moe_config.pool = self.cpu_infer.backend_
moe_config.max_len = self.chunked_prefill_size
# Enable save mode for online quantization
moe_config.save = True
moe_config.load = False
# Set weight pointers
moe_config.gate_proj = self.gate_proj.data_ptr()
moe_config.up_proj = self.up_proj.data_ptr()
moe_config.down_proj = self.down_proj.data_ptr()
# Set output path for quantized weights
moe_config.path = self.weight_path
# Create MoE module based on method
if self.method == "MOE_INT4":
self.moe = Int4_KERNEL_MOE(moe_config)
elif self.method == "MOE_INT8":
self.moe = Int8_KERNEL_MOE(moe_config)
else:
raise NotImplementedError(f"Unsupported MoE method: {self.method}")
# Submit quantization and save task
self.cpu_infer.submit(self.moe.load_weights_task(physical_to_logical_map_cpu.data_ptr()))
self.cpu_infer.sync()
def load_weights(self, physical_to_logical_map_cpu: torch.Tensor):
"""
Load weights for this layer and initialize the MoE module.
Args:
physical_to_logical_map_cpu: Mapping from physical to logical expert IDs
"""
gate_ptr = 0
up_ptr = 0
down_ptr = 0
gate_ptrs = []
up_ptrs = []
down_ptrs = []
gate_scale_ptrs = []
up_scale_ptrs = []
down_scale_ptrs = []
if self.load_merged_weight:
base_key = f"blk.{self.layer_idx}"
w = self.safetensor_loader.load_experts(base_key)
self.gate_weights = w["gate"]
self.up_weights = w["up"]
self.down_weights = w["down"]
self.gate_scales = w["gate_scale"]
self.up_scales = w["up_scale"]
self.down_scales = w["down_scale"]
# Get pointers to weight arrays
gate_ptrs = [
[
ctypes.addressof(ctypes.cast(et.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents)
for et in numa_array
]
for numa_array in self.gate_weights
]
up_ptrs = [
[
ctypes.addressof(ctypes.cast(et.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents)
for et in numa_array
]
for numa_array in self.up_weights
]
down_ptrs = [
[
ctypes.addressof(ctypes.cast(et.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents)
for et in numa_array
]
for numa_array in self.down_weights
]
gate_scale_ptrs = [
[
ctypes.addressof(ctypes.cast(et.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents)
for et in numa_array
]
for numa_array in self.gate_scales
]
up_scale_ptrs = [
[
ctypes.addressof(ctypes.cast(et.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents)
for et in numa_array
]
for numa_array in self.up_scales
]
down_scale_ptrs = [
[
ctypes.addressof(ctypes.cast(et.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents)
for et in numa_array
]
for numa_array in self.down_scales
]
# Configure MoE
moe_config = MOEConfig(
self.num_experts,
self.num_experts_per_tok,
self.hidden_size,
self.moe_intermediate_size,
self.gpu_experts_mask.data_ptr(),
)
moe_config.layer_idx = self.layer_idx
moe_config.pool = self.cpu_infer.backend_
moe_config.max_len = self.chunked_prefill_size
moe_config.gate_proj = gate_ptr
moe_config.up_proj = up_ptr
moe_config.down_proj = down_ptr
moe_config.gate_projs = gate_ptrs
moe_config.up_projs = up_ptrs
moe_config.down_projs = down_ptrs
moe_config.gate_scales = gate_scale_ptrs
moe_config.up_scales = up_scale_ptrs
moe_config.down_scales = down_scale_ptrs
if self.cpu_save:
moe_config.save = True
moe_config.load = False
base_key = f"model.layers.{self.layer_idx}"
w = self.safetensor_loader.load_experts(base_key)
self.gate_proj = torch.cat(w["gate_weight"], dim=0).contiguous()
self.up_proj = torch.cat(w["up_weight"], dim=0).contiguous()
self.down_proj = torch.cat(w["down_weight"], dim=0).contiguous()
moe_config.gate_proj = self.gate_proj.data_ptr()
moe_config.up_proj = self.up_proj.data_ptr()
moe_config.down_proj = self.down_proj.data_ptr()
else:
moe_config.load = True
if not self.load_merged_weight:
moe_config.path = self.weight_path
# Create MoE module based on moe method
if self.method == "MOE_INT4":
self.moe = Int4_KERNEL_MOE(moe_config)
elif self.method == "MOE_INT8":
self.moe = Int8_KERNEL_MOE(moe_config)
else:
raise NotImplementedError(f"Unsupported MoE method: {self.method}")
# Load weights
self.cpu_infer.submit(self.moe.load_weights_task(physical_to_logical_map_cpu.data_ptr()))
self.cpu_infer.sync()
# Clean up temporary weight storage if using merged weights
if self.load_merged_weight:
del self.gate_weights
del self.up_weights
del self.down_weights
del self.gate_scales
del self.up_scales
del self.down_scales