diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp index 54e53d230c..8d33184442 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp @@ -198,11 +198,6 @@ struct BlockFmhaPipelineQRKSVSAsync constexpr auto LdsSeq = Policy::template GetLdsBufferSequence(); - (void)logits_soft_cap_params; - const float logits_cap = 30.0f; - const float logits_cap_rev = 0.0333333f; - // const float logits_cap_scale = scale_s * rcp(logits_cap * log2e_v); - // K tile in LDS auto k_lds_ptr = reinterpret_cast(smem_ptr); auto k_lds_store = generate_tuple( @@ -446,16 +441,22 @@ struct BlockFmhaPipelineQRKSVSAsync else { s_acc = tile_elementwise_in(s_acc_element_func, s_acc); - // #if !CK_TILE_FMHA_FWD_FAST_EXP2 - float scale_lo = scale_s * 0.6931472f; - // #if !CK_TILE_FMHA_FWD_FAST_EXP2 - tile_elementwise_inout( - [&scale_lo, &logits_cap, &logits_cap_rev](auto& x) { - x = log2e_v * logits_cap * - tanh_fast(x * scale_lo * logits_cap_rev); - }, - s_acc); - // #endif +#if !CK_TILE_FMHA_FWD_FAST_EXP2 + tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc); +#else + if constexpr(kHasLogitsSoftCap) + { + float scale_lo = scale_s * 0.6931472f; + tile_elementwise_inout( + [&scale_lo, + &logits_cap = logits_soft_cap_params.logits_soft_cap, + &logits_cap_rev = logits_soft_cap_params.logits_soft_cap_rcp](auto& x) { + x = log2e_v * logits_cap * + tanh_fast(x * scale_lo * logits_cap_rev); + }, + s_acc); + } +#endif } move_tile_window(bias_dram_window, {0, kN0}); if constexpr(kPadSeqLenK || FmhaMask::IsMasking) @@ -548,9 +549,9 @@ struct BlockFmhaPipelineQRKSVSAsync constexpr auto p_spans = decltype(p_compute)::get_distributed_spans(); sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { constexpr auto i_idx = make_tuple(idx0); - // #if CK_TILE_FMHA_FWD_FAST_EXP2 - // auto row_max = scale_s * get_validated_m(m[i_idx]); - // #endif +#if CK_TILE_FMHA_FWD_FAST_EXP2 + [[maybe_unused]] auto row_max = scale_s * get_validated_m(m[i_idx]); +#endif sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { constexpr auto i_j_idx = make_tuple(idx0, idx1); #if CK_TILE_FMHA_FWD_FAST_EXP2 @@ -561,8 +562,14 @@ struct BlockFmhaPipelineQRKSVSAsync } else { - // p_compute(i_j_idx) = exp2(scale_s * s[i_j_idx] - row_max); - p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx])); + if constexpr(kHasLogitsSoftCap) + { + p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx])); + } + else + { + p_compute(i_j_idx) = exp2(scale_s * s[i_j_idx] - row_max); + } } #else p_compute(i_j_idx) = exp(s[i_j_idx] - get_validated_m(m[i_idx])); @@ -587,9 +594,15 @@ struct BlockFmhaPipelineQRKSVSAsync } else { - // auto row_max = scale_s * get_validated_m(m[i_idx]); - // return exp2(scale_s * m_old[i_idx] - row_max); - return exp2(m_old[i_idx] - get_validated_m(m[i_idx])); + if constexpr(kHasLogitsSoftCap) + { + return exp2(m_old[i_idx] - get_validated_m(m[i_idx])); + } + else + { + auto row_max = scale_s * get_validated_m(m[i_idx]); + return exp2(scale_s * m_old[i_idx] - row_max); + } } }(); #else