mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-20 12:59:49 +00:00
NHWC conv 2d: bwd fp32/fp16/bfp16/int8, Device level tuning and host API (#92)
* start conv2d bwd api
* kernel running
* add bwd reference
* change to no shuffle
* fix bwd reference
* pass verification
* add Filter1x1Stride1Pad0 and start testing
* change some tuning parameter
* fix test error
* add fp16 tuning parameter
* add bf16 tuning parameter
* add int8 tuning parameters
* change fp32 tuning parameter
* add bwd to profiler
* fix bug for bwd profiler
* fix ckProfiler bug
* change conv2d_bwd_xdl to fp16
* fix bug in comments
* fix precompile id
* fix enum conv name
* chage _bwd_ to _bwd_data_
* change conv2d_bwd example id
* bwd to bwd data
* fix prehead
* fix MakeDefaultBlock2CTileMap ,import form merge develop
* format bwd instance
* bwd to bwd data
* change name bwd to bwd data
* change name bwd to bwd data in example
* formate code
* change conv2d bwd data id in example
* rewrite readme for example
* fix CalculateMagicNumbers about div zero
* add workaround CK_WORKAROUND_SWDEV_325164
* change test_conf2d_bwd_data show info
* format
* fix bug for workaround:CK_WORKAROUND_SWDEV_325164
* formate tuning parameters
* formate tuning parameters again
* formate tuning parameters 3
* formate tuning parameters 4
* remove add function template
* format
* update comment
Co-authored-by: ltqin <letaoqin@amd.com>
Co-authored-by: Chao Liu <chao.liu2@amd.com>
[ROCm/composable_kernel commit: c254e5abd2]
This commit is contained in:
47
device_operation/include/device_conv_bwd_data.hpp
Normal file
47
device_operation/include/device_conv_bwd_data.hpp
Normal file
@@ -0,0 +1,47 @@
|
||||
#ifndef DEVICE_CONV_BWD_DATA_HPP
|
||||
#define DEVICE_CONV_BWD_DATA_HPP
|
||||
|
||||
#include <iostream>
|
||||
#include "device_base.hpp"
|
||||
#include "element_wise_operation.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename InElementwiseOperation,
|
||||
typename WeiElementwiseOperation,
|
||||
typename OutElementwiseOperation>
|
||||
struct DeviceConvBwdData : public BaseOperator
|
||||
{
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(void* p_in,
|
||||
const void* p_wei,
|
||||
const void* p_out,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t C,
|
||||
std::vector<ck::index_t> input_spatial_lengths,
|
||||
std::vector<ck::index_t> filter_spatial_lengths,
|
||||
std::vector<ck::index_t> output_spatial_lengths,
|
||||
std::vector<ck::index_t> conv_filter_strides,
|
||||
std::vector<ck::index_t> conv_filter_dilations,
|
||||
std::vector<ck::index_t> input_left_pads,
|
||||
std::vector<ck::index_t> input_right_pads,
|
||||
InElementwiseOperation in_element_op,
|
||||
WeiElementwiseOperation wei_element_op,
|
||||
OutElementwiseOperation out_element_op) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
template <typename InElementwiseOperation,
|
||||
typename WeiElementwiseOperation,
|
||||
typename OutElementwiseOperation>
|
||||
using DeviceConvBwdDataPtr = std::unique_ptr<
|
||||
DeviceConvBwdData<InElementwiseOperation, WeiElementwiseOperation, OutElementwiseOperation>>;
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
#endif
|
||||
Reference in New Issue
Block a user