From 4753c861d1c30774e8d3d2139dff6e6455734cb0 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Thu, 16 Jan 2025 17:10:42 +0200 Subject: [PATCH] FA: slightly faster V*softmax(K*Q)) on Zen4 --- ggml/src/iqk/iqk_mul_mat.cpp | 173 +++++++++++++++++++++++++++++------ 1 file changed, 145 insertions(+), 28 deletions(-) diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 5577ea99..808a86b3 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -12354,6 +12354,15 @@ struct F16 { static inline float reduce_add(Data data) { return _mm512_reduce_add_ps(data); } static inline Data max(Data v1, Data v2) { return _mm512_max_ps(v1, v2); } static inline Data add(Data v1, Data v2) { return _mm512_add_ps(v1, v2); } + static inline Data set4(const float * ptr) { + auto v128 = _mm_loadu_ps(ptr); + auto v256 = _mm256_set_m128(v128, v128); + return _mm512_insertf32x8(_mm512_castps256_ps512(v256), v256, 1); + } + static inline Data fmadd_lane0(Data prev, Data v1, Data v2) { return _mm512_fmadd_ps(v1, _mm512_shuffle_ps(v2, v2, 0x00), prev); } + static inline Data fmadd_lane1(Data prev, Data v1, Data v2) { return _mm512_fmadd_ps(v1, _mm512_shuffle_ps(v2, v2, 0x55), prev); } + static inline Data fmadd_lane2(Data prev, Data v1, Data v2) { return _mm512_fmadd_ps(v1, _mm512_shuffle_ps(v2, v2, 0xaa), prev); } + static inline Data fmadd_lane3(Data prev, Data v1, Data v2) { return _mm512_fmadd_ps(v1, _mm512_shuffle_ps(v2, v2, 0xff), prev); } #elif defined __AVX2__ using Data = __m256; constexpr static int block_size = 8; @@ -12371,6 +12380,14 @@ struct F16 { static inline float reduce_add(Data data) { return hsum_float_8(data); } static inline Data max(Data v1, Data v2) { return _mm256_max_ps(v1, v2); } static inline Data add(Data v1, Data v2) { return _mm256_add_ps(v1, v2); } + static inline Data set4(const float * ptr) { + auto v128 = _mm_loadu_ps(ptr); + return _mm256_set_m128(v128, v128); + } + static inline Data fmadd_lane0(Data prev, Data v1, Data v2) { return _mm256_fmadd_ps(v1, _mm256_shuffle_ps(v2, v2, 0x00), prev); } + static inline Data fmadd_lane1(Data prev, Data v1, Data v2) { return _mm256_fmadd_ps(v1, _mm256_shuffle_ps(v2, v2, 0x55), prev); } + static inline Data fmadd_lane2(Data prev, Data v1, Data v2) { return _mm256_fmadd_ps(v1, _mm256_shuffle_ps(v2, v2, 0xaa), prev); } + static inline Data fmadd_lane3(Data prev, Data v1, Data v2) { return _mm256_fmadd_ps(v1, _mm256_shuffle_ps(v2, v2, 0xff), prev); } #else using Data = float16x8_t; constexpr static int block_size = 8; @@ -12402,6 +12419,15 @@ struct F16 { } static inline Data max(Data v1, Data v2) { return vmaxq_f16(v1, v2); } static inline Data add(Data v1, Data v2) { return vaddq_f16(v1, v2); } + static inline Data set4(const float * ptr) { + auto val32 = vld1q_f32(ptr); + auto val16 = vcvt_f16_f32(val32); + return vcombine_f16(val16, val16); + } + static inline Data fmadd_lane0(Data prev, Data v1, Data v2) { return vfmaq_laneq_f16(vfmaq_laneq_f16(prev, v1, v2, 0), v1, v2, 4); } + static inline Data fmadd_lane1(Data prev, Data v1, Data v2) { return vfmaq_laneq_f16(vfmaq_laneq_f16(prev, v1, v2, 1), v1, v2, 5); } + static inline Data fmadd_lane2(Data prev, Data v1, Data v2) { return vfmaq_laneq_f16(vfmaq_laneq_f16(prev, v1, v2, 2), v1, v2, 6); } + static inline Data fmadd_lane3(Data prev, Data v1, Data v2) { return vfmaq_laneq_f16(vfmaq_laneq_f16(prev, v1, v2, 3), v1, v2, 7); } #endif template static inline float reduce_max(const Data * data) { return reduce_T(data); @@ -12927,38 +12953,129 @@ struct FlashQKV { // Hence, for now, we will not handle head sizes of 80 and 112 template inline void accumulate_qkv(const VHelper& vh, const FlashMS& fms) { - F16::Data vk[2*q_step]; - for (int i = 0; i < D/F16::block_size; i += 2) { - for (int j = 0; j < q_step; ++j) { - if (fms.need_scaling[j] == 2) { - vk[2*j+0] = vk[2*j+1] = F16::zero(); - } else { - auto R = qkv_cache + D*j; - vk[2*j+0] = F16::load(R + F16::block_size*i); - vk[2*j+1] = F16::load(R + F16::block_size*(i + 1)); - if (fms.need_scaling[j] == 1) { - vk[2*j+0] = F16::mul(vk[2*j+0], fms.vms[j]); - vk[2*j+1] = F16::mul(vk[2*j+1], fms.vms[j]); - } - } + F16::Data v[8]; + for (int j = 0; j < q_step; ++j) { + auto R = qkv_cache + D*j; + if (fms.need_scaling[j] == 2) { + std::memset(R, 0, D*sizeof(qkv_cache_t)); } - F16::Data v1, v2, v3, v4; - for (int l1 = 0; l1 < k_step; l1 += 2) { - vh.load(l1+0, i, v1, v2); - vh.load(l1+1, i, v3, v4); - for (int j = 0; j < q_step; ++j) { - auto vs1 = F16::set1(fms.cache[k_step*j + l1+0]); - auto vs2 = F16::set1(fms.cache[k_step*j + l1+1]); - vk[2*j+0] = F16::fmadd(F16::fmadd(vk[2*j+0], v1, vs1), v3, vs2); - vk[2*j+1] = F16::fmadd(F16::fmadd(vk[2*j+1], v2, vs1), v4, vs2); + else if (fms.need_scaling[j] == 1) { + for (int i = 0; i < D/F16::block_size; ++i) { + F16::store(R + F16::block_size*i, F16::mul(fms.vms[j], F16::load(R + F16::block_size*i))); } } - for (int j = 0; j < q_step; ++j) { - auto R = qkv_cache + D*j; - F16::store(R + F16::block_size*(i + 0), vk[2*j+0]); - F16::store(R + F16::block_size*(i + 1), vk[2*j+1]); - } } + for (int i = 0; i < D/F16::block_size; i += 2) { + for (int l = 0; l < k_step; l += 4) { + vh.load(l+0, i, v[0], v[4]); + vh.load(l+1, i, v[1], v[5]); + vh.load(l+2, i, v[2], v[6]); + vh.load(l+3, i, v[3], v[7]); + for (int j = 0; j < q_step; ++j) { + auto R = qkv_cache + D*j; + auto s1 = F16::load(R + F16::block_size*(i+0)); + auto s2 = F16::load(R + F16::block_size*(i+1)); + auto vs = F16::set4(fms.cache + k_step*j + l); + s1 = F16::fmadd_lane0(s1, v[0], vs); + s2 = F16::fmadd_lane0(s2, v[4], vs); + s1 = F16::fmadd_lane1(s1, v[1], vs); + s2 = F16::fmadd_lane1(s2, v[5], vs); + s1 = F16::fmadd_lane2(s1, v[2], vs); + s2 = F16::fmadd_lane2(s2, v[6], vs); + s1 = F16::fmadd_lane3(s1, v[3], vs); + s2 = F16::fmadd_lane3(s2, v[7], vs); + F16::store(R + F16::block_size*(i+0), s1); + F16::store(R + F16::block_size*(i+1), s2); + } + } + } + //F16::Data vk[2*q_step]; + //for (int i = 0; i < D/F16::block_size; i += 2) { + // for (int j = 0; j < q_step; ++j) { + // if (fms.need_scaling[j] == 2) { + // vk[2*j+0] = vk[2*j+1] = F16::zero(); + // } else { + // auto R = qkv_cache + D*j; + // vk[2*j+0] = F16::load(R + F16::block_size*i); + // vk[2*j+1] = F16::load(R + F16::block_size*(i + 1)); + // if (fms.need_scaling[j] == 1) { + // vk[2*j+0] = F16::mul(vk[2*j+0], fms.vms[j]); + // vk[2*j+1] = F16::mul(vk[2*j+1], fms.vms[j]); + // } + // } + // } + // // R[j][i] += sum[l, V[l][i] * C[j][l] ] + // // If we transpose V, we can accumulate as + // // R[i][j] += sum [l, V[i][l] * C[j][l] ] + // // or + // // R[j][i] += sum [l, V[i][l] * C[j][l] ] + // // so + // // for (int j = 0; j < q_step; ++j) { + // // for (int i = 0; i < D; ++i) { + // // for (int l = 0; l < k_step; ++k) { + // // R[j][i] += V[i][l] * C[j][l]; + // // } + // // } + // // } + // // so + // // for (int j = 0; j < q_step; ++j) { + // // for (int i = 0; i < D/F16::block_size; ++i) { + // // auto acc = F16::load(qkv_cache + D*j + F16::block_size*i); + // // for (int l = 0; l < k_step/F16::block_size; ++k) { + // // auto vk = vh.load1(j, l); + // // auto vs = F16::load(fms.cache(k_step*j + l*F16::block_size]; + // // acc = F16::fmadd(acc, vk, vs); + // // } + // // } + // // } + // // but for k_step = 32, q_step = 8, F16::bloc = 32, we need only 16 registers to load the entire fms.cache + // // F16::Data C[4*k_step/F16::block_size]; + // // F16::Data V[4*k_step/F16::block_size]; + // // F16::Data acc[4*k_step/F16::block_size]; + // // for (int j = 0; j < q_step; j += 4) { + // // for (int l = 0; l < k_step/F16::block_size; ++l) { + // // C[(k_step/F16::block_size)*(j+0) + l] = F16::load(fms.cache + k_step*(j+0) + l*F16::block_size); + // // C[(k_step/F16::block_size)*(j+1) + l] = F16::load(fms.cache + k_step*(j+1) + l*F16::block_size); + // // C[(k_step/F16::block_size)*(j+2) + l] = F16::load(fms.cache + k_step*(j+2) + l*F16::block_size); + // // C[(k_step/F16::block_size)*(j+3) + l] = F16::load(fms.cache + k_step*(j+3) + l*F16::block_size); + // // V[(k_step/F16::block_size)*(j+0) + l] = vh.load(j+0, l); + // // V[(k_step/F16::block_size)*(j+1) + l] = vh.load(j+1, l); + // // V[(k_step/F16::block_size)*(j+2) + l] = vh.load(j+2, l); + // // V[(k_step/F16::block_size)*(j+3) + l] = vh.load(j+3, l); + // // } + // // for (int i = 0; i < D/F16::block_size; ++i) { + // // auto acc = F16::load(qkv_cache + D*(j+0) + F16::block_size*i); + // // for (int l = 0; l < k_step/F16::block_size; ++l) { + // // acc = F16::fmadd(acc, C[(k_step/F16::block_size)*(j+0)+l], V[(k_step/F16::block_size)*(j+0)+l]); + // // acc1 = F16::fmadd(F16::fmadd(acc1, C[(k_step/F16::block_size)*(j+0), + // // acc[0] = F16::load(qkv_cache + D*(j+0) + F16::block_size*(i+0)); + // // acc[1] = F16::load(qkv_cache + D*(j+0) + F16::block_size*(i+1)); + // // acc[2] = F16::load(qkv_cache + D*(j+1) + F16::block_size*(i+0)); + // // acc[3] = F16::load(qkv_cache + D*(j+1) + F16::block_size*(i+1)); + // // acc[4] = F16::load(qkv_cache + D*(j+2) + F16::block_size*(i+0)); + // // acc[5] = F16::load(qkv_cache + D*(j+2) + F16::block_size*(i+1)); + // // acc[6] = F16::load(qkv_cache + D*(j+3) + F16::block_size*(i+0)); + // // acc[7] = F16::load(qkv_cache + D*(j+3) + F16::block_size*(i+1)); + // // + // // } + // // } + // F16::Data v1, v2, v3, v4; + // for (int l1 = 0; l1 < k_step; l1 += 2) { + // vh.load(l1+0, i, v1, v2); + // vh.load(l1+1, i, v3, v4); + // for (int j = 0; j < q_step; ++j) { + // auto vs1 = F16::set1(fms.cache[k_step*j + l1+0]); + // auto vs2 = F16::set1(fms.cache[k_step*j + l1+1]); + // vk[2*j+0] = F16::fmadd(F16::fmadd(vk[2*j+0], v1, vs1), v3, vs2); + // vk[2*j+1] = F16::fmadd(F16::fmadd(vk[2*j+1], v2, vs1), v4, vs2); + // } + // } + // for (int j = 0; j < q_step; ++j) { + // auto R = qkv_cache + D*j; + // F16::store(R + F16::block_size*(i + 0), vk[2*j+0]); + // F16::store(R + F16::block_size*(i + 1), vk[2*j+1]); + // } + //} } template = 2>>