From 303818a85171793d938e97d6a8bfc15199a72329 Mon Sep 17 00:00:00 2001 From: Qianfeng Zhang Date: Mon, 10 Nov 2025 15:27:34 +0000 Subject: [PATCH] Simplify the codes in block_gemm_areg_bsmem_trload_creg --- ..._gemm_areg_bsmem_trload_creg_v2_hack_1.hpp | 30 +++++++++---------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/example/ck_tile/18_hstu_attention/block_gemm_areg_bsmem_trload_creg_v2_hack_1.hpp b/example/ck_tile/18_hstu_attention/block_gemm_areg_bsmem_trload_creg_v2_hack_1.hpp index 07e52f1bc9..f2ca10f4e8 100644 --- a/example/ck_tile/18_hstu_attention/block_gemm_areg_bsmem_trload_creg_v2_hack_1.hpp +++ b/example/ck_tile/18_hstu_attention/block_gemm_areg_bsmem_trload_creg_v2_hack_1.hpp @@ -146,8 +146,15 @@ struct BlockGemmARegBSmemTrLoadCRegV2Hack_1 __builtin_amdgcn_sched_barrier(0); - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; + + c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { // read A warp tensor from A block tensor AWarpTensor a_warp_tensor; @@ -155,22 +162,15 @@ struct BlockGemmARegBSmemTrLoadCRegV2Hack_1 merge_sequences(sequence{}, a_warp_y_index_zeros), merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); - // read C warp tensor from C block tensor - CWarpTensor c_warp_tensor; - - c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - // warp GEMM WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensors[nIter][kIter]); - - // write C warp tensor into C block tensor - c_block_tensor.set_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); }); + + // write C warp tensor into C block tensor + c_block_tensor.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); }); }); }