WIP: plugging into ggml_compute_forward_flash_attn_ext_f16

This is now working. It is not faster, but at least it is not
massively slower as the original.
This commit is contained in:
Iwan Kawrakow
2024-08-23 15:47:08 +03:00
parent b127c6cced
commit ffeb8b40eb
3 changed files with 144 additions and 65 deletions

View File

@@ -16035,6 +16035,18 @@ static void ggml_compute_forward_flash_attn_ext_f16(
ggml_vec_dot_t const kq_vec_dot = type_traits[k->type].vec_dot;
ggml_to_float_t const v_to_float = type_traits[v->type].to_float;
size_t nek64 = CACHE_LINE_SIZE_F32*((nek1 + CACHE_LINE_SIZE_F32 - 1)/CACHE_LINE_SIZE_F32);
size_t needed_wsize = (nek64 + 3*D)*sizeof(float)*nth;
if (needed_wsize > params->wsize) {
printf("Work size is not big enough. Need %d bytes, have %d\n", (int)needed_wsize, (int)params->wsize);
GGML_ASSERT(false);
}
//if (ith == 0) {
// printf("=== %s: k->type = %s, q->type = %s, v->type = %s\n", __func__, ggml_type_name(k->type), ggml_type_name(q->type), ggml_type_name(v->type));
// printf(" D = %d, nr = %d, dr = %d, nek1 = %d wsize = %d\n", (int)D, nr, dr, (int)nek1, (int)params->wsize);
//}
// loop over n_batch and n_head
for (int ir = ir0; ir < ir1; ++ir) {
// q indices
@@ -16048,7 +16060,11 @@ static void ggml_compute_forward_flash_attn_ext_f16(
float S = 0.0f; // sum
float M = -INFINITY; // maximum KQ value
float * VKQ32 = (float *) params->wdata + ith*(3*D + CACHE_LINE_SIZE_F32); // FP32 VKQ accumulator
int64_t nek64 = CACHE_LINE_SIZE_F32*((nek1 + CACHE_LINE_SIZE_F32 - 1)/CACHE_LINE_SIZE_F32);
float * aux_kq = (float *)params->wdata + ith*(nek64 + 3*D);
float * VKQ32 = aux_kq + nek1;
//float * VKQ32 = (float *) params->wdata + ith*(3*D + CACHE_LINE_SIZE_F32); // FP32 VKQ accumulator
float * V32 = (VKQ32 + 1*D); // (temporary) FP32 V buffer
ggml_fp16_t * VKQ16 = (ggml_fp16_t *) (VKQ32 + 1*D); // (temporary) FP16 VKQ accumulator
ggml_fp16_t * Q_q = (ggml_fp16_t *) (VKQ32 + 2*D); // (temporary) buffer for Q converted to quantized/FP16
@@ -16069,78 +16085,96 @@ static void ggml_compute_forward_flash_attn_ext_f16(
const int iv3 = iq3 / rv3;
const int iv2 = iq2 / rv2;
const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3));
q_to_vec_dot(pq, Q_q, D);
iqk_flash_helper(D, nek1, nbk1,
(const float *)((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)),
(const void *)((char *) k->data + ik2*nbk2 + ik3*nbk3),
(const void *)mp,
scale, slope,
aux_kq);
// online softmax / attention
// loop over n_kv and n_head_kv
// ref: https://arxiv.org/pdf/2112.05682.pdf
for (int64_t ic = 0; ic < nek1; ++ic) {
const float mv = mp ? slope*GGML_FP16_TO_FP32(mp[ic]) : 0.0f;
if (mv == -INFINITY) {
continue;
}
float s; // KQ value
const char * k_data = (const char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3);
kq_vec_dot(D, &s, 0, k_data, 0, Q_q, 0, 1);
s = s*scale + mv; // scale KQ value and apply mask
const float Mold = M;
float ms = 1.0f; // upon new higher max val, scale VKQ and KQ sum with this value
float vs = 1.0f; // post-softmax KQ value, expf(s - M)
const char * v_data = ((const char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3));
if (v->type== GGML_TYPE_F16) {
if (s > M) {
// s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
M = s;
ms = expf(Mold - M);
// V = V*expf(Mold - M)
ggml_vec_scale_f16(D, VKQ16, ms);
} else {
// no new maximum, ms == 1.0f, vs != 1.0f
vs = expf(s - M);
}
// V += v*expf(s - M)
ggml_vec_mad_f16(D, VKQ16, (const ggml_fp16_t *) v_data, vs);
if (v->type == GGML_TYPE_F16) {
ggml_vec_mad_f16(D, VKQ16, (const ggml_fp16_t *) v_data, aux_kq[ic]);
} else {
if (s > M) {
// s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
M = s;
ms = expf(Mold - M);
// V = V*expf(Mold - M)
ggml_vec_scale_f32(D, VKQ32, ms);
} else {
// no new maximum, ms == 1.0f, vs != 1.0f
vs = expf(s - M);
}
v_to_float(v_data, V32, D);
// V += v*expf(s - M)
ggml_vec_mad_f32(D, VKQ32, V32, vs);
ggml_vec_mad_f32(D, VKQ32, V32, aux_kq[ic]);
}
S = S*ms + vs; // scale and increment sum with partial sum
}
////const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3));
////q_to_vec_dot(pq, Q_q, D);
//// online softmax / attention
//// loop over n_kv and n_head_kv
//// ref: https://arxiv.org/pdf/2112.05682.pdf
//for (int64_t ic = 0; ic < nek1; ++ic) {
// const float mv = mp ? slope*GGML_FP16_TO_FP32(mp[ic]) : 0.0f;
// if (mv == -INFINITY) {
// continue;
// }
// float s = aux_kq[ic];
// //float s; // KQ value
// //const char * k_data = (const char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3);
// //kq_vec_dot(D, &s, 0, k_data, 0, Q_q, 0, 1);
// s = s*scale + mv; // scale KQ value and apply mask
// const float Mold = M;
// float ms = 1.0f; // upon new higher max val, scale VKQ and KQ sum with this value
// float vs = 1.0f; // post-softmax KQ value, expf(s - M)
// const char * v_data = ((const char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3));
// if (v->type== GGML_TYPE_F16) {
// if (s > M) {
// // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
// M = s;
// ms = expf(Mold - M);
// // V = V*expf(Mold - M)
// ggml_vec_scale_f16(D, VKQ16, ms);
// } else {
// // no new maximum, ms == 1.0f, vs != 1.0f
// vs = expf(s - M);
// }
// // V += v*expf(s - M)
// ggml_vec_mad_f16(D, VKQ16, (const ggml_fp16_t *) v_data, vs);
// } else {
// if (s > M) {
// // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
// M = s;
// ms = expf(Mold - M);
// // V = V*expf(Mold - M)
// ggml_vec_scale_f32(D, VKQ32, ms);
// } else {
// // no new maximum, ms == 1.0f, vs != 1.0f
// vs = expf(s - M);
// }
// v_to_float(v_data, V32, D);
// // V += v*expf(s - M)
// ggml_vec_mad_f32(D, VKQ32, V32, vs);
// }
// S = S*ms + vs; // scale and increment sum with partial sum
//}
if (v->type == GGML_TYPE_F16) {
for (int64_t d = 0; d < D; ++d) {
VKQ32[d] = GGML_FP16_TO_FP32(VKQ16[d]);
}
}
// V /= S
const float S_inv = 1.0f/S;
ggml_vec_scale_f32(D, VKQ32, S_inv);
//// V /= S
//const float S_inv = 1.0f/S;
//ggml_vec_scale_f32(D, VKQ32, S_inv);
// dst indices
const int i1 = iq1;
@@ -17561,11 +17595,11 @@ static bool ggml_compute_forward(struct ggml_compute_params * params, struct ggm
} break;
case GGML_OP_MUL_MAT:
{
if (next && next->op == GGML_OP_SOFT_MAX) {
result = ggml_fused_mul_mat_softmax(params, tensor, next);
} else {
//if (next && next->op == GGML_OP_SOFT_MAX) {
// result = ggml_fused_mul_mat_softmax(params, tensor, next);
//} else {
ggml_compute_forward_mul_mat(params, tensor);
}
//}
} break;
case GGML_OP_MUL_MAT_ID:
{
@@ -19566,8 +19600,14 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
case GGML_OP_FLASH_ATTN_EXT:
{
const int64_t ne00 = node->src[0]->ne[0]; // D
cur = 3*sizeof(float)*ne00*n_tasks; // 3x head size/thread
int64_t ne01 = node->src[1]->ne[1];
ne01 = CACHE_LINE_SIZE_F32*((ne01 + CACHE_LINE_SIZE_F32 - 1)/CACHE_LINE_SIZE_F32);
cur = (3*ne00 + ne01)*sizeof(float)*n_tasks;
//printf("flash: ne00 = %d, ne01 = %d (%d)\n", (int)ne00, (int)ne01, (int)node->src[0]->ne[1]);
//cur = 3*sizeof(float)*ne00*n_tasks; // 3x head size/thread
//cur += (node->src[0]->ne[1] + CACHE_LINE_SIZE_F32)*sizeof(float)*n_tasks;
//printf("flash: need %g bytes (up from %g) ne00 = %d, ne01 = %d ne02 = %d, ne03 = %d\n", 1.*cur, 3.*sizeof(float)*ne00*n_tasks, (int)ne00, (int)node->src[0]->ne[1], (int)node->src[0]->ne[2], (int)node->src[0]->ne[3]);
////cur = MAX(cur, (node->src[0]->ne[1] + CACHE_LINE_SIZE_F32)*sizeof(float)*n_tasks);
} break;
case GGML_OP_FLASH_ATTN_BACK:
{

View File

@@ -6215,6 +6215,35 @@ bool iqk_fused_mul_mat_softmax(long Nx, long Ny, long ne00,
return true;
}
//struct DataInfo {
// float * s;
// const char * cy;
// size_t bs;
// size_t by;
// int cur_y = 0;
// int ne11;
// const mmid_row_mapping * row_mapping = nullptr;
// size_t bs2 = 0;
void iqk_flash_helper(int nq, // number of elements in q
int nk, // number of rows in k
int stride_k, // distance between rows in k (in bytes)
const float * q, // q vector
const void * k, // k matrix. Assumed to be fp16, nq x nk elements
const void * mask, // mask. If not null, assumed to be fp16. nk elements
float scale,
float slope,
float * qk) {
GGML_ASSERT(nq % 4 == 0);
//GGML_ASSERT(nq / 16 <= 16);
DataInfo info{qk, (const char*)q, 0, size_t(stride_k), 0, 1, nullptr, 0};
mul_mat_fX_fY_T<1, ggml_half, float>(nq, k, stride_k, info, nk);
softmax_extended(nk, qk, qk, scale, slope, (const char *)mask, true);
}
#else // IQK_IMPLEMENT
bool iqk_mul_mat(int, long, long, long, int, const void *, long, int, const void *, long, float *, long, int, int) {

View File

@@ -29,6 +29,16 @@ bool iqk_fused_mul_mat_softmax(long Nx, long Ny, long ne00,
const char * mask, float scale, float slope,
int ith, int nth);
void iqk_flash_helper(int nq, // number of elements in q
int nk, // number of rows in k
int stride_k, // distance between rows in k (in bytes)
const float * q, // q vector
const void * k, // k matrix. Assumed to be fp16, nq x nk elements
const void * mask, // mask. If not null, assumed to be fp16. nk elements
float scale,
float slope,
float * qk); // softmax(k*q) - k elements
#ifdef __cplusplus
}
#endif