mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-20 22:49:31 +00:00
CPU FA improvements (#351)
* FA: provide work buffer for K repacking * Add header to avoid comp0iler warnings * WIP * WIP * WIP * WIP * Slightly better * WIP (Zen4) * WIP * Try to improve for unusual number of heads/number of threads * Use mul_mat_qX_0_q8_2_Tx for q6_0 in FA * Use mul_mat_qX_0_q8_2_Tx for q4_0 in FA * Use Sum4q4 for q4_0 * WIP * WIP * Much better FA TG with q8_0 KV cache Just repack it even for TG. But do the repacking for k_step rows, not the whole K tensor. --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
@@ -21786,15 +21786,36 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
|
||||
|
||||
cur = 3*sizeof(float)*D*n_tasks; // 3x head size/thread
|
||||
#if GGML_USE_IQK_MULMAT
|
||||
size_t qsize = 0;
|
||||
const struct ggml_tensor * q = node->src[0];
|
||||
const struct ggml_tensor * k = node->src[1];
|
||||
if (k->type == GGML_TYPE_Q8_0) {
|
||||
qsize = ggml_nrows(k)*ggml_row_size(k->type, k->ne[0]);
|
||||
}
|
||||
if (q->ne[1] == 1 && q->ne[3] == 1 && q->ne[2]/k->ne[2] > 1 && n_tasks > 1 && k->ne[1]/32 > 1) {
|
||||
if (k->ne[2] > 1) {
|
||||
int nk = MAX(1, 32 * (k->ne[2]*k->ne[1]/(32*n_tasks)));
|
||||
int gcd = simple_gcd(k->ne[2], n_tasks);
|
||||
int nth_k = n_tasks/gcd;
|
||||
int nek2_k = k->ne[2]/gcd;
|
||||
int nchunk = nek2_k*k->ne[1]/32;
|
||||
int npt = (nchunk + nth_k - 1)/nth_k;
|
||||
int nk;
|
||||
if (npt*nth_k == nchunk) {
|
||||
nk = 32 * (k->ne[1]*k->ne[2]/(32*n_tasks));
|
||||
} else {
|
||||
//int nm = std::max(1, npt/8);
|
||||
int nm = 1;
|
||||
while (true) {
|
||||
if (nm*4 >= npt) break;
|
||||
nm *= 2;
|
||||
}
|
||||
nk = 32*nm;
|
||||
}
|
||||
//int nk = 32 * (k->ne[2]*k->ne[1]/(32*n_tasks));
|
||||
int nstep_k = k->ne[2]*k->ne[1]/nk;
|
||||
size_t result_size = (Dv + 16)*q->ne[2]/k->ne[2]*sizeof(float);
|
||||
size_t size = nstep_k*result_size;
|
||||
cur = MAX(cur, size);
|
||||
cur = MAX(cur, size+qsize);
|
||||
} else {
|
||||
int nstep_k = k->ne[1]/32;
|
||||
int gcd_k = simple_gcd(nstep_k, n_tasks);
|
||||
@@ -21808,9 +21829,11 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
|
||||
size_t row_size = ggml_row_size(vec_dot_type, q->ne[0]);
|
||||
size += q->ne[2]*row_size;
|
||||
}
|
||||
cur = MAX(cur, size);
|
||||
cur = MAX(cur, size+qsize);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
cur = MAX(cur, qsize);
|
||||
}
|
||||
#endif
|
||||
} break;
|
||||
|
||||
@@ -25,6 +25,24 @@ inline uint32_t simple_gcd(uint32_t a, uint32_t b) {
|
||||
}
|
||||
return a;
|
||||
}
|
||||
inline void accumulate_qkv(int Dv, float& M, float& S, float Mj, float Sj, float * Racc, const float * R) {
|
||||
if (Mj == -INFINITY) return;
|
||||
if (Mj > M) {
|
||||
if (M == -INFINITY) {
|
||||
std::memcpy(Racc, R, Dv*sizeof(float));
|
||||
S = Sj;
|
||||
} else {
|
||||
float c = exp(M - Mj);
|
||||
S = c*S + Sj;
|
||||
for (int i = 0; i < Dv; ++i) Racc[i] = c*Racc[i] + R[i];
|
||||
}
|
||||
M = Mj;
|
||||
} else {
|
||||
float c = exp(Mj - M);
|
||||
S += c*Sj;
|
||||
for (int i = 0; i < Dv; ++i) Racc[i] += c*R[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: get the ggml_type enum here without polution
|
||||
@@ -34,7 +52,7 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float
|
||||
int nek3, int nek2, long nbk3, long nbk2,
|
||||
int nev3, int nev2, long nbv3, long nbv2,
|
||||
int ne2, int ne1, long nb1,
|
||||
int int_type_k, // type of k
|
||||
int int_type_k_in, // type of k
|
||||
int int_type_v, // type of v
|
||||
int Dk, // K head size
|
||||
int Dv, // V head size
|
||||
@@ -51,7 +69,7 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float
|
||||
float scale, // scale applied before softmax
|
||||
float softcap, // if > 0, a "soft-cap" operation is applied before softmax
|
||||
float * qkv, // v*softmax(scale*(k*q))
|
||||
[[maybe_unused]] void * work_buffer, [[maybe_unused]] barrier_t barrier, [[maybe_unused]] void * barrier_data,
|
||||
[[maybe_unused]] void * work_buffer_in, [[maybe_unused]] barrier_t barrier, [[maybe_unused]] void * barrier_data,
|
||||
int ith, int nth) {
|
||||
|
||||
if (type_q != 0 || type_mask != 1 || max_bias > 0) return false;
|
||||
@@ -61,6 +79,29 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float
|
||||
int rk3 = neq3/nek3;
|
||||
int rv3 = neq3/nev3;
|
||||
|
||||
int int_type_k = int_type_k_in;
|
||||
auto work_buffer = work_buffer_in;
|
||||
if (neq1 >= 8 || rk2 >= 8) {
|
||||
uint64_t row_size = 0;
|
||||
work_buffer = iqk_repack_k(int_type_k, Dk, nek1, nek2, nek3, stride_k, nbk2, nbk3, k, work_buffer_in, ith, nth, int_type_k, row_size);
|
||||
if (int_type_k != int_type_k_in) {
|
||||
stride_k = row_size;
|
||||
nbk2 = stride_k*nek1;
|
||||
nbk3 = nbk2*nek2;
|
||||
k = work_buffer_in;
|
||||
barrier(barrier_data);
|
||||
}
|
||||
}
|
||||
//uint64_t row_size = 0;
|
||||
//auto work_buffer = iqk_repack_k(int_type_k, Dk, nek1, nek2, nek3, stride_k, nbk2, nbk3, k, work_buffer_in, ith, nth, int_type_k, row_size);
|
||||
//if (int_type_k != int_type_k_in) {
|
||||
// stride_k = row_size;
|
||||
// nbk2 = stride_k*nek1;
|
||||
// nbk3 = nbk2*nek2;
|
||||
// k = work_buffer_in;
|
||||
// barrier(barrier_data);
|
||||
//}
|
||||
|
||||
// Getting confused all the time about where to load data from and store the results to
|
||||
// (especially when combining the results from the threads).
|
||||
// So, for now, making it work just for MLA (nek2 = 1).
|
||||
@@ -128,22 +169,7 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float
|
||||
auto Mj = R + Dv*nq_this_j;
|
||||
auto Sj = Mj + nq_this_j;
|
||||
R += jj*Dv;
|
||||
if (Mj[jj] == -INFINITY) continue;
|
||||
if (Mj[jj] > M) {
|
||||
if (M == -INFINITY) {
|
||||
std::memcpy(Racc, R, Dv*sizeof(float));
|
||||
S = Sj[jj];
|
||||
} else {
|
||||
float c = exp(M - Mj[jj]);
|
||||
S = c*S + Sj[jj];
|
||||
for (int i = 0; i < Dv; ++i) Racc[i] = c*Racc[i] + R[i];
|
||||
}
|
||||
M = Mj[jj];
|
||||
} else {
|
||||
float c = exp(Mj[jj] - M);
|
||||
S += c*Sj[jj];
|
||||
for (int i = 0; i < Dv; ++i) Racc[i] += c*R[i];
|
||||
}
|
||||
accumulate_qkv(Dv, M, S, Mj[jj], Sj[jj], Racc, R);
|
||||
}
|
||||
float norm = S > 0 ? 1/S : 1;
|
||||
for (int i = 0; i < Dv; ++i) Racc[i] *= norm;
|
||||
@@ -154,10 +180,72 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float
|
||||
}
|
||||
|
||||
if (neq3 == 1 && rk2 > 1 && rk2 == rv2 && neq1 == 1 && nth >= 1 && nek2*nek1 >= 32*nth) {
|
||||
int nk = std::max(1, 32 * (nek2*nek1/(32*nth)));
|
||||
auto result_size = (Dv + 16)*rk2*sizeof(float);
|
||||
int gcd = simple_gcd(nek2, nth);
|
||||
if (false && gcd > 1) {
|
||||
int nth_g = nth/gcd;
|
||||
int ith_g = ith%nth_g;
|
||||
int nek1_32 = nek1/32;
|
||||
int nek1_pt = (nek1_32 + nth_g - 1)/nth_g;
|
||||
int ith_mid = nth_g;
|
||||
if (nek1_pt*nth_g > nek1_32) {
|
||||
ith_mid = nek1_32 - nth_g*(nek1_pt - 1);
|
||||
}
|
||||
nek1_pt *= 32;
|
||||
int nek1_mid = ith_mid*nek1_pt;
|
||||
int nek1_thread = ith_g < ith_mid ? nek1_pt : nek1_pt - 32;
|
||||
for (int ik02 = ith/nth_g; ik02 < nek2; ik02 += gcd) {
|
||||
int ik01 = ith_g < ith_mid ? ith_g*nek1_pt : nek1_mid + (ith_g - ith_mid)*nek1_thread;
|
||||
auto this_result = (float *)((char *)work_buffer + (ik02*nth_g + ith_g)*result_size);
|
||||
auto this_q = (const float *)((const char *)q + ik02*rk2*nbq2);
|
||||
auto this_k = (const char *)k + ik01*stride_k + ik02*nbk2;
|
||||
auto this_v = (const char *)v + ik01*stride_v + ik02*nbv2;
|
||||
auto this_m = (const char *)mask + ik01*sizeof(uint16_t); // we don't have ggml_half available here
|
||||
if (!iqk_flash_attn_impl(int_type_k, int_type_v,
|
||||
Dk, Dv, rk2, nek1_thread, nbq2, stride_k, stride_v, 0, Dv,
|
||||
this_q, (const void *)this_k, (const void *)this_v, (const void *)this_m,
|
||||
scale, softcap, this_result, this_result + (Dv+0)*rk2, this_result + (Dv+1)*rk2)) return false;
|
||||
}
|
||||
|
||||
barrier(barrier_data);
|
||||
|
||||
for (int iq2 = ith; iq2 < neq2; iq2 += nth) {
|
||||
int ik02 = iq2/rk2;
|
||||
int il = iq2 - ik02*rk2;
|
||||
auto Racc = qkv + iq2*nb1/sizeof(float);
|
||||
float M = -INFINITY, S = 0;
|
||||
for (int ig = 0; ig < nth_g; ++ig) {
|
||||
int istep_k = ik02*nth_g + ig;
|
||||
auto this_result = (float *)((char *)work_buffer + istep_k*result_size);
|
||||
const float * R = this_result + il*Dv;
|
||||
const float * Mj = this_result + Dv*rk2;
|
||||
const float * Sj = Mj + rk2;
|
||||
accumulate_qkv(Dv, M, S, Mj[il], Sj[il], Racc, R);
|
||||
}
|
||||
float norm = S > 0 ? 1/S : 1;
|
||||
for (int i = 0; i < Dv; ++i) Racc[i] *= norm;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
int nth_k = nth/gcd;
|
||||
int nek2_k = nek2/gcd;
|
||||
int nchunk = nek2_k*nek1/32;
|
||||
int npt = (nchunk + nth_k - 1)/nth_k;
|
||||
int nk;
|
||||
if (npt*nth_k == nchunk) {
|
||||
nk = 32 * (nek2*nek1/(32*nth));
|
||||
} else {
|
||||
//int nm = std::max(1, npt/8);
|
||||
int nm = 1;
|
||||
while (true) {
|
||||
if (nm*4 >= npt) break;
|
||||
nm *= 2;
|
||||
}
|
||||
nk = 32*nm;
|
||||
}
|
||||
//int nk = 32 * (nek2*nek1/(32*nth));
|
||||
int nkk = (nek1 + nk - 1)/nk;
|
||||
int nstep_k = nek2*nkk;
|
||||
auto result_size = (Dv + 16)*rk2*sizeof(float);
|
||||
//if (ith == 0) printf("rk2 = %d, nek1 = %d, nek2 = %d, nk = %d, nkk = %d, nstep_k = %d\n", (int)rk2, (int)nek1, (int)nek2, nk, nkk, nstep_k);
|
||||
for (int istep_k = ith; istep_k < nstep_k; istep_k += nth) {
|
||||
int ik02 = istep_k/nkk;
|
||||
@@ -183,7 +271,7 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float
|
||||
int ik02 = iq2/rk2;
|
||||
int il = iq2 - ik02*rk2;
|
||||
auto Racc = qkv + iq2*nb1/sizeof(float);
|
||||
std::memset(Racc, 0, Dv*sizeof(float));
|
||||
//std::memset(Racc, 0, Dv*sizeof(float));
|
||||
float M = -INFINITY, S = 0;
|
||||
for (int ikk = 0; ikk < nkk; ++ikk) {
|
||||
int istep_k = ik02*nkk + ikk;
|
||||
@@ -191,22 +279,7 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float
|
||||
const float * R = this_result + il*Dv;
|
||||
const float * Mj = this_result + Dv*rk2;
|
||||
const float * Sj = Mj + rk2;
|
||||
if (Mj[il] == -INFINITY) continue;
|
||||
if (Mj[il] > M) {
|
||||
if (M == -INFINITY) {
|
||||
std::memcpy(Racc, R, Dv*sizeof(float));
|
||||
S = Sj[il];
|
||||
} else {
|
||||
float c = exp(M - Mj[il]);
|
||||
S = c*S + Sj[il];
|
||||
for (int i = 0; i < Dv; ++i) Racc[i] = c*Racc[i] + R[i];
|
||||
}
|
||||
M = Mj[il];
|
||||
} else {
|
||||
float c = exp(Mj[il] - M);
|
||||
S += c*Sj[il];
|
||||
for (int i = 0; i < Dv; ++i) Racc[i] += c*R[i];
|
||||
}
|
||||
accumulate_qkv(Dv, M, S, Mj[il], Sj[il], Racc, R);
|
||||
}
|
||||
float norm = S > 0 ? 1/S : 1;
|
||||
for (int i = 0; i < Dv; ++i) Racc[i] *= norm;
|
||||
|
||||
@@ -6,6 +6,8 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
bool iqk_flash_attn_impl(int type_k, // type of k
|
||||
int type_v, // type of v
|
||||
int Dk, // K head size
|
||||
@@ -27,3 +29,5 @@ bool iqk_flash_attn_impl(int type_k, // type of k
|
||||
float * M,
|
||||
float * S);
|
||||
|
||||
void * iqk_repack_k(int type_k, int nek0, int nek1, int nek2, int nek3, long nbk1, long nbk2, long nbk3,
|
||||
const void * k, void * work, int ith, int nth, int& repacked_type, uint64_t& row_size);
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user