From ffeb8b40eb4bb1eb1259ba6842cedd36f8f12665 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Fri, 23 Aug 2024 15:47:08 +0300 Subject: [PATCH] WIP: plugging into ggml_compute_forward_flash_attn_ext_f16 This is now working. It is not faster, but at least it is not massively slower as the original. --- ggml/src/ggml.c | 170 +++++++++++++++++++++-------------- ggml/src/iqk/iqk_mul_mat.cpp | 29 ++++++ ggml/src/iqk/iqk_mul_mat.h | 10 +++ 3 files changed, 144 insertions(+), 65 deletions(-) diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 98fb4fa1..e4dfeb9c 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -16035,6 +16035,18 @@ static void ggml_compute_forward_flash_attn_ext_f16( ggml_vec_dot_t const kq_vec_dot = type_traits[k->type].vec_dot; ggml_to_float_t const v_to_float = type_traits[v->type].to_float; + size_t nek64 = CACHE_LINE_SIZE_F32*((nek1 + CACHE_LINE_SIZE_F32 - 1)/CACHE_LINE_SIZE_F32); + size_t needed_wsize = (nek64 + 3*D)*sizeof(float)*nth; + if (needed_wsize > params->wsize) { + printf("Work size is not big enough. Need %d bytes, have %d\n", (int)needed_wsize, (int)params->wsize); + GGML_ASSERT(false); + } + + //if (ith == 0) { + // printf("=== %s: k->type = %s, q->type = %s, v->type = %s\n", __func__, ggml_type_name(k->type), ggml_type_name(q->type), ggml_type_name(v->type)); + // printf(" D = %d, nr = %d, dr = %d, nek1 = %d wsize = %d\n", (int)D, nr, dr, (int)nek1, (int)params->wsize); + //} + // loop over n_batch and n_head for (int ir = ir0; ir < ir1; ++ir) { // q indices @@ -16048,7 +16060,11 @@ static void ggml_compute_forward_flash_attn_ext_f16( float S = 0.0f; // sum float M = -INFINITY; // maximum KQ value - float * VKQ32 = (float *) params->wdata + ith*(3*D + CACHE_LINE_SIZE_F32); // FP32 VKQ accumulator + int64_t nek64 = CACHE_LINE_SIZE_F32*((nek1 + CACHE_LINE_SIZE_F32 - 1)/CACHE_LINE_SIZE_F32); + float * aux_kq = (float *)params->wdata + ith*(nek64 + 3*D); + float * VKQ32 = aux_kq + nek1; + + //float * VKQ32 = (float *) params->wdata + ith*(3*D + CACHE_LINE_SIZE_F32); // FP32 VKQ accumulator float * V32 = (VKQ32 + 1*D); // (temporary) FP32 V buffer ggml_fp16_t * VKQ16 = (ggml_fp16_t *) (VKQ32 + 1*D); // (temporary) FP16 VKQ accumulator ggml_fp16_t * Q_q = (ggml_fp16_t *) (VKQ32 + 2*D); // (temporary) buffer for Q converted to quantized/FP16 @@ -16069,78 +16085,96 @@ static void ggml_compute_forward_flash_attn_ext_f16( const int iv3 = iq3 / rv3; const int iv2 = iq2 / rv2; - const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)); - q_to_vec_dot(pq, Q_q, D); + iqk_flash_helper(D, nek1, nbk1, + (const float *)((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)), + (const void *)((char *) k->data + ik2*nbk2 + ik3*nbk3), + (const void *)mp, + scale, slope, + aux_kq); - // online softmax / attention - // loop over n_kv and n_head_kv - // ref: https://arxiv.org/pdf/2112.05682.pdf for (int64_t ic = 0; ic < nek1; ++ic) { - const float mv = mp ? slope*GGML_FP16_TO_FP32(mp[ic]) : 0.0f; - if (mv == -INFINITY) { - continue; - } - - float s; // KQ value - - const char * k_data = (const char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3); - kq_vec_dot(D, &s, 0, k_data, 0, Q_q, 0, 1); - - s = s*scale + mv; // scale KQ value and apply mask - - const float Mold = M; - - float ms = 1.0f; // upon new higher max val, scale VKQ and KQ sum with this value - float vs = 1.0f; // post-softmax KQ value, expf(s - M) - const char * v_data = ((const char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3)); - - if (v->type== GGML_TYPE_F16) { - if (s > M) { - // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f - M = s; - ms = expf(Mold - M); - - // V = V*expf(Mold - M) - ggml_vec_scale_f16(D, VKQ16, ms); - } else { - // no new maximum, ms == 1.0f, vs != 1.0f - vs = expf(s - M); - } - - // V += v*expf(s - M) - ggml_vec_mad_f16(D, VKQ16, (const ggml_fp16_t *) v_data, vs); + if (v->type == GGML_TYPE_F16) { + ggml_vec_mad_f16(D, VKQ16, (const ggml_fp16_t *) v_data, aux_kq[ic]); } else { - if (s > M) { - // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f - M = s; - ms = expf(Mold - M); - - // V = V*expf(Mold - M) - ggml_vec_scale_f32(D, VKQ32, ms); - } else { - // no new maximum, ms == 1.0f, vs != 1.0f - vs = expf(s - M); - } - v_to_float(v_data, V32, D); - - // V += v*expf(s - M) - ggml_vec_mad_f32(D, VKQ32, V32, vs); + ggml_vec_mad_f32(D, VKQ32, V32, aux_kq[ic]); } - - S = S*ms + vs; // scale and increment sum with partial sum } + ////const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)); + ////q_to_vec_dot(pq, Q_q, D); + + //// online softmax / attention + //// loop over n_kv and n_head_kv + //// ref: https://arxiv.org/pdf/2112.05682.pdf + //for (int64_t ic = 0; ic < nek1; ++ic) { + // const float mv = mp ? slope*GGML_FP16_TO_FP32(mp[ic]) : 0.0f; + // if (mv == -INFINITY) { + // continue; + // } + + // float s = aux_kq[ic]; + // //float s; // KQ value + + // //const char * k_data = (const char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3); + // //kq_vec_dot(D, &s, 0, k_data, 0, Q_q, 0, 1); + + // s = s*scale + mv; // scale KQ value and apply mask + + // const float Mold = M; + + // float ms = 1.0f; // upon new higher max val, scale VKQ and KQ sum with this value + // float vs = 1.0f; // post-softmax KQ value, expf(s - M) + + // const char * v_data = ((const char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3)); + + // if (v->type== GGML_TYPE_F16) { + // if (s > M) { + // // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f + // M = s; + // ms = expf(Mold - M); + + // // V = V*expf(Mold - M) + // ggml_vec_scale_f16(D, VKQ16, ms); + // } else { + // // no new maximum, ms == 1.0f, vs != 1.0f + // vs = expf(s - M); + // } + + // // V += v*expf(s - M) + // ggml_vec_mad_f16(D, VKQ16, (const ggml_fp16_t *) v_data, vs); + // } else { + // if (s > M) { + // // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f + // M = s; + // ms = expf(Mold - M); + + // // V = V*expf(Mold - M) + // ggml_vec_scale_f32(D, VKQ32, ms); + // } else { + // // no new maximum, ms == 1.0f, vs != 1.0f + // vs = expf(s - M); + // } + + // v_to_float(v_data, V32, D); + + // // V += v*expf(s - M) + // ggml_vec_mad_f32(D, VKQ32, V32, vs); + // } + + // S = S*ms + vs; // scale and increment sum with partial sum + //} + if (v->type == GGML_TYPE_F16) { for (int64_t d = 0; d < D; ++d) { VKQ32[d] = GGML_FP16_TO_FP32(VKQ16[d]); } } - // V /= S - const float S_inv = 1.0f/S; - ggml_vec_scale_f32(D, VKQ32, S_inv); + //// V /= S + //const float S_inv = 1.0f/S; + //ggml_vec_scale_f32(D, VKQ32, S_inv); // dst indices const int i1 = iq1; @@ -17561,11 +17595,11 @@ static bool ggml_compute_forward(struct ggml_compute_params * params, struct ggm } break; case GGML_OP_MUL_MAT: { - if (next && next->op == GGML_OP_SOFT_MAX) { - result = ggml_fused_mul_mat_softmax(params, tensor, next); - } else { + //if (next && next->op == GGML_OP_SOFT_MAX) { + // result = ggml_fused_mul_mat_softmax(params, tensor, next); + //} else { ggml_compute_forward_mul_mat(params, tensor); - } + //} } break; case GGML_OP_MUL_MAT_ID: { @@ -19566,8 +19600,14 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa case GGML_OP_FLASH_ATTN_EXT: { const int64_t ne00 = node->src[0]->ne[0]; // D - - cur = 3*sizeof(float)*ne00*n_tasks; // 3x head size/thread + int64_t ne01 = node->src[1]->ne[1]; + ne01 = CACHE_LINE_SIZE_F32*((ne01 + CACHE_LINE_SIZE_F32 - 1)/CACHE_LINE_SIZE_F32); + cur = (3*ne00 + ne01)*sizeof(float)*n_tasks; + //printf("flash: ne00 = %d, ne01 = %d (%d)\n", (int)ne00, (int)ne01, (int)node->src[0]->ne[1]); + //cur = 3*sizeof(float)*ne00*n_tasks; // 3x head size/thread + //cur += (node->src[0]->ne[1] + CACHE_LINE_SIZE_F32)*sizeof(float)*n_tasks; + //printf("flash: need %g bytes (up from %g) ne00 = %d, ne01 = %d ne02 = %d, ne03 = %d\n", 1.*cur, 3.*sizeof(float)*ne00*n_tasks, (int)ne00, (int)node->src[0]->ne[1], (int)node->src[0]->ne[2], (int)node->src[0]->ne[3]); + ////cur = MAX(cur, (node->src[0]->ne[1] + CACHE_LINE_SIZE_F32)*sizeof(float)*n_tasks); } break; case GGML_OP_FLASH_ATTN_BACK: { diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 51156a74..d839fd44 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -6215,6 +6215,35 @@ bool iqk_fused_mul_mat_softmax(long Nx, long Ny, long ne00, return true; } +//struct DataInfo { +// float * s; +// const char * cy; +// size_t bs; +// size_t by; +// int cur_y = 0; +// int ne11; +// const mmid_row_mapping * row_mapping = nullptr; +// size_t bs2 = 0; + + +void iqk_flash_helper(int nq, // number of elements in q + int nk, // number of rows in k + int stride_k, // distance between rows in k (in bytes) + const float * q, // q vector + const void * k, // k matrix. Assumed to be fp16, nq x nk elements + const void * mask, // mask. If not null, assumed to be fp16. nk elements + float scale, + float slope, + float * qk) { + GGML_ASSERT(nq % 4 == 0); + //GGML_ASSERT(nq / 16 <= 16); + + DataInfo info{qk, (const char*)q, 0, size_t(stride_k), 0, 1, nullptr, 0}; + + mul_mat_fX_fY_T<1, ggml_half, float>(nq, k, stride_k, info, nk); + softmax_extended(nk, qk, qk, scale, slope, (const char *)mask, true); +} + #else // IQK_IMPLEMENT bool iqk_mul_mat(int, long, long, long, int, const void *, long, int, const void *, long, float *, long, int, int) { diff --git a/ggml/src/iqk/iqk_mul_mat.h b/ggml/src/iqk/iqk_mul_mat.h index 956ee8f7..71b40e41 100644 --- a/ggml/src/iqk/iqk_mul_mat.h +++ b/ggml/src/iqk/iqk_mul_mat.h @@ -29,6 +29,16 @@ bool iqk_fused_mul_mat_softmax(long Nx, long Ny, long ne00, const char * mask, float scale, float slope, int ith, int nth); +void iqk_flash_helper(int nq, // number of elements in q + int nk, // number of rows in k + int stride_k, // distance between rows in k (in bytes) + const float * q, // q vector + const void * k, // k matrix. Assumed to be fp16, nq x nk elements + const void * mask, // mask. If not null, assumed to be fp16. nk elements + float scale, + float slope, + float * qk); // softmax(k*q) - k elements + #ifdef __cplusplus } #endif