diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index ac37f5dd06..fe426f925e 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -847,6 +847,7 @@ struct FmhaFwdKernel window_size_left, window_size_right, mask_type, + 0, // min_seqlen_q p_drop, s_randval, std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset))); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp index cc532040e8..074a94613c 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp @@ -28,6 +28,7 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch using OaccDataType = remove_cvref_t; using ODataType = remove_cvref_t; using FmhaMask = remove_cvref_t; + using AttentionVariant = remove_cvref_t; using BlockFmhaShape = remove_cvref_t; using VLayout = remove_cvref_t; @@ -54,6 +55,7 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch static constexpr auto BiasEnum = Problem::BiasEnum; static constexpr bool kStoreLSE = Problem::kStoreLSE; static constexpr bool kHasDropout = Problem::kHasDropout; + static constexpr bool kHasLogitsSoftCap = Problem::kHasLogitsSoftCap; // last dimension vector length used to create tensor view(and decide buffer_load vector length) // ... together with tensor distribution. tensor dist should able to overwrite this @@ -127,7 +129,9 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch typename SAccElementFunction, typename PComputeElementFunction, typename OAccElementFunction, - typename PositionEncoding> + typename PositionEncoding, + typename AttentionVariantParams, + typename BlockIndices> CK_TILE_HOST_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kSubQKHeaddim tile const QElementFunction& q_element_func, @@ -146,6 +150,9 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch FmhaMask mask, PositionEncoding position_encoding, float scale_s, + const AttentionVariant& /* unused */, + const AttentionVariantParams& /* unused */, + const BlockIndices& /* unused */, void* smem_ptr, DropoutType& dropout) const { @@ -890,7 +897,9 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch typename BiasDramBlockWindowTmp, typename RandValDramBlockWindowTmp, typename LSEDramBlockWindowTmp, - typename PositionEncoding> + typename PositionEncoding, + typename AttentionVariantParams, + typename BlockIndices> CK_TILE_HOST_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile @@ -901,6 +910,9 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch FmhaMask mask, PositionEncoding position_encoding, float scale_s, + const AttentionVariant& variant, + const AttentionVariantParams& variant_params, + const BlockIndices& block_indices, void* smem_ptr, DropoutType& dropout) const { @@ -921,6 +933,9 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch mask, position_encoding, scale_s, + variant, + variant_params, + block_indices, smem_ptr, dropout); }