mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-25 15:44:10 +00:00
NEON Flash Attention - use fp32 for K*Q operations
Else I get wrong results for LLaMA-3.1-8B (but it works for Gemma-2b).
This commit is contained in:
@@ -6495,6 +6495,7 @@ struct F16 {
|
||||
vst1q_f32(ptr+4, vcvt_f32_f16(vget_high_f16(data)));
|
||||
}
|
||||
static inline void store(float16_t * ptr, Data data) { vst1q_f16(ptr, data); }
|
||||
static inline void store(float * ptr, float32x4_t data) { vst1q_f32(ptr, data); }
|
||||
static inline Data fmadd(Data prev, Data v1, Data v2) { return vfmaq_f16(prev, v1, v2); }
|
||||
static inline float reduce_max(Data data) { return vmaxvq_f16(data); }
|
||||
static inline float reduce_add(Data data) {
|
||||
@@ -6707,11 +6708,16 @@ struct HelperQ41 final : public BaseHelper<step> {
|
||||
|
||||
template <int q_step, int k_step>
|
||||
struct FlashMS {
|
||||
#ifdef __aarch64__
|
||||
using cache_t = float16_t;
|
||||
#else
|
||||
// Something goes wrong when storing and manipulating K*Q as fp16.
|
||||
// It works for some models (e.g., Gemma-2), but not for others (e.g., LLaMA-3.1-8B).
|
||||
// As I wasn't able to find where we lose precision, let's comment this out
|
||||
// for now and do the K*Q part in fp32.
|
||||
//#ifdef __aarch64__
|
||||
// using cache_t = float16_t;
|
||||
//#else
|
||||
// using cache_t = float;
|
||||
//#endif
|
||||
using cache_t = float;
|
||||
#endif
|
||||
|
||||
FlashMS(float scale, float softcap) : vscale(F16::set1(scale)), softcap(softcap), h_inf(GGML_FP32_TO_FP16(-INFINITY)) {}
|
||||
|
||||
@@ -6721,6 +6727,75 @@ struct FlashMS {
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef __aarch64__
|
||||
inline void update_M_S(int j, float32x4_t * vk) {
|
||||
float32x4_t vmax = vdupq_n_f32(-INFINITY);
|
||||
// Something goes wrong when storing and manipulating K*Q as fp16.
|
||||
// It works for some models (e.g., Gemma-2), but not for others (e.g., LLaMA-3.1-8B).
|
||||
// As I wasn't able to find where we lose precision, let's comment this out
|
||||
// for now and do the K*Q part in fp32.
|
||||
//if (softcap <= 0.0f) {
|
||||
// for (int l = 0; l < k_step/F16::block_size; ++l) {
|
||||
// auto val = F16::mul(vscale, F16::load(cache + k_step*j + F16::block_size*l));
|
||||
// vk[2*l+0] = vcvt_f32_f16(vget_low_f16(val));
|
||||
// vk[2*l+1] = vcvt_f32_f16(vget_high_f16(val));
|
||||
// vmax = vmaxq_f32(vmax, vmaxq_f32(vk[2*l+0], vk[2*l+1]));
|
||||
// }
|
||||
//} else {
|
||||
// auto v_softcap = vdupq_n_f32(softcap);
|
||||
// for (int l = 0; l < k_step/F16::block_size; ++l) {
|
||||
// auto val = F16::mul(vscale, F16::load(cache + k_step*j + F16::block_size*l));
|
||||
// vk[2*l+0] = vcvt_f32_f16(vget_low_f16(val));
|
||||
// vk[2*l+1] = vcvt_f32_f16(vget_high_f16(val));
|
||||
// vk[2*l+0] = vmulq_f32(v_softcap, v_tanh(vk[2*l+0]));
|
||||
// vk[2*l+1] = vmulq_f32(v_softcap, v_tanh(vk[2*l+1]));
|
||||
// vmax = vmaxq_f32(vmax, vmaxq_f32(vk[2*l+0], vk[2*l+1]));
|
||||
// }
|
||||
//}
|
||||
auto vscale32 = vcvt_f32_f16(vget_low_f16(vscale));
|
||||
if (softcap <= 0.0f) {
|
||||
for (int l = 0; l < k_step/4; ++l) {
|
||||
vk[l] = vmulq_f32(vscale32, vld1q_f32(cache + k_step*j + 4*l));
|
||||
vmax = vmaxq_f32(vmax, vk[l]);
|
||||
}
|
||||
} else {
|
||||
auto v_softcap = vdupq_n_f32(softcap);
|
||||
for (int l = 0; l < k_step/4; ++l) {
|
||||
vk[l] = vmulq_f32(vscale32, vld1q_f32(cache + k_step*j + 4*l));
|
||||
vk[l] = vmulq_f32(v_softcap, v_tanh(vk[l]));
|
||||
vmax = vmaxq_f32(vmax, vk[l]);
|
||||
}
|
||||
}
|
||||
|
||||
float smax = vmaxvq_f32(vmax);
|
||||
if (smax == -INFINITY) {
|
||||
std::memset(cache + k_step*j, 0, k_step*sizeof(float));
|
||||
need_scaling[j] = M[j] == -INFINITY ? 2 : 0;
|
||||
return;
|
||||
}
|
||||
need_scaling[j] = 0;
|
||||
if (smax > M[j]) {
|
||||
if (M[j] > -INFINITY) {
|
||||
float m = expf(M[j] - smax);
|
||||
vms[j] = F16::set1(m);
|
||||
need_scaling[j] = 1;
|
||||
S[j] *= m;
|
||||
} else {
|
||||
need_scaling[j] = 2;
|
||||
S[j] = 0;
|
||||
}
|
||||
M[j] = smax;
|
||||
}
|
||||
auto vm = vdupq_n_f32(M[j]);
|
||||
auto vsum = vdupq_n_f32(0);
|
||||
for (int l = 0; l < k_step/4; ++l) {
|
||||
vk[l] = v_expf(vsubq_f32(vk[l], vm));
|
||||
vsum = vaddq_f32(vsum, vk[l]);
|
||||
F16::store(cache + k_step*j + 4*l, vk[l]);
|
||||
}
|
||||
S[j] += vaddvq_f32(vsum);
|
||||
}
|
||||
#else
|
||||
inline void update_M_S(int j, F16::Data * vk) {
|
||||
if (softcap <= 0.0f) {
|
||||
for (int l = 0; l < k_step/F16::block_size; ++l) vk[l] = F16::mul(vscale, F16::load(cache + k_step*j + F16::block_size*l));
|
||||
@@ -6758,6 +6833,7 @@ struct FlashMS {
|
||||
}
|
||||
S[j] += F16::reduce_add<k_step>(vk);
|
||||
}
|
||||
#endif
|
||||
|
||||
cache_t cache[q_step*k_step];
|
||||
float S[q_step], M[q_step];
|
||||
@@ -6884,7 +6960,11 @@ struct FlashQKfp32 {
|
||||
static_assert(k_step%F16::block_size == 0);
|
||||
static_assert(q_step <= 4 || q_step%4 == 0);
|
||||
|
||||
#ifdef __aarch64__
|
||||
constexpr static bool is_small_head = false;
|
||||
#else
|
||||
constexpr static bool is_small_head = D <= (F16::num_registers/2)*F16::block_size;
|
||||
#endif
|
||||
|
||||
template <bool small = is_small_head, class = std::enable_if<small>, typename q_float>
|
||||
static inline void mult_mask_kq_one(int l1, int m1, int stride_q, int stride_m, const q_float * q, const char * mask,
|
||||
@@ -6987,10 +7067,17 @@ struct FlashQKfp32 {
|
||||
else {
|
||||
mult_mask_kq_l(kh, stride_q, stride_m, q, mask, fms);
|
||||
}
|
||||
#ifdef __aarch64__
|
||||
float32x4_t vk[k_step/4];
|
||||
for (int j = 0; j < q_step; ++j) {
|
||||
fms.update_M_S(j, vk);
|
||||
}
|
||||
#else
|
||||
F16::Data vk[k_step/F16::block_size];
|
||||
for (int j = 0; j < q_step; ++j) {
|
||||
fms.update_M_S(j, vk);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename KHelper, typename q_float>
|
||||
@@ -7002,10 +7089,17 @@ struct FlashQKfp32 {
|
||||
else {
|
||||
mult_mask_kq_l(nq, kh, stride_q, stride_m, q, mask, fms);
|
||||
}
|
||||
#ifdef __aarch64__
|
||||
float32x4_t vk[k_step/4];
|
||||
for (int j = 0; j < nq; ++j) {
|
||||
fms.update_M_S(j, vk);
|
||||
}
|
||||
#else
|
||||
F16::Data vk[k_step/F16::block_size];
|
||||
for (int j = 0; j < nq; ++j) {
|
||||
fms.update_M_S(j, vk);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
#ifdef __aarch64__
|
||||
|
||||
Reference in New Issue
Block a user