mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-03-05 19:40:19 +00:00
Zen4 Flash Attnetion: WIP generalize to other types
Now loading of data from K and V is done via a template parameter, so this should make it easy to generalize to typ[es other than F16 for the K and V cache.
This commit is contained in:
@@ -6057,6 +6057,35 @@ inline __m256 v_tanh(__m256 x) {
|
||||
|
||||
namespace {
|
||||
|
||||
template <int D, int k_step>
|
||||
struct KHelperF16 {
|
||||
KHelperF16(const char * k, int stride_k) : k(k), kb(k), stride_k(stride_k) {}
|
||||
|
||||
inline void set_block(int k1) { kb = k + k1*k_step*stride_k; }
|
||||
inline void reset_block() { kb = k; }
|
||||
inline void next_block() { kb += k_step*stride_k; }
|
||||
|
||||
inline void load(int l1, __m512 * vk) const {
|
||||
auto kr = (const ggml_half *)(kb + l1*stride_k);
|
||||
for (int i = 0; i < D/16; ++i) vk[i] = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr + i));
|
||||
}
|
||||
|
||||
inline void load(int l1, int i, __m512& v1, __m512& v2) const {
|
||||
auto kr = (const ggml_half *)(kb + l1*stride_k);
|
||||
v1 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr + i + 0));
|
||||
v2 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr + i + 1));
|
||||
}
|
||||
|
||||
inline void load_2(int l1, __m512 * vk) const {
|
||||
load(l1+0, vk+0);
|
||||
load(l1+1, vk+D/16);
|
||||
}
|
||||
|
||||
const char * k;
|
||||
const char * kb;
|
||||
int stride_k;
|
||||
};
|
||||
|
||||
// Some of the methods in FlashAttn have two identical implementations that only differ by
|
||||
// one version using a loop over the template parameter q_step, while the other using a loop
|
||||
// over an input parameter nq (these are loops over the rows of q^T). I dislike this a lot,
|
||||
@@ -6098,13 +6127,13 @@ struct FlashAttn {
|
||||
auto qr = q + m1*stride_q;
|
||||
for (int i = 0; i < D/16; ++i) qv[i] = _mm512_loadu_ps(qr + 16*i);
|
||||
if (mp[l1+0] != h_inf) {
|
||||
auto vsum = _mm512_mul_ps(vk[0], qv[0]);
|
||||
for (int i = 1; i < D/16; ++i) vsum = _mm512_fmadd_ps(vk[i], qv[i], vsum);
|
||||
auto vsum = _mm512_setzero_ps();
|
||||
for (int i = 0; i < D/16; ++i) vsum = _mm512_fmadd_ps(vk[i], qv[i], vsum);
|
||||
cache[k_step*m1 + l1 + 0] = _mm512_reduce_add_ps(vsum);
|
||||
}
|
||||
if (mp[l1+1] != h_inf) {
|
||||
auto vsum = _mm512_mul_ps(vk[D/16], qv[0]);
|
||||
for (int i = 1; i < D/16; ++i) vsum = _mm512_fmadd_ps(vk[i+D/16], qv[i], vsum);
|
||||
auto vsum = _mm512_setzero_ps();
|
||||
for (int i = 0; i < D/16; ++i) vsum = _mm512_fmadd_ps(vk[i+D/16], qv[i], vsum);
|
||||
cache[k_step*m1 + l1 + 1] = _mm512_reduce_add_ps(vsum);
|
||||
}
|
||||
}
|
||||
@@ -6119,8 +6148,8 @@ struct FlashAttn {
|
||||
return;
|
||||
}
|
||||
auto qr = q + m1*stride_q;
|
||||
auto vsum = _mm512_mul_ps(vk[0], _mm512_loadu_ps(qr));
|
||||
for (int i = 1; i < D/16; ++i) {
|
||||
auto vsum = _mm512_setzero_ps();
|
||||
for (int i = 0; i < D/16; ++i) {
|
||||
vsum = _mm512_fmadd_ps(vk[i], _mm512_loadu_ps(qr + 16*i), vsum);
|
||||
}
|
||||
cache[k_step*m1 + l1] = _mm512_reduce_add_ps(vsum);
|
||||
@@ -6191,78 +6220,70 @@ struct FlashAttn {
|
||||
}
|
||||
}
|
||||
|
||||
template <bool small = is_small_head, class = std::enable_if<small>>
|
||||
inline void mult_mask_kq(int stride_k, int stride_q, int stride_m,
|
||||
const char * k, const float * q, const char * mask) {
|
||||
template <typename KHelper, bool small = is_small_head, class = std::enable_if<small>>
|
||||
inline void mult_mask_kq(const KHelper& kh, int stride_q, int stride_m, const float * q, const char * mask) {
|
||||
__m512 qv[D/16];
|
||||
for (int l1 = 0; l1 < k_step; l1 += 2) {
|
||||
auto kr1 = (const ggml_half *)(k + (l1 + 0)*stride_k);
|
||||
auto kr2 = (const ggml_half *)(k + (l1 + 1)*stride_k);
|
||||
for (int i = 0; i < D/16; ++i) vk[i+ 0] = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr1 + i));
|
||||
for (int i = 0; i < D/16; ++i) vk[i+D/16] = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr2 + i));
|
||||
kh.load_2(l1, vk);
|
||||
for (int m1 = 0; m1 < q_step; ++m1) {
|
||||
mult_mask_kq_one(l1, m1, stride_q, stride_m, q, mask, qv);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <bool small = is_small_head, class = std::enable_if<!small>>
|
||||
inline void mult_mask_kq_l(int stride_k, int stride_q, int stride_m,
|
||||
const char * k, const float * q, const char * mask) {
|
||||
template <typename KHelper, bool small = is_small_head, class = std::enable_if<!small>>
|
||||
inline void mult_mask_kq_l(const KHelper& kh, int stride_q, int stride_m,
|
||||
const float * q, const char * mask) {
|
||||
for (int l1 = 0; l1 < k_step; ++l1) {
|
||||
auto kr = (const ggml_half *)(k + l1*stride_k);
|
||||
for (int i = 0; i < D/16; ++i) vk[i] = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr + i));
|
||||
kh.load(l1, vk);
|
||||
for (int m1 = 0; m1 < q_step; ++m1) {
|
||||
mult_mask_kq_one(l1, m1, stride_q, stride_m, q, mask);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <bool small = is_small_head, class = std::enable_if<small>>
|
||||
inline void mult_mask_kq(int nq, int stride_k, int stride_q, int stride_m,
|
||||
const char * k, const float * q, const char * mask) {
|
||||
template <typename KHelper, bool small = is_small_head, class = std::enable_if<small>>
|
||||
inline void mult_mask_kq(int nq, const KHelper& kh, int stride_q, int stride_m, const float * q, const char * mask) {
|
||||
__m512 qv[D/16];
|
||||
for (int l1 = 0; l1 < k_step; l1 += 2) {
|
||||
auto kr1 = (const ggml_half *)(k + (l1 + 0)*stride_k);
|
||||
auto kr2 = (const ggml_half *)(k + (l1 + 1)*stride_k);
|
||||
for (int i = 0; i < D/16; ++i) vk[i+ 0] = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr1 + i));
|
||||
for (int i = 0; i < D/16; ++i) vk[i+D/16] = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr2 + i));
|
||||
kh.load_2(l1, vk);
|
||||
for (int m1 = 0; m1 < nq; ++m1) {
|
||||
mult_mask_kq_one(l1, m1, stride_q, stride_m, q, mask, qv);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <bool small = is_small_head, class = std::enable_if<!small>>
|
||||
inline void mult_mask_kq_l(int nq, int stride_k, int stride_q, int stride_m,
|
||||
const char * k, const float * q, const char * mask) {
|
||||
template <typename KHelper, bool small = is_small_head, class = std::enable_if<!small>>
|
||||
inline void mult_mask_kq_l(int nq, const KHelper& kh, int stride_q, int stride_m,
|
||||
const float * q, const char * mask) {
|
||||
for (int l1 = 0; l1 < k_step; ++l1) {
|
||||
auto kr = (const ggml_half *)(k + l1*stride_k);
|
||||
for (int i = 0; i < D/16; ++i) vk[i] = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr + i));
|
||||
kh.load(l1, vk);
|
||||
for (int m1 = 0; m1 < nq; ++m1) {
|
||||
mult_mask_kq_one(l1, m1, stride_q, stride_m, q, mask);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
inline void multiply_mask_kq(int stride_k, int stride_q, int stride_m, const char * k, const float * q, const char * mask) {
|
||||
template <typename KHelper>
|
||||
inline void multiply_mask_kq(const KHelper& kh, int stride_q, int stride_m, const float * q, const char * mask) {
|
||||
if constexpr (is_small_head) {
|
||||
mult_mask_kq(stride_k, stride_q, stride_m, k, q, mask);
|
||||
mult_mask_kq(kh, stride_q, stride_m, q, mask);
|
||||
}
|
||||
else {
|
||||
mult_mask_kq_l(stride_k, stride_q, stride_m, k, q, mask);
|
||||
mult_mask_kq_l(kh, stride_q, stride_m, q, mask);
|
||||
}
|
||||
for (int j = 0; j < q_step; ++j) {
|
||||
update_M_S(j);
|
||||
}
|
||||
}
|
||||
|
||||
inline void multiply_mask_kq(int nq, int stride_k, int stride_q, int stride_m, const char * k, const float * q, const char * mask) {
|
||||
template <typename KHelper>
|
||||
inline void multiply_mask_kq(int nq, const KHelper& kh, int stride_q, int stride_m, const float * q, const char * mask) {
|
||||
if constexpr (is_small_head) {
|
||||
mult_mask_kq(nq, stride_k, stride_q, stride_m, k, q, mask);
|
||||
mult_mask_kq(nq, kh, stride_q, stride_m, q, mask);
|
||||
}
|
||||
else {
|
||||
mult_mask_kq_l(nq, stride_k, stride_q, stride_m, k, q, mask);
|
||||
mult_mask_kq_l(nq, kh, stride_q, stride_m, q, mask);
|
||||
}
|
||||
for (int j = 0; j < nq; ++j) {
|
||||
update_M_S(j);
|
||||
@@ -6271,7 +6292,8 @@ struct FlashAttn {
|
||||
|
||||
// This fails for head sizes of 80 and 112 as D/16 is odd, so we cannot do steps of 2
|
||||
// Hence, for now, we will not handle head sizes of 80 and 112
|
||||
inline void accumulate_qkv(int stride_v, const char * v) {
|
||||
template <typename VHelper>
|
||||
inline void accumulate_qkv(const VHelper& vh) {
|
||||
for (int i = 0; i < D/16; i += 2) {
|
||||
for (int j = 0; j < q_step; ++j) {
|
||||
if (need_scaling[j] == 2) {
|
||||
@@ -6286,10 +6308,9 @@ struct FlashAttn {
|
||||
}
|
||||
}
|
||||
}
|
||||
__m512 v1, v2;
|
||||
for (int l1 = 0; l1 < k_step; ++l1) {
|
||||
auto vr = (const ggml_half *)(v + l1*stride_v);
|
||||
auto v1 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)vr+i+0));
|
||||
auto v2 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)vr+i+1));
|
||||
vh.load(l1, i, v1, v2);
|
||||
for (int j = 0; j < q_step; ++j) {
|
||||
auto vs = _mm512_set1_ps(cache[k_step*j + l1]);
|
||||
vk[2*j+0] = _mm512_fmadd_ps(v1, vs, vk[2*j+0]);
|
||||
@@ -6304,8 +6325,8 @@ struct FlashAttn {
|
||||
}
|
||||
}
|
||||
|
||||
template <int Nq = q_step, class = std::enable_if<Nq >= 2>>
|
||||
inline void accumulate_qkv(int nq1, int stride_v, const char * v) {
|
||||
template <typename VHelper, int Nq = q_step, class = std::enable_if<Nq >= 2>>
|
||||
inline void accumulate_qkv(int nq1, const VHelper& vh) {
|
||||
for (int i = 0; i < D/16; i += 2) {
|
||||
for (int j = 0; j < nq1; ++j) {
|
||||
if (need_scaling[j] == 2) {
|
||||
@@ -6320,10 +6341,9 @@ struct FlashAttn {
|
||||
}
|
||||
}
|
||||
}
|
||||
__m512 v1, v2;
|
||||
for (int l1 = 0; l1 < k_step; ++l1) {
|
||||
auto vr = (const ggml_half *)(v + l1*stride_v);
|
||||
auto v1 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)vr+i+0));
|
||||
auto v2 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)vr+i+1));
|
||||
vh.load(l1, i, v1, v2);
|
||||
for (int j = 0; j < nq1; ++j) {
|
||||
auto vs = _mm512_set1_ps(cache[k_step*j + l1]);
|
||||
vk[2*j+0] = _mm512_fmadd_ps(v1, vs, vk[2*j+0]);
|
||||
@@ -6340,16 +6360,18 @@ struct FlashAttn {
|
||||
|
||||
void compute(int nq1, int nk1, int stride_k, int stride_q, int stride_m, int stride_v, int stride_qkv,
|
||||
const char * k, const float * q, const char * mask, const char * v, float * qkv) {
|
||||
KHelperF16<D, k_step> kh(k, stride_k);
|
||||
KHelperF16<D, k_step> vh(v, stride_v);
|
||||
for (int i1 = 0; i1 < nq1/q_step; ++i1) {
|
||||
init_qstep();
|
||||
auto kr = k;
|
||||
auto vr = v;
|
||||
kh.reset_block();
|
||||
vh.reset_block();
|
||||
auto mr = mask;
|
||||
for (int k1 = 0; k1 < nk1/k_step; ++k1) {
|
||||
multiply_mask_kq(stride_k, stride_q, stride_m, kr, q, mr);
|
||||
accumulate_qkv(stride_v, vr);
|
||||
kr += k_step*stride_k;
|
||||
vr += k_step*stride_v;
|
||||
multiply_mask_kq(kh, stride_q, stride_m, q, mr);
|
||||
accumulate_qkv(vh);
|
||||
kh.next_block();
|
||||
vh.next_block();
|
||||
mr += k_step*sizeof(ggml_half);
|
||||
}
|
||||
normalize_and_store(stride_qkv, qkv);
|
||||
@@ -6361,14 +6383,14 @@ struct FlashAttn {
|
||||
int n_left = nq1 - q_step*(nq1/q_step);
|
||||
if (n_left > 0) {
|
||||
init_qstep();
|
||||
auto kr = k;
|
||||
auto vr = v;
|
||||
kh.reset_block();
|
||||
vh.reset_block();
|
||||
auto mr = mask;
|
||||
for (int k1 = 0; k1 < nk1/k_step; ++k1) {
|
||||
multiply_mask_kq(n_left, stride_k, stride_q, stride_m, kr, q, mr);
|
||||
accumulate_qkv(n_left, stride_v, vr);
|
||||
kr += k_step*stride_k;
|
||||
vr += k_step*stride_v;
|
||||
multiply_mask_kq(n_left, kh, stride_q, stride_m, q, mr);
|
||||
accumulate_qkv(n_left, vh);
|
||||
kh.next_block();
|
||||
vh.next_block();
|
||||
mr += k_step*sizeof(ggml_half);
|
||||
}
|
||||
normalize_and_store(n_left, stride_qkv, qkv);
|
||||
|
||||
Reference in New Issue
Block a user