mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-25 23:54:10 +00:00
WIP: plugging into ggml_compute_forward_flash_attn_ext_f16
OK, if we take into account that the mask is diagonal and skip further computations once we encounter -INFINITY, we can speed it up and make it on par with no-FA. Better than nothing, but still no luck.
This commit is contained in:
@@ -6394,63 +6394,96 @@ void iqk_flash_helper_2(int nq, // number of elements in q
|
||||
float M = -INFINITY;
|
||||
float S = 0;
|
||||
int ik = 0;
|
||||
if (nk >= 8) {
|
||||
float s8[8];
|
||||
for (int ik8 = 0; ik8 < nk/8; ++ik8) {
|
||||
float smax = -INFINITY;
|
||||
for (int j = 0; j < 8; ++j) {
|
||||
const ggml_half * kr = (const ggml_half *)((const char *)k + stride_k*(8*ik8 + j));
|
||||
auto sum = _mm512_mul_ps(vq[0], _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr)));
|
||||
for (int i = 1; i < 8; ++i) sum = _mm512_fmadd_ps(vq[i], _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr + i)), sum);
|
||||
float s = scale * _mm512_reduce_add_ps(sum);
|
||||
if (mp) s += slope*GGML_FP16_TO_FP32(mp[8*ik8+j]);
|
||||
s8[j] = s;
|
||||
smax = std::max(smax, s);
|
||||
}
|
||||
if (smax > M) {
|
||||
float ms = expf(M - smax);
|
||||
auto scale = _mm512_set1_ps(ms);
|
||||
for (int i = 0; i < 8; ++i) acc[i] = _mm512_mul_ps(acc[i], scale);
|
||||
S *= ms;
|
||||
M = smax;
|
||||
}
|
||||
_mm256_storeu_ps(s8, v_expf(_mm256_sub_ps(_mm256_loadu_ps(s8), _mm256_set1_ps(M))));
|
||||
for (int j = 0; j < 8; ++j) {
|
||||
const ggml_half * vr = (const ggml_half *)((const char *)v + stride_v*(8*ik8+j));
|
||||
auto vs = _mm512_set1_ps(s8[j]);
|
||||
S += s8[j];
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
acc[i] = _mm512_fmadd_ps(vs, _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)vr + i)), acc[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
ik = 8*(nk/8);
|
||||
}
|
||||
//if (nk >= 8) {
|
||||
// float s8[8];
|
||||
// for (int ik8 = 0; ik8 < nk/8; ++ik8) {
|
||||
// float smax = -INFINITY;
|
||||
// for (int j = 0; j < 8; ++j) {
|
||||
// const float mv = mp ? slope*GGML_FP16_TO_FP32(mp[8*ik+j]) : 0.0f;
|
||||
// if (mv == -INFINITY) {
|
||||
// s8[j] = -INFINITY; continue;
|
||||
// }
|
||||
// const ggml_half * kr = (const ggml_half *)((const char *)k + stride_k*(8*ik8 + j));
|
||||
// auto sum = _mm512_mul_ps(vq[0], _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr)));
|
||||
// for (int i = 1; i < 8; ++i) sum = _mm512_fmadd_ps(vq[i], _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr + i)), sum);
|
||||
// float s = scale * _mm512_reduce_add_ps(sum) + mv;
|
||||
// s8[j] = s;
|
||||
// smax = std::max(smax, s);
|
||||
// }
|
||||
// if (smax > M) {
|
||||
// if (M > -INFINITY) {
|
||||
// float ms = expf(M - smax);
|
||||
// auto scale = _mm512_set1_ps(ms);
|
||||
// for (int i = 0; i < 8; ++i) acc[i] = _mm512_mul_ps(acc[i], scale);
|
||||
// S *= ms;
|
||||
// } else {
|
||||
// for (int i = 0; i < 8; ++i) acc[i] = _mm512_setzero_ps();
|
||||
// }
|
||||
// M = smax;
|
||||
// }
|
||||
// if (smax == -INFINITY) break;
|
||||
// _mm256_storeu_ps(s8, v_expf(_mm256_sub_ps(_mm256_loadu_ps(s8), _mm256_set1_ps(M))));
|
||||
// for (int j = 0; j < 8; ++j) {
|
||||
// if (s8[j] <= 0.0f) continue;
|
||||
// const ggml_half * vr = (const ggml_half *)((const char *)v + stride_v*(8*ik8+j));
|
||||
// auto vs = _mm512_set1_ps(s8[j]);
|
||||
// S += s8[j];
|
||||
// for (int i = 0; i < 8; ++i) {
|
||||
// acc[i] = _mm512_fmadd_ps(vs, _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)vr + i)), acc[i]);
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// ik = 8*(nk/8);
|
||||
//}
|
||||
//int last_i = -1;
|
||||
for (; ik < nk; ++ik) {
|
||||
const float mv = mp ? slope*GGML_FP16_TO_FP32(mp[ik]) : 0.0f;
|
||||
if (mv == -INFINITY) continue;
|
||||
if (mv == -INFINITY) break; //continue;
|
||||
//last_i = ik;
|
||||
const ggml_half * kr = (const ggml_half *)((const char *)k + stride_k*ik);
|
||||
const ggml_half * vr = (const ggml_half *)((const char *)v + stride_v*ik);
|
||||
//auto sum1 = _mm512_mul_ps(vq[0], _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr+0)));
|
||||
//auto sum2 = _mm512_mul_ps(vq[1], _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr+1)));
|
||||
//for (int i = 2; i < 8; i += 2) {
|
||||
// sum1 = _mm512_fmadd_ps(vq[i+0], _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr + i + 0)), sum1);
|
||||
// sum2 = _mm512_fmadd_ps(vq[i+1], _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr + i + 1)), sum2);
|
||||
//}
|
||||
//float s = scale * _mm512_reduce_add_ps(_mm512_add_ps(sum1, sum2)) + mv;
|
||||
auto sum = _mm512_mul_ps(vq[0], _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr)));
|
||||
for (int i = 1; i < 8; ++i) sum = _mm512_fmadd_ps(vq[i], _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr + i)), sum);
|
||||
float s = scale * _mm512_reduce_add_ps(sum) + mv;
|
||||
if (s > M) {
|
||||
float ms = expf(M - s);
|
||||
auto vms = _mm512_set1_ps(ms);
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
acc[i] = _mm512_fmadd_ps(vms, acc[i], _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)vr + i)));
|
||||
if (M > -INFINITY) {
|
||||
//if (M - s > -20.f) {
|
||||
float ms = expf(M - s);
|
||||
auto vms = _mm512_set1_ps(ms);
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
acc[i] = _mm512_fmadd_ps(vms, acc[i], _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)vr + i)));
|
||||
}
|
||||
S = ms*S + 1;
|
||||
} else {
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
acc[i] = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)vr + i));
|
||||
}
|
||||
S = 1;
|
||||
}
|
||||
S = ms*S + 1;
|
||||
M = s;
|
||||
} else {
|
||||
float vs = expf(s - M);
|
||||
auto vvs = _mm512_set1_ps(vs);
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
acc[i] = _mm512_fmadd_ps(vvs, _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)vr + i)), acc[i]);
|
||||
}
|
||||
S += vs;
|
||||
//if (s - M > -20.f) {
|
||||
float vs = expf(s - M);
|
||||
auto vvs = _mm512_set1_ps(vs);
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
acc[i] = _mm512_fmadd_ps(vvs, _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)vr + i)), acc[i]);
|
||||
}
|
||||
S += vs;
|
||||
//}
|
||||
}
|
||||
}
|
||||
//if (last_i < 0) {
|
||||
// std::memset(qkv, 0, 128*sizeof(float));
|
||||
// return;
|
||||
//}
|
||||
//printf("%s: nk = %d, last_i = %d\n", __func__, nk, last_i);
|
||||
auto norm = _mm512_set1_ps(1/S);
|
||||
for (int i = 0; i < 8; ++i) _mm512_storeu_ps(qkv + 16*i, _mm512_mul_ps(norm, acc[i]));
|
||||
return;
|
||||
|
||||
Reference in New Issue
Block a user