mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-25 15:44:10 +00:00
Zen4 Flash Attention: quantized K*Q for q4_0, q4_1, q8_0
This commit is contained in:
@@ -3045,7 +3045,6 @@ template <typename Minus, int nrc_y, bool is_multiple_of_4> struct AccumT {
|
||||
}
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
info.store(ix, iy, accm.result(acc[iy], iy));
|
||||
//s[iy*bs] = accm.result(acc[iy], iy);
|
||||
}
|
||||
}
|
||||
};
|
||||
@@ -3212,6 +3211,35 @@ struct Q_Unpacker {
|
||||
}
|
||||
};
|
||||
|
||||
struct Q8_0_x4_Unpacker {
|
||||
using Sum4T = Sum4TypeQ80;
|
||||
inline static int block_size() { return QK8_0; }
|
||||
Q8_0_x4_Unpacker(const void * vx, size_t bx) : cx_0((const char *)vx), x((const block_q8_0_x4 *)cx_0), bx(bx) {}
|
||||
|
||||
const char * cx_0;
|
||||
const block_q8_0_x4 * x;
|
||||
size_t bx;
|
||||
|
||||
__m256i qx[4];
|
||||
|
||||
inline const __m256i* quants() const { return qx; }
|
||||
|
||||
inline void set_row(int ix) { x = (const block_q8_0_x4 *)(cx_0 + ix*bx); }
|
||||
|
||||
inline auto set_block_4(int i) {
|
||||
auto scales = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)x[i].d));
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
qx[j] = _mm256_loadu_si256((const __m256i *)x[i].qs + j);
|
||||
}
|
||||
return scales;
|
||||
}
|
||||
inline auto set_block(int i) {
|
||||
auto q8 = (const block_q8_0 *)(x + i);
|
||||
qx[0] = _mm256_loadu_si256((const __m256i *)q8->qs);
|
||||
return GGML_FP16_TO_FP32(q8->d);
|
||||
}
|
||||
};
|
||||
|
||||
struct Q8_0_Unpacker final : public Q_Unpacker<block_q8_0, ScaleHelperQ_0, Q8_0_Dequantizer> {
|
||||
Q8_0_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {}
|
||||
using Sum4T = Sum4TypeQ80;
|
||||
@@ -7320,6 +7348,60 @@ struct FlashMS {
|
||||
}
|
||||
}
|
||||
|
||||
float smax = F16::reduce_max<k_step>(vk);
|
||||
if (smax == -INFINITY) {
|
||||
std::memset(cache + k_step*j, 0, k_step*sizeof(float));
|
||||
need_scaling[j] = M[j] == -INFINITY ? 2 : 0;
|
||||
return;
|
||||
}
|
||||
need_scaling[j] = 0;
|
||||
if (smax > M[j]) {
|
||||
if (M[j] > -INFINITY) {
|
||||
float m = expf(M[j] - smax);
|
||||
vms[j] = F16::set1(m);
|
||||
need_scaling[j] = 1;
|
||||
S[j] *= m;
|
||||
} else {
|
||||
need_scaling[j] = 2;
|
||||
S[j] = 0;
|
||||
}
|
||||
M[j] = smax;
|
||||
}
|
||||
auto vm = F16::set1(M[j]);
|
||||
for (int l = 0; l < k_step/F16::block_size; ++l) {
|
||||
vk[l] = v_expf(F16::sub(vk[l], vm));
|
||||
F16::store(cache + k_step*j + F16::block_size*l, vk[l]);
|
||||
}
|
||||
S[j] += F16::reduce_add<k_step>(vk);
|
||||
}
|
||||
inline void update_M_S(int j, F16::Data * vk, const char * mask) {
|
||||
auto vzero = _mm256_set1_epi16(0);
|
||||
auto vinf = _mm512_set1_ps(-INFINITY);
|
||||
//for (int l = 0; l < k_step/F16::block_size; ++l) {
|
||||
// auto m16 = _mm256_cmpeq_epi16_mask(_mm256_loadu_si256((const __m256i *)mask + l), vzero);
|
||||
// vk[l] = _mm512_mask_blend_ps(m16, vinf, F16::load(cache + k_step*j + F16::block_size*l));
|
||||
//}
|
||||
//if (softcap <= 0.0f) {
|
||||
// for (int l = 0; l < k_step/F16::block_size; ++l) vk[l] = F16::mul(vscale, vk[l]);
|
||||
//} else {
|
||||
// auto v_softcap = F16::set1(softcap);
|
||||
// for (int l = 0; l < k_step/F16::block_size; ++l) {
|
||||
// vk[l] = F16::mul(v_softcap, v_tanh(F16::mul(vscale, vk[l])));
|
||||
// }
|
||||
//}
|
||||
if (softcap <= 0) {
|
||||
for (int l = 0; l < k_step/F16::block_size; ++l) {
|
||||
auto m16 = _mm256_cmpeq_epi16_mask(_mm256_loadu_si256((const __m256i *)mask + l), vzero);
|
||||
vk[l] = _mm512_mask_mul_ps(vinf, m16, vscale, F16::load(cache + k_step*j + F16::block_size*l));
|
||||
}
|
||||
} else {
|
||||
auto v_softcap = F16::set1(softcap);
|
||||
for (int l = 0; l < k_step/F16::block_size; ++l) {
|
||||
auto m16 = _mm256_cmpeq_epi16_mask(_mm256_loadu_si256((const __m256i *)mask + l), vzero);
|
||||
vk[l] = _mm512_mask_mul_ps(vinf, m16, v_softcap, v_tanh(F16::mul(vscale, F16::load(cache + k_step*j + F16::block_size*l))));
|
||||
}
|
||||
}
|
||||
|
||||
float smax = F16::reduce_max<k_step>(vk);
|
||||
if (smax == -INFINITY) {
|
||||
std::memset(cache + k_step*j, 0, k_step*sizeof(float));
|
||||
@@ -7636,15 +7718,27 @@ struct FlashQKfp32 {
|
||||
static_assert(q_step <= 8);
|
||||
if constexpr (std::is_same_v<KHelper, HelperQ40<D, k_step>>) {
|
||||
DataInfo info{fms.cache, (const char *)q, k_step, (D/QK8_0)*sizeof(block_q8), 0, 1, nullptr};
|
||||
#ifdef __aarch64__
|
||||
mul_mat_qX_0_q8_0<DequantizerQ40, q_step>(D, kh.block, kh.stride, info, k_step);
|
||||
#else
|
||||
mul_mat_qX_0_q8_0_T<Q4_0_Unpacker, q_step>(D, kh.block, kh.stride, info, k_step);
|
||||
#endif
|
||||
}
|
||||
else if constexpr (std::is_same_v<KHelper, HelperQ80<D, k_step>>) {
|
||||
DataInfo info{fms.cache, (const char *)q, k_step, (D/QK8_0)*sizeof(block_q8), 0, 1, nullptr};
|
||||
#ifdef __aarch64__
|
||||
mul_mat_qX_0_q8_0<DequantizerQ80_x4, q_step>(D, kh.block, kh.stride, info, k_step);
|
||||
#else
|
||||
mul_mat_qX_0_q8_0_T<Q8_0_x4_Unpacker, q_step>(D, kh.block, kh.stride, info, k_step);
|
||||
#endif
|
||||
}
|
||||
else if constexpr (std::is_same_v<KHelper, HelperQ41<D, k_step>>) {
|
||||
DataInfo info{fms.cache, (const char *)q, k_step, (D/QK8_1)*sizeof(block_q8), 0, 1, nullptr};
|
||||
#ifdef __aarch64__
|
||||
mul_mat_qX_1_q8_1<DequantizerQ41, q_step>(D, kh.block, kh.stride, info, k_step);
|
||||
#else
|
||||
mul_mat_qX_1_q8_1_T<Q4_1_Unpacker, q_step>(D, kh.block, kh.stride, info, k_step);
|
||||
#endif
|
||||
}
|
||||
else {
|
||||
GGML_ASSERT(false);
|
||||
@@ -7657,7 +7751,7 @@ struct FlashQKfp32 {
|
||||
#else
|
||||
F16::Data vk[k_step/F16::block_size];
|
||||
for (int j = 0; j < q_step; ++j) {
|
||||
fms.update_M_S(j, vk);
|
||||
fms.update_M_S(j, vk, mask + stride_m*j);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
@@ -7668,6 +7762,7 @@ struct FlashQKfp32 {
|
||||
if constexpr (std::is_same_v<KHelper, HelperQ40<D, k_step>>) {
|
||||
DataInfo info{fms.cache, (const char *)q, D*sizeof(float), (D/QK8_0)*sizeof(block_q8), 0, 1, nullptr};
|
||||
switch (nq) {
|
||||
#ifdef __aarch64__
|
||||
case 1: mul_mat_qX_0_q8_0<DequantizerQ40, 1>(D, kh.block, kh.stride, info, k_step); break;
|
||||
case 2: mul_mat_qX_0_q8_0<DequantizerQ40, 2>(D, kh.block, kh.stride, info, k_step); break;
|
||||
case 3: mul_mat_qX_0_q8_0<DequantizerQ40, 3>(D, kh.block, kh.stride, info, k_step); break;
|
||||
@@ -7675,11 +7770,21 @@ struct FlashQKfp32 {
|
||||
case 5: mul_mat_qX_0_q8_0<DequantizerQ40, 5>(D, kh.block, kh.stride, info, k_step); break;
|
||||
case 6: mul_mat_qX_0_q8_0<DequantizerQ40, 6>(D, kh.block, kh.stride, info, k_step); break;
|
||||
case 7: mul_mat_qX_0_q8_0<DequantizerQ40, 7>(D, kh.block, kh.stride, info, k_step); break;
|
||||
#else
|
||||
case 1: mul_mat_qX_0_q8_0_T<Q4_0_Unpacker, 1>(D, kh.block, kh.stride, info, k_step); break;
|
||||
case 2: mul_mat_qX_0_q8_0_T<Q4_0_Unpacker, 2>(D, kh.block, kh.stride, info, k_step); break;
|
||||
case 3: mul_mat_qX_0_q8_0_T<Q4_0_Unpacker, 3>(D, kh.block, kh.stride, info, k_step); break;
|
||||
case 4: mul_mat_qX_0_q8_0_T<Q4_0_Unpacker, 4>(D, kh.block, kh.stride, info, k_step); break;
|
||||
case 5: mul_mat_qX_0_q8_0_T<Q4_0_Unpacker, 5>(D, kh.block, kh.stride, info, k_step); break;
|
||||
case 6: mul_mat_qX_0_q8_0_T<Q4_0_Unpacker, 6>(D, kh.block, kh.stride, info, k_step); break;
|
||||
case 7: mul_mat_qX_0_q8_0_T<Q4_0_Unpacker, 7>(D, kh.block, kh.stride, info, k_step); break;
|
||||
#endif
|
||||
}
|
||||
}
|
||||
else if constexpr (std::is_same_v<KHelper, HelperQ80<D, k_step>>) {
|
||||
DataInfo info{fms.cache, (const char *)q, D*sizeof(float), (D/QK8_0)*sizeof(block_q8), 0, 1, nullptr};
|
||||
switch (nq) {
|
||||
#ifdef __aarch64__
|
||||
case 1: mul_mat_qX_0_q8_0<DequantizerQ80_x4, 1>(D, kh.block, kh.stride, info, k_step); break;
|
||||
case 2: mul_mat_qX_0_q8_0<DequantizerQ80_x4, 2>(D, kh.block, kh.stride, info, k_step); break;
|
||||
case 3: mul_mat_qX_0_q8_0<DequantizerQ80_x4, 3>(D, kh.block, kh.stride, info, k_step); break;
|
||||
@@ -7687,11 +7792,21 @@ struct FlashQKfp32 {
|
||||
case 5: mul_mat_qX_0_q8_0<DequantizerQ80_x4, 5>(D, kh.block, kh.stride, info, k_step); break;
|
||||
case 6: mul_mat_qX_0_q8_0<DequantizerQ80_x4, 6>(D, kh.block, kh.stride, info, k_step); break;
|
||||
case 7: mul_mat_qX_0_q8_0<DequantizerQ80_x4, 7>(D, kh.block, kh.stride, info, k_step); break;
|
||||
#else
|
||||
case 1: mul_mat_qX_0_q8_0_T<Q8_0_x4_Unpacker, 1>(D, kh.block, kh.stride, info, k_step); break;
|
||||
case 2: mul_mat_qX_0_q8_0_T<Q8_0_x4_Unpacker, 2>(D, kh.block, kh.stride, info, k_step); break;
|
||||
case 3: mul_mat_qX_0_q8_0_T<Q8_0_x4_Unpacker, 3>(D, kh.block, kh.stride, info, k_step); break;
|
||||
case 4: mul_mat_qX_0_q8_0_T<Q8_0_x4_Unpacker, 4>(D, kh.block, kh.stride, info, k_step); break;
|
||||
case 5: mul_mat_qX_0_q8_0_T<Q8_0_x4_Unpacker, 5>(D, kh.block, kh.stride, info, k_step); break;
|
||||
case 6: mul_mat_qX_0_q8_0_T<Q8_0_x4_Unpacker, 6>(D, kh.block, kh.stride, info, k_step); break;
|
||||
case 7: mul_mat_qX_0_q8_0_T<Q8_0_x4_Unpacker, 7>(D, kh.block, kh.stride, info, k_step); break;
|
||||
#endif
|
||||
}
|
||||
}
|
||||
else if constexpr (std::is_same_v<KHelper, HelperQ41<D, k_step>>) {
|
||||
DataInfo info{fms.cache, (const char *)q, k_step, (D/QK8_1)*sizeof(block_q8), 0, 1, nullptr};
|
||||
switch (nq) {
|
||||
#ifdef __aarch64__
|
||||
case 1: mul_mat_qX_1_q8_1<DequantizerQ41, 1>(D, kh.block, kh.stride, info, k_step); break;
|
||||
case 2: mul_mat_qX_1_q8_1<DequantizerQ41, 2>(D, kh.block, kh.stride, info, k_step); break;
|
||||
case 3: mul_mat_qX_1_q8_1<DequantizerQ41, 3>(D, kh.block, kh.stride, info, k_step); break;
|
||||
@@ -7699,6 +7814,15 @@ struct FlashQKfp32 {
|
||||
case 5: mul_mat_qX_1_q8_1<DequantizerQ41, 5>(D, kh.block, kh.stride, info, k_step); break;
|
||||
case 6: mul_mat_qX_1_q8_1<DequantizerQ41, 6>(D, kh.block, kh.stride, info, k_step); break;
|
||||
case 7: mul_mat_qX_1_q8_1<DequantizerQ41, 7>(D, kh.block, kh.stride, info, k_step); break;
|
||||
#else
|
||||
case 1: mul_mat_qX_1_q8_1_T<Q4_1_Unpacker, 1>(D, kh.block, kh.stride, info, k_step); break;
|
||||
case 2: mul_mat_qX_1_q8_1_T<Q4_1_Unpacker, 2>(D, kh.block, kh.stride, info, k_step); break;
|
||||
case 3: mul_mat_qX_1_q8_1_T<Q4_1_Unpacker, 3>(D, kh.block, kh.stride, info, k_step); break;
|
||||
case 4: mul_mat_qX_1_q8_1_T<Q4_1_Unpacker, 4>(D, kh.block, kh.stride, info, k_step); break;
|
||||
case 5: mul_mat_qX_1_q8_1_T<Q4_1_Unpacker, 5>(D, kh.block, kh.stride, info, k_step); break;
|
||||
case 6: mul_mat_qX_1_q8_1_T<Q4_1_Unpacker, 6>(D, kh.block, kh.stride, info, k_step); break;
|
||||
case 7: mul_mat_qX_1_q8_1_T<Q4_1_Unpacker, 7>(D, kh.block, kh.stride, info, k_step); break;
|
||||
#endif
|
||||
}
|
||||
}
|
||||
else {
|
||||
@@ -7712,7 +7836,7 @@ struct FlashQKfp32 {
|
||||
#else
|
||||
F16::Data vk[k_step/F16::block_size];
|
||||
for (int j = 0; j < nq; ++j) {
|
||||
fms.update_M_S(j, vk);
|
||||
fms.update_M_S(j, vk, mask + stride_m*j);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user