From 88e54a8989d8c02ef1c68337cee01bb8405d347d Mon Sep 17 00:00:00 2001 From: Qianfeng Zhang Date: Fri, 18 Apr 2025 09:47:43 +0000 Subject: [PATCH] Use shared ring Lds buffers for K/V to avoid over-lapping between first-K/last-V or last-K/first-V --- .../hstu_attention_fwd_pipeline.hpp | 7 +- ..._attention_fwd_pipeline_default_policy.hpp | 145 +++++++++--------- 2 files changed, 75 insertions(+), 77 deletions(-) diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp index f92fdc6825..9f5a890002 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp @@ -220,8 +220,7 @@ struct HstuAttentionFwdPipelineQRKSVS Policy::template MakeVDramTileDistribution()); // V tile in LDS auto v_lds = make_tensor_view( - reinterpret_cast(static_cast(smem_ptr) + - Policy::template GetExclusiveKLdsBytes()), + reinterpret_cast(smem_ptr), Policy::template MakeVLdsBlockDescriptor()); auto v_lds_window = make_tile_window( v_lds, Policy::template MakeVLdsBlockDescriptor().get_lengths(), {0, 0}); @@ -396,7 +395,7 @@ struct HstuAttentionFwdPipelineQRKSVS if constexpr(kHasDropout) { auto randval_lds_ptr = - reinterpret_cast(smem_ptr) + Policy::template GetSmemSizeK(); + reinterpret_cast(smem_ptr) + Policy::template GetSmemSizeKV(); dropout.template Run( randval_lds_ptr, seqlen_k_start + i_loop * kN0, s, null_randval_window); @@ -411,6 +410,7 @@ struct HstuAttentionFwdPipelineQRKSVS shuffle_tile(v_shuffle_tmp, v_tiles[I0]); // ensure gemm_0 has finished access of k-Lds for all warps + // the over-lap only occurs when k0_loops is 3/5/7, NumKLdsBuffers is 2 if constexpr(Policy::template IsFirstVLdsBufferOverlapLastKLdsBuffer()) __builtin_amdgcn_s_barrier(); @@ -479,6 +479,7 @@ struct HstuAttentionFwdPipelineQRKSVS get_slice_tile(p, sequence<0, (k1_loops - 1) * kK1>{}, sequence{}), v_lds_windows[number<(k1_loops - 1) % NumVLdsBuffers>{}]); + // the over-lap only occurs when k1_loops is 3/5/7, NumVLdsBuffers is 2 if constexpr(Policy::template IsFirstKLdsBufferOverlapLastVLdsBuffer()) { __builtin_amdgcn_sched_barrier(0); diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_default_policy.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_default_policy.hpp index 0b40e353f0..c4740a7b79 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_default_policy.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_default_policy.hpp @@ -58,6 +58,44 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy return 8 / sizeof(QKVDataType); } + template + CK_TILE_HOST_DEVICE static constexpr auto GetKSingleSmemElementSpaceSize() + { + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; + constexpr index_t kKPack = GetSmemKPackK(); + constexpr index_t kKVector = GetAlignmentK(); + + static_assert(kKVector % kKPack == 0); + + return kKPerBlock * kNPerBlock + kKPerBlock * kKPack / kKVector; + }; + + template + CK_TILE_HOST_DEVICE static constexpr auto GetVSingleSmemElementSpaceSize() + { + using QKVDataType = remove_cvref_t; + + constexpr index_t Banks = 32; // TODO: need change based on arch + constexpr index_t PixelsPerRow = Banks * 4 / sizeof(QKVDataType); + constexpr index_t kKPack = GetSmemKPackV(); + static_assert(PixelsPerRow % kKPack == 0); + constexpr index_t NPerRow = PixelsPerRow / kKPack; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; + static_assert(kNPerBlock % NPerRow == 0); + static_assert(kKPerBlock % kKPack == 0); + + return (kKPerBlock / kKPack) * (kNPerBlock / NPerRow) * (PixelsPerRow + kKPack); + }; + + template + CK_TILE_HOST_DEVICE static constexpr auto GetSingleSmemElementSpaceSize() + { + return max(GetKSingleSmemElementSpaceSize(), + GetVSingleSmemElementSpaceSize()); + }; + template CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsBlockDescriptor() { @@ -69,19 +107,26 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy static_assert(kKVector % kKPack == 0); - constexpr auto k_lds_block_desc_0 = make_naive_tensor_descriptor( - make_tuple(number{}, - number{}, - number{}, - number{}, - number{}), - make_tuple(number{}, - number{}, - number{}, - number{}, - number<1>{}), - number{}, - number<1>{}); + constexpr index_t KSingleSmemElementSpaceSize = + kKPerBlock * kNPerBlock + kKPerBlock * kKPack / kKVector; + + static_assert(KSingleSmemElementSpaceSize == GetKSingleSmemElementSpaceSize()); + + constexpr index_t SingleSmemElementSpaceSize = GetSingleSmemElementSpaceSize(); + + constexpr auto k_lds_block_desc_0 = + make_naive_tensor_descriptor(make_tuple(number{}, + number{}, + number{}, + number{}, + number{}), + make_tuple(number{}, + number{}, + number{}, + number{}, + number<1>{}), + number{}, + number<1>{}); constexpr auto k_lds_block_desc = transform_tensor_descriptor( k_lds_block_desc_0, @@ -147,13 +192,17 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy constexpr index_t VSingleSmemElementSpaceSize = (kKPerBlock / kKPack) * (kNPerBlock / NPerRow) * (PixelsPerRow + kKPack); + static_assert(VSingleSmemElementSpaceSize == GetVSingleSmemElementSpaceSize()); + + constexpr index_t SingleSmemElementSpaceSize = GetSingleSmemElementSpaceSize(); + constexpr auto v_lds_block_desc_0 = make_naive_tensor_descriptor( make_tuple(number{}, number{}, number{}, number{}, number{}), - make_tuple(number{}, + make_tuple(number{}, number<(kNPerBlock / NPerRow) * (PixelsPerRow + kKPack)>{}, number{}, number{}, @@ -305,22 +354,6 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy return BlockGemmARegBSmemCRegOneWarpV1{}; } - // leave some exclusive space so that the second v_lds buffer will never overlap with the first - // k_lds bufffer - template - CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetExclusiveKLdsBytes() - { - constexpr index_t single_k_lds_buffer_size = - GetSmemSizeK() / GetNumKLdsBuffers(); - constexpr index_t single_v_lds_buffer_size = - GetSmemSizeV() / GetNumVLdsBuffers(); - - if constexpr(single_k_lds_buffer_size <= single_v_lds_buffer_size) - return 0; - else - return integer_least_multiple(single_k_lds_buffer_size - single_v_lds_buffer_size, 64); - }; - template CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t IsFirstVLdsBufferOverlapLastKLdsBuffer() { @@ -328,26 +361,8 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy constexpr index_t k0_loops = BlockFmhaShape::kQKHeaddim / BlockFmhaShape::kK0; constexpr index_t num_k_lds_buffers = GetNumKLdsBuffers(); - constexpr index_t num_v_lds_buffers = GetNumVLdsBuffers(); - constexpr index_t last_k_lds_buffer_offset = - MakeKLdsBlockDescriptor().get_element_space_size() / num_k_lds_buffers * - ((k0_loops - 1) % num_k_lds_buffers) * sizeof(typename Problem::KDataType); - - constexpr index_t last_k_lds_buffer_end = - last_k_lds_buffer_offset + MakeKLdsBlockDescriptor().get_element_space_size() / - num_k_lds_buffers * sizeof(typename Problem::KDataType); - - constexpr index_t first_v_lds_buffer_size = - MakeVLdsBlockDescriptor().get_element_space_size() / num_v_lds_buffers * - sizeof(typename Problem::VDataType); - - constexpr index_t first_v_lds_buffer_offset = GetExclusiveKLdsBytes(); - constexpr index_t first_v_lds_buffer_end = - first_v_lds_buffer_offset + first_v_lds_buffer_size; - - return !((first_v_lds_buffer_offset >= last_k_lds_buffer_end) || - (first_v_lds_buffer_end <= last_k_lds_buffer_offset)); + return (k0_loops - 1) % num_k_lds_buffers == 0; } template @@ -356,43 +371,25 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy using BlockFmhaShape = remove_cvref_t; constexpr index_t k1_loops = BlockFmhaShape::kN0 / BlockFmhaShape::kK1; - constexpr index_t num_k_lds_buffers = GetNumKLdsBuffers(); constexpr index_t num_v_lds_buffers = GetNumVLdsBuffers(); - constexpr index_t last_v_lds_buffer_offset = - MakeVLdsBlockDescriptor().get_element_space_size() / num_v_lds_buffers * - ((k1_loops - 1) % num_v_lds_buffers) * sizeof(typename Problem::VDataType); - - constexpr index_t first_k_lds_buffer_size = - MakeKLdsBlockDescriptor().get_element_space_size() / num_k_lds_buffers * - sizeof(typename Problem::QKVDataType); - - return GetExclusiveKLdsBytes() + last_v_lds_buffer_offset < - first_k_lds_buffer_size; + return (k1_loops - 1) % num_v_lds_buffers == 0; }; template - CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeK() + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeKV() { - return MakeKLdsBlockDescriptor().get_element_space_size() * - sizeof(typename Problem::QKVDataType); - } + constexpr index_t num_kv_lds_buffers = + max(GetNumKLdsBuffers(), GetNumVLdsBuffers()); - template - CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeV() - { - return MakeVLdsBlockDescriptor().get_element_space_size() * + return num_kv_lds_buffers * GetSingleSmemElementSpaceSize() * sizeof(typename Problem::QKVDataType); - } + }; template CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() { - // assume V can reuse the other shared memory by K except the first - // assume Dropout can reuse the shared memory by V - return GetExclusiveKLdsBytes() + - max(GetSmemSizeK() - GetExclusiveKLdsBytes(), - max(GetSmemSizeV(), GetSmemSizeDropout(0))); + return GetSmemSizeKV() + GetSmemSizeDropout(0); } };