mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-03-09 13:30:17 +00:00
WIP
This commit is contained in:
@@ -17870,7 +17870,9 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
||||
}
|
||||
|
||||
#if GGML_USE_IQK_MULMAT
|
||||
if (false && max_bias <= 0.0f && q->type == GGML_TYPE_F32 && mask && mask->type == GGML_TYPE_F16) {
|
||||
if (max_bias <= 0.0f && q->type == GGML_TYPE_F32 && mask && mask->type == GGML_TYPE_F16) {
|
||||
//if (ith == 0) printf("k: %ld x %ld x %ld, q: %ld x %ld x %ld, v: %ld x %ld x %ld mask: %ld x %ld x %ld\n",
|
||||
// k->ne[0], k->ne[1], k->ne[2], q->ne[0], q->ne[1], q->ne[2], v->ne[0], v->ne[1], v->ne[2], mask->ne[0], mask->ne[1], mask->ne[2]);
|
||||
// I keep changing my mind what is the best strategy to split the threads when processing
|
||||
// multiple heads. This is my current thinking, the commented out code below was the previous.
|
||||
int ntg = nth/simple_gcd(neq2*neq3, nth);
|
||||
|
||||
@@ -17249,15 +17249,16 @@ bool iqk_flash_attn_noalibi(int int_type_k, // type of k
|
||||
GGML_ASSERT(type_k == type_v);
|
||||
switch (type_k) {
|
||||
case GGML_TYPE_Q8_0: {
|
||||
//printf("%s: nk1 = %d, nq1 = %d, k = %p, v = %p, stride_k = %d stride_v = %d, stride_m = %d\n", __func__, nk1, nq1, k, v, stride_k, stride_v, stride_m);
|
||||
HelperQ80<576, 32> kh((const char *)k, stride_k);
|
||||
HelperQ80<512, 32> vh((const char *)v, stride_v);
|
||||
//if (nq1 % 8 == 0) {
|
||||
// FlashAttn<576, 512, 8, 32> fa(scale, softcap);
|
||||
// fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv);
|
||||
//} else {
|
||||
if (nq1 % 8 == 0) {
|
||||
FlashAttn<576, 512, 8, 32> fa(scale, softcap);
|
||||
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv);
|
||||
} else {
|
||||
FlashAttn<576, 512, 1, 32> fa(scale, softcap);
|
||||
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv);
|
||||
//}
|
||||
}
|
||||
return true;
|
||||
} break;
|
||||
// Something is wrong with Q8_KV in this case.
|
||||
@@ -17276,26 +17277,26 @@ bool iqk_flash_attn_noalibi(int int_type_k, // type of k
|
||||
case GGML_TYPE_F16: {
|
||||
HelperF16<576, 32> kh((const char *)k, stride_k);
|
||||
HelperF16<512, 32> vh((const char *)v, stride_v);
|
||||
//if (nq1 % 8 == 0) {
|
||||
// FlashAttn<576, 512, 8, 32> fa(scale, softcap);
|
||||
// fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv);
|
||||
//} else {
|
||||
if (nq1 % 8 == 0) {
|
||||
FlashAttn<576, 512, 8, 32> fa(scale, softcap);
|
||||
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv);
|
||||
} else {
|
||||
FlashAttn<576, 512, 1, 32> fa(scale, softcap);
|
||||
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv);
|
||||
//}
|
||||
}
|
||||
return true;
|
||||
} break;
|
||||
#ifdef __AVX512BF16__
|
||||
case GGML_TYPE_BF16: {
|
||||
HelperBF16<576, 32> kh((const char *)k, stride_k);
|
||||
HelperBF16<512, 32> vh((const char *)v, stride_v);
|
||||
//if (nq1 % 8 == 0) {
|
||||
// FlashAttn<576, 512, 8, 32> fa(scale, softcap);
|
||||
// fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv);
|
||||
//} else {
|
||||
if (nq1 % 8 == 0) {
|
||||
FlashAttn<576, 512, 8, 32> fa(scale, softcap);
|
||||
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv);
|
||||
} else {
|
||||
FlashAttn<576, 512, 1, 32> fa(scale, softcap);
|
||||
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv);
|
||||
//}
|
||||
}
|
||||
return true;
|
||||
} break;
|
||||
#endif
|
||||
|
||||
Reference in New Issue
Block a user