[feat](kt-kernel): Fix CPU instruction set variants for build & install (#1746)

* [feat]: Enhance CPU feature detection and support for AVX512 extensions

- Added cmake/DetectCPU.cmake for automatic CPU feature detection.
- Updated CMakeLists.txt to include auto-detection logic for AVX512 features.
- Modified install.sh to include new AVX512_VBMI option for FP8 MoE.
- Enhanced _cpu_detect.py to support progressive matching of CPU variants.
- Created scripts/check_cpu_features.py for manual CPU feature checks.
- Updated setup.py to reflect changes in CPU variant building and environment variables.

* [fix](kt-kernel): Add conditional inclusion of FP8 MoE for AVX512 BF16 support

* [chore](kt-kernel): update project version to 0.5.0 in CMakeLists.txt and version.py
This commit is contained in:
Jiaqi Liao
2025-12-24 18:57:45 +08:00
committed by GitHub
parent dc5feece8f
commit 46b0f36980
7 changed files with 570 additions and 114 deletions

View File

@@ -25,19 +25,24 @@ from pathlib import Path
def detect_cpu_features():
"""
Detect CPU features to determine the best kernel variant.
Detect CPU features and determine the best kernel variant using progressive matching.
Detection hierarchy:
1. AMX: Intel Sapphire Rapids+ with AMX support
2. AVX512: CPUs with AVX512F support
3. AVX2: Fallback for maximum compatibility
Progressive variant hierarchy (from most to least advanced):
1. AMX: amx_tile, amx_int8, amx_bf16 + full AVX512
2. AVX512_BF16: avx512f, avx512bw, avx512_vnni, avx512_vbmi, avx512_bf16
3. AVX512_VBMI: avx512f, avx512bw, avx512_vnni, avx512_vbmi
4. AVX512_VNNI: avx512f, avx512bw, avx512_vnni
5. AVX512_BASE: avx512f, avx512bw
6. AVX2: avx2 (fallback)
Returns:
str: 'amx', 'avx512', or 'avx2'
str: Variant name - one of: 'amx', 'avx512_bf16', 'avx512_vbmi',
'avx512_vnni', 'avx512_base', 'avx2'
"""
# Check environment override
variant = os.environ.get("KT_KERNEL_CPU_VARIANT", "").lower()
if variant in ["amx", "avx512", "avx2"]:
valid_variants = ["amx", "avx512_bf16", "avx512_vbmi", "avx512_vnni", "avx512_base", "avx2"]
if variant in valid_variants:
if os.environ.get("KT_KERNEL_DEBUG") == "1":
print(f"[kt-kernel] Using environment override: {variant}")
return variant
@@ -47,32 +52,57 @@ def detect_cpu_features():
with open("/proc/cpuinfo", "r") as f:
cpuinfo = f.read().lower()
# Check for AMX support (Intel Sapphire Rapids+)
# AMX requires amx_tile, amx_int8, and amx_bf16
amx_flags = ["amx_tile", "amx_int8", "amx_bf16"]
has_amx = all(flag in cpuinfo for flag in amx_flags)
# Extract CPU flags into a set for fast lookup
cpu_flags = set()
for line in cpuinfo.split("\n"):
if line.startswith("flags"):
flags_str = line.split(":", 1)[1]
cpu_flags = set(flags_str.split())
break
if has_amx:
if os.environ.get("KT_KERNEL_DEBUG") == "1":
print("[kt-kernel] Detected AMX support via /proc/cpuinfo")
return "amx"
# Define variant requirements in priority order (best to worst)
variant_requirements = [
(
"amx",
[
"amx_tile",
"amx_int8",
"amx_bf16",
"avx512f",
"avx512bw",
"avx512_vnni",
"avx512_vbmi",
"avx512_bf16",
],
),
("avx512_bf16", ["avx512f", "avx512bw", "avx512_vnni", "avx512_vbmi", "avx512_bf16"]),
("avx512_vbmi", ["avx512f", "avx512bw", "avx512_vnni", "avx512_vbmi"]),
("avx512_vnni", ["avx512f", "avx512bw", "avx512_vnni"]),
("avx512_base", ["avx512f", "avx512bw"]),
("avx2", ["avx2"]),
]
# Check for AVX512 support
# AVX512F is the foundation for all AVX512 variants
if "avx512f" in cpuinfo:
if os.environ.get("KT_KERNEL_DEBUG") == "1":
print("[kt-kernel] Detected AVX512 support via /proc/cpuinfo")
return "avx512"
# Find the best matching variant
for variant_name, required_flags in variant_requirements:
# Check if all required flags are present
# Handle flag name variations (e.g., avx512_bf16 vs avx512bf16)
has_all_flags = True
for flag in required_flags:
# Try exact match first, then without underscore
flag_alt = flag.replace("_", "")
if flag not in cpu_flags and flag_alt not in cpu_flags:
has_all_flags = False
break
# Check for AVX2 support
if "avx2" in cpuinfo:
if os.environ.get("KT_KERNEL_DEBUG") == "1":
print("[kt-kernel] Detected AVX2 support via /proc/cpuinfo")
return "avx2"
if has_all_flags:
if os.environ.get("KT_KERNEL_DEBUG") == "1":
print(f"[kt-kernel] Detected {variant_name} support via /proc/cpuinfo")
print(f"[kt-kernel] Matched flags: {', '.join(required_flags)}")
return variant_name
# Fallback to AVX2 (should be rare on modern CPUs)
if os.environ.get("KT_KERNEL_DEBUG") == "1":
print("[kt-kernel] No AVX2/AVX512/AMX detected, using AVX2 fallback")
print("[kt-kernel] No supported features detected, using AVX2 fallback")
return "avx2"
except FileNotFoundError:
@@ -84,17 +114,35 @@ def detect_cpu_features():
try:
import cpufeature
# Check for AMX
if cpufeature.CPUFeature.get("AMX_TILE", False):
if os.environ.get("KT_KERNEL_DEBUG") == "1":
print("[kt-kernel] Detected AMX support via cpufeature")
return "amx"
# Define variant requirements in priority order (using cpufeature naming)
cpufeature_requirements = [
(
"amx",
[
"AMX_TILE",
"AMX_INT8",
"AMX_BF16",
"AVX512F",
"AVX512BW",
"AVX512_VNNI",
"AVX512_VBMI",
"AVX512_BF16",
],
),
("avx512_bf16", ["AVX512F", "AVX512BW", "AVX512_VNNI", "AVX512_VBMI", "AVX512_BF16"]),
("avx512_vbmi", ["AVX512F", "AVX512BW", "AVX512_VNNI", "AVX512_VBMI"]),
("avx512_vnni", ["AVX512F", "AVX512BW", "AVX512_VNNI"]),
("avx512_base", ["AVX512F", "AVX512BW"]),
("avx2", ["AVX2"]),
]
# Check for AVX512
if cpufeature.CPUFeature.get("AVX512F", False):
if os.environ.get("KT_KERNEL_DEBUG") == "1":
print("[kt-kernel] Detected AVX512 support via cpufeature")
return "avx512"
# Find the best matching variant
for variant_name, required_features in cpufeature_requirements:
has_all_features = all(cpufeature.CPUFeature.get(feat, False) for feat in required_features)
if has_all_features:
if os.environ.get("KT_KERNEL_DEBUG") == "1":
print(f"[kt-kernel] Detected {variant_name} support via cpufeature")
return variant_name
# Fallback to AVX2
if os.environ.get("KT_KERNEL_DEBUG") == "1":
@@ -124,10 +172,11 @@ def load_extension(variant):
Supports both multi-variant builds (_kt_kernel_ext_amx.*.so) and
single-variant builds (kt_kernel_ext.*.so).
Fallback order: amx -> avx512 -> avx2 -> single-variant
Fallback chain (each variant falls back to the next in line):
amx -> avx512_bf16 -> avx512_vbmi -> avx512_vnni -> avx512_base -> avx2 -> single-variant
Args:
variant (str): 'amx', 'avx512', or 'avx2'
variant (str): One of 'amx', 'avx512_bf16', 'avx512_vbmi', 'avx512_vnni', 'avx512_base', 'avx2'
Returns:
module: The loaded extension module
@@ -187,15 +236,24 @@ def load_extension(variant):
if os.environ.get("KT_KERNEL_DEBUG") == "1":
print(f"[kt-kernel] Failed to load {variant} variant: {e}")
# Automatic fallback to next best variant
if variant == "amx":
# Define fallback chain: each variant falls back to the next lower one
fallback_chain = {
"amx": "avx512_bf16",
"avx512_bf16": "avx512_vbmi",
"avx512_vbmi": "avx512_vnni",
"avx512_vnni": "avx512_base",
"avx512_base": "avx2",
"avx2": None, # No fallback - terminal variant
}
# Get next fallback variant
next_variant = fallback_chain.get(variant)
if next_variant:
# Try next variant in the chain
if os.environ.get("KT_KERNEL_DEBUG") == "1":
print("[kt-kernel] Falling back from AMX to AVX512")
return load_extension("avx512")
elif variant == "avx512":
if os.environ.get("KT_KERNEL_DEBUG") == "1":
print("[kt-kernel] Falling back from AVX512 to AVX2")
return load_extension("avx2")
print(f"[kt-kernel] Falling back from {variant} to {next_variant}")
return load_extension(next_variant)
else:
# AVX2 is the last fallback - if this fails, we can't continue
raise ImportError(