From db39b44bab5ff3e666dd3d11be619ea35e14aa8d Mon Sep 17 00:00:00 2001 From: Qianfeng Zhang Date: Thu, 11 Dec 2025 10:47:54 +0000 Subject: [PATCH] Update in the implementation of GetAlignmentQ/GetAlignmentK/GetAlignmentV --- ..._attention_fwd_pipeline_default_policy.hpp | 56 ++++++++-------- .../hstu_attention_pipeline_problem.hpp | 64 +++++++++++++++++++ 2 files changed, 89 insertions(+), 31 deletions(-) diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_default_policy.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_default_policy.hpp index a69c0fe394..7c9241ebfb 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_default_policy.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_default_policy.hpp @@ -116,18 +116,23 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy template CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentQ() { - using QDataType = remove_cvref_t; + if constexpr(Problem::kLoadWholeQTileOnceThroughLds) + { + return Problem::GetQDramTileAccessMaxVectorSize(); + } + else + { + using QDataType = remove_cvref_t; - constexpr index_t kBlockSize = Problem::kBlockSize; - constexpr index_t kMPerBlock = Problem::kLoadWholeQTileOnceThroughLds - ? Problem::HstuAttentionTileSetting::kM0 - : GetQKBlockGemmSingleRepM(); - constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kSubQKHeaddim; + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kMPerBlock = GetQKBlockGemmSingleRepM(); + constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kSubQKHeaddim; - constexpr index_t MaxVectorSize = 16 / sizeof(QDataType); - constexpr index_t ElemPerThread = (kMPerBlock * kKPerBlock) / kBlockSize; - - return min(MaxVectorSize, ElemPerThread); + return Problem::template GetDramTileAccessMaxVectorSize(); + }; } template @@ -142,16 +147,7 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy template CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentK() { - using KDataType = remove_cvref_t; - - constexpr index_t kBlockSize = Problem::kBlockSize; - constexpr index_t kNPerBlock = Problem::HstuAttentionTileSetting::kN0Sub; - constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kSubQKHeaddim; - - constexpr index_t MaxVectorSize = 16 / sizeof(KDataType); - constexpr index_t ElemPerThread = (kNPerBlock * kKPerBlock) / kBlockSize; - - return min(MaxVectorSize, ElemPerThread); + return Problem::GetKDramTileAccessMaxVectorSize(); } template @@ -166,21 +162,22 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy template CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentV() { - - using VDataType = remove_cvref_t; - - constexpr index_t kBlockSize = Problem::kBlockSize; - constexpr index_t kNPerBlock = Problem::HstuAttentionTileSetting::kN1; - constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kK1; - + // special consideration when shuffling is required before storing V to LDS if constexpr(!Problem::kUseTrLoad) { + using VDataType = remove_cvref_t; + + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNPerBlock = Problem::HstuAttentionTileSetting::kN1; + constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kK1; + constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize; constexpr index_t MaxVectorSize = 16 / sizeof(VDataType); constexpr index_t kMaxVecLoad = min(ElemPerThread, MaxVectorSize); constexpr index_t kMinVecLoad = 4 / sizeof(VDataType); + // try to avoid writing sub-dword to LDS due to poor performance constexpr index_t kVecLoad = ((ElemPerThread / kMaxVecLoad) >= kMinVecLoad) ? kMaxVecLoad : (ElemPerThread / kMinVecLoad); @@ -189,10 +186,7 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy } else { - constexpr index_t MaxVectorSize = 16 / sizeof(VDataType); - constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize; - - return min(MaxVectorSize, ElemPerThread); + return Problem::GetVDramTileAccessMaxVectorSize(); }; } diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_pipeline_problem.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_pipeline_problem.hpp index a90d7eef69..ac09975aa5 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_pipeline_problem.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_pipeline_problem.hpp @@ -54,6 +54,70 @@ struct HstuAttentionFwdPipelineProblem static constexpr index_t kNumGemm0Warps = AttentionTileSetting_::NumGemm0Warps; static constexpr index_t kNumGemm1Warps = AttentionTileSetting_::NumGemm1Warps; static constexpr index_t kBlockSize = AttentionTileSetting_::NumWarps * get_warp_size(); + + template + CK_TILE_HOST_DEVICE static constexpr auto GetMaxVectorSize() + { + if constexpr(std::is_same_v || std::is_same_v) + { + if constexpr(ElemPerThread % 8 == 0) + return 8; + else if constexpr(ElemPerThread % 6 == 0) + return 6; + else if constexpr(ElemPerThread % 4 == 0) + return 4; + else if constexpr(ElemPerThread % 2 == 0) + return 2; + return 1; + } + else if constexpr(std::is_same_v) + { + if constexpr(ElemPerThread % 4 == 0) + return 4; + else if constexpr(ElemPerThread % 3 == 0) + return 3; + else if constexpr(ElemPerThread % 2 == 0) + return 2; + return 1; + } + else + static_assert(false, "The data type is not supported!"); + }; + + template + CK_TILE_HOST_DEVICE static constexpr auto GetDramTileAccessMaxVectorSize() + { + constexpr index_t ElemPerThread = (kHigherDimSize * kLowerDimSize) / kThreadBlockSize; + + return GetMaxVectorSize(); + } + + CK_TILE_HOST_DEVICE static constexpr auto GetQDramTileAccessMaxVectorSize() + { + constexpr index_t kMPerBlock = HstuAttentionTileSetting::kM0; + constexpr index_t kKPerBlock = HstuAttentionTileSetting::kSubQKHeaddim; + + return GetDramTileAccessMaxVectorSize(); + } + + CK_TILE_HOST_DEVICE static constexpr auto GetKDramTileAccessMaxVectorSize() + { + constexpr index_t kNPerBlock = HstuAttentionTileSetting::kN0Sub; + constexpr index_t kKPerBlock = HstuAttentionTileSetting::kSubQKHeaddim; + + return GetDramTileAccessMaxVectorSize(); + } + + CK_TILE_HOST_DEVICE static constexpr auto GetVDramTileAccessMaxVectorSize() + { + constexpr index_t kNPerBlock = HstuAttentionTileSetting::kN1; + constexpr index_t kKPerBlock = HstuAttentionTileSetting::kK1; + + return GetDramTileAccessMaxVectorSize(); + }; }; } // namespace ck_tile