From d32851e15ccda84f95929e1e3d575f2a43db4757 Mon Sep 17 00:00:00 2001 From: Qianfeng Zhang Date: Wed, 7 May 2025 10:09:23 +0000 Subject: [PATCH] Simplification in the static iterations of block_gemm_areg_bsmem_creg_v2_hack --- .../block_gemm_areg_bsmem_creg_v2_hack.hpp | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/example/ck_tile/18_hstu_attention/block_gemm_areg_bsmem_creg_v2_hack.hpp b/example/ck_tile/18_hstu_attention/block_gemm_areg_bsmem_creg_v2_hack.hpp index 78e8bfad4b..f0145e7b85 100644 --- a/example/ck_tile/18_hstu_attention/block_gemm_areg_bsmem_creg_v2_hack.hpp +++ b/example/ck_tile/18_hstu_attention/block_gemm_areg_bsmem_creg_v2_hack.hpp @@ -139,7 +139,7 @@ struct BlockGemmARegBSmemCRegV2Hack // hot loop: static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { // read B warp tensor from B Block window - const auto b_warp_tensor = load_tile(b_warp_windows(nIter)(I0)); + const auto b_warp_tensor_0 = load_tile(b_warp_windows(nIter)(I0)); static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { // read A warp tensor from A block tensor @@ -150,7 +150,7 @@ struct BlockGemmARegBSmemCRegV2Hack merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); // warp GEMM - auto c_warp_tensor = WG{}(a_warp_tensor, b_warp_tensor); + auto c_warp_tensor = WG{}(a_warp_tensor, b_warp_tensor_0); // WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor_array[nIter]); // write C warp tensor into C block tensor @@ -159,10 +159,8 @@ struct BlockGemmARegBSmemCRegV2Hack merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), c_warp_tensor.get_thread_buffer()); }); - }); - static_for<1, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + static_for<1, KIterPerWarp, 1>{}([&](auto kIter) { // read B warp tensor from B Block window const auto b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter));