From a41371f7348992b7fad873d1213785c2e06e1dcd Mon Sep 17 00:00:00 2001 From: Qianfeng Zhang Date: Thu, 24 Apr 2025 15:02:57 +0000 Subject: [PATCH] Update in K-Lds laying-out to consider for both WarpGemm-32x32x16 and WarpGemm-16x16x16 --- ..._attention_fwd_pipeline_default_policy.hpp | 121 +++++++++++++----- 1 file changed, 89 insertions(+), 32 deletions(-) diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_default_policy.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_default_policy.hpp index b702763bd0..6ee53005a0 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_default_policy.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_default_policy.hpp @@ -51,11 +51,23 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy Problem::BlockFmhaShape::kQKHeaddim>(); } + template + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetQKWarpGemmKPerThreadSize() + { + using BlockGemm = remove_cvref_t())>; + constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + + return WG::WarpGemmAttribute::kKPerThread; + }; + template CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackK() { - using QKVDataType = remove_cvref_t; - return 8 / sizeof(QKVDataType); + if constexpr(GetQKWarpGemmKPerThreadSize() >= 8) + return 8; + else + return 4; } template @@ -66,9 +78,18 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy constexpr index_t kKPack = GetSmemKPackK(); constexpr index_t kKVector = GetAlignmentK(); - static_assert(kKVector % kKPack == 0); + if constexpr(GetQKWarpGemmKPerThreadSize() >= 8) + { + static_assert(kKVector == kKPack); - return kKPerBlock * kNPerBlock + kKPerBlock * kKPack / kKVector; + return kKPerBlock * kNPerBlock + kKPerBlock; + } + else + { + static_assert(kKVector % kKPack == 0); + + return kKPerBlock * kNPerBlock + kKPerBlock * kKPack / kKVector; + }; }; template @@ -105,40 +126,76 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy constexpr index_t kKPack = GetSmemKPackK(); constexpr index_t kKVector = GetAlignmentK(); - static_assert(kKVector % kKPack == 0); + if constexpr(GetQKWarpGemmKPerThreadSize() >= 8) + { + static_assert(kKVector == kKPack); - constexpr index_t KSingleSmemElementSpaceSize = - kKPerBlock * kNPerBlock + kKPerBlock * kKPack / kKVector; + constexpr index_t KSingleSmemElementSpaceSize = kKPerBlock * kNPerBlock + kKPerBlock; - static_assert(KSingleSmemElementSpaceSize == GetKSingleSmemElementSpaceSize()); + static_assert(KSingleSmemElementSpaceSize == GetKSingleSmemElementSpaceSize()); - constexpr index_t SingleSmemElementSpaceSize = GetSingleSmemElementSpaceSize(); + constexpr index_t SingleSmemElementSpaceSize = GetSingleSmemElementSpaceSize(); - constexpr auto k_lds_block_desc_0 = - make_naive_tensor_descriptor(make_tuple(number{}, - number{}, - number{}, - number{}, - number{}), - make_tuple(number{}, - number{}, - number{}, - number{}, - number<1>{}), - number{}, - number<1>{}); + constexpr auto k_lds_block_desc_0 = + make_naive_tensor_descriptor(make_tuple(number{}, + number{}, + number{}, + number{}), + make_tuple(number{}, + number{}, + number{}, + number<1>{}), + number{}, + number<1>{}); - constexpr auto k_lds_block_desc = transform_tensor_descriptor( - k_lds_block_desc_0, - make_tuple( - make_merge_transform(make_tuple(number{}, number{})), - make_merge_transform(make_tuple(number{}, - number{}, - number{}))), - make_tuple(sequence<0, 3>{}, sequence<1, 2, 4>{}), - make_tuple(sequence<0>{}, sequence<1>{})); + constexpr auto k_lds_block_desc = transform_tensor_descriptor( + k_lds_block_desc_0, + make_tuple(make_merge_transform( + make_tuple(number{}, number{})), + make_merge_transform( + make_tuple(number{}, number{}))), + make_tuple(sequence<0, 2>{}, sequence<1, 3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); - return k_lds_block_desc; + return k_lds_block_desc; + } + else + { + static_assert(kKVector % kKPack == 0); + + constexpr index_t KSingleSmemElementSpaceSize = + kKPerBlock * kNPerBlock + kKPerBlock * kKPack / kKVector; + + static_assert(KSingleSmemElementSpaceSize == GetKSingleSmemElementSpaceSize()); + + constexpr index_t SingleSmemElementSpaceSize = GetSingleSmemElementSpaceSize(); + + constexpr auto k_lds_block_desc_0 = + make_naive_tensor_descriptor(make_tuple(number{}, + number{}, + number{}, + number{}, + number{}), + make_tuple(number{}, + number{}, + number{}, + number{}, + number<1>{}), + number{}, + number<1>{}); + + constexpr auto k_lds_block_desc = transform_tensor_descriptor( + k_lds_block_desc_0, + make_tuple(make_merge_transform( + make_tuple(number{}, number{})), + make_merge_transform(make_tuple(number{}, + number{}, + number{}))), + make_tuple(sequence<0, 3>{}, sequence<1, 2, 4>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return k_lds_block_desc; + }; } template