From 2e92e52137c3b4e429c60a696080b7ae37407041 Mon Sep 17 00:00:00 2001 From: arai713 <67439843+arai713@users.noreply.github.com> Date: Mon, 12 Dec 2022 07:18:10 -0800 Subject: [PATCH] Gridwise elementwise 2d (#466) * added 2d gridwise elementwise * added 2d version of device elementwise * added example file with updated device elementwise call * added Cmake file * changed NumDim into 2D * fixed compiler issues * fixed indexing for loop step * fixed NumDim dimension error * changed blockID to 2D * updated Grid Desc * updated kernel call * fixed 2d thread indexing * added dimensions for example file * commented out unused code * changed vector load * removed extra code * temporarily removing vector load on 2nd dim * changed vector load back, still causing errors * altered indexing * changed isSupportedArgument for 2D * changed indexing + do/while * fixed isSupportedArgument * changed dimension for debugging * fixed * added testing printouts * testing change * added variables to distribute threads through both dimensions * testing changes * integrated variable for thread distribution into device elementwise and added as parameter for gridwise elementwise * removed most of the extraneous code, testing with different dimensions * testing * removed debugging print statements * moved 2d elementwise permute into elementwise permute directory * fixed formatting * removed debugging comments from threadwise transfer Co-authored-by: Jing Zhang Co-authored-by: Po Yen Chen [ROCm/composable_kernel commit: 0e5c264c3e954c34483d8a50d9b622d5d455a160] --- example/44_elementwise_permute/CMakeLists.txt | 1 + .../elementwise_permute_4D_fp16_2d.cpp | 130 +++++++ .../gpu/device/device_elementwise_2d.hpp | 341 ++++++++++++++++++ .../gpu/grid/gridwise_elementwise_2d.hpp | 230 ++++++++++++ 4 files changed, 702 insertions(+) create mode 100644 example/44_elementwise_permute/elementwise_permute_4D_fp16_2d.cpp create mode 100644 include/ck/tensor_operation/gpu/device/device_elementwise_2d.hpp create mode 100644 include/ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp diff --git a/example/44_elementwise_permute/CMakeLists.txt b/example/44_elementwise_permute/CMakeLists.txt index 280797ad71..0e0091a986 100644 --- a/example/44_elementwise_permute/CMakeLists.txt +++ b/example/44_elementwise_permute/CMakeLists.txt @@ -1 +1,2 @@ add_example_executable(example_elementwise_permute_4D_fp16 elementwise_permute_4D_fp16.cpp) +add_example_executable(example_elementwise_permute_4D_fp16_2d elementwise_permute_4D_fp16_2d.cpp) diff --git a/example/44_elementwise_permute/elementwise_permute_4D_fp16_2d.cpp b/example/44_elementwise_permute/elementwise_permute_4D_fp16_2d.cpp new file mode 100644 index 0000000000..f16ad3b3c5 --- /dev/null +++ b/example/44_elementwise_permute/elementwise_permute_4D_fp16_2d.cpp @@ -0,0 +1,130 @@ +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/device_elementwise_2d.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" + +using F16 = ck::half_t; + +using ADataType = F16; +using BDataType = F16; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using DeviceElementwisePermuteInstance = + ck::tensor_operation::device::DeviceElementwise, + ck::Tuple, + PassThrough, + 3, // NumDim_M + 1, // NumDim_N + 8, + 8, + ck::Sequence<8>, + ck::Sequence<8>>; + +template +void host_elementwise4D(HostTensorB& B_nhwc, + const HostTensorA& A_nchw, + const std::vector& shape_nchw, + Functor functor) +{ + for(std::size_t n = 0; n < shape_nchw[0]; ++n) + for(std::size_t c = 0; c < shape_nchw[1]; ++c) + for(std::size_t h = 0; h < shape_nchw[2]; ++h) + for(std::size_t w = 0; w < shape_nchw[3]; ++w) + { + auto a_val = A_nchw(n, c, h, w); + functor(B_nhwc(n, h, w, c), a_val); + } +} + +int main() +{ + bool do_verification = true; + bool time_kernel = true; + + const int N = 120; + const int C = 128; + const int H = 32; + const int W = 1024; + + /**const int N = 120; + const int H = 32; + const int W = 64; + + const int C = 128;**/ + + std::vector nchw = {N, C, H, W}; + std::vector nhwc = {N, H, W, C}; + + Tensor a(nchw); + Tensor b(nhwc); + + a.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + + DeviceMem a_device_buf(sizeof(ADataType) * a.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a.mData.data()); + // LogRangeAsType(std::cout << "Tensor a : ", a.mData, ",") << std::endl; + + std::array input = {a_device_buf.GetDeviceBuffer()}; + std::array output = {b_device_buf.GetDeviceBuffer()}; + + std::array ab_lengths{N, H, W, C}; + + std::array a_strides = {C * H * W, W, 1, H * W}; + std::array b_strides = {H * W * C, W * C, C, 1}; + + auto broadcastPermute = DeviceElementwisePermuteInstance{}; + auto argument = broadcastPermute.MakeArgumentPointer( + ab_lengths, {a_strides}, {b_strides}, input, output, PassThrough{}); + + if(!broadcastPermute.IsSupportedArgument(argument.get())) + { + throw std::runtime_error( + "The runtime parameters seems not supported by the device instance, exiting!"); + }; + + std::cout << "A (nchw): " << a.mDesc << std::endl; + std::cout << "B (nhwc): " << b.mDesc << std::endl; + + auto broadcastPermute_invoker_ptr = broadcastPermute.MakeInvokerPointer(); + float ave_time = + broadcastPermute_invoker_ptr->Run(argument.get(), StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * nchw[0] * nchw[1] * nchw[2] * nchw[3]; + + std::size_t num_btype = sizeof(ADataType) * (nchw[0] * nchw[1] * nchw[2] * nchw[3]) + + sizeof(BDataType) * (nchw[0] * nchw[1] * nchw[2] * nchw[3]); + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + bool pass = true; + + if(do_verification) + { + b_device_buf.FromDevice(b.mData.data()); + // LogRangeAsType(std::cout << "Tensor b : ", b.mData, ",") << std::endl; + + Tensor host_b(nhwc); + host_elementwise4D, Tensor, PassThrough>( + host_b, a, nchw, PassThrough{}); + + // LogRangeAsType(std::cout << "Host b : ", host_b.mData, ",") << std::endl; + pass &= + ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3); + } + + return pass ? 0 : 1; +} diff --git a/include/ck/tensor_operation/gpu/device/device_elementwise_2d.hpp b/include/ck/tensor_operation/gpu/device/device_elementwise_2d.hpp new file mode 100644 index 0000000000..23aada0f44 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_elementwise_2d.hpp @@ -0,0 +1,341 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/math.hpp" +#include "ck/utility/sequence.hpp" +#include "ck/tensor_operation/gpu/device/device_elementwise_base.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" + +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +template +struct DeviceElementwise : public DeviceElementwiseBase +{ + static constexpr index_t NumDim = NumDim_m + NumDim_n; + + static constexpr int NumInput = InDataTypeTuple::Size(); + static constexpr int NumOutput = OutDataTypeTuple::Size(); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + static_assert(NumInput == InScalarPerVectorSeq::Size() && + NumOutput == OutScalarPerVectorSeq::Size(), + "Tuple size is inconsistent with the number of in/out!"); + + static auto GenerateInDataTypePointerTuple() + { + return generate_tuple( + [&](auto I) { + using DataType = remove_cvref_t; + + return static_cast(nullptr); + }, + Number{}); + }; + + static auto GenerateOutDataTypePointerTuple() + { + return generate_tuple( + [&](auto I) { + using DataType = remove_cvref_t; + + return static_cast(nullptr); + }, + Number{}); + }; + + using InDataTypePointerTuple = decltype(GenerateInDataTypePointerTuple()); + using OutDataTypePointerTuple = decltype(GenerateOutDataTypePointerTuple()); + + template + static auto PadDescriptor_MN_2d(Desc_MN desc_mn, + index_t gridSize, + index_t blockSize, + index_t num_threads_m, + index_t num_threads_n) + { + std::ignore = blockSize; + std::ignore = gridSize; + const auto m = desc_mn.GetLength(I0); + const auto n = desc_mn.GetLength(I1); + const index_t loop_step_m = num_threads_m * MPerThread; + const index_t loop_step_n = num_threads_n * NPerThread; + const auto pad_m = math::integer_least_multiple(m, loop_step_m) - m; + const auto pad_n = math::integer_least_multiple(n, loop_step_n) - n; + + const auto desc_mn_pad = transform_tensor_descriptor( + desc_mn, + make_tuple(make_right_pad_transform(m, pad_m), make_right_pad_transform(n, pad_n)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + return desc_mn_pad; + } + + static auto MakeDescriptor_MN(const std::array& lengths, + const std::array& stride, + index_t gridSize, + index_t blockSize, + index_t num_threads_m, + index_t num_threads_n) + { + auto tupleOfShape = generate_tuple([&](auto I) { return lengths[I]; }, Number{}); + auto tupleOfStride = generate_tuple([&](auto I) { return stride[I]; }, Number{}); + + // nd desc - [s0, s1, s2, ...] + const auto desc = make_naive_tensor_descriptor(tupleOfShape, tupleOfStride); + + constexpr auto mDimIds = typename arithmetic_sequence_gen<0, NumDim_m, 1>::type(); + constexpr auto nDimIds = + typename arithmetic_sequence_gen::type(); + + const auto mLengths = get_container_subset(tupleOfShape, mDimIds); + const auto nLengths = get_container_subset(tupleOfShape, nDimIds); + + // merge nd to 2d desc - [s0 * s1 * ...] + + if constexpr(NumDim > 2) + { + const auto desc_mn = transform_tensor_descriptor( + desc, + make_tuple(make_merge_transform(mLengths), make_merge_transform(nLengths)), + make_tuple(mDimIds, nDimIds), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return PadDescriptor_MN_2d(desc_mn, gridSize, blockSize, num_threads_m, num_threads_n); + } + else + return PadDescriptor_MN_2d(desc, gridSize, blockSize, num_threads_m, num_threads_n); + } + + template + static auto GenerateInOutGrid2dDescTuple(Number) + { + return generate_tuple( + [&](auto) { + if constexpr(NumDim > 2) + { + return MakeDescriptor_MN({1, 1}, {1, 1}, 1, 1, 1, 1); + } + else + { + return MakeDescriptor_MN({1}, {1}, 1, 1, 1, 1); + }; + }, + Number{}); + }; + + using OutGrid2dDescTuple = decltype(GenerateInOutGrid2dDescTuple(Number{})); + using InGrid2dDescTuple = decltype(GenerateInOutGrid2dDescTuple(Number{})); + + using GridwiseElementwise = GridwiseElementwise_2D; + + struct Argument : public BaseArgument + { + Argument(const std::array lengths, + const std::array, NumInput> inStridesArray, + const std::array, NumOutput> outStridesArray, + const std::array in_dev_buffers, + const std::array out_dev_buffers, + ElementwiseOperation elementwise_op) + + : lengths_(lengths), + inStridesArray_(inStridesArray), + outStridesArray_(outStridesArray), + elementwise_op_(elementwise_op), + blockSize_(256), + gridSize_(120), // FIXME - Calculate the grid size by number of CU in the future + num_threads_m_((gridSize_ * blockSize_) / 16), + num_threads_n_(16) + { + static_assert(NumDim_m > 0, ""); + static_assert(NumDim_n > 0, ""); + + in_dev_buffers_ = generate_tuple( + [&](auto I) { + using DataType = remove_cvref_t; + return static_cast(in_dev_buffers[I.value]); + }, + Number{}); + + out_dev_buffers_ = generate_tuple( + [&](auto I) { + using DataType = remove_cvref_t; + return static_cast(out_dev_buffers[I.value]); + }, + Number{}); + + in_grid_2d_desc_tuple_ = generate_tuple( + [&](auto I) { + return MakeDescriptor_MN(lengths, + inStridesArray[I.value], + gridSize_, + blockSize_, + num_threads_m_, + num_threads_n_); + }, + Number{}); + + out_grid_2d_desc_tuple_ = generate_tuple( + [&](auto I) { + return MakeDescriptor_MN(lengths, + outStridesArray[I.value], + gridSize_, + blockSize_, + num_threads_m_, + num_threads_n_); + }, + Number{}); + } + + InDataTypePointerTuple in_dev_buffers_; + OutDataTypePointerTuple out_dev_buffers_; + InGrid2dDescTuple in_grid_2d_desc_tuple_; + OutGrid2dDescTuple out_grid_2d_desc_tuple_; + + std::array lengths_; + std::array, NumInput> inStridesArray_; + std::array, NumOutput> outStridesArray_; + + ElementwiseOperation elementwise_op_; + index_t blockSize_; + index_t gridSize_; + index_t num_threads_m_; + index_t num_threads_n_; + }; + + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + const auto kernel = kernel_elementwise_2d; + + float elapsed_time = launch_and_time_kernel(stream_config, + kernel, + dim3(arg.gridSize_), + dim3(arg.blockSize_), + 0, + arg.in_grid_2d_desc_tuple_, + arg.out_grid_2d_desc_tuple_, + arg.in_dev_buffers_, + arg.out_dev_buffers_, + arg.elementwise_op_, + arg.num_threads_m_, + arg.num_threads_n_); + return elapsed_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + const Argument* pArg = dynamic_cast(p_arg); + + if(pArg == nullptr) + return false; + + if(pArg->lengths_.back() % MPerThread != 0) + return false; + + auto IsScalarPerVectorValid = [&](const std::array& lengths, + const std::array& strides, + index_t scalarPerVector, + index_t vectorDim) { + if(strides[vectorDim] == 1 && + (lengths[vectorDim] % scalarPerVector == 0 || + lengths[vectorDim] % scalarPerVector == lengths[vectorDim])) + { + return true; + } + if(strides[vectorDim] != 1 && scalarPerVector == strides[vectorDim]) + { + return true; + } + return false; + }; + + bool valid = true; + static_for<0, NumInput, 1>{}([&](auto I) { + if(!IsScalarPerVectorValid(pArg->lengths_, + pArg->inStridesArray_[I.value], + InScalarPerVectorSeq::At(I), + NumDim_m - 1)) + valid = false; + }); + + static_for<0, NumOutput, 1>{}([&](auto I) { + if(!IsScalarPerVectorValid(pArg->lengths_, + pArg->outStridesArray_[I.value], + OutScalarPerVectorSeq::At(I), + NumDim - 1)) + valid = false; + }); + + return valid; + }; + + std::unique_ptr + MakeArgumentPointer(const std::array lengths, + const std::array, NumInput> inStridesArray, + const std::array, NumOutput> outStridesArray, + const std::array in_dev_buffers, + const std::array out_dev_buffers, + ElementwiseOperation elementwise_op) override + { + return std::make_unique(lengths, + inStridesArray, + outStridesArray, + in_dev_buffers, + out_dev_buffers, + elementwise_op); + } + + static auto MakeInvoker() { return Invoker{}; } + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(); + }; +}; // namespace device + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp new file mode 100644 index 0000000000..05257d1627 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp @@ -0,0 +1,230 @@ +// SPDX-License-Identifier: MIT +// // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// +#pragma once + +#include "ck/tensor_description/cluster_descriptor.hpp" +#include "ck/utility/data_type.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +template +__global__ void kernel_elementwise_2d(const InGrid2dDescTuple in_grid_2d_desc_tuple, + const OutGrid2dDescTuple out_grid_2d_desc_tuple, + const InDataTypePointerTuple p_in_global_tuple, + const OutDataTypePointerTuple p_out_global_tuple, + const ElementwiseOperation elementwise_op, + const index_t num_threads_m, + const index_t num_threads_n) +{ + GridwiseElementwise2dFunctor::Run(in_grid_2d_desc_tuple, + out_grid_2d_desc_tuple, + p_in_global_tuple, + p_out_global_tuple, + elementwise_op, + num_threads_m, + num_threads_n); +} + +template +struct GridwiseElementwise_2D +{ + static constexpr index_t NumInput = InDataTypePointerTuple::Size(); + static constexpr index_t NumOutput = OutDataTypePointerTuple::Size(); + + static_assert(NumInput == InScalarPerVectorSeq::Size() && + NumOutput == OutScalarPerVectorSeq::Size() && + NumInput == InGrid2dDescTuple::Size() && + NumOutput == OutGrid2dDescTuple::Size(), + "Tuple size is inconsistent with the number of in/out!"); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + static constexpr auto thread_buffer_desc_mn = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, Number{})); + + using PassThroughOp = tensor_operation::element_wise::PassThrough; + + __device__ static void Run(const InGrid2dDescTuple in_grid_2d_desc_tuple, + const OutGrid2dDescTuple out_grid_2d_desc_tuple, + const InDataTypePointerTuple p_in_global_tuple, + const OutDataTypePointerTuple p_out_global_tuple, + const ElementwiseOperation elementwise_op, + const index_t num_threads_m, + const index_t num_threads_n) + { + auto in_thread_buf_tuple = generate_tuple( + [&](auto I) { + using DataTypePointer = remove_cvref_t; + using DataType = remove_cv_t>; + + return StaticBuffer{}; + }, + Number{}); + + auto out_thread_buf_tuple = generate_tuple( + [&](auto I) { + using DataTypePointer = remove_cvref_t; + using DataType = remove_pointer_t; + + return StaticBuffer{}; + }, + Number{}); + + auto in_global_buf_tuple = generate_tuple( + [&](auto I) { + return make_dynamic_buffer( + p_in_global_tuple[I], in_grid_2d_desc_tuple[I].GetElementSpaceSize()); + }, + Number{}); + + auto out_global_buf_tuple = generate_tuple( + [&](auto I) { + return make_dynamic_buffer( + p_out_global_tuple[I], out_grid_2d_desc_tuple[I].GetElementSpaceSize()); + }, + Number{}); + + const auto M = in_grid_2d_desc_tuple[I0].GetLength(I0); + const auto N = in_grid_2d_desc_tuple[I0].GetLength(I1); + + const index_t loop_step_m = num_threads_m * MPerThread; + const index_t loop_step_n = num_threads_n * NPerThread; + + const index_t thread_1d_id = get_thread_global_1d_id(); + index_t tid_m = thread_1d_id / num_threads_n; + index_t tid_n = thread_1d_id % num_threads_n; + + const auto thread_global_offset = make_multi_index(tid_m * MPerThread, tid_n * NPerThread); + + auto in_global_load_tuple = generate_tuple( + [&](auto I) { + using DataTypePointer = remove_cvref_t; + using DataType = remove_cv_t>; + + return ThreadwiseTensorSliceTransfer_v2< + DataType, + DataType, + decltype(in_grid_2d_desc_tuple[I]), + decltype(thread_buffer_desc_mn), + Sequence, // SliceLengths + Sequence<0, 1>, // DimAccessOrder + 0, // SrcVectorDim + InScalarPerVectorSeq::At(I), // ScalarPerVector + 1, // SrcScalarStrideInVector + true>{in_grid_2d_desc_tuple[I], thread_global_offset}; + }, + Number{}); + + auto out_global_store_tuple = generate_tuple( + [&](auto I) { + using DataTypePointer = remove_cvref_t; + using DataType = remove_pointer_t; + + return ThreadwiseTensorSliceTransfer_v1r3< + DataType, + DataType, + decltype(thread_buffer_desc_mn), + decltype(out_grid_2d_desc_tuple[I]), + PassThroughOp, + Sequence, // SliceLengths + Sequence<0, 1>, // DimAccessOrder + 1, // SrcVectorDim + 1, // OutScalarPerVectorSeq::At(I), + InMemoryDataOperationEnum::Set, + 1, + true>(out_grid_2d_desc_tuple[I], thread_global_offset, PassThroughOp{}); + }, + Number{}); + + index_t num_iter_m = M / (loop_step_m); + do + { + index_t num_iter_n = N / (loop_step_n); + do + { + static_for<0, NumInput, 1>{}([&](auto I) { + in_global_load_tuple(I).Run(in_grid_2d_desc_tuple[I], + in_global_buf_tuple[I], + thread_buffer_desc_mn, + make_tuple(I0, I0), + in_thread_buf_tuple(I)); + + in_global_load_tuple(I).MoveSrcSliceWindow(in_grid_2d_desc_tuple[I], + make_multi_index(0, loop_step_n)); + }); + + static_for<0, MPerThread, 1>{}([&](auto iM) { + static_for<0, NPerThread, 1>{}([&](auto iN) { + constexpr auto offset = + thread_buffer_desc_mn.CalculateOffset(make_tuple(iM, iN)); + // get reference to in data + const auto in_data_refs = generate_tie( + // return type should be lvalue + [&](auto I) -> const auto& { + return in_thread_buf_tuple(I)(Number{}); + }, + Number{}); + + // get referenec to dst data + auto out_data_refs = generate_tie( + // return type should be lvalue + [&](auto I) -> auto& { + return out_thread_buf_tuple(I)(Number{}); + }, + Number{}); + unpack2(elementwise_op, out_data_refs, in_data_refs); + }); + }); + + static_for<0, NumOutput, 1>{}([&](auto I) { + out_global_store_tuple(I).Run(thread_buffer_desc_mn, + make_tuple(I0, I0), + out_thread_buf_tuple[I], + out_grid_2d_desc_tuple[I], + out_global_buf_tuple(I)); + + out_global_store_tuple(I).MoveDstSliceWindow(out_grid_2d_desc_tuple[I], + make_multi_index(0, loop_step_n)); + }); + + } while(--num_iter_n); + + static_for<0, NumInput, 1>{}([&](auto I) { + in_global_load_tuple(I).MoveSrcSliceWindow( + in_grid_2d_desc_tuple[I], + make_multi_index(loop_step_m, -(N / loop_step_n) * loop_step_n)); + }); + + static_for<0, NumOutput, 1>{}([&](auto I) { + out_global_store_tuple(I).MoveDstSliceWindow( + out_grid_2d_desc_tuple[I], + make_multi_index(loop_step_m, -(N / loop_step_n) * loop_step_n)); + }); + } while(--num_iter_m); + } +}; + +} // namespace ck