Replace grouped conv bwd wei wmmaV3 bilin/scale bf16f32bf16 support with bf16bf16bf16 (#3470)

* Replace grouped convolution bwd weight wmma v3 bilinear and scale bf16f32bf16 support with bf16bf16bf16 support. Update tests.

* Tentative fix for bwd weight bilinear bf16bf16bf16, seems like the bilinear elementwise overload for this case (bf16, f32 accu, bf16) was wrong.

[ROCm/composable_kernel commit: 88ae445580]
This commit is contained in:
Kiefer van Teutem
2025-12-29 12:58:29 +01:00
committed by GitHub
parent 9045cafc8c
commit 04d4dd1ada
10 changed files with 47 additions and 46 deletions

View File

@@ -10,16 +10,16 @@ namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv3d_bwd_weight_wmma_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances(
void add_device_grouped_conv3d_bwd_weight_wmma_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_bf16_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeightMultipleD<3,
NDHWGC,
GKZYXC,
NDHWGK,
Tuple<GKZYXC>,
BF16,
F32,
BF16,
Tuple<F32>,
BF16,
Tuple<BF16>,
PassThrough,
Bilinear,
PassThrough>>>& instances)

View File

@@ -10,14 +10,14 @@ namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv3d_bwd_weight_wmma_scale_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances(
void add_device_grouped_conv3d_bwd_weight_wmma_scale_ndhwgc_gkzyxc_ndhwgk_bf16_bf16_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeightMultipleD<3,
NDHWGC,
GKZYXC,
NDHWGK,
Tuple<>,
BF16,
F32,
BF16,
BF16,
Tuple<>,
PassThrough,