Update Group convolution (#341)

* add conv oddC

* update example

* update example

* fix bug in example

* fix bug in group conv example

[ROCm/composable_kernel commit: 75ab874e02]
This commit is contained in:
Chao Liu
2022-08-03 12:28:33 -05:00
committed by GitHub
parent dee2696501
commit 159f0bc1b4
6 changed files with 121 additions and 918 deletions

View File

@@ -89,6 +89,15 @@ int main(int argc, char* argv[])
int init_method = 1;
bool time_kernel = false;
// conventional group conv definition
// G = 2
// [N, C, Hi, Wi] = [128, 384, 71, 71]
// [K, C, Y, X] = [512, 192, 3, 3]
// [N, K, Ho, Wo] = [128, 512, 36, 36]
// CK group conv definition
// [G, N, C, Hi, Wi] = [2, 128, 192, 71, 71]
// [G, K, C, Y, X] = [2, 256, 192, 3, 3]
// [G, N, K, Ho, Wo] = [2, 128, 256, 36, 36]
ck::utils::conv::ConvParam conv_param{
2, 2, 128, 256, 192, {3, 3}, {71, 71}, {2, 2}, {1, 1}, {1, 1}, {1, 1}};
@@ -135,10 +144,10 @@ int main(int argc, char* argv[])
const auto wei_g_k_c_xs_desc = HostTensorDescriptor(
{conv_param.G_, conv_param.K_, conv_param.C_, conv_param.filter_spatial_lengths_[0]},
{
conv_param.C_, // g
conv_param.filter_spatial_lengths_[0] * conv_param.G_ * conv_param.C_, // k
conv_param.K_ * conv_param.filter_spatial_lengths_[0] * conv_param.C_, // g
conv_param.filter_spatial_lengths_[0] * conv_param.C_, // k
1, // c
conv_param.G_ * conv_param.C_ // x
conv_param.C_ // x
});
const auto bias_g_n_k_wos_desc = HostTensorDescriptor(
@@ -194,7 +203,7 @@ int main(int argc, char* argv[])
conv_param.input_spatial_lengths_[0],
conv_param.input_spatial_lengths_[1]},
{
conv_param.output_spatial_lengths_[0] * conv_param.C_, // g
conv_param.C_, // g
conv_param.input_spatial_lengths_[0] * conv_param.input_spatial_lengths_[1] *
conv_param.G_ * conv_param.C_, // n
1, // c
@@ -202,20 +211,21 @@ int main(int argc, char* argv[])
conv_param.G_ * conv_param.C_ // wi
});
const auto wei_g_k_c_xs_desc = HostTensorDescriptor(
{conv_param.G_,
conv_param.K_,
conv_param.C_,
conv_param.filter_spatial_lengths_[0],
conv_param.filter_spatial_lengths_[1]},
{
conv_param.C_, // g
conv_param.filter_spatial_lengths_[0] * conv_param.filter_spatial_lengths_[1] *
conv_param.G_ * conv_param.C_, // k
1, // c
conv_param.filter_spatial_lengths_[1] * conv_param.G_ * conv_param.C_, // y
conv_param.G_ * conv_param.C_ // x
});
const auto wei_g_k_c_xs_desc =
HostTensorDescriptor({conv_param.G_,
conv_param.K_,
conv_param.C_,
conv_param.filter_spatial_lengths_[0],
conv_param.filter_spatial_lengths_[1]},
{
conv_param.K_ * conv_param.filter_spatial_lengths_[0] *
conv_param.filter_spatial_lengths_[1] * conv_param.C_, // g
conv_param.filter_spatial_lengths_[0] *
conv_param.filter_spatial_lengths_[1] * conv_param.C_, // k
1, // c
conv_param.filter_spatial_lengths_[1] * conv_param.C_, // y
conv_param.C_ // x
});
const auto bias_g_n_k_wos_desc =
HostTensorDescriptor({conv_param.G_,
@@ -282,7 +292,7 @@ int main(int argc, char* argv[])
conv_param.input_spatial_lengths_[1],
conv_param.input_spatial_lengths_[2]},
{
conv_param.output_spatial_lengths_[0] * conv_param.C_, // g
conv_param.C_, // g
conv_param.input_spatial_lengths_[0] * conv_param.input_spatial_lengths_[1] *
conv_param.input_spatial_lengths_[2] * conv_param.G_ * conv_param.C_, // n
1, // c
@@ -300,14 +310,16 @@ int main(int argc, char* argv[])
conv_param.filter_spatial_lengths_[1],
conv_param.filter_spatial_lengths_[2]},
{
conv_param.C_, // g
conv_param.K_ * conv_param.filter_spatial_lengths_[0] *
conv_param.filter_spatial_lengths_[1] * conv_param.filter_spatial_lengths_[2] *
conv_param.C_, // g
conv_param.filter_spatial_lengths_[0] * conv_param.filter_spatial_lengths_[1] *
conv_param.filter_spatial_lengths_[2] * conv_param.G_ * conv_param.C_, // k
1, // c
conv_param.filter_spatial_lengths_[2] * conv_param.C_, // k
1, // c
conv_param.filter_spatial_lengths_[1] * conv_param.filter_spatial_lengths_[2] *
conv_param.G_ * conv_param.C_, // z
conv_param.filter_spatial_lengths_[2] * conv_param.G_ * conv_param.C_, // y
conv_param.G_ * conv_param.C_ // x
conv_param.C_, // z
conv_param.filter_spatial_lengths_[2] * conv_param.C_, // y
conv_param.C_ // x
});
const auto bias_g_n_k_wos_desc =