mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-18 12:00:07 +00:00
Grouped conv bwd wei NDHWGC/NDHWGK (#804)
[ROCm/composable_kernel commit: 10732847e7]
This commit is contained in:
@@ -83,19 +83,7 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[])
|
||||
using F16 = ck::half_t;
|
||||
using BF16 = ck::bhalf_t;
|
||||
|
||||
using GNWC = ck::tensor_layout::convolution::GNWC;
|
||||
using GNHWC = ck::tensor_layout::convolution::GNHWC;
|
||||
using NHWGC = ck::tensor_layout::convolution::NHWGC;
|
||||
using GNDHWC = ck::tensor_layout::convolution::GNDHWC;
|
||||
|
||||
using GKXC = ck::tensor_layout::convolution::GKXC;
|
||||
using GKYXC = ck::tensor_layout::convolution::GKYXC;
|
||||
using GKZYXC = ck::tensor_layout::convolution::GKZYXC;
|
||||
|
||||
using GNWK = ck::tensor_layout::convolution::GNWK;
|
||||
using GNHWK = ck::tensor_layout::convolution::GNHWK;
|
||||
using NHWGK = ck::tensor_layout::convolution::NHWGK;
|
||||
using GNDHWK = ck::tensor_layout::convolution::GNDHWK;
|
||||
using namespace ck::tensor_layout::convolution;
|
||||
|
||||
constexpr auto I1 = ck::Number<1>{};
|
||||
constexpr auto I2 = ck::Number<2>{};
|
||||
@@ -194,6 +182,22 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[])
|
||||
return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, BF16{}, F32{}, BF16{});
|
||||
}
|
||||
}
|
||||
else if(num_dim_spatial == 3 && layout == ConvLayout::NHWGC_GKYXC_NHWGK)
|
||||
{
|
||||
if(data_type == ConvDataType::F32_F32_F32)
|
||||
{
|
||||
return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F32{}, F32{}, F32{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F16_F16_F16)
|
||||
{
|
||||
return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F16{}, F16{}, F16{});
|
||||
}
|
||||
else if(data_type == ConvDataType::BF16_F32_BF16)
|
||||
{
|
||||
// fp32 atomic add is used for weight tensor in bf16 kernel
|
||||
return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, BF16{}, F32{}, BF16{});
|
||||
}
|
||||
}
|
||||
|
||||
std::cout << "this data_type & layout is not implemented" << std::endl;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user