mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-23 14:44:09 +00:00
ARM_NEON Flash Attention (#49)
* NEON Flash Attention - first working version Simply reuse the Zen4/AVX2 implementation, but use f16 for the K*Q multiplication and V*softmax(K*Q) accumulation. This makes the FlashMS portion somewhat awkward because we do not have fast f16 implementations for expf (and tanh when softcap is enabled), so we need to convert back-and-fort to f32. FA is slightly faster than no-FA for the 4B TriLM model, but lightly slower for Gemma-2b. * NEON Flash Attention - convert Q to f16 before computing Q*K * NEON Flash Attention - use fp32 for K*Q operations Else I get wrong results for LLaMA-3.1-8B (but it works for Gemma-2b). * Delete commented out stuff --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
@@ -6293,6 +6293,11 @@ inline float32x4_t v_expf(float32x4_t x) {
|
||||
return vbslq_f32(vcagtq_f32(n, vdupq_n_f32(192)), vmulq_f32(s1, s1),
|
||||
vbslq_f32(c, vmulq_f32(vfmaq_f32(s2, s2, j), s1), vfmaq_f32(k, k, j)));
|
||||
}
|
||||
inline float16x8_t v_expf(float16x8_t x) {
|
||||
auto val1 = v_expf(vcvt_f32_f16(vget_low_f16(x)));
|
||||
auto val2 = v_expf(vcvt_f32_f16(vget_high_f16(x)));
|
||||
return vcombine_f16(vcvt_f16_f32(val1), vcvt_f16_f32(val2));
|
||||
}
|
||||
inline float32x4_t v_tanh(float32x4_t x) {
|
||||
const float32x4_t one = vdupq_n_f32(1.0f);
|
||||
const float32x4_t two_x = vmulq_f32(x, vdupq_n_f32(2.f));
|
||||
@@ -6302,6 +6307,11 @@ inline float32x4_t v_tanh(float32x4_t x) {
|
||||
return vreinterpretq_f32_u32(vorrq_u32(vandq_u32(vreinterpretq_u32_f32(one), mask), vbicq_u32(vreinterpretq_u32_f32(res), mask)));
|
||||
//return vdivq_f32(vsubq_f32(exp_two_x, one), vaddq_f32(exp_two_x, one));
|
||||
}
|
||||
inline float32x4_t v_tanh(float16x8_t x) {
|
||||
auto val1 = v_tanh(vcvt_f32_f16(vget_low_f16(x)));
|
||||
auto val2 = v_tanh(vcvt_f32_f16(vget_high_f16(x)));
|
||||
return vcombine_f16(vcvt_f16_f32(val1), vcvt_f16_f32(val2));
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(__AVX512F__) && defined(__AVX512DQ__)
|
||||
@@ -6401,8 +6411,6 @@ inline __m256 v_tanh(__m256 x) {
|
||||
#endif
|
||||
} // namespace
|
||||
|
||||
#ifndef __aarch64__
|
||||
|
||||
namespace {
|
||||
|
||||
template <int k_step>
|
||||
@@ -6442,7 +6450,7 @@ struct F16 {
|
||||
template <int k_step> static inline float reduce_add(const Data * data) {
|
||||
return reduce_T<k_step, _mm512_add_ps, _mm512_reduce_add_ps>(data);
|
||||
}
|
||||
#else
|
||||
#elif defined __AVX2__
|
||||
using Data = __m256;
|
||||
constexpr static int block_size = 8;
|
||||
constexpr static int num_registers = 16;
|
||||
@@ -6463,6 +6471,41 @@ struct F16 {
|
||||
template <int k_step> static inline float reduce_add(const Data * data) {
|
||||
return reduce_T<k_step, _mm256_add_ps, &F16::reduce_add>(data);
|
||||
}
|
||||
#else
|
||||
using Data = float16x8_t;
|
||||
constexpr static int block_size = 8;
|
||||
constexpr static int num_registers = 32;
|
||||
constexpr static int q_step = 8;
|
||||
static inline Data zero() { return vdupq_n_f16(0); }
|
||||
static inline Data load(const char * ptr, int i) { return vld1q_f16((const float16_t *)ptr + block_size*i); }
|
||||
static inline Data load(const float16_t * ptr, int i) { return vld1q_f16(ptr + block_size*i); }
|
||||
static inline Data load(const float16_t * ptr) { return vld1q_f16(ptr); }
|
||||
static inline Data load(const float * ptr) {
|
||||
auto val1 = vld1q_f32(ptr);
|
||||
auto val2 = vld1q_f32(ptr+4);
|
||||
return vcombine_f16(vcvt_f16_f32(val1), vcvt_f16_f32(val2));
|
||||
}
|
||||
static inline Data set1(float val) { return vdupq_n_f16(val); }
|
||||
static inline Data mul(Data v1, Data v2) { return vmulq_f16(v1, v2); }
|
||||
static inline Data sub(Data v1, Data v2) { return vsubq_f16(v1, v2); }
|
||||
static inline void store(float * ptr, Data data) {
|
||||
vst1q_f32(ptr+0, vcvt_f32_f16(vget_low_f16(data)));
|
||||
vst1q_f32(ptr+4, vcvt_f32_f16(vget_high_f16(data)));
|
||||
}
|
||||
static inline void store(float16_t * ptr, Data data) { vst1q_f16(ptr, data); }
|
||||
static inline void store(float * ptr, float32x4_t data) { vst1q_f32(ptr, data); }
|
||||
static inline Data fmadd(Data prev, Data v1, Data v2) { return vfmaq_f16(prev, v1, v2); }
|
||||
static inline float reduce_max(Data data) { return vmaxvq_f16(data); }
|
||||
static inline float reduce_add(Data data) {
|
||||
auto sum = vadd_f16(vget_low_f16(data), vget_high_f16(data));
|
||||
return vaddvq_f32(vcvt_f32_f16(sum));
|
||||
}
|
||||
template <int k_step> static inline float reduce_max(const Data * data) {
|
||||
return reduce_T<k_step, vmaxq_f16, &F16::reduce_max>(data);
|
||||
}
|
||||
template <int k_step> static inline float reduce_add(const Data * data) {
|
||||
return reduce_T<k_step, vaddq_f16, &F16::reduce_add>(data);
|
||||
}
|
||||
#endif
|
||||
template <int k_step, Data (*Op_combine)(Data, Data), float (*Op)(Data)>
|
||||
static float reduce_T(const Data * data) {
|
||||
@@ -6663,6 +6706,17 @@ struct HelperQ41 final : public BaseHelper<step> {
|
||||
|
||||
template <int q_step, int k_step>
|
||||
struct FlashMS {
|
||||
// Something goes wrong when storing and manipulating K*Q as fp16.
|
||||
// It works for some models (e.g., Gemma-2), but not for others (e.g., LLaMA-3.1-8B).
|
||||
// As I wasn't able to find where we lose precision, let's comment this out
|
||||
// for now and do the K*Q part in fp32.
|
||||
//#ifdef __aarch64__
|
||||
// using cache_t = float16_t;
|
||||
//#else
|
||||
// using cache_t = float;
|
||||
//#endif
|
||||
using cache_t = float;
|
||||
|
||||
FlashMS(float scale, float softcap) : vscale(F16::set1(scale)), softcap(softcap), h_inf(GGML_FP32_TO_FP16(-INFINITY)) {}
|
||||
|
||||
inline void init_qstep() {
|
||||
@@ -6671,6 +6725,75 @@ struct FlashMS {
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef __aarch64__
|
||||
inline void update_M_S(int j, float32x4_t * vk) {
|
||||
float32x4_t vmax = vdupq_n_f32(-INFINITY);
|
||||
// Something goes wrong when storing and manipulating K*Q as fp16.
|
||||
// It works for some models (e.g., Gemma-2), but not for others (e.g., LLaMA-3.1-8B).
|
||||
// As I wasn't able to find where we lose precision, let's comment this out
|
||||
// for now and do the K*Q part in fp32.
|
||||
//if (softcap <= 0.0f) {
|
||||
// for (int l = 0; l < k_step/F16::block_size; ++l) {
|
||||
// auto val = F16::mul(vscale, F16::load(cache + k_step*j + F16::block_size*l));
|
||||
// vk[2*l+0] = vcvt_f32_f16(vget_low_f16(val));
|
||||
// vk[2*l+1] = vcvt_f32_f16(vget_high_f16(val));
|
||||
// vmax = vmaxq_f32(vmax, vmaxq_f32(vk[2*l+0], vk[2*l+1]));
|
||||
// }
|
||||
//} else {
|
||||
// auto v_softcap = vdupq_n_f32(softcap);
|
||||
// for (int l = 0; l < k_step/F16::block_size; ++l) {
|
||||
// auto val = F16::mul(vscale, F16::load(cache + k_step*j + F16::block_size*l));
|
||||
// vk[2*l+0] = vcvt_f32_f16(vget_low_f16(val));
|
||||
// vk[2*l+1] = vcvt_f32_f16(vget_high_f16(val));
|
||||
// vk[2*l+0] = vmulq_f32(v_softcap, v_tanh(vk[2*l+0]));
|
||||
// vk[2*l+1] = vmulq_f32(v_softcap, v_tanh(vk[2*l+1]));
|
||||
// vmax = vmaxq_f32(vmax, vmaxq_f32(vk[2*l+0], vk[2*l+1]));
|
||||
// }
|
||||
//}
|
||||
auto vscale32 = vcvt_f32_f16(vget_low_f16(vscale));
|
||||
if (softcap <= 0.0f) {
|
||||
for (int l = 0; l < k_step/4; ++l) {
|
||||
vk[l] = vmulq_f32(vscale32, vld1q_f32(cache + k_step*j + 4*l));
|
||||
vmax = vmaxq_f32(vmax, vk[l]);
|
||||
}
|
||||
} else {
|
||||
auto v_softcap = vdupq_n_f32(softcap);
|
||||
for (int l = 0; l < k_step/4; ++l) {
|
||||
vk[l] = vmulq_f32(vscale32, vld1q_f32(cache + k_step*j + 4*l));
|
||||
vk[l] = vmulq_f32(v_softcap, v_tanh(vk[l]));
|
||||
vmax = vmaxq_f32(vmax, vk[l]);
|
||||
}
|
||||
}
|
||||
|
||||
float smax = vmaxvq_f32(vmax);
|
||||
if (smax == -INFINITY) {
|
||||
std::memset(cache + k_step*j, 0, k_step*sizeof(float));
|
||||
need_scaling[j] = M[j] == -INFINITY ? 2 : 0;
|
||||
return;
|
||||
}
|
||||
need_scaling[j] = 0;
|
||||
if (smax > M[j]) {
|
||||
if (M[j] > -INFINITY) {
|
||||
float m = expf(M[j] - smax);
|
||||
vms[j] = F16::set1(m);
|
||||
need_scaling[j] = 1;
|
||||
S[j] *= m;
|
||||
} else {
|
||||
need_scaling[j] = 2;
|
||||
S[j] = 0;
|
||||
}
|
||||
M[j] = smax;
|
||||
}
|
||||
auto vm = vdupq_n_f32(M[j]);
|
||||
auto vsum = vdupq_n_f32(0);
|
||||
for (int l = 0; l < k_step/4; ++l) {
|
||||
vk[l] = v_expf(vsubq_f32(vk[l], vm));
|
||||
vsum = vaddq_f32(vsum, vk[l]);
|
||||
F16::store(cache + k_step*j + 4*l, vk[l]);
|
||||
}
|
||||
S[j] += vaddvq_f32(vsum);
|
||||
}
|
||||
#else
|
||||
inline void update_M_S(int j, F16::Data * vk) {
|
||||
if (softcap <= 0.0f) {
|
||||
for (int l = 0; l < k_step/F16::block_size; ++l) vk[l] = F16::mul(vscale, F16::load(cache + k_step*j + F16::block_size*l));
|
||||
@@ -6708,8 +6831,9 @@ struct FlashMS {
|
||||
}
|
||||
S[j] += F16::reduce_add<k_step>(vk);
|
||||
}
|
||||
#endif
|
||||
|
||||
float cache[q_step*k_step];
|
||||
cache_t cache[q_step*k_step];
|
||||
float S[q_step], M[q_step];
|
||||
int need_scaling[q_step];
|
||||
F16::Data vms[q_step];
|
||||
@@ -6722,6 +6846,12 @@ struct FlashMS {
|
||||
template <int D, int q_step, int k_step>
|
||||
struct FlashQKV {
|
||||
|
||||
#ifdef __aarch64__
|
||||
using qkv_cache_t = float16_t;
|
||||
#else
|
||||
using qkv_cache_t = float;
|
||||
#endif
|
||||
|
||||
// This fails for head sizes of 80 and 112 as D/16 is odd, so we cannot do steps of 2
|
||||
// Hence, for now, we will not handle head sizes of 80 and 112
|
||||
template <typename VHelper>
|
||||
@@ -6792,7 +6922,7 @@ struct FlashQKV {
|
||||
}
|
||||
}
|
||||
|
||||
inline void normalize_and_store(const FlashMS<q_step, k_step>& fms, int j, const float * R, float * qkv) const {
|
||||
inline void normalize_and_store(const FlashMS<q_step, k_step>& fms, int j, const qkv_cache_t * R, float * qkv) const {
|
||||
GGML_ASSERT(fms.S[j] > 0);
|
||||
auto norm = F16::set1(1/fms.S[j]);
|
||||
for (int i = 0; i < D/F16::block_size; ++i) {
|
||||
@@ -6819,7 +6949,7 @@ struct FlashQKV {
|
||||
}
|
||||
}
|
||||
|
||||
float qkv_cache[D*q_step];
|
||||
qkv_cache_t qkv_cache[D*q_step];
|
||||
};
|
||||
|
||||
template <int D, int q_step, int k_step>
|
||||
@@ -6828,10 +6958,14 @@ struct FlashQKfp32 {
|
||||
static_assert(k_step%F16::block_size == 0);
|
||||
static_assert(q_step <= 4 || q_step%4 == 0);
|
||||
|
||||
#ifdef __aarch64__
|
||||
constexpr static bool is_small_head = false;
|
||||
#else
|
||||
constexpr static bool is_small_head = D <= (F16::num_registers/2)*F16::block_size;
|
||||
#endif
|
||||
|
||||
template <bool small = is_small_head, class = std::enable_if<small>>
|
||||
static inline void mult_mask_kq_one(int l1, int m1, int stride_q, int stride_m, const float * q, const char * mask,
|
||||
template <bool small = is_small_head, class = std::enable_if<small>, typename q_float>
|
||||
static inline void mult_mask_kq_one(int l1, int m1, int stride_q, int stride_m, const q_float * q, const char * mask,
|
||||
F16::Data * qv, F16::Data * vk, FlashMS<q_step, k_step>& fms) {
|
||||
// q index is q_step*i1 + m1
|
||||
// k index is k_step*k1 + l1
|
||||
@@ -6854,8 +6988,8 @@ struct FlashQKfp32 {
|
||||
}
|
||||
}
|
||||
|
||||
template <bool small = is_small_head, class = std::enable_if<!small>>
|
||||
static inline void mult_mask_kq_one(int l1, int m1, int stride_q, int stride_m, const float * q, const char * mask,
|
||||
template <bool small = is_small_head, class = std::enable_if<!small>, typename q_float>
|
||||
static inline void mult_mask_kq_one(int l1, int m1, int stride_q, int stride_m, const q_float * q, const char * mask,
|
||||
F16::Data * vk, FlashMS<q_step, k_step>& fms) {
|
||||
// q index is q_step*i1 + m1
|
||||
// k index is k_step*k1 + l1
|
||||
@@ -6872,8 +7006,8 @@ struct FlashQKfp32 {
|
||||
fms.cache[k_step*m1 + l1] = F16::reduce_add(vsum);
|
||||
}
|
||||
|
||||
template <typename KHelper, bool small = is_small_head, class = std::enable_if<small>>
|
||||
static inline void mult_mask_kq(const KHelper& kh, int stride_q, int stride_m, const float * q, const char * mask,
|
||||
template <typename KHelper, bool small = is_small_head, class = std::enable_if<small>, typename q_float>
|
||||
static inline void mult_mask_kq(const KHelper& kh, int stride_q, int stride_m, const q_float * q, const char * mask,
|
||||
FlashMS<q_step, k_step>& fms) {
|
||||
F16::Data qv[D/F16::block_size];
|
||||
F16::Data vk[D/(F16::block_size/2)];
|
||||
@@ -6885,9 +7019,9 @@ struct FlashQKfp32 {
|
||||
}
|
||||
}
|
||||
|
||||
template <typename KHelper, bool small = is_small_head, class = std::enable_if<!small>>
|
||||
template <typename KHelper, bool small = is_small_head, class = std::enable_if<!small>, typename q_float>
|
||||
static inline void mult_mask_kq_l(const KHelper& kh, int stride_q, int stride_m,
|
||||
const float * q, const char * mask, FlashMS<q_step, k_step>& fms) {
|
||||
const q_float * q, const char * mask, FlashMS<q_step, k_step>& fms) {
|
||||
F16::Data vk[D/F16::block_size];
|
||||
for (int l1 = 0; l1 < k_step; ++l1) {
|
||||
kh.load(l1, vk);
|
||||
@@ -6897,8 +7031,8 @@ struct FlashQKfp32 {
|
||||
}
|
||||
}
|
||||
|
||||
template <typename KHelper, bool small = is_small_head, class = std::enable_if<small>>
|
||||
static inline void mult_mask_kq(int nq, const KHelper& kh, int stride_q, int stride_m, const float * q, const char * mask,
|
||||
template <typename KHelper, bool small = is_small_head, class = std::enable_if<small>, typename q_float>
|
||||
static inline void mult_mask_kq(int nq, const KHelper& kh, int stride_q, int stride_m, const q_float * q, const char * mask,
|
||||
FlashMS<q_step, k_step>& fms) {
|
||||
F16::Data qv[D/F16::block_size];
|
||||
F16::Data vk[D/(F16::block_size/2)];
|
||||
@@ -6910,9 +7044,9 @@ struct FlashQKfp32 {
|
||||
}
|
||||
}
|
||||
|
||||
template <typename KHelper, bool small = is_small_head, class = std::enable_if<!small>>
|
||||
template <typename KHelper, bool small = is_small_head, class = std::enable_if<!small>, typename q_float>
|
||||
static inline void mult_mask_kq_l(int nq, const KHelper& kh, int stride_q, int stride_m,
|
||||
const float * q, const char * mask, FlashMS<q_step, k_step>& fms) {
|
||||
const q_float * q, const char * mask, FlashMS<q_step, k_step>& fms) {
|
||||
F16::Data vk[D/F16::block_size];
|
||||
for (int l1 = 0; l1 < k_step; ++l1) {
|
||||
kh.load(l1, vk);
|
||||
@@ -6922,8 +7056,8 @@ struct FlashQKfp32 {
|
||||
}
|
||||
}
|
||||
|
||||
template <typename KHelper>
|
||||
static inline void multiply_mask_kq(const KHelper& kh, int stride_q, int stride_m, const float * q, const char * mask,
|
||||
template <typename KHelper, typename q_float>
|
||||
static inline void multiply_mask_kq(const KHelper& kh, int stride_q, int stride_m, const q_float * q, const char * mask,
|
||||
FlashMS<q_step, k_step>& fms) {
|
||||
if constexpr (is_small_head) {
|
||||
mult_mask_kq(kh, stride_q, stride_m, q, mask, fms);
|
||||
@@ -6931,14 +7065,21 @@ struct FlashQKfp32 {
|
||||
else {
|
||||
mult_mask_kq_l(kh, stride_q, stride_m, q, mask, fms);
|
||||
}
|
||||
#ifdef __aarch64__
|
||||
float32x4_t vk[k_step/4];
|
||||
for (int j = 0; j < q_step; ++j) {
|
||||
fms.update_M_S(j, vk);
|
||||
}
|
||||
#else
|
||||
F16::Data vk[k_step/F16::block_size];
|
||||
for (int j = 0; j < q_step; ++j) {
|
||||
fms.update_M_S(j, vk);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename KHelper>
|
||||
static inline void multiply_mask_kq(int nq, const KHelper& kh, int stride_q, int stride_m, const float * q, const char * mask,
|
||||
template <typename KHelper, typename q_float>
|
||||
static inline void multiply_mask_kq(int nq, const KHelper& kh, int stride_q, int stride_m, const q_float * q, const char * mask,
|
||||
FlashMS<q_step, k_step>& fms) {
|
||||
if constexpr (is_small_head) {
|
||||
mult_mask_kq(nq, kh, stride_q, stride_m, q, mask, fms);
|
||||
@@ -6946,11 +7087,33 @@ struct FlashQKfp32 {
|
||||
else {
|
||||
mult_mask_kq_l(nq, kh, stride_q, stride_m, q, mask, fms);
|
||||
}
|
||||
#ifdef __aarch64__
|
||||
float32x4_t vk[k_step/4];
|
||||
for (int j = 0; j < nq; ++j) {
|
||||
fms.update_M_S(j, vk);
|
||||
}
|
||||
#else
|
||||
F16::Data vk[k_step/F16::block_size];
|
||||
for (int j = 0; j < nq; ++j) {
|
||||
fms.update_M_S(j, vk);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
#ifdef __aarch64__
|
||||
static inline void convert(int nq, int stride_q, const float * q, float16_t * q_f16) {
|
||||
for (int i = 0; i < nq; ++i) {
|
||||
for (int j = 0; j < D; j += 8) {
|
||||
auto val1_f32 = vld1q_f32(q + j + 0);
|
||||
auto val2_f32 = vld1q_f32(q + j + 4);
|
||||
auto val_f16 = vcombine_f16(vcvt_f16_f32(val1_f32), vcvt_f16_f32(val2_f32));
|
||||
vst1q_f16(q_f16 + j, val_f16);
|
||||
}
|
||||
q += stride_q;
|
||||
q_f16 += D;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
};
|
||||
|
||||
template <int D, int q_step, int k_step, typename KHelper, typename VHelper, typename KQHelper>
|
||||
@@ -6958,13 +7121,23 @@ void compute_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, in
|
||||
FlashMS<q_step, k_step>& fms,
|
||||
FlashQKV<D, q_step, k_step>& fqkv,
|
||||
const float * q, const char * mask, float * qkv) {
|
||||
#ifdef __aarch64__
|
||||
float16_t q_f16[D*q_step];
|
||||
#endif
|
||||
for (int i1 = 0; i1 < nq1/q_step; ++i1) {
|
||||
fms.init_qstep();
|
||||
kh.reset_block();
|
||||
vh.reset_block();
|
||||
#ifdef __aarch64__
|
||||
KQHelper::convert(q_step, stride_q, q, q_f16);
|
||||
#endif
|
||||
auto mr = mask;
|
||||
for (int k1 = 0; k1 < nk1/k_step; ++k1) {
|
||||
#ifdef __aarch64__
|
||||
KQHelper::multiply_mask_kq(kh, D, stride_m, q_f16, mr, fms);
|
||||
#else
|
||||
KQHelper::multiply_mask_kq(kh, stride_q, stride_m, q, mr, fms);
|
||||
#endif
|
||||
fqkv.accumulate_qkv(vh, fms);
|
||||
kh.next_block();
|
||||
vh.next_block();
|
||||
@@ -6981,9 +7154,16 @@ void compute_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, in
|
||||
fms.init_qstep();
|
||||
kh.reset_block();
|
||||
vh.reset_block();
|
||||
#ifdef __aarch64__
|
||||
KQHelper::convert(n_left, stride_q, q, q_f16);
|
||||
#endif
|
||||
auto mr = mask;
|
||||
for (int k1 = 0; k1 < nk1/k_step; ++k1) {
|
||||
#ifdef __aarch64__
|
||||
KQHelper::multiply_mask_kq(n_left, kh, D, stride_m, q_f16, mr, fms);
|
||||
#else
|
||||
KQHelper::multiply_mask_kq(n_left, kh, stride_q, stride_m, q, mr, fms);
|
||||
#endif
|
||||
fqkv.accumulate_qkv(n_left, vh, fms);
|
||||
kh.next_block();
|
||||
vh.next_block();
|
||||
@@ -7469,30 +7649,6 @@ bool iqk_flash_attn_noalibi(int int_type_k, // type of k
|
||||
return true;
|
||||
}
|
||||
|
||||
#else
|
||||
// TODO
|
||||
bool iqk_flash_attn_noalibi([[maybe_unused]] int int_type_k, // type of k
|
||||
[[maybe_unused]] int int_type_v, // type of v
|
||||
[[maybe_unused]] int D, // head size
|
||||
[[maybe_unused]] int nq, // number of columns in q
|
||||
[[maybe_unused]] int nk, // number of rows in k
|
||||
[[maybe_unused]] int stride_q, // distance between q columns in bytes
|
||||
[[maybe_unused]] int stride_k, // distance between k rows in bytes
|
||||
[[maybe_unused]] int stride_v, // distance between v rows in bytes
|
||||
[[maybe_unused]] int stride_m, // distance between mask rows (in bytes
|
||||
[[maybe_unused]] int stride_qkv, // distance between rows in mask (in bytes)
|
||||
[[maybe_unused]] const float * q, // q matrix.
|
||||
[[maybe_unused]] const void * k, // k matrix. Assumed to be fp16, nq x nk elements
|
||||
[[maybe_unused]] const void * v, // v matrix. Assumed to be fp16, nq x nk elements
|
||||
[[maybe_unused]] const void * mask, // mask. If not null, assumed to be fp16. nq x nk elements
|
||||
[[maybe_unused]] float scale, // scale applied before softmax
|
||||
[[maybe_unused]] float softcap, // if > 0, a "soft-cap" operation is applied before softmax
|
||||
[[maybe_unused]] float * qkv) { // v*softmax(scale*(k*q))
|
||||
return false;
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
#else // IQK_IMPLEMENT
|
||||
|
||||
bool iqk_mul_mat(int, long, long, long, int, const void *, long, int, const void *, long, float *, long, int, int) {
|
||||
|
||||
Reference in New Issue
Block a user