From b2d1cf056018d1f9f5f79232471d59dc225baf49 Mon Sep 17 00:00:00 2001 From: rocking5566 Date: Thu, 26 May 2022 00:17:27 +0800 Subject: [PATCH] Hotfix binary elementwise (for broadcast on fastest axis) (#254) * Support different length of ScalarPerVector * Add example of broadcast on fastest axis * Typo * Refine fastest example * Add dimension check * Modify fastest broadcast example to 3d * Enforce users give scalarPerVector explicitely * 1. Add CscalarPerVedctor 2. Not only broadcast on fastest need to set scalarPerVector to 1 * Rename var * Move IsScalarPerVectorValid() inside IsSupportedArgument() * Separate GridDesc_M0 into A, B and C * rename var * Rename var of length Co-authored-by: rocking [ROCm/composable_kernel commit: 82d7d9938f897a7ae9d15fd8de210af2563ae1e2] --- example/19_binary_elementwise/CMakeLists.txt | 3 +- ...add_2d.cpp => broadcast_add_2d_amn_bn.cpp} | 17 ++- .../broadcast_add_3d_am_bmnk.cpp | 123 +++++++++++++++ .../elementwise_add_1d.cpp | 17 ++- .../elementwise_add_4d.cpp | 17 ++- .../gpu/device/device_binary_elementwise.hpp | 143 +++++++++++------- .../grid/gridwise_binary_elementwise_1d.hpp | 124 +++++++-------- 7 files changed, 319 insertions(+), 125 deletions(-) rename example/19_binary_elementwise/{broadcast_add_2d.cpp => broadcast_add_2d_amn_bn.cpp} (84%) create mode 100644 example/19_binary_elementwise/broadcast_add_3d_am_bmnk.cpp diff --git a/example/19_binary_elementwise/CMakeLists.txt b/example/19_binary_elementwise/CMakeLists.txt index 6c95b2e55e..39646e0ab5 100644 --- a/example/19_binary_elementwise/CMakeLists.txt +++ b/example/19_binary_elementwise/CMakeLists.txt @@ -1,3 +1,4 @@ -add_example_executable(example_broadcast_add_2d broadcast_add_2d.cpp) +add_example_executable(example_broadcast_add_2d_amn_bn broadcast_add_2d_amn_bn.cpp) +add_example_executable(example_broadcast_add_3d_am_bmnk broadcast_add_3d_am_bmnk.cpp) add_example_executable(example_elementwise_add_1d elementwise_add_1d.cpp) add_example_executable(example_elementwise_add_4d elementwise_add_4d.cpp) \ No newline at end of file diff --git a/example/19_binary_elementwise/broadcast_add_2d.cpp b/example/19_binary_elementwise/broadcast_add_2d_amn_bn.cpp similarity index 84% rename from example/19_binary_elementwise/broadcast_add_2d.cpp rename to example/19_binary_elementwise/broadcast_add_2d_amn_bn.cpp index 2a3ef421ff..cbe768f30b 100644 --- a/example/19_binary_elementwise/broadcast_add_2d.cpp +++ b/example/19_binary_elementwise/broadcast_add_2d_amn_bn.cpp @@ -19,8 +19,17 @@ using EltwiseComputeDataType = F32; using Add = ck::tensor_operation::binary_element_wise::Add; -using DeviceElementwiseAddInstance = ck::tensor_operation::device:: - DeviceBinaryElementwise; +using DeviceElementwiseAddInstance = + ck::tensor_operation::device::DeviceBinaryElementwise; template (host_c_m_n, a_m_n, b_n, M, N, Add{}); pass &= ck::utils::check_err( - c_m_n.mData, host_c_m_n.mData, "Error: Incorrect results d1", 1e-3, 1e-3); + c_m_n.mData, host_c_m_n.mData, "Error: Incorrect results c", 1e-3, 1e-3); } return pass ? 0 : 1; diff --git a/example/19_binary_elementwise/broadcast_add_3d_am_bmnk.cpp b/example/19_binary_elementwise/broadcast_add_3d_am_bmnk.cpp new file mode 100644 index 0000000000..06523f0cf7 --- /dev/null +++ b/example/19_binary_elementwise/broadcast_add_3d_am_bmnk.cpp @@ -0,0 +1,123 @@ +#include +#include +#include "check_err.hpp" +#include "config.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" + +#include "device_tensor.hpp" +#include "binary_element_wise_operation.hpp" +#include "device_binary_elementwise.hpp" + +using F16 = ck::half_t; +using F32 = float; + +using ABDataType = F16; +using CDataType = F16; +using EltwiseComputeDataType = F32; + +using Add = ck::tensor_operation::binary_element_wise::Add; + +using DeviceElementwiseAddInstance = + ck::tensor_operation::device::DeviceBinaryElementwise; + +template +void host_broadcast3D_am_bmnk(HostTensorC& C, + const HostTensorA& A, + const HostTensorB& B, + const std::vector& shape, + Functor functor) +{ + using ctype = ck::remove_reference_t; + + for(std::size_t m = 0; m < shape[0]; ++m) + for(std::size_t n = 0; n < shape[1]; ++n) + for(std::size_t k = 0; k < shape[2]; ++k) + { + ComputeDataType a_val = static_cast(A(m)); + ComputeDataType b_val = static_cast(B(m, n, k)); + ComputeDataType c_val = 0; + functor(c_val, a_val, b_val); + C(m, n, k) = static_cast(c_val); + } +} + +int main() +{ + bool do_verification = true; + bool time_kernel = false; + + std::vector mnk = {4, 16, 32}; + ck::index_t M = mnk[0]; + + Tensor a_m({M}); + Tensor b_m_n_k(mnk); + Tensor c_m_n_k(mnk); + + a_m.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_m_n_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + + DeviceMem a_m_device_buf(sizeof(ABDataType) * a_m.mDesc.GetElementSpace()); + DeviceMem b_m_n_k_device_buf(sizeof(ABDataType) * b_m_n_k.mDesc.GetElementSpace()); + DeviceMem c_m_n_k_device_buf(sizeof(CDataType) * c_m_n_k.mDesc.GetElementSpace()); + + a_m_device_buf.ToDevice(a_m.mData.data()); + b_m_n_k_device_buf.ToDevice(b_m_n_k.mData.data()); + + auto broadcastAdd = DeviceElementwiseAddInstance{}; + auto argument = broadcastAdd.MakeArgumentPointer( + a_m_device_buf.GetDeviceBuffer(), + b_m_n_k_device_buf.GetDeviceBuffer(), + c_m_n_k_device_buf.GetDeviceBuffer(), + std::vector{mnk.begin(), mnk.end()}, + {1, 0, 0}, // broadcast A on second and third dimension + std::vector{b_m_n_k.mDesc.GetStrides().begin(), + b_m_n_k.mDesc.GetStrides().end()}, + std::vector{c_m_n_k.mDesc.GetStrides().begin(), + c_m_n_k.mDesc.GetStrides().end()}, + Add{}); + + if(!broadcastAdd.IsSupportedArgument(argument.get())) + { + throw std::runtime_error("The runtime parameters seems not supported by the " + "DeviceBinaryElementwise instance, exiting!"); + }; + + auto broadcastAdd_invoker_ptr = broadcastAdd.MakeInvokerPointer(); + float ave_time = + broadcastAdd_invoker_ptr->Run(argument.get(), StreamConfig{nullptr, time_kernel}); + + std::cout << "Perf: " << ave_time << " ms" << std::endl; + + bool pass = true; + if(do_verification) + { + c_m_n_k_device_buf.FromDevice(c_m_n_k.mData.data()); + Tensor host_c_m_n_k(mnk); + + host_broadcast3D_am_bmnk, + Tensor, + Tensor, + EltwiseComputeDataType, + Add>(host_c_m_n_k, a_m, b_m_n_k, mnk, Add{}); + + pass &= ck::utils::check_err( + c_m_n_k.mData, host_c_m_n_k.mData, "Error: Incorrect results c", 1e-3, 1e-3); + } + + return pass ? 0 : 1; +} diff --git a/example/19_binary_elementwise/elementwise_add_1d.cpp b/example/19_binary_elementwise/elementwise_add_1d.cpp index 455ff24c31..cebc3aa67a 100644 --- a/example/19_binary_elementwise/elementwise_add_1d.cpp +++ b/example/19_binary_elementwise/elementwise_add_1d.cpp @@ -19,8 +19,17 @@ using EltwiseComputeDataType = F32; using Add = ck::tensor_operation::binary_element_wise::Add; -using DeviceElementwiseAddInstance = ck::tensor_operation::device:: - DeviceBinaryElementwise; +using DeviceElementwiseAddInstance = + ck::tensor_operation::device::DeviceBinaryElementwise; template (host_c_m, a_m, b_m, M, Add{}); pass &= ck::utils::check_err( - c_m.mData, host_c_m.mData, "Error: Incorrect results d1", 1e-3, 1e-3); + c_m.mData, host_c_m.mData, "Error: Incorrect results c", 1e-3, 1e-3); } return pass ? 0 : 1; diff --git a/example/19_binary_elementwise/elementwise_add_4d.cpp b/example/19_binary_elementwise/elementwise_add_4d.cpp index 937a6c8c1d..7e6d1fd77b 100644 --- a/example/19_binary_elementwise/elementwise_add_4d.cpp +++ b/example/19_binary_elementwise/elementwise_add_4d.cpp @@ -19,8 +19,17 @@ using EltwiseComputeDataType = F32; using Add = ck::tensor_operation::binary_element_wise::Add; -using DeviceElementwiseAddInstance = ck::tensor_operation::device:: - DeviceBinaryElementwise; +using DeviceElementwiseAddInstance = + ck::tensor_operation::device::DeviceBinaryElementwise; template (host_c, a, b, nchw, Add{}); pass &= - ck::utils::check_err(c.mData, host_c.mData, "Error: Incorrect results d1", 1e-3, 1e-3); + ck::utils::check_err(c.mData, host_c.mData, "Error: Incorrect results c", 1e-3, 1e-3); } return pass ? 0 : 1; diff --git a/include/ck/tensor_operation/gpu/device/device_binary_elementwise.hpp b/include/ck/tensor_operation/gpu/device/device_binary_elementwise.hpp index 8955aadc11..34b3a59c74 100644 --- a/include/ck/tensor_operation/gpu/device/device_binary_elementwise.hpp +++ b/include/ck/tensor_operation/gpu/device/device_binary_elementwise.hpp @@ -15,91 +15,107 @@ template + index_t NDim, + index_t MPerThread, + index_t AScalarPerVector, + index_t BScalarPerVector, + index_t CScalarPerVector> struct DeviceBinaryElementwise : public BaseOperator { static constexpr auto I0 = Number<0>{}; - template - static auto PadDescriptor_M0_1d(Desc_M0 desc_m0, index_t gridSize, index_t blockSize) + template + static auto PadDescriptor_M_1d(Desc_M desc_m, index_t gridSize, index_t blockSize) { - const auto m0 = desc_m0.GetLength(I0); - const index_t loop_step = gridSize * blockSize * ScalarPerVector; - const auto pad = math::integer_least_multiple(m0, loop_step) - m0; - const auto desc_m0_pad = - transform_tensor_descriptor(desc_m0, - make_tuple(make_right_pad_transform(m0, pad)), + const auto M = desc_m.GetLength(I0); + const index_t loop_step = gridSize * blockSize * MPerThread; + const auto pad = math::integer_least_multiple(M, loop_step) - M; + const auto desc_m_pad = + transform_tensor_descriptor(desc_m, + make_tuple(make_right_pad_transform(M, pad)), make_tuple(Sequence<0>{}), make_tuple(Sequence<0>{})); - return desc_m0_pad; + return desc_m_pad; } - static auto MakeDescriptor_M0(const std::vector& shape, - const std::vector& stride, - index_t gridSize, - index_t blockSize) + static auto MakeDescriptor_M(const std::vector& lengths, + const std::vector& strides, + index_t gridSize, + index_t blockSize) { - auto tupleOfShape = generate_tuple([&](auto I) { return shape[I]; }, Number{}); - auto tupleOfStride = generate_tuple([&](auto I) { return stride[I]; }, Number{}); + auto tupleOfShape = generate_tuple([&](auto I) { return lengths[I]; }, Number{}); + auto tupleOfStride = generate_tuple([&](auto I) { return strides[I]; }, Number{}); // nd desc - [s0, s1, s2, ...] const auto desc = make_naive_tensor_descriptor(tupleOfShape, tupleOfStride); // merge nd to 1d desc - [s0 * s1 * ...] - if constexpr(Dim > 1) + if constexpr(NDim > 1) { - const auto desc_m0 = transform_tensor_descriptor( + const auto desc_m = transform_tensor_descriptor( desc, make_tuple(make_merge_transform(tupleOfShape)), - make_tuple(generate_sequence_v2([&](auto I) { return I; }, Number{})), + make_tuple(generate_sequence_v2([&](auto I) { return I; }, Number{})), make_tuple(Sequence<0>{})); - return PadDescriptor_M0_1d(desc_m0, gridSize, blockSize); + return PadDescriptor_M_1d(desc_m, gridSize, blockSize); } else - return PadDescriptor_M0_1d(desc, gridSize, blockSize); + return PadDescriptor_M_1d(desc, gridSize, blockSize); } - using GridDesc_M0 = decltype(MakeDescriptor_M0({1, 1}, {1, 1}, 1, 1)); + using AGridDesc_M = decltype(MakeDescriptor_M({1, 1}, {1, 1}, 1, 1)); + using BGridDesc_M = decltype(MakeDescriptor_M({1, 1}, {1, 1}, 1, 1)); + using CGridDesc_M = decltype(MakeDescriptor_M({1, 1}, {1, 1}, 1, 1)); using GridwiseBinEltwise = GridwiseBinaryElementwise_1D; + MPerThread, + AScalarPerVector, + BScalarPerVector, + CScalarPerVector>; struct Argument : public BaseArgument { Argument(const ADataType* p_a, const BDataType* p_b, CDataType* p_c, - const std::vector& shape, - const std::vector& stride_a, - const std::vector& stride_b, - const std::vector& stride_c, + const std::vector& lengths, + const std::vector& a_strides, + const std::vector& b_strides, + const std::vector& c_strides, ElementwiseFunctor functor) : p_a_(p_a), p_b_(p_b), p_c_(p_c), - shape_(shape), + lengths_(lengths), + a_strides_(a_strides), + b_strides_(b_strides), + c_strides_(c_strides), functor_(functor), blockSize_(256), gridSize_(120) // FIXME - Calculate the grid size by number of CU in the future { - a_grid_desc_m0_ = MakeDescriptor_M0(shape, stride_a, gridSize_, blockSize_); - b_grid_desc_m0_ = MakeDescriptor_M0(shape, stride_b, gridSize_, blockSize_); - c_grid_desc_m0_ = MakeDescriptor_M0(shape, stride_c, gridSize_, blockSize_); + a_grid_desc_m_ = MakeDescriptor_M(lengths, a_strides, gridSize_, blockSize_); + b_grid_desc_m_ = MakeDescriptor_M(lengths, b_strides, gridSize_, blockSize_); + c_grid_desc_m_ = MakeDescriptor_M(lengths, c_strides, gridSize_, blockSize_); } const ADataType* p_a_; const BDataType* p_b_; CDataType* p_c_; - std::vector shape_; - GridDesc_M0 a_grid_desc_m0_; - GridDesc_M0 b_grid_desc_m0_; - GridDesc_M0 c_grid_desc_m0_; + std::vector lengths_; + AGridDesc_M a_grid_desc_m_; + BGridDesc_M b_grid_desc_m_; + CGridDesc_M c_grid_desc_m_; + std::vector a_strides_; + std::vector b_strides_; + std::vector c_strides_; ElementwiseFunctor functor_; index_t blockSize_; index_t gridSize_; @@ -113,7 +129,9 @@ struct DeviceBinaryElementwise : public BaseOperator ADataType, BDataType, CDataType, - GridDesc_M0, + AGridDesc_M, + BGridDesc_M, + CGridDesc_M, ElementwiseFunctor>; float elapsed_time = launch_and_time_kernel(stream_config, @@ -124,9 +142,9 @@ struct DeviceBinaryElementwise : public BaseOperator arg.p_a_, arg.p_b_, arg.p_c_, - arg.a_grid_desc_m0_, - arg.b_grid_desc_m0_, - arg.c_grid_desc_m0_, + arg.a_grid_desc_m_, + arg.b_grid_desc_m_, + arg.c_grid_desc_m_, arg.functor_); return elapsed_time; } @@ -146,7 +164,30 @@ struct DeviceBinaryElementwise : public BaseOperator if(pArg == nullptr) return false; - if(pArg->shape_.back() % ScalarPerVector != 0) + if(pArg->lengths_.size() != NDim) + return false; + + if(pArg->lengths_.back() % MPerThread != 0) + return false; + + auto IsScalarPerVectorValid = [](bool isLastDimensionCoalesced, int scalarPerVector) { + bool ret = true; + + if(!isLastDimensionCoalesced) + ret = scalarPerVector == 1; + else + ret = MPerThread % scalarPerVector == 0; + + return ret; + }; + + if(!IsScalarPerVectorValid(pArg->a_strides_.back() == 1, AScalarPerVector)) + return false; + + if(!IsScalarPerVectorValid(pArg->b_strides_.back() == 1, BScalarPerVector)) + return false; + + if(!IsScalarPerVectorValid(pArg->c_strides_.back() == 1, CScalarPerVector)) return false; return true; @@ -155,19 +196,19 @@ struct DeviceBinaryElementwise : public BaseOperator std::unique_ptr MakeArgumentPointer(const void* p_a, const void* p_b, void* p_c, - std::vector shape, - std::vector stride_a, - std::vector stride_b, - std::vector stride_c, + std::vector lengths, + std::vector a_strides, + std::vector b_strides, + std::vector c_strides, ElementwiseFunctor functor) { return std::make_unique(static_cast(p_a), static_cast(p_b), static_cast(p_c), - shape, - stride_a, - stride_b, - stride_c, + lengths, + a_strides, + b_strides, + c_strides, functor); } @@ -180,7 +221,7 @@ struct DeviceBinaryElementwise : public BaseOperator // clang-format off str << "DeviceBinaryElementwise" << "<" - << "ScalarPerVector = " << ScalarPerVector + << "MPerThread = " << MPerThread << ">"; // clang-format on diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_binary_elementwise_1d.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_binary_elementwise_1d.hpp index c77d49ae94..374c4fe59a 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_binary_elementwise_1d.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_binary_elementwise_1d.hpp @@ -11,138 +11,140 @@ template __global__ void kernel_binary_elementwise_1d(const ADataType* __restrict__ p_a_global, const BDataType* __restrict__ p_b_global, CDataType* __restrict__ p_c_global, - const GridDesc_M0 a_grid_desc_m0, - const GridDesc_M0 b_grid_desc_m0, - const GridDesc_M0 c_grid_desc_m0, + const AGridDesc_M a_grid_desc_m, + const BGridDesc_M b_grid_desc_m, + const CGridDesc_M c_grid_desc_m, const ElementwiseFunctor functor) { - GridwiseBinEltwise::Run(p_a_global, - p_b_global, - p_c_global, - a_grid_desc_m0, - b_grid_desc_m0, - c_grid_desc_m0, - functor); + GridwiseBinEltwise::Run( + p_a_global, p_b_global, p_c_global, a_grid_desc_m, b_grid_desc_m, c_grid_desc_m, functor); } template + index_t MPerThread, + index_t AScalarPerVector, + index_t BScalarPerVector, + index_t CScalarPerVector> struct GridwiseBinaryElementwise_1D { static constexpr auto I0 = Number<0>{}; - static constexpr auto thread_desc_m0 = - make_naive_tensor_descriptor_packed(make_tuple(Number{})); + static constexpr auto thread_desc_m = + make_naive_tensor_descriptor_packed(make_tuple(Number{})); using PassThrough = tensor_operation::element_wise::PassThrough; static __device__ auto CalculateElementwiseIndex() { const index_t global_thread_id = get_thread_global_1d_id(); - return make_multi_index(global_thread_id * ScalarPerVector); + return make_multi_index(global_thread_id * MPerThread); } __device__ static void Run(const ADataType* __restrict__ p_a_global, const BDataType* __restrict__ p_b_global, CDataType* __restrict__ p_c_global, - const GridDesc_M0 a_grid_desc_m0, - const GridDesc_M0 b_grid_desc_m0, - const GridDesc_M0 c_grid_desc_m0, + const AGridDesc_M a_grid_desc_m, + const BGridDesc_M b_grid_desc_m, + const CGridDesc_M c_grid_desc_m, const ElementwiseFunctor functor) { const auto a_global_buf = make_dynamic_buffer( - p_a_global, a_grid_desc_m0.GetElementSpaceSize()); + p_a_global, a_grid_desc_m.GetElementSpaceSize()); const auto b_global_buf = make_dynamic_buffer( - p_b_global, b_grid_desc_m0.GetElementSpaceSize()); + p_b_global, b_grid_desc_m.GetElementSpaceSize()); auto c_global_buf = make_dynamic_buffer( - p_c_global, c_grid_desc_m0.GetElementSpaceSize()); + p_c_global, c_grid_desc_m.GetElementSpaceSize()); - StaticBuffer a_thread_buf; - StaticBuffer b_thread_buf; - StaticBuffer c_thread_buf; + StaticBuffer a_thread_buf; + StaticBuffer b_thread_buf; + StaticBuffer c_thread_buf; const auto thread_store_global_offset = CalculateElementwiseIndex(); auto a_global_load = ThreadwiseTensorSliceTransfer_v2, // SliceLengths - Sequence<0>, // DimAccessOrder - 0, // SrcVectorDim - ScalarPerVector, - 1, // SrcScalarStrideInVector - false>{a_grid_desc_m0, thread_store_global_offset}; + AGridDesc_M, + decltype(thread_desc_m), + Sequence, // SliceLengths + Sequence<0>, // DimAccessOrder + 0, // SrcVectorDim + AScalarPerVector, // ScalarPerVector + 1, // SrcScalarStrideInVector + false>{a_grid_desc_m, thread_store_global_offset}; auto b_global_load = ThreadwiseTensorSliceTransfer_v2, // SliceLengths - Sequence<0>, // DimAccessOrder - 0, // SrcVectorDim - ScalarPerVector, - 1, // SrcScalarStrideInVector - false>{b_grid_desc_m0, thread_store_global_offset}; + BGridDesc_M, + decltype(thread_desc_m), + Sequence, // SliceLengths + Sequence<0>, // DimAccessOrder + 0, // SrcVectorDim + BScalarPerVector, // ScalarPerVector + 1, // SrcScalarStrideInVector + false>{b_grid_desc_m, thread_store_global_offset}; auto c_global_write = ThreadwiseTensorSliceTransfer_v1r3, // SliceLengths - Sequence<0>, // DimAccessOrder - 0, // DstVectorDim - ScalarPerVector, + Sequence, // SliceLengths + Sequence<0>, // DimAccessOrder + 0, // DstVectorDim + CScalarPerVector, // ScalarPerVector InMemoryDataOperationEnum::Set, 1, // DstScalarStrideInVector false>{ - c_grid_desc_m0, thread_store_global_offset, PassThrough{}}; + c_grid_desc_m, thread_store_global_offset, PassThrough{}}; const index_t blockSize = get_block_size(); const index_t blockPerGrid = get_grid_size(); - const auto m0 = c_grid_desc_m0.GetLength(I0); - const index_t loop_step = blockPerGrid * blockSize * ScalarPerVector; + const auto M = c_grid_desc_m.GetLength(I0); + const index_t loop_step = blockPerGrid * blockSize * MPerThread; const auto loop_step_index = make_multi_index(loop_step); - index_t num_iter = m0 / (loop_step); + index_t num_iter = M / (loop_step); do { - // read and process ScalarPerVector elements + // read and process MPerThread elements a_global_load.Run( - a_grid_desc_m0, a_global_buf, thread_desc_m0, make_tuple(I0), a_thread_buf); + a_grid_desc_m, a_global_buf, thread_desc_m, make_tuple(I0), a_thread_buf); b_global_load.Run( - b_grid_desc_m0, b_global_buf, thread_desc_m0, make_tuple(I0), b_thread_buf); + b_grid_desc_m, b_global_buf, thread_desc_m, make_tuple(I0), b_thread_buf); - static_for<0, ScalarPerVector, 1>{}([&](auto m) { - constexpr auto offset = thread_desc_m0.CalculateOffset(make_tuple(m)); + static_for<0, MPerThread, 1>{}([&](auto m) { + constexpr auto offset = thread_desc_m.CalculateOffset(make_tuple(m)); functor(c_thread_buf(Number{}), a_thread_buf(Number{}), b_thread_buf(Number{})); }); - c_global_write.Run(thread_desc_m0, + c_global_write.Run(thread_desc_m, make_tuple(I0), // SrcSliceOriginIdx c_thread_buf, - c_grid_desc_m0, + c_grid_desc_m, c_global_buf); - a_global_load.MoveSrcSliceWindow(a_grid_desc_m0, loop_step_index); - b_global_load.MoveSrcSliceWindow(b_grid_desc_m0, loop_step_index); - c_global_write.MoveDstSliceWindow(c_grid_desc_m0, loop_step_index); + a_global_load.MoveSrcSliceWindow(a_grid_desc_m, loop_step_index); + b_global_load.MoveSrcSliceWindow(b_grid_desc_m, loop_step_index); + c_global_write.MoveDstSliceWindow(c_grid_desc_m, loop_step_index); } while(--num_iter); } };