mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-24 23:24:13 +00:00
FA: don't store sum scaling factor in SIMD registers
This commit is contained in:
@@ -12861,7 +12861,7 @@ struct FlashMS {
|
||||
if (smax > M[j]) {
|
||||
if (M[j] > -INFINITY) {
|
||||
float m = expf(M[j] - smax);
|
||||
vms[j] = F16::set1(m);
|
||||
vms[j] = m;
|
||||
need_scaling[j] = 1;
|
||||
S[j] *= m;
|
||||
} else {
|
||||
@@ -13043,7 +13043,7 @@ struct FlashMS {
|
||||
cache_t cache[q_step*k_step];
|
||||
float S[q_step], M[q_step];
|
||||
int need_scaling[q_step];
|
||||
F16::Data vms[q_step];
|
||||
float vms[q_step];
|
||||
const F16::Data vscale;
|
||||
const float softcap;
|
||||
const ggml_half h_inf;
|
||||
@@ -13070,8 +13070,9 @@ struct FlashQKV {
|
||||
std::memset(R, 0, D*sizeof(qkv_cache_t));
|
||||
}
|
||||
else if (fms.need_scaling[j] == 1) {
|
||||
auto vms = F16::set1(fms.vms[j]);
|
||||
for (int i = 0; i < D/F16::block_size; ++i) {
|
||||
F16::store(R + F16::block_size*i, F16::mul(fms.vms[j], F16::load(R + F16::block_size*i)));
|
||||
F16::store(R + F16::block_size*i, F16::mul(vms, F16::load(R + F16::block_size*i)));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -13110,8 +13111,9 @@ struct FlashQKV {
|
||||
std::memset(R, 0, D*sizeof(qkv_cache_t));
|
||||
}
|
||||
else if (fms.need_scaling[j] == 1) {
|
||||
auto vms = F16::set1(fms.vms[j]);
|
||||
for (int i = 0; i < D/F16::block_size; ++i) {
|
||||
F16::store(R + F16::block_size*i, F16::mul(fms.vms[j], F16::load(R + F16::block_size*i)));
|
||||
F16::store(R + F16::block_size*i, F16::mul(vms, F16::load(R + F16::block_size*i)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user