From afd7793e924b5131ad366934eb8d79f3beafdf81 Mon Sep 17 00:00:00 2001 From: Qianfeng Zhang Date: Sun, 18 May 2025 13:40:38 +0000 Subject: [PATCH] Prefetch K for next iteration from LDS in block_gemm_areg_bsmem_creg for gemm-0 --- .../block_gemm_areg_bsmem_creg_v2_hack_0.hpp | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/example/ck_tile/18_hstu_attention/block_gemm_areg_bsmem_creg_v2_hack_0.hpp b/example/ck_tile/18_hstu_attention/block_gemm_areg_bsmem_creg_v2_hack_0.hpp index 1b895c81a9..ada4407bcd 100644 --- a/example/ck_tile/18_hstu_attention/block_gemm_areg_bsmem_creg_v2_hack_0.hpp +++ b/example/ck_tile/18_hstu_attention/block_gemm_areg_bsmem_creg_v2_hack_0.hpp @@ -135,11 +135,17 @@ struct BlockGemmARegBSmemCRegV2Hack_0 constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; constexpr auto I0 = number<0>{}; + constexpr auto I1 = number<1>{}; // hot loop: static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - // read B warp tensor from B Block window - const auto b_warp_tensor_0 = load_tile(b_warp_windows(nIter)(I0)); + using b_warp_tensor_type = decltype(load_tile(b_warp_windows(I0)(I0))); + + statically_indexed_array b_warp_tensors; + + b_warp_tensors[I0] = load_tile(b_warp_windows(nIter)(I0)); + + b_warp_tensors[I1] = load_tile(b_warp_windows(nIter)(I1)); static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { // read A warp tensor from A block tensor @@ -150,7 +156,7 @@ struct BlockGemmARegBSmemCRegV2Hack_0 merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); // warp GEMM - auto c_warp_tensor = WG{}(a_warp_tensor, b_warp_tensor_0); + auto c_warp_tensor = WG{}(a_warp_tensor, b_warp_tensors[I0]); // WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor_array[nIter]); // write C warp tensor into C block tensor @@ -162,7 +168,9 @@ struct BlockGemmARegBSmemCRegV2Hack_0 static_for<1, KIterPerWarp, 1>{}([&](auto kIter) { // read B warp tensor from B Block window - const auto b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter)); + if constexpr(kIter < KIterPerWarp - 1) + b_warp_tensors[number{}] = + load_tile(b_warp_windows(nIter)(number{})); static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { // read A warp tensor from A block tensor @@ -180,7 +188,7 @@ struct BlockGemmARegBSmemCRegV2Hack_0 merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); // warp GEMM - WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); + WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensors[kIter]); // WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor_array[nIter]); // write C warp tensor into C block tensor