mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-25 15:44:10 +00:00
Flash attention: templated implementation
Needed to model different head sizes for different
LLMs, batch sizes that are not a multiple of 8, stc.
I see 2-3% performance degradation.
It is one of those things
that I don't understand, but really would like to:
I have an implementation of a function that depends in a compile time
constant. I get performance X.
I then turn the implementation into a template, where the former
compile time constant is a template parameter, and I instantiate the template
for a bunch of different values, one of which is the former compile
time constants. I observe performance c*X, where c almost always is
less than 1, and depending on how unlucky we get, it can be as low
as 0.5 or somesuch. But in my simple-minded understanding, I expect
the template instantiation with the former compile time constant
to turn into the exact same function as the former non-templated
implementation, and so I expect the exact same performance.
i.e., if I have some function
void some_function(...) {
constexpr int N = 128;
... // code that depends on N
}
and I now write
template <int N>
void some_function_T(...) {
... // same code as in some_function() that depends on N
}
and I say
void wrapper_function(int N) {
switch (N) {
case 64: some_function_T< 64>(); break;
case 128: some_function_T<128>(); break;
...
}
}
I expect wrapper_function(128) to have the exact same performance as
some_function() (run time of some_function() is long enough to have the
additional function call overhead be completely negligible).
This is the reason I'm using a template in the first place instead
of just having void some_function(int N).
But no. Tough luck.
This commit is contained in:
@@ -16226,8 +16226,6 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
||||
if (nth%8 == 0 && work_per_slice >= (1 << 23)) ntg = 8;
|
||||
else if (nth%4 == 0 && work_per_slice >= (1 << 21)) ntg = 4;
|
||||
else if (nth%2 == 0 && work_per_slice >= (1 << 19)) ntg = 2;
|
||||
//if (nth%4 == 0 && work_per_slice >= (1 << 21)) ntg = 4;
|
||||
//else if (nth%2 == 0 && work_per_slice >= (1 << 19)) ntg = 2;
|
||||
if ((neq2*neq3)%(nth/ntg) == 0) {
|
||||
//if (ith == 0) printf("%s: D = %d, neq2 = %d, neq1 = %d, nek1 = %d\n", __func__, (int)D, (int)neq2, (int)neq1, (int)nek1);
|
||||
int counter = 0;
|
||||
@@ -16235,32 +16233,19 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
||||
for (int64_t iq2 = 0; iq2 < neq2; iq2++) {
|
||||
if (counter++ % (nth/ntg) == ith/ntg) {
|
||||
int iq1 = (ith%ntg)*neq1/ntg;
|
||||
iqk_flash_helper_3(D, neq1/ntg, nek1, q->nb[1], k->nb[1], v->nb[1], mask->nb[1], ne1*nb1/sizeof(float),
|
||||
if (!iqk_flash_helper_3(D, neq1/ntg, nek1, q->nb[1], k->nb[1], v->nb[1], mask->nb[1], ne1*nb1/sizeof(float),
|
||||
(const float *)((const char *)q->data + iq2*q->nb[2] + iq3*q->nb[3] + iq1*q->nb[1]),
|
||||
(const void *)((const char *)k->data + iq2/rk2*k->nb[2] + iq3/rk3*k->nb[3]),
|
||||
(const void *)((const char *)v->data + iq2/rv2*v->nb[2] + iq3/rv3*v->nb[3]),
|
||||
(const void *)((const char *)mask->data + iq1*mask->nb[1]),
|
||||
scale,
|
||||
(float *)((char *) dst->data + (iq3*ne2*ne1 + iq2 + iq1*ne1)*nb1));
|
||||
(float *)((char *) dst->data + (iq3*ne2*ne1 + iq2 + iq1*ne1)*nb1))) goto IQK_Flash_Attn_NotAvailable;
|
||||
}
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
//for (int64_t iq3 = 0; iq3 < neq3; iq3++) {
|
||||
// for (int64_t iq2 = 0; iq2 < neq2; iq2++) {
|
||||
// if (counter++ % nth == ith) {
|
||||
// iqk_flash_helper_3(D, neq1, nek1, q->nb[1], k->nb[1], v->nb[1], mask->nb[1], ne1*nb1/sizeof(float),
|
||||
// (const float *)((const char *)q->data + iq2*q->nb[2] + iq3*q->nb[3]),
|
||||
// (const void *)((const char *)k->data + iq2/rk2*k->nb[2] + iq3/rk3*k->nb[3]),
|
||||
// (const void *)((const char *)v->data + iq2/rv2*v->nb[2] + iq3/rv3*v->nb[3]),
|
||||
// (const void *)((const char *)mask->data),
|
||||
// scale,
|
||||
// //(float *)params->wdata + ith*8*nek1,
|
||||
// (float *)((char *) dst->data + (iq3*ne2*ne1 + iq2)*nb1)); // + iq1*ne1)*nb1))
|
||||
// }
|
||||
// }
|
||||
//}
|
||||
IQK_Flash_Attn_NotAvailable:;
|
||||
}
|
||||
|
||||
const uint32_t n_head = neq2;
|
||||
|
||||
@@ -3259,23 +3259,23 @@ IQK_NOINLINE void mul_mat_Qx_Qy_1xN(int n, const char * cx, size_t bx, int ix0,
|
||||
Qy y(info);
|
||||
Qx x(cx + ix0*bx, bx);
|
||||
QFBase::Acc acc[Qx::nrc];
|
||||
if (nb <= Qx::nrc) {
|
||||
QFBase::Data yv[Qx::nrc];
|
||||
for (int i = 0; i < nb; ++i) yv[i] = y.load1(0, i);
|
||||
//for (int ix = 0; ix < Qx::nrc; ++ix) {
|
||||
// auto sum = QFBase::acc_first(yv[0], x.load1(ix, 0));
|
||||
// for (int i = 1; i < nb; ++i) {
|
||||
// sum = QFBase::acc(sum, yv[i], x.load1(ix, i));
|
||||
// }
|
||||
// info.store(ix0+ix, 0, QFBase::hsum(sum));
|
||||
//}
|
||||
for (int ix = 0; ix < Qx::nrc; ++ix) acc[ix] = QFBase::acc_first(yv[0], x.load1(ix, 0));
|
||||
for (int i = 1; i < nb; ++i) {
|
||||
for (int ix = 0; ix < Qx::nrc; ++ix) acc[ix] = QFBase::acc(acc[ix], yv[i], x.load1(ix, i));
|
||||
}
|
||||
for (int ix = 0; ix < Qx::nrc; ++ix) info.store(ix0+ix, 0, QFBase::hsum(acc[ix]));
|
||||
return;
|
||||
}
|
||||
//if (nb <= Qx::nrc) {
|
||||
// QFBase::Data yv[Qx::nrc];
|
||||
// for (int i = 0; i < nb; ++i) yv[i] = y.load1(0, i);
|
||||
// //for (int ix = 0; ix < Qx::nrc; ++ix) {
|
||||
// // auto sum = QFBase::acc_first(yv[0], x.load1(ix, 0));
|
||||
// // for (int i = 1; i < nb; ++i) {
|
||||
// // sum = QFBase::acc(sum, yv[i], x.load1(ix, i));
|
||||
// // }
|
||||
// // info.store(ix0+ix, 0, QFBase::hsum(sum));
|
||||
// //}
|
||||
// for (int ix = 0; ix < Qx::nrc; ++ix) acc[ix] = QFBase::acc_first(yv[0], x.load1(ix, 0));
|
||||
// for (int i = 1; i < nb; ++i) {
|
||||
// for (int ix = 0; ix < Qx::nrc; ++ix) acc[ix] = QFBase::acc(acc[ix], yv[i], x.load1(ix, i));
|
||||
// }
|
||||
// for (int ix = 0; ix < Qx::nrc; ++ix) info.store(ix0+ix, 0, QFBase::hsum(acc[ix]));
|
||||
// return;
|
||||
//}
|
||||
QFBase::Data xv[Qx::nrc];
|
||||
auto yv = y.load1(0, 0);
|
||||
for (int ix = 0; ix < Qx::nrc; ++ix) {
|
||||
@@ -6605,14 +6605,14 @@ void iqk_flash_helper_2(bool is_alibi,
|
||||
}
|
||||
}
|
||||
|
||||
if (mask) {
|
||||
const ggml_half * mp = (const ggml_half *)mask;
|
||||
for (int i = 0; i < nk; ++i) {
|
||||
if (GGML_FP16_TO_FP32(mp[i]) == -INFINITY) {
|
||||
nk = i; break;
|
||||
}
|
||||
}
|
||||
}
|
||||
//if (mask) {
|
||||
// const ggml_half * mp = (const ggml_half *)mask;
|
||||
// for (int i = 0; i < nk; ++i) {
|
||||
// if (GGML_FP16_TO_FP32(mp[i]) == -INFINITY) {
|
||||
// nk = i; break;
|
||||
// }
|
||||
// }
|
||||
//}
|
||||
|
||||
DataInfo info{qk, (const char*)q, 0, size_t(stride_k), 0, 1, nullptr, 0};
|
||||
|
||||
@@ -6756,6 +6756,7 @@ bool iqk_soft_max_noalibi(int nc, int ir0, int ir1, int ne00, int ne01,
|
||||
return true;
|
||||
}
|
||||
|
||||
namespace {
|
||||
template <int nrc_y, typename FloatX, typename FloatY>
|
||||
void mul_mat_fX_fY_fa(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
assert(n%QFBase::k_step == 0);
|
||||
@@ -6783,241 +6784,507 @@ void mul_mat_fX_fY_fa(int n, const void * vx, size_t bx, const DataInfo& info, i
|
||||
}
|
||||
}
|
||||
|
||||
void iqk_flash_helper_3(int ne00,
|
||||
int nq1, // number of elements in q
|
||||
int nk1, // number of rows in k
|
||||
int stride_q,
|
||||
int stride_k, // distance between rows in k (in bytes)
|
||||
int stride_v, // distance between rows in v (in bytes)
|
||||
int stride_m, // distance between rows in mask (in bytes)
|
||||
int stride_qkv, // distance between rows in mask (in bytes)
|
||||
const float * q, // q vector
|
||||
const void * k, // k matrix. Assumed to be fp16, nq x nk elements
|
||||
const void * v,
|
||||
const void * mask, // mask. If not null, assumed to be fp16. nk elements
|
||||
float scale,
|
||||
float * qkv) {
|
||||
constexpr int q_step = 8;
|
||||
constexpr int k_step = 32; //16;
|
||||
if (nq1%q_step != 0 || nk1%k_step != 0) {
|
||||
for (int iq1 = 0; iq1 < nq1; ++iq1) {
|
||||
iqk_flash_helper_2(false, ne00, nk1, stride_k, stride_v,
|
||||
q, k, v, (const void *)((const char *)mask + iq1*stride_m),
|
||||
scale, 1.0f, nullptr, qkv);
|
||||
q += stride_q;
|
||||
qkv += stride_qkv;
|
||||
}
|
||||
return;
|
||||
}
|
||||
stride_q /= sizeof(float);
|
||||
const ggml_half h_inf = GGML_FP32_TO_FP16(-INFINITY);
|
||||
float cache[q_step*k_step];
|
||||
float S[q_step], M[q_step];
|
||||
__m512 vk[16];
|
||||
__m512 vms[q_step];
|
||||
__m512 vals[k_step/16];
|
||||
float qkv_cache[128*q_step];
|
||||
int need_scaling[q_step];
|
||||
auto vscale = _mm512_set1_ps(scale);
|
||||
auto vinf = _mm512_set1_ps(-INFINITY);
|
||||
for (int i1 = 0; i1 < nq1/q_step; ++i1) {
|
||||
template <int D, int q_step, int k_step>
|
||||
struct FlashAttn {
|
||||
static_assert(D%16 == 0 && D <= 256);
|
||||
static_assert(k_step%16 == 0);
|
||||
static_assert(q_step <= 4 || q_step%4 == 0);
|
||||
|
||||
constexpr static int vk_size = D <= 128 ? D/8 : D/16;
|
||||
static_assert(q_step <= vk_size);
|
||||
|
||||
FlashAttn(float scale) : vscale(_mm512_set1_ps(scale)), h_inf(GGML_FP32_TO_FP16(-INFINITY)) {}
|
||||
|
||||
inline void init_qstep() {
|
||||
for (int j = 0; j < q_step; ++j) {
|
||||
S[j] = 0; M[j] = -INFINITY;
|
||||
}
|
||||
for (int k1 = 0; k1 < nk1/k_step; ++k1) {
|
||||
// This is slower
|
||||
//DataInfo info{cache, (const char *)(q + q_step*i1*stride_q), k_step*sizeof(float), stride_q*sizeof(float), 0, 0, nullptr, 0};
|
||||
//mul_mat_fX_fY_T<q_step/2, ggml_half, float>(ne00, (const void *)((const char *)k + k_step*k1*stride_k), stride_k, info, k_step);
|
||||
//info.cur_y += q_step/2;
|
||||
//mul_mat_fX_fY_T<q_step/2, ggml_half, float>(ne00, (const void *)((const char *)k + k_step*k1*stride_k), stride_k, info, k_step);
|
||||
//for (int j = 0; j < q_step; ++j) {
|
||||
// const ggml_half * mp = (const ggml_half *)((const char *)mask + stride_m*q_step*i1) + k_step*k1;
|
||||
// for (int l = 0; l < k_step/16; ++l) {
|
||||
// auto val = _mm512_loadu_ps(cache + k_step*j + 16*l);
|
||||
// auto m16 = _mm256_cmpeq_epi16_mask(_mm256_loadu_si256((const __m256i *)mp), _mm256_setzero_si256());
|
||||
// vals[l] = _mm512_mask_mul_ps(vinf, m16, vscale, val);
|
||||
// }
|
||||
// auto smax = _mm512_reduce_max_ps(_mm512_max_ps(vals[0], vals[1]));
|
||||
// need_scaling[j] = 0;
|
||||
// if (smax > M[j]) {
|
||||
// if (M[j] > -INFINITY) {
|
||||
// float m = expf(M[j] - smax);
|
||||
// vms[j] = _mm512_set1_ps(m);
|
||||
// need_scaling[j] = 1;
|
||||
// S[j] *= m;
|
||||
// } else {
|
||||
// need_scaling[j] = 2;
|
||||
// S[j] = 0;
|
||||
// }
|
||||
// M[j] = smax;
|
||||
// }
|
||||
// auto vm = _mm512_set1_ps(M[j]);
|
||||
// for (int l = 0; l < k_step/16; ++l) {
|
||||
// vals[l] = v_expf(_mm512_sub_ps(vals[l], vm));
|
||||
// _mm512_storeu_ps(cache + k_step*j + 16*l, vals[l]);
|
||||
// }
|
||||
// S[j] += _mm512_reduce_add_ps(_mm512_add_ps(vals[0], vals[1]));
|
||||
//}
|
||||
}
|
||||
|
||||
for (int l1 = 0; l1 < k_step; l1 += 2) {
|
||||
auto kr1 = (const ggml_half *)((const char *)k + (k_step*k1 + l1 + 0)*stride_k);
|
||||
auto kr2 = (const ggml_half *)((const char *)k + (k_step*k1 + l1 + 1)*stride_k);
|
||||
for (int i = 0; i < 8; ++i) vk[i+0] = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr1 + i));
|
||||
for (int i = 0; i < 8; ++i) vk[i+8] = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr2 + i));
|
||||
for (int m1 = 0; m1 < q_step; ++m1) {
|
||||
// q index is q_step*i1 + m1
|
||||
// k index is k_step*k1 + l1
|
||||
const ggml_half * mp = (const ggml_half *)((const char *)mask + stride_m*(q_step*i1 + m1)) + k_step*k1;
|
||||
cache[k_step*m1 + l1 + 0] = cache[k_step*m1 + l1 + 1] = -INFINITY;
|
||||
if (mp[l1+0] == h_inf && mp[l1+1] == h_inf) {
|
||||
continue;
|
||||
}
|
||||
__m512 qv[8];
|
||||
auto qr = q + (q_step*i1 + m1)*stride_q;
|
||||
for (int i = 0; i < 8; ++i) qv[i] = _mm512_loadu_ps(qr + 16*i);
|
||||
if (mp[l1+0] != h_inf) {
|
||||
auto vsum = _mm512_mul_ps(vk[0], qv[0]);
|
||||
for (int i = 1; i < 8; ++i) vsum = _mm512_fmadd_ps(vk[i], qv[i], vsum);
|
||||
cache[k_step*m1 + l1 + 0] = _mm512_reduce_add_ps(vsum);
|
||||
}
|
||||
if (mp[l1+1] != h_inf) {
|
||||
auto vsum = _mm512_mul_ps(vk[8], qv[0]);
|
||||
for (int i = 1; i < 8; ++i) vsum = _mm512_fmadd_ps(vk[i+8], qv[i], vsum);
|
||||
cache[k_step*m1 + l1 + 1] = _mm512_reduce_add_ps(vsum);
|
||||
inline void multiply_mask_kq(int nq1, int stride_k, int stride_q, int stride_m,
|
||||
const char * k, const float * q, const char * mask);
|
||||
|
||||
inline void accumulate_qkv(int nq1, int stride_v, const char * v);
|
||||
|
||||
inline void normalize_and_store(int nq1, int stride_qkv, float * qkv) const {
|
||||
auto R = qkv_cache;
|
||||
for (int j = 0; j < nq1; ++j) {
|
||||
GGML_ASSERT(S[j] > 0);
|
||||
auto norm = _mm512_set1_ps(1/S[j]);
|
||||
for (int i = 0; i < D/16; ++i) {
|
||||
auto r = _mm512_loadu_ps(R + 16*i);
|
||||
_mm512_storeu_ps(qkv + 16*i, _mm512_mul_ps(norm, r));
|
||||
}
|
||||
qkv += stride_qkv;
|
||||
R += D;
|
||||
}
|
||||
}
|
||||
|
||||
inline void multiply_mask_kq(int stride_k, int stride_q, int stride_m,
|
||||
const char * k, const float * q, const char * mask);
|
||||
|
||||
inline void accumulate_qkv(int stride_v, const char * v);
|
||||
|
||||
inline void normalize_and_store(int stride_qkv, float * qkv) const {
|
||||
auto R = qkv_cache;
|
||||
for (int j = 0; j < q_step; ++j) {
|
||||
GGML_ASSERT(S[j] > 0);
|
||||
auto norm = _mm512_set1_ps(1/S[j]);
|
||||
for (int i = 0; i < D/16; ++i) {
|
||||
auto r = _mm512_loadu_ps(R + 16*i);
|
||||
_mm512_storeu_ps(qkv + 16*i, _mm512_mul_ps(norm, r));
|
||||
}
|
||||
qkv += stride_qkv;
|
||||
R += D;
|
||||
}
|
||||
}
|
||||
|
||||
void compute(int nq1, int nk1, int stride_k, int stride_q, int stride_m, int stride_v, int stride_qkv,
|
||||
const char * k, const float * q, const char * mask, const char * v, float * qkv) {
|
||||
for (int i1 = 0; i1 < nq1/q_step; ++i1) {
|
||||
init_qstep();
|
||||
auto kr = k;
|
||||
auto vr = v;
|
||||
auto mr = mask;
|
||||
for (int k1 = 0; k1 < nk1/k_step; ++k1) {
|
||||
multiply_mask_kq(stride_k, stride_q, stride_m, kr, q, mr);
|
||||
accumulate_qkv(stride_v, vr);
|
||||
kr += k_step*stride_k;
|
||||
vr += k_step*stride_v;
|
||||
mr += k_step*sizeof(ggml_half);
|
||||
}
|
||||
normalize_and_store(stride_qkv, qkv);
|
||||
|
||||
q += q_step*stride_q;
|
||||
mask += q_step*stride_m;
|
||||
qkv += q_step*stride_qkv;
|
||||
}
|
||||
int n_left = nq1 - q_step*(nq1/q_step);
|
||||
if (n_left > 0) {
|
||||
init_qstep();
|
||||
auto kr = k;
|
||||
auto vr = v;
|
||||
auto mr = mask;
|
||||
for (int k1 = 0; k1 < nk1/k_step; ++k1) {
|
||||
multiply_mask_kq(n_left, stride_k, stride_q, stride_m, kr, q, mr);
|
||||
accumulate_qkv(n_left, stride_v, vr);
|
||||
kr += k_step*stride_k;
|
||||
vr += k_step*stride_v;
|
||||
mr += k_step*sizeof(ggml_half);
|
||||
}
|
||||
normalize_and_store(n_left, stride_qkv, qkv);
|
||||
}
|
||||
}
|
||||
|
||||
float cache[q_step*k_step];
|
||||
float qkv_cache[D*q_step];
|
||||
float S[q_step], M[q_step];
|
||||
int need_scaling[q_step];
|
||||
__m512 vms[q_step];
|
||||
__m512 vk[vk_size];
|
||||
const __m512 vscale;
|
||||
const ggml_half h_inf;
|
||||
|
||||
typedef __m512 (*combine_t)(__m512, __m512);
|
||||
typedef float (*reduce_t)(__m512);
|
||||
template <reduce_t Op, combine_t Op_combine>
|
||||
static inline float reduce_T(const __m512 * vals) {
|
||||
float result;
|
||||
if constexpr (k_step/16 == 1) {
|
||||
result = Op(vals[0]);
|
||||
}
|
||||
else if constexpr (k_step/16 == 2) {
|
||||
result = Op(Op_combine(vals[0], vals[1]));
|
||||
}
|
||||
else {
|
||||
auto vmax = vals[0];
|
||||
for (int l = 1; l < k_step/16; ++l) vmax = Op_combine(vmax, vals[l]);
|
||||
result = Op(vmax);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
template <int D, int q_step, int k_step>
|
||||
void FlashAttn<D, q_step, k_step>::multiply_mask_kq(int stride_k, int stride_q, int stride_m,
|
||||
const char * k, const float * q, const char * mask) {
|
||||
if constexpr (D <= 128) {
|
||||
__m512 qv[D/16];
|
||||
for (int l1 = 0; l1 < k_step; l1 += 2) {
|
||||
auto kr1 = (const ggml_half *)(k + (l1 + 0)*stride_k);
|
||||
auto kr2 = (const ggml_half *)(k + (l1 + 1)*stride_k);
|
||||
for (int i = 0; i < D/16; ++i) vk[i+ 0] = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr1 + i));
|
||||
for (int i = 0; i < D/16; ++i) vk[i+D/16] = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr2 + i));
|
||||
for (int m1 = 0; m1 < q_step; ++m1) {
|
||||
// q index is q_step*i1 + m1
|
||||
// k index is k_step*k1 + l1
|
||||
const ggml_half * mp = (const ggml_half *)(mask + stride_m*m1);
|
||||
cache[k_step*m1 + l1 + 0] = cache[k_step*m1 + l1 + 1] = -INFINITY;
|
||||
if (mp[l1+0] == h_inf && mp[l1+1] == h_inf) {
|
||||
continue;
|
||||
}
|
||||
auto qr = q + m1*stride_q;
|
||||
for (int i = 0; i < D/16; ++i) qv[i] = _mm512_loadu_ps(qr + 16*i);
|
||||
if (mp[l1+0] != h_inf) {
|
||||
auto vsum = _mm512_mul_ps(vk[0], qv[0]);
|
||||
for (int i = 1; i < 8; ++i) vsum = _mm512_fmadd_ps(vk[i], qv[i], vsum);
|
||||
cache[k_step*m1 + l1 + 0] = _mm512_reduce_add_ps(vsum);
|
||||
}
|
||||
if (mp[l1+1] != h_inf) {
|
||||
auto vsum = _mm512_mul_ps(vk[D/16], qv[0]);
|
||||
for (int i = 1; i < 8; ++i) vsum = _mm512_fmadd_ps(vk[i+D/16], qv[i], vsum);
|
||||
cache[k_step*m1 + l1 + 1] = _mm512_reduce_add_ps(vsum);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
else {
|
||||
DataInfo info{cache, (const char *)q, k_step*sizeof(float), stride_q*sizeof(float), 0, 0, nullptr, 0};
|
||||
for (int i = 0; i < q_step/4; ++i) {
|
||||
mul_mat_fX_fY_T<4, ggml_half, float>(D, (const void *)k, stride_k, info, k_step);
|
||||
info.cur_y += 4;
|
||||
}
|
||||
int n_left = q_step - 4*(q_step/4);
|
||||
if (n_left > 0) {
|
||||
switch (n_left) {
|
||||
case 1: mul_mat_fX_fY_T<1, ggml_half, float>(D, (const void *)k, stride_k, info, k_step); break;
|
||||
case 2: mul_mat_fX_fY_T<2, ggml_half, float>(D, (const void *)k, stride_k, info, k_step); break;
|
||||
case 3: mul_mat_fX_fY_T<3, ggml_half, float>(D, (const void *)k, stride_k, info, k_step); break;
|
||||
default: break;
|
||||
}
|
||||
}
|
||||
//if constexpr (q_step <= 4) {
|
||||
// mul_mat_fX_fY_T<q_step, ggml_half, float>(D, (const void *)k, stride_k, info, k_step);
|
||||
//}
|
||||
//else {
|
||||
// for (int i = 0; i < q_step/4; ++i) {
|
||||
// mul_mat_fX_fY_T<4, ggml_half, float>(D, (const void *)k, stride_k, info, k_step);
|
||||
// info.cur_y += 4;
|
||||
// }
|
||||
//}
|
||||
}
|
||||
for (int j = 0; j < q_step; ++j) {
|
||||
if constexpr (D <= 128) {
|
||||
for (int l = 0; l < k_step/16; ++l) vk[l] = _mm512_mul_ps(vscale, _mm512_loadu_ps(cache + k_step*j + 16*l));
|
||||
} else {
|
||||
auto vinf = _mm512_set1_ps(-INFINITY);
|
||||
const ggml_half * mp = (const ggml_half *)mask;
|
||||
for (int l = 0; l < k_step/16; ++l) {
|
||||
auto val = _mm512_loadu_ps(cache + k_step*j + 16*l);
|
||||
auto m16 = _mm256_cmpeq_epi16_mask(_mm256_loadu_si256((const __m256i *)mp), _mm256_setzero_si256());
|
||||
vk[l] = _mm512_mask_mul_ps(vinf, m16, vscale, val);
|
||||
}
|
||||
}
|
||||
float smax = reduce_T<_mm512_reduce_max_ps, _mm512_max_ps>(vk);
|
||||
need_scaling[j] = 0;
|
||||
if (smax > M[j]) {
|
||||
if (M[j] > -INFINITY) {
|
||||
float m = expf(M[j] - smax);
|
||||
vms[j] = _mm512_set1_ps(m);
|
||||
need_scaling[j] = 1;
|
||||
S[j] *= m;
|
||||
} else {
|
||||
need_scaling[j] = 2;
|
||||
S[j] = 0;
|
||||
}
|
||||
M[j] = smax;
|
||||
}
|
||||
auto vm = _mm512_set1_ps(M[j]);
|
||||
for (int l = 0; l < k_step/16; ++l) {
|
||||
vk[l] = v_expf(_mm512_sub_ps(vk[l], vm));
|
||||
_mm512_storeu_ps(cache + k_step*j + 16*l, vk[l]);
|
||||
}
|
||||
S[j] += reduce_T<_mm512_reduce_add_ps, _mm512_add_ps>(vk);
|
||||
}
|
||||
}
|
||||
|
||||
template <int D, int q_step, int k_step>
|
||||
void FlashAttn<D, q_step, k_step>::accumulate_qkv(int stride_v, const char * v) {
|
||||
if constexpr (2*q_step <= vk_size) {
|
||||
for (int i = 0; i < D/16; i += 2) {
|
||||
for (int j = 0; j < q_step; ++j) {
|
||||
if (need_scaling[j] == 2) {
|
||||
vk[2*j+0] = vk[2*j+1] = _mm512_setzero_ps();
|
||||
} else {
|
||||
auto R = qkv_cache + D*j;
|
||||
vk[2*j+0] = _mm512_loadu_ps(R + 16*i);
|
||||
vk[2*j+1] = _mm512_loadu_ps(R + 16*i + 16);
|
||||
if (need_scaling[j] == 1) {
|
||||
vk[2*j+0] = _mm512_mul_ps(vk[2*j+0], vms[j]);
|
||||
vk[2*j+1] = _mm512_mul_ps(vk[2*j+1], vms[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int l1 = 0; l1 < k_step; ++l1) {
|
||||
auto vr = (const ggml_half *)(v + l1*stride_v);
|
||||
auto v1 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)vr+i+0));
|
||||
auto v2 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)vr+i+1));
|
||||
for (int j = 0; j < q_step; ++j) {
|
||||
auto vs = _mm512_set1_ps(cache[k_step*j + l1]);
|
||||
vk[2*j+0] = _mm512_fmadd_ps(v1, vs, vk[2*j+0]);
|
||||
vk[2*j+1] = _mm512_fmadd_ps(v2, vs, vk[2*j+1]);
|
||||
}
|
||||
}
|
||||
for (int j = 0; j < q_step; ++j) {
|
||||
for (int l = 0; l < k_step/16; ++l) vals[l] = _mm512_mul_ps(vscale, _mm512_loadu_ps(cache + k_step*j + 16*l));
|
||||
auto smax = _mm512_reduce_max_ps(_mm512_max_ps(vals[0], vals[1]));
|
||||
need_scaling[j] = 0;
|
||||
if (smax > M[j]) {
|
||||
if (M[j] > -INFINITY) {
|
||||
float m = expf(M[j] - smax);
|
||||
vms[j] = _mm512_set1_ps(m);
|
||||
need_scaling[j] = 1;
|
||||
S[j] *= m;
|
||||
} else {
|
||||
need_scaling[j] = 2;
|
||||
S[j] = 0;
|
||||
}
|
||||
M[j] = smax;
|
||||
}
|
||||
auto vm = _mm512_set1_ps(M[j]);
|
||||
for (int l = 0; l < k_step/16; ++l) {
|
||||
vals[l] = v_expf(_mm512_sub_ps(vals[l], vm));
|
||||
_mm512_storeu_ps(cache + k_step*j + 16*l, vals[l]);
|
||||
}
|
||||
S[j] += _mm512_reduce_add_ps(_mm512_add_ps(vals[0], vals[1]));
|
||||
}
|
||||
for (int i = 0; i < 8; i += 2) {
|
||||
for (int j = 0; j < q_step; ++j) {
|
||||
if (need_scaling[j] == 2) {
|
||||
vk[2*j+0] = vk[2*j+1] = _mm512_setzero_ps();
|
||||
} else {
|
||||
auto R = qkv_cache + 128*j;
|
||||
vk[2*j+0] = _mm512_loadu_ps(R + 16*i);
|
||||
vk[2*j+1] = _mm512_loadu_ps(R + 16*i + 16);
|
||||
if (need_scaling[j] == 1) {
|
||||
vk[2*j+0] = _mm512_mul_ps(vk[2*j+0], vms[j]);
|
||||
vk[2*j+1] = _mm512_mul_ps(vk[2*j+1], vms[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int l1 = 0; l1 < k_step; ++l1) {
|
||||
auto vr = (const ggml_half *)((const char *)v + (k_step*k1 + l1)*stride_v);
|
||||
auto v1 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)vr+i+0));
|
||||
auto v2 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)vr+i+1));
|
||||
for (int j = 0; j < q_step; ++j) {
|
||||
auto vs = _mm512_set1_ps(cache[k_step*j + l1]);
|
||||
vk[2*j+0] = _mm512_fmadd_ps(v1, vs, vk[2*j+0]);
|
||||
vk[2*j+1] = _mm512_fmadd_ps(v2, vs, vk[2*j+1]);
|
||||
}
|
||||
}
|
||||
for (int j = 0; j < q_step; ++j) {
|
||||
auto R = qkv_cache + 128*j;
|
||||
_mm512_storeu_ps(R + 16*i, vk[2*j+0]);
|
||||
_mm512_storeu_ps(R + 16*i + 16, vk[2*j+1]);
|
||||
}
|
||||
auto R = qkv_cache + D*j;
|
||||
_mm512_storeu_ps(R + 16*i, vk[2*j+0]);
|
||||
_mm512_storeu_ps(R + 16*i + 16, vk[2*j+1]);
|
||||
}
|
||||
}
|
||||
for (int j = 0; j < q_step; ++j) {
|
||||
GGML_ASSERT(S[j] > 0);
|
||||
auto R = qkv_cache + 128*j;
|
||||
auto final_R = qkv + (q_step*i1 + j)*stride_qkv;
|
||||
auto norm = _mm512_set1_ps(1/S[j]);
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
auto r = _mm512_loadu_ps(R + 16*i);
|
||||
_mm512_storeu_ps(final_R + 16*i, _mm512_mul_ps(norm, r));
|
||||
} else {
|
||||
for (int i = 0; i < D/16; ++i) {
|
||||
for (int j = 0; j < q_step; ++j) {
|
||||
if (need_scaling[j] == 2) {
|
||||
vk[j] = _mm512_setzero_ps();
|
||||
} else {
|
||||
auto R = qkv_cache + D*j;
|
||||
vk[j] = _mm512_loadu_ps(R + 16*i);
|
||||
if (need_scaling[j] == 1) {
|
||||
vk[j] = _mm512_mul_ps(vk[j], vms[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int l1 = 0; l1 < k_step; ++l1) {
|
||||
auto vr = (const ggml_half *)(v + l1*stride_v);
|
||||
auto v = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)vr+i));
|
||||
for (int j = 0; j < q_step; ++j) {
|
||||
auto vs = _mm512_set1_ps(cache[k_step*j + l1]);
|
||||
vk[j] = _mm512_fmadd_ps(v, vs, vk[j]);
|
||||
}
|
||||
}
|
||||
for (int j = 0; j < q_step; ++j) {
|
||||
auto R = qkv_cache + D*j;
|
||||
_mm512_storeu_ps(R + 16*i, vk[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
return;
|
||||
if (nq1%16 != 0 || nk1%16 != 0) printf("Oops(%s): nq1 = %d, nk1 = %d\n", __func__, nq1, nk1);
|
||||
//GGML_ASSERT(nq1%16 == 0 && nk1%16 == 0);
|
||||
//auto vinf = _mm512_set1_ps(-INFINITY);
|
||||
for (int i1 = 0; i1 < nq1/16; ++i1) {
|
||||
//int iq1 = 16*i1;
|
||||
for (int j1 = 0; j1 < 16; ++j1) {
|
||||
S[j1] = 0; M[j1] = -INFINITY;
|
||||
std::memset(qkv + j1*stride_v, 0, ne00*sizeof(float));
|
||||
}
|
||||
for (int ik = 0; ik < nk1; ik += 16) {
|
||||
/////////////////////////////////////////////////////////////////////////////////
|
||||
const ggml_half * kr = (const ggml_half *)((const char *)k + stride_k*ik);
|
||||
DataInfo info{cache, (const char *)q, 16*sizeof(float), size_t(stride_q)*sizeof(float), 0, 0, nullptr, 0};
|
||||
mul_mat_fX_fY_T<4, ggml_half, float>(ne00, (const void *)kr, stride_k, info, 16);
|
||||
/////////////////////////////////////////////////////////////////////////////////
|
||||
float * R = qkv;
|
||||
for (int j1 = 0; j1 < 16; ++j1) {
|
||||
int iq1 = 16*i1 + j1;
|
||||
float * C = cache + 16*j1;
|
||||
auto qk = _mm512_loadu_ps(C);
|
||||
const ggml_half * mp = (const ggml_half *)((const char *)mask + stride_m*iq1);
|
||||
auto m16 = _mm256_cmpeq_epi16_mask(_mm256_loadu_si256((const __m256i*)mp), _mm256_setzero_si256());
|
||||
qk = _mm512_mask_blend_ps(m16, vinf, qk);
|
||||
float smax = _mm512_reduce_max_ps(qk);
|
||||
if (smax > M[j1]) {
|
||||
if (M[j1] > -INFINITY) {
|
||||
float m = expf(M[j1] - smax);
|
||||
auto ms = _mm512_set1_ps(m);
|
||||
for (int i = 0; i < ne00/16; ++i) _mm512_storeu_ps(R + 16*i, _mm512_mul_ps(ms, _mm512_loadu_ps(R + 16*i)));
|
||||
S[j1] *= m;
|
||||
} else {
|
||||
std::memset(R, 0, ne00*sizeof(float));
|
||||
S[j1] = 0;
|
||||
}
|
||||
M[j1] = smax;
|
||||
}
|
||||
|
||||
template <int D, int q_step, int k_step>
|
||||
void FlashAttn<D, q_step, k_step>::multiply_mask_kq(int nq1, int stride_k, int stride_q, int stride_m,
|
||||
const char * k, const float * q, const char * mask) {
|
||||
if constexpr (D <= 128) {
|
||||
__m512 qv[D/16];
|
||||
for (int l1 = 0; l1 < k_step; l1 += 2) {
|
||||
auto kr1 = (const ggml_half *)(k + (l1 + 0)*stride_k);
|
||||
auto kr2 = (const ggml_half *)(k + (l1 + 1)*stride_k);
|
||||
for (int i = 0; i < D/16; ++i) vk[i+ 0] = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr1 + i));
|
||||
for (int i = 0; i < D/16; ++i) vk[i+D/16] = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr2 + i));
|
||||
for (int m1 = 0; m1 < nq1; ++m1) {
|
||||
// q index is q_step*i1 + m1
|
||||
// k index is k_step*k1 + l1
|
||||
const ggml_half * mp = (const ggml_half *)(mask + stride_m*m1);
|
||||
cache[k_step*m1 + l1 + 0] = cache[k_step*m1 + l1 + 1] = -INFINITY;
|
||||
if (mp[l1+0] == h_inf && mp[l1+1] == h_inf) {
|
||||
continue;
|
||||
}
|
||||
auto vs = v_expf(_mm512_sub_ps(qk, _mm512_set1_ps(M[j1])));
|
||||
S[j1] += _mm512_reduce_add_ps(vs);
|
||||
_mm512_storeu_ps(C, vs);
|
||||
for (int jk = 0; jk < 16; ++jk) {
|
||||
vs = _mm512_set1_ps(C[jk]);
|
||||
const ggml_half * vr = (const ggml_half *)((const char *)v + stride_v*(ik + jk));
|
||||
for (int i = 0; i < ne00/16; ++i) {
|
||||
auto v = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)vr + i));
|
||||
auto r = _mm512_loadu_ps(qkv + 16*i);
|
||||
_mm512_storeu_ps(qkv + 16*i, _mm512_fmadd_ps(vs, v, r));
|
||||
}
|
||||
auto qr = q + m1*stride_q;
|
||||
for (int i = 0; i < D/16; ++i) qv[i] = _mm512_loadu_ps(qr + 16*i);
|
||||
if (mp[l1+0] != h_inf) {
|
||||
auto vsum = _mm512_mul_ps(vk[0], qv[0]);
|
||||
for (int i = 1; i < 8; ++i) vsum = _mm512_fmadd_ps(vk[i], qv[i], vsum);
|
||||
cache[k_step*m1 + l1 + 0] = _mm512_reduce_add_ps(vsum);
|
||||
}
|
||||
R += stride_qkv;
|
||||
}
|
||||
}
|
||||
for (int j1 = 0; j1 < 16; ++j1) {
|
||||
if (S[j1] > 0) {
|
||||
//GGML_ASSERT(S[j1] > 0);
|
||||
auto norm = _mm512_set1_ps(1/S[j1]);
|
||||
for (int i = 0; i < ne00/16; ++i) {
|
||||
auto r = _mm512_loadu_ps(qkv + 16*i);
|
||||
_mm512_storeu_ps(qkv + 16*i, _mm512_mul_ps(norm, r));
|
||||
if (mp[l1+1] != h_inf) {
|
||||
auto vsum = _mm512_mul_ps(vk[8], qv[0]);
|
||||
for (int i = 1; i < 8; ++i) vsum = _mm512_fmadd_ps(vk[i+D/16], qv[i], vsum);
|
||||
cache[k_step*m1 + l1 + 1] = _mm512_reduce_add_ps(vsum);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
else {
|
||||
DataInfo info{cache, (const char *)q, k_step*sizeof(float), stride_q*sizeof(float), 0, 0, nullptr, 0};
|
||||
for (int i = 0; i < nq1/4; ++i) {
|
||||
mul_mat_fX_fY_T<4, ggml_half, float>(D, (const void *)k, stride_k, info, k_step);
|
||||
info.cur_y += 4;
|
||||
}
|
||||
int n_left = nq1 - 4*(nq1/4);
|
||||
if (n_left > 0) {
|
||||
switch (n_left) {
|
||||
case 1: mul_mat_fX_fY_T<1, ggml_half, float>(D, (const void *)k, stride_k, info, k_step); break;
|
||||
case 2: mul_mat_fX_fY_T<2, ggml_half, float>(D, (const void *)k, stride_k, info, k_step); break;
|
||||
case 3: mul_mat_fX_fY_T<3, ggml_half, float>(D, (const void *)k, stride_k, info, k_step); break;
|
||||
default: break;
|
||||
}
|
||||
}
|
||||
//if constexpr (q_step <= 4) {
|
||||
// mul_mat_fX_fY_T<q_step, ggml_half, float>(D, (const void *)k, stride_k, info, k_step);
|
||||
//}
|
||||
//else {
|
||||
// for (int i = 0; i < q_step/4; ++i) {
|
||||
// mul_mat_fX_fY_T<4, ggml_half, float>(D, (const void *)k, stride_k, info, k_step);
|
||||
// info.cur_y += 4;
|
||||
// }
|
||||
//}
|
||||
}
|
||||
for (int j = 0; j < nq1; ++j) {
|
||||
if constexpr (D <= 128) {
|
||||
for (int l = 0; l < k_step/16; ++l) vk[l] = _mm512_mul_ps(vscale, _mm512_loadu_ps(cache + k_step*j + 16*l));
|
||||
} else {
|
||||
auto vinf = _mm512_set1_ps(-INFINITY);
|
||||
const ggml_half * mp = (const ggml_half *)mask;
|
||||
for (int l = 0; l < k_step/16; ++l) {
|
||||
auto val = _mm512_loadu_ps(cache + k_step*j + 16*l);
|
||||
auto m16 = _mm256_cmpeq_epi16_mask(_mm256_loadu_si256((const __m256i *)mp), _mm256_setzero_si256());
|
||||
vk[l] = _mm512_mask_mul_ps(vinf, m16, vscale, val);
|
||||
}
|
||||
}
|
||||
float smax = reduce_T<_mm512_reduce_max_ps, _mm512_max_ps>(vk);
|
||||
need_scaling[j] = 0;
|
||||
if (smax > M[j]) {
|
||||
if (M[j] > -INFINITY) {
|
||||
float m = expf(M[j] - smax);
|
||||
vms[j] = _mm512_set1_ps(m);
|
||||
need_scaling[j] = 1;
|
||||
S[j] *= m;
|
||||
} else {
|
||||
need_scaling[j] = 2;
|
||||
S[j] = 0;
|
||||
}
|
||||
M[j] = smax;
|
||||
}
|
||||
auto vm = _mm512_set1_ps(M[j]);
|
||||
for (int l = 0; l < k_step/16; ++l) {
|
||||
vk[l] = v_expf(_mm512_sub_ps(vk[l], vm));
|
||||
_mm512_storeu_ps(cache + k_step*j + 16*l, vk[l]);
|
||||
}
|
||||
S[j] += reduce_T<_mm512_reduce_add_ps, _mm512_add_ps>(vk);
|
||||
}
|
||||
}
|
||||
|
||||
template <int D, int q_step, int k_step>
|
||||
void FlashAttn<D, q_step, k_step>::accumulate_qkv(int nq1, int stride_v, const char * v) {
|
||||
if (2*nq1 <= vk_size) {
|
||||
for (int i = 0; i < D/16; i += 2) {
|
||||
for (int j = 0; j < nq1; ++j) {
|
||||
if (need_scaling[j] == 2) {
|
||||
vk[2*j+0] = vk[2*j+1] = _mm512_setzero_ps();
|
||||
} else {
|
||||
auto R = qkv_cache + D*j;
|
||||
vk[2*j+0] = _mm512_loadu_ps(R + 16*i);
|
||||
vk[2*j+1] = _mm512_loadu_ps(R + 16*i + 16);
|
||||
if (need_scaling[j] == 1) {
|
||||
vk[2*j+0] = _mm512_mul_ps(vk[2*j+0], vms[j]);
|
||||
vk[2*j+1] = _mm512_mul_ps(vk[2*j+1], vms[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int l1 = 0; l1 < k_step; ++l1) {
|
||||
auto vr = (const ggml_half *)(v + l1*stride_v);
|
||||
auto v1 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)vr+i+0));
|
||||
auto v2 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)vr+i+1));
|
||||
for (int j = 0; j < q_step; ++j) {
|
||||
auto vs = _mm512_set1_ps(cache[k_step*j + l1]);
|
||||
vk[2*j+0] = _mm512_fmadd_ps(v1, vs, vk[2*j+0]);
|
||||
vk[2*j+1] = _mm512_fmadd_ps(v2, vs, vk[2*j+1]);
|
||||
}
|
||||
}
|
||||
for (int j = 0; j < nq1; ++j) {
|
||||
auto R = qkv_cache + D*j;
|
||||
_mm512_storeu_ps(R + 16*i, vk[2*j+0]);
|
||||
_mm512_storeu_ps(R + 16*i + 16, vk[2*j+1]);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < D/16; ++i) {
|
||||
for (int j = 0; j < nq1; ++j) {
|
||||
if (need_scaling[j] == 2) {
|
||||
vk[j] = _mm512_setzero_ps();
|
||||
} else {
|
||||
auto R = qkv_cache + D*j;
|
||||
vk[j] = _mm512_loadu_ps(R + 16*i);
|
||||
if (need_scaling[j] == 1) {
|
||||
vk[j] = _mm512_mul_ps(vk[j], vms[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int l1 = 0; l1 < k_step; ++l1) {
|
||||
auto vr = (const ggml_half *)(v + l1*stride_v);
|
||||
auto v = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)vr+i));
|
||||
for (int j = 0; j < q_step; ++j) {
|
||||
auto vs = _mm512_set1_ps(cache[k_step*j + l1]);
|
||||
vk[j] = _mm512_fmadd_ps(v, vs, vk[j]);
|
||||
}
|
||||
}
|
||||
for (int j = 0; j < nq1; ++j) {
|
||||
auto R = qkv_cache + D*j;
|
||||
_mm512_storeu_ps(R + 16*i, vk[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int D, int q_step, int k_step>
|
||||
inline void iqk_flash_helper_T(int nq1, int nk1, int stride_q, int stride_k, int stride_v, int stride_m, int stride_qkv,
|
||||
const float * q, const char * k, const char * v, const char * mask,
|
||||
float scale, float * qkv) {
|
||||
if (nq1 >= q_step) {
|
||||
FlashAttn<D, q_step, k_step> fa(scale);
|
||||
fa.compute(nq1, nk1, stride_k, stride_q, stride_m, stride_v, stride_qkv,
|
||||
(const char *)k, q, (const char *)mask, (const char *)v, qkv);
|
||||
} else {
|
||||
FlashAttn<D, 1, k_step> fa(scale);
|
||||
fa.compute(nq1, nk1, stride_k, stride_q, stride_m, stride_v, stride_qkv,
|
||||
(const char *)k, q, (const char *)mask, (const char *)v, qkv);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
bool iqk_flash_helper_3(int ne00, // attention head size
|
||||
int nq1, // number of columns in q
|
||||
int nk1, // number of rows in k
|
||||
int stride_q, // distance between q columns in bytes
|
||||
int stride_k, // distance between k rows in bytes
|
||||
int stride_v, // distance between v rows (in bytes)
|
||||
int stride_m, // distance between rows in mask (in bytes)
|
||||
int stride_qkv, // distance between qkv rows in bytes
|
||||
const float * q, // q matrix
|
||||
const void * k, // k matrix. Assumed to be fp16, ne00 x nk elements
|
||||
const void * v, // v matrix. Assumed to be fp16, ne00 x nk elements
|
||||
const void * mask, // mask. If not null, assumed to be fp16. nq*nk elements
|
||||
float scale, // the scale in softmax(scale*(k*q))
|
||||
float * qkv) { // the qkv result
|
||||
if (!mask) return false; // we assume the mask is not null in the implementation
|
||||
if (nk1%32 != 0) {
|
||||
const char * mp = (const char *)mask;
|
||||
for (int iq1 = 0; iq1 < nq1; ++iq1) {
|
||||
iqk_flash_helper_2(false, ne00, nk1, stride_k, stride_v,
|
||||
q, k, v, (const void *)mp,
|
||||
scale, 1.0f, nullptr, qkv);
|
||||
q += stride_q;
|
||||
qkv += stride_qkv;
|
||||
mp += stride_m;
|
||||
}
|
||||
q += 16*stride_q;
|
||||
return true;
|
||||
}
|
||||
|
||||
stride_q /= sizeof(float); // q stride as float
|
||||
|
||||
auto ck = (const char *)k;
|
||||
auto cv = (const char *)v;
|
||||
auto cm = (const char *)mask;
|
||||
|
||||
switch (ne00) {
|
||||
case 64:
|
||||
iqk_flash_helper_T< 64, 4, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, qkv); break;
|
||||
case 80:
|
||||
iqk_flash_helper_T< 80, 4, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, qkv); break;
|
||||
case 96:
|
||||
iqk_flash_helper_T< 96, 4, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, qkv); break;
|
||||
case 112:
|
||||
iqk_flash_helper_T<112, 4, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, qkv); break;
|
||||
case 128:
|
||||
iqk_flash_helper_T<128, 8, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, qkv); break;
|
||||
case 256:
|
||||
iqk_flash_helper_T<256, 8, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, qkv); break;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
|
||||
}
|
||||
|
||||
//bool iqk_flash_attention_noalibi_f16(int ith, int nth,
|
||||
|
||||
@@ -58,7 +58,7 @@ void iqk_flash_helper_2(bool is_alibi,
|
||||
float * qk,
|
||||
float * qkv); // softmax(k*q) - k elements
|
||||
|
||||
void iqk_flash_helper_3(int ne00,
|
||||
bool iqk_flash_helper_3(int ne00,
|
||||
int nq, // number of elements in q
|
||||
int nk, // number of rows in k
|
||||
int stride_q,
|
||||
|
||||
Reference in New Issue
Block a user