NEON Flash Attention - use fp32 for K*Q operations

Else I get wrong results for LLaMA-3.1-8B (but it works for
Gemma-2b).
This commit is contained in:
Iwan Kawrakow
2024-09-11 08:37:26 +02:00
parent 2eb9e212be
commit 01b7a3a981

View File

@@ -6495,6 +6495,7 @@ struct F16 {
vst1q_f32(ptr+4, vcvt_f32_f16(vget_high_f16(data)));
}
static inline void store(float16_t * ptr, Data data) { vst1q_f16(ptr, data); }
static inline void store(float * ptr, float32x4_t data) { vst1q_f32(ptr, data); }
static inline Data fmadd(Data prev, Data v1, Data v2) { return vfmaq_f16(prev, v1, v2); }
static inline float reduce_max(Data data) { return vmaxvq_f16(data); }
static inline float reduce_add(Data data) {
@@ -6707,11 +6708,16 @@ struct HelperQ41 final : public BaseHelper<step> {
template <int q_step, int k_step>
struct FlashMS {
#ifdef __aarch64__
using cache_t = float16_t;
#else
// Something goes wrong when storing and manipulating K*Q as fp16.
// It works for some models (e.g., Gemma-2), but not for others (e.g., LLaMA-3.1-8B).
// As I wasn't able to find where we lose precision, let's comment this out
// for now and do the K*Q part in fp32.
//#ifdef __aarch64__
// using cache_t = float16_t;
//#else
// using cache_t = float;
//#endif
using cache_t = float;
#endif
FlashMS(float scale, float softcap) : vscale(F16::set1(scale)), softcap(softcap), h_inf(GGML_FP32_TO_FP16(-INFINITY)) {}
@@ -6721,6 +6727,75 @@ struct FlashMS {
}
}
#ifdef __aarch64__
inline void update_M_S(int j, float32x4_t * vk) {
float32x4_t vmax = vdupq_n_f32(-INFINITY);
// Something goes wrong when storing and manipulating K*Q as fp16.
// It works for some models (e.g., Gemma-2), but not for others (e.g., LLaMA-3.1-8B).
// As I wasn't able to find where we lose precision, let's comment this out
// for now and do the K*Q part in fp32.
//if (softcap <= 0.0f) {
// for (int l = 0; l < k_step/F16::block_size; ++l) {
// auto val = F16::mul(vscale, F16::load(cache + k_step*j + F16::block_size*l));
// vk[2*l+0] = vcvt_f32_f16(vget_low_f16(val));
// vk[2*l+1] = vcvt_f32_f16(vget_high_f16(val));
// vmax = vmaxq_f32(vmax, vmaxq_f32(vk[2*l+0], vk[2*l+1]));
// }
//} else {
// auto v_softcap = vdupq_n_f32(softcap);
// for (int l = 0; l < k_step/F16::block_size; ++l) {
// auto val = F16::mul(vscale, F16::load(cache + k_step*j + F16::block_size*l));
// vk[2*l+0] = vcvt_f32_f16(vget_low_f16(val));
// vk[2*l+1] = vcvt_f32_f16(vget_high_f16(val));
// vk[2*l+0] = vmulq_f32(v_softcap, v_tanh(vk[2*l+0]));
// vk[2*l+1] = vmulq_f32(v_softcap, v_tanh(vk[2*l+1]));
// vmax = vmaxq_f32(vmax, vmaxq_f32(vk[2*l+0], vk[2*l+1]));
// }
//}
auto vscale32 = vcvt_f32_f16(vget_low_f16(vscale));
if (softcap <= 0.0f) {
for (int l = 0; l < k_step/4; ++l) {
vk[l] = vmulq_f32(vscale32, vld1q_f32(cache + k_step*j + 4*l));
vmax = vmaxq_f32(vmax, vk[l]);
}
} else {
auto v_softcap = vdupq_n_f32(softcap);
for (int l = 0; l < k_step/4; ++l) {
vk[l] = vmulq_f32(vscale32, vld1q_f32(cache + k_step*j + 4*l));
vk[l] = vmulq_f32(v_softcap, v_tanh(vk[l]));
vmax = vmaxq_f32(vmax, vk[l]);
}
}
float smax = vmaxvq_f32(vmax);
if (smax == -INFINITY) {
std::memset(cache + k_step*j, 0, k_step*sizeof(float));
need_scaling[j] = M[j] == -INFINITY ? 2 : 0;
return;
}
need_scaling[j] = 0;
if (smax > M[j]) {
if (M[j] > -INFINITY) {
float m = expf(M[j] - smax);
vms[j] = F16::set1(m);
need_scaling[j] = 1;
S[j] *= m;
} else {
need_scaling[j] = 2;
S[j] = 0;
}
M[j] = smax;
}
auto vm = vdupq_n_f32(M[j]);
auto vsum = vdupq_n_f32(0);
for (int l = 0; l < k_step/4; ++l) {
vk[l] = v_expf(vsubq_f32(vk[l], vm));
vsum = vaddq_f32(vsum, vk[l]);
F16::store(cache + k_step*j + 4*l, vk[l]);
}
S[j] += vaddvq_f32(vsum);
}
#else
inline void update_M_S(int j, F16::Data * vk) {
if (softcap <= 0.0f) {
for (int l = 0; l < k_step/F16::block_size; ++l) vk[l] = F16::mul(vscale, F16::load(cache + k_step*j + F16::block_size*l));
@@ -6758,6 +6833,7 @@ struct FlashMS {
}
S[j] += F16::reduce_add<k_step>(vk);
}
#endif
cache_t cache[q_step*k_step];
float S[q_step], M[q_step];
@@ -6884,7 +6960,11 @@ struct FlashQKfp32 {
static_assert(k_step%F16::block_size == 0);
static_assert(q_step <= 4 || q_step%4 == 0);
#ifdef __aarch64__
constexpr static bool is_small_head = false;
#else
constexpr static bool is_small_head = D <= (F16::num_registers/2)*F16::block_size;
#endif
template <bool small = is_small_head, class = std::enable_if<small>, typename q_float>
static inline void mult_mask_kq_one(int l1, int m1, int stride_q, int stride_m, const q_float * q, const char * mask,
@@ -6987,10 +7067,17 @@ struct FlashQKfp32 {
else {
mult_mask_kq_l(kh, stride_q, stride_m, q, mask, fms);
}
#ifdef __aarch64__
float32x4_t vk[k_step/4];
for (int j = 0; j < q_step; ++j) {
fms.update_M_S(j, vk);
}
#else
F16::Data vk[k_step/F16::block_size];
for (int j = 0; j < q_step; ++j) {
fms.update_M_S(j, vk);
}
#endif
}
template <typename KHelper, typename q_float>
@@ -7002,10 +7089,17 @@ struct FlashQKfp32 {
else {
mult_mask_kq_l(nq, kh, stride_q, stride_m, q, mask, fms);
}
#ifdef __aarch64__
float32x4_t vk[k_step/4];
for (int j = 0; j < nq; ++j) {
fms.update_M_S(j, vk);
}
#else
F16::Data vk[k_step/F16::block_size];
for (int j = 0; j < nq; ++j) {
fms.update_M_S(j, vk);
}
#endif
}
#ifdef __aarch64__