mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-07 23:10:10 +00:00
Work buffer size
This commit is contained in:
@@ -25253,68 +25253,8 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
|
||||
|
||||
cur = 3*sizeof(float)*D*n_tasks; // 3x head size/thread
|
||||
#if GGML_USE_IQK_MULMAT
|
||||
size_t qsize = 0;
|
||||
const struct ggml_tensor * q = node->src[0];
|
||||
const struct ggml_tensor * k = node->src[1];
|
||||
if (k->type == GGML_TYPE_Q8_0) {
|
||||
qsize = ggml_nrows(k)*ggml_row_size(k->type, k->ne[0]);
|
||||
}
|
||||
int nstep_k = k->ne[1]/32;
|
||||
if (nstep_k >= 4*n_tasks && q->ne[1] == 1 && q->ne[3] == 1 && q->ne[2]/k->ne[2] > 1) {
|
||||
size_t size_thread = (Dv + 16)*q->ne[2]/k->ne[2]*sizeof(float);
|
||||
size_t size = size_thread*n_tasks;
|
||||
cur = MAX(cur, size+qsize);
|
||||
} else {
|
||||
if (q->ne[1] == 1 && q->ne[3] == 1 && q->ne[2]/k->ne[2] > 1 && n_tasks > 1 && k->ne[1]/32 > 1) {
|
||||
if (k->ne[2] > 1) {
|
||||
int gcd = simple_gcd(k->ne[2], n_tasks);
|
||||
int nth_k = n_tasks/gcd;
|
||||
int nek2_k = k->ne[2]/gcd;
|
||||
int nchunk = nek2_k*k->ne[1]/32;
|
||||
int npt = (nchunk + nth_k - 1)/nth_k;
|
||||
int nk;
|
||||
if (npt*nth_k == nchunk) {
|
||||
nk = 32 * (k->ne[1]*k->ne[2]/(32*n_tasks));
|
||||
} else {
|
||||
//int nm = std::max(1, npt/8);
|
||||
int nm = 1;
|
||||
while (true) {
|
||||
if (nm*4 >= npt) break;
|
||||
nm *= 2;
|
||||
}
|
||||
nk = 32*nm;
|
||||
}
|
||||
//int nk = 32 * (k->ne[2]*k->ne[1]/(32*n_tasks));
|
||||
nstep_k = k->ne[2]*k->ne[1]/nk;
|
||||
size_t result_size = (Dv + 16)*q->ne[2]/k->ne[2]*sizeof(float);
|
||||
size_t size = nstep_k*result_size;
|
||||
cur = MAX(cur, size+qsize);
|
||||
} else {
|
||||
nstep_k = k->ne[1]/32;
|
||||
if (nstep_k >= n_tasks) {
|
||||
size_t size_thread = (Dv + 16)*q->ne[2]/k->ne[2]*sizeof(float);
|
||||
size_t size = size_thread*n_tasks;
|
||||
cur = MAX(cur, size+qsize);
|
||||
} else {
|
||||
int gcd_k = simple_gcd(nstep_k, n_tasks);
|
||||
if (gcd_k > 1) {
|
||||
int nth_k = n_tasks/gcd_k;
|
||||
int rk2 = q->ne[2]/k->ne[2];
|
||||
int nq_per_thread = (rk2 + nth_k - 1)/nth_k;
|
||||
size_t size = (Dv + 16)*nq_per_thread*sizeof(float)*n_tasks;
|
||||
if (ggml_is_quantized(k->type)) {
|
||||
enum ggml_type vec_dot_type = type_traits[k->type].vec_dot_type;
|
||||
size_t row_size = ggml_row_size(vec_dot_type, q->ne[0]);
|
||||
size += q->ne[2]*row_size;
|
||||
}
|
||||
cur = MAX(cur, size+qsize);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
cur = MAX(cur, qsize);
|
||||
}
|
||||
}
|
||||
size_t size = iqk_fa_work_buffer_size(node, n_tasks);
|
||||
cur = MAX(cur, size);
|
||||
#endif
|
||||
} break;
|
||||
case GGML_OP_FLASH_ATTN_BACK:
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
#include "iqk_config.h"
|
||||
#include "iqk_mul_mat.h"
|
||||
#include "iqk_flash_impl.h"
|
||||
#include "ggml.h"
|
||||
|
||||
#if defined IQK_IMPLEMENT && defined GGML_IQK_FLASH_ATTENTION
|
||||
|
||||
@@ -45,6 +46,39 @@ inline void accumulate_qkv(int Dv, float& M, float& S, float Mj, float Sj, float
|
||||
}
|
||||
}
|
||||
|
||||
size_t iqk_fa_work_buffer_size(const struct ggml_tensor * dst, int nth) {
|
||||
auto Q = dst->src[0];
|
||||
auto K = dst->src[1];
|
||||
auto V = dst->src[2];
|
||||
int rk2 = Q->ne[2]/K->ne[2];
|
||||
size_t size = 0;
|
||||
if (K->type == GGML_TYPE_Q8_0 && (Q->ne[1] >= 8 || (rk2 >= 8 && K->ne[2] > 1))) {
|
||||
size = ggml_row_size(GGML_TYPE_Q8_0, K->ne[0]) * K->ne[1]*K->ne[2]*K->ne[3];
|
||||
}
|
||||
int nstep_k = K->ne[1]/32;
|
||||
if (nstep_k >= 4*nth) {
|
||||
auto size_thread = (V->ne[0] + 16)*rk2*sizeof(float);
|
||||
size += size_thread*nth;
|
||||
return size;
|
||||
}
|
||||
int gcd_k = simple_gcd(nstep_k, nth);
|
||||
if (gcd_k >= 1) {
|
||||
int nth_k = nth/gcd_k;
|
||||
int nq_per_thread = (rk2 + nth_k - 1)/nth_k;
|
||||
if (nq_per_thread > 1) {
|
||||
auto size_thread = (V->ne[0] + 16)*nq_per_thread*sizeof(float);
|
||||
size += size_thread*nth;
|
||||
return size;
|
||||
}
|
||||
}
|
||||
int rv2 = Q->ne[2] / V->ne[2];
|
||||
if (Q->ne[1] == 1 && Q->ne[3] == 1 && rk2 > 1 && rk2 == rv2 && K->ne[1]*K->ne[2] >= 32*nth) {
|
||||
auto result_size = (V->ne[0] + 16)*rk2*sizeof(float);
|
||||
size += result_size*nth;
|
||||
}
|
||||
return size;
|
||||
}
|
||||
|
||||
// TODO: get the ggml_type enum here without polution
|
||||
//
|
||||
extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float max_bias,
|
||||
@@ -145,8 +179,7 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float
|
||||
// I think it would also speed up things for GQA, but I'm leaving this for another day.
|
||||
if (neq3 == 1 && rk2 > 1 && neq1 == 1 && nth >= 1 && nek1/32 > 1 && nek2 == 1) {
|
||||
int nstep_k = nek1/32;
|
||||
//if (ith >= nstep_k && ith >= rk2) return true;
|
||||
if (nstep_k >= nth) { //4*nth) {
|
||||
if (nstep_k >= 4*nth) {
|
||||
int nstep_k_per_thread = (nstep_k + nth - 1)/nth;
|
||||
int ith_mid = nth;
|
||||
int nstep_k_this_thread = nstep_k_per_thread;
|
||||
@@ -169,22 +202,18 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float
|
||||
auto size_thread = (Dv + 16)*rk2*sizeof(float);
|
||||
auto result_buffer = work;
|
||||
auto work_this_thread = (float *)(result_buffer + ith*size_thread);
|
||||
//if (nstep_k_this_thread > 0) {
|
||||
if (!iqk_flash_attn_impl(int_type_k, int_type_v,
|
||||
Dk, Dv, rk2, nstep_k_this_thread, nbq2, stride_k, stride_v, 0, Dv, //Dk*sizeof(uint16_t), Dv,
|
||||
(const float *)qth, (const void *)kth, (const void *)vth, (const void *)mth, nullptr, 0,
|
||||
scale, softcap,
|
||||
work_this_thread, work_this_thread + (Dv+0)*rk2, work_this_thread + (Dv+1)*rk2)) return false;
|
||||
//}
|
||||
|
||||
barrier(barrier_data);
|
||||
|
||||
//int nhave = std::min(nstep_k, nth);
|
||||
for (int j = ith; j < rk2; j += nth) {
|
||||
auto Racc = qkv + j*nb1/sizeof(float);
|
||||
float M = -INFINITY, S = 0;
|
||||
for (int jth = 0; jth < nth; ++jth) {
|
||||
//for (int jth = 0; jth < nhave; ++jth) {
|
||||
auto R = (const float *)(result_buffer + jth*size_thread);
|
||||
auto Mj = R + Dv*rk2;
|
||||
auto Sj = Mj + rk2;
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
#pragma once
|
||||
#include <stdint.h>
|
||||
#include <stdbool.h>
|
||||
#include <stddef.h>
|
||||
#include "iqk_config.h"
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
@@ -37,6 +38,10 @@ IQK_API bool iqk_moe_fused_up_gate(long Nx, long Ny, long ne00, int ne11, int un
|
||||
|
||||
IQK_API int iqk_dequant_type(int type, int Ny);
|
||||
|
||||
struct ggml_tensor;
|
||||
|
||||
IQK_API size_t iqk_fa_work_buffer_size(const struct ggml_tensor * dst, int nthread);
|
||||
|
||||
typedef void (*barrier_t) (void *);
|
||||
|
||||
IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float max_bias,
|
||||
|
||||
Reference in New Issue
Block a user