mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2026-04-20 14:29:22 +00:00
[feat](kt-kernel): refactor convert_cpu_weights.py to support conversation for GLM-4.6V (#1687)
Signed-off-by: mrhaoxx <mr.haoxx@gmail.com>
This commit is contained in:
@@ -75,10 +75,8 @@ def load_model_config(input_path: str, input_type: str = None) -> Dict:
|
||||
|
||||
if "text_config" in config:
|
||||
text_cfg = config["text_config"]
|
||||
kt_cvt_type = "vl"
|
||||
else:
|
||||
text_cfg = config
|
||||
kt_cvt_type = "base"
|
||||
|
||||
# Extract required fields with fallbacks
|
||||
model_config = {
|
||||
@@ -86,11 +84,10 @@ def load_model_config(input_path: str, input_type: str = None) -> Dict:
|
||||
"num_experts_per_tok": text_cfg.get("num_experts_per_tok", 2),
|
||||
"hidden_size": text_cfg.get("hidden_size"),
|
||||
"moe_intermediate_size": text_cfg.get("moe_intermediate_size", text_cfg.get("intermediate_size")),
|
||||
"_kt_cvt_type": kt_cvt_type,
|
||||
}
|
||||
|
||||
# Validate required fields
|
||||
missing_fields = [k for k, v in model_config.items() if k != "_kt_cvt_type" and v is None]
|
||||
missing_fields = [k for k, v in model_config.items() if v is None]
|
||||
if missing_fields:
|
||||
raise ValueError(f"Missing required config fields: {missing_fields}")
|
||||
|
||||
@@ -120,8 +117,6 @@ def load_model_config(input_path: str, input_type: str = None) -> Dict:
|
||||
print(f"FP8 quantization config detected:")
|
||||
print(f" format: {quant_config.get('fmt', 'unknown')}")
|
||||
print(f" weight_block_size: {weight_block_size}")
|
||||
|
||||
print(f"Model Type: {model_config['_kt_cvt_type']}")
|
||||
return model_config
|
||||
|
||||
|
||||
@@ -262,6 +257,7 @@ class ConverterBase:
|
||||
self.input_type = input_type
|
||||
self.merge_to_safetensor = merge_to_safetensor
|
||||
self.tensor_file_map: Dict[str, str] = {} # key -> filename
|
||||
self.tensor_key_map: Dict[str, str] = {} # old key -> new key
|
||||
self.file_handle_map: Dict[str, any] = {} # filename -> file
|
||||
|
||||
# Extract commonly used config values for convenience
|
||||
@@ -269,7 +265,7 @@ class ConverterBase:
|
||||
self.num_experts_per_tok = model_config["num_experts_per_tok"]
|
||||
self.hidden_size = model_config["hidden_size"]
|
||||
self.moe_intermediate_size = model_config["moe_intermediate_size"]
|
||||
self.kt_cvt_type = model_config.get("_kt_cvt_type", "base")
|
||||
self.layout = "base"
|
||||
|
||||
# Load input safetensors files
|
||||
self._load_input_files()
|
||||
@@ -288,9 +284,19 @@ class ConverterBase:
|
||||
try:
|
||||
handle = safe_open(file_path, framework="pt")
|
||||
self.file_handle_map[file] = handle
|
||||
renamed = False
|
||||
for key in handle.keys():
|
||||
self.tensor_file_map[key] = file
|
||||
print(f" Loaded: {file} ({len(list(handle.keys()))} tensors)")
|
||||
if "language_model" in key:
|
||||
key_ = key.replace("language_model.", "")
|
||||
# print(" Renaming key:", key, "->", key_)
|
||||
renamed = True
|
||||
else:
|
||||
key_ = key
|
||||
self.tensor_key_map[key_] = key
|
||||
self.tensor_file_map[key_] = file
|
||||
print(
|
||||
f" Loaded: {file} ({len(list(handle.keys()))} tensors){' (renamed keys)' if renamed else ''}"
|
||||
)
|
||||
except Exception as e:
|
||||
print(f" Error loading {file}: {e}")
|
||||
|
||||
@@ -306,22 +312,26 @@ class ConverterBase:
|
||||
|
||||
file = self.tensor_file_map[key]
|
||||
handle = self.file_handle_map[file]
|
||||
return handle.get_tensor(key)
|
||||
return handle.get_tensor(self.tensor_key_map.get(key, key))
|
||||
|
||||
# layers_id -> list[experts_id]
|
||||
def _find_expert_layers(self) -> Dict[int, List[int]]:
|
||||
"""Find all layers and experts in the model"""
|
||||
layers = defaultdict(set)
|
||||
|
||||
# vl weights have a fused layout
|
||||
# Pattern: model.language_model.layers.{layer}.mlp.experts.{proj}
|
||||
if self.kt_cvt_type == "vl":
|
||||
|
||||
# detect layout
|
||||
for key in self.tensor_file_map.keys():
|
||||
if "mlp.experts" in key and "gate_up" in key:
|
||||
self.layout = "fused"
|
||||
break
|
||||
|
||||
if self.layout == "fused": # Pattern: model.layers.{layer}.mlp.experts.{proj}
|
||||
layers = set()
|
||||
for key in self.tensor_file_map.keys():
|
||||
if "model.language_model.layers." in key and ".mlp.experts." in key:
|
||||
if "model.layers." in key and ".mlp.experts." in key:
|
||||
parts = key.split(".")
|
||||
if len(parts) >= 7:
|
||||
layer_idx = int(parts[3])
|
||||
if len(parts) >= 6:
|
||||
layer_idx = int(parts[2])
|
||||
layers.add(layer_idx)
|
||||
|
||||
result: Dict[int, List[int]] = {}
|
||||
@@ -703,42 +713,41 @@ class OnlineQuantConverter(ConverterBase):
|
||||
def _convert_layer_experts(self, layer_idx: int, expert_ids: List[int]) -> Dict[str, torch.Tensor]:
|
||||
"""Convert all experts in a layer using online quantization via AMXMoEWrapper"""
|
||||
start_time = time.time()
|
||||
print(f"Converting layer {layer_idx} with {len(expert_ids) if self.kt_cvt_type == 'base' else 'fused'} experts via online quantization...")
|
||||
print(
|
||||
f"Converting layer {layer_idx} with {len(expert_ids) if self.layout == 'base' else 'fused'} experts via online quantization..."
|
||||
)
|
||||
# Load all expert weights for this layer
|
||||
if self.kt_cvt_type == "vl":
|
||||
if self.layout == "fused":
|
||||
if self.input_type not in ["bf16", "fp16"]:
|
||||
raise ValueError(f"VL path currently supports bf16/fp16 only, got input_type={self.input_type}")
|
||||
|
||||
raise ValueError(f"Fused path currently supports bf16/fp16 only, got input_type={self.input_type}")
|
||||
|
||||
proj_set = set()
|
||||
prefix = f"model.language_model.layers.{layer_idx}.mlp.experts."
|
||||
prefix = f"model.layers.{layer_idx}.mlp.experts."
|
||||
for key in self.tensor_file_map.keys():
|
||||
if key.startswith(prefix):
|
||||
parts = key.split(".")
|
||||
if len(parts) >= 7:
|
||||
proj_set.add(parts[6])
|
||||
if len(parts) >= 6:
|
||||
proj_set.add(parts[5])
|
||||
|
||||
if not proj_set:
|
||||
raise ValueError(
|
||||
f"[VL] No fused MoE experts found for layer {layer_idx} under 'model.language_model.layers'"
|
||||
)
|
||||
raise ValueError(f"[Fused] No fused MoE experts found for layer {layer_idx} under 'model.layers'")
|
||||
|
||||
projs = sorted(proj_set)
|
||||
print(f" [VL] layer {layer_idx} fused proj keys: {projs}")
|
||||
|
||||
print(f" [Fused] layer {layer_idx} fused proj keys: {projs}")
|
||||
if len(projs) < 2:
|
||||
raise ValueError(
|
||||
f"[VL] Expect at least 2 fused tensors (down & gate_up) in layer {layer_idx}, got {len(projs)}"
|
||||
f"[Fused] Expect at least 2 fused tensors (down & gate_up) in layer {layer_idx}, got {len(projs)}"
|
||||
)
|
||||
|
||||
fused_tensors = []
|
||||
for p in projs:
|
||||
key = f"model.language_model.layers.{layer_idx}.mlp.experts.{p}"
|
||||
key = f"model.layers.{layer_idx}.mlp.experts.{p}"
|
||||
if key not in self.tensor_file_map:
|
||||
raise KeyError(f"[VL] Missing fused tensor {key} for layer {layer_idx}")
|
||||
raise KeyError(f"[Fused] Missing fused tensor {key} for layer {layer_idx}")
|
||||
w = self._load_tensor(key)
|
||||
if self.input_type == "fp16":
|
||||
w = w.to(torch.bfloat16)
|
||||
print(f" [VL] tensor {p} shape: {tuple(w.shape)}")
|
||||
print(f" [Fused] tensor {p} shape: {tuple(w.shape)}")
|
||||
fused_tensors.append(w)
|
||||
|
||||
# fused_tensors[0] : down-like, [E, I, H]
|
||||
@@ -748,23 +757,23 @@ class OnlineQuantConverter(ConverterBase):
|
||||
|
||||
# gate_up_fused: [E, H, 2I] -> [E, 2I, H] -> gate / up
|
||||
if gate_up_fused.dim() != 3:
|
||||
raise ValueError(f"[VL] Expect gate_up fused tensor to be 3D, got shape {tuple(gate_up_fused.shape)}")
|
||||
raise ValueError(
|
||||
f"[Fused] Expect gate_up fused tensor to be 3D, got shape {tuple(gate_up_fused.shape)}"
|
||||
)
|
||||
E, H, twoI = gate_up_fused.shape
|
||||
if twoI % 2 != 0:
|
||||
raise ValueError(f"[VL] gate_up last dim (2I) not even: {twoI}")
|
||||
raise ValueError(f"[Fused] gate_up last dim (2I) not even: {twoI}")
|
||||
I = twoI // 2
|
||||
|
||||
gate_up_T = gate_up_fused.transpose(1, 2).contiguous() # [E, 2I, H]
|
||||
gate_proj = gate_up_T[:, :I, :] # [E, I, H]
|
||||
up_proj = gate_up_T[:, I:, :] # [E, I, H]
|
||||
gate_up_T = gate_up_fused.transpose(1, 2).contiguous() # [E, 2I, H]
|
||||
gate_proj = gate_up_T[:, :I, :] # [E, I, H]
|
||||
up_proj = gate_up_T[:, I:, :] # [E, I, H]
|
||||
|
||||
if down_fused.dim() != 3:
|
||||
raise ValueError(f"[VL] Expect down fused tensor to be 3D, got shape {tuple(down_fused.shape)}")
|
||||
raise ValueError(f"[Fused] Expect down fused tensor to be 3D, got shape {tuple(down_fused.shape)}")
|
||||
if down_fused.shape[0] != E:
|
||||
raise ValueError(
|
||||
f"[VL] down_fused expert dim mismatch: {down_fused.shape[0]} vs gate_up {E}"
|
||||
)
|
||||
down_proj = down_fused.transpose(1, 2).contiguous() # [E, H, I]
|
||||
raise ValueError(f"[Fused] down_fused expert dim mismatch: {down_fused.shape[0]} vs gate_up {E}")
|
||||
down_proj = down_fused.transpose(1, 2).contiguous() # [E, H, I]
|
||||
del fused_tensors
|
||||
del gate_up_fused
|
||||
del down_fused
|
||||
@@ -838,7 +847,6 @@ class OnlineQuantConverter(ConverterBase):
|
||||
down_proj = torch.stack(down_weights, dim=0).contiguous()
|
||||
del gate_weights, up_weights, down_weights
|
||||
|
||||
|
||||
print(f" Loaded weights shapes:")
|
||||
print(f" gate_proj: {gate_proj.shape}")
|
||||
print(f" up_proj: {up_proj.shape}")
|
||||
@@ -877,7 +885,7 @@ class OnlineQuantConverter(ConverterBase):
|
||||
# This triggers the quantization process and saves to disk
|
||||
wrapper.load_weights_from_tensors(gate_proj, up_proj, down_proj, physical_to_logical_map)
|
||||
|
||||
# Clean up to free memory
|
||||
# Clean up to free memory
|
||||
del gate_proj, up_proj, down_proj
|
||||
gc.collect()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user