mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-30 19:31:48 +00:00
Zen4 Flash Attnetion: it works for q4_0 and q8_0
This commit is contained in:
@@ -16150,8 +16150,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
||||
}
|
||||
|
||||
#if GGML_USE_IQK_MULMAT
|
||||
if (max_bias <= 0.0f && q->type == GGML_TYPE_F32 && k->type == GGML_TYPE_F16 && v->type == GGML_TYPE_F16 &&
|
||||
mask && mask->type == GGML_TYPE_F16) {
|
||||
if (max_bias <= 0.0f && q->type == GGML_TYPE_F32 && mask && mask->type == GGML_TYPE_F16) {
|
||||
int64_t work_per_slice = D*nek1*neq1;
|
||||
int ntg = 1;
|
||||
if (nth%8 == 0 && neq1%8 == 0 && work_per_slice >= (1 << 23)) ntg = 8;
|
||||
@@ -16165,7 +16164,8 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
||||
for (int64_t iq2 = 0; iq2 < neq2; iq2++) {
|
||||
if (counter++ % (nth/ntg) == ith/ntg) {
|
||||
int iq1 = (ith%ntg)*neq1/ntg;
|
||||
if (!iqk_flash_attn_noalibi(D, neq1/ntg, nek1, q->nb[1], k->nb[1], v->nb[1], mask->nb[1], ne1*nb1/sizeof(float),
|
||||
if (!iqk_flash_attn_noalibi(k->type, v->type,
|
||||
D, neq1/ntg, nek1, q->nb[1], k->nb[1], v->nb[1], mask->nb[1], ne1*nb1/sizeof(float),
|
||||
(const float *)((const char *)q->data + iq2*q->nb[2] + iq3*q->nb[3] + iq1*q->nb[1]),
|
||||
(const void *)((const char *)k->data + iq2/rk2*k->nb[2] + iq3/rk3*k->nb[3]),
|
||||
(const void *)((const char *)v->data + iq2/rv2*v->nb[2] + iq3/rv3*v->nb[3]),
|
||||
|
||||
@@ -6057,23 +6057,147 @@ inline __m256 v_tanh(__m256 x) {
|
||||
|
||||
namespace {
|
||||
|
||||
template <int D, int k_step>
|
||||
struct KHelperF16 {
|
||||
KHelperF16(const char * k, int stride_k) : k(k), kb(k), stride_k(stride_k) {}
|
||||
//template <int D, int k_step>
|
||||
//struct HelperF16 {
|
||||
// HelperF16(const char * data, int stride) : data(data), stride(stride) {}
|
||||
//
|
||||
// inline void set_block(int k1) { block = data + k1*k_step*stride; }
|
||||
// inline void reset_block() { block = data; }
|
||||
// inline void next_block() { block += k_step*stride; }
|
||||
// inline const char * lblock(int l1) const { return block + l1*stride; }
|
||||
//
|
||||
// inline void load(int l1, __m512 * vk) const {
|
||||
// auto dr = lblock(l1);
|
||||
// for (int i = 0; i < D/16; ++i) vk[i] = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)dr + i));
|
||||
// }
|
||||
//
|
||||
// inline void load(int l1, int i, __m512& v1, __m512& v2) const {
|
||||
// auto dr = lblock(l1);
|
||||
// v1 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)dr + i + 0));
|
||||
// v2 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)dr + i + 1));
|
||||
// }
|
||||
//
|
||||
// inline void load_2(int l1, __m512 * vk) const {
|
||||
// load(l1+0, vk+0);
|
||||
// load(l1+1, vk+D/16);
|
||||
// }
|
||||
//
|
||||
// const char * data;
|
||||
// const char * block;
|
||||
// int stride;
|
||||
//
|
||||
//};
|
||||
//
|
||||
//template <int D, int k_step>
|
||||
//struct HelperQ80 {
|
||||
// static_assert(k_step == QK8_0);
|
||||
// HelperQ80(const char * data, int stride) : data(data), stride(stride) {}
|
||||
//
|
||||
// inline void set_block(int k1) { block = data + k1*k_step*stride; }
|
||||
// inline void reset_block() { block = data; }
|
||||
// inline void next_block() { block += k_step*stride; }
|
||||
// inline const char * lblock(int l1) const { return block + l1*stride; }
|
||||
//
|
||||
// inline void load(int l1, __m512 * vk) const {
|
||||
// auto dl = (const block_q8_0 *)lblock(l1);
|
||||
// for (int i = 0; i < D/32; ++i) {
|
||||
// auto vd = _mm512_set1_ps(GGML_FP16_TO_FP32(dl[i].d));
|
||||
// vk[2*i+0] = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)dl[i].qs+0))));
|
||||
// vk[2*i+1] = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)dl[i].qs+1))));
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// inline void load(int l1, int i, __m512& v1, __m512& v2) const {
|
||||
// auto dl = (const block_q8_0 *)lblock(l1) + i/2;
|
||||
// auto vd = _mm512_set1_ps(GGML_FP16_TO_FP32(dl->d));
|
||||
// v1 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)dl->qs+0))));
|
||||
// v2 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)dl->qs+1))));
|
||||
// }
|
||||
//
|
||||
// inline void load_2(int l1, __m512 * vk) const {
|
||||
// load(l1+0, vk+0);
|
||||
// load(l1+1, vk+D/16);
|
||||
// }
|
||||
//
|
||||
// const char * data;
|
||||
// const char * block;
|
||||
// int stride;
|
||||
//};
|
||||
//
|
||||
//template <int D, int k_step>
|
||||
//struct HelperQ40 {
|
||||
// static_assert(k_step == QK4_0);
|
||||
// HelperQ40(const char * data, int stride) : data(data), stride(stride) {}
|
||||
//
|
||||
// inline void set_block(int k1) { block = data + k1*k_step*stride; }
|
||||
// inline void reset_block() { block = data; }
|
||||
// inline void next_block() { block += k_step*stride; }
|
||||
// inline const char * lblock(int l1) const { return block + l1*stride; }
|
||||
//
|
||||
// inline void load(int l1, __m512 * vk) const {
|
||||
// auto dl = (const block_q4_0 *)lblock(l1);
|
||||
// for (int i = 0; i < D/32; ++i) {
|
||||
// auto vd = _mm512_set1_ps(GGML_FP16_TO_FP32(dl[i].d));
|
||||
// auto q = _mm_loadu_si128((const __m128i *)dl[i].qs);
|
||||
// auto ql = _mm_add_epi8(_mm_and_si128(q, mask), m8);
|
||||
// auto qh = _mm_add_epi8(_mm_and_si128(_mm_srli_epi16(q, 4), mask), m8);
|
||||
// vk[2*i+0] = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(ql)));
|
||||
// vk[2*i+1] = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(qh)));
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// inline void load(int l1, int i, __m512& v1, __m512& v2) const {
|
||||
// auto dl = (const block_q4_0 *)lblock(l1) + i/2;
|
||||
// auto vd = _mm512_set1_ps(GGML_FP16_TO_FP32(dl->d));
|
||||
// auto q = _mm_loadu_si128((const __m128i *)dl->qs);
|
||||
// auto ql = _mm_add_epi8(_mm_and_si128(q, mask), m8);
|
||||
// auto qh = _mm_add_epi8(_mm_and_si128(_mm_srli_epi16(q, 4), mask), m8);
|
||||
// v1 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(ql)));
|
||||
// v2 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(qh)));
|
||||
// }
|
||||
//
|
||||
// inline void load_2(int l1, __m512 * vk) const {
|
||||
// load(l1+0, vk+0);
|
||||
// load(l1+1, vk+D/16);
|
||||
// }
|
||||
//
|
||||
// const __m128i mask = _mm_set1_epi8(0xf);
|
||||
// const __m128i m8 = _mm_set1_epi8(-8);
|
||||
//
|
||||
// const char * data;
|
||||
// const char * block;
|
||||
// int stride;
|
||||
//};
|
||||
|
||||
inline void set_block(int k1) { kb = k + k1*k_step*stride_k; }
|
||||
inline void reset_block() { kb = k; }
|
||||
inline void next_block() { kb += k_step*stride_k; }
|
||||
template <int k_step>
|
||||
struct BaseHelper {
|
||||
BaseHelper(const char * data, int stride) : data(data), block(data), stride(stride) {}
|
||||
|
||||
inline void set_block(int k1) { block = data + k1*k_step*stride; }
|
||||
inline void reset_block() { block = data; }
|
||||
inline void next_block() { block += k_step*stride; }
|
||||
inline const char * lblock(int l1) const { return block + l1*stride; }
|
||||
|
||||
const char * data;
|
||||
const char * block;
|
||||
int stride;
|
||||
|
||||
};
|
||||
|
||||
template <int D, int step>
|
||||
struct HelperF16 final : public BaseHelper<step> {
|
||||
using Base = BaseHelper<step>;
|
||||
HelperF16(const char * data, int stride) : Base(data, stride) {}
|
||||
|
||||
inline void load(int l1, __m512 * vk) const {
|
||||
auto kr = (const ggml_half *)(kb + l1*stride_k);
|
||||
for (int i = 0; i < D/16; ++i) vk[i] = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr + i));
|
||||
auto dr = Base::lblock(l1);
|
||||
for (int i = 0; i < D/16; ++i) vk[i] = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)dr + i));
|
||||
}
|
||||
|
||||
inline void load(int l1, int i, __m512& v1, __m512& v2) const {
|
||||
auto kr = (const ggml_half *)(kb + l1*stride_k);
|
||||
v1 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr + i + 0));
|
||||
v2 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr + i + 1));
|
||||
auto dr = (const ggml_half *)Base::lblock(l1);
|
||||
v1 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)dr + i + 0));
|
||||
v2 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)dr + i + 1));
|
||||
}
|
||||
|
||||
inline void load_2(int l1, __m512 * vk) const {
|
||||
@@ -6081,11 +6205,149 @@ struct KHelperF16 {
|
||||
load(l1+1, vk+D/16);
|
||||
}
|
||||
|
||||
const char * k;
|
||||
const char * kb;
|
||||
int stride_k;
|
||||
};
|
||||
|
||||
template <int D, int step>
|
||||
struct HelperQ80 final : public BaseHelper<step> {
|
||||
static_assert(step == QK8_0);
|
||||
using Base = BaseHelper<step>;
|
||||
HelperQ80(const char * data, int stride) : Base(data, stride) {}
|
||||
|
||||
//inline void load(int l1, __m512 * vk) const {
|
||||
// auto dl = (const block_q8_0 *)Base::lblock(l1);
|
||||
// for (int i = 0; i < D/32; ++i) {
|
||||
// auto vd = _mm512_set1_ps(GGML_FP16_TO_FP32(dl[i].d));
|
||||
// vk[2*i+0] = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)dl[i].qs+0))));
|
||||
// vk[2*i+1] = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)dl[i].qs+1))));
|
||||
// }
|
||||
//}
|
||||
inline void load(int l1, __m512 * vk) const {
|
||||
auto dl = (const block_q8_0_x4 *)Base::lblock(l1);
|
||||
for (int i = 0; i < D/32; ++i) {
|
||||
const auto& b8 = dl[i/4];
|
||||
int ii = i%4;
|
||||
auto vd = _mm512_set1_ps(GGML_FP16_TO_FP32(b8.d[ii]));
|
||||
vk[2*i+0] = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)b8.qs+2*ii+0))));
|
||||
vk[2*i+1] = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)b8.qs+2*ii+1))));
|
||||
}
|
||||
}
|
||||
|
||||
inline void load(int l1, int i, __m512& v1, __m512& v2) const {
|
||||
auto dl = (const block_q8_0_x4 *)Base::lblock(l1) + i/8;
|
||||
int ii = (i/2)%4;
|
||||
auto vd = _mm512_set1_ps(GGML_FP16_TO_FP32(dl->d[ii]));
|
||||
v1 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)dl->qs+2*ii+0))));
|
||||
v2 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)dl->qs+2*ii+1))));
|
||||
}
|
||||
|
||||
inline void load_2(int l1, __m512 * vk) const {
|
||||
load(l1+0, vk+0);
|
||||
load(l1+1, vk+D/16);
|
||||
}
|
||||
};
|
||||
|
||||
template <int D, int step>
|
||||
struct HelperQ40 final : public BaseHelper<step> {
|
||||
static_assert(step == QK4_0);
|
||||
using Base = BaseHelper<step>;
|
||||
HelperQ40(const char * data, int stride) : Base(data, stride) {}
|
||||
|
||||
|
||||
inline void load(int l1, __m512 * vk) const {
|
||||
auto dl = (const block_q4_0 *)Base::lblock(l1);
|
||||
for (int i = 0; i < D/32; ++i) {
|
||||
auto vd = _mm512_set1_ps(GGML_FP16_TO_FP32(dl[i].d));
|
||||
auto q = _mm_loadu_si128((const __m128i *)dl[i].qs);
|
||||
auto ql = _mm_add_epi8(_mm_and_si128(q, mask), m8);
|
||||
auto qh = _mm_add_epi8(_mm_and_si128(_mm_srli_epi16(q, 4), mask), m8);
|
||||
vk[2*i+0] = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(ql)));
|
||||
vk[2*i+1] = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(qh)));
|
||||
}
|
||||
}
|
||||
|
||||
inline void load(int l1, int i, __m512& v1, __m512& v2) const {
|
||||
auto dl = (const block_q4_0 *)Base::lblock(l1) + i/2;
|
||||
auto vd = _mm512_set1_ps(GGML_FP16_TO_FP32(dl->d));
|
||||
auto q = _mm_loadu_si128((const __m128i *)dl->qs);
|
||||
auto ql = _mm_add_epi8(_mm_and_si128(q, mask), m8);
|
||||
auto qh = _mm_add_epi8(_mm_and_si128(_mm_srli_epi16(q, 4), mask), m8);
|
||||
v1 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(ql)));
|
||||
v2 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(qh)));
|
||||
}
|
||||
|
||||
inline void load_2(int l1, __m512 * vk) const {
|
||||
load(l1+0, vk+0);
|
||||
load(l1+1, vk+D/16);
|
||||
}
|
||||
|
||||
const __m128i mask = _mm_set1_epi8(0xf);
|
||||
const __m128i m8 = _mm_set1_epi8(-8);
|
||||
};
|
||||
|
||||
//template <int D, int k_step>
|
||||
//struct KHelperF16 {
|
||||
// KHelperF16(const char * k, int stride_k) : k(k), kb(k), stride_k(stride_k) {}
|
||||
//
|
||||
// inline void set_block(int k1) { kb = k + k1*k_step*stride_k; }
|
||||
// inline void reset_block() { kb = k; }
|
||||
// inline void next_block() { kb += k_step*stride_k; }
|
||||
//
|
||||
// inline void load(int l1, __m512 * vk) const {
|
||||
// auto kr = (const ggml_half *)(kb + l1*stride_k);
|
||||
// for (int i = 0; i < D/16; ++i) vk[i] = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr + i));
|
||||
// }
|
||||
//
|
||||
// inline void load(int l1, int i, __m512& v1, __m512& v2) const {
|
||||
// auto kr = (const ggml_half *)(kb + l1*stride_k);
|
||||
// v1 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr + i + 0));
|
||||
// v2 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr + i + 1));
|
||||
// }
|
||||
//
|
||||
// inline void load_2(int l1, __m512 * vk) const {
|
||||
// load(l1+0, vk+0);
|
||||
// load(l1+1, vk+D/16);
|
||||
// }
|
||||
//
|
||||
// const char * k;
|
||||
// const char * kb;
|
||||
// int stride_k;
|
||||
//};
|
||||
//
|
||||
//template <int D, int k_step>
|
||||
//struct KHelperQ80 {
|
||||
// static_assert(k_step == QK8_0);
|
||||
// KHelperQ80(const char * k, int stride_k) : k(k), kb(k), stride_k(stride_k) {}
|
||||
//
|
||||
// inline void set_block(int k1) { kb = k + k1*k_step*stride_k; }
|
||||
// inline void reset_block() { kb = k; }
|
||||
// inline void next_block() { kb += k_step*stride_k; }
|
||||
//
|
||||
// inline void load(int l1, __m512 * vk) const {
|
||||
// auto kr = (const block_q8_0 *)(kb + l1*stride_k);
|
||||
// for (int i = 0; i < D/32; ++i) {
|
||||
// auto vd = _mm512_set1_ps(GGML_FP16_TO_FP32(kr[i].d));
|
||||
// vk[2*i+0] = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)kr[i].qs+0))));
|
||||
// vk[2*i+1] = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)kr[i].qs+1))));
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// inline void load(int l1, int i, __m512& v1, __m512& v2) const {
|
||||
// auto kr = (const block_q8_0 *)(kb + l1*stride_k) + i/2;
|
||||
// auto vd = _mm512_set1_ps(GGML_FP16_TO_FP32(kr->d));
|
||||
// v1 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)kr->qs+0))));
|
||||
// v2 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)kr->qs+1))));
|
||||
// }
|
||||
//
|
||||
// inline void load_2(int l1, __m512 * vk) const {
|
||||
// load(l1+0, vk+0);
|
||||
// load(l1+1, vk+D/16);
|
||||
// }
|
||||
//
|
||||
// const char * k;
|
||||
// const char * kb;
|
||||
// int stride_k;
|
||||
//};
|
||||
|
||||
// Some of the methods in FlashAttn have two identical implementations that only differ by
|
||||
// one version using a loop over the template parameter q_step, while the other using a loop
|
||||
// over an input parameter nq (these are loops over the rows of q^T). I dislike this a lot,
|
||||
@@ -6358,10 +6620,48 @@ struct FlashAttn {
|
||||
}
|
||||
}
|
||||
|
||||
void compute(int nq1, int nk1, int stride_k, int stride_q, int stride_m, int stride_v, int stride_qkv,
|
||||
const char * k, const float * q, const char * mask, const char * v, float * qkv) {
|
||||
KHelperF16<D, k_step> kh(k, stride_k);
|
||||
KHelperF16<D, k_step> vh(v, stride_v);
|
||||
//void compute(int nq1, int nk1, int stride_k, int stride_q, int stride_m, int stride_v, int stride_qkv,
|
||||
// const char * k, const float * q, const char * mask, const char * v, float * qkv) {
|
||||
// KHelperF16<D, k_step> kh(k, stride_k);
|
||||
// KHelperF16<D, k_step> vh(v, stride_v);
|
||||
// for (int i1 = 0; i1 < nq1/q_step; ++i1) {
|
||||
// init_qstep();
|
||||
// kh.reset_block();
|
||||
// vh.reset_block();
|
||||
// auto mr = mask;
|
||||
// for (int k1 = 0; k1 < nk1/k_step; ++k1) {
|
||||
// multiply_mask_kq(kh, stride_q, stride_m, q, mr);
|
||||
// accumulate_qkv(vh);
|
||||
// kh.next_block();
|
||||
// vh.next_block();
|
||||
// mr += k_step*sizeof(ggml_half);
|
||||
// }
|
||||
// normalize_and_store(stride_qkv, qkv);
|
||||
|
||||
// q += q_step*stride_q;
|
||||
// mask += q_step*stride_m;
|
||||
// qkv += q_step*stride_qkv;
|
||||
// }
|
||||
// int n_left = nq1 - q_step*(nq1/q_step);
|
||||
// if (n_left > 0) {
|
||||
// init_qstep();
|
||||
// kh.reset_block();
|
||||
// vh.reset_block();
|
||||
// auto mr = mask;
|
||||
// for (int k1 = 0; k1 < nk1/k_step; ++k1) {
|
||||
// multiply_mask_kq(n_left, kh, stride_q, stride_m, q, mr);
|
||||
// accumulate_qkv(n_left, vh);
|
||||
// kh.next_block();
|
||||
// vh.next_block();
|
||||
// mr += k_step*sizeof(ggml_half);
|
||||
// }
|
||||
// normalize_and_store(n_left, stride_qkv, qkv);
|
||||
// }
|
||||
//}
|
||||
|
||||
template <typename KHelper, typename VHelper>
|
||||
void compute(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv,
|
||||
const float * q, const char * mask, float * qkv) {
|
||||
for (int i1 = 0; i1 < nq1/q_step; ++i1) {
|
||||
init_qstep();
|
||||
kh.reset_block();
|
||||
@@ -6427,24 +6727,89 @@ struct FlashAttn {
|
||||
}
|
||||
};
|
||||
|
||||
template <int D, int q_step, int k_step>
|
||||
inline void iqk_flash_helper_T(int nq1, int nk1, int stride_q, int stride_k, int stride_v, int stride_m, int stride_qkv,
|
||||
const float * q, const char * k, const char * v, const char * mask,
|
||||
float scale, float softcap, float * qkv) {
|
||||
template <int D, int q_step, int k_step, typename KHelper, typename VHelper>
|
||||
inline void iqk_flash_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv,
|
||||
const float * q, const char * mask, float scale, float softcap, float * qkv) {
|
||||
|
||||
if (nq1 >= q_step) {
|
||||
FlashAttn<D, q_step, k_step> fa(scale, softcap);
|
||||
fa.compute(nq1, nk1, stride_k, stride_q, stride_m, stride_v, stride_qkv,
|
||||
(const char *)k, q, (const char *)mask, (const char *)v, qkv);
|
||||
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv);
|
||||
} else {
|
||||
FlashAttn<D, 1, k_step> fa(scale, softcap);
|
||||
fa.compute(nq1, nk1, stride_k, stride_q, stride_m, stride_v, stride_qkv,
|
||||
(const char *)k, q, (const char *)mask, (const char *)v, qkv);
|
||||
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv);
|
||||
}
|
||||
}
|
||||
|
||||
template <int D, int q_step, int k_step, typename KHelper>
|
||||
inline void iqk_flash_helper_T(KHelper& kh, ggml_type type_v,
|
||||
int nq1, int nk1, int stride_q, int stride_v, int stride_m, int stride_qkv,
|
||||
const float * q, const char * v, const char * mask,
|
||||
float scale, float softcap, float * qkv) {
|
||||
|
||||
switch (type_v) {
|
||||
case GGML_TYPE_F16: {
|
||||
HelperF16<D, k_step> vh(v, stride_v);
|
||||
iqk_flash_helper<D, q_step, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv);
|
||||
} break;
|
||||
case GGML_TYPE_Q8_0: {
|
||||
HelperQ80<D, k_step> vh(v, stride_v);
|
||||
iqk_flash_helper<D, q_step, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv);
|
||||
} break;
|
||||
case GGML_TYPE_Q4_0: {
|
||||
HelperQ40<D, k_step> vh(v, stride_v);
|
||||
iqk_flash_helper<D, q_step, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv);
|
||||
} break;
|
||||
default: break;
|
||||
}
|
||||
}
|
||||
|
||||
template <int D, int q_step, int k_step>
|
||||
inline void iqk_flash_helper_T(ggml_type type_k, ggml_type type_v,
|
||||
int nq1, int nk1, int stride_q, int stride_k, int stride_v, int stride_m, int stride_qkv,
|
||||
const float * q, const char * k, const char * v, const char * mask,
|
||||
float scale, float softcap, float * qkv) {
|
||||
|
||||
switch (type_k) {
|
||||
case GGML_TYPE_F16: {
|
||||
HelperF16<D, k_step> kh(k, stride_k);
|
||||
iqk_flash_helper_T<D, q_step, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv);
|
||||
} break;
|
||||
case GGML_TYPE_Q8_0: {
|
||||
HelperQ80<D, k_step> kh(k, stride_k);
|
||||
iqk_flash_helper_T<D, q_step, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv);
|
||||
} break;
|
||||
case GGML_TYPE_Q4_0: {
|
||||
HelperQ40<D, k_step> kh(k, stride_k);
|
||||
iqk_flash_helper_T<D, q_step, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv);
|
||||
} break;
|
||||
default: break;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
//template <int D, int q_step, int k_step>
|
||||
//inline void iqk_flash_helper_T(int nq1, int nk1, int stride_q, int stride_k, int stride_v, int stride_m, int stride_qkv,
|
||||
// const float * q, const char * k, const char * v, const char * mask,
|
||||
// float scale, float softcap, float * qkv) {
|
||||
//
|
||||
// KHelperF16<D, k_step> kh((const char *)k, stride_k);
|
||||
// KHelperF16<D, k_step> vh((const char *)v, stride_v);
|
||||
// if (nq1 >= q_step) {
|
||||
// FlashAttn<D, q_step, k_step> fa(scale, softcap);
|
||||
// fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv);
|
||||
// } else {
|
||||
// FlashAttn<D, 1, k_step> fa(scale, softcap);
|
||||
// fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv);
|
||||
// }
|
||||
//}
|
||||
inline bool flash_attn_is_supported(ggml_type type) {
|
||||
return type == GGML_TYPE_F16 || type == GGML_TYPE_Q8_0 || type == GGML_TYPE_Q4_0;
|
||||
}
|
||||
}
|
||||
|
||||
bool iqk_flash_attn_noalibi(int D, // head size
|
||||
bool iqk_flash_attn_noalibi(int int_type_k, // type of k
|
||||
int int_type_v, // type of v
|
||||
int D, // head size
|
||||
int nq1, // number of columns in q
|
||||
int nk1, // number of rows in k
|
||||
int stride_q, // distance between q columns in bytes
|
||||
@@ -6460,6 +6825,9 @@ bool iqk_flash_attn_noalibi(int D, // head size
|
||||
float softcap, // if > 0, a "soft-cap" operation is applied before softmax
|
||||
float * qkv) { // v*softmax(scale*(k*q))
|
||||
|
||||
auto type_k = ggml_type(int_type_k);
|
||||
auto type_v = ggml_type(int_type_v);
|
||||
if (!flash_attn_is_supported(type_k) || !flash_attn_is_supported(type_v)) return false;
|
||||
if (!mask || nk1%32 != 0) return false; // the implementation assumes mask is not null and nk is a multiple of 32
|
||||
|
||||
auto ck = (const char *)k;
|
||||
@@ -6470,19 +6838,19 @@ bool iqk_flash_attn_noalibi(int D, // head size
|
||||
|
||||
switch (D) {
|
||||
case 64:
|
||||
iqk_flash_helper_T< 64, 4, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
|
||||
iqk_flash_helper_T< 64, 4, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
|
||||
// Disable until we fix accumulate_qkv for odd D/16
|
||||
//case 80:
|
||||
// iqk_flash_helper_T< 80, 4, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
|
||||
case 96:
|
||||
iqk_flash_helper_T< 96, 4, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
|
||||
iqk_flash_helper_T< 96, 4, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
|
||||
// Disable until we fix accumulate_qkv for odd D/16
|
||||
//case 112:
|
||||
// iqk_flash_helper_T<112, 4, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
|
||||
case 128:
|
||||
iqk_flash_helper_T<128, 8, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
|
||||
iqk_flash_helper_T<128, 8, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
|
||||
case 256:
|
||||
iqk_flash_helper_T<256, 8, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
|
||||
iqk_flash_helper_T<256, 8, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
@@ -6492,7 +6860,9 @@ bool iqk_flash_attn_noalibi(int D, // head size
|
||||
|
||||
#else
|
||||
// TODO
|
||||
bool iqk_flash_attn_noalibi([[maybe_unused]] int D, // head size
|
||||
bool iqk_flash_attn_noalibi([[maybe_unused]] int int_type_k, // type of k
|
||||
[[maybe_unused]] int int_type_v, // type of v
|
||||
[[maybe_unused]] int D, // head size
|
||||
[[maybe_unused]] int nq, // number of columns in q
|
||||
[[maybe_unused]] int nk, // number of rows in k
|
||||
[[maybe_unused]] int stride_q, // distance between q columns in bytes
|
||||
@@ -6523,7 +6893,9 @@ bool iqk_mul_mat_moe(long, long, long, int, int, const void *, long, int, const
|
||||
return false;
|
||||
}
|
||||
|
||||
bool iqk_flash_attn_noalibi([[maybe_unused]] int D, // head size
|
||||
bool iqk_flash_attn_noalibi([[maybe_unused]] int int_type_k, // type of k
|
||||
[[maybe_unused]] int int_type_v, // type of v
|
||||
[[maybe_unused]] int D, // head size
|
||||
[[maybe_unused]] int nq, // number of columns in q
|
||||
[[maybe_unused]] int nk, // number of rows in k
|
||||
[[maybe_unused]] int stride_q, // distance between q columns in bytes
|
||||
|
||||
@@ -21,7 +21,9 @@ bool iqk_mul_mat_moe(long Nx, long Ny, long ne00, int ne11,
|
||||
int typeB, const void * B, long strideB,
|
||||
float * C, long nb1, long nb2, const void * vrow_mapping, int ith, int nth);
|
||||
|
||||
bool iqk_flash_attn_noalibi(int D, // head size
|
||||
bool iqk_flash_attn_noalibi(int type_k, // type of k
|
||||
int type_v, // type of v
|
||||
int D, // head size
|
||||
int nq, // number of columns in q
|
||||
int nk, // number of rows in k
|
||||
int stride_q, // distance between q columns in bytes
|
||||
|
||||
Reference in New Issue
Block a user