mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2026-04-29 10:41:38 +00:00
[feat]: add mistral moe loader compatibility (#1873)
Co-authored-by: chenht2022 <chenht2022@users.noreply.github.com>
This commit is contained in:
@@ -448,6 +448,10 @@ class NativeMoEWrapper(BaseMoEWrapper):
|
|||||||
self.down_scales = [t.to(torch.float32).contiguous() for t in weights["down_scale"]]
|
self.down_scales = [t.to(torch.float32).contiguous() for t in weights["down_scale"]]
|
||||||
assert self.gate_scales[0].dtype == torch.float32, "Expected float32 scales for FP8"
|
assert self.gate_scales[0].dtype == torch.float32, "Expected float32 scales for FP8"
|
||||||
elif self.method == "FP8_PERCHANNEL":
|
elif self.method == "FP8_PERCHANNEL":
|
||||||
|
if self.gate_scales[0].dtype != torch.float32:
|
||||||
|
self.gate_scales = [t.to(torch.float32).contiguous() for t in weights["gate_scale"]]
|
||||||
|
self.up_scales = [t.to(torch.float32).contiguous() for t in weights["up_scale"]]
|
||||||
|
self.down_scales = [t.to(torch.float32).contiguous() for t in weights["down_scale"]]
|
||||||
assert self.gate_scales[0].dtype == torch.float32, "Expected float32 scales for FP8_PERCHANNEL"
|
assert self.gate_scales[0].dtype == torch.float32, "Expected float32 scales for FP8_PERCHANNEL"
|
||||||
|
|
||||||
t2 = time.time()
|
t2 = time.time()
|
||||||
|
|||||||
@@ -243,6 +243,7 @@ class FP8SafeTensorLoader(SafeTensorLoader):
|
|||||||
Supported formats:
|
Supported formats:
|
||||||
- DeepSeek style: {base}.mlp.experts.{id}.{gate,up,down}_proj.weight
|
- DeepSeek style: {base}.mlp.experts.{id}.{gate,up,down}_proj.weight
|
||||||
- Mixtral/MiniMax style: {base}.block_sparse_moe.experts.{id}.{w1,w3,w2}.weight
|
- Mixtral/MiniMax style: {base}.block_sparse_moe.experts.{id}.{w1,w3,w2}.weight
|
||||||
|
- Mistral style: {base}.experts.{id}.{w1,w3,w2}.weight
|
||||||
|
|
||||||
Supported scale formats (auto-detected):
|
Supported scale formats (auto-detected):
|
||||||
- Block-wise: weight_scale_inv (DeepSeek FP8)
|
- Block-wise: weight_scale_inv (DeepSeek FP8)
|
||||||
@@ -255,6 +256,7 @@ class FP8SafeTensorLoader(SafeTensorLoader):
|
|||||||
MOE_FORMATS = {
|
MOE_FORMATS = {
|
||||||
"deepseek": ("{base}.mlp.experts", "gate_proj", "up_proj", "down_proj"),
|
"deepseek": ("{base}.mlp.experts", "gate_proj", "up_proj", "down_proj"),
|
||||||
"mixtral": ("{base}.block_sparse_moe.experts", "w1", "w3", "w2"),
|
"mixtral": ("{base}.block_sparse_moe.experts", "w1", "w3", "w2"),
|
||||||
|
"mistral": ("{base}.experts", "w1", "w3", "w2"),
|
||||||
}
|
}
|
||||||
|
|
||||||
def __init__(self, file_path: str, scale_suffix: str = None):
|
def __init__(self, file_path: str, scale_suffix: str = None):
|
||||||
@@ -297,6 +299,10 @@ class FP8SafeTensorLoader(SafeTensorLoader):
|
|||||||
self._detected_format = fmt_name
|
self._detected_format = fmt_name
|
||||||
print(f"[FP8SafeTensorLoader] Detected format: {fmt_name}")
|
print(f"[FP8SafeTensorLoader] Detected format: {fmt_name}")
|
||||||
break
|
break
|
||||||
|
elif fmt_name == "mistral" and ".mlp.experts" not in key and ".block_sparse_moe.experts" not in key:
|
||||||
|
self._detected_format = fmt_name
|
||||||
|
print(f"[FP8SafeTensorLoader] Detected format: {fmt_name}")
|
||||||
|
break
|
||||||
if self._detected_format:
|
if self._detected_format:
|
||||||
break
|
break
|
||||||
|
|
||||||
@@ -321,8 +327,21 @@ class FP8SafeTensorLoader(SafeTensorLoader):
|
|||||||
return
|
return
|
||||||
elif f".{gate}.weight_scale" in key and "weight_scale_inv" not in key:
|
elif f".{gate}.weight_scale" in key and "weight_scale_inv" not in key:
|
||||||
self._scale_suffix = "weight_scale"
|
self._scale_suffix = "weight_scale"
|
||||||
self._is_per_channel = True
|
# Some models (e.g., Mistral) use block-wise FP8 scales but keep
|
||||||
print("[FP8SafeTensorLoader] Detected scale format: per-channel (weight_scale)")
|
# the key suffix as `weight_scale` (without `_inv`). Infer format
|
||||||
|
# from scale tensor shape instead of suffix alone:
|
||||||
|
# - per-channel: [N] or [N, 1]
|
||||||
|
# - block-wise: [N_block, K_block] (both dims > 1)
|
||||||
|
scale_tensor = self.load_tensor(key, device="cpu")
|
||||||
|
if scale_tensor.dim() == 1:
|
||||||
|
self._is_per_channel = True
|
||||||
|
elif scale_tensor.dim() == 2 and scale_tensor.shape[1] == 1:
|
||||||
|
self._is_per_channel = True
|
||||||
|
else:
|
||||||
|
self._is_per_channel = False
|
||||||
|
|
||||||
|
scale_kind = "per-channel" if self._is_per_channel else "block-wise"
|
||||||
|
print(f"[FP8SafeTensorLoader] Detected scale format: {scale_kind} (weight_scale)")
|
||||||
return
|
return
|
||||||
# Default to weight_scale_inv
|
# Default to weight_scale_inv
|
||||||
self._scale_suffix = "weight_scale_inv"
|
self._scale_suffix = "weight_scale_inv"
|
||||||
@@ -333,12 +352,20 @@ class FP8SafeTensorLoader(SafeTensorLoader):
|
|||||||
scale_type = "per-channel" if self._is_per_channel else "block-wise"
|
scale_type = "per-channel" if self._is_per_channel else "block-wise"
|
||||||
print(f"[FP8SafeTensorLoader] Using explicit scale format: {scale_type} ({self._scale_suffix})")
|
print(f"[FP8SafeTensorLoader] Using explicit scale format: {scale_type} ({self._scale_suffix})")
|
||||||
|
|
||||||
def _get_experts_prefix(self, base_key: str) -> str:
|
def _get_experts_prefix_candidates(self, base_key: str) -> list[str]:
|
||||||
"""Get the experts prefix based on detected format."""
|
"""Get candidate experts prefixes based on detected format and base key variants."""
|
||||||
path_tpl, _, _, _ = self.MOE_FORMATS[self._detected_format]
|
path_tpl, _, _, _ = self.MOE_FORMATS[self._detected_format]
|
||||||
|
candidates = []
|
||||||
if self._is_vl_model:
|
if self._is_vl_model:
|
||||||
base_key = base_key.replace("model.layers", "model.language_model.layers")
|
base_key = base_key.replace("model.layers", "model.language_model.layers")
|
||||||
return path_tpl.format(base=base_key)
|
candidates.append(path_tpl.format(base=base_key))
|
||||||
|
|
||||||
|
# Some model weights (e.g., Mistral native format) do not have "model." prefix.
|
||||||
|
if base_key.startswith("model."):
|
||||||
|
candidates.append(path_tpl.format(base=base_key[len("model.") :]))
|
||||||
|
|
||||||
|
# Deduplicate while preserving order.
|
||||||
|
return list(dict.fromkeys(candidates))
|
||||||
|
|
||||||
def _get_proj_names(self):
|
def _get_proj_names(self):
|
||||||
"""Get projection names (gate, up, down) based on detected format."""
|
"""Get projection names (gate, up, down) based on detected format."""
|
||||||
@@ -363,15 +390,21 @@ class FP8SafeTensorLoader(SafeTensorLoader):
|
|||||||
Supports both block-wise (weight_scale_inv) and per-channel (weight_scale) formats.
|
Supports both block-wise (weight_scale_inv) and per-channel (weight_scale) formats.
|
||||||
Per-channel scales are squeezed from [N, 1] to [N] if needed.
|
Per-channel scales are squeezed from [N, 1] to [N] if needed.
|
||||||
"""
|
"""
|
||||||
experts_prefix = self._get_experts_prefix(base_key)
|
experts_prefix_candidates = self._get_experts_prefix_candidates(base_key)
|
||||||
gate_name, up_name, down_name = self._get_proj_names()
|
gate_name, up_name, down_name = self._get_proj_names()
|
||||||
|
|
||||||
expert_count = 0
|
expert_count = 0
|
||||||
while self.has_tensor(f"{experts_prefix}.{expert_count}.{gate_name}.weight"):
|
experts_prefix = None
|
||||||
expert_count += 1
|
for prefix in experts_prefix_candidates:
|
||||||
|
expert_count = 0
|
||||||
|
while self.has_tensor(f"{prefix}.{expert_count}.{gate_name}.weight"):
|
||||||
|
expert_count += 1
|
||||||
|
if expert_count > 0:
|
||||||
|
experts_prefix = prefix
|
||||||
|
break
|
||||||
|
|
||||||
if expert_count == 0:
|
if expert_count == 0 or experts_prefix is None:
|
||||||
raise ValueError(f"No experts found for key {experts_prefix}")
|
raise ValueError(f"No experts found for keys: {experts_prefix_candidates}")
|
||||||
|
|
||||||
gate_weights = [None] * expert_count
|
gate_weights = [None] * expert_count
|
||||||
up_weights = [None] * expert_count
|
up_weights = [None] * expert_count
|
||||||
@@ -423,13 +456,13 @@ class FP8SafeTensorLoader(SafeTensorLoader):
|
|||||||
return self._is_per_channel
|
return self._is_per_channel
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class BF16SafeTensorLoader(SafeTensorLoader):
|
class BF16SafeTensorLoader(SafeTensorLoader):
|
||||||
"""Loader for native BF16 expert weights (no quantization, no scales).
|
"""Loader for native BF16 expert weights (no quantization, no scales).
|
||||||
|
|
||||||
Supported formats:
|
Supported formats:
|
||||||
- DeepSeek style: {base}.mlp.experts.{id}.{gate,up,down}_proj.weight
|
- DeepSeek style: {base}.mlp.experts.{id}.{gate,up,down}_proj.weight
|
||||||
- Mixtral/MiniMax style: {base}.block_sparse_moe.experts.{id}.{w1,w3,w2}.weight
|
- Mixtral/MiniMax style: {base}.block_sparse_moe.experts.{id}.{w1,w3,w2}.weight
|
||||||
|
- Mistral style: {base}.experts.{id}.{w1,w3,w2}.weight
|
||||||
|
|
||||||
The format is auto-detected during initialization.
|
The format is auto-detected during initialization.
|
||||||
"""
|
"""
|
||||||
@@ -437,6 +470,7 @@ class BF16SafeTensorLoader(SafeTensorLoader):
|
|||||||
MOE_FORMATS = {
|
MOE_FORMATS = {
|
||||||
"deepseek": ("{base}.mlp.experts", "gate_proj", "up_proj", "down_proj"),
|
"deepseek": ("{base}.mlp.experts", "gate_proj", "up_proj", "down_proj"),
|
||||||
"mixtral": ("{base}.block_sparse_moe.experts", "w1", "w3", "w2"),
|
"mixtral": ("{base}.block_sparse_moe.experts", "w1", "w3", "w2"),
|
||||||
|
"mistral": ("{base}.experts", "w1", "w3", "w2"),
|
||||||
}
|
}
|
||||||
|
|
||||||
def __init__(self, file_path: str):
|
def __init__(self, file_path: str):
|
||||||
@@ -466,14 +500,24 @@ class BF16SafeTensorLoader(SafeTensorLoader):
|
|||||||
self._detected_format = fmt_name
|
self._detected_format = fmt_name
|
||||||
print(f"[BF16SafeTensorLoader] Detected format: {fmt_name}")
|
print(f"[BF16SafeTensorLoader] Detected format: {fmt_name}")
|
||||||
return
|
return
|
||||||
|
elif fmt_name == "mistral" and ".mlp.experts" not in key and ".block_sparse_moe.experts" not in key:
|
||||||
|
self._detected_format = fmt_name
|
||||||
|
print(f"[BF16SafeTensorLoader] Detected format: {fmt_name}")
|
||||||
|
return
|
||||||
|
|
||||||
self._detected_format = "deepseek"
|
self._detected_format = "deepseek"
|
||||||
print("[BF16SafeTensorLoader] No MoE format detected, defaulting to: deepseek")
|
print("[BF16SafeTensorLoader] No MoE format detected, defaulting to: deepseek")
|
||||||
|
|
||||||
def _get_experts_prefix(self, base_key: str) -> str:
|
def _get_experts_prefix_candidates(self, base_key: str) -> list[str]:
|
||||||
"""Get the experts prefix based on detected format."""
|
"""Get candidate experts prefixes based on detected format and base key variants."""
|
||||||
path_tpl, _, _, _ = self.MOE_FORMATS[self._detected_format]
|
path_tpl, _, _, _ = self.MOE_FORMATS[self._detected_format]
|
||||||
return path_tpl.format(base=base_key)
|
candidates = [path_tpl.format(base=base_key)]
|
||||||
|
|
||||||
|
# Some model weights (e.g., Mistral native format) do not have "model." prefix.
|
||||||
|
if base_key.startswith("model."):
|
||||||
|
candidates.append(path_tpl.format(base=base_key[len("model.") :]))
|
||||||
|
|
||||||
|
return list(dict.fromkeys(candidates))
|
||||||
|
|
||||||
def _get_proj_names(self):
|
def _get_proj_names(self):
|
||||||
"""Get projection names (gate, up, down) based on detected format."""
|
"""Get projection names (gate, up, down) based on detected format."""
|
||||||
@@ -497,15 +541,21 @@ class BF16SafeTensorLoader(SafeTensorLoader):
|
|||||||
if self._detected_format == "packed":
|
if self._detected_format == "packed":
|
||||||
return self._load_experts_packed(base_key, device)
|
return self._load_experts_packed(base_key, device)
|
||||||
|
|
||||||
experts_prefix = self._get_experts_prefix(base_key)
|
experts_prefix_candidates = self._get_experts_prefix_candidates(base_key)
|
||||||
gate_name, up_name, down_name = self._get_proj_names()
|
gate_name, up_name, down_name = self._get_proj_names()
|
||||||
|
|
||||||
expert_count = 0
|
expert_count = 0
|
||||||
while self.has_tensor(f"{experts_prefix}.{expert_count}.{gate_name}.weight"):
|
experts_prefix = None
|
||||||
expert_count += 1
|
for prefix in experts_prefix_candidates:
|
||||||
|
expert_count = 0
|
||||||
|
while self.has_tensor(f"{prefix}.{expert_count}.{gate_name}.weight"):
|
||||||
|
expert_count += 1
|
||||||
|
if expert_count > 0:
|
||||||
|
experts_prefix = prefix
|
||||||
|
break
|
||||||
|
|
||||||
if expert_count == 0:
|
if expert_count == 0 or experts_prefix is None:
|
||||||
raise ValueError(f"No experts found for key {experts_prefix}")
|
raise ValueError(f"No experts found for keys: {experts_prefix_candidates}")
|
||||||
|
|
||||||
gate_weights = [None] * expert_count
|
gate_weights = [None] * expert_count
|
||||||
up_weights = [None] * expert_count
|
up_weights = [None] * expert_count
|
||||||
|
|||||||
Reference in New Issue
Block a user