mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2026-04-19 22:09:10 +00:00
[feat](moe_kernel): add amd blis support (int8) (#1600)
* [feat]: init amd adaption * [feat]: add blis support * [fix]: fix setup and moe kernel warpper * [fix](setup.py): support rebuild with cache and import kt_kernel works fine * [feat]: add moe_kernel converter for amd and implement the load method(haven't tested yet) * [feat](moe_kernel/moe.hpp): delete unused memory when using save * [fix](moe_kernel): update PLAIN for pack * [fix](moe_kernel): rm printf debug * [fix](moe_kernel): skip gpu experts * [fix](moe_kernel/moe.hpp): update include memory path * [feat](moe_kernel/moe.hpp): support expert deferral * [feat]: finish amd --------- Co-authored-by: mrhaoxx <mr.haoxx@gmail.com>
This commit is contained in:
@@ -19,6 +19,7 @@ from .experts_base import BaseMoEWrapper, KExpertsCPUBuffer
|
||||
# Import backend implementations
|
||||
from .utils.amx import AMXMoEWrapper
|
||||
from .utils.llamafile import LlamafileMoEWrapper
|
||||
from .utils.moe_kernel import GeneralMoEWrapper
|
||||
|
||||
|
||||
class KTMoEWrapper:
|
||||
@@ -76,7 +77,7 @@ class KTMoEWrapper:
|
||||
chunked_prefill_size: Maximum prefill chunk size
|
||||
cpu_save: Whether to save weights to CPU memory
|
||||
max_deferred_experts_per_token: Number of experts per token to defer. Defaults to 0.
|
||||
method: Backend method ("AMXINT4", "AMXINT8", "LLAMAFILE")
|
||||
method: Backend method ("AMXINT4", "AMXINT8", "LLAMAFILE", "MOE_INT4", "MOE_INT8")
|
||||
|
||||
Returns:
|
||||
An instance of the appropriate backend implementation (e.g., AMXMoEWrapper)
|
||||
@@ -86,6 +87,8 @@ class KTMoEWrapper:
|
||||
backend_cls = AMXMoEWrapper
|
||||
elif method == "LLAMAFILE":
|
||||
backend_cls = LlamafileMoEWrapper
|
||||
elif method in ["MOE_INT4", "MOE_INT8"]:
|
||||
backend_cls = GeneralMoEWrapper
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported method: {method}")
|
||||
|
||||
|
||||
315
kt-kernel/python/utils/moe_kernel.py
Normal file
315
kt-kernel/python/utils/moe_kernel.py
Normal file
@@ -0,0 +1,315 @@
|
||||
import os
|
||||
import torch
|
||||
import ctypes
|
||||
|
||||
# Use relative imports for package structure
|
||||
from ..experts_base import BaseMoEWrapper
|
||||
from .loader import SafeTensorLoader
|
||||
from kt_kernel_ext.moe import MOEConfig
|
||||
|
||||
try:
|
||||
from kt_kernel_ext.moe import Int8_KERNEL_MOE
|
||||
|
||||
_HAS_INT8_SUPPORT = True
|
||||
except (ImportError, AttributeError):
|
||||
Int8_KERNEL_MOE = None
|
||||
_HAS_INT8_SUPPORT = False
|
||||
try:
|
||||
from kt_kernel_ext.moe import Int4_KERNEL_MOE
|
||||
|
||||
_HAS_INT4_SUPPORT = True
|
||||
except (ImportError, AttributeError):
|
||||
Int4_KERNEL_MOE = None
|
||||
_HAS_INT4_SUPPORT = False
|
||||
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class GeneralMoEWrapper(BaseMoEWrapper):
|
||||
"""
|
||||
moe-based MoE wrapper implementation.
|
||||
Supports MOE_INT4 and MOE_INT8 quantization methods.
|
||||
"""
|
||||
|
||||
_safetensor_loader_instance = None # Singleton SafeTensorLoader
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer_idx: int,
|
||||
num_experts: int,
|
||||
num_experts_per_tok: int,
|
||||
hidden_size: int,
|
||||
moe_intermediate_size: int,
|
||||
num_gpu_experts: int,
|
||||
cpuinfer_threads: int,
|
||||
threadpool_count: int,
|
||||
weight_path: str,
|
||||
chunked_prefill_size: int,
|
||||
cpu_save: bool = False,
|
||||
max_deferred_experts_per_token: Optional[int] = None,
|
||||
method: str = "MOE_INT8",
|
||||
):
|
||||
"""
|
||||
Initialize general MoE Wrapper.
|
||||
|
||||
Args:
|
||||
layer_idx: Layer index
|
||||
num_experts: Total number of experts
|
||||
num_experts_per_tok: Number of experts per token (top-k)
|
||||
hidden_size: Hidden dimension size
|
||||
moe_intermediate_size: MoE intermediate size
|
||||
num_gpu_experts: Number of experts to run on GPU
|
||||
cpuinfer_threads: Number of CPU inference threads
|
||||
threadpool_count: Number of NUMA subpools
|
||||
weight_path: Path to weights (SafeTensor format)
|
||||
chunked_prefill_size: Maximum prefill chunk size
|
||||
cpu_save: Whether to save weights to CPU memory
|
||||
max_deferred_experts_per_token: Number of experts per token to defer. Defaults to 0.
|
||||
method: general quantization method ("MOE_INT4" or "MOE_INT8")
|
||||
"""
|
||||
if not _HAS_INT4_SUPPORT and method == "MOE_INT4":
|
||||
raise RuntimeError(
|
||||
"MoE_INT4 backend not available. kt_kernel_ext was not compiled with int4 support.\n"
|
||||
"Please recompile with int4 enabled."
|
||||
)
|
||||
if not _HAS_INT8_SUPPORT and method == "MOE_INT8":
|
||||
raise RuntimeError(
|
||||
"MoE_INT8 backend not available. kt_kernel_ext was not compiled with int8 support.\n"
|
||||
"Please recompile with int8 enabled."
|
||||
)
|
||||
|
||||
# Initialize base class
|
||||
super().__init__(
|
||||
layer_idx=layer_idx,
|
||||
num_experts=num_experts,
|
||||
num_experts_per_tok=num_experts_per_tok,
|
||||
hidden_size=hidden_size,
|
||||
moe_intermediate_size=moe_intermediate_size,
|
||||
num_gpu_experts=num_gpu_experts,
|
||||
cpuinfer_threads=cpuinfer_threads,
|
||||
threadpool_count=threadpool_count,
|
||||
weight_path=weight_path,
|
||||
chunked_prefill_size=chunked_prefill_size,
|
||||
cpu_save=cpu_save,
|
||||
max_deferred_experts_per_token=max_deferred_experts_per_token,
|
||||
method=method,
|
||||
)
|
||||
|
||||
# moe-specific: Check if we should load merged safetensor weights
|
||||
self.load_merged_weight = False
|
||||
import glob
|
||||
|
||||
if glob.glob(os.path.join(weight_path, "*.safetensors")):
|
||||
self.load_merged_weight = True
|
||||
|
||||
# Initialize SafeTensor loader (singleton)
|
||||
if self.load_merged_weight:
|
||||
if GeneralMoEWrapper._safetensor_loader_instance is None:
|
||||
GeneralMoEWrapper._safetensor_loader_instance = SafeTensorLoader(weight_path)
|
||||
self.safetensor_loader = GeneralMoEWrapper._safetensor_loader_instance
|
||||
|
||||
# moe-specific weight storage
|
||||
self.gate_weights = None
|
||||
self.up_weights = None
|
||||
self.down_weights = None
|
||||
self.gate_scales = None
|
||||
self.up_scales = None
|
||||
self.down_scales = None
|
||||
|
||||
def load_weights_from_tensors(
|
||||
self,
|
||||
gate_proj: torch.Tensor,
|
||||
up_proj: torch.Tensor,
|
||||
down_proj: torch.Tensor,
|
||||
physical_to_logical_map_cpu: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Load and quantize weights from BF16/FP16 tensors (online quantization).
|
||||
|
||||
Args:
|
||||
gate_proj: Gate projection weights [num_experts, intermediate_size, hidden_size]
|
||||
up_proj: Up projection weights [num_experts, intermediate_size, hidden_size]
|
||||
down_proj: Down projection weights [num_experts, hidden_size, intermediate_size]
|
||||
physical_to_logical_map_cpu: Mapping from physical to logical expert IDs
|
||||
"""
|
||||
# Store tensors as instance variables to keep them alive
|
||||
self.gate_proj = gate_proj.contiguous()
|
||||
self.up_proj = up_proj.contiguous()
|
||||
self.down_proj = down_proj.contiguous()
|
||||
|
||||
# Configure MoE with online quantization (cpu_save mode)
|
||||
moe_config = MOEConfig(
|
||||
self.num_experts,
|
||||
self.num_experts_per_tok,
|
||||
self.hidden_size,
|
||||
self.moe_intermediate_size,
|
||||
self.num_gpu_experts,
|
||||
)
|
||||
moe_config.layer_idx = self.layer_idx
|
||||
moe_config.pool = self.cpu_infer.backend_
|
||||
moe_config.max_len = self.chunked_prefill_size
|
||||
|
||||
# Enable save mode for online quantization
|
||||
moe_config.save = True
|
||||
moe_config.load = False
|
||||
|
||||
# Set weight pointers
|
||||
moe_config.gate_proj = self.gate_proj.data_ptr()
|
||||
moe_config.up_proj = self.up_proj.data_ptr()
|
||||
moe_config.down_proj = self.down_proj.data_ptr()
|
||||
|
||||
# Set output path for quantized weights
|
||||
moe_config.path = self.weight_path
|
||||
|
||||
# Create MoE module based on method
|
||||
if self.method == "MOE_INT4":
|
||||
self.moe = Int4_KERNEL_MOE(moe_config)
|
||||
elif self.method == "MOE_INT8":
|
||||
self.moe = Int8_KERNEL_MOE(moe_config)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported MoE method: {self.method}")
|
||||
|
||||
# Submit quantization and save task
|
||||
self.cpu_infer.submit(self.moe.load_weights_task(physical_to_logical_map_cpu.data_ptr()))
|
||||
self.cpu_infer.sync()
|
||||
|
||||
def load_weights(self, physical_to_logical_map_cpu: torch.Tensor):
|
||||
"""
|
||||
Load weights for this layer and initialize the MoE module.
|
||||
|
||||
Args:
|
||||
physical_to_logical_map_cpu: Mapping from physical to logical expert IDs
|
||||
"""
|
||||
gate_ptr = 0
|
||||
up_ptr = 0
|
||||
down_ptr = 0
|
||||
|
||||
gate_ptrs = []
|
||||
up_ptrs = []
|
||||
down_ptrs = []
|
||||
|
||||
gate_scale_ptrs = []
|
||||
up_scale_ptrs = []
|
||||
down_scale_ptrs = []
|
||||
|
||||
if self.load_merged_weight:
|
||||
base_key = f"blk.{self.layer_idx}"
|
||||
w = self.safetensor_loader.load_experts(base_key)
|
||||
|
||||
self.gate_weights = w["gate"]
|
||||
self.up_weights = w["up"]
|
||||
self.down_weights = w["down"]
|
||||
self.gate_scales = w["gate_scale"]
|
||||
self.up_scales = w["up_scale"]
|
||||
self.down_scales = w["down_scale"]
|
||||
|
||||
# Get pointers to weight arrays
|
||||
gate_ptrs = [
|
||||
[
|
||||
ctypes.addressof(ctypes.cast(et.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents)
|
||||
for et in numa_array
|
||||
]
|
||||
for numa_array in self.gate_weights
|
||||
]
|
||||
|
||||
up_ptrs = [
|
||||
[
|
||||
ctypes.addressof(ctypes.cast(et.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents)
|
||||
for et in numa_array
|
||||
]
|
||||
for numa_array in self.up_weights
|
||||
]
|
||||
|
||||
down_ptrs = [
|
||||
[
|
||||
ctypes.addressof(ctypes.cast(et.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents)
|
||||
for et in numa_array
|
||||
]
|
||||
for numa_array in self.down_weights
|
||||
]
|
||||
|
||||
gate_scale_ptrs = [
|
||||
[
|
||||
ctypes.addressof(ctypes.cast(et.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents)
|
||||
for et in numa_array
|
||||
]
|
||||
for numa_array in self.gate_scales
|
||||
]
|
||||
|
||||
up_scale_ptrs = [
|
||||
[
|
||||
ctypes.addressof(ctypes.cast(et.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents)
|
||||
for et in numa_array
|
||||
]
|
||||
for numa_array in self.up_scales
|
||||
]
|
||||
|
||||
down_scale_ptrs = [
|
||||
[
|
||||
ctypes.addressof(ctypes.cast(et.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents)
|
||||
for et in numa_array
|
||||
]
|
||||
for numa_array in self.down_scales
|
||||
]
|
||||
|
||||
# Configure MoE
|
||||
moe_config = MOEConfig(
|
||||
self.num_experts,
|
||||
self.num_experts_per_tok,
|
||||
self.hidden_size,
|
||||
self.moe_intermediate_size,
|
||||
self.num_gpu_experts,
|
||||
)
|
||||
moe_config.layer_idx = self.layer_idx
|
||||
moe_config.pool = self.cpu_infer.backend_
|
||||
moe_config.max_len = self.chunked_prefill_size
|
||||
|
||||
moe_config.gate_proj = gate_ptr
|
||||
moe_config.up_proj = up_ptr
|
||||
moe_config.down_proj = down_ptr
|
||||
moe_config.gate_projs = gate_ptrs
|
||||
moe_config.up_projs = up_ptrs
|
||||
moe_config.down_projs = down_ptrs
|
||||
moe_config.gate_scales = gate_scale_ptrs
|
||||
moe_config.up_scales = up_scale_ptrs
|
||||
moe_config.down_scales = down_scale_ptrs
|
||||
|
||||
if self.cpu_save:
|
||||
moe_config.save = True
|
||||
moe_config.load = False
|
||||
base_key = f"model.layers.{self.layer_idx}"
|
||||
w = self.safetensor_loader.load_experts(base_key)
|
||||
|
||||
self.gate_proj = torch.cat(w["gate_weight"], dim=0).contiguous()
|
||||
self.up_proj = torch.cat(w["up_weight"], dim=0).contiguous()
|
||||
self.down_proj = torch.cat(w["down_weight"], dim=0).contiguous()
|
||||
|
||||
moe_config.gate_proj = self.gate_proj.data_ptr()
|
||||
moe_config.up_proj = self.up_proj.data_ptr()
|
||||
moe_config.down_proj = self.down_proj.data_ptr()
|
||||
else:
|
||||
moe_config.load = True
|
||||
|
||||
if not self.load_merged_weight:
|
||||
moe_config.path = self.weight_path
|
||||
|
||||
# Create MoE module based on moe method
|
||||
if self.method == "MOE_INT4":
|
||||
self.moe = Int4_KERNEL_MOE(moe_config)
|
||||
elif self.method == "MOE_INT8":
|
||||
self.moe = Int8_KERNEL_MOE(moe_config)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported MoE method: {self.method}")
|
||||
|
||||
# Load weights
|
||||
self.cpu_infer.submit(self.moe.load_weights_task(physical_to_logical_map_cpu.data_ptr()))
|
||||
self.cpu_infer.sync()
|
||||
|
||||
# Clean up temporary weight storage if using merged weights
|
||||
if self.load_merged_weight:
|
||||
del self.gate_weights
|
||||
del self.up_weights
|
||||
del self.down_weights
|
||||
del self.gate_scales
|
||||
del self.up_scales
|
||||
del self.down_scales
|
||||
Reference in New Issue
Block a user