Fix ARM_NEON

I had forgotten to guard the AVX2/Zen4 implementation against __aarch64__
This commit is contained in:
Iwan Kawrakow
2024-09-10 18:14:59 +02:00
parent 1b12b2658a
commit e3919f5f80

View File

@@ -6401,7 +6401,7 @@ inline __m256 v_tanh(__m256 x) {
#endif
} // namespace
//#ifdef HAVE_FANCY_SIMD
#ifndef __aarch64__
namespace {
@@ -7469,29 +7469,29 @@ bool iqk_flash_attn_noalibi(int int_type_k, // type of k
return true;
}
////#else
//// TODO
//bool iqk_flash_attn_noalibi([[maybe_unused]] int int_type_k, // type of k
// [[maybe_unused]] int int_type_v, // type of v
// [[maybe_unused]] int D, // head size
// [[maybe_unused]] int nq, // number of columns in q
// [[maybe_unused]] int nk, // number of rows in k
// [[maybe_unused]] int stride_q, // distance between q columns in bytes
// [[maybe_unused]] int stride_k, // distance between k rows in bytes
// [[maybe_unused]] int stride_v, // distance between v rows in bytes
// [[maybe_unused]] int stride_m, // distance between mask rows (in bytes
// [[maybe_unused]] int stride_qkv, // distance between rows in mask (in bytes)
// [[maybe_unused]] const float * q, // q matrix.
// [[maybe_unused]] const void * k, // k matrix. Assumed to be fp16, nq x nk elements
// [[maybe_unused]] const void * v, // v matrix. Assumed to be fp16, nq x nk elements
// [[maybe_unused]] const void * mask, // mask. If not null, assumed to be fp16. nq x nk elements
// [[maybe_unused]] float scale, // scale applied before softmax
// [[maybe_unused]] float softcap, // if > 0, a "soft-cap" operation is applied before softmax
// [[maybe_unused]] float * qkv) { // v*softmax(scale*(k*q))
// return false;
//}
//
//#endif
#else
// TODO
bool iqk_flash_attn_noalibi([[maybe_unused]] int int_type_k, // type of k
[[maybe_unused]] int int_type_v, // type of v
[[maybe_unused]] int D, // head size
[[maybe_unused]] int nq, // number of columns in q
[[maybe_unused]] int nk, // number of rows in k
[[maybe_unused]] int stride_q, // distance between q columns in bytes
[[maybe_unused]] int stride_k, // distance between k rows in bytes
[[maybe_unused]] int stride_v, // distance between v rows in bytes
[[maybe_unused]] int stride_m, // distance between mask rows (in bytes
[[maybe_unused]] int stride_qkv, // distance between rows in mask (in bytes)
[[maybe_unused]] const float * q, // q matrix.
[[maybe_unused]] const void * k, // k matrix. Assumed to be fp16, nq x nk elements
[[maybe_unused]] const void * v, // v matrix. Assumed to be fp16, nq x nk elements
[[maybe_unused]] const void * mask, // mask. If not null, assumed to be fp16. nq x nk elements
[[maybe_unused]] float scale, // scale applied before softmax
[[maybe_unused]] float softcap, // if > 0, a "soft-cap" operation is applied before softmax
[[maybe_unused]] float * qkv) { // v*softmax(scale*(k*q))
return false;
}
#endif
#else // IQK_IMPLEMENT