mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-02 13:17:36 +00:00
Support turn on/off logits_soft_cap in async pipeline
This commit is contained in:
@@ -198,11 +198,6 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
|
||||
constexpr auto LdsSeq = Policy::template GetLdsBufferSequence<Problem>();
|
||||
|
||||
(void)logits_soft_cap_params;
|
||||
const float logits_cap = 30.0f;
|
||||
const float logits_cap_rev = 0.0333333f;
|
||||
// const float logits_cap_scale = scale_s * rcp<float>(logits_cap * log2e_v<float>);
|
||||
|
||||
// K tile in LDS
|
||||
auto k_lds_ptr = reinterpret_cast<KDataType*>(smem_ptr);
|
||||
auto k_lds_store = generate_tuple(
|
||||
@@ -446,16 +441,22 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
else
|
||||
{
|
||||
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
|
||||
// #if !CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
float scale_lo = scale_s * 0.6931472f;
|
||||
// #if !CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
tile_elementwise_inout(
|
||||
[&scale_lo, &logits_cap, &logits_cap_rev](auto& x) {
|
||||
x = log2e_v<SaccDataType> * logits_cap *
|
||||
tanh_fast<SaccDataType>(x * scale_lo * logits_cap_rev);
|
||||
},
|
||||
s_acc);
|
||||
// #endif
|
||||
#if !CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc);
|
||||
#else
|
||||
if constexpr(kHasLogitsSoftCap)
|
||||
{
|
||||
float scale_lo = scale_s * 0.6931472f;
|
||||
tile_elementwise_inout(
|
||||
[&scale_lo,
|
||||
&logits_cap = logits_soft_cap_params.logits_soft_cap,
|
||||
&logits_cap_rev = logits_soft_cap_params.logits_soft_cap_rcp](auto& x) {
|
||||
x = log2e_v<SaccDataType> * logits_cap *
|
||||
tanh_fast<SaccDataType>(x * scale_lo * logits_cap_rev);
|
||||
},
|
||||
s_acc);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
move_tile_window(bias_dram_window, {0, kN0});
|
||||
if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
|
||||
@@ -548,9 +549,9 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
constexpr auto p_spans = decltype(p_compute)::get_distributed_spans();
|
||||
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
// #if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
// auto row_max = scale_s * get_validated_m(m[i_idx]);
|
||||
// #endif
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
[[maybe_unused]] auto row_max = scale_s * get_validated_m(m[i_idx]);
|
||||
#endif
|
||||
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
@@ -561,8 +562,14 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
}
|
||||
else
|
||||
{
|
||||
// p_compute(i_j_idx) = exp2(scale_s * s[i_j_idx] - row_max);
|
||||
p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx]));
|
||||
if constexpr(kHasLogitsSoftCap)
|
||||
{
|
||||
p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx]));
|
||||
}
|
||||
else
|
||||
{
|
||||
p_compute(i_j_idx) = exp2(scale_s * s[i_j_idx] - row_max);
|
||||
}
|
||||
}
|
||||
#else
|
||||
p_compute(i_j_idx) = exp(s[i_j_idx] - get_validated_m(m[i_idx]));
|
||||
@@ -587,9 +594,15 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
}
|
||||
else
|
||||
{
|
||||
// auto row_max = scale_s * get_validated_m(m[i_idx]);
|
||||
// return exp2(scale_s * m_old[i_idx] - row_max);
|
||||
return exp2(m_old[i_idx] - get_validated_m(m[i_idx]));
|
||||
if constexpr(kHasLogitsSoftCap)
|
||||
{
|
||||
return exp2(m_old[i_idx] - get_validated_m(m[i_idx]));
|
||||
}
|
||||
else
|
||||
{
|
||||
auto row_max = scale_s * get_validated_m(m[i_idx]);
|
||||
return exp2(scale_s * m_old[i_idx] - row_max);
|
||||
}
|
||||
}
|
||||
}();
|
||||
#else
|
||||
|
||||
Reference in New Issue
Block a user