mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-22 14:14:32 +00:00
Fix FA bug on AVX2 (#364)
* Fix FA bug on AVX2 * Also this was wrong --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
@@ -17120,11 +17120,12 @@ void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q,
|
|||||||
vh.reset_block();
|
vh.reset_block();
|
||||||
block_q8_0_r8 q8r8[Dk/QK8_0 * k_step/8];
|
block_q8_0_r8 q8r8[Dk/QK8_0 * k_step/8];
|
||||||
HelperQ80R8<Dk, k_step> khr8((const char *)q8r8, Dk/QK8_0*sizeof(block_q8_0));
|
HelperQ80R8<Dk, k_step> khr8((const char *)q8r8, Dk/QK8_0*sizeof(block_q8_0));
|
||||||
HelperQ80<Dk, QK8_0>::convert(q_step, stride_q, q, q8);
|
auto q8r = (typename HelperQ80R8<Dk, k_step>::block_q8 *)qptr;
|
||||||
|
HelperQ80<Dk, QK8_0>::convert(q_step, stride_q, q, q8r);
|
||||||
auto mr = mask;
|
auto mr = mask;
|
||||||
for (int k1 = 0; k1 < nk1/k_step; ++k1) {
|
for (int k1 = 0; k1 < nk1/k_step; ++k1) {
|
||||||
HelperQ80R8<Dk, k_step>::repack(k_step, kh.data, kh.stride, q8r8);
|
HelperQ80R8<Dk, k_step>::repack(k_step, kh.block, kh.stride, q8r8);
|
||||||
KQHelper::mul_mask_kq(khr8, stride_m, q8, mr, fms);
|
KQHelper::mul_mask_kq(khr8, stride_m, q8r, mr, fms);
|
||||||
fqkv.accumulate_qkv(vh, fms);
|
fqkv.accumulate_qkv(vh, fms);
|
||||||
kh.next_block();
|
kh.next_block();
|
||||||
vh.next_block();
|
vh.next_block();
|
||||||
@@ -17236,7 +17237,8 @@ struct FlashAttn {
|
|||||||
std::is_same_v<KHelper, HelperQ8KV<Dk, k_step>> ||
|
std::is_same_v<KHelper, HelperQ8KV<Dk, k_step>> ||
|
||||||
std::is_same_v<KHelper, HelperQ8KVR8<Dk, k_step>>) {
|
std::is_same_v<KHelper, HelperQ8KVR8<Dk, k_step>>) {
|
||||||
constexpr size_t kMaxOnStackSize = 576;
|
constexpr size_t kMaxOnStackSize = 576;
|
||||||
auto q_size = q_step*(Dk/KHelper::block_size_q)*sizeof(typename KHelper::block_q8);
|
//auto q_size = q_step*(Dk/KHelper::block_size_q)*sizeof(typename KHelper::block_q8);
|
||||||
|
auto q_size = q_step*(Dk/QK8_2*sizeof(block_q8_2));
|
||||||
q_size = GGML_PAD(q_size, 64);
|
q_size = GGML_PAD(q_size, 64);
|
||||||
if (q_size > kMaxOnStackSize) {
|
if (q_size > kMaxOnStackSize) {
|
||||||
auto qptr = get_q_storage(q_size);
|
auto qptr = get_q_storage(q_size);
|
||||||
|
|||||||
Reference in New Issue
Block a user