From 13fdb382b2ba53c993d1d71a2039d61bfef91d73 Mon Sep 17 00:00:00 2001 From: Qianfeng Zhang Date: Tue, 16 Dec 2025 14:31:11 +0000 Subject: [PATCH] Change to the Q/K DramTile encoding and renaming in V/VShuffled DramTile --- ..._attention_fwd_pipeline_default_policy.hpp | 65 +++++++++---------- 1 file changed, 31 insertions(+), 34 deletions(-) diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_default_policy.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_default_policy.hpp index c912d76576..3f66f9b7a9 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_default_policy.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_default_policy.hpp @@ -341,15 +341,15 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy constexpr index_t NumWarps = kBlockSize / get_warp_size(); constexpr index_t MPerThread = kMPerBlock / (MThreadPerWarp * NumWarps); - // for Q-Tile [64, 128], the encoding is [4W * 4T * 4E, 16T * 8E] + // for Q-Tile [64, 128], the encoding is [4W * 4E * 4T, 16T * 8E] return make_static_tile_distribution( tile_distribution_encoding, - tuple, + tuple, sequence>, tuple, sequence<1, 2>>, - tuple, sequence<1, 0>>, + tuple, sequence<2, 0>>, sequence<1, 2>, - sequence<2, 1>>{}); + sequence<1, 1>>{}); } template @@ -367,15 +367,15 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy constexpr index_t NumWarps = kBlockSize / get_warp_size(); constexpr index_t MPerThread = kMPerBlock / (MThreadPerWarp * NumWarps); - // for Q-Tile [64, 128], the encoding is [4W * 4T * 4E, 16T * 8E] + // for Q-Tile [64, 128], the encoding is [4W * 4E * 4T, 16T * 8E] return make_static_tile_distribution( tile_distribution_encoding, - tuple, + tuple, sequence>, tuple, sequence<1, 2>>, - tuple, sequence<1, 0>>, + tuple, sequence<2, 0>>, sequence<1, 2>, - sequence<2, 1>>{}); + sequence<1, 1>>{}); } template @@ -512,10 +512,10 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy return make_static_tile_distribution( tile_distribution_encoding, - tuple, + tuple, sequence>, tuple, sequence<1, 2>>, - tuple, sequence<1, 0>>, + tuple, sequence<2, 0>>, sequence<1, 2>, sequence<0, 1>>{}); } @@ -621,20 +621,19 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy if constexpr(!Problem::kUseTrLoad) { - constexpr index_t N1 = GetAlignmentV(); - constexpr index_t N0 = kNPerBlock / N1; + constexpr index_t NPerThread = GetAlignmentV(); + constexpr index_t NThreads = kNPerBlock / NPerThread; constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize; - static_assert(ElemPerThread % N1 == 0); - - constexpr index_t K2 = ElemPerThread / N1; - constexpr index_t K1 = get_warp_size() / N0; - constexpr index_t K0 = kBlockSize / get_warp_size(); + constexpr index_t KPerThread = ElemPerThread / NPerThread; + constexpr index_t KThreadPerWarp = get_warp_size() / NThreads; + constexpr index_t NumWarps = kBlockSize / get_warp_size(); return make_static_tile_distribution( tile_distribution_encoding, - tuple, sequence>, + tuple, + sequence>, tuple, sequence<2, 1>>, tuple, sequence<1, 0>>, sequence<2, 1>, @@ -642,20 +641,19 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy } else { - constexpr index_t N1 = GetAlignmentV(); - constexpr index_t N0 = kNPerBlock / N1; + constexpr index_t NPerThread = GetAlignmentV(); + constexpr index_t NThreads = kNPerBlock / NPerThread; constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize; - static_assert(ElemPerThread % N1 == 0); - - constexpr index_t K2 = ElemPerThread / N1; - constexpr index_t K1 = get_warp_size() / N0; - constexpr index_t K0 = kBlockSize / get_warp_size(); + constexpr index_t KPerThread = ElemPerThread / NPerThread; + constexpr index_t KThreadPerWarp = get_warp_size() / NThreads; + constexpr index_t NumWarps = kBlockSize / get_warp_size(); return make_static_tile_distribution( tile_distribution_encoding, - tuple, sequence>, + tuple, + sequence>, tuple, sequence<1, 2>>, tuple, sequence<1, 0>>, sequence<1, 2>, @@ -671,20 +669,19 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy constexpr index_t kNPerBlock = Problem::HstuAttentionTileSetting::kN1; constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kK1; - constexpr index_t N1 = GetAlignmentV(); - constexpr index_t N0 = kNPerBlock / N1; + constexpr index_t NPerThread = GetAlignmentV(); + constexpr index_t NThreads = kNPerBlock / NPerThread; constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize; - static_assert(ElemPerThread % N1 == 0); - - constexpr index_t K2 = ElemPerThread / N1; - constexpr index_t K1 = get_warp_size() / N0; - constexpr index_t K0 = kBlockSize / get_warp_size(); + constexpr index_t KPerThread = ElemPerThread / NPerThread; + constexpr index_t KThreadPerWarp = get_warp_size() / NThreads; + constexpr index_t NumWarps = kBlockSize / get_warp_size(); return make_static_tile_distribution( tile_distribution_encoding, - tuple, sequence>, + tuple, + sequence>, tuple, sequence<2, 1>>, tuple, sequence<1, 0>>, sequence<1, 2>,