Fix CPU FA work buffer size (#1216)

This commit is contained in:
Kawrakow
2026-02-02 12:39:41 +02:00
committed by GitHub
parent 49ba462f22
commit 589d80f677

View File

@@ -52,29 +52,56 @@ size_t iqk_fa_work_buffer_size(const struct ggml_tensor * dst, int nth) {
auto V = dst->src[2];
int rk2 = Q->ne[2]/K->ne[2];
size_t size = 0;
if (K->type == GGML_TYPE_Q8_0 && (Q->ne[1] >= 8 || (rk2 >= 8 && K->ne[2] > 1))) {
if (Q->ne[1] >= 8 && K->type == GGML_TYPE_Q8_0) {
size = ggml_row_size(GGML_TYPE_Q8_0, K->ne[0]) * K->ne[1]*K->ne[2]*K->ne[3];
}
int nstep_k = K->ne[1]/32;
if (nstep_k >= 4*nth) {
auto size_thread = (V->ne[0] + 16)*rk2*sizeof(float);
size += size_thread*nth;
return size;
}
int gcd_k = simple_gcd(nstep_k, nth);
if (gcd_k >= 1) {
int nth_k = nth/gcd_k;
int nq_per_thread = (rk2 + nth_k - 1)/nth_k;
if (nq_per_thread > 1) {
auto size_thread = (V->ne[0] + 16)*nq_per_thread*sizeof(float);
if (Q->ne[1] == 1 && Q->ne[3] == 1 && Q->ne[2]/K->ne[2] > 1 && nth >= 1 && K->ne[1]/32 > 1) {
if (K->ne[2] > 1) {
int gcd = simple_gcd(K->ne[2], nth);
int nth_k = nth/gcd;
int nek2_k = K->ne[2]/gcd;
int nchunk = nek2_k*K->ne[1]/32;
int npt = (nchunk + nth_k - 1)/nth_k;
int nk;
if (npt*nth_k == nchunk) {
nk = 32 * (K->ne[1]*K->ne[2]/(32*nth));
} else {
//int nm = std::max(1, npt/8);
int nm = 1;
while (true) {
if (nm*4 >= npt) break;
nm *= 2;
}
nk = 32*nm;
}
int nkk = (K->ne[1] + nk - 1)/nk;
int nstep_k = K->ne[2]*nkk;
size_t result_size = (V->ne[0] + 16)*Q->ne[2]/K->ne[2]*sizeof(float);
size += nstep_k*result_size;
return size;
}
int nstep_k = K->ne[1]/32;
if (nstep_k >= 4*nth) {
auto size_thread = (V->ne[0] + 16)*rk2*sizeof(float);
size += size_thread*nth;
return size;
}
}
int rv2 = Q->ne[2] / V->ne[2];
if (Q->ne[1] == 1 && Q->ne[3] == 1 && rk2 > 1 && rk2 == rv2 && K->ne[1]*K->ne[2] >= 32*nth) {
auto result_size = (V->ne[0] + 16)*rk2*sizeof(float);
size += result_size*nth;
int gcd_k = simple_gcd(nstep_k, nth);
if (gcd_k >= 1) {
int nth_k = nth/gcd_k;
int nq_per_thread = (rk2 + nth_k - 1)/nth_k;
if (nq_per_thread > 1) {
auto size_thread = (V->ne[0] + 16)*nq_per_thread*sizeof(float);
size += size_thread*nth;
return size;
}
}
int rv2 = Q->ne[2] / V->ne[2];
if (Q->ne[1] == 1 && Q->ne[3] == 1 && rk2 > 1 && rk2 == rv2 && K->ne[1]*K->ne[2] >= 32*nth) {
auto result_size = (V->ne[0] + 16)*rk2*sizeof(float);
size += result_size*nth;
}
return size;
}
return size;
}