mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
Unified implementation of 1d/2d/3d conv bwd-data. fp32/fp16/bfp16/int8 (#134)
* start convnd bwd data * add 3d laoyout name * add conv1d reference * add con3d reference * finished example client code * conv1d kernel finished * fix input error * add conv3d * add 3d layout in conv_utils.hpp * fix sepecial check * addconvnd lib * add test for bwd data * finished test * add check slice length * convnd bwd data start * profiler can be compiled * fix some bug * set input to zero * modify readme for example * fix test_convnd_bwd_data bug * test_convnd_bwd_data parameter desc * workaround for 1d * workaroud for 2d * change init value * workaround for 3d int8 * fix init value bug * remove workaround * fix acc data type * add int32 * change select function to template * tilda to tilde * remove int32 instance * fix commit for device hpp * fix comments for profiler * using profile imp to test * add pass verification * fix conv2d reference * fix conflict * remove double batched_gemm * fix exampel conv2d data and test convnd * format * change conv2d_bwd_data return value * remove repeat = 1 * remove conv bwd data Co-authored-by: ltqin <letaoqin@amd.com> Co-authored-by: Chao Liu <chao.liu2@amd.com>
This commit is contained in:
@@ -68,6 +68,7 @@ using DeviceConvBwdDataInstance = ck::tensor_operation::device::
|
||||
using ReferenceConvBwdInstance = ck::tensor_operation::host::ReferenceConvBwdData<InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
AccDataType,
|
||||
InElementOp,
|
||||
WeiElementOp,
|
||||
OutElementOp>;
|
||||
|
||||
1
example/17_convnd_bwd_data_xdl/CMakeLists.txt
Normal file
1
example/17_convnd_bwd_data_xdl/CMakeLists.txt
Normal file
@@ -0,0 +1 @@
|
||||
add_example_executable(example_convnd_bwd_data_xdl convnd_bwd_data_xdl.cpp)
|
||||
80
example/17_convnd_bwd_data_xdl/README.md
Normal file
80
example/17_convnd_bwd_data_xdl/README.md
Normal file
@@ -0,0 +1,80 @@
|
||||
# Instructions for ```convnd_bwd_data_xdl``` Example
|
||||
|
||||
## Docker script
|
||||
```bash
|
||||
docker run \
|
||||
-it \
|
||||
--rm \
|
||||
--privileged \
|
||||
--group-add sudo \
|
||||
-w /root/workspace \
|
||||
-v ${PATH_TO_LOCAL_WORKSPACE}:/root/workspace \
|
||||
rocm/tensorflow:rocm4.3.1-tf2.6-dev \
|
||||
/bin/bash
|
||||
```
|
||||
|
||||
## Build ```convnd_bwd_data_xdl```
|
||||
```bash
|
||||
mkdir build && cd build
|
||||
```
|
||||
|
||||
```bash
|
||||
# Need to specify target ID, example below is gfx908
|
||||
cmake \
|
||||
-D BUILD_DEV=OFF \
|
||||
-D CMAKE_BUILD_TYPE=Release \
|
||||
-D CMAKE_CXX_FLAGS="-DCK_AMD_GPU_GFX908 --amdgpu-target=gfx908 -O3 " \
|
||||
-D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
|
||||
-D CMAKE_PREFIX_PATH=/opt/rocm \
|
||||
..
|
||||
```
|
||||
|
||||
```bash
|
||||
make -j convnd_bwd_data_xdl
|
||||
```
|
||||
|
||||
## Run ```example_convnd_bwd_data_xdl```
|
||||
```bash
|
||||
#arg1: verification (0=no, 1=yes)
|
||||
#arg2: initialization (0=no init, 1=integer value, 2=decimal value)
|
||||
#arg3: run kernel # of times (>1)
|
||||
#arg4: num_dim_spatial(1|2|3)
|
||||
#arg5 to ...: N, K, C, [Z,] [Y,] X, [Di,] [Hi,] Wi, S[z,] [Sy,] Sx, [Dz,] [Dy,] Dx, [LeftPz,] [LeftPy,] LeftPx, [RightPy,] [RightPy,] RightPx
|
||||
./bin/convnd_bwd_data_xdl 0 1 5
|
||||
```
|
||||
|
||||
Result
|
||||
```
|
||||
in_n_c_hi_wi: dim 4, lengths {128, 128, 71, 71}, strides {645248, 1, 9088, 128}
|
||||
wei_k_c_y_x: dim 4, lengths {256, 128, 3, 3}, strides {1152, 1, 384, 128}
|
||||
out_n_k_ho_wo: dim 4, lengths {128, 256, 36, 36}, strides {331776, 1, 9216, 256}
|
||||
arg.a_grid_desc_k0_m_k1_container_{128, 175232, 8}
|
||||
arg.b_grid_desc_k0_n_k1_container_{128, 128, 8}
|
||||
arg.c_grid_desc_m_n_container_{ 175232, 128}
|
||||
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_( 2738, 2, 2, 2, 4, 2 )
|
||||
launch_and_time_kernel: grid_dim {1369, 1, 1}, block_dim {256, 1, 1}
|
||||
Warm up
|
||||
Start running 1 times...
|
||||
arg.a_grid_desc_k0_m_k1_container_{64, 175232, 8}
|
||||
arg.b_grid_desc_k0_n_k1_container_{64, 128, 8}
|
||||
arg.c_grid_desc_m_n_container_{ 175232, 128}
|
||||
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_( 2738, 2, 2, 2, 4, 2 )
|
||||
launch_and_time_kernel: grid_dim {1369, 1, 1}, block_dim {256, 1, 1}
|
||||
Warm up
|
||||
Start running 1 times...
|
||||
arg.a_grid_desc_k0_m_k1_container_{64, 175232, 8}
|
||||
arg.b_grid_desc_k0_n_k1_container_{64, 128, 8}
|
||||
arg.c_grid_desc_m_n_container_{ 175232, 128}
|
||||
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_( 2738, 2, 2, 2, 4, 2 )
|
||||
launch_and_time_kernel: grid_dim {1369, 1, 1}, block_dim {256, 1, 1}
|
||||
Warm up
|
||||
Start running 1 times...
|
||||
arg.a_grid_desc_k0_m_k1_container_{32, 175232, 8}
|
||||
arg.b_grid_desc_k0_n_k1_container_{32, 128, 8}
|
||||
arg.c_grid_desc_m_n_container_{ 175232, 128}
|
||||
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_( 2738, 2, 2, 2, 4, 2 )
|
||||
launch_and_time_kernel: grid_dim {1369, 1, 1}, block_dim {256, 1, 1}
|
||||
Warm up
|
||||
Start running 1 times...
|
||||
Perf: 1.40031 ms, 69.8734 TFlops, 179.037 GB/s
|
||||
```
|
||||
415
example/17_convnd_bwd_data_xdl/convnd_bwd_data_xdl.cpp
Normal file
415
example/17_convnd_bwd_data_xdl/convnd_bwd_data_xdl.cpp
Normal file
@@ -0,0 +1,415 @@
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
#include <stdlib.h>
|
||||
#include <half.hpp>
|
||||
|
||||
#include "config.hpp"
|
||||
#include "conv_utils.hpp"
|
||||
#include "print.hpp"
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "host_tensor_generator.hpp"
|
||||
#include "device_tensor.hpp"
|
||||
#include "tensor_layout.hpp"
|
||||
#include "element_wise_operation.hpp"
|
||||
#include "device_convnd_bwd_data_xdl_ndhwc_kzyxc_ndhwk.hpp"
|
||||
#include "reference_conv_bwd_data.hpp"
|
||||
|
||||
using InDataType = ck::half_t;
|
||||
using WeiDataType = ck::half_t;
|
||||
using OutDataType = ck::half_t;
|
||||
using AccDataType = float;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using InElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using WeiElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using OutElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
static constexpr auto ConvBwdDefault =
|
||||
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization_t::Default;
|
||||
|
||||
using DeviceConvBwdDataBasePtr =
|
||||
ck::tensor_operation::device::DeviceConvBwdDataPtr<InElementOp, WeiElementOp, OutElementOp>;
|
||||
|
||||
template <ck::index_t NumDimSpatial>
|
||||
using DeviceConvNDBwdDataInstance = ck::tensor_operation::device::
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K<
|
||||
InDataType, // InDataType
|
||||
WeiDataType, // WeiDataType
|
||||
OutDataType, // OutDataType
|
||||
AccDataType, // AccDataType
|
||||
InElementOp, // InElementwiseOperation
|
||||
WeiElementOp, // WeiElementwiseOperation
|
||||
OutElementOp, // OutElementwiseOperation
|
||||
ConvBwdDefault, // ConvolutionBackwardDataSpecialization_t
|
||||
NumDimSpatial, // NumDimSpatial
|
||||
256, // BlockSize
|
||||
128, // MPerBlock
|
||||
128, // NPerBlock
|
||||
4, // K0PerBlock
|
||||
8, // K1
|
||||
32, // MPerXdl
|
||||
32, // NPerXdl
|
||||
2, // MXdlPerWave
|
||||
2, // NXdlPerWave
|
||||
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1
|
||||
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
|
||||
2, // ABlockTransferSrcVectorDim
|
||||
8, // ABlockTransferSrcScalarPerVector
|
||||
8, // ABlockTransferDstScalarPerVector_K1
|
||||
true, // ABlockLdsAddExtraM
|
||||
S<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1
|
||||
S<2, 0, 1>, // BBlockTransferThreadClusterArrangeOrder
|
||||
S<0, 2, 1>, // BBlockTransferSrcAccessOrder
|
||||
1, // BBlockTransferSrcVectorDim
|
||||
2, // BBlockTransferSrcScalarPerVector
|
||||
8, // BBlockTransferDstScalarPerVector_K1
|
||||
true, // BBlockLdsAddExtraN
|
||||
7,
|
||||
1>; // GemmCThreadTransferDstScalarPerVector
|
||||
|
||||
template <ck::index_t NumDimSpatial>
|
||||
using ReferenceConvBwdDataInstance =
|
||||
ck::tensor_operation::host::ReferenceConvBwdData<InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
AccDataType,
|
||||
InElementOp,
|
||||
WeiElementOp,
|
||||
OutElementOp,
|
||||
NumDimSpatial>;
|
||||
|
||||
void PrintUseMsg()
|
||||
{
|
||||
std::cout << "arg1: verification (0=no, 1=yes)\n"
|
||||
<< "arg2: initialization (0=no init, 1=random value, 2= init to 1 )\n"
|
||||
<< "arg3: run kernel # of times (>1)\n"
|
||||
<< "arg4: N spatial dimensions (default 2)\n"
|
||||
<< "Following arguments (depending on number of spatial dims):\n"
|
||||
<< " N, K, C, \n"
|
||||
<< " <filter spatial dimensions>, (ie Y, X for 2D)\n"
|
||||
<< " <input image spatial dimensions>, (ie Hi, Wi for 2D)\n"
|
||||
<< " <strides>, (ie Sy, Sx for 2D)\n"
|
||||
<< " <dilations>, (ie Dy, Dx for 2D)\n"
|
||||
<< " <left padding>, (ie LeftPy, LeftPx for 2D)\n"
|
||||
<< " <right padding>, (ie RightPy, RightPx for 2D)\n"
|
||||
<< std::endl;
|
||||
}
|
||||
ck::conv_util::ConvParams ParseConvParams(int num_dim_spatial, char* argv[])
|
||||
{
|
||||
// (N, K, C) + num_dim_spatial * 6 (filter, input, strides, dilations, pad left, pad right)
|
||||
ck::conv_util::ConvParams params;
|
||||
int arg_idx = 5;
|
||||
|
||||
params.num_dim_spatial = num_dim_spatial;
|
||||
params.N = std::stoi(argv[arg_idx++]);
|
||||
params.K = std::stoi(argv[arg_idx++]);
|
||||
params.C = std::stoi(argv[arg_idx++]);
|
||||
|
||||
params.filter_spatial_lengths.resize(num_dim_spatial);
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
params.filter_spatial_lengths[i] = std::stoi(argv[arg_idx++]);
|
||||
}
|
||||
params.input_spatial_lengths.resize(num_dim_spatial);
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
params.input_spatial_lengths[i] = std::stoi(argv[arg_idx++]);
|
||||
}
|
||||
params.conv_filter_strides.resize(num_dim_spatial);
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
params.conv_filter_strides[i] = std::stoi(argv[arg_idx++]);
|
||||
}
|
||||
params.conv_filter_dilations.resize(num_dim_spatial);
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
params.conv_filter_dilations[i] = std::stoi(argv[arg_idx++]);
|
||||
}
|
||||
params.input_left_pads.resize(num_dim_spatial);
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
params.input_left_pads[i] = std::stoi(argv[arg_idx++]);
|
||||
}
|
||||
params.input_right_pads.resize(num_dim_spatial);
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
params.input_right_pads[i] = std::stoi(argv[arg_idx++]);
|
||||
}
|
||||
|
||||
return params;
|
||||
}
|
||||
|
||||
HostTensorDescriptor GetInputHostTensorDescriptor(const std::vector<std::size_t>& dims,
|
||||
int num_dim_spatial = 2)
|
||||
{
|
||||
namespace tl = ck::tensor_layout::convolution;
|
||||
|
||||
switch(num_dim_spatial)
|
||||
{
|
||||
case 3: {
|
||||
return ck::conv_util::GetHostTensorDescriptor(dims, tl::NDHWC{});
|
||||
}
|
||||
case 2: {
|
||||
return ck::conv_util::GetHostTensorDescriptor(dims, tl::NHWC{});
|
||||
}
|
||||
case 1: {
|
||||
return ck::conv_util::GetHostTensorDescriptor(dims, tl::NWC{});
|
||||
}
|
||||
default: {
|
||||
throw std::runtime_error("Unsupported number of spatial dimensions provided!");
|
||||
}
|
||||
}
|
||||
}
|
||||
HostTensorDescriptor GetFiltersHostTensorDescriptor(const std::vector<std::size_t>& dims,
|
||||
int num_dim_spatial = 2)
|
||||
{
|
||||
namespace tl = ck::tensor_layout::convolution;
|
||||
|
||||
switch(num_dim_spatial)
|
||||
{
|
||||
case 3: {
|
||||
return ck::conv_util::GetHostTensorDescriptor(dims, tl::KZYXC{});
|
||||
}
|
||||
case 2: {
|
||||
return ck::conv_util::GetHostTensorDescriptor(dims, tl::KYXC{});
|
||||
}
|
||||
case 1: {
|
||||
return ck::conv_util::GetHostTensorDescriptor(dims, tl::KXC{});
|
||||
}
|
||||
default: {
|
||||
throw std::runtime_error("Unsupported number of spatial dimensions provided!");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
HostTensorDescriptor GetOutputHostTensorDescriptor(const std::vector<std::size_t>& dims,
|
||||
int num_dim_spatial = 2)
|
||||
{
|
||||
namespace tl = ck::tensor_layout::convolution;
|
||||
|
||||
switch(num_dim_spatial)
|
||||
{
|
||||
case 3: {
|
||||
return ck::conv_util::GetHostTensorDescriptor(dims, tl::NDHWK{});
|
||||
}
|
||||
case 2: {
|
||||
return ck::conv_util::GetHostTensorDescriptor(dims, tl::NHWK{});
|
||||
}
|
||||
case 1: {
|
||||
return ck::conv_util::GetHostTensorDescriptor(dims, tl::NWK{});
|
||||
}
|
||||
|
||||
default: {
|
||||
throw std::runtime_error("Unsupported number of spatial dimensions provided!");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
DeviceConvBwdDataBasePtr GetConvInstance(int num_dim_spatial)
|
||||
{
|
||||
switch(num_dim_spatial)
|
||||
{
|
||||
case 3: {
|
||||
return std::make_unique<DeviceConvNDBwdDataInstance<3>>();
|
||||
}
|
||||
case 2: {
|
||||
return std::make_unique<DeviceConvNDBwdDataInstance<2>>();
|
||||
}
|
||||
case 1: {
|
||||
return std::make_unique<DeviceConvNDBwdDataInstance<1>>();
|
||||
}
|
||||
default: {
|
||||
throw std::runtime_error("Unsupported number of spatial dimensions provided!");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = 0;
|
||||
int init_method = 0;
|
||||
int nrepeat = 5;
|
||||
int num_dim_spatial = 2;
|
||||
|
||||
ck::conv_util::ConvParams params;
|
||||
params.C = 128;
|
||||
|
||||
if(argc == 4)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
nrepeat = std::stoi(argv[3]);
|
||||
}
|
||||
else if(argc > 4)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
nrepeat = std::stoi(argv[3]);
|
||||
num_dim_spatial = std::stoi(argv[4]);
|
||||
// check args number
|
||||
int conv_args = 3 + num_dim_spatial * 6;
|
||||
int cmdline_nargs = conv_args + 5;
|
||||
if(cmdline_nargs != argc)
|
||||
{
|
||||
PrintUseMsg();
|
||||
exit(1);
|
||||
}
|
||||
|
||||
params = ParseConvParams(num_dim_spatial, argv);
|
||||
}
|
||||
else if(argc != 1)
|
||||
{
|
||||
PrintUseMsg();
|
||||
exit(1);
|
||||
}
|
||||
|
||||
std::vector<std::size_t> input_dims{static_cast<std::size_t>(params.N),
|
||||
static_cast<std::size_t>(params.C)};
|
||||
input_dims.insert(std::end(input_dims),
|
||||
std::begin(params.input_spatial_lengths),
|
||||
std::end(params.input_spatial_lengths));
|
||||
|
||||
std::vector<std::size_t> filter_dims{static_cast<std::size_t>(params.K),
|
||||
static_cast<std::size_t>(params.C)};
|
||||
filter_dims.insert(std::end(filter_dims),
|
||||
std::begin(params.filter_spatial_lengths),
|
||||
std::end(params.filter_spatial_lengths));
|
||||
|
||||
const std::vector<ck::index_t>& output_spatial_lengths = params.GetOutputSpatialLengths();
|
||||
std::vector<std::size_t> output_dims{static_cast<std::size_t>(params.N),
|
||||
static_cast<std::size_t>(params.K)};
|
||||
output_dims.insert(std::end(output_dims),
|
||||
std::begin(output_spatial_lengths),
|
||||
std::end(output_spatial_lengths));
|
||||
|
||||
Tensor<InDataType> in_n_c_hi_wi_host_result(
|
||||
GetInputHostTensorDescriptor(input_dims, num_dim_spatial));
|
||||
Tensor<InDataType> in_n_c_hi_wi_device_result(
|
||||
GetInputHostTensorDescriptor(input_dims, num_dim_spatial));
|
||||
Tensor<WeiDataType> wei_k_c_y_x(GetFiltersHostTensorDescriptor(filter_dims, num_dim_spatial));
|
||||
Tensor<OutDataType> out_n_k_ho_wo(GetOutputHostTensorDescriptor(output_dims, num_dim_spatial));
|
||||
|
||||
std::cout << "in_n_c_hi_wi: " << in_n_c_hi_wi_host_result.mDesc << std::endl;
|
||||
std::cout << "wei_k_c_y_x: " << wei_k_c_y_x.mDesc << std::endl;
|
||||
std::cout << "out_n_k_ho_wo: " << out_n_k_ho_wo.mDesc << std::endl;
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
out_n_k_ho_wo.GenerateTensorValue(GeneratorTensor_3<OutDataType>{-0.2, 0.2});
|
||||
wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_3<WeiDataType>{-0.2, 0.2});
|
||||
break;
|
||||
default:
|
||||
out_n_k_ho_wo.GenerateTensorValue(GeneratorTensor_1<OutDataType>{1});
|
||||
wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_1<WeiDataType>{1});
|
||||
}
|
||||
|
||||
DeviceMem in_device_buf(sizeof(InDataType) *
|
||||
in_n_c_hi_wi_device_result.mDesc.GetElementSpace());
|
||||
DeviceMem wei_device_buf(sizeof(WeiDataType) * wei_k_c_y_x.mDesc.GetElementSpace());
|
||||
DeviceMem out_device_buf(sizeof(OutDataType) * out_n_k_ho_wo.mDesc.GetElementSpace());
|
||||
|
||||
out_device_buf.ToDevice(out_n_k_ho_wo.mData.data());
|
||||
wei_device_buf.ToDevice(wei_k_c_y_x.mData.data());
|
||||
// reset input to zero
|
||||
in_n_c_hi_wi_device_result.GenerateTensorValue(GeneratorTensor_1<InDataType>{0});
|
||||
in_device_buf.ToDevice(in_n_c_hi_wi_device_result.mData.data());
|
||||
|
||||
// do GEMM
|
||||
auto conv = GetConvInstance(num_dim_spatial);
|
||||
auto invoker = conv->MakeInvokerPointer();
|
||||
auto argument =
|
||||
conv->MakeArgumentPointer(static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()),
|
||||
static_cast<WeiDataType*>(wei_device_buf.GetDeviceBuffer()),
|
||||
static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()),
|
||||
params.N,
|
||||
params.K,
|
||||
params.C,
|
||||
params.input_spatial_lengths,
|
||||
params.filter_spatial_lengths,
|
||||
output_spatial_lengths,
|
||||
params.conv_filter_strides,
|
||||
params.conv_filter_dilations,
|
||||
params.input_left_pads,
|
||||
params.input_right_pads,
|
||||
InElementOp{},
|
||||
WeiElementOp{},
|
||||
OutElementOp{});
|
||||
|
||||
if(!conv->IsSupportedArgument(argument.get()))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! device_conv with the specified compilation parameters does "
|
||||
"not support this Conv problem");
|
||||
}
|
||||
|
||||
float ave_time = invoker->Run(argument.get(), nrepeat);
|
||||
|
||||
std::size_t flop = ck::conv_util::GetFlops(
|
||||
params.N, params.C, params.K, params.filter_spatial_lengths, output_spatial_lengths);
|
||||
std::size_t num_btype =
|
||||
ck::conv_util::GetBtype<InDataType, WeiDataType, OutDataType>(params.N,
|
||||
params.C,
|
||||
params.K,
|
||||
params.input_spatial_lengths,
|
||||
params.filter_spatial_lengths,
|
||||
output_spatial_lengths);
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
|
||||
<< std::endl;
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
auto verify_f = [&](const auto& ref_conv) {
|
||||
auto ref_invoker = ref_conv.MakeInvoker();
|
||||
|
||||
auto ref_argument = ref_conv.MakeArgument(in_n_c_hi_wi_host_result,
|
||||
wei_k_c_y_x,
|
||||
out_n_k_ho_wo,
|
||||
params.conv_filter_strides,
|
||||
params.conv_filter_dilations,
|
||||
params.input_left_pads,
|
||||
params.input_right_pads,
|
||||
InElementOp{},
|
||||
WeiElementOp{},
|
||||
OutElementOp{});
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
in_device_buf.FromDevice(in_n_c_hi_wi_device_result.mData.data());
|
||||
|
||||
check_error(in_n_c_hi_wi_host_result, in_n_c_hi_wi_device_result);
|
||||
};
|
||||
|
||||
switch(num_dim_spatial)
|
||||
{
|
||||
case 3: {
|
||||
auto ref_conv = ReferenceConvBwdDataInstance<3>();
|
||||
verify_f(ref_conv);
|
||||
break;
|
||||
}
|
||||
case 2: {
|
||||
auto ref_conv = ReferenceConvBwdDataInstance<2>();
|
||||
verify_f(ref_conv);
|
||||
break;
|
||||
}
|
||||
case 1: {
|
||||
auto ref_conv = ReferenceConvBwdDataInstance<1>();
|
||||
verify_f(ref_conv);
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
throw std::runtime_error("Unsupported number of spatial dimensions provided!");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -39,5 +39,6 @@ add_subdirectory(11_conv2d_bwd_wgt)
|
||||
add_subdirectory(12_reduce)
|
||||
add_subdirectory(13_pool2d_fwd)
|
||||
add_subdirectory(14_gemm_xdl_requant_relu_requant)
|
||||
add_subdirectory(17_convnd_bwd_data_xdl)
|
||||
add_subdirectory(15_grouped_gemm)
|
||||
add_subdirectory(16_gemm_reduce)
|
||||
|
||||
@@ -7,9 +7,9 @@
|
||||
|
||||
namespace ck {
|
||||
|
||||
// Number of GEMMs = YTilda * XTilda
|
||||
// Number of GEMMs = YTilde * XTilde
|
||||
// GemmM = C
|
||||
// GemmN = N * HTildaSlice * WTildaSlice
|
||||
// GemmN = N * HTildeSlice * WTildeSlice
|
||||
// GemmK = K * YDotSlice * XDotSlice
|
||||
template <typename... Wei,
|
||||
typename... In,
|
||||
@@ -18,8 +18,8 @@ template <typename... Wei,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads,
|
||||
index_t IYTildaValue,
|
||||
index_t IXTildaValue,
|
||||
index_t IYTildeValue,
|
||||
index_t IXTildeValue,
|
||||
index_t GemmK1Value>
|
||||
__host__ __device__ constexpr auto
|
||||
transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk(
|
||||
@@ -30,8 +30,8 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk(
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
Number<IYTildaValue>,
|
||||
Number<IXTildaValue>,
|
||||
Number<IYTildeValue>,
|
||||
Number<IXTildeValue>,
|
||||
Number<GemmK1Value>)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
@@ -40,8 +40,8 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk(
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto GemmK1 = Number<GemmK1Value>{};
|
||||
constexpr auto IYTilda = Number<IYTildaValue>{};
|
||||
constexpr auto IXTilda = Number<IXTildaValue>{};
|
||||
constexpr auto IYTilde = Number<IYTildeValue>{};
|
||||
constexpr auto IXTilde = Number<IXTildeValue>{};
|
||||
|
||||
const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0);
|
||||
const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3);
|
||||
@@ -71,55 +71,55 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk(
|
||||
const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
|
||||
const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
|
||||
|
||||
const auto YTilda = ConvStrideH / GcdStrideDilationH;
|
||||
const auto XTilda = ConvStrideW / GcdStrideDilationW;
|
||||
const auto YTilde = ConvStrideH / GcdStrideDilationH;
|
||||
const auto XTilde = ConvStrideW / GcdStrideDilationW;
|
||||
|
||||
const auto YDot = math::integer_divide_ceil(Y, YTilda);
|
||||
const auto XDot = math::integer_divide_ceil(X, XTilda);
|
||||
const auto YDot = math::integer_divide_ceil(Y, YTilde);
|
||||
const auto XDot = math::integer_divide_ceil(X, XTilde);
|
||||
|
||||
const auto HTilda = Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH);
|
||||
const auto WTilda = Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW);
|
||||
const auto HTilde = Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH);
|
||||
const auto WTilde = Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW);
|
||||
|
||||
// only work on HTilda and WTilda that contribute to non-padding area of input tensor
|
||||
const auto IHTildaSliceBegin = math::integer_divide_floor(
|
||||
math::max(I0, InLeftPadH - ConvDilationH * (YTilda - I1)), ConvStrideH);
|
||||
const auto IWTildaSliceBegin = math::integer_divide_floor(
|
||||
math::max(I0, InLeftPadW - ConvDilationW * (XTilda - I1)), ConvStrideW);
|
||||
// only work on HTilde and WTilde that contribute to non-padding area of input tensor
|
||||
const auto IHTildeSliceBegin = math::integer_divide_floor(
|
||||
math::max(I0, InLeftPadH - ConvDilationH * (YTilde - I1)), ConvStrideH);
|
||||
const auto IWTildeSliceBegin = math::integer_divide_floor(
|
||||
math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW);
|
||||
|
||||
const auto IHTildaSliceEnd =
|
||||
math::min(HTilda, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1);
|
||||
const auto IWTildaSliceEnd =
|
||||
math::min(WTilda, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1);
|
||||
const auto IHTildeSliceEnd =
|
||||
math::min(HTilde, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1);
|
||||
const auto IWTildeSliceEnd =
|
||||
math::min(WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1);
|
||||
|
||||
const auto HTildaSlice = IHTildaSliceEnd - IHTildaSliceBegin;
|
||||
const auto WTildaSlice = IWTildaSliceEnd - IWTildaSliceBegin;
|
||||
const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin;
|
||||
const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin;
|
||||
|
||||
// GemmK is different for each GEMM
|
||||
const auto YDotSlice = math::integer_divide_ceil(Y - IYTilda, YTilda);
|
||||
const auto XDotSlice = math::integer_divide_ceil(X - IXTilda, XTilda);
|
||||
const auto YDotSlice = math::integer_divide_ceil(Y - IYTilde, YTilde);
|
||||
const auto XDotSlice = math::integer_divide_ceil(X - IXTilde, XTilde);
|
||||
|
||||
const auto K1 = GemmK1;
|
||||
const auto K0 = K / K1;
|
||||
|
||||
// weight tensor
|
||||
const auto wei_k_ydot_ytilda_xdot_xtilda_c_grid_desc = transform_tensor_descriptor(
|
||||
const auto wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc = transform_tensor_descriptor(
|
||||
wei_k_y_x_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(K),
|
||||
make_embed_transform(make_tuple(YDot, YTilda),
|
||||
make_embed_transform(make_tuple(YDot, YTilde),
|
||||
make_tuple(ConvStrideH / GcdStrideDilationH, I1)),
|
||||
make_embed_transform(make_tuple(XDot, XTilda),
|
||||
make_embed_transform(make_tuple(XDot, XTilde),
|
||||
make_tuple(ConvStrideW / GcdStrideDilationW, I1)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
|
||||
|
||||
const auto wei_k0_k1_ydotslice_xdotslice_c_grid_desc =
|
||||
transform_tensor_descriptor(wei_k_ydot_ytilda_xdot_xtilda_c_grid_desc,
|
||||
transform_tensor_descriptor(wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(K0, K1)),
|
||||
make_slice_transform(YDot, I0, YDotSlice),
|
||||
make_slice_transform(XDot, I0, XDotSlice),
|
||||
make_freeze_transform(IYTilda),
|
||||
make_freeze_transform(IXTilda),
|
||||
make_freeze_transform(IYTilde),
|
||||
make_freeze_transform(IXTilde),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<1>{},
|
||||
@@ -163,25 +163,25 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk(
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto out_n_ydot_htilda_xdot_wtilda_k_grid_desc = transform_tensor_descriptor(
|
||||
const auto out_n_ydot_htilde_xdot_wtilde_k_grid_desc = transform_tensor_descriptor(
|
||||
out_n_hop_wop_k_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_embed_transform(make_tuple(YDot, HTilda),
|
||||
make_embed_transform(make_tuple(YDot, HTilde),
|
||||
make_tuple(-ConvDilationH / GcdStrideDilationH, I1)),
|
||||
make_embed_transform(make_tuple(XDot, WTilda),
|
||||
make_embed_transform(make_tuple(XDot, WTilde),
|
||||
make_tuple(-ConvDilationW / GcdStrideDilationW, I1)),
|
||||
make_pass_through_transform(K)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
|
||||
|
||||
const auto out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k0_k1_grid_desc =
|
||||
const auto out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc =
|
||||
transform_tensor_descriptor(
|
||||
out_n_ydot_htilda_xdot_wtilda_k_grid_desc,
|
||||
out_n_ydot_htilde_xdot_wtilde_k_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_slice_transform(YDot, I0, YDotSlice),
|
||||
make_slice_transform(HTilda, IHTildaSliceBegin, HTildaSlice),
|
||||
make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
|
||||
make_slice_transform(XDot, I0, XDotSlice),
|
||||
make_slice_transform(WTilda, IWTildaSliceBegin, WTildaSlice),
|
||||
make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
|
||||
make_unmerge_transform(make_tuple(K0, K1))),
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<1>{},
|
||||
@@ -198,17 +198,17 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk(
|
||||
|
||||
#if 1
|
||||
const auto out_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k0_k1_grid_desc,
|
||||
out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)),
|
||||
make_merge_transform(make_tuple(N, HTildaSlice, WTildaSlice)),
|
||||
make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)),
|
||||
make_pass_through_transform(K1)),
|
||||
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}, Sequence<6>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
#else
|
||||
const auto out_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k0_k1_grid_desc,
|
||||
out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(K0, YDotSlice, XDotSlice)),
|
||||
make_merge_transform(make_tuple(N, HTildaSlice, WTildaSlice)),
|
||||
make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)),
|
||||
make_pass_through_transform(K1)),
|
||||
make_tuple(Sequence<5, 1, 3>{}, Sequence<0, 2, 4>{}, Sequence<6>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
@@ -224,24 +224,24 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk(
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto in_n_ytilda_htilda_xtilda_wtilda_c_grid_desc = transform_tensor_descriptor(
|
||||
const auto in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hip_wip_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_embed_transform(make_tuple(YTilda, HTilda),
|
||||
make_embed_transform(make_tuple(YTilde, HTilde),
|
||||
make_tuple(ConvDilationH, ConvStrideH)),
|
||||
make_embed_transform(make_tuple(XTilda, WTilda),
|
||||
make_embed_transform(make_tuple(XTilde, WTilde),
|
||||
make_tuple(ConvDilationW, ConvStrideW)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
|
||||
|
||||
const auto in_n_htildaslice_wtildaslice_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_ytilda_htilda_xtilda_wtilda_c_grid_desc,
|
||||
const auto in_n_htildeslice_wtildeslice_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_freeze_transform(IYTilda),
|
||||
make_slice_transform(HTilda, IHTildaSliceBegin, HTildaSlice),
|
||||
make_freeze_transform(IXTilda),
|
||||
make_slice_transform(WTilda, IWTildaSliceBegin, WTildaSlice),
|
||||
make_freeze_transform(IYTilde),
|
||||
make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
|
||||
make_freeze_transform(IXTilde),
|
||||
make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<1>{},
|
||||
@@ -257,9 +257,9 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk(
|
||||
Sequence<3>{}));
|
||||
|
||||
const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
|
||||
in_n_htildaslice_wtildaslice_c_grid_desc,
|
||||
in_n_htildeslice_wtildeslice_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(C),
|
||||
make_merge_transform(make_tuple(N, HTildaSlice, WTildaSlice))),
|
||||
make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice))),
|
||||
make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
|
||||
@@ -10,8 +10,8 @@ namespace ck {
|
||||
// A: out
|
||||
// B: wei
|
||||
// C: in
|
||||
// Number of GEMMs = YTilda * XTilda
|
||||
// GemmM = N * HTildaSlice * WTildaSlice
|
||||
// Number of GEMMs = YTilde * XTilde
|
||||
// GemmM = N * HTildeSlice * WTildeSlice
|
||||
// GemmN = C
|
||||
// GemmK = K * YDotSlice * XDotSlice
|
||||
template <typename... Wei,
|
||||
@@ -21,8 +21,8 @@ template <typename... Wei,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads,
|
||||
typename IYTilda,
|
||||
typename IXTilda,
|
||||
typename IYTilde,
|
||||
typename IXTilde,
|
||||
index_t GemmK1Value>
|
||||
__host__ __device__ constexpr auto
|
||||
transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
|
||||
@@ -33,8 +33,8 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
IYTilda i_ytilda,
|
||||
IXTilda i_xtilda,
|
||||
IYTilde i_ytilde,
|
||||
IXTilde i_xtilde,
|
||||
Number<GemmK1Value>)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
@@ -72,32 +72,32 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
|
||||
const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
|
||||
const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
|
||||
|
||||
const auto YTilda = ConvStrideH / GcdStrideDilationH;
|
||||
const auto XTilda = ConvStrideW / GcdStrideDilationW;
|
||||
const auto YTilde = ConvStrideH / GcdStrideDilationH;
|
||||
const auto XTilde = ConvStrideW / GcdStrideDilationW;
|
||||
|
||||
const auto YDot = math::integer_divide_ceil(Y, YTilda);
|
||||
const auto XDot = math::integer_divide_ceil(X, XTilda);
|
||||
const auto YDot = math::integer_divide_ceil(Y, YTilde);
|
||||
const auto XDot = math::integer_divide_ceil(X, XTilde);
|
||||
|
||||
const auto HTilda = Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH);
|
||||
const auto WTilda = Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW);
|
||||
const auto HTilde = Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH);
|
||||
const auto WTilde = Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW);
|
||||
|
||||
// only work on HTilda and WTilda that contribute to non-padding area of input tensor
|
||||
const auto IHTildaSliceBegin = math::integer_divide_floor(
|
||||
math::max(I0, InLeftPadH - ConvDilationH * (YTilda - I1)), ConvStrideH);
|
||||
const auto IWTildaSliceBegin = math::integer_divide_floor(
|
||||
math::max(I0, InLeftPadW - ConvDilationW * (XTilda - I1)), ConvStrideW);
|
||||
// only work on HTilde and WTilde that contribute to non-padding area of input tensor
|
||||
const auto IHTildeSliceBegin = math::integer_divide_floor(
|
||||
math::max(I0, InLeftPadH - ConvDilationH * (YTilde - I1)), ConvStrideH);
|
||||
const auto IWTildeSliceBegin = math::integer_divide_floor(
|
||||
math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW);
|
||||
|
||||
const auto IHTildaSliceEnd =
|
||||
math::min(HTilda, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1);
|
||||
const auto IWTildaSliceEnd =
|
||||
math::min(WTilda, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1);
|
||||
const auto IHTildeSliceEnd =
|
||||
math::min(HTilde, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1);
|
||||
const auto IWTildeSliceEnd =
|
||||
math::min(WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1);
|
||||
|
||||
const auto HTildaSlice = IHTildaSliceEnd - IHTildaSliceBegin;
|
||||
const auto WTildaSlice = IWTildaSliceEnd - IWTildaSliceBegin;
|
||||
const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin;
|
||||
const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin;
|
||||
|
||||
// GemmK is different for each GEMM
|
||||
const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilda, YTilda);
|
||||
const auto XDotSlice = math::integer_divide_ceil(X - i_xtilda, XTilda);
|
||||
const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde);
|
||||
const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde);
|
||||
|
||||
const auto K1 = GemmK1;
|
||||
const auto K0 = K / K1;
|
||||
@@ -113,25 +113,25 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto out_n_ydot_htilda_xdot_wtilda_k_grid_desc = transform_tensor_descriptor(
|
||||
const auto out_n_ydot_htilde_xdot_wtilde_k_grid_desc = transform_tensor_descriptor(
|
||||
out_n_hop_wop_k_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_embed_transform(make_tuple(YDot, HTilda),
|
||||
make_embed_transform(make_tuple(YDot, HTilde),
|
||||
make_tuple(-ConvDilationH / GcdStrideDilationH, I1)),
|
||||
make_embed_transform(make_tuple(XDot, WTilda),
|
||||
make_embed_transform(make_tuple(XDot, WTilde),
|
||||
make_tuple(-ConvDilationW / GcdStrideDilationW, I1)),
|
||||
make_pass_through_transform(K)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
|
||||
|
||||
const auto out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k0_k1_grid_desc =
|
||||
const auto out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc =
|
||||
transform_tensor_descriptor(
|
||||
out_n_ydot_htilda_xdot_wtilda_k_grid_desc,
|
||||
out_n_ydot_htilde_xdot_wtilde_k_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_slice_transform(YDot, I0, YDotSlice),
|
||||
make_slice_transform(HTilda, IHTildaSliceBegin, HTildaSlice),
|
||||
make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
|
||||
make_slice_transform(XDot, I0, XDotSlice),
|
||||
make_slice_transform(WTilda, IWTildaSliceBegin, WTildaSlice),
|
||||
make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
|
||||
make_unmerge_transform(make_tuple(K0, K1))),
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<1>{},
|
||||
@@ -148,41 +148,41 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
|
||||
|
||||
#if 1
|
||||
const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k0_k1_grid_desc,
|
||||
out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)),
|
||||
make_merge_transform(make_tuple(N, HTildaSlice, WTildaSlice)),
|
||||
make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)),
|
||||
make_pass_through_transform(K1)),
|
||||
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}, Sequence<6>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
#else
|
||||
const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k0_k1_grid_desc,
|
||||
out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(K0, YDotSlice, XDotSlice)),
|
||||
make_merge_transform(make_tuple(N, HTildaSlice, WTildaSlice)),
|
||||
make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)),
|
||||
make_pass_through_transform(K1)),
|
||||
make_tuple(Sequence<5, 1, 3>{}, Sequence<0, 2, 4>{}, Sequence<6>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
#endif
|
||||
|
||||
// B: weight tensor
|
||||
const auto wei_k_ydot_ytilda_xdot_xtilda_c_grid_desc = transform_tensor_descriptor(
|
||||
const auto wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc = transform_tensor_descriptor(
|
||||
wei_k_y_x_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(K),
|
||||
make_embed_transform(make_tuple(YDot, YTilda),
|
||||
make_embed_transform(make_tuple(YDot, YTilde),
|
||||
make_tuple(ConvStrideH / GcdStrideDilationH, I1)),
|
||||
make_embed_transform(make_tuple(XDot, XTilda),
|
||||
make_embed_transform(make_tuple(XDot, XTilde),
|
||||
make_tuple(ConvStrideW / GcdStrideDilationW, I1)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
|
||||
|
||||
const auto wei_k0_k1_ydotslice_xdotslice_c_grid_desc =
|
||||
transform_tensor_descriptor(wei_k_ydot_ytilda_xdot_xtilda_c_grid_desc,
|
||||
transform_tensor_descriptor(wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(K0, K1)),
|
||||
make_slice_transform(YDot, I0, YDotSlice),
|
||||
make_slice_transform(XDot, I0, XDotSlice),
|
||||
make_freeze_transform(i_ytilda),
|
||||
make_freeze_transform(i_xtilda),
|
||||
make_freeze_transform(i_ytilde),
|
||||
make_freeze_transform(i_xtilde),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<1>{},
|
||||
@@ -225,24 +225,24 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto in_n_ytilda_htilda_xtilda_wtilda_c_grid_desc = transform_tensor_descriptor(
|
||||
const auto in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hip_wip_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_embed_transform(make_tuple(YTilda, HTilda),
|
||||
make_embed_transform(make_tuple(YTilde, HTilde),
|
||||
make_tuple(ConvDilationH, ConvStrideH)),
|
||||
make_embed_transform(make_tuple(XTilda, WTilda),
|
||||
make_embed_transform(make_tuple(XTilde, WTilde),
|
||||
make_tuple(ConvDilationW, ConvStrideW)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
|
||||
|
||||
const auto in_n_htildaslice_wtildaslice_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_ytilda_htilda_xtilda_wtilda_c_grid_desc,
|
||||
const auto in_n_htildeslice_wtildeslice_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_freeze_transform(i_ytilda),
|
||||
make_slice_transform(HTilda, IHTildaSliceBegin, HTildaSlice),
|
||||
make_freeze_transform(i_xtilda),
|
||||
make_slice_transform(WTilda, IWTildaSliceBegin, WTildaSlice),
|
||||
make_freeze_transform(i_ytilde),
|
||||
make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
|
||||
make_freeze_transform(i_xtilde),
|
||||
make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<1>{},
|
||||
@@ -258,8 +258,8 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
|
||||
Sequence<3>{}));
|
||||
|
||||
const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
|
||||
in_n_htildaslice_wtildaslice_c_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(N, HTildaSlice, WTildaSlice)),
|
||||
in_n_htildeslice_wtildeslice_c_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
@@ -108,6 +108,28 @@ struct ConvParams
|
||||
input_right_pads(2, 1)
|
||||
{
|
||||
}
|
||||
ConvParams(ck::index_t n_dim_spatial,
|
||||
ck::index_t n,
|
||||
ck::index_t k,
|
||||
ck::index_t c,
|
||||
std::vector<ck::index_t> filter_lengths,
|
||||
std::vector<ck::index_t> input_lengths,
|
||||
std::vector<ck::index_t> conv_strides,
|
||||
std::vector<ck::index_t> conv_dilations,
|
||||
std::vector<ck::index_t> left_pads,
|
||||
std::vector<ck::index_t> right_pads)
|
||||
: num_dim_spatial(n_dim_spatial),
|
||||
N(n),
|
||||
K(k),
|
||||
C(c),
|
||||
filter_spatial_lengths(filter_lengths),
|
||||
input_spatial_lengths(input_lengths),
|
||||
conv_filter_strides(conv_strides),
|
||||
conv_filter_dilations(conv_dilations),
|
||||
input_left_pads(left_pads),
|
||||
input_right_pads(right_pads)
|
||||
{
|
||||
}
|
||||
|
||||
ck::index_t num_dim_spatial;
|
||||
ck::index_t N;
|
||||
@@ -206,7 +228,7 @@ HostTensorDescriptor GetHostTensorDescriptor(const std::vector<std::size_t>& dim
|
||||
return HostTensorDescriptor(
|
||||
dims,
|
||||
std::vector<std::size_t>{
|
||||
C * dims[2] * dims[3] * dims[4], 1, C * dims[3] * dims[4], C * dims[4], C});
|
||||
C * dims[2] * dims[3] * dims[4], 1, dims[3] * dims[4] * C, dims[4] * C, C});
|
||||
}
|
||||
|
||||
std::stringstream err_msg;
|
||||
|
||||
@@ -95,8 +95,8 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
|
||||
std::vector<ck::index_t> conv_filter_dilations,
|
||||
std::vector<ck::index_t> input_left_pads,
|
||||
std::vector<ck::index_t> input_right_pads,
|
||||
index_t i_ytilda,
|
||||
index_t i_xtilda)
|
||||
index_t i_ytilde,
|
||||
index_t i_xtilde)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
@@ -177,34 +177,34 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
|
||||
const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
|
||||
const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
|
||||
|
||||
const auto YTilda = ConvStrideH / GcdStrideDilationH;
|
||||
const auto XTilda = ConvStrideW / GcdStrideDilationW;
|
||||
const auto YTilde = ConvStrideH / GcdStrideDilationH;
|
||||
const auto XTilde = ConvStrideW / GcdStrideDilationW;
|
||||
|
||||
const auto YDot = math::integer_divide_ceil(Y, YTilda);
|
||||
const auto XDot = math::integer_divide_ceil(X, XTilda);
|
||||
const auto YDot = math::integer_divide_ceil(Y, YTilde);
|
||||
const auto XDot = math::integer_divide_ceil(X, XTilde);
|
||||
|
||||
const auto HTilda =
|
||||
const auto HTilde =
|
||||
Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH);
|
||||
const auto WTilda =
|
||||
const auto WTilde =
|
||||
Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW);
|
||||
|
||||
// only work on HTilda and WTilda that contribute to non-padding area of input tensor
|
||||
const auto IHTildaSliceBegin = math::integer_divide_floor(
|
||||
math::max(I0, InLeftPadH - ConvDilationH * (YTilda - I1)), ConvStrideH);
|
||||
const auto IWTildaSliceBegin = math::integer_divide_floor(
|
||||
math::max(I0, InLeftPadW - ConvDilationW * (XTilda - I1)), ConvStrideW);
|
||||
// only work on HTilde and WTilde that contribute to non-padding area of input tensor
|
||||
const auto IHTildeSliceBegin = math::integer_divide_floor(
|
||||
math::max(I0, InLeftPadH - ConvDilationH * (YTilde - I1)), ConvStrideH);
|
||||
const auto IWTildeSliceBegin = math::integer_divide_floor(
|
||||
math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW);
|
||||
|
||||
const auto IHTildaSliceEnd = math::min(
|
||||
HTilda, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1);
|
||||
const auto IWTildaSliceEnd = math::min(
|
||||
WTilda, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1);
|
||||
const auto IHTildeSliceEnd = math::min(
|
||||
HTilde, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1);
|
||||
const auto IWTildeSliceEnd = math::min(
|
||||
WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1);
|
||||
|
||||
const auto HTildaSlice = IHTildaSliceEnd - IHTildaSliceBegin;
|
||||
const auto WTildaSlice = IWTildaSliceEnd - IWTildaSliceBegin;
|
||||
const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin;
|
||||
const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin;
|
||||
|
||||
// GemmK is different for each GEMM
|
||||
const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilda, YTilda);
|
||||
const auto XDotSlice = math::integer_divide_ceil(X - i_xtilda, XTilda);
|
||||
const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde);
|
||||
const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde);
|
||||
|
||||
// A: output tensor
|
||||
const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor(
|
||||
@@ -216,26 +216,26 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto out_n_ydot_htilda_xdot_wtilda_k_grid_desc = transform_tensor_descriptor(
|
||||
const auto out_n_ydot_htilde_xdot_wtilde_k_grid_desc = transform_tensor_descriptor(
|
||||
out_n_hop_wop_k_grid_desc,
|
||||
make_tuple(
|
||||
make_pass_through_transform(N),
|
||||
make_embed_transform(make_tuple(YDot, HTilda),
|
||||
make_embed_transform(make_tuple(YDot, HTilde),
|
||||
make_tuple(-ConvDilationH / GcdStrideDilationH, I1)),
|
||||
make_embed_transform(make_tuple(XDot, WTilda),
|
||||
make_embed_transform(make_tuple(XDot, WTilde),
|
||||
make_tuple(-ConvDilationW / GcdStrideDilationW, I1)),
|
||||
make_pass_through_transform(K)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
|
||||
|
||||
const auto out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k0_k1_grid_desc =
|
||||
const auto out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc =
|
||||
transform_tensor_descriptor(
|
||||
out_n_ydot_htilda_xdot_wtilda_k_grid_desc,
|
||||
out_n_ydot_htilde_xdot_wtilde_k_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_slice_transform(YDot, I0, YDotSlice),
|
||||
make_slice_transform(HTilda, IHTildaSliceBegin, HTildaSlice),
|
||||
make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
|
||||
make_slice_transform(XDot, I0, XDotSlice),
|
||||
make_slice_transform(WTilda, IWTildaSliceBegin, WTildaSlice),
|
||||
make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
|
||||
make_unmerge_transform(make_tuple(K0, K1))),
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<1>{},
|
||||
@@ -251,32 +251,32 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
|
||||
Sequence<5, 6>{}));
|
||||
|
||||
const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k0_k1_grid_desc,
|
||||
out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)),
|
||||
make_merge_transform(make_tuple(N, HTildaSlice, WTildaSlice)),
|
||||
make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)),
|
||||
make_pass_through_transform(K1)),
|
||||
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}, Sequence<6>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
|
||||
// B weight tensor
|
||||
const auto wei_k_ydot_ytilda_xdot_xtilda_c_grid_desc = transform_tensor_descriptor(
|
||||
const auto wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc = transform_tensor_descriptor(
|
||||
wei_k_y_x_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(K),
|
||||
make_embed_transform(make_tuple(YDot, YTilda),
|
||||
make_embed_transform(make_tuple(YDot, YTilde),
|
||||
make_tuple(ConvStrideH / GcdStrideDilationH, I1)),
|
||||
make_embed_transform(make_tuple(XDot, XTilda),
|
||||
make_embed_transform(make_tuple(XDot, XTilde),
|
||||
make_tuple(ConvStrideW / GcdStrideDilationW, I1)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
|
||||
|
||||
const auto wei_k0_k1_ydotslice_xdotslice_c_grid_desc =
|
||||
transform_tensor_descriptor(wei_k_ydot_ytilda_xdot_xtilda_c_grid_desc,
|
||||
transform_tensor_descriptor(wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(K0, K1)),
|
||||
make_slice_transform(YDot, I0, YDotSlice),
|
||||
make_slice_transform(XDot, I0, XDotSlice),
|
||||
make_freeze_transform(i_ytilda),
|
||||
make_freeze_transform(i_xtilda),
|
||||
make_freeze_transform(i_ytilde),
|
||||
make_freeze_transform(i_xtilde),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<1>{},
|
||||
@@ -309,24 +309,24 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto in_n_ytilda_htilda_xtilda_wtilda_c_grid_desc = transform_tensor_descriptor(
|
||||
const auto in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hip_wip_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_embed_transform(make_tuple(YTilda, HTilda),
|
||||
make_embed_transform(make_tuple(YTilde, HTilde),
|
||||
make_tuple(ConvDilationH, ConvStrideH)),
|
||||
make_embed_transform(make_tuple(XTilda, WTilda),
|
||||
make_embed_transform(make_tuple(XTilde, WTilde),
|
||||
make_tuple(ConvDilationW, ConvStrideW)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
|
||||
|
||||
const auto in_n_htildaslice_wtildaslice_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_ytilda_htilda_xtilda_wtilda_c_grid_desc,
|
||||
const auto in_n_htildeslice_wtildeslice_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_freeze_transform(i_ytilda),
|
||||
make_slice_transform(HTilda, IHTildaSliceBegin, HTildaSlice),
|
||||
make_freeze_transform(i_xtilda),
|
||||
make_slice_transform(WTilda, IWTildaSliceBegin, WTildaSlice),
|
||||
make_freeze_transform(i_ytilde),
|
||||
make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
|
||||
make_freeze_transform(i_xtilde),
|
||||
make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<1>{},
|
||||
@@ -342,8 +342,8 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
|
||||
Sequence<3>{}));
|
||||
|
||||
const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
|
||||
in_n_htildaslice_wtildaslice_c_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(N, HTildaSlice, WTildaSlice)),
|
||||
in_n_htildeslice_wtildeslice_c_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
@@ -452,18 +452,18 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
|
||||
const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
|
||||
const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
|
||||
|
||||
const auto YTilda = ConvStrideH / GcdStrideDilationH;
|
||||
const auto XTilda = ConvStrideW / GcdStrideDilationW;
|
||||
const auto YTilde = ConvStrideH / GcdStrideDilationH;
|
||||
const auto XTilde = ConvStrideW / GcdStrideDilationW;
|
||||
|
||||
for(index_t i_ytilda = 0; i_ytilda < YTilda; ++i_ytilda)
|
||||
for(index_t i_ytilde = 0; i_ytilde < YTilde; ++i_ytilde)
|
||||
{
|
||||
for(index_t i_xtilda = 0; i_xtilda < XTilda; ++i_xtilda)
|
||||
for(index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde)
|
||||
{
|
||||
// check slice is valid
|
||||
const index_t Y = filter_spatial_lengths_[0];
|
||||
const index_t X = filter_spatial_lengths_[1];
|
||||
const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilda, YTilda);
|
||||
const auto XDotSlice = math::integer_divide_ceil(X - i_xtilda, XTilda);
|
||||
const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde);
|
||||
const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde);
|
||||
if(YDotSlice * XDotSlice <= 0)
|
||||
{
|
||||
continue;
|
||||
@@ -480,8 +480,8 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
i_ytilda,
|
||||
i_xtilda);
|
||||
i_ytilde,
|
||||
i_xtilde);
|
||||
a_grid_desc_k0_m_k1_container_.push_back(descs[I0]);
|
||||
b_grid_desc_k0_n_k1_container_.push_back(descs[I1]);
|
||||
c_grid_desc_m_n_container_.push_back(descs[I2]);
|
||||
@@ -533,7 +533,6 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
|
||||
|
||||
float Run(const Argument& arg, int nrepeat = 1)
|
||||
{
|
||||
nrepeat = 1;
|
||||
float ave_time = 0;
|
||||
for(size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++)
|
||||
{
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -100,7 +100,6 @@ struct NDHWK : public BaseTensorLayout
|
||||
{
|
||||
static constexpr const char* name = "NDHWK";
|
||||
};
|
||||
|
||||
struct NCDHW : public BaseTensorLayout
|
||||
{
|
||||
static constexpr const char* name = "NCDHW";
|
||||
|
||||
@@ -303,14 +303,14 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk
|
||||
const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
|
||||
const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
|
||||
|
||||
const auto YTilda = ConvStrideH / GcdStrideDilationH;
|
||||
const auto XTilda = ConvStrideW / GcdStrideDilationW;
|
||||
const auto YTilde = ConvStrideH / GcdStrideDilationH;
|
||||
const auto XTilde = ConvStrideW / GcdStrideDilationW;
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
for(index_t i_ytilda = 0; i_ytilda < YTilda; ++i_ytilda)
|
||||
for(index_t i_ytilde = 0; i_ytilde < YTilde; ++i_ytilde)
|
||||
{
|
||||
for(index_t i_xtilda = 0; i_xtilda < XTilda; ++i_xtilda)
|
||||
for(index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde)
|
||||
{
|
||||
const auto descs =
|
||||
transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
|
||||
@@ -321,8 +321,8 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk
|
||||
conv_dilations,
|
||||
in_left_pads,
|
||||
in_right_pads,
|
||||
i_ytilda,
|
||||
i_xtilda,
|
||||
i_ytilde,
|
||||
i_xtilde,
|
||||
Number<GemmK1>{});
|
||||
|
||||
const auto out_gemmk0_gemmm_gemmk1_grid_desc = descs[I0];
|
||||
|
||||
@@ -14,17 +14,20 @@ namespace host {
|
||||
template <typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename OutDataType,
|
||||
typename AccDataType,
|
||||
typename InElementwiseOperation,
|
||||
typename WeiElementwiseOperation,
|
||||
typename OutElementwiseOperation>
|
||||
typename OutElementwiseOperation,
|
||||
ck::index_t NumDimSpatial = 2,
|
||||
typename std::enable_if<NumDimSpatial >= 1 && NumDimSpatial <= 3, bool>::type = false>
|
||||
struct ReferenceConvBwdData : public device::BaseOperator
|
||||
{
|
||||
// Argument
|
||||
struct Argument : public device::BaseArgument
|
||||
{
|
||||
Argument(Tensor<InDataType>& in_n_c_hi_wi,
|
||||
const Tensor<WeiDataType>& wei_k_c_y_x,
|
||||
const Tensor<OutDataType>& out_n_k_ho_wo,
|
||||
Argument(Tensor<InDataType>& input,
|
||||
const Tensor<WeiDataType>& weight,
|
||||
const Tensor<OutDataType>& output,
|
||||
std::vector<ck::index_t> conv_filter_strides,
|
||||
std::vector<ck::index_t> conv_filter_dilations,
|
||||
std::vector<ck::index_t> input_left_pads,
|
||||
@@ -32,9 +35,9 @@ struct ReferenceConvBwdData : public device::BaseOperator
|
||||
InElementwiseOperation in_element_op,
|
||||
WeiElementwiseOperation wei_element_op,
|
||||
OutElementwiseOperation out_element_op)
|
||||
: in_n_c_hi_wi_{in_n_c_hi_wi},
|
||||
wei_k_c_y_x_{wei_k_c_y_x},
|
||||
out_n_k_ho_wo_{out_n_k_ho_wo},
|
||||
: input_{input},
|
||||
weight_{weight},
|
||||
output_{output},
|
||||
conv_strides_{conv_filter_strides},
|
||||
conv_dilations_{conv_filter_dilations},
|
||||
in_left_pads_{input_left_pads},
|
||||
@@ -45,9 +48,9 @@ struct ReferenceConvBwdData : public device::BaseOperator
|
||||
{
|
||||
}
|
||||
|
||||
Tensor<InDataType>& in_n_c_hi_wi_;
|
||||
const Tensor<WeiDataType>& wei_k_c_y_x_;
|
||||
const Tensor<OutDataType>& out_n_k_ho_wo_;
|
||||
Tensor<InDataType>& input_;
|
||||
const Tensor<WeiDataType>& weight_;
|
||||
const Tensor<OutDataType>& output_;
|
||||
|
||||
std::vector<index_t> conv_strides_;
|
||||
std::vector<index_t> conv_dilations_;
|
||||
@@ -66,67 +69,199 @@ struct ReferenceConvBwdData : public device::BaseOperator
|
||||
|
||||
float Run(const Argument& arg)
|
||||
{
|
||||
auto f_nchw = [&](auto n, auto c, auto hi, auto wi) {
|
||||
std::size_t K = arg.wei_k_c_y_x_.mDesc.GetLengths()[0];
|
||||
std::size_t Y = arg.wei_k_c_y_x_.mDesc.GetLengths()[2];
|
||||
std::size_t X = arg.wei_k_c_y_x_.mDesc.GetLengths()[3];
|
||||
if constexpr(NumDimSpatial == 1)
|
||||
{
|
||||
auto f_nchw = [&](auto n, auto c, auto wi) {
|
||||
std::size_t K = arg.weight_.mDesc.GetLengths()[0];
|
||||
std::size_t X = arg.weight_.mDesc.GetLengths()[2];
|
||||
std::size_t Wo = arg.output_.mDesc.GetLengths()[2];
|
||||
|
||||
std::size_t Ho = arg.out_n_k_ho_wo_.mDesc.GetLengths()[2];
|
||||
std::size_t Wo = arg.out_n_k_ho_wo_.mDesc.GetLengths()[3];
|
||||
AccDataType v_acc = 0;
|
||||
|
||||
float v_acc = 0;
|
||||
|
||||
for(int y = 0; y < Y; ++y)
|
||||
{
|
||||
int h_tmp = hi + arg.in_left_pads_[0] - y * arg.conv_dilations_[0];
|
||||
if(h_tmp % arg.conv_strides_[0] == 0)
|
||||
for(int x = 0; x < X; ++x)
|
||||
{
|
||||
int ho = h_tmp / arg.conv_strides_[0];
|
||||
if(ho >= 0 && ho < Ho)
|
||||
int w_tmp = wi + arg.in_left_pads_[0] - x * arg.conv_dilations_[0];
|
||||
if(w_tmp % arg.conv_strides_[0] == 0)
|
||||
{
|
||||
for(int x = 0; x < X; ++x)
|
||||
int wo = w_tmp / arg.conv_strides_[0];
|
||||
if(wo >= 0 && wo < Wo)
|
||||
{
|
||||
int w_tmp = wi + arg.in_left_pads_[1] - x * arg.conv_dilations_[1];
|
||||
if(w_tmp % arg.conv_strides_[1] == 0)
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
int wo = w_tmp / arg.conv_strides_[1];
|
||||
if(wo >= 0 && wo < Wo)
|
||||
AccDataType v_out = 0;
|
||||
AccDataType v_wei = 0;
|
||||
|
||||
arg.out_element_op_(
|
||||
v_out,
|
||||
ck::type_convert<AccDataType>(arg.output_(n, k, wo)));
|
||||
arg.wei_element_op_(
|
||||
v_wei, ck::type_convert<AccDataType>(arg.weight_(k, c, x)));
|
||||
|
||||
v_acc += v_out * v_wei;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
float v_in;
|
||||
arg.in_element_op_(v_in, v_acc);
|
||||
arg.input_(n, c, wi) = ck::type_convert<InDataType>(v_in);
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_nchw,
|
||||
arg.input_.mDesc.GetLengths()[0],
|
||||
arg.input_.mDesc.GetLengths()[1],
|
||||
arg.input_.mDesc.GetLengths()[2])(
|
||||
std::thread::hardware_concurrency());
|
||||
|
||||
return 0;
|
||||
}
|
||||
else if constexpr(NumDimSpatial == 2)
|
||||
{
|
||||
auto f_nchw = [&](auto n, auto c, auto hi, auto wi) {
|
||||
std::size_t K = arg.weight_.mDesc.GetLengths()[0];
|
||||
std::size_t Y = arg.weight_.mDesc.GetLengths()[2];
|
||||
std::size_t X = arg.weight_.mDesc.GetLengths()[3];
|
||||
|
||||
std::size_t Ho = arg.output_.mDesc.GetLengths()[2];
|
||||
std::size_t Wo = arg.output_.mDesc.GetLengths()[3];
|
||||
|
||||
AccDataType v_acc = 0;
|
||||
|
||||
for(int y = 0; y < Y; ++y)
|
||||
{
|
||||
int h_tmp = hi + arg.in_left_pads_[0] - y * arg.conv_dilations_[0];
|
||||
if(h_tmp % arg.conv_strides_[0] == 0)
|
||||
{
|
||||
int ho = h_tmp / arg.conv_strides_[0];
|
||||
if(ho >= 0 && ho < Ho)
|
||||
{
|
||||
for(int x = 0; x < X; ++x)
|
||||
{
|
||||
int w_tmp =
|
||||
wi + arg.in_left_pads_[1] - x * arg.conv_dilations_[1];
|
||||
if(w_tmp % arg.conv_strides_[1] == 0)
|
||||
{
|
||||
for(int k = 0; k < K; ++k)
|
||||
int wo = w_tmp / arg.conv_strides_[1];
|
||||
if(wo >= 0 && wo < Wo)
|
||||
{
|
||||
float v_out = 0;
|
||||
float v_wei = 0;
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
AccDataType v_out = 0;
|
||||
AccDataType v_wei = 0;
|
||||
|
||||
arg.out_element_op_(
|
||||
v_out,
|
||||
ck::type_convert<float>(
|
||||
arg.out_n_k_ho_wo_(n, k, ho, wo)));
|
||||
arg.wei_element_op_(v_wei,
|
||||
ck::type_convert<float>(
|
||||
arg.wei_k_c_y_x_(k, c, y, x)));
|
||||
arg.out_element_op_(v_out,
|
||||
ck::type_convert<AccDataType>(
|
||||
arg.output_(n, k, ho, wo)));
|
||||
arg.wei_element_op_(v_wei,
|
||||
ck::type_convert<AccDataType>(
|
||||
arg.weight_(k, c, y, x)));
|
||||
|
||||
v_acc += v_out * v_wei;
|
||||
v_acc += v_out * v_wei;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
float v_in;
|
||||
arg.in_element_op_(v_in, v_acc);
|
||||
arg.in_n_c_hi_wi_(n, c, hi, wi) = ck::type_convert<InDataType>(v_in);
|
||||
};
|
||||
AccDataType v_in;
|
||||
arg.in_element_op_(v_in, v_acc);
|
||||
arg.input_(n, c, hi, wi) = ck::type_convert<InDataType>(v_in);
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_nchw,
|
||||
arg.in_n_c_hi_wi_.mDesc.GetLengths()[0],
|
||||
arg.in_n_c_hi_wi_.mDesc.GetLengths()[1],
|
||||
arg.in_n_c_hi_wi_.mDesc.GetLengths()[2],
|
||||
arg.in_n_c_hi_wi_.mDesc.GetLengths()[3])(
|
||||
std::thread::hardware_concurrency());
|
||||
make_ParallelTensorFunctor(f_nchw,
|
||||
arg.input_.mDesc.GetLengths()[0],
|
||||
arg.input_.mDesc.GetLengths()[1],
|
||||
arg.input_.mDesc.GetLengths()[2],
|
||||
arg.input_.mDesc.GetLengths()[3])(
|
||||
std::thread::hardware_concurrency());
|
||||
|
||||
return 0;
|
||||
return 0;
|
||||
}
|
||||
else if constexpr(NumDimSpatial == 3)
|
||||
{
|
||||
auto f_nchw = [&](auto n, auto c, auto di, auto hi, auto wi) {
|
||||
std::size_t K = arg.weight_.mDesc.GetLengths()[0];
|
||||
std::size_t Z = arg.weight_.mDesc.GetLengths()[2];
|
||||
std::size_t Y = arg.weight_.mDesc.GetLengths()[3];
|
||||
std::size_t X = arg.weight_.mDesc.GetLengths()[4];
|
||||
|
||||
std::size_t Do = arg.output_.mDesc.GetLengths()[2];
|
||||
std::size_t Ho = arg.output_.mDesc.GetLengths()[3];
|
||||
std::size_t Wo = arg.output_.mDesc.GetLengths()[4];
|
||||
|
||||
AccDataType v_acc = 0;
|
||||
|
||||
for(int z = 0; z < Z; ++z)
|
||||
{
|
||||
int d_tmp = di + arg.in_left_pads_[0] - z * arg.conv_dilations_[0];
|
||||
if(d_tmp % arg.conv_strides_[0] == 0)
|
||||
{
|
||||
int do_ = d_tmp / arg.conv_strides_[0];
|
||||
if(do_ >= 0 && do_ < Do)
|
||||
{
|
||||
for(int y = 0; y < Y; ++y)
|
||||
{
|
||||
int h_tmp =
|
||||
hi + arg.in_left_pads_[1] - y * arg.conv_dilations_[1];
|
||||
if(h_tmp % arg.conv_strides_[1] == 0)
|
||||
{
|
||||
int ho = h_tmp / arg.conv_strides_[1];
|
||||
if(ho >= 0 && ho < Ho)
|
||||
{
|
||||
for(int x = 0; x < X; ++x)
|
||||
{
|
||||
int w_tmp = wi + arg.in_left_pads_[2] -
|
||||
x * arg.conv_dilations_[2];
|
||||
if(w_tmp % arg.conv_strides_[2] == 0)
|
||||
{
|
||||
int wo = w_tmp / arg.conv_strides_[2];
|
||||
if(wo >= 0 && wo < Wo)
|
||||
{
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
AccDataType v_out = 0;
|
||||
AccDataType v_wei = 0;
|
||||
|
||||
arg.out_element_op_(
|
||||
v_out,
|
||||
ck::type_convert<AccDataType>(
|
||||
arg.output_(
|
||||
n, k, do_, ho, wo)));
|
||||
arg.wei_element_op_(
|
||||
v_wei,
|
||||
ck::type_convert<AccDataType>(
|
||||
arg.weight_(k, c, z, y, x)));
|
||||
|
||||
v_acc += v_out * v_wei;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
AccDataType v_in;
|
||||
arg.in_element_op_(v_in, v_acc);
|
||||
arg.input_(n, c, di, hi, wi) = ck::type_convert<InDataType>(v_in);
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_nchw,
|
||||
arg.input_.mDesc.GetLengths()[0],
|
||||
arg.input_.mDesc.GetLengths()[1],
|
||||
arg.input_.mDesc.GetLengths()[2],
|
||||
arg.input_.mDesc.GetLengths()[3],
|
||||
arg.input_.mDesc.GetLengths()[4])(
|
||||
std::thread::hardware_concurrency());
|
||||
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
float Run(const device::BaseArgument* p_arg, int) override
|
||||
@@ -143,9 +278,9 @@ struct ReferenceConvBwdData : public device::BaseOperator
|
||||
|
||||
bool IsSupportedArgument(const device::BaseArgument*) override { return true; }
|
||||
|
||||
static auto MakeArgument(Tensor<InDataType>& in_n_c_hi_wi,
|
||||
const Tensor<WeiDataType>& wei_k_c_y_x,
|
||||
const Tensor<OutDataType>& out_n_k_ho_wo,
|
||||
static auto MakeArgument(Tensor<InDataType>& input,
|
||||
const Tensor<WeiDataType>& weight,
|
||||
const Tensor<OutDataType>& output,
|
||||
std::vector<ck::index_t> conv_filter_strides,
|
||||
std::vector<ck::index_t> conv_filter_dilations,
|
||||
std::vector<ck::index_t> input_left_pads,
|
||||
@@ -154,9 +289,9 @@ struct ReferenceConvBwdData : public device::BaseOperator
|
||||
WeiElementwiseOperation wei_element_op,
|
||||
OutElementwiseOperation out_element_op)
|
||||
{
|
||||
return Argument{in_n_c_hi_wi,
|
||||
wei_k_c_y_x,
|
||||
out_n_k_ho_wo,
|
||||
return Argument{input,
|
||||
weight,
|
||||
output,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
|
||||
@@ -37,4 +37,5 @@ add_subdirectory(conv2d_fwd_bias_relu_add)
|
||||
add_subdirectory(conv2d_fwd_bias_relu_atomic_add)
|
||||
add_subdirectory(conv2d_bwd_data)
|
||||
add_subdirectory(reduce)
|
||||
add_subdirectory(convnd_bwd_data)
|
||||
add_subdirectory(grouped_gemm)
|
||||
|
||||
@@ -0,0 +1,22 @@
|
||||
# device_convnd_bwd_data_instance
|
||||
set(DEVICE_CONVND_BWD_DATA_INSTANCE_SOURCE
|
||||
device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f16_instance.cpp;
|
||||
device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instance.cpp;
|
||||
device_conv1d_bwd_data_xdl_nwc_kxc_nwk_bf16_instance.cpp;
|
||||
device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instance.cpp;
|
||||
device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instance.cpp;
|
||||
device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instance.cpp;
|
||||
device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp;
|
||||
device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instance.cpp;
|
||||
device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f16_instance.cpp;
|
||||
device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instance.cpp;
|
||||
device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instance.cpp;
|
||||
device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instance.cpp;
|
||||
)
|
||||
|
||||
add_library(device_convnd_bwd_data_instance SHARED ${DEVICE_CONVND_BWD_DATA_INSTANCE_SOURCE})
|
||||
target_compile_features(device_convnd_bwd_data_instance PUBLIC)
|
||||
set_target_properties(device_convnd_bwd_data_instance PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||
install(TARGETS device_convnd_bwd_data_instance LIBRARY DESTINATION lib)
|
||||
|
||||
clang_tidy_check(device_convnd_bwd_data_instance)
|
||||
@@ -0,0 +1,84 @@
|
||||
#include <stdlib.h>
|
||||
#include "config.hpp"
|
||||
#include "device_convnd_bwd_data_xdl_ndhwc_kzyxc_ndhwk.hpp"
|
||||
#include "element_wise_operation.hpp"
|
||||
#include "device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_conv2d_bwd_data_instance {
|
||||
|
||||
using BF16 = ushort;
|
||||
using F32 = float;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
static constexpr auto ConvBwdDataDefault =
|
||||
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization_t::Default;
|
||||
|
||||
static constexpr auto ConvBwdDataFilter1x1Stride1Pad0 =
|
||||
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization_t::Filter1x1Stride1Pad0;
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k]
|
||||
using device_conv1d_bwd_data_xdl_nwc_kxc_nwk_bf16_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//#############################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Num| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
|
||||
//#############################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Dim| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
|
||||
//#############################################################################| | | | | Operation| Operation| Operation| Specialization|Spatial| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
|
||||
//#############################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
using device_conv1d_bwd_data_xdl_nwc_kxc_nwk_1x1_s1_p0_bf16_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//#############################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Num| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
|
||||
//#############################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Dim| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
|
||||
//#############################################################################| | | | | Operation| Operation| Operation| Specialization|Spatial| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
|
||||
//#############################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_bf16_instances(
|
||||
std::vector<DeviceConvBwdDataPtr<PassThrough, PassThrough, PassThrough>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_conv1d_bwd_data_xdl_nwc_kxc_nwk_bf16_instances{});
|
||||
add_device_operation_instances(
|
||||
instances, device_conv1d_bwd_data_xdl_nwc_kxc_nwk_1x1_s1_p0_bf16_instances{});
|
||||
}
|
||||
|
||||
} // namespace device_conv2d_bwd_data_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,86 @@
|
||||
#include <stdlib.h>
|
||||
#include "config.hpp"
|
||||
#include "device_convnd_bwd_data_xdl_ndhwc_kzyxc_ndhwk.hpp"
|
||||
#include "element_wise_operation.hpp"
|
||||
#include "device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_conv2d_bwd_data_instance {
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
static constexpr auto ConvBwdDataDefault =
|
||||
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization_t::Default;
|
||||
|
||||
static constexpr auto ConvBwdDataFilter1x1Stride1Pad0 =
|
||||
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization_t::Filter1x1Stride1Pad0;
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k]
|
||||
using device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f16_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//#############################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Num| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
|
||||
//#############################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Dim| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
|
||||
//#############################################################################| | | | | Operation| Operation| Operation| Specialization|Spatial| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
|
||||
//#############################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>,
|
||||
#if 1
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>,
|
||||
#endif
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
using device_conv1d_bwd_data_xdl_nwc_kxc_nwk_1x1_s1_p0_f16_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//#############################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Num| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
|
||||
//#############################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Dim| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
|
||||
//#############################################################################| | | | | Operation| Operation| Operation| Specialization|Spatial| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
|
||||
//#############################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f16_instances(
|
||||
std::vector<DeviceConvBwdDataPtr<PassThrough, PassThrough, PassThrough>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f16_instances{});
|
||||
add_device_operation_instances(
|
||||
instances, device_conv1d_bwd_data_xdl_nwc_kxc_nwk_1x1_s1_p0_f16_instances{});
|
||||
}
|
||||
|
||||
} // namespace device_conv2d_bwd_data_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,83 @@
|
||||
#include <stdlib.h>
|
||||
#include "config.hpp"
|
||||
#include "device_convnd_bwd_data_xdl_ndhwc_kzyxc_ndhwk.hpp"
|
||||
#include "element_wise_operation.hpp"
|
||||
#include "device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_conv2d_bwd_data_instance {
|
||||
|
||||
using F32 = float;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
static constexpr auto ConvBwdDataDefault =
|
||||
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization_t::Default;
|
||||
|
||||
static constexpr auto ConvBwdDataFilter1x1Stride1Pad0 =
|
||||
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization_t::Filter1x1Stride1Pad0;
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k]
|
||||
using device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//#############################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Num| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
|
||||
//#############################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Dim| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
|
||||
//#############################################################################| | | | | Operation| Operation| Operation| Specialization|Spatial| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
|
||||
//#############################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 4, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 4, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
using device_conv1d_bwd_data_xdl_nwc_kxc_nwk_1x1_s1_p0_f32_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//#############################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Num| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
|
||||
//#############################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Dim| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
|
||||
//#############################################################################| | | | | Operation| Operation| Operation| Specialization|Spatial| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
|
||||
//#############################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 4, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 4, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instances(
|
||||
std::vector<DeviceConvBwdDataPtr<PassThrough, PassThrough, PassThrough>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instances{});
|
||||
add_device_operation_instances(
|
||||
instances, device_conv1d_bwd_data_xdl_nwc_kxc_nwk_1x1_s1_p0_f32_instances{});
|
||||
}
|
||||
|
||||
} // namespace device_conv2d_bwd_data_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,86 @@
|
||||
#include <stdlib.h>
|
||||
#include "config.hpp"
|
||||
#include "device_convnd_bwd_data_xdl_ndhwc_kzyxc_ndhwk.hpp"
|
||||
#include "element_wise_operation.hpp"
|
||||
#include "device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_conv2d_bwd_data_instance {
|
||||
|
||||
using DataType = int8_t;
|
||||
using AccType = int32_t;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
static constexpr auto ConvBwdDataDefault =
|
||||
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization_t::Default;
|
||||
|
||||
static constexpr auto ConvBwdDataFilter1x1Stride1Pad0 =
|
||||
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization_t::Filter1x1Stride1Pad0;
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k]
|
||||
using device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//#############################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Num| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
|
||||
//#############################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Dim| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
|
||||
//#############################################################################| | | | | Operation| Operation| Operation| Specialization|Spatial| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
|
||||
//#############################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>,
|
||||
#if 1
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>,
|
||||
#endif
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 64, 64, 64, 4, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 16, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 128, 128, 32, 4, 16, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 16, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
using device_conv1d_bwd_data_xdl_nwc_kxc_nwk_1x1_s1_p0_int8_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//##############################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Num| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
|
||||
//##############################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Dim| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
|
||||
//##############################################################################| | | | | Operation| Operation| Operation| Specialization|Spatial| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
|
||||
//##############################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 64, 64, 64, 4, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 16, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 128, 128, 32, 4, 16, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 16, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instances(
|
||||
std::vector<DeviceConvBwdDataPtr<PassThrough, PassThrough, PassThrough>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instances{});
|
||||
add_device_operation_instances(
|
||||
instances, device_conv1d_bwd_data_xdl_nwc_kxc_nwk_1x1_s1_p0_int8_instances{});
|
||||
}
|
||||
|
||||
} // namespace device_conv2d_bwd_data_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,84 @@
|
||||
#include <stdlib.h>
|
||||
#include "config.hpp"
|
||||
#include "device_convnd_bwd_data_xdl_ndhwc_kzyxc_ndhwk.hpp"
|
||||
#include "element_wise_operation.hpp"
|
||||
#include "device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_conv2d_bwd_data_instance {
|
||||
|
||||
using BF16 = ushort;
|
||||
using F32 = float;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
static constexpr auto ConvBwdDataDefault =
|
||||
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization_t::Default;
|
||||
|
||||
static constexpr auto ConvBwdDataFilter1x1Stride1Pad0 =
|
||||
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization_t::Filter1x1Stride1Pad0;
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k]
|
||||
using device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//#############################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Num| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
|
||||
//#############################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Dim| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
|
||||
//#############################################################################| | | | | Operation| Operation| Operation| Specialization|Spatial| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
|
||||
//#############################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
using device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_bf16_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//#############################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Num| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
|
||||
//#############################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Dim| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
|
||||
//#############################################################################| | | | | Operation| Operation| Operation| Specialization|Spatial| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
|
||||
//#############################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances(
|
||||
std::vector<DeviceConvBwdDataPtr<PassThrough, PassThrough, PassThrough>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances{});
|
||||
add_device_operation_instances(
|
||||
instances, device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_bf16_instances{});
|
||||
}
|
||||
|
||||
} // namespace device_conv2d_bwd_data_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,86 @@
|
||||
#include <stdlib.h>
|
||||
#include "config.hpp"
|
||||
#include "device_convnd_bwd_data_xdl_ndhwc_kzyxc_ndhwk.hpp"
|
||||
#include "element_wise_operation.hpp"
|
||||
#include "device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_conv2d_bwd_data_instance {
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
static constexpr auto ConvBwdDataDefault =
|
||||
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization_t::Default;
|
||||
|
||||
static constexpr auto ConvBwdDataFilter1x1Stride1Pad0 =
|
||||
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization_t::Filter1x1Stride1Pad0;
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k]
|
||||
using device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//#############################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Num| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
|
||||
//#############################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Dim| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
|
||||
//#############################################################################| | | | | Operation| Operation| Operation| Specialization|Spatial| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
|
||||
//#############################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>,
|
||||
#if 1
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>,
|
||||
#endif
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
using device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_f16_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//#############################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Num| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
|
||||
//#############################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Dim| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
|
||||
//#############################################################################| | | | | Operation| Operation| Operation| Specialization|Spatial| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
|
||||
//#############################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances(
|
||||
std::vector<DeviceConvBwdDataPtr<PassThrough, PassThrough, PassThrough>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances{});
|
||||
add_device_operation_instances(
|
||||
instances, device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_f16_instances{});
|
||||
}
|
||||
|
||||
} // namespace device_conv2d_bwd_data_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,83 @@
|
||||
#include <stdlib.h>
|
||||
#include "config.hpp"
|
||||
#include "device_convnd_bwd_data_xdl_ndhwc_kzyxc_ndhwk.hpp"
|
||||
#include "element_wise_operation.hpp"
|
||||
#include "device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_conv2d_bwd_data_instance {
|
||||
|
||||
using F32 = float;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
static constexpr auto ConvBwdDataDefault =
|
||||
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization_t::Default;
|
||||
|
||||
static constexpr auto ConvBwdDataFilter1x1Stride1Pad0 =
|
||||
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization_t::Filter1x1Stride1Pad0;
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k]
|
||||
using device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//#############################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Num| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
|
||||
//#############################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Dim| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
|
||||
//#############################################################################| | | | | Operation| Operation| Operation| Specialization|Spatial| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
|
||||
//#############################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 4, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 4, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
using device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_f32_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//#############################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Num| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
|
||||
//#############################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Dim| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
|
||||
//#############################################################################| | | | | Operation| Operation| Operation| Specialization|Spatial| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
|
||||
//#############################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 4, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 4, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances(
|
||||
std::vector<DeviceConvBwdDataPtr<PassThrough, PassThrough, PassThrough>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances{});
|
||||
add_device_operation_instances(
|
||||
instances, device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_f32_instances{});
|
||||
}
|
||||
|
||||
} // namespace device_conv2d_bwd_data_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,88 @@
|
||||
#include <stdlib.h>
|
||||
#include "config.hpp"
|
||||
#include "device_convnd_bwd_data_xdl_ndhwc_kzyxc_ndhwk.hpp"
|
||||
#include "element_wise_operation.hpp"
|
||||
#include "device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_conv2d_bwd_data_instance {
|
||||
|
||||
using DataType = int8_t;
|
||||
using AccType = int32_t;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
static constexpr auto ConvBwdDataDefault =
|
||||
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization_t::Default;
|
||||
|
||||
static constexpr auto ConvBwdDataFilter1x1Stride1Pad0 =
|
||||
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization_t::Filter1x1Stride1Pad0;
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k]
|
||||
using device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//#############################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Num| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
|
||||
//#############################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Dim| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
|
||||
//#############################################################################| | | | | Operation| Operation| Operation| Specialization|Spatial| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
|
||||
//#############################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>,
|
||||
#if 1
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>,
|
||||
#endif
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 64, 64, 64, 4, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 16, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 128, 128, 32, 4, 16, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 16, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
using device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_int8_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//##############################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Num| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
|
||||
//##############################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Dim| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
|
||||
//##############################################################################| | | | | Operation| Operation| Operation| Specialization|Spatial| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
|
||||
//##############################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>,
|
||||
#if 1
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>,
|
||||
#endif
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 64, 64, 64, 4, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 16, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 128, 128, 32, 4, 16, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 16, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances(
|
||||
std::vector<DeviceConvBwdDataPtr<PassThrough, PassThrough, PassThrough>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances{});
|
||||
add_device_operation_instances(
|
||||
instances, device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_int8_instances{});
|
||||
}
|
||||
|
||||
} // namespace device_conv2d_bwd_data_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,84 @@
|
||||
#include <stdlib.h>
|
||||
#include "config.hpp"
|
||||
#include "device_convnd_bwd_data_xdl_ndhwc_kzyxc_ndhwk.hpp"
|
||||
#include "element_wise_operation.hpp"
|
||||
#include "device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_conv2d_bwd_data_instance {
|
||||
|
||||
using BF16 = ushort;
|
||||
using F32 = float;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
static constexpr auto ConvBwdDataDefault =
|
||||
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization_t::Default;
|
||||
|
||||
static constexpr auto ConvBwdDataFilter1x1Stride1Pad0 =
|
||||
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization_t::Filter1x1Stride1Pad0;
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k]
|
||||
using device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//#############################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Num| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
|
||||
//#############################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Dim| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
|
||||
//#############################################################################| | | | | Operation| Operation| Operation| Specialization|Spatial| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
|
||||
//#############################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
using device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_1x1_s1_p0_bf16_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//#############################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Num| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
|
||||
//#############################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Dim| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
|
||||
//#############################################################################| | | | | Operation| Operation| Operation| Specialization|Spatial| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | ./ | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
|
||||
//#############################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instances(
|
||||
std::vector<DeviceConvBwdDataPtr<PassThrough, PassThrough, PassThrough>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instances{});
|
||||
add_device_operation_instances(
|
||||
instances, device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_1x1_s1_p0_bf16_instances{});
|
||||
}
|
||||
|
||||
} // namespace device_conv2d_bwd_data_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,86 @@
|
||||
#include <stdlib.h>
|
||||
#include "config.hpp"
|
||||
#include "device_convnd_bwd_data_xdl_ndhwc_kzyxc_ndhwk.hpp"
|
||||
#include "element_wise_operation.hpp"
|
||||
#include "device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_conv2d_bwd_data_instance {
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
static constexpr auto ConvBwdDataDefault =
|
||||
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization_t::Default;
|
||||
|
||||
static constexpr auto ConvBwdDataFilter1x1Stride1Pad0 =
|
||||
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization_t::Filter1x1Stride1Pad0;
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k]
|
||||
using device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f16_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//#############################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Num| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
|
||||
//#############################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Dim| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
|
||||
//#############################################################################| | | | | Operation| Operation| Operation| Specialization|Spatial| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
|
||||
//#############################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>,
|
||||
#if 1
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>,
|
||||
#endif
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
using device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_1x1_s1_p0_f16_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//#############################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Num| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
|
||||
//#############################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Dim| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
|
||||
//#############################################################################| | | | | Operation| Operation| Operation| Specialization|Spatial| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
|
||||
//#############################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f16_instances(
|
||||
std::vector<DeviceConvBwdDataPtr<PassThrough, PassThrough, PassThrough>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f16_instances{});
|
||||
add_device_operation_instances(
|
||||
instances, device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_1x1_s1_p0_f16_instances{});
|
||||
}
|
||||
|
||||
} // namespace device_conv2d_bwd_data_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,83 @@
|
||||
#include <stdlib.h>
|
||||
#include "config.hpp"
|
||||
#include "device_convnd_bwd_data_xdl_ndhwc_kzyxc_ndhwk.hpp"
|
||||
#include "element_wise_operation.hpp"
|
||||
#include "device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_conv2d_bwd_data_instance {
|
||||
|
||||
using F32 = float;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
static constexpr auto ConvBwdDataDefault =
|
||||
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization_t::Default;
|
||||
|
||||
static constexpr auto ConvBwdDataFilter1x1Stride1Pad0 =
|
||||
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization_t::Filter1x1Stride1Pad0;
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k]
|
||||
using device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//#############################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Num| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
|
||||
//#############################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Dim| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
|
||||
//#############################################################################| | | | | Operation| Operation| Operation| Specialization|Spatial| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
|
||||
//#############################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 4, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 4, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
using device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_1x1_s1_p0_f32_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//#############################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Num| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
|
||||
//#############################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Dim| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
|
||||
//#############################################################################| | | | | Operation| Operation| Operation| Specialization|Spatial| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
|
||||
//#############################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 4, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 4, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instances(
|
||||
std::vector<DeviceConvBwdDataPtr<PassThrough, PassThrough, PassThrough>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instances{});
|
||||
add_device_operation_instances(
|
||||
instances, device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_1x1_s1_p0_f32_instances{});
|
||||
}
|
||||
|
||||
} // namespace device_conv2d_bwd_data_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,86 @@
|
||||
#include <stdlib.h>
|
||||
#include "config.hpp"
|
||||
#include "device_convnd_bwd_data_xdl_ndhwc_kzyxc_ndhwk.hpp"
|
||||
#include "element_wise_operation.hpp"
|
||||
#include "device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_conv2d_bwd_data_instance {
|
||||
|
||||
using DataType = int8_t;
|
||||
using AccType = int32_t;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
static constexpr auto ConvBwdDataDefault =
|
||||
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization_t::Default;
|
||||
|
||||
static constexpr auto ConvBwdDataFilter1x1Stride1Pad0 =
|
||||
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization_t::Filter1x1Stride1Pad0;
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k]
|
||||
using device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//#############################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Num| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
|
||||
//#############################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Dim| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
|
||||
//#############################################################################| | | | | Operation| Operation| Operation| Specialization|Spatial| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
|
||||
//#############################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>,
|
||||
#if 1
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>,
|
||||
#endif
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 64, 64, 64, 4, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 16, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 128, 128, 32, 4, 16, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 16, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
using device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_1x1_s1_p0_int8_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//##############################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Num| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
|
||||
//##############################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Dim| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
|
||||
//##############################################################################| | | | | Operation| Operation| Operation| Specialization|Spatial| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
|
||||
//##############################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 64, 64, 64, 4, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 16, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 128, 128, 32, 4, 16, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 16, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>,
|
||||
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instances(
|
||||
std::vector<DeviceConvBwdDataPtr<PassThrough, PassThrough, PassThrough>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instances{});
|
||||
add_device_operation_instances(
|
||||
instances, device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_1x1_s1_p0_int8_instances{});
|
||||
}
|
||||
|
||||
} // namespace device_conv2d_bwd_data_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -32,7 +32,7 @@ set(PROFILER_SOURCE
|
||||
src/profile_conv_fwd_bias_relu.cpp
|
||||
src/profile_conv_fwd_bias_relu_add.cpp
|
||||
src/profile_conv_fwd_bias_relu_atomic_add.cpp
|
||||
src/profile_conv_bwd_data.cpp
|
||||
src/profile_convnd_bwd_data.cpp
|
||||
src/profile_reduce.cpp
|
||||
src/profile_grouped_gemm.cpp
|
||||
)
|
||||
@@ -50,7 +50,7 @@ target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_instance)
|
||||
target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_instance)
|
||||
target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_add_instance)
|
||||
target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_atomic_add_instance)
|
||||
target_link_libraries(ckProfiler PRIVATE device_conv2d_bwd_data_instance)
|
||||
target_link_libraries(ckProfiler PRIVATE device_convnd_bwd_data_instance)
|
||||
target_link_libraries(ckProfiler PRIVATE device_reduce_instance)
|
||||
target_link_libraries(ckProfiler PRIVATE device_reduce_instance)
|
||||
target_link_libraries(ckProfiler PRIVATE device_grouped_gemm_instance)
|
||||
|
||||
@@ -42,6 +42,7 @@ template <int NDimSpatial,
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename OutDataType,
|
||||
typename AccDataType,
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
typename OutLayout>
|
||||
@@ -123,6 +124,7 @@ void profile_conv_bwd_data_impl(int do_verification,
|
||||
ck::tensor_operation::host::ReferenceConvBwdData<InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
AccDataType,
|
||||
InElementOp,
|
||||
WeiElementOp,
|
||||
OutElementOp>;
|
||||
|
||||
514
profiler/include/profile_convnd_bwd_data_impl.hpp
Normal file
514
profiler/include/profile_convnd_bwd_data_impl.hpp
Normal file
@@ -0,0 +1,514 @@
|
||||
#pragma once
|
||||
#include "config.hpp"
|
||||
#include "device.hpp"
|
||||
#include "conv_utils.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "host_tensor_generator.hpp"
|
||||
#include "tensor_layout.hpp"
|
||||
#include "device_tensor.hpp"
|
||||
#include "device_conv_bwd_data.hpp"
|
||||
#include "element_wise_operation.hpp"
|
||||
#include "reference_conv_bwd_data.hpp"
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
using BF16 = ushort;
|
||||
using INT8 = int8_t;
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_conv2d_bwd_data_instance {
|
||||
|
||||
using DeviceConvBwdDataNoOpPtr =
|
||||
DeviceConvBwdDataPtr<ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough>;
|
||||
void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instances(
|
||||
std::vector<DeviceConvBwdDataNoOpPtr>&);
|
||||
void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f16_instances(
|
||||
std::vector<DeviceConvBwdDataNoOpPtr>&);
|
||||
void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_bf16_instances(
|
||||
std::vector<DeviceConvBwdDataNoOpPtr>&);
|
||||
void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instances(
|
||||
std::vector<DeviceConvBwdDataNoOpPtr>&);
|
||||
|
||||
void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances(
|
||||
std::vector<DeviceConvBwdDataNoOpPtr>&);
|
||||
void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances(
|
||||
std::vector<DeviceConvBwdDataNoOpPtr>&);
|
||||
void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances(
|
||||
std::vector<DeviceConvBwdDataNoOpPtr>&);
|
||||
void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances(
|
||||
std::vector<DeviceConvBwdDataNoOpPtr>&);
|
||||
|
||||
void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instances(
|
||||
std::vector<DeviceConvBwdDataNoOpPtr>&);
|
||||
void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f16_instances(
|
||||
std::vector<DeviceConvBwdDataNoOpPtr>&);
|
||||
void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instances(
|
||||
std::vector<DeviceConvBwdDataNoOpPtr>&);
|
||||
void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instances(
|
||||
std::vector<DeviceConvBwdDataNoOpPtr>&);
|
||||
} // namespace device_conv2d_bwd_data_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
|
||||
namespace ck {
|
||||
namespace profiler {
|
||||
using DeviceConvBwdDataNoOpPtr =
|
||||
ck::tensor_operation::device::device_conv2d_bwd_data_instance::DeviceConvBwdDataNoOpPtr;
|
||||
|
||||
template <typename InLayout>
|
||||
HostTensorDescriptor get_input_host_tensor_descriptor(const std::vector<std::size_t>& dims,
|
||||
int num_dim_spatial = 2)
|
||||
{
|
||||
namespace tl = ck::tensor_layout::convolution;
|
||||
|
||||
switch(num_dim_spatial)
|
||||
{
|
||||
case 3: {
|
||||
return ck::conv_util::GetHostTensorDescriptor(dims, InLayout{});
|
||||
}
|
||||
case 2: {
|
||||
return ck::conv_util::GetHostTensorDescriptor(dims, InLayout{});
|
||||
}
|
||||
case 1: {
|
||||
return ck::conv_util::GetHostTensorDescriptor(dims, InLayout{});
|
||||
}
|
||||
default: {
|
||||
throw std::runtime_error("Unsupported number of spatial dimensions provided!");
|
||||
}
|
||||
}
|
||||
}
|
||||
template <typename WeiLayout>
|
||||
HostTensorDescriptor get_filters_host_tensor_descriptor(const std::vector<std::size_t>& dims,
|
||||
int num_dim_spatial = 2)
|
||||
{
|
||||
namespace tl = ck::tensor_layout::convolution;
|
||||
|
||||
switch(num_dim_spatial)
|
||||
{
|
||||
case 3: {
|
||||
return ck::conv_util::GetHostTensorDescriptor(dims, WeiLayout{});
|
||||
}
|
||||
case 2: {
|
||||
return ck::conv_util::GetHostTensorDescriptor(dims, WeiLayout{});
|
||||
}
|
||||
case 1: {
|
||||
return ck::conv_util::GetHostTensorDescriptor(dims, WeiLayout{});
|
||||
}
|
||||
default: {
|
||||
throw std::runtime_error("Unsupported number of spatial dimensions provided!");
|
||||
}
|
||||
}
|
||||
}
|
||||
template <typename OutLayout>
|
||||
HostTensorDescriptor get_output_host_ensor_descriptor(const std::vector<std::size_t>& dims,
|
||||
int num_dim_spatial = 2)
|
||||
{
|
||||
namespace tl = ck::tensor_layout::convolution;
|
||||
|
||||
switch(num_dim_spatial)
|
||||
{
|
||||
case 3: {
|
||||
return ck::conv_util::GetHostTensorDescriptor(dims, OutLayout{});
|
||||
}
|
||||
case 2: {
|
||||
return ck::conv_util::GetHostTensorDescriptor(dims, OutLayout{});
|
||||
}
|
||||
case 1: {
|
||||
return ck::conv_util::GetHostTensorDescriptor(dims, OutLayout{});
|
||||
}
|
||||
|
||||
default: {
|
||||
throw std::runtime_error("Unsupported number of spatial dimensions provided!");
|
||||
}
|
||||
}
|
||||
}
|
||||
template <typename InDataType, typename WeiDataType, typename OutDataType>
|
||||
void get_device_conv_bwd_data_op_ptr(
|
||||
InDataType, WeiDataType, OutDataType, std::vector<DeviceConvBwdDataNoOpPtr>&, int)
|
||||
{
|
||||
std::cout << "can not find device conv bwd data" << std::endl;
|
||||
exit(1);
|
||||
}
|
||||
template <>
|
||||
void get_device_conv_bwd_data_op_ptr(
|
||||
F32, F32, F32, std::vector<DeviceConvBwdDataNoOpPtr>& conv_ptrs, int num_dim_spatial)
|
||||
{
|
||||
switch(num_dim_spatial)
|
||||
{
|
||||
case 1:
|
||||
ck::tensor_operation::device::device_conv2d_bwd_data_instance::
|
||||
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instances(conv_ptrs);
|
||||
break;
|
||||
case 2:
|
||||
ck::tensor_operation::device::device_conv2d_bwd_data_instance::
|
||||
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances(conv_ptrs);
|
||||
break;
|
||||
case 3:
|
||||
ck::tensor_operation::device::device_conv2d_bwd_data_instance::
|
||||
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instances(conv_ptrs);
|
||||
break;
|
||||
default: break;
|
||||
}
|
||||
}
|
||||
template <>
|
||||
void get_device_conv_bwd_data_op_ptr(
|
||||
F16, F16, F16, std::vector<DeviceConvBwdDataNoOpPtr>& conv_ptrs, int num_dim_spatial)
|
||||
{
|
||||
switch(num_dim_spatial)
|
||||
{
|
||||
case 1:
|
||||
ck::tensor_operation::device::device_conv2d_bwd_data_instance::
|
||||
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f16_instances(conv_ptrs);
|
||||
break;
|
||||
case 2:
|
||||
ck::tensor_operation::device::device_conv2d_bwd_data_instance::
|
||||
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances(conv_ptrs);
|
||||
break;
|
||||
case 3:
|
||||
ck::tensor_operation::device::device_conv2d_bwd_data_instance::
|
||||
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f16_instances(conv_ptrs);
|
||||
break;
|
||||
default: break;
|
||||
}
|
||||
}
|
||||
template <>
|
||||
void get_device_conv_bwd_data_op_ptr(
|
||||
BF16, BF16, BF16, std::vector<DeviceConvBwdDataNoOpPtr>& conv_ptrs, int num_dim_spatial)
|
||||
{
|
||||
switch(num_dim_spatial)
|
||||
{
|
||||
case 1:
|
||||
ck::tensor_operation::device::device_conv2d_bwd_data_instance::
|
||||
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_bf16_instances(conv_ptrs);
|
||||
break;
|
||||
case 2:
|
||||
ck::tensor_operation::device::device_conv2d_bwd_data_instance::
|
||||
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances(conv_ptrs);
|
||||
break;
|
||||
case 3:
|
||||
ck::tensor_operation::device::device_conv2d_bwd_data_instance::
|
||||
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instances(conv_ptrs);
|
||||
break;
|
||||
default: break;
|
||||
}
|
||||
}
|
||||
template <>
|
||||
void get_device_conv_bwd_data_op_ptr(
|
||||
INT8, INT8, INT8, std::vector<DeviceConvBwdDataNoOpPtr>& conv_ptrs, int num_dim_spatial)
|
||||
{
|
||||
switch(num_dim_spatial)
|
||||
{
|
||||
case 1:
|
||||
ck::tensor_operation::device::device_conv2d_bwd_data_instance::
|
||||
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instances(conv_ptrs);
|
||||
break;
|
||||
case 2:
|
||||
ck::tensor_operation::device::device_conv2d_bwd_data_instance::
|
||||
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances(conv_ptrs);
|
||||
break;
|
||||
case 3:
|
||||
ck::tensor_operation::device::device_conv2d_bwd_data_instance::
|
||||
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instances(conv_ptrs);
|
||||
break;
|
||||
default: break;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static bool check_out(const Tensor<T>& ref, const Tensor<T>& result)
|
||||
{
|
||||
float max_diff = 1e-6;
|
||||
|
||||
for(int i = 0; i < ref.mData.size(); ++i)
|
||||
{
|
||||
float diff = std::abs(double(ref.mData[i]) - double(result.mData[i]));
|
||||
if(max_diff < diff)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
template <typename DataType>
|
||||
void show_data_nhwc_layout(Tensor<DataType>& nhwc)
|
||||
{
|
||||
std::cout << "[";
|
||||
for(int n = 0; n < nhwc.mDesc.GetLengths()[0]; n++)
|
||||
{
|
||||
std::cout << "[";
|
||||
for(int hi = 0; hi < nhwc.mDesc.GetLengths()[2]; hi++)
|
||||
{
|
||||
std::cout << "[";
|
||||
for(int wi = 0; wi < nhwc.mDesc.GetLengths()[3]; wi++)
|
||||
{
|
||||
std::cout << "[";
|
||||
for(int c = 0; c < nhwc.mDesc.GetLengths()[1]; c++)
|
||||
{
|
||||
std::cout << static_cast<float>(nhwc(n, c, hi, wi)) << " ";
|
||||
}
|
||||
std::cout << "]";
|
||||
}
|
||||
std::cout << "]";
|
||||
}
|
||||
std::cout << "]";
|
||||
}
|
||||
std::cout << "]";
|
||||
}
|
||||
|
||||
template <int NDimSpatial,
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename OutDataType,
|
||||
typename AccDataType,
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
typename OutLayout>
|
||||
bool profile_convnd_bwd_data_impl(int do_verification,
|
||||
int init_method,
|
||||
bool do_log,
|
||||
int nrepeat,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t C,
|
||||
std::vector<ck::index_t> input_spatial_lengths,
|
||||
std::vector<ck::index_t> filter_spatial_lengths,
|
||||
std::vector<ck::index_t> output_spatial_lengths,
|
||||
std::vector<ck::index_t> conv_filter_strides,
|
||||
std::vector<ck::index_t> conv_filter_dilations,
|
||||
std::vector<ck::index_t> input_left_pads,
|
||||
std::vector<ck::index_t> input_right_pads)
|
||||
{
|
||||
using InElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using WeiElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using OutElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
const auto in_element_op = InElementOp{};
|
||||
const auto wei_element_op = WeiElementOp{};
|
||||
const auto out_element_op = OutElementOp{};
|
||||
|
||||
std::vector<std::size_t> input_dims{static_cast<std::size_t>(N), static_cast<std::size_t>(C)};
|
||||
input_dims.insert(
|
||||
std::end(input_dims), std::begin(input_spatial_lengths), std::end(input_spatial_lengths));
|
||||
|
||||
std::vector<std::size_t> filter_dims{static_cast<std::size_t>(K), static_cast<std::size_t>(C)};
|
||||
filter_dims.insert(std::end(filter_dims),
|
||||
std::begin(filter_spatial_lengths),
|
||||
std::end(filter_spatial_lengths));
|
||||
|
||||
std::vector<std::size_t> output_dims{static_cast<std::size_t>(N), static_cast<std::size_t>(K)};
|
||||
output_dims.insert(std::end(output_dims),
|
||||
std::begin(output_spatial_lengths),
|
||||
std::end(output_spatial_lengths));
|
||||
|
||||
Tensor<InDataType> in_n_c_hi_wi_host_result(
|
||||
get_input_host_tensor_descriptor<InLayout>(input_dims, NDimSpatial));
|
||||
Tensor<InDataType> in_n_c_hi_wi_device_result(
|
||||
get_input_host_tensor_descriptor<InLayout>(input_dims, NDimSpatial));
|
||||
Tensor<WeiDataType> wei_k_c_y_x(
|
||||
get_filters_host_tensor_descriptor<WeiLayout>(filter_dims, NDimSpatial));
|
||||
Tensor<OutDataType> out_n_k_ho_wo(
|
||||
get_output_host_ensor_descriptor<OutLayout>(output_dims, NDimSpatial));
|
||||
|
||||
std::cout << "in_n_c_hi_wi: " << in_n_c_hi_wi_host_result.mDesc << std::endl;
|
||||
std::cout << "wei_k_c_y_x: " << wei_k_c_y_x.mDesc << std::endl;
|
||||
std::cout << "out_n_k_ho_wo: " << out_n_k_ho_wo.mDesc << std::endl;
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
out_n_k_ho_wo.GenerateTensorValue(GeneratorTensor_2<OutDataType>{-5, 5});
|
||||
wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_2<WeiDataType>{-5, 5});
|
||||
break;
|
||||
default:
|
||||
out_n_k_ho_wo.GenerateTensorValue(GeneratorTensor_1<OutDataType>{1});
|
||||
wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_1<WeiDataType>{1});
|
||||
}
|
||||
|
||||
DeviceMem in_device_buf(sizeof(InDataType) *
|
||||
in_n_c_hi_wi_device_result.mDesc.GetElementSpace());
|
||||
DeviceMem wei_device_buf(sizeof(WeiDataType) * wei_k_c_y_x.mDesc.GetElementSpace());
|
||||
DeviceMem out_device_buf(sizeof(OutDataType) * out_n_k_ho_wo.mDesc.GetElementSpace());
|
||||
|
||||
out_device_buf.ToDevice(out_n_k_ho_wo.mData.data());
|
||||
wei_device_buf.ToDevice(wei_k_c_y_x.mData.data());
|
||||
|
||||
// reset input to zero
|
||||
in_n_c_hi_wi_device_result.GenerateTensorValue(GeneratorTensor_1<InDataType>{0});
|
||||
in_device_buf.ToDevice(in_n_c_hi_wi_device_result.mData.data());
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
auto RunReference = [&](auto& ref_conv) {
|
||||
auto ref_invoker = ref_conv.MakeInvoker();
|
||||
|
||||
auto ref_argument = ref_conv.MakeArgument(in_n_c_hi_wi_host_result,
|
||||
wei_k_c_y_x,
|
||||
out_n_k_ho_wo,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
InElementOp{},
|
||||
WeiElementOp{},
|
||||
OutElementOp{});
|
||||
ref_invoker.Run(ref_argument);
|
||||
};
|
||||
switch(NDimSpatial)
|
||||
{
|
||||
case 3: {
|
||||
auto ref_conv = ck::tensor_operation::host::ReferenceConvBwdData<InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
AccDataType,
|
||||
InElementOp,
|
||||
WeiElementOp,
|
||||
OutElementOp,
|
||||
3>();
|
||||
RunReference(ref_conv);
|
||||
break;
|
||||
}
|
||||
case 2: {
|
||||
auto ref_conv = ck::tensor_operation::host::ReferenceConvBwdData<InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
AccDataType,
|
||||
InElementOp,
|
||||
WeiElementOp,
|
||||
OutElementOp,
|
||||
2>();
|
||||
RunReference(ref_conv);
|
||||
break;
|
||||
}
|
||||
case 1: {
|
||||
auto ref_conv = ck::tensor_operation::host::ReferenceConvBwdData<InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
AccDataType,
|
||||
InElementOp,
|
||||
WeiElementOp,
|
||||
OutElementOp,
|
||||
1>();
|
||||
RunReference(ref_conv);
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
throw std::runtime_error("Unsupported number of spatial dimensions provided!");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// add device Conv instances
|
||||
std::vector<DeviceConvBwdDataNoOpPtr> conv_ptrs;
|
||||
get_device_conv_bwd_data_op_ptr(
|
||||
InDataType{}, WeiDataType{}, OutDataType{}, conv_ptrs, NDimSpatial);
|
||||
|
||||
if(conv_ptrs.size() <= 0)
|
||||
{
|
||||
throw std::runtime_error("wrong! no device Conv instance found");
|
||||
}
|
||||
|
||||
std::string best_conv_name;
|
||||
float best_ave_time = 0;
|
||||
float best_tflops = 0;
|
||||
float best_gb_per_sec = 0;
|
||||
|
||||
// profile device Conv instances
|
||||
bool success = true;
|
||||
for(auto& conv_ptr : conv_ptrs)
|
||||
{
|
||||
auto argument_ptr = conv_ptr->MakeArgumentPointer(
|
||||
static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()),
|
||||
static_cast<WeiDataType*>(wei_device_buf.GetDeviceBuffer()),
|
||||
static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()),
|
||||
N,
|
||||
K,
|
||||
C,
|
||||
input_spatial_lengths,
|
||||
filter_spatial_lengths,
|
||||
output_spatial_lengths,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
in_element_op,
|
||||
wei_element_op,
|
||||
out_element_op);
|
||||
|
||||
auto invoker_ptr = conv_ptr->MakeInvokerPointer();
|
||||
|
||||
if(conv_ptr->IsSupportedArgument(argument_ptr.get()))
|
||||
{
|
||||
std::string conv_name = conv_ptr->GetTypeString();
|
||||
|
||||
float ave_time = invoker_ptr->Run(argument_ptr.get(), nrepeat);
|
||||
|
||||
std::size_t flop =
|
||||
ck::conv_util::GetFlops(N, C, K, filter_spatial_lengths, output_spatial_lengths);
|
||||
std::size_t num_btype = ck::conv_util::GetBtype<InDataType, WeiDataType, OutDataType>(
|
||||
N, C, K, input_spatial_lengths, filter_spatial_lengths, output_spatial_lengths);
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
|
||||
<< " GB/s" << std::endl;
|
||||
|
||||
if(tflops > best_tflops)
|
||||
{
|
||||
best_conv_name = conv_name;
|
||||
best_tflops = tflops;
|
||||
best_ave_time = ave_time;
|
||||
best_gb_per_sec = gb_per_sec;
|
||||
}
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
in_device_buf.FromDevice(in_n_c_hi_wi_device_result.mData.data());
|
||||
|
||||
if(!check_out(in_n_c_hi_wi_host_result, in_n_c_hi_wi_device_result))
|
||||
{
|
||||
std::cout << "Fail Info: " << conv_ptr->GetTypeString() << std::endl;
|
||||
|
||||
success = false;
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "Pass Info: " << conv_ptr->GetTypeString() << std::endl;
|
||||
}
|
||||
|
||||
check_error(in_n_c_hi_wi_host_result, in_n_c_hi_wi_device_result);
|
||||
|
||||
if(do_log)
|
||||
{
|
||||
std::cout << "in : ";
|
||||
show_data_nhwc_layout(out_n_k_ho_wo);
|
||||
std::cout << std::endl;
|
||||
|
||||
std::cout << "wei: ";
|
||||
show_data_nhwc_layout(wei_k_c_y_x);
|
||||
std::cout << std::endl;
|
||||
|
||||
std::cout << "out_host : ";
|
||||
show_data_nhwc_layout(in_n_c_hi_wi_host_result);
|
||||
std::cout << std::endl;
|
||||
|
||||
std::cout << "out_device: ";
|
||||
show_data_nhwc_layout(in_n_c_hi_wi_device_result);
|
||||
std::cout << std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
|
||||
<< best_gb_per_sec << " GB/s, " << best_conv_name << std::endl;
|
||||
return success;
|
||||
}
|
||||
|
||||
} // namespace profiler
|
||||
} // namespace ck
|
||||
@@ -89,6 +89,7 @@ int profile_conv_bwd_data(int argc, char* argv[])
|
||||
float,
|
||||
float,
|
||||
float,
|
||||
float,
|
||||
ck::tensor_layout::convolution::NHWC,
|
||||
ck::tensor_layout::convolution::KYXC,
|
||||
ck::tensor_layout::convolution::NHWK>(
|
||||
@@ -114,6 +115,7 @@ int profile_conv_bwd_data(int argc, char* argv[])
|
||||
ck::half_t,
|
||||
ck::half_t,
|
||||
ck::half_t,
|
||||
float,
|
||||
ck::tensor_layout::convolution::NHWC,
|
||||
ck::tensor_layout::convolution::KYXC,
|
||||
ck::tensor_layout::convolution::NHWK>(
|
||||
@@ -139,6 +141,7 @@ int profile_conv_bwd_data(int argc, char* argv[])
|
||||
uint16_t,
|
||||
uint16_t,
|
||||
uint16_t,
|
||||
float,
|
||||
ck::tensor_layout::convolution::NHWC,
|
||||
ck::tensor_layout::convolution::KYXC,
|
||||
ck::tensor_layout::convolution::NHWK>(
|
||||
@@ -164,6 +167,7 @@ int profile_conv_bwd_data(int argc, char* argv[])
|
||||
int8_t,
|
||||
int8_t,
|
||||
int8_t,
|
||||
int32_t,
|
||||
ck::tensor_layout::convolution::NHWC,
|
||||
ck::tensor_layout::convolution::KYXC,
|
||||
ck::tensor_layout::convolution::NHWK>(
|
||||
|
||||
224
profiler/src/profile_convnd_bwd_data.cpp
Normal file
224
profiler/src/profile_convnd_bwd_data.cpp
Normal file
@@ -0,0 +1,224 @@
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
#include <stdlib.h>
|
||||
#include <half.hpp>
|
||||
|
||||
#include "profile_convnd_bwd_data_impl.hpp"
|
||||
|
||||
enum ConvDataType
|
||||
{
|
||||
F32_F32_F32, // 0
|
||||
F16_F16_F16, // 1
|
||||
BF16_BF16_BF16, // 2
|
||||
INT8_INT8_INT8, // 3
|
||||
};
|
||||
|
||||
enum ConvInputLayout
|
||||
{
|
||||
NCHW, // 0
|
||||
NHWC, // 1
|
||||
};
|
||||
|
||||
enum ConvWeightLayout
|
||||
{
|
||||
KCYX, // 0
|
||||
KYXC, // 1
|
||||
};
|
||||
|
||||
enum ConvOutputLayout
|
||||
{
|
||||
NKHW, // 0
|
||||
NHWK, // 1
|
||||
};
|
||||
ck::conv_util::ConvParams parse_conv_params(int num_dim_spatial, char* argv[], int arg_idx)
|
||||
{
|
||||
// (N, K, C) + num_dim_spatial * 6 (filter, input, strides, dilations, pad left, pad right)
|
||||
ck::conv_util::ConvParams params;
|
||||
|
||||
params.num_dim_spatial = num_dim_spatial;
|
||||
params.N = std::stoi(argv[arg_idx++]);
|
||||
params.K = std::stoi(argv[arg_idx++]);
|
||||
params.C = std::stoi(argv[arg_idx++]);
|
||||
|
||||
params.filter_spatial_lengths.resize(num_dim_spatial);
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
params.filter_spatial_lengths[i] = std::stoi(argv[arg_idx++]);
|
||||
}
|
||||
params.input_spatial_lengths.resize(num_dim_spatial);
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
params.input_spatial_lengths[i] = std::stoi(argv[arg_idx++]);
|
||||
}
|
||||
params.conv_filter_strides.resize(num_dim_spatial);
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
params.conv_filter_strides[i] = std::stoi(argv[arg_idx++]);
|
||||
}
|
||||
params.conv_filter_dilations.resize(num_dim_spatial);
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
params.conv_filter_dilations[i] = std::stoi(argv[arg_idx++]);
|
||||
}
|
||||
params.input_left_pads.resize(num_dim_spatial);
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
params.input_left_pads[i] = std::stoi(argv[arg_idx++]);
|
||||
}
|
||||
params.input_right_pads.resize(num_dim_spatial);
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
params.input_right_pads[i] = std::stoi(argv[arg_idx++]);
|
||||
}
|
||||
|
||||
return params;
|
||||
}
|
||||
|
||||
int profile_convnd_bwd_data(int argc, char* argv[], int num_dim_spatial)
|
||||
{
|
||||
const int preParams = 10;
|
||||
int conv_args = 3 + num_dim_spatial * 6;
|
||||
int cmdline_nargs = conv_args + preParams;
|
||||
if(cmdline_nargs != argc)
|
||||
{
|
||||
printf("arg1: tensor operation (conv[1|2|3]d_bwd_data: BackwardConvolution)\n");
|
||||
printf("arg2: data type (0: fp32; 1: fp16)\n");
|
||||
printf("arg3: input tensor layout (0: NCHW; 1: NHWC)\n");
|
||||
printf("arg4: weight tensor layout (0: KCYX; 1: KYXC)\n");
|
||||
printf("arg5: output tensor layout (0: NKHW; 1: NHWK)\n");
|
||||
printf("arg6: verification (0: no; 1: yes)\n");
|
||||
printf("arg7: initialization (0: no init; 1: integer value; 2: decimal value)\n");
|
||||
printf("arg8: print tensor value (0: no; 1: yes)\n");
|
||||
printf("arg9: run kernel # of times (>1)\n");
|
||||
printf("arg10 to 24: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, "
|
||||
"RightPx\n");
|
||||
return 1;
|
||||
}
|
||||
|
||||
const int data_type = static_cast<ConvDataType>(std::stoi(argv[2]));
|
||||
const int in_layout = static_cast<ConvInputLayout>(std::stoi(argv[3]));
|
||||
const int wei_layout = static_cast<ConvWeightLayout>(std::stoi(argv[4]));
|
||||
const int out_layout = static_cast<ConvOutputLayout>(std::stoi(argv[5]));
|
||||
const bool do_verification = std::stoi(argv[6]);
|
||||
const int init_method = std::stoi(argv[7]);
|
||||
const bool do_log = std::stoi(argv[8]);
|
||||
const int nrepeat = std::stoi(argv[9]);
|
||||
|
||||
ck::conv_util::ConvParams params = parse_conv_params(num_dim_spatial, argv, preParams);
|
||||
|
||||
auto Run = [&](auto input_type, auto wei_type, auto out_type, auto acc_type) {
|
||||
using InDataType = decltype(input_type);
|
||||
using WeiDataType = decltype(wei_type);
|
||||
using OutDataType = decltype(out_type);
|
||||
using AccDataType = decltype(acc_type);
|
||||
|
||||
switch(num_dim_spatial)
|
||||
{
|
||||
case 1:
|
||||
ck::profiler::profile_convnd_bwd_data_impl<1,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
AccDataType,
|
||||
ck::tensor_layout::convolution::NWC,
|
||||
ck::tensor_layout::convolution::KXC,
|
||||
ck::tensor_layout::convolution::NWK>(
|
||||
do_verification,
|
||||
init_method,
|
||||
do_log,
|
||||
nrepeat,
|
||||
params.N,
|
||||
params.K,
|
||||
params.C,
|
||||
params.input_spatial_lengths,
|
||||
params.filter_spatial_lengths,
|
||||
params.GetOutputSpatialLengths(),
|
||||
params.conv_filter_strides,
|
||||
params.conv_filter_dilations,
|
||||
params.input_left_pads,
|
||||
params.input_right_pads);
|
||||
break;
|
||||
|
||||
case 2:
|
||||
ck::profiler::profile_convnd_bwd_data_impl<2,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
AccDataType,
|
||||
ck::tensor_layout::convolution::NHWC,
|
||||
ck::tensor_layout::convolution::KYXC,
|
||||
ck::tensor_layout::convolution::NHWK>(
|
||||
do_verification,
|
||||
init_method,
|
||||
do_log,
|
||||
nrepeat,
|
||||
params.N,
|
||||
params.K,
|
||||
params.C,
|
||||
params.input_spatial_lengths,
|
||||
params.filter_spatial_lengths,
|
||||
params.GetOutputSpatialLengths(),
|
||||
params.conv_filter_strides,
|
||||
params.conv_filter_dilations,
|
||||
params.input_left_pads,
|
||||
params.input_right_pads);
|
||||
break;
|
||||
|
||||
case 3:
|
||||
ck::profiler::profile_convnd_bwd_data_impl<3,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
AccDataType,
|
||||
ck::tensor_layout::convolution::NDHWC,
|
||||
ck::tensor_layout::convolution::KZYXC,
|
||||
ck::tensor_layout::convolution::NDHWK>(
|
||||
do_verification,
|
||||
init_method,
|
||||
do_log,
|
||||
nrepeat,
|
||||
params.N,
|
||||
params.K,
|
||||
params.C,
|
||||
params.input_spatial_lengths,
|
||||
params.filter_spatial_lengths,
|
||||
params.GetOutputSpatialLengths(),
|
||||
params.conv_filter_strides,
|
||||
params.conv_filter_dilations,
|
||||
params.input_left_pads,
|
||||
params.input_right_pads);
|
||||
break;
|
||||
|
||||
default: break;
|
||||
}
|
||||
};
|
||||
if(data_type == ConvDataType::F32_F32_F32 && in_layout == ConvInputLayout::NHWC &&
|
||||
wei_layout == ConvWeightLayout::KYXC && out_layout == ConvOutputLayout::NHWK)
|
||||
{
|
||||
Run(float{}, float{}, float{}, float{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F16_F16_F16 && in_layout == ConvInputLayout::NHWC &&
|
||||
wei_layout == ConvWeightLayout::KYXC && out_layout == ConvOutputLayout::NHWK)
|
||||
{
|
||||
Run(ck::half_t{}, ck::half_t{}, ck::half_t{}, float{});
|
||||
}
|
||||
else if(data_type == ConvDataType::BF16_BF16_BF16 && in_layout == ConvInputLayout::NHWC &&
|
||||
wei_layout == ConvWeightLayout::KYXC && out_layout == ConvOutputLayout::NHWK)
|
||||
{
|
||||
Run(ck::bhalf_t{}, ck::bhalf_t{}, ck::bhalf_t{}, float{});
|
||||
}
|
||||
else if(data_type == ConvDataType::INT8_INT8_INT8 && in_layout == ConvInputLayout::NHWC &&
|
||||
wei_layout == ConvWeightLayout::KYXC && out_layout == ConvOutputLayout::NHWK)
|
||||
{
|
||||
Run(int8_t{}, int8_t{}, int8_t{}, int32_t{});
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "wrong! this Conv data_type & layout is not implemented" << std::endl;
|
||||
return 1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
@@ -15,7 +15,7 @@ int profile_conv_fwd(int, char*[]);
|
||||
int profile_conv_fwd_bias_relu(int, char*[]);
|
||||
int profile_conv_fwd_bias_relu_add(int, char*[]);
|
||||
int profile_conv_fwd_bias_relu_atomic_add(int, char*[]);
|
||||
int profile_conv_bwd_data(int, char*[]);
|
||||
int profile_convnd_bwd_data(int, char*[], int);
|
||||
int profile_reduce(int, char*[]);
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
@@ -64,9 +64,17 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
return profile_conv_fwd_bias_relu_atomic_add(argc, argv);
|
||||
}
|
||||
else if(strcmp(argv[1], "conv_bwd") == 0)
|
||||
else if(strcmp(argv[1], "conv1d_bwd_data") == 0)
|
||||
{
|
||||
return profile_conv_bwd_data(argc, argv);
|
||||
return profile_convnd_bwd_data(argc, argv, 1);
|
||||
}
|
||||
else if(strcmp(argv[1], "conv2d_bwd_data") == 0)
|
||||
{
|
||||
return profile_convnd_bwd_data(argc, argv, 2);
|
||||
}
|
||||
else if(strcmp(argv[1], "conv3d_bwd_data") == 0)
|
||||
{
|
||||
return profile_convnd_bwd_data(argc, argv, 3);
|
||||
}
|
||||
else if(strcmp(argv[1], "reduce") == 0)
|
||||
{
|
||||
@@ -85,8 +93,11 @@ int main(int argc, char* argv[])
|
||||
" conv_fwd_bias_relu: ForwardConvolution+Bias+ReLU\n"
|
||||
" conv_fwd_bias_relu_add: ForwardConvolution+Bias+ReLU+Add\n"
|
||||
" conv_fwd_bias_relu_atomic_add: ForwardConvolution+Bias+ReLU+AtomicAdd\n"
|
||||
" conv_bwd: BackwardConvolution\n"
|
||||
" reduce: Reduce\n");
|
||||
" conv1d_bwd_data: BackwardConvolution data 1 dim\n"
|
||||
" conv2d_bwd_data: BackwardConvolution data 2 dim\n"
|
||||
" conv3d_bwd_data: BackwardConvolution data 3 dim\n"
|
||||
" grouped_gemm: Grouped Gemm\n"
|
||||
" reduce: REDUCE\n");
|
||||
// clang-format on
|
||||
|
||||
return 0;
|
||||
|
||||
@@ -41,5 +41,4 @@ add_subdirectory(gemm_reduce)
|
||||
add_subdirectory(batched_gemm)
|
||||
add_subdirectory(grouped_gemm)
|
||||
add_subdirectory(convnd_fwd)
|
||||
add_subdirectory(conv2d_bwd_data)
|
||||
add_subdirectory(reduce)
|
||||
|
||||
@@ -121,15 +121,17 @@ int main(int argc, char* argv[])
|
||||
exit(1);
|
||||
}
|
||||
|
||||
auto Run = [&](auto input_type, auto wei_type, auto out_type) {
|
||||
auto Run = [&](auto input_type, auto wei_type, auto out_type, auto acc_type) {
|
||||
using InDataType = decltype(input_type);
|
||||
using WeiDataType = decltype(wei_type);
|
||||
using OutDataType = decltype(out_type);
|
||||
using AccDataType = decltype(acc_type);
|
||||
|
||||
using ReferenceConvBwdInstance =
|
||||
ck::tensor_operation::host::ReferenceConvBwdData<InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
AccDataType,
|
||||
InElementOp,
|
||||
WeiElementOp,
|
||||
OutElementOp>;
|
||||
@@ -293,33 +295,33 @@ int main(int argc, char* argv[])
|
||||
if(success)
|
||||
{
|
||||
std::cout << "test conv2d bwd : Pass" << std::endl;
|
||||
return 0;
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "test conv2d bwd: Fail " << std::endl;
|
||||
return -1;
|
||||
}
|
||||
};
|
||||
|
||||
if(data_type == 0)
|
||||
{
|
||||
Run(F32(), F32(), F32());
|
||||
return Run(F32(), F32(), F32(), F32());
|
||||
}
|
||||
else if(data_type == 1)
|
||||
{
|
||||
Run(F16(), F16(), F16());
|
||||
return Run(F16(), F16(), F16(), F32());
|
||||
}
|
||||
else if(data_type == 2)
|
||||
{
|
||||
Run(BF16(), BF16(), BF16());
|
||||
return Run(BF16(), BF16(), BF16(), F32());
|
||||
}
|
||||
else if(data_type == 3)
|
||||
{
|
||||
Run(INT8(), INT8(), INT8());
|
||||
return Run(INT8(), INT8(), INT8(), int());
|
||||
}
|
||||
else
|
||||
{
|
||||
return 1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
8
test/convnd_bwd_data/CMakeLists.txt
Normal file
8
test/convnd_bwd_data/CMakeLists.txt
Normal file
@@ -0,0 +1,8 @@
|
||||
include_directories(BEFORE
|
||||
${PROJECT_SOURCE_DIR}/profiler/include
|
||||
${PROJECT_SOURCE_DIR}/external/include/half
|
||||
)
|
||||
|
||||
add_test_executable(test_convnd_bwd_data convnd_bwd_data.cpp)
|
||||
target_link_libraries(test_convnd_bwd_data PRIVATE host_tensor)
|
||||
target_link_libraries(test_convnd_bwd_data PRIVATE device_convnd_bwd_data_instance)
|
||||
330
test/convnd_bwd_data/convnd_bwd_data.cpp
Normal file
330
test/convnd_bwd_data/convnd_bwd_data.cpp
Normal file
@@ -0,0 +1,330 @@
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
#include <stdlib.h>
|
||||
#include <half.hpp>
|
||||
#include <vector>
|
||||
|
||||
#include "profile_convnd_bwd_data_impl.hpp"
|
||||
|
||||
int main()
|
||||
{
|
||||
bool pass = true;
|
||||
// check 1d
|
||||
std::vector<ck::conv_util::ConvParams> params;
|
||||
params.push_back({1, 128, 128, 256, {1}, {14}, {2}, {1}, {0}, {0}});
|
||||
params.push_back({1, 128, 128, 256, {3}, {28}, {1}, {1}, {1}, {1}});
|
||||
params.push_back({1, 128, 128, 256, {1}, {3}, {1}, {1}, {0}, {0}});
|
||||
|
||||
for(auto& param : params)
|
||||
{
|
||||
pass &= ck::profiler::profile_convnd_bwd_data_impl<1,
|
||||
float,
|
||||
float,
|
||||
float,
|
||||
float,
|
||||
ck::tensor_layout::convolution::NWC,
|
||||
ck::tensor_layout::convolution::KXC,
|
||||
ck::tensor_layout::convolution::NWK>(
|
||||
1, // do_verification,
|
||||
1, // init_method,
|
||||
0, // do_log,
|
||||
1, // nrepeat,
|
||||
param.N,
|
||||
param.K,
|
||||
param.C,
|
||||
param.input_spatial_lengths,
|
||||
param.filter_spatial_lengths,
|
||||
param.GetOutputSpatialLengths(),
|
||||
param.conv_filter_strides,
|
||||
param.conv_filter_dilations,
|
||||
param.input_left_pads,
|
||||
param.input_right_pads);
|
||||
|
||||
pass &= ck::profiler::profile_convnd_bwd_data_impl<1,
|
||||
ck::half_t,
|
||||
ck::half_t,
|
||||
ck::half_t,
|
||||
float,
|
||||
ck::tensor_layout::convolution::NWC,
|
||||
ck::tensor_layout::convolution::KXC,
|
||||
ck::tensor_layout::convolution::NWK>(
|
||||
1, // do_verification,
|
||||
1, // init_method,
|
||||
0, // do_log,
|
||||
1, // nrepeat,
|
||||
param.N,
|
||||
param.K,
|
||||
param.C,
|
||||
param.input_spatial_lengths,
|
||||
param.filter_spatial_lengths,
|
||||
param.GetOutputSpatialLengths(),
|
||||
param.conv_filter_strides,
|
||||
param.conv_filter_dilations,
|
||||
param.input_left_pads,
|
||||
param.input_right_pads);
|
||||
|
||||
pass &= ck::profiler::profile_convnd_bwd_data_impl<1,
|
||||
ck::bhalf_t,
|
||||
ck::bhalf_t,
|
||||
ck::bhalf_t,
|
||||
float,
|
||||
ck::tensor_layout::convolution::NWC,
|
||||
ck::tensor_layout::convolution::KXC,
|
||||
ck::tensor_layout::convolution::NWK>(
|
||||
1, // do_verification,
|
||||
1, // init_method,
|
||||
0, // do_log,
|
||||
1, // nrepeat,
|
||||
param.N,
|
||||
param.K,
|
||||
param.C,
|
||||
param.input_spatial_lengths,
|
||||
param.filter_spatial_lengths,
|
||||
param.GetOutputSpatialLengths(),
|
||||
param.conv_filter_strides,
|
||||
param.conv_filter_dilations,
|
||||
param.input_left_pads,
|
||||
param.input_right_pads);
|
||||
|
||||
pass &= ck::profiler::profile_convnd_bwd_data_impl<1,
|
||||
int8_t,
|
||||
int8_t,
|
||||
int8_t,
|
||||
int,
|
||||
ck::tensor_layout::convolution::NWC,
|
||||
ck::tensor_layout::convolution::KXC,
|
||||
ck::tensor_layout::convolution::NWK>(
|
||||
1, // do_verification,
|
||||
1, // init_method,
|
||||
0, // do_log,
|
||||
1, // nrepeat,
|
||||
param.N,
|
||||
param.K,
|
||||
param.C,
|
||||
param.input_spatial_lengths,
|
||||
param.filter_spatial_lengths,
|
||||
param.GetOutputSpatialLengths(),
|
||||
param.conv_filter_strides,
|
||||
param.conv_filter_dilations,
|
||||
param.input_left_pads,
|
||||
param.input_right_pads);
|
||||
}
|
||||
|
||||
// check 2d
|
||||
params.clear();
|
||||
params.push_back({2, 128, 128, 256, {1, 1}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0}});
|
||||
params.push_back({2, 128, 128, 256, {3, 3}, {14, 14}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
|
||||
params.push_back({2, 128, 128, 256, {1, 1}, {3, 3}, {1, 1}, {1, 1}, {0, 0}, {0, 0}});
|
||||
|
||||
for(auto& param : params)
|
||||
{
|
||||
pass &= ck::profiler::profile_convnd_bwd_data_impl<2,
|
||||
float,
|
||||
float,
|
||||
float,
|
||||
float,
|
||||
ck::tensor_layout::convolution::NHWC,
|
||||
ck::tensor_layout::convolution::KYXC,
|
||||
ck::tensor_layout::convolution::NHWK>(
|
||||
1, // do_verification,
|
||||
1, // init_method,
|
||||
0, // do_log,
|
||||
1, // nrepeat,
|
||||
param.N,
|
||||
param.K,
|
||||
param.C,
|
||||
param.input_spatial_lengths,
|
||||
param.filter_spatial_lengths,
|
||||
param.GetOutputSpatialLengths(),
|
||||
param.conv_filter_strides,
|
||||
param.conv_filter_dilations,
|
||||
param.input_left_pads,
|
||||
param.input_right_pads);
|
||||
|
||||
pass &= ck::profiler::profile_convnd_bwd_data_impl<2,
|
||||
ck::half_t,
|
||||
ck::half_t,
|
||||
ck::half_t,
|
||||
float,
|
||||
ck::tensor_layout::convolution::NHWC,
|
||||
ck::tensor_layout::convolution::KYXC,
|
||||
ck::tensor_layout::convolution::NHWK>(
|
||||
1, // do_verification,
|
||||
1, // init_method,
|
||||
0, // do_log,
|
||||
1, // nrepeat,
|
||||
param.N,
|
||||
param.K,
|
||||
param.C,
|
||||
param.input_spatial_lengths,
|
||||
param.filter_spatial_lengths,
|
||||
param.GetOutputSpatialLengths(),
|
||||
param.conv_filter_strides,
|
||||
param.conv_filter_dilations,
|
||||
param.input_left_pads,
|
||||
param.input_right_pads);
|
||||
|
||||
pass &= ck::profiler::profile_convnd_bwd_data_impl<2,
|
||||
ck::bhalf_t,
|
||||
ck::bhalf_t,
|
||||
ck::bhalf_t,
|
||||
float,
|
||||
ck::tensor_layout::convolution::NHWC,
|
||||
ck::tensor_layout::convolution::KYXC,
|
||||
ck::tensor_layout::convolution::NHWK>(
|
||||
1, // do_verification,
|
||||
1, // init_method,
|
||||
0, // do_log,
|
||||
1, // nrepeat,
|
||||
param.N,
|
||||
param.K,
|
||||
param.C,
|
||||
param.input_spatial_lengths,
|
||||
param.filter_spatial_lengths,
|
||||
param.GetOutputSpatialLengths(),
|
||||
param.conv_filter_strides,
|
||||
param.conv_filter_dilations,
|
||||
param.input_left_pads,
|
||||
param.input_right_pads);
|
||||
|
||||
pass &= ck::profiler::profile_convnd_bwd_data_impl<2,
|
||||
int8_t,
|
||||
int8_t,
|
||||
int8_t,
|
||||
int,
|
||||
ck::tensor_layout::convolution::NHWC,
|
||||
ck::tensor_layout::convolution::KYXC,
|
||||
ck::tensor_layout::convolution::NHWK>(
|
||||
1, // do_verification,
|
||||
1, // init_method,
|
||||
0, // do_log,
|
||||
1, // nrepeat,
|
||||
param.N,
|
||||
param.K,
|
||||
param.C,
|
||||
param.input_spatial_lengths,
|
||||
param.filter_spatial_lengths,
|
||||
param.GetOutputSpatialLengths(),
|
||||
param.conv_filter_strides,
|
||||
param.conv_filter_dilations,
|
||||
param.input_left_pads,
|
||||
param.input_right_pads);
|
||||
}
|
||||
|
||||
// check 3d
|
||||
params.clear();
|
||||
params.push_back(
|
||||
{3, 128, 128, 256, {1, 1, 1}, {7, 7, 7}, {2, 2, 2}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}});
|
||||
params.push_back(
|
||||
{3, 128, 128, 256, {3, 3, 3}, {14, 14, 14}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
|
||||
params.push_back(
|
||||
{3, 128, 128, 256, {1, 1, 1}, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}});
|
||||
|
||||
for(auto& param : params)
|
||||
{
|
||||
pass &= ck::profiler::profile_convnd_bwd_data_impl<3,
|
||||
float,
|
||||
float,
|
||||
float,
|
||||
float,
|
||||
ck::tensor_layout::convolution::NDHWC,
|
||||
ck::tensor_layout::convolution::KZYXC,
|
||||
ck::tensor_layout::convolution::NDHWK>(
|
||||
1, // do_verification,
|
||||
1, // init_method,
|
||||
0, // do_log,
|
||||
1, // nrepeat,
|
||||
param.N,
|
||||
param.K,
|
||||
param.C,
|
||||
param.input_spatial_lengths,
|
||||
param.filter_spatial_lengths,
|
||||
param.GetOutputSpatialLengths(),
|
||||
param.conv_filter_strides,
|
||||
param.conv_filter_dilations,
|
||||
param.input_left_pads,
|
||||
param.input_right_pads);
|
||||
|
||||
pass &= ck::profiler::profile_convnd_bwd_data_impl<3,
|
||||
ck::half_t,
|
||||
ck::half_t,
|
||||
ck::half_t,
|
||||
float,
|
||||
ck::tensor_layout::convolution::NDHWC,
|
||||
ck::tensor_layout::convolution::KZYXC,
|
||||
ck::tensor_layout::convolution::NDHWK>(
|
||||
1, // do_verification,
|
||||
1, // init_method,
|
||||
0, // do_log,
|
||||
1, // nrepeat,
|
||||
param.N,
|
||||
param.K,
|
||||
param.C,
|
||||
param.input_spatial_lengths,
|
||||
param.filter_spatial_lengths,
|
||||
param.GetOutputSpatialLengths(),
|
||||
param.conv_filter_strides,
|
||||
param.conv_filter_dilations,
|
||||
param.input_left_pads,
|
||||
param.input_right_pads);
|
||||
|
||||
pass &= ck::profiler::profile_convnd_bwd_data_impl<3,
|
||||
ck::bhalf_t,
|
||||
ck::bhalf_t,
|
||||
ck::bhalf_t,
|
||||
float,
|
||||
ck::tensor_layout::convolution::NDHWC,
|
||||
ck::tensor_layout::convolution::KZYXC,
|
||||
ck::tensor_layout::convolution::NDHWK>(
|
||||
1, // do_verification,
|
||||
1, // init_method,
|
||||
0, // do_log,
|
||||
1, // nrepeat,
|
||||
param.N,
|
||||
param.K,
|
||||
param.C,
|
||||
param.input_spatial_lengths,
|
||||
param.filter_spatial_lengths,
|
||||
param.GetOutputSpatialLengths(),
|
||||
param.conv_filter_strides,
|
||||
param.conv_filter_dilations,
|
||||
param.input_left_pads,
|
||||
param.input_right_pads);
|
||||
|
||||
pass &= ck::profiler::profile_convnd_bwd_data_impl<3,
|
||||
int8_t,
|
||||
int8_t,
|
||||
int8_t,
|
||||
int,
|
||||
ck::tensor_layout::convolution::NDHWC,
|
||||
ck::tensor_layout::convolution::KZYXC,
|
||||
ck::tensor_layout::convolution::NDHWK>(
|
||||
1, // do_verification,
|
||||
1, // init_method,
|
||||
0, // do_log,
|
||||
1, // nrepeat,
|
||||
param.N,
|
||||
param.K,
|
||||
param.C,
|
||||
param.input_spatial_lengths,
|
||||
param.filter_spatial_lengths,
|
||||
param.GetOutputSpatialLengths(),
|
||||
param.conv_filter_strides,
|
||||
param.conv_filter_dilations,
|
||||
param.input_left_pads,
|
||||
param.input_right_pads);
|
||||
}
|
||||
|
||||
if(pass)
|
||||
{
|
||||
std::cout << "test convnd bwd : Pass" << std::endl;
|
||||
return 0;
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "test convnd bwd: Fail " << std::endl;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user