mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-16 10:59:55 +00:00
Add Conv Backward Data on Navi21 for ResNet50 (#499)
* start add example
* add device dl
* change launch kernel
* change init data method
* change example config
* add config valid check
* add instance for dl bwd
* add instance to ckProfiler
* reserver to profiler and cmakelist
* add instance to ckProfiler2
* change instance f32 config
* fix example return value
Co-authored-by: letaoqin <letaoqin@amd.com>
Co-authored-by: Po Yen Chen <PoYen.Chen@amd.com>
[ROCm/composable_kernel commit: db0eb1ea9c]
This commit is contained in:
@@ -101,6 +101,42 @@ void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances(
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
// conv2d dl
|
||||
void add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceConvBwdData<2,
|
||||
NHWC,
|
||||
KYXC,
|
||||
NHWK,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceConvBwdData<2,
|
||||
NHWC,
|
||||
KYXC,
|
||||
NHWK,
|
||||
F32,
|
||||
F32,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_int8_instances(
|
||||
std::vector<std::unique_ptr<DeviceConvBwdData<2,
|
||||
NHWC,
|
||||
KYXC,
|
||||
NHWK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
// conv3d backward data
|
||||
void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceConvBwdData<3,
|
||||
@@ -216,11 +252,13 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceConvBw
|
||||
is_same_v<OutDataType, float>)
|
||||
{
|
||||
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances(op_ptrs);
|
||||
add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
|
||||
is_same_v<OutDataType, half_t>)
|
||||
{
|
||||
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances(op_ptrs);
|
||||
add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f16_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
|
||||
is_same_v<WeiDataType, ck::bhalf_t> &&
|
||||
@@ -232,6 +270,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceConvBw
|
||||
is_same_v<OutDataType, int8_t>)
|
||||
{
|
||||
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances(op_ptrs);
|
||||
add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_int8_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
else if constexpr(NumDimSpatial == 3 && is_same_v<InLayout, NDHWC> &&
|
||||
|
||||
Reference in New Issue
Block a user