diff --git a/example/ck_tile/18_hstu_attention/block_gemm_areg_bsmem_creg_v2_hack_0.hpp b/example/ck_tile/18_hstu_attention/block_gemm_areg_bsmem_creg_v2_hack_0.hpp index 9bb80cf258..b267dbdca9 100644 --- a/example/ck_tile/18_hstu_attention/block_gemm_areg_bsmem_creg_v2_hack_0.hpp +++ b/example/ck_tile/18_hstu_attention/block_gemm_areg_bsmem_creg_v2_hack_0.hpp @@ -87,24 +87,10 @@ struct BlockGemmARegBSmemCRegV2Hack_0 b_block_window_tmp.get_window_origin() + multi_index<2>{iNWarp * WG::kN, 0}, make_static_tile_distribution(typename WG::BWarpDstrEncoding{})); -#if 0 // FIXME: using array will cause register spill - array, NIterPerWarp> b_warp_windows{ - {b_warp_window_tmp}}; - - for(index_t nIter = 0; nIter < NIterPerWarp; nIter++) - { - for(index_t kIter = 0; kIter < KIterPerWarp; kIter++) - { - move_tile_window(b_warp_windows(nIter)(kIter), - {nIter * NPerBlockPerIter, kIter * KPerBlockPerIter}); - } - } -#else statically_indexed_array< statically_indexed_array, NIterPerWarp> b_warp_windows; -#endif // check C-block-distribution static_assert( @@ -128,7 +114,6 @@ struct BlockGemmARegBSmemCRegV2Hack_0 constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; constexpr auto I0 = number<0>{}; - constexpr auto I1 = number<1>{}; // hot loop: static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { @@ -136,6 +121,7 @@ struct BlockGemmARegBSmemCRegV2Hack_0 statically_indexed_array b_warp_tensors; + // read B warp tensor from B Block window b_warp_windows(nIter)(I0) = b_warp_window_tmp; move_tile_window(b_warp_windows(nIter)(I0), {nIter * NPerBlockPerIter, 0 * KPerBlockPerIter}); @@ -143,36 +129,10 @@ struct BlockGemmARegBSmemCRegV2Hack_0 __builtin_amdgcn_sched_barrier(0); - b_warp_windows(nIter)(I1) = b_warp_window_tmp; - move_tile_window(b_warp_windows(nIter)(I1), - {nIter * NPerBlockPerIter, 1 * KPerBlockPerIter}); - b_warp_tensors[I1] = load_tile(b_warp_windows(nIter)(I1)); - - __builtin_amdgcn_sched_barrier(0); - - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - // read A warp tensor from A block tensor - AWarpTensor a_warp_tensor; - - a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data( - merge_sequences(sequence{}, a_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); - - // warp GEMM - auto c_warp_tensor = WG{}(a_warp_tensor, b_warp_tensors[I0]); - // WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor_array[nIter]); - - // write C warp tensor into C block tensor - c_block_tensor.set_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); - }); - - static_for<1, KIterPerWarp, 1>{}([&](auto kIter) { - // read B warp tensor from B Block window + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { if constexpr(kIter < KIterPerWarp - 1) { + // read B warp tensor from B Block window b_warp_windows(nIter)(number{}) = b_warp_window_tmp; move_tile_window(b_warp_windows(nIter)(number{}), {nIter * NPerBlockPerIter, (kIter + 1) * KPerBlockPerIter}); @@ -193,13 +153,22 @@ struct BlockGemmARegBSmemCRegV2Hack_0 // read C warp tensor from C block tensor CWarpTensor c_warp_tensor; - c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + if constexpr(kIter == 0) + { + // warp GEMM + c_warp_tensor = WG{}(a_warp_tensor, b_warp_tensors[kIter]); + // WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor_array[nIter]); + } + else + { + c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - // warp GEMM - WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensors[kIter]); - // WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor_array[nIter]); + // warp GEMM + WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensors[kIter]); + // WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor_array[nIter]); + }; // write C warp tensor into C block tensor c_block_tensor.set_y_sliced_thread_data( diff --git a/example/ck_tile/18_hstu_attention/block_gemm_areg_bsmem_creg_v2_hack_1.hpp b/example/ck_tile/18_hstu_attention/block_gemm_areg_bsmem_creg_v2_hack_1.hpp index a7a21ef311..8131639512 100644 --- a/example/ck_tile/18_hstu_attention/block_gemm_areg_bsmem_creg_v2_hack_1.hpp +++ b/example/ck_tile/18_hstu_attention/block_gemm_areg_bsmem_creg_v2_hack_1.hpp @@ -115,42 +115,31 @@ struct BlockGemmARegBSmemCRegV2Hack_1 using b_warp_tensor_type = decltype(load_tile(b_warp_windows(I0)(I0))); - statically_indexed_array, - NIterPerWarp> - b_warp_tensors; - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + statically_indexed_array b_warp_tensors; + + // read B warp tensor from B Block window b_warp_windows(I0)(kIter) = b_warp_window_tmp; move_tile_window(b_warp_windows(I0)(kIter), {0 * NPerBlockPerIter, kIter * KPerBlockPerIter}); - b_warp_tensors(I0)(kIter) = load_tile(b_warp_windows(I0)(kIter)); - }); - - __builtin_amdgcn_sched_barrier(0); - - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - if constexpr(nIter < NIterPerWarp - 1) - { - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - b_warp_windows(number{})(kIter) = b_warp_window_tmp; - move_tile_window(b_warp_windows(number{})(kIter), - {(nIter + 1) * NPerBlockPerIter, kIter * KPerBlockPerIter}); - b_warp_tensors(number{})(kIter) = - load_tile(b_warp_windows(number{})(kIter)); - }); - }; + b_warp_tensors(I0) = load_tile(b_warp_windows(I0)(kIter)); __builtin_amdgcn_sched_barrier(0); - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - // read C warp tensor from C block tensor - CWarpTensor c_warp_tensor; + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + if constexpr(nIter < NIterPerWarp - 1) + { + // read B warp tensor from B Block window + b_warp_windows(number{})(kIter) = b_warp_window_tmp; + move_tile_window(b_warp_windows(number{})(kIter), + {(nIter + 1) * NPerBlockPerIter, kIter * KPerBlockPerIter}); + b_warp_tensors(number{}) = + load_tile(b_warp_windows(number{})(kIter)); + }; - c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + __builtin_amdgcn_sched_barrier(0); - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { // read A warp tensor from A block tensor AWarpTensor a_warp_tensor; @@ -158,14 +147,22 @@ struct BlockGemmARegBSmemCRegV2Hack_1 merge_sequences(sequence{}, a_warp_y_index_zeros), merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; + + c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + // warp GEMM - WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensors[nIter][kIter]); + WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensors[nIter]); + + // write C warp tensor into C block tensor + c_block_tensor.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); }); - // write C warp tensor into C block tensor - c_block_tensor.set_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); }); }); } diff --git a/example/ck_tile/18_hstu_attention/block_gemm_areg_bsmem_trload_creg_v2_hack_1.hpp b/example/ck_tile/18_hstu_attention/block_gemm_areg_bsmem_trload_creg_v2_hack_1.hpp index 359e55a0e0..716ad8fb5a 100644 --- a/example/ck_tile/18_hstu_attention/block_gemm_areg_bsmem_trload_creg_v2_hack_1.hpp +++ b/example/ck_tile/18_hstu_attention/block_gemm_areg_bsmem_trload_creg_v2_hack_1.hpp @@ -119,42 +119,31 @@ struct BlockGemmARegBSmemTrLoadCRegV2Hack_1 using b_warp_tensor_type = decltype(load_tile_transpose(b_warp_windows(I0)(I0))); - statically_indexed_array, - NIterPerWarp> - b_warp_tensors; - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + statically_indexed_array b_warp_tensors; + + // read B warp tensor from B Block window b_warp_windows(I0)(kIter) = b_warp_window_tmp; move_tile_window(b_warp_windows(I0)(kIter), {kIter * KPerBlockPerIter, 0 * NPerBlockPerIter}); - b_warp_tensors(I0)(kIter) = load_tile_transpose(b_warp_windows(I0)(kIter)); - }); - - __builtin_amdgcn_sched_barrier(0); - - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - if constexpr(nIter < NIterPerWarp - 1) - { - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - b_warp_windows(number{})(kIter) = b_warp_window_tmp; - move_tile_window(b_warp_windows(number{})(kIter), - {kIter * KPerBlockPerIter, (nIter + 1) * NPerBlockPerIter}); - b_warp_tensors(number{})(kIter) = - load_tile_transpose(b_warp_windows(number{})(kIter)); - }); - }; + b_warp_tensors(I0) = load_tile_transpose(b_warp_windows(I0)(kIter)); __builtin_amdgcn_sched_barrier(0); - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - // read C warp tensor from C block tensor - CWarpTensor c_warp_tensor; + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + if constexpr(nIter < NIterPerWarp - 1) + { + // read B warp tensor from B Block window + b_warp_windows(number{})(kIter) = b_warp_window_tmp; + move_tile_window(b_warp_windows(number{})(kIter), + {kIter * KPerBlockPerIter, (nIter + 1) * NPerBlockPerIter}); + b_warp_tensors(number{}) = + load_tile_transpose(b_warp_windows(number{})(kIter)); + }; - c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + __builtin_amdgcn_sched_barrier(0); - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { // read A warp tensor from A block tensor AWarpTensor a_warp_tensor; @@ -162,15 +151,22 @@ struct BlockGemmARegBSmemTrLoadCRegV2Hack_1 merge_sequences(sequence{}, a_warp_y_index_zeros), merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); - // warp GEMM - WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensors[nIter][kIter]); - }); + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; - // write C warp tensor into C block tensor - c_block_tensor.set_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); + c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + + // warp GEMM + WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensors[nIter]); + + // write C warp tensor into C block tensor + c_block_tensor.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); + }); }); }); }