mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-25 23:54:10 +00:00
Experimenting with flash attention on Zen4
This version outperforms no-FA up to 16k tokens, but still becomes slower at 32k. Here the t/s for LLaMA-3.1-8B on a Ryzen-7950X | test | t/s no FA | Georgi FA | This commit FA | | --------: | ---------------: | -------------: | --------------: | | pp256 | 193.46 ± 2.40 | 193.19 ± 5.07 | 197.73 ± 0.72 | | pp512 | 192.23 ± 1.83 | 188.14 ± 0.63 | 194.38 ± 0.69 | | pp1024 | 189.06 ± 0.72 | 170.81 ± 4.82 | 191.12 ± 1.47 | | pp2048 | 181.92 ± 1.21 | 140.36 ± 1.77 | 184.57 ± 1.20 | | pp4096 | 165.10 ± 0.95 | 117.50 ± 0.35 | 168.79 ± 0.50 | | pp8192 | 137.48 ± 0.75 | 68.54 ± 1.00 | 148.21 ± 0.64 | | pp16384 | 100.35 ± 0.93 | | 105.14 ± 0.00 | | pp32768 | 64.44 | | 57.36 | Didn't have the patience to run Georgi's FA at 16k tokens. No error estimate on the 32k result as I only ran 1 sample.
This commit is contained in:
@@ -6772,7 +6772,6 @@ void iqk_flash_helper_3(int ne00,
|
||||
float * qkv) {
|
||||
constexpr int q_step = 8;
|
||||
constexpr int k_step = 32; //16;
|
||||
//if (nq1%q_step != 0 || nk1%k_step != 0) printf("Oops(%s): nq1 = %d, nk1 = %d\n", __func__, nq1, nk1);
|
||||
if (nq1%q_step != 0 || nk1%k_step != 0) {
|
||||
for (int iq1 = 0; iq1 < nq1; ++iq1) {
|
||||
iqk_flash_helper_2(false, ne00, nk1, stride_k, stride_v,
|
||||
@@ -6784,49 +6783,19 @@ void iqk_flash_helper_3(int ne00,
|
||||
return;
|
||||
}
|
||||
stride_q /= sizeof(float);
|
||||
// The following works
|
||||
//for (int iq1 = 0; iq1 < nq1; ++iq1) {
|
||||
// iqk_flash_helper_2(false,
|
||||
// ne00,
|
||||
// nk1,
|
||||
// stride_k,
|
||||
// stride_v,
|
||||
// q,
|
||||
// k,
|
||||
// v,
|
||||
// (const void *)((const char *)mask + iq1*stride_m),
|
||||
// scale,
|
||||
// 1.0f,
|
||||
// nullptr,
|
||||
// qkv);
|
||||
// q += stride_q;
|
||||
// qkv += stride_qkv;
|
||||
//}
|
||||
const ggml_half h_inf = GGML_FP32_TO_FP16(-INFINITY);
|
||||
float cache[q_step*k_step];
|
||||
float S[q_step], M[q_step];
|
||||
__m512 vk[8];
|
||||
__m512 vk[16];
|
||||
__m512 vms[q_step];
|
||||
__m512 vals[k_step/16];
|
||||
bool need_scaling[q_step];
|
||||
auto vscale = _mm512_set1_ps(scale);
|
||||
for (int i1 = 0; i1 < nq1/q_step; ++i1) {
|
||||
for (int j = 0; j < q_step; ++j) {
|
||||
//auto R = qkv + (q_step*i1 + j)*stride_qkv;
|
||||
//std::memset(R, 0, 128*sizeof(float));
|
||||
S[j] = 0; M[j] = -INFINITY;
|
||||
}
|
||||
for (int k1 = 0; k1 < nk1/k_step; ++k1) {
|
||||
///////////////////////////////////////////////////////////////////////////////////
|
||||
//const ggml_half * kr = (const ggml_half *)((const char *)k + stride_k*16*k1);
|
||||
//DataInfo info{cache, (const char *)(q + 16*i1*stride_q), 16*sizeof(float), size_t(stride_q)*sizeof(float), 0, 0, nullptr, 0};
|
||||
//mul_mat_fX_fY_T<4, ggml_half, float>(ne00, (const void *)kr, stride_k, info, 16);
|
||||
///////////////////////////////////////////////////////////////////////////////////
|
||||
//for (int l1 = 0; l1 < 16; ++l1) {
|
||||
// for (int m1 = 0; m1 < 16; ++m1) {
|
||||
// const ggml_half * mp = (const ggml_half *)((const char *)mask + stride_m*(16*i1 + m1)) + 16*k1;
|
||||
// cache[16*m1 + l1] = GGML_FP16_TO_FP32(mp[l1]) == -INFINITY ? -INFINITY : scale*cache[16*m1 + l1];
|
||||
// }
|
||||
//}
|
||||
for (int l1 = 0; l1 < k_step; ++l1) {
|
||||
auto kr = (const ggml_half *)((const char *)k + (k_step*k1 + l1)*stride_k);
|
||||
for (int i = 0; i < 8; ++i) vk[i] = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr + i));
|
||||
@@ -6844,14 +6813,16 @@ void iqk_flash_helper_3(int ne00,
|
||||
cache[k_step*m1 + l1] = _mm512_reduce_add_ps(vsum);
|
||||
}
|
||||
}
|
||||
// This variant is much slower than the one below
|
||||
for (int j = 0; j < q_step; ++j) {
|
||||
auto R = qkv + (q_step*i1 + j)*stride_qkv;
|
||||
auto val1 = _mm512_mul_ps(vscale, _mm512_loadu_ps(cache + k_step*j));
|
||||
auto val2 = _mm512_mul_ps(vscale, _mm512_loadu_ps(cache + k_step*j + 16));
|
||||
auto smax = _mm512_reduce_max_ps(_mm512_max_ps(val1, val2));
|
||||
//auto val = _mm512_mul_ps(vscale, _mm512_loadu_ps(cache + k_step*j));
|
||||
//auto smax = _mm512_reduce_max_ps(val);
|
||||
for (int l = 0; l < k_step/16; ++l) vals[l] = _mm512_mul_ps(vscale, _mm512_loadu_ps(cache + k_step*j + 16*l));
|
||||
auto smax = _mm512_reduce_max_ps(_mm512_max_ps(vals[0], vals[1]));
|
||||
//auto smax = _mm512_reduce_max_ps(_mm512_max_ps(_mm512_max_ps(vals[0], vals[1]), _mm512_max_ps(vals[2], vals[3])));
|
||||
//auto val1 = _mm512_mul_ps(vscale, _mm512_loadu_ps(cache + k_step*j));
|
||||
//auto val2 = _mm512_mul_ps(vscale, _mm512_loadu_ps(cache + k_step*j + 16));
|
||||
//auto smax = _mm512_reduce_max_ps(_mm512_max_ps(val1, val2));
|
||||
////auto val = _mm512_mul_ps(vscale, _mm512_loadu_ps(cache + k_step*j));
|
||||
////auto smax = _mm512_reduce_max_ps(val);
|
||||
need_scaling[j] = false;
|
||||
if (smax > M[j]) {
|
||||
if (M[j] > -INFINITY) {
|
||||
@@ -6865,33 +6836,67 @@ void iqk_flash_helper_3(int ne00,
|
||||
}
|
||||
M[j] = smax;
|
||||
}
|
||||
val1 = v_expf(_mm512_sub_ps(val1, _mm512_set1_ps(M[j])));
|
||||
val2 = v_expf(_mm512_sub_ps(val2, _mm512_set1_ps(M[j])));
|
||||
S[j] += _mm512_reduce_add_ps(_mm512_add_ps(val1, val2));
|
||||
_mm512_storeu_ps(cache + k_step*j, val1);
|
||||
_mm512_storeu_ps(cache + k_step*j + 16, val2);
|
||||
//val = v_expf(_mm512_sub_ps(val, _mm512_set1_ps(M[j])));
|
||||
//S[j] += _mm512_reduce_add_ps(val);
|
||||
//_mm512_storeu_ps(cache + k_step*j, val);
|
||||
auto vm = _mm512_set1_ps(M[j]);
|
||||
for (int l = 0; l < k_step/16; ++l) {
|
||||
vals[l] = v_expf(_mm512_sub_ps(vals[l], vm));
|
||||
_mm512_storeu_ps(cache + k_step*j + 16*l, vals[l]);
|
||||
}
|
||||
S[j] += _mm512_reduce_add_ps(_mm512_add_ps(vals[0], vals[1]));
|
||||
//S[j] += _mm512_reduce_add_ps(_mm512_add_ps(_mm512_add_ps(vals[0], vals[1]), _mm512_add_ps(vals[2], vals[3])));
|
||||
//val1 = v_expf(_mm512_sub_ps(val1, _mm512_set1_ps(M[j])));
|
||||
//val2 = v_expf(_mm512_sub_ps(val2, _mm512_set1_ps(M[j])));
|
||||
//S[j] += _mm512_reduce_add_ps(_mm512_add_ps(val1, val2));
|
||||
//_mm512_storeu_ps(cache + k_step*j, val1);
|
||||
//_mm512_storeu_ps(cache + k_step*j + 16, val2);
|
||||
////val = v_expf(_mm512_sub_ps(val, _mm512_set1_ps(M[j])));
|
||||
////S[j] += _mm512_reduce_add_ps(val);
|
||||
////_mm512_storeu_ps(cache + k_step*j, val);
|
||||
}
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
for (int i = 0; i < 8; i += 2) {
|
||||
for (int j = 0; j < q_step; ++j) {
|
||||
auto R = qkv + (q_step*i1 + j)*stride_qkv;
|
||||
vk[j] = _mm512_loadu_ps(R + 16*i);
|
||||
if (need_scaling[j]) vk[j] = _mm512_mul_ps(vk[j], vms[j]);
|
||||
vk[2*j+0] = _mm512_loadu_ps(R + 16*i);
|
||||
vk[2*j+1] = _mm512_loadu_ps(R + 16*i + 16);
|
||||
if (need_scaling[j]) {
|
||||
vk[2*j+0] = _mm512_mul_ps(vk[2*j+0], vms[j]);
|
||||
vk[2*j+1] = _mm512_mul_ps(vk[2*j+1], vms[j]);
|
||||
}
|
||||
}
|
||||
for (int l1 = 0; l1 < k_step; ++l1) {
|
||||
auto vr = (const ggml_half *)((const char *)v + (k_step*k1 + l1)*stride_v);
|
||||
auto v = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)vr+i));
|
||||
auto v1 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)vr+i+0));
|
||||
auto v2 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)vr+i+1));
|
||||
for (int j = 0; j < q_step; ++j) {
|
||||
vk[j] = _mm512_fmadd_ps(v, _mm512_set1_ps(cache[k_step*j + l1]), vk[j]);
|
||||
auto vs = _mm512_set1_ps(cache[k_step*j + l1]);
|
||||
vk[2*j+0] = _mm512_fmadd_ps(v1, vs, vk[2*j+0]);
|
||||
vk[2*j+1] = _mm512_fmadd_ps(v2, vs, vk[2*j+1]);
|
||||
}
|
||||
}
|
||||
for (int j = 0; j < q_step; ++j) {
|
||||
auto R = qkv + (q_step*i1 + j)*stride_qkv;
|
||||
_mm512_storeu_ps(R + 16*i, vk[j]);
|
||||
_mm512_storeu_ps(R + 16*i, vk[2*j+0]);
|
||||
_mm512_storeu_ps(R + 16*i + 16, vk[2*j+1]);
|
||||
}
|
||||
}
|
||||
//for (int i = 0; i < 8; ++i) {
|
||||
// for (int j = 0; j < q_step; ++j) {
|
||||
// auto R = qkv + (q_step*i1 + j)*stride_qkv;
|
||||
// vk[j] = _mm512_loadu_ps(R + 16*i);
|
||||
// if (need_scaling[j]) vk[j] = _mm512_mul_ps(vk[j], vms[j]);
|
||||
// }
|
||||
// for (int l1 = 0; l1 < k_step; ++l1) {
|
||||
// auto vr = (const ggml_half *)((const char *)v + (k_step*k1 + l1)*stride_v);
|
||||
// auto v = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)vr+i));
|
||||
// for (int j = 0; j < q_step; ++j) {
|
||||
// vk[j] = _mm512_fmadd_ps(v, _mm512_set1_ps(cache[k_step*j + l1]), vk[j]);
|
||||
// }
|
||||
// }
|
||||
// for (int j = 0; j < q_step; ++j) {
|
||||
// auto R = qkv + (q_step*i1 + j)*stride_qkv;
|
||||
// _mm512_storeu_ps(R + 16*i, vk[j]);
|
||||
// }
|
||||
//}
|
||||
//
|
||||
//for (int j = 0; j < q_step; ++j) {
|
||||
// auto R = qkv + (q_step*i1 + j)*stride_qkv;
|
||||
// //auto val1 = _mm512_mul_ps(vscale, _mm512_loadu_ps(cache + k_step*j));
|
||||
|
||||
Reference in New Issue
Block a user