mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-17 03:19:48 +00:00
Update Group convolution (#341)
* add conv oddC
* update example
* update example
* fix bug in example
* fix bug in group conv example
[ROCm/composable_kernel commit: 75ab874e02]
This commit is contained in:
@@ -89,6 +89,15 @@ int main(int argc, char* argv[])
|
||||
int init_method = 1;
|
||||
bool time_kernel = false;
|
||||
|
||||
// conventional group conv definition
|
||||
// G = 2
|
||||
// [N, C, Hi, Wi] = [128, 384, 71, 71]
|
||||
// [K, C, Y, X] = [512, 192, 3, 3]
|
||||
// [N, K, Ho, Wo] = [128, 512, 36, 36]
|
||||
// CK group conv definition
|
||||
// [G, N, C, Hi, Wi] = [2, 128, 192, 71, 71]
|
||||
// [G, K, C, Y, X] = [2, 256, 192, 3, 3]
|
||||
// [G, N, K, Ho, Wo] = [2, 128, 256, 36, 36]
|
||||
ck::utils::conv::ConvParam conv_param{
|
||||
2, 2, 128, 256, 192, {3, 3}, {71, 71}, {2, 2}, {1, 1}, {1, 1}, {1, 1}};
|
||||
|
||||
@@ -135,10 +144,10 @@ int main(int argc, char* argv[])
|
||||
const auto wei_g_k_c_xs_desc = HostTensorDescriptor(
|
||||
{conv_param.G_, conv_param.K_, conv_param.C_, conv_param.filter_spatial_lengths_[0]},
|
||||
{
|
||||
conv_param.C_, // g
|
||||
conv_param.filter_spatial_lengths_[0] * conv_param.G_ * conv_param.C_, // k
|
||||
conv_param.K_ * conv_param.filter_spatial_lengths_[0] * conv_param.C_, // g
|
||||
conv_param.filter_spatial_lengths_[0] * conv_param.C_, // k
|
||||
1, // c
|
||||
conv_param.G_ * conv_param.C_ // x
|
||||
conv_param.C_ // x
|
||||
});
|
||||
|
||||
const auto bias_g_n_k_wos_desc = HostTensorDescriptor(
|
||||
@@ -194,7 +203,7 @@ int main(int argc, char* argv[])
|
||||
conv_param.input_spatial_lengths_[0],
|
||||
conv_param.input_spatial_lengths_[1]},
|
||||
{
|
||||
conv_param.output_spatial_lengths_[0] * conv_param.C_, // g
|
||||
conv_param.C_, // g
|
||||
conv_param.input_spatial_lengths_[0] * conv_param.input_spatial_lengths_[1] *
|
||||
conv_param.G_ * conv_param.C_, // n
|
||||
1, // c
|
||||
@@ -202,20 +211,21 @@ int main(int argc, char* argv[])
|
||||
conv_param.G_ * conv_param.C_ // wi
|
||||
});
|
||||
|
||||
const auto wei_g_k_c_xs_desc = HostTensorDescriptor(
|
||||
{conv_param.G_,
|
||||
conv_param.K_,
|
||||
conv_param.C_,
|
||||
conv_param.filter_spatial_lengths_[0],
|
||||
conv_param.filter_spatial_lengths_[1]},
|
||||
{
|
||||
conv_param.C_, // g
|
||||
conv_param.filter_spatial_lengths_[0] * conv_param.filter_spatial_lengths_[1] *
|
||||
conv_param.G_ * conv_param.C_, // k
|
||||
1, // c
|
||||
conv_param.filter_spatial_lengths_[1] * conv_param.G_ * conv_param.C_, // y
|
||||
conv_param.G_ * conv_param.C_ // x
|
||||
});
|
||||
const auto wei_g_k_c_xs_desc =
|
||||
HostTensorDescriptor({conv_param.G_,
|
||||
conv_param.K_,
|
||||
conv_param.C_,
|
||||
conv_param.filter_spatial_lengths_[0],
|
||||
conv_param.filter_spatial_lengths_[1]},
|
||||
{
|
||||
conv_param.K_ * conv_param.filter_spatial_lengths_[0] *
|
||||
conv_param.filter_spatial_lengths_[1] * conv_param.C_, // g
|
||||
conv_param.filter_spatial_lengths_[0] *
|
||||
conv_param.filter_spatial_lengths_[1] * conv_param.C_, // k
|
||||
1, // c
|
||||
conv_param.filter_spatial_lengths_[1] * conv_param.C_, // y
|
||||
conv_param.C_ // x
|
||||
});
|
||||
|
||||
const auto bias_g_n_k_wos_desc =
|
||||
HostTensorDescriptor({conv_param.G_,
|
||||
@@ -282,7 +292,7 @@ int main(int argc, char* argv[])
|
||||
conv_param.input_spatial_lengths_[1],
|
||||
conv_param.input_spatial_lengths_[2]},
|
||||
{
|
||||
conv_param.output_spatial_lengths_[0] * conv_param.C_, // g
|
||||
conv_param.C_, // g
|
||||
conv_param.input_spatial_lengths_[0] * conv_param.input_spatial_lengths_[1] *
|
||||
conv_param.input_spatial_lengths_[2] * conv_param.G_ * conv_param.C_, // n
|
||||
1, // c
|
||||
@@ -300,14 +310,16 @@ int main(int argc, char* argv[])
|
||||
conv_param.filter_spatial_lengths_[1],
|
||||
conv_param.filter_spatial_lengths_[2]},
|
||||
{
|
||||
conv_param.C_, // g
|
||||
conv_param.K_ * conv_param.filter_spatial_lengths_[0] *
|
||||
conv_param.filter_spatial_lengths_[1] * conv_param.filter_spatial_lengths_[2] *
|
||||
conv_param.C_, // g
|
||||
conv_param.filter_spatial_lengths_[0] * conv_param.filter_spatial_lengths_[1] *
|
||||
conv_param.filter_spatial_lengths_[2] * conv_param.G_ * conv_param.C_, // k
|
||||
1, // c
|
||||
conv_param.filter_spatial_lengths_[2] * conv_param.C_, // k
|
||||
1, // c
|
||||
conv_param.filter_spatial_lengths_[1] * conv_param.filter_spatial_lengths_[2] *
|
||||
conv_param.G_ * conv_param.C_, // z
|
||||
conv_param.filter_spatial_lengths_[2] * conv_param.G_ * conv_param.C_, // y
|
||||
conv_param.G_ * conv_param.C_ // x
|
||||
conv_param.C_, // z
|
||||
conv_param.filter_spatial_lengths_[2] * conv_param.C_, // y
|
||||
conv_param.C_ // x
|
||||
});
|
||||
|
||||
const auto bias_g_n_k_wos_desc =
|
||||
|
||||
Reference in New Issue
Block a user