mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-01 20:21:23 +00:00
Unified implementation of 1d/2d/3d conv bwd-data. fp32/fp16/bfp16/int8 (#134)
* start convnd bwd data * add 3d laoyout name * add conv1d reference * add con3d reference * finished example client code * conv1d kernel finished * fix input error * add conv3d * add 3d layout in conv_utils.hpp * fix sepecial check * addconvnd lib * add test for bwd data * finished test * add check slice length * convnd bwd data start * profiler can be compiled * fix some bug * set input to zero * modify readme for example * fix test_convnd_bwd_data bug * test_convnd_bwd_data parameter desc * workaround for 1d * workaroud for 2d * change init value * workaround for 3d int8 * fix init value bug * remove workaround * fix acc data type * add int32 * change select function to template * tilda to tilde * remove int32 instance * fix commit for device hpp * fix comments for profiler * using profile imp to test * add pass verification * fix conv2d reference * fix conflict * remove double batched_gemm * fix exampel conv2d data and test convnd * format * change conv2d_bwd_data return value * remove repeat = 1 * remove conv bwd data Co-authored-by: ltqin <letaoqin@amd.com> Co-authored-by: Chao Liu <chao.liu2@amd.com>
This commit is contained in:
@@ -121,15 +121,17 @@ int main(int argc, char* argv[])
|
||||
exit(1);
|
||||
}
|
||||
|
||||
auto Run = [&](auto input_type, auto wei_type, auto out_type) {
|
||||
auto Run = [&](auto input_type, auto wei_type, auto out_type, auto acc_type) {
|
||||
using InDataType = decltype(input_type);
|
||||
using WeiDataType = decltype(wei_type);
|
||||
using OutDataType = decltype(out_type);
|
||||
using AccDataType = decltype(acc_type);
|
||||
|
||||
using ReferenceConvBwdInstance =
|
||||
ck::tensor_operation::host::ReferenceConvBwdData<InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
AccDataType,
|
||||
InElementOp,
|
||||
WeiElementOp,
|
||||
OutElementOp>;
|
||||
@@ -293,33 +295,33 @@ int main(int argc, char* argv[])
|
||||
if(success)
|
||||
{
|
||||
std::cout << "test conv2d bwd : Pass" << std::endl;
|
||||
return 0;
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "test conv2d bwd: Fail " << std::endl;
|
||||
return -1;
|
||||
}
|
||||
};
|
||||
|
||||
if(data_type == 0)
|
||||
{
|
||||
Run(F32(), F32(), F32());
|
||||
return Run(F32(), F32(), F32(), F32());
|
||||
}
|
||||
else if(data_type == 1)
|
||||
{
|
||||
Run(F16(), F16(), F16());
|
||||
return Run(F16(), F16(), F16(), F32());
|
||||
}
|
||||
else if(data_type == 2)
|
||||
{
|
||||
Run(BF16(), BF16(), BF16());
|
||||
return Run(BF16(), BF16(), BF16(), F32());
|
||||
}
|
||||
else if(data_type == 3)
|
||||
{
|
||||
Run(INT8(), INT8(), INT8());
|
||||
return Run(INT8(), INT8(), INT8(), int());
|
||||
}
|
||||
else
|
||||
{
|
||||
return 1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user