mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-01 20:27:42 +00:00
fix bug
This commit is contained in:
@@ -193,6 +193,7 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
constexpr auto LdsSeq = Policy::template GetLdsBufferSequence<Problem>();
|
||||
|
||||
const float logits_cap = 30.0f;
|
||||
const float logits_cap_rev = 0.0333333f;
|
||||
// const float logits_cap_scale = scale_s * rcp<float>(logits_cap * log2e_v<float>);
|
||||
|
||||
// K tile in LDS
|
||||
@@ -439,10 +440,11 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
{
|
||||
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
|
||||
// #if !CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc);
|
||||
float scale_lo = scale_s * 0.6931472f;
|
||||
// #if !CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
tile_elementwise_inout(
|
||||
[&logits_cap](auto& x) {
|
||||
x = log2e_v<float> * logits_cap * tanh_fast<float>(x * rcp<float>(logits_cap));
|
||||
[&scale_lo, &logits_cap, &logits_cap_rev](auto& x) {
|
||||
x = log2e_v<SaccDataType> * logits_cap * tanh_fast<SaccDataType>(x * scale_lo * logits_cap_rev);
|
||||
},
|
||||
s_acc
|
||||
);
|
||||
|
||||
Reference in New Issue
Block a user