This commit is contained in:
Iwan Kawrakow
2025-04-18 13:50:30 +03:00
parent 74a21d48d6
commit b498633203

View File

@@ -15169,6 +15169,13 @@ struct F16 {
auto v256 = _mm256_set_m128(v128, v128);
return _mm512_insertf32x8(_mm512_castps256_ps512(v256), v256, 1);
}
static inline void set4(const float * ptr, Data * vs) {
auto v = set4(ptr);
vs[0] = _mm512_shuffle_ps(v, v, 0x00);
vs[1] = _mm512_shuffle_ps(v, v, 0x55);
vs[2] = _mm512_shuffle_ps(v, v, 0xaa);
vs[3] = _mm512_shuffle_ps(v, v, 0xff);
}
static inline Data fmadd_lane0(Data prev, Data v1, Data v2) { return _mm512_fmadd_ps(v1, _mm512_shuffle_ps(v2, v2, 0x00), prev); }
static inline Data fmadd_lane1(Data prev, Data v1, Data v2) { return _mm512_fmadd_ps(v1, _mm512_shuffle_ps(v2, v2, 0x55), prev); }
static inline Data fmadd_lane2(Data prev, Data v1, Data v2) { return _mm512_fmadd_ps(v1, _mm512_shuffle_ps(v2, v2, 0xaa), prev); }
@@ -15194,6 +15201,13 @@ struct F16 {
auto v128 = _mm_loadu_ps(ptr);
return _mm256_set_m128(v128, v128);
}
static inline void set4(const float * ptr, Data * vs) {
auto v = set4(ptr);
vs[0] = _mm256_shuffle_ps(v, v, 0x00);
vs[1] = _mm256_shuffle_ps(v, v, 0x55);
vs[2] = _mm256_shuffle_ps(v, v, 0xaa);
vs[3] = _mm256_shuffle_ps(v, v, 0xff);
}
static inline Data fmadd_lane0(Data prev, Data v1, Data v2) { return _mm256_fmadd_ps(v1, _mm256_shuffle_ps(v2, v2, 0x00), prev); }
static inline Data fmadd_lane1(Data prev, Data v1, Data v2) { return _mm256_fmadd_ps(v1, _mm256_shuffle_ps(v2, v2, 0x55), prev); }
static inline Data fmadd_lane2(Data prev, Data v1, Data v2) { return _mm256_fmadd_ps(v1, _mm256_shuffle_ps(v2, v2, 0xaa), prev); }
@@ -16214,6 +16228,9 @@ struct FlashQKV {
}
}
}
#ifdef __AVX2__
F16::Data vs[4];
#endif
for (int i = 0; i < D/F16::block_size; i += 2) {
for (int l = 0; l < k_step; l += 4) {
vh.load(l+0, i, v[0], v[4]);
@@ -16224,6 +16241,13 @@ struct FlashQKV {
auto R = qkv_cache + D*j;
auto s1 = F16::load(R + F16::block_size*(i+0));
auto s2 = F16::load(R + F16::block_size*(i+1));
#ifdef __AVX2__
F16::set4(fms.cache + k_step*j + l, vs);
for (int k = 0; k < 4; ++k) {
s1 = F16::fmadd(s1, v[k+0], vs[k]);
s2 = F16::fmadd(s2, v[k+4], vs[k]);
}
#else
auto vs = F16::set4(fms.cache + k_step*j + l);
s1 = F16::fmadd_lane0(s1, v[0], vs);
s2 = F16::fmadd_lane0(s2, v[4], vs);
@@ -16233,6 +16257,7 @@ struct FlashQKV {
s2 = F16::fmadd_lane2(s2, v[6], vs);
s1 = F16::fmadd_lane3(s1, v[3], vs);
s2 = F16::fmadd_lane3(s2, v[7], vs);
#endif
F16::store(R + F16::block_size*(i+0), s1);
F16::store(R + F16::block_size*(i+1), s2);
}
@@ -16778,8 +16803,9 @@ void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q,
FlashMS<q_step, k_step>& fms,
FlashQKV<Dv, q_step, k_step>& fqkv,
const float * q, const char * mask, float * qkv,
float * M, float * S) {
typename KHelper::block_q8 q8[q_step*(Dk/KHelper::block_size_q)];
float * M, float * S, char * qptr) {
//typename KHelper::block_q8 q8[q_step*(Dk/KHelper::block_size_q)];
auto q8 = (typename KHelper::block_q8 *)qptr;
#if FA_TIMING
Perf perf(false);
#endif
@@ -16845,6 +16871,12 @@ void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q,
#endif
}
char * get_q_storage(size_t size) {
thread_local std::vector<char> q_storage;
if (q_storage.size() < size) q_storage.resize(size);
return q_storage.data();
}
// Some of the methods in FlashAttn have two identical implementations that only differ by
// one version using a loop over the template parameter q_step, while the other using a loop
// over an input parameter nq (these are loops over the rows of q^T). I dislike this a lot,
@@ -16867,45 +16899,116 @@ struct FlashAttn {
template <typename KHelper, typename VHelper>
void compute(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv,
const float * q, const char * mask, float * qkv, [[maybe_unused]] float * M, [[maybe_unused]] float * S) {
if constexpr (std::is_same_v<KHelper, HelperQ40<Dk, k_step>> || std::is_same_v<KHelper, HelperQ41<Dk, k_step>> ||
if constexpr (std::is_same_v<KHelper, HelperQ40<Dk, k_step>> ||
std::is_same_v<KHelper, HelperQ41<Dk, k_step>> ||
std::is_same_v<KHelper, HelperIQ4nl<Dk, k_step>> ||
std::is_same_v<KHelper, HelperQ60<Dk, k_step>> ||
std::is_same_v<KHelper, HelperQ80R8<Dk, k_step>>) {
compute_helper_q<Dk, Dv, q_step, k_step, KHelper, VHelper, FlashQKfp32<Dk, q_step, k_step>>(
kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S);
}
else if constexpr (std::is_same_v<KHelper, HelperQ80<Dk, k_step>>) {
if (nq1 >= 8) {
std::is_same_v<KHelper, HelperQ80R8<Dk, k_step>> ||
std::is_same_v<KHelper, HelperQ80<Dk, k_step>> ||
std::is_same_v<KHelper, HelperQ8KV<Dk, k_step>> ||
std::is_same_v<KHelper, HelperQ8KVR8<Dk, k_step>>) {
constexpr size_t kMaxOnStackSize = 576;
auto q_size = q_step*(Dk/KHelper::block_size_q)*sizeof(typename KHelper::block_q8);
q_size = GGML_PAD(q_size, 64);
if (q_size*q_step > kMaxOnStackSize) {
auto qptr = get_q_storage(q_size);
if (nq1 >= 8) {
if constexpr (std::is_same_v<KHelper, HelperQ80<Dk, k_step>>) {
#if FA_TIMING
auto t1 = Perf::cur_time();
HelperQ80R8<Dk, k_step> khr4(nk1, kh);
Perf::instance().accum(4, t1);
auto t1 = Perf::cur_time();
HelperQ80R8<Dk, k_step> khr4(nk1, kh);
Perf::instance().accum(4, t1);
#else
HelperQ80R8<Dk, k_step> khr4(nk1, kh);
HelperQ80R8<Dk, k_step> khr4(nk1, kh);
#endif
compute_helper_q<Dk, Dv, q_step, k_step, HelperQ80R8<Dk, k_step>, VHelper, FlashQKfp32<Dk, q_step, k_step>>(
khr4, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S);
} else{
compute_helper_q<Dk, Dv, q_step, k_step, HelperQ80R8<Dk, k_step>, VHelper, FlashQKfp32<Dk, q_step, k_step>>(
khr4, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S, qptr);
return;
}
if constexpr (std::is_same_v<KHelper, HelperQ8KV<Dk, k_step>>) {
#if FA_TIMING
auto t1 = Perf::cur_time();
HelperQ8KVR8<Dk, k_step> khr4(nk1, kh);
Perf::instance().accum(4, t1);
#else
HelperQ8KVR8<Dk, k_step> khr4(nk1, kh);
#endif
compute_helper_q<Dk, Dv, q_step, k_step, HelperQ8KVR8<Dk, k_step>, VHelper, FlashQKfp32<Dk, q_step, k_step>>(
khr4, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S, qptr);
return;
}
}
compute_helper_q<Dk, Dv, q_step, k_step, KHelper, VHelper, FlashQKfp32<Dk, q_step, k_step>>(
kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S);
kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S, qptr);
}
else {
typename KHelper::block_q8 q8[q_step*(Dk/KHelper::block_size_q)];
auto qptr = (char *)q8;
if (nq1 >= 8) {
if constexpr (std::is_same_v<KHelper, HelperQ80<Dk, k_step>>) {
#if FA_TIMING
auto t1 = Perf::cur_time();
HelperQ80R8<Dk, k_step> khr4(nk1, kh);
Perf::instance().accum(4, t1);
#else
HelperQ80R8<Dk, k_step> khr4(nk1, kh);
#endif
compute_helper_q<Dk, Dv, q_step, k_step, HelperQ80R8<Dk, k_step>, VHelper, FlashQKfp32<Dk, q_step, k_step>>(
khr4, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S, qptr);
return;
}
if constexpr (std::is_same_v<KHelper, HelperQ8KV<Dk, k_step>>) {
#if FA_TIMING
auto t1 = Perf::cur_time();
HelperQ8KVR8<Dk, k_step> khr4(nk1, kh);
Perf::instance().accum(4, t1);
#else
HelperQ8KVR8<Dk, k_step> khr4(nk1, kh);
#endif
compute_helper_q<Dk, Dv, q_step, k_step, HelperQ8KVR8<Dk, k_step>, VHelper, FlashQKfp32<Dk, q_step, k_step>>(
khr4, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S, qptr);
return;
}
}
compute_helper_q<Dk, Dv, q_step, k_step, KHelper, VHelper, FlashQKfp32<Dk, q_step, k_step>>(
kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S, qptr);
}
}
else if constexpr (std::is_same_v<KHelper, HelperQ8KV<Dk, k_step>>) {
if (nq1 >= 8) {
#if FA_TIMING
auto t1 = Perf::cur_time();
HelperQ8KVR8<Dk, k_step> khr4(nk1, kh);
Perf::instance().accum(4, t1);
#else
HelperQ8KVR8<Dk, k_step> khr4(nk1, kh);
#endif
compute_helper_q<Dk, Dv, q_step, k_step, HelperQ8KVR8<Dk, k_step>, VHelper, FlashQKfp32<Dk, q_step, k_step>>(
khr4, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S);
} else{
compute_helper_q<Dk, Dv, q_step, k_step, KHelper, VHelper, FlashQKfp32<Dk, q_step, k_step>>(
kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S);
}
} else {
// else if constexpr (std::is_same_v<KHelper, HelperQ80<Dk, k_step>>) {
// if (nq1 >= 8) {
//#if FA_TIMING
// auto t1 = Perf::cur_time();
// HelperQ80R8<Dk, k_step> khr4(nk1, kh);
// Perf::instance().accum(4, t1);
//#else
// HelperQ80R8<Dk, k_step> khr4(nk1, kh);
//#endif
// compute_helper_q<Dk, Dv, q_step, k_step, HelperQ80R8<Dk, k_step>, VHelper, FlashQKfp32<Dk, q_step, k_step>>(
// khr4, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S);
// } else{
// compute_helper_q<Dk, Dv, q_step, k_step, KHelper, VHelper, FlashQKfp32<Dk, q_step, k_step>>(
// kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S);
// }
// }
// else if constexpr (std::is_same_v<KHelper, HelperQ8KV<Dk, k_step>>) {
// if (nq1 >= 8) {
//#if FA_TIMING
// auto t1 = Perf::cur_time();
// HelperQ8KVR8<Dk, k_step> khr4(nk1, kh);
// Perf::instance().accum(4, t1);
//#else
// HelperQ8KVR8<Dk, k_step> khr4(nk1, kh);
//#endif
// compute_helper_q<Dk, Dv, q_step, k_step, HelperQ8KVR8<Dk, k_step>, VHelper, FlashQKfp32<Dk, q_step, k_step>>(
// khr4, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S);
// } else{
// compute_helper_q<Dk, Dv, q_step, k_step, KHelper, VHelper, FlashQKfp32<Dk, q_step, k_step>>(
// kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S);
// }
else {
compute_helper<Dk, Dv, q_step, k_step, KHelper, VHelper, FlashQKfp32<Dk, q_step, k_step>>(
kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S);
}
@@ -17349,39 +17452,61 @@ template <int Dk, int Dv, int k_step, typename KHelper, typename VHelper>
inline void iqk_flash_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv,
const float * q, const char * mask, float scale, float softcap, float * qkv, float * M, float * S) {
if (nk1 >= 256) { //4096) {
if (nq1 >= 64) {
auto update = [&nq1, &mask, &q, &qkv, &M, &S, stride_q, stride_m, stride_qkv] (int n) {
nq1 -= n;
if (nq1 == 0) return true;
q += n*stride_q;
mask += n*stride_m;
qkv += n*stride_qkv;
if (M && S) { M += n; S += n; }
return false;
};
if (nk1 >= 512) {
if (nq1 >= 128) {
int n_step = nq1/128;
FlashAttn<Dk, Dv, 64, k_step> fa(scale, softcap);
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
return;
fa.compute(kh, vh, 128*n_step, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
if (update(128*n_step)) return;
}
if (nq1 >= 64) {
int n_step = nq1/64;
FlashAttn<Dk, Dv, 64, k_step> fa(scale, softcap);
fa.compute(kh, vh, 64*n_step, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
if (update(64*n_step)) return;
}
if (nq1 >= 32) {
int n_step = nq1/32;
FlashAttn<Dk, Dv, 32, k_step> fa(scale, softcap);
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
return;
fa.compute(kh, vh, 32*n_step, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
if (update(32*n_step)) return;
}
if (nq1 >= 16) {
int n_step = nq1/16;
FlashAttn<Dk, Dv, 16, k_step> fa(scale, softcap);
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
return;
fa.compute(kh, vh, 16*n_step, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
if (update(16*n_step)) return;
}
}
if (nq1 >= 8) {
int n_step = nq1/8;
FlashAttn<Dk, Dv, 8, k_step> fa(scale, softcap);
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
fa.compute(kh, vh, 8*n_step, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
if (update(8*n_step)) return;
}
else if (nq1 >= 4) {
int n_step = nq1/4;
FlashAttn<Dk, Dv, 4, k_step> fa(scale, softcap);
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
fa.compute(kh, vh, 4*n_step, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
if (update(4*n_step)) return;
}
else if (nq1 >= 2) {
int n_step = nq1/2;
FlashAttn<Dk, Dv, 2, k_step> fa(scale, softcap);
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
}
else {
FlashAttn<Dk, Dv, 1, k_step> fa(scale, softcap);
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
fa.compute(kh, vh, 2*n_step, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
if (update(2*n_step)) return;
}
FlashAttn<Dk, Dv, 1, k_step> fa(scale, softcap);
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
}
#ifdef __AVX512BF16__
@@ -17523,25 +17648,35 @@ template <int step_k, typename KHelper, typename VHelper>
inline void iqk_deepseek_helper(KHelper& kh, VHelper& vh,
int nq1, int nk1, int stride_q, int stride_m, int stride_qkv,
const float * q, const char * mask, float scale, float softcap, float * qkv, float * M, float * S) {
auto update = [&nq1, &mask, &q, &qkv, &M, &S, stride_q, stride_m, stride_qkv] (int n) {
nq1 -= n;
if (nq1 == 0) return true;
q += n*stride_q;
mask += n*stride_m;
qkv += n*stride_qkv;
if (M && S) { M += n; S += n; }
return false;
};
if (nq1 >= 8) {
int n_step = nq1/8;
FlashAttn<576, 512, 8, step_k> fa(scale, softcap);
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S);
fa.compute(kh, vh, 8*n_step, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S);
if (update(8*n_step)) return;
}
else if (nq1 >= 4) {
if (nq1 >= 4) {
int n_step = nq1/4;
FlashAttn<576, 512, 4, step_k> fa(scale, softcap);
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S);
fa.compute(kh, vh, 4*n_step, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S);
if (update(4*n_step)) return;
}
else {
FlashAttn<576, 512, 1, step_k> fa(scale, softcap);
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S);
if (nq1 >= 2) {
int n_step = nq1/2;
FlashAttn<576, 512, 2, step_k> fa(scale, softcap);
fa.compute(kh, vh, 2*n_step, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S);
if (update(2*n_step)) return;
}
//if (nq1 % 8 == 0) {
// FlashAttn<576, 512, 8, step_k> fa(scale, softcap);
// fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S);
//} else {
// FlashAttn<576, 512, 1, step_k> fa(scale, softcap);
// fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S);
//}
FlashAttn<576, 512, 1, step_k> fa(scale, softcap);
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S);
}
template <int step_k>
@@ -17683,6 +17818,23 @@ bool iqk_flash_attn_impl(int int_type_k, // type of k
}
#endif
if (nk1%128 == 0) {
switch (Dk) {
case 64:
iqk_flash_helper_T< 64, 64, 128>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break;
case 96:
iqk_flash_helper_T< 96, 96, 128>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break;
case 128:
iqk_flash_helper_T<128, 128, 128>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break;
case 192:
iqk_flash_helper_T<192, 128, 128>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break;
case 256:
iqk_flash_helper_T<256, 256, 128>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break;
default:
return false;
}
return true;
}
if (nk1%64 == 0) {
switch (Dk) {
case 64: