mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2026-04-19 22:09:10 +00:00
@@ -77,7 +77,7 @@ class KTMoEWrapper:
|
||||
chunked_prefill_size: Maximum prefill chunk size
|
||||
cpu_save: Whether to save weights to CPU memory
|
||||
max_deferred_experts_per_token: Number of experts per token to defer. Defaults to 0.
|
||||
method: Backend method ("AMXINT4", "AMXINT8", "RAWINT4", "FP8", "LLAMAFILE", "MOE_INT4", "MOE_INT8")
|
||||
method: Backend method ("AMXINT4", "AMXINT8", "RAWINT4", "FP8", "BF16", "LLAMAFILE", "MOE_INT4", "MOE_INT8")
|
||||
|
||||
Returns:
|
||||
An instance of the appropriate backend implementation (e.g., AMXMoEWrapper)
|
||||
@@ -85,7 +85,7 @@ class KTMoEWrapper:
|
||||
# Select backend based on method
|
||||
if method in ["AMXINT4", "AMXINT8"]:
|
||||
backend_cls = AMXMoEWrapper
|
||||
elif method in ["RAWINT4", "FP8"]:
|
||||
elif method in ["RAWINT4", "FP8", "BF16"]:
|
||||
backend_cls = NativeMoEWrapper
|
||||
elif method == "LLAMAFILE":
|
||||
backend_cls = LlamafileMoEWrapper
|
||||
|
||||
@@ -4,16 +4,16 @@ import ctypes
|
||||
|
||||
# Use relative imports for package structure
|
||||
from ..experts_base import BaseMoEWrapper
|
||||
from .loader import SafeTensorLoader, CompressedSafeTensorLoader, FP8SafeTensorLoader
|
||||
from .loader import SafeTensorLoader, CompressedSafeTensorLoader, FP8SafeTensorLoader, BF16SafeTensorLoader
|
||||
from kt_kernel_ext.moe import MOEConfig
|
||||
|
||||
try:
|
||||
from kt_kernel_ext.moe import AMXInt4_MOE, AMXInt8_MOE, AMXInt4_KGroup_MOE, AMXFP8_MOE
|
||||
from kt_kernel_ext.moe import AMXInt4_MOE, AMXInt8_MOE, AMXInt4_KGroup_MOE, AMXFP8_MOE, AMXBF16_MOE
|
||||
|
||||
_HAS_AMX_SUPPORT = True
|
||||
except (ImportError, AttributeError):
|
||||
_HAS_AMX_SUPPORT = False
|
||||
AMXInt4_MOE, AMXInt8_MOE, AMXInt4_KGroup_MOE, AMXFP8_MOE = None, None, None, None
|
||||
AMXInt4_MOE, AMXInt8_MOE, AMXInt4_KGroup_MOE, AMXFP8_MOE, AMXBF16_MOE = None, None, None, None, None
|
||||
|
||||
from typing import Optional
|
||||
|
||||
@@ -304,7 +304,7 @@ class AMXMoEWrapper(BaseMoEWrapper):
|
||||
|
||||
|
||||
class NativeMoEWrapper(BaseMoEWrapper):
|
||||
"""Wrapper for RAWINT4/FP8 experts stored in compressed SafeTensor format."""
|
||||
"""Wrapper for RAWINT4/FP8/BF16 experts stored in compressed SafeTensor format."""
|
||||
|
||||
_native_loader_instance = None
|
||||
|
||||
@@ -330,6 +330,8 @@ class NativeMoEWrapper(BaseMoEWrapper):
|
||||
raise RuntimeError("AMX backend with RAWINT4 support is not available.")
|
||||
if method == "FP8" and AMXFP8_MOE is None:
|
||||
raise RuntimeError("AMX backend with FP8 support is not available.")
|
||||
if method == "BF16" and AMXBF16_MOE is None:
|
||||
raise RuntimeError("AMX backend with BF16 support is not available.")
|
||||
|
||||
super().__init__(
|
||||
layer_idx=layer_idx,
|
||||
@@ -352,6 +354,8 @@ class NativeMoEWrapper(BaseMoEWrapper):
|
||||
NativeMoEWrapper._native_loader_instance = CompressedSafeTensorLoader(weight_path)
|
||||
elif method == "FP8":
|
||||
NativeMoEWrapper._native_loader_instance = FP8SafeTensorLoader(weight_path)
|
||||
elif method == "BF16":
|
||||
NativeMoEWrapper._native_loader_instance = BF16SafeTensorLoader(weight_path)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported method for NativeMoEWrapper: {method}")
|
||||
self.loader = NativeMoEWrapper._native_loader_instance
|
||||
@@ -386,28 +390,42 @@ class NativeMoEWrapper(BaseMoEWrapper):
|
||||
self.up_weights = weights["up"]
|
||||
self.down_weights = weights["down"]
|
||||
|
||||
# Convert scales to bf16 individually
|
||||
# self.gate_scales = [t.to(torch.bfloat16).contiguous() for t in weights["gate_scale"]]
|
||||
# self.up_scales = [t.to(torch.bfloat16).contiguous() for t in weights["up_scale"]]
|
||||
# self.down_scales = [t.to(torch.bfloat16).contiguous() for t in weights["down_scale"]]
|
||||
self.gate_scales = weights["gate_scale"]
|
||||
self.up_scales = weights["up_scale"]
|
||||
self.down_scales = weights["down_scale"]
|
||||
if self.method == "RAWINT4":
|
||||
assert self.gate_scales[0].dtype == torch.bfloat16, "Expected bf16 scales for RAWINT4"
|
||||
elif self.method == "FP8":
|
||||
assert self.gate_scales[0].dtype == torch.float32, "Expected float32 scales for FP8"
|
||||
# BF16 has no scales, others have scales
|
||||
if self.method == "BF16":
|
||||
# BF16 doesn't have scales
|
||||
self.gate_scales = None
|
||||
self.up_scales = None
|
||||
self.down_scales = None
|
||||
else:
|
||||
# Convert scales to bf16 individually
|
||||
# self.gate_scales = [t.to(torch.bfloat16).contiguous() for t in weights["gate_scale"]]
|
||||
# self.up_scales = [t.to(torch.bfloat16).contiguous() for t in weights["up_scale"]]
|
||||
# self.down_scales = [t.to(torch.bfloat16).contiguous() for t in weights["down_scale"]]
|
||||
self.gate_scales = weights["gate_scale"]
|
||||
self.up_scales = weights["up_scale"]
|
||||
self.down_scales = weights["down_scale"]
|
||||
if self.method == "RAWINT4":
|
||||
assert self.gate_scales[0].dtype == torch.bfloat16, "Expected bf16 scales for RAWINT4"
|
||||
elif self.method == "FP8":
|
||||
assert self.gate_scales[0].dtype == torch.float32, "Expected float32 scales for FP8"
|
||||
|
||||
t2 = time.time()
|
||||
|
||||
# Build pointer lists: [numa_id][expert_id] -> pointer
|
||||
# Since RAWINT4 has no numa sharding, numa dimension is 1
|
||||
# Since RAWINT4/FP8/BF16 has no numa sharding, numa dimension is 1
|
||||
gate_ptrs = [[t.data_ptr() for t in self.gate_weights]]
|
||||
up_ptrs = [[t.data_ptr() for t in self.up_weights]]
|
||||
down_ptrs = [[t.data_ptr() for t in self.down_weights]]
|
||||
gate_scale_ptrs = [[t.data_ptr() for t in self.gate_scales]]
|
||||
up_scale_ptrs = [[t.data_ptr() for t in self.up_scales]]
|
||||
down_scale_ptrs = [[t.data_ptr() for t in self.down_scales]]
|
||||
|
||||
# BF16 has no scales, pass empty lists (will use 0/nullptr for consistency)
|
||||
if self.method == "BF16":
|
||||
gate_scale_ptrs = [[0 for _ in self.gate_weights]]
|
||||
up_scale_ptrs = [[0 for _ in self.up_weights]]
|
||||
down_scale_ptrs = [[0 for _ in self.down_weights]]
|
||||
else:
|
||||
gate_scale_ptrs = [[t.data_ptr() for t in self.gate_scales]]
|
||||
up_scale_ptrs = [[t.data_ptr() for t in self.up_scales]]
|
||||
down_scale_ptrs = [[t.data_ptr() for t in self.down_scales]]
|
||||
t3 = time.time()
|
||||
|
||||
moe_config = MOEConfig(
|
||||
@@ -444,6 +462,9 @@ class NativeMoEWrapper(BaseMoEWrapper):
|
||||
moe_config.quant_config.group_size = 128
|
||||
moe_config.quant_config.zero_point = False
|
||||
self.moe = AMXFP8_MOE(moe_config)
|
||||
elif self.method == "BF16":
|
||||
# BF16 has no quantization config needed
|
||||
self.moe = AMXBF16_MOE(moe_config)
|
||||
t4 = time.time()
|
||||
|
||||
self.cpu_infer.submit(self.moe.load_weights_task(physical_to_logical_map_cpu.data_ptr()))
|
||||
@@ -453,9 +474,10 @@ class NativeMoEWrapper(BaseMoEWrapper):
|
||||
del self.gate_weights
|
||||
del self.up_weights
|
||||
del self.down_weights
|
||||
del self.gate_scales
|
||||
del self.up_scales
|
||||
del self.down_scales
|
||||
if self.gate_scales is not None:
|
||||
del self.gate_scales
|
||||
del self.up_scales
|
||||
del self.down_scales
|
||||
t6 = time.time()
|
||||
|
||||
print(
|
||||
|
||||
@@ -348,6 +348,99 @@ class FP8SafeTensorLoader(SafeTensorLoader):
|
||||
}
|
||||
|
||||
|
||||
class BF16SafeTensorLoader(SafeTensorLoader):
|
||||
"""Loader for native BF16 expert weights (no quantization, no scales).
|
||||
|
||||
Supported formats:
|
||||
- DeepSeek style: {base}.mlp.experts.{id}.{gate,up,down}_proj.weight
|
||||
- Mixtral/MiniMax style: {base}.block_sparse_moe.experts.{id}.{w1,w3,w2}.weight
|
||||
|
||||
The format is auto-detected during initialization.
|
||||
"""
|
||||
|
||||
MOE_FORMATS = {
|
||||
"deepseek": ("{base}.mlp.experts", "gate_proj", "up_proj", "down_proj"),
|
||||
"mixtral": ("{base}.block_sparse_moe.experts", "w1", "w3", "w2"),
|
||||
}
|
||||
|
||||
def __init__(self, file_path: str):
|
||||
super().__init__(file_path)
|
||||
self._detected_format = None
|
||||
self._detect_format()
|
||||
|
||||
def _detect_format(self):
|
||||
"""Auto-detect the MoE naming format by checking tensor keys."""
|
||||
sample_keys = list(self.tensor_file_map.keys())[:1000]
|
||||
|
||||
for fmt_name, (path_tpl, gate, up, down) in self.MOE_FORMATS.items():
|
||||
for key in sample_keys:
|
||||
if ".experts." in key and f".{gate}.weight" in key:
|
||||
if "block_sparse_moe.experts" in key and fmt_name == "mixtral":
|
||||
self._detected_format = fmt_name
|
||||
print(f"[BF16SafeTensorLoader] Detected format: {fmt_name}")
|
||||
return
|
||||
elif "mlp.experts" in key and "block_sparse_moe" not in key and fmt_name == "deepseek":
|
||||
self._detected_format = fmt_name
|
||||
print(f"[BF16SafeTensorLoader] Detected format: {fmt_name}")
|
||||
return
|
||||
|
||||
self._detected_format = "deepseek"
|
||||
print("[BF16SafeTensorLoader] No MoE format detected, defaulting to: deepseek")
|
||||
|
||||
def _get_experts_prefix(self, base_key: str) -> str:
|
||||
"""Get the experts prefix based on detected format."""
|
||||
path_tpl, _, _, _ = self.MOE_FORMATS[self._detected_format]
|
||||
return path_tpl.format(base=base_key)
|
||||
|
||||
def _get_proj_names(self):
|
||||
"""Get projection names (gate, up, down) based on detected format."""
|
||||
_, gate, up, down = self.MOE_FORMATS[self._detected_format]
|
||||
return gate, up, down
|
||||
|
||||
def load_tensor(self, key: str, device: str = "cpu"):
|
||||
if key not in self.tensor_file_map:
|
||||
raise KeyError(f"Key {key} not found in Safetensor files")
|
||||
file = self.tensor_file_map[key]
|
||||
f = self.file_handle_map.get(file)
|
||||
if f is None:
|
||||
raise FileNotFoundError(f"File {file} not found in Safetensor files")
|
||||
tensor = f.get_tensor(key)
|
||||
if device == "cpu":
|
||||
return tensor
|
||||
return tensor.to(device)
|
||||
|
||||
def load_experts(self, base_key: str, device: str = "cpu"):
|
||||
"""Load BF16 expert weights (no scales needed)."""
|
||||
experts_prefix = self._get_experts_prefix(base_key)
|
||||
gate_name, up_name, down_name = self._get_proj_names()
|
||||
|
||||
expert_count = 0
|
||||
while self.has_tensor(f"{experts_prefix}.{expert_count}.{gate_name}.weight"):
|
||||
expert_count += 1
|
||||
|
||||
if expert_count == 0:
|
||||
raise ValueError(f"No experts found for key {experts_prefix}")
|
||||
|
||||
gate_weights = [None] * expert_count
|
||||
up_weights = [None] * expert_count
|
||||
down_weights = [None] * expert_count
|
||||
|
||||
for exp_id in range(expert_count):
|
||||
gate_w_key = f"{experts_prefix}.{exp_id}.{gate_name}.weight"
|
||||
up_w_key = f"{experts_prefix}.{exp_id}.{up_name}.weight"
|
||||
down_w_key = f"{experts_prefix}.{exp_id}.{down_name}.weight"
|
||||
|
||||
gate_weights[exp_id] = self.load_tensor(gate_w_key, device).contiguous()
|
||||
up_weights[exp_id] = self.load_tensor(up_w_key, device).contiguous()
|
||||
down_weights[exp_id] = self.load_tensor(down_w_key, device).contiguous()
|
||||
|
||||
return {
|
||||
"gate": gate_weights,
|
||||
"up": up_weights,
|
||||
"down": down_weights,
|
||||
}
|
||||
|
||||
|
||||
class CompressedSafeTensorLoader(SafeTensorLoader):
|
||||
"""Loader for compressed SafeTensor layouts (RAWINT4 weights)."""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user