mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-20 21:24:08 +00:00
Fix CPU FA work buffer size (#1216)
This commit is contained in:
@@ -52,29 +52,56 @@ size_t iqk_fa_work_buffer_size(const struct ggml_tensor * dst, int nth) {
|
||||
auto V = dst->src[2];
|
||||
int rk2 = Q->ne[2]/K->ne[2];
|
||||
size_t size = 0;
|
||||
if (K->type == GGML_TYPE_Q8_0 && (Q->ne[1] >= 8 || (rk2 >= 8 && K->ne[2] > 1))) {
|
||||
if (Q->ne[1] >= 8 && K->type == GGML_TYPE_Q8_0) {
|
||||
size = ggml_row_size(GGML_TYPE_Q8_0, K->ne[0]) * K->ne[1]*K->ne[2]*K->ne[3];
|
||||
}
|
||||
int nstep_k = K->ne[1]/32;
|
||||
if (nstep_k >= 4*nth) {
|
||||
auto size_thread = (V->ne[0] + 16)*rk2*sizeof(float);
|
||||
size += size_thread*nth;
|
||||
return size;
|
||||
}
|
||||
int gcd_k = simple_gcd(nstep_k, nth);
|
||||
if (gcd_k >= 1) {
|
||||
int nth_k = nth/gcd_k;
|
||||
int nq_per_thread = (rk2 + nth_k - 1)/nth_k;
|
||||
if (nq_per_thread > 1) {
|
||||
auto size_thread = (V->ne[0] + 16)*nq_per_thread*sizeof(float);
|
||||
if (Q->ne[1] == 1 && Q->ne[3] == 1 && Q->ne[2]/K->ne[2] > 1 && nth >= 1 && K->ne[1]/32 > 1) {
|
||||
if (K->ne[2] > 1) {
|
||||
int gcd = simple_gcd(K->ne[2], nth);
|
||||
int nth_k = nth/gcd;
|
||||
int nek2_k = K->ne[2]/gcd;
|
||||
int nchunk = nek2_k*K->ne[1]/32;
|
||||
int npt = (nchunk + nth_k - 1)/nth_k;
|
||||
int nk;
|
||||
if (npt*nth_k == nchunk) {
|
||||
nk = 32 * (K->ne[1]*K->ne[2]/(32*nth));
|
||||
} else {
|
||||
//int nm = std::max(1, npt/8);
|
||||
int nm = 1;
|
||||
while (true) {
|
||||
if (nm*4 >= npt) break;
|
||||
nm *= 2;
|
||||
}
|
||||
nk = 32*nm;
|
||||
}
|
||||
int nkk = (K->ne[1] + nk - 1)/nk;
|
||||
int nstep_k = K->ne[2]*nkk;
|
||||
size_t result_size = (V->ne[0] + 16)*Q->ne[2]/K->ne[2]*sizeof(float);
|
||||
size += nstep_k*result_size;
|
||||
return size;
|
||||
}
|
||||
int nstep_k = K->ne[1]/32;
|
||||
if (nstep_k >= 4*nth) {
|
||||
auto size_thread = (V->ne[0] + 16)*rk2*sizeof(float);
|
||||
size += size_thread*nth;
|
||||
return size;
|
||||
}
|
||||
}
|
||||
int rv2 = Q->ne[2] / V->ne[2];
|
||||
if (Q->ne[1] == 1 && Q->ne[3] == 1 && rk2 > 1 && rk2 == rv2 && K->ne[1]*K->ne[2] >= 32*nth) {
|
||||
auto result_size = (V->ne[0] + 16)*rk2*sizeof(float);
|
||||
size += result_size*nth;
|
||||
int gcd_k = simple_gcd(nstep_k, nth);
|
||||
if (gcd_k >= 1) {
|
||||
int nth_k = nth/gcd_k;
|
||||
int nq_per_thread = (rk2 + nth_k - 1)/nth_k;
|
||||
if (nq_per_thread > 1) {
|
||||
auto size_thread = (V->ne[0] + 16)*nq_per_thread*sizeof(float);
|
||||
size += size_thread*nth;
|
||||
return size;
|
||||
}
|
||||
}
|
||||
int rv2 = Q->ne[2] / V->ne[2];
|
||||
if (Q->ne[1] == 1 && Q->ne[3] == 1 && rk2 > 1 && rk2 == rv2 && K->ne[1]*K->ne[2] >= 32*nth) {
|
||||
auto result_size = (V->ne[0] + 16)*rk2*sizeof(float);
|
||||
size += result_size*nth;
|
||||
}
|
||||
return size;
|
||||
}
|
||||
return size;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user