mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-25 23:54:10 +00:00
WIP: plugging into ggml_compute_forward_flash_attn_ext_f16
This is now working. It is not faster, but at least it is not massively slower as the original.
This commit is contained in:
170
ggml/src/ggml.c
170
ggml/src/ggml.c
@@ -16035,6 +16035,18 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
||||
ggml_vec_dot_t const kq_vec_dot = type_traits[k->type].vec_dot;
|
||||
ggml_to_float_t const v_to_float = type_traits[v->type].to_float;
|
||||
|
||||
size_t nek64 = CACHE_LINE_SIZE_F32*((nek1 + CACHE_LINE_SIZE_F32 - 1)/CACHE_LINE_SIZE_F32);
|
||||
size_t needed_wsize = (nek64 + 3*D)*sizeof(float)*nth;
|
||||
if (needed_wsize > params->wsize) {
|
||||
printf("Work size is not big enough. Need %d bytes, have %d\n", (int)needed_wsize, (int)params->wsize);
|
||||
GGML_ASSERT(false);
|
||||
}
|
||||
|
||||
//if (ith == 0) {
|
||||
// printf("=== %s: k->type = %s, q->type = %s, v->type = %s\n", __func__, ggml_type_name(k->type), ggml_type_name(q->type), ggml_type_name(v->type));
|
||||
// printf(" D = %d, nr = %d, dr = %d, nek1 = %d wsize = %d\n", (int)D, nr, dr, (int)nek1, (int)params->wsize);
|
||||
//}
|
||||
|
||||
// loop over n_batch and n_head
|
||||
for (int ir = ir0; ir < ir1; ++ir) {
|
||||
// q indices
|
||||
@@ -16048,7 +16060,11 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
||||
float S = 0.0f; // sum
|
||||
float M = -INFINITY; // maximum KQ value
|
||||
|
||||
float * VKQ32 = (float *) params->wdata + ith*(3*D + CACHE_LINE_SIZE_F32); // FP32 VKQ accumulator
|
||||
int64_t nek64 = CACHE_LINE_SIZE_F32*((nek1 + CACHE_LINE_SIZE_F32 - 1)/CACHE_LINE_SIZE_F32);
|
||||
float * aux_kq = (float *)params->wdata + ith*(nek64 + 3*D);
|
||||
float * VKQ32 = aux_kq + nek1;
|
||||
|
||||
//float * VKQ32 = (float *) params->wdata + ith*(3*D + CACHE_LINE_SIZE_F32); // FP32 VKQ accumulator
|
||||
float * V32 = (VKQ32 + 1*D); // (temporary) FP32 V buffer
|
||||
ggml_fp16_t * VKQ16 = (ggml_fp16_t *) (VKQ32 + 1*D); // (temporary) FP16 VKQ accumulator
|
||||
ggml_fp16_t * Q_q = (ggml_fp16_t *) (VKQ32 + 2*D); // (temporary) buffer for Q converted to quantized/FP16
|
||||
@@ -16069,78 +16085,96 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
||||
const int iv3 = iq3 / rv3;
|
||||
const int iv2 = iq2 / rv2;
|
||||
|
||||
const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3));
|
||||
q_to_vec_dot(pq, Q_q, D);
|
||||
iqk_flash_helper(D, nek1, nbk1,
|
||||
(const float *)((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)),
|
||||
(const void *)((char *) k->data + ik2*nbk2 + ik3*nbk3),
|
||||
(const void *)mp,
|
||||
scale, slope,
|
||||
aux_kq);
|
||||
|
||||
// online softmax / attention
|
||||
// loop over n_kv and n_head_kv
|
||||
// ref: https://arxiv.org/pdf/2112.05682.pdf
|
||||
for (int64_t ic = 0; ic < nek1; ++ic) {
|
||||
const float mv = mp ? slope*GGML_FP16_TO_FP32(mp[ic]) : 0.0f;
|
||||
if (mv == -INFINITY) {
|
||||
continue;
|
||||
}
|
||||
|
||||
float s; // KQ value
|
||||
|
||||
const char * k_data = (const char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3);
|
||||
kq_vec_dot(D, &s, 0, k_data, 0, Q_q, 0, 1);
|
||||
|
||||
s = s*scale + mv; // scale KQ value and apply mask
|
||||
|
||||
const float Mold = M;
|
||||
|
||||
float ms = 1.0f; // upon new higher max val, scale VKQ and KQ sum with this value
|
||||
float vs = 1.0f; // post-softmax KQ value, expf(s - M)
|
||||
|
||||
const char * v_data = ((const char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3));
|
||||
|
||||
if (v->type== GGML_TYPE_F16) {
|
||||
if (s > M) {
|
||||
// s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
|
||||
M = s;
|
||||
ms = expf(Mold - M);
|
||||
|
||||
// V = V*expf(Mold - M)
|
||||
ggml_vec_scale_f16(D, VKQ16, ms);
|
||||
} else {
|
||||
// no new maximum, ms == 1.0f, vs != 1.0f
|
||||
vs = expf(s - M);
|
||||
}
|
||||
|
||||
// V += v*expf(s - M)
|
||||
ggml_vec_mad_f16(D, VKQ16, (const ggml_fp16_t *) v_data, vs);
|
||||
if (v->type == GGML_TYPE_F16) {
|
||||
ggml_vec_mad_f16(D, VKQ16, (const ggml_fp16_t *) v_data, aux_kq[ic]);
|
||||
} else {
|
||||
if (s > M) {
|
||||
// s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
|
||||
M = s;
|
||||
ms = expf(Mold - M);
|
||||
|
||||
// V = V*expf(Mold - M)
|
||||
ggml_vec_scale_f32(D, VKQ32, ms);
|
||||
} else {
|
||||
// no new maximum, ms == 1.0f, vs != 1.0f
|
||||
vs = expf(s - M);
|
||||
}
|
||||
|
||||
v_to_float(v_data, V32, D);
|
||||
|
||||
// V += v*expf(s - M)
|
||||
ggml_vec_mad_f32(D, VKQ32, V32, vs);
|
||||
ggml_vec_mad_f32(D, VKQ32, V32, aux_kq[ic]);
|
||||
}
|
||||
|
||||
S = S*ms + vs; // scale and increment sum with partial sum
|
||||
}
|
||||
|
||||
////const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3));
|
||||
////q_to_vec_dot(pq, Q_q, D);
|
||||
|
||||
//// online softmax / attention
|
||||
//// loop over n_kv and n_head_kv
|
||||
//// ref: https://arxiv.org/pdf/2112.05682.pdf
|
||||
//for (int64_t ic = 0; ic < nek1; ++ic) {
|
||||
// const float mv = mp ? slope*GGML_FP16_TO_FP32(mp[ic]) : 0.0f;
|
||||
// if (mv == -INFINITY) {
|
||||
// continue;
|
||||
// }
|
||||
|
||||
// float s = aux_kq[ic];
|
||||
// //float s; // KQ value
|
||||
|
||||
// //const char * k_data = (const char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3);
|
||||
// //kq_vec_dot(D, &s, 0, k_data, 0, Q_q, 0, 1);
|
||||
|
||||
// s = s*scale + mv; // scale KQ value and apply mask
|
||||
|
||||
// const float Mold = M;
|
||||
|
||||
// float ms = 1.0f; // upon new higher max val, scale VKQ and KQ sum with this value
|
||||
// float vs = 1.0f; // post-softmax KQ value, expf(s - M)
|
||||
|
||||
// const char * v_data = ((const char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3));
|
||||
|
||||
// if (v->type== GGML_TYPE_F16) {
|
||||
// if (s > M) {
|
||||
// // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
|
||||
// M = s;
|
||||
// ms = expf(Mold - M);
|
||||
|
||||
// // V = V*expf(Mold - M)
|
||||
// ggml_vec_scale_f16(D, VKQ16, ms);
|
||||
// } else {
|
||||
// // no new maximum, ms == 1.0f, vs != 1.0f
|
||||
// vs = expf(s - M);
|
||||
// }
|
||||
|
||||
// // V += v*expf(s - M)
|
||||
// ggml_vec_mad_f16(D, VKQ16, (const ggml_fp16_t *) v_data, vs);
|
||||
// } else {
|
||||
// if (s > M) {
|
||||
// // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
|
||||
// M = s;
|
||||
// ms = expf(Mold - M);
|
||||
|
||||
// // V = V*expf(Mold - M)
|
||||
// ggml_vec_scale_f32(D, VKQ32, ms);
|
||||
// } else {
|
||||
// // no new maximum, ms == 1.0f, vs != 1.0f
|
||||
// vs = expf(s - M);
|
||||
// }
|
||||
|
||||
// v_to_float(v_data, V32, D);
|
||||
|
||||
// // V += v*expf(s - M)
|
||||
// ggml_vec_mad_f32(D, VKQ32, V32, vs);
|
||||
// }
|
||||
|
||||
// S = S*ms + vs; // scale and increment sum with partial sum
|
||||
//}
|
||||
|
||||
if (v->type == GGML_TYPE_F16) {
|
||||
for (int64_t d = 0; d < D; ++d) {
|
||||
VKQ32[d] = GGML_FP16_TO_FP32(VKQ16[d]);
|
||||
}
|
||||
}
|
||||
|
||||
// V /= S
|
||||
const float S_inv = 1.0f/S;
|
||||
ggml_vec_scale_f32(D, VKQ32, S_inv);
|
||||
//// V /= S
|
||||
//const float S_inv = 1.0f/S;
|
||||
//ggml_vec_scale_f32(D, VKQ32, S_inv);
|
||||
|
||||
// dst indices
|
||||
const int i1 = iq1;
|
||||
@@ -17561,11 +17595,11 @@ static bool ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
||||
} break;
|
||||
case GGML_OP_MUL_MAT:
|
||||
{
|
||||
if (next && next->op == GGML_OP_SOFT_MAX) {
|
||||
result = ggml_fused_mul_mat_softmax(params, tensor, next);
|
||||
} else {
|
||||
//if (next && next->op == GGML_OP_SOFT_MAX) {
|
||||
// result = ggml_fused_mul_mat_softmax(params, tensor, next);
|
||||
//} else {
|
||||
ggml_compute_forward_mul_mat(params, tensor);
|
||||
}
|
||||
//}
|
||||
} break;
|
||||
case GGML_OP_MUL_MAT_ID:
|
||||
{
|
||||
@@ -19566,8 +19600,14 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
|
||||
case GGML_OP_FLASH_ATTN_EXT:
|
||||
{
|
||||
const int64_t ne00 = node->src[0]->ne[0]; // D
|
||||
|
||||
cur = 3*sizeof(float)*ne00*n_tasks; // 3x head size/thread
|
||||
int64_t ne01 = node->src[1]->ne[1];
|
||||
ne01 = CACHE_LINE_SIZE_F32*((ne01 + CACHE_LINE_SIZE_F32 - 1)/CACHE_LINE_SIZE_F32);
|
||||
cur = (3*ne00 + ne01)*sizeof(float)*n_tasks;
|
||||
//printf("flash: ne00 = %d, ne01 = %d (%d)\n", (int)ne00, (int)ne01, (int)node->src[0]->ne[1]);
|
||||
//cur = 3*sizeof(float)*ne00*n_tasks; // 3x head size/thread
|
||||
//cur += (node->src[0]->ne[1] + CACHE_LINE_SIZE_F32)*sizeof(float)*n_tasks;
|
||||
//printf("flash: need %g bytes (up from %g) ne00 = %d, ne01 = %d ne02 = %d, ne03 = %d\n", 1.*cur, 3.*sizeof(float)*ne00*n_tasks, (int)ne00, (int)node->src[0]->ne[1], (int)node->src[0]->ne[2], (int)node->src[0]->ne[3]);
|
||||
////cur = MAX(cur, (node->src[0]->ne[1] + CACHE_LINE_SIZE_F32)*sizeof(float)*n_tasks);
|
||||
} break;
|
||||
case GGML_OP_FLASH_ATTN_BACK:
|
||||
{
|
||||
|
||||
@@ -6215,6 +6215,35 @@ bool iqk_fused_mul_mat_softmax(long Nx, long Ny, long ne00,
|
||||
return true;
|
||||
}
|
||||
|
||||
//struct DataInfo {
|
||||
// float * s;
|
||||
// const char * cy;
|
||||
// size_t bs;
|
||||
// size_t by;
|
||||
// int cur_y = 0;
|
||||
// int ne11;
|
||||
// const mmid_row_mapping * row_mapping = nullptr;
|
||||
// size_t bs2 = 0;
|
||||
|
||||
|
||||
void iqk_flash_helper(int nq, // number of elements in q
|
||||
int nk, // number of rows in k
|
||||
int stride_k, // distance between rows in k (in bytes)
|
||||
const float * q, // q vector
|
||||
const void * k, // k matrix. Assumed to be fp16, nq x nk elements
|
||||
const void * mask, // mask. If not null, assumed to be fp16. nk elements
|
||||
float scale,
|
||||
float slope,
|
||||
float * qk) {
|
||||
GGML_ASSERT(nq % 4 == 0);
|
||||
//GGML_ASSERT(nq / 16 <= 16);
|
||||
|
||||
DataInfo info{qk, (const char*)q, 0, size_t(stride_k), 0, 1, nullptr, 0};
|
||||
|
||||
mul_mat_fX_fY_T<1, ggml_half, float>(nq, k, stride_k, info, nk);
|
||||
softmax_extended(nk, qk, qk, scale, slope, (const char *)mask, true);
|
||||
}
|
||||
|
||||
#else // IQK_IMPLEMENT
|
||||
|
||||
bool iqk_mul_mat(int, long, long, long, int, const void *, long, int, const void *, long, float *, long, int, int) {
|
||||
|
||||
@@ -29,6 +29,16 @@ bool iqk_fused_mul_mat_softmax(long Nx, long Ny, long ne00,
|
||||
const char * mask, float scale, float slope,
|
||||
int ith, int nth);
|
||||
|
||||
void iqk_flash_helper(int nq, // number of elements in q
|
||||
int nk, // number of rows in k
|
||||
int stride_k, // distance between rows in k (in bytes)
|
||||
const float * q, // q vector
|
||||
const void * k, // k matrix. Assumed to be fp16, nq x nk elements
|
||||
const void * mask, // mask. If not null, assumed to be fp16. nk elements
|
||||
float scale,
|
||||
float slope,
|
||||
float * qk); // softmax(k*q) - k elements
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
Reference in New Issue
Block a user