mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2026-04-20 06:18:59 +00:00
@@ -4,16 +4,16 @@ import ctypes
|
||||
|
||||
# Use relative imports for package structure
|
||||
from ..experts_base import BaseMoEWrapper
|
||||
from .loader import SafeTensorLoader, CompressedSafeTensorLoader
|
||||
from .loader import SafeTensorLoader, CompressedSafeTensorLoader, FP8SafeTensorLoader
|
||||
from kt_kernel_ext.moe import MOEConfig
|
||||
|
||||
try:
|
||||
from kt_kernel_ext.moe import AMXInt4_MOE, AMXInt8_MOE, AMXInt4_KGroup_MOE
|
||||
from kt_kernel_ext.moe import AMXInt4_MOE, AMXInt8_MOE, AMXInt4_KGroup_MOE, AMXFP8_MOE
|
||||
|
||||
_HAS_AMX_SUPPORT = True
|
||||
except (ImportError, AttributeError):
|
||||
_HAS_AMX_SUPPORT = False
|
||||
AMXInt4_MOE, AMXInt8_MOE, AMXInt4_KGroup_MOE = None, None, None
|
||||
AMXInt4_MOE, AMXInt8_MOE, AMXInt4_KGroup_MOE, AMXFP8_MOE = None, None, None, None
|
||||
|
||||
from typing import Optional
|
||||
|
||||
@@ -303,10 +303,10 @@ class AMXMoEWrapper(BaseMoEWrapper):
|
||||
del self.down_scales
|
||||
|
||||
|
||||
class RAWAMXMoEWrapper(BaseMoEWrapper):
|
||||
"""Wrapper for RAWINT4 experts stored in compressed SafeTensor format."""
|
||||
class NativeMoEWrapper(BaseMoEWrapper):
|
||||
"""Wrapper for RAWINT4/FP8 experts stored in compressed SafeTensor format."""
|
||||
|
||||
_compressed_loader_instance = None
|
||||
_native_loader_instance = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -324,8 +324,12 @@ class RAWAMXMoEWrapper(BaseMoEWrapper):
|
||||
max_deferred_experts_per_token: Optional[int] = None,
|
||||
method: str = "RAWINT4",
|
||||
):
|
||||
if not _HAS_AMX_SUPPORT or AMXInt4_KGroup_MOE is None:
|
||||
if not _HAS_AMX_SUPPORT:
|
||||
raise RuntimeError("AMX backend is not available.")
|
||||
if method == "RAWINT4" and AMXInt4_KGroup_MOE is None:
|
||||
raise RuntimeError("AMX backend with RAWINT4 support is not available.")
|
||||
if method == "FP8" and AMXFP8_MOE is None:
|
||||
raise RuntimeError("AMX backend with FP8 support is not available.")
|
||||
|
||||
super().__init__(
|
||||
layer_idx=layer_idx,
|
||||
@@ -343,9 +347,14 @@ class RAWAMXMoEWrapper(BaseMoEWrapper):
|
||||
method=method,
|
||||
)
|
||||
|
||||
if RAWAMXMoEWrapper._compressed_loader_instance is None:
|
||||
RAWAMXMoEWrapper._compressed_loader_instance = CompressedSafeTensorLoader(weight_path)
|
||||
self.loader = RAWAMXMoEWrapper._compressed_loader_instance
|
||||
if NativeMoEWrapper._native_loader_instance is None:
|
||||
if method == "RAWINT4":
|
||||
NativeMoEWrapper._native_loader_instance = CompressedSafeTensorLoader(weight_path)
|
||||
elif method == "FP8":
|
||||
NativeMoEWrapper._native_loader_instance = FP8SafeTensorLoader(weight_path)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported method for NativeMoEWrapper: {method}")
|
||||
self.loader = NativeMoEWrapper._native_loader_instance
|
||||
|
||||
self.gate_weights = None
|
||||
self.up_weights = None
|
||||
@@ -378,9 +387,17 @@ class RAWAMXMoEWrapper(BaseMoEWrapper):
|
||||
self.down_weights = weights["down"]
|
||||
|
||||
# Convert scales to bf16 individually
|
||||
self.gate_scales = [t.to(torch.bfloat16).contiguous() for t in weights["gate_scale"]]
|
||||
self.up_scales = [t.to(torch.bfloat16).contiguous() for t in weights["up_scale"]]
|
||||
self.down_scales = [t.to(torch.bfloat16).contiguous() for t in weights["down_scale"]]
|
||||
# self.gate_scales = [t.to(torch.bfloat16).contiguous() for t in weights["gate_scale"]]
|
||||
# self.up_scales = [t.to(torch.bfloat16).contiguous() for t in weights["up_scale"]]
|
||||
# self.down_scales = [t.to(torch.bfloat16).contiguous() for t in weights["down_scale"]]
|
||||
self.gate_scales = weights["gate_scale"]
|
||||
self.up_scales = weights["up_scale"]
|
||||
self.down_scales = weights["down_scale"]
|
||||
if self.method == "RAWINT4":
|
||||
assert self.gate_scales[0].dtype == torch.bfloat16, "Expected bf16 scales for RAWINT4"
|
||||
elif self.method == "FP8":
|
||||
assert self.gate_scales[0].dtype == torch.float32, "Expected float32 scales for FP8"
|
||||
|
||||
t2 = time.time()
|
||||
|
||||
# Build pointer lists: [numa_id][expert_id] -> pointer
|
||||
@@ -404,18 +421,6 @@ class RAWAMXMoEWrapper(BaseMoEWrapper):
|
||||
moe_config.pool = self.cpu_infer.backend_
|
||||
moe_config.max_len = self.chunked_prefill_size
|
||||
|
||||
# Infer group_size from scale shape (column-major layout)
|
||||
# For gate/up projection: in_features = hidden_size
|
||||
# So: group_size = hidden_size / scale.shape[1]
|
||||
scale_shape = self.gate_scales[0].shape
|
||||
group_size = self.hidden_size // scale_shape[1]
|
||||
print(f"[RAWAMXMoEWrapper Layer {self.layer_idx}] Inferred group_size: {group_size}")
|
||||
|
||||
moe_config.quant_config.bits = 4
|
||||
moe_config.quant_config.group_size = group_size
|
||||
|
||||
moe_config.quant_config.zero_point = False
|
||||
|
||||
# Use gate_projs instead of gate_proj for per-expert pointers
|
||||
moe_config.gate_projs = gate_ptrs
|
||||
moe_config.up_projs = up_ptrs
|
||||
@@ -424,7 +429,21 @@ class RAWAMXMoEWrapper(BaseMoEWrapper):
|
||||
moe_config.up_scales = up_scale_ptrs
|
||||
moe_config.down_scales = down_scale_ptrs
|
||||
|
||||
self.moe = AMXInt4_KGroup_MOE(moe_config)
|
||||
# Infer group_size from scale shape (column-major layout)
|
||||
# For gate/up projection: in_features = hidden_size
|
||||
# So: group_size = hidden_size / scale.shape[1]
|
||||
|
||||
if self.method == "RAWINT4":
|
||||
group_size = self.hidden_size // self.gate_scales[0].shape[1]
|
||||
moe_config.quant_config.bits = 4
|
||||
moe_config.quant_config.group_size = group_size
|
||||
moe_config.quant_config.zero_point = False
|
||||
self.moe = AMXInt4_KGroup_MOE(moe_config)
|
||||
elif self.method == "FP8":
|
||||
moe_config.quant_config.bits = 8
|
||||
moe_config.quant_config.group_size = 128
|
||||
moe_config.quant_config.zero_point = False
|
||||
self.moe = AMXFP8_MOE(moe_config)
|
||||
t4 = time.time()
|
||||
|
||||
self.cpu_infer.submit(self.moe.load_weights_task(physical_to_logical_map_cpu.data_ptr()))
|
||||
@@ -440,7 +459,7 @@ class RAWAMXMoEWrapper(BaseMoEWrapper):
|
||||
t6 = time.time()
|
||||
|
||||
print(
|
||||
f"[RAWAMXMoEWrapper Layer {self.layer_idx}] "
|
||||
f"[NativeMoEWrapper Layer {self.layer_idx}] "
|
||||
f"load_experts: {(t1-t0)*1000:.1f}ms, "
|
||||
f"prepare_tensors: {(t2-t1)*1000:.1f}ms, "
|
||||
f"build_ptrs: {(t3-t2)*1000:.1f}ms, "
|
||||
@@ -453,7 +472,7 @@ class RAWAMXMoEWrapper(BaseMoEWrapper):
|
||||
def submit_write_weight_scale_to_buffer(
|
||||
self,
|
||||
gpu_tp_count: int,
|
||||
gpu_experts_num: int,
|
||||
expert_id: int,
|
||||
w13_weight_ptrs,
|
||||
w13_scale_ptrs,
|
||||
w2_weight_ptrs,
|
||||
@@ -477,7 +496,7 @@ class RAWAMXMoEWrapper(BaseMoEWrapper):
|
||||
self.cpu_infer.submit(
|
||||
self.moe.write_weight_scale_to_buffer_task(
|
||||
gpu_tp_count,
|
||||
gpu_experts_num,
|
||||
expert_id,
|
||||
w13_weight_ptrs,
|
||||
w13_scale_ptrs,
|
||||
w2_weight_ptrs,
|
||||
|
||||
Reference in New Issue
Block a user