From d4f43f0653d23e27b95047c9749770db0dc2540f Mon Sep 17 00:00:00 2001 From: Qianfeng Zhang Date: Thu, 21 Aug 2025 13:59:32 +0000 Subject: [PATCH] Use xor transform to implement Q/K Lds descriptor for kKpack == 8 cases --- ..._attention_fwd_pipeline_default_policy.hpp | 93 +++++++++++++++---- 1 file changed, 73 insertions(+), 20 deletions(-) diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_default_policy.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_default_policy.hpp index d26913deaa..d1169f7973 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_default_policy.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_default_policy.hpp @@ -178,7 +178,7 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy { static_assert(kKVector == kKPack); - return kKPerBlock * kNPerBlock + kKPerBlock; + return kKPerBlock * kNPerBlock; } else { @@ -220,18 +220,47 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy { static_assert(kKVector == kKPack); + using QDataType = remove_cvref_t; + + constexpr index_t DataTypeSize = sizeof(QDataType); + + // 128 contiguous bytes mapped to 32 banks with each bank 4 contiguous bytes + constexpr auto MLdsLayer = + (32 * 4 / kKPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / kKPerBlock / DataTypeSize); + constexpr auto q_lds_block_desc_0 = make_naive_tensor_descriptor( - make_tuple(number{}, number{}, number{}), - make_tuple(number{}, number{}, number<1>{}), + make_tuple(number{}, + number{}, + number{}), + make_tuple(number{}, number{}, number<1>{}), number{}, number<1>{}); - constexpr auto q_lds_block_desc = transform_tensor_descriptor( + constexpr auto q_lds_block_desc_permuted = transform_tensor_descriptor( q_lds_block_desc_0, - make_tuple(make_pass_through_transform(number{}), - make_merge_transform( + make_tuple( + make_xor_transform(make_tuple(number{}, + number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<0, 1>{}, sequence<2>{}), + make_tuple(sequence<0, 1>{}, sequence<2>{})); + + constexpr auto q_lds_block_desc_k0_mldslayer_m_k1 = transform_tensor_descriptor( + q_lds_block_desc_permuted, + make_tuple(make_pass_through_transform(number{}), + make_unmerge_transform( + make_tuple(number{}, number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{})); + + constexpr auto q_lds_block_desc = transform_tensor_descriptor( + q_lds_block_desc_k0_mldslayer_m_k1, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(number{}, number{})), + make_merge_transform_v3_division_mod( make_tuple(number{}, number{}))), - make_tuple(sequence<1>{}, sequence<0, 2>{}), + make_tuple(sequence<0, 2>{}, sequence<1, 3>{}), make_tuple(sequence<0>{}, sequence<1>{})); return q_lds_block_desc; @@ -310,31 +339,55 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy { static_assert(kKVector == kKPack); - constexpr index_t KSingleSmemElementSpaceSize = kKPerBlock * kNPerBlock + kKPerBlock; + using KDataType = remove_cvref_t; - static_assert(KSingleSmemElementSpaceSize == GetKSingleSmemElementSpaceSize()); + constexpr index_t DataTypeSize = sizeof(KDataType); - constexpr index_t SingleSmemElementSpaceSize = GetSingleSmemElementSpaceSize(); + // 128 contiguous bytes mapped to 32 banks with each bank 4 contiguous bytes + constexpr auto NLdsLayer = + (32 * 4 / kKPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / kKPerBlock / DataTypeSize); constexpr auto k_lds_block_desc_0 = make_naive_tensor_descriptor(make_tuple(number{}, - number{}, - number{}, + number{}, + number{}, number{}), - make_tuple(number{}, - number{}, + make_tuple(number{}, + number{}, number{}, number<1>{}), number{}, number<1>{}); - constexpr auto k_lds_block_desc = transform_tensor_descriptor( + constexpr auto k_lds_block_desc_permuted = transform_tensor_descriptor( k_lds_block_desc_0, - make_tuple(make_merge_transform( - make_tuple(number{}, number{})), - make_merge_transform( - make_tuple(number{}, number{}))), - make_tuple(sequence<0, 2>{}, sequence<1, 3>{}), + make_tuple( + make_pass_through_transform(number{}), + make_xor_transform(make_tuple(number{}, + number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{}), + make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{})); + + constexpr auto k_lds_block_desc_k0_nldslayer_n_k1 = transform_tensor_descriptor( + k_lds_block_desc_permuted, + make_tuple(make_pass_through_transform(number{}), + make_pass_through_transform(number{}), + make_unmerge_transform( + make_tuple(number{}, number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2, 3>{}, sequence<4>{})); + + constexpr auto k_lds_block_desc = transform_tensor_descriptor( + k_lds_block_desc_k0_nldslayer_n_k1, + make_tuple( + make_merge_transform_v3_division_mod( + make_tuple(number{}, number{})), + make_merge_transform_v3_division_mod(make_tuple(number{}, + number{}, + number{}))), + make_tuple(sequence<1, 3>{}, sequence<0, 2, 4>{}), make_tuple(sequence<0>{}, sequence<1>{})); return k_lds_block_desc;