support GLM 4.7 (#1791)

support GLM 4.7
This commit is contained in:
Oql
2026-01-13 17:36:25 +08:00
committed by GitHub
parent 667030d6e6
commit 6277da4c2b
14 changed files with 2336 additions and 144 deletions

View File

@@ -244,6 +244,10 @@ class FP8SafeTensorLoader(SafeTensorLoader):
- DeepSeek style: {base}.mlp.experts.{id}.{gate,up,down}_proj.weight
- Mixtral/MiniMax style: {base}.block_sparse_moe.experts.{id}.{w1,w3,w2}.weight
Supported scale formats (auto-detected):
- Block-wise: weight_scale_inv (DeepSeek FP8)
- Per-channel: weight_scale (GLM-4.7-FP8)
The format is auto-detected during initialization.
"""
@@ -253,13 +257,28 @@ class FP8SafeTensorLoader(SafeTensorLoader):
"mixtral": ("{base}.block_sparse_moe.experts", "w1", "w3", "w2"),
}
def __init__(self, file_path: str):
def __init__(self, file_path: str, scale_suffix: str = None):
"""Initialize FP8 loader with optional scale suffix override.
Args:
file_path: Path to safetensor files
scale_suffix: Optional scale key suffix. If None, auto-detect between
'weight_scale_inv' (block-wise) and 'weight_scale' (per-channel).
"""
super().__init__(file_path)
self._detected_format = None
self._scale_suffix = scale_suffix # None means auto-detect
# Set per_channel based on explicit scale_suffix if provided
if scale_suffix == "weight_scale":
self._is_per_channel = True
elif scale_suffix == "weight_scale_inv":
self._is_per_channel = False
else:
self._is_per_channel = False # Will be updated in _detect_format if auto-detect
self._detect_format()
def _detect_format(self):
"""Auto-detect the MoE naming format by checking tensor keys."""
"""Auto-detect the MoE naming format and scale format by checking tensor keys."""
# Sample some tensor names to detect format
sample_keys = list(self.tensor_file_map.keys())[:1000]
@@ -272,15 +291,42 @@ class FP8SafeTensorLoader(SafeTensorLoader):
if "block_sparse_moe.experts" in key and fmt_name == "mixtral":
self._detected_format = fmt_name
print(f"[FP8SafeTensorLoader] Detected format: {fmt_name}")
return
break
elif "mlp.experts" in key and "block_sparse_moe" not in key and fmt_name == "deepseek":
self._detected_format = fmt_name
print(f"[FP8SafeTensorLoader] Detected format: {fmt_name}")
return
break
if self._detected_format:
break
# Default to deepseek if no format detected
self._detected_format = "deepseek"
print("[FP8SafeTensorLoader] No MoE format detected, defaulting to: deepseek")
if not self._detected_format:
self._detected_format = "deepseek"
print("[FP8SafeTensorLoader] No MoE format detected, defaulting to: deepseek")
# Auto-detect scale suffix if not specified
if self._scale_suffix is None:
_, gate, _, _ = self.MOE_FORMATS[self._detected_format]
# Check for per-channel scale (weight_scale) vs block-wise (weight_scale_inv)
for key in sample_keys:
if f".{gate}.weight_scale_inv" in key:
self._scale_suffix = "weight_scale_inv"
self._is_per_channel = False
print("[FP8SafeTensorLoader] Detected scale format: block-wise (weight_scale_inv)")
return
elif f".{gate}.weight_scale" in key and "weight_scale_inv" not in key:
self._scale_suffix = "weight_scale"
self._is_per_channel = True
print("[FP8SafeTensorLoader] Detected scale format: per-channel (weight_scale)")
return
# Default to weight_scale_inv
self._scale_suffix = "weight_scale_inv"
self._is_per_channel = False
print("[FP8SafeTensorLoader] No scale format detected, defaulting to: weight_scale_inv")
else:
# Scale suffix was explicitly provided
scale_type = "per-channel" if self._is_per_channel else "block-wise"
print(f"[FP8SafeTensorLoader] Using explicit scale format: {scale_type} ({self._scale_suffix})")
def _get_experts_prefix(self, base_key: str) -> str:
"""Get the experts prefix based on detected format."""
@@ -305,7 +351,11 @@ class FP8SafeTensorLoader(SafeTensorLoader):
return tensor.to(device)
def load_experts(self, base_key: str, device: str = "cpu"):
"""Load FP8 expert weights and their block-wise scale_inv tensors."""
"""Load FP8 expert weights and their scale tensors.
Supports both block-wise (weight_scale_inv) and per-channel (weight_scale) formats.
Per-channel scales are squeezed from [N, 1] to [N] if needed.
"""
experts_prefix = self._get_experts_prefix(base_key)
gate_name, up_name, down_name = self._get_proj_names()
@@ -327,16 +377,30 @@ class FP8SafeTensorLoader(SafeTensorLoader):
gate_w_key = f"{experts_prefix}.{exp_id}.{gate_name}.weight"
up_w_key = f"{experts_prefix}.{exp_id}.{up_name}.weight"
down_w_key = f"{experts_prefix}.{exp_id}.{down_name}.weight"
gate_s_key = f"{experts_prefix}.{exp_id}.{gate_name}.weight_scale_inv"
up_s_key = f"{experts_prefix}.{exp_id}.{up_name}.weight_scale_inv"
down_s_key = f"{experts_prefix}.{exp_id}.{down_name}.weight_scale_inv"
gate_s_key = f"{experts_prefix}.{exp_id}.{gate_name}.{self._scale_suffix}"
up_s_key = f"{experts_prefix}.{exp_id}.{up_name}.{self._scale_suffix}"
down_s_key = f"{experts_prefix}.{exp_id}.{down_name}.{self._scale_suffix}"
gate_weights[exp_id] = self.load_tensor(gate_w_key, device).contiguous()
up_weights[exp_id] = self.load_tensor(up_w_key, device).contiguous()
down_weights[exp_id] = self.load_tensor(down_w_key, device).contiguous()
gate_scales[exp_id] = self.load_tensor(gate_s_key, device).contiguous()
up_scales[exp_id] = self.load_tensor(up_s_key, device).contiguous()
down_scales[exp_id] = self.load_tensor(down_s_key, device).contiguous()
gate_scale = self.load_tensor(gate_s_key, device)
up_scale = self.load_tensor(up_s_key, device)
down_scale = self.load_tensor(down_s_key, device)
# For per-channel scales, squeeze [N, 1] -> [N] if needed
if self._is_per_channel:
if gate_scale.dim() == 2 and gate_scale.shape[1] == 1:
gate_scale = gate_scale.squeeze(1)
if up_scale.dim() == 2 and up_scale.shape[1] == 1:
up_scale = up_scale.squeeze(1)
if down_scale.dim() == 2 and down_scale.shape[1] == 1:
down_scale = down_scale.squeeze(1)
gate_scales[exp_id] = gate_scale.contiguous()
up_scales[exp_id] = up_scale.contiguous()
down_scales[exp_id] = down_scale.contiguous()
return {
"gate": gate_weights,
@@ -347,6 +411,103 @@ class FP8SafeTensorLoader(SafeTensorLoader):
"down_scale": down_scales,
}
def is_per_channel(self) -> bool:
"""Return True if using per-channel quantization, False for block-wise."""
return self._is_per_channel
class BF16SafeTensorLoader(SafeTensorLoader):
"""Loader for native BF16 expert weights (no quantization, no scales).
Supported formats:
- DeepSeek style: {base}.mlp.experts.{id}.{gate,up,down}_proj.weight
- Mixtral/MiniMax style: {base}.block_sparse_moe.experts.{id}.{w1,w3,w2}.weight
The format is auto-detected during initialization.
"""
MOE_FORMATS = {
"deepseek": ("{base}.mlp.experts", "gate_proj", "up_proj", "down_proj"),
"mixtral": ("{base}.block_sparse_moe.experts", "w1", "w3", "w2"),
}
def __init__(self, file_path: str):
super().__init__(file_path)
self._detected_format = None
self._detect_format()
def _detect_format(self):
"""Auto-detect the MoE naming format by checking tensor keys."""
sample_keys = list(self.tensor_file_map.keys())[:1000]
for fmt_name, (path_tpl, gate, up, down) in self.MOE_FORMATS.items():
for key in sample_keys:
if ".experts." in key and f".{gate}.weight" in key:
if "block_sparse_moe.experts" in key and fmt_name == "mixtral":
self._detected_format = fmt_name
print(f"[BF16SafeTensorLoader] Detected format: {fmt_name}")
return
elif "mlp.experts" in key and "block_sparse_moe" not in key and fmt_name == "deepseek":
self._detected_format = fmt_name
print(f"[BF16SafeTensorLoader] Detected format: {fmt_name}")
return
self._detected_format = "deepseek"
print("[BF16SafeTensorLoader] No MoE format detected, defaulting to: deepseek")
def _get_experts_prefix(self, base_key: str) -> str:
"""Get the experts prefix based on detected format."""
path_tpl, _, _, _ = self.MOE_FORMATS[self._detected_format]
return path_tpl.format(base=base_key)
def _get_proj_names(self):
"""Get projection names (gate, up, down) based on detected format."""
_, gate, up, down = self.MOE_FORMATS[self._detected_format]
return gate, up, down
def load_tensor(self, key: str, device: str = "cpu"):
if key not in self.tensor_file_map:
raise KeyError(f"Key {key} not found in Safetensor files")
file = self.tensor_file_map[key]
f = self.file_handle_map.get(file)
if f is None:
raise FileNotFoundError(f"File {file} not found in Safetensor files")
tensor = f.get_tensor(key)
if device == "cpu":
return tensor
return tensor.to(device)
def load_experts(self, base_key: str, device: str = "cpu"):
"""Load BF16 expert weights (no scales needed)."""
experts_prefix = self._get_experts_prefix(base_key)
gate_name, up_name, down_name = self._get_proj_names()
expert_count = 0
while self.has_tensor(f"{experts_prefix}.{expert_count}.{gate_name}.weight"):
expert_count += 1
if expert_count == 0:
raise ValueError(f"No experts found for key {experts_prefix}")
gate_weights = [None] * expert_count
up_weights = [None] * expert_count
down_weights = [None] * expert_count
for exp_id in range(expert_count):
gate_w_key = f"{experts_prefix}.{exp_id}.{gate_name}.weight"
up_w_key = f"{experts_prefix}.{exp_id}.{up_name}.weight"
down_w_key = f"{experts_prefix}.{exp_id}.{down_name}.weight"
gate_weights[exp_id] = self.load_tensor(gate_w_key, device).contiguous()
up_weights[exp_id] = self.load_tensor(up_w_key, device).contiguous()
down_weights[exp_id] = self.load_tensor(down_w_key, device).contiguous()
return {
"gate": gate_weights,
"up": up_weights,
"down": down_weights,
}
class BF16SafeTensorLoader(SafeTensorLoader):
"""Loader for native BF16 expert weights (no quantization, no scales).