Replace grouped conv bwd wei wmmaV3 bilin/scale bf16f32bf16 support with bf16bf16bf16 (#3470)

* Replace grouped convolution bwd weight wmma v3 bilinear and scale bf16f32bf16 support with bf16bf16bf16 support. Update tests.

* Tentative fix for bwd weight bilinear bf16bf16bf16, seems like the bilinear elementwise overload for this case (bf16, f32 accu, bf16) was wrong.

[ROCm/composable_kernel commit: 88ae445580]
This commit is contained in:
Kiefer van Teutem
2025-12-29 12:58:29 +01:00
committed by GitHub
parent 13134864cc
commit ac28f1b016
10 changed files with 47 additions and 46 deletions

View File

@@ -296,6 +296,7 @@ class TestGroupedConvndBwdWeight3d : public TestGroupedConvndBwdWeight<Tuple>
using KernelTypes3d =
::testing::Types<std::tuple<float, float, float, ck::Number<3>>,
std::tuple<ck::half_t, ck::half_t, ck::half_t, ck::Number<3>>,
std::tuple<ck::bhalf_t, ck::bhalf_t, ck::bhalf_t, ck::Number<3>>,
std::tuple<ck::bhalf_t, float, ck::bhalf_t, ck::Number<3>>>;
TYPED_TEST_SUITE(TestGroupedConvndBwdWeight3d, KernelTypes3d);

View File

@@ -269,6 +269,7 @@ class TestGroupedConvndBwdWeight3d : public TestGroupedConvndBwdWeight<Tuple>
using KernelTypes3d =
::testing::Types<std::tuple<float, float, float, ck::Number<3>>,
std::tuple<ck::half_t, ck::half_t, ck::half_t, ck::Number<3>>,
std::tuple<ck::bhalf_t, ck::bhalf_t, ck::bhalf_t, ck::Number<3>>,
std::tuple<ck::bhalf_t, float, ck::bhalf_t, ck::Number<3>>>;
TYPED_TEST_SUITE(TestGroupedConvndBwdWeight3d, KernelTypes3d);