diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index e4dfeb9c..ffd4be44 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -16069,11 +16069,12 @@ static void ggml_compute_forward_flash_attn_ext_f16( ggml_fp16_t * VKQ16 = (ggml_fp16_t *) (VKQ32 + 1*D); // (temporary) FP16 VKQ accumulator ggml_fp16_t * Q_q = (ggml_fp16_t *) (VKQ32 + 2*D); // (temporary) buffer for Q converted to quantized/FP16 - if (v->type == GGML_TYPE_F16) { - memset(VKQ16, 0, D*sizeof(ggml_fp16_t)); - } else { - memset(VKQ32, 0, D*sizeof(float)); - } + //memset(VKQ32, 0, D*sizeof(float)); + //if (v->type == GGML_TYPE_F16) { + // memset(VKQ16, 0, D*sizeof(ggml_fp16_t)); + //} else { + // memset(VKQ32, 0, D*sizeof(float)); + //} const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL; @@ -16085,22 +16086,31 @@ static void ggml_compute_forward_flash_attn_ext_f16( const int iv3 = iq3 / rv3; const int iv2 = iq2 / rv2; - iqk_flash_helper(D, nek1, nbk1, - (const float *)((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)), + //iqk_flash_helper(D, nek1, nbk1, + // (const float *)((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)), + // (const void *)((char *) k->data + ik2*nbk2 + ik3*nbk3), + // (const void *)mp, + // scale, slope, + // aux_kq); + + //for (int64_t ic = 0; ic < nek1; ++ic) { + // const char * v_data = ((const char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3)); + // if (v->type == GGML_TYPE_F16) { + // ggml_vec_mad_f16(D, VKQ16, (const ggml_fp16_t *) v_data, aux_kq[ic]); + // } else { + // v_to_float(v_data, V32, D); + // ggml_vec_mad_f32(D, VKQ32, V32, aux_kq[ic]); + // } + //} + + iqk_flash_helper_2(D, nek1, nbk1, nbv1, + (const float *)((char *) q->data + iq1*nbq1 + iq2*nbq2 + iq3*nbq3), (const void *)((char *) k->data + ik2*nbk2 + ik3*nbk3), + (const void *)((char *) v->data + iv2*nbv2 + iv3*nbv3), (const void *)mp, scale, slope, - aux_kq); - - for (int64_t ic = 0; ic < nek1; ++ic) { - const char * v_data = ((const char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3)); - if (v->type == GGML_TYPE_F16) { - ggml_vec_mad_f16(D, VKQ16, (const ggml_fp16_t *) v_data, aux_kq[ic]); - } else { - v_to_float(v_data, V32, D); - ggml_vec_mad_f32(D, VKQ32, V32, aux_kq[ic]); - } - } + aux_kq, (float *)((char *) dst->data + (iq3*ne2*ne1 + iq2 + iq1*ne1)*nb1)); + //aux_kq, VKQ32); ////const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)); ////q_to_vec_dot(pq, Q_q, D); @@ -16166,26 +16176,26 @@ static void ggml_compute_forward_flash_attn_ext_f16( // S = S*ms + vs; // scale and increment sum with partial sum //} - if (v->type == GGML_TYPE_F16) { - for (int64_t d = 0; d < D; ++d) { - VKQ32[d] = GGML_FP16_TO_FP32(VKQ16[d]); - } - } + //if (v->type == GGML_TYPE_F16) { + // for (int64_t d = 0; d < D; ++d) { + // VKQ32[d] = GGML_FP16_TO_FP32(VKQ16[d]); + // } + //} //// V /= S //const float S_inv = 1.0f/S; //ggml_vec_scale_f32(D, VKQ32, S_inv); - // dst indices - const int i1 = iq1; - const int i2 = iq2; - const int i3 = iq3; + //// dst indices + //const int i1 = iq1; + //const int i2 = iq2; + //const int i3 = iq3; - // original - //memcpy((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3), V, nev0*sizeof(float)); + //// original + ////memcpy((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3), V, nev0*sizeof(float)); - // permute(0, 2, 1, 3) - memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32, nb1); + //// permute(0, 2, 1, 3) + //memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32, nb1); } } diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index d839fd44..ba1637c5 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -6244,6 +6244,60 @@ void iqk_flash_helper(int nq, // number of elements in q softmax_extended(nk, qk, qk, scale, slope, (const char *)mask, true); } +void iqk_flash_helper_2(int nq, // number of elements in q + int nk, // number of rows in k + int stride_k, // distance between rows in k (in bytes) + int stride_v, + const float * q, // q vector + const void * k, // k matrix. Assumed to be fp16, nq x nk elements + const void * v, // k matrix. Assumed to be fp16, nq x nk elements + const void * mask, // mask. If not null, assumed to be fp16. nk elements + float scale, + float slope, + float * qk, + float * qkv) { + GGML_ASSERT(nq % 4 == 0); + //GGML_ASSERT(nq / 16 <= 16); + + DataInfo info{qk, (const char*)q, 0, size_t(stride_k), 0, 1, nullptr, 0}; + + mul_mat_fX_fY_T<1, ggml_half, float>(nq, k, stride_k, info, nk); + softmax_extended(nk, qk, qk, scale, slope, (const char *)mask, true); + + GGML_ASSERT(nq%16 == 0); + if (nq/16 <= 16) { + __m512 v_qkv[16]; + auto v_qk = _mm512_set1_ps(qk[0]); + for (int j = 0; j < nq/16; ++j) { + auto v_v = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)v + j)); + v_qkv[j] = _mm512_mul_ps(v_qk, v_v); + } + for (int ic = 1; ic < nk; ++ic) { + const ggml_half * vr = (const ggml_half *)((const char *)v + ic*stride_v); + v_qk = _mm512_set1_ps(qk[ic]); + for (int j = 0; j < nq/16; ++j) { + auto v_v = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)vr + j)); + v_qkv[j] = _mm512_fmadd_ps(v_qk, v_v, v_qkv[j]); + } + } + for (int j = 0; j < nq/16; ++j) { + _mm512_storeu_ps(qkv + 16*j, v_qkv[j]); + } + return; + } + + for (int ic = 0; ic < nk; ++ic) { + auto v_qk = _mm512_set1_ps(qk[ic]); + const ggml_half * vr = (const ggml_half *)((const char *)v + ic*stride_v); + for (int j = 0; j < nq/16; ++j) { + auto v_v = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)vr + j)); + auto v_qkv = _mm512_loadu_ps(qkv + 16*j); + v_qkv = _mm512_fmadd_ps(v_qk, v_v, v_qkv); + _mm512_storeu_ps(qkv + 16*j, v_qkv); + } + } +} + #else // IQK_IMPLEMENT bool iqk_mul_mat(int, long, long, long, int, const void *, long, int, const void *, long, float *, long, int, int) { diff --git a/ggml/src/iqk/iqk_mul_mat.h b/ggml/src/iqk/iqk_mul_mat.h index 71b40e41..646b3190 100644 --- a/ggml/src/iqk/iqk_mul_mat.h +++ b/ggml/src/iqk/iqk_mul_mat.h @@ -39,6 +39,19 @@ void iqk_flash_helper(int nq, // number of elements in q float slope, float * qk); // softmax(k*q) - k elements +void iqk_flash_helper_2(int nq, // number of elements in q + int nk, // number of rows in k + int stride_k, // distance between rows in k (in bytes) + int stride_v, // distance between rows in k (in bytes) + const float * q, // q vector + const void * k, // k matrix. Assumed to be fp16, nq x nk elements + const void * v, + const void * mask, // mask. If not null, assumed to be fp16. nk elements + float scale, + float slope, + float * qk, + float * qkv); // softmax(k*q) - k elements + #ifdef __cplusplus } #endif