FA: Add option to build all FA kernels

Similar to the CUDA situation.
It is OFF by default.
If OFF, only F16, Q8_0, Q6_0, and, if the CPU provides native
BF16 support, BF16 FA kernels will be included.
To enable all, cmake -DGGML_IQK_FA_ALL_QUANTS=1 ...
This cuts compilation time for iqk_mul_mat.cpp by almost half
(45 seconds vs 81 seconds on my Ryzen-7950X).
This commit is contained in:
Iwan Kawrakow
2025-02-09 18:50:50 +02:00
parent 33390c4b74
commit 01e2b0c2ce
3 changed files with 39 additions and 33 deletions

View File

@@ -130,6 +130,8 @@ option(GGML_CUDA_NO_VMM "ggml: do not try to use CUDA VMM"
option(GGML_CUDA_FA_ALL_QUANTS "ggml: compile all quants for FlashAttention" OFF)
option(GGML_CUDA_USE_GRAPHS "ggml: use CUDA graphs (llama.cpp only)" OFF)
option(GGML_IQK_FA_ALL_QUANTS "ggml: compile all quants for IQK FlashAttention" OFF)
option(GGML_CURL "ggml: use libcurl to download model from an URL" OFF)
option(GGML_HIPBLAS "ggml: use hipBLAS" OFF)
option(GGML_HIP_UMA "ggml: use HIP unified memory architecture" OFF)

View File

@@ -259,6 +259,10 @@ if (GGML_IQK_MUL_MAT)
add_compile_definitions(GGML_USE_IQK_MULMAT)
set(GGML_SOURCES_IQK_MM iqk/iqk_mul_mat.cpp)
set(GGML_HEADERS_IQK_MM iqk/iqk_mul_mat.h)
if (GGML_IQK_FA_ALL_QUANTS)
message(STATUS "Including all IQK FA kernels")
add_compile_definitions(GGML_IQK_FA_ALL_QUANTS)
endif()
endif()
if (GGML_LLAMAFILE)

View File

@@ -15239,14 +15239,7 @@ struct FlashQKfp32 {
case 7: return std::make_pair(mul_mat<7>, 7);\
}\
}
if constexpr (std::is_same_v<KHelper, HelperQ40<D, k_step>>) {
#ifdef __aarch64__
MAKE_FUNCS(mul_mat_qX_0_q8_0<DequantizerQ40, nq);
#else
MAKE_FUNCS(mul_mat_qX_0_q8_0_T<Q4_0_Unpacker, nq);
#endif
}
else if constexpr (std::is_same_v<KHelper, HelperQ80<D, k_step>>) {
if constexpr (std::is_same_v<KHelper, HelperQ80<D, k_step>>) {
#ifdef __aarch64__
MAKE_FUNCS(mul_mat_qX_0_q8_0<DequantizerQ80, nq);
#else
@@ -15262,6 +15255,21 @@ struct FlashQKfp32 {
MAKE_FUNCS_ONLY_NRC(mul_mat_q8_0_r8_q8_0, nq);
#else
MAKE_FUNCS_ONLY_NRC(mul_mat_q8_0_r8_q8_1, nq);
#endif
}
else if constexpr (std::is_same_v<KHelper, HelperQ60<D, k_step>>) {
#ifdef __aarch64__
MAKE_FUNCS(mul_mat_qX_0_q8_0<DequantizerQ60, nq);
#else
MAKE_FUNCS(mul_mat_qX_1_q8_1_T<Q6_0_1_Unpacker, nq);
#endif
}
#if GGML_IQK_FA_ALL_QUANTS
else if constexpr (std::is_same_v<KHelper, HelperQ40<D, k_step>>) {
#ifdef __aarch64__
MAKE_FUNCS(mul_mat_qX_0_q8_0<DequantizerQ40, nq);
#else
MAKE_FUNCS(mul_mat_qX_0_q8_0_T<Q4_0_Unpacker, nq);
#endif
}
else if constexpr (std::is_same_v<KHelper, HelperQ41<D, k_step>>) {
@@ -15278,13 +15286,7 @@ struct FlashQKfp32 {
MAKE_FUNCS(mul_mat_qX_1_q8_1_T<IQ4_NL_Unpacker, nq);
#endif
}
else if constexpr (std::is_same_v<KHelper, HelperQ60<D, k_step>>) {
#ifdef __aarch64__
MAKE_FUNCS(mul_mat_qX_0_q8_0<DequantizerQ60, nq);
#else
MAKE_FUNCS(mul_mat_qX_1_q8_1_T<Q6_0_1_Unpacker, nq);
#endif
}
else {
GGML_ASSERT(false);
}
@@ -15493,17 +15495,6 @@ struct FlashAttn {
template <typename KHelper, typename VHelper>
void compute(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv,
const float * q, const char * mask, float * qkv) {
// if constexpr (std::is_same_v<KHelper, HelperQ40<D, k_step>> || std::is_same_v<KHelper, HelperQ41<D, k_step>> ||
// std::is_same_v<KHelper, HelperIQ4nl<D, k_step>> ||
// std::is_same_v<KHelper, HelperQ80<D, k_step>> ||
// std::is_same_v<KHelper, HelperQ80R4<D, k_step>> ||
// std::is_same_v<KHelper, HelperQ60<D, k_step>>) {
// compute_helper_q<D, q_step, k_step, KHelper, VHelper, FlashQKfp32<D, q_step, k_step>>(
// kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv);
// } else {
// compute_helper<D, q_step, k_step, KHelper, VHelper, FlashQKfp32<D, q_step, k_step>>(
// kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv);
// }
if constexpr (std::is_same_v<KHelper, HelperQ40<D, k_step>> || std::is_same_v<KHelper, HelperQ41<D, k_step>> ||
std::is_same_v<KHelper, HelperIQ4nl<D, k_step>> ||
std::is_same_v<KHelper, HelperQ60<D, k_step>>) {
@@ -16027,6 +16018,11 @@ inline void iqk_flash_helper_T(KHelper& kh, ggml_type type_v,
HelperQ80<D, k_step> vh(v, stride_v);
iqk_flash_helper<D, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv);
} break;
case GGML_TYPE_Q6_0: {
HelperQ60<D, k_step> vh(v, stride_v);
iqk_flash_helper<D, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv);
} break;
#if GGML_IQK_FA_ALL_QUANTS
case GGML_TYPE_Q4_0: {
HelperQ40<D, k_step> vh(v, stride_v);
iqk_flash_helper<D, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv);
@@ -16039,10 +16035,7 @@ inline void iqk_flash_helper_T(KHelper& kh, ggml_type type_v,
HelperIQ4nl<D, k_step> vh(v, stride_v);
iqk_flash_helper<D, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv);
} break;
case GGML_TYPE_Q6_0: {
HelperQ60<D, k_step> vh(v, stride_v);
iqk_flash_helper<D, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv);
} break;
#endif
default: break;
}
}
@@ -16062,6 +16055,11 @@ inline void iqk_flash_helper_T(ggml_type type_k, ggml_type type_v,
HelperQ80<D, k_step> kh(k, stride_k);
iqk_flash_helper_T<D, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv);
} break;
case GGML_TYPE_Q6_0: {
HelperQ60<D, k_step> kh(k, stride_k);
iqk_flash_helper_T<D, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv);
} break;
#if GGML_IQK_FA_ALL_QUANTS
case GGML_TYPE_Q4_0: {
HelperQ40<D, k_step> kh(k, stride_k);
iqk_flash_helper_T<D, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv);
@@ -16074,10 +16072,7 @@ inline void iqk_flash_helper_T(ggml_type type_k, ggml_type type_v,
HelperIQ4nl<D, k_step> kh(k, stride_k);
iqk_flash_helper_T<D, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv);
} break;
case GGML_TYPE_Q6_0: {
HelperQ60<D, k_step> kh(k, stride_k);
iqk_flash_helper_T<D, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv);
} break;
#endif
default: break;
}
@@ -16087,8 +16082,12 @@ inline bool flash_attn_is_supported(ggml_type type) {
#ifdef __AVX512BF16__
if (type == GGML_TYPE_BF16) return true;
#endif
#if GGML_IQK_FA_ALL_QUANTS
if (type == GGML_TYPE_F16 || type == GGML_TYPE_Q8_0 || type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1 ||
type == GGML_TYPE_Q6_0 || type == GGML_TYPE_IQ4_NL) return true;
#else
if (type == GGML_TYPE_F16 || type == GGML_TYPE_Q8_0 || type == GGML_TYPE_Q6_0) return true;
#endif
return false;
}
}
@@ -16115,6 +16114,7 @@ bool iqk_flash_attn_noalibi(int int_type_k, // type of k
auto type_v = ggml_type(int_type_v);
if (!flash_attn_is_supported(type_k) || !flash_attn_is_supported(type_v)) return false;
if (!mask || nk1%32 != 0) return false; // the implementation assumes mask is not null and nk is a multiple of 32
if (D != 64 && D != 96 && D != 128 && D != 256) return false;
auto ck = (const char *)k;
auto cv = (const char *)v;