diff --git a/kt-kernel/python/utils/amx.py b/kt-kernel/python/utils/amx.py index 1dfd3b6..cb7fd82 100644 --- a/kt-kernel/python/utils/amx.py +++ b/kt-kernel/python/utils/amx.py @@ -448,6 +448,10 @@ class NativeMoEWrapper(BaseMoEWrapper): self.down_scales = [t.to(torch.float32).contiguous() for t in weights["down_scale"]] assert self.gate_scales[0].dtype == torch.float32, "Expected float32 scales for FP8" elif self.method == "FP8_PERCHANNEL": + if self.gate_scales[0].dtype != torch.float32: + self.gate_scales = [t.to(torch.float32).contiguous() for t in weights["gate_scale"]] + self.up_scales = [t.to(torch.float32).contiguous() for t in weights["up_scale"]] + self.down_scales = [t.to(torch.float32).contiguous() for t in weights["down_scale"]] assert self.gate_scales[0].dtype == torch.float32, "Expected float32 scales for FP8_PERCHANNEL" t2 = time.time() diff --git a/kt-kernel/python/utils/loader.py b/kt-kernel/python/utils/loader.py index 250a99e..ed3ce89 100644 --- a/kt-kernel/python/utils/loader.py +++ b/kt-kernel/python/utils/loader.py @@ -243,6 +243,7 @@ class FP8SafeTensorLoader(SafeTensorLoader): Supported formats: - DeepSeek style: {base}.mlp.experts.{id}.{gate,up,down}_proj.weight - Mixtral/MiniMax style: {base}.block_sparse_moe.experts.{id}.{w1,w3,w2}.weight + - Mistral style: {base}.experts.{id}.{w1,w3,w2}.weight Supported scale formats (auto-detected): - Block-wise: weight_scale_inv (DeepSeek FP8) @@ -255,6 +256,7 @@ class FP8SafeTensorLoader(SafeTensorLoader): MOE_FORMATS = { "deepseek": ("{base}.mlp.experts", "gate_proj", "up_proj", "down_proj"), "mixtral": ("{base}.block_sparse_moe.experts", "w1", "w3", "w2"), + "mistral": ("{base}.experts", "w1", "w3", "w2"), } def __init__(self, file_path: str, scale_suffix: str = None): @@ -297,6 +299,10 @@ class FP8SafeTensorLoader(SafeTensorLoader): self._detected_format = fmt_name print(f"[FP8SafeTensorLoader] Detected format: {fmt_name}") break + elif fmt_name == "mistral" and ".mlp.experts" not in key and ".block_sparse_moe.experts" not in key: + self._detected_format = fmt_name + print(f"[FP8SafeTensorLoader] Detected format: {fmt_name}") + break if self._detected_format: break @@ -321,8 +327,21 @@ class FP8SafeTensorLoader(SafeTensorLoader): return elif f".{gate}.weight_scale" in key and "weight_scale_inv" not in key: self._scale_suffix = "weight_scale" - self._is_per_channel = True - print("[FP8SafeTensorLoader] Detected scale format: per-channel (weight_scale)") + # Some models (e.g., Mistral) use block-wise FP8 scales but keep + # the key suffix as `weight_scale` (without `_inv`). Infer format + # from scale tensor shape instead of suffix alone: + # - per-channel: [N] or [N, 1] + # - block-wise: [N_block, K_block] (both dims > 1) + scale_tensor = self.load_tensor(key, device="cpu") + if scale_tensor.dim() == 1: + self._is_per_channel = True + elif scale_tensor.dim() == 2 and scale_tensor.shape[1] == 1: + self._is_per_channel = True + else: + self._is_per_channel = False + + scale_kind = "per-channel" if self._is_per_channel else "block-wise" + print(f"[FP8SafeTensorLoader] Detected scale format: {scale_kind} (weight_scale)") return # Default to weight_scale_inv self._scale_suffix = "weight_scale_inv" @@ -333,12 +352,20 @@ class FP8SafeTensorLoader(SafeTensorLoader): scale_type = "per-channel" if self._is_per_channel else "block-wise" print(f"[FP8SafeTensorLoader] Using explicit scale format: {scale_type} ({self._scale_suffix})") - def _get_experts_prefix(self, base_key: str) -> str: - """Get the experts prefix based on detected format.""" + def _get_experts_prefix_candidates(self, base_key: str) -> list[str]: + """Get candidate experts prefixes based on detected format and base key variants.""" path_tpl, _, _, _ = self.MOE_FORMATS[self._detected_format] + candidates = [] if self._is_vl_model: base_key = base_key.replace("model.layers", "model.language_model.layers") - return path_tpl.format(base=base_key) + candidates.append(path_tpl.format(base=base_key)) + + # Some model weights (e.g., Mistral native format) do not have "model." prefix. + if base_key.startswith("model."): + candidates.append(path_tpl.format(base=base_key[len("model.") :])) + + # Deduplicate while preserving order. + return list(dict.fromkeys(candidates)) def _get_proj_names(self): """Get projection names (gate, up, down) based on detected format.""" @@ -363,15 +390,21 @@ class FP8SafeTensorLoader(SafeTensorLoader): Supports both block-wise (weight_scale_inv) and per-channel (weight_scale) formats. Per-channel scales are squeezed from [N, 1] to [N] if needed. """ - experts_prefix = self._get_experts_prefix(base_key) + experts_prefix_candidates = self._get_experts_prefix_candidates(base_key) gate_name, up_name, down_name = self._get_proj_names() expert_count = 0 - while self.has_tensor(f"{experts_prefix}.{expert_count}.{gate_name}.weight"): - expert_count += 1 + experts_prefix = None + for prefix in experts_prefix_candidates: + expert_count = 0 + while self.has_tensor(f"{prefix}.{expert_count}.{gate_name}.weight"): + expert_count += 1 + if expert_count > 0: + experts_prefix = prefix + break - if expert_count == 0: - raise ValueError(f"No experts found for key {experts_prefix}") + if expert_count == 0 or experts_prefix is None: + raise ValueError(f"No experts found for keys: {experts_prefix_candidates}") gate_weights = [None] * expert_count up_weights = [None] * expert_count @@ -423,13 +456,13 @@ class FP8SafeTensorLoader(SafeTensorLoader): return self._is_per_channel - class BF16SafeTensorLoader(SafeTensorLoader): """Loader for native BF16 expert weights (no quantization, no scales). Supported formats: - DeepSeek style: {base}.mlp.experts.{id}.{gate,up,down}_proj.weight - Mixtral/MiniMax style: {base}.block_sparse_moe.experts.{id}.{w1,w3,w2}.weight + - Mistral style: {base}.experts.{id}.{w1,w3,w2}.weight The format is auto-detected during initialization. """ @@ -437,6 +470,7 @@ class BF16SafeTensorLoader(SafeTensorLoader): MOE_FORMATS = { "deepseek": ("{base}.mlp.experts", "gate_proj", "up_proj", "down_proj"), "mixtral": ("{base}.block_sparse_moe.experts", "w1", "w3", "w2"), + "mistral": ("{base}.experts", "w1", "w3", "w2"), } def __init__(self, file_path: str): @@ -466,14 +500,24 @@ class BF16SafeTensorLoader(SafeTensorLoader): self._detected_format = fmt_name print(f"[BF16SafeTensorLoader] Detected format: {fmt_name}") return + elif fmt_name == "mistral" and ".mlp.experts" not in key and ".block_sparse_moe.experts" not in key: + self._detected_format = fmt_name + print(f"[BF16SafeTensorLoader] Detected format: {fmt_name}") + return self._detected_format = "deepseek" print("[BF16SafeTensorLoader] No MoE format detected, defaulting to: deepseek") - def _get_experts_prefix(self, base_key: str) -> str: - """Get the experts prefix based on detected format.""" + def _get_experts_prefix_candidates(self, base_key: str) -> list[str]: + """Get candidate experts prefixes based on detected format and base key variants.""" path_tpl, _, _, _ = self.MOE_FORMATS[self._detected_format] - return path_tpl.format(base=base_key) + candidates = [path_tpl.format(base=base_key)] + + # Some model weights (e.g., Mistral native format) do not have "model." prefix. + if base_key.startswith("model."): + candidates.append(path_tpl.format(base=base_key[len("model.") :])) + + return list(dict.fromkeys(candidates)) def _get_proj_names(self): """Get projection names (gate, up, down) based on detected format.""" @@ -497,15 +541,21 @@ class BF16SafeTensorLoader(SafeTensorLoader): if self._detected_format == "packed": return self._load_experts_packed(base_key, device) - experts_prefix = self._get_experts_prefix(base_key) + experts_prefix_candidates = self._get_experts_prefix_candidates(base_key) gate_name, up_name, down_name = self._get_proj_names() expert_count = 0 - while self.has_tensor(f"{experts_prefix}.{expert_count}.{gate_name}.weight"): - expert_count += 1 + experts_prefix = None + for prefix in experts_prefix_candidates: + expert_count = 0 + while self.has_tensor(f"{prefix}.{expert_count}.{gate_name}.weight"): + expert_count += 1 + if expert_count > 0: + experts_prefix = prefix + break - if expert_count == 0: - raise ValueError(f"No experts found for key {experts_prefix}") + if expert_count == 0 or experts_prefix is None: + raise ValueError(f"No experts found for keys: {experts_prefix_candidates}") gate_weights = [None] * expert_count up_weights = [None] * expert_count