mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-10 00:10:13 +00:00
Fix Zen4 Flash Attention (#35)
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
@@ -16154,11 +16154,12 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
||||
mask && mask->type == GGML_TYPE_F16) {
|
||||
int64_t work_per_slice = D*nek1*neq1;
|
||||
int ntg = 1;
|
||||
if (nth%8 == 0 && work_per_slice >= (1 << 23)) ntg = 8;
|
||||
else if (nth%4 == 0 && work_per_slice >= (1 << 21)) ntg = 4;
|
||||
else if (nth%2 == 0 && work_per_slice >= (1 << 19)) ntg = 2;
|
||||
if (nth%8 == 0 && neq1%8 == 0 && work_per_slice >= (1 << 23)) ntg = 8;
|
||||
else if (nth%4 == 0 && neq1%4 == 0 && work_per_slice >= (1 << 21)) ntg = 4;
|
||||
else if (nth%2 == 0 && neq1%2 == 0 && work_per_slice >= (1 << 19)) ntg = 2;
|
||||
if ((neq2*neq3)%(nth/ntg) == 0) {
|
||||
//if (ith == 0) printf("%s: D = %d, neq2 = %d, neq1 = %d, nek1 = %d\n", __func__, (int)D, (int)neq2, (int)neq1, (int)nek1);
|
||||
//if (ith == 0) printf("%s: D = %d, neq2 = %d, neq1 = %d, nek1 = %d, ntg = %d, neq1/ntg = %d\n", __func__,
|
||||
// (int)D, (int)neq2, (int)neq1, (int)nek1, ntg, (int)(neq1/ntg));
|
||||
int counter = 0;
|
||||
for (int64_t iq3 = 0; iq3 < neq3; iq3++) {
|
||||
for (int64_t iq2 = 0; iq2 < neq2; iq2++) {
|
||||
|
||||
@@ -6120,7 +6120,7 @@ struct FlashAttn {
|
||||
}
|
||||
auto qr = q + m1*stride_q;
|
||||
auto vsum = _mm512_mul_ps(vk[0], _mm512_loadu_ps(qr));
|
||||
for (int i = 0; i < D/16; ++i) {
|
||||
for (int i = 1; i < D/16; ++i) {
|
||||
vsum = _mm512_fmadd_ps(vk[i], _mm512_loadu_ps(qr + 16*i), vsum);
|
||||
}
|
||||
cache[k_step*m1 + l1] = _mm512_reduce_add_ps(vsum);
|
||||
@@ -6138,6 +6138,11 @@ struct FlashAttn {
|
||||
}
|
||||
|
||||
float smax = reduce_T<_mm512_reduce_max_ps, _mm512_max_ps>(vk);
|
||||
if (smax == -INFINITY) {
|
||||
std::memset(cache + k_step*j, 0, k_step*sizeof(float));
|
||||
need_scaling[j] = M[j] == -INFINITY ? 2 : 0;
|
||||
return;
|
||||
}
|
||||
need_scaling[j] = 0;
|
||||
if (smax > M[j]) {
|
||||
if (M[j] > -INFINITY) {
|
||||
@@ -6404,6 +6409,7 @@ template <int D, int q_step, int k_step>
|
||||
inline void iqk_flash_helper_T(int nq1, int nk1, int stride_q, int stride_k, int stride_v, int stride_m, int stride_qkv,
|
||||
const float * q, const char * k, const char * v, const char * mask,
|
||||
float scale, float softcap, float * qkv) {
|
||||
|
||||
if (nq1 >= q_step) {
|
||||
FlashAttn<D, q_step, k_step> fa(scale, softcap);
|
||||
fa.compute(nq1, nk1, stride_k, stride_q, stride_m, stride_v, stride_qkv,
|
||||
|
||||
Reference in New Issue
Block a user