mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2026-04-20 14:29:22 +00:00
[feat](kt-kernel): Fix CPU instruction set variants for build & install (#1746)
* [feat]: Enhance CPU feature detection and support for AVX512 extensions - Added cmake/DetectCPU.cmake for automatic CPU feature detection. - Updated CMakeLists.txt to include auto-detection logic for AVX512 features. - Modified install.sh to include new AVX512_VBMI option for FP8 MoE. - Enhanced _cpu_detect.py to support progressive matching of CPU variants. - Created scripts/check_cpu_features.py for manual CPU feature checks. - Updated setup.py to reflect changes in CPU variant building and environment variables. * [fix](kt-kernel): Add conditional inclusion of FP8 MoE for AVX512 BF16 support * [chore](kt-kernel): update project version to 0.5.0 in CMakeLists.txt and version.py
This commit is contained in:
@@ -25,19 +25,24 @@ from pathlib import Path
|
||||
|
||||
def detect_cpu_features():
|
||||
"""
|
||||
Detect CPU features to determine the best kernel variant.
|
||||
Detect CPU features and determine the best kernel variant using progressive matching.
|
||||
|
||||
Detection hierarchy:
|
||||
1. AMX: Intel Sapphire Rapids+ with AMX support
|
||||
2. AVX512: CPUs with AVX512F support
|
||||
3. AVX2: Fallback for maximum compatibility
|
||||
Progressive variant hierarchy (from most to least advanced):
|
||||
1. AMX: amx_tile, amx_int8, amx_bf16 + full AVX512
|
||||
2. AVX512_BF16: avx512f, avx512bw, avx512_vnni, avx512_vbmi, avx512_bf16
|
||||
3. AVX512_VBMI: avx512f, avx512bw, avx512_vnni, avx512_vbmi
|
||||
4. AVX512_VNNI: avx512f, avx512bw, avx512_vnni
|
||||
5. AVX512_BASE: avx512f, avx512bw
|
||||
6. AVX2: avx2 (fallback)
|
||||
|
||||
Returns:
|
||||
str: 'amx', 'avx512', or 'avx2'
|
||||
str: Variant name - one of: 'amx', 'avx512_bf16', 'avx512_vbmi',
|
||||
'avx512_vnni', 'avx512_base', 'avx2'
|
||||
"""
|
||||
# Check environment override
|
||||
variant = os.environ.get("KT_KERNEL_CPU_VARIANT", "").lower()
|
||||
if variant in ["amx", "avx512", "avx2"]:
|
||||
valid_variants = ["amx", "avx512_bf16", "avx512_vbmi", "avx512_vnni", "avx512_base", "avx2"]
|
||||
if variant in valid_variants:
|
||||
if os.environ.get("KT_KERNEL_DEBUG") == "1":
|
||||
print(f"[kt-kernel] Using environment override: {variant}")
|
||||
return variant
|
||||
@@ -47,32 +52,57 @@ def detect_cpu_features():
|
||||
with open("/proc/cpuinfo", "r") as f:
|
||||
cpuinfo = f.read().lower()
|
||||
|
||||
# Check for AMX support (Intel Sapphire Rapids+)
|
||||
# AMX requires amx_tile, amx_int8, and amx_bf16
|
||||
amx_flags = ["amx_tile", "amx_int8", "amx_bf16"]
|
||||
has_amx = all(flag in cpuinfo for flag in amx_flags)
|
||||
# Extract CPU flags into a set for fast lookup
|
||||
cpu_flags = set()
|
||||
for line in cpuinfo.split("\n"):
|
||||
if line.startswith("flags"):
|
||||
flags_str = line.split(":", 1)[1]
|
||||
cpu_flags = set(flags_str.split())
|
||||
break
|
||||
|
||||
if has_amx:
|
||||
if os.environ.get("KT_KERNEL_DEBUG") == "1":
|
||||
print("[kt-kernel] Detected AMX support via /proc/cpuinfo")
|
||||
return "amx"
|
||||
# Define variant requirements in priority order (best to worst)
|
||||
variant_requirements = [
|
||||
(
|
||||
"amx",
|
||||
[
|
||||
"amx_tile",
|
||||
"amx_int8",
|
||||
"amx_bf16",
|
||||
"avx512f",
|
||||
"avx512bw",
|
||||
"avx512_vnni",
|
||||
"avx512_vbmi",
|
||||
"avx512_bf16",
|
||||
],
|
||||
),
|
||||
("avx512_bf16", ["avx512f", "avx512bw", "avx512_vnni", "avx512_vbmi", "avx512_bf16"]),
|
||||
("avx512_vbmi", ["avx512f", "avx512bw", "avx512_vnni", "avx512_vbmi"]),
|
||||
("avx512_vnni", ["avx512f", "avx512bw", "avx512_vnni"]),
|
||||
("avx512_base", ["avx512f", "avx512bw"]),
|
||||
("avx2", ["avx2"]),
|
||||
]
|
||||
|
||||
# Check for AVX512 support
|
||||
# AVX512F is the foundation for all AVX512 variants
|
||||
if "avx512f" in cpuinfo:
|
||||
if os.environ.get("KT_KERNEL_DEBUG") == "1":
|
||||
print("[kt-kernel] Detected AVX512 support via /proc/cpuinfo")
|
||||
return "avx512"
|
||||
# Find the best matching variant
|
||||
for variant_name, required_flags in variant_requirements:
|
||||
# Check if all required flags are present
|
||||
# Handle flag name variations (e.g., avx512_bf16 vs avx512bf16)
|
||||
has_all_flags = True
|
||||
for flag in required_flags:
|
||||
# Try exact match first, then without underscore
|
||||
flag_alt = flag.replace("_", "")
|
||||
if flag not in cpu_flags and flag_alt not in cpu_flags:
|
||||
has_all_flags = False
|
||||
break
|
||||
|
||||
# Check for AVX2 support
|
||||
if "avx2" in cpuinfo:
|
||||
if os.environ.get("KT_KERNEL_DEBUG") == "1":
|
||||
print("[kt-kernel] Detected AVX2 support via /proc/cpuinfo")
|
||||
return "avx2"
|
||||
if has_all_flags:
|
||||
if os.environ.get("KT_KERNEL_DEBUG") == "1":
|
||||
print(f"[kt-kernel] Detected {variant_name} support via /proc/cpuinfo")
|
||||
print(f"[kt-kernel] Matched flags: {', '.join(required_flags)}")
|
||||
return variant_name
|
||||
|
||||
# Fallback to AVX2 (should be rare on modern CPUs)
|
||||
if os.environ.get("KT_KERNEL_DEBUG") == "1":
|
||||
print("[kt-kernel] No AVX2/AVX512/AMX detected, using AVX2 fallback")
|
||||
print("[kt-kernel] No supported features detected, using AVX2 fallback")
|
||||
return "avx2"
|
||||
|
||||
except FileNotFoundError:
|
||||
@@ -84,17 +114,35 @@ def detect_cpu_features():
|
||||
try:
|
||||
import cpufeature
|
||||
|
||||
# Check for AMX
|
||||
if cpufeature.CPUFeature.get("AMX_TILE", False):
|
||||
if os.environ.get("KT_KERNEL_DEBUG") == "1":
|
||||
print("[kt-kernel] Detected AMX support via cpufeature")
|
||||
return "amx"
|
||||
# Define variant requirements in priority order (using cpufeature naming)
|
||||
cpufeature_requirements = [
|
||||
(
|
||||
"amx",
|
||||
[
|
||||
"AMX_TILE",
|
||||
"AMX_INT8",
|
||||
"AMX_BF16",
|
||||
"AVX512F",
|
||||
"AVX512BW",
|
||||
"AVX512_VNNI",
|
||||
"AVX512_VBMI",
|
||||
"AVX512_BF16",
|
||||
],
|
||||
),
|
||||
("avx512_bf16", ["AVX512F", "AVX512BW", "AVX512_VNNI", "AVX512_VBMI", "AVX512_BF16"]),
|
||||
("avx512_vbmi", ["AVX512F", "AVX512BW", "AVX512_VNNI", "AVX512_VBMI"]),
|
||||
("avx512_vnni", ["AVX512F", "AVX512BW", "AVX512_VNNI"]),
|
||||
("avx512_base", ["AVX512F", "AVX512BW"]),
|
||||
("avx2", ["AVX2"]),
|
||||
]
|
||||
|
||||
# Check for AVX512
|
||||
if cpufeature.CPUFeature.get("AVX512F", False):
|
||||
if os.environ.get("KT_KERNEL_DEBUG") == "1":
|
||||
print("[kt-kernel] Detected AVX512 support via cpufeature")
|
||||
return "avx512"
|
||||
# Find the best matching variant
|
||||
for variant_name, required_features in cpufeature_requirements:
|
||||
has_all_features = all(cpufeature.CPUFeature.get(feat, False) for feat in required_features)
|
||||
if has_all_features:
|
||||
if os.environ.get("KT_KERNEL_DEBUG") == "1":
|
||||
print(f"[kt-kernel] Detected {variant_name} support via cpufeature")
|
||||
return variant_name
|
||||
|
||||
# Fallback to AVX2
|
||||
if os.environ.get("KT_KERNEL_DEBUG") == "1":
|
||||
@@ -124,10 +172,11 @@ def load_extension(variant):
|
||||
Supports both multi-variant builds (_kt_kernel_ext_amx.*.so) and
|
||||
single-variant builds (kt_kernel_ext.*.so).
|
||||
|
||||
Fallback order: amx -> avx512 -> avx2 -> single-variant
|
||||
Fallback chain (each variant falls back to the next in line):
|
||||
amx -> avx512_bf16 -> avx512_vbmi -> avx512_vnni -> avx512_base -> avx2 -> single-variant
|
||||
|
||||
Args:
|
||||
variant (str): 'amx', 'avx512', or 'avx2'
|
||||
variant (str): One of 'amx', 'avx512_bf16', 'avx512_vbmi', 'avx512_vnni', 'avx512_base', 'avx2'
|
||||
|
||||
Returns:
|
||||
module: The loaded extension module
|
||||
@@ -187,15 +236,24 @@ def load_extension(variant):
|
||||
if os.environ.get("KT_KERNEL_DEBUG") == "1":
|
||||
print(f"[kt-kernel] Failed to load {variant} variant: {e}")
|
||||
|
||||
# Automatic fallback to next best variant
|
||||
if variant == "amx":
|
||||
# Define fallback chain: each variant falls back to the next lower one
|
||||
fallback_chain = {
|
||||
"amx": "avx512_bf16",
|
||||
"avx512_bf16": "avx512_vbmi",
|
||||
"avx512_vbmi": "avx512_vnni",
|
||||
"avx512_vnni": "avx512_base",
|
||||
"avx512_base": "avx2",
|
||||
"avx2": None, # No fallback - terminal variant
|
||||
}
|
||||
|
||||
# Get next fallback variant
|
||||
next_variant = fallback_chain.get(variant)
|
||||
|
||||
if next_variant:
|
||||
# Try next variant in the chain
|
||||
if os.environ.get("KT_KERNEL_DEBUG") == "1":
|
||||
print("[kt-kernel] Falling back from AMX to AVX512")
|
||||
return load_extension("avx512")
|
||||
elif variant == "avx512":
|
||||
if os.environ.get("KT_KERNEL_DEBUG") == "1":
|
||||
print("[kt-kernel] Falling back from AVX512 to AVX2")
|
||||
return load_extension("avx2")
|
||||
print(f"[kt-kernel] Falling back from {variant} to {next_variant}")
|
||||
return load_extension(next_variant)
|
||||
else:
|
||||
# AVX2 is the last fallback - if this fails, we can't continue
|
||||
raise ImportError(
|
||||
|
||||
Reference in New Issue
Block a user