mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-26 08:04:09 +00:00
Experimenting with flash attention on Zen4
Slightly (~1%) better by keeping qkv accumulators on the stack and only storing the final result into the qkv tensor.
This commit is contained in:
@@ -6789,6 +6789,7 @@ void iqk_flash_helper_3(int ne00,
|
||||
__m512 vk[16];
|
||||
__m512 vms[q_step];
|
||||
__m512 vals[k_step/16];
|
||||
float qkv_cache[128*q_step];
|
||||
bool need_scaling[q_step];
|
||||
auto vscale = _mm512_set1_ps(scale);
|
||||
for (int i1 = 0; i1 < nq1/q_step; ++i1) {
|
||||
@@ -6814,7 +6815,8 @@ void iqk_flash_helper_3(int ne00,
|
||||
}
|
||||
}
|
||||
for (int j = 0; j < q_step; ++j) {
|
||||
auto R = qkv + (q_step*i1 + j)*stride_qkv;
|
||||
auto R = qkv_cache + 128*j;
|
||||
//auto R = qkv + (q_step*i1 + j)*stride_qkv;
|
||||
for (int l = 0; l < k_step/16; ++l) vals[l] = _mm512_mul_ps(vscale, _mm512_loadu_ps(cache + k_step*j + 16*l));
|
||||
auto smax = _mm512_reduce_max_ps(_mm512_max_ps(vals[0], vals[1]));
|
||||
//auto smax = _mm512_reduce_max_ps(_mm512_max_ps(_mm512_max_ps(vals[0], vals[1]), _mm512_max_ps(vals[2], vals[3])));
|
||||
@@ -6854,7 +6856,8 @@ void iqk_flash_helper_3(int ne00,
|
||||
}
|
||||
for (int i = 0; i < 8; i += 2) {
|
||||
for (int j = 0; j < q_step; ++j) {
|
||||
auto R = qkv + (q_step*i1 + j)*stride_qkv;
|
||||
//auto R = qkv + (q_step*i1 + j)*stride_qkv;
|
||||
auto R = qkv_cache + 128*j;
|
||||
vk[2*j+0] = _mm512_loadu_ps(R + 16*i);
|
||||
vk[2*j+1] = _mm512_loadu_ps(R + 16*i + 16);
|
||||
if (need_scaling[j]) {
|
||||
@@ -6873,83 +6876,21 @@ void iqk_flash_helper_3(int ne00,
|
||||
}
|
||||
}
|
||||
for (int j = 0; j < q_step; ++j) {
|
||||
auto R = qkv + (q_step*i1 + j)*stride_qkv;
|
||||
auto R = qkv_cache + 128*j;
|
||||
//auto R = qkv + (q_step*i1 + j)*stride_qkv;
|
||||
_mm512_storeu_ps(R + 16*i, vk[2*j+0]);
|
||||
_mm512_storeu_ps(R + 16*i + 16, vk[2*j+1]);
|
||||
}
|
||||
}
|
||||
//for (int i = 0; i < 8; ++i) {
|
||||
// for (int j = 0; j < q_step; ++j) {
|
||||
// auto R = qkv + (q_step*i1 + j)*stride_qkv;
|
||||
// vk[j] = _mm512_loadu_ps(R + 16*i);
|
||||
// if (need_scaling[j]) vk[j] = _mm512_mul_ps(vk[j], vms[j]);
|
||||
// }
|
||||
// for (int l1 = 0; l1 < k_step; ++l1) {
|
||||
// auto vr = (const ggml_half *)((const char *)v + (k_step*k1 + l1)*stride_v);
|
||||
// auto v = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)vr+i));
|
||||
// for (int j = 0; j < q_step; ++j) {
|
||||
// vk[j] = _mm512_fmadd_ps(v, _mm512_set1_ps(cache[k_step*j + l1]), vk[j]);
|
||||
// }
|
||||
// }
|
||||
// for (int j = 0; j < q_step; ++j) {
|
||||
// auto R = qkv + (q_step*i1 + j)*stride_qkv;
|
||||
// _mm512_storeu_ps(R + 16*i, vk[j]);
|
||||
// }
|
||||
//}
|
||||
//
|
||||
//for (int j = 0; j < q_step; ++j) {
|
||||
// auto R = qkv + (q_step*i1 + j)*stride_qkv;
|
||||
// //auto val1 = _mm512_mul_ps(vscale, _mm512_loadu_ps(cache + k_step*j));
|
||||
// //auto val2 = _mm512_mul_ps(vscale, _mm512_loadu_ps(cache + k_step*j + 16));
|
||||
// //auto smax = _mm512_reduce_max_ps(_mm512_max_ps(val1, val2));
|
||||
// auto val = _mm512_mul_ps(vscale, _mm512_loadu_ps(cache + k_step*j));
|
||||
// auto smax = _mm512_reduce_max_ps(val);
|
||||
// for (int i = 0; i < 8; ++i) vk[i] = _mm512_loadu_ps(R + 16*i);
|
||||
// if (smax > M[j]) {
|
||||
// if (M[j] > -INFINITY) {
|
||||
// float m = expf(M[j] - smax);
|
||||
// auto vm = _mm512_set1_ps(m);
|
||||
// for (int i = 0; i < 8; ++i) {
|
||||
// vk[i] = _mm512_mul_ps(vm, vk[i]);
|
||||
// }
|
||||
// S[j] *= m;
|
||||
// } else {
|
||||
// for (int i = 0; i < 8; ++i) vk[i] = _mm512_setzero_ps();
|
||||
// S[j] = 0;
|
||||
// }
|
||||
// M[j] = smax;
|
||||
// }
|
||||
// //auto vm = _mm512_set1_ps(M[j]);
|
||||
// //val1 = v_expf(_mm512_sub_ps(val1, vm));
|
||||
// //val2 = v_expf(_mm512_sub_ps(val2, vm));
|
||||
// //S[j] += _mm512_reduce_add_ps(_mm512_add_ps(val1, val2));
|
||||
// //_mm512_storeu_ps(cache + k_step*j, val1);
|
||||
// //_mm512_storeu_ps(cache + k_step*j + 16, val2);
|
||||
// val = v_expf(_mm512_sub_ps(val, _mm512_set1_ps(M[j])));
|
||||
// S[j] += _mm512_reduce_add_ps(val);
|
||||
// _mm512_storeu_ps(cache + k_step*j, val);
|
||||
// for (int l1 = 0; l1 < k_step; ++l1) {
|
||||
// if (cache[k_step*j + l1] < -20.0f) continue;
|
||||
// auto vr = (const ggml_half *)((const char *)v + (k_step*k1 + l1)*stride_v);
|
||||
// auto vs = _mm512_set1_ps(cache[k_step*j + l1]);
|
||||
// for (int i = 0; i < 8; ++i) {
|
||||
// vk[i] = _mm512_fmadd_ps(vs, _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)vr+i)), vk[i]);
|
||||
// }
|
||||
// }
|
||||
// for (int i = 0; i < 8; ++i) _mm512_storeu_ps(R + 16*i, vk[i]);
|
||||
//}
|
||||
}
|
||||
for (int j = 0; j < q_step; ++j) {
|
||||
auto R = qkv + (q_step*i1 + j)*stride_qkv;
|
||||
GGML_ASSERT(S[j] > 0);
|
||||
if (S[j] > 0) {
|
||||
auto norm = _mm512_set1_ps(1/S[j]);
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
auto r = _mm512_loadu_ps(R + 16*i);
|
||||
_mm512_storeu_ps(R + 16*i, _mm512_mul_ps(norm, r));
|
||||
}
|
||||
} else {
|
||||
std::memset(R, 0, 128*sizeof(float));
|
||||
auto R = qkv_cache + 128*j;
|
||||
auto final_R = qkv + (q_step*i1 + j)*stride_qkv;
|
||||
auto norm = _mm512_set1_ps(1/S[j]);
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
auto r = _mm512_loadu_ps(R + 16*i);
|
||||
_mm512_storeu_ps(final_R + 16*i, _mm512_mul_ps(norm, r));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user