Experimenting with flash attention on Zen4

This variant is better for long contexts but still
not as good as no FA.
This commit is contained in:
Iwan Kawrakow
2024-08-29 17:52:56 +03:00
parent a02e78d5c1
commit b5df88b120
3 changed files with 325 additions and 13 deletions

View File

@@ -16217,6 +16217,25 @@ static void ggml_compute_forward_flash_attn_ext_f16(
memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
if (nr%nth == 0 && max_bias <= 0.0f && q->type == GGML_TYPE_F32 && k->type == GGML_TYPE_F16 && v->type == GGML_TYPE_F16 &&
mask && mask->type == GGML_TYPE_F16) {
int counter = 0;
for (int64_t iq3 = 0; iq3 < neq3; iq3++) {
for (int64_t iq2 = 0; iq2 < neq2; iq2++) {
if (counter++ % nth == ith) {
iqk_flash_helper_3(D, neq1, nek1, q->nb[1], k->nb[1], v->nb[1], mask->nb[1], ne1*nb1/sizeof(float),
(const float *)((const char *)q->data + iq2*q->nb[2] + iq3*q->nb[3]),
(const void *)((const char *)k->data + iq2/rk2*k->nb[2] + iq3/rk3*k->nb[3]),
(const void *)((const char *)v->data + iq2/rv2*v->nb[2] + iq3/rv3*v->nb[3]),
(const void *)((const char *)mask->data),
scale,
(float *)((char *) dst->data + (iq3*ne2*ne1 + iq2)*nb1)); // + iq1*ne1)*nb1))
}
}
}
return;
}
const uint32_t n_head = neq2;
const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
@@ -16296,7 +16315,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
// }
//}
iqk_flash_helper_2(D, nek1, nbk1, nbv1,
iqk_flash_helper_2(max_bias > 0, D, nek1, nbk1, nbv1,
(const float *)((char *) q->data + iq1*nbq1 + iq2*nbq2 + iq3*nbq3),
(const void *)((char *) k->data + ik2*nbk2 + ik3*nbk3),
(const void *)((char *) v->data + iv2*nbv2 + iv3*nbv3),

View File

@@ -6377,6 +6377,7 @@ void iqk_flash_helper(int nq, // number of elements in q
}
namespace {
template <int nq>
inline void accumulate(int n, float * saux, float smax, float& M, float& S, __m512 * acc, const char * v, int stride_v) {
if (smax > M) {
@@ -6412,6 +6413,91 @@ inline void accumulate(int n, float * saux, float smax, float& M, float& S, __m5
}
}
template <int nq>
inline void accumulate(int n, float scale, float * saux, float smax, float& M, float& S, __m512 * acc, const char * v, int stride_v) {
smax *= scale;
if (smax > M) {
if (M > -INFINITY) {
float ms = expf(M - smax);
auto vms = _mm512_set1_ps(ms);
for (int i = 0; i < nq/16; ++i) acc[i] = _mm512_mul_ps(vms, acc[i]);
S *= ms;
} else {
for (int i = 0; i < nq/16; ++i) acc[i] = _mm512_setzero_ps();
S = 0;
}
M = smax;
}
auto vs_all = v_expf(_mm512_fmsub_ps(_mm512_set1_ps(scale), _mm512_loadu_ps(saux), _mm512_set1_ps(M)));
_mm512_storeu_ps(saux, vs_all);
S += _mm512_reduce_add_ps(vs_all);
for (int j = 0; j < n; ++j) {
if (saux[j] < -18.f) continue; // ignore anything less than 1.5e-8 - it want't change the single precision result.
auto vs = _mm512_set1_ps(saux[j]);
auto vr = (const ggml_half *)(v + stride_v*j);
for (int i = 0; i < nq/16; ++i) {
acc[i] = _mm512_fmadd_ps(vs, _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)vr + i)), acc[i]);
}
}
}
template <int nq>
void flash_attn_noalibi_T(int nk, // number of rows in k
int stride_k, // distance between rows in k (in bytes)
int stride_v,
const float * q, // q vector
const void * k, // k matrix. Assumed to be fp16, nq x nk elements
const void * v, // k matrix. Assumed to be fp16, nq x nk elements
const void * mask, // mask. If not null, assumed to be fp16. nk elements
float scale,
float * qkv) {
GGML_ASSERT(mask);
constexpr int kNchunk = 16;
__m512 vq[nq/16];
__m512 acc[nq/16];
float saux[kNchunk];
for (int i = 0; i < nq/16; ++i) vq[i] = _mm512_loadu_ps(q + 16*i);
const ggml_half * mp = (const ggml_half *)mask;
ggml_half h_inf = GGML_FP32_TO_FP16(-INFINITY);
float M = -INFINITY;
float S = 0;
float smax = -INFINITY;
int last_ik = 0;
int ik = 0;
for (; ik < nk; ++ik) {
if (ik - last_ik == kNchunk) {
if (smax != -INFINITY) {
accumulate<nq>(kNchunk, scale, saux, smax, M, S, acc, (const char *)v + stride_v*last_ik, stride_v);
}
last_ik = ik;
smax = -INFINITY;
}
if (mp[ik] == h_inf) {
saux[ik - last_ik] = -INFINITY;
continue;
}
const ggml_half * kr = (const ggml_half *)((const char *)k + stride_k*ik);
auto sum = _mm512_mul_ps(vq[0], _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr)));
for (int i = 1; i < nq/16; ++i) sum = _mm512_fmadd_ps(vq[i], _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr + i)), sum);
float s = _mm512_reduce_add_ps(sum);
saux[ik - last_ik] = s;
smax = std::max(smax, s);
}
int n_left = ik - last_ik;
if (n_left > 0 && smax != -INFINITY) {
for (int j = n_left; j < kNchunk; ++j) saux[j] = -INFINITY;
accumulate<nq>(n_left, scale, saux, smax, M, S, acc, (const char *)v + stride_v*last_ik, stride_v);
}
if (S > 0) {
auto norm = _mm512_set1_ps(1/S);
for (int i = 0; i < nq/16; ++i) _mm512_storeu_ps(qkv + 16*i, _mm512_mul_ps(norm, acc[i]));
} else {
printf("Oops: S = %g. ik = %d, last_ik = %d, nk = %d, M = %g\n", S, ik, last_ik, nk, M);
GGML_ASSERT(false);
std::memset(qkv, 0, 128*sizeof(float));
}
}
template <int nq>
void flash_attn_T(int nk, // number of rows in k
int stride_k, // distance between rows in k (in bytes)
@@ -6467,7 +6553,7 @@ void flash_attn_T(int nk, // number of rows in k
smax = std::max(smax, s);
}
int n_left = ik - last_ik;
if (n_left > 0 & smax != -INFINITY) {
if (n_left > 0 && smax != -INFINITY) {
for (int j = n_left; j < kNchunk; ++j) saux[j] = -INFINITY;
accumulate<nq>(n_left, saux, smax, M, S, acc, (const char *)v + stride_v*last_ik, stride_v);
}
@@ -6482,7 +6568,8 @@ void flash_attn_T(int nk, // number of rows in k
}
}
void iqk_flash_helper_2(int nq, // number of elements in q
void iqk_flash_helper_2(bool is_alibi,
int nq, // number of elements in q
int nk, // number of rows in k
int stride_k, // distance between rows in k (in bytes)
int stride_v,
@@ -6496,15 +6583,26 @@ void iqk_flash_helper_2(int nq, // number of elements in q
float * qkv) {
GGML_ASSERT(nq % 4 == 0);
switch (nq) {
case 64: flash_attn_T< 64>(nk, stride_k, stride_v, q, k, v, mask, scale, slope, qkv); return;
case 80: flash_attn_T< 80>(nk, stride_k, stride_v, q, k, v, mask, scale, slope, qkv); return;
case 96: flash_attn_T< 96>(nk, stride_k, stride_v, q, k, v, mask, scale, slope, qkv); return;
case 112: flash_attn_T<112>(nk, stride_k, stride_v, q, k, v, mask, scale, slope, qkv); return;
case 128: flash_attn_T<128>(nk, stride_k, stride_v, q, k, v, mask, scale, slope, qkv); return;
case 256: flash_attn_T<256>(nk, stride_k, stride_v, q, k, v, mask, scale, slope, qkv); return;
default: break;
//default: GGML_ABORT("unhandled head size -> fatal error");
if (is_alibi) {
switch (nq) {
case 64: flash_attn_T< 64>(nk, stride_k, stride_v, q, k, v, mask, scale, slope, qkv); return;
case 80: flash_attn_T< 80>(nk, stride_k, stride_v, q, k, v, mask, scale, slope, qkv); return;
case 96: flash_attn_T< 96>(nk, stride_k, stride_v, q, k, v, mask, scale, slope, qkv); return;
case 112: flash_attn_T<112>(nk, stride_k, stride_v, q, k, v, mask, scale, slope, qkv); return;
case 128: flash_attn_T<128>(nk, stride_k, stride_v, q, k, v, mask, scale, slope, qkv); return;
case 256: flash_attn_T<256>(nk, stride_k, stride_v, q, k, v, mask, scale, slope, qkv); return;
default: break;
}
} else {
switch (nq) {
case 64: flash_attn_noalibi_T< 64>(nk, stride_k, stride_v, q, k, v, mask, scale, qkv); return;
case 80: flash_attn_noalibi_T< 80>(nk, stride_k, stride_v, q, k, v, mask, scale, qkv); return;
case 96: flash_attn_noalibi_T< 96>(nk, stride_k, stride_v, q, k, v, mask, scale, qkv); return;
case 112: flash_attn_noalibi_T<112>(nk, stride_k, stride_v, q, k, v, mask, scale, qkv); return;
case 128: flash_attn_noalibi_T<128>(nk, stride_k, stride_v, q, k, v, mask, scale, qkv); return;
case 256: flash_attn_noalibi_T<256>(nk, stride_k, stride_v, q, k, v, mask, scale, qkv); return;
default: break;
}
}
if (mask) {
@@ -6658,6 +6756,185 @@ bool iqk_soft_max_noalibi(int nc, int ir0, int ir1, int ne00, int ne01,
return true;
}
void iqk_flash_helper_3(int ne00,
int nq1, // number of elements in q
int nk1, // number of rows in k
int stride_q,
int stride_k, // distance between rows in k (in bytes)
int stride_v, // distance between rows in v (in bytes)
int stride_m, // distance between rows in mask (in bytes)
int stride_qkv, // distance between rows in mask (in bytes)
const float * q, // q vector
const void * k, // k matrix. Assumed to be fp16, nq x nk elements
const void * v,
const void * mask, // mask. If not null, assumed to be fp16. nk elements
float scale,
float * qkv) {
stride_q /= sizeof(float);
// The following works
//for (int iq1 = 0; iq1 < nq1; ++iq1) {
// iqk_flash_helper_2(false,
// ne00,
// nk1,
// stride_k,
// stride_v,
// q,
// k,
// v,
// (const void *)((const char *)mask + iq1*stride_m),
// scale,
// 1.0f,
// nullptr,
// qkv);
// q += stride_q;
// qkv += stride_qkv;
//}
float cache[256];
float S[16], M[16];
__m512 vk[8];
for (int i1 = 0; i1 < nq1/16; ++i1) {
for (int j = 0; j < 16; ++j) {
auto R = qkv + (16*i1 + j)*stride_qkv;
std::memset(R, 0, 128*sizeof(float));
S[j] = 0; M[j] = -INFINITY;
}
for (int k1 = 0; k1 < nk1/16; ++k1) {
for (int l1 = 0; l1 < 16; ++l1) {
auto kr = (const ggml_half *)((const char *)k + (16*k1 + l1)*stride_k);
for (int i = 0; i < 8; ++i) vk[i] = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr + i));
for (int m1 = 0; m1 < 16; ++m1) {
// q index is 16*i1 + m1
// k index is 16*k1 + l1
const ggml_half * mp = (const ggml_half *)((const char *)mask + stride_m*(16*i1 + m1)) + 16*k1;
if (GGML_FP16_TO_FP32(mp[l1]) == -INFINITY) {
cache[16*m1 + l1] = -INFINITY;
continue;
}
auto qr = q + (16*i1 + m1)*stride_q;
auto vsum = _mm512_mul_ps(vk[0], _mm512_loadu_ps(qr));
for (int i = 1; i < 8; ++i) vsum = _mm512_fmadd_ps(vk[i], _mm512_loadu_ps(qr + 16*i), vsum);
cache[16*m1 + l1] = scale*_mm512_reduce_add_ps(vsum);
}
}
for (int j = 0; j < 16; ++j) {
auto R = qkv + (16*i1 + j)*stride_qkv;
auto val = _mm512_loadu_ps(cache + 16*j);
auto smax = _mm512_reduce_max_ps(val);
for (int i = 0; i < 8; ++i) vk[i] = _mm512_loadu_ps(R + 16*i);
if (smax > M[j]) {
if (M[j] > -INFINITY) {
float m = expf(M[j] - smax);
auto vm = _mm512_set1_ps(m);
for (int i = 0; i < 8; ++i) {
vk[i] = _mm512_mul_ps(vm, vk[i]);
//auto r = _mm512_loadu_ps(R + 16*i);
//_mm512_storeu_ps(R + 16*i, _mm512_mul_ps(vm, r));
}
S[j] *= m;
} else {
for (int i = 0; i < 8; ++i) vk[i] = _mm512_setzero_ps();
//std::memset(R, 0, 128*sizeof(float));
S[j] = 0;
}
M[j] = smax;
}
val = v_expf(_mm512_sub_ps(val, _mm512_set1_ps(M[j])));
S[j] += _mm512_reduce_add_ps(val);
_mm512_storeu_ps(cache + 16*j, val);
for (int l1 = 0; l1 < 16; ++l1) {
if (cache[16*j + l1] < -20.0f) continue;
auto vr = (const ggml_half *)((const char *)v + (16*k1 + l1)*stride_v);
auto vs = _mm512_set1_ps(cache[16*j + l1]);
for (int i = 0; i < 8; ++i) {
vk[i] = _mm512_fmadd_ps(vs, _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)vr+i)), vk[i]);
}
}
for (int i = 0; i < 8; ++i) _mm512_storeu_ps(R + 16*i, vk[i]);
}
}
for (int j = 0; j < 16; ++j) {
auto R = qkv + (16*i1 + j)*stride_qkv;
GGML_ASSERT(S[j] > 0);
if (S[j] > 0) {
auto norm = _mm512_set1_ps(1/S[j]);
for (int i = 0; i < 8; ++i) {
auto r = _mm512_loadu_ps(R + 16*i);
_mm512_storeu_ps(R + 16*i, _mm512_mul_ps(norm, r));
}
} else {
std::memset(R, 0, 128*sizeof(float));
}
}
}
return;
if (nq1%16 != 0 || nk1%16 != 0) printf("Oops(%s): nq1 = %d, nk1 = %d\n", __func__, nq1, nk1);
//GGML_ASSERT(nq1%16 == 0 && nk1%16 == 0);
auto vinf = _mm512_set1_ps(-INFINITY);
for (int i1 = 0; i1 < nq1/16; ++i1) {
//int iq1 = 16*i1;
for (int j1 = 0; j1 < 16; ++j1) {
S[j1] = 0; M[j1] = -INFINITY;
std::memset(qkv + j1*stride_v, 0, ne00*sizeof(float));
}
for (int ik = 0; ik < nk1; ik += 16) {
/////////////////////////////////////////////////////////////////////////////////
const ggml_half * kr = (const ggml_half *)((const char *)k + stride_k*ik);
DataInfo info{cache, (const char *)q, 16*sizeof(float), size_t(stride_q)*sizeof(float), 0, 0, nullptr, 0};
mul_mat_fX_fY_T<4, ggml_half, float>(ne00, (const void *)kr, stride_k, info, 16);
/////////////////////////////////////////////////////////////////////////////////
float * R = qkv;
for (int j1 = 0; j1 < 16; ++j1) {
int iq1 = 16*i1 + j1;
float * C = cache + 16*j1;
auto qk = _mm512_loadu_ps(C);
const ggml_half * mp = (const ggml_half *)((const char *)mask + stride_m*iq1);
auto m16 = _mm256_cmpeq_epi16_mask(_mm256_loadu_si256((const __m256i*)mp), _mm256_setzero_si256());
qk = _mm512_mask_blend_ps(m16, vinf, qk);
float smax = _mm512_reduce_max_ps(qk);
if (smax > M[j1]) {
if (M[j1] > -INFINITY) {
float m = expf(M[j1] - smax);
auto ms = _mm512_set1_ps(m);
for (int i = 0; i < ne00/16; ++i) _mm512_storeu_ps(R + 16*i, _mm512_mul_ps(ms, _mm512_loadu_ps(R + 16*i)));
S[j1] *= m;
} else {
std::memset(R, 0, ne00*sizeof(float));
S[j1] = 0;
}
M[j1] = smax;
}
auto vs = v_expf(_mm512_sub_ps(qk, _mm512_set1_ps(M[j1])));
S[j1] += _mm512_reduce_add_ps(vs);
_mm512_storeu_ps(C, vs);
for (int jk = 0; jk < 16; ++jk) {
vs = _mm512_set1_ps(C[jk]);
const ggml_half * vr = (const ggml_half *)((const char *)v + stride_v*(ik + jk));
for (int i = 0; i < ne00/16; ++i) {
auto v = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)vr + i));
auto r = _mm512_loadu_ps(qkv + 16*i);
_mm512_storeu_ps(qkv + 16*i, _mm512_fmadd_ps(vs, v, r));
}
}
R += stride_qkv;
}
}
for (int j1 = 0; j1 < 16; ++j1) {
if (S[j1] > 0) {
//GGML_ASSERT(S[j1] > 0);
auto norm = _mm512_set1_ps(1/S[j1]);
for (int i = 0; i < ne00/16; ++i) {
auto r = _mm512_loadu_ps(qkv + 16*i);
_mm512_storeu_ps(qkv + 16*i, _mm512_mul_ps(norm, r));
}
}
qkv += stride_qkv;
}
q += 16*stride_q;
}
}
#else // IQK_IMPLEMENT
bool iqk_mul_mat(int, long, long, long, int, const void *, long, int, const void *, long, float *, long, int, int) {

View File

@@ -44,7 +44,8 @@ void iqk_flash_helper(int nq, // number of elements in q
float slope,
float * qk); // softmax(k*q) - k elements
void iqk_flash_helper_2(int nq, // number of elements in q
void iqk_flash_helper_2(bool is_alibi,
int nq, // number of elements in q
int nk, // number of rows in k
int stride_k, // distance between rows in k (in bytes)
int stride_v, // distance between rows in k (in bytes)
@@ -57,6 +58,21 @@ void iqk_flash_helper_2(int nq, // number of elements in q
float * qk,
float * qkv); // softmax(k*q) - k elements
void iqk_flash_helper_3(int ne00,
int nq, // number of elements in q
int nk, // number of rows in k
int stride_q,
int stride_k, // distance between rows in k (in bytes)
int stride_v, // distance between rows in v (in bytes)
int stride_m, // distance between rows in mask (in bytes)
int stride_qkv, // distance between rows in mask (in bytes)
const float * q, // q vector
const void * k, // k matrix. Assumed to be fp16, nq x nk elements
const void * v,
const void * mask, // mask. If not null, assumed to be fp16. nk elements
float scale,
float * qkv); // v*softmax(k*q)
#ifdef __cplusplus
}
#endif