diff --git a/ggml/src/iqk/fa/iqk_fa_templates.h b/ggml/src/iqk/fa/iqk_fa_templates.h index 481515cb..3a0b7248 100644 --- a/ggml/src/iqk/fa/iqk_fa_templates.h +++ b/ggml/src/iqk/fa/iqk_fa_templates.h @@ -1291,8 +1291,11 @@ struct FlashQKfp32 { template static inline void mul_mask_kq(const KHelper& kh, int stride_m, const block_q8 * q, const char * mask, FlashMS& fms) { - constexpr int kMaxQ = 8; - static_assert(q_step < kMaxQ || q_step%kMaxQ == 0); + // As far as I can tell, this static assert is a remnant of the times where the matrix multiplications were done inline + // here with bespoke kernels instead of just using the regular mat mul kernels. But, just in case, leaving it in place + // but commneted out. + //constexpr int kMaxQ = 8; + //static_assert(q_step < kMaxQ || q_step%kMaxQ == 0); DataInfo info{fms.cache, (const char *)q, k_step, (D/KHelper::block_size_q)*sizeof(block_q8), 0, 1, nullptr}; if constexpr (std::is_same_v> || std::is_same_v>) { @@ -2069,6 +2072,12 @@ inline void iqk_flash_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int str if (update(16*n_step)) return; } } + if (nq1 == 12) { + // Special case: TG for GLM-4.5/4.6 + FlashAttn fa(scale, softcap, sinkf); + fa.compute(kh, vh, 12, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); + return; + } if (nq1 >= 8) { int n_step = nq1/8; FlashAttn fa(scale, softcap, sinkf);