FA: don't store sum scaling factor in SIMD registers

This commit is contained in:
Iwan Kawrakow
2025-01-18 09:44:32 +02:00
parent 0e8cfb3d78
commit 7efe16f715

View File

@@ -12861,7 +12861,7 @@ struct FlashMS {
if (smax > M[j]) {
if (M[j] > -INFINITY) {
float m = expf(M[j] - smax);
vms[j] = F16::set1(m);
vms[j] = m;
need_scaling[j] = 1;
S[j] *= m;
} else {
@@ -13043,7 +13043,7 @@ struct FlashMS {
cache_t cache[q_step*k_step];
float S[q_step], M[q_step];
int need_scaling[q_step];
F16::Data vms[q_step];
float vms[q_step];
const F16::Data vscale;
const float softcap;
const ggml_half h_inf;
@@ -13070,8 +13070,9 @@ struct FlashQKV {
std::memset(R, 0, D*sizeof(qkv_cache_t));
}
else if (fms.need_scaling[j] == 1) {
auto vms = F16::set1(fms.vms[j]);
for (int i = 0; i < D/F16::block_size; ++i) {
F16::store(R + F16::block_size*i, F16::mul(fms.vms[j], F16::load(R + F16::block_size*i)));
F16::store(R + F16::block_size*i, F16::mul(vms, F16::load(R + F16::block_size*i)));
}
}
}
@@ -13110,8 +13111,9 @@ struct FlashQKV {
std::memset(R, 0, D*sizeof(qkv_cache_t));
}
else if (fms.need_scaling[j] == 1) {
auto vms = F16::set1(fms.vms[j]);
for (int i = 0; i < D/F16::block_size; ++i) {
F16::store(R + F16::block_size*i, F16::mul(fms.vms[j], F16::load(R + F16::block_size*i)));
F16::store(R + F16::block_size*i, F16::mul(vms, F16::load(R + F16::block_size*i)));
}
}
}