mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-30 19:31:48 +00:00
NEON Flash Attention - first working version
Simply reuse the Zen4/AVX2 implementation, but use f16 for the K*Q multiplication and V*softmax(K*Q) accumulation. This makes the FlashMS portion somewhat awkward because we do not have fast f16 implementations for expf (and tanh when softcap is enabled), so we need to convert back-and-fort to f32. FA is slightly faster than no-FA for the 4B TriLM model, but lightly slower for Gemma-2b.
This commit is contained in:
@@ -6293,6 +6293,11 @@ inline float32x4_t v_expf(float32x4_t x) {
|
|||||||
return vbslq_f32(vcagtq_f32(n, vdupq_n_f32(192)), vmulq_f32(s1, s1),
|
return vbslq_f32(vcagtq_f32(n, vdupq_n_f32(192)), vmulq_f32(s1, s1),
|
||||||
vbslq_f32(c, vmulq_f32(vfmaq_f32(s2, s2, j), s1), vfmaq_f32(k, k, j)));
|
vbslq_f32(c, vmulq_f32(vfmaq_f32(s2, s2, j), s1), vfmaq_f32(k, k, j)));
|
||||||
}
|
}
|
||||||
|
inline float16x8_t v_expf(float16x8_t x) {
|
||||||
|
auto val1 = v_expf(vcvt_f32_f16(vget_low_f16(x)));
|
||||||
|
auto val2 = v_expf(vcvt_f32_f16(vget_high_f16(x)));
|
||||||
|
return vcombine_f16(vcvt_f16_f32(val1), vcvt_f16_f32(val2));
|
||||||
|
}
|
||||||
inline float32x4_t v_tanh(float32x4_t x) {
|
inline float32x4_t v_tanh(float32x4_t x) {
|
||||||
const float32x4_t one = vdupq_n_f32(1.0f);
|
const float32x4_t one = vdupq_n_f32(1.0f);
|
||||||
const float32x4_t two_x = vmulq_f32(x, vdupq_n_f32(2.f));
|
const float32x4_t two_x = vmulq_f32(x, vdupq_n_f32(2.f));
|
||||||
@@ -6302,6 +6307,11 @@ inline float32x4_t v_tanh(float32x4_t x) {
|
|||||||
return vreinterpretq_f32_u32(vorrq_u32(vandq_u32(vreinterpretq_u32_f32(one), mask), vbicq_u32(vreinterpretq_u32_f32(res), mask)));
|
return vreinterpretq_f32_u32(vorrq_u32(vandq_u32(vreinterpretq_u32_f32(one), mask), vbicq_u32(vreinterpretq_u32_f32(res), mask)));
|
||||||
//return vdivq_f32(vsubq_f32(exp_two_x, one), vaddq_f32(exp_two_x, one));
|
//return vdivq_f32(vsubq_f32(exp_two_x, one), vaddq_f32(exp_two_x, one));
|
||||||
}
|
}
|
||||||
|
inline float32x4_t v_tanh(float16x8_t x) {
|
||||||
|
auto val1 = v_tanh(vcvt_f32_f16(vget_low_f16(x)));
|
||||||
|
auto val2 = v_tanh(vcvt_f32_f16(vget_high_f16(x)));
|
||||||
|
return vcombine_f16(vcvt_f16_f32(val1), vcvt_f16_f32(val2));
|
||||||
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if defined(__AVX512F__) && defined(__AVX512DQ__)
|
#if defined(__AVX512F__) && defined(__AVX512DQ__)
|
||||||
@@ -6401,7 +6411,7 @@ inline __m256 v_tanh(__m256 x) {
|
|||||||
#endif
|
#endif
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
#ifndef __aarch64__
|
//#ifndef __aarch64__
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
@@ -6442,7 +6452,7 @@ struct F16 {
|
|||||||
template <int k_step> static inline float reduce_add(const Data * data) {
|
template <int k_step> static inline float reduce_add(const Data * data) {
|
||||||
return reduce_T<k_step, _mm512_add_ps, _mm512_reduce_add_ps>(data);
|
return reduce_T<k_step, _mm512_add_ps, _mm512_reduce_add_ps>(data);
|
||||||
}
|
}
|
||||||
#else
|
#elif defined __AVX2__
|
||||||
using Data = __m256;
|
using Data = __m256;
|
||||||
constexpr static int block_size = 8;
|
constexpr static int block_size = 8;
|
||||||
constexpr static int num_registers = 16;
|
constexpr static int num_registers = 16;
|
||||||
@@ -6463,6 +6473,40 @@ struct F16 {
|
|||||||
template <int k_step> static inline float reduce_add(const Data * data) {
|
template <int k_step> static inline float reduce_add(const Data * data) {
|
||||||
return reduce_T<k_step, _mm256_add_ps, &F16::reduce_add>(data);
|
return reduce_T<k_step, _mm256_add_ps, &F16::reduce_add>(data);
|
||||||
}
|
}
|
||||||
|
#else
|
||||||
|
using Data = float16x8_t;
|
||||||
|
constexpr static int block_size = 8;
|
||||||
|
constexpr static int num_registers = 32;
|
||||||
|
constexpr static int q_step = 8;
|
||||||
|
static inline Data zero() { return vdupq_n_f16(0); }
|
||||||
|
static inline Data load(const char * ptr, int i) { return vld1q_f16((const float16_t *)ptr + block_size*i); }
|
||||||
|
static inline Data load(const float16_t * ptr, int i) { return vld1q_f16(ptr + block_size*i); }
|
||||||
|
static inline Data load(const float16_t * ptr) { return vld1q_f16(ptr); }
|
||||||
|
static inline Data load(const float * ptr) {
|
||||||
|
auto val1 = vld1q_f32(ptr);
|
||||||
|
auto val2 = vld1q_f32(ptr+4);
|
||||||
|
return vcombine_f16(vcvt_f16_f32(val1), vcvt_f16_f32(val2));
|
||||||
|
}
|
||||||
|
static inline Data set1(float val) { return vdupq_n_f16(val); }
|
||||||
|
static inline Data mul(Data v1, Data v2) { return vmulq_f16(v1, v2); }
|
||||||
|
static inline Data sub(Data v1, Data v2) { return vsubq_f16(v1, v2); }
|
||||||
|
static inline void store(float * ptr, Data data) {
|
||||||
|
vst1q_f32(ptr+0, vcvt_f32_f16(vget_low_f16(data)));
|
||||||
|
vst1q_f32(ptr+4, vcvt_f32_f16(vget_high_f16(data)));
|
||||||
|
}
|
||||||
|
static inline void store(float16_t * ptr, Data data) { vst1q_f16(ptr, data); }
|
||||||
|
static inline Data fmadd(Data prev, Data v1, Data v2) { return vfmaq_f16(prev, v1, v2); }
|
||||||
|
static inline float reduce_max(Data data) { return vmaxvq_f16(data); }
|
||||||
|
static inline float reduce_add(Data data) {
|
||||||
|
auto sum = vadd_f16(vget_low_f16(data), vget_high_f16(data));
|
||||||
|
return vaddvq_f32(vcvt_f32_f16(sum));
|
||||||
|
}
|
||||||
|
template <int k_step> static inline float reduce_max(const Data * data) {
|
||||||
|
return reduce_T<k_step, vmaxq_f16, &F16::reduce_max>(data);
|
||||||
|
}
|
||||||
|
template <int k_step> static inline float reduce_add(const Data * data) {
|
||||||
|
return reduce_T<k_step, vaddq_f16, &F16::reduce_add>(data);
|
||||||
|
}
|
||||||
#endif
|
#endif
|
||||||
template <int k_step, Data (*Op_combine)(Data, Data), float (*Op)(Data)>
|
template <int k_step, Data (*Op_combine)(Data, Data), float (*Op)(Data)>
|
||||||
static float reduce_T(const Data * data) {
|
static float reduce_T(const Data * data) {
|
||||||
@@ -6663,6 +6707,12 @@ struct HelperQ41 final : public BaseHelper<step> {
|
|||||||
|
|
||||||
template <int q_step, int k_step>
|
template <int q_step, int k_step>
|
||||||
struct FlashMS {
|
struct FlashMS {
|
||||||
|
#ifdef __aarch64__
|
||||||
|
using cache_t = float16_t;
|
||||||
|
#else
|
||||||
|
using cache_t = float;
|
||||||
|
#endif
|
||||||
|
|
||||||
FlashMS(float scale, float softcap) : vscale(F16::set1(scale)), softcap(softcap), h_inf(GGML_FP32_TO_FP16(-INFINITY)) {}
|
FlashMS(float scale, float softcap) : vscale(F16::set1(scale)), softcap(softcap), h_inf(GGML_FP32_TO_FP16(-INFINITY)) {}
|
||||||
|
|
||||||
inline void init_qstep() {
|
inline void init_qstep() {
|
||||||
@@ -6709,7 +6759,7 @@ struct FlashMS {
|
|||||||
S[j] += F16::reduce_add<k_step>(vk);
|
S[j] += F16::reduce_add<k_step>(vk);
|
||||||
}
|
}
|
||||||
|
|
||||||
float cache[q_step*k_step];
|
cache_t cache[q_step*k_step];
|
||||||
float S[q_step], M[q_step];
|
float S[q_step], M[q_step];
|
||||||
int need_scaling[q_step];
|
int need_scaling[q_step];
|
||||||
F16::Data vms[q_step];
|
F16::Data vms[q_step];
|
||||||
@@ -6722,6 +6772,12 @@ struct FlashMS {
|
|||||||
template <int D, int q_step, int k_step>
|
template <int D, int q_step, int k_step>
|
||||||
struct FlashQKV {
|
struct FlashQKV {
|
||||||
|
|
||||||
|
#ifdef __aarch64__
|
||||||
|
using qkv_cache_t = float16_t;
|
||||||
|
#else
|
||||||
|
using qkv_cache_t = float;
|
||||||
|
#endif
|
||||||
|
|
||||||
// This fails for head sizes of 80 and 112 as D/16 is odd, so we cannot do steps of 2
|
// This fails for head sizes of 80 and 112 as D/16 is odd, so we cannot do steps of 2
|
||||||
// Hence, for now, we will not handle head sizes of 80 and 112
|
// Hence, for now, we will not handle head sizes of 80 and 112
|
||||||
template <typename VHelper>
|
template <typename VHelper>
|
||||||
@@ -6792,7 +6848,7 @@ struct FlashQKV {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void normalize_and_store(const FlashMS<q_step, k_step>& fms, int j, const float * R, float * qkv) const {
|
inline void normalize_and_store(const FlashMS<q_step, k_step>& fms, int j, const qkv_cache_t * R, float * qkv) const {
|
||||||
GGML_ASSERT(fms.S[j] > 0);
|
GGML_ASSERT(fms.S[j] > 0);
|
||||||
auto norm = F16::set1(1/fms.S[j]);
|
auto norm = F16::set1(1/fms.S[j]);
|
||||||
for (int i = 0; i < D/F16::block_size; ++i) {
|
for (int i = 0; i < D/F16::block_size; ++i) {
|
||||||
@@ -6819,7 +6875,7 @@ struct FlashQKV {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
float qkv_cache[D*q_step];
|
qkv_cache_t qkv_cache[D*q_step];
|
||||||
};
|
};
|
||||||
|
|
||||||
template <int D, int q_step, int k_step>
|
template <int D, int q_step, int k_step>
|
||||||
@@ -7469,29 +7525,29 @@ bool iqk_flash_attn_noalibi(int int_type_k, // type of k
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
#else
|
//#else
|
||||||
// TODO
|
//// TODO
|
||||||
bool iqk_flash_attn_noalibi([[maybe_unused]] int int_type_k, // type of k
|
//bool iqk_flash_attn_noalibi([[maybe_unused]] int int_type_k, // type of k
|
||||||
[[maybe_unused]] int int_type_v, // type of v
|
// [[maybe_unused]] int int_type_v, // type of v
|
||||||
[[maybe_unused]] int D, // head size
|
// [[maybe_unused]] int D, // head size
|
||||||
[[maybe_unused]] int nq, // number of columns in q
|
// [[maybe_unused]] int nq, // number of columns in q
|
||||||
[[maybe_unused]] int nk, // number of rows in k
|
// [[maybe_unused]] int nk, // number of rows in k
|
||||||
[[maybe_unused]] int stride_q, // distance between q columns in bytes
|
// [[maybe_unused]] int stride_q, // distance between q columns in bytes
|
||||||
[[maybe_unused]] int stride_k, // distance between k rows in bytes
|
// [[maybe_unused]] int stride_k, // distance between k rows in bytes
|
||||||
[[maybe_unused]] int stride_v, // distance between v rows in bytes
|
// [[maybe_unused]] int stride_v, // distance between v rows in bytes
|
||||||
[[maybe_unused]] int stride_m, // distance between mask rows (in bytes
|
// [[maybe_unused]] int stride_m, // distance between mask rows (in bytes
|
||||||
[[maybe_unused]] int stride_qkv, // distance between rows in mask (in bytes)
|
// [[maybe_unused]] int stride_qkv, // distance between rows in mask (in bytes)
|
||||||
[[maybe_unused]] const float * q, // q matrix.
|
// [[maybe_unused]] const float * q, // q matrix.
|
||||||
[[maybe_unused]] const void * k, // k matrix. Assumed to be fp16, nq x nk elements
|
// [[maybe_unused]] const void * k, // k matrix. Assumed to be fp16, nq x nk elements
|
||||||
[[maybe_unused]] const void * v, // v matrix. Assumed to be fp16, nq x nk elements
|
// [[maybe_unused]] const void * v, // v matrix. Assumed to be fp16, nq x nk elements
|
||||||
[[maybe_unused]] const void * mask, // mask. If not null, assumed to be fp16. nq x nk elements
|
// [[maybe_unused]] const void * mask, // mask. If not null, assumed to be fp16. nq x nk elements
|
||||||
[[maybe_unused]] float scale, // scale applied before softmax
|
// [[maybe_unused]] float scale, // scale applied before softmax
|
||||||
[[maybe_unused]] float softcap, // if > 0, a "soft-cap" operation is applied before softmax
|
// [[maybe_unused]] float softcap, // if > 0, a "soft-cap" operation is applied before softmax
|
||||||
[[maybe_unused]] float * qkv) { // v*softmax(scale*(k*q))
|
// [[maybe_unused]] float * qkv) { // v*softmax(scale*(k*q))
|
||||||
return false;
|
// return false;
|
||||||
}
|
//}
|
||||||
|
//
|
||||||
#endif
|
//#endif
|
||||||
|
|
||||||
#else // IQK_IMPLEMENT
|
#else // IQK_IMPLEMENT
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user