diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 3a3b9eba..cae02ed1 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -6293,6 +6293,11 @@ inline float32x4_t v_expf(float32x4_t x) { return vbslq_f32(vcagtq_f32(n, vdupq_n_f32(192)), vmulq_f32(s1, s1), vbslq_f32(c, vmulq_f32(vfmaq_f32(s2, s2, j), s1), vfmaq_f32(k, k, j))); } +inline float16x8_t v_expf(float16x8_t x) { + auto val1 = v_expf(vcvt_f32_f16(vget_low_f16(x))); + auto val2 = v_expf(vcvt_f32_f16(vget_high_f16(x))); + return vcombine_f16(vcvt_f16_f32(val1), vcvt_f16_f32(val2)); +} inline float32x4_t v_tanh(float32x4_t x) { const float32x4_t one = vdupq_n_f32(1.0f); const float32x4_t two_x = vmulq_f32(x, vdupq_n_f32(2.f)); @@ -6302,6 +6307,11 @@ inline float32x4_t v_tanh(float32x4_t x) { return vreinterpretq_f32_u32(vorrq_u32(vandq_u32(vreinterpretq_u32_f32(one), mask), vbicq_u32(vreinterpretq_u32_f32(res), mask))); //return vdivq_f32(vsubq_f32(exp_two_x, one), vaddq_f32(exp_two_x, one)); } +inline float32x4_t v_tanh(float16x8_t x) { + auto val1 = v_tanh(vcvt_f32_f16(vget_low_f16(x))); + auto val2 = v_tanh(vcvt_f32_f16(vget_high_f16(x))); + return vcombine_f16(vcvt_f16_f32(val1), vcvt_f16_f32(val2)); +} #endif #if defined(__AVX512F__) && defined(__AVX512DQ__) @@ -6401,7 +6411,7 @@ inline __m256 v_tanh(__m256 x) { #endif } // namespace -#ifndef __aarch64__ +//#ifndef __aarch64__ namespace { @@ -6442,7 +6452,7 @@ struct F16 { template static inline float reduce_add(const Data * data) { return reduce_T(data); } -#else +#elif defined __AVX2__ using Data = __m256; constexpr static int block_size = 8; constexpr static int num_registers = 16; @@ -6463,6 +6473,40 @@ struct F16 { template static inline float reduce_add(const Data * data) { return reduce_T(data); } +#else + using Data = float16x8_t; + constexpr static int block_size = 8; + constexpr static int num_registers = 32; + constexpr static int q_step = 8; + static inline Data zero() { return vdupq_n_f16(0); } + static inline Data load(const char * ptr, int i) { return vld1q_f16((const float16_t *)ptr + block_size*i); } + static inline Data load(const float16_t * ptr, int i) { return vld1q_f16(ptr + block_size*i); } + static inline Data load(const float16_t * ptr) { return vld1q_f16(ptr); } + static inline Data load(const float * ptr) { + auto val1 = vld1q_f32(ptr); + auto val2 = vld1q_f32(ptr+4); + return vcombine_f16(vcvt_f16_f32(val1), vcvt_f16_f32(val2)); + } + static inline Data set1(float val) { return vdupq_n_f16(val); } + static inline Data mul(Data v1, Data v2) { return vmulq_f16(v1, v2); } + static inline Data sub(Data v1, Data v2) { return vsubq_f16(v1, v2); } + static inline void store(float * ptr, Data data) { + vst1q_f32(ptr+0, vcvt_f32_f16(vget_low_f16(data))); + vst1q_f32(ptr+4, vcvt_f32_f16(vget_high_f16(data))); + } + static inline void store(float16_t * ptr, Data data) { vst1q_f16(ptr, data); } + static inline Data fmadd(Data prev, Data v1, Data v2) { return vfmaq_f16(prev, v1, v2); } + static inline float reduce_max(Data data) { return vmaxvq_f16(data); } + static inline float reduce_add(Data data) { + auto sum = vadd_f16(vget_low_f16(data), vget_high_f16(data)); + return vaddvq_f32(vcvt_f32_f16(sum)); + } + template static inline float reduce_max(const Data * data) { + return reduce_T(data); + } + template static inline float reduce_add(const Data * data) { + return reduce_T(data); + } #endif template static float reduce_T(const Data * data) { @@ -6663,6 +6707,12 @@ struct HelperQ41 final : public BaseHelper { template struct FlashMS { +#ifdef __aarch64__ + using cache_t = float16_t; +#else + using cache_t = float; +#endif + FlashMS(float scale, float softcap) : vscale(F16::set1(scale)), softcap(softcap), h_inf(GGML_FP32_TO_FP16(-INFINITY)) {} inline void init_qstep() { @@ -6709,7 +6759,7 @@ struct FlashMS { S[j] += F16::reduce_add(vk); } - float cache[q_step*k_step]; + cache_t cache[q_step*k_step]; float S[q_step], M[q_step]; int need_scaling[q_step]; F16::Data vms[q_step]; @@ -6722,6 +6772,12 @@ struct FlashMS { template struct FlashQKV { +#ifdef __aarch64__ + using qkv_cache_t = float16_t; +#else + using qkv_cache_t = float; +#endif + // This fails for head sizes of 80 and 112 as D/16 is odd, so we cannot do steps of 2 // Hence, for now, we will not handle head sizes of 80 and 112 template @@ -6792,7 +6848,7 @@ struct FlashQKV { } } - inline void normalize_and_store(const FlashMS& fms, int j, const float * R, float * qkv) const { + inline void normalize_and_store(const FlashMS& fms, int j, const qkv_cache_t * R, float * qkv) const { GGML_ASSERT(fms.S[j] > 0); auto norm = F16::set1(1/fms.S[j]); for (int i = 0; i < D/F16::block_size; ++i) { @@ -6819,7 +6875,7 @@ struct FlashQKV { } } - float qkv_cache[D*q_step]; + qkv_cache_t qkv_cache[D*q_step]; }; template @@ -7469,29 +7525,29 @@ bool iqk_flash_attn_noalibi(int int_type_k, // type of k return true; } -#else -// TODO -bool iqk_flash_attn_noalibi([[maybe_unused]] int int_type_k, // type of k - [[maybe_unused]] int int_type_v, // type of v - [[maybe_unused]] int D, // head size - [[maybe_unused]] int nq, // number of columns in q - [[maybe_unused]] int nk, // number of rows in k - [[maybe_unused]] int stride_q, // distance between q columns in bytes - [[maybe_unused]] int stride_k, // distance between k rows in bytes - [[maybe_unused]] int stride_v, // distance between v rows in bytes - [[maybe_unused]] int stride_m, // distance between mask rows (in bytes - [[maybe_unused]] int stride_qkv, // distance between rows in mask (in bytes) - [[maybe_unused]] const float * q, // q matrix. - [[maybe_unused]] const void * k, // k matrix. Assumed to be fp16, nq x nk elements - [[maybe_unused]] const void * v, // v matrix. Assumed to be fp16, nq x nk elements - [[maybe_unused]] const void * mask, // mask. If not null, assumed to be fp16. nq x nk elements - [[maybe_unused]] float scale, // scale applied before softmax - [[maybe_unused]] float softcap, // if > 0, a "soft-cap" operation is applied before softmax - [[maybe_unused]] float * qkv) { // v*softmax(scale*(k*q)) - return false; -} - -#endif +//#else +//// TODO +//bool iqk_flash_attn_noalibi([[maybe_unused]] int int_type_k, // type of k +// [[maybe_unused]] int int_type_v, // type of v +// [[maybe_unused]] int D, // head size +// [[maybe_unused]] int nq, // number of columns in q +// [[maybe_unused]] int nk, // number of rows in k +// [[maybe_unused]] int stride_q, // distance between q columns in bytes +// [[maybe_unused]] int stride_k, // distance between k rows in bytes +// [[maybe_unused]] int stride_v, // distance between v rows in bytes +// [[maybe_unused]] int stride_m, // distance between mask rows (in bytes +// [[maybe_unused]] int stride_qkv, // distance between rows in mask (in bytes) +// [[maybe_unused]] const float * q, // q matrix. +// [[maybe_unused]] const void * k, // k matrix. Assumed to be fp16, nq x nk elements +// [[maybe_unused]] const void * v, // v matrix. Assumed to be fp16, nq x nk elements +// [[maybe_unused]] const void * mask, // mask. If not null, assumed to be fp16. nq x nk elements +// [[maybe_unused]] float scale, // scale applied before softmax +// [[maybe_unused]] float softcap, // if > 0, a "soft-cap" operation is applied before softmax +// [[maybe_unused]] float * qkv) { // v*softmax(scale*(k*q)) +// return false; +//} +// +//#endif #else // IQK_IMPLEMENT