[feat](kt-kernel): refactor convert_cpu_weights.py to support conversation for GLM-4.6V (#1687)

Signed-off-by: mrhaoxx <mr.haoxx@gmail.com>
This commit is contained in:
mrhaoxx
2025-12-09 14:24:41 +08:00
committed by GitHub
parent ac69ea891e
commit 503295fc88

View File

@@ -75,10 +75,8 @@ def load_model_config(input_path: str, input_type: str = None) -> Dict:
if "text_config" in config:
text_cfg = config["text_config"]
kt_cvt_type = "vl"
else:
text_cfg = config
kt_cvt_type = "base"
# Extract required fields with fallbacks
model_config = {
@@ -86,11 +84,10 @@ def load_model_config(input_path: str, input_type: str = None) -> Dict:
"num_experts_per_tok": text_cfg.get("num_experts_per_tok", 2),
"hidden_size": text_cfg.get("hidden_size"),
"moe_intermediate_size": text_cfg.get("moe_intermediate_size", text_cfg.get("intermediate_size")),
"_kt_cvt_type": kt_cvt_type,
}
# Validate required fields
missing_fields = [k for k, v in model_config.items() if k != "_kt_cvt_type" and v is None]
missing_fields = [k for k, v in model_config.items() if v is None]
if missing_fields:
raise ValueError(f"Missing required config fields: {missing_fields}")
@@ -120,8 +117,6 @@ def load_model_config(input_path: str, input_type: str = None) -> Dict:
print(f"FP8 quantization config detected:")
print(f" format: {quant_config.get('fmt', 'unknown')}")
print(f" weight_block_size: {weight_block_size}")
print(f"Model Type: {model_config['_kt_cvt_type']}")
return model_config
@@ -262,6 +257,7 @@ class ConverterBase:
self.input_type = input_type
self.merge_to_safetensor = merge_to_safetensor
self.tensor_file_map: Dict[str, str] = {} # key -> filename
self.tensor_key_map: Dict[str, str] = {} # old key -> new key
self.file_handle_map: Dict[str, any] = {} # filename -> file
# Extract commonly used config values for convenience
@@ -269,7 +265,7 @@ class ConverterBase:
self.num_experts_per_tok = model_config["num_experts_per_tok"]
self.hidden_size = model_config["hidden_size"]
self.moe_intermediate_size = model_config["moe_intermediate_size"]
self.kt_cvt_type = model_config.get("_kt_cvt_type", "base")
self.layout = "base"
# Load input safetensors files
self._load_input_files()
@@ -288,9 +284,19 @@ class ConverterBase:
try:
handle = safe_open(file_path, framework="pt")
self.file_handle_map[file] = handle
renamed = False
for key in handle.keys():
self.tensor_file_map[key] = file
print(f" Loaded: {file} ({len(list(handle.keys()))} tensors)")
if "language_model" in key:
key_ = key.replace("language_model.", "")
# print(" Renaming key:", key, "->", key_)
renamed = True
else:
key_ = key
self.tensor_key_map[key_] = key
self.tensor_file_map[key_] = file
print(
f" Loaded: {file} ({len(list(handle.keys()))} tensors){' (renamed keys)' if renamed else ''}"
)
except Exception as e:
print(f" Error loading {file}: {e}")
@@ -306,22 +312,26 @@ class ConverterBase:
file = self.tensor_file_map[key]
handle = self.file_handle_map[file]
return handle.get_tensor(key)
return handle.get_tensor(self.tensor_key_map.get(key, key))
# layers_id -> list[experts_id]
def _find_expert_layers(self) -> Dict[int, List[int]]:
"""Find all layers and experts in the model"""
layers = defaultdict(set)
# vl weights have a fused layout
# Pattern: model.language_model.layers.{layer}.mlp.experts.{proj}
if self.kt_cvt_type == "vl":
# detect layout
for key in self.tensor_file_map.keys():
if "mlp.experts" in key and "gate_up" in key:
self.layout = "fused"
break
if self.layout == "fused": # Pattern: model.layers.{layer}.mlp.experts.{proj}
layers = set()
for key in self.tensor_file_map.keys():
if "model.language_model.layers." in key and ".mlp.experts." in key:
if "model.layers." in key and ".mlp.experts." in key:
parts = key.split(".")
if len(parts) >= 7:
layer_idx = int(parts[3])
if len(parts) >= 6:
layer_idx = int(parts[2])
layers.add(layer_idx)
result: Dict[int, List[int]] = {}
@@ -703,42 +713,41 @@ class OnlineQuantConverter(ConverterBase):
def _convert_layer_experts(self, layer_idx: int, expert_ids: List[int]) -> Dict[str, torch.Tensor]:
"""Convert all experts in a layer using online quantization via AMXMoEWrapper"""
start_time = time.time()
print(f"Converting layer {layer_idx} with {len(expert_ids) if self.kt_cvt_type == 'base' else 'fused'} experts via online quantization...")
print(
f"Converting layer {layer_idx} with {len(expert_ids) if self.layout == 'base' else 'fused'} experts via online quantization..."
)
# Load all expert weights for this layer
if self.kt_cvt_type == "vl":
if self.layout == "fused":
if self.input_type not in ["bf16", "fp16"]:
raise ValueError(f"VL path currently supports bf16/fp16 only, got input_type={self.input_type}")
raise ValueError(f"Fused path currently supports bf16/fp16 only, got input_type={self.input_type}")
proj_set = set()
prefix = f"model.language_model.layers.{layer_idx}.mlp.experts."
prefix = f"model.layers.{layer_idx}.mlp.experts."
for key in self.tensor_file_map.keys():
if key.startswith(prefix):
parts = key.split(".")
if len(parts) >= 7:
proj_set.add(parts[6])
if len(parts) >= 6:
proj_set.add(parts[5])
if not proj_set:
raise ValueError(
f"[VL] No fused MoE experts found for layer {layer_idx} under 'model.language_model.layers'"
)
raise ValueError(f"[Fused] No fused MoE experts found for layer {layer_idx} under 'model.layers'")
projs = sorted(proj_set)
print(f" [VL] layer {layer_idx} fused proj keys: {projs}")
print(f" [Fused] layer {layer_idx} fused proj keys: {projs}")
if len(projs) < 2:
raise ValueError(
f"[VL] Expect at least 2 fused tensors (down & gate_up) in layer {layer_idx}, got {len(projs)}"
f"[Fused] Expect at least 2 fused tensors (down & gate_up) in layer {layer_idx}, got {len(projs)}"
)
fused_tensors = []
for p in projs:
key = f"model.language_model.layers.{layer_idx}.mlp.experts.{p}"
key = f"model.layers.{layer_idx}.mlp.experts.{p}"
if key not in self.tensor_file_map:
raise KeyError(f"[VL] Missing fused tensor {key} for layer {layer_idx}")
raise KeyError(f"[Fused] Missing fused tensor {key} for layer {layer_idx}")
w = self._load_tensor(key)
if self.input_type == "fp16":
w = w.to(torch.bfloat16)
print(f" [VL] tensor {p} shape: {tuple(w.shape)}")
print(f" [Fused] tensor {p} shape: {tuple(w.shape)}")
fused_tensors.append(w)
# fused_tensors[0] : down-like, [E, I, H]
@@ -748,23 +757,23 @@ class OnlineQuantConverter(ConverterBase):
# gate_up_fused: [E, H, 2I] -> [E, 2I, H] -> gate / up
if gate_up_fused.dim() != 3:
raise ValueError(f"[VL] Expect gate_up fused tensor to be 3D, got shape {tuple(gate_up_fused.shape)}")
raise ValueError(
f"[Fused] Expect gate_up fused tensor to be 3D, got shape {tuple(gate_up_fused.shape)}"
)
E, H, twoI = gate_up_fused.shape
if twoI % 2 != 0:
raise ValueError(f"[VL] gate_up last dim (2I) not even: {twoI}")
raise ValueError(f"[Fused] gate_up last dim (2I) not even: {twoI}")
I = twoI // 2
gate_up_T = gate_up_fused.transpose(1, 2).contiguous() # [E, 2I, H]
gate_proj = gate_up_T[:, :I, :] # [E, I, H]
up_proj = gate_up_T[:, I:, :] # [E, I, H]
gate_up_T = gate_up_fused.transpose(1, 2).contiguous() # [E, 2I, H]
gate_proj = gate_up_T[:, :I, :] # [E, I, H]
up_proj = gate_up_T[:, I:, :] # [E, I, H]
if down_fused.dim() != 3:
raise ValueError(f"[VL] Expect down fused tensor to be 3D, got shape {tuple(down_fused.shape)}")
raise ValueError(f"[Fused] Expect down fused tensor to be 3D, got shape {tuple(down_fused.shape)}")
if down_fused.shape[0] != E:
raise ValueError(
f"[VL] down_fused expert dim mismatch: {down_fused.shape[0]} vs gate_up {E}"
)
down_proj = down_fused.transpose(1, 2).contiguous() # [E, H, I]
raise ValueError(f"[Fused] down_fused expert dim mismatch: {down_fused.shape[0]} vs gate_up {E}")
down_proj = down_fused.transpose(1, 2).contiguous() # [E, H, I]
del fused_tensors
del gate_up_fused
del down_fused
@@ -838,7 +847,6 @@ class OnlineQuantConverter(ConverterBase):
down_proj = torch.stack(down_weights, dim=0).contiguous()
del gate_weights, up_weights, down_weights
print(f" Loaded weights shapes:")
print(f" gate_proj: {gate_proj.shape}")
print(f" up_proj: {up_proj.shape}")
@@ -877,7 +885,7 @@ class OnlineQuantConverter(ConverterBase):
# This triggers the quantization process and saves to disk
wrapper.load_weights_from_tensors(gate_proj, up_proj, down_proj, physical_to_logical_map)
# Clean up to free memory
# Clean up to free memory
del gate_proj, up_proj, down_proj
gc.collect()