Merge branch 'main' into develop-cht

This commit is contained in:
chenht2022
2025-11-03 14:35:44 +00:00
192 changed files with 22265 additions and 12592 deletions

View File

@@ -17,9 +17,9 @@ from safetensors import safe_open
import os
import ctypes
# Import the C++ extension module (compiled as cpuinfer_ext)
import cpuinfer_ext
from cpuinfer_ext.moe import MOEConfig, AMXInt4_MOE, AMXInt8_MOE
# Import the C++ extension module (compiled as kt_kernel_ext)
import kt_kernel_ext
from kt_kernel_ext.moe import MOEConfig, AMXInt4_MOE, AMXInt8_MOE
class SafeTensorLoader:
@@ -227,6 +227,7 @@ class AMXMoEWrapper:
chunked_prefill_size: int,
cpu_save: bool = False,
max_deferred_experts_per_token: Optional[int] = None,
amx_method: str = "AMXINT4",
):
"""
Initialize AMX MoE Wrapper.
@@ -244,6 +245,7 @@ class AMXMoEWrapper:
chunked_prefill_size: Maximum prefill chunk size
cpu_save: Whether to save weights to CPU memory
max_deferred_experts_per_token: Number of experts per token to defer on this layer. Defaults to 0 (no defer).
amx_method: AMX quantization method ("AMXINT4" or "AMXINT8")
"""
self.layer_idx = layer_idx
@@ -258,10 +260,11 @@ class AMXMoEWrapper:
self.max_deferred_experts_per_token = int(max_deferred_experts_per_token) if max_deferred_experts_per_token is not None else 0
AMXMoEWrapper._layer_has_pending_deferred[self.layer_idx] = False
self.amx_method = amx_method
# Initialize CPU inference engine (singleton)
if AMXMoEWrapper._cpu_infer_instance is None:
worker_config = cpuinfer_ext.WorkerPoolConfig()
worker_config = kt_kernel_ext.WorkerPoolConfig()
subpool_numa_map = list(range(threadpool_count))
subpool_thread_count = [
@@ -272,7 +275,7 @@ class AMXMoEWrapper:
worker_config.subpool_count = threadpool_count
worker_config.subpool_numa_map = subpool_numa_map
worker_config.subpool_thread_count = subpool_thread_count
AMXMoEWrapper._cpu_infer_instance = cpuinfer_ext.CPUInfer(worker_config)
AMXMoEWrapper._cpu_infer_instance = kt_kernel_ext.CPUInfer(worker_config)
self.cpu_infer = AMXMoEWrapper._cpu_infer_instance
@@ -343,13 +346,12 @@ class AMXMoEWrapper:
moe_config.path = self.amx_weight_path
# Create MoE module based on AMX method
amx_method = os.environ.get("AMX_METHOD", "AMXINT4")
if amx_method == "AMXINT4":
if self.amx_method == "AMXINT4":
self.moe = AMXInt4_MOE(moe_config)
elif amx_method == "AMXINT8":
elif self.amx_method == "AMXINT8":
self.moe = AMXInt8_MOE(moe_config)
else:
raise NotImplementedError(f"Unsupported AMX method: {amx_method}")
raise NotImplementedError(f"Unsupported AMX method: {self.amx_method}")
# Submit quantization and save task
self.cpu_infer.submit(self.moe.load_weights_task(physical_to_logical_map_cpu.data_ptr()))
@@ -476,13 +478,12 @@ class AMXMoEWrapper:
moe_config.path = self.amx_weight_path
# Create MoE module based on AMX method
amx_method = os.environ.get("AMX_METHOD", "AMXINT4")
if amx_method == "AMXINT4":
if self.amx_method == "AMXINT4":
self.moe = AMXInt4_MOE(moe_config)
elif amx_method == "AMXINT8":
elif self.amx_method == "AMXINT8":
self.moe = AMXInt8_MOE(moe_config)
else:
raise NotImplementedError(f"Unsupported AMX method: {amx_method}")
raise NotImplementedError(f"Unsupported AMX method: {self.amx_method}")
# Load weights
self.cpu_infer.submit(self.moe.load_weights_task(physical_to_logical_map_cpu.data_ptr()))