mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2026-04-20 06:18:59 +00:00
Merge branch 'main' into develop-cht
This commit is contained in:
@@ -17,9 +17,9 @@ from safetensors import safe_open
|
||||
import os
|
||||
import ctypes
|
||||
|
||||
# Import the C++ extension module (compiled as cpuinfer_ext)
|
||||
import cpuinfer_ext
|
||||
from cpuinfer_ext.moe import MOEConfig, AMXInt4_MOE, AMXInt8_MOE
|
||||
# Import the C++ extension module (compiled as kt_kernel_ext)
|
||||
import kt_kernel_ext
|
||||
from kt_kernel_ext.moe import MOEConfig, AMXInt4_MOE, AMXInt8_MOE
|
||||
|
||||
|
||||
class SafeTensorLoader:
|
||||
@@ -227,6 +227,7 @@ class AMXMoEWrapper:
|
||||
chunked_prefill_size: int,
|
||||
cpu_save: bool = False,
|
||||
max_deferred_experts_per_token: Optional[int] = None,
|
||||
amx_method: str = "AMXINT4",
|
||||
):
|
||||
"""
|
||||
Initialize AMX MoE Wrapper.
|
||||
@@ -244,6 +245,7 @@ class AMXMoEWrapper:
|
||||
chunked_prefill_size: Maximum prefill chunk size
|
||||
cpu_save: Whether to save weights to CPU memory
|
||||
max_deferred_experts_per_token: Number of experts per token to defer on this layer. Defaults to 0 (no defer).
|
||||
amx_method: AMX quantization method ("AMXINT4" or "AMXINT8")
|
||||
"""
|
||||
|
||||
self.layer_idx = layer_idx
|
||||
@@ -258,10 +260,11 @@ class AMXMoEWrapper:
|
||||
self.max_deferred_experts_per_token = int(max_deferred_experts_per_token) if max_deferred_experts_per_token is not None else 0
|
||||
|
||||
AMXMoEWrapper._layer_has_pending_deferred[self.layer_idx] = False
|
||||
self.amx_method = amx_method
|
||||
|
||||
# Initialize CPU inference engine (singleton)
|
||||
if AMXMoEWrapper._cpu_infer_instance is None:
|
||||
worker_config = cpuinfer_ext.WorkerPoolConfig()
|
||||
worker_config = kt_kernel_ext.WorkerPoolConfig()
|
||||
|
||||
subpool_numa_map = list(range(threadpool_count))
|
||||
subpool_thread_count = [
|
||||
@@ -272,7 +275,7 @@ class AMXMoEWrapper:
|
||||
worker_config.subpool_count = threadpool_count
|
||||
worker_config.subpool_numa_map = subpool_numa_map
|
||||
worker_config.subpool_thread_count = subpool_thread_count
|
||||
AMXMoEWrapper._cpu_infer_instance = cpuinfer_ext.CPUInfer(worker_config)
|
||||
AMXMoEWrapper._cpu_infer_instance = kt_kernel_ext.CPUInfer(worker_config)
|
||||
|
||||
self.cpu_infer = AMXMoEWrapper._cpu_infer_instance
|
||||
|
||||
@@ -343,13 +346,12 @@ class AMXMoEWrapper:
|
||||
moe_config.path = self.amx_weight_path
|
||||
|
||||
# Create MoE module based on AMX method
|
||||
amx_method = os.environ.get("AMX_METHOD", "AMXINT4")
|
||||
if amx_method == "AMXINT4":
|
||||
if self.amx_method == "AMXINT4":
|
||||
self.moe = AMXInt4_MOE(moe_config)
|
||||
elif amx_method == "AMXINT8":
|
||||
elif self.amx_method == "AMXINT8":
|
||||
self.moe = AMXInt8_MOE(moe_config)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported AMX method: {amx_method}")
|
||||
raise NotImplementedError(f"Unsupported AMX method: {self.amx_method}")
|
||||
|
||||
# Submit quantization and save task
|
||||
self.cpu_infer.submit(self.moe.load_weights_task(physical_to_logical_map_cpu.data_ptr()))
|
||||
@@ -476,13 +478,12 @@ class AMXMoEWrapper:
|
||||
moe_config.path = self.amx_weight_path
|
||||
|
||||
# Create MoE module based on AMX method
|
||||
amx_method = os.environ.get("AMX_METHOD", "AMXINT4")
|
||||
if amx_method == "AMXINT4":
|
||||
if self.amx_method == "AMXINT4":
|
||||
self.moe = AMXInt4_MOE(moe_config)
|
||||
elif amx_method == "AMXINT8":
|
||||
elif self.amx_method == "AMXINT8":
|
||||
self.moe = AMXInt8_MOE(moe_config)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported AMX method: {amx_method}")
|
||||
raise NotImplementedError(f"Unsupported AMX method: {self.amx_method}")
|
||||
|
||||
# Load weights
|
||||
self.cpu_infer.submit(self.moe.load_weights_task(physical_to_logical_map_cpu.data_ptr()))
|
||||
|
||||
Reference in New Issue
Block a user