Work buffer size

This commit is contained in:
Kawrakow
2026-01-31 16:10:23 +00:00
parent 2bf2fa8ba4
commit 685df0e69d
3 changed files with 42 additions and 68 deletions

View File

@@ -25253,68 +25253,8 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
cur = 3*sizeof(float)*D*n_tasks; // 3x head size/thread
#if GGML_USE_IQK_MULMAT
size_t qsize = 0;
const struct ggml_tensor * q = node->src[0];
const struct ggml_tensor * k = node->src[1];
if (k->type == GGML_TYPE_Q8_0) {
qsize = ggml_nrows(k)*ggml_row_size(k->type, k->ne[0]);
}
int nstep_k = k->ne[1]/32;
if (nstep_k >= 4*n_tasks && q->ne[1] == 1 && q->ne[3] == 1 && q->ne[2]/k->ne[2] > 1) {
size_t size_thread = (Dv + 16)*q->ne[2]/k->ne[2]*sizeof(float);
size_t size = size_thread*n_tasks;
cur = MAX(cur, size+qsize);
} else {
if (q->ne[1] == 1 && q->ne[3] == 1 && q->ne[2]/k->ne[2] > 1 && n_tasks > 1 && k->ne[1]/32 > 1) {
if (k->ne[2] > 1) {
int gcd = simple_gcd(k->ne[2], n_tasks);
int nth_k = n_tasks/gcd;
int nek2_k = k->ne[2]/gcd;
int nchunk = nek2_k*k->ne[1]/32;
int npt = (nchunk + nth_k - 1)/nth_k;
int nk;
if (npt*nth_k == nchunk) {
nk = 32 * (k->ne[1]*k->ne[2]/(32*n_tasks));
} else {
//int nm = std::max(1, npt/8);
int nm = 1;
while (true) {
if (nm*4 >= npt) break;
nm *= 2;
}
nk = 32*nm;
}
//int nk = 32 * (k->ne[2]*k->ne[1]/(32*n_tasks));
nstep_k = k->ne[2]*k->ne[1]/nk;
size_t result_size = (Dv + 16)*q->ne[2]/k->ne[2]*sizeof(float);
size_t size = nstep_k*result_size;
cur = MAX(cur, size+qsize);
} else {
nstep_k = k->ne[1]/32;
if (nstep_k >= n_tasks) {
size_t size_thread = (Dv + 16)*q->ne[2]/k->ne[2]*sizeof(float);
size_t size = size_thread*n_tasks;
cur = MAX(cur, size+qsize);
} else {
int gcd_k = simple_gcd(nstep_k, n_tasks);
if (gcd_k > 1) {
int nth_k = n_tasks/gcd_k;
int rk2 = q->ne[2]/k->ne[2];
int nq_per_thread = (rk2 + nth_k - 1)/nth_k;
size_t size = (Dv + 16)*nq_per_thread*sizeof(float)*n_tasks;
if (ggml_is_quantized(k->type)) {
enum ggml_type vec_dot_type = type_traits[k->type].vec_dot_type;
size_t row_size = ggml_row_size(vec_dot_type, q->ne[0]);
size += q->ne[2]*row_size;
}
cur = MAX(cur, size+qsize);
}
}
}
} else {
cur = MAX(cur, qsize);
}
}
size_t size = iqk_fa_work_buffer_size(node, n_tasks);
cur = MAX(cur, size);
#endif
} break;
case GGML_OP_FLASH_ATTN_BACK:

View File

@@ -7,6 +7,7 @@
#include "iqk_config.h"
#include "iqk_mul_mat.h"
#include "iqk_flash_impl.h"
#include "ggml.h"
#if defined IQK_IMPLEMENT && defined GGML_IQK_FLASH_ATTENTION
@@ -45,6 +46,39 @@ inline void accumulate_qkv(int Dv, float& M, float& S, float Mj, float Sj, float
}
}
size_t iqk_fa_work_buffer_size(const struct ggml_tensor * dst, int nth) {
auto Q = dst->src[0];
auto K = dst->src[1];
auto V = dst->src[2];
int rk2 = Q->ne[2]/K->ne[2];
size_t size = 0;
if (K->type == GGML_TYPE_Q8_0 && (Q->ne[1] >= 8 || (rk2 >= 8 && K->ne[2] > 1))) {
size = ggml_row_size(GGML_TYPE_Q8_0, K->ne[0]) * K->ne[1]*K->ne[2]*K->ne[3];
}
int nstep_k = K->ne[1]/32;
if (nstep_k >= 4*nth) {
auto size_thread = (V->ne[0] + 16)*rk2*sizeof(float);
size += size_thread*nth;
return size;
}
int gcd_k = simple_gcd(nstep_k, nth);
if (gcd_k >= 1) {
int nth_k = nth/gcd_k;
int nq_per_thread = (rk2 + nth_k - 1)/nth_k;
if (nq_per_thread > 1) {
auto size_thread = (V->ne[0] + 16)*nq_per_thread*sizeof(float);
size += size_thread*nth;
return size;
}
}
int rv2 = Q->ne[2] / V->ne[2];
if (Q->ne[1] == 1 && Q->ne[3] == 1 && rk2 > 1 && rk2 == rv2 && K->ne[1]*K->ne[2] >= 32*nth) {
auto result_size = (V->ne[0] + 16)*rk2*sizeof(float);
size += result_size*nth;
}
return size;
}
// TODO: get the ggml_type enum here without polution
//
extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float max_bias,
@@ -145,8 +179,7 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float
// I think it would also speed up things for GQA, but I'm leaving this for another day.
if (neq3 == 1 && rk2 > 1 && neq1 == 1 && nth >= 1 && nek1/32 > 1 && nek2 == 1) {
int nstep_k = nek1/32;
//if (ith >= nstep_k && ith >= rk2) return true;
if (nstep_k >= nth) { //4*nth) {
if (nstep_k >= 4*nth) {
int nstep_k_per_thread = (nstep_k + nth - 1)/nth;
int ith_mid = nth;
int nstep_k_this_thread = nstep_k_per_thread;
@@ -169,22 +202,18 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float
auto size_thread = (Dv + 16)*rk2*sizeof(float);
auto result_buffer = work;
auto work_this_thread = (float *)(result_buffer + ith*size_thread);
//if (nstep_k_this_thread > 0) {
if (!iqk_flash_attn_impl(int_type_k, int_type_v,
Dk, Dv, rk2, nstep_k_this_thread, nbq2, stride_k, stride_v, 0, Dv, //Dk*sizeof(uint16_t), Dv,
(const float *)qth, (const void *)kth, (const void *)vth, (const void *)mth, nullptr, 0,
scale, softcap,
work_this_thread, work_this_thread + (Dv+0)*rk2, work_this_thread + (Dv+1)*rk2)) return false;
//}
barrier(barrier_data);
//int nhave = std::min(nstep_k, nth);
for (int j = ith; j < rk2; j += nth) {
auto Racc = qkv + j*nb1/sizeof(float);
float M = -INFINITY, S = 0;
for (int jth = 0; jth < nth; ++jth) {
//for (int jth = 0; jth < nhave; ++jth) {
auto R = (const float *)(result_buffer + jth*size_thread);
auto Mj = R + Dv*rk2;
auto Sj = Mj + rk2;

View File

@@ -7,6 +7,7 @@
#pragma once
#include <stdint.h>
#include <stdbool.h>
#include <stddef.h>
#include "iqk_config.h"
#ifdef __cplusplus
extern "C" {
@@ -37,6 +38,10 @@ IQK_API bool iqk_moe_fused_up_gate(long Nx, long Ny, long ne00, int ne11, int un
IQK_API int iqk_dequant_type(int type, int Ny);
struct ggml_tensor;
IQK_API size_t iqk_fa_work_buffer_size(const struct ggml_tensor * dst, int nthread);
typedef void (*barrier_t) (void *);
IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float max_bias,