mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2026-03-14 18:37:23 +00:00
[feat]: add mistral moe loader compatibility (#1873)
Co-authored-by: chenht2022 <chenht2022@users.noreply.github.com>
This commit is contained in:
@@ -448,6 +448,10 @@ class NativeMoEWrapper(BaseMoEWrapper):
|
||||
self.down_scales = [t.to(torch.float32).contiguous() for t in weights["down_scale"]]
|
||||
assert self.gate_scales[0].dtype == torch.float32, "Expected float32 scales for FP8"
|
||||
elif self.method == "FP8_PERCHANNEL":
|
||||
if self.gate_scales[0].dtype != torch.float32:
|
||||
self.gate_scales = [t.to(torch.float32).contiguous() for t in weights["gate_scale"]]
|
||||
self.up_scales = [t.to(torch.float32).contiguous() for t in weights["up_scale"]]
|
||||
self.down_scales = [t.to(torch.float32).contiguous() for t in weights["down_scale"]]
|
||||
assert self.gate_scales[0].dtype == torch.float32, "Expected float32 scales for FP8_PERCHANNEL"
|
||||
|
||||
t2 = time.time()
|
||||
|
||||
@@ -243,6 +243,7 @@ class FP8SafeTensorLoader(SafeTensorLoader):
|
||||
Supported formats:
|
||||
- DeepSeek style: {base}.mlp.experts.{id}.{gate,up,down}_proj.weight
|
||||
- Mixtral/MiniMax style: {base}.block_sparse_moe.experts.{id}.{w1,w3,w2}.weight
|
||||
- Mistral style: {base}.experts.{id}.{w1,w3,w2}.weight
|
||||
|
||||
Supported scale formats (auto-detected):
|
||||
- Block-wise: weight_scale_inv (DeepSeek FP8)
|
||||
@@ -255,6 +256,7 @@ class FP8SafeTensorLoader(SafeTensorLoader):
|
||||
MOE_FORMATS = {
|
||||
"deepseek": ("{base}.mlp.experts", "gate_proj", "up_proj", "down_proj"),
|
||||
"mixtral": ("{base}.block_sparse_moe.experts", "w1", "w3", "w2"),
|
||||
"mistral": ("{base}.experts", "w1", "w3", "w2"),
|
||||
}
|
||||
|
||||
def __init__(self, file_path: str, scale_suffix: str = None):
|
||||
@@ -297,6 +299,10 @@ class FP8SafeTensorLoader(SafeTensorLoader):
|
||||
self._detected_format = fmt_name
|
||||
print(f"[FP8SafeTensorLoader] Detected format: {fmt_name}")
|
||||
break
|
||||
elif fmt_name == "mistral" and ".mlp.experts" not in key and ".block_sparse_moe.experts" not in key:
|
||||
self._detected_format = fmt_name
|
||||
print(f"[FP8SafeTensorLoader] Detected format: {fmt_name}")
|
||||
break
|
||||
if self._detected_format:
|
||||
break
|
||||
|
||||
@@ -321,8 +327,21 @@ class FP8SafeTensorLoader(SafeTensorLoader):
|
||||
return
|
||||
elif f".{gate}.weight_scale" in key and "weight_scale_inv" not in key:
|
||||
self._scale_suffix = "weight_scale"
|
||||
self._is_per_channel = True
|
||||
print("[FP8SafeTensorLoader] Detected scale format: per-channel (weight_scale)")
|
||||
# Some models (e.g., Mistral) use block-wise FP8 scales but keep
|
||||
# the key suffix as `weight_scale` (without `_inv`). Infer format
|
||||
# from scale tensor shape instead of suffix alone:
|
||||
# - per-channel: [N] or [N, 1]
|
||||
# - block-wise: [N_block, K_block] (both dims > 1)
|
||||
scale_tensor = self.load_tensor(key, device="cpu")
|
||||
if scale_tensor.dim() == 1:
|
||||
self._is_per_channel = True
|
||||
elif scale_tensor.dim() == 2 and scale_tensor.shape[1] == 1:
|
||||
self._is_per_channel = True
|
||||
else:
|
||||
self._is_per_channel = False
|
||||
|
||||
scale_kind = "per-channel" if self._is_per_channel else "block-wise"
|
||||
print(f"[FP8SafeTensorLoader] Detected scale format: {scale_kind} (weight_scale)")
|
||||
return
|
||||
# Default to weight_scale_inv
|
||||
self._scale_suffix = "weight_scale_inv"
|
||||
@@ -333,12 +352,20 @@ class FP8SafeTensorLoader(SafeTensorLoader):
|
||||
scale_type = "per-channel" if self._is_per_channel else "block-wise"
|
||||
print(f"[FP8SafeTensorLoader] Using explicit scale format: {scale_type} ({self._scale_suffix})")
|
||||
|
||||
def _get_experts_prefix(self, base_key: str) -> str:
|
||||
"""Get the experts prefix based on detected format."""
|
||||
def _get_experts_prefix_candidates(self, base_key: str) -> list[str]:
|
||||
"""Get candidate experts prefixes based on detected format and base key variants."""
|
||||
path_tpl, _, _, _ = self.MOE_FORMATS[self._detected_format]
|
||||
candidates = []
|
||||
if self._is_vl_model:
|
||||
base_key = base_key.replace("model.layers", "model.language_model.layers")
|
||||
return path_tpl.format(base=base_key)
|
||||
candidates.append(path_tpl.format(base=base_key))
|
||||
|
||||
# Some model weights (e.g., Mistral native format) do not have "model." prefix.
|
||||
if base_key.startswith("model."):
|
||||
candidates.append(path_tpl.format(base=base_key[len("model.") :]))
|
||||
|
||||
# Deduplicate while preserving order.
|
||||
return list(dict.fromkeys(candidates))
|
||||
|
||||
def _get_proj_names(self):
|
||||
"""Get projection names (gate, up, down) based on detected format."""
|
||||
@@ -363,15 +390,21 @@ class FP8SafeTensorLoader(SafeTensorLoader):
|
||||
Supports both block-wise (weight_scale_inv) and per-channel (weight_scale) formats.
|
||||
Per-channel scales are squeezed from [N, 1] to [N] if needed.
|
||||
"""
|
||||
experts_prefix = self._get_experts_prefix(base_key)
|
||||
experts_prefix_candidates = self._get_experts_prefix_candidates(base_key)
|
||||
gate_name, up_name, down_name = self._get_proj_names()
|
||||
|
||||
expert_count = 0
|
||||
while self.has_tensor(f"{experts_prefix}.{expert_count}.{gate_name}.weight"):
|
||||
expert_count += 1
|
||||
experts_prefix = None
|
||||
for prefix in experts_prefix_candidates:
|
||||
expert_count = 0
|
||||
while self.has_tensor(f"{prefix}.{expert_count}.{gate_name}.weight"):
|
||||
expert_count += 1
|
||||
if expert_count > 0:
|
||||
experts_prefix = prefix
|
||||
break
|
||||
|
||||
if expert_count == 0:
|
||||
raise ValueError(f"No experts found for key {experts_prefix}")
|
||||
if expert_count == 0 or experts_prefix is None:
|
||||
raise ValueError(f"No experts found for keys: {experts_prefix_candidates}")
|
||||
|
||||
gate_weights = [None] * expert_count
|
||||
up_weights = [None] * expert_count
|
||||
@@ -423,13 +456,13 @@ class FP8SafeTensorLoader(SafeTensorLoader):
|
||||
return self._is_per_channel
|
||||
|
||||
|
||||
|
||||
class BF16SafeTensorLoader(SafeTensorLoader):
|
||||
"""Loader for native BF16 expert weights (no quantization, no scales).
|
||||
|
||||
Supported formats:
|
||||
- DeepSeek style: {base}.mlp.experts.{id}.{gate,up,down}_proj.weight
|
||||
- Mixtral/MiniMax style: {base}.block_sparse_moe.experts.{id}.{w1,w3,w2}.weight
|
||||
- Mistral style: {base}.experts.{id}.{w1,w3,w2}.weight
|
||||
|
||||
The format is auto-detected during initialization.
|
||||
"""
|
||||
@@ -437,6 +470,7 @@ class BF16SafeTensorLoader(SafeTensorLoader):
|
||||
MOE_FORMATS = {
|
||||
"deepseek": ("{base}.mlp.experts", "gate_proj", "up_proj", "down_proj"),
|
||||
"mixtral": ("{base}.block_sparse_moe.experts", "w1", "w3", "w2"),
|
||||
"mistral": ("{base}.experts", "w1", "w3", "w2"),
|
||||
}
|
||||
|
||||
def __init__(self, file_path: str):
|
||||
@@ -466,14 +500,24 @@ class BF16SafeTensorLoader(SafeTensorLoader):
|
||||
self._detected_format = fmt_name
|
||||
print(f"[BF16SafeTensorLoader] Detected format: {fmt_name}")
|
||||
return
|
||||
elif fmt_name == "mistral" and ".mlp.experts" not in key and ".block_sparse_moe.experts" not in key:
|
||||
self._detected_format = fmt_name
|
||||
print(f"[BF16SafeTensorLoader] Detected format: {fmt_name}")
|
||||
return
|
||||
|
||||
self._detected_format = "deepseek"
|
||||
print("[BF16SafeTensorLoader] No MoE format detected, defaulting to: deepseek")
|
||||
|
||||
def _get_experts_prefix(self, base_key: str) -> str:
|
||||
"""Get the experts prefix based on detected format."""
|
||||
def _get_experts_prefix_candidates(self, base_key: str) -> list[str]:
|
||||
"""Get candidate experts prefixes based on detected format and base key variants."""
|
||||
path_tpl, _, _, _ = self.MOE_FORMATS[self._detected_format]
|
||||
return path_tpl.format(base=base_key)
|
||||
candidates = [path_tpl.format(base=base_key)]
|
||||
|
||||
# Some model weights (e.g., Mistral native format) do not have "model." prefix.
|
||||
if base_key.startswith("model."):
|
||||
candidates.append(path_tpl.format(base=base_key[len("model.") :]))
|
||||
|
||||
return list(dict.fromkeys(candidates))
|
||||
|
||||
def _get_proj_names(self):
|
||||
"""Get projection names (gate, up, down) based on detected format."""
|
||||
@@ -497,15 +541,21 @@ class BF16SafeTensorLoader(SafeTensorLoader):
|
||||
if self._detected_format == "packed":
|
||||
return self._load_experts_packed(base_key, device)
|
||||
|
||||
experts_prefix = self._get_experts_prefix(base_key)
|
||||
experts_prefix_candidates = self._get_experts_prefix_candidates(base_key)
|
||||
gate_name, up_name, down_name = self._get_proj_names()
|
||||
|
||||
expert_count = 0
|
||||
while self.has_tensor(f"{experts_prefix}.{expert_count}.{gate_name}.weight"):
|
||||
expert_count += 1
|
||||
experts_prefix = None
|
||||
for prefix in experts_prefix_candidates:
|
||||
expert_count = 0
|
||||
while self.has_tensor(f"{prefix}.{expert_count}.{gate_name}.weight"):
|
||||
expert_count += 1
|
||||
if expert_count > 0:
|
||||
experts_prefix = prefix
|
||||
break
|
||||
|
||||
if expert_count == 0:
|
||||
raise ValueError(f"No experts found for key {experts_prefix}")
|
||||
if expert_count == 0 or experts_prefix is None:
|
||||
raise ValueError(f"No experts found for keys: {experts_prefix_candidates}")
|
||||
|
||||
gate_weights = [None] * expert_count
|
||||
up_weights = [None] * expert_count
|
||||
|
||||
Reference in New Issue
Block a user