Flash attention: templated implementation

Needed to model different head sizes for different
LLMs, batch sizes that are not a multiple of 8, stc.

I see 2-3% performance degradation.

It is one of those things
that I don't understand, but really would like to:

I have an implementation of a function that depends in a compile time
constant. I get performance X.
I then turn the implementation into a template, where the former
compile time constant is a template parameter, and I instantiate the template
for a bunch of different values, one of which is the former compile
time constants. I observe performance c*X, where c almost always is
less than 1, and depending on how unlucky we get, it can be as low
as 0.5 or somesuch. But in my simple-minded understanding, I expect
the template instantiation with the former compile time constant
to turn into the exact same function as the former non-templated
implementation, and so I expect the exact same performance.

i.e., if I have some function
void some_function(...) {
    constexpr int N = 128;
    ... // code that depends on N
}

and I now write
template <int N>
void some_function_T(...) {
    ... // same code as in some_function() that depends on N
}

and I say
void wrapper_function(int N) {
    switch (N) {
        case  64: some_function_T< 64>(); break;
        case 128: some_function_T<128>(); break;
        ...
    }
}
I expect wrapper_function(128) to have the exact same performance as
some_function() (run time of some_function() is long enough to have the
additional function call overhead be completely negligible).
This is the reason I'm using a template in the first place instead
of just having void some_function(int N).

But no. Tough luck.
This commit is contained in:
Iwan Kawrakow
2024-08-31 13:10:36 +03:00
parent 6d9510c680
commit 1b834ac6e4
3 changed files with 510 additions and 258 deletions

View File

@@ -16226,8 +16226,6 @@ static void ggml_compute_forward_flash_attn_ext_f16(
if (nth%8 == 0 && work_per_slice >= (1 << 23)) ntg = 8;
else if (nth%4 == 0 && work_per_slice >= (1 << 21)) ntg = 4;
else if (nth%2 == 0 && work_per_slice >= (1 << 19)) ntg = 2;
//if (nth%4 == 0 && work_per_slice >= (1 << 21)) ntg = 4;
//else if (nth%2 == 0 && work_per_slice >= (1 << 19)) ntg = 2;
if ((neq2*neq3)%(nth/ntg) == 0) {
//if (ith == 0) printf("%s: D = %d, neq2 = %d, neq1 = %d, nek1 = %d\n", __func__, (int)D, (int)neq2, (int)neq1, (int)nek1);
int counter = 0;
@@ -16235,32 +16233,19 @@ static void ggml_compute_forward_flash_attn_ext_f16(
for (int64_t iq2 = 0; iq2 < neq2; iq2++) {
if (counter++ % (nth/ntg) == ith/ntg) {
int iq1 = (ith%ntg)*neq1/ntg;
iqk_flash_helper_3(D, neq1/ntg, nek1, q->nb[1], k->nb[1], v->nb[1], mask->nb[1], ne1*nb1/sizeof(float),
if (!iqk_flash_helper_3(D, neq1/ntg, nek1, q->nb[1], k->nb[1], v->nb[1], mask->nb[1], ne1*nb1/sizeof(float),
(const float *)((const char *)q->data + iq2*q->nb[2] + iq3*q->nb[3] + iq1*q->nb[1]),
(const void *)((const char *)k->data + iq2/rk2*k->nb[2] + iq3/rk3*k->nb[3]),
(const void *)((const char *)v->data + iq2/rv2*v->nb[2] + iq3/rv3*v->nb[3]),
(const void *)((const char *)mask->data + iq1*mask->nb[1]),
scale,
(float *)((char *) dst->data + (iq3*ne2*ne1 + iq2 + iq1*ne1)*nb1));
(float *)((char *) dst->data + (iq3*ne2*ne1 + iq2 + iq1*ne1)*nb1))) goto IQK_Flash_Attn_NotAvailable;
}
}
}
return;
}
//for (int64_t iq3 = 0; iq3 < neq3; iq3++) {
// for (int64_t iq2 = 0; iq2 < neq2; iq2++) {
// if (counter++ % nth == ith) {
// iqk_flash_helper_3(D, neq1, nek1, q->nb[1], k->nb[1], v->nb[1], mask->nb[1], ne1*nb1/sizeof(float),
// (const float *)((const char *)q->data + iq2*q->nb[2] + iq3*q->nb[3]),
// (const void *)((const char *)k->data + iq2/rk2*k->nb[2] + iq3/rk3*k->nb[3]),
// (const void *)((const char *)v->data + iq2/rv2*v->nb[2] + iq3/rv3*v->nb[3]),
// (const void *)((const char *)mask->data),
// scale,
// //(float *)params->wdata + ith*8*nek1,
// (float *)((char *) dst->data + (iq3*ne2*ne1 + iq2)*nb1)); // + iq1*ne1)*nb1))
// }
// }
//}
IQK_Flash_Attn_NotAvailable:;
}
const uint32_t n_head = neq2;

View File

@@ -3259,23 +3259,23 @@ IQK_NOINLINE void mul_mat_Qx_Qy_1xN(int n, const char * cx, size_t bx, int ix0,
Qy y(info);
Qx x(cx + ix0*bx, bx);
QFBase::Acc acc[Qx::nrc];
if (nb <= Qx::nrc) {
QFBase::Data yv[Qx::nrc];
for (int i = 0; i < nb; ++i) yv[i] = y.load1(0, i);
//for (int ix = 0; ix < Qx::nrc; ++ix) {
// auto sum = QFBase::acc_first(yv[0], x.load1(ix, 0));
// for (int i = 1; i < nb; ++i) {
// sum = QFBase::acc(sum, yv[i], x.load1(ix, i));
// }
// info.store(ix0+ix, 0, QFBase::hsum(sum));
//}
for (int ix = 0; ix < Qx::nrc; ++ix) acc[ix] = QFBase::acc_first(yv[0], x.load1(ix, 0));
for (int i = 1; i < nb; ++i) {
for (int ix = 0; ix < Qx::nrc; ++ix) acc[ix] = QFBase::acc(acc[ix], yv[i], x.load1(ix, i));
}
for (int ix = 0; ix < Qx::nrc; ++ix) info.store(ix0+ix, 0, QFBase::hsum(acc[ix]));
return;
}
//if (nb <= Qx::nrc) {
// QFBase::Data yv[Qx::nrc];
// for (int i = 0; i < nb; ++i) yv[i] = y.load1(0, i);
// //for (int ix = 0; ix < Qx::nrc; ++ix) {
// // auto sum = QFBase::acc_first(yv[0], x.load1(ix, 0));
// // for (int i = 1; i < nb; ++i) {
// // sum = QFBase::acc(sum, yv[i], x.load1(ix, i));
// // }
// // info.store(ix0+ix, 0, QFBase::hsum(sum));
// //}
// for (int ix = 0; ix < Qx::nrc; ++ix) acc[ix] = QFBase::acc_first(yv[0], x.load1(ix, 0));
// for (int i = 1; i < nb; ++i) {
// for (int ix = 0; ix < Qx::nrc; ++ix) acc[ix] = QFBase::acc(acc[ix], yv[i], x.load1(ix, i));
// }
// for (int ix = 0; ix < Qx::nrc; ++ix) info.store(ix0+ix, 0, QFBase::hsum(acc[ix]));
// return;
//}
QFBase::Data xv[Qx::nrc];
auto yv = y.load1(0, 0);
for (int ix = 0; ix < Qx::nrc; ++ix) {
@@ -6605,14 +6605,14 @@ void iqk_flash_helper_2(bool is_alibi,
}
}
if (mask) {
const ggml_half * mp = (const ggml_half *)mask;
for (int i = 0; i < nk; ++i) {
if (GGML_FP16_TO_FP32(mp[i]) == -INFINITY) {
nk = i; break;
}
}
}
//if (mask) {
// const ggml_half * mp = (const ggml_half *)mask;
// for (int i = 0; i < nk; ++i) {
// if (GGML_FP16_TO_FP32(mp[i]) == -INFINITY) {
// nk = i; break;
// }
// }
//}
DataInfo info{qk, (const char*)q, 0, size_t(stride_k), 0, 1, nullptr, 0};
@@ -6756,6 +6756,7 @@ bool iqk_soft_max_noalibi(int nc, int ir0, int ir1, int ne00, int ne01,
return true;
}
namespace {
template <int nrc_y, typename FloatX, typename FloatY>
void mul_mat_fX_fY_fa(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
assert(n%QFBase::k_step == 0);
@@ -6783,241 +6784,507 @@ void mul_mat_fX_fY_fa(int n, const void * vx, size_t bx, const DataInfo& info, i
}
}
void iqk_flash_helper_3(int ne00,
int nq1, // number of elements in q
int nk1, // number of rows in k
int stride_q,
int stride_k, // distance between rows in k (in bytes)
int stride_v, // distance between rows in v (in bytes)
int stride_m, // distance between rows in mask (in bytes)
int stride_qkv, // distance between rows in mask (in bytes)
const float * q, // q vector
const void * k, // k matrix. Assumed to be fp16, nq x nk elements
const void * v,
const void * mask, // mask. If not null, assumed to be fp16. nk elements
float scale,
float * qkv) {
constexpr int q_step = 8;
constexpr int k_step = 32; //16;
if (nq1%q_step != 0 || nk1%k_step != 0) {
for (int iq1 = 0; iq1 < nq1; ++iq1) {
iqk_flash_helper_2(false, ne00, nk1, stride_k, stride_v,
q, k, v, (const void *)((const char *)mask + iq1*stride_m),
scale, 1.0f, nullptr, qkv);
q += stride_q;
qkv += stride_qkv;
}
return;
}
stride_q /= sizeof(float);
const ggml_half h_inf = GGML_FP32_TO_FP16(-INFINITY);
float cache[q_step*k_step];
float S[q_step], M[q_step];
__m512 vk[16];
__m512 vms[q_step];
__m512 vals[k_step/16];
float qkv_cache[128*q_step];
int need_scaling[q_step];
auto vscale = _mm512_set1_ps(scale);
auto vinf = _mm512_set1_ps(-INFINITY);
for (int i1 = 0; i1 < nq1/q_step; ++i1) {
template <int D, int q_step, int k_step>
struct FlashAttn {
static_assert(D%16 == 0 && D <= 256);
static_assert(k_step%16 == 0);
static_assert(q_step <= 4 || q_step%4 == 0);
constexpr static int vk_size = D <= 128 ? D/8 : D/16;
static_assert(q_step <= vk_size);
FlashAttn(float scale) : vscale(_mm512_set1_ps(scale)), h_inf(GGML_FP32_TO_FP16(-INFINITY)) {}
inline void init_qstep() {
for (int j = 0; j < q_step; ++j) {
S[j] = 0; M[j] = -INFINITY;
}
for (int k1 = 0; k1 < nk1/k_step; ++k1) {
// This is slower
//DataInfo info{cache, (const char *)(q + q_step*i1*stride_q), k_step*sizeof(float), stride_q*sizeof(float), 0, 0, nullptr, 0};
//mul_mat_fX_fY_T<q_step/2, ggml_half, float>(ne00, (const void *)((const char *)k + k_step*k1*stride_k), stride_k, info, k_step);
//info.cur_y += q_step/2;
//mul_mat_fX_fY_T<q_step/2, ggml_half, float>(ne00, (const void *)((const char *)k + k_step*k1*stride_k), stride_k, info, k_step);
//for (int j = 0; j < q_step; ++j) {
// const ggml_half * mp = (const ggml_half *)((const char *)mask + stride_m*q_step*i1) + k_step*k1;
// for (int l = 0; l < k_step/16; ++l) {
// auto val = _mm512_loadu_ps(cache + k_step*j + 16*l);
// auto m16 = _mm256_cmpeq_epi16_mask(_mm256_loadu_si256((const __m256i *)mp), _mm256_setzero_si256());
// vals[l] = _mm512_mask_mul_ps(vinf, m16, vscale, val);
// }
// auto smax = _mm512_reduce_max_ps(_mm512_max_ps(vals[0], vals[1]));
// need_scaling[j] = 0;
// if (smax > M[j]) {
// if (M[j] > -INFINITY) {
// float m = expf(M[j] - smax);
// vms[j] = _mm512_set1_ps(m);
// need_scaling[j] = 1;
// S[j] *= m;
// } else {
// need_scaling[j] = 2;
// S[j] = 0;
// }
// M[j] = smax;
// }
// auto vm = _mm512_set1_ps(M[j]);
// for (int l = 0; l < k_step/16; ++l) {
// vals[l] = v_expf(_mm512_sub_ps(vals[l], vm));
// _mm512_storeu_ps(cache + k_step*j + 16*l, vals[l]);
// }
// S[j] += _mm512_reduce_add_ps(_mm512_add_ps(vals[0], vals[1]));
//}
}
for (int l1 = 0; l1 < k_step; l1 += 2) {
auto kr1 = (const ggml_half *)((const char *)k + (k_step*k1 + l1 + 0)*stride_k);
auto kr2 = (const ggml_half *)((const char *)k + (k_step*k1 + l1 + 1)*stride_k);
for (int i = 0; i < 8; ++i) vk[i+0] = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr1 + i));
for (int i = 0; i < 8; ++i) vk[i+8] = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr2 + i));
for (int m1 = 0; m1 < q_step; ++m1) {
// q index is q_step*i1 + m1
// k index is k_step*k1 + l1
const ggml_half * mp = (const ggml_half *)((const char *)mask + stride_m*(q_step*i1 + m1)) + k_step*k1;
cache[k_step*m1 + l1 + 0] = cache[k_step*m1 + l1 + 1] = -INFINITY;
if (mp[l1+0] == h_inf && mp[l1+1] == h_inf) {
continue;
}
__m512 qv[8];
auto qr = q + (q_step*i1 + m1)*stride_q;
for (int i = 0; i < 8; ++i) qv[i] = _mm512_loadu_ps(qr + 16*i);
if (mp[l1+0] != h_inf) {
auto vsum = _mm512_mul_ps(vk[0], qv[0]);
for (int i = 1; i < 8; ++i) vsum = _mm512_fmadd_ps(vk[i], qv[i], vsum);
cache[k_step*m1 + l1 + 0] = _mm512_reduce_add_ps(vsum);
}
if (mp[l1+1] != h_inf) {
auto vsum = _mm512_mul_ps(vk[8], qv[0]);
for (int i = 1; i < 8; ++i) vsum = _mm512_fmadd_ps(vk[i+8], qv[i], vsum);
cache[k_step*m1 + l1 + 1] = _mm512_reduce_add_ps(vsum);
inline void multiply_mask_kq(int nq1, int stride_k, int stride_q, int stride_m,
const char * k, const float * q, const char * mask);
inline void accumulate_qkv(int nq1, int stride_v, const char * v);
inline void normalize_and_store(int nq1, int stride_qkv, float * qkv) const {
auto R = qkv_cache;
for (int j = 0; j < nq1; ++j) {
GGML_ASSERT(S[j] > 0);
auto norm = _mm512_set1_ps(1/S[j]);
for (int i = 0; i < D/16; ++i) {
auto r = _mm512_loadu_ps(R + 16*i);
_mm512_storeu_ps(qkv + 16*i, _mm512_mul_ps(norm, r));
}
qkv += stride_qkv;
R += D;
}
}
inline void multiply_mask_kq(int stride_k, int stride_q, int stride_m,
const char * k, const float * q, const char * mask);
inline void accumulate_qkv(int stride_v, const char * v);
inline void normalize_and_store(int stride_qkv, float * qkv) const {
auto R = qkv_cache;
for (int j = 0; j < q_step; ++j) {
GGML_ASSERT(S[j] > 0);
auto norm = _mm512_set1_ps(1/S[j]);
for (int i = 0; i < D/16; ++i) {
auto r = _mm512_loadu_ps(R + 16*i);
_mm512_storeu_ps(qkv + 16*i, _mm512_mul_ps(norm, r));
}
qkv += stride_qkv;
R += D;
}
}
void compute(int nq1, int nk1, int stride_k, int stride_q, int stride_m, int stride_v, int stride_qkv,
const char * k, const float * q, const char * mask, const char * v, float * qkv) {
for (int i1 = 0; i1 < nq1/q_step; ++i1) {
init_qstep();
auto kr = k;
auto vr = v;
auto mr = mask;
for (int k1 = 0; k1 < nk1/k_step; ++k1) {
multiply_mask_kq(stride_k, stride_q, stride_m, kr, q, mr);
accumulate_qkv(stride_v, vr);
kr += k_step*stride_k;
vr += k_step*stride_v;
mr += k_step*sizeof(ggml_half);
}
normalize_and_store(stride_qkv, qkv);
q += q_step*stride_q;
mask += q_step*stride_m;
qkv += q_step*stride_qkv;
}
int n_left = nq1 - q_step*(nq1/q_step);
if (n_left > 0) {
init_qstep();
auto kr = k;
auto vr = v;
auto mr = mask;
for (int k1 = 0; k1 < nk1/k_step; ++k1) {
multiply_mask_kq(n_left, stride_k, stride_q, stride_m, kr, q, mr);
accumulate_qkv(n_left, stride_v, vr);
kr += k_step*stride_k;
vr += k_step*stride_v;
mr += k_step*sizeof(ggml_half);
}
normalize_and_store(n_left, stride_qkv, qkv);
}
}
float cache[q_step*k_step];
float qkv_cache[D*q_step];
float S[q_step], M[q_step];
int need_scaling[q_step];
__m512 vms[q_step];
__m512 vk[vk_size];
const __m512 vscale;
const ggml_half h_inf;
typedef __m512 (*combine_t)(__m512, __m512);
typedef float (*reduce_t)(__m512);
template <reduce_t Op, combine_t Op_combine>
static inline float reduce_T(const __m512 * vals) {
float result;
if constexpr (k_step/16 == 1) {
result = Op(vals[0]);
}
else if constexpr (k_step/16 == 2) {
result = Op(Op_combine(vals[0], vals[1]));
}
else {
auto vmax = vals[0];
for (int l = 1; l < k_step/16; ++l) vmax = Op_combine(vmax, vals[l]);
result = Op(vmax);
}
return result;
}
};
template <int D, int q_step, int k_step>
void FlashAttn<D, q_step, k_step>::multiply_mask_kq(int stride_k, int stride_q, int stride_m,
const char * k, const float * q, const char * mask) {
if constexpr (D <= 128) {
__m512 qv[D/16];
for (int l1 = 0; l1 < k_step; l1 += 2) {
auto kr1 = (const ggml_half *)(k + (l1 + 0)*stride_k);
auto kr2 = (const ggml_half *)(k + (l1 + 1)*stride_k);
for (int i = 0; i < D/16; ++i) vk[i+ 0] = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr1 + i));
for (int i = 0; i < D/16; ++i) vk[i+D/16] = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr2 + i));
for (int m1 = 0; m1 < q_step; ++m1) {
// q index is q_step*i1 + m1
// k index is k_step*k1 + l1
const ggml_half * mp = (const ggml_half *)(mask + stride_m*m1);
cache[k_step*m1 + l1 + 0] = cache[k_step*m1 + l1 + 1] = -INFINITY;
if (mp[l1+0] == h_inf && mp[l1+1] == h_inf) {
continue;
}
auto qr = q + m1*stride_q;
for (int i = 0; i < D/16; ++i) qv[i] = _mm512_loadu_ps(qr + 16*i);
if (mp[l1+0] != h_inf) {
auto vsum = _mm512_mul_ps(vk[0], qv[0]);
for (int i = 1; i < 8; ++i) vsum = _mm512_fmadd_ps(vk[i], qv[i], vsum);
cache[k_step*m1 + l1 + 0] = _mm512_reduce_add_ps(vsum);
}
if (mp[l1+1] != h_inf) {
auto vsum = _mm512_mul_ps(vk[D/16], qv[0]);
for (int i = 1; i < 8; ++i) vsum = _mm512_fmadd_ps(vk[i+D/16], qv[i], vsum);
cache[k_step*m1 + l1 + 1] = _mm512_reduce_add_ps(vsum);
}
}
}
}
else {
DataInfo info{cache, (const char *)q, k_step*sizeof(float), stride_q*sizeof(float), 0, 0, nullptr, 0};
for (int i = 0; i < q_step/4; ++i) {
mul_mat_fX_fY_T<4, ggml_half, float>(D, (const void *)k, stride_k, info, k_step);
info.cur_y += 4;
}
int n_left = q_step - 4*(q_step/4);
if (n_left > 0) {
switch (n_left) {
case 1: mul_mat_fX_fY_T<1, ggml_half, float>(D, (const void *)k, stride_k, info, k_step); break;
case 2: mul_mat_fX_fY_T<2, ggml_half, float>(D, (const void *)k, stride_k, info, k_step); break;
case 3: mul_mat_fX_fY_T<3, ggml_half, float>(D, (const void *)k, stride_k, info, k_step); break;
default: break;
}
}
//if constexpr (q_step <= 4) {
// mul_mat_fX_fY_T<q_step, ggml_half, float>(D, (const void *)k, stride_k, info, k_step);
//}
//else {
// for (int i = 0; i < q_step/4; ++i) {
// mul_mat_fX_fY_T<4, ggml_half, float>(D, (const void *)k, stride_k, info, k_step);
// info.cur_y += 4;
// }
//}
}
for (int j = 0; j < q_step; ++j) {
if constexpr (D <= 128) {
for (int l = 0; l < k_step/16; ++l) vk[l] = _mm512_mul_ps(vscale, _mm512_loadu_ps(cache + k_step*j + 16*l));
} else {
auto vinf = _mm512_set1_ps(-INFINITY);
const ggml_half * mp = (const ggml_half *)mask;
for (int l = 0; l < k_step/16; ++l) {
auto val = _mm512_loadu_ps(cache + k_step*j + 16*l);
auto m16 = _mm256_cmpeq_epi16_mask(_mm256_loadu_si256((const __m256i *)mp), _mm256_setzero_si256());
vk[l] = _mm512_mask_mul_ps(vinf, m16, vscale, val);
}
}
float smax = reduce_T<_mm512_reduce_max_ps, _mm512_max_ps>(vk);
need_scaling[j] = 0;
if (smax > M[j]) {
if (M[j] > -INFINITY) {
float m = expf(M[j] - smax);
vms[j] = _mm512_set1_ps(m);
need_scaling[j] = 1;
S[j] *= m;
} else {
need_scaling[j] = 2;
S[j] = 0;
}
M[j] = smax;
}
auto vm = _mm512_set1_ps(M[j]);
for (int l = 0; l < k_step/16; ++l) {
vk[l] = v_expf(_mm512_sub_ps(vk[l], vm));
_mm512_storeu_ps(cache + k_step*j + 16*l, vk[l]);
}
S[j] += reduce_T<_mm512_reduce_add_ps, _mm512_add_ps>(vk);
}
}
template <int D, int q_step, int k_step>
void FlashAttn<D, q_step, k_step>::accumulate_qkv(int stride_v, const char * v) {
if constexpr (2*q_step <= vk_size) {
for (int i = 0; i < D/16; i += 2) {
for (int j = 0; j < q_step; ++j) {
if (need_scaling[j] == 2) {
vk[2*j+0] = vk[2*j+1] = _mm512_setzero_ps();
} else {
auto R = qkv_cache + D*j;
vk[2*j+0] = _mm512_loadu_ps(R + 16*i);
vk[2*j+1] = _mm512_loadu_ps(R + 16*i + 16);
if (need_scaling[j] == 1) {
vk[2*j+0] = _mm512_mul_ps(vk[2*j+0], vms[j]);
vk[2*j+1] = _mm512_mul_ps(vk[2*j+1], vms[j]);
}
}
}
for (int l1 = 0; l1 < k_step; ++l1) {
auto vr = (const ggml_half *)(v + l1*stride_v);
auto v1 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)vr+i+0));
auto v2 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)vr+i+1));
for (int j = 0; j < q_step; ++j) {
auto vs = _mm512_set1_ps(cache[k_step*j + l1]);
vk[2*j+0] = _mm512_fmadd_ps(v1, vs, vk[2*j+0]);
vk[2*j+1] = _mm512_fmadd_ps(v2, vs, vk[2*j+1]);
}
}
for (int j = 0; j < q_step; ++j) {
for (int l = 0; l < k_step/16; ++l) vals[l] = _mm512_mul_ps(vscale, _mm512_loadu_ps(cache + k_step*j + 16*l));
auto smax = _mm512_reduce_max_ps(_mm512_max_ps(vals[0], vals[1]));
need_scaling[j] = 0;
if (smax > M[j]) {
if (M[j] > -INFINITY) {
float m = expf(M[j] - smax);
vms[j] = _mm512_set1_ps(m);
need_scaling[j] = 1;
S[j] *= m;
} else {
need_scaling[j] = 2;
S[j] = 0;
}
M[j] = smax;
}
auto vm = _mm512_set1_ps(M[j]);
for (int l = 0; l < k_step/16; ++l) {
vals[l] = v_expf(_mm512_sub_ps(vals[l], vm));
_mm512_storeu_ps(cache + k_step*j + 16*l, vals[l]);
}
S[j] += _mm512_reduce_add_ps(_mm512_add_ps(vals[0], vals[1]));
}
for (int i = 0; i < 8; i += 2) {
for (int j = 0; j < q_step; ++j) {
if (need_scaling[j] == 2) {
vk[2*j+0] = vk[2*j+1] = _mm512_setzero_ps();
} else {
auto R = qkv_cache + 128*j;
vk[2*j+0] = _mm512_loadu_ps(R + 16*i);
vk[2*j+1] = _mm512_loadu_ps(R + 16*i + 16);
if (need_scaling[j] == 1) {
vk[2*j+0] = _mm512_mul_ps(vk[2*j+0], vms[j]);
vk[2*j+1] = _mm512_mul_ps(vk[2*j+1], vms[j]);
}
}
}
for (int l1 = 0; l1 < k_step; ++l1) {
auto vr = (const ggml_half *)((const char *)v + (k_step*k1 + l1)*stride_v);
auto v1 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)vr+i+0));
auto v2 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)vr+i+1));
for (int j = 0; j < q_step; ++j) {
auto vs = _mm512_set1_ps(cache[k_step*j + l1]);
vk[2*j+0] = _mm512_fmadd_ps(v1, vs, vk[2*j+0]);
vk[2*j+1] = _mm512_fmadd_ps(v2, vs, vk[2*j+1]);
}
}
for (int j = 0; j < q_step; ++j) {
auto R = qkv_cache + 128*j;
_mm512_storeu_ps(R + 16*i, vk[2*j+0]);
_mm512_storeu_ps(R + 16*i + 16, vk[2*j+1]);
}
auto R = qkv_cache + D*j;
_mm512_storeu_ps(R + 16*i, vk[2*j+0]);
_mm512_storeu_ps(R + 16*i + 16, vk[2*j+1]);
}
}
for (int j = 0; j < q_step; ++j) {
GGML_ASSERT(S[j] > 0);
auto R = qkv_cache + 128*j;
auto final_R = qkv + (q_step*i1 + j)*stride_qkv;
auto norm = _mm512_set1_ps(1/S[j]);
for (int i = 0; i < 8; ++i) {
auto r = _mm512_loadu_ps(R + 16*i);
_mm512_storeu_ps(final_R + 16*i, _mm512_mul_ps(norm, r));
} else {
for (int i = 0; i < D/16; ++i) {
for (int j = 0; j < q_step; ++j) {
if (need_scaling[j] == 2) {
vk[j] = _mm512_setzero_ps();
} else {
auto R = qkv_cache + D*j;
vk[j] = _mm512_loadu_ps(R + 16*i);
if (need_scaling[j] == 1) {
vk[j] = _mm512_mul_ps(vk[j], vms[j]);
}
}
}
for (int l1 = 0; l1 < k_step; ++l1) {
auto vr = (const ggml_half *)(v + l1*stride_v);
auto v = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)vr+i));
for (int j = 0; j < q_step; ++j) {
auto vs = _mm512_set1_ps(cache[k_step*j + l1]);
vk[j] = _mm512_fmadd_ps(v, vs, vk[j]);
}
}
for (int j = 0; j < q_step; ++j) {
auto R = qkv_cache + D*j;
_mm512_storeu_ps(R + 16*i, vk[j]);
}
}
}
return;
if (nq1%16 != 0 || nk1%16 != 0) printf("Oops(%s): nq1 = %d, nk1 = %d\n", __func__, nq1, nk1);
//GGML_ASSERT(nq1%16 == 0 && nk1%16 == 0);
//auto vinf = _mm512_set1_ps(-INFINITY);
for (int i1 = 0; i1 < nq1/16; ++i1) {
//int iq1 = 16*i1;
for (int j1 = 0; j1 < 16; ++j1) {
S[j1] = 0; M[j1] = -INFINITY;
std::memset(qkv + j1*stride_v, 0, ne00*sizeof(float));
}
for (int ik = 0; ik < nk1; ik += 16) {
/////////////////////////////////////////////////////////////////////////////////
const ggml_half * kr = (const ggml_half *)((const char *)k + stride_k*ik);
DataInfo info{cache, (const char *)q, 16*sizeof(float), size_t(stride_q)*sizeof(float), 0, 0, nullptr, 0};
mul_mat_fX_fY_T<4, ggml_half, float>(ne00, (const void *)kr, stride_k, info, 16);
/////////////////////////////////////////////////////////////////////////////////
float * R = qkv;
for (int j1 = 0; j1 < 16; ++j1) {
int iq1 = 16*i1 + j1;
float * C = cache + 16*j1;
auto qk = _mm512_loadu_ps(C);
const ggml_half * mp = (const ggml_half *)((const char *)mask + stride_m*iq1);
auto m16 = _mm256_cmpeq_epi16_mask(_mm256_loadu_si256((const __m256i*)mp), _mm256_setzero_si256());
qk = _mm512_mask_blend_ps(m16, vinf, qk);
float smax = _mm512_reduce_max_ps(qk);
if (smax > M[j1]) {
if (M[j1] > -INFINITY) {
float m = expf(M[j1] - smax);
auto ms = _mm512_set1_ps(m);
for (int i = 0; i < ne00/16; ++i) _mm512_storeu_ps(R + 16*i, _mm512_mul_ps(ms, _mm512_loadu_ps(R + 16*i)));
S[j1] *= m;
} else {
std::memset(R, 0, ne00*sizeof(float));
S[j1] = 0;
}
M[j1] = smax;
}
template <int D, int q_step, int k_step>
void FlashAttn<D, q_step, k_step>::multiply_mask_kq(int nq1, int stride_k, int stride_q, int stride_m,
const char * k, const float * q, const char * mask) {
if constexpr (D <= 128) {
__m512 qv[D/16];
for (int l1 = 0; l1 < k_step; l1 += 2) {
auto kr1 = (const ggml_half *)(k + (l1 + 0)*stride_k);
auto kr2 = (const ggml_half *)(k + (l1 + 1)*stride_k);
for (int i = 0; i < D/16; ++i) vk[i+ 0] = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr1 + i));
for (int i = 0; i < D/16; ++i) vk[i+D/16] = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr2 + i));
for (int m1 = 0; m1 < nq1; ++m1) {
// q index is q_step*i1 + m1
// k index is k_step*k1 + l1
const ggml_half * mp = (const ggml_half *)(mask + stride_m*m1);
cache[k_step*m1 + l1 + 0] = cache[k_step*m1 + l1 + 1] = -INFINITY;
if (mp[l1+0] == h_inf && mp[l1+1] == h_inf) {
continue;
}
auto vs = v_expf(_mm512_sub_ps(qk, _mm512_set1_ps(M[j1])));
S[j1] += _mm512_reduce_add_ps(vs);
_mm512_storeu_ps(C, vs);
for (int jk = 0; jk < 16; ++jk) {
vs = _mm512_set1_ps(C[jk]);
const ggml_half * vr = (const ggml_half *)((const char *)v + stride_v*(ik + jk));
for (int i = 0; i < ne00/16; ++i) {
auto v = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)vr + i));
auto r = _mm512_loadu_ps(qkv + 16*i);
_mm512_storeu_ps(qkv + 16*i, _mm512_fmadd_ps(vs, v, r));
}
auto qr = q + m1*stride_q;
for (int i = 0; i < D/16; ++i) qv[i] = _mm512_loadu_ps(qr + 16*i);
if (mp[l1+0] != h_inf) {
auto vsum = _mm512_mul_ps(vk[0], qv[0]);
for (int i = 1; i < 8; ++i) vsum = _mm512_fmadd_ps(vk[i], qv[i], vsum);
cache[k_step*m1 + l1 + 0] = _mm512_reduce_add_ps(vsum);
}
R += stride_qkv;
}
}
for (int j1 = 0; j1 < 16; ++j1) {
if (S[j1] > 0) {
//GGML_ASSERT(S[j1] > 0);
auto norm = _mm512_set1_ps(1/S[j1]);
for (int i = 0; i < ne00/16; ++i) {
auto r = _mm512_loadu_ps(qkv + 16*i);
_mm512_storeu_ps(qkv + 16*i, _mm512_mul_ps(norm, r));
if (mp[l1+1] != h_inf) {
auto vsum = _mm512_mul_ps(vk[8], qv[0]);
for (int i = 1; i < 8; ++i) vsum = _mm512_fmadd_ps(vk[i+D/16], qv[i], vsum);
cache[k_step*m1 + l1 + 1] = _mm512_reduce_add_ps(vsum);
}
}
}
}
else {
DataInfo info{cache, (const char *)q, k_step*sizeof(float), stride_q*sizeof(float), 0, 0, nullptr, 0};
for (int i = 0; i < nq1/4; ++i) {
mul_mat_fX_fY_T<4, ggml_half, float>(D, (const void *)k, stride_k, info, k_step);
info.cur_y += 4;
}
int n_left = nq1 - 4*(nq1/4);
if (n_left > 0) {
switch (n_left) {
case 1: mul_mat_fX_fY_T<1, ggml_half, float>(D, (const void *)k, stride_k, info, k_step); break;
case 2: mul_mat_fX_fY_T<2, ggml_half, float>(D, (const void *)k, stride_k, info, k_step); break;
case 3: mul_mat_fX_fY_T<3, ggml_half, float>(D, (const void *)k, stride_k, info, k_step); break;
default: break;
}
}
//if constexpr (q_step <= 4) {
// mul_mat_fX_fY_T<q_step, ggml_half, float>(D, (const void *)k, stride_k, info, k_step);
//}
//else {
// for (int i = 0; i < q_step/4; ++i) {
// mul_mat_fX_fY_T<4, ggml_half, float>(D, (const void *)k, stride_k, info, k_step);
// info.cur_y += 4;
// }
//}
}
for (int j = 0; j < nq1; ++j) {
if constexpr (D <= 128) {
for (int l = 0; l < k_step/16; ++l) vk[l] = _mm512_mul_ps(vscale, _mm512_loadu_ps(cache + k_step*j + 16*l));
} else {
auto vinf = _mm512_set1_ps(-INFINITY);
const ggml_half * mp = (const ggml_half *)mask;
for (int l = 0; l < k_step/16; ++l) {
auto val = _mm512_loadu_ps(cache + k_step*j + 16*l);
auto m16 = _mm256_cmpeq_epi16_mask(_mm256_loadu_si256((const __m256i *)mp), _mm256_setzero_si256());
vk[l] = _mm512_mask_mul_ps(vinf, m16, vscale, val);
}
}
float smax = reduce_T<_mm512_reduce_max_ps, _mm512_max_ps>(vk);
need_scaling[j] = 0;
if (smax > M[j]) {
if (M[j] > -INFINITY) {
float m = expf(M[j] - smax);
vms[j] = _mm512_set1_ps(m);
need_scaling[j] = 1;
S[j] *= m;
} else {
need_scaling[j] = 2;
S[j] = 0;
}
M[j] = smax;
}
auto vm = _mm512_set1_ps(M[j]);
for (int l = 0; l < k_step/16; ++l) {
vk[l] = v_expf(_mm512_sub_ps(vk[l], vm));
_mm512_storeu_ps(cache + k_step*j + 16*l, vk[l]);
}
S[j] += reduce_T<_mm512_reduce_add_ps, _mm512_add_ps>(vk);
}
}
template <int D, int q_step, int k_step>
void FlashAttn<D, q_step, k_step>::accumulate_qkv(int nq1, int stride_v, const char * v) {
if (2*nq1 <= vk_size) {
for (int i = 0; i < D/16; i += 2) {
for (int j = 0; j < nq1; ++j) {
if (need_scaling[j] == 2) {
vk[2*j+0] = vk[2*j+1] = _mm512_setzero_ps();
} else {
auto R = qkv_cache + D*j;
vk[2*j+0] = _mm512_loadu_ps(R + 16*i);
vk[2*j+1] = _mm512_loadu_ps(R + 16*i + 16);
if (need_scaling[j] == 1) {
vk[2*j+0] = _mm512_mul_ps(vk[2*j+0], vms[j]);
vk[2*j+1] = _mm512_mul_ps(vk[2*j+1], vms[j]);
}
}
}
for (int l1 = 0; l1 < k_step; ++l1) {
auto vr = (const ggml_half *)(v + l1*stride_v);
auto v1 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)vr+i+0));
auto v2 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)vr+i+1));
for (int j = 0; j < q_step; ++j) {
auto vs = _mm512_set1_ps(cache[k_step*j + l1]);
vk[2*j+0] = _mm512_fmadd_ps(v1, vs, vk[2*j+0]);
vk[2*j+1] = _mm512_fmadd_ps(v2, vs, vk[2*j+1]);
}
}
for (int j = 0; j < nq1; ++j) {
auto R = qkv_cache + D*j;
_mm512_storeu_ps(R + 16*i, vk[2*j+0]);
_mm512_storeu_ps(R + 16*i + 16, vk[2*j+1]);
}
}
} else {
for (int i = 0; i < D/16; ++i) {
for (int j = 0; j < nq1; ++j) {
if (need_scaling[j] == 2) {
vk[j] = _mm512_setzero_ps();
} else {
auto R = qkv_cache + D*j;
vk[j] = _mm512_loadu_ps(R + 16*i);
if (need_scaling[j] == 1) {
vk[j] = _mm512_mul_ps(vk[j], vms[j]);
}
}
}
for (int l1 = 0; l1 < k_step; ++l1) {
auto vr = (const ggml_half *)(v + l1*stride_v);
auto v = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)vr+i));
for (int j = 0; j < q_step; ++j) {
auto vs = _mm512_set1_ps(cache[k_step*j + l1]);
vk[j] = _mm512_fmadd_ps(v, vs, vk[j]);
}
}
for (int j = 0; j < nq1; ++j) {
auto R = qkv_cache + D*j;
_mm512_storeu_ps(R + 16*i, vk[j]);
}
}
}
}
template <int D, int q_step, int k_step>
inline void iqk_flash_helper_T(int nq1, int nk1, int stride_q, int stride_k, int stride_v, int stride_m, int stride_qkv,
const float * q, const char * k, const char * v, const char * mask,
float scale, float * qkv) {
if (nq1 >= q_step) {
FlashAttn<D, q_step, k_step> fa(scale);
fa.compute(nq1, nk1, stride_k, stride_q, stride_m, stride_v, stride_qkv,
(const char *)k, q, (const char *)mask, (const char *)v, qkv);
} else {
FlashAttn<D, 1, k_step> fa(scale);
fa.compute(nq1, nk1, stride_k, stride_q, stride_m, stride_v, stride_qkv,
(const char *)k, q, (const char *)mask, (const char *)v, qkv);
}
}
}
bool iqk_flash_helper_3(int ne00, // attention head size
int nq1, // number of columns in q
int nk1, // number of rows in k
int stride_q, // distance between q columns in bytes
int stride_k, // distance between k rows in bytes
int stride_v, // distance between v rows (in bytes)
int stride_m, // distance between rows in mask (in bytes)
int stride_qkv, // distance between qkv rows in bytes
const float * q, // q matrix
const void * k, // k matrix. Assumed to be fp16, ne00 x nk elements
const void * v, // v matrix. Assumed to be fp16, ne00 x nk elements
const void * mask, // mask. If not null, assumed to be fp16. nq*nk elements
float scale, // the scale in softmax(scale*(k*q))
float * qkv) { // the qkv result
if (!mask) return false; // we assume the mask is not null in the implementation
if (nk1%32 != 0) {
const char * mp = (const char *)mask;
for (int iq1 = 0; iq1 < nq1; ++iq1) {
iqk_flash_helper_2(false, ne00, nk1, stride_k, stride_v,
q, k, v, (const void *)mp,
scale, 1.0f, nullptr, qkv);
q += stride_q;
qkv += stride_qkv;
mp += stride_m;
}
q += 16*stride_q;
return true;
}
stride_q /= sizeof(float); // q stride as float
auto ck = (const char *)k;
auto cv = (const char *)v;
auto cm = (const char *)mask;
switch (ne00) {
case 64:
iqk_flash_helper_T< 64, 4, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, qkv); break;
case 80:
iqk_flash_helper_T< 80, 4, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, qkv); break;
case 96:
iqk_flash_helper_T< 96, 4, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, qkv); break;
case 112:
iqk_flash_helper_T<112, 4, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, qkv); break;
case 128:
iqk_flash_helper_T<128, 8, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, qkv); break;
case 256:
iqk_flash_helper_T<256, 8, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, qkv); break;
default:
return false;
}
return true;
}
//bool iqk_flash_attention_noalibi_f16(int ith, int nth,

View File

@@ -58,7 +58,7 @@ void iqk_flash_helper_2(bool is_alibi,
float * qk,
float * qkv); // softmax(k*q) - k elements
void iqk_flash_helper_3(int ne00,
bool iqk_flash_helper_3(int ne00,
int nq, // number of elements in q
int nk, // number of rows in k
int stride_q,