mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-27 01:49:28 +00:00
WIP
This commit is contained in:
@@ -15169,6 +15169,13 @@ struct F16 {
|
||||
auto v256 = _mm256_set_m128(v128, v128);
|
||||
return _mm512_insertf32x8(_mm512_castps256_ps512(v256), v256, 1);
|
||||
}
|
||||
static inline void set4(const float * ptr, Data * vs) {
|
||||
auto v = set4(ptr);
|
||||
vs[0] = _mm512_shuffle_ps(v, v, 0x00);
|
||||
vs[1] = _mm512_shuffle_ps(v, v, 0x55);
|
||||
vs[2] = _mm512_shuffle_ps(v, v, 0xaa);
|
||||
vs[3] = _mm512_shuffle_ps(v, v, 0xff);
|
||||
}
|
||||
static inline Data fmadd_lane0(Data prev, Data v1, Data v2) { return _mm512_fmadd_ps(v1, _mm512_shuffle_ps(v2, v2, 0x00), prev); }
|
||||
static inline Data fmadd_lane1(Data prev, Data v1, Data v2) { return _mm512_fmadd_ps(v1, _mm512_shuffle_ps(v2, v2, 0x55), prev); }
|
||||
static inline Data fmadd_lane2(Data prev, Data v1, Data v2) { return _mm512_fmadd_ps(v1, _mm512_shuffle_ps(v2, v2, 0xaa), prev); }
|
||||
@@ -15194,6 +15201,13 @@ struct F16 {
|
||||
auto v128 = _mm_loadu_ps(ptr);
|
||||
return _mm256_set_m128(v128, v128);
|
||||
}
|
||||
static inline void set4(const float * ptr, Data * vs) {
|
||||
auto v = set4(ptr);
|
||||
vs[0] = _mm256_shuffle_ps(v, v, 0x00);
|
||||
vs[1] = _mm256_shuffle_ps(v, v, 0x55);
|
||||
vs[2] = _mm256_shuffle_ps(v, v, 0xaa);
|
||||
vs[3] = _mm256_shuffle_ps(v, v, 0xff);
|
||||
}
|
||||
static inline Data fmadd_lane0(Data prev, Data v1, Data v2) { return _mm256_fmadd_ps(v1, _mm256_shuffle_ps(v2, v2, 0x00), prev); }
|
||||
static inline Data fmadd_lane1(Data prev, Data v1, Data v2) { return _mm256_fmadd_ps(v1, _mm256_shuffle_ps(v2, v2, 0x55), prev); }
|
||||
static inline Data fmadd_lane2(Data prev, Data v1, Data v2) { return _mm256_fmadd_ps(v1, _mm256_shuffle_ps(v2, v2, 0xaa), prev); }
|
||||
@@ -16214,6 +16228,9 @@ struct FlashQKV {
|
||||
}
|
||||
}
|
||||
}
|
||||
#ifdef __AVX2__
|
||||
F16::Data vs[4];
|
||||
#endif
|
||||
for (int i = 0; i < D/F16::block_size; i += 2) {
|
||||
for (int l = 0; l < k_step; l += 4) {
|
||||
vh.load(l+0, i, v[0], v[4]);
|
||||
@@ -16224,6 +16241,13 @@ struct FlashQKV {
|
||||
auto R = qkv_cache + D*j;
|
||||
auto s1 = F16::load(R + F16::block_size*(i+0));
|
||||
auto s2 = F16::load(R + F16::block_size*(i+1));
|
||||
#ifdef __AVX2__
|
||||
F16::set4(fms.cache + k_step*j + l, vs);
|
||||
for (int k = 0; k < 4; ++k) {
|
||||
s1 = F16::fmadd(s1, v[k+0], vs[k]);
|
||||
s2 = F16::fmadd(s2, v[k+4], vs[k]);
|
||||
}
|
||||
#else
|
||||
auto vs = F16::set4(fms.cache + k_step*j + l);
|
||||
s1 = F16::fmadd_lane0(s1, v[0], vs);
|
||||
s2 = F16::fmadd_lane0(s2, v[4], vs);
|
||||
@@ -16233,6 +16257,7 @@ struct FlashQKV {
|
||||
s2 = F16::fmadd_lane2(s2, v[6], vs);
|
||||
s1 = F16::fmadd_lane3(s1, v[3], vs);
|
||||
s2 = F16::fmadd_lane3(s2, v[7], vs);
|
||||
#endif
|
||||
F16::store(R + F16::block_size*(i+0), s1);
|
||||
F16::store(R + F16::block_size*(i+1), s2);
|
||||
}
|
||||
@@ -16778,8 +16803,9 @@ void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q,
|
||||
FlashMS<q_step, k_step>& fms,
|
||||
FlashQKV<Dv, q_step, k_step>& fqkv,
|
||||
const float * q, const char * mask, float * qkv,
|
||||
float * M, float * S) {
|
||||
typename KHelper::block_q8 q8[q_step*(Dk/KHelper::block_size_q)];
|
||||
float * M, float * S, char * qptr) {
|
||||
//typename KHelper::block_q8 q8[q_step*(Dk/KHelper::block_size_q)];
|
||||
auto q8 = (typename KHelper::block_q8 *)qptr;
|
||||
#if FA_TIMING
|
||||
Perf perf(false);
|
||||
#endif
|
||||
@@ -16845,6 +16871,12 @@ void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q,
|
||||
#endif
|
||||
}
|
||||
|
||||
char * get_q_storage(size_t size) {
|
||||
thread_local std::vector<char> q_storage;
|
||||
if (q_storage.size() < size) q_storage.resize(size);
|
||||
return q_storage.data();
|
||||
}
|
||||
|
||||
// Some of the methods in FlashAttn have two identical implementations that only differ by
|
||||
// one version using a loop over the template parameter q_step, while the other using a loop
|
||||
// over an input parameter nq (these are loops over the rows of q^T). I dislike this a lot,
|
||||
@@ -16867,45 +16899,116 @@ struct FlashAttn {
|
||||
template <typename KHelper, typename VHelper>
|
||||
void compute(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv,
|
||||
const float * q, const char * mask, float * qkv, [[maybe_unused]] float * M, [[maybe_unused]] float * S) {
|
||||
if constexpr (std::is_same_v<KHelper, HelperQ40<Dk, k_step>> || std::is_same_v<KHelper, HelperQ41<Dk, k_step>> ||
|
||||
if constexpr (std::is_same_v<KHelper, HelperQ40<Dk, k_step>> ||
|
||||
std::is_same_v<KHelper, HelperQ41<Dk, k_step>> ||
|
||||
std::is_same_v<KHelper, HelperIQ4nl<Dk, k_step>> ||
|
||||
std::is_same_v<KHelper, HelperQ60<Dk, k_step>> ||
|
||||
std::is_same_v<KHelper, HelperQ80R8<Dk, k_step>>) {
|
||||
compute_helper_q<Dk, Dv, q_step, k_step, KHelper, VHelper, FlashQKfp32<Dk, q_step, k_step>>(
|
||||
kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S);
|
||||
}
|
||||
else if constexpr (std::is_same_v<KHelper, HelperQ80<Dk, k_step>>) {
|
||||
if (nq1 >= 8) {
|
||||
std::is_same_v<KHelper, HelperQ80R8<Dk, k_step>> ||
|
||||
std::is_same_v<KHelper, HelperQ80<Dk, k_step>> ||
|
||||
std::is_same_v<KHelper, HelperQ8KV<Dk, k_step>> ||
|
||||
std::is_same_v<KHelper, HelperQ8KVR8<Dk, k_step>>) {
|
||||
constexpr size_t kMaxOnStackSize = 576;
|
||||
auto q_size = q_step*(Dk/KHelper::block_size_q)*sizeof(typename KHelper::block_q8);
|
||||
q_size = GGML_PAD(q_size, 64);
|
||||
if (q_size*q_step > kMaxOnStackSize) {
|
||||
auto qptr = get_q_storage(q_size);
|
||||
if (nq1 >= 8) {
|
||||
if constexpr (std::is_same_v<KHelper, HelperQ80<Dk, k_step>>) {
|
||||
#if FA_TIMING
|
||||
auto t1 = Perf::cur_time();
|
||||
HelperQ80R8<Dk, k_step> khr4(nk1, kh);
|
||||
Perf::instance().accum(4, t1);
|
||||
auto t1 = Perf::cur_time();
|
||||
HelperQ80R8<Dk, k_step> khr4(nk1, kh);
|
||||
Perf::instance().accum(4, t1);
|
||||
#else
|
||||
HelperQ80R8<Dk, k_step> khr4(nk1, kh);
|
||||
HelperQ80R8<Dk, k_step> khr4(nk1, kh);
|
||||
#endif
|
||||
compute_helper_q<Dk, Dv, q_step, k_step, HelperQ80R8<Dk, k_step>, VHelper, FlashQKfp32<Dk, q_step, k_step>>(
|
||||
khr4, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S);
|
||||
} else{
|
||||
compute_helper_q<Dk, Dv, q_step, k_step, HelperQ80R8<Dk, k_step>, VHelper, FlashQKfp32<Dk, q_step, k_step>>(
|
||||
khr4, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S, qptr);
|
||||
return;
|
||||
|
||||
}
|
||||
if constexpr (std::is_same_v<KHelper, HelperQ8KV<Dk, k_step>>) {
|
||||
#if FA_TIMING
|
||||
auto t1 = Perf::cur_time();
|
||||
HelperQ8KVR8<Dk, k_step> khr4(nk1, kh);
|
||||
Perf::instance().accum(4, t1);
|
||||
#else
|
||||
HelperQ8KVR8<Dk, k_step> khr4(nk1, kh);
|
||||
#endif
|
||||
compute_helper_q<Dk, Dv, q_step, k_step, HelperQ8KVR8<Dk, k_step>, VHelper, FlashQKfp32<Dk, q_step, k_step>>(
|
||||
khr4, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S, qptr);
|
||||
return;
|
||||
}
|
||||
}
|
||||
compute_helper_q<Dk, Dv, q_step, k_step, KHelper, VHelper, FlashQKfp32<Dk, q_step, k_step>>(
|
||||
kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S);
|
||||
kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S, qptr);
|
||||
|
||||
}
|
||||
else {
|
||||
typename KHelper::block_q8 q8[q_step*(Dk/KHelper::block_size_q)];
|
||||
auto qptr = (char *)q8;
|
||||
if (nq1 >= 8) {
|
||||
if constexpr (std::is_same_v<KHelper, HelperQ80<Dk, k_step>>) {
|
||||
#if FA_TIMING
|
||||
auto t1 = Perf::cur_time();
|
||||
HelperQ80R8<Dk, k_step> khr4(nk1, kh);
|
||||
Perf::instance().accum(4, t1);
|
||||
#else
|
||||
HelperQ80R8<Dk, k_step> khr4(nk1, kh);
|
||||
#endif
|
||||
compute_helper_q<Dk, Dv, q_step, k_step, HelperQ80R8<Dk, k_step>, VHelper, FlashQKfp32<Dk, q_step, k_step>>(
|
||||
khr4, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S, qptr);
|
||||
return;
|
||||
|
||||
}
|
||||
if constexpr (std::is_same_v<KHelper, HelperQ8KV<Dk, k_step>>) {
|
||||
#if FA_TIMING
|
||||
auto t1 = Perf::cur_time();
|
||||
HelperQ8KVR8<Dk, k_step> khr4(nk1, kh);
|
||||
Perf::instance().accum(4, t1);
|
||||
#else
|
||||
HelperQ8KVR8<Dk, k_step> khr4(nk1, kh);
|
||||
#endif
|
||||
compute_helper_q<Dk, Dv, q_step, k_step, HelperQ8KVR8<Dk, k_step>, VHelper, FlashQKfp32<Dk, q_step, k_step>>(
|
||||
khr4, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S, qptr);
|
||||
return;
|
||||
}
|
||||
}
|
||||
compute_helper_q<Dk, Dv, q_step, k_step, KHelper, VHelper, FlashQKfp32<Dk, q_step, k_step>>(
|
||||
kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S, qptr);
|
||||
}
|
||||
}
|
||||
else if constexpr (std::is_same_v<KHelper, HelperQ8KV<Dk, k_step>>) {
|
||||
if (nq1 >= 8) {
|
||||
#if FA_TIMING
|
||||
auto t1 = Perf::cur_time();
|
||||
HelperQ8KVR8<Dk, k_step> khr4(nk1, kh);
|
||||
Perf::instance().accum(4, t1);
|
||||
#else
|
||||
HelperQ8KVR8<Dk, k_step> khr4(nk1, kh);
|
||||
#endif
|
||||
compute_helper_q<Dk, Dv, q_step, k_step, HelperQ8KVR8<Dk, k_step>, VHelper, FlashQKfp32<Dk, q_step, k_step>>(
|
||||
khr4, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S);
|
||||
} else{
|
||||
compute_helper_q<Dk, Dv, q_step, k_step, KHelper, VHelper, FlashQKfp32<Dk, q_step, k_step>>(
|
||||
kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S);
|
||||
}
|
||||
} else {
|
||||
// else if constexpr (std::is_same_v<KHelper, HelperQ80<Dk, k_step>>) {
|
||||
// if (nq1 >= 8) {
|
||||
//#if FA_TIMING
|
||||
// auto t1 = Perf::cur_time();
|
||||
// HelperQ80R8<Dk, k_step> khr4(nk1, kh);
|
||||
// Perf::instance().accum(4, t1);
|
||||
//#else
|
||||
// HelperQ80R8<Dk, k_step> khr4(nk1, kh);
|
||||
//#endif
|
||||
// compute_helper_q<Dk, Dv, q_step, k_step, HelperQ80R8<Dk, k_step>, VHelper, FlashQKfp32<Dk, q_step, k_step>>(
|
||||
// khr4, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S);
|
||||
// } else{
|
||||
// compute_helper_q<Dk, Dv, q_step, k_step, KHelper, VHelper, FlashQKfp32<Dk, q_step, k_step>>(
|
||||
// kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S);
|
||||
// }
|
||||
// }
|
||||
// else if constexpr (std::is_same_v<KHelper, HelperQ8KV<Dk, k_step>>) {
|
||||
// if (nq1 >= 8) {
|
||||
//#if FA_TIMING
|
||||
// auto t1 = Perf::cur_time();
|
||||
// HelperQ8KVR8<Dk, k_step> khr4(nk1, kh);
|
||||
// Perf::instance().accum(4, t1);
|
||||
//#else
|
||||
// HelperQ8KVR8<Dk, k_step> khr4(nk1, kh);
|
||||
//#endif
|
||||
// compute_helper_q<Dk, Dv, q_step, k_step, HelperQ8KVR8<Dk, k_step>, VHelper, FlashQKfp32<Dk, q_step, k_step>>(
|
||||
// khr4, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S);
|
||||
// } else{
|
||||
// compute_helper_q<Dk, Dv, q_step, k_step, KHelper, VHelper, FlashQKfp32<Dk, q_step, k_step>>(
|
||||
// kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S);
|
||||
// }
|
||||
else {
|
||||
compute_helper<Dk, Dv, q_step, k_step, KHelper, VHelper, FlashQKfp32<Dk, q_step, k_step>>(
|
||||
kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S);
|
||||
}
|
||||
@@ -17349,39 +17452,61 @@ template <int Dk, int Dv, int k_step, typename KHelper, typename VHelper>
|
||||
inline void iqk_flash_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv,
|
||||
const float * q, const char * mask, float scale, float softcap, float * qkv, float * M, float * S) {
|
||||
|
||||
if (nk1 >= 256) { //4096) {
|
||||
if (nq1 >= 64) {
|
||||
auto update = [&nq1, &mask, &q, &qkv, &M, &S, stride_q, stride_m, stride_qkv] (int n) {
|
||||
nq1 -= n;
|
||||
if (nq1 == 0) return true;
|
||||
q += n*stride_q;
|
||||
mask += n*stride_m;
|
||||
qkv += n*stride_qkv;
|
||||
if (M && S) { M += n; S += n; }
|
||||
return false;
|
||||
};
|
||||
if (nk1 >= 512) {
|
||||
if (nq1 >= 128) {
|
||||
int n_step = nq1/128;
|
||||
FlashAttn<Dk, Dv, 64, k_step> fa(scale, softcap);
|
||||
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
|
||||
return;
|
||||
fa.compute(kh, vh, 128*n_step, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
|
||||
if (update(128*n_step)) return;
|
||||
}
|
||||
if (nq1 >= 64) {
|
||||
int n_step = nq1/64;
|
||||
FlashAttn<Dk, Dv, 64, k_step> fa(scale, softcap);
|
||||
fa.compute(kh, vh, 64*n_step, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
|
||||
if (update(64*n_step)) return;
|
||||
}
|
||||
if (nq1 >= 32) {
|
||||
int n_step = nq1/32;
|
||||
FlashAttn<Dk, Dv, 32, k_step> fa(scale, softcap);
|
||||
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
|
||||
return;
|
||||
fa.compute(kh, vh, 32*n_step, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
|
||||
if (update(32*n_step)) return;
|
||||
}
|
||||
if (nq1 >= 16) {
|
||||
int n_step = nq1/16;
|
||||
FlashAttn<Dk, Dv, 16, k_step> fa(scale, softcap);
|
||||
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
|
||||
return;
|
||||
fa.compute(kh, vh, 16*n_step, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
|
||||
if (update(16*n_step)) return;
|
||||
}
|
||||
}
|
||||
if (nq1 >= 8) {
|
||||
int n_step = nq1/8;
|
||||
FlashAttn<Dk, Dv, 8, k_step> fa(scale, softcap);
|
||||
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
|
||||
fa.compute(kh, vh, 8*n_step, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
|
||||
if (update(8*n_step)) return;
|
||||
}
|
||||
else if (nq1 >= 4) {
|
||||
int n_step = nq1/4;
|
||||
FlashAttn<Dk, Dv, 4, k_step> fa(scale, softcap);
|
||||
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
|
||||
fa.compute(kh, vh, 4*n_step, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
|
||||
if (update(4*n_step)) return;
|
||||
}
|
||||
else if (nq1 >= 2) {
|
||||
int n_step = nq1/2;
|
||||
FlashAttn<Dk, Dv, 2, k_step> fa(scale, softcap);
|
||||
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
|
||||
}
|
||||
else {
|
||||
FlashAttn<Dk, Dv, 1, k_step> fa(scale, softcap);
|
||||
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
|
||||
fa.compute(kh, vh, 2*n_step, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
|
||||
if (update(2*n_step)) return;
|
||||
}
|
||||
FlashAttn<Dk, Dv, 1, k_step> fa(scale, softcap);
|
||||
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
|
||||
}
|
||||
|
||||
#ifdef __AVX512BF16__
|
||||
@@ -17523,25 +17648,35 @@ template <int step_k, typename KHelper, typename VHelper>
|
||||
inline void iqk_deepseek_helper(KHelper& kh, VHelper& vh,
|
||||
int nq1, int nk1, int stride_q, int stride_m, int stride_qkv,
|
||||
const float * q, const char * mask, float scale, float softcap, float * qkv, float * M, float * S) {
|
||||
auto update = [&nq1, &mask, &q, &qkv, &M, &S, stride_q, stride_m, stride_qkv] (int n) {
|
||||
nq1 -= n;
|
||||
if (nq1 == 0) return true;
|
||||
q += n*stride_q;
|
||||
mask += n*stride_m;
|
||||
qkv += n*stride_qkv;
|
||||
if (M && S) { M += n; S += n; }
|
||||
return false;
|
||||
};
|
||||
if (nq1 >= 8) {
|
||||
int n_step = nq1/8;
|
||||
FlashAttn<576, 512, 8, step_k> fa(scale, softcap);
|
||||
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S);
|
||||
fa.compute(kh, vh, 8*n_step, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S);
|
||||
if (update(8*n_step)) return;
|
||||
}
|
||||
else if (nq1 >= 4) {
|
||||
if (nq1 >= 4) {
|
||||
int n_step = nq1/4;
|
||||
FlashAttn<576, 512, 4, step_k> fa(scale, softcap);
|
||||
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S);
|
||||
fa.compute(kh, vh, 4*n_step, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S);
|
||||
if (update(4*n_step)) return;
|
||||
}
|
||||
else {
|
||||
FlashAttn<576, 512, 1, step_k> fa(scale, softcap);
|
||||
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S);
|
||||
if (nq1 >= 2) {
|
||||
int n_step = nq1/2;
|
||||
FlashAttn<576, 512, 2, step_k> fa(scale, softcap);
|
||||
fa.compute(kh, vh, 2*n_step, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S);
|
||||
if (update(2*n_step)) return;
|
||||
}
|
||||
//if (nq1 % 8 == 0) {
|
||||
// FlashAttn<576, 512, 8, step_k> fa(scale, softcap);
|
||||
// fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S);
|
||||
//} else {
|
||||
// FlashAttn<576, 512, 1, step_k> fa(scale, softcap);
|
||||
// fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S);
|
||||
//}
|
||||
FlashAttn<576, 512, 1, step_k> fa(scale, softcap);
|
||||
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S);
|
||||
}
|
||||
|
||||
template <int step_k>
|
||||
@@ -17683,6 +17818,23 @@ bool iqk_flash_attn_impl(int int_type_k, // type of k
|
||||
}
|
||||
#endif
|
||||
|
||||
if (nk1%128 == 0) {
|
||||
switch (Dk) {
|
||||
case 64:
|
||||
iqk_flash_helper_T< 64, 64, 128>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break;
|
||||
case 96:
|
||||
iqk_flash_helper_T< 96, 96, 128>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break;
|
||||
case 128:
|
||||
iqk_flash_helper_T<128, 128, 128>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break;
|
||||
case 192:
|
||||
iqk_flash_helper_T<192, 128, 128>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break;
|
||||
case 256:
|
||||
iqk_flash_helper_T<256, 256, 128>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
if (nk1%64 == 0) {
|
||||
switch (Dk) {
|
||||
case 64:
|
||||
|
||||
Reference in New Issue
Block a user