mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-27 16:44:21 +00:00
Tidy up FlashMS
This commit is contained in:
@@ -7209,8 +7209,51 @@ struct FlashMS {
|
||||
}
|
||||
}
|
||||
|
||||
inline void update_M(int j, float smax) {
|
||||
if (smax == -INFINITY) {
|
||||
std::memset(cache + k_step*j, 0, k_step*sizeof(float));
|
||||
need_scaling[j] = M[j] == -INFINITY ? 2 : 0;
|
||||
return;
|
||||
}
|
||||
need_scaling[j] = 0;
|
||||
if (smax > M[j]) {
|
||||
if (M[j] > -INFINITY) {
|
||||
float m = expf(M[j] - smax);
|
||||
vms[j] = F16::set1(m);
|
||||
need_scaling[j] = 1;
|
||||
S[j] *= m;
|
||||
} else {
|
||||
need_scaling[j] = 2;
|
||||
S[j] = 0;
|
||||
}
|
||||
M[j] = smax;
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef __aarch64__
|
||||
inline void update_M_S(int j, float32x4_t * vk) {
|
||||
inline void update_S(int j, float32x4_t * vk) {
|
||||
auto vm = vdupq_n_f32(M[j]);
|
||||
auto vsum = vdupq_n_f32(0);
|
||||
for (int l = 0; l < k_step/4; ++l) {
|
||||
vk[l] = v_expf(vsubq_f32(vk[l], vm));
|
||||
vsum = vaddq_f32(vsum, vk[l]);
|
||||
F16::store(cache + k_step*j + 4*l, vk[l]);
|
||||
}
|
||||
S[j] += vaddvq_f32(vsum);
|
||||
}
|
||||
#else
|
||||
inline void update_S(int j, F16::Data * vk) {
|
||||
auto vm = F16::set1(M[j]);
|
||||
for (int l = 0; l < k_step/F16::block_size; ++l) {
|
||||
vk[l] = v_expf(F16::sub(vk[l], vm));
|
||||
F16::store(cache + k_step*j + F16::block_size*l, vk[l]);
|
||||
}
|
||||
S[j] += F16::reduce_add<k_step>(vk);
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef __aarch64__
|
||||
inline float load_and_scale(int j, float32x4_t * vk) {
|
||||
float32x4_t vmax = vdupq_n_f32(-INFINITY);
|
||||
// Something goes wrong when storing and manipulating K*Q as fp16.
|
||||
// It works for some models (e.g., Gemma-2), but not for others (e.g., LLaMA-3.1-8B).
|
||||
@@ -7248,37 +7291,9 @@ struct FlashMS {
|
||||
vmax = vmaxq_f32(vmax, vk[l]);
|
||||
}
|
||||
}
|
||||
|
||||
float smax = vmaxvq_f32(vmax);
|
||||
if (smax == -INFINITY) {
|
||||
std::memset(cache + k_step*j, 0, k_step*sizeof(float));
|
||||
need_scaling[j] = M[j] == -INFINITY ? 2 : 0;
|
||||
return;
|
||||
}
|
||||
need_scaling[j] = 0;
|
||||
if (smax > M[j]) {
|
||||
if (M[j] > -INFINITY) {
|
||||
float m = expf(M[j] - smax);
|
||||
vms[j] = F16::set1(m);
|
||||
need_scaling[j] = 1;
|
||||
S[j] *= m;
|
||||
} else {
|
||||
need_scaling[j] = 2;
|
||||
S[j] = 0;
|
||||
}
|
||||
M[j] = smax;
|
||||
}
|
||||
auto vm = vdupq_n_f32(M[j]);
|
||||
auto vsum = vdupq_n_f32(0);
|
||||
for (int l = 0; l < k_step/4; ++l) {
|
||||
vk[l] = v_expf(vsubq_f32(vk[l], vm));
|
||||
vsum = vaddq_f32(vsum, vk[l]);
|
||||
F16::store(cache + k_step*j + 4*l, vk[l]);
|
||||
}
|
||||
S[j] += vaddvq_f32(vsum);
|
||||
return vmaxvq_f32(vmax);
|
||||
}
|
||||
inline void update_M_S(int j, float32x4_t * vk, const char * mask) {
|
||||
{
|
||||
inline float load_apply_mask_and_scale(int j, float32x4_t * vk, const char * mask) {
|
||||
auto vzero = vdupq_n_f32(0);
|
||||
auto vinf = vdupq_n_f32(-INFINITY);
|
||||
for (int l = 0; l < k_step/8; ++l) {
|
||||
@@ -7291,7 +7306,6 @@ struct FlashMS {
|
||||
vk[2*l+1] = vreinterpretq_f32_u32(vorrq_u32(vandq_u32(vreinterpretq_u32_f32(kq.val[1]), vm2),
|
||||
vbicq_u32(vinf, vm2)));
|
||||
}
|
||||
}
|
||||
float32x4_t vmax = vdupq_n_f32(-INFINITY);
|
||||
auto vscale32 = vcvt_f32_f16(vget_low_f16(vscale));
|
||||
if (softcap <= 0.0f) {
|
||||
@@ -7307,37 +7321,10 @@ struct FlashMS {
|
||||
vmax = vmaxq_f32(vmax, vk[l]);
|
||||
}
|
||||
}
|
||||
|
||||
float smax = vmaxvq_f32(vmax);
|
||||
if (smax == -INFINITY) {
|
||||
std::memset(cache + k_step*j, 0, k_step*sizeof(float));
|
||||
need_scaling[j] = M[j] == -INFINITY ? 2 : 0;
|
||||
return;
|
||||
}
|
||||
need_scaling[j] = 0;
|
||||
if (smax > M[j]) {
|
||||
if (M[j] > -INFINITY) {
|
||||
float m = expf(M[j] - smax);
|
||||
vms[j] = F16::set1(m);
|
||||
need_scaling[j] = 1;
|
||||
S[j] *= m;
|
||||
} else {
|
||||
need_scaling[j] = 2;
|
||||
S[j] = 0;
|
||||
}
|
||||
M[j] = smax;
|
||||
}
|
||||
auto vm = vdupq_n_f32(M[j]);
|
||||
auto vsum = vdupq_n_f32(0);
|
||||
for (int l = 0; l < k_step/4; ++l) {
|
||||
vk[l] = v_expf(vsubq_f32(vk[l], vm));
|
||||
vsum = vaddq_f32(vsum, vk[l]);
|
||||
F16::store(cache + k_step*j + 4*l, vk[l]);
|
||||
}
|
||||
S[j] += vaddvq_f32(vsum);
|
||||
return vmaxvq_f32(vmax);
|
||||
}
|
||||
#else
|
||||
inline void update_M_S(int j, F16::Data * vk) {
|
||||
inline float load_and_scale(int j, F16::Data * vk) {
|
||||
if (softcap <= 0.0f) {
|
||||
for (int l = 0; l < k_step/F16::block_size; ++l) vk[l] = F16::mul(vscale, F16::load(cache + k_step*j + F16::block_size*l));
|
||||
} else {
|
||||
@@ -7347,34 +7334,9 @@ struct FlashMS {
|
||||
vk[l] = F16::mul(v_softcap, v_tanh(F16::mul(vscale, val)));
|
||||
}
|
||||
}
|
||||
|
||||
float smax = F16::reduce_max<k_step>(vk);
|
||||
if (smax == -INFINITY) {
|
||||
std::memset(cache + k_step*j, 0, k_step*sizeof(float));
|
||||
need_scaling[j] = M[j] == -INFINITY ? 2 : 0;
|
||||
return;
|
||||
}
|
||||
need_scaling[j] = 0;
|
||||
if (smax > M[j]) {
|
||||
if (M[j] > -INFINITY) {
|
||||
float m = expf(M[j] - smax);
|
||||
vms[j] = F16::set1(m);
|
||||
need_scaling[j] = 1;
|
||||
S[j] *= m;
|
||||
} else {
|
||||
need_scaling[j] = 2;
|
||||
S[j] = 0;
|
||||
}
|
||||
M[j] = smax;
|
||||
}
|
||||
auto vm = F16::set1(M[j]);
|
||||
for (int l = 0; l < k_step/F16::block_size; ++l) {
|
||||
vk[l] = v_expf(F16::sub(vk[l], vm));
|
||||
F16::store(cache + k_step*j + F16::block_size*l, vk[l]);
|
||||
}
|
||||
S[j] += F16::reduce_add<k_step>(vk);
|
||||
return F16::reduce_max<k_step>(vk);
|
||||
}
|
||||
inline void update_M_S(int j, F16::Data * vk, const char * mask) {
|
||||
inline float load_apply_mask_and_scale(int j, F16::Data * vk, const char * mask) {
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
auto vzero = _mm256_set1_epi16(0);
|
||||
auto vinf = _mm512_set1_ps(-INFINITY);
|
||||
@@ -7408,32 +7370,31 @@ struct FlashMS {
|
||||
for (int l = 0; l < k_step/F16::block_size; ++l) vk[l] = F16::mul(v_softcap, v_tanh(F16::mul(vscale, vk[l])));
|
||||
}
|
||||
#endif
|
||||
return F16::reduce_max<k_step>(vk);
|
||||
}
|
||||
#endif
|
||||
|
||||
float smax = F16::reduce_max<k_step>(vk);
|
||||
if (smax == -INFINITY) {
|
||||
std::memset(cache + k_step*j, 0, k_step*sizeof(float));
|
||||
need_scaling[j] = M[j] == -INFINITY ? 2 : 0;
|
||||
return;
|
||||
}
|
||||
need_scaling[j] = 0;
|
||||
if (smax > M[j]) {
|
||||
if (M[j] > -INFINITY) {
|
||||
float m = expf(M[j] - smax);
|
||||
vms[j] = F16::set1(m);
|
||||
need_scaling[j] = 1;
|
||||
S[j] *= m;
|
||||
} else {
|
||||
need_scaling[j] = 2;
|
||||
S[j] = 0;
|
||||
}
|
||||
M[j] = smax;
|
||||
}
|
||||
auto vm = F16::set1(M[j]);
|
||||
for (int l = 0; l < k_step/F16::block_size; ++l) {
|
||||
vk[l] = v_expf(F16::sub(vk[l], vm));
|
||||
F16::store(cache + k_step*j + F16::block_size*l, vk[l]);
|
||||
}
|
||||
S[j] += F16::reduce_add<k_step>(vk);
|
||||
#ifdef __aarch64__
|
||||
inline void update_M_S(int j, float32x4_t * vk) {
|
||||
float smax = load_and_scale(j, vk);
|
||||
update_M(j, smax);
|
||||
update_S(j, vk);
|
||||
}
|
||||
inline void update_M_S(int j, float32x4_t * vk, const char * mask) {
|
||||
float smax = load_apply_mask_and_scale(j, vk, mask);
|
||||
update_M(j, smax);
|
||||
update_S(j, vk);
|
||||
}
|
||||
#else
|
||||
inline void update_M_S(int j, F16::Data * vk) {
|
||||
float smax = load_and_scale(j, vk);
|
||||
update_M(j, smax);
|
||||
update_S(j, vk);
|
||||
}
|
||||
inline void update_M_S(int j, F16::Data * vk, const char * mask) {
|
||||
float smax = load_apply_mask_and_scale(j, vk, mask);
|
||||
update_M(j, smax);
|
||||
update_S(j, vk);
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
Reference in New Issue
Block a user