This commit is contained in:
coderfeli
2025-04-18 10:45:53 +00:00
parent 838ff7034d
commit d2c653f177

View File

@@ -193,6 +193,7 @@ struct BlockFmhaPipelineQRKSVSAsync
constexpr auto LdsSeq = Policy::template GetLdsBufferSequence<Problem>();
const float logits_cap = 30.0f;
const float logits_cap_rev = 0.0333333f;
// const float logits_cap_scale = scale_s * rcp<float>(logits_cap * log2e_v<float>);
// K tile in LDS
@@ -439,10 +440,11 @@ struct BlockFmhaPipelineQRKSVSAsync
{
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
// #if !CK_TILE_FMHA_FWD_FAST_EXP2
tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc);
float scale_lo = scale_s * 0.6931472f;
// #if !CK_TILE_FMHA_FWD_FAST_EXP2
tile_elementwise_inout(
[&logits_cap](auto& x) {
x = log2e_v<float> * logits_cap * tanh_fast<float>(x * rcp<float>(logits_cap));
[&scale_lo, &logits_cap, &logits_cap_rev](auto& x) {
x = log2e_v<SaccDataType> * logits_cap * tanh_fast<SaccDataType>(x * scale_lo * logits_cap_rev);
},
s_acc
);