Grouped conv bwd wei NDHWGC/NDHWGK (#804)

[ROCm/composable_kernel commit: 10732847e7]
This commit is contained in:
Bartłomiej Kocot
2023-07-21 19:00:55 +02:00
committed by GitHub
parent b795a0f549
commit 4687f4eb2a
9 changed files with 414 additions and 252 deletions

View File

@@ -83,19 +83,7 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[])
using F16 = ck::half_t;
using BF16 = ck::bhalf_t;
using GNWC = ck::tensor_layout::convolution::GNWC;
using GNHWC = ck::tensor_layout::convolution::GNHWC;
using NHWGC = ck::tensor_layout::convolution::NHWGC;
using GNDHWC = ck::tensor_layout::convolution::GNDHWC;
using GKXC = ck::tensor_layout::convolution::GKXC;
using GKYXC = ck::tensor_layout::convolution::GKYXC;
using GKZYXC = ck::tensor_layout::convolution::GKZYXC;
using GNWK = ck::tensor_layout::convolution::GNWK;
using GNHWK = ck::tensor_layout::convolution::GNHWK;
using NHWGK = ck::tensor_layout::convolution::NHWGK;
using GNDHWK = ck::tensor_layout::convolution::GNDHWK;
using namespace ck::tensor_layout::convolution;
constexpr auto I1 = ck::Number<1>{};
constexpr auto I2 = ck::Number<2>{};
@@ -194,6 +182,22 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[])
return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, BF16{}, F32{}, BF16{});
}
}
else if(num_dim_spatial == 3 && layout == ConvLayout::NHWGC_GKYXC_NHWGK)
{
if(data_type == ConvDataType::F32_F32_F32)
{
return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F32{}, F32{}, F32{});
}
else if(data_type == ConvDataType::F16_F16_F16)
{
return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F16{}, F16{}, F16{});
}
else if(data_type == ConvDataType::BF16_F32_BF16)
{
// fp32 atomic add is used for weight tensor in bf16 kernel
return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, BF16{}, F32{}, BF16{});
}
}
std::cout << "this data_type & layout is not implemented" << std::endl;