[fix]: fix wrapper import issue (#1819)

This commit is contained in:
Jiaqi Liao
2026-01-28 16:31:56 +08:00
committed by GitHub
parent 8321d00cc5
commit edc48aba37
2 changed files with 51 additions and 32 deletions

View File

@@ -7,24 +7,21 @@ from typing import Optional
from ..experts_base import BaseMoEWrapper
from .loader import SafeTensorLoader, CompressedSafeTensorLoader, FP8SafeTensorLoader, BF16SafeTensorLoader
from kt_kernel_ext.moe import MOEConfig
import kt_kernel_ext.moe as _moe_mod
try:
from kt_kernel_ext.moe import AMXInt4_MOE, AMXInt8_MOE, AMXInt4_KGroup_MOE, AMXFP8_MOE, AMXBF16_MOE
AMXInt4_MOE = getattr(_moe_mod, "AMXInt4_MOE", None)
AMXInt8_MOE = getattr(_moe_mod, "AMXInt8_MOE", None)
AMXInt4_KGroup_MOE = getattr(_moe_mod, "AMXInt4_KGroup_MOE", None)
AMXFP8_MOE = getattr(_moe_mod, "AMXFP8_MOE", None)
AMXBF16_MOE = getattr(_moe_mod, "AMXBF16_MOE", None)
AMXFP8PerChannel_MOE = getattr(_moe_mod, "AMXFP8PerChannel_MOE", None)
_HAS_AMX_SUPPORT = True
except (ImportError, AttributeError):
_HAS_AMX_SUPPORT = False
AMXInt4_MOE, AMXInt8_MOE, AMXInt4_KGroup_MOE, AMXFP8_MOE, AMXBF16_MOE = None, None, None, None, None
try:
from kt_kernel_ext.moe import AMXFP8PerChannel_MOE
_HAS_FP8_PERCHANNEL_SUPPORT = True
except (ImportError, AttributeError):
_HAS_FP8_PERCHANNEL_SUPPORT = False
AMXFP8PerChannel_MOE = None
from typing import Optional
_HAS_AMXINT4_SUPPORT = AMXInt4_MOE is not None
_HAS_AMXINT8_SUPPORT = AMXInt8_MOE is not None
_HAS_RAWINT4_SUPPORT = AMXInt4_KGroup_MOE is not None
_HAS_FP8_SUPPORT = AMXFP8_MOE is not None
_HAS_BF16_SUPPORT = AMXBF16_MOE is not None
_HAS_FP8_PERCHANNEL_SUPPORT = AMXFP8PerChannel_MOE is not None
class AMXMoEWrapper(BaseMoEWrapper):
@@ -72,10 +69,17 @@ class AMXMoEWrapper(BaseMoEWrapper):
max_deferred_experts_per_token: Number of experts per token to defer. Defaults to 0.
method: AMX quantization method ("AMXINT4" or "AMXINT8")
"""
if not _HAS_AMX_SUPPORT:
if method == "AMXINT4" and not _HAS_AMXINT4_SUPPORT:
raise RuntimeError(
"AMX backend not available. kt_kernel_ext was not compiled with AMX support.\n"
"Please recompile with AMX enabled."
"AMXINT4 backend not available. Required ISA:\n"
" - AVX512F + AVX512BW (VNNI optional)\n"
"Please recompile kt_kernel_ext with AVX512 enabled."
)
if method == "AMXINT8" and not _HAS_AMXINT8_SUPPORT:
raise RuntimeError(
"AMXINT8 backend not available. Required ISA:\n"
" - AVX512F + AVX512BW (VNNI optional)\n"
"Please recompile kt_kernel_ext with AVX512 enabled."
)
# Initialize base class
@@ -336,16 +340,30 @@ class NativeMoEWrapper(BaseMoEWrapper):
max_deferred_experts_per_token: Optional[int] = None,
method: str = "RAWINT4",
):
if not _HAS_AMX_SUPPORT:
raise RuntimeError("AMX backend is not available.")
if method == "RAWINT4" and AMXInt4_KGroup_MOE is None:
raise RuntimeError("AMX backend with RAWINT4 support is not available.")
if method == "FP8" and AMXFP8_MOE is None:
raise RuntimeError("AMX backend with FP8 support is not available.")
if method == "RAWINT4" and not _HAS_RAWINT4_SUPPORT:
raise RuntimeError(
"RAWINT4 backend not available. Required ISA:\n"
" - AVX512F + AVX512BW (VNNI optional)\n"
"Please recompile kt_kernel_ext with AVX512 enabled."
)
if method == "FP8" and not _HAS_FP8_SUPPORT:
raise RuntimeError(
"FP8 backend not available. Required ISA:\n"
" - AVX512F + AVX512BW + AVX512_BF16 + AVX512_VBMI\n"
"Please recompile kt_kernel_ext with AVX512 + BF16 + VBMI enabled."
)
if method == "FP8_PERCHANNEL" and not _HAS_FP8_PERCHANNEL_SUPPORT:
raise RuntimeError("AMX backend with FP8 per-channel support is not available.")
if method == "BF16" and AMXBF16_MOE is None:
raise RuntimeError("AMX backend with BF16 support is not available.")
raise RuntimeError(
"FP8_PERCHANNEL backend not available. Required ISA:\n"
" - AVX512F + AVX512BW + AVX512_BF16 + AVX512_VBMI\n"
"Please recompile kt_kernel_ext with AVX512 + BF16 + VBMI enabled."
)
if method == "BF16" and not _HAS_BF16_SUPPORT:
raise RuntimeError(
"BF16 backend not available. Required ISA:\n"
" - AVX512F + AVX512BW + AVX512_BF16\n"
"Please recompile kt_kernel_ext with AVX512 + BF16 enabled."
)
super().__init__(
layer_idx=layer_idx,

View File

@@ -552,9 +552,10 @@ class CMakeBuild(build_ext):
# These are passed to CMake to conditionally add compiler flags
# Track if any AVX512 extension is enabled
avx512_extension_enabled = False
allow_avx512_ext_auto = cpu_mode in ("NATIVE", "FANCY", "AVX512")
if not _forward_bool_env(cmake_args, "CPUINFER_ENABLE_AVX512_VNNI", "LLAMA_AVX512_VNNI"):
if "AVX512_VNNI" in d["features"]:
if allow_avx512_ext_auto and "AVX512_VNNI" in d["features"]:
cmake_args.append("-DLLAMA_AVX512_VNNI=ON")
print("-- AVX512_VNNI detected; enabling (-DLLAMA_AVX512_VNNI=ON)")
avx512_extension_enabled = True
@@ -562,7 +563,7 @@ class CMakeBuild(build_ext):
avx512_extension_enabled = True
if not _forward_bool_env(cmake_args, "CPUINFER_ENABLE_AVX512_BF16", "LLAMA_AVX512_BF16"):
if "AVX512_BF16" in d["features"]:
if allow_avx512_ext_auto and "AVX512_BF16" in d["features"]:
cmake_args.append("-DLLAMA_AVX512_BF16=ON")
print("-- AVX512_BF16 detected; enabling (-DLLAMA_AVX512_BF16=ON)")
avx512_extension_enabled = True
@@ -570,7 +571,7 @@ class CMakeBuild(build_ext):
avx512_extension_enabled = True
if not _forward_bool_env(cmake_args, "CPUINFER_ENABLE_AVX512_VBMI", "LLAMA_AVX512_VBMI"):
if "AVX512_VBMI" in d["features"]:
if allow_avx512_ext_auto and "AVX512_VBMI" in d["features"]:
cmake_args.append("-DLLAMA_AVX512_VBMI=ON")
print("-- AVX512_VBMI detected; enabling (-DLLAMA_AVX512_VBMI=ON)")
avx512_extension_enabled = True
@@ -578,7 +579,7 @@ class CMakeBuild(build_ext):
avx512_extension_enabled = True
# If any AVX512 extension is enabled, ensure base AVX512 is also enabled
if avx512_extension_enabled and cpu_mode == "NATIVE":
if avx512_extension_enabled and cpu_mode in ("NATIVE", "FANCY", "AVX512"):
if not any("LLAMA_AVX512=ON" in a for a in cmake_args):
cmake_args.append("-DLLAMA_AVX512=ON")
print("-- AVX512 extensions enabled; also enabling base AVX512F (-DLLAMA_AVX512=ON)")