mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 11:47:48 +00:00
Sync logits soft-capping across pipelines
This commit is contained in:
@@ -56,7 +56,9 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
|
||||
static constexpr bool kIsPagedKV = Problem::kIsPagedKV;
|
||||
static constexpr bool kHasUnevenSplits = Problem::kHasUnevenSplits;
|
||||
|
||||
static_assert(CK_TILE_FMHA_FWD_FAST_EXP2 ||
|
||||
static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 &&
|
||||
(kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS ||
|
||||
!kHasLogitsSoftCap)) ||
|
||||
(!CK_TILE_FMHA_FWD_FAST_EXP2 && !kHasLogitsSoftCap));
|
||||
|
||||
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
|
||||
@@ -173,8 +175,6 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
|
||||
kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
|
||||
"wrong!");
|
||||
|
||||
(void)logits_soft_cap_params;
|
||||
|
||||
// K tile in LDS
|
||||
KDataType* k_lds_ptr = static_cast<KDataType*>(static_cast<void*>(
|
||||
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeQ<Problem>()));
|
||||
@@ -411,6 +411,19 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
|
||||
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
|
||||
#if !CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc);
|
||||
#else
|
||||
if constexpr(kHasLogitsSoftCap)
|
||||
{
|
||||
float scale_lo = scale_s * 0.6931472f;
|
||||
tile_elementwise_inout(
|
||||
[&scale_lo,
|
||||
&logits_cap = logits_soft_cap_params.logits_soft_cap,
|
||||
&logits_cap_rev = logits_soft_cap_params.logits_soft_cap_rcp](auto& x) {
|
||||
x = log2e_v<SaccDataType> * logits_cap *
|
||||
tanh_fast<SaccDataType>(x * scale_lo * logits_cap_rev);
|
||||
},
|
||||
s_acc);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
move_tile_window(bias_dram_window, {0, kN0});
|
||||
@@ -493,7 +506,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
|
||||
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
auto row_max = scale_s * get_validated_m(m[i_idx]);
|
||||
[[maybe_unused]] auto row_max = scale_s * get_validated_m(m[i_idx]);
|
||||
#endif
|
||||
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
@@ -505,7 +518,14 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
|
||||
}
|
||||
else
|
||||
{
|
||||
p_compute(i_j_idx) = exp2(scale_s * s[i_j_idx] - row_max);
|
||||
if constexpr(kHasLogitsSoftCap)
|
||||
{
|
||||
p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx]));
|
||||
}
|
||||
else
|
||||
{
|
||||
p_compute(i_j_idx) = exp2(scale_s * s[i_j_idx] - row_max);
|
||||
}
|
||||
}
|
||||
#else
|
||||
p_compute(i_j_idx) = exp(s[i_j_idx] - get_validated_m(m[i_idx]));
|
||||
@@ -530,8 +550,15 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
|
||||
}
|
||||
else
|
||||
{
|
||||
auto row_max = scale_s * get_validated_m(m[i_idx]);
|
||||
return exp2(scale_s * m_old[i_idx] - row_max);
|
||||
if constexpr(kHasLogitsSoftCap)
|
||||
{
|
||||
return exp2(m_old[i_idx] - get_validated_m(m[i_idx]));
|
||||
}
|
||||
else
|
||||
{
|
||||
auto row_max = scale_s * get_validated_m(m[i_idx]);
|
||||
return exp2(scale_s * m_old[i_idx] - row_max);
|
||||
}
|
||||
}
|
||||
}();
|
||||
#else
|
||||
|
||||
@@ -57,7 +57,9 @@ struct BlockFmhaPipelineQRKSVS
|
||||
static constexpr bool kStoreLSE = Problem::kStoreLSE;
|
||||
static constexpr bool kHasDropout = Problem::kHasDropout;
|
||||
|
||||
static_assert(CK_TILE_FMHA_FWD_FAST_EXP2 ||
|
||||
static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 &&
|
||||
(kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS ||
|
||||
!kHasLogitsSoftCap)) ||
|
||||
(!CK_TILE_FMHA_FWD_FAST_EXP2 && !kHasLogitsSoftCap));
|
||||
|
||||
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
|
||||
@@ -171,8 +173,6 @@ struct BlockFmhaPipelineQRKSVS
|
||||
kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
|
||||
"wrong!");
|
||||
|
||||
(void)logits_soft_cap_params;
|
||||
|
||||
// K tile in LDS
|
||||
KDataType* k_lds_ptr = static_cast<KDataType*>(static_cast<void*>(
|
||||
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeQ<Problem>()));
|
||||
@@ -390,6 +390,19 @@ struct BlockFmhaPipelineQRKSVS
|
||||
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
|
||||
#if !CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc);
|
||||
#else
|
||||
if constexpr(kHasLogitsSoftCap)
|
||||
{
|
||||
float scale_lo = scale_s * 0.6931472f;
|
||||
tile_elementwise_inout(
|
||||
[&scale_lo,
|
||||
&logits_cap = logits_soft_cap_params.logits_soft_cap,
|
||||
&logits_cap_rev = logits_soft_cap_params.logits_soft_cap_rcp](auto& x) {
|
||||
x = log2e_v<SaccDataType> * logits_cap *
|
||||
tanh_fast<SaccDataType>(x * scale_lo * logits_cap_rev);
|
||||
},
|
||||
s_acc);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
move_tile_window(bias_dram_window, {0, kN0});
|
||||
@@ -446,7 +459,7 @@ struct BlockFmhaPipelineQRKSVS
|
||||
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
auto row_max = scale_s * get_validated_m(m[i_idx]);
|
||||
[[maybe_unused]] auto row_max = scale_s * get_validated_m(m[i_idx]);
|
||||
#endif
|
||||
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
@@ -458,7 +471,14 @@ struct BlockFmhaPipelineQRKSVS
|
||||
}
|
||||
else
|
||||
{
|
||||
p_compute(i_j_idx) = exp2(scale_s * s[i_j_idx] - row_max);
|
||||
if constexpr(kHasLogitsSoftCap)
|
||||
{
|
||||
p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx]));
|
||||
}
|
||||
else
|
||||
{
|
||||
p_compute(i_j_idx) = exp2(scale_s * s[i_j_idx] - row_max);
|
||||
}
|
||||
}
|
||||
#else
|
||||
p_compute(i_j_idx) = exp(s[i_j_idx] - get_validated_m(m[i_idx]));
|
||||
@@ -483,8 +503,15 @@ struct BlockFmhaPipelineQRKSVS
|
||||
}
|
||||
else
|
||||
{
|
||||
auto row_max = scale_s * get_validated_m(m[i_idx]);
|
||||
return exp2(scale_s * m_old[i_idx] - row_max);
|
||||
if constexpr(kHasLogitsSoftCap)
|
||||
{
|
||||
return exp2(m_old[i_idx] - get_validated_m(m[i_idx]));
|
||||
}
|
||||
else
|
||||
{
|
||||
auto row_max = scale_s * get_validated_m(m[i_idx]);
|
||||
return exp2(scale_s * m_old[i_idx] - row_max);
|
||||
}
|
||||
}
|
||||
}();
|
||||
#else
|
||||
|
||||
@@ -62,7 +62,9 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
static constexpr bool kStoreLSE = Problem::kStoreLSE;
|
||||
static constexpr bool kHasDropout = Problem::kHasDropout;
|
||||
|
||||
static_assert(CK_TILE_FMHA_FWD_FAST_EXP2 ||
|
||||
static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 &&
|
||||
(kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS ||
|
||||
!kHasLogitsSoftCap)) ||
|
||||
(!CK_TILE_FMHA_FWD_FAST_EXP2 && !kHasLogitsSoftCap));
|
||||
|
||||
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
|
||||
|
||||
Reference in New Issue
Block a user