Grouped conv bwd weight with grouped gemm (#2304)

* Grouped conv bwd weight with grouped gemm

* fixes

* fix

* Fixes

* test comments

* restore atol

* fix

[ROCm/composable_kernel commit: bb4f471b09]
This commit is contained in:
Bartłomiej Kocot
2025-06-12 10:15:07 +02:00
committed by GitHub
parent a7eb83a51b
commit adfcb7ea2b
5 changed files with 242 additions and 153 deletions

View File

@@ -96,6 +96,18 @@ TYPED_TEST(TestGroupedConvndBwdDataXdl2d, Test2D)
{
this->conv_params.clear();
// GroupedGemmGroupsNum = 4, ZTilde * YTilde * XTilde = 4, MaxGroupedGemmGroupsNum = 32
this->conv_params.push_back(
{2, 2, 2, 16, 16, {3, 3}, {28, 28}, {2, 2}, {1, 1}, {1, 1}, {1, 1}});
// GroupedGemmGroupsNum = 9, ZTilde * YTilde * XTilde = 36, MaxGroupedGemmGroupsNum = 32
this->conv_params.push_back(
{2, 2, 2, 16, 16, {3, 3}, {28, 28}, {6, 6}, {1, 1}, {1, 1}, {1, 1}});
// GroupedGemmGroupsNum = 36, ZTilde * YTilde * XTilde = 36, MaxGroupedGemmGroupsNum = 32
this->conv_params.push_back(
{2, 2, 2, 16, 16, {6, 6}, {28, 28}, {6, 6}, {1, 1}, {1, 1}, {1, 1}});
// GroupedGemmGroupsNum = 32, ZTilde * YTilde * XTilde = 32, MaxGroupedGemmGroupsNum = 32
this->conv_params.push_back(
{2, 2, 2, 16, 16, {4, 8}, {28, 28}, {4, 8}, {1, 1}, {1, 1}, {1, 1}});
this->conv_params.push_back(
{2, 2, 2, 192, 192, {3, 3}, {28, 28}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
this->conv_params.push_back(