mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-24 14:54:47 +00:00
Replace grouped conv bwd wei wmmaV3 bilin/scale bf16f32bf16 support with bf16bf16bf16 (#3470)
* Replace grouped convolution bwd weight wmma v3 bilinear and scale bf16f32bf16 support with bf16bf16bf16 support. Update tests.
* Tentative fix for bwd weight bilinear bf16bf16bf16, seems like the bilinear elementwise overload for this case (bf16, f32 accu, bf16) was wrong.
[ROCm/composable_kernel commit: 88ae445580]
This commit is contained in:
committed by
GitHub
parent
9045cafc8c
commit
04d4dd1ada
@@ -296,6 +296,7 @@ class TestGroupedConvndBwdWeight3d : public TestGroupedConvndBwdWeight<Tuple>
|
||||
using KernelTypes3d =
|
||||
::testing::Types<std::tuple<float, float, float, ck::Number<3>>,
|
||||
std::tuple<ck::half_t, ck::half_t, ck::half_t, ck::Number<3>>,
|
||||
std::tuple<ck::bhalf_t, ck::bhalf_t, ck::bhalf_t, ck::Number<3>>,
|
||||
std::tuple<ck::bhalf_t, float, ck::bhalf_t, ck::Number<3>>>;
|
||||
|
||||
TYPED_TEST_SUITE(TestGroupedConvndBwdWeight3d, KernelTypes3d);
|
||||
|
||||
@@ -269,6 +269,7 @@ class TestGroupedConvndBwdWeight3d : public TestGroupedConvndBwdWeight<Tuple>
|
||||
using KernelTypes3d =
|
||||
::testing::Types<std::tuple<float, float, float, ck::Number<3>>,
|
||||
std::tuple<ck::half_t, ck::half_t, ck::half_t, ck::Number<3>>,
|
||||
std::tuple<ck::bhalf_t, ck::bhalf_t, ck::bhalf_t, ck::Number<3>>,
|
||||
std::tuple<ck::bhalf_t, float, ck::bhalf_t, ck::Number<3>>>;
|
||||
|
||||
TYPED_TEST_SUITE(TestGroupedConvndBwdWeight3d, KernelTypes3d);
|
||||
|
||||
Reference in New Issue
Block a user