mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-24 23:24:13 +00:00
FlashMLA: allow for f16 and bf16 cache in addition to q8_0
This commit is contained in:
@@ -17237,25 +17237,73 @@ bool iqk_flash_attn_noalibi(int int_type_k, // type of k
|
||||
float softcap, // if > 0, a "soft-cap" operation is applied before softmax
|
||||
float * qkv) { // v*softmax(scale*(k*q))
|
||||
|
||||
if (!mask || nk1%32 != 0) return false; // the implementation assumes mask is not null and nk is a multiple of 32
|
||||
|
||||
auto type_k = ggml_type(int_type_k);
|
||||
auto type_v = ggml_type(int_type_v);
|
||||
|
||||
if (type_k == GGML_TYPE_Q8_0 && type_v == GGML_TYPE_Q8_0 && Dk == 576 && Dv == 512) {
|
||||
//printf("Using DeepSeek FA with nq1 = %d, nk1 = %d\n", (int)nq1, (int)nk1);
|
||||
HelperQ80<576, 32> kh((const char *)k, stride_k);
|
||||
HelperQ80<512, 32> vh((const char *)v, stride_v);
|
||||
if (nq1 % 8 == 0) {
|
||||
FlashAttn<576, 512, 8, 32> fa(scale, softcap);
|
||||
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv);
|
||||
} else {
|
||||
FlashAttn<576, 512, 1, 32> fa(scale, softcap);
|
||||
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv);
|
||||
// This is a DeepSeek model with MLA. In that case we only have one cache (and k and v are different views of the cache),
|
||||
// so type_k must be the same as type_v
|
||||
GGML_ASSERT(type_k == type_v);
|
||||
switch (type_k) {
|
||||
case GGML_TYPE_Q8_0: {
|
||||
HelperQ80<576, 32> kh((const char *)k, stride_k);
|
||||
HelperQ80<512, 32> vh((const char *)v, stride_v);
|
||||
if (nq1 % 8 == 0) {
|
||||
FlashAttn<576, 512, 8, 32> fa(scale, softcap);
|
||||
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv);
|
||||
} else {
|
||||
FlashAttn<576, 512, 1, 32> fa(scale, softcap);
|
||||
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv);
|
||||
}
|
||||
return true;
|
||||
} break;
|
||||
// Something is wrong with Q8_KV in this case.
|
||||
//case GGML_TYPE_Q8_KV: {
|
||||
// HelperQ8KV<576, 32> kh((const char *)k, stride_k);
|
||||
// HelperQ8KV<512, 32> vh((const char *)v, stride_v);
|
||||
// if (nq1 % 8 == 0) {
|
||||
// FlashAttn<576, 512, 8, 32> fa(scale, softcap);
|
||||
// fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv);
|
||||
// } else {
|
||||
// FlashAttn<576, 512, 1, 32> fa(scale, softcap);
|
||||
// fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv);
|
||||
// }
|
||||
// return true;
|
||||
//} break;
|
||||
case GGML_TYPE_F16: {
|
||||
HelperF16<576, 32> kh((const char *)k, stride_k);
|
||||
HelperF16<512, 32> vh((const char *)v, stride_v);
|
||||
if (nq1 % 8 == 0) {
|
||||
FlashAttn<576, 512, 8, 32> fa(scale, softcap);
|
||||
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv);
|
||||
} else {
|
||||
FlashAttn<576, 512, 1, 32> fa(scale, softcap);
|
||||
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv);
|
||||
}
|
||||
return true;
|
||||
} break;
|
||||
#ifdef __AVX512BF16__
|
||||
case GGML_TYPE_BF16: {
|
||||
HelperBF16<576, 32> kh((const char *)k, stride_k);
|
||||
HelperBF16<512, 32> vh((const char *)v, stride_v);
|
||||
if (nq1 % 8 == 0) {
|
||||
FlashAttn<576, 512, 8, 32> fa(scale, softcap);
|
||||
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv);
|
||||
} else {
|
||||
FlashAttn<576, 512, 1, 32> fa(scale, softcap);
|
||||
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv);
|
||||
}
|
||||
return true;
|
||||
} break;
|
||||
#endif
|
||||
default: break;
|
||||
}
|
||||
return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!flash_attn_is_supported(type_k) || !flash_attn_is_supported(type_v)) return false;
|
||||
if (!mask || nk1%32 != 0) return false; // the implementation assumes mask is not null and nk is a multiple of 32
|
||||
if (Dk != Dv && Dk != 192 && Dv != 128) return false;
|
||||
if (Dv != 64 && Dv != 96 && Dv != 128 && Dv != 256) return false;
|
||||
if (Dk != 64 && Dk != 96 && Dk != 128 && Dk != 192 && Dv != 256) return false;
|
||||
|
||||
Reference in New Issue
Block a user