mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-26 08:04:09 +00:00
Most helpers don't need to be templates
Also hide Q4_0 and Q8_KV behind IQK_FA_ALL_QUANTS. Compilation time drops to 14 second on the Ryzen-5975WX
This commit is contained in:
@@ -53,32 +53,34 @@ inline bool iqk_deepseek_helper(ggml_type type_k,
|
||||
const float * q, const char * k, const char * v, const char * mask,
|
||||
float scale, float softcap, float * qkv, float * M, float * S) {
|
||||
if (type_k == GGML_TYPE_Q8_0) {
|
||||
HelperQ80<576, step_k> kh((const char *)k, stride_k);
|
||||
HelperQ80<512, step_k> vh((const char *)v, stride_v);
|
||||
HelperQ80 kh((const char *)k, stride_k);
|
||||
HelperQ80 vh((const char *)v, stride_v);
|
||||
iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S);
|
||||
return true;
|
||||
}
|
||||
if (type_k == GGML_TYPE_Q8_0_R8) {
|
||||
HelperQ80R8<576, step_k> kh((const char *)k, stride_k);
|
||||
HelperQ80<512, step_k> vh((const char *)v, stride_v);
|
||||
HelperQ80R8<576> kh((const char *)k, stride_k);
|
||||
HelperQ80 vh((const char *)v, stride_v);
|
||||
iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S);
|
||||
return true;
|
||||
}
|
||||
if (type_k == GGML_TYPE_Q6_0) {
|
||||
HelperQ60<576, step_k> kh((const char *)k, stride_k);
|
||||
HelperQ60<512, step_k> vh((const char *)v, stride_v);
|
||||
HelperQ60 kh((const char *)k, stride_k);
|
||||
HelperQ60 vh((const char *)v, stride_v);
|
||||
iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S);
|
||||
return true;
|
||||
}
|
||||
#if GGML_IQK_FA_ALL_QUANTS
|
||||
if (type_k == GGML_TYPE_Q8_KV) {
|
||||
HelperQ8KV<576, step_k> kh((const char *)k, stride_k);
|
||||
HelperQ8KV<512, step_k> vh((const char *)v, stride_v);
|
||||
HelperQ8KV<576> kh((const char *)k, stride_k);
|
||||
HelperQ8KV<512> vh((const char *)v, stride_v);
|
||||
iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S);
|
||||
return true;
|
||||
}
|
||||
#endif
|
||||
if (type_k == GGML_TYPE_F16) {
|
||||
HelperF16<576, step_k> kh((const char *)k, stride_k);
|
||||
HelperF16<512, step_k> vh((const char *)v, stride_v);
|
||||
HelperF16 kh((const char *)k, stride_k);
|
||||
HelperF16 vh((const char *)v, stride_v);
|
||||
iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S);
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -32,13 +32,12 @@
|
||||
|
||||
namespace {
|
||||
|
||||
template <int k_step>
|
||||
struct BaseHelper {
|
||||
BaseHelper(const char * data, int stride) : data(data), block(data), stride(stride) {}
|
||||
|
||||
//inline void set_block(int k1) { block = data + k1*k_step*stride; }
|
||||
inline void reset_block() { block = data; }
|
||||
inline void next_block() { block += k_step*stride; }
|
||||
inline void next_block(int step) { block += step*stride; }
|
||||
inline const char * lblock(int l1) const { return block + l1*stride; }
|
||||
|
||||
const char * data;
|
||||
@@ -177,27 +176,16 @@ struct F16 {
|
||||
}
|
||||
};
|
||||
|
||||
template <int D, int step>
|
||||
struct HelperF16 final : public BaseHelper<step> {
|
||||
using Base = BaseHelper<step>;
|
||||
struct HelperF16 final : public BaseHelper {
|
||||
using Base = BaseHelper;
|
||||
HelperF16(const char * data, int stride) : Base(data, stride) {}
|
||||
|
||||
inline void load(int l1, F16::Data * vk) const {
|
||||
auto dr = Base::lblock(l1);
|
||||
for (int i = 0; i < D/F16::block_size; ++i) vk[i] = F16::load(dr, i);
|
||||
}
|
||||
|
||||
inline void load(int l1, int i, F16::Data& v1, F16::Data& v2) const {
|
||||
//auto dr = (const ggml_half *)Base::lblock(l1);
|
||||
auto dr = Base::lblock(l1);
|
||||
v1 = F16::load(dr, i + 0);
|
||||
v2 = F16::load(dr, i + 1);
|
||||
}
|
||||
|
||||
inline void load_2(int l1, F16::Data* vk) const {
|
||||
load(l1+0, vk+0);
|
||||
load(l1+1, vk+D/16);
|
||||
}
|
||||
};
|
||||
|
||||
template <int D> struct block_q8_KV {
|
||||
@@ -206,9 +194,9 @@ template <int D> struct block_q8_KV {
|
||||
int8_t qs[D];
|
||||
};
|
||||
|
||||
template <int D, int step>
|
||||
struct HelperQ8KV final : public BaseHelper<step> {
|
||||
using Base = BaseHelper<step>;
|
||||
template <int D>
|
||||
struct HelperQ8KV final : public BaseHelper {
|
||||
using Base = BaseHelper;
|
||||
using block_q8 = block_q8_KV<D>;
|
||||
constexpr static ggml_type type = GGML_TYPE_Q8_KV;
|
||||
constexpr static int block_size_q = D;
|
||||
@@ -235,9 +223,8 @@ struct HelperQ8KV final : public BaseHelper<step> {
|
||||
}
|
||||
};
|
||||
|
||||
template <int D, int step>
|
||||
struct HelperQ80 final : public BaseHelper<step> {
|
||||
using Base = BaseHelper<step>;
|
||||
struct HelperQ80 final : public BaseHelper {
|
||||
using Base = BaseHelper;
|
||||
constexpr static ggml_type type = GGML_TYPE_Q8_0;
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
using block_q8 = block_q8_2;
|
||||
@@ -271,8 +258,8 @@ struct HelperQ80 final : public BaseHelper<step> {
|
||||
#endif
|
||||
}
|
||||
|
||||
template <int D>
|
||||
static inline void convert(int nq, int stride_q, const float * q, block_q8_0 * y) {
|
||||
//GGML_ASSERT(nq <= step); Why did I have this assert?
|
||||
for (int i = 0; i < nq; ++i) {
|
||||
quantize_row_q8_0_x4(q, y, D);
|
||||
q += stride_q;
|
||||
@@ -280,8 +267,8 @@ struct HelperQ80 final : public BaseHelper<step> {
|
||||
}
|
||||
}
|
||||
|
||||
template <int D>
|
||||
static inline void convert(int nq, int stride_q, const float * q, block_q8_1 * y) {
|
||||
//GGML_ASSERT(nq <= step); Why did I have this assert?
|
||||
for (int i = 0; i < nq; ++i) {
|
||||
quantize_row_q8_1_x4(q, y, D);
|
||||
q += stride_q;
|
||||
@@ -289,8 +276,8 @@ struct HelperQ80 final : public BaseHelper<step> {
|
||||
}
|
||||
}
|
||||
|
||||
template <int D>
|
||||
static inline void convert(int nq, int stride_q, const float * q, block_q8_2 * y) {
|
||||
//GGML_ASSERT(nq <= step); Why did I have this assert?
|
||||
for (int i = 0; i < nq; ++i) {
|
||||
quantize_row_q8_2_x4(q, y, D);
|
||||
q += stride_q;
|
||||
@@ -298,6 +285,7 @@ struct HelperQ80 final : public BaseHelper<step> {
|
||||
}
|
||||
}
|
||||
|
||||
template <int D>
|
||||
static inline void convert(int nq, int stride_q, const float * q, block_q8_KV<D> * y) {
|
||||
for (int i = 0; i < nq; ++i) {
|
||||
quantize_row_q8_KV(q, y, D);
|
||||
@@ -307,9 +295,9 @@ struct HelperQ80 final : public BaseHelper<step> {
|
||||
}
|
||||
};
|
||||
|
||||
template <int D, int step>
|
||||
struct HelperQ80R8 : public BaseHelper<step> {
|
||||
using Base = BaseHelper<step>;
|
||||
template <int D>
|
||||
struct HelperQ80R8 : public BaseHelper {
|
||||
using Base = BaseHelper;
|
||||
constexpr static ggml_type type = GGML_TYPE_Q8_0_R8;
|
||||
#ifdef __AVX2__
|
||||
constexpr static int block_size_q = QK8_2;
|
||||
@@ -319,7 +307,7 @@ struct HelperQ80R8 : public BaseHelper<step> {
|
||||
using block_q8 = block_q8_0;
|
||||
#endif
|
||||
HelperQ80R8(const char * data, int stride) : Base(data, stride) {}
|
||||
HelperQ80R8(int nk, const HelperQ80<D, step>& q8) : Base(q8.data, q8.stride) {
|
||||
HelperQ80R8(int nk, const HelperQ80& q8) : Base(q8.data, q8.stride) {
|
||||
r4 = repack(nk, q8);
|
||||
Base::data = (const char *)r4.data();
|
||||
Base::stride = (D/QK8_0)*sizeof(block_q8_0);
|
||||
@@ -416,7 +404,7 @@ struct HelperQ80R8 : public BaseHelper<step> {
|
||||
}
|
||||
}
|
||||
|
||||
static std::vector<block_q8_0_r8> repack(int nk, const HelperQ80<D, step>& q8) {
|
||||
static std::vector<block_q8_0_r8> repack(int nk, const HelperQ80& q8) {
|
||||
static_assert(D%QK8_0 == 0);
|
||||
GGML_ASSERT(nk%8 == 0);
|
||||
constexpr int nblock = D/QK8_0;
|
||||
@@ -430,9 +418,9 @@ struct HelperQ80R8 : public BaseHelper<step> {
|
||||
};
|
||||
|
||||
// TODO: unite this with the above
|
||||
template <int D, int step>
|
||||
struct HelperQ8KVR8 : public BaseHelper<step> {
|
||||
using Base = BaseHelper<step>;
|
||||
template <int D>
|
||||
struct HelperQ8KVR8 : public BaseHelper {
|
||||
using Base = BaseHelper;
|
||||
constexpr static ggml_type type = GGML_TYPE_Q8_KV_R8;
|
||||
constexpr static int block_size_q = D;
|
||||
using block_q8 = block_q8_KV<D>;
|
||||
@@ -442,13 +430,13 @@ struct HelperQ8KVR8 : public BaseHelper<step> {
|
||||
int8_t qs[8*D];
|
||||
};
|
||||
|
||||
HelperQ8KVR8(int nk, const HelperQ8KV<D, step>& q8) : Base(q8.data, q8.stride) {
|
||||
HelperQ8KVR8(int nk, const HelperQ8KV<D>& q8) : Base(q8.data, q8.stride) {
|
||||
r4 = repack(nk, q8);
|
||||
Base::data = (const char *)r4.data();
|
||||
Base::stride = sizeof(block_q8_KV_r8)/8;
|
||||
}
|
||||
|
||||
static std::vector<block_q8_KV_r8> repack(int nk, const HelperQ8KV<D, step>& q8) {
|
||||
static std::vector<block_q8_KV_r8> repack(int nk, const HelperQ8KV<D>& q8) {
|
||||
static_assert(D%32 == 0);
|
||||
GGML_ASSERT(nk%8 == 0);
|
||||
std::vector<block_q8_KV_r8> result(nk/8);
|
||||
@@ -526,9 +514,8 @@ struct HelperQ8KVR8 : public BaseHelper<step> {
|
||||
std::vector<block_q8_KV_r8> r4;
|
||||
};
|
||||
|
||||
template <int D, int step>
|
||||
struct HelperQ40 final : public BaseHelper<step> {
|
||||
using Base = BaseHelper<step>;
|
||||
struct HelperQ40 final : public BaseHelper {
|
||||
using Base = BaseHelper;
|
||||
constexpr static ggml_type type = GGML_TYPE_Q4_0;
|
||||
#if defined __AVX2__
|
||||
using block_q8 = block_q8_2;
|
||||
@@ -576,9 +563,8 @@ struct HelperQ40 final : public BaseHelper<step> {
|
||||
#endif
|
||||
};
|
||||
|
||||
template <int D, int step>
|
||||
struct HelperQ41 final : public BaseHelper<step> {
|
||||
using Base = BaseHelper<step>;
|
||||
struct HelperQ41 final : public BaseHelper {
|
||||
using Base = BaseHelper;
|
||||
using block_q8 = block_q8_2;
|
||||
constexpr static ggml_type type = GGML_TYPE_Q4_1;
|
||||
constexpr static int block_size_q = QK8_2;
|
||||
@@ -620,9 +606,8 @@ struct HelperQ41 final : public BaseHelper<step> {
|
||||
#endif
|
||||
};
|
||||
|
||||
template <int D, int step>
|
||||
struct HelperIQ4nl final : public BaseHelper<step> {
|
||||
using Base = BaseHelper<step>;
|
||||
struct HelperIQ4nl final : public BaseHelper {
|
||||
using Base = BaseHelper;
|
||||
constexpr static ggml_type type = GGML_TYPE_IQ4_NL;
|
||||
#ifdef __aarch64__
|
||||
using block_q8 = block_q8_0;
|
||||
@@ -676,8 +661,7 @@ struct HelperIQ4nl final : public BaseHelper<step> {
|
||||
#endif
|
||||
};
|
||||
|
||||
template <int D, int step>
|
||||
struct HelperQ60 final : public BaseHelper<step> {
|
||||
struct HelperQ60 final : public BaseHelper {
|
||||
constexpr static ggml_type type = GGML_TYPE_Q6_0;
|
||||
#ifdef __aarch64__
|
||||
using block_q8 = block_q8_0;
|
||||
@@ -686,7 +670,7 @@ struct HelperQ60 final : public BaseHelper<step> {
|
||||
using block_q8 = block_q8_2;
|
||||
constexpr static int block_size_q = QK8_2;
|
||||
#endif
|
||||
using Base = BaseHelper<step>;
|
||||
using Base = BaseHelper;
|
||||
HelperQ60(const char * data, int stride) : Base(data, stride) {}
|
||||
|
||||
// Needed for v * softmax(k * q)
|
||||
@@ -739,8 +723,10 @@ struct HelperQ60 final : public BaseHelper<step> {
|
||||
#endif
|
||||
};
|
||||
|
||||
template <int q_step, int k_step>
|
||||
template <int q_step_in, int k_step_in>
|
||||
struct FlashMS {
|
||||
constexpr static int q_step = q_step_in;
|
||||
constexpr static int k_step = k_step_in;
|
||||
// Something goes wrong when storing and manipulating K*Q as fp16.
|
||||
// It works for some models (e.g., Gemma-2), but not for others (e.g., LLaMA-3.1-8B).
|
||||
// As I wasn't able to find where we lose precision, let's comment this out
|
||||
@@ -979,8 +965,9 @@ struct FlashQKV {
|
||||
using qkv_cache_t = float;
|
||||
#endif
|
||||
|
||||
template <typename VHelper>
|
||||
inline void accumulate_qkv_1(const VHelper& vh, const FlashMS<q_step, k_step>& fms) {
|
||||
template <typename VHelper, typename FMS>
|
||||
inline void accumulate_qkv_1(const VHelper& vh, const FMS& fms) {
|
||||
static_assert(q_step == FMS::q_step);
|
||||
F16::Data vq[D/F16::block_size];
|
||||
if (fms.need_scaling[0] == 2) {
|
||||
for (int i = 0; i < D/F16::block_size; ++i) vq[i] = F16::zero();
|
||||
@@ -1017,8 +1004,9 @@ struct FlashQKV {
|
||||
|
||||
// This fails for head sizes of 80 and 112 as D/16 is odd, so we cannot do steps of 2
|
||||
// Hence, for now, we will not handle head sizes of 80 and 112
|
||||
template <typename VHelper>
|
||||
inline void accumulate_qkv(const VHelper& vh, const FlashMS<q_step, k_step>& fms) {
|
||||
template <typename VHelper, typename FMS>
|
||||
inline void accumulate_qkv(const VHelper& vh, const FMS& fms) {
|
||||
static_assert(q_step == FMS::q_step);
|
||||
if constexpr (q_step == 1) {
|
||||
accumulate_qkv_1(vh, fms);
|
||||
return;
|
||||
@@ -1106,8 +1094,9 @@ struct FlashQKV {
|
||||
}
|
||||
}
|
||||
|
||||
template <typename VHelper>
|
||||
inline void accumulate_qkv(int nq1, const VHelper& vh, const FlashMS<q_step, k_step>& fms) {
|
||||
template <typename VHelper, typename FMS>
|
||||
inline void accumulate_qkv(int nq1, const VHelper& vh, const FMS& fms) {
|
||||
static_assert(q_step == FMS::q_step);
|
||||
if (nq1 == 1) {
|
||||
accumulate_qkv_1(vh, fms);
|
||||
return;
|
||||
@@ -1151,7 +1140,9 @@ struct FlashQKV {
|
||||
}
|
||||
}
|
||||
|
||||
inline void normalize_and_store_1row(const FlashMS<q_step, k_step>& fms, int j, const qkv_cache_t * R, float * qkv) const {
|
||||
template <typename FMS>
|
||||
inline void normalize_and_store_1row(const FMS& fms, int j, const qkv_cache_t * R, float * qkv) const {
|
||||
static_assert(q_step == FMS::q_step);
|
||||
GGML_ASSERT(fms.S[j] > 0);
|
||||
auto norm = F16::set1(1/fms.S[j]);
|
||||
//auto norm = F16::set1(fms.S[j] > 0 ? 1/fms.S[j] : 0.f);
|
||||
@@ -1161,7 +1152,9 @@ struct FlashQKV {
|
||||
}
|
||||
}
|
||||
|
||||
inline void normalize_and_store(const FlashMS<q_step, k_step>& fms, int nq1, int stride_qkv, float * qkv, float * M, float * S) const {
|
||||
template <typename FMS>
|
||||
inline void normalize_and_store(const FMS& fms, int nq1, int stride_qkv, float * qkv, float * M, float * S) const {
|
||||
static_assert(q_step == FMS::q_step);
|
||||
if (M && S) {
|
||||
std::memcpy(M, fms.M, nq1*sizeof(float));
|
||||
std::memcpy(S, fms.S, nq1*sizeof(float));
|
||||
@@ -1187,7 +1180,9 @@ struct FlashQKV {
|
||||
}
|
||||
}
|
||||
|
||||
inline void normalize_and_store(const FlashMS<q_step, k_step>& fms, int stride_qkv, float * qkv, float * M, float * S) const {
|
||||
template <typename FMS>
|
||||
inline void normalize_and_store(const FMS& fms, int stride_qkv, float * qkv, float * M, float * S) const {
|
||||
static_assert(q_step == FMS::q_step);
|
||||
if (M && S) {
|
||||
std::memcpy(M, fms.M, q_step*sizeof(float));
|
||||
std::memcpy(S, fms.S, q_step*sizeof(float));
|
||||
@@ -1365,8 +1360,8 @@ struct FlashQKfp32 {
|
||||
constexpr int kMaxQ = 8;
|
||||
static_assert(q_step < kMaxQ || q_step%kMaxQ == 0);
|
||||
DataInfo info{fms.cache, (const char *)q, k_step, (D/KHelper::block_size_q)*sizeof(block_q8), 0, 1, nullptr};
|
||||
if constexpr (std::is_same_v<KHelper, HelperQ8KVR8<D, k_step>> ||
|
||||
std::is_same_v<KHelper, HelperQ8KV<D, k_step>>) {
|
||||
if constexpr (std::is_same_v<KHelper, HelperQ8KVR8<D>> ||
|
||||
std::is_same_v<KHelper, HelperQ8KV<D>>) {
|
||||
iqk_gemm_q8kv_fa(D, q_step, kh.type, kh.block, kh.stride, info, k_step);
|
||||
} else {
|
||||
iqk_gemm_legacy_fa(D, q_step, kh.type, kh.block, kh.stride, info, k_step);
|
||||
@@ -1389,8 +1384,8 @@ struct FlashQKfp32 {
|
||||
const block_q8 * q, const char * mask, FlashMS<q_step, k_step>& fms) {
|
||||
GGML_ASSERT(nq < q_step);
|
||||
DataInfo info{fms.cache, (const char *)q, k_step, (D/KHelper::block_size_q)*sizeof(block_q8), 0, 1, nullptr};
|
||||
if constexpr (std::is_same_v<KHelper, HelperQ8KVR8<D, k_step>> ||
|
||||
std::is_same_v<KHelper, HelperQ8KV<D, k_step>>) {
|
||||
if constexpr (std::is_same_v<KHelper, HelperQ8KVR8<D>> ||
|
||||
std::is_same_v<KHelper, HelperQ8KV<D>>) {
|
||||
iqk_gemm_q8kv_fa(D, nq, kh.type, kh.block, kh.stride, info, k_step);
|
||||
} else {
|
||||
iqk_gemm_legacy_fa(D, nq, kh.type, kh.block, kh.stride, info, k_step);
|
||||
@@ -1434,8 +1429,8 @@ void compute_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, in
|
||||
KQHelper::multiply_mask_kq(kh, stride_q, stride_m, q, mr, fms);
|
||||
#endif
|
||||
fqkv.accumulate_qkv(vh, fms);
|
||||
kh.next_block();
|
||||
vh.next_block();
|
||||
kh.next_block(k_step);
|
||||
vh.next_block(k_step);
|
||||
mr += k_step*sizeof(ggml_half);
|
||||
}
|
||||
fqkv.normalize_and_store(fms, stride_qkv, qkv, M, S);
|
||||
@@ -1461,8 +1456,8 @@ void compute_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, in
|
||||
KQHelper::multiply_mask_kq(n_left, kh, stride_q, stride_m, q, mr, fms);
|
||||
#endif
|
||||
fqkv.accumulate_qkv(n_left, vh, fms);
|
||||
kh.next_block();
|
||||
vh.next_block();
|
||||
kh.next_block(k_step);
|
||||
vh.next_block(k_step);
|
||||
mr += k_step*sizeof(ggml_half);
|
||||
}
|
||||
fqkv.normalize_and_store(fms, n_left, stride_qkv, qkv, M, S);
|
||||
@@ -1476,22 +1471,22 @@ void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q,
|
||||
const float * q, const char * mask, float * qkv,
|
||||
float * M, float * S, char * qptr) {
|
||||
auto q8 = (typename KHelper::block_q8 *)qptr;
|
||||
if constexpr (q_step > 1 && std::is_same_v<KHelper, HelperQ80<Dk, k_step>>) {
|
||||
if constexpr (q_step > 1 && std::is_same_v<KHelper, HelperQ80>) {
|
||||
if (nq1 == q_step) {
|
||||
fms.init_qstep();
|
||||
kh.reset_block();
|
||||
vh.reset_block();
|
||||
block_q8_0_r8 q8r8[Dk/QK8_0 * k_step/8];
|
||||
HelperQ80R8<Dk, k_step> khr8((const char *)q8r8, Dk/QK8_0*sizeof(block_q8_0));
|
||||
auto q8r = (typename HelperQ80R8<Dk, k_step>::block_q8 *)qptr;
|
||||
HelperQ80<Dk, QK8_0>::convert(q_step, stride_q, q, q8r);
|
||||
HelperQ80R8<Dk> khr8((const char *)q8r8, Dk/QK8_0*sizeof(block_q8_0));
|
||||
auto q8r = (typename HelperQ80R8<Dk>::block_q8 *)qptr;
|
||||
HelperQ80::convert<Dk>(q_step, stride_q, q, q8r);
|
||||
auto mr = mask;
|
||||
for (int k1 = 0; k1 < nk1/k_step; ++k1) {
|
||||
HelperQ80R8<Dk, k_step>::repack(k_step, kh.block, kh.stride, q8r8);
|
||||
HelperQ80R8<Dk>::repack(k_step, kh.block, kh.stride, q8r8);
|
||||
KQHelper::mul_mask_kq(khr8, stride_m, q8r, mr, fms);
|
||||
fqkv.accumulate_qkv(vh, fms);
|
||||
kh.next_block();
|
||||
vh.next_block();
|
||||
kh.next_block(k_step);
|
||||
vh.next_block(k_step);
|
||||
mr += k_step*sizeof(ggml_half);
|
||||
}
|
||||
fqkv.normalize_and_store(fms, stride_qkv, qkv, M, S);
|
||||
@@ -1508,7 +1503,7 @@ void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q,
|
||||
fms.init_qstep();
|
||||
kh.reset_block();
|
||||
vh.reset_block();
|
||||
HelperQ80<Dk, QK8_0>::convert(q_step, stride_q, q, q8);
|
||||
HelperQ80::convert<Dk>(q_step, stride_q, q, q8);
|
||||
#if FA_TIMING
|
||||
perf.accum_nolock(0, t1);
|
||||
#endif
|
||||
@@ -1525,8 +1520,8 @@ void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q,
|
||||
KQHelper::mul_mask_kq(kh, stride_m, q8, mr, fms);
|
||||
fqkv.accumulate_qkv(vh, fms);
|
||||
#endif
|
||||
kh.next_block();
|
||||
vh.next_block();
|
||||
kh.next_block(k_step);
|
||||
vh.next_block(k_step);
|
||||
mr += k_step*sizeof(ggml_half);
|
||||
}
|
||||
#if FA_TIMING
|
||||
@@ -1547,13 +1542,13 @@ void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q,
|
||||
fms.init_qstep();
|
||||
kh.reset_block();
|
||||
vh.reset_block();
|
||||
HelperQ80<Dk, QK8_0>::convert(n_left, stride_q, q, q8);
|
||||
HelperQ80::convert<Dk>(n_left, stride_q, q, q8);
|
||||
auto mr = mask;
|
||||
for (int k1 = 0; k1 < nk1/k_step; ++k1) {
|
||||
KQHelper::mul_mask_kq(n_left, kh, stride_m, q8, mr, fms);
|
||||
fqkv.accumulate_qkv(n_left, vh, fms);
|
||||
kh.next_block();
|
||||
vh.next_block();
|
||||
kh.next_block(k_step);
|
||||
vh.next_block(k_step);
|
||||
mr += k_step*sizeof(ggml_half);
|
||||
}
|
||||
fqkv.normalize_and_store(fms, n_left, stride_qkv, qkv, M, S);
|
||||
@@ -1591,14 +1586,14 @@ struct FlashAttn {
|
||||
template <typename KHelper, typename VHelper>
|
||||
void compute(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv,
|
||||
const float * q, const char * mask, float * qkv, [[maybe_unused]] float * M, [[maybe_unused]] float * S) {
|
||||
if constexpr (std::is_same_v<KHelper, HelperQ40<Dk, k_step>> ||
|
||||
std::is_same_v<KHelper, HelperQ41<Dk, k_step>> ||
|
||||
std::is_same_v<KHelper, HelperIQ4nl<Dk, k_step>> ||
|
||||
std::is_same_v<KHelper, HelperQ60<Dk, k_step>> ||
|
||||
std::is_same_v<KHelper, HelperQ80R8<Dk, k_step>> ||
|
||||
std::is_same_v<KHelper, HelperQ80<Dk, k_step>> ||
|
||||
std::is_same_v<KHelper, HelperQ8KV<Dk, k_step>> ||
|
||||
std::is_same_v<KHelper, HelperQ8KVR8<Dk, k_step>>) {
|
||||
if constexpr (std::is_same_v<KHelper, HelperQ40> ||
|
||||
std::is_same_v<KHelper, HelperQ41> ||
|
||||
std::is_same_v<KHelper, HelperIQ4nl> ||
|
||||
std::is_same_v<KHelper, HelperQ60> ||
|
||||
std::is_same_v<KHelper, HelperQ80R8<Dk>> ||
|
||||
std::is_same_v<KHelper, HelperQ80> ||
|
||||
std::is_same_v<KHelper, HelperQ8KV<Dk>> ||
|
||||
std::is_same_v<KHelper, HelperQ8KVR8<Dk>>) {
|
||||
constexpr size_t kMaxOnStackSize = 576;
|
||||
//auto q_size = q_step*(Dk/KHelper::block_size_q)*sizeof(typename KHelper::block_q8);
|
||||
auto q_size = q_step*(Dk/QK8_2*sizeof(block_q8_2));
|
||||
@@ -1606,31 +1601,33 @@ struct FlashAttn {
|
||||
if (q_size > kMaxOnStackSize) {
|
||||
auto qptr = get_q_storage(q_size);
|
||||
if (false && nq1 >= 8) {
|
||||
if constexpr (std::is_same_v<KHelper, HelperQ80<Dk, k_step>>) {
|
||||
if constexpr (std::is_same_v<KHelper, HelperQ80>) {
|
||||
#if FA_TIMING
|
||||
auto t1 = Perf::cur_time();
|
||||
HelperQ80R8<Dk, k_step> khr4(nk1, kh);
|
||||
Perf::instance().accum(4, t1);
|
||||
#else
|
||||
HelperQ80R8<Dk, k_step> khr4(nk1, kh);
|
||||
HelperQ80R8<Dk> khr4(nk1, kh);
|
||||
#endif
|
||||
compute_helper_q<Dk, Dv, q_step, k_step, HelperQ80R8<Dk, k_step>, VHelper, FlashQKfp32<Dk, q_step, k_step>>(
|
||||
compute_helper_q<Dk, Dv, q_step, k_step, HelperQ80R8<Dk>, VHelper, FlashQKfp32<Dk, q_step, k_step>>(
|
||||
khr4, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S, qptr);
|
||||
return;
|
||||
|
||||
}
|
||||
if constexpr (std::is_same_v<KHelper, HelperQ8KV<Dk, k_step>>) {
|
||||
#if GGML_IQK_FA_ALL_QUANTS
|
||||
if constexpr (std::is_same_v<KHelper, HelperQ8KV<Dk>>) {
|
||||
#if FA_TIMING
|
||||
auto t1 = Perf::cur_time();
|
||||
HelperQ8KVR8<Dk, k_step> khr4(nk1, kh);
|
||||
Perf::instance().accum(4, t1);
|
||||
#else
|
||||
HelperQ8KVR8<Dk, k_step> khr4(nk1, kh);
|
||||
HelperQ8KVR8<Dk> khr4(nk1, kh);
|
||||
#endif
|
||||
compute_helper_q<Dk, Dv, q_step, k_step, HelperQ8KVR8<Dk, k_step>, VHelper, FlashQKfp32<Dk, q_step, k_step>>(
|
||||
compute_helper_q<Dk, Dv, q_step, k_step, HelperQ8KVR8<Dk>, VHelper, FlashQKfp32<Dk, q_step, k_step>>(
|
||||
khr4, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S, qptr);
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
compute_helper_q<Dk, Dv, q_step, k_step, KHelper, VHelper, FlashQKfp32<Dk, q_step, k_step>>(
|
||||
kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S, qptr);
|
||||
@@ -2040,8 +2037,8 @@ struct FlashAttnBF16 {
|
||||
FlashQKbf16<Dk, q_step, k_step>::multiply_mask_kq(kh, stride_m, q_bf16, mr, fms);
|
||||
fqkv.accumulate_qkv(vh, fms);
|
||||
#endif
|
||||
kh.next_block();
|
||||
vh.next_block();
|
||||
kh.next_block(k_step);
|
||||
vh.next_block(k_step);
|
||||
mr += k_step*sizeof(ggml_half);
|
||||
}
|
||||
#if FA_TIMING
|
||||
@@ -2066,8 +2063,8 @@ struct FlashAttnBF16 {
|
||||
for (int k1 = 0; k1 < nk1/k_step; ++k1) {
|
||||
FlashQKbf16<Dk, q_step, k_step>::multiply_mask_kq(n_left, kh, stride_m, q_bf16, mr, fms);
|
||||
fqkv.accumulate_qkv(n_left, vh, fms);
|
||||
kh.next_block();
|
||||
vh.next_block();
|
||||
kh.next_block(k_step);
|
||||
vh.next_block(k_step);
|
||||
mr += k_step*sizeof(ggml_half);
|
||||
}
|
||||
fqkv.normalize_and_store(fms, n_left, stride_qkv, qkv, M, S);
|
||||
@@ -2180,7 +2177,7 @@ inline bool iqk_flash_helper_T(KHelper& kh, ggml_type type_v,
|
||||
|
||||
switch (type_v) {
|
||||
case GGML_TYPE_F16: {
|
||||
HelperF16<Dv, k_step> vh(v, stride_v);
|
||||
HelperF16 vh(v, stride_v);
|
||||
iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S);
|
||||
} break;
|
||||
#ifdef __AVX512BF16__
|
||||
@@ -2190,28 +2187,28 @@ inline bool iqk_flash_helper_T(KHelper& kh, ggml_type type_v,
|
||||
} break;
|
||||
#endif
|
||||
case GGML_TYPE_Q8_0: {
|
||||
HelperQ80<Dv, k_step> vh(v, stride_v);
|
||||
HelperQ80 vh(v, stride_v);
|
||||
iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S);
|
||||
} break;
|
||||
case GGML_TYPE_Q8_KV: {
|
||||
HelperQ8KV<Dv, k_step> vh(v, stride_v);
|
||||
HelperQ8KV<Dv> vh(v, stride_v);
|
||||
iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S);
|
||||
} break;
|
||||
case GGML_TYPE_Q6_0: {
|
||||
HelperQ60<Dv, k_step> vh(v, stride_v);
|
||||
iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S);
|
||||
} break;
|
||||
case GGML_TYPE_Q4_0: {
|
||||
HelperQ40<Dv, k_step> vh(v, stride_v);
|
||||
HelperQ60 vh(v, stride_v);
|
||||
iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S);
|
||||
} break;
|
||||
#if GGML_IQK_FA_ALL_QUANTS
|
||||
case GGML_TYPE_Q4_0: {
|
||||
HelperQ40 vh(v, stride_v);
|
||||
iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S);
|
||||
} break;
|
||||
case GGML_TYPE_Q4_1: {
|
||||
HelperQ41<Dv, k_step> vh(v, stride_v);
|
||||
HelperQ41 vh(v, stride_v);
|
||||
iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S);
|
||||
} break;
|
||||
case GGML_TYPE_IQ4_NL: {
|
||||
HelperIQ4nl<Dv, k_step> vh(v, stride_v);
|
||||
HelperIQ4nl vh(v, stride_v);
|
||||
iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S);
|
||||
} break;
|
||||
#endif
|
||||
@@ -2229,36 +2226,36 @@ inline bool iqk_flash_helper_T(ggml_type type_k, ggml_type type_v,
|
||||
bool result = false;
|
||||
switch (type_k) {
|
||||
case GGML_TYPE_F16: {
|
||||
HelperF16<Dk, k_step> kh(k, stride_k);
|
||||
HelperF16 kh(k, stride_k);
|
||||
result = iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S);
|
||||
} break;
|
||||
case GGML_TYPE_Q8_0: {
|
||||
HelperQ80<Dk, k_step> kh(k, stride_k);
|
||||
HelperQ80 kh(k, stride_k);
|
||||
result = iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S);
|
||||
} break;
|
||||
case GGML_TYPE_Q8_0_R8: {
|
||||
HelperQ80R8<Dk, k_step> kh(k, stride_k);
|
||||
result = iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S);
|
||||
} break;
|
||||
case GGML_TYPE_Q8_KV: {
|
||||
HelperQ8KV<Dk, k_step> kh(k, stride_k);
|
||||
HelperQ80R8<Dk> kh(k, stride_k);
|
||||
result = iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S);
|
||||
} break;
|
||||
case GGML_TYPE_Q6_0: {
|
||||
HelperQ60<Dk, k_step> kh(k, stride_k);
|
||||
result = iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S);
|
||||
} break;
|
||||
case GGML_TYPE_Q4_0: {
|
||||
HelperQ40<Dk, k_step> kh(k, stride_k);
|
||||
HelperQ60 kh(k, stride_k);
|
||||
result = iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S);
|
||||
} break;
|
||||
#if GGML_IQK_FA_ALL_QUANTS
|
||||
case GGML_TYPE_Q8_KV: {
|
||||
HelperQ8KV<Dk> kh(k, stride_k);
|
||||
result = iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S);
|
||||
} break;
|
||||
case GGML_TYPE_Q4_0: {
|
||||
HelperQ40 kh(k, stride_k);
|
||||
result = iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S);
|
||||
} break;
|
||||
case GGML_TYPE_Q4_1: {
|
||||
HelperQ41<Dk, k_step> kh(k, stride_k);
|
||||
HelperQ41 kh(k, stride_k);
|
||||
result = iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S);
|
||||
} break;
|
||||
case GGML_TYPE_IQ4_NL: {
|
||||
HelperIQ4nl<Dk, k_step> kh(k, stride_k);
|
||||
HelperIQ4nl kh(k, stride_k);
|
||||
result = iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S);
|
||||
} break;
|
||||
#endif
|
||||
|
||||
@@ -3039,7 +3039,7 @@ static void mul_mat_q8_KV_q8_KV_8(int n, const void * vx, size_t bx, const DataI
|
||||
#endif
|
||||
|
||||
template <int k_step>
|
||||
inline std::pair<mul_mat_t, int> mul_mat_kernel(int D, int int_typeA, int nq) {
|
||||
inline std::pair<mul_mat_t, int> mul_mat_kernel([[maybe_unused]] int D, int int_typeA, int nq) {
|
||||
auto typeA = ggml_type(int_typeA);
|
||||
constexpr int kMaxQ = 8;
|
||||
#define MAKE_FUNCS(mul_mat, n) \
|
||||
|
||||
Reference in New Issue
Block a user