mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-03-04 19:10:03 +00:00
NEON Flash Attention: quantized K*Q for q4_1
This commit is contained in:
@@ -6662,6 +6662,142 @@ void quantize_row_q8_0(const float * x, block_q8_0 * y, int k) {
|
||||
#endif
|
||||
}
|
||||
|
||||
void quantize_row_q8_1(const float * x, block_q8_1 * y, int k) {
|
||||
assert(k % QK8_1 == 0);
|
||||
const int nb = k / QK8_1;
|
||||
|
||||
const int nb4 = 4*(nb/4);
|
||||
block_q8_1_x4 * y4 = (block_q8_1_x4 *)y;
|
||||
#if defined(__aarch64__)
|
||||
for (int i = 0; i < nb; i++) {
|
||||
int i4 = i/4, ir = i%4;
|
||||
float32x4_t srcv [8];
|
||||
float32x4_t asrcv[8];
|
||||
float32x4_t amaxv[8];
|
||||
|
||||
for (int j = 0; j < 8; j++) srcv[j] = vld1q_f32(x + i*32 + 4*j);
|
||||
for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[j]);
|
||||
|
||||
for (int j = 0; j < 4; j++) amaxv[2*j] = vmaxq_f32(asrcv[2*j], asrcv[2*j+1]);
|
||||
for (int j = 0; j < 2; j++) amaxv[4*j] = vmaxq_f32(amaxv[4*j], amaxv[4*j+2]);
|
||||
for (int j = 0; j < 1; j++) amaxv[8*j] = vmaxq_f32(amaxv[8*j], amaxv[8*j+4]);
|
||||
|
||||
const float amax = vmaxvq_f32(amaxv[0]);
|
||||
|
||||
const float d = amax / ((1 << 7) - 1);
|
||||
const float id = d ? 1.0f/d : 0.0f;
|
||||
|
||||
if (i < nb4) {
|
||||
y4[i4].d[ir] = GGML_FP32_TO_FP16(d);
|
||||
} else {
|
||||
y[i].d = GGML_FP32_TO_FP16(d);
|
||||
}
|
||||
|
||||
int32x4_t accv = vdupq_n_s32(0);
|
||||
|
||||
for (int j = 0; j < 8; j++) {
|
||||
const float32x4_t v = vmulq_n_f32(srcv[j], id);
|
||||
const int32x4_t vi = vcvtnq_s32_f32(v);
|
||||
|
||||
if (i < nb4) {
|
||||
y4[i4].qs[QK8_1*ir + 4*j + 0] = vgetq_lane_s32(vi, 0);
|
||||
y4[i4].qs[QK8_1*ir + 4*j + 1] = vgetq_lane_s32(vi, 1);
|
||||
y4[i4].qs[QK8_1*ir + 4*j + 2] = vgetq_lane_s32(vi, 2);
|
||||
y4[i4].qs[QK8_1*ir + 4*j + 3] = vgetq_lane_s32(vi, 3);
|
||||
} else {
|
||||
y[i].qs[4*j + 0] = vgetq_lane_s32(vi, 0);
|
||||
y[i].qs[4*j + 1] = vgetq_lane_s32(vi, 1);
|
||||
y[i].qs[4*j + 2] = vgetq_lane_s32(vi, 2);
|
||||
y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3);
|
||||
}
|
||||
|
||||
accv = vaddq_s32(accv, vi);
|
||||
}
|
||||
|
||||
if (i < nb4) {
|
||||
y4[i4].d[ir+4] = GGML_FP32_TO_FP16(d * vaddvq_s32(accv));
|
||||
} else {
|
||||
y[i].s = GGML_FP32_TO_FP16(d * vaddvq_s32(accv));
|
||||
}
|
||||
}
|
||||
#else
|
||||
for (int i = 0; i < nb; i++) {
|
||||
int i4 = i/4, ir = i%4;
|
||||
// Load elements into 4 AVX vectors
|
||||
__m256 v0 = _mm256_loadu_ps( x );
|
||||
__m256 v1 = _mm256_loadu_ps( x + 8 );
|
||||
__m256 v2 = _mm256_loadu_ps( x + 16 );
|
||||
__m256 v3 = _mm256_loadu_ps( x + 24 );
|
||||
x += 32;
|
||||
|
||||
// Compute max(abs(e)) for the block
|
||||
const __m256 signBit = _mm256_set1_ps( -0.0f );
|
||||
__m256 maxAbs = _mm256_andnot_ps( signBit, v0 );
|
||||
maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) );
|
||||
maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) );
|
||||
maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) );
|
||||
|
||||
__m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) );
|
||||
max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
|
||||
max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
|
||||
const float max_scalar = _mm_cvtss_f32( max4 );
|
||||
|
||||
// Quantize these floats
|
||||
const float d = max_scalar / 127.f;
|
||||
if (i < nb4) {
|
||||
y4[i4].d[ir] = GGML_FP32_TO_FP16(d);
|
||||
} else {
|
||||
y[i].d = GGML_FP32_TO_FP16(d);
|
||||
}
|
||||
const float id = ( max_scalar != 0.0f ) ? 127.f / max_scalar : 0.0f;
|
||||
const __m256 mul = _mm256_set1_ps( id );
|
||||
|
||||
// Apply the multiplier
|
||||
v0 = _mm256_mul_ps( v0, mul );
|
||||
v1 = _mm256_mul_ps( v1, mul );
|
||||
v2 = _mm256_mul_ps( v2, mul );
|
||||
v3 = _mm256_mul_ps( v3, mul );
|
||||
|
||||
// Round to nearest integer
|
||||
v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST );
|
||||
v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST );
|
||||
v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST );
|
||||
v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST );
|
||||
|
||||
// Convert floats to integers
|
||||
__m256i i0 = _mm256_cvtps_epi32( v0 );
|
||||
__m256i i1 = _mm256_cvtps_epi32( v1 );
|
||||
__m256i i2 = _mm256_cvtps_epi32( v2 );
|
||||
__m256i i3 = _mm256_cvtps_epi32( v3 );
|
||||
|
||||
// Compute the sum of the quants and set y[i].s
|
||||
if (i < nb4) {
|
||||
y4[i4].d[ir+4] = GGML_FP32_TO_FP16(d * hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3))));
|
||||
} else {
|
||||
y[i].s = GGML_FP32_TO_FP16(d * hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3))));
|
||||
}
|
||||
|
||||
// Convert int32 to int16
|
||||
i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15
|
||||
i2 = _mm256_packs_epi32( i2, i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31
|
||||
// Convert int16 to int8
|
||||
i0 = _mm256_packs_epi16( i0, i2 ); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31
|
||||
|
||||
// We got our precious signed bytes, but the order is now wrong
|
||||
// These AVX2 pack instructions process 16-byte pieces independently
|
||||
// The following instruction is fixing the order
|
||||
const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );
|
||||
i0 = _mm256_permutevar8x32_epi32( i0, perm );
|
||||
|
||||
if (i < nb4) {
|
||||
_mm256_storeu_si256((__m256i *)y4[i4].qs + ir, i0);
|
||||
} else {
|
||||
_mm256_storeu_si256((__m256i *)y[i].qs, i0);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
template <int D, int step>
|
||||
struct HelperQ80 final : public BaseHelper<step> {
|
||||
static_assert(step == QK8_0);
|
||||
@@ -6776,12 +6912,22 @@ struct HelperQ80 final : public BaseHelper<step> {
|
||||
y += D/QK8_0;
|
||||
}
|
||||
}
|
||||
|
||||
static inline void convert(int nq, int stride_q, const float * q, block_q8_1 * y) {
|
||||
GGML_ASSERT(nq <= step);
|
||||
for (int i = 0; i < nq; ++i) {
|
||||
quantize_row_q8_1(q, y, D);
|
||||
q += stride_q;
|
||||
y += D/QK8_1;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <int D, int step>
|
||||
struct HelperQ40 final : public BaseHelper<step> {
|
||||
static_assert(step == QK4_0);
|
||||
using Base = BaseHelper<step>;
|
||||
using block_q8 = block_q8_0;
|
||||
HelperQ40(const char * data, int stride) : Base(data, stride) {}
|
||||
|
||||
|
||||
@@ -6914,6 +7060,7 @@ template <int D, int step>
|
||||
struct HelperQ41 final : public BaseHelper<step> {
|
||||
static_assert(step == QK4_1);
|
||||
using Base = BaseHelper<step>;
|
||||
using block_q8 = block_q8_1;
|
||||
HelperQ41(const char * data, int stride) : Base(data, stride) {}
|
||||
|
||||
inline void load(int l1, F16::Data * vk) const {
|
||||
@@ -7461,30 +7608,21 @@ struct FlashQKfp32 {
|
||||
}
|
||||
#endif
|
||||
|
||||
static inline void mul_mask_kq(const HelperQ40<D, k_step>& kh, int stride_m,
|
||||
const block_q8_0 * q, const char * mask, FlashMS<q_step, k_step>& fms) {
|
||||
template <typename KHelper, typename block_q8>
|
||||
static inline void mul_mask_kq(const KHelper& kh, int stride_m,
|
||||
const block_q8 * q, const char * mask, FlashMS<q_step, k_step>& fms) {
|
||||
static_assert(q_step <= 8);
|
||||
DataInfo info{fms.cache, (const char *)q, k_step, (D/QK8_0)*sizeof(block_q8_0), 0, 1, nullptr};
|
||||
mul_mat_qX_0_q8_0<DequantizerQ40, q_step>(D, kh.block, kh.stride, info, k_step);
|
||||
//auto vinf = vdupq_n_f32(-INFINITY);
|
||||
//auto vzero = vdupq_n_f16(0);
|
||||
//for (int j = 0; j < q_step; ++j) {
|
||||
// const ggml_half * mp = (const ggml_half *)(mask + stride_m*j);
|
||||
// for (int l = 0; l < k_step/8; ++l) {
|
||||
// auto vm = vceqq_f16(vzero, vld1q_f16((const float16_t *)mp + 8*l));
|
||||
// auto vm1 = vzip1q_u16(vm, vm);
|
||||
// auto vm2 = vzip2q_u16(vm, vm);
|
||||
// auto kq = vld1q_f32_x2(fms.cache + k_step*j + 8*l);
|
||||
// kq.val[0] = vreinterpretq_f32_u32(vorrq_u32(vandq_u32(vreinterpretq_u32_f32(kq.val[0]), vm1),
|
||||
// vbicq_u32(vinf, vm1)));
|
||||
// kq.val[1] = vreinterpretq_f32_u32(vorrq_u32(vandq_u32(vreinterpretq_u32_f32(kq.val[1]), vm2),
|
||||
// vbicq_u32(vinf, vm2)));
|
||||
// vst1q_f32_x2(fms.cache + k_step*j + 8*l, kq);
|
||||
// }
|
||||
// //for (int l = 0; l < k_step; ++l) {
|
||||
// // if (mp[l] == fms.h_inf) fms.cache[k_step*j + l] = -INFINITY;
|
||||
// //}
|
||||
//}
|
||||
if constexpr (std::is_same_v<KHelper, HelperQ40<D, k_step>>) {
|
||||
DataInfo info{fms.cache, (const char *)q, k_step, (D/QK8_0)*sizeof(block_q8), 0, 1, nullptr};
|
||||
mul_mat_qX_0_q8_0<DequantizerQ40, q_step>(D, kh.block, kh.stride, info, k_step);
|
||||
}
|
||||
else if constexpr (std::is_same_v<KHelper, HelperQ41<D, k_step>>) {
|
||||
DataInfo info{fms.cache, (const char *)q, k_step, (D/QK8_1)*sizeof(block_q8), 0, 1, nullptr};
|
||||
mul_mat_qX_1_q8_1<DequantizerQ41, q_step>(D, kh.block, kh.stride, info, k_step);
|
||||
}
|
||||
else {
|
||||
GGML_ASSERT(false);
|
||||
}
|
||||
#ifdef __aarch64__
|
||||
float32x4_t vk[k_step/4];
|
||||
for (int j = 0; j < q_step; ++j) {
|
||||
@@ -7497,25 +7635,37 @@ struct FlashQKfp32 {
|
||||
}
|
||||
#endif
|
||||
}
|
||||
static inline void mul_mask_kq(int nq, const HelperQ40<D, k_step>& kh, int stride_m,
|
||||
const block_q8_0 * q, const char * mask, FlashMS<q_step, k_step>& fms) {
|
||||
template <typename KHelper, typename block_q8>
|
||||
static inline void mul_mask_kq(int nq, const KHelper& kh, int stride_m,
|
||||
const block_q8 * q, const char * mask, FlashMS<q_step, k_step>& fms) {
|
||||
GGML_ASSERT(nq < 8);
|
||||
DataInfo info{fms.cache, (const char *)q, D*sizeof(float), (D/QK8_0)*sizeof(block_q8_0), 0, 1, nullptr};
|
||||
switch (nq) {
|
||||
case 1: mul_mat_qX_0_q8_0<DequantizerQ40, 1>(D, kh.block, kh.stride, info, k_step); break;
|
||||
case 2: mul_mat_qX_0_q8_0<DequantizerQ40, 2>(D, kh.block, kh.stride, info, k_step); break;
|
||||
case 3: mul_mat_qX_0_q8_0<DequantizerQ40, 3>(D, kh.block, kh.stride, info, k_step); break;
|
||||
case 4: mul_mat_qX_0_q8_0<DequantizerQ40, 4>(D, kh.block, kh.stride, info, k_step); break;
|
||||
case 5: mul_mat_qX_0_q8_0<DequantizerQ40, 5>(D, kh.block, kh.stride, info, k_step); break;
|
||||
case 6: mul_mat_qX_0_q8_0<DequantizerQ40, 6>(D, kh.block, kh.stride, info, k_step); break;
|
||||
case 7: mul_mat_qX_0_q8_0<DequantizerQ40, 7>(D, kh.block, kh.stride, info, k_step); break;
|
||||
if constexpr (std::is_same_v<KHelper, HelperQ40<D, k_step>>) {
|
||||
DataInfo info{fms.cache, (const char *)q, D*sizeof(float), (D/QK8_0)*sizeof(block_q8), 0, 1, nullptr};
|
||||
switch (nq) {
|
||||
case 1: mul_mat_qX_0_q8_0<DequantizerQ40, 1>(D, kh.block, kh.stride, info, k_step); break;
|
||||
case 2: mul_mat_qX_0_q8_0<DequantizerQ40, 2>(D, kh.block, kh.stride, info, k_step); break;
|
||||
case 3: mul_mat_qX_0_q8_0<DequantizerQ40, 3>(D, kh.block, kh.stride, info, k_step); break;
|
||||
case 4: mul_mat_qX_0_q8_0<DequantizerQ40, 4>(D, kh.block, kh.stride, info, k_step); break;
|
||||
case 5: mul_mat_qX_0_q8_0<DequantizerQ40, 5>(D, kh.block, kh.stride, info, k_step); break;
|
||||
case 6: mul_mat_qX_0_q8_0<DequantizerQ40, 6>(D, kh.block, kh.stride, info, k_step); break;
|
||||
case 7: mul_mat_qX_0_q8_0<DequantizerQ40, 7>(D, kh.block, kh.stride, info, k_step); break;
|
||||
}
|
||||
}
|
||||
else if constexpr (std::is_same_v<KHelper, HelperQ41<D, k_step>>) {
|
||||
DataInfo info{fms.cache, (const char *)q, k_step, (D/QK8_1)*sizeof(block_q8), 0, 1, nullptr};
|
||||
switch (nq) {
|
||||
case 1: mul_mat_qX_1_q8_1<DequantizerQ41, 1>(D, kh.block, kh.stride, info, k_step); break;
|
||||
case 2: mul_mat_qX_1_q8_1<DequantizerQ41, 2>(D, kh.block, kh.stride, info, k_step); break;
|
||||
case 3: mul_mat_qX_1_q8_1<DequantizerQ41, 3>(D, kh.block, kh.stride, info, k_step); break;
|
||||
case 4: mul_mat_qX_1_q8_1<DequantizerQ41, 4>(D, kh.block, kh.stride, info, k_step); break;
|
||||
case 5: mul_mat_qX_1_q8_1<DequantizerQ41, 5>(D, kh.block, kh.stride, info, k_step); break;
|
||||
case 6: mul_mat_qX_1_q8_1<DequantizerQ41, 6>(D, kh.block, kh.stride, info, k_step); break;
|
||||
case 7: mul_mat_qX_1_q8_1<DequantizerQ41, 7>(D, kh.block, kh.stride, info, k_step); break;
|
||||
}
|
||||
}
|
||||
else {
|
||||
GGML_ASSERT(false);
|
||||
}
|
||||
//for (int j = 0; j < nq; ++j) {
|
||||
// const ggml_half * mp = (const ggml_half *)(mask + stride_m*j);
|
||||
// for (int l = 0; l < k_step; ++l) {
|
||||
// if (mp[l] == fms.h_inf) fms.cache[k_step*j + l] = -INFINITY;
|
||||
// }
|
||||
//}
|
||||
#ifdef __aarch64__
|
||||
float32x4_t vk[k_step/4];
|
||||
for (int j = 0; j < nq; ++j) {
|
||||
@@ -7587,20 +7737,20 @@ void compute_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, in
|
||||
}
|
||||
}
|
||||
|
||||
template <int D, int q_step, int k_step, typename VHelper, typename KQHelper>
|
||||
void compute_helper_q(HelperQ40<D, k_step>& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv,
|
||||
template <int D, int q_step, int k_step, typename KHelper, typename VHelper, typename KQHelper>
|
||||
void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv,
|
||||
FlashMS<q_step, k_step>& fms,
|
||||
FlashQKV<D, q_step, k_step>& fqkv,
|
||||
const float * q, const char * mask, float * qkv) {
|
||||
block_q8_0 q80[q_step*(D/QK8_0)];
|
||||
typename KHelper::block_q8 q8[q_step*(D/QK8_0)];
|
||||
for (int i1 = 0; i1 < nq1/q_step; ++i1) {
|
||||
fms.init_qstep();
|
||||
kh.reset_block();
|
||||
vh.reset_block();
|
||||
HelperQ80<D, QK8_0>::convert(q_step, stride_q, q, q80);
|
||||
HelperQ80<D, QK8_0>::convert(q_step, stride_q, q, q8);
|
||||
auto mr = mask;
|
||||
for (int k1 = 0; k1 < nk1/k_step; ++k1) {
|
||||
KQHelper::mul_mask_kq(kh, stride_m, q80, mr, fms);
|
||||
KQHelper::mul_mask_kq(kh, stride_m, q8, mr, fms);
|
||||
fqkv.accumulate_qkv(vh, fms);
|
||||
kh.next_block();
|
||||
vh.next_block();
|
||||
@@ -7617,10 +7767,10 @@ void compute_helper_q(HelperQ40<D, k_step>& kh, VHelper& vh, int nq1, int nk1, i
|
||||
fms.init_qstep();
|
||||
kh.reset_block();
|
||||
vh.reset_block();
|
||||
HelperQ80<D, QK8_0>::convert(n_left, stride_q, q, q80);
|
||||
HelperQ80<D, QK8_0>::convert(n_left, stride_q, q, q8);
|
||||
auto mr = mask;
|
||||
for (int k1 = 0; k1 < nk1/k_step; ++k1) {
|
||||
KQHelper::mul_mask_kq(n_left, kh, stride_m, q80, mr, fms);
|
||||
KQHelper::mul_mask_kq(n_left, kh, stride_m, q8, mr, fms);
|
||||
fqkv.accumulate_qkv(n_left, vh, fms);
|
||||
kh.next_block();
|
||||
vh.next_block();
|
||||
@@ -7651,8 +7801,8 @@ struct FlashAttn {
|
||||
template <typename KHelper, typename VHelper>
|
||||
void compute(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv,
|
||||
const float * q, const char * mask, float * qkv) {
|
||||
if constexpr (std::is_same_v<KHelper, HelperQ40<D, k_step>>) {
|
||||
compute_helper_q<D, q_step, k_step, VHelper, FlashQKfp32<D, q_step, k_step>>(
|
||||
if constexpr (std::is_same_v<KHelper, HelperQ40<D, k_step>> || std::is_same_v<KHelper, HelperQ41<D, k_step>>) {
|
||||
compute_helper_q<D, q_step, k_step, KHelper, VHelper, FlashQKfp32<D, q_step, k_step>>(
|
||||
kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv);
|
||||
} else {
|
||||
compute_helper<D, q_step, k_step, KHelper, VHelper, FlashQKfp32<D, q_step, k_step>>(
|
||||
|
||||
Reference in New Issue
Block a user