From 125934a966afb4c341260d2e1a720661c4712cb6 Mon Sep 17 00:00:00 2001 From: Qianfeng Zhang Date: Sun, 14 Dec 2025 13:50:49 +0000 Subject: [PATCH] Simplifying the codes in defining KDram and QDram tile distribution --- ..._attention_fwd_pipeline_default_policy.hpp | 34 +++++-------------- 1 file changed, 8 insertions(+), 26 deletions(-) diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_default_policy.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_default_policy.hpp index 1b7065dea4..c912d76576 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_default_policy.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_default_policy.hpp @@ -173,9 +173,8 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize; - constexpr index_t MaxVectorSize = 16 / sizeof(VDataType); - constexpr index_t kMaxVecLoad = min(ElemPerThread, MaxVectorSize); - constexpr index_t kMinVecLoad = 4 / sizeof(VDataType); + constexpr index_t kMaxVecLoad = Problem::GetVDramTileAccessMaxVectorSize(); + constexpr index_t kMinVecLoad = 4 / sizeof(VDataType); // try to avoid writing sub-dword to LDS due to poor performance constexpr index_t kVecLoad = ((ElemPerThread / kMaxVecLoad) >= kMinVecLoad) @@ -330,19 +329,13 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy template CK_TILE_HOST_DEVICE static constexpr auto MakeQDramSingleRepMTileDistribution() { - using QKVDataType = remove_cvref_t; - constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kMPerBlock = GetQKBlockGemmSingleRepM(); constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kSubQKHeaddim; - constexpr index_t MaxVectorSize = 16 / sizeof(QKVDataType); + constexpr index_t kKVector = GetAlignmentQ(); - constexpr index_t ElemPerThread = (kMPerBlock * kKPerBlock) / kBlockSize; - static_assert(0 < ElemPerThread); - constexpr index_t kMaxVecLoad = min(ElemPerThread, MaxVectorSize); - - constexpr index_t KPerThread = kMaxVecLoad; + constexpr index_t KPerThread = kKVector; constexpr index_t KThreads = kKPerBlock / KPerThread; constexpr index_t MThreadPerWarp = get_warp_size() / KThreads; constexpr index_t NumWarps = kBlockSize / get_warp_size(); @@ -362,19 +355,13 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy template CK_TILE_HOST_DEVICE static constexpr auto MakeQDramTileDistribution() { - using QKVDataType = remove_cvref_t; - constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kMPerBlock = Problem::HstuAttentionTileSetting::kM0; constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kSubQKHeaddim; - constexpr index_t MaxVectorSize = 16 / sizeof(QKVDataType); + constexpr index_t kKVector = GetAlignmentQ(); - constexpr index_t ElemPerThread = (kMPerBlock * kKPerBlock) / kBlockSize; - static_assert(0 < ElemPerThread); - constexpr index_t kMaxVecLoad = min(ElemPerThread, MaxVectorSize); - - constexpr index_t KPerThread = kMaxVecLoad; + constexpr index_t KPerThread = kKVector; constexpr index_t KThreads = kKPerBlock / KPerThread; constexpr index_t MThreadPerWarp = get_warp_size() / KThreads; constexpr index_t NumWarps = kBlockSize / get_warp_size(); @@ -511,18 +498,13 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy template CK_TILE_HOST_DEVICE static constexpr auto MakeKDramTileDistribution() { - using QKVDataType = remove_cvref_t; - constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kNPerBlock = Problem::HstuAttentionTileSetting::kN0Sub; constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kSubQKHeaddim; - constexpr index_t MaxVectorSize = 16 / sizeof(QKVDataType); - constexpr index_t ElemPerThread = (kNPerBlock * kKPerBlock) / kBlockSize; + constexpr index_t kKVector = GetAlignmentK(); - constexpr index_t kMaxVecLoad = min(ElemPerThread, MaxVectorSize); - - constexpr index_t KPerThread = kMaxVecLoad; + constexpr index_t KPerThread = kKVector; constexpr index_t KThreads = kKPerBlock / KPerThread; constexpr index_t NThreadPerWarp = get_warp_size() / KThreads; constexpr index_t NumWarps = kBlockSize / get_warp_size();