diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp index 2102fe768f..55f2354ff4 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp @@ -196,7 +196,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync const index_t* page_idx, const index_t stride_k, const index_t stride_v, - DropoutType& dropout) const + DropoutType& dropout, + const float sink_v) const { static_assert( std::is_same_v> && @@ -282,8 +283,16 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync auto l = MLBlockTileType{}; clear_tile(o_acc); - set_tile(m, -numeric::infinity()); - clear_tile(l); + if(!__builtin_isinf_sign(sink_v) || __builtin_isinf_sign(sink_v) > 0) + { + set_tile(m, sink_v); + set_tile(l, SMPLComputeDataType{1.0f}); + } + else + { + set_tile(m, -numeric::infinity()); + clear_tile(l); + } __builtin_amdgcn_sched_barrier(0); const auto q_origin = q_dram_window.get_window_origin(); @@ -887,7 +896,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync const index_t* page_idx, const index_t stride_k, const index_t stride_v, - DropoutType& dropout) const + DropoutType& dropout, + const float sink_v) const { return operator()(q_dram_block_window_tmp, identity{}, @@ -913,7 +923,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync page_idx, stride_k, stride_v, - dropout); + dropout, + sink_v); } }; diff --git a/test/ck_tile/fmha/test_fmha_fwd.cpp b/test/ck_tile/fmha/test_fmha_fwd.cpp index b81fa88aa2..e9bd2549b1 100644 --- a/test/ck_tile/fmha/test_fmha_fwd.cpp +++ b/test/ck_tile/fmha/test_fmha_fwd.cpp @@ -120,8 +120,8 @@ const ck_tile::stream_config stream_config{ 1, // rotating_count_ }; -#define COMMON_ARGS \ - init_method, static_cast(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))), 1, \ +#define COMMON_ARGS \ + init_method, static_cast(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))), 1, 0, \ stream_config auto EnableTestIf(bool condition) @@ -255,6 +255,7 @@ TEST(TestCkTileFmhaFwd, AppendKvWithBatchEffLensShouldFail) init_method, static_cast(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))), 0, + 1, //init_sink stream_config); ASSERT_EQ(result, fwd_result::invalid_args); } @@ -299,6 +300,7 @@ TEST(TestCkTileFmhaFwd, SplitKvWithGroupPaddingShouldFail) init_method, static_cast(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))), 0, + 1, //init_sink stream_config); ASSERT_EQ(result, fwd_result::invalid_args); } @@ -342,6 +344,7 @@ TEST(TestCkTileFmhaFwd, PagedKvWithGroupPaddingShouldFail) init_method, static_cast(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))), 0, + 1, //init_sink stream_config); ASSERT_EQ(result, fwd_result::invalid_args); }