mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-30 11:21:56 +00:00
Much better CPU TG performance at long context for GLM-4.5
This commit is contained in:
@@ -1291,8 +1291,11 @@ struct FlashQKfp32 {
|
|||||||
template <typename KHelper, typename block_q8>
|
template <typename KHelper, typename block_q8>
|
||||||
static inline void mul_mask_kq(const KHelper& kh, int stride_m,
|
static inline void mul_mask_kq(const KHelper& kh, int stride_m,
|
||||||
const block_q8 * q, const char * mask, FlashMS<q_step, k_step>& fms) {
|
const block_q8 * q, const char * mask, FlashMS<q_step, k_step>& fms) {
|
||||||
constexpr int kMaxQ = 8;
|
// As far as I can tell, this static assert is a remnant of the times where the matrix multiplications were done inline
|
||||||
static_assert(q_step < kMaxQ || q_step%kMaxQ == 0);
|
// here with bespoke kernels instead of just using the regular mat mul kernels. But, just in case, leaving it in place
|
||||||
|
// but commneted out.
|
||||||
|
//constexpr int kMaxQ = 8;
|
||||||
|
//static_assert(q_step < kMaxQ || q_step%kMaxQ == 0);
|
||||||
DataInfo info{fms.cache, (const char *)q, k_step, (D/KHelper::block_size_q)*sizeof(block_q8), 0, 1, nullptr};
|
DataInfo info{fms.cache, (const char *)q, k_step, (D/KHelper::block_size_q)*sizeof(block_q8), 0, 1, nullptr};
|
||||||
if constexpr (std::is_same_v<KHelper, HelperQ8KVR8<D>> ||
|
if constexpr (std::is_same_v<KHelper, HelperQ8KVR8<D>> ||
|
||||||
std::is_same_v<KHelper, HelperQ8KV<D>>) {
|
std::is_same_v<KHelper, HelperQ8KV<D>>) {
|
||||||
@@ -2069,6 +2072,12 @@ inline void iqk_flash_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int str
|
|||||||
if (update(16*n_step)) return;
|
if (update(16*n_step)) return;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if (nq1 == 12) {
|
||||||
|
// Special case: TG for GLM-4.5/4.6
|
||||||
|
FlashAttn<Dk, Dv, 12, k_step> fa(scale, softcap, sinkf);
|
||||||
|
fa.compute(kh, vh, 12, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
|
||||||
|
return;
|
||||||
|
}
|
||||||
if (nq1 >= 8) {
|
if (nq1 >= 8) {
|
||||||
int n_step = nq1/8;
|
int n_step = nq1/8;
|
||||||
FlashAttn<Dk, Dv, 8, k_step> fa(scale, softcap, sinkf);
|
FlashAttn<Dk, Dv, 8, k_step> fa(scale, softcap, sinkf);
|
||||||
|
|||||||
Reference in New Issue
Block a user