mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-01-26 17:20:01 +00:00
Much better CPU TG performance at long context for GLM-4.5 (#899)
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
@@ -1291,8 +1291,11 @@ struct FlashQKfp32 {
|
||||
template <typename KHelper, typename block_q8>
|
||||
static inline void mul_mask_kq(const KHelper& kh, int stride_m,
|
||||
const block_q8 * q, const char * mask, FlashMS<q_step, k_step>& fms) {
|
||||
constexpr int kMaxQ = 8;
|
||||
static_assert(q_step < kMaxQ || q_step%kMaxQ == 0);
|
||||
// As far as I can tell, this static assert is a remnant of the times where the matrix multiplications were done inline
|
||||
// here with bespoke kernels instead of just using the regular mat mul kernels. But, just in case, leaving it in place
|
||||
// but commneted out.
|
||||
//constexpr int kMaxQ = 8;
|
||||
//static_assert(q_step < kMaxQ || q_step%kMaxQ == 0);
|
||||
DataInfo info{fms.cache, (const char *)q, k_step, (D/KHelper::block_size_q)*sizeof(block_q8), 0, 1, nullptr};
|
||||
if constexpr (std::is_same_v<KHelper, HelperQ8KVR8<D>> ||
|
||||
std::is_same_v<KHelper, HelperQ8KV<D>>) {
|
||||
@@ -2069,6 +2072,12 @@ inline void iqk_flash_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int str
|
||||
if (update(16*n_step)) return;
|
||||
}
|
||||
}
|
||||
if (nq1 == 12) {
|
||||
// Special case: TG for GLM-4.5/4.6
|
||||
FlashAttn<Dk, Dv, 12, k_step> fa(scale, softcap, sinkf);
|
||||
fa.compute(kh, vh, 12, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
|
||||
return;
|
||||
}
|
||||
if (nq1 >= 8) {
|
||||
int n_step = nq1/8;
|
||||
FlashAttn<Dk, Dv, 8, k_step> fa(scale, softcap, sinkf);
|
||||
|
||||
Reference in New Issue
Block a user