mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-29 19:01:47 +00:00
CPU: faster FA (#797)
* Avoid computing FA chunks where the mask is -infinity * Avoid computing FA chunks where the mask is -infinity also for f16/bf16 --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
@@ -1355,7 +1355,11 @@ void compute_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, in
|
|||||||
KQHelper::convert(q_step, stride_q, q, q_f16);
|
KQHelper::convert(q_step, stride_q, q, q_f16);
|
||||||
#endif
|
#endif
|
||||||
auto mr = mask;
|
auto mr = mask;
|
||||||
for (int k1 = 0; k1 < nk1/k_step; ++k1) {
|
auto Mc = (const uint16_t *)(mr + (q_step - 1)*stride_m);
|
||||||
|
int ik = nk1 - k_step;
|
||||||
|
for (; ik >=0 && Mc[ik] != 0; ik -= k_step);
|
||||||
|
ik += k_step;
|
||||||
|
for (int k1 = 0; k1 < ik/k_step; ++k1) {
|
||||||
#ifdef __aarch64__
|
#ifdef __aarch64__
|
||||||
KQHelper::multiply_mask_kq(kh, Dk, stride_m, q_f16, mr, fms);
|
KQHelper::multiply_mask_kq(kh, Dk, stride_m, q_f16, mr, fms);
|
||||||
#else
|
#else
|
||||||
@@ -1415,6 +1419,8 @@ void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q,
|
|||||||
HelperQ80::convert<Dk>(q_step, stride_q, q, q8r);
|
HelperQ80::convert<Dk>(q_step, stride_q, q, q8r);
|
||||||
auto mr = mask;
|
auto mr = mask;
|
||||||
for (int k1 = 0; k1 < nk1/k_step; ++k1) {
|
for (int k1 = 0; k1 < nk1/k_step; ++k1) {
|
||||||
|
auto Mc = (const uint16_t *)(mr + (q_step - 1)*stride_m);
|
||||||
|
if (Mc[0] != 0) break;
|
||||||
HelperQ80R8<Dk>::repack(k_step, kh.block, kh.stride, q8r8);
|
HelperQ80R8<Dk>::repack(k_step, kh.block, kh.stride, q8r8);
|
||||||
KQHelper::mul_mask_kq(khr8, stride_m, q8r, mr, fms);
|
KQHelper::mul_mask_kq(khr8, stride_m, q8r, mr, fms);
|
||||||
fqkv.accumulate_qkv(vh, fms);
|
fqkv.accumulate_qkv(vh, fms);
|
||||||
@@ -1441,7 +1447,11 @@ void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q,
|
|||||||
perf.accum_nolock(0, t1);
|
perf.accum_nolock(0, t1);
|
||||||
#endif
|
#endif
|
||||||
auto mr = mask;
|
auto mr = mask;
|
||||||
for (int k1 = 0; k1 < nk1/k_step; ++k1) {
|
auto Mc = (const uint16_t *)(mr + (q_step - 1)*stride_m);
|
||||||
|
int ik = nk1 - k_step;
|
||||||
|
for (; ik >=0 && Mc[ik] != 0; ik -= k_step);
|
||||||
|
ik += k_step;
|
||||||
|
for (int k1 = 0; k1 < ik/k_step; ++k1) {
|
||||||
#if FA_TIMING
|
#if FA_TIMING
|
||||||
t1 = Perf::cur_time();
|
t1 = Perf::cur_time();
|
||||||
KQHelper::mul_mask_kq(kh, stride_m, q8, mr, fms);
|
KQHelper::mul_mask_kq(kh, stride_m, q8, mr, fms);
|
||||||
@@ -1959,7 +1969,11 @@ struct FlashAttnBF16 {
|
|||||||
perf.accum_nolock(0, t1);
|
perf.accum_nolock(0, t1);
|
||||||
#endif
|
#endif
|
||||||
auto mr = mask;
|
auto mr = mask;
|
||||||
for (int k1 = 0; k1 < nk1/k_step; ++k1) {
|
auto Mc = (const uint16_t *)(mr + (q_step - 1)*stride_m);
|
||||||
|
int ik = nk1 - k_step;
|
||||||
|
for (; ik >=0 && Mc[ik] != 0; ik -= k_step);
|
||||||
|
ik += k_step;
|
||||||
|
for (int k1 = 0; k1 < ik/k_step; ++k1) {
|
||||||
#if FA_TIMING
|
#if FA_TIMING
|
||||||
//t1 = Perf::cur_time();
|
//t1 = Perf::cur_time();
|
||||||
FlashQKbf16<Dk, q_step, k_step>::multiply_mask_kq(kh, stride_m, q_bf16, mr, fms, perf);
|
FlashQKbf16<Dk, q_step, k_step>::multiply_mask_kq(kh, stride_m, q_bf16, mr, fms, perf);
|
||||||
|
|||||||
Reference in New Issue
Block a user