mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-21 21:39:15 +00:00
Add Conv Backward Data on Navi21 for ResNet50 (#499)
* start add example
* add device dl
* change launch kernel
* change init data method
* change example config
* add config valid check
* add instance for dl bwd
* add instance to ckProfiler
* reserver to profiler and cmakelist
* add instance to ckProfiler2
* change instance f32 config
* fix example return value
Co-authored-by: letaoqin <letaoqin@amd.com>
Co-authored-by: Po Yen Chen <PoYen.Chen@amd.com>
[ROCm/composable_kernel commit: db0eb1ea9c]
This commit is contained in:
@@ -1,2 +1,5 @@
|
|||||||
add_example_executable(example_convnd_bwd_data_xdl_fp16 convnd_bwd_data_xdl_fp16.cpp)
|
add_example_executable(example_convnd_bwd_data_xdl_fp16 convnd_bwd_data_xdl_fp16.cpp)
|
||||||
target_link_libraries(example_convnd_bwd_data_xdl_fp16 PRIVATE utility)
|
target_link_libraries(example_convnd_bwd_data_xdl_fp16 PRIVATE utility)
|
||||||
|
|
||||||
|
add_example_executable(example_convnd_bwd_data_dl_fp16 convnd_bwd_data_dl_fp16.cpp)
|
||||||
|
target_link_libraries(example_convnd_bwd_data_dl_fp16 PRIVATE utility)
|
||||||
|
|||||||
@@ -61,9 +61,13 @@ int run_conv_bwd_data(bool do_verification,
|
|||||||
out.GenerateTensorValue(GeneratorTensor_2<OutDataType>{-5, 5});
|
out.GenerateTensorValue(GeneratorTensor_2<OutDataType>{-5, 5});
|
||||||
wei.GenerateTensorValue(GeneratorTensor_2<WeiDataType>{-5, 5});
|
wei.GenerateTensorValue(GeneratorTensor_2<WeiDataType>{-5, 5});
|
||||||
break;
|
break;
|
||||||
default:
|
case 2:
|
||||||
out.GenerateTensorValue(GeneratorTensor_3<OutDataType>{0.0, 1.0});
|
out.GenerateTensorValue(GeneratorTensor_3<OutDataType>{0.0, 1.0});
|
||||||
wei.GenerateTensorValue(GeneratorTensor_3<WeiDataType>{-0.5, 0.5});
|
wei.GenerateTensorValue(GeneratorTensor_3<WeiDataType>{-0.5, 0.5});
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
out.GenerateTensorValue(GeneratorTensor_1<OutDataType>{1});
|
||||||
|
wei.GenerateTensorValue(GeneratorTensor_1<WeiDataType>{1});
|
||||||
}
|
}
|
||||||
|
|
||||||
DeviceMem in_device_buf(sizeof(InDataType) * in_device.mDesc.GetElementSpaceSize());
|
DeviceMem in_device_buf(sizeof(InDataType) * in_device.mDesc.GetElementSpaceSize());
|
||||||
@@ -98,9 +102,8 @@ int run_conv_bwd_data(bool do_verification,
|
|||||||
|
|
||||||
if(!conv.IsSupportedArgument(argument))
|
if(!conv.IsSupportedArgument(argument))
|
||||||
{
|
{
|
||||||
throw std::runtime_error(
|
std::cout << "Not support,please check parameters or device";
|
||||||
"wrong! device_conv with the specified compilation parameters does "
|
return 0;
|
||||||
"not support this Conv problem");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
|
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
|
||||||
|
|||||||
180
example/17_convnd_bwd_data/convnd_bwd_data_dl_fp16.cpp
Normal file
180
example/17_convnd_bwd_data/convnd_bwd_data_dl_fp16.cpp
Normal file
@@ -0,0 +1,180 @@
|
|||||||
|
// SPDX-License-Identifier: MIT
|
||||||
|
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||||
|
|
||||||
|
#include "convnd_bwd_data_common.hpp"
|
||||||
|
|
||||||
|
#include "ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp"
|
||||||
|
|
||||||
|
using InDataType = ck::half_t;
|
||||||
|
using WeiDataType = ck::half_t;
|
||||||
|
using OutDataType = ck::half_t;
|
||||||
|
using AccDataType = float;
|
||||||
|
|
||||||
|
template <ck::index_t... Is>
|
||||||
|
using S = ck::Sequence<Is...>;
|
||||||
|
|
||||||
|
using InElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||||
|
using WeiElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||||
|
using OutElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||||
|
|
||||||
|
static constexpr auto ConvBwdDefault =
|
||||||
|
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Default;
|
||||||
|
|
||||||
|
template <ck::index_t NDimSpatial>
|
||||||
|
// clang-format off
|
||||||
|
using DeviceConvNdBwdDataInstance = ck::tensor_operation::device::DeviceConvNdBwdDataNwcKxcNwk_Dl<
|
||||||
|
// ######| NDim| InData| WeiData| OutData| AccData| In| Wei| Out| Convolution| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer|
|
||||||
|
// ######| Spatial| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Forward| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector|
|
||||||
|
// ######| | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | |
|
||||||
|
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||||
|
NDimSpatial, InDataType, WeiDataType, OutDataType, AccDataType, InElementOp, WeiElementOp, OutElementOp, ConvBwdDefault, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<1, 1, 8, 2>, S<16, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 8, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>;
|
||||||
|
// clang-format on
|
||||||
|
|
||||||
|
int main(int argc, char* argv[])
|
||||||
|
{
|
||||||
|
namespace ctc = ck::tensor_layout::convolution;
|
||||||
|
|
||||||
|
print_helper_msg();
|
||||||
|
|
||||||
|
bool do_verification = true;
|
||||||
|
int init_method = 1;
|
||||||
|
bool time_kernel = false;
|
||||||
|
|
||||||
|
ck::utils::conv::ConvParam conv_param{
|
||||||
|
2, 1, 128, 256, 256, {3, 3}, {71, 71}, {2, 2}, {1, 1}, {1, 1}, {1, 1}};
|
||||||
|
|
||||||
|
if(argc == 1)
|
||||||
|
{
|
||||||
|
// use default
|
||||||
|
}
|
||||||
|
else if(argc == 4)
|
||||||
|
{
|
||||||
|
do_verification = std::stoi(argv[1]);
|
||||||
|
init_method = std::stoi(argv[2]);
|
||||||
|
time_kernel = std::stoi(argv[3]);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
do_verification = std::stoi(argv[1]);
|
||||||
|
init_method = std::stoi(argv[2]);
|
||||||
|
time_kernel = std::stoi(argv[3]);
|
||||||
|
const ck::index_t num_dim_spatial = std::stoi(argv[4]);
|
||||||
|
|
||||||
|
conv_param = ck::utils::conv::parse_conv_param(num_dim_spatial, 5, argv);
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto in_element_op = InElementOp{};
|
||||||
|
const auto wei_element_op = WeiElementOp{};
|
||||||
|
const auto out_element_op = OutElementOp{};
|
||||||
|
|
||||||
|
if(conv_param.num_dim_spatial_ == 1)
|
||||||
|
{
|
||||||
|
using InLayout = ctc::GNWC;
|
||||||
|
using WeiLayout = ctc::GKXC;
|
||||||
|
using OutLayout = ctc::GNWK;
|
||||||
|
|
||||||
|
const auto in_g_n_c_wis_desc =
|
||||||
|
ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<InLayout>(
|
||||||
|
conv_param);
|
||||||
|
|
||||||
|
const auto wei_g_k_c_xs_desc =
|
||||||
|
ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed<WeiLayout>(
|
||||||
|
conv_param);
|
||||||
|
|
||||||
|
const auto out_g_n_k_wos_desc =
|
||||||
|
ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed<OutLayout>(
|
||||||
|
conv_param);
|
||||||
|
|
||||||
|
return run_conv_bwd_data<1,
|
||||||
|
InDataType,
|
||||||
|
WeiDataType,
|
||||||
|
OutDataType,
|
||||||
|
InElementOp,
|
||||||
|
WeiElementOp,
|
||||||
|
OutElementOp,
|
||||||
|
DeviceConvNdBwdDataInstance<1>>(do_verification,
|
||||||
|
init_method,
|
||||||
|
time_kernel,
|
||||||
|
conv_param,
|
||||||
|
in_g_n_c_wis_desc,
|
||||||
|
wei_g_k_c_xs_desc,
|
||||||
|
out_g_n_k_wos_desc,
|
||||||
|
in_element_op,
|
||||||
|
wei_element_op,
|
||||||
|
out_element_op);
|
||||||
|
}
|
||||||
|
else if(conv_param.num_dim_spatial_ == 2)
|
||||||
|
{
|
||||||
|
using InLayout = ctc::GNHWC;
|
||||||
|
using WeiLayout = ctc::GKYXC;
|
||||||
|
using OutLayout = ctc::GNHWK;
|
||||||
|
|
||||||
|
const auto in_g_n_c_wis_desc =
|
||||||
|
ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<InLayout>(
|
||||||
|
conv_param);
|
||||||
|
|
||||||
|
const auto wei_g_k_c_xs_desc =
|
||||||
|
ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed<WeiLayout>(
|
||||||
|
conv_param);
|
||||||
|
|
||||||
|
const auto out_g_n_k_wos_desc =
|
||||||
|
ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed<OutLayout>(
|
||||||
|
conv_param);
|
||||||
|
|
||||||
|
return run_conv_bwd_data<2,
|
||||||
|
InDataType,
|
||||||
|
WeiDataType,
|
||||||
|
OutDataType,
|
||||||
|
InElementOp,
|
||||||
|
WeiElementOp,
|
||||||
|
OutElementOp,
|
||||||
|
DeviceConvNdBwdDataInstance<2>>(do_verification,
|
||||||
|
init_method,
|
||||||
|
time_kernel,
|
||||||
|
conv_param,
|
||||||
|
in_g_n_c_wis_desc,
|
||||||
|
wei_g_k_c_xs_desc,
|
||||||
|
out_g_n_k_wos_desc,
|
||||||
|
in_element_op,
|
||||||
|
wei_element_op,
|
||||||
|
out_element_op);
|
||||||
|
}
|
||||||
|
else if(conv_param.num_dim_spatial_ == 3)
|
||||||
|
{
|
||||||
|
using InLayout = ctc::GNDHWC;
|
||||||
|
using WeiLayout = ctc::GKZYXC;
|
||||||
|
using OutLayout = ctc::GNDHWK;
|
||||||
|
|
||||||
|
const auto in_g_n_c_wis_desc =
|
||||||
|
ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<InLayout>(
|
||||||
|
conv_param);
|
||||||
|
|
||||||
|
const auto wei_g_k_c_xs_desc =
|
||||||
|
ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed<WeiLayout>(
|
||||||
|
conv_param);
|
||||||
|
|
||||||
|
const auto out_g_n_k_wos_desc =
|
||||||
|
ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed<OutLayout>(
|
||||||
|
conv_param);
|
||||||
|
|
||||||
|
return run_conv_bwd_data<3,
|
||||||
|
InDataType,
|
||||||
|
WeiDataType,
|
||||||
|
OutDataType,
|
||||||
|
InElementOp,
|
||||||
|
WeiElementOp,
|
||||||
|
OutElementOp,
|
||||||
|
DeviceConvNdBwdDataInstance<3>>(do_verification,
|
||||||
|
init_method,
|
||||||
|
time_kernel,
|
||||||
|
conv_param,
|
||||||
|
in_g_n_c_wis_desc,
|
||||||
|
wei_g_k_c_xs_desc,
|
||||||
|
out_g_n_k_wos_desc,
|
||||||
|
in_element_op,
|
||||||
|
wei_element_op,
|
||||||
|
out_element_op);
|
||||||
|
}
|
||||||
|
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -101,6 +101,42 @@ void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances(
|
|||||||
PassThrough,
|
PassThrough,
|
||||||
PassThrough>>>& instances);
|
PassThrough>>>& instances);
|
||||||
|
|
||||||
|
// conv2d dl
|
||||||
|
void add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f16_instances(
|
||||||
|
std::vector<std::unique_ptr<DeviceConvBwdData<2,
|
||||||
|
NHWC,
|
||||||
|
KYXC,
|
||||||
|
NHWK,
|
||||||
|
F16,
|
||||||
|
F16,
|
||||||
|
F16,
|
||||||
|
PassThrough,
|
||||||
|
PassThrough,
|
||||||
|
PassThrough>>>& instances);
|
||||||
|
|
||||||
|
void add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instances(
|
||||||
|
std::vector<std::unique_ptr<DeviceConvBwdData<2,
|
||||||
|
NHWC,
|
||||||
|
KYXC,
|
||||||
|
NHWK,
|
||||||
|
F32,
|
||||||
|
F32,
|
||||||
|
F32,
|
||||||
|
PassThrough,
|
||||||
|
PassThrough,
|
||||||
|
PassThrough>>>& instances);
|
||||||
|
|
||||||
|
void add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_int8_instances(
|
||||||
|
std::vector<std::unique_ptr<DeviceConvBwdData<2,
|
||||||
|
NHWC,
|
||||||
|
KYXC,
|
||||||
|
NHWK,
|
||||||
|
int8_t,
|
||||||
|
int8_t,
|
||||||
|
int8_t,
|
||||||
|
PassThrough,
|
||||||
|
PassThrough,
|
||||||
|
PassThrough>>>& instances);
|
||||||
// conv3d backward data
|
// conv3d backward data
|
||||||
void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instances(
|
void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instances(
|
||||||
std::vector<std::unique_ptr<DeviceConvBwdData<3,
|
std::vector<std::unique_ptr<DeviceConvBwdData<3,
|
||||||
@@ -216,11 +252,13 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceConvBw
|
|||||||
is_same_v<OutDataType, float>)
|
is_same_v<OutDataType, float>)
|
||||||
{
|
{
|
||||||
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances(op_ptrs);
|
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances(op_ptrs);
|
||||||
|
add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instances(op_ptrs);
|
||||||
}
|
}
|
||||||
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
|
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
|
||||||
is_same_v<OutDataType, half_t>)
|
is_same_v<OutDataType, half_t>)
|
||||||
{
|
{
|
||||||
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances(op_ptrs);
|
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances(op_ptrs);
|
||||||
|
add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f16_instances(op_ptrs);
|
||||||
}
|
}
|
||||||
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
|
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
|
||||||
is_same_v<WeiDataType, ck::bhalf_t> &&
|
is_same_v<WeiDataType, ck::bhalf_t> &&
|
||||||
@@ -232,6 +270,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceConvBw
|
|||||||
is_same_v<OutDataType, int8_t>)
|
is_same_v<OutDataType, int8_t>)
|
||||||
{
|
{
|
||||||
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances(op_ptrs);
|
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances(op_ptrs);
|
||||||
|
add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_int8_instances(op_ptrs);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else if constexpr(NumDimSpatial == 3 && is_same_v<InLayout, NDHWC> &&
|
else if constexpr(NumDimSpatial == 3 && is_same_v<InLayout, NDHWC> &&
|
||||||
|
|||||||
@@ -3,4 +3,8 @@ add_instance_library(device_conv2d_bwd_data_instance
|
|||||||
device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instance.cpp
|
device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instance.cpp
|
||||||
device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp
|
device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp
|
||||||
device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instance.cpp
|
device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instance.cpp
|
||||||
|
|
||||||
|
device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f16_instance.cpp
|
||||||
|
device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instance.cpp
|
||||||
|
device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_int8_instance.cpp
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -0,0 +1,83 @@
|
|||||||
|
// SPDX-License-Identifier: MIT
|
||||||
|
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||||
|
|
||||||
|
#include <cstdlib>
|
||||||
|
|
||||||
|
#include "ck/ck.hpp"
|
||||||
|
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||||
|
#include "ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp"
|
||||||
|
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||||
|
|
||||||
|
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||||
|
|
||||||
|
namespace ck {
|
||||||
|
namespace tensor_operation {
|
||||||
|
namespace device {
|
||||||
|
namespace instance {
|
||||||
|
|
||||||
|
using InDataType = ck::half_t;
|
||||||
|
using WeiDataType = ck::half_t;
|
||||||
|
using OutDataType = ck::half_t;
|
||||||
|
using AccDataType = float;
|
||||||
|
|
||||||
|
template <ck::index_t... Is>
|
||||||
|
using S = ck::Sequence<Is...>;
|
||||||
|
|
||||||
|
using NHWC = ck::tensor_layout::convolution::NHWC;
|
||||||
|
using KYXC = ck::tensor_layout::convolution::KYXC;
|
||||||
|
using NHWK = ck::tensor_layout::convolution::NHWK;
|
||||||
|
|
||||||
|
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||||
|
using InElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||||
|
using WeiElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||||
|
using OutElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||||
|
|
||||||
|
static constexpr auto ConvBwdDataDefault =
|
||||||
|
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Default;
|
||||||
|
|
||||||
|
static constexpr auto ConvBwdDataFilter1x1Stride1Pad0 =
|
||||||
|
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0;
|
||||||
|
|
||||||
|
// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k]
|
||||||
|
using device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f16_instances = std::tuple<
|
||||||
|
// clang-format off
|
||||||
|
//#########################| NDim| InData| WeiData| OutData| AccData| In| Wei| Out| Convolution| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer|
|
||||||
|
//#########################| Spatial| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Forward| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector|
|
||||||
|
//#########################| | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | |
|
||||||
|
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||||
|
DeviceConvNdBwdDataNwcKxcNwk_Dl< 2, InDataType, WeiDataType, OutDataType, AccDataType, InElementOp, WeiElementOp, OutElementOp, ConvBwdDataDefault, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<1, 1, 8, 2>, S<16, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 8, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>
|
||||||
|
// clang-format on
|
||||||
|
>;
|
||||||
|
|
||||||
|
using device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_1x1_s1_p0_f16_instances = std::tuple<
|
||||||
|
// clang-format off
|
||||||
|
//#########################| NDim| InData| WeiData| OutData| AccData| In| Wei| Out| Convolution| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer|
|
||||||
|
//#########################| Spatial| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Forward| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector|
|
||||||
|
//#########################| | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | |
|
||||||
|
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||||
|
DeviceConvNdBwdDataNwcKxcNwk_Dl< 2, InDataType, WeiDataType, OutDataType, AccDataType, InElementOp, WeiElementOp, OutElementOp, ConvBwdDataFilter1x1Stride1Pad0, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<1, 1, 8, 2>, S<16, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 8, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>
|
||||||
|
// clang-format on
|
||||||
|
>;
|
||||||
|
|
||||||
|
void add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f16_instances(
|
||||||
|
std::vector<std::unique_ptr<DeviceConvBwdData<2,
|
||||||
|
NHWC,
|
||||||
|
KYXC,
|
||||||
|
NHWK,
|
||||||
|
InDataType,
|
||||||
|
WeiDataType,
|
||||||
|
OutDataType,
|
||||||
|
PassThrough,
|
||||||
|
PassThrough,
|
||||||
|
PassThrough>>>& instances)
|
||||||
|
{
|
||||||
|
add_device_operation_instances(instances,
|
||||||
|
device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f16_instances{});
|
||||||
|
add_device_operation_instances(
|
||||||
|
instances, device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_1x1_s1_p0_f16_instances{});
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace instance
|
||||||
|
} // namespace device
|
||||||
|
} // namespace tensor_operation
|
||||||
|
} // namespace ck
|
||||||
@@ -0,0 +1,83 @@
|
|||||||
|
// SPDX-License-Identifier: MIT
|
||||||
|
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||||
|
|
||||||
|
#include <cstdlib>
|
||||||
|
|
||||||
|
#include "ck/ck.hpp"
|
||||||
|
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||||
|
#include "ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp"
|
||||||
|
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||||
|
|
||||||
|
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||||
|
|
||||||
|
namespace ck {
|
||||||
|
namespace tensor_operation {
|
||||||
|
namespace device {
|
||||||
|
namespace instance {
|
||||||
|
|
||||||
|
using InDataType = float;
|
||||||
|
using WeiDataType = float;
|
||||||
|
using OutDataType = float;
|
||||||
|
using AccDataType = float;
|
||||||
|
|
||||||
|
template <ck::index_t... Is>
|
||||||
|
using S = ck::Sequence<Is...>;
|
||||||
|
|
||||||
|
using NHWC = ck::tensor_layout::convolution::NHWC;
|
||||||
|
using KYXC = ck::tensor_layout::convolution::KYXC;
|
||||||
|
using NHWK = ck::tensor_layout::convolution::NHWK;
|
||||||
|
|
||||||
|
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||||
|
using InElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||||
|
using WeiElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||||
|
using OutElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||||
|
|
||||||
|
static constexpr auto ConvBwdDataDefault =
|
||||||
|
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Default;
|
||||||
|
|
||||||
|
static constexpr auto ConvBwdDataFilter1x1Stride1Pad0 =
|
||||||
|
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0;
|
||||||
|
|
||||||
|
// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k]
|
||||||
|
using device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instances = std::tuple<
|
||||||
|
// clang-format off
|
||||||
|
//#########################| NDim| InData| WeiData| OutData| AccData| In| Wei| Out| Convolution| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer|
|
||||||
|
//#########################| Spatial| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Forward| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector|
|
||||||
|
//#########################| | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | |
|
||||||
|
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||||
|
DeviceConvNdBwdDataNwcKxcNwk_Dl< 2, InDataType, WeiDataType, OutDataType, AccDataType, InElementOp, WeiElementOp, OutElementOp, ConvBwdDataDefault, 256, 128, 128, 16, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<1, 1, 8, 1>, S<16, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>
|
||||||
|
// clang-format on
|
||||||
|
>;
|
||||||
|
|
||||||
|
using device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_1x1_s1_p0_f32_instances = std::tuple<
|
||||||
|
// clang-format off
|
||||||
|
//#########################| NDim| InData| WeiData| OutData| AccData| In| Wei| Out| Convolution| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer|
|
||||||
|
//#########################| Spatial| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Forward| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector|
|
||||||
|
//#########################| | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | |
|
||||||
|
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||||
|
DeviceConvNdBwdDataNwcKxcNwk_Dl< 2, InDataType, WeiDataType, OutDataType, AccDataType, InElementOp, WeiElementOp, OutElementOp, ConvBwdDataFilter1x1Stride1Pad0, 256, 128, 128, 16, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<1, 1, 8, 1>, S<16, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>
|
||||||
|
// clang-format on
|
||||||
|
>;
|
||||||
|
|
||||||
|
void add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instances(
|
||||||
|
std::vector<std::unique_ptr<DeviceConvBwdData<2,
|
||||||
|
NHWC,
|
||||||
|
KYXC,
|
||||||
|
NHWK,
|
||||||
|
InDataType,
|
||||||
|
WeiDataType,
|
||||||
|
OutDataType,
|
||||||
|
PassThrough,
|
||||||
|
PassThrough,
|
||||||
|
PassThrough>>>& instances)
|
||||||
|
{
|
||||||
|
add_device_operation_instances(instances,
|
||||||
|
device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instances{});
|
||||||
|
add_device_operation_instances(
|
||||||
|
instances, device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_1x1_s1_p0_f32_instances{});
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace instance
|
||||||
|
} // namespace device
|
||||||
|
} // namespace tensor_operation
|
||||||
|
} // namespace ck
|
||||||
@@ -0,0 +1,83 @@
|
|||||||
|
// SPDX-License-Identifier: MIT
|
||||||
|
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||||
|
|
||||||
|
#include <cstdlib>
|
||||||
|
|
||||||
|
#include "ck/ck.hpp"
|
||||||
|
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||||
|
#include "ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp"
|
||||||
|
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||||
|
|
||||||
|
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||||
|
|
||||||
|
namespace ck {
|
||||||
|
namespace tensor_operation {
|
||||||
|
namespace device {
|
||||||
|
namespace instance {
|
||||||
|
|
||||||
|
using InDataType = int8_t;
|
||||||
|
using WeiDataType = int8_t;
|
||||||
|
using OutDataType = int8_t;
|
||||||
|
using AccDataType = int32_t;
|
||||||
|
|
||||||
|
template <ck::index_t... Is>
|
||||||
|
using S = ck::Sequence<Is...>;
|
||||||
|
|
||||||
|
using NHWC = ck::tensor_layout::convolution::NHWC;
|
||||||
|
using KYXC = ck::tensor_layout::convolution::KYXC;
|
||||||
|
using NHWK = ck::tensor_layout::convolution::NHWK;
|
||||||
|
|
||||||
|
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||||
|
using InElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||||
|
using WeiElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||||
|
using OutElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||||
|
|
||||||
|
static constexpr auto ConvBwdDataDefault =
|
||||||
|
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Default;
|
||||||
|
|
||||||
|
static constexpr auto ConvBwdDataFilter1x1Stride1Pad0 =
|
||||||
|
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0;
|
||||||
|
|
||||||
|
// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k]
|
||||||
|
using device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_int8_instances = std::tuple<
|
||||||
|
// clang-format off
|
||||||
|
//#########################| NDim| InData| WeiData| OutData| AccData| In| Wei| Out| Convolution| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer|
|
||||||
|
//#########################| Spatial| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Forward| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector|
|
||||||
|
//#########################| | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | |
|
||||||
|
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||||
|
DeviceConvNdBwdDataNwcKxcNwk_Dl< 2, InDataType, WeiDataType, OutDataType, AccDataType, InElementOp, WeiElementOp, OutElementOp, ConvBwdDataDefault, 256, 128, 128, 16, 4, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<1, 1, 8, 4>, S<16, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 8, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>
|
||||||
|
// clang-format on
|
||||||
|
>;
|
||||||
|
|
||||||
|
using device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_1x1_s1_p0_int8_instances = std::tuple<
|
||||||
|
// clang-format off
|
||||||
|
//#########################| NDim| InData| WeiData| OutData| AccData| In| Wei| Out| Convolution| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer|
|
||||||
|
//#########################| Spatial| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Forward| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector|
|
||||||
|
//#########################| | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | |
|
||||||
|
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||||
|
DeviceConvNdBwdDataNwcKxcNwk_Dl< 2, InDataType, WeiDataType, OutDataType, AccDataType, InElementOp, WeiElementOp, OutElementOp, ConvBwdDataFilter1x1Stride1Pad0, 256, 128, 128, 16, 4, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<1, 1, 8, 4>, S<16, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 8, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>
|
||||||
|
// clang-format on
|
||||||
|
>;
|
||||||
|
|
||||||
|
void add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_int8_instances(
|
||||||
|
std::vector<std::unique_ptr<DeviceConvBwdData<2,
|
||||||
|
NHWC,
|
||||||
|
KYXC,
|
||||||
|
NHWK,
|
||||||
|
InDataType,
|
||||||
|
WeiDataType,
|
||||||
|
OutDataType,
|
||||||
|
PassThrough,
|
||||||
|
PassThrough,
|
||||||
|
PassThrough>>>& instances)
|
||||||
|
{
|
||||||
|
add_device_operation_instances(instances,
|
||||||
|
device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_int8_instances{});
|
||||||
|
add_device_operation_instances(
|
||||||
|
instances, device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_1x1_s1_p0_int8_instances{});
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace instance
|
||||||
|
} // namespace device
|
||||||
|
} // namespace tensor_operation
|
||||||
|
} // namespace ck
|
||||||
Reference in New Issue
Block a user