diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 907f1271..801aa111 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -6495,6 +6495,7 @@ struct F16 { vst1q_f32(ptr+4, vcvt_f32_f16(vget_high_f16(data))); } static inline void store(float16_t * ptr, Data data) { vst1q_f16(ptr, data); } + static inline void store(float * ptr, float32x4_t data) { vst1q_f32(ptr, data); } static inline Data fmadd(Data prev, Data v1, Data v2) { return vfmaq_f16(prev, v1, v2); } static inline float reduce_max(Data data) { return vmaxvq_f16(data); } static inline float reduce_add(Data data) { @@ -6707,11 +6708,16 @@ struct HelperQ41 final : public BaseHelper { template struct FlashMS { -#ifdef __aarch64__ - using cache_t = float16_t; -#else +// Something goes wrong when storing and manipulating K*Q as fp16. +// It works for some models (e.g., Gemma-2), but not for others (e.g., LLaMA-3.1-8B). +// As I wasn't able to find where we lose precision, let's comment this out +// for now and do the K*Q part in fp32. +//#ifdef __aarch64__ +// using cache_t = float16_t; +//#else +// using cache_t = float; +//#endif using cache_t = float; -#endif FlashMS(float scale, float softcap) : vscale(F16::set1(scale)), softcap(softcap), h_inf(GGML_FP32_TO_FP16(-INFINITY)) {} @@ -6721,6 +6727,75 @@ struct FlashMS { } } +#ifdef __aarch64__ + inline void update_M_S(int j, float32x4_t * vk) { + float32x4_t vmax = vdupq_n_f32(-INFINITY); + // Something goes wrong when storing and manipulating K*Q as fp16. + // It works for some models (e.g., Gemma-2), but not for others (e.g., LLaMA-3.1-8B). + // As I wasn't able to find where we lose precision, let's comment this out + // for now and do the K*Q part in fp32. + //if (softcap <= 0.0f) { + // for (int l = 0; l < k_step/F16::block_size; ++l) { + // auto val = F16::mul(vscale, F16::load(cache + k_step*j + F16::block_size*l)); + // vk[2*l+0] = vcvt_f32_f16(vget_low_f16(val)); + // vk[2*l+1] = vcvt_f32_f16(vget_high_f16(val)); + // vmax = vmaxq_f32(vmax, vmaxq_f32(vk[2*l+0], vk[2*l+1])); + // } + //} else { + // auto v_softcap = vdupq_n_f32(softcap); + // for (int l = 0; l < k_step/F16::block_size; ++l) { + // auto val = F16::mul(vscale, F16::load(cache + k_step*j + F16::block_size*l)); + // vk[2*l+0] = vcvt_f32_f16(vget_low_f16(val)); + // vk[2*l+1] = vcvt_f32_f16(vget_high_f16(val)); + // vk[2*l+0] = vmulq_f32(v_softcap, v_tanh(vk[2*l+0])); + // vk[2*l+1] = vmulq_f32(v_softcap, v_tanh(vk[2*l+1])); + // vmax = vmaxq_f32(vmax, vmaxq_f32(vk[2*l+0], vk[2*l+1])); + // } + //} + auto vscale32 = vcvt_f32_f16(vget_low_f16(vscale)); + if (softcap <= 0.0f) { + for (int l = 0; l < k_step/4; ++l) { + vk[l] = vmulq_f32(vscale32, vld1q_f32(cache + k_step*j + 4*l)); + vmax = vmaxq_f32(vmax, vk[l]); + } + } else { + auto v_softcap = vdupq_n_f32(softcap); + for (int l = 0; l < k_step/4; ++l) { + vk[l] = vmulq_f32(vscale32, vld1q_f32(cache + k_step*j + 4*l)); + vk[l] = vmulq_f32(v_softcap, v_tanh(vk[l])); + vmax = vmaxq_f32(vmax, vk[l]); + } + } + + float smax = vmaxvq_f32(vmax); + if (smax == -INFINITY) { + std::memset(cache + k_step*j, 0, k_step*sizeof(float)); + need_scaling[j] = M[j] == -INFINITY ? 2 : 0; + return; + } + need_scaling[j] = 0; + if (smax > M[j]) { + if (M[j] > -INFINITY) { + float m = expf(M[j] - smax); + vms[j] = F16::set1(m); + need_scaling[j] = 1; + S[j] *= m; + } else { + need_scaling[j] = 2; + S[j] = 0; + } + M[j] = smax; + } + auto vm = vdupq_n_f32(M[j]); + auto vsum = vdupq_n_f32(0); + for (int l = 0; l < k_step/4; ++l) { + vk[l] = v_expf(vsubq_f32(vk[l], vm)); + vsum = vaddq_f32(vsum, vk[l]); + F16::store(cache + k_step*j + 4*l, vk[l]); + } + S[j] += vaddvq_f32(vsum); + } +#else inline void update_M_S(int j, F16::Data * vk) { if (softcap <= 0.0f) { for (int l = 0; l < k_step/F16::block_size; ++l) vk[l] = F16::mul(vscale, F16::load(cache + k_step*j + F16::block_size*l)); @@ -6758,6 +6833,7 @@ struct FlashMS { } S[j] += F16::reduce_add(vk); } +#endif cache_t cache[q_step*k_step]; float S[q_step], M[q_step]; @@ -6884,7 +6960,11 @@ struct FlashQKfp32 { static_assert(k_step%F16::block_size == 0); static_assert(q_step <= 4 || q_step%4 == 0); +#ifdef __aarch64__ + constexpr static bool is_small_head = false; +#else constexpr static bool is_small_head = D <= (F16::num_registers/2)*F16::block_size; +#endif template , typename q_float> static inline void mult_mask_kq_one(int l1, int m1, int stride_q, int stride_m, const q_float * q, const char * mask, @@ -6987,10 +7067,17 @@ struct FlashQKfp32 { else { mult_mask_kq_l(kh, stride_q, stride_m, q, mask, fms); } +#ifdef __aarch64__ + float32x4_t vk[k_step/4]; + for (int j = 0; j < q_step; ++j) { + fms.update_M_S(j, vk); + } +#else F16::Data vk[k_step/F16::block_size]; for (int j = 0; j < q_step; ++j) { fms.update_M_S(j, vk); } +#endif } template @@ -7002,10 +7089,17 @@ struct FlashQKfp32 { else { mult_mask_kq_l(nq, kh, stride_q, stride_m, q, mask, fms); } +#ifdef __aarch64__ + float32x4_t vk[k_step/4]; + for (int j = 0; j < nq; ++j) { + fms.update_M_S(j, vk); + } +#else F16::Data vk[k_step/F16::block_size]; for (int j = 0; j < nq; ++j) { fms.update_M_S(j, vk); } +#endif } #ifdef __aarch64__