Tidy up FlashMS

This commit is contained in:
Iwan Kawrakow
2024-09-12 14:50:50 +03:00
parent 7c4bc981dc
commit 0a17ff156f

View File

@@ -7209,8 +7209,51 @@ struct FlashMS {
}
}
inline void update_M(int j, float smax) {
if (smax == -INFINITY) {
std::memset(cache + k_step*j, 0, k_step*sizeof(float));
need_scaling[j] = M[j] == -INFINITY ? 2 : 0;
return;
}
need_scaling[j] = 0;
if (smax > M[j]) {
if (M[j] > -INFINITY) {
float m = expf(M[j] - smax);
vms[j] = F16::set1(m);
need_scaling[j] = 1;
S[j] *= m;
} else {
need_scaling[j] = 2;
S[j] = 0;
}
M[j] = smax;
}
}
#ifdef __aarch64__
inline void update_M_S(int j, float32x4_t * vk) {
inline void update_S(int j, float32x4_t * vk) {
auto vm = vdupq_n_f32(M[j]);
auto vsum = vdupq_n_f32(0);
for (int l = 0; l < k_step/4; ++l) {
vk[l] = v_expf(vsubq_f32(vk[l], vm));
vsum = vaddq_f32(vsum, vk[l]);
F16::store(cache + k_step*j + 4*l, vk[l]);
}
S[j] += vaddvq_f32(vsum);
}
#else
inline void update_S(int j, F16::Data * vk) {
auto vm = F16::set1(M[j]);
for (int l = 0; l < k_step/F16::block_size; ++l) {
vk[l] = v_expf(F16::sub(vk[l], vm));
F16::store(cache + k_step*j + F16::block_size*l, vk[l]);
}
S[j] += F16::reduce_add<k_step>(vk);
}
#endif
#ifdef __aarch64__
inline float load_and_scale(int j, float32x4_t * vk) {
float32x4_t vmax = vdupq_n_f32(-INFINITY);
// Something goes wrong when storing and manipulating K*Q as fp16.
// It works for some models (e.g., Gemma-2), but not for others (e.g., LLaMA-3.1-8B).
@@ -7248,37 +7291,9 @@ struct FlashMS {
vmax = vmaxq_f32(vmax, vk[l]);
}
}
float smax = vmaxvq_f32(vmax);
if (smax == -INFINITY) {
std::memset(cache + k_step*j, 0, k_step*sizeof(float));
need_scaling[j] = M[j] == -INFINITY ? 2 : 0;
return;
}
need_scaling[j] = 0;
if (smax > M[j]) {
if (M[j] > -INFINITY) {
float m = expf(M[j] - smax);
vms[j] = F16::set1(m);
need_scaling[j] = 1;
S[j] *= m;
} else {
need_scaling[j] = 2;
S[j] = 0;
}
M[j] = smax;
}
auto vm = vdupq_n_f32(M[j]);
auto vsum = vdupq_n_f32(0);
for (int l = 0; l < k_step/4; ++l) {
vk[l] = v_expf(vsubq_f32(vk[l], vm));
vsum = vaddq_f32(vsum, vk[l]);
F16::store(cache + k_step*j + 4*l, vk[l]);
}
S[j] += vaddvq_f32(vsum);
return vmaxvq_f32(vmax);
}
inline void update_M_S(int j, float32x4_t * vk, const char * mask) {
{
inline float load_apply_mask_and_scale(int j, float32x4_t * vk, const char * mask) {
auto vzero = vdupq_n_f32(0);
auto vinf = vdupq_n_f32(-INFINITY);
for (int l = 0; l < k_step/8; ++l) {
@@ -7291,7 +7306,6 @@ struct FlashMS {
vk[2*l+1] = vreinterpretq_f32_u32(vorrq_u32(vandq_u32(vreinterpretq_u32_f32(kq.val[1]), vm2),
vbicq_u32(vinf, vm2)));
}
}
float32x4_t vmax = vdupq_n_f32(-INFINITY);
auto vscale32 = vcvt_f32_f16(vget_low_f16(vscale));
if (softcap <= 0.0f) {
@@ -7307,37 +7321,10 @@ struct FlashMS {
vmax = vmaxq_f32(vmax, vk[l]);
}
}
float smax = vmaxvq_f32(vmax);
if (smax == -INFINITY) {
std::memset(cache + k_step*j, 0, k_step*sizeof(float));
need_scaling[j] = M[j] == -INFINITY ? 2 : 0;
return;
}
need_scaling[j] = 0;
if (smax > M[j]) {
if (M[j] > -INFINITY) {
float m = expf(M[j] - smax);
vms[j] = F16::set1(m);
need_scaling[j] = 1;
S[j] *= m;
} else {
need_scaling[j] = 2;
S[j] = 0;
}
M[j] = smax;
}
auto vm = vdupq_n_f32(M[j]);
auto vsum = vdupq_n_f32(0);
for (int l = 0; l < k_step/4; ++l) {
vk[l] = v_expf(vsubq_f32(vk[l], vm));
vsum = vaddq_f32(vsum, vk[l]);
F16::store(cache + k_step*j + 4*l, vk[l]);
}
S[j] += vaddvq_f32(vsum);
return vmaxvq_f32(vmax);
}
#else
inline void update_M_S(int j, F16::Data * vk) {
inline float load_and_scale(int j, F16::Data * vk) {
if (softcap <= 0.0f) {
for (int l = 0; l < k_step/F16::block_size; ++l) vk[l] = F16::mul(vscale, F16::load(cache + k_step*j + F16::block_size*l));
} else {
@@ -7347,34 +7334,9 @@ struct FlashMS {
vk[l] = F16::mul(v_softcap, v_tanh(F16::mul(vscale, val)));
}
}
float smax = F16::reduce_max<k_step>(vk);
if (smax == -INFINITY) {
std::memset(cache + k_step*j, 0, k_step*sizeof(float));
need_scaling[j] = M[j] == -INFINITY ? 2 : 0;
return;
}
need_scaling[j] = 0;
if (smax > M[j]) {
if (M[j] > -INFINITY) {
float m = expf(M[j] - smax);
vms[j] = F16::set1(m);
need_scaling[j] = 1;
S[j] *= m;
} else {
need_scaling[j] = 2;
S[j] = 0;
}
M[j] = smax;
}
auto vm = F16::set1(M[j]);
for (int l = 0; l < k_step/F16::block_size; ++l) {
vk[l] = v_expf(F16::sub(vk[l], vm));
F16::store(cache + k_step*j + F16::block_size*l, vk[l]);
}
S[j] += F16::reduce_add<k_step>(vk);
return F16::reduce_max<k_step>(vk);
}
inline void update_M_S(int j, F16::Data * vk, const char * mask) {
inline float load_apply_mask_and_scale(int j, F16::Data * vk, const char * mask) {
#ifdef HAVE_FANCY_SIMD
auto vzero = _mm256_set1_epi16(0);
auto vinf = _mm512_set1_ps(-INFINITY);
@@ -7408,32 +7370,31 @@ struct FlashMS {
for (int l = 0; l < k_step/F16::block_size; ++l) vk[l] = F16::mul(v_softcap, v_tanh(F16::mul(vscale, vk[l])));
}
#endif
return F16::reduce_max<k_step>(vk);
}
#endif
float smax = F16::reduce_max<k_step>(vk);
if (smax == -INFINITY) {
std::memset(cache + k_step*j, 0, k_step*sizeof(float));
need_scaling[j] = M[j] == -INFINITY ? 2 : 0;
return;
}
need_scaling[j] = 0;
if (smax > M[j]) {
if (M[j] > -INFINITY) {
float m = expf(M[j] - smax);
vms[j] = F16::set1(m);
need_scaling[j] = 1;
S[j] *= m;
} else {
need_scaling[j] = 2;
S[j] = 0;
}
M[j] = smax;
}
auto vm = F16::set1(M[j]);
for (int l = 0; l < k_step/F16::block_size; ++l) {
vk[l] = v_expf(F16::sub(vk[l], vm));
F16::store(cache + k_step*j + F16::block_size*l, vk[l]);
}
S[j] += F16::reduce_add<k_step>(vk);
#ifdef __aarch64__
inline void update_M_S(int j, float32x4_t * vk) {
float smax = load_and_scale(j, vk);
update_M(j, smax);
update_S(j, vk);
}
inline void update_M_S(int j, float32x4_t * vk, const char * mask) {
float smax = load_apply_mask_and_scale(j, vk, mask);
update_M(j, smax);
update_S(j, vk);
}
#else
inline void update_M_S(int j, F16::Data * vk) {
float smax = load_and_scale(j, vk);
update_M(j, smax);
update_S(j, vk);
}
inline void update_M_S(int j, F16::Data * vk, const char * mask) {
float smax = load_apply_mask_and_scale(j, vk, mask);
update_M(j, smax);
update_S(j, vk);
}
#endif