From 8f876f094e760ed645e647916308b470c16e44db Mon Sep 17 00:00:00 2001 From: Qianfeng Zhang Date: Mon, 10 Nov 2025 15:52:12 +0000 Subject: [PATCH] Simplify the codes in block_gemm_areg_bsmem_creg_v2_hack_1 --- .../block_gemm_areg_bsmem_creg_v2_hack_1.hpp | 43 ++++++------------- 1 file changed, 14 insertions(+), 29 deletions(-) diff --git a/example/ck_tile/18_hstu_attention/block_gemm_areg_bsmem_creg_v2_hack_1.hpp b/example/ck_tile/18_hstu_attention/block_gemm_areg_bsmem_creg_v2_hack_1.hpp index 4d35ad030d..0038d725a3 100644 --- a/example/ck_tile/18_hstu_attention/block_gemm_areg_bsmem_creg_v2_hack_1.hpp +++ b/example/ck_tile/18_hstu_attention/block_gemm_areg_bsmem_creg_v2_hack_1.hpp @@ -85,24 +85,10 @@ struct BlockGemmARegBSmemCRegV2Hack_1 b_block_window_tmp.get_window_origin() + multi_index<2>{iNWarp * WG::kN, 0}, make_static_tile_distribution(typename WG::BWarpDstrEncoding{})); -#if 0 // FIXME: using array will cause register spill - array, NIterPerWarp> b_warp_windows{ - {b_warp_window_tmp}}; - - for(index_t nIter = 0; nIter < NIterPerWarp; nIter++) - { - for(index_t kIter = 0; kIter < KIterPerWarp; kIter++) - { - move_tile_window(b_warp_windows(nIter)(kIter), - {nIter * NPerBlockPerIter, kIter * KPerBlockPerIter}); - } - } -#else statically_indexed_array< statically_indexed_array, NIterPerWarp> b_warp_windows; -#endif // check C-block-distribution static_assert( @@ -156,8 +142,15 @@ struct BlockGemmARegBSmemCRegV2Hack_1 __builtin_amdgcn_sched_barrier(0); - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; + + c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { // read A warp tensor from A block tensor AWarpTensor a_warp_tensor; @@ -165,22 +158,14 @@ struct BlockGemmARegBSmemCRegV2Hack_1 merge_sequences(sequence{}, a_warp_y_index_zeros), merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); - // read C warp tensor from C block tensor - CWarpTensor c_warp_tensor; - - c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - // warp GEMM WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensors[nIter][kIter]); - - // write C warp tensor into C block tensor - c_block_tensor.set_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); }); + // write C warp tensor into C block tensor + c_block_tensor.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); }); }); }