Support turn on/off logits_soft_cap in async pipeline

This commit is contained in:
Po Yen Chen
2025-04-20 06:56:20 +00:00
committed by felix
parent b3829c11b5
commit 2632baa05b

View File

@@ -198,11 +198,6 @@ struct BlockFmhaPipelineQRKSVSAsync
constexpr auto LdsSeq = Policy::template GetLdsBufferSequence<Problem>();
(void)logits_soft_cap_params;
const float logits_cap = 30.0f;
const float logits_cap_rev = 0.0333333f;
// const float logits_cap_scale = scale_s * rcp<float>(logits_cap * log2e_v<float>);
// K tile in LDS
auto k_lds_ptr = reinterpret_cast<KDataType*>(smem_ptr);
auto k_lds_store = generate_tuple(
@@ -446,16 +441,22 @@ struct BlockFmhaPipelineQRKSVSAsync
else
{
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
// #if !CK_TILE_FMHA_FWD_FAST_EXP2
float scale_lo = scale_s * 0.6931472f;
// #if !CK_TILE_FMHA_FWD_FAST_EXP2
tile_elementwise_inout(
[&scale_lo, &logits_cap, &logits_cap_rev](auto& x) {
x = log2e_v<SaccDataType> * logits_cap *
tanh_fast<SaccDataType>(x * scale_lo * logits_cap_rev);
},
s_acc);
// #endif
#if !CK_TILE_FMHA_FWD_FAST_EXP2
tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc);
#else
if constexpr(kHasLogitsSoftCap)
{
float scale_lo = scale_s * 0.6931472f;
tile_elementwise_inout(
[&scale_lo,
&logits_cap = logits_soft_cap_params.logits_soft_cap,
&logits_cap_rev = logits_soft_cap_params.logits_soft_cap_rcp](auto& x) {
x = log2e_v<SaccDataType> * logits_cap *
tanh_fast<SaccDataType>(x * scale_lo * logits_cap_rev);
},
s_acc);
}
#endif
}
move_tile_window(bias_dram_window, {0, kN0});
if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
@@ -548,9 +549,9 @@ struct BlockFmhaPipelineQRKSVSAsync
constexpr auto p_spans = decltype(p_compute)::get_distributed_spans();
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
// #if CK_TILE_FMHA_FWD_FAST_EXP2
// auto row_max = scale_s * get_validated_m(m[i_idx]);
// #endif
#if CK_TILE_FMHA_FWD_FAST_EXP2
[[maybe_unused]] auto row_max = scale_s * get_validated_m(m[i_idx]);
#endif
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
#if CK_TILE_FMHA_FWD_FAST_EXP2
@@ -561,8 +562,14 @@ struct BlockFmhaPipelineQRKSVSAsync
}
else
{
// p_compute(i_j_idx) = exp2(scale_s * s[i_j_idx] - row_max);
p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx]));
if constexpr(kHasLogitsSoftCap)
{
p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx]));
}
else
{
p_compute(i_j_idx) = exp2(scale_s * s[i_j_idx] - row_max);
}
}
#else
p_compute(i_j_idx) = exp(s[i_j_idx] - get_validated_m(m[i_idx]));
@@ -587,9 +594,15 @@ struct BlockFmhaPipelineQRKSVSAsync
}
else
{
// auto row_max = scale_s * get_validated_m(m[i_idx]);
// return exp2(scale_s * m_old[i_idx] - row_max);
return exp2(m_old[i_idx] - get_validated_m(m[i_idx]));
if constexpr(kHasLogitsSoftCap)
{
return exp2(m_old[i_idx] - get_validated_m(m[i_idx]));
}
else
{
auto row_max = scale_s * get_validated_m(m[i_idx]);
return exp2(scale_s * m_old[i_idx] - row_max);
}
}
}();
#else