This commit is contained in:
Iwan Kawrakow
2024-08-29 10:06:29 +03:00
parent ec0ae14aee
commit a02e78d5c1

View File

@@ -6391,17 +6391,24 @@ inline void accumulate(int n, float * saux, float smax, float& M, float& S, __m5
}
M = smax;
}
//auto vs_all = v_expf(_mm256_sub_ps(_mm256_loadu_ps(saux), _mm256_set1_ps(M)));
//_mm256_storeu_ps(saux, vs_all);
auto vs_all = v_expf(_mm512_sub_ps(_mm512_loadu_ps(saux), _mm512_set1_ps(M)));
_mm512_storeu_ps(saux, vs_all);
S += _mm512_reduce_add_ps(vs_all);
//for (int j = 0; j < n; ++j) {
// if (saux[j] < -18.f) continue; // ignore anything less than 1.5e-8 - it want't change the single precision result.
// auto vs = _mm512_set1_ps(saux[j]);
// auto vr = (const ggml_half *)(v + stride_v*j);
// for (int i = 0; i < nq/16; ++i) {
// acc[i] = _mm512_fmadd_ps(vs, _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)vr + i)), acc[i]);
// }
//}
__m512 vec_v[nq/16];
for (int j = 0; j < n; ++j) {
S += saux[j];
auto vs = _mm512_set1_ps(saux[j]);
if (saux[j] < -18.f) continue; // ignore anything less than 1.5e-8 - it want't change the single precision result.
auto vr = (const ggml_half *)(v + stride_v*j);
for (int i = 0; i < nq/16; ++i) {
acc[i] = _mm512_fmadd_ps(vs, _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)vr + i)), acc[i]);
}
for (int i = 0; i < nq/16; ++i) vec_v[i] = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)vr + i));
auto vs = _mm512_set1_ps(saux[j]);
for (int i = 0; i < nq/16; ++i) acc[i] = _mm512_fmadd_ps(vs, vec_v[i], acc[i]);
}
}
@@ -6420,8 +6427,14 @@ void flash_attn_T(int nk, // number of rows in k
__m512 vq[nq/16];
__m512 acc[nq/16];
float saux[kNchunk];
//float mv32[kNchunk];
for (int i = 0; i < nq/16; ++i) vq[i] = _mm512_loadu_ps(q + 16*i);
const ggml_half * mp = mask ? (const ggml_half *)mask : nullptr;
//if (!mp) std::memset(mv32, 0, kNchunk*sizeof(float));
//else {
// auto vmask = _mm512_mul_ps(_mm512_set1_ps(slope), _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)mp)));
// _mm512_storeu_ps(mv32, vmask);
//}
float M = -INFINITY;
float S = 0;
float smax = -INFINITY;
@@ -6429,12 +6442,23 @@ void flash_attn_T(int nk, // number of rows in k
int ik = 0;
for (; ik < nk; ++ik) {
if (ik - last_ik == kNchunk) {
accumulate<nq>(kNchunk, saux, smax, M, S, acc, (const char *)v + stride_v*last_ik, stride_v);
//accumulate<nq>(kNchunk, saux, smax, M, S, acc, (const char *)v + stride_v*last_ik, stride_v);
if (smax != -INFINITY) {
accumulate<nq>(kNchunk, saux, smax, M, S, acc, (const char *)v + stride_v*last_ik, stride_v);
}
last_ik = ik;
smax = -INFINITY;
//if (ik + kNchunk <= nk) {
// auto vmask = _mm512_mul_ps(_mm512_set1_ps(slope), _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)mp)));
// _mm512_storeu_ps(mv32, vmask);
//}
}
//const float mv = mv32[ik - last_ik];
const float mv = mp ? slope*GGML_FP16_TO_FP32(mp[ik]) : 0.0f;
if (mv == -INFINITY) break;
if (mv == -INFINITY) {
saux[ik - last_ik] = -INFINITY;
continue;
}
const ggml_half * kr = (const ggml_half *)((const char *)k + stride_k*ik);
auto sum = _mm512_mul_ps(vq[0], _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr)));
for (int i = 1; i < nq/16; ++i) sum = _mm512_fmadd_ps(vq[i], _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr + i)), sum);
@@ -6442,8 +6466,10 @@ void flash_attn_T(int nk, // number of rows in k
saux[ik - last_ik] = s;
smax = std::max(smax, s);
}
if (ik > last_ik) {
accumulate<nq>(ik - last_ik, saux, smax, M, S, acc, (const char *)v + stride_v*last_ik, stride_v);
int n_left = ik - last_ik;
if (n_left > 0 & smax != -INFINITY) {
for (int j = n_left; j < kNchunk; ++j) saux[j] = -INFINITY;
accumulate<nq>(n_left, saux, smax, M, S, acc, (const char *)v + stride_v*last_ik, stride_v);
}
if (S > 0) {
auto norm = _mm512_set1_ps(1/S);