mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2026-03-14 18:37:23 +00:00
[fix]: fix wrapper import issue (#1819)
This commit is contained in:
@@ -7,24 +7,21 @@ from typing import Optional
|
||||
from ..experts_base import BaseMoEWrapper
|
||||
from .loader import SafeTensorLoader, CompressedSafeTensorLoader, FP8SafeTensorLoader, BF16SafeTensorLoader
|
||||
from kt_kernel_ext.moe import MOEConfig
|
||||
import kt_kernel_ext.moe as _moe_mod
|
||||
|
||||
try:
|
||||
from kt_kernel_ext.moe import AMXInt4_MOE, AMXInt8_MOE, AMXInt4_KGroup_MOE, AMXFP8_MOE, AMXBF16_MOE
|
||||
AMXInt4_MOE = getattr(_moe_mod, "AMXInt4_MOE", None)
|
||||
AMXInt8_MOE = getattr(_moe_mod, "AMXInt8_MOE", None)
|
||||
AMXInt4_KGroup_MOE = getattr(_moe_mod, "AMXInt4_KGroup_MOE", None)
|
||||
AMXFP8_MOE = getattr(_moe_mod, "AMXFP8_MOE", None)
|
||||
AMXBF16_MOE = getattr(_moe_mod, "AMXBF16_MOE", None)
|
||||
AMXFP8PerChannel_MOE = getattr(_moe_mod, "AMXFP8PerChannel_MOE", None)
|
||||
|
||||
_HAS_AMX_SUPPORT = True
|
||||
except (ImportError, AttributeError):
|
||||
_HAS_AMX_SUPPORT = False
|
||||
AMXInt4_MOE, AMXInt8_MOE, AMXInt4_KGroup_MOE, AMXFP8_MOE, AMXBF16_MOE = None, None, None, None, None
|
||||
|
||||
try:
|
||||
from kt_kernel_ext.moe import AMXFP8PerChannel_MOE
|
||||
|
||||
_HAS_FP8_PERCHANNEL_SUPPORT = True
|
||||
except (ImportError, AttributeError):
|
||||
_HAS_FP8_PERCHANNEL_SUPPORT = False
|
||||
AMXFP8PerChannel_MOE = None
|
||||
|
||||
from typing import Optional
|
||||
_HAS_AMXINT4_SUPPORT = AMXInt4_MOE is not None
|
||||
_HAS_AMXINT8_SUPPORT = AMXInt8_MOE is not None
|
||||
_HAS_RAWINT4_SUPPORT = AMXInt4_KGroup_MOE is not None
|
||||
_HAS_FP8_SUPPORT = AMXFP8_MOE is not None
|
||||
_HAS_BF16_SUPPORT = AMXBF16_MOE is not None
|
||||
_HAS_FP8_PERCHANNEL_SUPPORT = AMXFP8PerChannel_MOE is not None
|
||||
|
||||
|
||||
class AMXMoEWrapper(BaseMoEWrapper):
|
||||
@@ -72,10 +69,17 @@ class AMXMoEWrapper(BaseMoEWrapper):
|
||||
max_deferred_experts_per_token: Number of experts per token to defer. Defaults to 0.
|
||||
method: AMX quantization method ("AMXINT4" or "AMXINT8")
|
||||
"""
|
||||
if not _HAS_AMX_SUPPORT:
|
||||
if method == "AMXINT4" and not _HAS_AMXINT4_SUPPORT:
|
||||
raise RuntimeError(
|
||||
"AMX backend not available. kt_kernel_ext was not compiled with AMX support.\n"
|
||||
"Please recompile with AMX enabled."
|
||||
"AMXINT4 backend not available. Required ISA:\n"
|
||||
" - AVX512F + AVX512BW (VNNI optional)\n"
|
||||
"Please recompile kt_kernel_ext with AVX512 enabled."
|
||||
)
|
||||
if method == "AMXINT8" and not _HAS_AMXINT8_SUPPORT:
|
||||
raise RuntimeError(
|
||||
"AMXINT8 backend not available. Required ISA:\n"
|
||||
" - AVX512F + AVX512BW (VNNI optional)\n"
|
||||
"Please recompile kt_kernel_ext with AVX512 enabled."
|
||||
)
|
||||
|
||||
# Initialize base class
|
||||
@@ -336,16 +340,30 @@ class NativeMoEWrapper(BaseMoEWrapper):
|
||||
max_deferred_experts_per_token: Optional[int] = None,
|
||||
method: str = "RAWINT4",
|
||||
):
|
||||
if not _HAS_AMX_SUPPORT:
|
||||
raise RuntimeError("AMX backend is not available.")
|
||||
if method == "RAWINT4" and AMXInt4_KGroup_MOE is None:
|
||||
raise RuntimeError("AMX backend with RAWINT4 support is not available.")
|
||||
if method == "FP8" and AMXFP8_MOE is None:
|
||||
raise RuntimeError("AMX backend with FP8 support is not available.")
|
||||
if method == "RAWINT4" and not _HAS_RAWINT4_SUPPORT:
|
||||
raise RuntimeError(
|
||||
"RAWINT4 backend not available. Required ISA:\n"
|
||||
" - AVX512F + AVX512BW (VNNI optional)\n"
|
||||
"Please recompile kt_kernel_ext with AVX512 enabled."
|
||||
)
|
||||
if method == "FP8" and not _HAS_FP8_SUPPORT:
|
||||
raise RuntimeError(
|
||||
"FP8 backend not available. Required ISA:\n"
|
||||
" - AVX512F + AVX512BW + AVX512_BF16 + AVX512_VBMI\n"
|
||||
"Please recompile kt_kernel_ext with AVX512 + BF16 + VBMI enabled."
|
||||
)
|
||||
if method == "FP8_PERCHANNEL" and not _HAS_FP8_PERCHANNEL_SUPPORT:
|
||||
raise RuntimeError("AMX backend with FP8 per-channel support is not available.")
|
||||
if method == "BF16" and AMXBF16_MOE is None:
|
||||
raise RuntimeError("AMX backend with BF16 support is not available.")
|
||||
raise RuntimeError(
|
||||
"FP8_PERCHANNEL backend not available. Required ISA:\n"
|
||||
" - AVX512F + AVX512BW + AVX512_BF16 + AVX512_VBMI\n"
|
||||
"Please recompile kt_kernel_ext with AVX512 + BF16 + VBMI enabled."
|
||||
)
|
||||
if method == "BF16" and not _HAS_BF16_SUPPORT:
|
||||
raise RuntimeError(
|
||||
"BF16 backend not available. Required ISA:\n"
|
||||
" - AVX512F + AVX512BW + AVX512_BF16\n"
|
||||
"Please recompile kt_kernel_ext with AVX512 + BF16 enabled."
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
layer_idx=layer_idx,
|
||||
|
||||
@@ -552,9 +552,10 @@ class CMakeBuild(build_ext):
|
||||
# These are passed to CMake to conditionally add compiler flags
|
||||
# Track if any AVX512 extension is enabled
|
||||
avx512_extension_enabled = False
|
||||
allow_avx512_ext_auto = cpu_mode in ("NATIVE", "FANCY", "AVX512")
|
||||
|
||||
if not _forward_bool_env(cmake_args, "CPUINFER_ENABLE_AVX512_VNNI", "LLAMA_AVX512_VNNI"):
|
||||
if "AVX512_VNNI" in d["features"]:
|
||||
if allow_avx512_ext_auto and "AVX512_VNNI" in d["features"]:
|
||||
cmake_args.append("-DLLAMA_AVX512_VNNI=ON")
|
||||
print("-- AVX512_VNNI detected; enabling (-DLLAMA_AVX512_VNNI=ON)")
|
||||
avx512_extension_enabled = True
|
||||
@@ -562,7 +563,7 @@ class CMakeBuild(build_ext):
|
||||
avx512_extension_enabled = True
|
||||
|
||||
if not _forward_bool_env(cmake_args, "CPUINFER_ENABLE_AVX512_BF16", "LLAMA_AVX512_BF16"):
|
||||
if "AVX512_BF16" in d["features"]:
|
||||
if allow_avx512_ext_auto and "AVX512_BF16" in d["features"]:
|
||||
cmake_args.append("-DLLAMA_AVX512_BF16=ON")
|
||||
print("-- AVX512_BF16 detected; enabling (-DLLAMA_AVX512_BF16=ON)")
|
||||
avx512_extension_enabled = True
|
||||
@@ -570,7 +571,7 @@ class CMakeBuild(build_ext):
|
||||
avx512_extension_enabled = True
|
||||
|
||||
if not _forward_bool_env(cmake_args, "CPUINFER_ENABLE_AVX512_VBMI", "LLAMA_AVX512_VBMI"):
|
||||
if "AVX512_VBMI" in d["features"]:
|
||||
if allow_avx512_ext_auto and "AVX512_VBMI" in d["features"]:
|
||||
cmake_args.append("-DLLAMA_AVX512_VBMI=ON")
|
||||
print("-- AVX512_VBMI detected; enabling (-DLLAMA_AVX512_VBMI=ON)")
|
||||
avx512_extension_enabled = True
|
||||
@@ -578,7 +579,7 @@ class CMakeBuild(build_ext):
|
||||
avx512_extension_enabled = True
|
||||
|
||||
# If any AVX512 extension is enabled, ensure base AVX512 is also enabled
|
||||
if avx512_extension_enabled and cpu_mode == "NATIVE":
|
||||
if avx512_extension_enabled and cpu_mode in ("NATIVE", "FANCY", "AVX512"):
|
||||
if not any("LLAMA_AVX512=ON" in a for a in cmake_args):
|
||||
cmake_args.append("-DLLAMA_AVX512=ON")
|
||||
print("-- AVX512 extensions enabled; also enabling base AVX512F (-DLLAMA_AVX512=ON)")
|
||||
|
||||
Reference in New Issue
Block a user