Disallow mixing bf16 with other types for kv caches

This commit is contained in:
Iwan Kawrakow
2024-09-12 18:55:13 +03:00
parent cdd51579e0
commit 27fa27daf9

View File

@@ -8139,7 +8139,8 @@ bool iqk_flash_attn_noalibi(int int_type_k, // type of k
stride_q /= sizeof(float); // q stride as float
#ifdef __AVX512BF16__
if (type_k == GGML_TYPE_BF16 && type_v == GGML_TYPE_BF16) {
if (type_k == GGML_TYPE_BF16 || type_v == GGML_TYPE_BF16) {
if (type_k != GGML_TYPE_BF16 || type_v != GGML_TYPE_BF16) return false; // we do not support mixing bf16 with other types
switch (D) {
case 64:
iqk_flash_helper_T< 64, 8, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;