From f53be61a746e320bb8765ff1f50bb9a48c6188c5 Mon Sep 17 00:00:00 2001 From: Qianfeng Zhang Date: Mon, 28 Apr 2025 06:41:37 +0000 Subject: [PATCH] Put two gemms call inside one n0loop unroll --- .../hstu_attention_fwd_pipeline.hpp | 167 +++++------------- ..._attention_fwd_pipeline_default_policy.hpp | 51 +----- 2 files changed, 54 insertions(+), 164 deletions(-) diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp index 1d6504475a..ecbadfe1ac 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp @@ -155,17 +155,10 @@ struct HstuAttentionFwdPipelineQRKSVS kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], "wrong!"); - constexpr auto I0 = number<0>{}; - constexpr index_t k1_loops = kN0 / kK1; static_assert(2 <= k1_loops); - constexpr auto NumKLdsBuffers = Policy::template GetNumKLdsBuffers(); - constexpr auto NumVLdsBuffers = Policy::template GetNumVLdsBuffers(); - constexpr auto NumPrefetchV = Policy::template GetNumPrefetchV(); - - static_assert(NumKLdsBuffers >= 2); - static_assert(NumPrefetchV >= 2); + constexpr auto NumKVLdsBuffers = Policy::template GetNumKVLdsBuffers(); auto q_dram_window = make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(), q_dram_block_window_tmp.get_window_lengths(), @@ -199,9 +192,9 @@ struct HstuAttentionFwdPipelineQRKSVS using k_lds_window_type = decltype(get_slice_tile(k_lds_window, sequence<0, 0>{}, sequence{})); - statically_indexed_array k_lds_windows; + statically_indexed_array k_lds_windows; - static_for<0, NumKLdsBuffers, 1>{}([&](auto i_buf) { + static_for<0, NumKVLdsBuffers, 1>{}([&](auto i_buf) { k_lds_windows[i_buf] = get_slice_tile(k_lds_window, sequence{}, sequence<(i_buf + 1) * kK1, kQKHeaddim>{}); @@ -219,16 +212,12 @@ struct HstuAttentionFwdPipelineQRKSVS auto v_lds_window = make_tile_window( v_lds, Policy::template MakeVLdsBlockDescriptor().get_lengths(), {0, 0}); - using v_tile_type = decltype(load_tile(v_dram_window)); - - statically_indexed_array v_tiles; - using v_lds_window_type = decltype(get_slice_tile(v_lds_window, sequence<0, 0>{}, sequence{})); - statically_indexed_array v_lds_windows; + statically_indexed_array v_lds_windows; - static_for<0, NumVLdsBuffers, 1>{}([&](auto i_buf) { + static_for<0, NumKVLdsBuffers, 1>{}([&](auto i_buf) { v_lds_windows[i_buf] = get_slice_tile( v_lds_window, sequence{}, sequence<(i_buf + 1) * kN1, kK1>{}); }); @@ -307,6 +296,9 @@ struct HstuAttentionFwdPipelineQRKSVS q_tile = tile_elementwise_in(q_element_func, q_tile); + auto v_tile = load_tile(v_dram_window); + move_tile_window(v_dram_window, {0, kK1}); + auto seqlen_k_curr = seqlen_k_start; index_t i_loop = 0; @@ -314,25 +306,18 @@ struct HstuAttentionFwdPipelineQRKSVS do { static_for<0, k1_loops, 1>{}([&](auto i_k1) { - store_tile(k_lds_windows[number{}], + store_tile(k_lds_windows[number{}], tile_elementwise_in(k_element_func, k_tile)); - if constexpr(i_k1 < k1_loops - 1) - { - k_tile = load_tile(k_dram_window); - move_tile_window(k_dram_window, {kK1, 0}); - } - else - { - v_tiles[I0] = load_tile(v_dram_window); - move_tile_window(v_dram_window, {0, kK1}); - }; + // for i_k1 = k1_loop-1, the loading is for next iteration + k_tile = load_tile(k_dram_window); + move_tile_window(k_dram_window, {kK1, 0}); clear_tile(sacc_tiles[i_k1]); block_sync_lds(); // execute current unroll of gemm_0 - gemm_0(sacc_tiles[i_k1], q_tile, k_lds_windows[number{}]); + gemm_0(sacc_tiles[i_k1], q_tile, k_lds_windows[number{}]); sacc_tiles[i_k1] = tile_elementwise_in(s_acc_element_func, sacc_tiles[i_k1]); @@ -426,95 +411,33 @@ struct HstuAttentionFwdPipelineQRKSVS pcomp_tiles[i_k1] = cast_tile(sacc_tiles[i_k1]); - seqlen_k_curr += kK1; - }); - - static_for<1, NumPrefetchV, 1>{}([&](auto i_buf) { - v_tiles[i_buf] = load_tile(v_dram_window); - move_tile_window(v_dram_window, {0, kK1}); - }); - - if constexpr(std::is_same_v) - { - auto v_shuffle_tmp = make_static_distributed_tensor( - Policy::template MakeShuffledVRegBlockDescriptor()); - shuffle_tile(v_shuffle_tmp, v_tiles[I0]); - - // ensure gemm_0 has finished access of k-Lds for all warps - // the over-lap only occurs when k0_loops is 3/5/7, NumKLdsBuffers is 2 - if constexpr(Policy::template IsFirstVLdsBufferOverlapLastKLdsBuffer()) - __builtin_amdgcn_s_barrier(); - - store_tile( - v_lds_windows[I0], - tile_elementwise_in(v_element_func, v_shuffle_tmp)); // store the prefetch - } - else - { - // ensure gemm_0 has finished access of k-Lds for all warps - if constexpr(Policy::template IsFirstVLdsBufferOverlapLastKLdsBuffer()) - __builtin_amdgcn_s_barrier(); - - store_tile(v_lds_windows[I0], - tile_elementwise_in(v_element_func, v_tiles[I0])); // store the prefetch - }; - - tile_elementwise_inout(f_silu, pcomp_tiles[I0]); - - if constexpr(kHasDropout) - { - auto randval_lds_ptr = - reinterpret_cast(smem_ptr) + Policy::template GetSmemSizeKV(); - - dropout.template Run( - randval_lds_ptr, seqlen_k_curr - kN0, pcomp_tiles[I0], null_randval_window); - } - - auto p = [&]() { - if constexpr(std::is_same_v) - return impl::cast_tile_pk_fp16_fp32( - tile_elementwise_in(p_compute_element_func, pcomp_tiles[I0])); - else - return cast_tile( - tile_elementwise_in(p_compute_element_func, pcomp_tiles[I0])); - }(); - - static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) { - if constexpr(i_k1 < k1_loops - NumPrefetchV) - { - v_tiles[number{}] = load_tile(v_dram_window); - move_tile_window(v_dram_window, {0, kK1}); - } - else if constexpr(i_k1 == k1_loops - NumPrefetchV) - { - // load one k_tile for next iteration - k_tile = load_tile(k_dram_window); - move_tile_window(k_dram_window, {kK1, 0}); - }; - - block_sync_lds(); - - gemm_1(o_acc, p, v_lds_windows[number{}]); - tile_elementwise_inout(f_silu, pcomp_tiles[number{}]); - if constexpr(std::is_same_v) { auto v_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledVRegBlockDescriptor()); - shuffle_tile(v_shuffle_tmp, v_tiles[number<(i_k1 + 1) % NumPrefetchV>{}]); + shuffle_tile(v_shuffle_tmp, v_tile); - store_tile(v_lds_windows[number<(i_k1 + 1) % NumVLdsBuffers>{}], - tile_elementwise_in(v_element_func, - v_shuffle_tmp)); // store the prefetch + // if K in this unroll uses Lds-buffer i, then V in this uroll uses Lds-buffer + // i+1, No overlap occurs between V and K in the same unroll, and V in current + // unroll and K in next unroll or first unrool in next iteration + store_tile( + v_lds_windows[number<(i_k1 + 1) % NumKVLdsBuffers>{}], + tile_elementwise_in(v_element_func, v_shuffle_tmp)); // store the prefetch } else { - store_tile(v_lds_windows[number<(i_k1 + 1) % NumVLdsBuffers>{}], - tile_elementwise_in( - v_element_func, - v_tiles[number<(i_k1 + 1) % NumPrefetchV>{}])); // store the - // prefetch - } + // if K in this unroll uses Lds-buffer i, then V in this uroll uses Lds-buffer + // i+1, No overlap occurs between V and K in the same unroll, and V in current + // unroll and K in next unroll or first unrool in next iteration + store_tile(v_lds_windows[number<(i_k1 + 1) % NumKVLdsBuffers>{}], + tile_elementwise_in(v_element_func, v_tile)); // store the prefetch + }; + + // for i_k1 = k1_loops-1, the loading is for next iteration + v_tile = load_tile(v_dram_window); + move_tile_window(v_dram_window, {0, kK1}); + + tile_elementwise_inout(f_silu, pcomp_tiles[i_k1]); if constexpr(kHasDropout) { @@ -522,26 +445,26 @@ struct HstuAttentionFwdPipelineQRKSVS Policy::template GetSmemSizeKV(); dropout.template Run( - randval_lds_ptr, - seqlen_k_curr - kN0 + (i_k1 + 1) * kK1, - pcomp_tiles[number{}], - null_randval_window); + randval_lds_ptr, seqlen_k_curr, pcomp_tiles[i_k1], null_randval_window); } - p = [&]() { + auto p = [&]() { if constexpr(std::is_same_v) - return impl::cast_tile_pk_fp16_fp32(tile_elementwise_in( - p_compute_element_func, pcomp_tiles[number{}])); + return impl::cast_tile_pk_fp16_fp32( + tile_elementwise_in(p_compute_element_func, pcomp_tiles[i_k1])); else - return cast_tile(tile_elementwise_in( - p_compute_element_func, pcomp_tiles[number{}])); + return cast_tile( + tile_elementwise_in(p_compute_element_func, pcomp_tiles[i_k1])); }(); + + block_sync_lds(); + + gemm_1(o_acc, p, v_lds_windows[number<(i_k1 + 1) % NumKVLdsBuffers>{}]); + + seqlen_k_curr += kK1; }); - block_sync_lds(); - gemm_1(o_acc, p, v_lds_windows[number<(k1_loops - 1) % NumVLdsBuffers>{}]); - - // the over-lap only occurs when k1_loops is 3/5/7, NumVLdsBuffers is 2 + // this does not occur when k1_loops == 2 and NumKVLdsBuffers == 3 if constexpr(Policy::template IsFirstKLdsBufferOverlapLastVLdsBuffer()) __builtin_amdgcn_s_barrier(); } while(++i_loop < num_loops); diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_default_policy.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_default_policy.hpp index df1a5cb9c8..0794699c5a 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_default_policy.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_default_policy.hpp @@ -12,35 +12,14 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy : BlockFmhaPipelineQXKSVSCustomPolicy + /* NumPrefetchV = */ 1> { - static constexpr index_t NumPrefetchV = 2; - template - CK_TILE_DEVICE static constexpr auto GetNumKLdsBuffers() + CK_TILE_DEVICE static constexpr auto GetNumKVLdsBuffers() { - return 2; + return 3; } - template - CK_TILE_DEVICE static constexpr auto GetNumPrefetchV() - { - using BlockFmhaShape = remove_cvref_t; - - constexpr index_t kN0 = BlockFmhaShape::kN0; - constexpr index_t kK1 = BlockFmhaShape::kK1; - - constexpr index_t k1_loops = kN0 / kK1; - - return min(NumPrefetchV, k1_loops); - } - - template - CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetNumVLdsBuffers() - { - return 2; - }; - template CK_TILE_HOST_DEVICE static constexpr auto MakeQRegTileDistribution() { @@ -120,7 +99,7 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy template CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsBlockDescriptor() { - constexpr index_t NumKLdsBuffers = GetNumKLdsBuffers(); + constexpr index_t NumKLdsBuffers = GetNumKVLdsBuffers(); constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kK1; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim; constexpr index_t kKPack = GetSmemKPackK(); @@ -234,7 +213,7 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy { using QKVDataType = remove_cvref_t; - constexpr index_t NumVLdsBuffers = GetNumVLdsBuffers(); + constexpr index_t NumVLdsBuffers = GetNumKVLdsBuffers(); constexpr index_t Banks = 32; // TODO: need change based on arch constexpr index_t PixelsPerRow = Banks * 4 / sizeof(QKVDataType); @@ -422,33 +401,21 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy return BlockGemmARegBSmemCRegOneWarpV1{}; } - template - CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t IsFirstVLdsBufferOverlapLastKLdsBuffer() - { - using BlockFmhaShape = remove_cvref_t; - - constexpr index_t k0_loops = BlockFmhaShape::kQKHeaddim / BlockFmhaShape::kK0; - constexpr index_t num_k_lds_buffers = GetNumKLdsBuffers(); - - return (k0_loops - 1) % num_k_lds_buffers == 0; - } - template CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t IsFirstKLdsBufferOverlapLastVLdsBuffer() { using BlockFmhaShape = remove_cvref_t; - constexpr index_t k1_loops = BlockFmhaShape::kN0 / BlockFmhaShape::kK1; - constexpr index_t num_v_lds_buffers = GetNumVLdsBuffers(); + constexpr index_t k1_loops = BlockFmhaShape::kN0 / BlockFmhaShape::kK1; + constexpr index_t num_kv_lds_buffers = GetNumKVLdsBuffers(); - return (k1_loops - 1) % num_v_lds_buffers == 0; + return (k1_loops - 1 + 1) % num_kv_lds_buffers == 0; }; template CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeKV() { - constexpr index_t num_kv_lds_buffers = - max(GetNumKLdsBuffers(), GetNumVLdsBuffers()); + constexpr index_t num_kv_lds_buffers = GetNumKVLdsBuffers(); return num_kv_lds_buffers * GetSingleSmemElementSpaceSize() * sizeof(typename Problem::QKVDataType);