diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_default_policy.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_default_policy.hpp index 2c4d3d50b1..a8e0c3b3e4 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_default_policy.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_default_policy.hpp @@ -101,6 +101,54 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy return min(MaxVectorSize, ElemPerThread); } + template + CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackV() + { + using VDataType = remove_cvref_t; + + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; + constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize; + + constexpr index_t MaxVectorSize = 16 / sizeof(VDataType); + return min(ElemPerThread, MaxVectorSize); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentV() + { + using VLayout = remove_cvref_t; + using VDataType = remove_cvref_t; + + if constexpr(std::is_same_v) + { + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; + constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize; + constexpr index_t MaxVectorSize = 16 / sizeof(VDataType); + constexpr index_t kMaxVecLoad = min(ElemPerThread, MaxVectorSize); + constexpr index_t kMinVecLoad = 4 / sizeof(VDataType); + + constexpr index_t kVecLoad = ((ElemPerThread / kMaxVecLoad) >= kMinVecLoad) + ? kMaxVecLoad + : (ElemPerThread / kMinVecLoad); + + return kVecLoad; + } + else + { + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; + constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize; + + constexpr index_t MaxVectorSize = 16 / sizeof(VDataType); + return min(ElemPerThread, MaxVectorSize); + } + } + template CK_TILE_HOST_DEVICE static constexpr auto GetKSingleSmemElementSpaceSize() {