From 4066235b8f15cf3be3a7eeb683a883f4545c0b75 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Mon, 10 Feb 2025 18:25:44 +0200 Subject: [PATCH] FA: very slightly faster for nq = 1 (TG) --- ggml/src/iqk/iqk_mul_mat.cpp | 71 +++++++++++++++++++++++++++++++++++- 1 file changed, 70 insertions(+), 1 deletion(-) diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 72bff532..3b58495e 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -14879,10 +14879,60 @@ struct FlashQKV { using qkv_cache_t = float; #endif + template + inline void accumulate_qkv_1(const VHelper& vh, const FlashMS& fms) { + F16::Data vq[D/F16::block_size]; + if (fms.need_scaling[0] == 2) { + for (int i = 0; i < D/F16::block_size; ++i) vq[i] = F16::zero(); + } else { + for (int i = 0; i < D/F16::block_size; ++i) vq[i] = F16::load(qkv_cache + F16::block_size*i); + if (fms.need_scaling[0] == 1) { + auto vms = F16::set1(fms.vms[0]); + for (int i = 0; i < D/F16::block_size; ++i) vq[i] = F16::mul(vms, vq[i]); + } + } + //F16::Data v[8]; + F16::Data v0, v1; + for (int l = 0; l < k_step; l += 4) { + auto vs0 = F16::set1(fms.cache[l + 0]); + auto vs1 = F16::set1(fms.cache[l + 1]); + auto vs2 = F16::set1(fms.cache[l + 2]); + auto vs3 = F16::set1(fms.cache[l + 3]); + //auto vs = F16::set4(fms.cache + l); + for (int i = 0; i < D/F16::block_size; i += 2) { + vh.load(l+0, i, v0, v1); + vq[i+0] = F16::fmadd(vq[i+0], v0, vs0); + vq[i+1] = F16::fmadd(vq[i+1], v1, vs0); + vh.load(l+1, i, v0, v1); + vq[i+0] = F16::fmadd(vq[i+0], v0, vs1); + vq[i+1] = F16::fmadd(vq[i+1], v1, vs1); + vh.load(l+2, i, v0, v1); + vq[i+0] = F16::fmadd(vq[i+0], v0, vs2); + vq[i+1] = F16::fmadd(vq[i+1], v1, vs2); + vh.load(l+3, i, v0, v1); + vq[i+0] = F16::fmadd(vq[i+0], v0, vs3); + vq[i+1] = F16::fmadd(vq[i+1], v1, vs3); + //vq[i+0] = F16::fmadd_lane0(vq[i+0], v[0], vs); + //vq[i+1] = F16::fmadd_lane0(vq[i+1], v[4], vs); + //vq[i+0] = F16::fmadd_lane1(vq[i+0], v[1], vs); + //vq[i+1] = F16::fmadd_lane1(vq[i+1], v[5], vs); + //vq[i+0] = F16::fmadd_lane2(vq[i+0], v[2], vs); + //vq[i+1] = F16::fmadd_lane2(vq[i+1], v[6], vs); + //vq[i+0] = F16::fmadd_lane3(vq[i+0], v[3], vs); + //vq[i+1] = F16::fmadd_lane3(vq[i+1], v[7], vs); + } + } + for (int i = 0; i < D/F16::block_size; ++i) F16::store(qkv_cache + F16::block_size*i, vq[i]); + } + // This fails for head sizes of 80 and 112 as D/16 is odd, so we cannot do steps of 2 // Hence, for now, we will not handle head sizes of 80 and 112 template inline void accumulate_qkv(const VHelper& vh, const FlashMS& fms) { + if constexpr (q_step == 1) { + accumulate_qkv_1(vh, fms); + return; + } F16::Data v[8]; for (int j = 0; j < q_step; ++j) { auto R = qkv_cache + D*j; @@ -14924,6 +14974,10 @@ struct FlashQKV { template inline void accumulate_qkv(int nq1, const VHelper& vh, const FlashMS& fms) { + if (nq1 == 1) { + accumulate_qkv_1(vh, fms); + return; + } F16::Data v[8]; for (int j = 0; j < nq1; ++j) { auto R = qkv_cache + D*j; @@ -15757,7 +15811,22 @@ struct FlashQKbf16 { static inline void multiply_mask_kq(const KHelper& kh, int stride_m, const ggml_bf16_t * q, const char * mask, FlashMS& fms) { #endif - { + if constexpr (q_step == 1) { + __m512bh vq[D/32]; + __m512bh vk[D/32]; + __m256 sum[8]; + for (int i = 0; i < D/32; ++i) vq[i] = __m512bh(_mm512_loadu_si512((const __m512i *)q + i)); + for (int l = 0; l < k_step; l += 8) { + for (int k = 0; k < 8; ++k) { + kh.load(l+k, vk); + auto vsum = _mm512_setzero_ps(); + for (int i = 0; i < D/32; ++i) vsum = _mm512_dpbf16_ps(vsum, vk[i], vq[i]); + sum[k] = _mm256_add_ps(_mm512_castps512_ps256(vsum), _mm512_extractf32x8_ps(vsum, 1)); + } + _mm256_storeu_ps(fms.cache + l, hsum_float_8x8(sum)); + } + } + else { __m512bh qv[D/32]; if constexpr (D <= 128) { __m512bh vkh[D/4];