FA: slightly faster V*softmax(K*Q)) on Zen4

This commit is contained in:
Iwan Kawrakow
2025-01-16 17:10:42 +02:00
parent 0b74397d59
commit 4753c861d1

View File

@@ -12354,6 +12354,15 @@ struct F16 {
static inline float reduce_add(Data data) { return _mm512_reduce_add_ps(data); }
static inline Data max(Data v1, Data v2) { return _mm512_max_ps(v1, v2); }
static inline Data add(Data v1, Data v2) { return _mm512_add_ps(v1, v2); }
static inline Data set4(const float * ptr) {
auto v128 = _mm_loadu_ps(ptr);
auto v256 = _mm256_set_m128(v128, v128);
return _mm512_insertf32x8(_mm512_castps256_ps512(v256), v256, 1);
}
static inline Data fmadd_lane0(Data prev, Data v1, Data v2) { return _mm512_fmadd_ps(v1, _mm512_shuffle_ps(v2, v2, 0x00), prev); }
static inline Data fmadd_lane1(Data prev, Data v1, Data v2) { return _mm512_fmadd_ps(v1, _mm512_shuffle_ps(v2, v2, 0x55), prev); }
static inline Data fmadd_lane2(Data prev, Data v1, Data v2) { return _mm512_fmadd_ps(v1, _mm512_shuffle_ps(v2, v2, 0xaa), prev); }
static inline Data fmadd_lane3(Data prev, Data v1, Data v2) { return _mm512_fmadd_ps(v1, _mm512_shuffle_ps(v2, v2, 0xff), prev); }
#elif defined __AVX2__
using Data = __m256;
constexpr static int block_size = 8;
@@ -12371,6 +12380,14 @@ struct F16 {
static inline float reduce_add(Data data) { return hsum_float_8(data); }
static inline Data max(Data v1, Data v2) { return _mm256_max_ps(v1, v2); }
static inline Data add(Data v1, Data v2) { return _mm256_add_ps(v1, v2); }
static inline Data set4(const float * ptr) {
auto v128 = _mm_loadu_ps(ptr);
return _mm256_set_m128(v128, v128);
}
static inline Data fmadd_lane0(Data prev, Data v1, Data v2) { return _mm256_fmadd_ps(v1, _mm256_shuffle_ps(v2, v2, 0x00), prev); }
static inline Data fmadd_lane1(Data prev, Data v1, Data v2) { return _mm256_fmadd_ps(v1, _mm256_shuffle_ps(v2, v2, 0x55), prev); }
static inline Data fmadd_lane2(Data prev, Data v1, Data v2) { return _mm256_fmadd_ps(v1, _mm256_shuffle_ps(v2, v2, 0xaa), prev); }
static inline Data fmadd_lane3(Data prev, Data v1, Data v2) { return _mm256_fmadd_ps(v1, _mm256_shuffle_ps(v2, v2, 0xff), prev); }
#else
using Data = float16x8_t;
constexpr static int block_size = 8;
@@ -12402,6 +12419,15 @@ struct F16 {
}
static inline Data max(Data v1, Data v2) { return vmaxq_f16(v1, v2); }
static inline Data add(Data v1, Data v2) { return vaddq_f16(v1, v2); }
static inline Data set4(const float * ptr) {
auto val32 = vld1q_f32(ptr);
auto val16 = vcvt_f16_f32(val32);
return vcombine_f16(val16, val16);
}
static inline Data fmadd_lane0(Data prev, Data v1, Data v2) { return vfmaq_laneq_f16(vfmaq_laneq_f16(prev, v1, v2, 0), v1, v2, 4); }
static inline Data fmadd_lane1(Data prev, Data v1, Data v2) { return vfmaq_laneq_f16(vfmaq_laneq_f16(prev, v1, v2, 1), v1, v2, 5); }
static inline Data fmadd_lane2(Data prev, Data v1, Data v2) { return vfmaq_laneq_f16(vfmaq_laneq_f16(prev, v1, v2, 2), v1, v2, 6); }
static inline Data fmadd_lane3(Data prev, Data v1, Data v2) { return vfmaq_laneq_f16(vfmaq_laneq_f16(prev, v1, v2, 3), v1, v2, 7); }
#endif
template <int k_step> static inline float reduce_max(const Data * data) {
return reduce_T<k_step, &F16::max, &F16::reduce_max>(data);
@@ -12927,38 +12953,129 @@ struct FlashQKV {
// Hence, for now, we will not handle head sizes of 80 and 112
template <typename VHelper>
inline void accumulate_qkv(const VHelper& vh, const FlashMS<q_step, k_step>& fms) {
F16::Data vk[2*q_step];
for (int i = 0; i < D/F16::block_size; i += 2) {
for (int j = 0; j < q_step; ++j) {
if (fms.need_scaling[j] == 2) {
vk[2*j+0] = vk[2*j+1] = F16::zero();
} else {
auto R = qkv_cache + D*j;
vk[2*j+0] = F16::load(R + F16::block_size*i);
vk[2*j+1] = F16::load(R + F16::block_size*(i + 1));
if (fms.need_scaling[j] == 1) {
vk[2*j+0] = F16::mul(vk[2*j+0], fms.vms[j]);
vk[2*j+1] = F16::mul(vk[2*j+1], fms.vms[j]);
}
}
F16::Data v[8];
for (int j = 0; j < q_step; ++j) {
auto R = qkv_cache + D*j;
if (fms.need_scaling[j] == 2) {
std::memset(R, 0, D*sizeof(qkv_cache_t));
}
F16::Data v1, v2, v3, v4;
for (int l1 = 0; l1 < k_step; l1 += 2) {
vh.load(l1+0, i, v1, v2);
vh.load(l1+1, i, v3, v4);
for (int j = 0; j < q_step; ++j) {
auto vs1 = F16::set1(fms.cache[k_step*j + l1+0]);
auto vs2 = F16::set1(fms.cache[k_step*j + l1+1]);
vk[2*j+0] = F16::fmadd(F16::fmadd(vk[2*j+0], v1, vs1), v3, vs2);
vk[2*j+1] = F16::fmadd(F16::fmadd(vk[2*j+1], v2, vs1), v4, vs2);
else if (fms.need_scaling[j] == 1) {
for (int i = 0; i < D/F16::block_size; ++i) {
F16::store(R + F16::block_size*i, F16::mul(fms.vms[j], F16::load(R + F16::block_size*i)));
}
}
for (int j = 0; j < q_step; ++j) {
auto R = qkv_cache + D*j;
F16::store(R + F16::block_size*(i + 0), vk[2*j+0]);
F16::store(R + F16::block_size*(i + 1), vk[2*j+1]);
}
}
for (int i = 0; i < D/F16::block_size; i += 2) {
for (int l = 0; l < k_step; l += 4) {
vh.load(l+0, i, v[0], v[4]);
vh.load(l+1, i, v[1], v[5]);
vh.load(l+2, i, v[2], v[6]);
vh.load(l+3, i, v[3], v[7]);
for (int j = 0; j < q_step; ++j) {
auto R = qkv_cache + D*j;
auto s1 = F16::load(R + F16::block_size*(i+0));
auto s2 = F16::load(R + F16::block_size*(i+1));
auto vs = F16::set4(fms.cache + k_step*j + l);
s1 = F16::fmadd_lane0(s1, v[0], vs);
s2 = F16::fmadd_lane0(s2, v[4], vs);
s1 = F16::fmadd_lane1(s1, v[1], vs);
s2 = F16::fmadd_lane1(s2, v[5], vs);
s1 = F16::fmadd_lane2(s1, v[2], vs);
s2 = F16::fmadd_lane2(s2, v[6], vs);
s1 = F16::fmadd_lane3(s1, v[3], vs);
s2 = F16::fmadd_lane3(s2, v[7], vs);
F16::store(R + F16::block_size*(i+0), s1);
F16::store(R + F16::block_size*(i+1), s2);
}
}
}
//F16::Data vk[2*q_step];
//for (int i = 0; i < D/F16::block_size; i += 2) {
// for (int j = 0; j < q_step; ++j) {
// if (fms.need_scaling[j] == 2) {
// vk[2*j+0] = vk[2*j+1] = F16::zero();
// } else {
// auto R = qkv_cache + D*j;
// vk[2*j+0] = F16::load(R + F16::block_size*i);
// vk[2*j+1] = F16::load(R + F16::block_size*(i + 1));
// if (fms.need_scaling[j] == 1) {
// vk[2*j+0] = F16::mul(vk[2*j+0], fms.vms[j]);
// vk[2*j+1] = F16::mul(vk[2*j+1], fms.vms[j]);
// }
// }
// }
// // R[j][i] += sum[l, V[l][i] * C[j][l] ]
// // If we transpose V, we can accumulate as
// // R[i][j] += sum [l, V[i][l] * C[j][l] ]
// // or
// // R[j][i] += sum [l, V[i][l] * C[j][l] ]
// // so
// // for (int j = 0; j < q_step; ++j) {
// // for (int i = 0; i < D; ++i) {
// // for (int l = 0; l < k_step; ++k) {
// // R[j][i] += V[i][l] * C[j][l];
// // }
// // }
// // }
// // so
// // for (int j = 0; j < q_step; ++j) {
// // for (int i = 0; i < D/F16::block_size; ++i) {
// // auto acc = F16::load(qkv_cache + D*j + F16::block_size*i);
// // for (int l = 0; l < k_step/F16::block_size; ++k) {
// // auto vk = vh.load1(j, l);
// // auto vs = F16::load(fms.cache(k_step*j + l*F16::block_size];
// // acc = F16::fmadd(acc, vk, vs);
// // }
// // }
// // }
// // but for k_step = 32, q_step = 8, F16::bloc = 32, we need only 16 registers to load the entire fms.cache
// // F16::Data C[4*k_step/F16::block_size];
// // F16::Data V[4*k_step/F16::block_size];
// // F16::Data acc[4*k_step/F16::block_size];
// // for (int j = 0; j < q_step; j += 4) {
// // for (int l = 0; l < k_step/F16::block_size; ++l) {
// // C[(k_step/F16::block_size)*(j+0) + l] = F16::load(fms.cache + k_step*(j+0) + l*F16::block_size);
// // C[(k_step/F16::block_size)*(j+1) + l] = F16::load(fms.cache + k_step*(j+1) + l*F16::block_size);
// // C[(k_step/F16::block_size)*(j+2) + l] = F16::load(fms.cache + k_step*(j+2) + l*F16::block_size);
// // C[(k_step/F16::block_size)*(j+3) + l] = F16::load(fms.cache + k_step*(j+3) + l*F16::block_size);
// // V[(k_step/F16::block_size)*(j+0) + l] = vh.load(j+0, l);
// // V[(k_step/F16::block_size)*(j+1) + l] = vh.load(j+1, l);
// // V[(k_step/F16::block_size)*(j+2) + l] = vh.load(j+2, l);
// // V[(k_step/F16::block_size)*(j+3) + l] = vh.load(j+3, l);
// // }
// // for (int i = 0; i < D/F16::block_size; ++i) {
// // auto acc = F16::load(qkv_cache + D*(j+0) + F16::block_size*i);
// // for (int l = 0; l < k_step/F16::block_size; ++l) {
// // acc = F16::fmadd(acc, C[(k_step/F16::block_size)*(j+0)+l], V[(k_step/F16::block_size)*(j+0)+l]);
// // acc1 = F16::fmadd(F16::fmadd(acc1, C[(k_step/F16::block_size)*(j+0),
// // acc[0] = F16::load(qkv_cache + D*(j+0) + F16::block_size*(i+0));
// // acc[1] = F16::load(qkv_cache + D*(j+0) + F16::block_size*(i+1));
// // acc[2] = F16::load(qkv_cache + D*(j+1) + F16::block_size*(i+0));
// // acc[3] = F16::load(qkv_cache + D*(j+1) + F16::block_size*(i+1));
// // acc[4] = F16::load(qkv_cache + D*(j+2) + F16::block_size*(i+0));
// // acc[5] = F16::load(qkv_cache + D*(j+2) + F16::block_size*(i+1));
// // acc[6] = F16::load(qkv_cache + D*(j+3) + F16::block_size*(i+0));
// // acc[7] = F16::load(qkv_cache + D*(j+3) + F16::block_size*(i+1));
// //
// // }
// // }
// F16::Data v1, v2, v3, v4;
// for (int l1 = 0; l1 < k_step; l1 += 2) {
// vh.load(l1+0, i, v1, v2);
// vh.load(l1+1, i, v3, v4);
// for (int j = 0; j < q_step; ++j) {
// auto vs1 = F16::set1(fms.cache[k_step*j + l1+0]);
// auto vs2 = F16::set1(fms.cache[k_step*j + l1+1]);
// vk[2*j+0] = F16::fmadd(F16::fmadd(vk[2*j+0], v1, vs1), v3, vs2);
// vk[2*j+1] = F16::fmadd(F16::fmadd(vk[2*j+1], v2, vs1), v4, vs2);
// }
// }
// for (int j = 0; j < q_step; ++j) {
// auto R = qkv_cache + D*j;
// F16::store(R + F16::block_size*(i + 0), vk[2*j+0]);
// F16::store(R + F16::block_size*(i + 1), vk[2*j+1]);
// }
//}
}
template <typename VHelper, int Nq = q_step, class = std::enable_if<Nq >= 2>>