From edc48aba37a0289a794e02ca4e72731735af956a Mon Sep 17 00:00:00 2001 From: Jiaqi Liao <30439460+SkqLiao@users.noreply.github.com> Date: Wed, 28 Jan 2026 16:31:56 +0800 Subject: [PATCH] [fix]: fix wrapper import issue (#1819) --- kt-kernel/python/utils/amx.py | 74 ++++++++++++++++++++++------------- kt-kernel/setup.py | 9 +++-- 2 files changed, 51 insertions(+), 32 deletions(-) diff --git a/kt-kernel/python/utils/amx.py b/kt-kernel/python/utils/amx.py index e4eec46..1dfd3b6 100644 --- a/kt-kernel/python/utils/amx.py +++ b/kt-kernel/python/utils/amx.py @@ -7,24 +7,21 @@ from typing import Optional from ..experts_base import BaseMoEWrapper from .loader import SafeTensorLoader, CompressedSafeTensorLoader, FP8SafeTensorLoader, BF16SafeTensorLoader from kt_kernel_ext.moe import MOEConfig +import kt_kernel_ext.moe as _moe_mod -try: - from kt_kernel_ext.moe import AMXInt4_MOE, AMXInt8_MOE, AMXInt4_KGroup_MOE, AMXFP8_MOE, AMXBF16_MOE +AMXInt4_MOE = getattr(_moe_mod, "AMXInt4_MOE", None) +AMXInt8_MOE = getattr(_moe_mod, "AMXInt8_MOE", None) +AMXInt4_KGroup_MOE = getattr(_moe_mod, "AMXInt4_KGroup_MOE", None) +AMXFP8_MOE = getattr(_moe_mod, "AMXFP8_MOE", None) +AMXBF16_MOE = getattr(_moe_mod, "AMXBF16_MOE", None) +AMXFP8PerChannel_MOE = getattr(_moe_mod, "AMXFP8PerChannel_MOE", None) - _HAS_AMX_SUPPORT = True -except (ImportError, AttributeError): - _HAS_AMX_SUPPORT = False - AMXInt4_MOE, AMXInt8_MOE, AMXInt4_KGroup_MOE, AMXFP8_MOE, AMXBF16_MOE = None, None, None, None, None - -try: - from kt_kernel_ext.moe import AMXFP8PerChannel_MOE - - _HAS_FP8_PERCHANNEL_SUPPORT = True -except (ImportError, AttributeError): - _HAS_FP8_PERCHANNEL_SUPPORT = False - AMXFP8PerChannel_MOE = None - -from typing import Optional +_HAS_AMXINT4_SUPPORT = AMXInt4_MOE is not None +_HAS_AMXINT8_SUPPORT = AMXInt8_MOE is not None +_HAS_RAWINT4_SUPPORT = AMXInt4_KGroup_MOE is not None +_HAS_FP8_SUPPORT = AMXFP8_MOE is not None +_HAS_BF16_SUPPORT = AMXBF16_MOE is not None +_HAS_FP8_PERCHANNEL_SUPPORT = AMXFP8PerChannel_MOE is not None class AMXMoEWrapper(BaseMoEWrapper): @@ -72,10 +69,17 @@ class AMXMoEWrapper(BaseMoEWrapper): max_deferred_experts_per_token: Number of experts per token to defer. Defaults to 0. method: AMX quantization method ("AMXINT4" or "AMXINT8") """ - if not _HAS_AMX_SUPPORT: + if method == "AMXINT4" and not _HAS_AMXINT4_SUPPORT: raise RuntimeError( - "AMX backend not available. kt_kernel_ext was not compiled with AMX support.\n" - "Please recompile with AMX enabled." + "AMXINT4 backend not available. Required ISA:\n" + " - AVX512F + AVX512BW (VNNI optional)\n" + "Please recompile kt_kernel_ext with AVX512 enabled." + ) + if method == "AMXINT8" and not _HAS_AMXINT8_SUPPORT: + raise RuntimeError( + "AMXINT8 backend not available. Required ISA:\n" + " - AVX512F + AVX512BW (VNNI optional)\n" + "Please recompile kt_kernel_ext with AVX512 enabled." ) # Initialize base class @@ -336,16 +340,30 @@ class NativeMoEWrapper(BaseMoEWrapper): max_deferred_experts_per_token: Optional[int] = None, method: str = "RAWINT4", ): - if not _HAS_AMX_SUPPORT: - raise RuntimeError("AMX backend is not available.") - if method == "RAWINT4" and AMXInt4_KGroup_MOE is None: - raise RuntimeError("AMX backend with RAWINT4 support is not available.") - if method == "FP8" and AMXFP8_MOE is None: - raise RuntimeError("AMX backend with FP8 support is not available.") + if method == "RAWINT4" and not _HAS_RAWINT4_SUPPORT: + raise RuntimeError( + "RAWINT4 backend not available. Required ISA:\n" + " - AVX512F + AVX512BW (VNNI optional)\n" + "Please recompile kt_kernel_ext with AVX512 enabled." + ) + if method == "FP8" and not _HAS_FP8_SUPPORT: + raise RuntimeError( + "FP8 backend not available. Required ISA:\n" + " - AVX512F + AVX512BW + AVX512_BF16 + AVX512_VBMI\n" + "Please recompile kt_kernel_ext with AVX512 + BF16 + VBMI enabled." + ) if method == "FP8_PERCHANNEL" and not _HAS_FP8_PERCHANNEL_SUPPORT: - raise RuntimeError("AMX backend with FP8 per-channel support is not available.") - if method == "BF16" and AMXBF16_MOE is None: - raise RuntimeError("AMX backend with BF16 support is not available.") + raise RuntimeError( + "FP8_PERCHANNEL backend not available. Required ISA:\n" + " - AVX512F + AVX512BW + AVX512_BF16 + AVX512_VBMI\n" + "Please recompile kt_kernel_ext with AVX512 + BF16 + VBMI enabled." + ) + if method == "BF16" and not _HAS_BF16_SUPPORT: + raise RuntimeError( + "BF16 backend not available. Required ISA:\n" + " - AVX512F + AVX512BW + AVX512_BF16\n" + "Please recompile kt_kernel_ext with AVX512 + BF16 enabled." + ) super().__init__( layer_idx=layer_idx, diff --git a/kt-kernel/setup.py b/kt-kernel/setup.py index 82f3c4c..97654cc 100644 --- a/kt-kernel/setup.py +++ b/kt-kernel/setup.py @@ -552,9 +552,10 @@ class CMakeBuild(build_ext): # These are passed to CMake to conditionally add compiler flags # Track if any AVX512 extension is enabled avx512_extension_enabled = False + allow_avx512_ext_auto = cpu_mode in ("NATIVE", "FANCY", "AVX512") if not _forward_bool_env(cmake_args, "CPUINFER_ENABLE_AVX512_VNNI", "LLAMA_AVX512_VNNI"): - if "AVX512_VNNI" in d["features"]: + if allow_avx512_ext_auto and "AVX512_VNNI" in d["features"]: cmake_args.append("-DLLAMA_AVX512_VNNI=ON") print("-- AVX512_VNNI detected; enabling (-DLLAMA_AVX512_VNNI=ON)") avx512_extension_enabled = True @@ -562,7 +563,7 @@ class CMakeBuild(build_ext): avx512_extension_enabled = True if not _forward_bool_env(cmake_args, "CPUINFER_ENABLE_AVX512_BF16", "LLAMA_AVX512_BF16"): - if "AVX512_BF16" in d["features"]: + if allow_avx512_ext_auto and "AVX512_BF16" in d["features"]: cmake_args.append("-DLLAMA_AVX512_BF16=ON") print("-- AVX512_BF16 detected; enabling (-DLLAMA_AVX512_BF16=ON)") avx512_extension_enabled = True @@ -570,7 +571,7 @@ class CMakeBuild(build_ext): avx512_extension_enabled = True if not _forward_bool_env(cmake_args, "CPUINFER_ENABLE_AVX512_VBMI", "LLAMA_AVX512_VBMI"): - if "AVX512_VBMI" in d["features"]: + if allow_avx512_ext_auto and "AVX512_VBMI" in d["features"]: cmake_args.append("-DLLAMA_AVX512_VBMI=ON") print("-- AVX512_VBMI detected; enabling (-DLLAMA_AVX512_VBMI=ON)") avx512_extension_enabled = True @@ -578,7 +579,7 @@ class CMakeBuild(build_ext): avx512_extension_enabled = True # If any AVX512 extension is enabled, ensure base AVX512 is also enabled - if avx512_extension_enabled and cpu_mode == "NATIVE": + if avx512_extension_enabled and cpu_mode in ("NATIVE", "FANCY", "AVX512"): if not any("LLAMA_AVX512=ON" in a for a in cmake_args): cmake_args.append("-DLLAMA_AVX512=ON") print("-- AVX512 extensions enabled; also enabling base AVX512F (-DLLAMA_AVX512=ON)")