mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-20 04:49:54 +00:00
Unified implementation of 1d/2d/3d conv bwd-data. fp32/fp16/bfp16/int8 (#134)
* start convnd bwd data
* add 3d laoyout name
* add conv1d reference
* add con3d reference
* finished example client code
* conv1d kernel finished
* fix input error
* add conv3d
* add 3d layout in conv_utils.hpp
* fix sepecial check
* addconvnd lib
* add test for bwd data
* finished test
* add check slice length
* convnd bwd data start
* profiler can be compiled
* fix some bug
* set input to zero
* modify readme for example
* fix test_convnd_bwd_data bug
* test_convnd_bwd_data parameter desc
* workaround for 1d
* workaroud for 2d
* change init value
* workaround for 3d int8
* fix init value bug
* remove workaround
* fix acc data type
* add int32
* change select function to template
* tilda to tilde
* remove int32 instance
* fix commit for device hpp
* fix comments for profiler
* using profile imp to test
* add pass verification
* fix conv2d reference
* fix conflict
* remove double batched_gemm
* fix exampel conv2d data and test convnd
* format
* change conv2d_bwd_data return value
* remove repeat = 1
* remove conv bwd data
Co-authored-by: ltqin <letaoqin@amd.com>
Co-authored-by: Chao Liu <chao.liu2@amd.com>
[ROCm/composable_kernel commit: 0536f2b312]
This commit is contained in:
@@ -89,6 +89,7 @@ int profile_conv_bwd_data(int argc, char* argv[])
|
||||
float,
|
||||
float,
|
||||
float,
|
||||
float,
|
||||
ck::tensor_layout::convolution::NHWC,
|
||||
ck::tensor_layout::convolution::KYXC,
|
||||
ck::tensor_layout::convolution::NHWK>(
|
||||
@@ -114,6 +115,7 @@ int profile_conv_bwd_data(int argc, char* argv[])
|
||||
ck::half_t,
|
||||
ck::half_t,
|
||||
ck::half_t,
|
||||
float,
|
||||
ck::tensor_layout::convolution::NHWC,
|
||||
ck::tensor_layout::convolution::KYXC,
|
||||
ck::tensor_layout::convolution::NHWK>(
|
||||
@@ -139,6 +141,7 @@ int profile_conv_bwd_data(int argc, char* argv[])
|
||||
uint16_t,
|
||||
uint16_t,
|
||||
uint16_t,
|
||||
float,
|
||||
ck::tensor_layout::convolution::NHWC,
|
||||
ck::tensor_layout::convolution::KYXC,
|
||||
ck::tensor_layout::convolution::NHWK>(
|
||||
@@ -164,6 +167,7 @@ int profile_conv_bwd_data(int argc, char* argv[])
|
||||
int8_t,
|
||||
int8_t,
|
||||
int8_t,
|
||||
int32_t,
|
||||
ck::tensor_layout::convolution::NHWC,
|
||||
ck::tensor_layout::convolution::KYXC,
|
||||
ck::tensor_layout::convolution::NHWK>(
|
||||
|
||||
Reference in New Issue
Block a user