diff --git a/test/grouped_convnd_bwd_weight/grouped_convnd_bwd_weight.cpp b/test/grouped_convnd_bwd_weight/grouped_convnd_bwd_weight.cpp index e14173cb41..6545b6e566 100644 --- a/test/grouped_convnd_bwd_weight/grouped_convnd_bwd_weight.cpp +++ b/test/grouped_convnd_bwd_weight/grouped_convnd_bwd_weight.cpp @@ -61,7 +61,7 @@ TYPED_TEST(TestGroupedConvndBwdWeight, Test1D) { this->conv_params.clear(); this->conv_params.push_back({1, 4, 128, 128, 256, {1}, {14}, {2}, {1}, {0}, {0}}); - this->conv_params.push_back({1, 4, 128, 128, 256, {3}, {28}, {1}, {1}, {1}, {1}}); + this->conv_params.push_back({1, 4, 64, 128, 256, {3}, {28}, {1}, {1}, {1}, {1}}); this->conv_params.push_back({1, 4, 128, 128, 256, {1}, {3}, {1}, {1}, {0}, {0}}); this->template Run<1>(); } @@ -72,7 +72,7 @@ TYPED_TEST(TestGroupedConvndBwdWeight, Test2D) this->conv_params.push_back( {2, 4, 128, 128, 256, {1, 1}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0}}); this->conv_params.push_back( - {2, 4, 32, 128, 256, {3, 3}, {14, 14}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); + {2, 4, 8, 128, 256, {3, 3}, {14, 14}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); this->conv_params.push_back( {2, 4, 128, 128, 256, {1, 1}, {3, 3}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}); this->template Run<2>(); @@ -84,7 +84,7 @@ TYPED_TEST(TestGroupedConvndBwdWeight, Test3D) this->conv_params.push_back( {3, 4, 128, 128, 256, {1, 1, 1}, {7, 7, 7}, {2, 2, 2}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); this->conv_params.push_back( - {3, 4, 32, 128, 256, {3, 3, 3}, {14, 14, 3}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + {3, 4, 8, 128, 256, {3, 3, 3}, {14, 14, 3}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); this->conv_params.push_back( {3, 4, 128, 128, 256, {1, 1, 1}, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); this->template Run<3>();