diff --git a/example/51_avgpool3d_bwd/CMakeLists.txt b/example/51_avgpool3d_bwd/CMakeLists.txt new file mode 100644 index 0000000000..fef0c66835 --- /dev/null +++ b/example/51_avgpool3d_bwd/CMakeLists.txt @@ -0,0 +1,3 @@ +add_example_executable(example_avgpool3d_bwd_bf16 avgpool3d_bwd_bf16.cpp) +add_example_executable(example_avgpool3d_bwd_fp16 avgpool3d_bwd_fp16.cpp) +add_example_executable(example_avgpool3d_bwd_fp32 avgpool3d_bwd_fp32.cpp) diff --git a/example/51_avgpool3d_bwd/avgpool3d_bwd_bf16.cpp b/example/51_avgpool3d_bwd/avgpool3d_bwd_bf16.cpp new file mode 100644 index 0000000000..a5ab6ef24a --- /dev/null +++ b/example/51_avgpool3d_bwd/avgpool3d_bwd_bf16.cpp @@ -0,0 +1,62 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_avgpool3d_bwd_ndhwc_ndhwc.hpp" + +#include "avgpool3d_bwd_common.hpp" + +using DOutDataType = ck::bhalf_t; +using DInDataType = ck::bhalf_t; +using ComputeDataType = float; + +#if 1 +using DOutLayout = ck::tensor_layout::convolution::NDHWC; +using DInLayout = ck::tensor_layout::convolution::NDHWC; +#else +using DOutLayout = ck::tensor_layout::convolution::NCDHW; +using DInLayout = ck::tensor_layout::convolution::NCDHW; +#endif + +using DevicePoolBwdInstance = + ck::tensor_operation::device::DeviceAvgPool3dBwd_NDHWC_NDHWC; // InSrcOutDstVectorSize + +int main() +{ + std::vector window_lengths = {5, 5, 5}; + std::vector window_strides = {2, 2, 2}; + std::vector window_dilations = {2, 2, 2}; + std::vector dinput_left_pads = {0, 0, 0}; + std::vector dinput_right_pads = {0, 0, 0}; + + ck::index_t N = 1; + ck::index_t C = 16; + ck::index_t Di = 40; + ck::index_t Hi = 40; + ck::index_t Wi = 40; + + pool3d_bwd_test( + true, + false, + N, + C, + Di, + Hi, + Wi, + window_lengths, + window_strides, + window_dilations, + dinput_left_pads, + dinput_right_pads); +} diff --git a/example/51_avgpool3d_bwd/avgpool3d_bwd_common.hpp b/example/51_avgpool3d_bwd/avgpool3d_bwd_common.hpp new file mode 100644 index 0000000000..394f046b1e --- /dev/null +++ b/example/51_avgpool3d_bwd/avgpool3d_bwd_common.hpp @@ -0,0 +1,147 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_avgpool_bwd.hpp" + +template +std::vector f_tensor_strides_ncdhw(ck::index_t N_, + ck::index_t C_, + ck::index_t D, + ck::index_t H, + ck::index_t W, + TensorLayout layout) +{ + using namespace ck::literals; + (void)N_; + if constexpr(ck::is_same::value) + return {C_ * D * H * W, D * H * W, H * W, W, 1_uz}; + else if constexpr(ck::is_same::value) + return {D * C_ * H * W, 1_uz, C_ * H * W, W * C_, C_}; +}; + +template +HostTensorDescriptor f_host_tensor_descriptor(std::size_t N_, + std::size_t C_, + std::size_t D, + std::size_t H, + std::size_t W, + TensorLayout layout) +{ + using namespace ck::literals; + + if constexpr(ck::is_same::value) + { + return HostTensorDescriptor({N_, C_, D, H, W}, {C_ * D * H * W, D * H * W, H * W, W, 1_uz}); + } + else if constexpr(ck::is_same::value) + { + return HostTensorDescriptor({N_, C_, D, H, W}, + {D * C_ * H * W, 1_uz, C_ * H * W, W * C_, C_}); + } +}; + +template +bool pool3d_bwd_test(bool do_verification, + bool time_kernel, + ck::index_t N, + ck::index_t C, + ck::index_t Di, + ck::index_t Hi, + ck::index_t Wi, + std::vector window_lengths, + std::vector window_strides, + std::vector window_dilations, + std::vector dinput_left_pads, + std::vector dinput_right_pads) +{ + auto OutSpatialLength = [&](auto InSpatialLength, int index) { + ck::index_t left_pad = dinput_left_pads[index]; + ck::index_t right_pad = dinput_right_pads[index]; + ck::index_t window_len = window_lengths[index]; + ck::index_t stride = window_strides[index]; + ck::index_t dilation = window_dilations[index]; + ck::index_t eff = (window_len - 1) * dilation + 1; + return (InSpatialLength + left_pad + right_pad - eff) / stride + 1; + }; + + ck::index_t Do = OutSpatialLength(Di, 0); + ck::index_t Ho = OutSpatialLength(Hi, 1); + ck::index_t Wo = OutSpatialLength(Wi, 2); + + Tensor dout(f_host_tensor_descriptor(N, C, Do, Ho, Wo, DOutLayout{})); + Tensor din_dev(f_host_tensor_descriptor(N, C, Di, Hi, Wi, DInLayout{})); + Tensor din_host(f_host_tensor_descriptor(N, C, Di, Hi, Wi, DInLayout{})); + + std::cout << "dout: " << dout.mDesc << std::endl; + std::cout << "din_host: " << din_host.mDesc << std::endl; + + dout.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + + DeviceMem dout_device_buf(sizeof(DOutDataType) * dout.mDesc.GetElementSpaceSize()); + DeviceMem din_device_buf(sizeof(DInDataType) * din_dev.mDesc.GetElementSpaceSize()); + + dout_device_buf.ToDevice(dout.mData.data()); + din_device_buf.SetZero(); + + auto pool = DevicePoolBwdInstance{}; + auto invoker_ptr = pool.MakeInvokerPointer(); + auto argument_ptr = + pool.MakeArgumentPointer(static_cast(dout_device_buf.GetDeviceBuffer()), + static_cast(din_device_buf.GetDeviceBuffer()), + {N, C, Do, Ho, Wo}, + {N, C, Di, Hi, Wi}, + f_tensor_strides_ncdhw(N, C, Do, Ho, Wo, DOutLayout{}), + f_tensor_strides_ncdhw(N, C, Di, Hi, Wi, DInLayout{}), + window_lengths, + window_strides, + window_dilations, + dinput_left_pads, + dinput_right_pads); + + if(!pool.IsSupportedArgument(argument_ptr.get())) + { + throw std::runtime_error("wrong! device_op with the specified compilation parameters does " + "not support this problem"); + } + + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + std::cout << "Perf: " << ave_time << std::endl; + + bool pass = true; + + if(do_verification) + { + auto ref_pool = + ck::tensor_operation::host::ReferenceAvgPoolBwd<3, DInDataType, DOutDataType>(); + + auto ref_invoker = ref_pool.MakeInvoker(); + + auto ref_argument = ref_pool.MakeArgument(din_host, + dout, + window_lengths, + window_strides, + window_dilations, + dinput_left_pads, + dinput_right_pads); + + ref_invoker.Run(ref_argument); + + din_device_buf.FromDevice(din_dev.mData.data()); + pass = ck::utils::check_err(din_dev, din_host); + } + + return pass; +} diff --git a/example/51_avgpool3d_bwd/avgpool3d_bwd_fp16.cpp b/example/51_avgpool3d_bwd/avgpool3d_bwd_fp16.cpp new file mode 100644 index 0000000000..578f563d5c --- /dev/null +++ b/example/51_avgpool3d_bwd/avgpool3d_bwd_fp16.cpp @@ -0,0 +1,62 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_avgpool3d_bwd_ndhwc_ndhwc.hpp" + +#include "avgpool3d_bwd_common.hpp" + +using DOutDataType = ck::half_t; +using DInDataType = ck::half_t; +using ComputeDataType = float; + +#if 1 +using DOutLayout = ck::tensor_layout::convolution::NDHWC; +using DInLayout = ck::tensor_layout::convolution::NDHWC; +#else +using DOutLayout = ck::tensor_layout::convolution::NCDHW; +using DInLayout = ck::tensor_layout::convolution::NCDHW; +#endif + +using DevicePoolBwdInstance = + ck::tensor_operation::device::DeviceAvgPool3dBwd_NDHWC_NDHWC; // InSrcOutDstVectorSize + +int main() +{ + std::vector window_lengths = {5, 5, 5}; + std::vector window_strides = {2, 2, 2}; + std::vector window_dilations = {2, 2, 2}; + std::vector dinput_left_pads = {0, 0, 0}; + std::vector dinput_right_pads = {0, 0, 0}; + + ck::index_t N = 1; + ck::index_t C = 16; + ck::index_t Di = 40; + ck::index_t Hi = 40; + ck::index_t Wi = 40; + + pool3d_bwd_test( + true, + false, + N, + C, + Di, + Hi, + Wi, + window_lengths, + window_strides, + window_dilations, + dinput_left_pads, + dinput_right_pads); +} diff --git a/example/51_avgpool3d_bwd/avgpool3d_bwd_fp32.cpp b/example/51_avgpool3d_bwd/avgpool3d_bwd_fp32.cpp new file mode 100644 index 0000000000..c2910c5567 --- /dev/null +++ b/example/51_avgpool3d_bwd/avgpool3d_bwd_fp32.cpp @@ -0,0 +1,62 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_avgpool3d_bwd_ndhwc_ndhwc.hpp" + +#include "avgpool3d_bwd_common.hpp" + +using DOutDataType = float; +using DInDataType = float; +using ComputeDataType = float; + +#if 1 +using DOutLayout = ck::tensor_layout::convolution::NDHWC; +using DInLayout = ck::tensor_layout::convolution::NDHWC; +#else +using DOutLayout = ck::tensor_layout::convolution::NCDHW; +using DInLayout = ck::tensor_layout::convolution::NCDHW; +#endif + +using DevicePoolBwdInstance = + ck::tensor_operation::device::DeviceAvgPool3dBwd_NDHWC_NDHWC; // InSrcOutDstVectorSize + +int main() +{ + std::vector window_lengths = {5, 5, 5}; + std::vector window_strides = {2, 2, 2}; + std::vector window_dilations = {2, 2, 2}; + std::vector dinput_left_pads = {0, 0, 0}; + std::vector dinput_right_pads = {0, 0, 0}; + + ck::index_t N = 1; + ck::index_t C = 16; + ck::index_t Di = 40; + ck::index_t Hi = 40; + ck::index_t Wi = 40; + + pool3d_bwd_test( + true, + false, + N, + C, + Di, + Hi, + Wi, + window_lengths, + window_strides, + window_dilations, + dinput_left_pads, + dinput_right_pads); +} diff --git a/include/ck/tensor_operation/gpu/device/device_avgpool_bwd.hpp b/include/ck/tensor_operation/gpu/device/device_avgpool_bwd.hpp new file mode 100644 index 0000000000..5a2b082c31 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_avgpool_bwd.hpp @@ -0,0 +1,39 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceAvgPoolBwd : public BaseOperator +{ + virtual std::unique_ptr + MakeArgumentPointer(const void* p_dout, + void* p_din, + std::vector dout_n_k_wos_lengths, + std::vector dout_n_k_wos_strides, + std::vector din_n_k_wos_length, + std::vector din_n_k_wos_strides, + std::vector window_k_c_xs_lengths, + std::vector window_strides, + std::vector window_dilations, + std::vector input_left_pads, + std::vector input_right_pads) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_put_element.hpp b/include/ck/tensor_operation/gpu/device/device_put_element.hpp index 9186827491..17df2de37b 100644 --- a/include/ck/tensor_operation/gpu/device/device_put_element.hpp +++ b/include/ck/tensor_operation/gpu/device/device_put_element.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/impl/device_avgpool3d_bwd_ndhwc_ndhwc.hpp b/include/ck/tensor_operation/gpu/device/impl/device_avgpool3d_bwd_ndhwc_ndhwc.hpp new file mode 100644 index 0000000000..3a2280a75c --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_avgpool3d_bwd_ndhwc_ndhwc.hpp @@ -0,0 +1,575 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp" +#include "ck/tensor_operation/gpu/device/device_avgpool_bwd.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// In and Din = [N, C, Di, Hi, Wi] +// Out and Dout = [N, C, Do, Ho, Wo] +// Out = AvgPoolFwd(In) +// Din = AvgPoolBwd(Dout) +// Pooling dimension = D, H, W +template +struct DeviceAvgPool3dBwd_NDHWC_NDHWC : public DeviceAvgPoolBwd<3, + DOutDataType, + DInDataType, + tensor_layout::convolution::NDHWC, + tensor_layout::convolution::NDHWC> +{ + static constexpr ck::index_t NDimSpatial = 3; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + static constexpr ck::index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; + static constexpr ck::index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; + + static auto + Make3DGridDescriptor_Out_M_K_In_M(const std::vector& dout_n_c_wos_lengths, + const std::vector& din_n_c_wos_length, + const std::vector& dout_n_c_wos_strides, + const std::vector& din_n_c_wos_strides, + const std::vector& window_lengths, + const std::vector& window_strides, + const std::vector& window_dilations, + const std::vector& input_left_pads, + const std::vector& input_right_pads, + const std::vector& tildes) + { + index_t i_ztilde = tildes[0]; + index_t i_ytilde = tildes[1]; + index_t i_xtilde = tildes[2]; + + const index_t N = dout_n_c_wos_lengths[0]; + const index_t C = dout_n_c_wos_lengths[1]; + + const index_t Di = din_n_c_wos_length[2]; + const index_t Hi = din_n_c_wos_length[3]; + const index_t Wi = din_n_c_wos_length[4]; + + const index_t Do = dout_n_c_wos_lengths[2]; + const index_t Ho = dout_n_c_wos_lengths[3]; + const index_t Wo = dout_n_c_wos_lengths[4]; + + const index_t Z = window_lengths[0]; + const index_t Y = window_lengths[1]; + const index_t X = window_lengths[2]; + + const index_t InLeftPadD = input_left_pads[0]; + const index_t InLeftPadH = input_left_pads[1]; + const index_t InLeftPadW = input_left_pads[2]; + + const index_t InRightPadD = input_right_pads[0]; + const index_t InRightPadH = input_right_pads[1]; + const index_t InRightPadW = input_right_pads[2]; + + const index_t ConvStrideD = window_strides[0]; + const index_t ConvStrideH = window_strides[1]; + const index_t ConvStrideW = window_strides[2]; + + const index_t ConvDilationD = window_dilations[0]; + const index_t ConvDilationH = window_dilations[1]; + const index_t ConvDilationW = window_dilations[2]; + + const auto out_n_do_ho_wo_c_grid_desc = + make_naive_tensor_descriptor(make_tuple(N, Do, Ho, Wo, C), + make_tuple(dout_n_c_wos_strides[0], + dout_n_c_wos_strides[2], + dout_n_c_wos_strides[3], + dout_n_c_wos_strides[4], + dout_n_c_wos_strides[1])); + + const auto GcdStrideDilationD = math::gcd(ConvStrideD, ConvDilationD); + const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); + const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); + + const auto ZTilde = ConvStrideD / GcdStrideDilationD; + const auto YTilde = ConvStrideH / GcdStrideDilationH; + const auto XTilde = ConvStrideW / GcdStrideDilationW; + + const auto ZDot = math::integer_divide_ceil(Z, ZTilde); + const auto YDot = math::integer_divide_ceil(Y, YTilde); + const auto XDot = math::integer_divide_ceil(X, XTilde); + + const auto DTilde = Do + math::integer_divide_ceil(ConvDilationD * (Z - I1), ConvStrideD); + const auto HTilde = Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH); + const auto WTilde = Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW); + + // only work on Tildes that contribute to non-padding area of input tensor + const auto IDTildeSliceBegin = math::integer_divide_floor( + math::max(I0, InLeftPadD - ConvDilationD * (ZTilde - I1)), ConvStrideD); + const auto IHTildeSliceBegin = math::integer_divide_floor( + math::max(I0, InLeftPadH - ConvDilationH * (YTilde - I1)), ConvStrideH); + const auto IWTildeSliceBegin = math::integer_divide_floor( + math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW); + + const auto IDTildeSliceEnd = + math::min(DTilde, math::integer_divide_ceil(InLeftPadD + Di - I1, ConvStrideD) + I1); + const auto IHTildeSliceEnd = + math::min(HTilde, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1); + const auto IWTildeSliceEnd = + math::min(WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1); + + const auto DTildeSlice = IDTildeSliceEnd - IDTildeSliceBegin; + const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin; + const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin; + + // ReduceK is different for each Reduce + const auto ZDotSlice = math::integer_divide_ceil(Z - i_ztilde, ZTilde); + const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde); + const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde); + + // Problem size of reduction kernel + const index_t MRaw = N * DTildeSlice * HTildeSlice * WTildeSlice * C; + const index_t MPad = math::integer_least_multiple(MRaw, M_BlockTileSize) - MRaw; + + const index_t KRaw = ZDotSlice * YDotSlice * XDotSlice; + const index_t KPad = math::integer_least_multiple(KRaw, K_BlockTileSize) - KRaw; + + // Out[ReduceM, ReduceK] + const auto out_n_dop_hop_wop_c_grid_desc = transform_tensor_descriptor( + out_n_do_ho_wo_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Do, I0, I0), + make_pad_transform(Ho, I0, I0), + make_pad_transform(Wo, I0, I0), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + + const auto out_n_zdot_dtilde_ydot_htilde_xdot_wtilde_c_grid_desc = + transform_tensor_descriptor( + out_n_dop_hop_wop_c_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(ZDot, DTilde), + make_tuple(-ConvDilationD / GcdStrideDilationD, I1)), + make_embed_transform(make_tuple(YDot, HTilde), + make_tuple(-ConvDilationH / GcdStrideDilationH, I1)), + make_embed_transform(make_tuple(XDot, WTilde), + make_tuple(-ConvDilationW / GcdStrideDilationW, I1)), + make_pass_through_transform(C)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, + Sequence<1, 2>{}, + Sequence<3, 4>{}, + Sequence<5, 6>{}, + Sequence<7>{})); + + const auto + out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_c_grid_desc = + transform_tensor_descriptor( + out_n_zdot_dtilde_ydot_htilde_xdot_wtilde_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_slice_transform(ZDot, I0, ZDotSlice), + make_slice_transform(DTilde, IDTildeSliceBegin, DTildeSlice), + make_slice_transform(YDot, I0, YDotSlice), + make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice), + make_slice_transform(XDot, I0, XDotSlice), + make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}, + Sequence<6>{}, + Sequence<7>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}, + Sequence<6>{}, + Sequence<7>{})); + + const auto out_grid_desc_reducemraw_reducekraw = transform_tensor_descriptor( + out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_c_grid_desc, + make_tuple( + make_merge_transform(make_tuple(N, DTildeSlice, HTildeSlice, WTildeSlice, C)), + make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice))), + make_tuple(Sequence<0, 2, 4, 6, 7>{}, Sequence<1, 3, 5>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_grid_desc_reducem_reducek = transform_tensor_descriptor( + out_grid_desc_reducemraw_reducekraw, + make_tuple(make_right_pad_transform(MRaw, MPad), make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + // In[ReduceM] + const auto in_n_di_hi_wi_c_grid_desc = + make_naive_tensor_descriptor(make_tuple(N, Di, Hi, Wi, C), + make_tuple(din_n_c_wos_strides[0], + din_n_c_wos_strides[2], + din_n_c_wos_strides[3], + din_n_c_wos_strides[4], + din_n_c_wos_strides[1])); + + const auto in_n_dip_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_n_di_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Di, InLeftPadD, InRightPadD), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + + const auto in_n_ztilde_dtilde_ytilde_htilde_xtilde_wtilde_c_grid_desc = + transform_tensor_descriptor( + in_n_dip_hip_wip_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(XTilde, DTilde), + make_tuple(ConvDilationD, ConvStrideD)), + make_embed_transform(make_tuple(YTilde, HTilde), + make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(XTilde, WTilde), + make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, + Sequence<1, 2>{}, + Sequence<3, 4>{}, + Sequence<5, 6>{}, + Sequence<7>{})); + + const auto in_n_dtildeslice_htildeslice_wtildeslice_c_grid_desc = + transform_tensor_descriptor( + in_n_ztilde_dtilde_ytilde_htilde_xtilde_wtilde_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_freeze_transform(i_ztilde), + make_slice_transform(DTilde, IDTildeSliceBegin, DTildeSlice), + make_freeze_transform(i_ytilde), + make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice), + make_freeze_transform(i_xtilde), + make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}, + Sequence<6>{}, + Sequence<7>{}), + make_tuple(Sequence<0>{}, + Sequence<>{}, + Sequence<1>{}, + Sequence<>{}, + Sequence<2>{}, + Sequence<>{}, + Sequence<3>{}, + Sequence<4>{})); + + const auto in_grid_desc_reducemraw = transform_tensor_descriptor( + in_n_dtildeslice_htildeslice_wtildeslice_c_grid_desc, + make_tuple( + make_merge_transform(make_tuple(N, DTildeSlice, HTildeSlice, WTildeSlice, C))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto in_grid_desc_reducem = + transform_tensor_descriptor(in_grid_desc_reducemraw, + make_tuple(make_right_pad_transform(MRaw, MPad)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); + + return make_tuple(out_grid_desc_reducem_reducek, in_grid_desc_reducem); + } + + using DoutDinGridDesc = decltype(Make3DGridDescriptor_Out_M_K_In_M({0, 0, 0, 0, 0}, + {0, 0, 0, 0, 0}, + {0, 0, 0, 0, 0}, + {0, 0, 0, 0, 0}, + {0, 0, 0}, + {0, 0, 0}, + {0, 0, 0}, + {0, 0, 0}, + {0, 0, 0}, + {0, 0, 0})); + + using DoutGridDesc_M_K = remove_cvref_t>; + using DinGridDesc_M = remove_cvref_t>; + + // FIXME + // for NDHWC, the dim C is the fastest dimension, and is not reduced. + // Hence, it is in M dimension for reduction kernel. + static constexpr index_t OutSrcInDstVectorDim = 0; // 0: M, 1: K + + using PassThrough = tensor_operation::element_wise::PassThrough; + using Div = tensor_operation::element_wise::UnaryDivide; + + using gridwise_reduce = GridwiseReduction_mk_to_m_threadwise; + + struct Argument : public BaseArgument + { + Argument(const DOutDataType* p_dout, + DInDataType* p_din, + std::vector dout_n_c_wos_lengths, + std::vector din_n_c_wos_length, + std::vector dout_n_c_wos_strides, + std::vector din_n_c_wos_strides, + std::vector window_lengths, + std::vector window_strides, + std::vector window_dilations, + std::vector input_left_pads, + std::vector input_right_pads) + : p_dout_grid_{p_dout}, + p_din_grid_{p_din}, + dout_n_c_wos_lengths_{dout_n_c_wos_lengths}, + din_n_c_wos_length_{din_n_c_wos_length}, + dout_n_c_wos_strides_{dout_n_c_wos_strides}, + din_n_c_wos_strides_{din_n_c_wos_strides}, + num_reduce_{1}, + div_element_op_{window_lengths[0] * window_lengths[1] * window_lengths[2]} + { + std::vector Tildes(NDimSpatial); + for(int i = 0; i < NDimSpatial; ++i) + { + int GcdStrideDilation = math::gcd(window_strides[i], window_dilations[i]); + Tildes[i] = window_strides[i] / GcdStrideDilation; + num_reduce_ *= Tildes[i]; + } + + for(index_t i_ztilde = 0; i_ztilde < Tildes[0]; ++i_ztilde) + { + for(index_t i_ytilde = 0; i_ytilde < Tildes[1]; ++i_ytilde) + { + for(index_t i_xtilde = 0; i_xtilde < Tildes[2]; ++i_xtilde) + { + // check slice is valid + const auto ZDotSlice = + math::integer_divide_ceil(window_lengths[0] - i_ztilde, Tildes[0]); + const auto YDotSlice = + math::integer_divide_ceil(window_lengths[1] - i_ytilde, Tildes[1]); + const auto XDotSlice = + math::integer_divide_ceil(window_lengths[2] - i_xtilde, Tildes[2]); + + if(ZDotSlice * YDotSlice * XDotSlice <= 0) + { + continue; + } + + const auto dout_din_grid_desc = + Make3DGridDescriptor_Out_M_K_In_M(dout_n_c_wos_lengths, + din_n_c_wos_length, + dout_n_c_wos_strides, + din_n_c_wos_strides, + window_lengths, + window_strides, + window_dilations, + input_left_pads, + input_right_pads, + {i_ztilde, i_ytilde, i_xtilde}); + + dout_grid_desc_m_k_container_.push_back(dout_din_grid_desc[I0]); + din_grid_desc_m_container_.push_back(dout_din_grid_desc[I1]); + } + } + } + } + + const DOutDataType* p_dout_grid_; + DInDataType* p_din_grid_; + std::vector dout_n_c_wos_lengths_; + std::vector din_n_c_wos_length_; + std::vector dout_n_c_wos_strides_; + std::vector din_n_c_wos_strides_; + + int num_reduce_; + std::vector dout_grid_desc_m_k_container_; + std::vector din_grid_desc_m_container_; + + Div div_element_op_; + }; + + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + float ave_time = 0; + + for(index_t i = 0; i < arg.num_reduce_; i++) + { + const auto kernel = kernel_reduce_threadwise; + + ck::index_t M = arg.dout_grid_desc_m_k_container_[i].GetLength(I0); + const index_t grid_size = (M / M_BlockTileSize); + + ave_time += launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.dout_grid_desc_m_k_container_[i], + arg.din_grid_desc_m_container_[i], + PassThrough{}, + arg.div_element_op_, + float(1), + arg.p_dout_grid_, + nullptr, + float(0), + arg.p_din_grid_, + nullptr); + } + + return ave_time; + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static bool IsSupportedArgument(const Argument& arg) + { + constexpr index_t Rank = NDimSpatial + 2; + int doutFastestDim = -1; + int dinFastestDim = -1; + + for(int i = 0; i < Rank; ++i) + { + if(arg.dout_n_c_wos_strides_[i] == 1) + doutFastestDim = i; + if(arg.din_n_c_wos_strides_[i] == 1) + dinFastestDim = i; + } + + if(doutFastestDim == -1 || dinFastestDim == -1) + { + if constexpr(InSrcOutDstVectorSize != 1) + return false; + } + else + { + if(arg.dout_n_c_wos_lengths_[doutFastestDim] % InSrcOutDstVectorSize != 0) + return false; + if(arg.din_n_c_wos_length_[dinFastestDim] % InSrcOutDstVectorSize != 0) + return false; + } + + return true; + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + std::unique_ptr + MakeArgumentPointer(const void* p_dout, + void* p_din, + std::vector dout_n_c_wos_lengths, + std::vector din_n_c_wos_length, + std::vector dout_n_c_wos_strides, + std::vector din_n_c_wos_strides, + std::vector window_lengths, + std::vector window_strides, + std::vector window_dilations, + std::vector input_left_pads, + std::vector input_right_pads) override + { + constexpr index_t Rank = NDimSpatial + 2; + + if(dout_n_c_wos_strides.size() != Rank || din_n_c_wos_strides.size() != Rank || + dout_n_c_wos_lengths.size() != Rank || din_n_c_wos_length.size() != Rank) + throw std::runtime_error("dimension is incorrect"); + + if(window_lengths.size() != NDimSpatial || window_strides.size() != NDimSpatial || + window_dilations.size() != NDimSpatial || input_left_pads.size() != NDimSpatial || + input_right_pads.size() != NDimSpatial) + throw std::runtime_error("dimension is incorrect"); + + return std::make_unique(static_cast(p_dout), + static_cast(p_din), + dout_n_c_wos_lengths, + din_n_c_wos_length, + dout_n_c_wos_strides, + din_n_c_wos_strides, + window_lengths, + window_strides, + window_dilations, + input_left_pads, + input_right_pads); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceAvgPool3dBwd<" << BlockSize << ","; + str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ","; + str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ","; + str <<"InSrcOutDstVectorSize_" << InSrcOutDstVectorSize << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_put_element_1d.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_put_element_1d.hpp index 8f72f88b06..c41eef8c45 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_put_element_1d.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_put_element_1d.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_avgpool_bwd.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_avgpool_bwd.hpp new file mode 100644 index 0000000000..fa06e77560 --- /dev/null +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_avgpool_bwd.hpp @@ -0,0 +1,354 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +#include "ck/library/utility/host_tensor.hpp" + +namespace ck { +namespace tensor_operation { +namespace host { + +// dinput descriptor in [N, C, Do, Ho, Wo] order +// doutput descriptor in [N, C, Di, Hi, Wi] order +// phyiscal layout is irrelavent +template = 1 && NDimSpatial <= 3, bool>::type = false> +struct ReferenceAvgPoolBwd : public device::BaseOperator +{ + // Argument + struct Argument : public device::BaseArgument + { + Argument(Tensor& dinput, + const Tensor& doutput, + std::vector window_spatial_lengths, + std::vector window_strides, + std::vector window_dilations, + std::vector dinput_left_pads, + std::vector dinput_right_pads) + : dinput_{dinput}, + doutput_{doutput}, + window_spatial_lengths_{window_spatial_lengths}, + window_strides_{window_strides}, + window_dilations_{window_dilations}, + in_left_pads_{dinput_left_pads}, + in_right_pads_{dinput_right_pads} + { + } + + Tensor& dinput_; + const Tensor& doutput_; + + std::vector window_spatial_lengths_; + std::vector window_strides_; + std::vector window_dilations_; + std::vector in_left_pads_; + std::vector in_right_pads_; + }; + + // Invoker + struct Invoker : public device::BaseInvoker + { + using Argument = ReferenceAvgPoolBwd::Argument; + + template ::type = false> + float RunAvgPoolBwd(const Argument& arg) + { + // Let input = x, outpu = y + // shape of x = [10], y = [6] + // window_size = 5, pad = 0, stride = 1, dilation = 1 + // Forward: + // y0 = 1/5 * (x0 + x1 + x2 + x3 + x4) + // y1 = 1/5 * (x1 + x2 + x3 + x4 + x5) + // ... + // y5 = 1/5 * (x5 + x6 + x7 + x8 + x9) + // y6 = 1/5 * (x6 + x7 + x8 + x9) + // ... + // y9 = 1/5 * (x9) + + // Backward: + // shape of dy = [6], dx = [10] + // dx0 = 1/5 * dy0 + // dx1 = 1/5 * (dy0 + dy1) + // dx2 = 1/5 * (dy0 + dy1 + dy2) + // ... + // dx4 = 1/5 * (dy0 + dy1 + dy2 + dy3 + dy4) + // dx5 = 1/5 * (dy1 + dy2 + dy3 + dy4 + dy5) + // ... + // dx9 = 1/5 * (dy5 + dy6 + dy7 + dy8 + dy9) + + auto f_ncw = [&](auto n, auto c, auto wi) { + std::size_t X = arg.window_spatial_lengths_[0]; + std::size_t Wo = arg.doutput_.GetLengths()[2]; + + float v_acc = 0; + + for(std::size_t x = 0; x < X; ++x) + { + // Out_Position = (In_Position + pad - x * dilation) / stride + auto w_tmp = static_cast(wi) + + static_cast(arg.in_left_pads_[0]) - + static_cast(x * arg.window_dilations_[0]); + + // Check the input pixel validity (in perspective of being affected by some + // doutput pixel) + if(w_tmp % arg.window_strides_[0] == 0) + { + auto wo = static_cast(w_tmp) / + static_cast(arg.window_strides_[0]); + + // Get the doutput pixel in valid range to accumulate the gradients for this + // input pixel + if(wo >= 0 && ck::type_convert(wo) < Wo) + { + v_acc += ck::type_convert(arg.doutput_(n, c, wo)); + } + } + } + + v_acc /= ck::type_convert(X); + arg.dinput_(n, c, wi) = ck::type_convert(v_acc); + }; + + make_ParallelTensorFunctor(f_ncw, + arg.dinput_.GetLengths()[0], + arg.dinput_.GetLengths()[1], + arg.dinput_.GetLengths()[2])( + std::thread::hardware_concurrency()); + + return 0; + } + + template ::type = false> + float RunAvgPoolBwd(const Argument& arg) + { + auto f_nchw = [&](auto n, auto c, auto hi, auto wi) { + std::size_t Y = arg.window_spatial_lengths_[0]; + std::size_t X = arg.window_spatial_lengths_[1]; + + std::size_t Ho = arg.doutput_.GetLengths()[2]; + std::size_t Wo = arg.doutput_.GetLengths()[3]; + + float v_acc = 0; + + for(std::size_t y = 0; y < Y; ++y) + { + // Out_Position = (In_Position + pad - x * dilation) / stride + auto h_tmp = static_cast(hi) + + static_cast(arg.in_left_pads_[0]) - + static_cast(y * arg.window_dilations_[0]); + + // Check the input pixel validity (in perspective of being affected by some + // doutput pixel) + if(h_tmp % arg.window_strides_[0] == 0) + { + auto ho = static_cast(h_tmp) / + static_cast(arg.window_strides_[0]); + + // Get the doutput pixel in valid range to accumulate the gradients for this + // input pixel + if(ho >= 0 && ck::type_convert(ho) < Ho) + { + for(std::size_t x = 0; x < X; ++x) + { + auto w_tmp = + static_cast(wi) + + static_cast(arg.in_left_pads_[1]) - + static_cast(x * arg.window_dilations_[1]); + if(w_tmp % arg.window_strides_[1] == 0) + { + auto wo = static_cast(w_tmp) / + static_cast(arg.window_strides_[1]); + if(wo >= 0 && ck::type_convert(wo) < Wo) + { + v_acc += + ck::type_convert(arg.doutput_(n, c, ho, wo)); + } + } + } + } + } + } + + v_acc /= ck::type_convert(Y * X); + arg.dinput_(n, c, hi, wi) = ck::type_convert(v_acc); + }; + + make_ParallelTensorFunctor(f_nchw, + arg.dinput_.GetLengths()[0], + arg.dinput_.GetLengths()[1], + arg.dinput_.GetLengths()[2], + arg.dinput_.GetLengths()[3])( + std::thread::hardware_concurrency()); + + return 0; + } + + template ::type = false> + float RunAvgPoolBwd(const Argument& arg) + { + auto f_ncdhw = [&](auto n, auto c, auto di, auto hi, auto wi) { + std::size_t Z = arg.window_spatial_lengths_[0]; + std::size_t Y = arg.window_spatial_lengths_[1]; + std::size_t X = arg.window_spatial_lengths_[2]; + + std::size_t Do = arg.doutput_.GetLengths()[2]; + std::size_t Ho = arg.doutput_.GetLengths()[3]; + std::size_t Wo = arg.doutput_.GetLengths()[4]; + + float v_acc = 0; + + for(std::size_t z = 0; z < Z; ++z) + { + // Out_Position = (In_Position + pad - x * dilation) / stride + auto d_tmp = static_cast(di) + + static_cast(arg.in_left_pads_[0]) - + static_cast(z * arg.window_dilations_[0]); + + // Check the input pixel validity (in perspective of being affected by some + // doutput pixel) + if(d_tmp % arg.window_strides_[0] == 0) + { + auto do_ = static_cast(d_tmp) / + static_cast(arg.window_strides_[0]); + + // Get the doutput pixel in valid range to accumulate the gradients for this + // input pixel + if(do_ >= 0 && ck::type_convert(do_) < Do) + { + for(std::size_t y = 0; y < Y; ++y) + { + auto h_tmp = + static_cast(hi) + + static_cast(arg.in_left_pads_[1]) - + static_cast(y * arg.window_dilations_[1]); + if(h_tmp % arg.window_strides_[1] == 0) + { + auto ho = static_cast(h_tmp) / + static_cast(arg.window_strides_[1]); + if(ho >= 0 && ck::type_convert(ho) < Ho) + { + for(std::size_t x = 0; x < X; ++x) + { + auto w_tmp = static_cast(wi) + + static_cast( + arg.in_left_pads_[2]) - + static_cast( + x * arg.window_dilations_[2]); + + if(w_tmp % arg.window_strides_[2] == 0) + { + auto wo = static_cast(w_tmp) / + static_cast( + arg.window_strides_[2]); + if(wo >= 0 && + ck::type_convert(wo) < Wo) + { + v_acc += ck::type_convert( + arg.doutput_(n, c, do_, ho, wo)); + } + } + } + } + } + } + } + } + } + + v_acc /= ck::type_convert(Z * Y * X); + arg.dinput_(n, c, di, hi, wi) = ck::type_convert(v_acc); + }; + + make_ParallelTensorFunctor(f_ncdhw, + arg.dinput_.GetLengths()[0], + arg.dinput_.GetLengths()[1], + arg.dinput_.GetLengths()[2], + arg.dinput_.GetLengths()[3], + arg.dinput_.GetLengths()[4])( + std::thread::hardware_concurrency()); + + return 0; + } + + float Run(const Argument& arg) + { + if(!(arg.dinput_.GetNumOfDimension() == NDimSpatial + 2 && + arg.doutput_.GetNumOfDimension() == NDimSpatial + 2)) + { + throw std::runtime_error("wrong! inconsistent dimension"); + } + + return RunAvgPoolBwd(arg); + } + + float Run(const device::BaseArgument* p_arg, + const StreamConfig& /* stream_config */ = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg)); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + bool IsSupportedArgument(const device::BaseArgument*) override { return true; } + + static auto MakeArgument(Tensor& dinput, + const Tensor& doutput, + std::vector window_spatial_lengths, + std::vector window_strides, + std::vector window_dilations, + std::vector dinput_left_pads, + std::vector dinput_right_pads) + { + if(window_spatial_lengths.size() != NDimSpatial || window_strides.size() != NDimSpatial || + window_dilations.size() != NDimSpatial || dinput_left_pads.size() != NDimSpatial || + dinput_right_pads.size() != NDimSpatial) + throw std::runtime_error("dimension is incorrect"); + + return Argument{dinput, + doutput, + window_spatial_lengths, + window_strides, + window_dilations, + dinput_left_pads, + dinput_right_pads}; + } + + static auto MakeInvoker() { return Invoker{}; } + + virtual std::unique_ptr MakeInvokerPointer() + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "ReferenceAvgPoolBwd" + << std::endl; + // clang-format on + + return str.str(); + } +}; + +} // namespace host +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp index 449734f434..50040a2441 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp @@ -125,7 +125,7 @@ struct ReferenceConvBwdData : public device::BaseOperator arg.in_element_op_(v_in, v_acc); - arg.input_(g, n, c, wi) = ck::type_convert(v_acc); + arg.input_(g, n, c, wi) = ck::type_convert(v_in); }; make_ParallelTensorFunctor(f_ncw, @@ -201,7 +201,7 @@ struct ReferenceConvBwdData : public device::BaseOperator arg.in_element_op_(v_in, v_acc); - arg.input_(g, n, c, hi, wi) = ck::type_convert(v_acc); + arg.input_(g, n, c, hi, wi) = ck::type_convert(v_in); }; make_ParallelTensorFunctor(f_nchw, @@ -299,7 +299,7 @@ struct ReferenceConvBwdData : public device::BaseOperator arg.in_element_op_(v_in, v_acc); - arg.input_(g, n, c, di, hi, wi) = ck::type_convert(v_acc); + arg.input_(g, n, c, di, hi, wi) = ck::type_convert(v_in); }; make_ParallelTensorFunctor(f_ncdhw,