[feat](kt-kernel): support qwen3-vl weights convert (#1648)

This commit is contained in:
mrhaoxx
2025-11-27 22:29:09 +08:00
committed by GitHub
parent c256150e08
commit 637c49c83f

View File

@@ -73,22 +73,30 @@ def load_model_config(input_path: str, input_type: str = None) -> Dict:
with open(config_path, "r") as f:
config = json.load(f)
if "text_config" in config:
text_cfg = config["text_config"]
kt_cvt_type = "vl"
else:
text_cfg = config
kt_cvt_type = "base"
# Extract required fields with fallbacks
model_config = {
"num_experts": config.get("n_routed_experts", config.get("num_experts")),
"num_experts_per_tok": config.get("num_experts_per_tok", 2),
"hidden_size": config.get("hidden_size"),
"moe_intermediate_size": config.get("moe_intermediate_size", config.get("intermediate_size")),
"num_experts": text_cfg.get("n_routed_experts", text_cfg.get("num_experts")),
"num_experts_per_tok": text_cfg.get("num_experts_per_tok", 2),
"hidden_size": text_cfg.get("hidden_size"),
"moe_intermediate_size": text_cfg.get("moe_intermediate_size", text_cfg.get("intermediate_size")),
"_kt_cvt_type": kt_cvt_type,
}
# Validate required fields
missing_fields = [k for k, v in model_config.items() if v is None]
missing_fields = [k for k, v in model_config.items() if k != "_kt_cvt_type" and v is None]
if missing_fields:
raise ValueError(f"Missing required config fields: {missing_fields}")
# For FP8 input, extract and validate quantization_config
if input_type == "fp8":
quant_config = config.get("quantization_config")
quant_config = config.get("quantization_config") or text_cfg.get("quantization_config")
if quant_config is None:
raise ValueError(
"FP8 input type specified but 'quantization_config' not found in config.json. "
@@ -113,6 +121,7 @@ def load_model_config(input_path: str, input_type: str = None) -> Dict:
print(f" format: {quant_config.get('fmt', 'unknown')}")
print(f" weight_block_size: {weight_block_size}")
print(f"Model Type: {model_config['_kt_cvt_type']}")
return model_config
@@ -260,6 +269,7 @@ class ConverterBase:
self.num_experts_per_tok = model_config["num_experts_per_tok"]
self.hidden_size = model_config["hidden_size"]
self.moe_intermediate_size = model_config["moe_intermediate_size"]
self.kt_cvt_type = model_config.get("_kt_cvt_type", "base")
# Load input safetensors files
self._load_input_files()
@@ -302,6 +312,24 @@ class ConverterBase:
def _find_expert_layers(self) -> Dict[int, List[int]]:
"""Find all layers and experts in the model"""
layers = defaultdict(set)
# vl weights have a fused layout
# Pattern: model.language_model.layers.{layer}.mlp.experts.{proj}
if self.kt_cvt_type == "vl":
layers = set()
for key in self.tensor_file_map.keys():
if "model.language_model.layers." in key and ".mlp.experts." in key:
parts = key.split(".")
if len(parts) >= 7:
layer_idx = int(parts[3])
layers.add(layer_idx)
result: Dict[int, List[int]] = {}
for layer_idx in sorted(layers):
result[layer_idx] = [-1]
print(f"Found {len(result)} layers with fused MoE experts")
return result
# Pattern: model.layers.{layer}.mlp.experts.{expert}.{proj}.{type}
for key in self.tensor_file_map.keys():
@@ -675,76 +703,141 @@ class OnlineQuantConverter(ConverterBase):
def _convert_layer_experts(self, layer_idx: int, expert_ids: List[int]) -> Dict[str, torch.Tensor]:
"""Convert all experts in a layer using online quantization via AMXMoEWrapper"""
start_time = time.time()
print(f"Converting layer {layer_idx} with {len(expert_ids)} experts via online quantization...")
print(f"Converting layer {layer_idx} with {len(expert_ids) if self.kt_cvt_type == 'base' else 'fused'} experts via online quantization...")
# Load all expert weights for this layer
gate_weights = []
up_weights = []
down_weights = []
if self.kt_cvt_type == "vl":
if self.input_type not in ["bf16", "fp16"]:
raise ValueError(f"VL path currently supports bf16/fp16 only, got input_type={self.input_type}")
proj_set = set()
prefix = f"model.language_model.layers.{layer_idx}.mlp.experts."
for key in self.tensor_file_map.keys():
if key.startswith(prefix):
parts = key.split(".")
if len(parts) >= 7:
proj_set.add(parts[6])
for expert_id in expert_ids:
gate_key = f"model.layers.{layer_idx}.mlp.experts.{expert_id}.gate_proj.weight"
up_key = f"model.layers.{layer_idx}.mlp.experts.{expert_id}.up_proj.weight"
down_key = f"model.layers.{layer_idx}.mlp.experts.{expert_id}.down_proj.weight"
if not proj_set:
raise ValueError(
f"[VL] No fused MoE experts found for layer {layer_idx} under 'model.language_model.layers'"
)
if gate_key not in self.tensor_file_map:
raise KeyError(f"Missing gate weight for layer {layer_idx}, expert {expert_id}")
if up_key not in self.tensor_file_map:
raise KeyError(f"Missing up weight for layer {layer_idx}, expert {expert_id}")
if down_key not in self.tensor_file_map:
raise KeyError(f"Missing down weight for layer {layer_idx}, expert {expert_id}")
projs = sorted(proj_set)
print(f" [VL] layer {layer_idx} fused proj keys: {projs}")
# Load weights based on input type
if self.input_type == "fp8":
# Load FP8 weights and their scale_inv tensors
gate_scale_key = f"model.layers.{layer_idx}.mlp.experts.{expert_id}.gate_proj.weight_scale_inv"
up_scale_key = f"model.layers.{layer_idx}.mlp.experts.{expert_id}.up_proj.weight_scale_inv"
down_scale_key = f"model.layers.{layer_idx}.mlp.experts.{expert_id}.down_proj.weight_scale_inv"
if len(projs) < 2:
raise ValueError(
f"[VL] Expect at least 2 fused tensors (down & gate_up) in layer {layer_idx}, got {len(projs)}"
)
if gate_scale_key not in self.tensor_file_map:
raise KeyError(f"Missing gate weight_scale_inv for layer {layer_idx}, expert {expert_id}")
if up_scale_key not in self.tensor_file_map:
raise KeyError(f"Missing up weight_scale_inv for layer {layer_idx}, expert {expert_id}")
if down_scale_key not in self.tensor_file_map:
raise KeyError(f"Missing down weight_scale_inv for layer {layer_idx}, expert {expert_id}")
fused_tensors = []
for p in projs:
key = f"model.language_model.layers.{layer_idx}.mlp.experts.{p}"
if key not in self.tensor_file_map:
raise KeyError(f"[VL] Missing fused tensor {key} for layer {layer_idx}")
w = self._load_tensor(key)
if self.input_type == "fp16":
w = w.to(torch.bfloat16)
print(f" [VL] tensor {p} shape: {tuple(w.shape)}")
fused_tensors.append(w)
# Load FP8 weights and scales
gate_fp8 = self._load_tensor(gate_key).to("cuda")
up_fp8 = self._load_tensor(up_key).to("cuda")
down_fp8 = self._load_tensor(down_key).to("cuda")
# fused_tensors[0] : down-like, [E, I, H]
# fused_tensors[1] : gate_up-like, [E, H, 2I]
down_fused = fused_tensors[0]
gate_up_fused = fused_tensors[1]
gate_scale_inv = self._load_tensor(gate_scale_key).to("cuda")
up_scale_inv = self._load_tensor(up_scale_key).to("cuda")
down_scale_inv = self._load_tensor(down_scale_key).to("cuda")
# gate_up_fused: [E, H, 2I] -> [E, 2I, H] -> gate / up
if gate_up_fused.dim() != 3:
raise ValueError(f"[VL] Expect gate_up fused tensor to be 3D, got shape {tuple(gate_up_fused.shape)}")
E, H, twoI = gate_up_fused.shape
if twoI % 2 != 0:
raise ValueError(f"[VL] gate_up last dim (2I) not even: {twoI}")
I = twoI // 2
# Dequantize FP8 to BF16 using block-wise scaling
gate_weight = weight_dequant(gate_fp8, gate_scale_inv).to("cpu").to(torch.bfloat16).contiguous()
up_weight = weight_dequant(up_fp8, up_scale_inv).to("cpu").to(torch.bfloat16).contiguous()
down_weight = weight_dequant(down_fp8, down_scale_inv).to("cpu").to(torch.bfloat16).contiguous()
gate_up_T = gate_up_fused.transpose(1, 2).contiguous() # [E, 2I, H]
gate_proj = gate_up_T[:, :I, :] # [E, I, H]
up_proj = gate_up_T[:, I:, :] # [E, I, H]
elif self.input_type == "fp16":
# Load FP16 and convert to BF16
gate_weight = self._load_tensor(gate_key).to(torch.bfloat16)
up_weight = self._load_tensor(up_key).to(torch.bfloat16)
down_weight = self._load_tensor(down_key).to(torch.bfloat16)
if down_fused.dim() != 3:
raise ValueError(f"[VL] Expect down fused tensor to be 3D, got shape {tuple(down_fused.shape)}")
if down_fused.shape[0] != E:
raise ValueError(
f"[VL] down_fused expert dim mismatch: {down_fused.shape[0]} vs gate_up {E}"
)
down_proj = down_fused.transpose(1, 2).contiguous() # [E, H, I]
del fused_tensors
del gate_up_fused
del down_fused
else:
gate_weights = []
up_weights = []
down_weights = []
elif self.input_type == "bf16":
# Load BF16 directly
gate_weight = self._load_tensor(gate_key)
up_weight = self._load_tensor(up_key)
down_weight = self._load_tensor(down_key)
for expert_id in expert_ids:
gate_key = f"model.layers.{layer_idx}.mlp.experts.{expert_id}.gate_proj.weight"
up_key = f"model.layers.{layer_idx}.mlp.experts.{expert_id}.up_proj.weight"
down_key = f"model.layers.{layer_idx}.mlp.experts.{expert_id}.down_proj.weight"
else:
raise ValueError(f"Unsupported input_type for INT4 conversion: {self.input_type}")
if gate_key not in self.tensor_file_map:
raise KeyError(f"Missing gate weight for layer {layer_idx}, expert {expert_id}")
if up_key not in self.tensor_file_map:
raise KeyError(f"Missing up weight for layer {layer_idx}, expert {expert_id}")
if down_key not in self.tensor_file_map:
raise KeyError(f"Missing down weight for layer {layer_idx}, expert {expert_id}")
gate_weights.append(gate_weight)
up_weights.append(up_weight)
down_weights.append(down_weight)
# Load weights based on input type
if self.input_type == "fp8":
# Load FP8 weights and their scale_inv tensors
gate_scale_key = f"model.layers.{layer_idx}.mlp.experts.{expert_id}.gate_proj.weight_scale_inv"
up_scale_key = f"model.layers.{layer_idx}.mlp.experts.{expert_id}.up_proj.weight_scale_inv"
down_scale_key = f"model.layers.{layer_idx}.mlp.experts.{expert_id}.down_proj.weight_scale_inv"
if gate_scale_key not in self.tensor_file_map:
raise KeyError(f"Missing gate weight_scale_inv for layer {layer_idx}, expert {expert_id}")
if up_scale_key not in self.tensor_file_map:
raise KeyError(f"Missing up weight_scale_inv for layer {layer_idx}, expert {expert_id}")
if down_scale_key not in self.tensor_file_map:
raise KeyError(f"Missing down weight_scale_inv for layer {layer_idx}, expert {expert_id}")
# Load FP8 weights and scales
gate_fp8 = self._load_tensor(gate_key).to("cuda")
up_fp8 = self._load_tensor(up_key).to("cuda")
down_fp8 = self._load_tensor(down_key).to("cuda")
gate_scale_inv = self._load_tensor(gate_scale_key).to("cuda")
up_scale_inv = self._load_tensor(up_scale_key).to("cuda")
down_scale_inv = self._load_tensor(down_scale_key).to("cuda")
# Dequantize FP8 to BF16 using block-wise scaling
gate_weight = weight_dequant(gate_fp8, gate_scale_inv).to("cpu").to(torch.bfloat16).contiguous()
up_weight = weight_dequant(up_fp8, up_scale_inv).to("cpu").to(torch.bfloat16).contiguous()
down_weight = weight_dequant(down_fp8, down_scale_inv).to("cpu").to(torch.bfloat16).contiguous()
elif self.input_type == "fp16":
# Load FP16 and convert to BF16
gate_weight = self._load_tensor(gate_key).to(torch.bfloat16)
up_weight = self._load_tensor(up_key).to(torch.bfloat16)
down_weight = self._load_tensor(down_key).to(torch.bfloat16)
elif self.input_type == "bf16":
# Load BF16 directly
gate_weight = self._load_tensor(gate_key)
up_weight = self._load_tensor(up_key)
down_weight = self._load_tensor(down_key)
else:
raise ValueError(f"Unsupported input_type for INT4 conversion: {self.input_type}")
gate_weights.append(gate_weight)
up_weights.append(up_weight)
down_weights.append(down_weight)
# Stack weights into single tensors: [num_experts, ...]
gate_proj = torch.stack(gate_weights, dim=0).contiguous()
up_proj = torch.stack(up_weights, dim=0).contiguous()
down_proj = torch.stack(down_weights, dim=0).contiguous()
del gate_weights, up_weights, down_weights
# Stack weights into single tensors: [num_experts, ...]
gate_proj = torch.stack(gate_weights, dim=0).contiguous()
up_proj = torch.stack(up_weights, dim=0).contiguous()
down_proj = torch.stack(down_weights, dim=0).contiguous()
print(f" Loaded weights shapes:")
print(f" gate_proj: {gate_proj.shape}")
@@ -784,8 +877,7 @@ class OnlineQuantConverter(ConverterBase):
# This triggers the quantization process and saves to disk
wrapper.load_weights_from_tensors(gate_proj, up_proj, down_proj, physical_to_logical_map)
# Clean up to free memory
del gate_weights, up_weights, down_weights
# Clean up to free memory
del gate_proj, up_proj, down_proj
gc.collect()