diff --git a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_prefetch_k.hpp b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_prefetch_k.hpp index 3f21c44207..ffd6bb18bf 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_prefetch_k.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_prefetch_k.hpp @@ -128,7 +128,6 @@ struct BlockGemmARegBSmemCRegV2PrefetchK constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; constexpr auto I0 = number<0>{}; - constexpr auto I1 = number<1>{}; // hot loop: static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { @@ -136,6 +135,7 @@ struct BlockGemmARegBSmemCRegV2PrefetchK statically_indexed_array b_warp_tensors; + // read B warp tensor from B Block window b_warp_windows(nIter)(I0) = b_warp_window_tmp; move_tile_window(b_warp_windows(nIter)(I0), {nIter * NPerBlockPerIter, 0 * KPerBlockPerIter}); @@ -143,36 +143,10 @@ struct BlockGemmARegBSmemCRegV2PrefetchK __builtin_amdgcn_sched_barrier(0); - b_warp_windows(nIter)(I1) = b_warp_window_tmp; - move_tile_window(b_warp_windows(nIter)(I1), - {nIter * NPerBlockPerIter, 1 * KPerBlockPerIter}); - b_warp_tensors[I1] = load_tile(b_warp_windows(nIter)(I1)); - - __builtin_amdgcn_sched_barrier(0); - - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - // read A warp tensor from A block tensor - AWarpTensor a_warp_tensor; - - a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data( - merge_sequences(sequence{}, a_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); - - // warp GEMM - auto c_warp_tensor = WG{}(a_warp_tensor, b_warp_tensors[I0]); - // WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor_array[nIter]); - - // write C warp tensor into C block tensor - c_block_tensor.set_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); - }); - - static_for<1, KIterPerWarp, 1>{}([&](auto kIter) { - // read B warp tensor from B Block window + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { if constexpr(kIter < KIterPerWarp - 1) { + // read B warp tensor from B Block window b_warp_windows(nIter)(number{}) = b_warp_window_tmp; move_tile_window(b_warp_windows(nIter)(number{}), {nIter * NPerBlockPerIter, (kIter + 1) * KPerBlockPerIter}); @@ -193,13 +167,22 @@ struct BlockGemmARegBSmemCRegV2PrefetchK // read C warp tensor from C block tensor CWarpTensor c_warp_tensor; - c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + if constexpr(kIter == 0) + { + // warp GEMM + c_warp_tensor = WG{}(a_warp_tensor, b_warp_tensors[kIter]); + // WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor_array[nIter]); + } + else + { + c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - // warp GEMM - WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensors[kIter]); - // WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor_array[nIter]); + // warp GEMM + WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensors[kIter]); + // WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor_array[nIter]); + }; // write C warp tensor into C block tensor c_block_tensor.set_y_sliced_thread_data( diff --git a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_prefetch_n.hpp b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_prefetch_n.hpp index 3ad4037926..89a00053c7 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_prefetch_n.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_prefetch_n.hpp @@ -115,42 +115,31 @@ struct BlockGemmARegBSmemCRegV2PrefetchN using b_warp_tensor_type = decltype(load_tile(b_warp_windows(I0)(I0))); - statically_indexed_array, - NIterPerWarp> - b_warp_tensors; - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + statically_indexed_array b_warp_tensors; + + // read B warp tensor from B Block window b_warp_windows(I0)(kIter) = b_warp_window_tmp; move_tile_window(b_warp_windows(I0)(kIter), {0 * NPerBlockPerIter, kIter * KPerBlockPerIter}); - b_warp_tensors(I0)(kIter) = load_tile(b_warp_windows(I0)(kIter)); - }); + b_warp_tensors(I0) = load_tile(b_warp_windows(I0)(kIter)); - __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_sched_barrier(0x00000001); - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - if constexpr(nIter < NIterPerWarp - 1) - { - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + if constexpr(nIter < NIterPerWarp - 1) + { + // read B warp tensor from B Block window b_warp_windows(number{})(kIter) = b_warp_window_tmp; move_tile_window(b_warp_windows(number{})(kIter), {(nIter + 1) * NPerBlockPerIter, kIter * KPerBlockPerIter}); - b_warp_tensors(number{})(kIter) = + b_warp_tensors(number{}) = load_tile(b_warp_windows(number{})(kIter)); - }); - }; + }; - __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_sched_barrier(0x0000001); - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - // read C warp tensor from C block tensor - CWarpTensor c_warp_tensor; - - c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { // read A warp tensor from A block tensor AWarpTensor a_warp_tensor; @@ -158,14 +147,22 @@ struct BlockGemmARegBSmemCRegV2PrefetchN merge_sequences(sequence{}, a_warp_y_index_zeros), merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; + + c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + // warp GEMM - WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensors[nIter][kIter]); + WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensors[nIter]); + + // write C warp tensor into C block tensor + c_block_tensor.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); }); - // write C warp tensor into C block tensor - c_block_tensor.set_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); }); }); }