diff --git a/composable_kernel/include/tensor_description/tensor_descriptor.hpp b/composable_kernel/include/tensor_description/tensor_descriptor.hpp index 9821ee8641..4038ef63da 100644 --- a/composable_kernel/include/tensor_description/tensor_descriptor.hpp +++ b/composable_kernel/include/tensor_description/tensor_descriptor.hpp @@ -10,7 +10,7 @@ template struct TensorCoordinate; template -struct TensorCoordinateIterator; +struct TensorCoordinateStep; // Transforms: Tuple // LowerDimensionIdss : Tuple, ...> @@ -252,17 +252,16 @@ struct TensorCoordinate }; template -struct TensorCoordinateIterator +struct TensorCoordinateStep { // TODO make these private using VisibleIndex = MultiIndex; public: - __host__ __device__ constexpr TensorCoordinateIterator() = default; + __host__ __device__ constexpr TensorCoordinateStep() = default; - __host__ - __device__ constexpr TensorCoordinateIterator(const VisibleIndex& idx_diff_visible, - const MultiIndex& do_transforms) + __host__ __device__ constexpr TensorCoordinateStep(const VisibleIndex& idx_diff_visible, + const MultiIndex& do_transforms) : idx_diff_visible_{idx_diff_visible}, do_transforms_{do_transforms} { } @@ -423,8 +422,9 @@ __host__ __device__ constexpr auto make_tensor_coordinate(const TensorDesc& tens // UpdateLowerIndexHack: Sequence<...> // HACK: control UpdateLowerIndex template -__host__ __device__ constexpr auto make_tensor_coordinate_iterator( - const TensorDesc&, const VisibleIndex& idx_diff_visible, UpdateLowerIndexHack) +__host__ __device__ constexpr auto make_tensor_coordinate_step(const TensorDesc&, + const VisibleIndex& idx_diff_visible, + UpdateLowerIndexHack) { static_assert(TensorDesc::GetNumOfDimension() == VisibleIndex::Size(), "wrong! # of dimension inconsistent"); @@ -471,24 +471,24 @@ __host__ __device__ constexpr auto make_tensor_coordinate_iterator( set_container_subset(is_non_zero_diff, dims_low, non_zero_diff_pick_low); }); - return TensorCoordinateIterator{ - idx_diff_visible, do_transforms}; + return TensorCoordinateStep{idx_diff_visible, + do_transforms}; } template -__host__ __device__ constexpr auto -make_tensor_coordinate_iterator(const TensorDesc&, const VisibleIndex& idx_diff_visible) +__host__ __device__ constexpr auto make_tensor_coordinate_step(const TensorDesc&, + const VisibleIndex& idx_diff_visible) { constexpr index_t ntransform = TensorDesc::GetNumOfTransform(); - return make_tensor_coordinate_iterator( + return make_tensor_coordinate_step( TensorDesc{}, idx_diff_visible, typename uniform_sequence_gen::type{}); } -template +template __host__ __device__ constexpr void move_tensor_coordinate(const TensorDesc& tensor_desc, TensorCoord& coord, - const TensorCoordIterator& coord_iterator) + const TensorCoordStep& coord_step) { constexpr index_t ndim_hidden = TensorDesc::GetNumOfHiddenDimension(); constexpr index_t ntransform = TensorDesc::GetNumOfTransform(); @@ -497,9 +497,8 @@ __host__ __device__ constexpr void move_tensor_coordinate(const TensorDesc& tens auto idx_diff_hidden = make_zero_multi_index(); // initialize visible index diff - set_container_subset(idx_diff_hidden, - TensorDesc::GetVisibleDimensionIds(), - coord_iterator.GetVisibleIndexDiff()); + set_container_subset( + idx_diff_hidden, TensorDesc::GetVisibleDimensionIds(), coord_step.GetVisibleIndexDiff()); // this is what needs to be updated auto& idx_hidden = coord.GetHiddenIndex(); @@ -508,13 +507,13 @@ __host__ __device__ constexpr void move_tensor_coordinate(const TensorDesc& tens auto idx_hidden_pick_visible = get_container_subset(idx_hidden, TensorDesc::GetVisibleDimensionIds()); - idx_hidden_pick_visible += coord_iterator.GetIndexDiff(); + idx_hidden_pick_visible += coord_step.GetIndexDiff(); set_container_subset(idx_hidden, TensorDesc::GetVisibleDimensionIds(), idx_hidden_pick_visible); // update rest of hidden index static_for{}([&](auto itran) { - if(coord_iterator.do_transforms_[itran]) + if(coord_step.do_transforms_[itran]) { const auto& tran = tensor_desc.GetTransforms().At(itran); constexpr auto dims_low = TensorDesc::GetLowerDimensionIdss().At(itran); @@ -527,7 +526,7 @@ __host__ __device__ constexpr void move_tensor_coordinate(const TensorDesc& tens MultiIndex idx_diff_low; // HACK: control UpdateLowerIndex for Merge using hack - constexpr index_t Hack = decltype(coord_iterator.update_lower_index_hack_)::At(itran); + constexpr index_t Hack = decltype(coord_step.update_lower_index_hack_)::At(itran); tran.UpdateLowerIndex(idx_diff_low, idx_diff_up, idx_low, idx_up_new, Number{}); @@ -591,7 +590,7 @@ using TensorCoordinate_t = decltype(make_tensor_coordinate( TensorDesc{}, MultiIndex>::GetNumOfDimension()>{})); template -using TensorCoordinateIterator_t = decltype(make_tensor_coordinate_iterator( +using TensorCoordinateStep_t = decltype(make_tensor_coordinate_step( TensorDesc{}, MultiIndex>::GetNumOfDimension()>{})); } // namespace ck diff --git a/composable_kernel/include/tensor_operation/blockwise_tensor_slice_transfer.hpp b/composable_kernel/include/tensor_operation/blockwise_tensor_slice_transfer.hpp index 4303b6a4ca..cf21123de6 100644 --- a/composable_kernel/include/tensor_operation/blockwise_tensor_slice_transfer.hpp +++ b/composable_kernel/include/tensor_operation/blockwise_tensor_slice_transfer.hpp @@ -77,15 +77,14 @@ struct BlockwiseTensorSliceTransfer_v4 } } - template - __device__ void RunRead(const SrcDesc& src_desc, - const SrcBuffer& src_buf, - const SrcIteratorHacks& src_iterator_hacks) + template + __device__ void + RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf, const SrcStepHacks& src_step_hacks) { if(BlockSize == thread_cluster_desc_.GetElementSize() or get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) { - threadwise_transfer_.RunRead(src_desc, src_buf, src_iterator_hacks); + threadwise_transfer_.RunRead(src_desc, src_buf, src_step_hacks); } } @@ -118,18 +117,18 @@ struct BlockwiseTensorSliceTransfer_v4 } } - // SrcMoveSliceWindowIteratorHack to control index calculation move slice window - template + // SrcMoveSliceWindowStepHack to control index calculation move slice window + template __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step, - const SrcMoveSliceWindowIteratorHack& src_move_slice_window_iterator_hack) + const SrcMoveSliceWindowStepHack& src_move_slice_window_step_hack) { if(BlockSize == thread_cluster_desc_.GetElementSize() or get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) { threadwise_transfer_.MoveSrcSliceWindow( - src_desc, step, src_move_slice_window_iterator_hack); + src_desc, step, src_move_slice_window_step_hack); } } diff --git a/composable_kernel/include/tensor_operation/blockwise_tensor_slice_transfer_v2.hpp b/composable_kernel/include/tensor_operation/blockwise_tensor_slice_transfer_v2.hpp index 25df52904d..4f3336f9f7 100644 --- a/composable_kernel/include/tensor_operation/blockwise_tensor_slice_transfer_v2.hpp +++ b/composable_kernel/include/tensor_operation/blockwise_tensor_slice_transfer_v2.hpp @@ -75,15 +75,14 @@ struct BlockwiseTensorSliceTransfer_v4r1 } } - template - __device__ void RunRead(const SrcDesc& src_desc, - const SrcBuffer& src_buf, - const SrcIteratorHacks& src_iterator_hacks) + template + __device__ void + RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf, const SrcStepHacks& src_step_hacks) { if(BlockSize == thread_cluster_desc_.GetElementSize() or get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) { - threadwise_transfer_.RunRead(src_desc, src_buf, src_iterator_hacks); + threadwise_transfer_.RunRead(src_desc, src_buf, src_step_hacks); } } @@ -106,18 +105,18 @@ struct BlockwiseTensorSliceTransfer_v4r1 } } - // SrcMoveSliceWindowIteratorHack to control index calculation move slice window - template + // SrcMoveSliceWindowStepHack to control index calculation move slice window + template __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step, - const SrcMoveSliceWindowIteratorHack& src_move_slice_window_iterator_hack) + const SrcMoveSliceWindowStepHack& src_move_slice_window_step_hack) { if(BlockSize == thread_cluster_desc_.GetElementSize() or get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) { threadwise_transfer_.MoveSrcSliceWindow( - src_desc, step, src_move_slice_window_iterator_hack); + src_desc, step, src_move_slice_window_step_hack); } } diff --git a/composable_kernel/include/tensor_operation/gridwise_contraction_dlops_v1r2.hpp b/composable_kernel/include/tensor_operation/gridwise_contraction_dlops_v1r2.hpp index 3070045554..366451dcc3 100644 --- a/composable_kernel/include/tensor_operation/gridwise_contraction_dlops_v1r2.hpp +++ b/composable_kernel/include/tensor_operation/gridwise_contraction_dlops_v1r2.hpp @@ -84,11 +84,11 @@ template + typename AGridStepHacks, + typename BGridStepHacks, + typename CGridStepHacks, + typename AGridMoveSliceWindowStepHacks, + typename BGridMoveSliceWindowStepHacks> struct GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1 { static constexpr auto I0 = Number<0>{}; @@ -496,9 +496,9 @@ struct GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN // LDS double buffer: preload data into LDS { a_blockwise_copy.RunRead( - a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridIteratorHacks{}); + a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridStepHacks{}); b_blockwise_copy.RunRead( - b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridIteratorHacks{}); + b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridStepHacks{}); a_blockwise_copy.RunWrite(a_block_desc_gk0_gm0_gm10_gm11_gk1, a_block_even_buf); b_blockwise_copy.RunWrite(b_block_desc_gk0_gn0_gn10_gn11_gk1, b_block_even_buf); @@ -515,18 +515,18 @@ struct GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN // even iteration a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_block_slice_copy_step, - AGridMoveSliceWindowIteratorHacks{}); + AGridMoveSliceWindowStepHacks{}); b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_block_slice_copy_step, - BGridMoveSliceWindowIteratorHacks{}); + BGridMoveSliceWindowStepHacks{}); __syncthreads(); // LDS doubel buffer: load next data from device mem a_blockwise_copy.RunRead( - a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridIteratorHacks{}); + a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridStepHacks{}); b_blockwise_copy.RunRead( - b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridIteratorHacks{}); + b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridStepHacks{}); // LDS double buffer: GEMM on current data blockwise_gemm.Run(c_thread_desc_bm0_bm1_bn0_bn1, @@ -541,18 +541,18 @@ struct GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN // odd iteration a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_block_slice_copy_step, - AGridMoveSliceWindowIteratorHacks{}); + AGridMoveSliceWindowStepHacks{}); b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_block_slice_copy_step, - BGridMoveSliceWindowIteratorHacks{}); + BGridMoveSliceWindowStepHacks{}); __syncthreads(); // LDS doubel buffer: load next data from device mem a_blockwise_copy.RunRead( - a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridIteratorHacks{}); + a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridStepHacks{}); b_blockwise_copy.RunRead( - b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridIteratorHacks{}); + b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridStepHacks{}); // LDS double buffer: GEMM on current data blockwise_gemm.Run( @@ -571,18 +571,18 @@ struct GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN { a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_block_slice_copy_step, - AGridMoveSliceWindowIteratorHacks{}); + AGridMoveSliceWindowStepHacks{}); b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_block_slice_copy_step, - BGridMoveSliceWindowIteratorHacks{}); + BGridMoveSliceWindowStepHacks{}); __syncthreads(); // LDS double buffer: load last data from device mem a_blockwise_copy.RunRead( - a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridIteratorHacks{}); + a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridStepHacks{}); b_blockwise_copy.RunRead( - b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridIteratorHacks{}); + b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridStepHacks{}); // LDS double buffer: GEMM on 2nd-last data blockwise_gemm.Run( @@ -650,7 +650,7 @@ struct GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN c_thread_buf, c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1, c_grid_buf, - CGridIteratorHacks{}); + CGridStepHacks{}); } } }; diff --git a/composable_kernel/include/tensor_operation/gridwise_gemm_dlops_v1r2.hpp b/composable_kernel/include/tensor_operation/gridwise_gemm_dlops_v1r2.hpp index 88f2059bbf..31a0fa342a 100644 --- a/composable_kernel/include/tensor_operation/gridwise_gemm_dlops_v1r2.hpp +++ b/composable_kernel/include/tensor_operation/gridwise_gemm_dlops_v1r2.hpp @@ -145,11 +145,11 @@ template + typename AGridStepHacks, + typename BGridStepHacks, + typename CGridStepHacks, + typename AGridMoveSliceWindowStepHacks, + typename BGridMoveSliceWindowStepHacks> struct GridwiseGemmDlops_km_kn_mn_v1r2 { static constexpr auto I0 = Number<0>{}; @@ -475,15 +475,15 @@ struct GridwiseGemmDlops_km_kn_mn_v1r2 constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0); // hack to control index calculation when iterating over A and B matrix for threadwise copy - constexpr auto a_k_m0_m1_global_iterator_hacks = AGridIteratorHacks{}; - constexpr auto b_k_n0_n1_global_iterator_hacks = BGridIteratorHacks{}; + constexpr auto a_k_m0_m1_global_step_hacks = AGridStepHacks{}; + constexpr auto b_k_n0_n1_global_step_hacks = BGridStepHacks{}; // hack to control index calculation when move slice window for A and B matrix for // threadwise copy - constexpr auto a_k_m0_m1_global_move_slice_window_iterator_hack = - AGridMoveSliceWindowIteratorHacks{}; - constexpr auto b_k_n0_n1_global_move_slice_window_iterator_hack = - BGridMoveSliceWindowIteratorHacks{}; + constexpr auto a_k_m0_m1_global_move_slice_window_step_hack = + AGridMoveSliceWindowStepHacks{}; + constexpr auto b_k_n0_n1_global_move_slice_window_step_hack = + BGridMoveSliceWindowStepHacks{}; auto a_block_even_buf = make_dynamic_buffer( p_a_block_double, a_k_m0_m1_block_desc.GetElementSpaceSize()); @@ -500,9 +500,9 @@ struct GridwiseGemmDlops_km_kn_mn_v1r2 // LDS double buffer: preload data into LDS { a_blockwise_copy.RunRead( - a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_iterator_hacks); + a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_step_hacks); b_blockwise_copy.RunRead( - b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_iterator_hacks); + b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_step_hacks); a_blockwise_copy.RunWrite(a_k_m0_m1_block_desc, a_block_even_buf); b_blockwise_copy.RunWrite(b_k_n0_n1_block_desc, b_block_even_buf); @@ -517,22 +517,20 @@ struct GridwiseGemmDlops_km_kn_mn_v1r2 do { // even iteration - a_blockwise_copy.MoveSrcSliceWindow( - a_k_m0_m1_grid_desc, - a_block_slice_copy_step, - a_k_m0_m1_global_move_slice_window_iterator_hack); - b_blockwise_copy.MoveSrcSliceWindow( - b_k_n0_n1_grid_desc, - b_block_slice_copy_step, - b_k_n0_n1_global_move_slice_window_iterator_hack); + a_blockwise_copy.MoveSrcSliceWindow(a_k_m0_m1_grid_desc, + a_block_slice_copy_step, + a_k_m0_m1_global_move_slice_window_step_hack); + b_blockwise_copy.MoveSrcSliceWindow(b_k_n0_n1_grid_desc, + b_block_slice_copy_step, + b_k_n0_n1_global_move_slice_window_step_hack); __syncthreads(); // LDS doubel buffer: load next data from device mem a_blockwise_copy.RunRead( - a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_iterator_hacks); + a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_step_hacks); b_blockwise_copy.RunRead( - b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_iterator_hacks); + b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_step_hacks); // LDS double buffer: GEMM on current data blockwise_gemm.Run(c_m10_m11_n10_n11_thread_desc, @@ -545,22 +543,20 @@ struct GridwiseGemmDlops_km_kn_mn_v1r2 b_blockwise_copy.RunWrite(b_k_n0_n1_block_desc, b_block_odd_buf); // odd iteration - a_blockwise_copy.MoveSrcSliceWindow( - a_k_m0_m1_grid_desc, - a_block_slice_copy_step, - a_k_m0_m1_global_move_slice_window_iterator_hack); - b_blockwise_copy.MoveSrcSliceWindow( - b_k_n0_n1_grid_desc, - b_block_slice_copy_step, - b_k_n0_n1_global_move_slice_window_iterator_hack); + a_blockwise_copy.MoveSrcSliceWindow(a_k_m0_m1_grid_desc, + a_block_slice_copy_step, + a_k_m0_m1_global_move_slice_window_step_hack); + b_blockwise_copy.MoveSrcSliceWindow(b_k_n0_n1_grid_desc, + b_block_slice_copy_step, + b_k_n0_n1_global_move_slice_window_step_hack); __syncthreads(); // LDS doubel buffer: load next data from device mem a_blockwise_copy.RunRead( - a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_iterator_hacks); + a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_step_hacks); b_blockwise_copy.RunRead( - b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_iterator_hacks); + b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_step_hacks); // LDS double buffer: GEMM on current data blockwise_gemm.Run( @@ -579,18 +575,18 @@ struct GridwiseGemmDlops_km_kn_mn_v1r2 { a_blockwise_copy.MoveSrcSliceWindow(a_k_m0_m1_grid_desc, a_block_slice_copy_step, - a_k_m0_m1_global_move_slice_window_iterator_hack); + a_k_m0_m1_global_move_slice_window_step_hack); b_blockwise_copy.MoveSrcSliceWindow(b_k_n0_n1_grid_desc, b_block_slice_copy_step, - b_k_n0_n1_global_move_slice_window_iterator_hack); + b_k_n0_n1_global_move_slice_window_step_hack); __syncthreads(); // LDS double buffer: load last data from device mem a_blockwise_copy.RunRead( - a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_iterator_hacks); + a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_step_hacks); b_blockwise_copy.RunRead( - b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_iterator_hacks); + b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_step_hacks); // LDS double buffer: GEMM on 2nd-last data blockwise_gemm.Run( @@ -657,7 +653,7 @@ struct GridwiseGemmDlops_km_kn_mn_v1r2 c_thread_buf, c_m0_m10_m11_n0_n10_n11_grid_desc, c_grid_buf, - CGridIteratorHacks{}); + CGridStepHacks{}); } } }; diff --git a/composable_kernel/include/tensor_operation/gridwise_gemm_dlops_v1r3.hpp b/composable_kernel/include/tensor_operation/gridwise_gemm_dlops_v1r3.hpp index 70cedf3fa0..1017dcc2a1 100644 --- a/composable_kernel/include/tensor_operation/gridwise_gemm_dlops_v1r3.hpp +++ b/composable_kernel/include/tensor_operation/gridwise_gemm_dlops_v1r3.hpp @@ -141,11 +141,11 @@ template + typename AGridStepHacks, + typename BGridStepHacks, + typename CGridStepHacks, + typename AGridMoveSliceWindowStepHacks, + typename BGridMoveSliceWindowStepHacks> struct GridwiseGemmDlops_km_kn_mn_v1r3 { static constexpr auto I0 = Number<0>{}; @@ -494,8 +494,8 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3 // LDS double buffer: preload data into LDS { - a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridIteratorHacks{}); - b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridIteratorHacks{}); + a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridStepHacks{}); + b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridStepHacks{}); a_blockwise_copy.RunWrite(a_k0_m0_m1_k1_block_desc, a_block_even_buf); b_blockwise_copy.RunWrite(b_k0_n0_n1_k1_block_desc, b_block_even_buf); @@ -514,18 +514,16 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3 // even iteration a_blockwise_copy.MoveSrcSliceWindow(a_k0_m0_m1_k1_grid_desc, a_block_slice_copy_step, - AGridMoveSliceWindowIteratorHacks{}); + AGridMoveSliceWindowStepHacks{}); b_blockwise_copy.MoveSrcSliceWindow(b_k0_n0_n1_k1_grid_desc, b_block_slice_copy_step, - BGridMoveSliceWindowIteratorHacks{}); + BGridMoveSliceWindowStepHacks{}); __syncthreads(); // LDS doubel buffer: load next data from device mem - a_blockwise_copy.RunRead( - a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridIteratorHacks{}); - b_blockwise_copy.RunRead( - b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridIteratorHacks{}); + a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridStepHacks{}); + b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridStepHacks{}); // LDS double buffer: GEMM on current data blockwise_gemm.Run(c_m10_m11_n10_n11_thread_desc, @@ -540,18 +538,16 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3 // odd iteration a_blockwise_copy.MoveSrcSliceWindow(a_k0_m0_m1_k1_grid_desc, a_block_slice_copy_step, - AGridMoveSliceWindowIteratorHacks{}); + AGridMoveSliceWindowStepHacks{}); b_blockwise_copy.MoveSrcSliceWindow(b_k0_n0_n1_k1_grid_desc, b_block_slice_copy_step, - BGridMoveSliceWindowIteratorHacks{}); + BGridMoveSliceWindowStepHacks{}); __syncthreads(); // LDS doubel buffer: load next data from device mem - a_blockwise_copy.RunRead( - a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridIteratorHacks{}); - b_blockwise_copy.RunRead( - b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridIteratorHacks{}); + a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridStepHacks{}); + b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridStepHacks{}); // LDS double buffer: GEMM on current data blockwise_gemm.Run( @@ -568,18 +564,16 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3 // LDS double buffer: tail if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left { - a_blockwise_copy.MoveSrcSliceWindow(a_k0_m0_m1_k1_grid_desc, - a_block_slice_copy_step, - AGridMoveSliceWindowIteratorHacks{}); - b_blockwise_copy.MoveSrcSliceWindow(b_k0_n0_n1_k1_grid_desc, - b_block_slice_copy_step, - BGridMoveSliceWindowIteratorHacks{}); + a_blockwise_copy.MoveSrcSliceWindow( + a_k0_m0_m1_k1_grid_desc, a_block_slice_copy_step, AGridMoveSliceWindowStepHacks{}); + b_blockwise_copy.MoveSrcSliceWindow( + b_k0_n0_n1_k1_grid_desc, b_block_slice_copy_step, BGridMoveSliceWindowStepHacks{}); __syncthreads(); // LDS double buffer: load last data from device mem - a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridIteratorHacks{}); - b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridIteratorHacks{}); + a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridStepHacks{}); + b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridStepHacks{}); // LDS double buffer: GEMM on 2nd-last data blockwise_gemm.Run( @@ -647,7 +641,7 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3 c_thread_buf, c_m0_m10_m11_n0_n10_n11_grid_desc, c_grid_buf, - CGridIteratorHacks{}); + CGridStepHacks{}); } } }; diff --git a/composable_kernel/include/tensor_operation/gridwise_gemm_dlops_v2.hpp b/composable_kernel/include/tensor_operation/gridwise_gemm_dlops_v2.hpp index 484f5d938d..7fdb89781d 100644 --- a/composable_kernel/include/tensor_operation/gridwise_gemm_dlops_v2.hpp +++ b/composable_kernel/include/tensor_operation/gridwise_gemm_dlops_v2.hpp @@ -42,11 +42,11 @@ template + typename AGlobalStepHacks, + typename BGlobalStepHacks, + typename CGlobalStepHacks, + typename AGlobalMoveSliceWindowStepHacks, + typename BGlobalMoveSliceWindowStepHacks> struct GridwiseGemmDlops_km_kn_mn_v3 { __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() @@ -239,15 +239,14 @@ struct GridwiseGemmDlops_km_kn_mn_v3 constexpr auto b_thread_slice_copy_step = make_multi_index(EPerBlock, 0, 0, 0); // hack to control index calculation when iterating over A and B matrix for threadwise copy - constexpr auto a_e_k_global_iterator_hacks = AGlobalIteratorHacks{}; - constexpr auto b_e_n_ho_wo_global_iterator_hacks = BGlobalIteratorHacks{}; + constexpr auto a_e_k_global_step_hacks = AGlobalStepHacks{}; + constexpr auto b_e_n_ho_wo_global_step_hacks = BGlobalStepHacks{}; // hack to control index calculation when move slice window for A and B matrix for // threadwise copy - constexpr auto a_e_k_global_move_slice_window_iterator_hack = - AGlobalMoveSliceWindowIteratorHacks{}; - constexpr auto b_e_n_ho_wo_global_move_slice_window_iterator_hack = - BGlobalMoveSliceWindowIteratorHacks{}; + constexpr auto a_e_k_global_move_slice_window_step_hack = AGlobalMoveSliceWindowStepHacks{}; + constexpr auto b_e_n_ho_wo_global_move_slice_window_step_hack = + BGlobalMoveSliceWindowStepHacks{}; // double regsiter buffer for b StaticBuffer struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 { @@ -416,15 +416,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0); // hack to control index calculation when iterating over A and B matrix for threadwise copy - constexpr auto a_k0_m_k1_grid_iterator_hacks = AGridIteratorHacks{}; - constexpr auto b_k0_n_k1_grid_iterator_hacks = BGridIteratorHacks{}; + constexpr auto a_k0_m_k1_grid_step_hacks = AGridStepHacks{}; + constexpr auto b_k0_n_k1_grid_step_hacks = BGridStepHacks{}; // hack to control index calculation when move slice window for A and B matrix for // threadwise copy - constexpr auto a_k0_m_k1_grid_move_slice_window_iterator_hack = - AGridMoveSliceWindowIteratorHacks{}; - constexpr auto b_k0_n_k1_grid_move_slice_window_iterator_hack = - BGridMoveSliceWindowIteratorHacks{}; + constexpr auto a_k0_m_k1_grid_move_slice_window_step_hack = AGridMoveSliceWindowStepHacks{}; + constexpr auto b_k0_n_k1_grid_move_slice_window_step_hack = BGridMoveSliceWindowStepHacks{}; auto a_block_buf = make_dynamic_buffer( p_a_block, a_k0_m_k1_block_desc.GetElementSpaceSize()); @@ -433,10 +431,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 // preload data into LDS { - a_blockwise_copy.RunRead( - a_k0_m_k1_grid_desc, a_grid_buf, a_k0_m_k1_grid_iterator_hacks); - b_blockwise_copy.RunRead( - b_k0_n_k1_grid_desc, b_grid_buf, b_k0_n_k1_grid_iterator_hacks); + a_blockwise_copy.RunRead(a_k0_m_k1_grid_desc, a_grid_buf, a_k0_m_k1_grid_step_hacks); + b_blockwise_copy.RunRead(b_k0_n_k1_grid_desc, b_grid_buf, b_k0_n_k1_grid_step_hacks); a_blockwise_copy.RunWrite(a_k0_m_k1_block_desc, a_block_buf); b_blockwise_copy.RunWrite(b_k0_n_k1_block_desc, b_block_buf); @@ -449,18 +445,16 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 { a_blockwise_copy.MoveSrcSliceWindow(a_k0_m_k1_grid_desc, a_block_slice_copy_step, - a_k0_m_k1_grid_move_slice_window_iterator_hack); + a_k0_m_k1_grid_move_slice_window_step_hack); b_blockwise_copy.MoveSrcSliceWindow(b_k0_n_k1_grid_desc, b_block_slice_copy_step, - b_k0_n_k1_grid_move_slice_window_iterator_hack); + b_k0_n_k1_grid_move_slice_window_step_hack); - a_blockwise_copy.RunRead( - a_k0_m_k1_grid_desc, a_grid_buf, a_k0_m_k1_grid_iterator_hacks); + a_blockwise_copy.RunRead(a_k0_m_k1_grid_desc, a_grid_buf, a_k0_m_k1_grid_step_hacks); block_sync_lds(); - b_blockwise_copy.RunRead( - b_k0_n_k1_grid_desc, b_grid_buf, b_k0_n_k1_grid_iterator_hacks); + b_blockwise_copy.RunRead(b_k0_n_k1_grid_desc, b_grid_buf, b_k0_n_k1_grid_step_hacks); blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); @@ -526,7 +520,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 const index_t n_thread_data_on_grid = n_block_data_idx_on_grid + c_thread_mtx_on_block[I1]; - constexpr auto c_m0_m1_m2_n_grid_tensor_iterator_hacks = CGridIteratorHacks{}; + constexpr auto c_m0_m1_m2_n_grid_tensor_step_hacks = CGridStepHacks{}; constexpr index_t MWaves = MPerBlock / (MPerWave * MRepeat); constexpr index_t NWaves = NPerBlock / (NPerWave * NRepeat); @@ -557,7 +551,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 c_blk_buf_, c_m0_m1_m2_n_grid_desc, c_grid_buf, - c_m0_m1_m2_n_grid_tensor_iterator_hacks); + c_m0_m1_m2_n_grid_tensor_step_hacks); } #else { @@ -579,7 +573,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 const index_t n_thread_data_on_grid = n_block_data_idx_on_grid + c_thread_mtx_on_block[I1]; - constexpr auto c_m0_m1_m2_n_grid_tensor_iterator_hacks = CGridIteratorHacks{}; + constexpr auto c_m0_m1_m2_n_grid_tensor_step_hacks = CGridStepHacks{}; auto c_thread_copy = ThreadwiseTensorSliceTransfer_v1r3{}].template AsType(), c_m0_m1_m2_n_grid_desc, c_grid_buf, - c_m0_m1_m2_n_grid_tensor_iterator_hacks); + c_m0_m1_m2_n_grid_tensor_step_hacks); return c_thread_idx_; }; @@ -625,7 +619,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 c_thread_buf[Number{}].template AsType(), c_m0_m1_m2_n_grid_desc, c_grid_buf, - c_m0_m1_m2_n_grid_tensor_iterator_hacks); + c_m0_m1_m2_n_grid_tensor_step_hacks); }; auto nrepeat_plus_copy = [&](auto c_thread_idx_) { @@ -638,7 +632,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 c_thread_buf[Number{}].template AsType(), c_m0_m1_m2_n_grid_desc, c_grid_buf, - c_m0_m1_m2_n_grid_tensor_iterator_hacks); + c_m0_m1_m2_n_grid_tensor_step_hacks); }; auto mrepeat_minus_copy = [&](auto c_thread_idx_) { @@ -651,7 +645,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 c_thread_buf[Number{}].template AsType(), c_m0_m1_m2_n_grid_desc, c_grid_buf, - c_m0_m1_m2_n_grid_tensor_iterator_hacks); + c_m0_m1_m2_n_grid_tensor_step_hacks); }; auto nrepeat_minus_copy = [&](auto c_thread_idx_) { @@ -664,7 +658,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 c_thread_buf[Number{}].template AsType(), c_m0_m1_m2_n_grid_desc, c_grid_buf, - c_m0_m1_m2_n_grid_tensor_iterator_hacks); + c_m0_m1_m2_n_grid_tensor_step_hacks); }; static_assert((MRepeat == 4 && NRepeat == 4) or (MRepeat == 4 && NRepeat == 2) or diff --git a/composable_kernel/include/tensor_operation/threadwise_tensor_slice_set.hpp b/composable_kernel/include/tensor_operation/threadwise_tensor_slice_set.hpp index 6eb058711e..a4128c274b 100644 --- a/composable_kernel/include/tensor_operation/threadwise_tensor_slice_set.hpp +++ b/composable_kernel/include/tensor_operation/threadwise_tensor_slice_set.hpp @@ -11,7 +11,7 @@ namespace ck { // 1. Desc is known at compile-time // 2. Buffer is StaticBuffer // 3. OriginIdx is known at compile-time -// 4. use #-iterator +// 4. use #-step template + typename DstStepHacks> __device__ void Run(const SrcDesc&, const SrcSliceOriginIdx&, const SrcBuffer& src_buf, const DstDesc& dst_desc, DstBuffer& dst_buf, - const DstIteratorHacks& dst_iterator_hacks) + const DstStepHacks& dst_step_hacks) { static_assert(SrcDesc::IsKnownAtCompileTime(), "wrong! SrcDesc need to known at compile-time"); @@ -127,31 +127,31 @@ struct ThreadwiseTensorSliceTransfer_v1r3 constexpr auto ordered_access_lengths = container_reorder_given_new2old(access_lengths, dim_access_order); - // make forward iterators - const auto dst_forward_iterators = generate_tuple( + // make forward steps + const auto dst_forward_steps = generate_tuple( [&](auto i) { - Index forward_step; + Index forward_step_idx; static_for<0, nDim, 1>{}([&](auto j) { - forward_step(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0; + forward_step_idx(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0; }); - return make_tensor_coordinate_iterator( - dst_desc, forward_step, dst_iterator_hacks[I0][i]); + return make_tensor_coordinate_step( + dst_desc, forward_step_idx, dst_step_hacks[I0][i]); }, Number{}); - // make backward iterators - const auto dst_backward_iterators = generate_tuple( + // make backward steps + const auto dst_backward_steps = generate_tuple( [&](auto i) { - Index backward_step; + Index backward_step_idx; static_for<0, nDim, 1>{}([&](auto j) { - backward_step(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0; + backward_step_idx(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0; }); - return make_tensor_coordinate_iterator( - dst_desc, backward_step, dst_iterator_hacks[I1][i]); + return make_tensor_coordinate_step( + dst_desc, backward_step_idx, dst_step_hacks[I1][i]); }, Number{}); @@ -236,12 +236,12 @@ struct ThreadwiseTensorSliceTransfer_v1r3 if constexpr(forward_sweep[i]) { move_tensor_coordinate( - dst_desc, dst_coord_, dst_forward_iterators[dim_access_order[i]]); + dst_desc, dst_coord_, dst_forward_steps[dim_access_order[i]]); } else { move_tensor_coordinate( - dst_desc, dst_coord_, dst_backward_iterators[dim_access_order[i]]); + dst_desc, dst_coord_, dst_backward_steps[dim_access_order[i]]); } } }); @@ -250,10 +250,10 @@ struct ThreadwiseTensorSliceTransfer_v1r3 // move dst coordinate back to slice origin (or not) if constexpr(DstResetCoordinateAfterRun) { - const auto dst_reset_iterator = - make_tensor_coordinate_iterator(dst_desc, GetDstCoordinateResetStep()); + const auto dst_reset_step = + make_tensor_coordinate_step(dst_desc, GetDstCoordinateResetStep()); - move_tensor_coordinate(dst_desc, dst_coord_, dst_reset_iterator); + move_tensor_coordinate(dst_desc, dst_coord_, dst_reset_step); } } @@ -268,11 +268,11 @@ struct ThreadwiseTensorSliceTransfer_v1r3 constexpr auto zeros = typename uniform_sequence_gen::type{}; - constexpr auto dst_iterator_hacks = + constexpr auto dst_step_hacks = make_tuple(generate_tuple([&](auto) { return zeros; }, Number{}), generate_tuple([&](auto) { return zeros; }, Number{})); - Run(SrcDesc{}, SrcSliceOriginIdx{}, src_buf, dst_desc, dst_buf, dst_iterator_hacks); + Run(SrcDesc{}, SrcSliceOriginIdx{}, src_buf, dst_desc, dst_buf, dst_step_hacks); } __device__ static constexpr auto GetDstCoordinateResetStep() @@ -345,7 +345,7 @@ struct ThreadwiseTensorSliceTransfer_v1r3 : dst_slice_origin_step_idx + GetDstCoordinateResetStep(); // is it OK to construct a new step every time? - const auto adjusted_step = make_tensor_coordinate_iterator(dst_desc, adjusted_step_idx); + const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx); move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step); } @@ -382,7 +382,7 @@ struct ThreadwiseTensorSliceTransfer_v2 using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{})); - using SrcCoordIterator = decltype(make_tensor_coordinate_iterator(SrcDesc{}, Index{})); + using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{})); __device__ constexpr ThreadwiseTensorSliceTransfer_v2(const SrcDesc& src_desc, const Index& src_slice_origin_idx) @@ -400,13 +400,13 @@ struct ThreadwiseTensorSliceTransfer_v2 template + typename SrcStepHacks> __device__ void Run(const SrcDesc& src_desc, const SrcBuffer& src_buf, const DstDesc&, const DstSliceOriginIdx&, DstBuffer& dst_buf, - const SrcIteratorHacks& src_iterator_hacks) + const SrcStepHacks& src_step_hacks) { static_assert(DstDesc::IsKnownAtCompileTime(), "wrong! DstDesc need to known at compile-time"); @@ -441,31 +441,31 @@ struct ThreadwiseTensorSliceTransfer_v2 constexpr auto ordered_access_lengths = container_reorder_given_new2old(access_lengths, dim_access_order); - // make forward iterators - const auto src_forward_iterators = generate_tuple( + // make forward steps + const auto src_forward_steps = generate_tuple( [&](auto i) { - Index forward_step; + Index forward_step_idx; static_for<0, nDim, 1>{}([&](auto j) { - forward_step(j) = (i.value == j.value) ? src_scalar_per_access[i] : 0; + forward_step_idx(j) = (i.value == j.value) ? src_scalar_per_access[i] : 0; }); - return make_tensor_coordinate_iterator( - src_desc, forward_step, src_iterator_hacks[I0][i]); + return make_tensor_coordinate_step( + src_desc, forward_step_idx, src_step_hacks[I0][i]); }, Number{}); - // make backward iterators - const auto src_backward_iterators = generate_tuple( + // make backward steps + const auto src_backward_steps = generate_tuple( [&](auto i) { - Index backward_step; + Index backward_step_idx; static_for<0, nDim, 1>{}([&](auto j) { - backward_step(j) = (i.value == j.value) ? -src_scalar_per_access[i] : 0; + backward_step_idx(j) = (i.value == j.value) ? -src_scalar_per_access[i] : 0; }); - return make_tensor_coordinate_iterator( - src_desc, backward_step, src_iterator_hacks[I1][i]); + return make_tensor_coordinate_step( + src_desc, backward_step_idx, src_step_hacks[I1][i]); }, Number{}); @@ -548,12 +548,12 @@ struct ThreadwiseTensorSliceTransfer_v2 if constexpr(forward_sweep[i]) { move_tensor_coordinate( - src_desc, src_coord_, src_forward_iterators[dim_access_order[i]]); + src_desc, src_coord_, src_forward_steps[dim_access_order[i]]); } else { move_tensor_coordinate( - src_desc, src_coord_, src_backward_iterators[dim_access_order[i]]); + src_desc, src_coord_, src_backward_steps[dim_access_order[i]]); } } }); @@ -562,10 +562,10 @@ struct ThreadwiseTensorSliceTransfer_v2 // move src coordinate back to slice origin (or not) if constexpr(SrcResetCoordinateAfterRun) { - const auto src_reset_iterator = - make_tensor_coordinate_iterator(src_desc, GetSrcCoordinateResetStep()); + const auto src_reset_step = + make_tensor_coordinate_step(src_desc, GetSrcCoordinateResetStep()); - move_tensor_coordinate(src_desc, src_coord_, src_reset_iterator); + move_tensor_coordinate(src_desc, src_coord_, src_reset_step); } } @@ -580,11 +580,11 @@ struct ThreadwiseTensorSliceTransfer_v2 constexpr auto zeros = typename uniform_sequence_gen::type{}; - constexpr auto src_iterator_hacks = + constexpr auto src_step_hacks = make_tuple(generate_tuple([&](auto) { return zeros; }, Number{}), generate_tuple([&](auto) { return zeros; }, Number{})); - Run(src_desc, src_buf, DstDesc{}, DstSliceOriginIdx{}, dst_buf, src_iterator_hacks); + Run(src_desc, src_buf, DstDesc{}, DstSliceOriginIdx{}, dst_buf, src_step_hacks); } __device__ static constexpr auto GetSrcCoordinateResetStep() @@ -657,7 +657,7 @@ struct ThreadwiseTensorSliceTransfer_v2 : src_slice_origin_step_idx + GetSrcCoordinateResetStep(); // is it OK to construct a new step every time? - const auto adjusted_step = make_tensor_coordinate_iterator(src_desc, adjusted_step_idx); + const auto adjusted_step = make_tensor_coordinate_step(src_desc, adjusted_step_idx); move_tensor_coordinate(src_desc, src_coord_, adjusted_step); } @@ -699,8 +699,8 @@ struct ThreadwiseTensorSliceTransfer_v3 using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{})); using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{})); - using SrcCoordIterator = decltype(make_tensor_coordinate_iterator(SrcDesc{}, Index{})); - using DstCoordIterator = decltype(make_tensor_coordinate_iterator(DstDesc{}, Index{})); + using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{})); + using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{})); __device__ constexpr ThreadwiseTensorSliceTransfer_v3(const SrcDesc& src_desc, const Index& src_slice_origin, @@ -724,10 +724,9 @@ struct ThreadwiseTensorSliceTransfer_v3 dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx); } - template - __device__ void RunRead(const SrcDesc& src_desc, - const SrcBuffer& src_buf, - const SrcIteratorHacks& src_iterator_hacks) + template + __device__ void + RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf, const SrcStepHacks& src_step_hacks) { static_assert(SrcBuffer::GetAddressSpace() == AddressSpaceEnum_t::Global or SrcBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds, @@ -755,31 +754,31 @@ struct ThreadwiseTensorSliceTransfer_v3 constexpr auto ordered_src_access_lengths = container_reorder_given_new2old(src_access_lengths, src_dim_access_order); - // make forward iterators - const auto src_forward_iterators = generate_tuple( + // make forward steps + const auto src_forward_steps = generate_tuple( [&](auto i) { - Index forward_step; + Index forward_step_idx; static_for<0, nDim, 1>{}([&](auto j) { - forward_step(j) = (i.value == j.value) ? src_scalar_per_access[i] : 0; + forward_step_idx(j) = (i.value == j.value) ? src_scalar_per_access[i] : 0; }); - return make_tensor_coordinate_iterator( - src_desc, forward_step, src_iterator_hacks[I0][i]); + return make_tensor_coordinate_step( + src_desc, forward_step_idx, src_step_hacks[I0][i]); }, Number{}); - // make backward iterators - const auto src_backward_iterators = generate_tuple( + // make backward steps + const auto src_backward_steps = generate_tuple( [&](auto i) { - Index backward_step; + Index backward_step_idx; static_for<0, nDim, 1>{}([&](auto j) { - backward_step(j) = (i.value == j.value) ? -src_scalar_per_access[i] : 0; + backward_step_idx(j) = (i.value == j.value) ? -src_scalar_per_access[i] : 0; }); - return make_tensor_coordinate_iterator( - src_desc, backward_step, src_iterator_hacks[I1][i]); + return make_tensor_coordinate_step( + src_desc, backward_step_idx, src_step_hacks[I1][i]); }, Number{}); @@ -861,12 +860,12 @@ struct ThreadwiseTensorSliceTransfer_v3 if constexpr(forward_sweep[i]) { move_tensor_coordinate( - src_desc, src_coord_, src_forward_iterators[src_dim_access_order[i]]); + src_desc, src_coord_, src_forward_steps[src_dim_access_order[i]]); } else { move_tensor_coordinate( - src_desc, src_coord_, src_backward_iterators[src_dim_access_order[i]]); + src_desc, src_coord_, src_backward_steps[src_dim_access_order[i]]); } } }); @@ -875,17 +874,16 @@ struct ThreadwiseTensorSliceTransfer_v3 // move src coordinate back to slice origin (or not) if constexpr(SrcResetCoordinateAfterRun) { - const auto src_reset_iterator = - make_tensor_coordinate_iterator(src_desc, GetSrcCoordinateResetStep()); + const auto src_reset_step = + make_tensor_coordinate_step(src_desc, GetSrcCoordinateResetStep()); - move_tensor_coordinate(src_desc, src_coord_, src_reset_iterator); + move_tensor_coordinate(src_desc, src_coord_, src_reset_step); } } - template - __device__ void RunWrite(const DstDesc& dst_desc, - DstBuffer& dst_buf, - const DstIteratorHacks& dst_iterator_hacks) + template + __device__ void + RunWrite(const DstDesc& dst_desc, DstBuffer& dst_buf, const DstStepHacks& dst_step_hacks) { static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Global or DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds, @@ -913,35 +911,31 @@ struct ThreadwiseTensorSliceTransfer_v3 constexpr auto ordered_dst_access_lengths = container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order); - // make forward iterators - const auto dst_forward_iterators = generate_tuple( + // make forward steps + const auto dst_forward_steps = generate_tuple( [&](auto i) { - Index forward_step; + Index forward_step_idx; static_for<0, nDim, 1>{}([&](auto j) { - forward_step(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0; + forward_step_idx(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0; }); - const auto forward_iterator = make_tensor_coordinate_iterator( - dst_desc, forward_step, dst_iterator_hacks[I0][i]); - - return forward_iterator; + return make_tensor_coordinate_step( + dst_desc, forward_step_idx, dst_step_hacks[I0][i]); }, Number{}); - // make backward iterators - const auto dst_backward_iterators = generate_tuple( + // make backward steps + const auto dst_backward_steps = generate_tuple( [&](auto i) { - Index backward_step; + Index backward_step_idx; static_for<0, nDim, 1>{}([&](auto j) { - backward_step(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0; + backward_step_idx(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0; }); - const auto backward_iterator = make_tensor_coordinate_iterator( - dst_desc, backward_step, dst_iterator_hacks[I1][i]); - - return backward_iterator; + return make_tensor_coordinate_step( + dst_desc, backward_step_idx, dst_step_hacks[I1][i]); }, Number{}); @@ -1025,12 +1019,12 @@ struct ThreadwiseTensorSliceTransfer_v3 if constexpr(forward_sweep[i]) { move_tensor_coordinate( - dst_desc, dst_coord_, dst_forward_iterators[dst_dim_access_order[i]]); + dst_desc, dst_coord_, dst_forward_steps[dst_dim_access_order[i]]); } else { move_tensor_coordinate( - dst_desc, dst_coord_, dst_backward_iterators[dst_dim_access_order[i]]); + dst_desc, dst_coord_, dst_backward_steps[dst_dim_access_order[i]]); } } }); @@ -1039,10 +1033,10 @@ struct ThreadwiseTensorSliceTransfer_v3 // move dst coordinate back to slice origin (or not) if constexpr(DstResetCoordinateAfterRun) { - const auto dst_reset_iterator = - make_tensor_coordinate_iterator(dst_desc, GetDstCoordinateResetStep()); + const auto dst_reset_step = + make_tensor_coordinate_step(dst_desc, GetDstCoordinateResetStep()); - move_tensor_coordinate(dst_desc, dst_coord_, dst_reset_iterator); + move_tensor_coordinate(dst_desc, dst_coord_, dst_reset_step); } } @@ -1053,11 +1047,11 @@ struct ThreadwiseTensorSliceTransfer_v3 constexpr auto zeros = typename uniform_sequence_gen::type{}; - constexpr auto src_iterator_hacks = + constexpr auto src_step_hacks = make_tuple(generate_tuple([&](auto) { return zeros; }, Number{}), generate_tuple([&](auto) { return zeros; }, Number{})); - RunRead(src_desc, src_buf, src_iterator_hacks); + RunRead(src_desc, src_buf, src_step_hacks); } template @@ -1067,11 +1061,11 @@ struct ThreadwiseTensorSliceTransfer_v3 constexpr auto zeros = typename uniform_sequence_gen::type{}; - constexpr auto dst_iterator_hacks = + constexpr auto dst_step_hacks = make_tuple(generate_tuple([&](auto) { return zeros; }, Number{}), generate_tuple([&](auto) { return zeros; }, Number{})); - RunWrite(dst_desc, dst_buf, dst_iterator_hacks); + RunWrite(dst_desc, dst_buf, dst_step_hacks); } __device__ static constexpr auto GetSrcCoordinateResetStep() @@ -1204,17 +1198,17 @@ struct ThreadwiseTensorSliceTransfer_v3 : src_slice_origin_step_idx + GetSrcCoordinateResetStep(); // is it OK to construct a new step every time? - const auto adjusted_step = make_tensor_coordinate_iterator(src_desc, adjusted_step_idx); + const auto adjusted_step = make_tensor_coordinate_step(src_desc, adjusted_step_idx); move_tensor_coordinate(src_desc, src_coord_, adjusted_step); } // src_slice_origin_step_idx need to be known at compile-time, for performance reason - template + template __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& src_slice_origin_step_idx, - const SrcMoveSliceWindowIteratorHack& src_move_slice_window_iterator_hack) + const SrcMoveSliceWindowStepHack& src_move_slice_window_step_hack) { // if src coord was not reset by RunRead(), then need to adjust the step here const auto adjusted_step_idx = @@ -1222,8 +1216,8 @@ struct ThreadwiseTensorSliceTransfer_v3 : src_slice_origin_step_idx + GetSrcCoordinateResetStep(); // is it OK to construct a new step every time? - const auto adjusted_step = make_tensor_coordinate_iterator( - src_desc, adjusted_step_idx, src_move_slice_window_iterator_hack); + const auto adjusted_step = make_tensor_coordinate_step( + src_desc, adjusted_step_idx, src_move_slice_window_step_hack); move_tensor_coordinate(src_desc, src_coord_, adjusted_step); } @@ -1237,7 +1231,7 @@ struct ThreadwiseTensorSliceTransfer_v3 : dst_slice_origin_step_idx + GetDstCoordinateResetStep(); // is it OK to construct a new step every time? - const auto adjusted_step = make_tensor_coordinate_iterator(dst_desc, adjusted_step_idx); + const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx); move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step); } @@ -1260,7 +1254,7 @@ struct ThreadwiseTensorSliceTransfer_v3 // 2. SrcBuffer is DynamicBuffer // 3. src_ref_idx is known at run-time // 4. SrcRefToOriginDisplacement is known at compile-time -// 5. use #-iterator +// 5. use #-step // 2. dst: // 1. DstDesc is known at compile-time // 2. DstBuffer is StaticBuffer @@ -1287,7 +1281,7 @@ struct ThreadwiseTensorSliceTransfer_v4 using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{})); - using SrcCoordIterator = decltype(make_tensor_coordinate_iterator(SrcDesc{}, Index{})); + using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{})); __device__ constexpr ThreadwiseTensorSliceTransfer_v4(const Index& src_ref_idx) : src_ref_coord_(make_tensor_coordinate(SrcDesc{}, src_ref_idx)) @@ -1386,12 +1380,12 @@ struct ThreadwiseTensorSliceTransfer_v4 constexpr auto src_ref_to_data_disp_idx = src_ref_to_origin_disp_idx + data_to_origin_disp_idx; - constexpr auto src_ref_to_data_disp_coord_iterator = - make_tensor_coordinate_iterator(src_desc, src_ref_to_data_disp_idx); + constexpr auto src_ref_to_data_disp_coord_step = + make_tensor_coordinate_step(src_desc, src_ref_to_data_disp_idx); auto src_data_coord = src_ref_coord_; - move_tensor_coordinate(src_desc, src_data_coord, src_ref_to_data_disp_coord_iterator); + move_tensor_coordinate(src_desc, src_data_coord, src_ref_to_data_disp_coord_step); vector_type_maker_t src_tmp_vector; @@ -1431,7 +1425,7 @@ struct ThreadwiseTensorSliceTransfer_v4 constexpr auto src_desc = SrcDesc{}; const auto src_slice_move_step_iter = - make_tensor_coordinate_iterator(src_desc, to_multi_index(src_slice_move_step_idx)); + make_tensor_coordinate_step(src_desc, to_multi_index(src_slice_move_step_idx)); move_tensor_coordinate(SrcDesc{}, src_ref_coord_, src_slice_move_step_iter); } diff --git a/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v2.hpp b/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v2.hpp index a2613f2e2d..ceac47a364 100644 --- a/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v2.hpp +++ b/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v2.hpp @@ -41,8 +41,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1 using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{})); using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{})); - using SrcCoordIterator = decltype(make_tensor_coordinate_iterator(SrcDesc{}, Index{})); - using DstCoordIterator = decltype(make_tensor_coordinate_iterator(DstDesc{}, Index{})); + using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{})); + using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{})); __device__ constexpr ThreadwiseTensorSliceTransfer_v3r1(const SrcDesc& src_desc, const Index& src_slice_origin, @@ -72,10 +72,9 @@ struct ThreadwiseTensorSliceTransfer_v3r1 dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx); } - template - __device__ void RunRead(const SrcDesc& src_desc, - const SrcBuffer& src_buf, - const SrcIteratorHacks& src_iterator_hacks) + template + __device__ void + RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf, const SrcStepHacks& src_step_hacks) { static_assert(SrcBuffer::GetAddressSpace() == AddressSpaceEnum_t::Global or SrcBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds, @@ -108,31 +107,31 @@ struct ThreadwiseTensorSliceTransfer_v3r1 constexpr auto ordered_src_access_lengths = container_reorder_given_new2old(src_access_lengths, src_dim_access_order); - // make forward iterators - const auto src_forward_iterators = generate_tuple( + // make forward steps + const auto src_forward_steps = generate_tuple( [&](auto i) { - Index forward_step; + Index forward_step_idx; static_for<0, nDim, 1>{}([&](auto j) { - forward_step(j) = (i.value == j.value) ? src_vector_tensor_lengths[i] : 0; + forward_step_idx(j) = (i.value == j.value) ? src_vector_tensor_lengths[i] : 0; }); - return make_tensor_coordinate_iterator( - src_desc, forward_step, src_iterator_hacks[I0][i]); + return make_tensor_coordinate_step( + src_desc, forward_step_idx, src_step_hacks[I0][i]); }, Number{}); - // make backward iterators - const auto src_backward_iterators = generate_tuple( + // make backward steps + const auto src_backward_steps = generate_tuple( [&](auto i) { - Index backward_step; + Index backward_step_idx; static_for<0, nDim, 1>{}([&](auto j) { - backward_step(j) = (i.value == j.value) ? -src_vector_tensor_lengths[i] : 0; + backward_step_idx(j) = (i.value == j.value) ? -src_vector_tensor_lengths[i] : 0; }); - return make_tensor_coordinate_iterator( - src_desc, backward_step, src_iterator_hacks[I1][i]); + return make_tensor_coordinate_step( + src_desc, backward_step_idx, src_step_hacks[I1][i]); }, Number{}); @@ -220,12 +219,12 @@ struct ThreadwiseTensorSliceTransfer_v3r1 if constexpr(forward_sweep[i]) { move_tensor_coordinate( - src_desc, src_coord_, src_forward_iterators[src_dim_access_order[i]]); + src_desc, src_coord_, src_forward_steps[src_dim_access_order[i]]); } else { move_tensor_coordinate( - src_desc, src_coord_, src_backward_iterators[src_dim_access_order[i]]); + src_desc, src_coord_, src_backward_steps[src_dim_access_order[i]]); } } }); @@ -234,17 +233,16 @@ struct ThreadwiseTensorSliceTransfer_v3r1 // move src coordinate back to slice origin (or not) if constexpr(SrcResetCoordinateAfterRun) { - const auto src_reset_iterator = - make_tensor_coordinate_iterator(src_desc, GetSrcCoordinateResetStep()); + const auto src_reset_step = + make_tensor_coordinate_step(src_desc, GetSrcCoordinateResetStep()); - move_tensor_coordinate(src_desc, src_coord_, src_reset_iterator); + move_tensor_coordinate(src_desc, src_coord_, src_reset_step); } } - template - __device__ void RunWrite(const DstDesc& dst_desc, - DstBuffer& dst_buf, - const DstIteratorHacks& dst_iterator_hacks) + template + __device__ void + RunWrite(const DstDesc& dst_desc, DstBuffer& dst_buf, const DstStepHacks& dst_step_hacks) { static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Global or DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds, @@ -277,35 +275,31 @@ struct ThreadwiseTensorSliceTransfer_v3r1 constexpr auto ordered_dst_access_lengths = container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order); - // make forward iterators - const auto dst_forward_iterators = generate_tuple( + // make forward steps + const auto dst_forward_steps = generate_tuple( [&](auto i) { - Index forward_step; + Index forward_step_idx; static_for<0, nDim, 1>{}([&](auto j) { - forward_step(j) = (i.value == j.value) ? dst_vector_tensor_lengths[i] : 0; + forward_step_idx(j) = (i.value == j.value) ? dst_vector_tensor_lengths[i] : 0; }); - const auto forward_iterator = make_tensor_coordinate_iterator( - dst_desc, forward_step, dst_iterator_hacks[I0][i]); - - return forward_iterator; + return make_tensor_coordinate_step( + dst_desc, forward_step_idx, dst_step_hacks[I0][i]); }, Number{}); - // make backward iterators - const auto dst_backward_iterators = generate_tuple( + // make backward steps + const auto dst_backward_steps = generate_tuple( [&](auto i) { - Index backward_step; + Index backward_step_idx; static_for<0, nDim, 1>{}([&](auto j) { - backward_step(j) = (i.value == j.value) ? -dst_vector_tensor_lengths[i] : 0; + backward_step_idx(j) = (i.value == j.value) ? -dst_vector_tensor_lengths[i] : 0; }); - const auto backward_iterator = make_tensor_coordinate_iterator( - dst_desc, backward_step, dst_iterator_hacks[I1][i]); - - return backward_iterator; + return make_tensor_coordinate_step( + dst_desc, backward_step_idx, dst_step_hacks[I1][i]); }, Number{}); @@ -395,12 +389,12 @@ struct ThreadwiseTensorSliceTransfer_v3r1 if constexpr(forward_sweep[i]) { move_tensor_coordinate( - dst_desc, dst_coord_, dst_forward_iterators[dst_dim_access_order[i]]); + dst_desc, dst_coord_, dst_forward_steps[dst_dim_access_order[i]]); } else { move_tensor_coordinate( - dst_desc, dst_coord_, dst_backward_iterators[dst_dim_access_order[i]]); + dst_desc, dst_coord_, dst_backward_steps[dst_dim_access_order[i]]); } } }); @@ -409,10 +403,10 @@ struct ThreadwiseTensorSliceTransfer_v3r1 // move dst coordinate back to slice origin (or not) if constexpr(DstResetCoordinateAfterRun) { - const auto dst_reset_iterator = - make_tensor_coordinate_iterator(dst_desc, GetDstCoordinateResetStep()); + const auto dst_reset_step = + make_tensor_coordinate_step(dst_desc, GetDstCoordinateResetStep()); - move_tensor_coordinate(dst_desc, dst_coord_, dst_reset_iterator); + move_tensor_coordinate(dst_desc, dst_coord_, dst_reset_step); } } @@ -423,11 +417,11 @@ struct ThreadwiseTensorSliceTransfer_v3r1 constexpr auto zeros = typename uniform_sequence_gen::type{}; - constexpr auto src_iterator_hacks = + constexpr auto src_step_hacks = make_tuple(generate_tuple([&](auto) { return zeros; }, Number{}), generate_tuple([&](auto) { return zeros; }, Number{})); - RunRead(src_desc, src_buf, src_iterator_hacks); + RunRead(src_desc, src_buf, src_step_hacks); } template @@ -437,11 +431,11 @@ struct ThreadwiseTensorSliceTransfer_v3r1 constexpr auto zeros = typename uniform_sequence_gen::type{}; - constexpr auto dst_iterator_hacks = + constexpr auto dst_step_hacks = make_tuple(generate_tuple([&](auto) { return zeros; }, Number{}), generate_tuple([&](auto) { return zeros; }, Number{})); - RunWrite(dst_desc, dst_buf, dst_iterator_hacks); + RunWrite(dst_desc, dst_buf, dst_step_hacks); } __device__ static constexpr auto GetSrcCoordinateResetStep() @@ -564,17 +558,17 @@ struct ThreadwiseTensorSliceTransfer_v3r1 : src_slice_origin_step_idx + GetSrcCoordinateResetStep(); // is it OK to construct a new step every time? - const auto adjusted_step = make_tensor_coordinate_iterator(src_desc, adjusted_step_idx); + const auto adjusted_step = make_tensor_coordinate_step(src_desc, adjusted_step_idx); move_tensor_coordinate(src_desc, src_coord_, adjusted_step); } // src_slice_origin_step_idx need to be known at compile-time, for performance reason - template + template __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& src_slice_origin_step_idx, - const SrcMoveSliceWindowIteratorHack& src_move_slice_window_iterator_hack) + const SrcMoveSliceWindowStepHack& src_move_slice_window_step_hack) { // if src coord was not reset by RunRead(), then need to adjust the step here const auto adjusted_step_idx = @@ -582,8 +576,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1 : src_slice_origin_step_idx + GetSrcCoordinateResetStep(); // is it OK to construct a new step every time? - const auto adjusted_step = make_tensor_coordinate_iterator( - src_desc, adjusted_step_idx, src_move_slice_window_iterator_hack); + const auto adjusted_step = make_tensor_coordinate_step( + src_desc, adjusted_step_idx, src_move_slice_window_step_hack); move_tensor_coordinate(src_desc, src_coord_, adjusted_step); } @@ -597,7 +591,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1 : dst_slice_origin_step_idx + GetDstCoordinateResetStep(); // is it OK to construct a new step every time? - const auto adjusted_step = make_tensor_coordinate_iterator(dst_desc, adjusted_step_idx); + const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx); move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step); } @@ -620,7 +614,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1 // 2. SrcBuffer is DynamicBuffer // 3. src_ref_idx is known at run-time // 4. SrcRefToOriginDisplacement is known at compile-time -// 5. use #-iterator +// 5. use #-step // 2. dst: // 1. DstDesc is known at compile-time // 2. DstBuffer is StaticBuffer @@ -649,7 +643,7 @@ struct ThreadwiseTensorSliceTransfer_v4r1 using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{})); - using SrcCoordIterator = decltype(make_tensor_coordinate_iterator(SrcDesc{}, Index{})); + using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{})); __device__ constexpr ThreadwiseTensorSliceTransfer_v4r1(const Index& src_ref_idx) : src_ref_coord_(make_tensor_coordinate(SrcDesc{}, src_ref_idx)) @@ -732,12 +726,12 @@ struct ThreadwiseTensorSliceTransfer_v4r1 constexpr auto src_ref_to_data_disp_idx = src_ref_to_origin_disp_idx + data_to_origin_disp_idx; - constexpr auto src_ref_to_data_disp_coord_iterator = - make_tensor_coordinate_iterator(src_desc, src_ref_to_data_disp_idx); + constexpr auto src_ref_to_data_disp_coord_step = + make_tensor_coordinate_step(src_desc, src_ref_to_data_disp_idx); auto src_data_coord = src_ref_coord_; - move_tensor_coordinate(src_desc, src_data_coord, src_ref_to_data_disp_coord_iterator); + move_tensor_coordinate(src_desc, src_data_coord, src_ref_to_data_disp_coord_step); vector_type_maker_t src_vector; @@ -773,7 +767,7 @@ struct ThreadwiseTensorSliceTransfer_v4r1 constexpr auto src_desc = SrcDesc{}; const auto src_slice_move_step_iter = - make_tensor_coordinate_iterator(src_desc, to_multi_index(src_slice_move_step_idx)); + make_tensor_coordinate_step(src_desc, to_multi_index(src_slice_move_step_idx)); move_tensor_coordinate(SrcDesc{}, src_ref_coord_, src_slice_move_step_iter); } diff --git a/composable_kernel/src/kernel_wrapper/convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.cpp b/composable_kernel/src/kernel_wrapper/convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.cpp index 1843a0ca64..09a7fffa3e 100644 --- a/composable_kernel/src/kernel_wrapper/convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.cpp +++ b/composable_kernel/src/kernel_wrapper/convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.cpp @@ -113,16 +113,16 @@ extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcy using BKNGridDesc = decltype(b_k_n_grid_desc); using CMNGridDesc = decltype(c_m_n_grid_desc); - using AGridIteratorHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}))); + using AGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}))); - using BGridIteratorHacks = + using BGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}), @@ -130,21 +130,21 @@ extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcy Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}))); - using CGridIteratorHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{}))); + using CGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{}))); - using AGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0>; - using BGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>; + using AGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0>; + using BGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>; using GridwiseGemm = GridwiseGemmDlops_km_kn_mn_v1r2; + AGridStepHacks, + BGridStepHacks, + CGridStepHacks, + AGridMoveSliceWindowStepHacks, + BGridMoveSliceWindowStepHacks>; auto a_k_m0_m1_grid_desc = GridwiseGemm::MakeAKM0M1GridDescriptor(a_k_m_grid_desc); auto b_k_n0_n1_grid_desc = GridwiseGemm::MakeBKN0N1GridDescriptor(b_k_n_grid_desc); @@ -249,16 +249,16 @@ extern "C" __global__ void using BKNGridDesc = decltype(b_k_n_grid_desc); using CMNGridDesc = decltype(c_m_n_grid_desc); - using AGridIteratorHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}))); + using AGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}))); - using BGridIteratorHacks = + using BGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}), @@ -266,21 +266,21 @@ extern "C" __global__ void Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}))); - using CGridIteratorHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{}))); + using CGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{}))); - using AGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0>; - using BGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>; + using AGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0>; + using BGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>; using GridwiseGemm = GridwiseGemmDlops_km_kn_mn_v1r2; + AGridStepHacks, + BGridStepHacks, + CGridStepHacks, + AGridMoveSliceWindowStepHacks, + BGridMoveSliceWindowStepHacks>; constexpr auto a_k_m0_m1_grid_desc_tmp = GridwiseGemm::MakeAKM0M1GridDescriptor(a_k_m_grid_desc); diff --git a/composable_kernel/src/kernel_wrapper/convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.cpp b/composable_kernel/src/kernel_wrapper/convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.cpp index d434dab6fe..51d852617f 100644 --- a/composable_kernel/src/kernel_wrapper/convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.cpp +++ b/composable_kernel/src/kernel_wrapper/convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.cpp @@ -110,12 +110,12 @@ extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kc using BK0NK1GridDesc = decltype(b_k0_n_k1_grid_desc); using CMNGridDesc = decltype(c_m_n_grid_desc); - using AGridIteratorHacks = decltype(make_tuple( + using AGridStepHacks = decltype(make_tuple( make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}), make_tuple( Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}))); - using BGridIteratorHacks = + using BGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}), @@ -123,25 +123,25 @@ extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kc Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}))); - using CGridIteratorHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{}))); + using CGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{}))); - using AGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0>; - using BGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>; + using AGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0>; + using BGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>; using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3; auto c_m0_m1_m2_n_grid_desc = GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc); @@ -243,12 +243,12 @@ extern "C" __global__ void constexpr auto b_k0_n_k1_grid_desc_tmp = descs[I1]; constexpr auto c_m_n_grid_desc = descs[I2]; - using AGridIteratorHacks = decltype(make_tuple( + using AGridStepHacks = decltype(make_tuple( make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}), make_tuple( Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}))); - using BGridIteratorHacks = + using BGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}), @@ -256,25 +256,25 @@ extern "C" __global__ void Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}))); - using CGridIteratorHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{}))); + using CGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{}))); - using AGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0>; - using BGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>; + using AGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0>; + using BGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>; using AK0MK1GridDesc = decltype(a_k0_m_k1_grid_desc_tmp); using BK0NK1GridDesc = decltype(b_k0_n_k1_grid_desc_tmp); @@ -316,11 +316,11 @@ extern "C" __global__ void CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector, - AGridIteratorHacks, - BGridIteratorHacks, - CGridIteratorHacks, - AGridMoveSliceWindowIteratorHacks, - BGridMoveSliceWindowIteratorHacks, + AGridStepHacks, + BGridStepHacks, + CGridStepHacks, + AGridMoveSliceWindowStepHacks, + BGridMoveSliceWindowStepHacks, false>; constexpr auto c_m0_m1_m2_n_grid_desc_tmp = diff --git a/composable_kernel/src/kernel_wrapper/convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk.cpp b/composable_kernel/src/kernel_wrapper/convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk.cpp index 7678a69b12..30e4c518ce 100644 --- a/composable_kernel/src/kernel_wrapper/convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk.cpp +++ b/composable_kernel/src/kernel_wrapper/convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk.cpp @@ -110,12 +110,12 @@ extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_ky using BK0NK1GridDesc = decltype(b_k0_n_k1_grid_desc); using CMNGridDesc = decltype(c_m_n_grid_desc); - using BGridIteratorHacks = decltype(make_tuple( + using BGridStepHacks = decltype(make_tuple( make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}), make_tuple( Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}))); - using AGridIteratorHacks = + using AGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}), @@ -123,25 +123,25 @@ extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_ky Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}))); - using CGridIteratorHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{}))); + using CGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{}))); - using AGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>; - using BGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0>; + using AGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>; + using BGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0>; using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3; auto c_m0_m1_m2_n_grid_desc = GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc); @@ -247,12 +247,12 @@ extern "C" __global__ void using BK0NK1GridDesc = decltype(b_k0_n_k1_grid_desc_tmp); using CMNGridDesc = decltype(c_m_n_grid_desc); - using BGridIteratorHacks = decltype(make_tuple( + using BGridStepHacks = decltype(make_tuple( make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}), make_tuple( Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}))); - using AGridIteratorHacks = + using AGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}), @@ -260,25 +260,25 @@ extern "C" __global__ void Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}))); - using CGridIteratorHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{}))); + using CGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{}))); - using AGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>; - using BGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0>; + using AGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>; + using BGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0>; using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3; constexpr auto c_m0_m1_m2_n_grid_desc_tmp = GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc); diff --git a/composable_kernel/src/kernel_wrapper/convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.cpp b/composable_kernel/src/kernel_wrapper/convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.cpp index ac7e1dd6d4..9661f0e50c 100644 --- a/composable_kernel/src/kernel_wrapper/convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.cpp +++ b/composable_kernel/src/kernel_wrapper/convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.cpp @@ -111,7 +111,7 @@ convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw_prepare(index_t N, using BGridDesc_GK0_GN0_GN1_GK1 = decltype(b_grid_desc_gk0_gn0_gn1_gk1); using CGridDesc_GM0_GM1_GN0_GN1 = decltype(c_grid_desc_gm0_gm1_gn0_gn1); - using AGridIteratorHacks = + using AGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 0+: GK0 Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 1+: GM0 Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 2+: GM10 @@ -123,7 +123,7 @@ convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw_prepare(index_t N, Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 3-: GM11 Sequence<0, 0, 0, 0, 0, 0, 0>{}))); // 4-: GK1 - using BGridIteratorHacks = decltype(make_tuple( + using BGridStepHacks = decltype(make_tuple( make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GK0 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 1+: GN0 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 2+: GN10 @@ -135,7 +135,7 @@ convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw_prepare(index_t N, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 3-: GN11 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}))); // 4-: GK1 - using CGridIteratorHacks = decltype(make_tuple( + using CGridStepHacks = decltype(make_tuple( make_tuple( Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GM10 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 1+: BM0 @@ -151,9 +151,9 @@ convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw_prepare(index_t N, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}, // 4-: BN0 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}))); // 5-: GN1 - using AGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0, 0, 0>; + using AGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0, 0, 0>; - using BGridMoveSliceWindowIteratorHacks = + using BGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 2, 0, 0, 0, 0, 0>; using GridwiseContraction = @@ -191,11 +191,11 @@ convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw_prepare(index_t N, CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector, - AGridIteratorHacks, - BGridIteratorHacks, - CGridIteratorHacks, - AGridMoveSliceWindowIteratorHacks, - BGridMoveSliceWindowIteratorHacks>; + AGridStepHacks, + BGridStepHacks, + CGridStepHacks, + AGridMoveSliceWindowStepHacks, + BGridMoveSliceWindowStepHacks>; if(get_block_1d_id() == 0 && get_thread_local_1d_id() == 0) { @@ -254,7 +254,7 @@ extern "C" __global__ void using BGridDesc_GK0_GN0_GN1_GK1 = decltype(b_grid_desc_gk0_gn0_gn1_gk1); using CGridDesc_GM0_GM1_GN0_GN1 = decltype(c_grid_desc_gm0_gm1_gn0_gn1); - using AGridIteratorHacks = + using AGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 0+: GK0 Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 1+: GM0 Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 2+: GM10 @@ -266,7 +266,7 @@ extern "C" __global__ void Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 3-: GM11 Sequence<0, 0, 0, 0, 0, 0, 0>{}))); // 4-: GK1 - using BGridIteratorHacks = decltype(make_tuple( + using BGridStepHacks = decltype(make_tuple( make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GK0 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 1+: GN0 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 2+: GN10 @@ -278,7 +278,7 @@ extern "C" __global__ void Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 3-: GN11 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}))); // 4-: GK1 - using CGridIteratorHacks = decltype(make_tuple( + using CGridStepHacks = decltype(make_tuple( make_tuple( Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GM10 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 1+: BM0 @@ -294,9 +294,9 @@ extern "C" __global__ void Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}, // 4-: BN0 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}))); // 5-: GN1 - using AGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0, 0, 0>; + using AGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0, 0, 0>; - using BGridMoveSliceWindowIteratorHacks = + using BGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 2, 0, 0, 0, 0, 0>; using GridwiseContraction = @@ -334,11 +334,11 @@ extern "C" __global__ void CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector, - AGridIteratorHacks, - BGridIteratorHacks, - CGridIteratorHacks, - AGridMoveSliceWindowIteratorHacks, - BGridMoveSliceWindowIteratorHacks>; + AGridStepHacks, + BGridStepHacks, + CGridStepHacks, + AGridMoveSliceWindowStepHacks, + BGridMoveSliceWindowStepHacks>; using AGridDesc_GK0_GM0_GM10_GM11_GK1 = decltype(GridwiseContraction::MakeAGridDescriptor_GK0_GM0_GM10_GM11_GK1( diff --git a/host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp b/host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp index 5f162ec24b..7bd82bf6d5 100644 --- a/host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp +++ b/host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp @@ -207,7 +207,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk( const auto in_gemmm_gemmn_grid_desc = descs[I2]; // HACK: hacks that control index calculation when iterating over A, B, C matrix - constexpr auto wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks = + constexpr auto wei_gemmk0_gemmm_gemmk1_grid_step_hacks = make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: gemmk0 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: gemmm Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 2+: gemmk1 @@ -215,7 +215,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk( Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: Gemmm Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: Gemmk1 - constexpr auto out_gemmk0_gemmn_gemmk1_grid_iterator_hacks = make_tuple( + constexpr auto out_gemmk0_gemmn_gemmk1_grid_step_hacks = make_tuple( make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: gemmk0 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0>{}, // 1+: gemmn Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 2+: gemmk1 @@ -223,7 +223,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk( Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0>{}, // 1-: gemmn Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: gemmk1 - constexpr auto in_m0_m1_m2_n_grid_iterator_hacks = make_tuple( + constexpr auto in_m0_m1_m2_n_grid_step_hacks = make_tuple( make_tuple( Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: MRepeat Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 1+: NRepeat @@ -243,10 +243,10 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk( Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M2 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{})); // 7-: N1 - constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks = + constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}; - constexpr auto out_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks = + constexpr auto out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0>{}; for(index_t i = 0; i < 5; ++i) @@ -287,11 +287,11 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk( Sequence<1, 3, 7, 0, 2, 4, 5, 6>, 6, GemmCThreadTransferDstScalarPerVector, - decltype(wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks), - decltype(out_gemmk0_gemmn_gemmk1_grid_iterator_hacks), - decltype(in_m0_m1_m2_n_grid_iterator_hacks), - decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks), - decltype(out_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks), + decltype(wei_gemmk0_gemmm_gemmk1_grid_step_hacks), + decltype(out_gemmk0_gemmn_gemmk1_grid_step_hacks), + decltype(in_m0_m1_m2_n_grid_step_hacks), + decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks), + decltype(out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks), false // CAccessOrderMRepeatNRepeat >(static_cast(wei_k_y_x_c_device_buf.GetDeviceBuffer()), static_cast(out_n_ho_wo_k_device_buf.GetDeviceBuffer()), @@ -299,11 +299,11 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk( wei_gemmk0_gemmm_gemmk1_grid_desc, out_gemmk0_gemmn_gemmk1_grid_desc, in_gemmm_gemmn_grid_desc, - wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks, - out_gemmk0_gemmn_gemmk1_grid_iterator_hacks, - in_m0_m1_m2_n_grid_iterator_hacks, - wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks, - out_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks, + wei_gemmk0_gemmm_gemmk1_grid_step_hacks, + out_gemmk0_gemmn_gemmk1_grid_step_hacks, + in_m0_m1_m2_n_grid_step_hacks, + wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks, + out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks, nrepeat); { diff --git a/host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp b/host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp index 82539fdd11..0ebf8571f4 100644 --- a/host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp +++ b/host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp @@ -179,7 +179,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk const auto in_gemmm_gemmn_grid_desc = descs[I2]; // HACK: hacks that control index calculation when iterating over A, B, C matrix - constexpr auto out_gemmk0_gemmm_gemmk1_grid_iterator_hacks = make_tuple( + constexpr auto out_gemmk0_gemmm_gemmk1_grid_step_hacks = make_tuple( make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: gemmk0 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0>{}, // 1+: gemmm Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 2+: gemmk1 @@ -187,7 +187,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0>{}, // 1-: gemmm Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: gemmk1 - constexpr auto wei_gemmk0_gemmn_gemmk1_grid_iterator_hacks = + constexpr auto wei_gemmk0_gemmn_gemmk1_grid_step_hacks = make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: gemmk0 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: gemmn Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 2+: gemmk1 @@ -195,7 +195,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: Gemmn Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: Gemmk1 - constexpr auto in_m0_m1_m2_n_grid_iterator_hacks = make_tuple( + constexpr auto in_m0_m1_m2_n_grid_step_hacks = make_tuple( make_tuple( Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 0+: MRepeat Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: NRepeat @@ -215,10 +215,10 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 6-: M2 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N1 - constexpr auto out_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks = + constexpr auto out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0>{}; - constexpr auto wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks = + constexpr auto wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}; for(index_t i = 0; i < 5; ++i) @@ -263,11 +263,11 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk #endif 7, GemmCThreadTransferDstScalarPerVector, - decltype(out_gemmk0_gemmm_gemmk1_grid_iterator_hacks), - decltype(wei_gemmk0_gemmn_gemmk1_grid_iterator_hacks), - decltype(in_m0_m1_m2_n_grid_iterator_hacks), - decltype(out_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks), - decltype(wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks), + decltype(out_gemmk0_gemmm_gemmk1_grid_step_hacks), + decltype(wei_gemmk0_gemmn_gemmk1_grid_step_hacks), + decltype(in_m0_m1_m2_n_grid_step_hacks), + decltype(out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks), + decltype(wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks), true // CAccessOrderMRepeatNRepeat >(static_cast(out_n_ho_wo_k_device_buf.GetDeviceBuffer()), static_cast(wei_k_y_x_c_device_buf.GetDeviceBuffer()), @@ -275,11 +275,11 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk out_gemmk0_gemmm_gemmk1_grid_desc, wei_gemmk0_gemmn_gemmk1_grid_desc, in_gemmm_gemmn_grid_desc, - out_gemmk0_gemmm_gemmk1_grid_iterator_hacks, - wei_gemmk0_gemmn_gemmk1_grid_iterator_hacks, - in_m0_m1_m2_n_grid_iterator_hacks, - out_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks, - wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks, + out_gemmk0_gemmm_gemmk1_grid_step_hacks, + wei_gemmk0_gemmn_gemmk1_grid_step_hacks, + in_m0_m1_m2_n_grid_step_hacks, + out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks, + wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks, nrepeat); { diff --git a/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.hpp b/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.hpp index a2af8eab28..e6554cf0fe 100644 --- a/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.hpp +++ b/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.hpp @@ -89,7 +89,7 @@ void device_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw( in_right_pads); // HACK: hacks that control index calculation when iterating over A, B, C matrix - constexpr auto wei_gemmk_gemmm0_gemmn1_grid_iterator_hacks = + constexpr auto wei_gemmk_gemmm0_gemmn1_grid_step_hacks = make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, @@ -99,7 +99,7 @@ void device_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw( Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{})); - constexpr auto in_gemmk_gemmn0_gemmn1_grid_iterator_hacks = + constexpr auto in_gemmk_gemmn0_gemmn1_grid_step_hacks = make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}), @@ -107,7 +107,7 @@ void device_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw( Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{})); - constexpr auto out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_iterator_hacks = + constexpr auto out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_step_hacks = make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, @@ -121,10 +121,10 @@ void device_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw( Sequence<0, 0, 2, 0, 0>{}, Sequence<0, 0, 2, 0, 0>{})); - constexpr auto wei_gemmk_gemmm0_gemmm1_grid_move_slice_window_iterator_hacks = + constexpr auto wei_gemmk_gemmm0_gemmm1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0, 0, 0>{}; - constexpr auto in_gemmk_gemmn0_gemmn1_grid_move_slice_window_iterator_hacks = + constexpr auto in_gemmk_gemmn0_gemmn1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{}; const auto wei_gemmk_gemmm_grid_desc = descs[I0]; @@ -171,22 +171,22 @@ void device_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw( Sequence<3, 4, 5, 0, 1, 2>, // CThreadTransferSrcDstAccessOrder 5, // CThreadTransferSrcDstVectorDim GemmCThreadTransferDstScalarPerVector_N11, - decltype(wei_gemmk_gemmm0_gemmn1_grid_iterator_hacks), - decltype(in_gemmk_gemmn0_gemmn1_grid_iterator_hacks), - decltype(out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_iterator_hacks), - decltype(wei_gemmk_gemmm0_gemmm1_grid_move_slice_window_iterator_hacks), - decltype(in_gemmk_gemmn0_gemmn1_grid_move_slice_window_iterator_hacks)>( + decltype(wei_gemmk_gemmm0_gemmn1_grid_step_hacks), + decltype(in_gemmk_gemmn0_gemmn1_grid_step_hacks), + decltype(out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_step_hacks), + decltype(wei_gemmk_gemmm0_gemmm1_grid_move_slice_window_step_hacks), + decltype(in_gemmk_gemmn0_gemmn1_grid_move_slice_window_step_hacks)>( static_cast(wei_k_c_y_x_device_buf.GetDeviceBuffer()), static_cast(in_n_c_hi_wi_device_buf.GetDeviceBuffer()), static_cast(out_n_k_ho_wo_device_buf.GetDeviceBuffer()), wei_gemmk_gemmm_grid_desc, in_gemmk_gemmn_grid_desc, out_gemmm_gemmn_grid_desc, - wei_gemmk_gemmm0_gemmn1_grid_iterator_hacks, - in_gemmk_gemmn0_gemmn1_grid_iterator_hacks, - out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_iterator_hacks, - wei_gemmk_gemmm0_gemmm1_grid_move_slice_window_iterator_hacks, - in_gemmk_gemmn0_gemmn1_grid_move_slice_window_iterator_hacks, + wei_gemmk_gemmm0_gemmn1_grid_step_hacks, + in_gemmk_gemmn0_gemmn1_grid_step_hacks, + out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_step_hacks, + wei_gemmk_gemmm0_gemmm1_grid_move_slice_window_step_hacks, + in_gemmk_gemmn0_gemmn1_grid_move_slice_window_step_hacks, nrepeat); float perf = static_cast(calculate_convolution_flops( diff --git a/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk.hpp b/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk.hpp index d32eeea9cd..40685e81cf 100644 --- a/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk.hpp +++ b/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk.hpp @@ -155,7 +155,7 @@ void device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk( const auto out_gemmm_gemmn_grid_desc = descs[I2]; // HACK: hacks that control index calculation when iterating over A, B, C matrix - constexpr auto in_gemmk0_gemmm0_gemmm1_gemmk1_grid_iterator_hacks = make_tuple( + constexpr auto in_gemmk0_gemmm0_gemmm1_gemmk1_grid_step_hacks = make_tuple( make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 0+: GemmK0 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 1+: GemmM0 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 2+: GemmM1 @@ -165,7 +165,7 @@ void device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk( Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 3-: GemmM1 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{})); // 3-: GemmK1 - constexpr auto wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_iterator_hacks = + constexpr auto wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_step_hacks = make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GemmK0 Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: GemmN0 Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: GemmN1 @@ -175,7 +175,7 @@ void device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk( Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: GemmN1 Sequence<0, 0, 0, 0, 0, 0, 0, 0>{})); // 3-: GemmK1 - constexpr auto out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_iterator_hacks = + constexpr auto out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_step_hacks = make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: GemmM0 Sequence<0, 0, 0, 0, 0>{}, // 1+: GemmM10 Sequence<0, 0, 0, 0, 0>{}, // 2+: GemmM11 @@ -189,10 +189,10 @@ void device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk( Sequence<0, 0, 0, 0, 0>{}, // 4-: GemmN10 Sequence<0, 0, 0, 0, 0>{})); // 5-: GemmN11 - constexpr auto in_gemmk0_gemmm0_gemmm1_gemmk1_grid_move_slice_window_iterator_hacks = + constexpr auto in_gemmk0_gemmm0_gemmm1_gemmk1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0, 0, 0, 0>{}; - constexpr auto wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_move_slice_window_iterator_hacks = + constexpr auto wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}; for(index_t i = 0; i < 5; ++i) @@ -231,22 +231,22 @@ void device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk( Sequence<0, 1, 2, 3, 4, 5>, // CThreadTransferSrcDstAccessOrder 5, // CThreadTransferSrcDstVectorDim GemmCThreadTransferDstScalarPerVector_N11, - decltype(in_gemmk0_gemmm0_gemmm1_gemmk1_grid_iterator_hacks), - decltype(wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_iterator_hacks), - decltype(out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_iterator_hacks), - decltype(in_gemmk0_gemmm0_gemmm1_gemmk1_grid_move_slice_window_iterator_hacks), - decltype(wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_move_slice_window_iterator_hacks)>( + decltype(in_gemmk0_gemmm0_gemmm1_gemmk1_grid_step_hacks), + decltype(wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_step_hacks), + decltype(out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_step_hacks), + decltype(in_gemmk0_gemmm0_gemmm1_gemmk1_grid_move_slice_window_step_hacks), + decltype(wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_move_slice_window_step_hacks)>( static_cast(in_n_hi_wi_c_device_buf.GetDeviceBuffer()), static_cast(wei_k_y_x_c_device_buf.GetDeviceBuffer()), static_cast(out_n_ho_wo_k_device_buf.GetDeviceBuffer()), in_gemmk0_gemmm_gemmk1_grid_desc, wei_gemmk0_gemmn_gemmk1_grid_desc, out_gemmm_gemmn_grid_desc, - in_gemmk0_gemmm0_gemmm1_gemmk1_grid_iterator_hacks, - wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_iterator_hacks, - out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_iterator_hacks, - in_gemmk0_gemmm0_gemmm1_gemmk1_grid_move_slice_window_iterator_hacks, - wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_move_slice_window_iterator_hacks, + in_gemmk0_gemmm0_gemmm1_gemmk1_grid_step_hacks, + wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_step_hacks, + out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_step_hacks, + in_gemmk0_gemmm0_gemmm1_gemmk1_grid_move_slice_window_step_hacks, + wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_move_slice_window_step_hacks, nrepeat); { diff --git a/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp b/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp index d82fbf69d6..695ffeeb36 100644 --- a/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp +++ b/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp @@ -92,12 +92,12 @@ void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw( const auto out_gemmm_gemmn_grid_desc = descs[I2]; // HACK: hacks that control index calculation when iterating over A, B, C matrix - constexpr auto wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks = make_tuple( + constexpr auto wei_gemmk0_gemmm_gemmk1_grid_step_hacks = make_tuple( make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}), make_tuple( Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{})); - constexpr auto in_gemmk0_gemmn_gemmk1_grid_iterator_hacks = + constexpr auto in_gemmk0_gemmn_gemmk1_grid_step_hacks = make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}), @@ -105,7 +105,7 @@ void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw( Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})); - constexpr auto out_m0_m1_m2_n_grid_iterator_hacks = + constexpr auto out_m0_m1_m2_n_grid_step_hacks = make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 1, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, @@ -123,10 +123,10 @@ void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw( Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 2, 0, 0>{})); - constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks = + constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0, 0, 0>{}; - constexpr auto in_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks = + constexpr auto in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{}; for(index_t i = 0; i < 5; ++i) @@ -167,22 +167,22 @@ void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw( Sequence<3, 0, 1, 2, 7, 5, 4, 6>, 7, GemmCThreadTransferDstScalarPerVector, - decltype(wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks), - decltype(in_gemmk0_gemmn_gemmk1_grid_iterator_hacks), - decltype(out_m0_m1_m2_n_grid_iterator_hacks), - decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks), - decltype(in_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks), + decltype(wei_gemmk0_gemmm_gemmk1_grid_step_hacks), + decltype(in_gemmk0_gemmn_gemmk1_grid_step_hacks), + decltype(out_m0_m1_m2_n_grid_step_hacks), + decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks), + decltype(in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks), false>(static_cast(wei_k_c_y_x_device_buf.GetDeviceBuffer()), static_cast(in_n_c_hi_wi_device_buf.GetDeviceBuffer()), static_cast(out_n_k_ho_wo_device_buf.GetDeviceBuffer()), wei_gemmk0_gemmm_gemmk1_grid_desc, in_gemmk0_gemmn_gemmk1_grid_desc, out_gemmm_gemmn_grid_desc, - wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks, - in_gemmk0_gemmn_gemmk1_grid_iterator_hacks, - out_m0_m1_m2_n_grid_iterator_hacks, - wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks, - in_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks, + wei_gemmk0_gemmm_gemmk1_grid_step_hacks, + in_gemmk0_gemmn_gemmk1_grid_step_hacks, + out_m0_m1_m2_n_grid_step_hacks, + wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks, + in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks, nrepeat); float perf = static_cast(calculate_convolution_flops( diff --git a/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk.hpp b/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk.hpp index 37d89ec5a2..141a326574 100644 --- a/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk.hpp +++ b/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk.hpp @@ -121,12 +121,12 @@ void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk( const auto out_gemmm_gemmn_grid_desc = descs[I2]; // HACK: hacks that control index calculation when iterating over A, B, C matrix - constexpr auto wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks = make_tuple( + constexpr auto wei_gemmk0_gemmm_gemmk1_grid_step_hacks = make_tuple( make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}), make_tuple( Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{})); - constexpr auto in_gemmk0_gemmn_gemmk1_grid_iterator_hacks = + constexpr auto in_gemmk0_gemmn_gemmk1_grid_step_hacks = make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}), @@ -134,7 +134,7 @@ void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk( Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})); - constexpr auto out_m0_m1_m2_n_grid_iterator_hacks = + constexpr auto out_m0_m1_m2_n_grid_step_hacks = make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, @@ -144,10 +144,10 @@ void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk( Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 2, 0, 0>{})); - constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks = + constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0, 0, 0>{}; - constexpr auto in_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks = + constexpr auto in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{}; for(index_t i = 0; i < 5; ++i) @@ -187,22 +187,22 @@ void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk( Sequence<2, 3, 0, 1>, 2, GemmCThreadTransferDstScalarPerVector, - decltype(wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks), - decltype(in_gemmk0_gemmn_gemmk1_grid_iterator_hacks), - decltype(out_m0_m1_m2_n_grid_iterator_hacks), - decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks), - decltype(in_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks)>( + decltype(wei_gemmk0_gemmm_gemmk1_grid_step_hacks), + decltype(in_gemmk0_gemmn_gemmk1_grid_step_hacks), + decltype(out_m0_m1_m2_n_grid_step_hacks), + decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks), + decltype(in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks)>( static_cast(wei_k_y_x_c_device_buf.GetDeviceBuffer()), static_cast(in_n_hi_wi_c_device_buf.GetDeviceBuffer()), static_cast(out_n_ho_wo_k_device_buf.GetDeviceBuffer()), wei_gemmk0_gemmm_gemmk1_grid_desc, in_gemmk0_gemmn_gemmk1_grid_desc, out_gemmm_gemmn_grid_desc, - wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks, - in_gemmk0_gemmn_gemmk1_grid_iterator_hacks, - out_m0_m1_m2_n_grid_iterator_hacks, - wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks, - in_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks, + wei_gemmk0_gemmm_gemmk1_grid_step_hacks, + in_gemmk0_gemmn_gemmk1_grid_step_hacks, + out_m0_m1_m2_n_grid_step_hacks, + wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks, + in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks, nrepeat); { diff --git a/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk.hpp b/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk.hpp index d1671bb87c..692751bfb3 100644 --- a/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk.hpp +++ b/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk.hpp @@ -182,12 +182,12 @@ void device_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk( const auto out_gemmm_gemmn_grid_desc = descs[I2]; // HACK: hacks that control index calculation when iterating over A, B, C matrix - constexpr auto wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks = make_tuple( + constexpr auto wei_gemmk0_gemmm_gemmk1_grid_step_hacks = make_tuple( make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}), make_tuple( Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{})); - constexpr auto in_gemmk0_gemmn_gemmk1_grid_iterator_hacks = + constexpr auto in_gemmk0_gemmn_gemmk1_grid_step_hacks = make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}), @@ -195,7 +195,7 @@ void device_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk( Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})); - constexpr auto out_m0_m1_m2_n_grid_iterator_hacks = + constexpr auto out_m0_m1_m2_n_grid_step_hacks = make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 1, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, @@ -213,10 +213,10 @@ void device_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk( Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 2, 0, 0>{})); - constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks = + constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0, 0, 0>{}; - constexpr auto in_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks = + constexpr auto in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{}; for(index_t i = 0; i < 5; ++i) @@ -256,11 +256,11 @@ void device_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk( Sequence<2, 3, 0, 1, 7, 5, 4, 6>, 6, GemmCThreadTransferDstScalarPerVector, - decltype(wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks), - decltype(in_gemmk0_gemmn_gemmk1_grid_iterator_hacks), - decltype(out_m0_m1_m2_n_grid_iterator_hacks), - decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks), - decltype(in_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks), + decltype(wei_gemmk0_gemmm_gemmk1_grid_step_hacks), + decltype(in_gemmk0_gemmn_gemmk1_grid_step_hacks), + decltype(out_m0_m1_m2_n_grid_step_hacks), + decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks), + decltype(in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks), false // CAccessOrderMRepeatNRepeat >(static_cast(wei_k_y_x_c_device_buf.GetDeviceBuffer()), static_cast(in_n_hi_wi_c_device_buf.GetDeviceBuffer()), @@ -268,11 +268,11 @@ void device_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk( wei_gemmk0_gemmm_gemmk1_grid_desc, in_gemmk0_gemmn_gemmk1_grid_desc, out_gemmm_gemmn_grid_desc, - wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks, - in_gemmk0_gemmn_gemmk1_grid_iterator_hacks, - out_m0_m1_m2_n_grid_iterator_hacks, - wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks, - in_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks, + wei_gemmk0_gemmm_gemmk1_grid_step_hacks, + in_gemmk0_gemmn_gemmk1_grid_step_hacks, + out_m0_m1_m2_n_grid_step_hacks, + wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks, + in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks, nrepeat); { diff --git a/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp b/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp index 7a38b569c9..7067291c8a 100644 --- a/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp +++ b/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp @@ -233,7 +233,7 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk( const auto out_gemmm_gemmn_grid_desc = descs[I2]; // HACK: hacks that control index calculation when iterating over A, B, C matrix - constexpr auto in_gemmk0_gemmm_gemmk1_grid_iterator_hacks = + constexpr auto in_gemmk0_gemmm_gemmk1_grid_step_hacks = make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 0+: GemmK0 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 1+: GemmM Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}), // 2+: GemmK1 @@ -241,7 +241,7 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk( Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 1-: GemmM Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})); // 2-: GemmK1 - constexpr auto wei_gemmk0_gemmn_gemmk1_grid_iterator_hacks = + constexpr auto wei_gemmk0_gemmn_gemmk1_grid_step_hacks = make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: GemmK0 Sequence<0, 0, 0, 0, 0>{}, // 1+: GemmN Sequence<0, 0, 0, 0, 0>{}), // 2+: GemmK1 @@ -249,7 +249,7 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk( Sequence<0, 0, 0, 0, 0>{}, // 1-: GemmN Sequence<0, 0, 0, 0, 0>{})); // 2-: GemmK1 - constexpr auto out_m0_m1_m2_n_grid_iterator_hacks = + constexpr auto out_m0_m1_m2_n_grid_step_hacks = make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: MRepeat Sequence<0, 0, 0, 0, 0>{}, // 1+: NRepeat Sequence<0, 0, 0, 0, 0>{}, // 2+: MWaves @@ -267,10 +267,10 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk( Sequence<0, 0, 0, 0, 0>{}, // 6-: M2 Sequence<0, 0, 0, 0, 0>{})); // 7-: N1 - constexpr auto in_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks = + constexpr auto in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{}; - constexpr auto wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks = + constexpr auto wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0, 0, 0>{}; for(index_t i = 0; i < 5; ++i) @@ -311,11 +311,11 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk( Sequence<2, 3, 0, 1, 7, 5, 4, 6>, 7, GemmCThreadTransferDstScalarPerVector, - decltype(in_gemmk0_gemmm_gemmk1_grid_iterator_hacks), - decltype(wei_gemmk0_gemmn_gemmk1_grid_iterator_hacks), - decltype(out_m0_m1_m2_n_grid_iterator_hacks), - decltype(in_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks), - decltype(wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks), + decltype(in_gemmk0_gemmm_gemmk1_grid_step_hacks), + decltype(wei_gemmk0_gemmn_gemmk1_grid_step_hacks), + decltype(out_m0_m1_m2_n_grid_step_hacks), + decltype(in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks), + decltype(wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks), false // CAccessOrderMRepeatNRepeat >(static_cast(in_n_hi_wi_c_device_buf.GetDeviceBuffer()), static_cast(wei_k_y_x_c_device_buf.GetDeviceBuffer()), @@ -323,11 +323,11 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk( in_gemmk0_gemmm_gemmk1_grid_desc, wei_gemmk0_gemmn_gemmk1_grid_desc, out_gemmm_gemmn_grid_desc, - in_gemmk0_gemmm_gemmk1_grid_iterator_hacks, - wei_gemmk0_gemmn_gemmk1_grid_iterator_hacks, - out_m0_m1_m2_n_grid_iterator_hacks, - in_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks, - wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks, + in_gemmk0_gemmm_gemmk1_grid_step_hacks, + wei_gemmk0_gemmn_gemmk1_grid_step_hacks, + out_m0_m1_m2_n_grid_step_hacks, + in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks, + wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks, nrepeat); { diff --git a/host/driver_offline/include/device_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.hpp b/host/driver_offline/include/device_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.hpp index f2a8a1a2b2..0d28616386 100644 --- a/host/driver_offline/include/device_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.hpp +++ b/host/driver_offline/include/device_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.hpp @@ -130,7 +130,7 @@ void device_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw( const auto out_grid_desc_gm0_gm1_gn0_gn1 = descs[I2]; // HACK: hacks that control index calculation when iterating over A, B, C matrix - constexpr auto wei_grid_iterator_hacks = + constexpr auto wei_grid_step_hacks = make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 0+: GK0 Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 1+: GM0 Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 2+: GM10 @@ -142,7 +142,7 @@ void device_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw( Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 3-: GM11 Sequence<0, 0, 0, 0, 0, 0, 0>{})); // 4-: GK1 - constexpr auto in_grid_iterator_hacks = make_tuple( + constexpr auto in_grid_step_hacks = make_tuple( make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GK0 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 1+: GN0 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 2+: GN10 @@ -154,7 +154,7 @@ void device_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw( Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 3-: GN11 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 4-: GK1 - constexpr auto out_grid_iterator_hacks = make_tuple( + constexpr auto out_grid_step_hacks = make_tuple( make_tuple( Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GM10 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 1+: BM0 @@ -170,9 +170,9 @@ void device_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw( Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}, // 4-: BN0 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{})); // 5-: GN1 - constexpr auto wei_grid_move_slice_window_iterator_hacks = Sequence<0, 0, 0, 0, 0, 0, 0>{}; + constexpr auto wei_grid_move_slice_window_step_hacks = Sequence<0, 0, 0, 0, 0, 0, 0>{}; - constexpr auto in_grid_move_slice_window_iterator_hacks = + constexpr auto in_grid_move_slice_window_step_hacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 2, 0, 0, 0, 0, 0>{}; for(index_t i = 0; i < 5; ++i) @@ -211,22 +211,22 @@ void device_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw( Sequence<3, 4, 5, 0, 1, 2>, // CThreadTransferSrcDstAccessOrder 5, // CThreadTransferSrcDstVectorDim CThreadTransferDstScalarPerVector_BN1, - decltype(wei_grid_iterator_hacks), - decltype(in_grid_iterator_hacks), - decltype(out_grid_iterator_hacks), - decltype(wei_grid_move_slice_window_iterator_hacks), - decltype(in_grid_move_slice_window_iterator_hacks)>( + decltype(wei_grid_step_hacks), + decltype(in_grid_step_hacks), + decltype(out_grid_step_hacks), + decltype(wei_grid_move_slice_window_step_hacks), + decltype(in_grid_move_slice_window_step_hacks)>( static_cast(wei_k_c_y_x_device_buf.GetDeviceBuffer()), static_cast(in_n_c_hi_wi_device_buf.GetDeviceBuffer()), static_cast(out_n_k_ho_wo_device_buf.GetDeviceBuffer()), wei_grid_desc_gk0_gm0_gm1_gk1, in_grid_desc_gk0_gn0_gn1_gk1, out_grid_desc_gm0_gm1_gn0_gn1, - wei_grid_iterator_hacks, - in_grid_iterator_hacks, - out_grid_iterator_hacks, - wei_grid_move_slice_window_iterator_hacks, - in_grid_move_slice_window_iterator_hacks, + wei_grid_step_hacks, + in_grid_step_hacks, + out_grid_step_hacks, + wei_grid_move_slice_window_step_hacks, + in_grid_move_slice_window_step_hacks, nrepeat); float perf = static_cast(calculate_convolution_flops( diff --git a/host/driver_offline/include/driver_contraction_dlops_v1r2.hpp b/host/driver_offline/include/driver_contraction_dlops_v1r2.hpp index fbd1ce4e5e..d207728a2e 100644 --- a/host/driver_offline/include/driver_contraction_dlops_v1r2.hpp +++ b/host/driver_offline/include/driver_contraction_dlops_v1r2.hpp @@ -39,11 +39,11 @@ template + typename AGridStepHacks, + typename BGridStepHacks, + typename CGridStepHacks, + typename AGridMoveSliceWindowStepHacks, + typename BGridMoveSliceWindowStepHacks> __host__ float driver_contraction_dlops_v1r2(const FloatAB* p_a_grid, const FloatAB* p_b_grid, @@ -51,11 +51,11 @@ driver_contraction_dlops_v1r2(const FloatAB* p_a_grid, const AGridDesc_GK0_GM0_GM1_GK1& a_grid_desc_gk0_gm0_gm1_gk1, const BGridDesc_GK0_GN0_GN1_GK1& b_grid_desc_gk0_gn0_gn1_gk1, const CGridDesc_GM0_GM1_GN0_GN1& c_grid_desc_gm0_gm1_gn0_gn1, - AGridIteratorHacks, - BGridIteratorHacks, - CGridIteratorHacks, - AGridMoveSliceWindowIteratorHacks, - BGridMoveSliceWindowIteratorHacks, + AGridStepHacks, + BGridStepHacks, + CGridStepHacks, + AGridMoveSliceWindowStepHacks, + BGridMoveSliceWindowStepHacks, ck::index_t nrepeat) { @@ -104,11 +104,11 @@ driver_contraction_dlops_v1r2(const FloatAB* p_a_grid, CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector, - AGridIteratorHacks, - BGridIteratorHacks, - CGridIteratorHacks, - AGridMoveSliceWindowIteratorHacks, - BGridMoveSliceWindowIteratorHacks>; + AGridStepHacks, + BGridStepHacks, + CGridStepHacks, + AGridMoveSliceWindowStepHacks, + BGridMoveSliceWindowStepHacks>; const auto GK0 = a_grid_desc_gk0_gm0_gm1_gk1.GetLength(I0); diff --git a/host/driver_offline/include/driver_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp b/host/driver_offline/include/driver_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp index 6f4db5ff7b..efd4ce6a19 100644 --- a/host/driver_offline/include/driver_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp +++ b/host/driver_offline/include/driver_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp @@ -136,13 +136,13 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad } // hack to control index calculation when iterating over a_k_m_global tensor - constexpr auto a_e_k_global_iterator_hacks = + constexpr auto a_e_k_global_step_hacks = make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}), make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{})); - constexpr auto a_e_k_global_move_slice_window_iterator_hack = Sequence<0, 0, 0>{}; + constexpr auto a_e_k_global_move_slice_window_step_hack = Sequence<0, 0, 0>{}; - constexpr auto b_e_n_ho_wo_global_iterator_hacks = + constexpr auto b_e_n_ho_wo_global_step_hacks = make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, @@ -152,12 +152,12 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); - constexpr auto b_e_n_ho_wo_global_move_slice_window_iterator_hack = + constexpr auto b_e_n_ho_wo_global_move_slice_window_step_hack = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}; // hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor // hack for NKHW format - constexpr auto c_k_n_ho_wo_global_tensor_iterator_hacks = + constexpr auto c_k_n_ho_wo_global_tensor_step_hacks = make_tuple(make_tuple(Sequence<0, 1, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, @@ -202,11 +202,11 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad Sequence<0, 2, 3, 1>, 0, CThreadTransferDstScalarPerVector_W, - decltype(a_e_k_global_iterator_hacks), - decltype(b_e_n_ho_wo_global_iterator_hacks), - decltype(c_k_n_ho_wo_global_tensor_iterator_hacks), - decltype(a_e_k_global_move_slice_window_iterator_hack), - decltype(b_e_n_ho_wo_global_move_slice_window_iterator_hack)>; + decltype(a_e_k_global_step_hacks), + decltype(b_e_n_ho_wo_global_step_hacks), + decltype(c_k_n_ho_wo_global_tensor_step_hacks), + decltype(a_e_k_global_move_slice_window_step_hack), + decltype(b_e_n_ho_wo_global_move_slice_window_step_hack)>; const auto GridSize = (K / KPerBlock) * (Ho / HoPerBlock) * (Wo / WoPerBlock) * N; diff --git a/host/driver_offline/include/driver_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw_outpad.hpp b/host/driver_offline/include/driver_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw_outpad.hpp index 1b7179173c..70f73cbf4a 100644 --- a/host/driver_offline/include/driver_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw_outpad.hpp +++ b/host/driver_offline/include/driver_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw_outpad.hpp @@ -149,13 +149,13 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp } // hack to control index calculation when iterating over a_k_m_global tensor - constexpr auto a_e_k_global_iterator_hacks = + constexpr auto a_e_k_global_step_hacks = make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}), make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{})); - constexpr auto a_e_k_global_move_slice_window_iterator_hack = Sequence<0, 0, 0>{}; + constexpr auto a_e_k_global_move_slice_window_step_hack = Sequence<0, 0, 0>{}; - constexpr auto b_e_n_ho_wo_global_iterator_hacks = + constexpr auto b_e_n_ho_wo_global_step_hacks = make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, @@ -165,12 +165,12 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); - constexpr auto b_e_n_ho_wo_global_move_slice_window_iterator_hack = + constexpr auto b_e_n_ho_wo_global_move_slice_window_step_hack = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}; // hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor // hack for NKHW format - constexpr auto c_k_n_ho_wo_global_tensor_iterator_hacks = + constexpr auto c_k_n_ho_wo_global_tensor_step_hacks = make_tuple(make_tuple(Sequence<0, 1, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, @@ -214,11 +214,11 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp Sequence<0, 2, 3, 1>, 0, CThreadTransferDstScalarPerVector_W, - decltype(a_e_k_global_iterator_hacks), - decltype(b_e_n_ho_wo_global_iterator_hacks), - decltype(c_k_n_ho_wo_global_tensor_iterator_hacks), - decltype(a_e_k_global_move_slice_window_iterator_hack), - decltype(b_e_n_ho_wo_global_move_slice_window_iterator_hack)>; + decltype(a_e_k_global_step_hacks), + decltype(b_e_n_ho_wo_global_step_hacks), + decltype(c_k_n_ho_wo_global_tensor_step_hacks), + decltype(a_e_k_global_move_slice_window_step_hack), + decltype(b_e_n_ho_wo_global_move_slice_window_step_hack)>; const auto GridSize = (K / KPerBlock) * (Hop / HoPerBlock) * (Wop / WoPerBlock) * N; diff --git a/host/driver_offline/include/driver_gemm_dlops_v1r2.hpp b/host/driver_offline/include/driver_gemm_dlops_v1r2.hpp index 114f31e760..bf5f7f1c0f 100644 --- a/host/driver_offline/include/driver_gemm_dlops_v1r2.hpp +++ b/host/driver_offline/include/driver_gemm_dlops_v1r2.hpp @@ -43,22 +43,22 @@ template + typename AGridStepHacks, + typename BGridStepHacks, + typename CGridStepHacks, + typename AGridMoveSliceWindowStepHacks, + typename BGridMoveSliceWindowStepHacks> __host__ float driver_gemm_dlops_v1r2(const FloatAB* p_a_grid, const FloatAB* p_b_grid, FloatC* p_c_grid, const AKMGridDesc& a_k_m_grid_desc, const BKNGridDesc& b_k_n_grid_desc, const CMNGridDesc& c_m_n_grid_desc, - AGridIteratorHacks, - BGridIteratorHacks, - CGridIteratorHacks, - AGridMoveSliceWindowIteratorHacks, - BGridMoveSliceWindowIteratorHacks, + AGridStepHacks, + BGridStepHacks, + CGridStepHacks, + AGridMoveSliceWindowStepHacks, + BGridMoveSliceWindowStepHacks, ck::index_t nrepeat) { @@ -109,11 +109,11 @@ __host__ float driver_gemm_dlops_v1r2(const FloatAB* p_a_grid, CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector, - AGridIteratorHacks, - BGridIteratorHacks, - CGridIteratorHacks, - AGridMoveSliceWindowIteratorHacks, - BGridMoveSliceWindowIteratorHacks>; + AGridStepHacks, + BGridStepHacks, + CGridStepHacks, + AGridMoveSliceWindowStepHacks, + BGridMoveSliceWindowStepHacks>; const auto M = a_k_m_grid_desc.GetLength(I1); const auto N = b_k_n_grid_desc.GetLength(I1); diff --git a/host/driver_offline/include/driver_gemm_dlops_v1r3.hpp b/host/driver_offline/include/driver_gemm_dlops_v1r3.hpp index a9350bf0f8..4470918820 100644 --- a/host/driver_offline/include/driver_gemm_dlops_v1r3.hpp +++ b/host/driver_offline/include/driver_gemm_dlops_v1r3.hpp @@ -39,22 +39,22 @@ template + typename AGridStepHacks, + typename BGridStepHacks, + typename CGridStepHacks, + typename AGridMoveSliceWindowStepHacks, + typename BGridMoveSliceWindowStepHacks> __host__ float driver_gemm_dlops_v1r3(const FloatAB* p_a_grid, const FloatAB* p_b_grid, FloatC* p_c_grid, const AK0MK1GridDesc& a_k0_m_k1_grid_desc, const BK0NK1GridDesc& b_k0_n_k1_grid_desc, const CMNGridDesc& c_m_n_grid_desc, - AGridIteratorHacks, - BGridIteratorHacks, - CGridIteratorHacks, - AGridMoveSliceWindowIteratorHacks, - BGridMoveSliceWindowIteratorHacks, + AGridStepHacks, + BGridStepHacks, + CGridStepHacks, + AGridMoveSliceWindowStepHacks, + BGridMoveSliceWindowStepHacks, ck::index_t nrepeat) { @@ -102,11 +102,11 @@ __host__ float driver_gemm_dlops_v1r3(const FloatAB* p_a_grid, CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector, - AGridIteratorHacks, - BGridIteratorHacks, - CGridIteratorHacks, - AGridMoveSliceWindowIteratorHacks, - BGridMoveSliceWindowIteratorHacks>; + AGridStepHacks, + BGridStepHacks, + CGridStepHacks, + AGridMoveSliceWindowStepHacks, + BGridMoveSliceWindowStepHacks>; const auto M = a_k0_m_k1_grid_desc.GetLength(I1); const auto N = b_k0_n_k1_grid_desc.GetLength(I1); diff --git a/host/driver_offline/include/driver_gemm_xdlops_v2r3.hpp b/host/driver_offline/include/driver_gemm_xdlops_v2r3.hpp index c29dbdae69..edfce52a19 100644 --- a/host/driver_offline/include/driver_gemm_xdlops_v2r3.hpp +++ b/host/driver_offline/include/driver_gemm_xdlops_v2r3.hpp @@ -41,11 +41,11 @@ template __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid, const FloatAB* p_b_grid, @@ -53,11 +53,11 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid, const AK0MK1GridDesc& a_k0_m_k1_grid_desc, const BK0NK1GridDesc& b_k0_n_k1_grid_desc, const CMNGridDesc& c_m_n_grid_desc, - AGridIteratorHacks, - BGridIteratorHacks, - CGridIteratorHacks, - AGridMoveSliceWindowIteratorHacks, - BGridMoveSliceWindowIteratorHacks, + AGridStepHacks, + BGridStepHacks, + CGridStepHacks, + AGridMoveSliceWindowStepHacks, + BGridMoveSliceWindowStepHacks, ck::index_t nrepeat) { @@ -103,11 +103,11 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid, CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector, - AGridIteratorHacks, - BGridIteratorHacks, - CGridIteratorHacks, - AGridMoveSliceWindowIteratorHacks, - BGridMoveSliceWindowIteratorHacks, + AGridStepHacks, + BGridStepHacks, + CGridStepHacks, + AGridMoveSliceWindowStepHacks, + BGridMoveSliceWindowStepHacks, CAccessOrderMRepeatNRepeat>; {