Experimenting with flash attention on Zen4

This version outperforms no-FA up to 16k tokens, but
still becomes slower at 32k.

Here the t/s for LLaMA-3.1-8B on a Ryzen-7950X

|      test |       t/s no FA  |   Georgi FA    |  This commit FA |
| --------: | ---------------: | -------------: | --------------: |
|     pp256 |    193.46 ± 2.40 |  193.19 ± 5.07 |   197.73 ± 0.72 |
|     pp512 |    192.23 ± 1.83 |  188.14 ± 0.63 |   194.38 ± 0.69 |
|    pp1024 |    189.06 ± 0.72 |  170.81 ± 4.82 |   191.12 ± 1.47 |
|    pp2048 |    181.92 ± 1.21 |  140.36 ± 1.77 |   184.57 ± 1.20 |
|    pp4096 |    165.10 ± 0.95 |  117.50 ± 0.35 |   168.79 ± 0.50 |
|    pp8192 |    137.48 ± 0.75 |   68.54 ± 1.00 |   148.21 ± 0.64 |
|   pp16384 |    100.35 ± 0.93 |                |   105.14 ± 0.00 |
|   pp32768 |     64.44        |                |    57.36        |

Didn't have the patience to run Georgi's FA at 16k tokens.
No error estimate on the 32k result as I only ran 1 sample.
This commit is contained in:
Iwan Kawrakow
2024-08-30 08:42:34 +03:00
parent 92adf7e6df
commit e4959f9e46

View File

@@ -6772,7 +6772,6 @@ void iqk_flash_helper_3(int ne00,
float * qkv) {
constexpr int q_step = 8;
constexpr int k_step = 32; //16;
//if (nq1%q_step != 0 || nk1%k_step != 0) printf("Oops(%s): nq1 = %d, nk1 = %d\n", __func__, nq1, nk1);
if (nq1%q_step != 0 || nk1%k_step != 0) {
for (int iq1 = 0; iq1 < nq1; ++iq1) {
iqk_flash_helper_2(false, ne00, nk1, stride_k, stride_v,
@@ -6784,49 +6783,19 @@ void iqk_flash_helper_3(int ne00,
return;
}
stride_q /= sizeof(float);
// The following works
//for (int iq1 = 0; iq1 < nq1; ++iq1) {
// iqk_flash_helper_2(false,
// ne00,
// nk1,
// stride_k,
// stride_v,
// q,
// k,
// v,
// (const void *)((const char *)mask + iq1*stride_m),
// scale,
// 1.0f,
// nullptr,
// qkv);
// q += stride_q;
// qkv += stride_qkv;
//}
const ggml_half h_inf = GGML_FP32_TO_FP16(-INFINITY);
float cache[q_step*k_step];
float S[q_step], M[q_step];
__m512 vk[8];
__m512 vk[16];
__m512 vms[q_step];
__m512 vals[k_step/16];
bool need_scaling[q_step];
auto vscale = _mm512_set1_ps(scale);
for (int i1 = 0; i1 < nq1/q_step; ++i1) {
for (int j = 0; j < q_step; ++j) {
//auto R = qkv + (q_step*i1 + j)*stride_qkv;
//std::memset(R, 0, 128*sizeof(float));
S[j] = 0; M[j] = -INFINITY;
}
for (int k1 = 0; k1 < nk1/k_step; ++k1) {
///////////////////////////////////////////////////////////////////////////////////
//const ggml_half * kr = (const ggml_half *)((const char *)k + stride_k*16*k1);
//DataInfo info{cache, (const char *)(q + 16*i1*stride_q), 16*sizeof(float), size_t(stride_q)*sizeof(float), 0, 0, nullptr, 0};
//mul_mat_fX_fY_T<4, ggml_half, float>(ne00, (const void *)kr, stride_k, info, 16);
///////////////////////////////////////////////////////////////////////////////////
//for (int l1 = 0; l1 < 16; ++l1) {
// for (int m1 = 0; m1 < 16; ++m1) {
// const ggml_half * mp = (const ggml_half *)((const char *)mask + stride_m*(16*i1 + m1)) + 16*k1;
// cache[16*m1 + l1] = GGML_FP16_TO_FP32(mp[l1]) == -INFINITY ? -INFINITY : scale*cache[16*m1 + l1];
// }
//}
for (int l1 = 0; l1 < k_step; ++l1) {
auto kr = (const ggml_half *)((const char *)k + (k_step*k1 + l1)*stride_k);
for (int i = 0; i < 8; ++i) vk[i] = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr + i));
@@ -6844,14 +6813,16 @@ void iqk_flash_helper_3(int ne00,
cache[k_step*m1 + l1] = _mm512_reduce_add_ps(vsum);
}
}
// This variant is much slower than the one below
for (int j = 0; j < q_step; ++j) {
auto R = qkv + (q_step*i1 + j)*stride_qkv;
auto val1 = _mm512_mul_ps(vscale, _mm512_loadu_ps(cache + k_step*j));
auto val2 = _mm512_mul_ps(vscale, _mm512_loadu_ps(cache + k_step*j + 16));
auto smax = _mm512_reduce_max_ps(_mm512_max_ps(val1, val2));
//auto val = _mm512_mul_ps(vscale, _mm512_loadu_ps(cache + k_step*j));
//auto smax = _mm512_reduce_max_ps(val);
for (int l = 0; l < k_step/16; ++l) vals[l] = _mm512_mul_ps(vscale, _mm512_loadu_ps(cache + k_step*j + 16*l));
auto smax = _mm512_reduce_max_ps(_mm512_max_ps(vals[0], vals[1]));
//auto smax = _mm512_reduce_max_ps(_mm512_max_ps(_mm512_max_ps(vals[0], vals[1]), _mm512_max_ps(vals[2], vals[3])));
//auto val1 = _mm512_mul_ps(vscale, _mm512_loadu_ps(cache + k_step*j));
//auto val2 = _mm512_mul_ps(vscale, _mm512_loadu_ps(cache + k_step*j + 16));
//auto smax = _mm512_reduce_max_ps(_mm512_max_ps(val1, val2));
////auto val = _mm512_mul_ps(vscale, _mm512_loadu_ps(cache + k_step*j));
////auto smax = _mm512_reduce_max_ps(val);
need_scaling[j] = false;
if (smax > M[j]) {
if (M[j] > -INFINITY) {
@@ -6865,33 +6836,67 @@ void iqk_flash_helper_3(int ne00,
}
M[j] = smax;
}
val1 = v_expf(_mm512_sub_ps(val1, _mm512_set1_ps(M[j])));
val2 = v_expf(_mm512_sub_ps(val2, _mm512_set1_ps(M[j])));
S[j] += _mm512_reduce_add_ps(_mm512_add_ps(val1, val2));
_mm512_storeu_ps(cache + k_step*j, val1);
_mm512_storeu_ps(cache + k_step*j + 16, val2);
//val = v_expf(_mm512_sub_ps(val, _mm512_set1_ps(M[j])));
//S[j] += _mm512_reduce_add_ps(val);
//_mm512_storeu_ps(cache + k_step*j, val);
auto vm = _mm512_set1_ps(M[j]);
for (int l = 0; l < k_step/16; ++l) {
vals[l] = v_expf(_mm512_sub_ps(vals[l], vm));
_mm512_storeu_ps(cache + k_step*j + 16*l, vals[l]);
}
S[j] += _mm512_reduce_add_ps(_mm512_add_ps(vals[0], vals[1]));
//S[j] += _mm512_reduce_add_ps(_mm512_add_ps(_mm512_add_ps(vals[0], vals[1]), _mm512_add_ps(vals[2], vals[3])));
//val1 = v_expf(_mm512_sub_ps(val1, _mm512_set1_ps(M[j])));
//val2 = v_expf(_mm512_sub_ps(val2, _mm512_set1_ps(M[j])));
//S[j] += _mm512_reduce_add_ps(_mm512_add_ps(val1, val2));
//_mm512_storeu_ps(cache + k_step*j, val1);
//_mm512_storeu_ps(cache + k_step*j + 16, val2);
////val = v_expf(_mm512_sub_ps(val, _mm512_set1_ps(M[j])));
////S[j] += _mm512_reduce_add_ps(val);
////_mm512_storeu_ps(cache + k_step*j, val);
}
for (int i = 0; i < 8; ++i) {
for (int i = 0; i < 8; i += 2) {
for (int j = 0; j < q_step; ++j) {
auto R = qkv + (q_step*i1 + j)*stride_qkv;
vk[j] = _mm512_loadu_ps(R + 16*i);
if (need_scaling[j]) vk[j] = _mm512_mul_ps(vk[j], vms[j]);
vk[2*j+0] = _mm512_loadu_ps(R + 16*i);
vk[2*j+1] = _mm512_loadu_ps(R + 16*i + 16);
if (need_scaling[j]) {
vk[2*j+0] = _mm512_mul_ps(vk[2*j+0], vms[j]);
vk[2*j+1] = _mm512_mul_ps(vk[2*j+1], vms[j]);
}
}
for (int l1 = 0; l1 < k_step; ++l1) {
auto vr = (const ggml_half *)((const char *)v + (k_step*k1 + l1)*stride_v);
auto v = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)vr+i));
auto v1 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)vr+i+0));
auto v2 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)vr+i+1));
for (int j = 0; j < q_step; ++j) {
vk[j] = _mm512_fmadd_ps(v, _mm512_set1_ps(cache[k_step*j + l1]), vk[j]);
auto vs = _mm512_set1_ps(cache[k_step*j + l1]);
vk[2*j+0] = _mm512_fmadd_ps(v1, vs, vk[2*j+0]);
vk[2*j+1] = _mm512_fmadd_ps(v2, vs, vk[2*j+1]);
}
}
for (int j = 0; j < q_step; ++j) {
auto R = qkv + (q_step*i1 + j)*stride_qkv;
_mm512_storeu_ps(R + 16*i, vk[j]);
_mm512_storeu_ps(R + 16*i, vk[2*j+0]);
_mm512_storeu_ps(R + 16*i + 16, vk[2*j+1]);
}
}
//for (int i = 0; i < 8; ++i) {
// for (int j = 0; j < q_step; ++j) {
// auto R = qkv + (q_step*i1 + j)*stride_qkv;
// vk[j] = _mm512_loadu_ps(R + 16*i);
// if (need_scaling[j]) vk[j] = _mm512_mul_ps(vk[j], vms[j]);
// }
// for (int l1 = 0; l1 < k_step; ++l1) {
// auto vr = (const ggml_half *)((const char *)v + (k_step*k1 + l1)*stride_v);
// auto v = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)vr+i));
// for (int j = 0; j < q_step; ++j) {
// vk[j] = _mm512_fmadd_ps(v, _mm512_set1_ps(cache[k_step*j + l1]), vk[j]);
// }
// }
// for (int j = 0; j < q_step; ++j) {
// auto R = qkv + (q_step*i1 + j)*stride_qkv;
// _mm512_storeu_ps(R + 16*i, vk[j]);
// }
//}
//
//for (int j = 0; j < q_step; ++j) {
// auto R = qkv + (q_step*i1 + j)*stride_qkv;
// //auto val1 = _mm512_mul_ps(vscale, _mm512_loadu_ps(cache + k_step*j));