mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2026-04-20 06:18:59 +00:00
@@ -244,6 +244,10 @@ class FP8SafeTensorLoader(SafeTensorLoader):
|
||||
- DeepSeek style: {base}.mlp.experts.{id}.{gate,up,down}_proj.weight
|
||||
- Mixtral/MiniMax style: {base}.block_sparse_moe.experts.{id}.{w1,w3,w2}.weight
|
||||
|
||||
Supported scale formats (auto-detected):
|
||||
- Block-wise: weight_scale_inv (DeepSeek FP8)
|
||||
- Per-channel: weight_scale (GLM-4.7-FP8)
|
||||
|
||||
The format is auto-detected during initialization.
|
||||
"""
|
||||
|
||||
@@ -253,13 +257,28 @@ class FP8SafeTensorLoader(SafeTensorLoader):
|
||||
"mixtral": ("{base}.block_sparse_moe.experts", "w1", "w3", "w2"),
|
||||
}
|
||||
|
||||
def __init__(self, file_path: str):
|
||||
def __init__(self, file_path: str, scale_suffix: str = None):
|
||||
"""Initialize FP8 loader with optional scale suffix override.
|
||||
|
||||
Args:
|
||||
file_path: Path to safetensor files
|
||||
scale_suffix: Optional scale key suffix. If None, auto-detect between
|
||||
'weight_scale_inv' (block-wise) and 'weight_scale' (per-channel).
|
||||
"""
|
||||
super().__init__(file_path)
|
||||
self._detected_format = None
|
||||
self._scale_suffix = scale_suffix # None means auto-detect
|
||||
# Set per_channel based on explicit scale_suffix if provided
|
||||
if scale_suffix == "weight_scale":
|
||||
self._is_per_channel = True
|
||||
elif scale_suffix == "weight_scale_inv":
|
||||
self._is_per_channel = False
|
||||
else:
|
||||
self._is_per_channel = False # Will be updated in _detect_format if auto-detect
|
||||
self._detect_format()
|
||||
|
||||
def _detect_format(self):
|
||||
"""Auto-detect the MoE naming format by checking tensor keys."""
|
||||
"""Auto-detect the MoE naming format and scale format by checking tensor keys."""
|
||||
# Sample some tensor names to detect format
|
||||
sample_keys = list(self.tensor_file_map.keys())[:1000]
|
||||
|
||||
@@ -272,15 +291,42 @@ class FP8SafeTensorLoader(SafeTensorLoader):
|
||||
if "block_sparse_moe.experts" in key and fmt_name == "mixtral":
|
||||
self._detected_format = fmt_name
|
||||
print(f"[FP8SafeTensorLoader] Detected format: {fmt_name}")
|
||||
return
|
||||
break
|
||||
elif "mlp.experts" in key and "block_sparse_moe" not in key and fmt_name == "deepseek":
|
||||
self._detected_format = fmt_name
|
||||
print(f"[FP8SafeTensorLoader] Detected format: {fmt_name}")
|
||||
return
|
||||
break
|
||||
if self._detected_format:
|
||||
break
|
||||
|
||||
# Default to deepseek if no format detected
|
||||
self._detected_format = "deepseek"
|
||||
print("[FP8SafeTensorLoader] No MoE format detected, defaulting to: deepseek")
|
||||
if not self._detected_format:
|
||||
self._detected_format = "deepseek"
|
||||
print("[FP8SafeTensorLoader] No MoE format detected, defaulting to: deepseek")
|
||||
|
||||
# Auto-detect scale suffix if not specified
|
||||
if self._scale_suffix is None:
|
||||
_, gate, _, _ = self.MOE_FORMATS[self._detected_format]
|
||||
# Check for per-channel scale (weight_scale) vs block-wise (weight_scale_inv)
|
||||
for key in sample_keys:
|
||||
if f".{gate}.weight_scale_inv" in key:
|
||||
self._scale_suffix = "weight_scale_inv"
|
||||
self._is_per_channel = False
|
||||
print("[FP8SafeTensorLoader] Detected scale format: block-wise (weight_scale_inv)")
|
||||
return
|
||||
elif f".{gate}.weight_scale" in key and "weight_scale_inv" not in key:
|
||||
self._scale_suffix = "weight_scale"
|
||||
self._is_per_channel = True
|
||||
print("[FP8SafeTensorLoader] Detected scale format: per-channel (weight_scale)")
|
||||
return
|
||||
# Default to weight_scale_inv
|
||||
self._scale_suffix = "weight_scale_inv"
|
||||
self._is_per_channel = False
|
||||
print("[FP8SafeTensorLoader] No scale format detected, defaulting to: weight_scale_inv")
|
||||
else:
|
||||
# Scale suffix was explicitly provided
|
||||
scale_type = "per-channel" if self._is_per_channel else "block-wise"
|
||||
print(f"[FP8SafeTensorLoader] Using explicit scale format: {scale_type} ({self._scale_suffix})")
|
||||
|
||||
def _get_experts_prefix(self, base_key: str) -> str:
|
||||
"""Get the experts prefix based on detected format."""
|
||||
@@ -305,7 +351,11 @@ class FP8SafeTensorLoader(SafeTensorLoader):
|
||||
return tensor.to(device)
|
||||
|
||||
def load_experts(self, base_key: str, device: str = "cpu"):
|
||||
"""Load FP8 expert weights and their block-wise scale_inv tensors."""
|
||||
"""Load FP8 expert weights and their scale tensors.
|
||||
|
||||
Supports both block-wise (weight_scale_inv) and per-channel (weight_scale) formats.
|
||||
Per-channel scales are squeezed from [N, 1] to [N] if needed.
|
||||
"""
|
||||
experts_prefix = self._get_experts_prefix(base_key)
|
||||
gate_name, up_name, down_name = self._get_proj_names()
|
||||
|
||||
@@ -327,16 +377,30 @@ class FP8SafeTensorLoader(SafeTensorLoader):
|
||||
gate_w_key = f"{experts_prefix}.{exp_id}.{gate_name}.weight"
|
||||
up_w_key = f"{experts_prefix}.{exp_id}.{up_name}.weight"
|
||||
down_w_key = f"{experts_prefix}.{exp_id}.{down_name}.weight"
|
||||
gate_s_key = f"{experts_prefix}.{exp_id}.{gate_name}.weight_scale_inv"
|
||||
up_s_key = f"{experts_prefix}.{exp_id}.{up_name}.weight_scale_inv"
|
||||
down_s_key = f"{experts_prefix}.{exp_id}.{down_name}.weight_scale_inv"
|
||||
gate_s_key = f"{experts_prefix}.{exp_id}.{gate_name}.{self._scale_suffix}"
|
||||
up_s_key = f"{experts_prefix}.{exp_id}.{up_name}.{self._scale_suffix}"
|
||||
down_s_key = f"{experts_prefix}.{exp_id}.{down_name}.{self._scale_suffix}"
|
||||
|
||||
gate_weights[exp_id] = self.load_tensor(gate_w_key, device).contiguous()
|
||||
up_weights[exp_id] = self.load_tensor(up_w_key, device).contiguous()
|
||||
down_weights[exp_id] = self.load_tensor(down_w_key, device).contiguous()
|
||||
gate_scales[exp_id] = self.load_tensor(gate_s_key, device).contiguous()
|
||||
up_scales[exp_id] = self.load_tensor(up_s_key, device).contiguous()
|
||||
down_scales[exp_id] = self.load_tensor(down_s_key, device).contiguous()
|
||||
|
||||
gate_scale = self.load_tensor(gate_s_key, device)
|
||||
up_scale = self.load_tensor(up_s_key, device)
|
||||
down_scale = self.load_tensor(down_s_key, device)
|
||||
|
||||
# For per-channel scales, squeeze [N, 1] -> [N] if needed
|
||||
if self._is_per_channel:
|
||||
if gate_scale.dim() == 2 and gate_scale.shape[1] == 1:
|
||||
gate_scale = gate_scale.squeeze(1)
|
||||
if up_scale.dim() == 2 and up_scale.shape[1] == 1:
|
||||
up_scale = up_scale.squeeze(1)
|
||||
if down_scale.dim() == 2 and down_scale.shape[1] == 1:
|
||||
down_scale = down_scale.squeeze(1)
|
||||
|
||||
gate_scales[exp_id] = gate_scale.contiguous()
|
||||
up_scales[exp_id] = up_scale.contiguous()
|
||||
down_scales[exp_id] = down_scale.contiguous()
|
||||
|
||||
return {
|
||||
"gate": gate_weights,
|
||||
@@ -347,6 +411,103 @@ class FP8SafeTensorLoader(SafeTensorLoader):
|
||||
"down_scale": down_scales,
|
||||
}
|
||||
|
||||
def is_per_channel(self) -> bool:
|
||||
"""Return True if using per-channel quantization, False for block-wise."""
|
||||
return self._is_per_channel
|
||||
|
||||
|
||||
class BF16SafeTensorLoader(SafeTensorLoader):
|
||||
"""Loader for native BF16 expert weights (no quantization, no scales).
|
||||
|
||||
Supported formats:
|
||||
- DeepSeek style: {base}.mlp.experts.{id}.{gate,up,down}_proj.weight
|
||||
- Mixtral/MiniMax style: {base}.block_sparse_moe.experts.{id}.{w1,w3,w2}.weight
|
||||
|
||||
The format is auto-detected during initialization.
|
||||
"""
|
||||
|
||||
MOE_FORMATS = {
|
||||
"deepseek": ("{base}.mlp.experts", "gate_proj", "up_proj", "down_proj"),
|
||||
"mixtral": ("{base}.block_sparse_moe.experts", "w1", "w3", "w2"),
|
||||
}
|
||||
|
||||
def __init__(self, file_path: str):
|
||||
super().__init__(file_path)
|
||||
self._detected_format = None
|
||||
self._detect_format()
|
||||
|
||||
def _detect_format(self):
|
||||
"""Auto-detect the MoE naming format by checking tensor keys."""
|
||||
sample_keys = list(self.tensor_file_map.keys())[:1000]
|
||||
|
||||
for fmt_name, (path_tpl, gate, up, down) in self.MOE_FORMATS.items():
|
||||
for key in sample_keys:
|
||||
if ".experts." in key and f".{gate}.weight" in key:
|
||||
if "block_sparse_moe.experts" in key and fmt_name == "mixtral":
|
||||
self._detected_format = fmt_name
|
||||
print(f"[BF16SafeTensorLoader] Detected format: {fmt_name}")
|
||||
return
|
||||
elif "mlp.experts" in key and "block_sparse_moe" not in key and fmt_name == "deepseek":
|
||||
self._detected_format = fmt_name
|
||||
print(f"[BF16SafeTensorLoader] Detected format: {fmt_name}")
|
||||
return
|
||||
|
||||
self._detected_format = "deepseek"
|
||||
print("[BF16SafeTensorLoader] No MoE format detected, defaulting to: deepseek")
|
||||
|
||||
def _get_experts_prefix(self, base_key: str) -> str:
|
||||
"""Get the experts prefix based on detected format."""
|
||||
path_tpl, _, _, _ = self.MOE_FORMATS[self._detected_format]
|
||||
return path_tpl.format(base=base_key)
|
||||
|
||||
def _get_proj_names(self):
|
||||
"""Get projection names (gate, up, down) based on detected format."""
|
||||
_, gate, up, down = self.MOE_FORMATS[self._detected_format]
|
||||
return gate, up, down
|
||||
|
||||
def load_tensor(self, key: str, device: str = "cpu"):
|
||||
if key not in self.tensor_file_map:
|
||||
raise KeyError(f"Key {key} not found in Safetensor files")
|
||||
file = self.tensor_file_map[key]
|
||||
f = self.file_handle_map.get(file)
|
||||
if f is None:
|
||||
raise FileNotFoundError(f"File {file} not found in Safetensor files")
|
||||
tensor = f.get_tensor(key)
|
||||
if device == "cpu":
|
||||
return tensor
|
||||
return tensor.to(device)
|
||||
|
||||
def load_experts(self, base_key: str, device: str = "cpu"):
|
||||
"""Load BF16 expert weights (no scales needed)."""
|
||||
experts_prefix = self._get_experts_prefix(base_key)
|
||||
gate_name, up_name, down_name = self._get_proj_names()
|
||||
|
||||
expert_count = 0
|
||||
while self.has_tensor(f"{experts_prefix}.{expert_count}.{gate_name}.weight"):
|
||||
expert_count += 1
|
||||
|
||||
if expert_count == 0:
|
||||
raise ValueError(f"No experts found for key {experts_prefix}")
|
||||
|
||||
gate_weights = [None] * expert_count
|
||||
up_weights = [None] * expert_count
|
||||
down_weights = [None] * expert_count
|
||||
|
||||
for exp_id in range(expert_count):
|
||||
gate_w_key = f"{experts_prefix}.{exp_id}.{gate_name}.weight"
|
||||
up_w_key = f"{experts_prefix}.{exp_id}.{up_name}.weight"
|
||||
down_w_key = f"{experts_prefix}.{exp_id}.{down_name}.weight"
|
||||
|
||||
gate_weights[exp_id] = self.load_tensor(gate_w_key, device).contiguous()
|
||||
up_weights[exp_id] = self.load_tensor(up_w_key, device).contiguous()
|
||||
down_weights[exp_id] = self.load_tensor(down_w_key, device).contiguous()
|
||||
|
||||
return {
|
||||
"gate": gate_weights,
|
||||
"up": up_weights,
|
||||
"down": down_weights,
|
||||
}
|
||||
|
||||
|
||||
class BF16SafeTensorLoader(SafeTensorLoader):
|
||||
"""Loader for native BF16 expert weights (no quantization, no scales).
|
||||
|
||||
Reference in New Issue
Block a user