mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-24 15:14:10 +00:00
Better CPU FA implementation for TG when GQA
This commit is contained in:
@@ -21781,19 +21781,27 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
|
||||
const struct ggml_tensor * q = node->src[0];
|
||||
const struct ggml_tensor * k = node->src[1];
|
||||
if (q->ne[1] == 1 && q->ne[3] == 1 && q->ne[2]/k->ne[2] > 1 && n_tasks > 1 && k->ne[1]/32 > 1) {
|
||||
int nstep_k = k->ne[1]/32;
|
||||
int gcd_k = simple_gcd(nstep_k, n_tasks);
|
||||
if (gcd_k > 1) {
|
||||
int nth_k = n_tasks/gcd_k;
|
||||
int rk2 = q->ne[2]/k->ne[2];
|
||||
int nq_per_thread = (rk2 + nth_k - 1)/nth_k;
|
||||
size_t size = (Dv + 16)*nq_per_thread*sizeof(float)*n_tasks;
|
||||
if (ggml_is_quantized(k->type)) {
|
||||
enum ggml_type vec_dot_type = type_traits[k->type].vec_dot_type;
|
||||
size_t row_size = ggml_row_size(vec_dot_type, q->ne[0]);
|
||||
size += q->ne[2]*row_size;
|
||||
}
|
||||
if (k->ne[2] > 1) {
|
||||
int nk = 32 * (k->ne[2]*k->ne[1]/(32*n_tasks));
|
||||
int nstep_k = k->ne[2]*k->ne[1]/nk;
|
||||
size_t result_size = (Dv + 16)*q->ne[2]/k->ne[2]*sizeof(float);
|
||||
size_t size = nstep_k*result_size;
|
||||
cur = MAX(cur, size);
|
||||
} else {
|
||||
int nstep_k = k->ne[1]/32;
|
||||
int gcd_k = simple_gcd(nstep_k, n_tasks);
|
||||
if (gcd_k > 1) {
|
||||
int nth_k = n_tasks/gcd_k;
|
||||
int rk2 = q->ne[2]/k->ne[2];
|
||||
int nq_per_thread = (rk2 + nth_k - 1)/nth_k;
|
||||
size_t size = (Dv + 16)*nq_per_thread*sizeof(float)*n_tasks;
|
||||
if (ggml_is_quantized(k->type)) {
|
||||
enum ggml_type vec_dot_type = type_traits[k->type].vec_dot_type;
|
||||
size_t row_size = ggml_row_size(vec_dot_type, q->ne[0]);
|
||||
size += q->ne[2]*row_size;
|
||||
}
|
||||
cur = MAX(cur, size);
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -153,6 +153,65 @@ bool iqk_flash_attn_noalibi(int type_q, int type_mask, float max_bias,
|
||||
}
|
||||
}
|
||||
|
||||
if (neq3 == 1 && rk2 > 1 && rk2 == rv2 && neq1 == 1 && nth >= 1 && nek2*nek1 >= 32*nth) {
|
||||
int nk = 32 * (nek2*nek1/(32*nth));
|
||||
int nkk = nek1/nk;
|
||||
int nstep_k = nek2*nkk;
|
||||
auto result_size = (Dv + 16)*rk2*sizeof(float);
|
||||
//if (ith == 0) printf("rk2 = %d, nek1 = %d, nek2 = %d, nk = %d, nkk = %d, nstep_k = %d\n", (int)rk2, (int)nek1, (int)nek2, nk, nkk, nstep_k);
|
||||
for (int istep_k = ith; istep_k < nstep_k; istep_k += nth) {
|
||||
int ik02 = istep_k/nkk;
|
||||
int ik01 = nk*(istep_k - ik02*nkk);
|
||||
auto this_result = (float *)((char *)work_buffer + istep_k*result_size);
|
||||
auto this_q = (const float *)((const char *)q + ik02*rk2*nbq2);
|
||||
auto this_k = (const char *)k + ik01*stride_k + ik02*nbk2;
|
||||
auto this_v = (const char *)v + ik01*stride_v + ik02*nbv2;
|
||||
auto this_m = (const char *)mask + ik01*sizeof(uint16_t); // we don't have ggml_half available here
|
||||
if (!iqk_flash_attn_impl(int_type_k, int_type_v,
|
||||
Dk, Dv, rk2, nk, nbq2, stride_k, stride_v, 0, Dv,
|
||||
this_q, (const void *)this_k, (const void *)this_v, (const void *)this_m,
|
||||
scale, softcap, this_result, this_result + (Dv+0)*rk2, this_result + (Dv+1)*rk2)) return false;
|
||||
}
|
||||
|
||||
barrier(barrier_data);
|
||||
|
||||
// We have nkk results for each head
|
||||
for (int iq2 = ith; iq2 < neq2; iq2 += nth) {
|
||||
// ik02*rk2 + il = iq2 (il = 0...rk2-1) => ik02 = iq2/rk2, il = iq2%rk2;
|
||||
int ik02 = iq2/rk2;
|
||||
int il = iq2 - ik02*rk2;
|
||||
auto Racc = qkv + iq2*nb1/sizeof(float);
|
||||
std::memset(Racc, 0, Dv*sizeof(float));
|
||||
float M = -INFINITY, S = 0;
|
||||
for (int ikk = 0; ikk < nkk; ++ikk) {
|
||||
int istep_k = ik02*nkk + ikk;
|
||||
auto this_result = (float *)((char *)work_buffer + istep_k*result_size);
|
||||
const float * R = this_result + il*Dv;
|
||||
const float * Mj = this_result + Dv*rk2;
|
||||
const float * Sj = Mj + rk2;
|
||||
if (Mj[il] == -INFINITY) continue;
|
||||
if (Mj[il] > M) {
|
||||
if (M == -INFINITY) {
|
||||
std::memcpy(Racc, R, Dv*sizeof(float));
|
||||
S = Sj[il];
|
||||
} else {
|
||||
float c = exp(M - Mj[il]);
|
||||
S = c*S + Sj[il];
|
||||
for (int i = 0; i < Dv; ++i) Racc[i] = c*Racc[i] + R[i];
|
||||
}
|
||||
M = Mj[il];
|
||||
} else {
|
||||
float c = exp(Mj[il] - M);
|
||||
S += c*Sj[il];
|
||||
for (int i = 0; i < Dv; ++i) Racc[i] += c*R[i];
|
||||
}
|
||||
}
|
||||
float norm = S > 0 ? 1/S : 1;
|
||||
for (int i = 0; i < Dv; ++i) Racc[i] *= norm;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// I keep changing my mind what is the best strategy to split the threads when processing
|
||||
// multiple heads. This is my current thinking, the commented out code below was the previous.
|
||||
int ntg = nth/simple_gcd(neq2*neq3, nth);
|
||||
|
||||
Reference in New Issue
Block a user