From 01a5cf111156757dce524cd09cc485e94894f96c Mon Sep 17 00:00:00 2001 From: "assistant-librarian[bot]" Date: Tue, 23 Sep 2025 07:13:16 +0000 Subject: [PATCH] Merge commit 'b6e899438631118ff962a6be12cabc9930366267' into develop --- example/ck_tile/01_fmha/fmha_bwd_runner.hpp | 10 ++++- .../script/fmha_fwd_known_fails_gfx950.txt | 4 -- .../ops/fmha/kernel/fmha_fwd_kernel.hpp | 39 +++++++++++-------- ...ck_fmha_pipeline_qr_ks_vs_async_trload.hpp | 1 + test/ck_tile/fmha/test_fmha_bwd_bf16.cpp | 2 +- test/ck_tile/fmha/test_fmha_bwd_fp16.cpp | 2 +- 6 files changed, 34 insertions(+), 24 deletions(-) diff --git a/example/ck_tile/01_fmha/fmha_bwd_runner.hpp b/example/ck_tile/01_fmha/fmha_bwd_runner.hpp index 3a5b5b4603..d861b351d4 100644 --- a/example/ck_tile/01_fmha/fmha_bwd_runner.hpp +++ b/example/ck_tile/01_fmha/fmha_bwd_runner.hpp @@ -763,15 +763,21 @@ bwd_result fmha_bwd_run(mode_enum mode, ck_tile::FillConstant{ck_tile::numeric::infinity()}(dq_host); ck_tile::FillConstant{ck_tile::numeric::infinity()}(dk_host); ck_tile::FillConstant{ck_tile::numeric::infinity()}(dv_host); + ck_tile::FillConstant{ck_tile::numeric::infinity()}(dq_acc_host); dq_buf.ToDevice(dq_host.data()); dk_buf.ToDevice(dk_host.data()); dv_buf.ToDevice(dv_host.data()); + dq_acc_buf.ToDevice(dq_acc_host.data()); o_buf.ToDevice(o_host.data()); lse_buf.ToDevice(lse_host.data()); - dq_buf.SetZero(); dbias_buf.SetZero(); - dq_acc_buf.SetZero(); + + // non-deterministic kernels use atomic add to write dq + // Some block may be skipped with causal mask and dq are not set to zeros + // In these cases thus we need to zero out it first + if(!deterministic || mask.type == mask_enum::no_mask) + dq_acc_buf.SetZero(); ck_tile::stream_config stream_config_v{nullptr, true, 0, 0, 1}; fmha_bwd(fmha_traits, fmha_args, stream_config_v); diff --git a/example/ck_tile/01_fmha/script/fmha_fwd_known_fails_gfx950.txt b/example/ck_tile/01_fmha/script/fmha_fwd_known_fails_gfx950.txt index 90c5e2b7fb..e69de29bb2 100644 --- a/example/ck_tile/01_fmha/script/fmha_fwd_known_fails_gfx950.txt +++ b/example/ck_tile/01_fmha/script/fmha_fwd_known_fails_gfx950.txt @@ -1,4 +0,0 @@ -tile_example_fmha_fwd -prec=fp16 -mode=0 -b=2 -h=1 -d=128 -d_v=24 -s=3 -s_k=99 -bias=n -p_drop=0.0 -lse=0 -iperm=0 -operm=0 -mask=2 -num_splits=1 -page_block_size=0 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -tile_example_fmha_fwd -prec=fp16 -mode=0 -b=1 -h=2 -h_k=1 -d=128 -s=1 -s_k=10 -s_kpad=32 -bias=n -p_drop=0.0 -lse=0 -iperm=0 -operm=0 -mask=2 -num_splits=1 -page_block_size=0 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -tile_example_fmha_fwd -prec=fp16 -mode=0 -b=2 -h=1 -d=128 -d_v=24 -s=3 -s_k=99 -bias=n -p_drop=0.0 -lse=0 -iperm=1 -operm=1 -mask=2 -num_splits=1 -page_block_size=0 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -tile_example_fmha_fwd -prec=fp16 -mode=0 -b=1 -h=2 -h_k=1 -d=128 -s=1 -s_k=10 -s_kpad=32 -bias=n -p_drop=0.0 -lse=0 -iperm=1 -operm=1 -mask=2 -num_splits=1 -page_block_size=0 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index 58fdad149a..e562f6dd5a 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -1767,6 +1767,9 @@ struct FmhaFwdKernel make_tuple(number{}, number{}), sequence{}); + constexpr auto kDramTileK = + FmhaPipeline::kKLoadOnce ? FmhaPipeline::kQKHeaddim : FmhaPipeline::kK0; + #if CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD constexpr index_t LDSLayerSize = 256 / sizeof(KDataType); constexpr index_t XorLengthFold = LDSLayerSize / (FmhaPipeline::kQKHeaddim); @@ -1835,32 +1838,36 @@ struct FmhaFwdKernel { const auto k_dram_unmerged = transform_tensor_view( k_dram_pad, - make_tuple( - make_pass_through_transform(height), - make_unmerge_transform(make_tuple( - number{}, - number{}))), + make_tuple(make_pass_through_transform(height), + make_unmerge_transform( + make_tuple(number{}, + number{}, + number{}))), make_tuple(sequence<0>{}, sequence<1>{}), - make_tuple(sequence<0>{}, sequence<1, 2>{})); + make_tuple(sequence<0>{}, sequence<1, 2, 3>{})); const auto k_dram_permuted = transform_tensor_view( k_dram_unmerged, make_tuple( make_xor_transform(make_tuple( - height, - number{})), + height, number{})), + make_pass_through_transform( + number{}), make_pass_through_transform(number{})), - make_tuple(sequence<0, 1>{}, sequence<2>{}), - make_tuple(sequence<0, 1>{}, sequence<2>{})); + make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{}), + make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{})); return transform_tensor_view( k_dram_permuted, - make_tuple( - make_pass_through_transform(height), - make_merge_transform_v3_division_mod(make_tuple( - number{}, - number{}))), - make_tuple(sequence<0>{}, sequence<1, 2>{}), + make_tuple(make_pass_through_transform(height), + make_merge_transform_v3_division_mod( + make_tuple(number{}, + number{}, + number{}))), + make_tuple(sequence<0>{}, sequence<1, 2, 3>{}), make_tuple(sequence<0>{}, sequence<1>{})); } }; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp index aafe481d2b..b2c1b06955 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp @@ -37,6 +37,7 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload using VLayout = remove_cvref_t; static constexpr bool kQLoadOnce = true; // if q_tile load whole block length (hdim) at once static_assert(kQLoadOnce == Policy::QLoadOnce); + static constexpr bool kKLoadOnce = BlockFmhaShape::kM0 >= 64; static constexpr index_t kBlockSize = Problem::kBlockSize; diff --git a/test/ck_tile/fmha/test_fmha_bwd_bf16.cpp b/test/ck_tile/fmha/test_fmha_bwd_bf16.cpp index cd143e8e83..077e45a10d 100644 --- a/test/ck_tile/fmha/test_fmha_bwd_bf16.cpp +++ b/test/ck_tile/fmha/test_fmha_bwd_bf16.cpp @@ -16,6 +16,6 @@ const auto HDimValues = const auto ModeValues = Values(mode_enum::batch, mode_enum::group); -constexpr std::string init_method = "uf"; +constexpr auto init_method = "uf"; #include "test_fmha_bwd.inc" diff --git a/test/ck_tile/fmha/test_fmha_bwd_fp16.cpp b/test/ck_tile/fmha/test_fmha_bwd_fp16.cpp index 4bb1e04ad0..86621b0494 100644 --- a/test/ck_tile/fmha/test_fmha_bwd_fp16.cpp +++ b/test/ck_tile/fmha/test_fmha_bwd_fp16.cpp @@ -16,6 +16,6 @@ const auto HDimValues = const auto ModeValues = Values(mode_enum::batch, mode_enum::group); -constexpr std::string init_method = "uf"; +constexpr auto init_method = "uf"; #include "test_fmha_bwd.inc"