mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-03-12 15:00:11 +00:00
WIP
This commit is contained in:
@@ -17219,6 +17219,77 @@ inline bool flash_attn_is_supported(ggml_type type) {
|
||||
#endif
|
||||
return false;
|
||||
}
|
||||
|
||||
template <int step_k>
|
||||
inline bool iqk_deepseek_helper(ggml_type type_k,
|
||||
int nq1, int nk1, int stride_q, int stride_k, int stride_v, int stride_m, int stride_qkv,
|
||||
const float * q, const char * k, const char * v, const char * mask,
|
||||
float scale, float softcap, float * qkv) {
|
||||
if (type_k == GGML_TYPE_Q8_0) {
|
||||
HelperQ80<576, step_k> kh((const char *)k, stride_k);
|
||||
HelperQ80<512, step_k> vh((const char *)v, stride_v);
|
||||
if (nq1 % 8 == 0) {
|
||||
FlashAttn<576, 512, 8, step_k> fa(scale, softcap);
|
||||
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv);
|
||||
} else {
|
||||
FlashAttn<576, 512, 1, step_k> fa(scale, softcap);
|
||||
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
if (type_k == GGML_TYPE_Q6_0) {
|
||||
HelperQ60<576, step_k> kh((const char *)k, stride_k);
|
||||
HelperQ60<512, step_k> vh((const char *)v, stride_v);
|
||||
if (nq1 % 8 == 0) {
|
||||
FlashAttn<576, 512, 8, step_k> fa(scale, softcap);
|
||||
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv);
|
||||
} else {
|
||||
FlashAttn<576, 512, 1, step_k> fa(scale, softcap);
|
||||
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
if (type_k == GGML_TYPE_Q8_KV) {
|
||||
HelperQ8KV<576, step_k> kh((const char *)k, stride_k);
|
||||
HelperQ8KV<512, step_k> vh((const char *)v, stride_v);
|
||||
if (nq1 % 8 == 0) {
|
||||
FlashAttn<576, 512, 8, step_k> fa(scale, softcap);
|
||||
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv);
|
||||
} else {
|
||||
FlashAttn<576, 512, 1, step_k> fa(scale, softcap);
|
||||
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
if (type_k == GGML_TYPE_F16) {
|
||||
HelperF16<576, step_k> kh((const char *)k, stride_k);
|
||||
HelperF16<512, step_k> vh((const char *)v, stride_v);
|
||||
if (nq1 % 8 == 0) {
|
||||
FlashAttn<576, 512, 8, step_k> fa(scale, softcap);
|
||||
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv);
|
||||
} else {
|
||||
FlashAttn<576, 512, 1, step_k> fa(scale, softcap);
|
||||
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
#ifdef __AVX512BF16__
|
||||
if (type_k == GGML_TYPE_BF16) {
|
||||
HelperBF16<576, step_k> kh((const char *)k, stride_k);
|
||||
HelperBF16<512, step_k> vh((const char *)v, stride_v);
|
||||
if (nq1 % 8 == 0) {
|
||||
FlashAttnBF16<576, 512, 8, step_k> fa(scale, softcap);
|
||||
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv);
|
||||
} else {
|
||||
FlashAttnBF16<576, 512, 1, step_k> fa(scale, softcap);
|
||||
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
#endif
|
||||
return false;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
bool iqk_flash_attn_noalibi(int int_type_k, // type of k
|
||||
@@ -17246,80 +17317,10 @@ bool iqk_flash_attn_noalibi(int int_type_k, // type of k
|
||||
auto type_v = ggml_type(int_type_v);
|
||||
|
||||
if (Dk == 576 && Dv == 512) {
|
||||
//if (type_k == GGML_TYPE_Q8_0 && type_v == GGML_TYPE_Q8_0 && Dk == 576 && Dv == 512) {
|
||||
// This is a DeepSeek model with MLA. In that case we only have one cache (and k and v are different views of the cache),
|
||||
// so type_k must be the same as type_v
|
||||
GGML_ASSERT(type_k == type_v);
|
||||
stride_q /= sizeof(float); // q stride as float
|
||||
switch (type_k) {
|
||||
case GGML_TYPE_Q8_0: {
|
||||
//printf("%s: nk1 = %d, nq1 = %d, k = %p, v = %p, stride_k = %d stride_v = %d, stride_m = %d\n", __func__, nk1, nq1, k, v, stride_k, stride_v, stride_m);
|
||||
HelperQ80<576, 32> kh((const char *)k, stride_k);
|
||||
HelperQ80<512, 32> vh((const char *)v, stride_v);
|
||||
if (nq1 % 8 == 0) {
|
||||
FlashAttn<576, 512, 8, 32> fa(scale, softcap);
|
||||
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv);
|
||||
} else {
|
||||
FlashAttn<576, 512, 1, 32> fa(scale, softcap);
|
||||
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv);
|
||||
}
|
||||
return true;
|
||||
} break;
|
||||
case GGML_TYPE_Q6_0: {
|
||||
//printf("%s: nk1 = %d, nq1 = %d, k = %p, v = %p, stride_k = %d stride_v = %d, stride_m = %d\n", __func__, nk1, nq1, k, v, stride_k, stride_v, stride_m);
|
||||
HelperQ60<576, 32> kh((const char *)k, stride_k);
|
||||
HelperQ60<512, 32> vh((const char *)v, stride_v);
|
||||
if (nq1 % 8 == 0) {
|
||||
FlashAttn<576, 512, 8, 32> fa(scale, softcap);
|
||||
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv);
|
||||
} else {
|
||||
FlashAttn<576, 512, 1, 32> fa(scale, softcap);
|
||||
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv);
|
||||
}
|
||||
return true;
|
||||
} break;
|
||||
// Something is wrong with Q8_KV in this case.
|
||||
case GGML_TYPE_Q8_KV: {
|
||||
HelperQ8KV<576, 32> kh((const char *)k, stride_k);
|
||||
HelperQ8KV<512, 32> vh((const char *)v, stride_v);
|
||||
if (nq1 % 8 == 0) {
|
||||
FlashAttn<576, 512, 8, 32> fa(scale, softcap);
|
||||
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv);
|
||||
} else {
|
||||
FlashAttn<576, 512, 1, 32> fa(scale, softcap);
|
||||
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv);
|
||||
}
|
||||
return true;
|
||||
} break;
|
||||
case GGML_TYPE_F16: {
|
||||
HelperF16<576, 32> kh((const char *)k, stride_k);
|
||||
HelperF16<512, 32> vh((const char *)v, stride_v);
|
||||
if (nq1 % 8 == 0) {
|
||||
FlashAttn<576, 512, 8, 32> fa(scale, softcap);
|
||||
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv);
|
||||
} else {
|
||||
FlashAttn<576, 512, 1, 32> fa(scale, softcap);
|
||||
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv);
|
||||
}
|
||||
return true;
|
||||
} break;
|
||||
#ifdef __AVX512BF16__
|
||||
case GGML_TYPE_BF16: {
|
||||
HelperBF16<576, 32> kh((const char *)k, stride_k);
|
||||
HelperBF16<512, 32> vh((const char *)v, stride_v);
|
||||
if (nq1 % 8 == 0) {
|
||||
FlashAttnBF16<576, 512, 8, 32> fa(scale, softcap);
|
||||
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv);
|
||||
} else {
|
||||
FlashAttnBF16<576, 512, 1, 32> fa(scale, softcap);
|
||||
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv);
|
||||
}
|
||||
return true;
|
||||
} break;
|
||||
#endif
|
||||
default: break;
|
||||
}
|
||||
return false;
|
||||
return iqk_deepseek_helper<32>(type_k, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||
q, (const char *)k, (const char *)v, (const char *)mask, scale, softcap, qkv);
|
||||
}
|
||||
|
||||
if (!flash_attn_is_supported(type_k) || !flash_attn_is_supported(type_v)) return false;
|
||||
|
||||
Reference in New Issue
Block a user