mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-30 03:11:51 +00:00
Zen4 Flash Attnetion: small q8_0 performance improvement
This commit is contained in:
@@ -6057,118 +6057,6 @@ inline __m256 v_tanh(__m256 x) {
|
||||
|
||||
namespace {
|
||||
|
||||
//template <int D, int k_step>
|
||||
//struct HelperF16 {
|
||||
// HelperF16(const char * data, int stride) : data(data), stride(stride) {}
|
||||
//
|
||||
// inline void set_block(int k1) { block = data + k1*k_step*stride; }
|
||||
// inline void reset_block() { block = data; }
|
||||
// inline void next_block() { block += k_step*stride; }
|
||||
// inline const char * lblock(int l1) const { return block + l1*stride; }
|
||||
//
|
||||
// inline void load(int l1, __m512 * vk) const {
|
||||
// auto dr = lblock(l1);
|
||||
// for (int i = 0; i < D/16; ++i) vk[i] = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)dr + i));
|
||||
// }
|
||||
//
|
||||
// inline void load(int l1, int i, __m512& v1, __m512& v2) const {
|
||||
// auto dr = lblock(l1);
|
||||
// v1 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)dr + i + 0));
|
||||
// v2 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)dr + i + 1));
|
||||
// }
|
||||
//
|
||||
// inline void load_2(int l1, __m512 * vk) const {
|
||||
// load(l1+0, vk+0);
|
||||
// load(l1+1, vk+D/16);
|
||||
// }
|
||||
//
|
||||
// const char * data;
|
||||
// const char * block;
|
||||
// int stride;
|
||||
//
|
||||
//};
|
||||
//
|
||||
//template <int D, int k_step>
|
||||
//struct HelperQ80 {
|
||||
// static_assert(k_step == QK8_0);
|
||||
// HelperQ80(const char * data, int stride) : data(data), stride(stride) {}
|
||||
//
|
||||
// inline void set_block(int k1) { block = data + k1*k_step*stride; }
|
||||
// inline void reset_block() { block = data; }
|
||||
// inline void next_block() { block += k_step*stride; }
|
||||
// inline const char * lblock(int l1) const { return block + l1*stride; }
|
||||
//
|
||||
// inline void load(int l1, __m512 * vk) const {
|
||||
// auto dl = (const block_q8_0 *)lblock(l1);
|
||||
// for (int i = 0; i < D/32; ++i) {
|
||||
// auto vd = _mm512_set1_ps(GGML_FP16_TO_FP32(dl[i].d));
|
||||
// vk[2*i+0] = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)dl[i].qs+0))));
|
||||
// vk[2*i+1] = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)dl[i].qs+1))));
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// inline void load(int l1, int i, __m512& v1, __m512& v2) const {
|
||||
// auto dl = (const block_q8_0 *)lblock(l1) + i/2;
|
||||
// auto vd = _mm512_set1_ps(GGML_FP16_TO_FP32(dl->d));
|
||||
// v1 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)dl->qs+0))));
|
||||
// v2 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)dl->qs+1))));
|
||||
// }
|
||||
//
|
||||
// inline void load_2(int l1, __m512 * vk) const {
|
||||
// load(l1+0, vk+0);
|
||||
// load(l1+1, vk+D/16);
|
||||
// }
|
||||
//
|
||||
// const char * data;
|
||||
// const char * block;
|
||||
// int stride;
|
||||
//};
|
||||
//
|
||||
//template <int D, int k_step>
|
||||
//struct HelperQ40 {
|
||||
// static_assert(k_step == QK4_0);
|
||||
// HelperQ40(const char * data, int stride) : data(data), stride(stride) {}
|
||||
//
|
||||
// inline void set_block(int k1) { block = data + k1*k_step*stride; }
|
||||
// inline void reset_block() { block = data; }
|
||||
// inline void next_block() { block += k_step*stride; }
|
||||
// inline const char * lblock(int l1) const { return block + l1*stride; }
|
||||
//
|
||||
// inline void load(int l1, __m512 * vk) const {
|
||||
// auto dl = (const block_q4_0 *)lblock(l1);
|
||||
// for (int i = 0; i < D/32; ++i) {
|
||||
// auto vd = _mm512_set1_ps(GGML_FP16_TO_FP32(dl[i].d));
|
||||
// auto q = _mm_loadu_si128((const __m128i *)dl[i].qs);
|
||||
// auto ql = _mm_add_epi8(_mm_and_si128(q, mask), m8);
|
||||
// auto qh = _mm_add_epi8(_mm_and_si128(_mm_srli_epi16(q, 4), mask), m8);
|
||||
// vk[2*i+0] = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(ql)));
|
||||
// vk[2*i+1] = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(qh)));
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// inline void load(int l1, int i, __m512& v1, __m512& v2) const {
|
||||
// auto dl = (const block_q4_0 *)lblock(l1) + i/2;
|
||||
// auto vd = _mm512_set1_ps(GGML_FP16_TO_FP32(dl->d));
|
||||
// auto q = _mm_loadu_si128((const __m128i *)dl->qs);
|
||||
// auto ql = _mm_add_epi8(_mm_and_si128(q, mask), m8);
|
||||
// auto qh = _mm_add_epi8(_mm_and_si128(_mm_srli_epi16(q, 4), mask), m8);
|
||||
// v1 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(ql)));
|
||||
// v2 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(qh)));
|
||||
// }
|
||||
//
|
||||
// inline void load_2(int l1, __m512 * vk) const {
|
||||
// load(l1+0, vk+0);
|
||||
// load(l1+1, vk+D/16);
|
||||
// }
|
||||
//
|
||||
// const __m128i mask = _mm_set1_epi8(0xf);
|
||||
// const __m128i m8 = _mm_set1_epi8(-8);
|
||||
//
|
||||
// const char * data;
|
||||
// const char * block;
|
||||
// int stride;
|
||||
//};
|
||||
|
||||
template <int k_step>
|
||||
struct BaseHelper {
|
||||
BaseHelper(const char * data, int stride) : data(data), block(data), stride(stride) {}
|
||||
@@ -6223,12 +6111,30 @@ struct HelperQ80 final : public BaseHelper<step> {
|
||||
//}
|
||||
inline void load(int l1, __m512 * vk) const {
|
||||
auto dl = (const block_q8_0_x4 *)Base::lblock(l1);
|
||||
for (int i = 0; i < D/32; ++i) {
|
||||
const auto& b8 = dl[i/4];
|
||||
int ii = i%4;
|
||||
auto vd = _mm512_set1_ps(GGML_FP16_TO_FP32(b8.d[ii]));
|
||||
vk[2*i+0] = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)b8.qs+2*ii+0))));
|
||||
vk[2*i+1] = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)b8.qs+2*ii+1))));
|
||||
if constexpr (D >= 128) {
|
||||
__m512 vd[4];
|
||||
for (int ib = 0; ib < D/128; ++ib) {
|
||||
const auto& b8 = dl[ib];
|
||||
auto scales4 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)b8.d));
|
||||
auto scales8 = _mm256_insertf128_ps(_mm256_castps128_ps256(scales4), scales4, 1);
|
||||
auto scales = _mm512_insertf32x8(_mm512_castps256_ps512(scales8), scales8, 1);
|
||||
vd[0] = _mm512_shuffle_ps(scales, scales, _MM_SHUFFLE(0, 0, 0, 0));
|
||||
vd[1] = _mm512_shuffle_ps(scales, scales, _MM_SHUFFLE(1, 1, 1, 1));
|
||||
vd[2] = _mm512_shuffle_ps(scales, scales, _MM_SHUFFLE(2, 2, 2, 2));
|
||||
vd[3] = _mm512_shuffle_ps(scales, scales, _MM_SHUFFLE(3, 3, 3, 3));
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
vk[8*ib+2*i+0] = _mm512_mul_ps(vd[i], _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)b8.qs+2*i+0))));
|
||||
vk[8*ib+2*i+1] = _mm512_mul_ps(vd[i], _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)b8.qs+2*i+1))));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < D/32; ++i) {
|
||||
const auto& b8 = dl[i/4];
|
||||
int ii = i%4;
|
||||
auto vd = _mm512_set1_ps(GGML_FP16_TO_FP32(b8.d[ii]));
|
||||
vk[2*i+0] = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)b8.qs+2*ii+0))));
|
||||
vk[2*i+1] = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)b8.qs+2*ii+1))));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user