mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 01:10:17 +00:00
Add support for NGCHW in grouped conv fwd (#1499)
* Support NGCHW in grouped conv fwd * Remove not needed variable * Fixes
This commit is contained in:
@@ -45,6 +45,8 @@ static void print_helper_msg()
|
||||
"N, Ho, Wo, K]\n"
|
||||
<< " 2: Input[N, Hi, Wi, G, C], Weight[G, K, Y, X, C], Output[N, "
|
||||
"Ho, Wo, G, K]\n"
|
||||
<< " 3: Input[N, G, C, Hi, Wi], Weight[G, K, Y, X, C], Output[N, "
|
||||
"G, K, Ho, Wo]\n"
|
||||
<< "arg4: verification (0: no, 1: yes)\n"
|
||||
<< "arg5: initialization (0: no init, 1: integer value, 2: decimal value)\n"
|
||||
<< "arg6: print tensor value (0: no; 1: yes)\n"
|
||||
|
||||
@@ -15,6 +15,7 @@ enum struct ConvLayout
|
||||
{
|
||||
GNHWC_GKYXC_GNHWK, // 0
|
||||
NHWGC_GKYXC_NHWGK, // 1
|
||||
NGCHW_GKYXC_NGKHW, // 2
|
||||
};
|
||||
|
||||
enum struct ConvDataType
|
||||
@@ -54,6 +55,8 @@ static void print_helper_msg()
|
||||
<< "arg3: indexing data type (0: 32-bit, 1: 64-bit)\n"
|
||||
<< "arg4: tensor layout (0: Input[G, N, Hi, Wi, C], Weight[G, K, Y, X, C], Output[G, N, Ho, Wo, K]\n"
|
||||
<< " 1: Input[N, Hi, Wi, G, C], Weight[G, K, Y, X, C], Output[N, Ho, Wo, G, K])\n"
|
||||
<< " 2: Input[N, G, C, Hi, Wi], Weight[G, K, Y, X, C], Output[N, "
|
||||
"G, K, Ho, Wo]\n"
|
||||
<< "arg5: verification (0: no, 1: yes)\n"
|
||||
<< "arg6: initialization (0: no init, 1: integer value, 2: decimal value)\n"
|
||||
<< "arg7: print tensor value (0: no; 1: yes)\n"
|
||||
@@ -111,6 +114,11 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
|
||||
using GNHWK = ck::tensor_layout::convolution::GNHWK;
|
||||
using GNDHWK = ck::tensor_layout::convolution::GNDHWK;
|
||||
|
||||
//
|
||||
using NGCHW = ck::tensor_layout::convolution::NGCHW;
|
||||
|
||||
using NGKHW = ck::tensor_layout::convolution::NGKHW;
|
||||
|
||||
//
|
||||
using NWGC = ck::tensor_layout::convolution::NWGC;
|
||||
using NHWGC = ck::tensor_layout::convolution::NHWGC;
|
||||
@@ -284,6 +292,17 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
|
||||
return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, INT8{}, INT8{}, INT8{}, INT8{}, INT8{});
|
||||
}
|
||||
}
|
||||
else if(num_dim_spatial == 2 && layout == ConvLayout::NGCHW_GKYXC_NGKHW)
|
||||
{
|
||||
if(data_type == ConvDataType::F32_F32_F32)
|
||||
{
|
||||
return profile(I2, NGCHW{}, GKYXC{}, NGKHW{}, F32{}, F32{}, F32{}, F32{}, F32{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F16_F16_F16)
|
||||
{
|
||||
return profile(I2, NGCHW{}, GKYXC{}, NGKHW{}, F16{}, F16{}, F16{}, F16{}, F16{});
|
||||
}
|
||||
}
|
||||
else if(num_dim_spatial == 3 && layout == ConvLayout::NHWGC_GKYXC_NHWGK)
|
||||
{
|
||||
if(data_type == ConvDataType::F32_F32_F32)
|
||||
|
||||
Reference in New Issue
Block a user