From 4632d30cc0c98e38742fa591ce9883a57cfddc30 Mon Sep 17 00:00:00 2001 From: Qianfeng Zhang Date: Sun, 8 Jun 2025 11:22:21 +0000 Subject: [PATCH] Improve the VDramTileDistribution and VLds layout for better device loading and reduce bank-conflict --- .../hstu_attention_fwd_pipeline.hpp | 2 +- ..._attention_fwd_pipeline_default_policy.hpp | 347 ++++++++++++------ 2 files changed, 237 insertions(+), 112 deletions(-) diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp index dacdea67e1..014929bd08 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp @@ -480,7 +480,7 @@ struct HstuAttentionFwdPipelineQRKSVS if constexpr(std::is_same_v) { auto v_shuffle_tmp = make_static_distributed_tensor( - Policy::template MakeShuffledVRegBlockDescriptor()); + Policy::template MakeShuffledVRegTileDistribution()); shuffle_tile(v_shuffle_tmp, v_tile); // if K in this unroll uses Lds-buffer i, then V in this uroll uses Lds-buffer diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_default_policy.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_default_policy.hpp index a8e0c3b3e4..feb458673f 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_default_policy.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_default_policy.hpp @@ -53,6 +53,16 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy return WG::WarpGemmAttribute::kKPerThread; }; + template + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetKVWarpGemmKPerThreadSize() + { + using BlockGemm = remove_cvref_t())>; + constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + + return WG::WarpGemmAttribute::kKPerThread; + }; + template CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackQ() { @@ -104,15 +114,10 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy template CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackV() { - using VDataType = remove_cvref_t; - - constexpr index_t kBlockSize = Problem::kBlockSize; - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; - constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize; - - constexpr index_t MaxVectorSize = 16 / sizeof(VDataType); - return min(ElemPerThread, MaxVectorSize); + if constexpr(GetKVWarpGemmKPerThreadSize() >= 8) + return 8; + else + return 4; } template @@ -121,12 +126,15 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy using VLayout = remove_cvref_t; using VDataType = remove_cvref_t; + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; + + constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize; + + // Need special consideration for RowMajor since shuffling is needed to write LDS in dwords if constexpr(std::is_same_v) { - constexpr index_t kBlockSize = Problem::kBlockSize; - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; - constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize; constexpr index_t MaxVectorSize = 16 / sizeof(VDataType); constexpr index_t kMaxVecLoad = min(ElemPerThread, MaxVectorSize); constexpr index_t kMinVecLoad = 4 / sizeof(VDataType); @@ -137,13 +145,8 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy return kVecLoad; } - else + else // Similar to GetAlignmentK() { - constexpr index_t kBlockSize = Problem::kBlockSize; - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; - constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize; - constexpr index_t MaxVectorSize = 16 / sizeof(VDataType); return min(ElemPerThread, MaxVectorSize); } @@ -174,19 +177,38 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy template CK_TILE_HOST_DEVICE static constexpr auto GetVSingleSmemElementSpaceSize() { - using QKVDataType = remove_cvref_t; + using VLayout = remove_cvref_t; - constexpr index_t Banks = 32; // TODO: need change based on arch - constexpr index_t PixelsPerRow = Banks * 4 / sizeof(QKVDataType); - constexpr index_t kKPack = GetSmemKPackV(); - static_assert(PixelsPerRow % kKPack == 0); - constexpr index_t NPerRow = PixelsPerRow / kKPack; constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; - static_assert(kNPerBlock % NPerRow == 0); - static_assert(kKPerBlock % kKPack == 0); - return (kKPerBlock / kKPack) * (kNPerBlock / NPerRow) * (PixelsPerRow + kKPack); + // Need special consideration for RowMajor since shuffling is needed to write LDS in dwords + if constexpr(std::is_same_v) + { + constexpr index_t N1 = GetAlignmentV(); + constexpr index_t N0 = kNPerBlock / N1; + constexpr index_t kKPack = GetKVWarpGemmKPerThreadSize(); + + return N0 * (N1 * kKPerBlock + kKPack); + } + else // similar to GetKSingleSmemElementSpaceSize() + { + constexpr index_t kKPack = GetSmemKPackV(); + constexpr index_t kKVector = GetAlignmentV(); + + if constexpr(GetKVWarpGemmKPerThreadSize() >= 8) + { + static_assert(kKVector == kKPack); + + return kKPerBlock * kNPerBlock + kKPerBlock; + } + else + { + static_assert(kKVector % kKPack == 0); + + return kKPerBlock * kNPerBlock + kKPerBlock * kKPack / kKVector; + }; + }; }; template @@ -376,9 +398,8 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim; constexpr index_t MaxVectorSize = 16 / sizeof(QKVDataType); - constexpr index_t ElemPerThread = (kNPerBlock * kKPerBlock) / kBlockSize; - static_assert(0 < ElemPerThread); + constexpr index_t kMaxVecLoad = min(ElemPerThread, MaxVectorSize); constexpr index_t KPerThread = kMaxVecLoad; @@ -400,51 +421,136 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy template CK_TILE_HOST_DEVICE static constexpr auto MakeVLdsBlockDescriptor() { - using QKVDataType = remove_cvref_t; + using VLayout = remove_cvref_t; constexpr index_t NumVLdsBuffers = GetNumKVLdsBuffers(); + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; - constexpr index_t Banks = 32; // TODO: need change based on arch - constexpr index_t PixelsPerRow = Banks * 4 / sizeof(QKVDataType); - constexpr index_t kKPack = GetSmemKPackV(); - static_assert(PixelsPerRow % kKPack == 0); - constexpr index_t NPerRow = PixelsPerRow / kKPack; - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; - static_assert(kNPerBlock % NPerRow == 0); - static_assert(kKPerBlock % kKPack == 0); + // Need special consideration for RowMajor since shuffling is needed to write LDS in dwords + if constexpr(std::is_same_v) + { + constexpr index_t N1 = GetAlignmentV(); + constexpr index_t N0 = kNPerBlock / N1; - constexpr index_t VSingleSmemElementSpaceSize = - (kKPerBlock / kKPack) * (kNPerBlock / NPerRow) * (PixelsPerRow + kKPack); + constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize; - static_assert(VSingleSmemElementSpaceSize == GetVSingleSmemElementSpaceSize()); + // K2 is the vector size for storing shuffled tile to LDS + constexpr index_t K2 = ElemPerThread / N1; - constexpr index_t SingleSmemElementSpaceSize = GetSingleSmemElementSpaceSize(); + // GetSmemKPackV() is the vector size for loading from LDS by BlockGemm + constexpr index_t kKPack = GetSmemKPackV(); - constexpr auto v_lds_block_desc_0 = make_naive_tensor_descriptor( - make_tuple(number{}, - number{}, - number{}, - number{}, - number{}), - make_tuple(number{}, - number<(kNPerBlock / NPerRow) * (PixelsPerRow + kKPack)>{}, - number{}, - number{}, - number<1>{}), - number{}, - number<1>{}); + static_assert(kKPack >= K2, "Check failed!"); - constexpr auto v_lds_block_desc = transform_tensor_descriptor( - v_lds_block_desc_0, - make_tuple( - make_merge_transform(make_tuple( - number{}, number{}, number{})), - make_merge_transform(make_tuple(number{}, number{}))), - make_tuple(sequence<0, 2, 3>{}, sequence<1, 4>{}), - make_tuple(sequence<0>{}, sequence<1>{})); + constexpr index_t VSingleSmemElementSpaceSize = N0 * (N1 * kKPerBlock + kKPack); - return v_lds_block_desc; + static_assert(VSingleSmemElementSpaceSize == GetVSingleSmemElementSpaceSize()); + + constexpr index_t SingleSmemElementSpaceSize = GetSingleSmemElementSpaceSize(); + + constexpr auto v_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple( + number{}, number{}, number{}, number{}), + make_tuple(number{}, + number{}, + number{}, + number<1>{}), + number<8>{}, + number<1>{}); + + constexpr auto v_lds_block_desc = transform_tensor_descriptor( + v_lds_block_desc_0, + make_tuple(make_merge_transform( + make_tuple(number{}, number{}, number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<0, 1, 2>{}, sequence<3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return v_lds_block_desc; + } + else // Similar to MakeKLdsBlockDescriptor() + { + constexpr index_t kKPack = GetSmemKPackV(); + constexpr index_t kKVector = GetAlignmentV(); + + if constexpr(GetKVWarpGemmKPerThreadSize() >= 8) + { + static_assert(kKVector == kKPack); + + constexpr index_t VSingleSmemElementSpaceSize = + kKPerBlock * kNPerBlock + kKPerBlock; + + static_assert(VSingleSmemElementSpaceSize == + GetVSingleSmemElementSpaceSize()); + + constexpr index_t SingleSmemElementSpaceSize = + GetSingleSmemElementSpaceSize(); + + constexpr auto v_lds_block_desc_0 = + make_naive_tensor_descriptor(make_tuple(number{}, + number{}, + number{}, + number{}), + make_tuple(number{}, + number{}, + number{}, + number<1>{}), + number{}, + number<1>{}); + + constexpr auto v_lds_block_desc = transform_tensor_descriptor( + v_lds_block_desc_0, + make_tuple(make_merge_transform( + make_tuple(number{}, number{})), + make_merge_transform( + make_tuple(number{}, number{}))), + make_tuple(sequence<0, 2>{}, sequence<1, 3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return v_lds_block_desc; + } + else + { + static_assert(kKVector % kKPack == 0); + + constexpr index_t VSingleSmemElementSpaceSize = + kKPerBlock * kNPerBlock + kKPerBlock * kKPack / kKVector; + + static_assert(VSingleSmemElementSpaceSize == + GetVSingleSmemElementSpaceSize()); + + constexpr index_t SingleSmemElementSpaceSize = + GetSingleSmemElementSpaceSize(); + + constexpr auto v_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, + number{}, + number{}, + number{}, + number{}), + make_tuple(number{}, + number{}, + number{}, + number{}, + number<1>{}), + number{}, + number<1>{}); + + constexpr auto v_lds_block_desc = transform_tensor_descriptor( + v_lds_block_desc_0, + make_tuple(make_merge_transform( + make_tuple(number{}, number{})), + make_merge_transform(make_tuple(number{}, + number{}, + number{}))), + make_tuple(sequence<0, 3>{}, sequence<1, 2, 4>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return v_lds_block_desc; + }; + } } template @@ -456,66 +562,85 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; + // Need special consideration for RowMajor since shuffling is needed to write LDS in dwords if constexpr(std::is_same_v) { constexpr index_t N1 = GetAlignmentV(); - constexpr index_t N0 = kNPerBlock / N1; // P + constexpr index_t N0 = kNPerBlock / N1; constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize; + static_assert(ElemPerThread % N1 == 0); - constexpr index_t K3 = ElemPerThread / N1; - constexpr index_t kKPack = GetSmemKPackV(); - static_assert(kKPack % K3 == 0); - constexpr index_t K2 = kKPack / K3; - if constexpr(get_warp_size() % (K2 * N0) == 0) - { - constexpr index_t K1 = get_warp_size() / (K2 * N0); - constexpr index_t K0 = kBlockSize / get_warp_size(); - static_assert(kKPerBlock == K0 * K1 * K2 * K3); - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<2, 1, 2>>, - tuple, sequence<1, 0, 2>>, - sequence<2, 1>, - sequence<3, 1>>{}); - } - else - { - constexpr index_t K1 = (K2 * N0) / get_warp_size(); - constexpr index_t K2_m = K2 / K1; - constexpr index_t K0 = kBlockSize / get_warp_size() / K1; - static_assert(kKPerBlock == K0 * K1 * K2_m * K3); - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<1, 2>>, - tuple, sequence<0, 2>>, - sequence<2, 1>, - sequence<3, 1>>{}); - } - } - else - { - constexpr index_t K1 = GetAlignmentV(); - constexpr index_t K0 = kKPerBlock / K1; - constexpr index_t N2 = get_warp_size() / K0; - constexpr index_t N1 = kBlockSize / get_warp_size(); - static_assert(N2 != 0, "N2 is zero, which will lead to a division by zero error."); - static_assert(N1 != 0, "N1 is zero, which will lead to a division by zero error."); - constexpr index_t N0 = kNPerBlock / (N2 * N1); - static_assert(N0 != 0); + + constexpr index_t K2 = ElemPerThread / N1; + constexpr index_t K1 = get_warp_size() / N0; + constexpr index_t K0 = kBlockSize / get_warp_size(); return make_static_tile_distribution( tile_distribution_encoding, - tuple, sequence>, + tuple, sequence>, + tuple, sequence<2, 1>>, + tuple, sequence<1, 0>>, + sequence<2, 1>, + sequence<2, 1>>{}); + } + else // Similar to MakeKDramTileDistribution() + { + using QKVDataType = remove_cvref_t; + + constexpr index_t MaxVectorSize = 16 / sizeof(QKVDataType); + constexpr index_t ElemPerThread = (kNPerBlock * kKPerBlock) / kBlockSize; + + constexpr index_t kMaxVecLoad = min(ElemPerThread, MaxVectorSize); + + constexpr index_t KPerThread = kMaxVecLoad; + constexpr index_t KThreads = kKPerBlock / KPerThread; + constexpr index_t NThreadPerWarp = get_warp_size() / KThreads; + constexpr index_t NumWarps = kBlockSize / get_warp_size(); + constexpr index_t NPerThread = kNPerBlock / (NThreadPerWarp * NumWarps); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, + sequence>, tuple, sequence<1, 2>>, - tuple, sequence<2, 0>>, + tuple, sequence<1, 0>>, sequence<1, 2>, sequence<0, 1>>{}); } } + template + CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledVRegTileDistribution() + { + // This tile-distribuiton only used when V layout is RowMajor + using VLayout = remove_cvref_t; + static_assert(std::is_same_v); + + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; + + constexpr index_t N1 = GetAlignmentV(); + constexpr index_t N0 = kNPerBlock / N1; + + constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize; + + static_assert(ElemPerThread % N1 == 0); + + constexpr index_t K2 = ElemPerThread / N1; + constexpr index_t K1 = get_warp_size() / N0; + constexpr index_t K0 = kBlockSize / get_warp_size(); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<2, 1>>, + tuple, sequence<1, 0>>, + sequence<1, 2>, + sequence<1, 2>>{}); + } + template CK_TILE_HOST_DEVICE static constexpr index_t GetQKBlockGemmSingleRepM() {