mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-30 03:11:51 +00:00
Zen4 Flash Attnetion: add q4_1
This commit is contained in:
@@ -6161,13 +6161,40 @@ struct HelperQ40 final : public BaseHelper<step> {
|
||||
|
||||
inline void load(int l1, __m512 * vk) const {
|
||||
auto dl = (const block_q4_0 *)Base::lblock(l1);
|
||||
for (int i = 0; i < D/32; ++i) {
|
||||
auto vd = _mm512_set1_ps(GGML_FP16_TO_FP32(dl[i].d));
|
||||
auto q = _mm_loadu_si128((const __m128i *)dl[i].qs);
|
||||
auto ql = _mm_add_epi8(_mm_and_si128(q, mask), m8);
|
||||
auto qh = _mm_add_epi8(_mm_and_si128(_mm_srli_epi16(q, 4), mask), m8);
|
||||
vk[2*i+0] = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(ql)));
|
||||
vk[2*i+1] = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(qh)));
|
||||
if constexpr (D >= 128) {
|
||||
ggml_half aux[4];
|
||||
__m512 vd[4];
|
||||
for (int ib = 0; ib < D/128; ++ib) {
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
auto& b4 = dl[4*ib+i];
|
||||
aux[i] = b4.d;
|
||||
auto q = _mm_loadu_si128((const __m128i *)b4.qs);
|
||||
auto ql = _mm_add_epi8(_mm_and_si128(q, mask), m8);
|
||||
auto qh = _mm_add_epi8(_mm_and_si128(_mm_srli_epi16(q, 4), mask), m8);
|
||||
vk[8*ib+2*i+0] = _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(ql));
|
||||
vk[8*ib+2*i+1] = _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(qh));
|
||||
}
|
||||
auto scales4 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)aux));
|
||||
auto scales8 = _mm256_insertf128_ps(_mm256_castps128_ps256(scales4), scales4, 1);
|
||||
auto scales = _mm512_insertf32x8(_mm512_castps256_ps512(scales8), scales8, 1);
|
||||
vd[0] = _mm512_shuffle_ps(scales, scales, _MM_SHUFFLE(0, 0, 0, 0));
|
||||
vd[1] = _mm512_shuffle_ps(scales, scales, _MM_SHUFFLE(1, 1, 1, 1));
|
||||
vd[2] = _mm512_shuffle_ps(scales, scales, _MM_SHUFFLE(2, 2, 2, 2));
|
||||
vd[3] = _mm512_shuffle_ps(scales, scales, _MM_SHUFFLE(3, 3, 3, 3));
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
vk[8*ib+2*i+0] = _mm512_mul_ps(vd[i], vk[8*ib+2*i+0]);
|
||||
vk[8*ib+2*i+1] = _mm512_mul_ps(vd[i], vk[8*ib+2*i+1]);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < D/32; ++i) {
|
||||
auto vd = _mm512_set1_ps(GGML_FP16_TO_FP32(dl[i].d));
|
||||
auto q = _mm_loadu_si128((const __m128i *)dl[i].qs);
|
||||
auto ql = _mm_add_epi8(_mm_and_si128(q, mask), m8);
|
||||
auto qh = _mm_add_epi8(_mm_and_si128(_mm_srli_epi16(q, 4), mask), m8);
|
||||
vk[2*i+0] = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(ql)));
|
||||
vk[2*i+1] = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(qh)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6190,69 +6217,45 @@ struct HelperQ40 final : public BaseHelper<step> {
|
||||
const __m128i m8 = _mm_set1_epi8(-8);
|
||||
};
|
||||
|
||||
//template <int D, int k_step>
|
||||
//struct KHelperF16 {
|
||||
// KHelperF16(const char * k, int stride_k) : k(k), kb(k), stride_k(stride_k) {}
|
||||
//
|
||||
// inline void set_block(int k1) { kb = k + k1*k_step*stride_k; }
|
||||
// inline void reset_block() { kb = k; }
|
||||
// inline void next_block() { kb += k_step*stride_k; }
|
||||
//
|
||||
// inline void load(int l1, __m512 * vk) const {
|
||||
// auto kr = (const ggml_half *)(kb + l1*stride_k);
|
||||
// for (int i = 0; i < D/16; ++i) vk[i] = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr + i));
|
||||
// }
|
||||
//
|
||||
// inline void load(int l1, int i, __m512& v1, __m512& v2) const {
|
||||
// auto kr = (const ggml_half *)(kb + l1*stride_k);
|
||||
// v1 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr + i + 0));
|
||||
// v2 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr + i + 1));
|
||||
// }
|
||||
//
|
||||
// inline void load_2(int l1, __m512 * vk) const {
|
||||
// load(l1+0, vk+0);
|
||||
// load(l1+1, vk+D/16);
|
||||
// }
|
||||
//
|
||||
// const char * k;
|
||||
// const char * kb;
|
||||
// int stride_k;
|
||||
//};
|
||||
//
|
||||
//template <int D, int k_step>
|
||||
//struct KHelperQ80 {
|
||||
// static_assert(k_step == QK8_0);
|
||||
// KHelperQ80(const char * k, int stride_k) : k(k), kb(k), stride_k(stride_k) {}
|
||||
//
|
||||
// inline void set_block(int k1) { kb = k + k1*k_step*stride_k; }
|
||||
// inline void reset_block() { kb = k; }
|
||||
// inline void next_block() { kb += k_step*stride_k; }
|
||||
//
|
||||
// inline void load(int l1, __m512 * vk) const {
|
||||
// auto kr = (const block_q8_0 *)(kb + l1*stride_k);
|
||||
// for (int i = 0; i < D/32; ++i) {
|
||||
// auto vd = _mm512_set1_ps(GGML_FP16_TO_FP32(kr[i].d));
|
||||
// vk[2*i+0] = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)kr[i].qs+0))));
|
||||
// vk[2*i+1] = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)kr[i].qs+1))));
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// inline void load(int l1, int i, __m512& v1, __m512& v2) const {
|
||||
// auto kr = (const block_q8_0 *)(kb + l1*stride_k) + i/2;
|
||||
// auto vd = _mm512_set1_ps(GGML_FP16_TO_FP32(kr->d));
|
||||
// v1 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)kr->qs+0))));
|
||||
// v2 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)kr->qs+1))));
|
||||
// }
|
||||
//
|
||||
// inline void load_2(int l1, __m512 * vk) const {
|
||||
// load(l1+0, vk+0);
|
||||
// load(l1+1, vk+D/16);
|
||||
// }
|
||||
//
|
||||
// const char * k;
|
||||
// const char * kb;
|
||||
// int stride_k;
|
||||
//};
|
||||
template <int D, int step>
|
||||
struct HelperQ41 final : public BaseHelper<step> {
|
||||
static_assert(step == QK4_1);
|
||||
using Base = BaseHelper<step>;
|
||||
HelperQ41(const char * data, int stride) : Base(data, stride) {}
|
||||
|
||||
|
||||
inline void load(int l1, __m512 * vk) const {
|
||||
auto dl = (const block_q4_1 *)Base::lblock(l1);
|
||||
for (int i = 0; i < D/32; ++i) {
|
||||
auto vd = _mm512_set1_ps(GGML_FP16_TO_FP32(dl[i].d));
|
||||
auto vm = _mm512_set1_ps(GGML_FP16_TO_FP32(dl[i].m));
|
||||
auto q = _mm_loadu_si128((const __m128i *)dl[i].qs);
|
||||
auto ql = _mm_and_si128(q, mask);
|
||||
auto qh = _mm_and_si128(_mm_srli_epi16(q, 4), mask);
|
||||
vk[2*i+0] = _mm512_fmadd_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(ql)), vm);
|
||||
vk[2*i+1] = _mm512_fmadd_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(qh)), vm);
|
||||
}
|
||||
}
|
||||
|
||||
inline void load(int l1, int i, __m512& v1, __m512& v2) const {
|
||||
auto dl = (const block_q4_1 *)Base::lblock(l1) + i/2;
|
||||
auto vd = _mm512_set1_ps(GGML_FP16_TO_FP32(dl->d));
|
||||
auto vm = _mm512_set1_ps(GGML_FP16_TO_FP32(dl->m));
|
||||
auto q = _mm_loadu_si128((const __m128i *)dl->qs);
|
||||
auto ql = _mm_and_si128(q, mask);
|
||||
auto qh = _mm_and_si128(_mm_srli_epi16(q, 4), mask);
|
||||
v1 = _mm512_fmadd_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(ql)), vm);
|
||||
v2 = _mm512_fmadd_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(qh)), vm);
|
||||
}
|
||||
|
||||
inline void load_2(int l1, __m512 * vk) const {
|
||||
load(l1+0, vk+0);
|
||||
load(l1+1, vk+D/16);
|
||||
}
|
||||
|
||||
const __m128i mask = _mm_set1_epi8(0xf);
|
||||
};
|
||||
|
||||
|
||||
// Some of the methods in FlashAttn have two identical implementations that only differ by
|
||||
// one version using a loop over the template parameter q_step, while the other using a loop
|
||||
@@ -6665,6 +6668,10 @@ inline void iqk_flash_helper_T(KHelper& kh, ggml_type type_v,
|
||||
HelperQ40<D, k_step> vh(v, stride_v);
|
||||
iqk_flash_helper<D, q_step, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv);
|
||||
} break;
|
||||
case GGML_TYPE_Q4_1: {
|
||||
HelperQ41<D, k_step> vh(v, stride_v);
|
||||
iqk_flash_helper<D, q_step, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv);
|
||||
} break;
|
||||
default: break;
|
||||
}
|
||||
}
|
||||
@@ -6688,28 +6695,17 @@ inline void iqk_flash_helper_T(ggml_type type_k, ggml_type type_v,
|
||||
HelperQ40<D, k_step> kh(k, stride_k);
|
||||
iqk_flash_helper_T<D, q_step, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv);
|
||||
} break;
|
||||
case GGML_TYPE_Q4_1: {
|
||||
HelperQ41<D, k_step> kh(k, stride_k);
|
||||
iqk_flash_helper_T<D, q_step, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv);
|
||||
} break;
|
||||
default: break;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
//template <int D, int q_step, int k_step>
|
||||
//inline void iqk_flash_helper_T(int nq1, int nk1, int stride_q, int stride_k, int stride_v, int stride_m, int stride_qkv,
|
||||
// const float * q, const char * k, const char * v, const char * mask,
|
||||
// float scale, float softcap, float * qkv) {
|
||||
//
|
||||
// KHelperF16<D, k_step> kh((const char *)k, stride_k);
|
||||
// KHelperF16<D, k_step> vh((const char *)v, stride_v);
|
||||
// if (nq1 >= q_step) {
|
||||
// FlashAttn<D, q_step, k_step> fa(scale, softcap);
|
||||
// fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv);
|
||||
// } else {
|
||||
// FlashAttn<D, 1, k_step> fa(scale, softcap);
|
||||
// fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv);
|
||||
// }
|
||||
//}
|
||||
inline bool flash_attn_is_supported(ggml_type type) {
|
||||
return type == GGML_TYPE_F16 || type == GGML_TYPE_Q8_0 || type == GGML_TYPE_Q4_0;
|
||||
return type == GGML_TYPE_F16 || type == GGML_TYPE_Q8_0 || type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user