[feat]: add mistral moe loader compatibility (#1873)

Co-authored-by: chenht2022 <chenht2022@users.noreply.github.com>
This commit is contained in:
Chen Hongtao
2026-02-28 17:50:23 +08:00
committed by GitHub
parent 19887e4363
commit 9e69fccb02
2 changed files with 73 additions and 19 deletions

View File

@@ -448,6 +448,10 @@ class NativeMoEWrapper(BaseMoEWrapper):
self.down_scales = [t.to(torch.float32).contiguous() for t in weights["down_scale"]]
assert self.gate_scales[0].dtype == torch.float32, "Expected float32 scales for FP8"
elif self.method == "FP8_PERCHANNEL":
if self.gate_scales[0].dtype != torch.float32:
self.gate_scales = [t.to(torch.float32).contiguous() for t in weights["gate_scale"]]
self.up_scales = [t.to(torch.float32).contiguous() for t in weights["up_scale"]]
self.down_scales = [t.to(torch.float32).contiguous() for t in weights["down_scale"]]
assert self.gate_scales[0].dtype == torch.float32, "Expected float32 scales for FP8_PERCHANNEL"
t2 = time.time()

View File

@@ -243,6 +243,7 @@ class FP8SafeTensorLoader(SafeTensorLoader):
Supported formats:
- DeepSeek style: {base}.mlp.experts.{id}.{gate,up,down}_proj.weight
- Mixtral/MiniMax style: {base}.block_sparse_moe.experts.{id}.{w1,w3,w2}.weight
- Mistral style: {base}.experts.{id}.{w1,w3,w2}.weight
Supported scale formats (auto-detected):
- Block-wise: weight_scale_inv (DeepSeek FP8)
@@ -255,6 +256,7 @@ class FP8SafeTensorLoader(SafeTensorLoader):
MOE_FORMATS = {
"deepseek": ("{base}.mlp.experts", "gate_proj", "up_proj", "down_proj"),
"mixtral": ("{base}.block_sparse_moe.experts", "w1", "w3", "w2"),
"mistral": ("{base}.experts", "w1", "w3", "w2"),
}
def __init__(self, file_path: str, scale_suffix: str = None):
@@ -297,6 +299,10 @@ class FP8SafeTensorLoader(SafeTensorLoader):
self._detected_format = fmt_name
print(f"[FP8SafeTensorLoader] Detected format: {fmt_name}")
break
elif fmt_name == "mistral" and ".mlp.experts" not in key and ".block_sparse_moe.experts" not in key:
self._detected_format = fmt_name
print(f"[FP8SafeTensorLoader] Detected format: {fmt_name}")
break
if self._detected_format:
break
@@ -321,8 +327,21 @@ class FP8SafeTensorLoader(SafeTensorLoader):
return
elif f".{gate}.weight_scale" in key and "weight_scale_inv" not in key:
self._scale_suffix = "weight_scale"
self._is_per_channel = True
print("[FP8SafeTensorLoader] Detected scale format: per-channel (weight_scale)")
# Some models (e.g., Mistral) use block-wise FP8 scales but keep
# the key suffix as `weight_scale` (without `_inv`). Infer format
# from scale tensor shape instead of suffix alone:
# - per-channel: [N] or [N, 1]
# - block-wise: [N_block, K_block] (both dims > 1)
scale_tensor = self.load_tensor(key, device="cpu")
if scale_tensor.dim() == 1:
self._is_per_channel = True
elif scale_tensor.dim() == 2 and scale_tensor.shape[1] == 1:
self._is_per_channel = True
else:
self._is_per_channel = False
scale_kind = "per-channel" if self._is_per_channel else "block-wise"
print(f"[FP8SafeTensorLoader] Detected scale format: {scale_kind} (weight_scale)")
return
# Default to weight_scale_inv
self._scale_suffix = "weight_scale_inv"
@@ -333,12 +352,20 @@ class FP8SafeTensorLoader(SafeTensorLoader):
scale_type = "per-channel" if self._is_per_channel else "block-wise"
print(f"[FP8SafeTensorLoader] Using explicit scale format: {scale_type} ({self._scale_suffix})")
def _get_experts_prefix(self, base_key: str) -> str:
"""Get the experts prefix based on detected format."""
def _get_experts_prefix_candidates(self, base_key: str) -> list[str]:
"""Get candidate experts prefixes based on detected format and base key variants."""
path_tpl, _, _, _ = self.MOE_FORMATS[self._detected_format]
candidates = []
if self._is_vl_model:
base_key = base_key.replace("model.layers", "model.language_model.layers")
return path_tpl.format(base=base_key)
candidates.append(path_tpl.format(base=base_key))
# Some model weights (e.g., Mistral native format) do not have "model." prefix.
if base_key.startswith("model."):
candidates.append(path_tpl.format(base=base_key[len("model.") :]))
# Deduplicate while preserving order.
return list(dict.fromkeys(candidates))
def _get_proj_names(self):
"""Get projection names (gate, up, down) based on detected format."""
@@ -363,15 +390,21 @@ class FP8SafeTensorLoader(SafeTensorLoader):
Supports both block-wise (weight_scale_inv) and per-channel (weight_scale) formats.
Per-channel scales are squeezed from [N, 1] to [N] if needed.
"""
experts_prefix = self._get_experts_prefix(base_key)
experts_prefix_candidates = self._get_experts_prefix_candidates(base_key)
gate_name, up_name, down_name = self._get_proj_names()
expert_count = 0
while self.has_tensor(f"{experts_prefix}.{expert_count}.{gate_name}.weight"):
expert_count += 1
experts_prefix = None
for prefix in experts_prefix_candidates:
expert_count = 0
while self.has_tensor(f"{prefix}.{expert_count}.{gate_name}.weight"):
expert_count += 1
if expert_count > 0:
experts_prefix = prefix
break
if expert_count == 0:
raise ValueError(f"No experts found for key {experts_prefix}")
if expert_count == 0 or experts_prefix is None:
raise ValueError(f"No experts found for keys: {experts_prefix_candidates}")
gate_weights = [None] * expert_count
up_weights = [None] * expert_count
@@ -423,13 +456,13 @@ class FP8SafeTensorLoader(SafeTensorLoader):
return self._is_per_channel
class BF16SafeTensorLoader(SafeTensorLoader):
"""Loader for native BF16 expert weights (no quantization, no scales).
Supported formats:
- DeepSeek style: {base}.mlp.experts.{id}.{gate,up,down}_proj.weight
- Mixtral/MiniMax style: {base}.block_sparse_moe.experts.{id}.{w1,w3,w2}.weight
- Mistral style: {base}.experts.{id}.{w1,w3,w2}.weight
The format is auto-detected during initialization.
"""
@@ -437,6 +470,7 @@ class BF16SafeTensorLoader(SafeTensorLoader):
MOE_FORMATS = {
"deepseek": ("{base}.mlp.experts", "gate_proj", "up_proj", "down_proj"),
"mixtral": ("{base}.block_sparse_moe.experts", "w1", "w3", "w2"),
"mistral": ("{base}.experts", "w1", "w3", "w2"),
}
def __init__(self, file_path: str):
@@ -466,14 +500,24 @@ class BF16SafeTensorLoader(SafeTensorLoader):
self._detected_format = fmt_name
print(f"[BF16SafeTensorLoader] Detected format: {fmt_name}")
return
elif fmt_name == "mistral" and ".mlp.experts" not in key and ".block_sparse_moe.experts" not in key:
self._detected_format = fmt_name
print(f"[BF16SafeTensorLoader] Detected format: {fmt_name}")
return
self._detected_format = "deepseek"
print("[BF16SafeTensorLoader] No MoE format detected, defaulting to: deepseek")
def _get_experts_prefix(self, base_key: str) -> str:
"""Get the experts prefix based on detected format."""
def _get_experts_prefix_candidates(self, base_key: str) -> list[str]:
"""Get candidate experts prefixes based on detected format and base key variants."""
path_tpl, _, _, _ = self.MOE_FORMATS[self._detected_format]
return path_tpl.format(base=base_key)
candidates = [path_tpl.format(base=base_key)]
# Some model weights (e.g., Mistral native format) do not have "model." prefix.
if base_key.startswith("model."):
candidates.append(path_tpl.format(base=base_key[len("model.") :]))
return list(dict.fromkeys(candidates))
def _get_proj_names(self):
"""Get projection names (gate, up, down) based on detected format."""
@@ -497,15 +541,21 @@ class BF16SafeTensorLoader(SafeTensorLoader):
if self._detected_format == "packed":
return self._load_experts_packed(base_key, device)
experts_prefix = self._get_experts_prefix(base_key)
experts_prefix_candidates = self._get_experts_prefix_candidates(base_key)
gate_name, up_name, down_name = self._get_proj_names()
expert_count = 0
while self.has_tensor(f"{experts_prefix}.{expert_count}.{gate_name}.weight"):
expert_count += 1
experts_prefix = None
for prefix in experts_prefix_candidates:
expert_count = 0
while self.has_tensor(f"{prefix}.{expert_count}.{gate_name}.weight"):
expert_count += 1
if expert_count > 0:
experts_prefix = prefix
break
if expert_count == 0:
raise ValueError(f"No experts found for key {experts_prefix}")
if expert_count == 0 or experts_prefix is None:
raise ValueError(f"No experts found for keys: {experts_prefix_candidates}")
gate_weights = [None] * expert_count
up_weights = [None] * expert_count