From ff3415d97dca63410e5848bbfa26e96bf69821b1 Mon Sep 17 00:00:00 2001 From: Qianfeng Zhang Date: Sun, 18 May 2025 15:02:36 +0000 Subject: [PATCH] Prefetch b_warp_tensor for next nIter and move b_warp_windows construction into n-iteration in block_gemm_areg_bsmem_creg for gemm-1 --- .../block_gemm_areg_bsmem_creg_v2_hack_1.hpp | 91 +++++++------------ 1 file changed, 31 insertions(+), 60 deletions(-) diff --git a/example/ck_tile/18_hstu_attention/block_gemm_areg_bsmem_creg_v2_hack_1.hpp b/example/ck_tile/18_hstu_attention/block_gemm_areg_bsmem_creg_v2_hack_1.hpp index a34fde90f4..0198704dbe 100644 --- a/example/ck_tile/18_hstu_attention/block_gemm_areg_bsmem_creg_v2_hack_1.hpp +++ b/example/ck_tile/18_hstu_attention/block_gemm_areg_bsmem_creg_v2_hack_1.hpp @@ -102,15 +102,6 @@ struct BlockGemmARegBSmemCRegV2Hack_1 statically_indexed_array, NIterPerWarp> b_warp_windows; - - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - b_warp_windows(nIter)(kIter) = b_warp_window_tmp; - - move_tile_window(b_warp_windows(nIter)(kIter), - {nIter * NPerBlockPerIter, kIter * KPerBlockPerIter}); - }); - }); #endif // check C-block-distribution @@ -134,63 +125,44 @@ struct BlockGemmARegBSmemCRegV2Hack_1 constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t{}; constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; - // hot loop: - if constexpr(KIterPerWarp > 1) - { - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + constexpr auto I0 = number<0>{}; + + using b_warp_tensor_type = decltype(load_tile(b_warp_windows(I0)(I0))); + + statically_indexed_array, + NIterPerWarp> + b_warp_tensors; + + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + b_warp_windows(I0)(kIter) = b_warp_window_tmp; + move_tile_window(b_warp_windows(I0)(kIter), + {0 * NPerBlockPerIter, kIter * KPerBlockPerIter}); + b_warp_tensors(I0)(kIter) = load_tile(b_warp_windows(I0)(kIter)); + }); + + __builtin_amdgcn_sched_barrier(0); + + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + if constexpr(nIter < NIterPerWarp - 1) + { static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - // read B warp tensor from B Block window - const auto b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter)); - - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - // read A warp tensor from A block tensor - AWarpTensor a_warp_tensor; - - a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data( - merge_sequences(sequence{}, a_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); - - // read C warp tensor from C block tensor - CWarpTensor c_warp_tensor; - - c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - - // warp GEMM - WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); - // WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor_array[nIter]); - - // write C warp tensor into C block tensor - c_block_tensor.set_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); - }); + b_warp_windows(number{})(kIter) = b_warp_window_tmp; + move_tile_window(b_warp_windows(number{})(kIter), + {(nIter + 1) * NPerBlockPerIter, kIter * KPerBlockPerIter}); + b_warp_tensors(number{})(kIter) = + load_tile(b_warp_windows(number{})(kIter)); }); - }); - } - else - { - constexpr auto I0 = number<0>{}; + }; - using b_warp_tensor_type = decltype(load_tile(b_warp_windows(I0)(I0))); - - statically_indexed_array b_warp_tensors; - - b_warp_tensors[I0] = load_tile(b_warp_windows(I0)(I0)); - - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - if constexpr(nIter < NIterPerWarp - 1) - b_warp_tensors[number<(nIter + 1) % 2>{}] = - load_tile(b_warp_windows(number{})(I0)); + __builtin_amdgcn_sched_barrier(0); + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { // read A warp tensor from A block tensor AWarpTensor a_warp_tensor; a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data( - merge_sequences(sequence{}, a_warp_y_index_zeros), + merge_sequences(sequence{}, a_warp_y_index_zeros), merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); // read C warp tensor from C block tensor @@ -201,8 +173,7 @@ struct BlockGemmARegBSmemCRegV2Hack_1 merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); // warp GEMM - WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensors[number{}]); - // WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor_array[nIter]); + WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensors[nIter][kIter]); // write C warp tensor into C block tensor c_block_tensor.set_y_sliced_thread_data( @@ -211,7 +182,7 @@ struct BlockGemmARegBSmemCRegV2Hack_1 c_warp_tensor.get_thread_buffer()); }); }); - } + }); } template