From d85ca87d9704a770c817aba8f7779437e3c8bcf8 Mon Sep 17 00:00:00 2001 From: Haocong WANG Date: Tue, 23 Sep 2025 15:00:46 +0800 Subject: [PATCH] [CK_TILE] FMHA FWD bug fix (#2888) * tempsave debug * fix the bug in fmha fwd_kernel * Remove unnecessary changes * Fix the buggy part * remove fmha fwd known failure cases [ROCm/composable_kernel commit: b6e899438631118ff962a6be12cabc9930366267] --- .../script/fmha_fwd_known_fails_gfx950.txt | 4 -- .../ops/fmha/kernel/fmha_fwd_kernel.hpp | 39 +++++++++++-------- ...ck_fmha_pipeline_qr_ks_vs_async_trload.hpp | 1 + 3 files changed, 24 insertions(+), 20 deletions(-) diff --git a/example/ck_tile/01_fmha/script/fmha_fwd_known_fails_gfx950.txt b/example/ck_tile/01_fmha/script/fmha_fwd_known_fails_gfx950.txt index 90c5e2b7fb..e69de29bb2 100644 --- a/example/ck_tile/01_fmha/script/fmha_fwd_known_fails_gfx950.txt +++ b/example/ck_tile/01_fmha/script/fmha_fwd_known_fails_gfx950.txt @@ -1,4 +0,0 @@ -tile_example_fmha_fwd -prec=fp16 -mode=0 -b=2 -h=1 -d=128 -d_v=24 -s=3 -s_k=99 -bias=n -p_drop=0.0 -lse=0 -iperm=0 -operm=0 -mask=2 -num_splits=1 -page_block_size=0 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -tile_example_fmha_fwd -prec=fp16 -mode=0 -b=1 -h=2 -h_k=1 -d=128 -s=1 -s_k=10 -s_kpad=32 -bias=n -p_drop=0.0 -lse=0 -iperm=0 -operm=0 -mask=2 -num_splits=1 -page_block_size=0 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -tile_example_fmha_fwd -prec=fp16 -mode=0 -b=2 -h=1 -d=128 -d_v=24 -s=3 -s_k=99 -bias=n -p_drop=0.0 -lse=0 -iperm=1 -operm=1 -mask=2 -num_splits=1 -page_block_size=0 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -tile_example_fmha_fwd -prec=fp16 -mode=0 -b=1 -h=2 -h_k=1 -d=128 -s=1 -s_k=10 -s_kpad=32 -bias=n -p_drop=0.0 -lse=0 -iperm=1 -operm=1 -mask=2 -num_splits=1 -page_block_size=0 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index 58fdad149a..e562f6dd5a 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -1767,6 +1767,9 @@ struct FmhaFwdKernel make_tuple(number{}, number{}), sequence{}); + constexpr auto kDramTileK = + FmhaPipeline::kKLoadOnce ? FmhaPipeline::kQKHeaddim : FmhaPipeline::kK0; + #if CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD constexpr index_t LDSLayerSize = 256 / sizeof(KDataType); constexpr index_t XorLengthFold = LDSLayerSize / (FmhaPipeline::kQKHeaddim); @@ -1835,32 +1838,36 @@ struct FmhaFwdKernel { const auto k_dram_unmerged = transform_tensor_view( k_dram_pad, - make_tuple( - make_pass_through_transform(height), - make_unmerge_transform(make_tuple( - number{}, - number{}))), + make_tuple(make_pass_through_transform(height), + make_unmerge_transform( + make_tuple(number{}, + number{}, + number{}))), make_tuple(sequence<0>{}, sequence<1>{}), - make_tuple(sequence<0>{}, sequence<1, 2>{})); + make_tuple(sequence<0>{}, sequence<1, 2, 3>{})); const auto k_dram_permuted = transform_tensor_view( k_dram_unmerged, make_tuple( make_xor_transform(make_tuple( - height, - number{})), + height, number{})), + make_pass_through_transform( + number{}), make_pass_through_transform(number{})), - make_tuple(sequence<0, 1>{}, sequence<2>{}), - make_tuple(sequence<0, 1>{}, sequence<2>{})); + make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{}), + make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{})); return transform_tensor_view( k_dram_permuted, - make_tuple( - make_pass_through_transform(height), - make_merge_transform_v3_division_mod(make_tuple( - number{}, - number{}))), - make_tuple(sequence<0>{}, sequence<1, 2>{}), + make_tuple(make_pass_through_transform(height), + make_merge_transform_v3_division_mod( + make_tuple(number{}, + number{}, + number{}))), + make_tuple(sequence<0>{}, sequence<1, 2, 3>{}), make_tuple(sequence<0>{}, sequence<1>{})); } }; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp index aafe481d2b..b2c1b06955 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp @@ -37,6 +37,7 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload using VLayout = remove_cvref_t; static constexpr bool kQLoadOnce = true; // if q_tile load whole block length (hdim) at once static_assert(kQLoadOnce == Policy::QLoadOnce); + static constexpr bool kKLoadOnce = BlockFmhaShape::kM0 >= 64; static constexpr index_t kBlockSize = Problem::kBlockSize;