Much better CPU TG performance at long context for GLM-4.5 (#899)

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
Kawrakow
2025-11-05 10:20:26 +02:00
committed by GitHub
parent 98357d9aa5
commit 92607d44c4

View File

@@ -1291,8 +1291,11 @@ struct FlashQKfp32 {
template <typename KHelper, typename block_q8>
static inline void mul_mask_kq(const KHelper& kh, int stride_m,
const block_q8 * q, const char * mask, FlashMS<q_step, k_step>& fms) {
constexpr int kMaxQ = 8;
static_assert(q_step < kMaxQ || q_step%kMaxQ == 0);
// As far as I can tell, this static assert is a remnant of the times where the matrix multiplications were done inline
// here with bespoke kernels instead of just using the regular mat mul kernels. But, just in case, leaving it in place
// but commneted out.
//constexpr int kMaxQ = 8;
//static_assert(q_step < kMaxQ || q_step%kMaxQ == 0);
DataInfo info{fms.cache, (const char *)q, k_step, (D/KHelper::block_size_q)*sizeof(block_q8), 0, 1, nullptr};
if constexpr (std::is_same_v<KHelper, HelperQ8KVR8<D>> ||
std::is_same_v<KHelper, HelperQ8KV<D>>) {
@@ -2069,6 +2072,12 @@ inline void iqk_flash_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int str
if (update(16*n_step)) return;
}
}
if (nq1 == 12) {
// Special case: TG for GLM-4.5/4.6
FlashAttn<Dk, Dv, 12, k_step> fa(scale, softcap, sinkf);
fa.compute(kh, vh, 12, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
return;
}
if (nq1 >= 8) {
int n_step = nq1/8;
FlashAttn<Dk, Dv, 8, k_step> fa(scale, softcap, sinkf);