FlashMLA: allow for f16 and bf16 cache in addition to q8_0

This commit is contained in:
Iwan Kawrakow
2025-03-02 14:22:35 +02:00
parent 16569c670c
commit af91231f93

View File

@@ -17237,25 +17237,73 @@ bool iqk_flash_attn_noalibi(int int_type_k, // type of k
float softcap, // if > 0, a "soft-cap" operation is applied before softmax
float * qkv) { // v*softmax(scale*(k*q))
if (!mask || nk1%32 != 0) return false; // the implementation assumes mask is not null and nk is a multiple of 32
auto type_k = ggml_type(int_type_k);
auto type_v = ggml_type(int_type_v);
if (type_k == GGML_TYPE_Q8_0 && type_v == GGML_TYPE_Q8_0 && Dk == 576 && Dv == 512) {
//printf("Using DeepSeek FA with nq1 = %d, nk1 = %d\n", (int)nq1, (int)nk1);
HelperQ80<576, 32> kh((const char *)k, stride_k);
HelperQ80<512, 32> vh((const char *)v, stride_v);
if (nq1 % 8 == 0) {
FlashAttn<576, 512, 8, 32> fa(scale, softcap);
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv);
} else {
FlashAttn<576, 512, 1, 32> fa(scale, softcap);
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv);
// This is a DeepSeek model with MLA. In that case we only have one cache (and k and v are different views of the cache),
// so type_k must be the same as type_v
GGML_ASSERT(type_k == type_v);
switch (type_k) {
case GGML_TYPE_Q8_0: {
HelperQ80<576, 32> kh((const char *)k, stride_k);
HelperQ80<512, 32> vh((const char *)v, stride_v);
if (nq1 % 8 == 0) {
FlashAttn<576, 512, 8, 32> fa(scale, softcap);
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv);
} else {
FlashAttn<576, 512, 1, 32> fa(scale, softcap);
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv);
}
return true;
} break;
// Something is wrong with Q8_KV in this case.
//case GGML_TYPE_Q8_KV: {
// HelperQ8KV<576, 32> kh((const char *)k, stride_k);
// HelperQ8KV<512, 32> vh((const char *)v, stride_v);
// if (nq1 % 8 == 0) {
// FlashAttn<576, 512, 8, 32> fa(scale, softcap);
// fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv);
// } else {
// FlashAttn<576, 512, 1, 32> fa(scale, softcap);
// fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv);
// }
// return true;
//} break;
case GGML_TYPE_F16: {
HelperF16<576, 32> kh((const char *)k, stride_k);
HelperF16<512, 32> vh((const char *)v, stride_v);
if (nq1 % 8 == 0) {
FlashAttn<576, 512, 8, 32> fa(scale, softcap);
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv);
} else {
FlashAttn<576, 512, 1, 32> fa(scale, softcap);
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv);
}
return true;
} break;
#ifdef __AVX512BF16__
case GGML_TYPE_BF16: {
HelperBF16<576, 32> kh((const char *)k, stride_k);
HelperBF16<512, 32> vh((const char *)v, stride_v);
if (nq1 % 8 == 0) {
FlashAttn<576, 512, 8, 32> fa(scale, softcap);
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv);
} else {
FlashAttn<576, 512, 1, 32> fa(scale, softcap);
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv);
}
return true;
} break;
#endif
default: break;
}
return true;
return false;
}
if (!flash_attn_is_supported(type_k) || !flash_attn_is_supported(type_v)) return false;
if (!mask || nk1%32 != 0) return false; // the implementation assumes mask is not null and nk is a multiple of 32
if (Dk != Dv && Dk != 192 && Dv != 128) return false;
if (Dv != 64 && Dv != 96 && Dv != 128 && Dv != 256) return false;
if (Dk != 64 && Dk != 96 && Dk != 128 && Dk != 192 && Dv != 256) return false;