From 2546e905ce75543392d485ef7a5f660ae9a0152c Mon Sep 17 00:00:00 2001 From: Qianfeng Zhang Date: Sat, 19 Apr 2025 15:52:51 +0000 Subject: [PATCH] Change gemm0 to iterate along kN0 so that BlockGemm can overlap with maksing and siLu --- .../hstu_attention_fwd_kernel.hpp | 47 ++- .../hstu_attention_fwd_pipeline.hpp | 277 ++++++++---------- ..._attention_fwd_pipeline_default_policy.hpp | 16 +- .../hstu_attention_fwd_setting.hpp | 2 +- 4 files changed, 155 insertions(+), 187 deletions(-) diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp index 792d9ed44e..8bd68d7fd3 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp @@ -579,20 +579,10 @@ struct HstuAttentionFwdKernel make_tuple(kargs.seq_stride_q, 1), number{}, number<1>{}); - if constexpr(HstuAttentionPipeline::kQLoadOnce) - { - return pad_tensor_view(q_dram_naive, - make_tuple(number{}, - number{}), - sequence{}); - } - else - { - return pad_tensor_view(q_dram_naive, - make_tuple(number{}, - number{}), - sequence{}); - } + return pad_tensor_view(q_dram_naive, + make_tuple(number{}, + number{}), + sequence{}); }(); const auto k_dram = [&]() { const auto k_dram_naive = make_naive_tensor_view( @@ -604,7 +594,7 @@ struct HstuAttentionFwdKernel return pad_tensor_view(k_dram_naive, make_tuple(number{}, - number{}), + number{}), sequence{}); }(); const auto v_dram = [&]() { @@ -645,22 +635,19 @@ struct HstuAttentionFwdKernel } }(); - auto q_dram_window = make_tile_window( - q_dram, - [&]() { - if constexpr(HstuAttentionPipeline::kQLoadOnce) - return make_tuple(number{}, - number{}); - else - return make_tuple(number{}, - number{}); - }(), - {i_m0, 0}); + auto q_dram_window = + make_tile_window(q_dram, + [&]() { + return make_tuple(number{}, + number{}); + }(), + {i_m0, 0}); - auto k_dram_window = make_tile_window( - k_dram, - make_tuple(number{}, number{}), - {0, 0}); + auto k_dram_window = + make_tile_window(k_dram, + make_tuple(number{}, + number{}), + {0, 0}); auto v_dram_window = make_tile_window( v_dram, diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp index c458054010..315379365e 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp @@ -148,7 +148,7 @@ struct HstuAttentionFwdPipelineQRKSVS static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && - kK0 == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && + kQKHeaddim == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && kK1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && @@ -157,9 +157,7 @@ struct HstuAttentionFwdPipelineQRKSVS constexpr auto I0 = number<0>{}; - constexpr index_t k0_loops = kQKHeaddim / kK0; constexpr index_t k1_loops = kN0 / kK1; - static_assert(2 <= k0_loops); static_assert(2 <= k1_loops); constexpr auto NumKLdsBuffers = Policy::template GetNumKLdsBuffers(); @@ -178,19 +176,14 @@ struct HstuAttentionFwdPipelineQRKSVS const auto [seqlen_k_start, seqlen_k_end] = mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number{}, number{}); - auto k_dram_block_window = - make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(), - k_dram_block_window_tmp.get_window_lengths(), - {seqlen_k_start, 0}); - auto k_dram_window = - make_tile_window(k_dram_block_window.get_bottom_tensor_view(), - k_dram_block_window.get_window_lengths(), - k_dram_block_window.get_window_origin(), + make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + {seqlen_k_start, 0}, Policy::template MakeKDramTileDistribution()); auto k_tile = load_tile(k_dram_window); - move_tile_window(k_dram_window, {0, kK0}); + move_tile_window(k_dram_window, {kK1, 0}); auto q_tile = load_tile(q_dram_window); @@ -204,13 +197,14 @@ struct HstuAttentionFwdPipelineQRKSVS k_lds, Policy::template MakeKLdsBlockDescriptor().get_lengths(), {0, 0}); using k_lds_window_type = - decltype(get_slice_tile(k_lds_window, sequence<0, 0>{}, sequence{})); + decltype(get_slice_tile(k_lds_window, sequence<0, 0>{}, sequence{})); statically_indexed_array k_lds_windows; static_for<0, NumKLdsBuffers, 1>{}([&](auto i_buf) { - k_lds_windows[i_buf] = get_slice_tile( - k_lds_window, sequence{}, sequence<(i_buf + 1) * kN0, kK0>{}); + k_lds_windows[i_buf] = get_slice_tile(k_lds_window, + sequence{}, + sequence<(i_buf + 1) * kK1, kQKHeaddim>{}); }); auto v_dram_window = @@ -243,8 +237,11 @@ struct HstuAttentionFwdPipelineQRKSVS constexpr auto gemm_0 = Policy::template GetQKBlockGemm(); constexpr auto gemm_1 = Policy::template GetKVBlockGemm(); - using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile()); - auto s_acc = SaccBlockTileType{}; + using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile()); + using PcompBlockTileType = decltype(cast_tile(SaccBlockTileType{})); + + statically_indexed_array sacc_tiles; + statically_indexed_array pcomp_tiles; // reduction function for softmax const auto f_silu = [](CompDataType& x) { @@ -274,7 +271,7 @@ struct HstuAttentionFwdPipelineQRKSVS const auto bias_origin = bias_dram_block_window_tmp.get_window_origin(); auto bias_dram_window = make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(), - bias_dram_block_window_tmp.get_window_lengths(), + make_tuple(number{}, number{}), {bias_origin.at(number<0>{}), seqlen_k_start}, // M/N Policy::template MakeBiasDramTileDistribution()); @@ -303,105 +300,98 @@ struct HstuAttentionFwdPipelineQRKSVS q_tile = tile_elementwise_in(q_element_func, q_tile); + auto seqlen_k_curr = seqlen_k_start; + index_t i_loop = 0; do { - static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) { - store_tile(k_lds_windows[number{}], + static_for<0, k1_loops, 1>{}([&](auto i_k1) { + store_tile(k_lds_windows[number{}], tile_elementwise_in(k_element_func, k_tile)); - if constexpr(i_k0 == 0) - clear_tile(s_acc); - k_tile = load_tile(k_dram_window); - if constexpr(i_k0 < k0_loops - 2) - move_tile_window(k_dram_window, {0, kK0}); + clear_tile(sacc_tiles[i_k1]); + + if constexpr(i_k1 < k1_loops - 1) + { + k_tile = load_tile(k_dram_window); + move_tile_window(k_dram_window, {kK1, 0}); + } + else + { + static_for<0, NumPrefetchV, 1>{}([&](auto i_buf) { + v_tiles[i_buf] = load_tile(v_dram_window); + move_tile_window(v_dram_window, {0, kK1}); + }); + }; block_sync_lds(); // execute current unroll of gemm_0 - gemm_0(s_acc, - get_slice_tile( - q_tile, sequence<0, i_k0 * kK0>{}, sequence{}), - k_lds_windows[number{}]); + gemm_0(sacc_tiles[i_k1], q_tile, k_lds_windows[number{}]); + + sacc_tiles[i_k1] = tile_elementwise_in(s_acc_element_func, sacc_tiles[i_k1]); + + // STAGE 2, scale_s, add bias, mask, siLU + if constexpr(kHasBias) + { + const auto bias_tile = load_tile(bias_dram_window); // load bias tile + + tile_elementwise_inout( + [&scale_s, &bias_element_func](auto& x, const auto& y) { + x = x * scale_s + type_convert(bias_element_func(y)); + }, + sacc_tiles[i_k1], + bias_tile); + + move_tile_window(bias_dram_window, {0, kK1}); + } + else + { + tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, + sacc_tiles[i_k1]); + } + + if constexpr(HstuMask::IsMasking) + { + set_tile_if( + sacc_tiles[i_k1], type_convert(0), [&](auto tile_idx) { + const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = seqlen_k_curr + tile_idx.at(number<1>{}); + return !mask.IsTokenPairInsideMask(row, col); + }); + } + else if constexpr(kPadSeqLenK) + { + set_tile_if( + sacc_tiles[i_k1], type_convert(0), [&](auto tile_idx) { + if(q_origin.at(number<0>{}) + kM0 <= mask.max_uih_len && + i_loop < num_loops - 1) + return false; + const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = seqlen_k_curr + tile_idx.at(number<1>{}); + return !mask.IsTokenPairInsideMask(row, col); + }); + } + + pcomp_tiles[i_k1] = cast_tile(sacc_tiles[i_k1]); + + tile_elementwise_inout(f_silu, pcomp_tiles[i_k1]); + + if constexpr(kHasDropout) + { + auto randval_lds_ptr = reinterpret_cast(smem_ptr) + + Policy::template GetSmemSizeKV(); + + dropout.template Run( + randval_lds_ptr, seqlen_k_curr, pcomp_tiles[i_k1], null_randval_window); + } + + seqlen_k_curr += kK1; }); - store_tile(k_lds_windows[number<(k0_loops - 1) % NumKLdsBuffers>{}], - tile_elementwise_in(k_element_func, k_tile)); - - // prefetch first v_tile - v_tiles[I0] = load_tile(v_dram_window); - move_tile_window(v_dram_window, {0, kK1}); - - block_sync_lds(); - gemm_0(s_acc, - get_slice_tile(q_tile, - sequence<0, (k0_loops - 1) * kK0>{}, - sequence{}), - k_lds_windows[number<(k0_loops - 1) % NumKLdsBuffers>{}]); - - __builtin_amdgcn_sched_barrier(0); - - const auto bias_tile = load_tile(bias_dram_window); // load bias tile - - static_for<1, NumPrefetchV, 1>{}([&](auto i_buf) { - v_tiles[i_buf] = load_tile(v_dram_window); - move_tile_window(v_dram_window, {0, kK1}); - }); - - s_acc = tile_elementwise_in(s_acc_element_func, s_acc); - - // STAGE 2, scale_s, add bias, mask, siLU - if constexpr(kHasBias) - { - tile_elementwise_inout( - [&scale_s, &bias_element_func](auto& x, const auto& y) { - x = x * scale_s + type_convert(bias_element_func(y)); - }, - s_acc, - bias_tile); - - move_tile_window(bias_dram_window, {0, kN0}); - } - else - { - tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc); - } - - if constexpr(HstuMask::IsMasking) - { - const auto k_origin = k_dram_block_window.get_window_origin(); - set_tile_if(s_acc, type_convert(0), [&](auto tile_idx) { - const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); - const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); - return !mask.IsTokenPairInsideMask(row, col); - }); - } - else if constexpr(kPadSeqLenK) - { - const auto k_origin = k_dram_block_window.get_window_origin(); - set_tile_if(s_acc, type_convert(0), [&](auto tile_idx) { - if(q_origin.at(number<0>{}) + kM0 <= mask.max_uih_len && i_loop < num_loops - 1) - return false; - const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); - const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); - return !mask.IsTokenPairInsideMask(row, col); - }); - } - - auto s = cast_tile(s_acc); - - tile_elementwise_inout(f_silu, s); - - if constexpr(kHasDropout) - { - auto randval_lds_ptr = - reinterpret_cast(smem_ptr) + Policy::template GetSmemSizeKV(); - - dropout.template Run( - randval_lds_ptr, seqlen_k_start + i_loop * kN0, s, null_randval_window); - } - - __builtin_amdgcn_sched_barrier(0x7f); + // load one k_tile for next iteration + k_tile = load_tile(k_dram_window); + move_tile_window(k_dram_window, {kK1, 0}); if constexpr(std::is_same_v) { @@ -426,59 +416,50 @@ struct HstuAttentionFwdPipelineQRKSVS store_tile(v_lds_windows[I0], tile_elementwise_in(v_element_func, v_tiles[I0])); // store the prefetch - } + }; - const auto p = [&]() { - if constexpr(std::is_same_v) - return impl::cast_tile_pk_fp16_fp32( - tile_elementwise_in(p_compute_element_func, s)); - else - return cast_tile(tile_elementwise_in(p_compute_element_func, s)); - }(); + static_for<0, k1_loops, 1>{}([&](auto i_k1) { + const auto p = [&]() { + if constexpr(std::is_same_v) + return impl::cast_tile_pk_fp16_fp32( + tile_elementwise_in(p_compute_element_func, pcomp_tiles[i_k1])); + else + return cast_tile( + tile_elementwise_in(p_compute_element_func, pcomp_tiles[i_k1])); + }(); - move_tile_window(k_dram_window, {kN0, -(k0_loops - 1) * kK0}); - k_tile = load_tile(k_dram_window); - move_tile_window(k_dram_window, {0, kK0}); - - __builtin_amdgcn_sched_barrier(0); - - static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) { if constexpr(i_k1 < k1_loops - NumPrefetchV) + { v_tiles[number{}] = load_tile(v_dram_window); + move_tile_window(v_dram_window, {0, kK1}); + }; block_sync_lds(); - gemm_1( - o_acc, - get_slice_tile(p, sequence<0, i_k1 * kK1>{}, sequence{}), - v_lds_windows[number{}]); + gemm_1(o_acc, p, v_lds_windows[number{}]); - if constexpr(std::is_same_v) + if constexpr(i_k1 < k1_loops - 1) { - auto v_shuffle_tmp = make_static_distributed_tensor( - Policy::template MakeShuffledVRegBlockDescriptor()); - shuffle_tile(v_shuffle_tmp, v_tiles[number<(i_k1 + 1) % NumPrefetchV>{}]); - store_tile(v_lds_windows[number<(i_k1 + 1) % NumVLdsBuffers>{}], - tile_elementwise_in(v_element_func, v_shuffle_tmp)); - } - else - { - store_tile(v_lds_windows[number<(i_k1 + 1) % NumVLdsBuffers>{}], - tile_elementwise_in(v_element_func, - v_tiles[number<(i_k1 + 1) % NumPrefetchV>{}])); - } + if constexpr(std::is_same_v) + { + auto v_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledVRegBlockDescriptor()); + shuffle_tile(v_shuffle_tmp, v_tiles[number<(i_k1 + 1) % NumPrefetchV>{}]); - if constexpr(i_k1 < k1_loops - NumPrefetchV) - move_tile_window(v_dram_window, {0, kK1}); + store_tile(v_lds_windows[number<(i_k1 + 1) % NumVLdsBuffers>{}], + tile_elementwise_in(v_element_func, + v_shuffle_tmp)); // store the prefetch + } + else + { + store_tile(v_lds_windows[number<(i_k1 + 1) % NumVLdsBuffers>{}], + tile_elementwise_in( + v_element_func, + v_tiles[number<(i_k1 + 1) % NumPrefetchV>{}])); // store the + // prefetch + } + }; }); - // move K tile windows - move_tile_window(k_dram_block_window, {kN0, 0}); - - block_sync_lds(); - gemm_1(o_acc, - get_slice_tile(p, sequence<0, (k1_loops - 1) * kK1>{}, sequence{}), - v_lds_windows[number<(k1_loops - 1) % NumVLdsBuffers>{}]); - // the over-lap only occurs when k1_loops is 3/5/7, NumVLdsBuffers is 2 if constexpr(Policy::template IsFirstKLdsBufferOverlapLastVLdsBuffer()) __builtin_amdgcn_s_barrier(); diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_default_policy.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_default_policy.hpp index c4740a7b79..b702763bd0 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_default_policy.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_default_policy.hpp @@ -61,8 +61,8 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy template CK_TILE_HOST_DEVICE static constexpr auto GetKSingleSmemElementSpaceSize() { - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kK1; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim; constexpr index_t kKPack = GetSmemKPackK(); constexpr index_t kKVector = GetAlignmentK(); @@ -100,8 +100,8 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsBlockDescriptor() { constexpr index_t NumKLdsBuffers = GetNumKLdsBuffers(); - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kK1; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim; constexpr index_t kKPack = GetSmemKPackK(); constexpr index_t kKVector = GetAlignmentK(); @@ -147,8 +147,8 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy using QKVDataType = remove_cvref_t; constexpr index_t kBlockSize = Problem::kBlockSize; - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kK1; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim; constexpr index_t MaxVectorSize = 16 / sizeof(QKVDataType); @@ -300,8 +300,8 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy typename Problem::GemmAccDataType, Problem::kNumGemm0Warps * get_warp_size(), TileGemmShape, + Problem::BlockFmhaShape::kK1, + Problem::BlockFmhaShape::kQKHeaddim>, typename Problem::BlockFmhaShape::Gemm0BlockWarps, typename Problem::BlockFmhaShape::Gemm0WarpTile>>; diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_setting.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_setting.hpp index 5d82a3680c..c842df341b 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_setting.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_setting.hpp @@ -35,7 +35,7 @@ struct HstuAttentionFwdBlockTile<64> template <> struct HstuAttentionFwdBlockTile<128> { - using type = ck_tile::sequence<128, 64, 32, 128, 32, 128>; + using type = ck_tile::sequence<128, 128, 32, 128, 32, 128>; using gemm0_warps = ck_tile::sequence<4, 1, 1>; using gemm1_warps = ck_tile::sequence<4, 1, 1>; };