FA: very slightly faster for nq = 1 (TG)

This commit is contained in:
Iwan Kawrakow
2025-02-10 18:25:44 +02:00
parent 10815e7ebe
commit 4066235b8f

View File

@@ -14879,10 +14879,60 @@ struct FlashQKV {
using qkv_cache_t = float;
#endif
template <typename VHelper>
inline void accumulate_qkv_1(const VHelper& vh, const FlashMS<q_step, k_step>& fms) {
F16::Data vq[D/F16::block_size];
if (fms.need_scaling[0] == 2) {
for (int i = 0; i < D/F16::block_size; ++i) vq[i] = F16::zero();
} else {
for (int i = 0; i < D/F16::block_size; ++i) vq[i] = F16::load(qkv_cache + F16::block_size*i);
if (fms.need_scaling[0] == 1) {
auto vms = F16::set1(fms.vms[0]);
for (int i = 0; i < D/F16::block_size; ++i) vq[i] = F16::mul(vms, vq[i]);
}
}
//F16::Data v[8];
F16::Data v0, v1;
for (int l = 0; l < k_step; l += 4) {
auto vs0 = F16::set1(fms.cache[l + 0]);
auto vs1 = F16::set1(fms.cache[l + 1]);
auto vs2 = F16::set1(fms.cache[l + 2]);
auto vs3 = F16::set1(fms.cache[l + 3]);
//auto vs = F16::set4(fms.cache + l);
for (int i = 0; i < D/F16::block_size; i += 2) {
vh.load(l+0, i, v0, v1);
vq[i+0] = F16::fmadd(vq[i+0], v0, vs0);
vq[i+1] = F16::fmadd(vq[i+1], v1, vs0);
vh.load(l+1, i, v0, v1);
vq[i+0] = F16::fmadd(vq[i+0], v0, vs1);
vq[i+1] = F16::fmadd(vq[i+1], v1, vs1);
vh.load(l+2, i, v0, v1);
vq[i+0] = F16::fmadd(vq[i+0], v0, vs2);
vq[i+1] = F16::fmadd(vq[i+1], v1, vs2);
vh.load(l+3, i, v0, v1);
vq[i+0] = F16::fmadd(vq[i+0], v0, vs3);
vq[i+1] = F16::fmadd(vq[i+1], v1, vs3);
//vq[i+0] = F16::fmadd_lane0(vq[i+0], v[0], vs);
//vq[i+1] = F16::fmadd_lane0(vq[i+1], v[4], vs);
//vq[i+0] = F16::fmadd_lane1(vq[i+0], v[1], vs);
//vq[i+1] = F16::fmadd_lane1(vq[i+1], v[5], vs);
//vq[i+0] = F16::fmadd_lane2(vq[i+0], v[2], vs);
//vq[i+1] = F16::fmadd_lane2(vq[i+1], v[6], vs);
//vq[i+0] = F16::fmadd_lane3(vq[i+0], v[3], vs);
//vq[i+1] = F16::fmadd_lane3(vq[i+1], v[7], vs);
}
}
for (int i = 0; i < D/F16::block_size; ++i) F16::store(qkv_cache + F16::block_size*i, vq[i]);
}
// This fails for head sizes of 80 and 112 as D/16 is odd, so we cannot do steps of 2
// Hence, for now, we will not handle head sizes of 80 and 112
template <typename VHelper>
inline void accumulate_qkv(const VHelper& vh, const FlashMS<q_step, k_step>& fms) {
if constexpr (q_step == 1) {
accumulate_qkv_1(vh, fms);
return;
}
F16::Data v[8];
for (int j = 0; j < q_step; ++j) {
auto R = qkv_cache + D*j;
@@ -14924,6 +14974,10 @@ struct FlashQKV {
template <typename VHelper>
inline void accumulate_qkv(int nq1, const VHelper& vh, const FlashMS<q_step, k_step>& fms) {
if (nq1 == 1) {
accumulate_qkv_1(vh, fms);
return;
}
F16::Data v[8];
for (int j = 0; j < nq1; ++j) {
auto R = qkv_cache + D*j;
@@ -15757,7 +15811,22 @@ struct FlashQKbf16 {
static inline void multiply_mask_kq(const KHelper& kh, int stride_m, const ggml_bf16_t * q,
const char * mask, FlashMS<q_step, k_step>& fms) {
#endif
{
if constexpr (q_step == 1) {
__m512bh vq[D/32];
__m512bh vk[D/32];
__m256 sum[8];
for (int i = 0; i < D/32; ++i) vq[i] = __m512bh(_mm512_loadu_si512((const __m512i *)q + i));
for (int l = 0; l < k_step; l += 8) {
for (int k = 0; k < 8; ++k) {
kh.load(l+k, vk);
auto vsum = _mm512_setzero_ps();
for (int i = 0; i < D/32; ++i) vsum = _mm512_dpbf16_ps(vsum, vk[i], vq[i]);
sum[k] = _mm256_add_ps(_mm512_castps512_ps256(vsum), _mm512_extractf32x8_ps(vsum, 1));
}
_mm256_storeu_ps(fms.cache + l, hsum_float_8x8(sum));
}
}
else {
__m512bh qv[D/32];
if constexpr (D <= 128) {
__m512bh vkh[D/4];