mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-25 23:54:10 +00:00
WIP
This commit is contained in:
@@ -6391,17 +6391,24 @@ inline void accumulate(int n, float * saux, float smax, float& M, float& S, __m5
|
||||
}
|
||||
M = smax;
|
||||
}
|
||||
//auto vs_all = v_expf(_mm256_sub_ps(_mm256_loadu_ps(saux), _mm256_set1_ps(M)));
|
||||
//_mm256_storeu_ps(saux, vs_all);
|
||||
auto vs_all = v_expf(_mm512_sub_ps(_mm512_loadu_ps(saux), _mm512_set1_ps(M)));
|
||||
_mm512_storeu_ps(saux, vs_all);
|
||||
S += _mm512_reduce_add_ps(vs_all);
|
||||
//for (int j = 0; j < n; ++j) {
|
||||
// if (saux[j] < -18.f) continue; // ignore anything less than 1.5e-8 - it want't change the single precision result.
|
||||
// auto vs = _mm512_set1_ps(saux[j]);
|
||||
// auto vr = (const ggml_half *)(v + stride_v*j);
|
||||
// for (int i = 0; i < nq/16; ++i) {
|
||||
// acc[i] = _mm512_fmadd_ps(vs, _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)vr + i)), acc[i]);
|
||||
// }
|
||||
//}
|
||||
__m512 vec_v[nq/16];
|
||||
for (int j = 0; j < n; ++j) {
|
||||
S += saux[j];
|
||||
auto vs = _mm512_set1_ps(saux[j]);
|
||||
if (saux[j] < -18.f) continue; // ignore anything less than 1.5e-8 - it want't change the single precision result.
|
||||
auto vr = (const ggml_half *)(v + stride_v*j);
|
||||
for (int i = 0; i < nq/16; ++i) {
|
||||
acc[i] = _mm512_fmadd_ps(vs, _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)vr + i)), acc[i]);
|
||||
}
|
||||
for (int i = 0; i < nq/16; ++i) vec_v[i] = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)vr + i));
|
||||
auto vs = _mm512_set1_ps(saux[j]);
|
||||
for (int i = 0; i < nq/16; ++i) acc[i] = _mm512_fmadd_ps(vs, vec_v[i], acc[i]);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6420,8 +6427,14 @@ void flash_attn_T(int nk, // number of rows in k
|
||||
__m512 vq[nq/16];
|
||||
__m512 acc[nq/16];
|
||||
float saux[kNchunk];
|
||||
//float mv32[kNchunk];
|
||||
for (int i = 0; i < nq/16; ++i) vq[i] = _mm512_loadu_ps(q + 16*i);
|
||||
const ggml_half * mp = mask ? (const ggml_half *)mask : nullptr;
|
||||
//if (!mp) std::memset(mv32, 0, kNchunk*sizeof(float));
|
||||
//else {
|
||||
// auto vmask = _mm512_mul_ps(_mm512_set1_ps(slope), _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)mp)));
|
||||
// _mm512_storeu_ps(mv32, vmask);
|
||||
//}
|
||||
float M = -INFINITY;
|
||||
float S = 0;
|
||||
float smax = -INFINITY;
|
||||
@@ -6429,12 +6442,23 @@ void flash_attn_T(int nk, // number of rows in k
|
||||
int ik = 0;
|
||||
for (; ik < nk; ++ik) {
|
||||
if (ik - last_ik == kNchunk) {
|
||||
accumulate<nq>(kNchunk, saux, smax, M, S, acc, (const char *)v + stride_v*last_ik, stride_v);
|
||||
//accumulate<nq>(kNchunk, saux, smax, M, S, acc, (const char *)v + stride_v*last_ik, stride_v);
|
||||
if (smax != -INFINITY) {
|
||||
accumulate<nq>(kNchunk, saux, smax, M, S, acc, (const char *)v + stride_v*last_ik, stride_v);
|
||||
}
|
||||
last_ik = ik;
|
||||
smax = -INFINITY;
|
||||
//if (ik + kNchunk <= nk) {
|
||||
// auto vmask = _mm512_mul_ps(_mm512_set1_ps(slope), _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)mp)));
|
||||
// _mm512_storeu_ps(mv32, vmask);
|
||||
//}
|
||||
}
|
||||
//const float mv = mv32[ik - last_ik];
|
||||
const float mv = mp ? slope*GGML_FP16_TO_FP32(mp[ik]) : 0.0f;
|
||||
if (mv == -INFINITY) break;
|
||||
if (mv == -INFINITY) {
|
||||
saux[ik - last_ik] = -INFINITY;
|
||||
continue;
|
||||
}
|
||||
const ggml_half * kr = (const ggml_half *)((const char *)k + stride_k*ik);
|
||||
auto sum = _mm512_mul_ps(vq[0], _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr)));
|
||||
for (int i = 1; i < nq/16; ++i) sum = _mm512_fmadd_ps(vq[i], _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr + i)), sum);
|
||||
@@ -6442,8 +6466,10 @@ void flash_attn_T(int nk, // number of rows in k
|
||||
saux[ik - last_ik] = s;
|
||||
smax = std::max(smax, s);
|
||||
}
|
||||
if (ik > last_ik) {
|
||||
accumulate<nq>(ik - last_ik, saux, smax, M, S, acc, (const char *)v + stride_v*last_ik, stride_v);
|
||||
int n_left = ik - last_ik;
|
||||
if (n_left > 0 & smax != -INFINITY) {
|
||||
for (int j = n_left; j < kNchunk; ++j) saux[j] = -INFINITY;
|
||||
accumulate<nq>(n_left, saux, smax, M, S, acc, (const char *)v + stride_v*last_ik, stride_v);
|
||||
}
|
||||
if (S > 0) {
|
||||
auto norm = _mm512_set1_ps(1/S);
|
||||
|
||||
Reference in New Issue
Block a user