From 611f2ce1f9c77d93184cc0d750e99ec755ae3d32 Mon Sep 17 00:00:00 2001 From: Qianfeng Zhang Date: Sat, 3 May 2025 16:17:28 +0000 Subject: [PATCH] Override and fix GetAlignmentK() --- ...hstu_attention_fwd_pipeline_default_policy.hpp | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_default_policy.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_default_policy.hpp index dca2ec3426..9ea83ec606 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_default_policy.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_default_policy.hpp @@ -51,6 +51,21 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy return 4; } + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentK() + { + using KDataType = remove_cvref_t; + + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kK1; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim; + + constexpr index_t MaxVectorSize = 16 / sizeof(KDataType); + constexpr index_t ElemPerThread = (kNPerBlock * kKPerBlock) / kBlockSize; + + return min(MaxVectorSize, ElemPerThread); + } + template CK_TILE_HOST_DEVICE static constexpr auto GetKSingleSmemElementSpaceSize() {