From bc2c2fd54e8be42ba38563abec04b5aaa9085cdb Mon Sep 17 00:00:00 2001 From: kiefer Date: Mon, 1 Sep 2025 12:19:47 +0000 Subject: [PATCH] Add some more tests for vanilla grouped conv fwd --- .../test_grouped_convnd_fwd.cpp | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/test/grouped_convnd_fwd/test_grouped_convnd_fwd.cpp b/test/grouped_convnd_fwd/test_grouped_convnd_fwd.cpp index 6cb4f1eed3..35ddee94e2 100644 --- a/test/grouped_convnd_fwd/test_grouped_convnd_fwd.cpp +++ b/test/grouped_convnd_fwd/test_grouped_convnd_fwd.cpp @@ -135,9 +135,8 @@ TYPED_TEST(TestGroupedConvndFwd2d, Test2D) this->conv_params.clear(); // TODO: not all filter sizes accepted at the moment, related to output N size and // CDEBlockTransferScalarPerVector_NPerBlock - // this->conv_params.push_back( - // {2, 3, 5, 96, 200, {1, 1}, {73, 128}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}); - + this->conv_params.push_back( + {2, 3, 5, 96, 200, {1, 1}, {73, 128}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}); this->conv_params.push_back( {2, 1, 1, 32, 32, {1, 1}, {128, 128}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}); this->conv_params.push_back( @@ -149,8 +148,8 @@ TYPED_TEST(TestGroupedConvndFwd2d, Test2D) this->conv_params.push_back( {2, 1, 1, 32, 32, {9, 9}, {128, 128}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}); - // this->conv_params.push_back( - // {2, 2, 32, 128, 256, {1, 1}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0}}); + this->conv_params.push_back( + {2, 2, 32, 128, 256, {1, 1}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0}}); this->conv_params.push_back( {2, 2, 32, 128, 256, {3, 3}, {14, 14}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); @@ -169,14 +168,16 @@ TYPED_TEST(TestGroupedConvndFwd2d, Test2D) TYPED_TEST(TestGroupedConvndFwd3d, Test3D) { this->conv_params.clear(); - // this->conv_params.push_back( - // {3, 3, 5, 96, 200, {1, 1, 1}, {17, 27, 13}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 3, 5, 96, 200, {1, 1, 1}, {17, 27, 13}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); this->conv_params.push_back( {3, 2, 32, 128, 256, {1, 1, 1}, {7, 7, 7}, {2, 2, 2}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); this->conv_params.push_back( {3, 2, 32, 128, 256, {3, 3, 3}, {14, 14, 3}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); this->conv_params.push_back( {3, 2, 32, 128, 256, {1, 1, 1}, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 1, 1, 32, 32, {1, 1, 1}, {16, 16, 16}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); this->conv_params.push_back( {3, 1, 1, 1, 32, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); this->conv_params.push_back(