mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-28 00:54:09 +00:00
FA: very slightly faster for nq = 1 (TG)
This commit is contained in:
@@ -14879,10 +14879,60 @@ struct FlashQKV {
|
||||
using qkv_cache_t = float;
|
||||
#endif
|
||||
|
||||
template <typename VHelper>
|
||||
inline void accumulate_qkv_1(const VHelper& vh, const FlashMS<q_step, k_step>& fms) {
|
||||
F16::Data vq[D/F16::block_size];
|
||||
if (fms.need_scaling[0] == 2) {
|
||||
for (int i = 0; i < D/F16::block_size; ++i) vq[i] = F16::zero();
|
||||
} else {
|
||||
for (int i = 0; i < D/F16::block_size; ++i) vq[i] = F16::load(qkv_cache + F16::block_size*i);
|
||||
if (fms.need_scaling[0] == 1) {
|
||||
auto vms = F16::set1(fms.vms[0]);
|
||||
for (int i = 0; i < D/F16::block_size; ++i) vq[i] = F16::mul(vms, vq[i]);
|
||||
}
|
||||
}
|
||||
//F16::Data v[8];
|
||||
F16::Data v0, v1;
|
||||
for (int l = 0; l < k_step; l += 4) {
|
||||
auto vs0 = F16::set1(fms.cache[l + 0]);
|
||||
auto vs1 = F16::set1(fms.cache[l + 1]);
|
||||
auto vs2 = F16::set1(fms.cache[l + 2]);
|
||||
auto vs3 = F16::set1(fms.cache[l + 3]);
|
||||
//auto vs = F16::set4(fms.cache + l);
|
||||
for (int i = 0; i < D/F16::block_size; i += 2) {
|
||||
vh.load(l+0, i, v0, v1);
|
||||
vq[i+0] = F16::fmadd(vq[i+0], v0, vs0);
|
||||
vq[i+1] = F16::fmadd(vq[i+1], v1, vs0);
|
||||
vh.load(l+1, i, v0, v1);
|
||||
vq[i+0] = F16::fmadd(vq[i+0], v0, vs1);
|
||||
vq[i+1] = F16::fmadd(vq[i+1], v1, vs1);
|
||||
vh.load(l+2, i, v0, v1);
|
||||
vq[i+0] = F16::fmadd(vq[i+0], v0, vs2);
|
||||
vq[i+1] = F16::fmadd(vq[i+1], v1, vs2);
|
||||
vh.load(l+3, i, v0, v1);
|
||||
vq[i+0] = F16::fmadd(vq[i+0], v0, vs3);
|
||||
vq[i+1] = F16::fmadd(vq[i+1], v1, vs3);
|
||||
//vq[i+0] = F16::fmadd_lane0(vq[i+0], v[0], vs);
|
||||
//vq[i+1] = F16::fmadd_lane0(vq[i+1], v[4], vs);
|
||||
//vq[i+0] = F16::fmadd_lane1(vq[i+0], v[1], vs);
|
||||
//vq[i+1] = F16::fmadd_lane1(vq[i+1], v[5], vs);
|
||||
//vq[i+0] = F16::fmadd_lane2(vq[i+0], v[2], vs);
|
||||
//vq[i+1] = F16::fmadd_lane2(vq[i+1], v[6], vs);
|
||||
//vq[i+0] = F16::fmadd_lane3(vq[i+0], v[3], vs);
|
||||
//vq[i+1] = F16::fmadd_lane3(vq[i+1], v[7], vs);
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < D/F16::block_size; ++i) F16::store(qkv_cache + F16::block_size*i, vq[i]);
|
||||
}
|
||||
|
||||
// This fails for head sizes of 80 and 112 as D/16 is odd, so we cannot do steps of 2
|
||||
// Hence, for now, we will not handle head sizes of 80 and 112
|
||||
template <typename VHelper>
|
||||
inline void accumulate_qkv(const VHelper& vh, const FlashMS<q_step, k_step>& fms) {
|
||||
if constexpr (q_step == 1) {
|
||||
accumulate_qkv_1(vh, fms);
|
||||
return;
|
||||
}
|
||||
F16::Data v[8];
|
||||
for (int j = 0; j < q_step; ++j) {
|
||||
auto R = qkv_cache + D*j;
|
||||
@@ -14924,6 +14974,10 @@ struct FlashQKV {
|
||||
|
||||
template <typename VHelper>
|
||||
inline void accumulate_qkv(int nq1, const VHelper& vh, const FlashMS<q_step, k_step>& fms) {
|
||||
if (nq1 == 1) {
|
||||
accumulate_qkv_1(vh, fms);
|
||||
return;
|
||||
}
|
||||
F16::Data v[8];
|
||||
for (int j = 0; j < nq1; ++j) {
|
||||
auto R = qkv_cache + D*j;
|
||||
@@ -15757,7 +15811,22 @@ struct FlashQKbf16 {
|
||||
static inline void multiply_mask_kq(const KHelper& kh, int stride_m, const ggml_bf16_t * q,
|
||||
const char * mask, FlashMS<q_step, k_step>& fms) {
|
||||
#endif
|
||||
{
|
||||
if constexpr (q_step == 1) {
|
||||
__m512bh vq[D/32];
|
||||
__m512bh vk[D/32];
|
||||
__m256 sum[8];
|
||||
for (int i = 0; i < D/32; ++i) vq[i] = __m512bh(_mm512_loadu_si512((const __m512i *)q + i));
|
||||
for (int l = 0; l < k_step; l += 8) {
|
||||
for (int k = 0; k < 8; ++k) {
|
||||
kh.load(l+k, vk);
|
||||
auto vsum = _mm512_setzero_ps();
|
||||
for (int i = 0; i < D/32; ++i) vsum = _mm512_dpbf16_ps(vsum, vk[i], vq[i]);
|
||||
sum[k] = _mm256_add_ps(_mm512_castps512_ps256(vsum), _mm512_extractf32x8_ps(vsum, 1));
|
||||
}
|
||||
_mm256_storeu_ps(fms.cache + l, hsum_float_8x8(sum));
|
||||
}
|
||||
}
|
||||
else {
|
||||
__m512bh qv[D/32];
|
||||
if constexpr (D <= 128) {
|
||||
__m512bh vkh[D/4];
|
||||
|
||||
Reference in New Issue
Block a user