diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 77e3b538..0bcdd269 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -17870,7 +17870,9 @@ static void ggml_compute_forward_flash_attn_ext_f16( } #if GGML_USE_IQK_MULMAT - if (false && max_bias <= 0.0f && q->type == GGML_TYPE_F32 && mask && mask->type == GGML_TYPE_F16) { + if (max_bias <= 0.0f && q->type == GGML_TYPE_F32 && mask && mask->type == GGML_TYPE_F16) { + //if (ith == 0) printf("k: %ld x %ld x %ld, q: %ld x %ld x %ld, v: %ld x %ld x %ld mask: %ld x %ld x %ld\n", + // k->ne[0], k->ne[1], k->ne[2], q->ne[0], q->ne[1], q->ne[2], v->ne[0], v->ne[1], v->ne[2], mask->ne[0], mask->ne[1], mask->ne[2]); // I keep changing my mind what is the best strategy to split the threads when processing // multiple heads. This is my current thinking, the commented out code below was the previous. int ntg = nth/simple_gcd(neq2*neq3, nth); diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 106f2a56..55725c15 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -17249,15 +17249,16 @@ bool iqk_flash_attn_noalibi(int int_type_k, // type of k GGML_ASSERT(type_k == type_v); switch (type_k) { case GGML_TYPE_Q8_0: { + //printf("%s: nk1 = %d, nq1 = %d, k = %p, v = %p, stride_k = %d stride_v = %d, stride_m = %d\n", __func__, nk1, nq1, k, v, stride_k, stride_v, stride_m); HelperQ80<576, 32> kh((const char *)k, stride_k); HelperQ80<512, 32> vh((const char *)v, stride_v); - //if (nq1 % 8 == 0) { - // FlashAttn<576, 512, 8, 32> fa(scale, softcap); - // fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); - //} else { + if (nq1 % 8 == 0) { + FlashAttn<576, 512, 8, 32> fa(scale, softcap); + fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); + } else { FlashAttn<576, 512, 1, 32> fa(scale, softcap); fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); - //} + } return true; } break; // Something is wrong with Q8_KV in this case. @@ -17276,26 +17277,26 @@ bool iqk_flash_attn_noalibi(int int_type_k, // type of k case GGML_TYPE_F16: { HelperF16<576, 32> kh((const char *)k, stride_k); HelperF16<512, 32> vh((const char *)v, stride_v); - //if (nq1 % 8 == 0) { - // FlashAttn<576, 512, 8, 32> fa(scale, softcap); - // fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); - //} else { + if (nq1 % 8 == 0) { + FlashAttn<576, 512, 8, 32> fa(scale, softcap); + fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); + } else { FlashAttn<576, 512, 1, 32> fa(scale, softcap); fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); - //} + } return true; } break; #ifdef __AVX512BF16__ case GGML_TYPE_BF16: { HelperBF16<576, 32> kh((const char *)k, stride_k); HelperBF16<512, 32> vh((const char *)v, stride_v); - //if (nq1 % 8 == 0) { - // FlashAttn<576, 512, 8, 32> fa(scale, softcap); - // fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); - //} else { + if (nq1 % 8 == 0) { + FlashAttn<576, 512, 8, 32> fa(scale, softcap); + fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); + } else { FlashAttn<576, 512, 1, 32> fa(scale, softcap); fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); - //} + } return true; } break; #endif