Zen4 Flash Attnetion: small q8_0 performance improvement

This commit is contained in:
Iwan Kawrakow
2024-09-03 10:56:20 +03:00
parent a4256004a8
commit e73835de95

View File

@@ -6057,118 +6057,6 @@ inline __m256 v_tanh(__m256 x) {
namespace {
//template <int D, int k_step>
//struct HelperF16 {
// HelperF16(const char * data, int stride) : data(data), stride(stride) {}
//
// inline void set_block(int k1) { block = data + k1*k_step*stride; }
// inline void reset_block() { block = data; }
// inline void next_block() { block += k_step*stride; }
// inline const char * lblock(int l1) const { return block + l1*stride; }
//
// inline void load(int l1, __m512 * vk) const {
// auto dr = lblock(l1);
// for (int i = 0; i < D/16; ++i) vk[i] = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)dr + i));
// }
//
// inline void load(int l1, int i, __m512& v1, __m512& v2) const {
// auto dr = lblock(l1);
// v1 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)dr + i + 0));
// v2 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)dr + i + 1));
// }
//
// inline void load_2(int l1, __m512 * vk) const {
// load(l1+0, vk+0);
// load(l1+1, vk+D/16);
// }
//
// const char * data;
// const char * block;
// int stride;
//
//};
//
//template <int D, int k_step>
//struct HelperQ80 {
// static_assert(k_step == QK8_0);
// HelperQ80(const char * data, int stride) : data(data), stride(stride) {}
//
// inline void set_block(int k1) { block = data + k1*k_step*stride; }
// inline void reset_block() { block = data; }
// inline void next_block() { block += k_step*stride; }
// inline const char * lblock(int l1) const { return block + l1*stride; }
//
// inline void load(int l1, __m512 * vk) const {
// auto dl = (const block_q8_0 *)lblock(l1);
// for (int i = 0; i < D/32; ++i) {
// auto vd = _mm512_set1_ps(GGML_FP16_TO_FP32(dl[i].d));
// vk[2*i+0] = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)dl[i].qs+0))));
// vk[2*i+1] = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)dl[i].qs+1))));
// }
// }
//
// inline void load(int l1, int i, __m512& v1, __m512& v2) const {
// auto dl = (const block_q8_0 *)lblock(l1) + i/2;
// auto vd = _mm512_set1_ps(GGML_FP16_TO_FP32(dl->d));
// v1 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)dl->qs+0))));
// v2 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)dl->qs+1))));
// }
//
// inline void load_2(int l1, __m512 * vk) const {
// load(l1+0, vk+0);
// load(l1+1, vk+D/16);
// }
//
// const char * data;
// const char * block;
// int stride;
//};
//
//template <int D, int k_step>
//struct HelperQ40 {
// static_assert(k_step == QK4_0);
// HelperQ40(const char * data, int stride) : data(data), stride(stride) {}
//
// inline void set_block(int k1) { block = data + k1*k_step*stride; }
// inline void reset_block() { block = data; }
// inline void next_block() { block += k_step*stride; }
// inline const char * lblock(int l1) const { return block + l1*stride; }
//
// inline void load(int l1, __m512 * vk) const {
// auto dl = (const block_q4_0 *)lblock(l1);
// for (int i = 0; i < D/32; ++i) {
// auto vd = _mm512_set1_ps(GGML_FP16_TO_FP32(dl[i].d));
// auto q = _mm_loadu_si128((const __m128i *)dl[i].qs);
// auto ql = _mm_add_epi8(_mm_and_si128(q, mask), m8);
// auto qh = _mm_add_epi8(_mm_and_si128(_mm_srli_epi16(q, 4), mask), m8);
// vk[2*i+0] = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(ql)));
// vk[2*i+1] = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(qh)));
// }
// }
//
// inline void load(int l1, int i, __m512& v1, __m512& v2) const {
// auto dl = (const block_q4_0 *)lblock(l1) + i/2;
// auto vd = _mm512_set1_ps(GGML_FP16_TO_FP32(dl->d));
// auto q = _mm_loadu_si128((const __m128i *)dl->qs);
// auto ql = _mm_add_epi8(_mm_and_si128(q, mask), m8);
// auto qh = _mm_add_epi8(_mm_and_si128(_mm_srli_epi16(q, 4), mask), m8);
// v1 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(ql)));
// v2 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(qh)));
// }
//
// inline void load_2(int l1, __m512 * vk) const {
// load(l1+0, vk+0);
// load(l1+1, vk+D/16);
// }
//
// const __m128i mask = _mm_set1_epi8(0xf);
// const __m128i m8 = _mm_set1_epi8(-8);
//
// const char * data;
// const char * block;
// int stride;
//};
template <int k_step>
struct BaseHelper {
BaseHelper(const char * data, int stride) : data(data), block(data), stride(stride) {}
@@ -6223,12 +6111,30 @@ struct HelperQ80 final : public BaseHelper<step> {
//}
inline void load(int l1, __m512 * vk) const {
auto dl = (const block_q8_0_x4 *)Base::lblock(l1);
for (int i = 0; i < D/32; ++i) {
const auto& b8 = dl[i/4];
int ii = i%4;
auto vd = _mm512_set1_ps(GGML_FP16_TO_FP32(b8.d[ii]));
vk[2*i+0] = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)b8.qs+2*ii+0))));
vk[2*i+1] = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)b8.qs+2*ii+1))));
if constexpr (D >= 128) {
__m512 vd[4];
for (int ib = 0; ib < D/128; ++ib) {
const auto& b8 = dl[ib];
auto scales4 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)b8.d));
auto scales8 = _mm256_insertf128_ps(_mm256_castps128_ps256(scales4), scales4, 1);
auto scales = _mm512_insertf32x8(_mm512_castps256_ps512(scales8), scales8, 1);
vd[0] = _mm512_shuffle_ps(scales, scales, _MM_SHUFFLE(0, 0, 0, 0));
vd[1] = _mm512_shuffle_ps(scales, scales, _MM_SHUFFLE(1, 1, 1, 1));
vd[2] = _mm512_shuffle_ps(scales, scales, _MM_SHUFFLE(2, 2, 2, 2));
vd[3] = _mm512_shuffle_ps(scales, scales, _MM_SHUFFLE(3, 3, 3, 3));
for (int i = 0; i < 4; ++i) {
vk[8*ib+2*i+0] = _mm512_mul_ps(vd[i], _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)b8.qs+2*i+0))));
vk[8*ib+2*i+1] = _mm512_mul_ps(vd[i], _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)b8.qs+2*i+1))));
}
}
} else {
for (int i = 0; i < D/32; ++i) {
const auto& b8 = dl[i/4];
int ii = i%4;
auto vd = _mm512_set1_ps(GGML_FP16_TO_FP32(b8.d[ii]));
vk[2*i+0] = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)b8.qs+2*ii+0))));
vk[2*i+1] = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)b8.qs+2*ii+1))));
}
}
}