From e841981ddd3532f1c908cd3737d635f3f9876d0f Mon Sep 17 00:00:00 2001 From: Qianfeng Zhang Date: Sat, 23 May 2026 07:22:29 +0000 Subject: [PATCH] Update to MakeLSEaccDramTileDistribution trying to assign more threads to MThreadPerWarp so that block_tile_reduce_sync() work on less KThreadPerWarp --- ...on_fwd_splitkv_combine_pipeline_policy.hpp | 51 ++++++------------- .../hstu_attention_pipeline_problem.hpp | 2 + 2 files changed, 17 insertions(+), 36 deletions(-) diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_splitkv_combine_pipeline_policy.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_splitkv_combine_pipeline_policy.hpp index c0e5140951..f92725339a 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_splitkv_combine_pipeline_policy.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_splitkv_combine_pipeline_policy.hpp @@ -95,44 +95,23 @@ struct HstuAttentionFwdSplitKVCombinePipelinePolicy static_assert((OtherK & (OtherK - 1)) == 0, "Check failed!"); - constexpr index_t KPerThread = KVector; + static_assert(kMPerBlock <= NumWarps * get_warp_size(), "Check failed!"); - if constexpr(OtherK < get_warp_size()) - { - // try to assign more consecutive threads on dim-K - constexpr index_t KThreads = OtherK; + constexpr index_t MThreadPerWarp = kMPerBlock / NumWarps; + constexpr index_t MPerThread = 1; + constexpr index_t KThreadPerWarp = get_warp_size() / MThreadPerWarp; + constexpr index_t KRepPerThread = OtherK / KThreadPerWarp; - constexpr index_t MThreadPerWarp = get_warp_size() / KThreads; - constexpr index_t MPerThread = kMPerBlock / (MThreadPerWarp * NumWarps); - - // 32/64 Threads should be in lay-out [kThreads, MThreadPerWarp] since the tile - // distribution will be used by block_tile_reduce_sync(..., bool_constant<0>{}) - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, - sequence>, - tuple, sequence<2, 1>>, - tuple, sequence<0, 2>>, - sequence<1, 2>, - sequence<0, 1>>{}); - } - else - { - // all threads in the warp are assigned on dim-K - constexpr index_t KThreads = get_warp_size(); - constexpr index_t KRepPerThread = OtherK / KThreads; - - constexpr index_t MPerThread = kMPerBlock / NumWarps; - - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, - sequence>, - tuple, sequence<2>>, - tuple, sequence<1>>, - sequence<1, 2, 2>, - sequence<0, 0, 2>>{}); - }; + // 32/64 Threads should be in lay-out [kThreads, MThreadPerWarp] since the tile + // distribution will be used by block_tile_reduce_sync(..., bool_constant<0>{}) + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, + sequence>, + tuple, sequence<2, 1>>, + tuple, sequence<1, 0>>, + sequence<1, 2, 2>, + sequence<2, 0, 2>>{}); } template diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_pipeline_problem.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_pipeline_problem.hpp index c32b972ffa..10688b5cf2 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_pipeline_problem.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_pipeline_problem.hpp @@ -152,6 +152,8 @@ struct HstuAttentionFwdSplitKVCombinePipelineProblem static constexpr index_t kBlockSize = CombineTileSetting_::NumWarps * get_warp_size(); static constexpr index_t kMaxSplits = kMaxSplits_; + static_assert((kMaxSplits == 0) || (kM * kMaxSplits >= kBlockSize), "Check failed!"); + CK_TILE_HOST_DEVICE static constexpr auto GetOaccDramTileAccessMaxVectorSize() { constexpr index_t kMPerBlock = kM;