NEON Flash Attention - first working version

Simply reuse the Zen4/AVX2 implementation, but use
f16 for the K*Q multiplication and V*softmax(K*Q) accumulation.
This makes the FlashMS portion somewhat awkward because we
do not have fast f16 implementations for expf (and tanh when
softcap is enabled), so we need to convert back-and-fort
to f32.

FA is slightly faster than no-FA for the 4B TriLM model,
but lightly slower for Gemma-2b.
This commit is contained in:
Iwan Kawrakow
2024-09-10 19:01:03 +02:00
parent 72f5dfe12a
commit 67bf083f9d

View File

@@ -6293,6 +6293,11 @@ inline float32x4_t v_expf(float32x4_t x) {
return vbslq_f32(vcagtq_f32(n, vdupq_n_f32(192)), vmulq_f32(s1, s1),
vbslq_f32(c, vmulq_f32(vfmaq_f32(s2, s2, j), s1), vfmaq_f32(k, k, j)));
}
inline float16x8_t v_expf(float16x8_t x) {
auto val1 = v_expf(vcvt_f32_f16(vget_low_f16(x)));
auto val2 = v_expf(vcvt_f32_f16(vget_high_f16(x)));
return vcombine_f16(vcvt_f16_f32(val1), vcvt_f16_f32(val2));
}
inline float32x4_t v_tanh(float32x4_t x) {
const float32x4_t one = vdupq_n_f32(1.0f);
const float32x4_t two_x = vmulq_f32(x, vdupq_n_f32(2.f));
@@ -6302,6 +6307,11 @@ inline float32x4_t v_tanh(float32x4_t x) {
return vreinterpretq_f32_u32(vorrq_u32(vandq_u32(vreinterpretq_u32_f32(one), mask), vbicq_u32(vreinterpretq_u32_f32(res), mask)));
//return vdivq_f32(vsubq_f32(exp_two_x, one), vaddq_f32(exp_two_x, one));
}
inline float32x4_t v_tanh(float16x8_t x) {
auto val1 = v_tanh(vcvt_f32_f16(vget_low_f16(x)));
auto val2 = v_tanh(vcvt_f32_f16(vget_high_f16(x)));
return vcombine_f16(vcvt_f16_f32(val1), vcvt_f16_f32(val2));
}
#endif
#if defined(__AVX512F__) && defined(__AVX512DQ__)
@@ -6401,7 +6411,7 @@ inline __m256 v_tanh(__m256 x) {
#endif
} // namespace
#ifndef __aarch64__
//#ifndef __aarch64__
namespace {
@@ -6442,7 +6452,7 @@ struct F16 {
template <int k_step> static inline float reduce_add(const Data * data) {
return reduce_T<k_step, _mm512_add_ps, _mm512_reduce_add_ps>(data);
}
#else
#elif defined __AVX2__
using Data = __m256;
constexpr static int block_size = 8;
constexpr static int num_registers = 16;
@@ -6463,6 +6473,40 @@ struct F16 {
template <int k_step> static inline float reduce_add(const Data * data) {
return reduce_T<k_step, _mm256_add_ps, &F16::reduce_add>(data);
}
#else
using Data = float16x8_t;
constexpr static int block_size = 8;
constexpr static int num_registers = 32;
constexpr static int q_step = 8;
static inline Data zero() { return vdupq_n_f16(0); }
static inline Data load(const char * ptr, int i) { return vld1q_f16((const float16_t *)ptr + block_size*i); }
static inline Data load(const float16_t * ptr, int i) { return vld1q_f16(ptr + block_size*i); }
static inline Data load(const float16_t * ptr) { return vld1q_f16(ptr); }
static inline Data load(const float * ptr) {
auto val1 = vld1q_f32(ptr);
auto val2 = vld1q_f32(ptr+4);
return vcombine_f16(vcvt_f16_f32(val1), vcvt_f16_f32(val2));
}
static inline Data set1(float val) { return vdupq_n_f16(val); }
static inline Data mul(Data v1, Data v2) { return vmulq_f16(v1, v2); }
static inline Data sub(Data v1, Data v2) { return vsubq_f16(v1, v2); }
static inline void store(float * ptr, Data data) {
vst1q_f32(ptr+0, vcvt_f32_f16(vget_low_f16(data)));
vst1q_f32(ptr+4, vcvt_f32_f16(vget_high_f16(data)));
}
static inline void store(float16_t * ptr, Data data) { vst1q_f16(ptr, data); }
static inline Data fmadd(Data prev, Data v1, Data v2) { return vfmaq_f16(prev, v1, v2); }
static inline float reduce_max(Data data) { return vmaxvq_f16(data); }
static inline float reduce_add(Data data) {
auto sum = vadd_f16(vget_low_f16(data), vget_high_f16(data));
return vaddvq_f32(vcvt_f32_f16(sum));
}
template <int k_step> static inline float reduce_max(const Data * data) {
return reduce_T<k_step, vmaxq_f16, &F16::reduce_max>(data);
}
template <int k_step> static inline float reduce_add(const Data * data) {
return reduce_T<k_step, vaddq_f16, &F16::reduce_add>(data);
}
#endif
template <int k_step, Data (*Op_combine)(Data, Data), float (*Op)(Data)>
static float reduce_T(const Data * data) {
@@ -6663,6 +6707,12 @@ struct HelperQ41 final : public BaseHelper<step> {
template <int q_step, int k_step>
struct FlashMS {
#ifdef __aarch64__
using cache_t = float16_t;
#else
using cache_t = float;
#endif
FlashMS(float scale, float softcap) : vscale(F16::set1(scale)), softcap(softcap), h_inf(GGML_FP32_TO_FP16(-INFINITY)) {}
inline void init_qstep() {
@@ -6709,7 +6759,7 @@ struct FlashMS {
S[j] += F16::reduce_add<k_step>(vk);
}
float cache[q_step*k_step];
cache_t cache[q_step*k_step];
float S[q_step], M[q_step];
int need_scaling[q_step];
F16::Data vms[q_step];
@@ -6722,6 +6772,12 @@ struct FlashMS {
template <int D, int q_step, int k_step>
struct FlashQKV {
#ifdef __aarch64__
using qkv_cache_t = float16_t;
#else
using qkv_cache_t = float;
#endif
// This fails for head sizes of 80 and 112 as D/16 is odd, so we cannot do steps of 2
// Hence, for now, we will not handle head sizes of 80 and 112
template <typename VHelper>
@@ -6792,7 +6848,7 @@ struct FlashQKV {
}
}
inline void normalize_and_store(const FlashMS<q_step, k_step>& fms, int j, const float * R, float * qkv) const {
inline void normalize_and_store(const FlashMS<q_step, k_step>& fms, int j, const qkv_cache_t * R, float * qkv) const {
GGML_ASSERT(fms.S[j] > 0);
auto norm = F16::set1(1/fms.S[j]);
for (int i = 0; i < D/F16::block_size; ++i) {
@@ -6819,7 +6875,7 @@ struct FlashQKV {
}
}
float qkv_cache[D*q_step];
qkv_cache_t qkv_cache[D*q_step];
};
template <int D, int q_step, int k_step>
@@ -7469,29 +7525,29 @@ bool iqk_flash_attn_noalibi(int int_type_k, // type of k
return true;
}
#else
// TODO
bool iqk_flash_attn_noalibi([[maybe_unused]] int int_type_k, // type of k
[[maybe_unused]] int int_type_v, // type of v
[[maybe_unused]] int D, // head size
[[maybe_unused]] int nq, // number of columns in q
[[maybe_unused]] int nk, // number of rows in k
[[maybe_unused]] int stride_q, // distance between q columns in bytes
[[maybe_unused]] int stride_k, // distance between k rows in bytes
[[maybe_unused]] int stride_v, // distance between v rows in bytes
[[maybe_unused]] int stride_m, // distance between mask rows (in bytes
[[maybe_unused]] int stride_qkv, // distance between rows in mask (in bytes)
[[maybe_unused]] const float * q, // q matrix.
[[maybe_unused]] const void * k, // k matrix. Assumed to be fp16, nq x nk elements
[[maybe_unused]] const void * v, // v matrix. Assumed to be fp16, nq x nk elements
[[maybe_unused]] const void * mask, // mask. If not null, assumed to be fp16. nq x nk elements
[[maybe_unused]] float scale, // scale applied before softmax
[[maybe_unused]] float softcap, // if > 0, a "soft-cap" operation is applied before softmax
[[maybe_unused]] float * qkv) { // v*softmax(scale*(k*q))
return false;
}
#endif
//#else
//// TODO
//bool iqk_flash_attn_noalibi([[maybe_unused]] int int_type_k, // type of k
// [[maybe_unused]] int int_type_v, // type of v
// [[maybe_unused]] int D, // head size
// [[maybe_unused]] int nq, // number of columns in q
// [[maybe_unused]] int nk, // number of rows in k
// [[maybe_unused]] int stride_q, // distance between q columns in bytes
// [[maybe_unused]] int stride_k, // distance between k rows in bytes
// [[maybe_unused]] int stride_v, // distance between v rows in bytes
// [[maybe_unused]] int stride_m, // distance between mask rows (in bytes
// [[maybe_unused]] int stride_qkv, // distance between rows in mask (in bytes)
// [[maybe_unused]] const float * q, // q matrix.
// [[maybe_unused]] const void * k, // k matrix. Assumed to be fp16, nq x nk elements
// [[maybe_unused]] const void * v, // v matrix. Assumed to be fp16, nq x nk elements
// [[maybe_unused]] const void * mask, // mask. If not null, assumed to be fp16. nq x nk elements
// [[maybe_unused]] float scale, // scale applied before softmax
// [[maybe_unused]] float softcap, // if > 0, a "soft-cap" operation is applied before softmax
// [[maybe_unused]] float * qkv) { // v*softmax(scale*(k*q))
// return false;
//}
//
//#endif
#else // IQK_IMPLEMENT