From 9512f78616b06232e1b605488fee42dc8d689320 Mon Sep 17 00:00:00 2001 From: Po Yen Chen Date: Mon, 21 Apr 2025 23:33:52 +0000 Subject: [PATCH] Support logits_soft_cap in batch_decode() --- .../fmha/kernel/fmha_batch_decode_kernel.hpp | 2 + ...ck_fmha_batch_decode_pipeline_qr_ks_vs.hpp | 42 ++++++++++++++++--- ...k_fmha_batch_prefill_pipeline_qr_ks_vs.hpp | 2 +- 3 files changed, 40 insertions(+), 6 deletions(-) diff --git a/include/ck_tile/ops/fmha/kernel/fmha_batch_decode_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_batch_decode_kernel.hpp index 16e0fce7c0..236b3c8369 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_batch_decode_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_batch_decode_kernel.hpp @@ -915,6 +915,7 @@ struct FmhaBatchDecodeWithPagedKVCacheKernel mask, position_encoding, kargs.scale_s, + kargs, smem_ptr, kargs.kv_page_indices, kargs.stride_k, @@ -932,6 +933,7 @@ struct FmhaBatchDecodeWithPagedKVCacheKernel mask, position_encoding, kargs.scale_s, + kargs, smem_ptr, kargs.kv_page_indices, kargs.stride_k, diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_decode_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_decode_pipeline_qr_ks_vs.hpp index 8fdf6287d2..0d91b4cdbe 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_decode_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_decode_pipeline_qr_ks_vs.hpp @@ -136,7 +136,8 @@ struct BlockFmhaBatchDecodeWithPagedKVCachePipelineQRKSVS typename SAccElementFunction, typename PComputeElementFunction, typename OAccElementFunction, - typename PositionEncoding> + typename PositionEncoding, + typename LogitsSoftCapParams> CK_TILE_HOST_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile const QElementFunction& q_element_func, @@ -156,6 +157,7 @@ struct BlockFmhaBatchDecodeWithPagedKVCachePipelineQRKSVS FmhaMask mask, PositionEncoding position_encoding, float scale_s, + const LogitsSoftCapParams& logits_soft_cap_params, void* smem_ptr, const int32_t* kv_page_indices, const index_t stride_k, @@ -472,6 +474,19 @@ struct BlockFmhaBatchDecodeWithPagedKVCachePipelineQRKSVS s_acc = tile_elementwise_in(s_acc_element_func, s_acc); #if !CK_TILE_FMHA_FWD_FAST_EXP2 tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc); +#else + if constexpr(kHasLogitsSoftCap) + { + float scale_lo = + scale_s * 0.6931472f * logits_soft_cap_params.logits_soft_cap_rcp; + float logits_cap = + log2e_v * logits_soft_cap_params.logits_soft_cap; + tile_elementwise_inout( + [&scale_lo, &logits_cap](auto& x) { + x = logits_cap * tanh_fast(x * scale_lo); + }, + s_acc); + } #endif } move_tile_window(bias_dram_window, {0, kN0}); @@ -579,7 +594,14 @@ struct BlockFmhaBatchDecodeWithPagedKVCachePipelineQRKSVS } else { - p_compute(i_j_idx) = exp2(scale_s * s_new[i_j_idx] - row_max); + if constexpr(kHasLogitsSoftCap) + { + p_compute(i_j_idx) = exp2(s_new[i_j_idx] - get_validated_m(m[i_idx])); + } + else + { + p_compute(i_j_idx) = exp2(scale_s * s_new[i_j_idx] - row_max); + } } #else p_compute(i_j_idx) = exp(s_new[i_j_idx] - get_validated_m(m[i_idx])); @@ -608,8 +630,15 @@ struct BlockFmhaBatchDecodeWithPagedKVCachePipelineQRKSVS } else { - auto row_max = scale_s * get_validated_m(m[i_idx]); - return exp2(scale_s * m_old[i_idx] - row_max); + if constexpr(kHasLogitsSoftCap) + { + return exp2(m_old[i_idx] - get_validated_m(m[i_idx])); + } + else + { + auto row_max = scale_s * get_validated_m(m[i_idx]); + return exp2(scale_s * m_old[i_idx] - row_max); + } } }(); #else @@ -759,7 +788,8 @@ struct BlockFmhaBatchDecodeWithPagedKVCachePipelineQRKSVS typename VDramBlockWindowTmp, typename BiasDramBlockWindowTmp, typename LSEaccDramBlockWindowTmp, - typename PositionEncoding> + typename PositionEncoding, + typename LogitsSoftCapParams> CK_TILE_HOST_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile @@ -771,6 +801,7 @@ struct BlockFmhaBatchDecodeWithPagedKVCachePipelineQRKSVS FmhaMask mask, PositionEncoding position_encoding, float scale_s, + const LogitsSoftCapParams& logits_soft_cap_params, void* smem_ptr, const int32_t* kv_page_indices, const index_t stride_k, @@ -794,6 +825,7 @@ struct BlockFmhaBatchDecodeWithPagedKVCachePipelineQRKSVS mask, position_encoding, scale_s, + logits_soft_cap_params, smem_ptr, kv_page_indices, stride_k, diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs.hpp index 26dcb611e0..882b2ad57a 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs.hpp @@ -112,7 +112,7 @@ struct BlockFmhaBatchPrefillWithPagedKVCachePipelineQRKSVS else { return 1; - }; + } } }();