From 889b2d33fd55d3d19014f99697d85dd4c433f208 Mon Sep 17 00:00:00 2001 From: Po Yen Chen Date: Sun, 20 Apr 2025 07:55:36 +0000 Subject: [PATCH] Sync logits soft-capping across pipelines --- ...ock_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp | 41 +++++++++++++++---- .../pipeline/block_fmha_pipeline_qr_ks_vs.hpp | 41 +++++++++++++++---- .../block_fmha_pipeline_qr_ks_vs_async.hpp | 4 +- 3 files changed, 71 insertions(+), 15 deletions(-) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp index ce9d388693..384d327bbf 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp @@ -56,7 +56,9 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS static constexpr bool kIsPagedKV = Problem::kIsPagedKV; static constexpr bool kHasUnevenSplits = Problem::kHasUnevenSplits; - static_assert(CK_TILE_FMHA_FWD_FAST_EXP2 || + static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 && + (kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS || + !kHasLogitsSoftCap)) || (!CK_TILE_FMHA_FWD_FAST_EXP2 && !kHasLogitsSoftCap)); // last dimension vector length used to create tensor view(and decide buffer_load vector length) @@ -173,8 +175,6 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], "wrong!"); - (void)logits_soft_cap_params; - // K tile in LDS KDataType* k_lds_ptr = static_cast(static_cast( static_cast(smem_ptr) + Policy::template GetSmemSizeQ())); @@ -411,6 +411,19 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS s_acc = tile_elementwise_in(s_acc_element_func, s_acc); #if !CK_TILE_FMHA_FWD_FAST_EXP2 tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc); +#else + if constexpr(kHasLogitsSoftCap) + { + float scale_lo = scale_s * 0.6931472f; + tile_elementwise_inout( + [&scale_lo, + &logits_cap = logits_soft_cap_params.logits_soft_cap, + &logits_cap_rev = logits_soft_cap_params.logits_soft_cap_rcp](auto& x) { + x = log2e_v * logits_cap * + tanh_fast(x * scale_lo * logits_cap_rev); + }, + s_acc); + } #endif } move_tile_window(bias_dram_window, {0, kN0}); @@ -493,7 +506,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { constexpr auto i_idx = make_tuple(idx0); #if CK_TILE_FMHA_FWD_FAST_EXP2 - auto row_max = scale_s * get_validated_m(m[i_idx]); + [[maybe_unused]] auto row_max = scale_s * get_validated_m(m[i_idx]); #endif sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { constexpr auto i_j_idx = make_tuple(idx0, idx1); @@ -505,7 +518,14 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS } else { - p_compute(i_j_idx) = exp2(scale_s * s[i_j_idx] - row_max); + if constexpr(kHasLogitsSoftCap) + { + p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx])); + } + else + { + p_compute(i_j_idx) = exp2(scale_s * s[i_j_idx] - row_max); + } } #else p_compute(i_j_idx) = exp(s[i_j_idx] - get_validated_m(m[i_idx])); @@ -530,8 +550,15 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS } else { - auto row_max = scale_s * get_validated_m(m[i_idx]); - return exp2(scale_s * m_old[i_idx] - row_max); + if constexpr(kHasLogitsSoftCap) + { + return exp2(m_old[i_idx] - get_validated_m(m[i_idx])); + } + else + { + auto row_max = scale_s * get_validated_m(m[i_idx]); + return exp2(scale_s * m_old[i_idx] - row_max); + } } }(); #else diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp index 31788f5853..72f3da119c 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp @@ -57,7 +57,9 @@ struct BlockFmhaPipelineQRKSVS static constexpr bool kStoreLSE = Problem::kStoreLSE; static constexpr bool kHasDropout = Problem::kHasDropout; - static_assert(CK_TILE_FMHA_FWD_FAST_EXP2 || + static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 && + (kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS || + !kHasLogitsSoftCap)) || (!CK_TILE_FMHA_FWD_FAST_EXP2 && !kHasLogitsSoftCap)); // last dimension vector length used to create tensor view(and decide buffer_load vector length) @@ -171,8 +173,6 @@ struct BlockFmhaPipelineQRKSVS kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], "wrong!"); - (void)logits_soft_cap_params; - // K tile in LDS KDataType* k_lds_ptr = static_cast(static_cast( static_cast(smem_ptr) + Policy::template GetSmemSizeQ())); @@ -390,6 +390,19 @@ struct BlockFmhaPipelineQRKSVS s_acc = tile_elementwise_in(s_acc_element_func, s_acc); #if !CK_TILE_FMHA_FWD_FAST_EXP2 tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc); +#else + if constexpr(kHasLogitsSoftCap) + { + float scale_lo = scale_s * 0.6931472f; + tile_elementwise_inout( + [&scale_lo, + &logits_cap = logits_soft_cap_params.logits_soft_cap, + &logits_cap_rev = logits_soft_cap_params.logits_soft_cap_rcp](auto& x) { + x = log2e_v * logits_cap * + tanh_fast(x * scale_lo * logits_cap_rev); + }, + s_acc); + } #endif } move_tile_window(bias_dram_window, {0, kN0}); @@ -446,7 +459,7 @@ struct BlockFmhaPipelineQRKSVS sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { constexpr auto i_idx = make_tuple(idx0); #if CK_TILE_FMHA_FWD_FAST_EXP2 - auto row_max = scale_s * get_validated_m(m[i_idx]); + [[maybe_unused]] auto row_max = scale_s * get_validated_m(m[i_idx]); #endif sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { constexpr auto i_j_idx = make_tuple(idx0, idx1); @@ -458,7 +471,14 @@ struct BlockFmhaPipelineQRKSVS } else { - p_compute(i_j_idx) = exp2(scale_s * s[i_j_idx] - row_max); + if constexpr(kHasLogitsSoftCap) + { + p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx])); + } + else + { + p_compute(i_j_idx) = exp2(scale_s * s[i_j_idx] - row_max); + } } #else p_compute(i_j_idx) = exp(s[i_j_idx] - get_validated_m(m[i_idx])); @@ -483,8 +503,15 @@ struct BlockFmhaPipelineQRKSVS } else { - auto row_max = scale_s * get_validated_m(m[i_idx]); - return exp2(scale_s * m_old[i_idx] - row_max); + if constexpr(kHasLogitsSoftCap) + { + return exp2(m_old[i_idx] - get_validated_m(m[i_idx])); + } + else + { + auto row_max = scale_s * get_validated_m(m[i_idx]); + return exp2(scale_s * m_old[i_idx] - row_max); + } } }(); #else diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp index 8d33184442..2bc374fbe9 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp @@ -62,7 +62,9 @@ struct BlockFmhaPipelineQRKSVSAsync static constexpr bool kStoreLSE = Problem::kStoreLSE; static constexpr bool kHasDropout = Problem::kHasDropout; - static_assert(CK_TILE_FMHA_FWD_FAST_EXP2 || + static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 && + (kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS || + !kHasLogitsSoftCap)) || (!CK_TILE_FMHA_FWD_FAST_EXP2 && !kHasLogitsSoftCap)); // last dimension vector length used to create tensor view(and decide buffer_load vector length)