Sync logits soft-capping across pipelines

This commit is contained in:
Po Yen Chen
2025-04-20 07:55:36 +00:00
parent e0e9040f88
commit 889b2d33fd
3 changed files with 71 additions and 15 deletions

View File

@@ -56,7 +56,9 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
static constexpr bool kIsPagedKV = Problem::kIsPagedKV;
static constexpr bool kHasUnevenSplits = Problem::kHasUnevenSplits;
static_assert(CK_TILE_FMHA_FWD_FAST_EXP2 ||
static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 &&
(kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS ||
!kHasLogitsSoftCap)) ||
(!CK_TILE_FMHA_FWD_FAST_EXP2 && !kHasLogitsSoftCap));
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
@@ -173,8 +175,6 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
"wrong!");
(void)logits_soft_cap_params;
// K tile in LDS
KDataType* k_lds_ptr = static_cast<KDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeQ<Problem>()));
@@ -411,6 +411,19 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
#if !CK_TILE_FMHA_FWD_FAST_EXP2
tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc);
#else
if constexpr(kHasLogitsSoftCap)
{
float scale_lo = scale_s * 0.6931472f;
tile_elementwise_inout(
[&scale_lo,
&logits_cap = logits_soft_cap_params.logits_soft_cap,
&logits_cap_rev = logits_soft_cap_params.logits_soft_cap_rcp](auto& x) {
x = log2e_v<SaccDataType> * logits_cap *
tanh_fast<SaccDataType>(x * scale_lo * logits_cap_rev);
},
s_acc);
}
#endif
}
move_tile_window(bias_dram_window, {0, kN0});
@@ -493,7 +506,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
#if CK_TILE_FMHA_FWD_FAST_EXP2
auto row_max = scale_s * get_validated_m(m[i_idx]);
[[maybe_unused]] auto row_max = scale_s * get_validated_m(m[i_idx]);
#endif
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
@@ -505,7 +518,14 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
}
else
{
p_compute(i_j_idx) = exp2(scale_s * s[i_j_idx] - row_max);
if constexpr(kHasLogitsSoftCap)
{
p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx]));
}
else
{
p_compute(i_j_idx) = exp2(scale_s * s[i_j_idx] - row_max);
}
}
#else
p_compute(i_j_idx) = exp(s[i_j_idx] - get_validated_m(m[i_idx]));
@@ -530,8 +550,15 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
}
else
{
auto row_max = scale_s * get_validated_m(m[i_idx]);
return exp2(scale_s * m_old[i_idx] - row_max);
if constexpr(kHasLogitsSoftCap)
{
return exp2(m_old[i_idx] - get_validated_m(m[i_idx]));
}
else
{
auto row_max = scale_s * get_validated_m(m[i_idx]);
return exp2(scale_s * m_old[i_idx] - row_max);
}
}
}();
#else

View File

@@ -57,7 +57,9 @@ struct BlockFmhaPipelineQRKSVS
static constexpr bool kStoreLSE = Problem::kStoreLSE;
static constexpr bool kHasDropout = Problem::kHasDropout;
static_assert(CK_TILE_FMHA_FWD_FAST_EXP2 ||
static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 &&
(kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS ||
!kHasLogitsSoftCap)) ||
(!CK_TILE_FMHA_FWD_FAST_EXP2 && !kHasLogitsSoftCap));
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
@@ -171,8 +173,6 @@ struct BlockFmhaPipelineQRKSVS
kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
"wrong!");
(void)logits_soft_cap_params;
// K tile in LDS
KDataType* k_lds_ptr = static_cast<KDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeQ<Problem>()));
@@ -390,6 +390,19 @@ struct BlockFmhaPipelineQRKSVS
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
#if !CK_TILE_FMHA_FWD_FAST_EXP2
tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc);
#else
if constexpr(kHasLogitsSoftCap)
{
float scale_lo = scale_s * 0.6931472f;
tile_elementwise_inout(
[&scale_lo,
&logits_cap = logits_soft_cap_params.logits_soft_cap,
&logits_cap_rev = logits_soft_cap_params.logits_soft_cap_rcp](auto& x) {
x = log2e_v<SaccDataType> * logits_cap *
tanh_fast<SaccDataType>(x * scale_lo * logits_cap_rev);
},
s_acc);
}
#endif
}
move_tile_window(bias_dram_window, {0, kN0});
@@ -446,7 +459,7 @@ struct BlockFmhaPipelineQRKSVS
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
#if CK_TILE_FMHA_FWD_FAST_EXP2
auto row_max = scale_s * get_validated_m(m[i_idx]);
[[maybe_unused]] auto row_max = scale_s * get_validated_m(m[i_idx]);
#endif
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
@@ -458,7 +471,14 @@ struct BlockFmhaPipelineQRKSVS
}
else
{
p_compute(i_j_idx) = exp2(scale_s * s[i_j_idx] - row_max);
if constexpr(kHasLogitsSoftCap)
{
p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx]));
}
else
{
p_compute(i_j_idx) = exp2(scale_s * s[i_j_idx] - row_max);
}
}
#else
p_compute(i_j_idx) = exp(s[i_j_idx] - get_validated_m(m[i_idx]));
@@ -483,8 +503,15 @@ struct BlockFmhaPipelineQRKSVS
}
else
{
auto row_max = scale_s * get_validated_m(m[i_idx]);
return exp2(scale_s * m_old[i_idx] - row_max);
if constexpr(kHasLogitsSoftCap)
{
return exp2(m_old[i_idx] - get_validated_m(m[i_idx]));
}
else
{
auto row_max = scale_s * get_validated_m(m[i_idx]);
return exp2(scale_s * m_old[i_idx] - row_max);
}
}
}();
#else

View File

@@ -62,7 +62,9 @@ struct BlockFmhaPipelineQRKSVSAsync
static constexpr bool kStoreLSE = Problem::kStoreLSE;
static constexpr bool kHasDropout = Problem::kHasDropout;
static_assert(CK_TILE_FMHA_FWD_FAST_EXP2 ||
static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 &&
(kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS ||
!kHasLogitsSoftCap)) ||
(!CK_TILE_FMHA_FWD_FAST_EXP2 && !kHasLogitsSoftCap));
// last dimension vector length used to create tensor view(and decide buffer_load vector length)