mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-03-03 18:40:14 +00:00
NEON Flash Attention - first working version
Simply reuse the Zen4/AVX2 implementation, but use f16 for the K*Q multiplication and V*softmax(K*Q) accumulation. This makes the FlashMS portion somewhat awkward because we do not have fast f16 implementations for expf (and tanh when softcap is enabled), so we need to convert back-and-fort to f32. FA is slightly faster than no-FA for the 4B TriLM model, but lightly slower for Gemma-2b.
This commit is contained in:
@@ -6293,6 +6293,11 @@ inline float32x4_t v_expf(float32x4_t x) {
|
||||
return vbslq_f32(vcagtq_f32(n, vdupq_n_f32(192)), vmulq_f32(s1, s1),
|
||||
vbslq_f32(c, vmulq_f32(vfmaq_f32(s2, s2, j), s1), vfmaq_f32(k, k, j)));
|
||||
}
|
||||
inline float16x8_t v_expf(float16x8_t x) {
|
||||
auto val1 = v_expf(vcvt_f32_f16(vget_low_f16(x)));
|
||||
auto val2 = v_expf(vcvt_f32_f16(vget_high_f16(x)));
|
||||
return vcombine_f16(vcvt_f16_f32(val1), vcvt_f16_f32(val2));
|
||||
}
|
||||
inline float32x4_t v_tanh(float32x4_t x) {
|
||||
const float32x4_t one = vdupq_n_f32(1.0f);
|
||||
const float32x4_t two_x = vmulq_f32(x, vdupq_n_f32(2.f));
|
||||
@@ -6302,6 +6307,11 @@ inline float32x4_t v_tanh(float32x4_t x) {
|
||||
return vreinterpretq_f32_u32(vorrq_u32(vandq_u32(vreinterpretq_u32_f32(one), mask), vbicq_u32(vreinterpretq_u32_f32(res), mask)));
|
||||
//return vdivq_f32(vsubq_f32(exp_two_x, one), vaddq_f32(exp_two_x, one));
|
||||
}
|
||||
inline float32x4_t v_tanh(float16x8_t x) {
|
||||
auto val1 = v_tanh(vcvt_f32_f16(vget_low_f16(x)));
|
||||
auto val2 = v_tanh(vcvt_f32_f16(vget_high_f16(x)));
|
||||
return vcombine_f16(vcvt_f16_f32(val1), vcvt_f16_f32(val2));
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(__AVX512F__) && defined(__AVX512DQ__)
|
||||
@@ -6401,7 +6411,7 @@ inline __m256 v_tanh(__m256 x) {
|
||||
#endif
|
||||
} // namespace
|
||||
|
||||
#ifndef __aarch64__
|
||||
//#ifndef __aarch64__
|
||||
|
||||
namespace {
|
||||
|
||||
@@ -6442,7 +6452,7 @@ struct F16 {
|
||||
template <int k_step> static inline float reduce_add(const Data * data) {
|
||||
return reduce_T<k_step, _mm512_add_ps, _mm512_reduce_add_ps>(data);
|
||||
}
|
||||
#else
|
||||
#elif defined __AVX2__
|
||||
using Data = __m256;
|
||||
constexpr static int block_size = 8;
|
||||
constexpr static int num_registers = 16;
|
||||
@@ -6463,6 +6473,40 @@ struct F16 {
|
||||
template <int k_step> static inline float reduce_add(const Data * data) {
|
||||
return reduce_T<k_step, _mm256_add_ps, &F16::reduce_add>(data);
|
||||
}
|
||||
#else
|
||||
using Data = float16x8_t;
|
||||
constexpr static int block_size = 8;
|
||||
constexpr static int num_registers = 32;
|
||||
constexpr static int q_step = 8;
|
||||
static inline Data zero() { return vdupq_n_f16(0); }
|
||||
static inline Data load(const char * ptr, int i) { return vld1q_f16((const float16_t *)ptr + block_size*i); }
|
||||
static inline Data load(const float16_t * ptr, int i) { return vld1q_f16(ptr + block_size*i); }
|
||||
static inline Data load(const float16_t * ptr) { return vld1q_f16(ptr); }
|
||||
static inline Data load(const float * ptr) {
|
||||
auto val1 = vld1q_f32(ptr);
|
||||
auto val2 = vld1q_f32(ptr+4);
|
||||
return vcombine_f16(vcvt_f16_f32(val1), vcvt_f16_f32(val2));
|
||||
}
|
||||
static inline Data set1(float val) { return vdupq_n_f16(val); }
|
||||
static inline Data mul(Data v1, Data v2) { return vmulq_f16(v1, v2); }
|
||||
static inline Data sub(Data v1, Data v2) { return vsubq_f16(v1, v2); }
|
||||
static inline void store(float * ptr, Data data) {
|
||||
vst1q_f32(ptr+0, vcvt_f32_f16(vget_low_f16(data)));
|
||||
vst1q_f32(ptr+4, vcvt_f32_f16(vget_high_f16(data)));
|
||||
}
|
||||
static inline void store(float16_t * ptr, Data data) { vst1q_f16(ptr, data); }
|
||||
static inline Data fmadd(Data prev, Data v1, Data v2) { return vfmaq_f16(prev, v1, v2); }
|
||||
static inline float reduce_max(Data data) { return vmaxvq_f16(data); }
|
||||
static inline float reduce_add(Data data) {
|
||||
auto sum = vadd_f16(vget_low_f16(data), vget_high_f16(data));
|
||||
return vaddvq_f32(vcvt_f32_f16(sum));
|
||||
}
|
||||
template <int k_step> static inline float reduce_max(const Data * data) {
|
||||
return reduce_T<k_step, vmaxq_f16, &F16::reduce_max>(data);
|
||||
}
|
||||
template <int k_step> static inline float reduce_add(const Data * data) {
|
||||
return reduce_T<k_step, vaddq_f16, &F16::reduce_add>(data);
|
||||
}
|
||||
#endif
|
||||
template <int k_step, Data (*Op_combine)(Data, Data), float (*Op)(Data)>
|
||||
static float reduce_T(const Data * data) {
|
||||
@@ -6663,6 +6707,12 @@ struct HelperQ41 final : public BaseHelper<step> {
|
||||
|
||||
template <int q_step, int k_step>
|
||||
struct FlashMS {
|
||||
#ifdef __aarch64__
|
||||
using cache_t = float16_t;
|
||||
#else
|
||||
using cache_t = float;
|
||||
#endif
|
||||
|
||||
FlashMS(float scale, float softcap) : vscale(F16::set1(scale)), softcap(softcap), h_inf(GGML_FP32_TO_FP16(-INFINITY)) {}
|
||||
|
||||
inline void init_qstep() {
|
||||
@@ -6709,7 +6759,7 @@ struct FlashMS {
|
||||
S[j] += F16::reduce_add<k_step>(vk);
|
||||
}
|
||||
|
||||
float cache[q_step*k_step];
|
||||
cache_t cache[q_step*k_step];
|
||||
float S[q_step], M[q_step];
|
||||
int need_scaling[q_step];
|
||||
F16::Data vms[q_step];
|
||||
@@ -6722,6 +6772,12 @@ struct FlashMS {
|
||||
template <int D, int q_step, int k_step>
|
||||
struct FlashQKV {
|
||||
|
||||
#ifdef __aarch64__
|
||||
using qkv_cache_t = float16_t;
|
||||
#else
|
||||
using qkv_cache_t = float;
|
||||
#endif
|
||||
|
||||
// This fails for head sizes of 80 and 112 as D/16 is odd, so we cannot do steps of 2
|
||||
// Hence, for now, we will not handle head sizes of 80 and 112
|
||||
template <typename VHelper>
|
||||
@@ -6792,7 +6848,7 @@ struct FlashQKV {
|
||||
}
|
||||
}
|
||||
|
||||
inline void normalize_and_store(const FlashMS<q_step, k_step>& fms, int j, const float * R, float * qkv) const {
|
||||
inline void normalize_and_store(const FlashMS<q_step, k_step>& fms, int j, const qkv_cache_t * R, float * qkv) const {
|
||||
GGML_ASSERT(fms.S[j] > 0);
|
||||
auto norm = F16::set1(1/fms.S[j]);
|
||||
for (int i = 0; i < D/F16::block_size; ++i) {
|
||||
@@ -6819,7 +6875,7 @@ struct FlashQKV {
|
||||
}
|
||||
}
|
||||
|
||||
float qkv_cache[D*q_step];
|
||||
qkv_cache_t qkv_cache[D*q_step];
|
||||
};
|
||||
|
||||
template <int D, int q_step, int k_step>
|
||||
@@ -7469,29 +7525,29 @@ bool iqk_flash_attn_noalibi(int int_type_k, // type of k
|
||||
return true;
|
||||
}
|
||||
|
||||
#else
|
||||
// TODO
|
||||
bool iqk_flash_attn_noalibi([[maybe_unused]] int int_type_k, // type of k
|
||||
[[maybe_unused]] int int_type_v, // type of v
|
||||
[[maybe_unused]] int D, // head size
|
||||
[[maybe_unused]] int nq, // number of columns in q
|
||||
[[maybe_unused]] int nk, // number of rows in k
|
||||
[[maybe_unused]] int stride_q, // distance between q columns in bytes
|
||||
[[maybe_unused]] int stride_k, // distance between k rows in bytes
|
||||
[[maybe_unused]] int stride_v, // distance between v rows in bytes
|
||||
[[maybe_unused]] int stride_m, // distance between mask rows (in bytes
|
||||
[[maybe_unused]] int stride_qkv, // distance between rows in mask (in bytes)
|
||||
[[maybe_unused]] const float * q, // q matrix.
|
||||
[[maybe_unused]] const void * k, // k matrix. Assumed to be fp16, nq x nk elements
|
||||
[[maybe_unused]] const void * v, // v matrix. Assumed to be fp16, nq x nk elements
|
||||
[[maybe_unused]] const void * mask, // mask. If not null, assumed to be fp16. nq x nk elements
|
||||
[[maybe_unused]] float scale, // scale applied before softmax
|
||||
[[maybe_unused]] float softcap, // if > 0, a "soft-cap" operation is applied before softmax
|
||||
[[maybe_unused]] float * qkv) { // v*softmax(scale*(k*q))
|
||||
return false;
|
||||
}
|
||||
|
||||
#endif
|
||||
//#else
|
||||
//// TODO
|
||||
//bool iqk_flash_attn_noalibi([[maybe_unused]] int int_type_k, // type of k
|
||||
// [[maybe_unused]] int int_type_v, // type of v
|
||||
// [[maybe_unused]] int D, // head size
|
||||
// [[maybe_unused]] int nq, // number of columns in q
|
||||
// [[maybe_unused]] int nk, // number of rows in k
|
||||
// [[maybe_unused]] int stride_q, // distance between q columns in bytes
|
||||
// [[maybe_unused]] int stride_k, // distance between k rows in bytes
|
||||
// [[maybe_unused]] int stride_v, // distance between v rows in bytes
|
||||
// [[maybe_unused]] int stride_m, // distance between mask rows (in bytes
|
||||
// [[maybe_unused]] int stride_qkv, // distance between rows in mask (in bytes)
|
||||
// [[maybe_unused]] const float * q, // q matrix.
|
||||
// [[maybe_unused]] const void * k, // k matrix. Assumed to be fp16, nq x nk elements
|
||||
// [[maybe_unused]] const void * v, // v matrix. Assumed to be fp16, nq x nk elements
|
||||
// [[maybe_unused]] const void * mask, // mask. If not null, assumed to be fp16. nq x nk elements
|
||||
// [[maybe_unused]] float scale, // scale applied before softmax
|
||||
// [[maybe_unused]] float softcap, // if > 0, a "soft-cap" operation is applied before softmax
|
||||
// [[maybe_unused]] float * qkv) { // v*softmax(scale*(k*q))
|
||||
// return false;
|
||||
//}
|
||||
//
|
||||
//#endif
|
||||
|
||||
#else // IQK_IMPLEMENT
|
||||
|
||||
|
||||
Reference in New Issue
Block a user