mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-25 23:54:10 +00:00
WIP: plugging into ggml_compute_forward_flash_attn_ext_f16
Now everything is done in iqk_flash_helper_2. It is slower than no FA at 2048 tokens we have 167 vs 176 t/s. This is better than Georgi's FA (138 t/s), but... At 8192 tokens we degrade to 93 t/s vs 134 t/s without.
This commit is contained in:
@@ -16069,11 +16069,12 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
||||
ggml_fp16_t * VKQ16 = (ggml_fp16_t *) (VKQ32 + 1*D); // (temporary) FP16 VKQ accumulator
|
||||
ggml_fp16_t * Q_q = (ggml_fp16_t *) (VKQ32 + 2*D); // (temporary) buffer for Q converted to quantized/FP16
|
||||
|
||||
if (v->type == GGML_TYPE_F16) {
|
||||
memset(VKQ16, 0, D*sizeof(ggml_fp16_t));
|
||||
} else {
|
||||
memset(VKQ32, 0, D*sizeof(float));
|
||||
}
|
||||
//memset(VKQ32, 0, D*sizeof(float));
|
||||
//if (v->type == GGML_TYPE_F16) {
|
||||
// memset(VKQ16, 0, D*sizeof(ggml_fp16_t));
|
||||
//} else {
|
||||
// memset(VKQ32, 0, D*sizeof(float));
|
||||
//}
|
||||
|
||||
const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL;
|
||||
|
||||
@@ -16085,22 +16086,31 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
||||
const int iv3 = iq3 / rv3;
|
||||
const int iv2 = iq2 / rv2;
|
||||
|
||||
iqk_flash_helper(D, nek1, nbk1,
|
||||
(const float *)((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)),
|
||||
//iqk_flash_helper(D, nek1, nbk1,
|
||||
// (const float *)((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)),
|
||||
// (const void *)((char *) k->data + ik2*nbk2 + ik3*nbk3),
|
||||
// (const void *)mp,
|
||||
// scale, slope,
|
||||
// aux_kq);
|
||||
|
||||
//for (int64_t ic = 0; ic < nek1; ++ic) {
|
||||
// const char * v_data = ((const char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3));
|
||||
// if (v->type == GGML_TYPE_F16) {
|
||||
// ggml_vec_mad_f16(D, VKQ16, (const ggml_fp16_t *) v_data, aux_kq[ic]);
|
||||
// } else {
|
||||
// v_to_float(v_data, V32, D);
|
||||
// ggml_vec_mad_f32(D, VKQ32, V32, aux_kq[ic]);
|
||||
// }
|
||||
//}
|
||||
|
||||
iqk_flash_helper_2(D, nek1, nbk1, nbv1,
|
||||
(const float *)((char *) q->data + iq1*nbq1 + iq2*nbq2 + iq3*nbq3),
|
||||
(const void *)((char *) k->data + ik2*nbk2 + ik3*nbk3),
|
||||
(const void *)((char *) v->data + iv2*nbv2 + iv3*nbv3),
|
||||
(const void *)mp,
|
||||
scale, slope,
|
||||
aux_kq);
|
||||
|
||||
for (int64_t ic = 0; ic < nek1; ++ic) {
|
||||
const char * v_data = ((const char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3));
|
||||
if (v->type == GGML_TYPE_F16) {
|
||||
ggml_vec_mad_f16(D, VKQ16, (const ggml_fp16_t *) v_data, aux_kq[ic]);
|
||||
} else {
|
||||
v_to_float(v_data, V32, D);
|
||||
ggml_vec_mad_f32(D, VKQ32, V32, aux_kq[ic]);
|
||||
}
|
||||
}
|
||||
aux_kq, (float *)((char *) dst->data + (iq3*ne2*ne1 + iq2 + iq1*ne1)*nb1));
|
||||
//aux_kq, VKQ32);
|
||||
|
||||
////const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3));
|
||||
////q_to_vec_dot(pq, Q_q, D);
|
||||
@@ -16166,26 +16176,26 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
||||
// S = S*ms + vs; // scale and increment sum with partial sum
|
||||
//}
|
||||
|
||||
if (v->type == GGML_TYPE_F16) {
|
||||
for (int64_t d = 0; d < D; ++d) {
|
||||
VKQ32[d] = GGML_FP16_TO_FP32(VKQ16[d]);
|
||||
}
|
||||
}
|
||||
//if (v->type == GGML_TYPE_F16) {
|
||||
// for (int64_t d = 0; d < D; ++d) {
|
||||
// VKQ32[d] = GGML_FP16_TO_FP32(VKQ16[d]);
|
||||
// }
|
||||
//}
|
||||
|
||||
//// V /= S
|
||||
//const float S_inv = 1.0f/S;
|
||||
//ggml_vec_scale_f32(D, VKQ32, S_inv);
|
||||
|
||||
// dst indices
|
||||
const int i1 = iq1;
|
||||
const int i2 = iq2;
|
||||
const int i3 = iq3;
|
||||
//// dst indices
|
||||
//const int i1 = iq1;
|
||||
//const int i2 = iq2;
|
||||
//const int i3 = iq3;
|
||||
|
||||
// original
|
||||
//memcpy((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3), V, nev0*sizeof(float));
|
||||
//// original
|
||||
////memcpy((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3), V, nev0*sizeof(float));
|
||||
|
||||
// permute(0, 2, 1, 3)
|
||||
memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32, nb1);
|
||||
//// permute(0, 2, 1, 3)
|
||||
//memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32, nb1);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -6244,6 +6244,60 @@ void iqk_flash_helper(int nq, // number of elements in q
|
||||
softmax_extended(nk, qk, qk, scale, slope, (const char *)mask, true);
|
||||
}
|
||||
|
||||
void iqk_flash_helper_2(int nq, // number of elements in q
|
||||
int nk, // number of rows in k
|
||||
int stride_k, // distance between rows in k (in bytes)
|
||||
int stride_v,
|
||||
const float * q, // q vector
|
||||
const void * k, // k matrix. Assumed to be fp16, nq x nk elements
|
||||
const void * v, // k matrix. Assumed to be fp16, nq x nk elements
|
||||
const void * mask, // mask. If not null, assumed to be fp16. nk elements
|
||||
float scale,
|
||||
float slope,
|
||||
float * qk,
|
||||
float * qkv) {
|
||||
GGML_ASSERT(nq % 4 == 0);
|
||||
//GGML_ASSERT(nq / 16 <= 16);
|
||||
|
||||
DataInfo info{qk, (const char*)q, 0, size_t(stride_k), 0, 1, nullptr, 0};
|
||||
|
||||
mul_mat_fX_fY_T<1, ggml_half, float>(nq, k, stride_k, info, nk);
|
||||
softmax_extended(nk, qk, qk, scale, slope, (const char *)mask, true);
|
||||
|
||||
GGML_ASSERT(nq%16 == 0);
|
||||
if (nq/16 <= 16) {
|
||||
__m512 v_qkv[16];
|
||||
auto v_qk = _mm512_set1_ps(qk[0]);
|
||||
for (int j = 0; j < nq/16; ++j) {
|
||||
auto v_v = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)v + j));
|
||||
v_qkv[j] = _mm512_mul_ps(v_qk, v_v);
|
||||
}
|
||||
for (int ic = 1; ic < nk; ++ic) {
|
||||
const ggml_half * vr = (const ggml_half *)((const char *)v + ic*stride_v);
|
||||
v_qk = _mm512_set1_ps(qk[ic]);
|
||||
for (int j = 0; j < nq/16; ++j) {
|
||||
auto v_v = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)vr + j));
|
||||
v_qkv[j] = _mm512_fmadd_ps(v_qk, v_v, v_qkv[j]);
|
||||
}
|
||||
}
|
||||
for (int j = 0; j < nq/16; ++j) {
|
||||
_mm512_storeu_ps(qkv + 16*j, v_qkv[j]);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
for (int ic = 0; ic < nk; ++ic) {
|
||||
auto v_qk = _mm512_set1_ps(qk[ic]);
|
||||
const ggml_half * vr = (const ggml_half *)((const char *)v + ic*stride_v);
|
||||
for (int j = 0; j < nq/16; ++j) {
|
||||
auto v_v = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)vr + j));
|
||||
auto v_qkv = _mm512_loadu_ps(qkv + 16*j);
|
||||
v_qkv = _mm512_fmadd_ps(v_qk, v_v, v_qkv);
|
||||
_mm512_storeu_ps(qkv + 16*j, v_qkv);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#else // IQK_IMPLEMENT
|
||||
|
||||
bool iqk_mul_mat(int, long, long, long, int, const void *, long, int, const void *, long, float *, long, int, int) {
|
||||
|
||||
@@ -39,6 +39,19 @@ void iqk_flash_helper(int nq, // number of elements in q
|
||||
float slope,
|
||||
float * qk); // softmax(k*q) - k elements
|
||||
|
||||
void iqk_flash_helper_2(int nq, // number of elements in q
|
||||
int nk, // number of rows in k
|
||||
int stride_k, // distance between rows in k (in bytes)
|
||||
int stride_v, // distance between rows in k (in bytes)
|
||||
const float * q, // q vector
|
||||
const void * k, // k matrix. Assumed to be fp16, nq x nk elements
|
||||
const void * v,
|
||||
const void * mask, // mask. If not null, assumed to be fp16. nk elements
|
||||
float scale,
|
||||
float slope,
|
||||
float * qk,
|
||||
float * qkv); // softmax(k*q) - k elements
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
Reference in New Issue
Block a user