mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2026-03-14 18:37:23 +00:00
[feat](kt-kernel): Fix CPU instruction set variants for build & install (#1746)
* [feat]: Enhance CPU feature detection and support for AVX512 extensions - Added cmake/DetectCPU.cmake for automatic CPU feature detection. - Updated CMakeLists.txt to include auto-detection logic for AVX512 features. - Modified install.sh to include new AVX512_VBMI option for FP8 MoE. - Enhanced _cpu_detect.py to support progressive matching of CPU variants. - Created scripts/check_cpu_features.py for manual CPU feature checks. - Updated setup.py to reflect changes in CPU variant building and environment variables. * [fix](kt-kernel): Add conditional inclusion of FP8 MoE for AVX512 BF16 support * [chore](kt-kernel): update project version to 0.5.0 in CMakeLists.txt and version.py
This commit is contained in:
@@ -5,10 +5,8 @@ option(USE_CONDA_TOOLCHAIN "Use C/C++ compilers and libraries from active conda
|
||||
option(LLAMA_NATIVE "llama: enable -march=native flag" OFF)
|
||||
option(LLAMA_AVX "llama: enable AVX" OFF)
|
||||
option(LLAMA_AVX2 "llama: enable AVX2" OFF)
|
||||
option(LLAMA_AVX512 "llama: enable AVX512" OFF)
|
||||
option(LLAMA_AVX512_VBMI "llama: enable AVX512-VBMI" OFF)
|
||||
option(LLAMA_AVX512_VNNI "llama: enable AVX512-VNNI" OFF)
|
||||
option(LLAMA_AVX512_BF16 "llama: enable AVX512-BF16" OFF)
|
||||
# AVX512 options will be auto-detected by cmake/DetectCPU.cmake
|
||||
# Users can override with -DLLAMA_AVX512=OFF etc.
|
||||
option(LLAMA_FMA "llama: enable FMA" OFF)
|
||||
# in MSVC F16C is implied with AVX2/AVX512
|
||||
if(NOT MSVC)
|
||||
@@ -28,7 +26,13 @@ option(KTRANSFORMERS_CPU_MOE_AMD "ktransformers: CPU use moe kernel for amd" OFF
|
||||
# LTO control
|
||||
option(CPUINFER_ENABLE_LTO "Enable link time optimization (IPO)" OFF)
|
||||
|
||||
project(kt_kernel_ext VERSION 0.4.4)
|
||||
project(kt_kernel_ext VERSION 0.5.0)
|
||||
|
||||
# Auto-detect CPU features early (unless building with LLAMA_NATIVE)
|
||||
if(NOT LLAMA_NATIVE AND NOT MSVC)
|
||||
include(cmake/DetectCPU.cmake)
|
||||
endif()
|
||||
|
||||
# Choose compilers BEFORE project() so CMake honors them
|
||||
if(USE_CONDA_TOOLCHAIN)
|
||||
if(NOT DEFINED ENV{CONDA_PREFIX} OR NOT EXISTS "$ENV{CONDA_PREFIX}")
|
||||
@@ -378,20 +382,14 @@ if(HOST_IS_X86)
|
||||
target_link_libraries(${test_name} llama OpenMP::OpenMP_CXX numa)
|
||||
endforeach()
|
||||
endif()
|
||||
# Note: AVX512 subset flags (-mavx512vnni, -mavx512bf16) are already added
|
||||
# in the generic x86 detection block above (lines 276-289) when corresponding
|
||||
# LLAMA_AVX512_* options are enabled. No need to add them again here.
|
||||
# -mfma is already added by LLAMA_NATIVE (line 254), LLAMA_AVX*, or LLAMA_FMA blocks.
|
||||
|
||||
# AVX512 extensions are auto-detected by cmake/DetectCPU.cmake
|
||||
# Users can override with -DLLAMA_AVX512_BF16=OFF etc.
|
||||
# Only add -mf16c if LLAMA_F16C is not already enabled.
|
||||
if(NOT LLAMA_F16C)
|
||||
list(APPEND ARCH_FLAGS -mf16c)
|
||||
endif()
|
||||
if(LLAMA_AVX512_VNNI)
|
||||
message(STATUS "AVX512_VNNI enabled")
|
||||
endif()
|
||||
if(LLAMA_AVX512_BF16)
|
||||
message(STATUS "AVX512_BF16 enabled")
|
||||
endif()
|
||||
message(STATUS "AVX512 extensions: F=${LLAMA_AVX512}, BF16=${LLAMA_AVX512_BF16}, VNNI=${LLAMA_AVX512_VNNI}, VBMI=${LLAMA_AVX512_VBMI}")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
|
||||
142
kt-kernel/cmake/DetectCPU.cmake
Normal file
142
kt-kernel/cmake/DetectCPU.cmake
Normal file
@@ -0,0 +1,142 @@
|
||||
# CPU Feature Detection for kt-kernel
|
||||
# Detects CPU capabilities and sets appropriate compiler flags
|
||||
|
||||
function(detect_cpu_features)
|
||||
set(HAS_AVX2 OFF PARENT_SCOPE)
|
||||
set(HAS_AVX512F OFF PARENT_SCOPE)
|
||||
set(HAS_AVX512_VNNI OFF PARENT_SCOPE)
|
||||
set(HAS_AVX512_BF16 OFF PARENT_SCOPE)
|
||||
set(HAS_AVX512_VBMI OFF PARENT_SCOPE)
|
||||
set(HAS_AMX OFF PARENT_SCOPE)
|
||||
|
||||
if(NOT EXISTS "/proc/cpuinfo")
|
||||
message(STATUS "CPU detection: /proc/cpuinfo not found, skipping auto-detection")
|
||||
return()
|
||||
endif()
|
||||
|
||||
# Read CPU flags from /proc/cpuinfo
|
||||
file(READ "/proc/cpuinfo" CPUINFO_CONTENT)
|
||||
string(REGEX MATCH "flags[ \t]*:[ \t]*([^\n]*)" FLAGS_LINE "${CPUINFO_CONTENT}")
|
||||
if(NOT CMAKE_MATCH_1)
|
||||
message(STATUS "CPU detection: Could not parse CPU flags")
|
||||
return()
|
||||
endif()
|
||||
|
||||
set(CPU_FLAGS "${CMAKE_MATCH_1}")
|
||||
string(REPLACE " " ";" CPU_FLAGS_LIST "${CPU_FLAGS}")
|
||||
|
||||
# Check for each feature
|
||||
if("avx2" IN_LIST CPU_FLAGS_LIST)
|
||||
set(HAS_AVX2 ON PARENT_SCOPE)
|
||||
endif()
|
||||
|
||||
if("avx512f" IN_LIST CPU_FLAGS_LIST)
|
||||
set(HAS_AVX512F ON PARENT_SCOPE)
|
||||
endif()
|
||||
|
||||
if("avx512_vnni" IN_LIST CPU_FLAGS_LIST OR "avx512vnni" IN_LIST CPU_FLAGS_LIST)
|
||||
set(HAS_AVX512_VNNI ON PARENT_SCOPE)
|
||||
endif()
|
||||
|
||||
if("avx512_bf16" IN_LIST CPU_FLAGS_LIST OR "avx512bf16" IN_LIST CPU_FLAGS_LIST)
|
||||
set(HAS_AVX512_BF16 ON PARENT_SCOPE)
|
||||
endif()
|
||||
|
||||
if("avx512_vbmi" IN_LIST CPU_FLAGS_LIST OR "avx512vbmi" IN_LIST CPU_FLAGS_LIST)
|
||||
set(HAS_AVX512_VBMI ON PARENT_SCOPE)
|
||||
endif()
|
||||
|
||||
# Check for AMX (need all three)
|
||||
set(AMX_COUNT 0)
|
||||
foreach(flag "amx_tile" "amx_int8" "amx_bf16")
|
||||
if("${flag}" IN_LIST CPU_FLAGS_LIST)
|
||||
math(EXPR AMX_COUNT "${AMX_COUNT} + 1")
|
||||
endif()
|
||||
endforeach()
|
||||
if(AMX_COUNT EQUAL 3)
|
||||
set(HAS_AMX ON PARENT_SCOPE)
|
||||
endif()
|
||||
|
||||
# Get CPU model name for display
|
||||
string(REGEX MATCH "model name[ \t]*:[ \t]*([^\n]*)" MODEL_LINE "${CPUINFO_CONTENT}")
|
||||
if(CMAKE_MATCH_1)
|
||||
set(CPU_MODEL "${CMAKE_MATCH_1}" PARENT_SCOPE)
|
||||
endif()
|
||||
endfunction()
|
||||
|
||||
# Main detection and configuration
|
||||
message(STATUS "")
|
||||
message(STATUS "========================================")
|
||||
message(STATUS "CPU Feature Detection (CMake)")
|
||||
message(STATUS "========================================")
|
||||
|
||||
# Check if variables were already set by install.sh/setup.py
|
||||
set(FROM_INSTALL_SH OFF)
|
||||
if(DEFINED LLAMA_AVX512_VNNI OR DEFINED LLAMA_AVX512_BF16 OR DEFINED LLAMA_AVX512_VBMI)
|
||||
set(FROM_INSTALL_SH ON)
|
||||
message(STATUS "Detected configuration from install.sh/setup.py")
|
||||
message(STATUS " LLAMA_AVX512: ${LLAMA_AVX512}")
|
||||
message(STATUS " LLAMA_AVX512_VNNI: ${LLAMA_AVX512_VNNI}")
|
||||
message(STATUS " LLAMA_AVX512_BF16: ${LLAMA_AVX512_BF16}")
|
||||
message(STATUS " LLAMA_AVX512_VBMI: ${LLAMA_AVX512_VBMI}")
|
||||
message(STATUS "")
|
||||
message(STATUS "Skipping auto-detection (using install.sh settings)")
|
||||
message(STATUS "========================================")
|
||||
message(STATUS "")
|
||||
return()
|
||||
endif()
|
||||
|
||||
# Detect CPU features (only if not set by install.sh)
|
||||
detect_cpu_features()
|
||||
|
||||
if(CPU_MODEL)
|
||||
message(STATUS "CPU Model: ${CPU_MODEL}")
|
||||
endif()
|
||||
|
||||
message(STATUS "")
|
||||
message(STATUS "Detected features:")
|
||||
message(STATUS " AVX2: ${HAS_AVX2}")
|
||||
message(STATUS " AVX512F: ${HAS_AVX512F}")
|
||||
message(STATUS " AVX512_VNNI: ${HAS_AVX512_VNNI}")
|
||||
message(STATUS " AVX512_BF16: ${HAS_AVX512_BF16}")
|
||||
message(STATUS " AVX512_VBMI: ${HAS_AVX512_VBMI}")
|
||||
message(STATUS " AMX: ${HAS_AMX}")
|
||||
message(STATUS "")
|
||||
|
||||
# Auto-enable features based on detection
|
||||
# Only set if not already defined by user via -D flags
|
||||
if(NOT DEFINED LLAMA_AVX2 AND HAS_AVX2)
|
||||
set(LLAMA_AVX2 ON CACHE BOOL "Enable AVX2" FORCE)
|
||||
message(STATUS "Auto-enabled: AVX2")
|
||||
endif()
|
||||
|
||||
if(NOT DEFINED LLAMA_AVX512 AND HAS_AVX512F)
|
||||
set(LLAMA_AVX512 ON CACHE BOOL "Enable AVX512F" FORCE)
|
||||
message(STATUS "Auto-enabled: AVX512F")
|
||||
endif()
|
||||
|
||||
if(NOT DEFINED LLAMA_AVX512_VNNI AND HAS_AVX512_VNNI)
|
||||
set(LLAMA_AVX512_VNNI ON CACHE BOOL "Enable AVX512_VNNI" FORCE)
|
||||
message(STATUS "Auto-enabled: AVX512_VNNI")
|
||||
endif()
|
||||
|
||||
if(NOT DEFINED LLAMA_AVX512_BF16 AND HAS_AVX512_BF16)
|
||||
set(LLAMA_AVX512_BF16 ON CACHE BOOL "Enable AVX512_BF16" FORCE)
|
||||
message(STATUS "Auto-enabled: AVX512_BF16")
|
||||
endif()
|
||||
|
||||
if(NOT DEFINED LLAMA_AVX512_VBMI AND HAS_AVX512_VBMI)
|
||||
set(LLAMA_AVX512_VBMI ON CACHE BOOL "Enable AVX512_VBMI" FORCE)
|
||||
message(STATUS "Auto-enabled: AVX512_VBMI")
|
||||
endif()
|
||||
|
||||
if(NOT DEFINED KTRANSFORMERS_CPU_USE_AMX AND HAS_AMX)
|
||||
set(KTRANSFORMERS_CPU_USE_AMX ON CACHE BOOL "Enable AMX" FORCE)
|
||||
message(STATUS "Auto-enabled: AMX")
|
||||
endif()
|
||||
|
||||
message(STATUS "")
|
||||
message(STATUS "Note: You can override by passing -DLLAMA_AVX512_BF16=OFF etc.")
|
||||
message(STATUS "Note: Or use install.sh with environment variables")
|
||||
message(STATUS "========================================")
|
||||
message(STATUS "")
|
||||
@@ -36,7 +36,9 @@ static const bool _is_plain_ = false;
|
||||
|
||||
#if defined(__x86_64__) && defined(USE_AMX_AVX_KERNEL)
|
||||
#include "operators/amx/awq-moe.hpp"
|
||||
#include "operators/amx/fp8-moe.hpp"
|
||||
#if defined(__AVX512BF16__)
|
||||
#include "operators/amx/fp8-moe.hpp" // FP8 MoE requires AVX512 BF16 support
|
||||
#endif
|
||||
#include "operators/amx/k2-moe.hpp"
|
||||
#include "operators/amx/la/amx_kernels.hpp"
|
||||
#include "operators/amx/moe.hpp"
|
||||
@@ -293,7 +295,9 @@ void bind_moe_module(py::module_& moe_module, const char* name) {
|
||||
py::arg("w2_weight_ptrs"), py::arg("w2_scale_ptrs"));
|
||||
}
|
||||
|
||||
#if defined(__AVX512BF16__)
|
||||
// FP8 MoE: processes one expert at a time (expert_id instead of gpu_experts_num)
|
||||
// Only available on CPUs with AVX512 BF16 support
|
||||
if constexpr (std::is_same_v<MoeTP, AMX_FP8_MOE_TP<amx::GemmKernel224FP8>>) {
|
||||
struct WriteWeightScaleToBufferBindings {
|
||||
struct Args {
|
||||
@@ -336,6 +340,7 @@ void bind_moe_module(py::module_& moe_module, const char* name) {
|
||||
py::arg("gpu_tp_count"), py::arg("expert_id"), py::arg("w13_weight_ptrs"), py::arg("w13_scale_ptrs"),
|
||||
py::arg("w2_weight_ptrs"), py::arg("w2_scale_ptrs"));
|
||||
}
|
||||
#endif // __AVX512BF16__
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -607,8 +612,10 @@ PYBIND11_MODULE(kt_kernel_ext, m) {
|
||||
bind_moe_module<AMX_MOE_TP<amx::GemmKernel224Int4_1>>(moe_module, "AMXInt4_1_MOE");
|
||||
bind_moe_module<AMX_AWQ_MOE_TP<amx::GemmKernel224Int4_1_LowKGroup>>(moe_module, "AMXInt4_1KGroup_MOE");
|
||||
bind_moe_module<AMX_K2_MOE_TP<amx::GemmKernel224Int4SmallKGroup>>(moe_module, "AMXInt4_KGroup_MOE");
|
||||
#if defined(__AVX512BF16__)
|
||||
bind_moe_module<AMX_FP8_MOE_TP<amx::GemmKernel224FP8>>(moe_module, "AMXFP8_MOE");
|
||||
#endif
|
||||
#endif
|
||||
#if defined(USE_MOE_KERNEL)
|
||||
bind_moe_module<MOE_KERNEL_TP<moe_kernel::GemmKernelInt8, _is_plain_>>(moe_module, "Int8_KERNEL_MOE");
|
||||
#if defined(__aarch64__) && defined(CPU_USE_KML)
|
||||
|
||||
@@ -24,6 +24,7 @@ AUTO-DETECTION (Default):
|
||||
- CPUINFER_ENABLE_AMX = ON/OFF (based on detection)
|
||||
- CPUINFER_ENABLE_AVX512_VNNI = ON/OFF (with fallback if OFF)
|
||||
- CPUINFER_ENABLE_AVX512_BF16 = ON/OFF (with fallback if OFF)
|
||||
- CPUINFER_ENABLE_AVX512_VBMI = ON/OFF (required for FP8 MoE)
|
||||
|
||||
✓ Best performance on YOUR machine
|
||||
✗ Binary may NOT work on different/older CPUs
|
||||
@@ -73,6 +74,7 @@ Optional variables (with defaults):
|
||||
CPUINFER_VERBOSE=1 Verbose build output (0/1)
|
||||
CPUINFER_ENABLE_AVX512_VNNI=ON/OFF Override VNNI detection (auto if unset)
|
||||
CPUINFER_ENABLE_AVX512_BF16=ON/OFF Override BF16 detection (auto if unset)
|
||||
CPUINFER_ENABLE_AVX512_VBMI=ON/OFF Override VBMI detection (auto if unset)
|
||||
|
||||
Software Fallback Support:
|
||||
✓ If VNNI not available: Uses AVX512BW fallback (2-3x slower but works)
|
||||
@@ -144,11 +146,13 @@ install_dependencies() {
|
||||
}
|
||||
|
||||
# Function to detect CPU features
|
||||
# Returns: "has_amx has_avx512_vnni has_avx512_bf16" (space-separated 0/1 values)
|
||||
# Returns: "has_amx has_avx512f has_avx512_vnni has_avx512_bf16 has_avx512_vbmi" (space-separated 0/1 values)
|
||||
detect_cpu_features() {
|
||||
local has_amx=0
|
||||
local has_avx512f=0
|
||||
local has_avx512_vnni=0
|
||||
local has_avx512_bf16=0
|
||||
local has_avx512_vbmi=0
|
||||
|
||||
if [ -f /proc/cpuinfo ]; then
|
||||
local cpu_flags
|
||||
@@ -159,6 +163,11 @@ detect_cpu_features() {
|
||||
has_amx=1
|
||||
fi
|
||||
|
||||
# Check for AVX512F (foundation)
|
||||
if echo "$cpu_flags" | grep -qE "avx512f"; then
|
||||
has_avx512f=1
|
||||
fi
|
||||
|
||||
# Check for AVX512_VNNI support
|
||||
if echo "$cpu_flags" | grep -qE "avx512_vnni|avx512vnni"; then
|
||||
has_avx512_vnni=1
|
||||
@@ -168,14 +177,21 @@ detect_cpu_features() {
|
||||
if echo "$cpu_flags" | grep -qE "avx512_bf16|avx512bf16"; then
|
||||
has_avx512_bf16=1
|
||||
fi
|
||||
|
||||
# Check for AVX512_VBMI support
|
||||
if echo "$cpu_flags" | grep -qE "avx512_vbmi|avx512vbmi"; then
|
||||
has_avx512_vbmi=1
|
||||
fi
|
||||
elif [ "$(uname)" = "Darwin" ]; then
|
||||
# macOS doesn't have AMX (ARM or Intel without AMX)
|
||||
has_amx=0
|
||||
has_avx512f=0
|
||||
has_avx512_vnni=0
|
||||
has_avx512_bf16=0
|
||||
has_avx512_vbmi=0
|
||||
fi
|
||||
|
||||
echo "$has_amx $has_avx512_vnni $has_avx512_bf16"
|
||||
echo "$has_amx $has_avx512f $has_avx512_vnni $has_avx512_bf16 $has_avx512_vbmi"
|
||||
}
|
||||
|
||||
build_step() {
|
||||
@@ -210,11 +226,13 @@ build_step() {
|
||||
echo "=========================================="
|
||||
echo ""
|
||||
|
||||
# detect_cpu_features returns "has_amx has_avx512_vnni has_avx512_bf16"
|
||||
# detect_cpu_features returns "has_amx has_avx512f has_avx512_vnni has_avx512_bf16 has_avx512_vbmi"
|
||||
CPU_FEATURES=$(detect_cpu_features)
|
||||
HAS_AMX=$(echo "$CPU_FEATURES" | cut -d' ' -f1)
|
||||
HAS_AVX512_VNNI=$(echo "$CPU_FEATURES" | cut -d' ' -f2)
|
||||
HAS_AVX512_BF16=$(echo "$CPU_FEATURES" | cut -d' ' -f3)
|
||||
HAS_AVX512F=$(echo "$CPU_FEATURES" | cut -d' ' -f2)
|
||||
HAS_AVX512_VNNI=$(echo "$CPU_FEATURES" | cut -d' ' -f3)
|
||||
HAS_AVX512_BF16=$(echo "$CPU_FEATURES" | cut -d' ' -f4)
|
||||
HAS_AVX512_VBMI=$(echo "$CPU_FEATURES" | cut -d' ' -f5)
|
||||
|
||||
export CPUINFER_CPU_INSTRUCT=NATIVE
|
||||
|
||||
@@ -244,6 +262,13 @@ build_step() {
|
||||
echo ""
|
||||
echo "AVX512 Feature Detection:"
|
||||
|
||||
# AVX512F: Foundation (required for all AVX512 variants)
|
||||
if [ "$HAS_AVX512F" = "1" ]; then
|
||||
echo " AVX512F: ✓ Detected (foundation)"
|
||||
else
|
||||
echo " AVX512F: ✗ Not detected (AVX512 not available)"
|
||||
fi
|
||||
|
||||
# VNNI: Check if user manually set it, otherwise auto-detect
|
||||
if [ -n "${CPUINFER_ENABLE_AVX512_VNNI:-}" ]; then
|
||||
echo " VNNI: User override = $CPUINFER_ENABLE_AVX512_VNNI"
|
||||
@@ -270,9 +295,23 @@ build_step() {
|
||||
fi
|
||||
fi
|
||||
|
||||
# VBMI: Check if user manually set it, otherwise auto-detect
|
||||
if [ -n "${CPUINFER_ENABLE_AVX512_VBMI:-}" ]; then
|
||||
echo " VBMI: User override = $CPUINFER_ENABLE_AVX512_VBMI"
|
||||
else
|
||||
if [ "$HAS_AVX512_VBMI" = "1" ]; then
|
||||
echo " VBMI: ✓ Detected (byte permutation enabled)"
|
||||
export CPUINFER_ENABLE_AVX512_VBMI=ON
|
||||
else
|
||||
echo " VBMI: ✗ Not detected (FP8 MoE may not work)"
|
||||
export CPUINFER_ENABLE_AVX512_VBMI=OFF
|
||||
fi
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo " Note: Software fallbacks ensure all code works on older CPUs"
|
||||
echo " Tip: Override with CPUINFER_ENABLE_AVX512_VNNI/BF16=ON/OFF"
|
||||
echo " Note: FP8 MoE requires AVX512F + BF16 + VNNI + VBMI"
|
||||
echo " Tip: Override with CPUINFER_ENABLE_AVX512_[VNNI|BF16|VBMI]=ON/OFF"
|
||||
|
||||
echo ""
|
||||
echo "To use manual configuration instead, run: $0 build --manual"
|
||||
@@ -357,6 +396,7 @@ echo " CPUINFER_CPU_INSTRUCT = $CPUINFER_CPU_INSTRUCT"
|
||||
echo " CPUINFER_ENABLE_AMX = $CPUINFER_ENABLE_AMX"
|
||||
echo " CPUINFER_ENABLE_AVX512_VNNI = ${CPUINFER_ENABLE_AVX512_VNNI:-AUTO}"
|
||||
echo " CPUINFER_ENABLE_AVX512_BF16 = ${CPUINFER_ENABLE_AVX512_BF16:-AUTO}"
|
||||
echo " CPUINFER_ENABLE_AVX512_VBMI = ${CPUINFER_ENABLE_AVX512_VBMI:-AUTO}"
|
||||
echo " CPUINFER_BUILD_TYPE = $CPUINFER_BUILD_TYPE"
|
||||
echo " CPUINFER_PARALLEL = $CPUINFER_PARALLEL"
|
||||
echo ""
|
||||
|
||||
@@ -25,19 +25,24 @@ from pathlib import Path
|
||||
|
||||
def detect_cpu_features():
|
||||
"""
|
||||
Detect CPU features to determine the best kernel variant.
|
||||
Detect CPU features and determine the best kernel variant using progressive matching.
|
||||
|
||||
Detection hierarchy:
|
||||
1. AMX: Intel Sapphire Rapids+ with AMX support
|
||||
2. AVX512: CPUs with AVX512F support
|
||||
3. AVX2: Fallback for maximum compatibility
|
||||
Progressive variant hierarchy (from most to least advanced):
|
||||
1. AMX: amx_tile, amx_int8, amx_bf16 + full AVX512
|
||||
2. AVX512_BF16: avx512f, avx512bw, avx512_vnni, avx512_vbmi, avx512_bf16
|
||||
3. AVX512_VBMI: avx512f, avx512bw, avx512_vnni, avx512_vbmi
|
||||
4. AVX512_VNNI: avx512f, avx512bw, avx512_vnni
|
||||
5. AVX512_BASE: avx512f, avx512bw
|
||||
6. AVX2: avx2 (fallback)
|
||||
|
||||
Returns:
|
||||
str: 'amx', 'avx512', or 'avx2'
|
||||
str: Variant name - one of: 'amx', 'avx512_bf16', 'avx512_vbmi',
|
||||
'avx512_vnni', 'avx512_base', 'avx2'
|
||||
"""
|
||||
# Check environment override
|
||||
variant = os.environ.get("KT_KERNEL_CPU_VARIANT", "").lower()
|
||||
if variant in ["amx", "avx512", "avx2"]:
|
||||
valid_variants = ["amx", "avx512_bf16", "avx512_vbmi", "avx512_vnni", "avx512_base", "avx2"]
|
||||
if variant in valid_variants:
|
||||
if os.environ.get("KT_KERNEL_DEBUG") == "1":
|
||||
print(f"[kt-kernel] Using environment override: {variant}")
|
||||
return variant
|
||||
@@ -47,32 +52,57 @@ def detect_cpu_features():
|
||||
with open("/proc/cpuinfo", "r") as f:
|
||||
cpuinfo = f.read().lower()
|
||||
|
||||
# Check for AMX support (Intel Sapphire Rapids+)
|
||||
# AMX requires amx_tile, amx_int8, and amx_bf16
|
||||
amx_flags = ["amx_tile", "amx_int8", "amx_bf16"]
|
||||
has_amx = all(flag in cpuinfo for flag in amx_flags)
|
||||
# Extract CPU flags into a set for fast lookup
|
||||
cpu_flags = set()
|
||||
for line in cpuinfo.split("\n"):
|
||||
if line.startswith("flags"):
|
||||
flags_str = line.split(":", 1)[1]
|
||||
cpu_flags = set(flags_str.split())
|
||||
break
|
||||
|
||||
if has_amx:
|
||||
if os.environ.get("KT_KERNEL_DEBUG") == "1":
|
||||
print("[kt-kernel] Detected AMX support via /proc/cpuinfo")
|
||||
return "amx"
|
||||
# Define variant requirements in priority order (best to worst)
|
||||
variant_requirements = [
|
||||
(
|
||||
"amx",
|
||||
[
|
||||
"amx_tile",
|
||||
"amx_int8",
|
||||
"amx_bf16",
|
||||
"avx512f",
|
||||
"avx512bw",
|
||||
"avx512_vnni",
|
||||
"avx512_vbmi",
|
||||
"avx512_bf16",
|
||||
],
|
||||
),
|
||||
("avx512_bf16", ["avx512f", "avx512bw", "avx512_vnni", "avx512_vbmi", "avx512_bf16"]),
|
||||
("avx512_vbmi", ["avx512f", "avx512bw", "avx512_vnni", "avx512_vbmi"]),
|
||||
("avx512_vnni", ["avx512f", "avx512bw", "avx512_vnni"]),
|
||||
("avx512_base", ["avx512f", "avx512bw"]),
|
||||
("avx2", ["avx2"]),
|
||||
]
|
||||
|
||||
# Check for AVX512 support
|
||||
# AVX512F is the foundation for all AVX512 variants
|
||||
if "avx512f" in cpuinfo:
|
||||
if os.environ.get("KT_KERNEL_DEBUG") == "1":
|
||||
print("[kt-kernel] Detected AVX512 support via /proc/cpuinfo")
|
||||
return "avx512"
|
||||
# Find the best matching variant
|
||||
for variant_name, required_flags in variant_requirements:
|
||||
# Check if all required flags are present
|
||||
# Handle flag name variations (e.g., avx512_bf16 vs avx512bf16)
|
||||
has_all_flags = True
|
||||
for flag in required_flags:
|
||||
# Try exact match first, then without underscore
|
||||
flag_alt = flag.replace("_", "")
|
||||
if flag not in cpu_flags and flag_alt not in cpu_flags:
|
||||
has_all_flags = False
|
||||
break
|
||||
|
||||
# Check for AVX2 support
|
||||
if "avx2" in cpuinfo:
|
||||
if os.environ.get("KT_KERNEL_DEBUG") == "1":
|
||||
print("[kt-kernel] Detected AVX2 support via /proc/cpuinfo")
|
||||
return "avx2"
|
||||
if has_all_flags:
|
||||
if os.environ.get("KT_KERNEL_DEBUG") == "1":
|
||||
print(f"[kt-kernel] Detected {variant_name} support via /proc/cpuinfo")
|
||||
print(f"[kt-kernel] Matched flags: {', '.join(required_flags)}")
|
||||
return variant_name
|
||||
|
||||
# Fallback to AVX2 (should be rare on modern CPUs)
|
||||
if os.environ.get("KT_KERNEL_DEBUG") == "1":
|
||||
print("[kt-kernel] No AVX2/AVX512/AMX detected, using AVX2 fallback")
|
||||
print("[kt-kernel] No supported features detected, using AVX2 fallback")
|
||||
return "avx2"
|
||||
|
||||
except FileNotFoundError:
|
||||
@@ -84,17 +114,35 @@ def detect_cpu_features():
|
||||
try:
|
||||
import cpufeature
|
||||
|
||||
# Check for AMX
|
||||
if cpufeature.CPUFeature.get("AMX_TILE", False):
|
||||
if os.environ.get("KT_KERNEL_DEBUG") == "1":
|
||||
print("[kt-kernel] Detected AMX support via cpufeature")
|
||||
return "amx"
|
||||
# Define variant requirements in priority order (using cpufeature naming)
|
||||
cpufeature_requirements = [
|
||||
(
|
||||
"amx",
|
||||
[
|
||||
"AMX_TILE",
|
||||
"AMX_INT8",
|
||||
"AMX_BF16",
|
||||
"AVX512F",
|
||||
"AVX512BW",
|
||||
"AVX512_VNNI",
|
||||
"AVX512_VBMI",
|
||||
"AVX512_BF16",
|
||||
],
|
||||
),
|
||||
("avx512_bf16", ["AVX512F", "AVX512BW", "AVX512_VNNI", "AVX512_VBMI", "AVX512_BF16"]),
|
||||
("avx512_vbmi", ["AVX512F", "AVX512BW", "AVX512_VNNI", "AVX512_VBMI"]),
|
||||
("avx512_vnni", ["AVX512F", "AVX512BW", "AVX512_VNNI"]),
|
||||
("avx512_base", ["AVX512F", "AVX512BW"]),
|
||||
("avx2", ["AVX2"]),
|
||||
]
|
||||
|
||||
# Check for AVX512
|
||||
if cpufeature.CPUFeature.get("AVX512F", False):
|
||||
if os.environ.get("KT_KERNEL_DEBUG") == "1":
|
||||
print("[kt-kernel] Detected AVX512 support via cpufeature")
|
||||
return "avx512"
|
||||
# Find the best matching variant
|
||||
for variant_name, required_features in cpufeature_requirements:
|
||||
has_all_features = all(cpufeature.CPUFeature.get(feat, False) for feat in required_features)
|
||||
if has_all_features:
|
||||
if os.environ.get("KT_KERNEL_DEBUG") == "1":
|
||||
print(f"[kt-kernel] Detected {variant_name} support via cpufeature")
|
||||
return variant_name
|
||||
|
||||
# Fallback to AVX2
|
||||
if os.environ.get("KT_KERNEL_DEBUG") == "1":
|
||||
@@ -124,10 +172,11 @@ def load_extension(variant):
|
||||
Supports both multi-variant builds (_kt_kernel_ext_amx.*.so) and
|
||||
single-variant builds (kt_kernel_ext.*.so).
|
||||
|
||||
Fallback order: amx -> avx512 -> avx2 -> single-variant
|
||||
Fallback chain (each variant falls back to the next in line):
|
||||
amx -> avx512_bf16 -> avx512_vbmi -> avx512_vnni -> avx512_base -> avx2 -> single-variant
|
||||
|
||||
Args:
|
||||
variant (str): 'amx', 'avx512', or 'avx2'
|
||||
variant (str): One of 'amx', 'avx512_bf16', 'avx512_vbmi', 'avx512_vnni', 'avx512_base', 'avx2'
|
||||
|
||||
Returns:
|
||||
module: The loaded extension module
|
||||
@@ -187,15 +236,24 @@ def load_extension(variant):
|
||||
if os.environ.get("KT_KERNEL_DEBUG") == "1":
|
||||
print(f"[kt-kernel] Failed to load {variant} variant: {e}")
|
||||
|
||||
# Automatic fallback to next best variant
|
||||
if variant == "amx":
|
||||
# Define fallback chain: each variant falls back to the next lower one
|
||||
fallback_chain = {
|
||||
"amx": "avx512_bf16",
|
||||
"avx512_bf16": "avx512_vbmi",
|
||||
"avx512_vbmi": "avx512_vnni",
|
||||
"avx512_vnni": "avx512_base",
|
||||
"avx512_base": "avx2",
|
||||
"avx2": None, # No fallback - terminal variant
|
||||
}
|
||||
|
||||
# Get next fallback variant
|
||||
next_variant = fallback_chain.get(variant)
|
||||
|
||||
if next_variant:
|
||||
# Try next variant in the chain
|
||||
if os.environ.get("KT_KERNEL_DEBUG") == "1":
|
||||
print("[kt-kernel] Falling back from AMX to AVX512")
|
||||
return load_extension("avx512")
|
||||
elif variant == "avx512":
|
||||
if os.environ.get("KT_KERNEL_DEBUG") == "1":
|
||||
print("[kt-kernel] Falling back from AVX512 to AVX2")
|
||||
return load_extension("avx2")
|
||||
print(f"[kt-kernel] Falling back from {variant} to {next_variant}")
|
||||
return load_extension(next_variant)
|
||||
else:
|
||||
# AVX2 is the last fallback - if this fails, we can't continue
|
||||
raise ImportError(
|
||||
|
||||
120
kt-kernel/scripts/check_cpu_features.py
Executable file
120
kt-kernel/scripts/check_cpu_features.py
Executable file
@@ -0,0 +1,120 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
CPU feature detection script for kt-kernel.
|
||||
|
||||
This script checks if your CPU supports the required instruction sets for FP8 MoE:
|
||||
- AVX512F (foundation)
|
||||
- AVX512_BF16 (BF16 dot product)
|
||||
- AVX512_VNNI (VNNI instructions)
|
||||
- AVX512_VBMI (byte permutation)
|
||||
|
||||
Usage:
|
||||
python3 scripts/check_cpu_features.py
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
||||
def check_cpuinfo():
|
||||
"""Check CPU features via /proc/cpuinfo."""
|
||||
try:
|
||||
with open("/proc/cpuinfo", "r") as f:
|
||||
cpuinfo = f.read().lower()
|
||||
return cpuinfo
|
||||
except FileNotFoundError:
|
||||
return None
|
||||
|
||||
|
||||
def main():
|
||||
print("=" * 70)
|
||||
print("KT-Kernel CPU Feature Detection")
|
||||
print("=" * 70)
|
||||
print()
|
||||
|
||||
cpuinfo = check_cpuinfo()
|
||||
|
||||
if cpuinfo is None:
|
||||
print("❌ /proc/cpuinfo not found (not on Linux?)")
|
||||
print(" Cannot detect CPU features automatically.")
|
||||
sys.exit(1)
|
||||
|
||||
# Extract CPU model
|
||||
for line in cpuinfo.split("\n"):
|
||||
if "model name" in line:
|
||||
model = line.split(":")[1].strip()
|
||||
print(f"CPU Model: {model}")
|
||||
break
|
||||
print()
|
||||
|
||||
# Check AMX support
|
||||
print("AMX Support (Intel Sapphire Rapids+):")
|
||||
amx_flags = ["amx_tile", "amx_int8", "amx_bf16"]
|
||||
amx_status = {}
|
||||
for flag in amx_flags:
|
||||
has_flag = flag in cpuinfo
|
||||
amx_status[flag] = has_flag
|
||||
status = "✅" if has_flag else "❌"
|
||||
print(f" {status} {flag.upper()}")
|
||||
|
||||
has_amx = all(amx_status.values())
|
||||
print(f"\n Overall AMX Support: {'✅ YES' if has_amx else '❌ NO'}")
|
||||
print()
|
||||
|
||||
# Check AVX512 support
|
||||
print("AVX512 Support (required for FP8 MoE):")
|
||||
avx512_flags = ["avx512f", "avx512_bf16", "avx512_vnni", "avx512_vbmi"]
|
||||
avx512_status = {}
|
||||
for flag in avx512_flags:
|
||||
has_flag = flag in cpuinfo
|
||||
avx512_status[flag] = has_flag
|
||||
status = "✅" if has_flag else "❌"
|
||||
flag_desc = {
|
||||
"avx512f": "AVX512F (foundation)",
|
||||
"avx512_bf16": "AVX512_BF16 (BF16 dot product)",
|
||||
"avx512_vnni": "AVX512_VNNI (VNNI instructions)",
|
||||
"avx512_vbmi": "AVX512_VBMI (byte permutation)",
|
||||
}
|
||||
print(f" {status} {flag_desc.get(flag, flag.upper())}")
|
||||
|
||||
has_avx512_full = all(avx512_status.values())
|
||||
print(f"\n Overall AVX512 Support: {'✅ YES' if has_avx512_full else '❌ NO'}")
|
||||
|
||||
if not has_avx512_full and avx512_status["avx512f"]:
|
||||
missing = [f for f in avx512_flags if not avx512_status[f]]
|
||||
print(f" ⚠️ Warning: AVX512F detected but missing: {', '.join(missing)}")
|
||||
print(f" kt-kernel will fall back to AVX2 mode")
|
||||
print()
|
||||
|
||||
# Check AVX2 support
|
||||
print("AVX2 Support (fallback):")
|
||||
has_avx2 = "avx2" in cpuinfo
|
||||
status = "✅" if has_avx2 else "❌"
|
||||
print(f" {status} AVX2")
|
||||
print()
|
||||
|
||||
# Recommendation
|
||||
print("=" * 70)
|
||||
print("Recommendation:")
|
||||
print("=" * 70)
|
||||
if has_amx:
|
||||
print("✅ Your CPU supports AMX - you can use the highest performance mode!")
|
||||
print(" Build with: -DKTRANSFORMERS_CPU_USE_AMX_AVX512=ON -DKTRANSFORMERS_CPU_USE_AMX=ON")
|
||||
elif has_avx512_full:
|
||||
print("✅ Your CPU supports full AVX512 (F/BF16/VNNI/VBMI) - FP8 MoE will work!")
|
||||
print(" Build with: -DKTRANSFORMERS_CPU_USE_AMX_AVX512=ON")
|
||||
elif avx512_status.get("avx512f", False):
|
||||
print("⚠️ Your CPU has AVX512F but missing required extensions.")
|
||||
print(" FP8 MoE will NOT work. kt-kernel will fall back to AVX2 mode.")
|
||||
print(" Missing extensions:", ", ".join([f for f in avx512_flags if not avx512_status.get(f, False)]))
|
||||
elif has_avx2:
|
||||
print("ℹ️ Your CPU supports AVX2 only - basic compatibility mode.")
|
||||
print(" FP8 MoE will NOT be available, but other features will work.")
|
||||
else:
|
||||
print("❌ Your CPU does not support the minimum required instruction set (AVX2).")
|
||||
print(" kt-kernel may not work on this system.")
|
||||
print()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -21,6 +21,9 @@ Environment knobs (export before running pip install .):
|
||||
CPUINFER_ENABLE_BLIS=OFF ON/OFF -> -DKTRANSFORMERS_CPU_MOE_AMD
|
||||
CPUINFER_ENABLE_KML=OFF ON/OFF -> -DKTRANSFORMERS_CPU_USE_KML
|
||||
CPUINFER_ENABLE_AVX512=OFF ON/OFF -> -DKTRANSFORMERS_CPU_USE_AMX_AVX512
|
||||
CPUINFER_ENABLE_AVX512_VNNI=OFF ON/OFF -> -DLLAMA_AVX512_VNNI
|
||||
CPUINFER_ENABLE_AVX512_BF16=OFF ON/OFF -> -DLLAMA_AVX512_BF16
|
||||
CPUINFER_ENABLE_AVX512_VBMI=OFF ON/OFF -> -DLLAMA_AVX512_VBMI (required for FP8 MoE)
|
||||
CPUINFER_BLIS_ROOT=/path/to/blis Forward to -DBLIS_ROOT
|
||||
|
||||
|
||||
@@ -254,23 +257,29 @@ class CMakeBuild(build_ext):
|
||||
|
||||
def build_multi_variants(self, ext: CMakeExtension):
|
||||
"""
|
||||
Build all 3 CPU variants (AMX, AVX512, AVX2) in a single wheel.
|
||||
Build all 6 CPU variants with progressive AVX512 capabilities.
|
||||
|
||||
This creates 3 separate .so files:
|
||||
- _kt_kernel_ext_amx.cpython-311-x86_64-linux-gnu.so
|
||||
- _kt_kernel_ext_avx512.cpython-311-x86_64-linux-gnu.so
|
||||
- _kt_kernel_ext_avx2.cpython-311-x86_64-linux-gnu.so
|
||||
This creates 6 separate .so files optimized for different CPU generations:
|
||||
- _kt_kernel_ext_avx2.so (Haswell+, 2013)
|
||||
- _kt_kernel_ext_avx512_base.so (Skylake-X+, 2017)
|
||||
- _kt_kernel_ext_avx512_vnni.so (Cascade Lake+, 2019)
|
||||
- _kt_kernel_ext_avx512_vbmi.so (Ice Lake client, 2019)
|
||||
- _kt_kernel_ext_avx512_bf16.so (Ice Lake server/Zen 4+, 2021)
|
||||
- _kt_kernel_ext_amx.so (Sapphire Rapids+, 2023)
|
||||
|
||||
Runtime CPU detection (in _cpu_detect.py) will automatically load the best one.
|
||||
Runtime CPU detection (in _cpu_detect.py) will automatically select the best match.
|
||||
"""
|
||||
print("=" * 70)
|
||||
print("Building kt-kernel with ALL CPU variants (AMX, AVX512, AVX2)")
|
||||
print("Building kt-kernel with ALL 6 CPU variants")
|
||||
print("=" * 70)
|
||||
print()
|
||||
print("This will build three variants in a single wheel:")
|
||||
print(" - AMX variant (Intel Sapphire Rapids+)")
|
||||
print(" - AVX512 variant (Intel Skylake-X/Ice Lake+, AMD Zen 4+)")
|
||||
print(" - AVX2 variant (maximum compatibility, 2013+)")
|
||||
print("This will build six progressive variants in a single wheel:")
|
||||
print(" 1. AVX2 - Haswell+ (2013)")
|
||||
print(" 2. AVX512 Base - Skylake-X+ (2017)")
|
||||
print(" 3. AVX512+VNNI - Cascade Lake+ (2019)")
|
||||
print(" 4. AVX512+VBMI - Ice Lake client (2019)")
|
||||
print(" 5. AVX512+BF16 - Ice Lake server, Zen 4+ (2021)")
|
||||
print(" 6. AMX - Sapphire Rapids+ (2023)")
|
||||
print()
|
||||
print("Runtime CPU detection will automatically select the best variant.")
|
||||
print()
|
||||
@@ -278,33 +287,100 @@ class CMakeBuild(build_ext):
|
||||
extdir = Path(self.get_ext_fullpath(ext.name)).parent.resolve()
|
||||
cfg = default_build_type()
|
||||
|
||||
# Save original env vars
|
||||
orig_cpu_instruct = os.environ.get("CPUINFER_CPU_INSTRUCT")
|
||||
orig_enable_amx = os.environ.get("CPUINFER_ENABLE_AMX")
|
||||
orig_enable_avx512 = os.environ.get("CPUINFER_ENABLE_AVX512")
|
||||
# Save original env vars to restore later
|
||||
env_backup = {
|
||||
"CPUINFER_CPU_INSTRUCT": os.environ.get("CPUINFER_CPU_INSTRUCT"),
|
||||
"CPUINFER_ENABLE_AMX": os.environ.get("CPUINFER_ENABLE_AMX"),
|
||||
"CPUINFER_ENABLE_AVX512": os.environ.get("CPUINFER_ENABLE_AVX512"),
|
||||
"CPUINFER_ENABLE_AVX512_VNNI": os.environ.get("CPUINFER_ENABLE_AVX512_VNNI"),
|
||||
"CPUINFER_ENABLE_AVX512_BF16": os.environ.get("CPUINFER_ENABLE_AVX512_BF16"),
|
||||
"CPUINFER_ENABLE_AVX512_VBMI": os.environ.get("CPUINFER_ENABLE_AVX512_VBMI"),
|
||||
}
|
||||
|
||||
# Variant configurations: (name, CPUINFER_CPU_INSTRUCT, CPUINFER_ENABLE_AMX)
|
||||
# Variant configurations: (name, description, env_vars)
|
||||
# Each variant specifies exactly which features to enable
|
||||
variants = [
|
||||
("amx", "AVX512", "ON"), # AVX512 + AMX
|
||||
("avx512", "AVX512", "OFF"), # AVX512 only
|
||||
("avx2", "AVX2", "OFF"), # AVX2 only
|
||||
(
|
||||
"avx2",
|
||||
"AVX2 baseline",
|
||||
{
|
||||
"CPUINFER_CPU_INSTRUCT": "AVX2",
|
||||
"CPUINFER_ENABLE_AVX512": "OFF",
|
||||
"CPUINFER_ENABLE_AMX": "OFF",
|
||||
},
|
||||
),
|
||||
(
|
||||
"avx512_base",
|
||||
"AVX512F+BW",
|
||||
{
|
||||
"CPUINFER_CPU_INSTRUCT": "AVX512",
|
||||
"CPUINFER_ENABLE_AVX512": "ON",
|
||||
"CPUINFER_ENABLE_AVX512_VNNI": "OFF",
|
||||
"CPUINFER_ENABLE_AVX512_BF16": "OFF",
|
||||
"CPUINFER_ENABLE_AVX512_VBMI": "OFF",
|
||||
"CPUINFER_ENABLE_AMX": "OFF",
|
||||
},
|
||||
),
|
||||
(
|
||||
"avx512_vnni",
|
||||
"AVX512F+VNNI",
|
||||
{
|
||||
"CPUINFER_CPU_INSTRUCT": "AVX512",
|
||||
"CPUINFER_ENABLE_AVX512": "ON",
|
||||
"CPUINFER_ENABLE_AVX512_VNNI": "ON",
|
||||
"CPUINFER_ENABLE_AVX512_BF16": "OFF",
|
||||
"CPUINFER_ENABLE_AVX512_VBMI": "OFF",
|
||||
"CPUINFER_ENABLE_AMX": "OFF",
|
||||
},
|
||||
),
|
||||
(
|
||||
"avx512_vbmi",
|
||||
"AVX512F+VNNI+VBMI",
|
||||
{
|
||||
"CPUINFER_CPU_INSTRUCT": "AVX512",
|
||||
"CPUINFER_ENABLE_AVX512": "ON",
|
||||
"CPUINFER_ENABLE_AVX512_VNNI": "ON",
|
||||
"CPUINFER_ENABLE_AVX512_BF16": "OFF",
|
||||
"CPUINFER_ENABLE_AVX512_VBMI": "ON",
|
||||
"CPUINFER_ENABLE_AMX": "OFF",
|
||||
},
|
||||
),
|
||||
(
|
||||
"avx512_bf16",
|
||||
"AVX512 Full (F+VNNI+VBMI+BF16)",
|
||||
{
|
||||
"CPUINFER_CPU_INSTRUCT": "AVX512",
|
||||
"CPUINFER_ENABLE_AVX512": "ON",
|
||||
"CPUINFER_ENABLE_AVX512_VNNI": "ON",
|
||||
"CPUINFER_ENABLE_AVX512_BF16": "ON",
|
||||
"CPUINFER_ENABLE_AVX512_VBMI": "ON",
|
||||
"CPUINFER_ENABLE_AMX": "OFF",
|
||||
},
|
||||
),
|
||||
(
|
||||
"amx",
|
||||
"AMX + AVX512 Full",
|
||||
{
|
||||
"CPUINFER_CPU_INSTRUCT": "AVX512",
|
||||
"CPUINFER_ENABLE_AVX512": "ON",
|
||||
"CPUINFER_ENABLE_AVX512_VNNI": "ON",
|
||||
"CPUINFER_ENABLE_AVX512_BF16": "ON",
|
||||
"CPUINFER_ENABLE_AVX512_VBMI": "ON",
|
||||
"CPUINFER_ENABLE_AMX": "ON",
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
for variant_name, cpu_instruct, enable_amx in variants:
|
||||
for variant_name, variant_desc, env_vars in variants:
|
||||
print("=" * 70)
|
||||
print(f"Building {variant_name.upper()} variant...")
|
||||
print(f"Building {variant_name.upper()} variant ({variant_desc})")
|
||||
print("=" * 70)
|
||||
print()
|
||||
|
||||
# Set environment variables for this variant
|
||||
os.environ["CPUINFER_CPU_INSTRUCT"] = cpu_instruct
|
||||
os.environ["CPUINFER_ENABLE_AMX"] = enable_amx
|
||||
if variant_name == "avx2":
|
||||
# For AVX2 variant, disable AVX512 umbrella to prevent AVX512 code
|
||||
os.environ["CPUINFER_ENABLE_AVX512"] = "OFF"
|
||||
else:
|
||||
# For AMX and AVX512 variants, enable AVX512 umbrella
|
||||
os.environ["CPUINFER_ENABLE_AVX512"] = "ON"
|
||||
for key, value in env_vars.items():
|
||||
os.environ[key] = value
|
||||
print(f" {key} = {value}")
|
||||
|
||||
# Use separate build directory for each variant
|
||||
build_temp = Path(self.build_temp) / f"{ext.name}_{cfg}_{variant_name}"
|
||||
@@ -338,26 +414,17 @@ class CMakeBuild(build_ext):
|
||||
print()
|
||||
|
||||
# Restore original env vars
|
||||
if orig_cpu_instruct is not None:
|
||||
os.environ["CPUINFER_CPU_INSTRUCT"] = orig_cpu_instruct
|
||||
elif "CPUINFER_CPU_INSTRUCT" in os.environ:
|
||||
del os.environ["CPUINFER_CPU_INSTRUCT"]
|
||||
|
||||
if orig_enable_amx is not None:
|
||||
os.environ["CPUINFER_ENABLE_AMX"] = orig_enable_amx
|
||||
elif "CPUINFER_ENABLE_AMX" in os.environ:
|
||||
del os.environ["CPUINFER_ENABLE_AMX"]
|
||||
|
||||
if orig_enable_avx512 is not None:
|
||||
os.environ["CPUINFER_ENABLE_AVX512"] = orig_enable_avx512
|
||||
elif "CPUINFER_ENABLE_AVX512" in os.environ:
|
||||
del os.environ["CPUINFER_ENABLE_AVX512"]
|
||||
for key, value in env_backup.items():
|
||||
if value is not None:
|
||||
os.environ[key] = value
|
||||
elif key in os.environ:
|
||||
del os.environ[key]
|
||||
|
||||
print("=" * 70)
|
||||
print("✓ All variants built successfully!")
|
||||
print("✓ All 6 variants built successfully!")
|
||||
print("=" * 70)
|
||||
print()
|
||||
print("The wheel now contains 3 CPU variants:")
|
||||
print("The wheel now contains 6 CPU variants:")
|
||||
for so_file in sorted(extdir.glob("_kt_kernel_ext_*.so")):
|
||||
print(f" - {so_file.name}")
|
||||
print()
|
||||
@@ -483,14 +550,38 @@ class CMakeBuild(build_ext):
|
||||
|
||||
# Fine-grained AVX512 subset flags: only enable if CPU actually supports them
|
||||
# These are passed to CMake to conditionally add compiler flags
|
||||
# Track if any AVX512 extension is enabled
|
||||
avx512_extension_enabled = False
|
||||
|
||||
if not _forward_bool_env(cmake_args, "CPUINFER_ENABLE_AVX512_VNNI", "LLAMA_AVX512_VNNI"):
|
||||
if "AVX512_VNNI" in d["features"]:
|
||||
cmake_args.append("-DLLAMA_AVX512_VNNI=ON")
|
||||
print("-- AVX512_VNNI detected; enabling (-DLLAMA_AVX512_VNNI=ON)")
|
||||
avx512_extension_enabled = True
|
||||
else:
|
||||
avx512_extension_enabled = True
|
||||
|
||||
if not _forward_bool_env(cmake_args, "CPUINFER_ENABLE_AVX512_BF16", "LLAMA_AVX512_BF16"):
|
||||
if "AVX512_BF16" in d["features"]:
|
||||
cmake_args.append("-DLLAMA_AVX512_BF16=ON")
|
||||
print("-- AVX512_BF16 detected; enabling (-DLLAMA_AVX512_BF16=ON)")
|
||||
avx512_extension_enabled = True
|
||||
else:
|
||||
avx512_extension_enabled = True
|
||||
|
||||
if not _forward_bool_env(cmake_args, "CPUINFER_ENABLE_AVX512_VBMI", "LLAMA_AVX512_VBMI"):
|
||||
if "AVX512_VBMI" in d["features"]:
|
||||
cmake_args.append("-DLLAMA_AVX512_VBMI=ON")
|
||||
print("-- AVX512_VBMI detected; enabling (-DLLAMA_AVX512_VBMI=ON)")
|
||||
avx512_extension_enabled = True
|
||||
else:
|
||||
avx512_extension_enabled = True
|
||||
|
||||
# If any AVX512 extension is enabled, ensure base AVX512 is also enabled
|
||||
if avx512_extension_enabled and cpu_mode == "NATIVE":
|
||||
if not any("LLAMA_AVX512=ON" in a for a in cmake_args):
|
||||
cmake_args.append("-DLLAMA_AVX512=ON")
|
||||
print("-- AVX512 extensions enabled; also enabling base AVX512F (-DLLAMA_AVX512=ON)")
|
||||
|
||||
# Auto-enable MOE kernel only when env explicitly turns on AMD or KML backend
|
||||
# (Do not enable purely on vendor auto-detection to avoid surprise behavior.)
|
||||
|
||||
Reference in New Issue
Block a user