mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-24 23:24:13 +00:00
FA: slightly faster V*softmax(K*Q)) on Zen4
This commit is contained in:
@@ -12354,6 +12354,15 @@ struct F16 {
|
||||
static inline float reduce_add(Data data) { return _mm512_reduce_add_ps(data); }
|
||||
static inline Data max(Data v1, Data v2) { return _mm512_max_ps(v1, v2); }
|
||||
static inline Data add(Data v1, Data v2) { return _mm512_add_ps(v1, v2); }
|
||||
static inline Data set4(const float * ptr) {
|
||||
auto v128 = _mm_loadu_ps(ptr);
|
||||
auto v256 = _mm256_set_m128(v128, v128);
|
||||
return _mm512_insertf32x8(_mm512_castps256_ps512(v256), v256, 1);
|
||||
}
|
||||
static inline Data fmadd_lane0(Data prev, Data v1, Data v2) { return _mm512_fmadd_ps(v1, _mm512_shuffle_ps(v2, v2, 0x00), prev); }
|
||||
static inline Data fmadd_lane1(Data prev, Data v1, Data v2) { return _mm512_fmadd_ps(v1, _mm512_shuffle_ps(v2, v2, 0x55), prev); }
|
||||
static inline Data fmadd_lane2(Data prev, Data v1, Data v2) { return _mm512_fmadd_ps(v1, _mm512_shuffle_ps(v2, v2, 0xaa), prev); }
|
||||
static inline Data fmadd_lane3(Data prev, Data v1, Data v2) { return _mm512_fmadd_ps(v1, _mm512_shuffle_ps(v2, v2, 0xff), prev); }
|
||||
#elif defined __AVX2__
|
||||
using Data = __m256;
|
||||
constexpr static int block_size = 8;
|
||||
@@ -12371,6 +12380,14 @@ struct F16 {
|
||||
static inline float reduce_add(Data data) { return hsum_float_8(data); }
|
||||
static inline Data max(Data v1, Data v2) { return _mm256_max_ps(v1, v2); }
|
||||
static inline Data add(Data v1, Data v2) { return _mm256_add_ps(v1, v2); }
|
||||
static inline Data set4(const float * ptr) {
|
||||
auto v128 = _mm_loadu_ps(ptr);
|
||||
return _mm256_set_m128(v128, v128);
|
||||
}
|
||||
static inline Data fmadd_lane0(Data prev, Data v1, Data v2) { return _mm256_fmadd_ps(v1, _mm256_shuffle_ps(v2, v2, 0x00), prev); }
|
||||
static inline Data fmadd_lane1(Data prev, Data v1, Data v2) { return _mm256_fmadd_ps(v1, _mm256_shuffle_ps(v2, v2, 0x55), prev); }
|
||||
static inline Data fmadd_lane2(Data prev, Data v1, Data v2) { return _mm256_fmadd_ps(v1, _mm256_shuffle_ps(v2, v2, 0xaa), prev); }
|
||||
static inline Data fmadd_lane3(Data prev, Data v1, Data v2) { return _mm256_fmadd_ps(v1, _mm256_shuffle_ps(v2, v2, 0xff), prev); }
|
||||
#else
|
||||
using Data = float16x8_t;
|
||||
constexpr static int block_size = 8;
|
||||
@@ -12402,6 +12419,15 @@ struct F16 {
|
||||
}
|
||||
static inline Data max(Data v1, Data v2) { return vmaxq_f16(v1, v2); }
|
||||
static inline Data add(Data v1, Data v2) { return vaddq_f16(v1, v2); }
|
||||
static inline Data set4(const float * ptr) {
|
||||
auto val32 = vld1q_f32(ptr);
|
||||
auto val16 = vcvt_f16_f32(val32);
|
||||
return vcombine_f16(val16, val16);
|
||||
}
|
||||
static inline Data fmadd_lane0(Data prev, Data v1, Data v2) { return vfmaq_laneq_f16(vfmaq_laneq_f16(prev, v1, v2, 0), v1, v2, 4); }
|
||||
static inline Data fmadd_lane1(Data prev, Data v1, Data v2) { return vfmaq_laneq_f16(vfmaq_laneq_f16(prev, v1, v2, 1), v1, v2, 5); }
|
||||
static inline Data fmadd_lane2(Data prev, Data v1, Data v2) { return vfmaq_laneq_f16(vfmaq_laneq_f16(prev, v1, v2, 2), v1, v2, 6); }
|
||||
static inline Data fmadd_lane3(Data prev, Data v1, Data v2) { return vfmaq_laneq_f16(vfmaq_laneq_f16(prev, v1, v2, 3), v1, v2, 7); }
|
||||
#endif
|
||||
template <int k_step> static inline float reduce_max(const Data * data) {
|
||||
return reduce_T<k_step, &F16::max, &F16::reduce_max>(data);
|
||||
@@ -12927,38 +12953,129 @@ struct FlashQKV {
|
||||
// Hence, for now, we will not handle head sizes of 80 and 112
|
||||
template <typename VHelper>
|
||||
inline void accumulate_qkv(const VHelper& vh, const FlashMS<q_step, k_step>& fms) {
|
||||
F16::Data vk[2*q_step];
|
||||
for (int i = 0; i < D/F16::block_size; i += 2) {
|
||||
for (int j = 0; j < q_step; ++j) {
|
||||
if (fms.need_scaling[j] == 2) {
|
||||
vk[2*j+0] = vk[2*j+1] = F16::zero();
|
||||
} else {
|
||||
auto R = qkv_cache + D*j;
|
||||
vk[2*j+0] = F16::load(R + F16::block_size*i);
|
||||
vk[2*j+1] = F16::load(R + F16::block_size*(i + 1));
|
||||
if (fms.need_scaling[j] == 1) {
|
||||
vk[2*j+0] = F16::mul(vk[2*j+0], fms.vms[j]);
|
||||
vk[2*j+1] = F16::mul(vk[2*j+1], fms.vms[j]);
|
||||
}
|
||||
}
|
||||
F16::Data v[8];
|
||||
for (int j = 0; j < q_step; ++j) {
|
||||
auto R = qkv_cache + D*j;
|
||||
if (fms.need_scaling[j] == 2) {
|
||||
std::memset(R, 0, D*sizeof(qkv_cache_t));
|
||||
}
|
||||
F16::Data v1, v2, v3, v4;
|
||||
for (int l1 = 0; l1 < k_step; l1 += 2) {
|
||||
vh.load(l1+0, i, v1, v2);
|
||||
vh.load(l1+1, i, v3, v4);
|
||||
for (int j = 0; j < q_step; ++j) {
|
||||
auto vs1 = F16::set1(fms.cache[k_step*j + l1+0]);
|
||||
auto vs2 = F16::set1(fms.cache[k_step*j + l1+1]);
|
||||
vk[2*j+0] = F16::fmadd(F16::fmadd(vk[2*j+0], v1, vs1), v3, vs2);
|
||||
vk[2*j+1] = F16::fmadd(F16::fmadd(vk[2*j+1], v2, vs1), v4, vs2);
|
||||
else if (fms.need_scaling[j] == 1) {
|
||||
for (int i = 0; i < D/F16::block_size; ++i) {
|
||||
F16::store(R + F16::block_size*i, F16::mul(fms.vms[j], F16::load(R + F16::block_size*i)));
|
||||
}
|
||||
}
|
||||
for (int j = 0; j < q_step; ++j) {
|
||||
auto R = qkv_cache + D*j;
|
||||
F16::store(R + F16::block_size*(i + 0), vk[2*j+0]);
|
||||
F16::store(R + F16::block_size*(i + 1), vk[2*j+1]);
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < D/F16::block_size; i += 2) {
|
||||
for (int l = 0; l < k_step; l += 4) {
|
||||
vh.load(l+0, i, v[0], v[4]);
|
||||
vh.load(l+1, i, v[1], v[5]);
|
||||
vh.load(l+2, i, v[2], v[6]);
|
||||
vh.load(l+3, i, v[3], v[7]);
|
||||
for (int j = 0; j < q_step; ++j) {
|
||||
auto R = qkv_cache + D*j;
|
||||
auto s1 = F16::load(R + F16::block_size*(i+0));
|
||||
auto s2 = F16::load(R + F16::block_size*(i+1));
|
||||
auto vs = F16::set4(fms.cache + k_step*j + l);
|
||||
s1 = F16::fmadd_lane0(s1, v[0], vs);
|
||||
s2 = F16::fmadd_lane0(s2, v[4], vs);
|
||||
s1 = F16::fmadd_lane1(s1, v[1], vs);
|
||||
s2 = F16::fmadd_lane1(s2, v[5], vs);
|
||||
s1 = F16::fmadd_lane2(s1, v[2], vs);
|
||||
s2 = F16::fmadd_lane2(s2, v[6], vs);
|
||||
s1 = F16::fmadd_lane3(s1, v[3], vs);
|
||||
s2 = F16::fmadd_lane3(s2, v[7], vs);
|
||||
F16::store(R + F16::block_size*(i+0), s1);
|
||||
F16::store(R + F16::block_size*(i+1), s2);
|
||||
}
|
||||
}
|
||||
}
|
||||
//F16::Data vk[2*q_step];
|
||||
//for (int i = 0; i < D/F16::block_size; i += 2) {
|
||||
// for (int j = 0; j < q_step; ++j) {
|
||||
// if (fms.need_scaling[j] == 2) {
|
||||
// vk[2*j+0] = vk[2*j+1] = F16::zero();
|
||||
// } else {
|
||||
// auto R = qkv_cache + D*j;
|
||||
// vk[2*j+0] = F16::load(R + F16::block_size*i);
|
||||
// vk[2*j+1] = F16::load(R + F16::block_size*(i + 1));
|
||||
// if (fms.need_scaling[j] == 1) {
|
||||
// vk[2*j+0] = F16::mul(vk[2*j+0], fms.vms[j]);
|
||||
// vk[2*j+1] = F16::mul(vk[2*j+1], fms.vms[j]);
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// // R[j][i] += sum[l, V[l][i] * C[j][l] ]
|
||||
// // If we transpose V, we can accumulate as
|
||||
// // R[i][j] += sum [l, V[i][l] * C[j][l] ]
|
||||
// // or
|
||||
// // R[j][i] += sum [l, V[i][l] * C[j][l] ]
|
||||
// // so
|
||||
// // for (int j = 0; j < q_step; ++j) {
|
||||
// // for (int i = 0; i < D; ++i) {
|
||||
// // for (int l = 0; l < k_step; ++k) {
|
||||
// // R[j][i] += V[i][l] * C[j][l];
|
||||
// // }
|
||||
// // }
|
||||
// // }
|
||||
// // so
|
||||
// // for (int j = 0; j < q_step; ++j) {
|
||||
// // for (int i = 0; i < D/F16::block_size; ++i) {
|
||||
// // auto acc = F16::load(qkv_cache + D*j + F16::block_size*i);
|
||||
// // for (int l = 0; l < k_step/F16::block_size; ++k) {
|
||||
// // auto vk = vh.load1(j, l);
|
||||
// // auto vs = F16::load(fms.cache(k_step*j + l*F16::block_size];
|
||||
// // acc = F16::fmadd(acc, vk, vs);
|
||||
// // }
|
||||
// // }
|
||||
// // }
|
||||
// // but for k_step = 32, q_step = 8, F16::bloc = 32, we need only 16 registers to load the entire fms.cache
|
||||
// // F16::Data C[4*k_step/F16::block_size];
|
||||
// // F16::Data V[4*k_step/F16::block_size];
|
||||
// // F16::Data acc[4*k_step/F16::block_size];
|
||||
// // for (int j = 0; j < q_step; j += 4) {
|
||||
// // for (int l = 0; l < k_step/F16::block_size; ++l) {
|
||||
// // C[(k_step/F16::block_size)*(j+0) + l] = F16::load(fms.cache + k_step*(j+0) + l*F16::block_size);
|
||||
// // C[(k_step/F16::block_size)*(j+1) + l] = F16::load(fms.cache + k_step*(j+1) + l*F16::block_size);
|
||||
// // C[(k_step/F16::block_size)*(j+2) + l] = F16::load(fms.cache + k_step*(j+2) + l*F16::block_size);
|
||||
// // C[(k_step/F16::block_size)*(j+3) + l] = F16::load(fms.cache + k_step*(j+3) + l*F16::block_size);
|
||||
// // V[(k_step/F16::block_size)*(j+0) + l] = vh.load(j+0, l);
|
||||
// // V[(k_step/F16::block_size)*(j+1) + l] = vh.load(j+1, l);
|
||||
// // V[(k_step/F16::block_size)*(j+2) + l] = vh.load(j+2, l);
|
||||
// // V[(k_step/F16::block_size)*(j+3) + l] = vh.load(j+3, l);
|
||||
// // }
|
||||
// // for (int i = 0; i < D/F16::block_size; ++i) {
|
||||
// // auto acc = F16::load(qkv_cache + D*(j+0) + F16::block_size*i);
|
||||
// // for (int l = 0; l < k_step/F16::block_size; ++l) {
|
||||
// // acc = F16::fmadd(acc, C[(k_step/F16::block_size)*(j+0)+l], V[(k_step/F16::block_size)*(j+0)+l]);
|
||||
// // acc1 = F16::fmadd(F16::fmadd(acc1, C[(k_step/F16::block_size)*(j+0),
|
||||
// // acc[0] = F16::load(qkv_cache + D*(j+0) + F16::block_size*(i+0));
|
||||
// // acc[1] = F16::load(qkv_cache + D*(j+0) + F16::block_size*(i+1));
|
||||
// // acc[2] = F16::load(qkv_cache + D*(j+1) + F16::block_size*(i+0));
|
||||
// // acc[3] = F16::load(qkv_cache + D*(j+1) + F16::block_size*(i+1));
|
||||
// // acc[4] = F16::load(qkv_cache + D*(j+2) + F16::block_size*(i+0));
|
||||
// // acc[5] = F16::load(qkv_cache + D*(j+2) + F16::block_size*(i+1));
|
||||
// // acc[6] = F16::load(qkv_cache + D*(j+3) + F16::block_size*(i+0));
|
||||
// // acc[7] = F16::load(qkv_cache + D*(j+3) + F16::block_size*(i+1));
|
||||
// //
|
||||
// // }
|
||||
// // }
|
||||
// F16::Data v1, v2, v3, v4;
|
||||
// for (int l1 = 0; l1 < k_step; l1 += 2) {
|
||||
// vh.load(l1+0, i, v1, v2);
|
||||
// vh.load(l1+1, i, v3, v4);
|
||||
// for (int j = 0; j < q_step; ++j) {
|
||||
// auto vs1 = F16::set1(fms.cache[k_step*j + l1+0]);
|
||||
// auto vs2 = F16::set1(fms.cache[k_step*j + l1+1]);
|
||||
// vk[2*j+0] = F16::fmadd(F16::fmadd(vk[2*j+0], v1, vs1), v3, vs2);
|
||||
// vk[2*j+1] = F16::fmadd(F16::fmadd(vk[2*j+1], v2, vs1), v4, vs2);
|
||||
// }
|
||||
// }
|
||||
// for (int j = 0; j < q_step; ++j) {
|
||||
// auto R = qkv_cache + D*j;
|
||||
// F16::store(R + F16::block_size*(i + 0), vk[2*j+0]);
|
||||
// F16::store(R + F16::block_size*(i + 1), vk[2*j+1]);
|
||||
// }
|
||||
//}
|
||||
}
|
||||
|
||||
template <typename VHelper, int Nq = q_step, class = std::enable_if<Nq >= 2>>
|
||||
|
||||
Reference in New Issue
Block a user