mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-19 12:30:16 +00:00
Replace grouped conv bwd wei wmmaV3 bilin/scale bf16f32bf16 support with bf16bf16bf16 (#3470)
* Replace grouped convolution bwd weight wmma v3 bilinear and scale bf16f32bf16 support with bf16bf16bf16 support. Update tests.
* Tentative fix for bwd weight bilinear bf16bf16bf16, seems like the bilinear elementwise overload for this case (bf16, f32 accu, bf16) was wrong.
[ROCm/composable_kernel commit: 88ae445580]
This commit is contained in:
committed by
GitHub
parent
9045cafc8c
commit
04d4dd1ada
@@ -10,16 +10,16 @@ namespace device {
|
||||
namespace instance {
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv3d_bwd_weight_wmma_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances(
|
||||
void add_device_grouped_conv3d_bwd_weight_wmma_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_bf16_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeightMultipleD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
Tuple<GKZYXC>,
|
||||
BF16,
|
||||
F32,
|
||||
BF16,
|
||||
Tuple<F32>,
|
||||
BF16,
|
||||
Tuple<BF16>,
|
||||
PassThrough,
|
||||
Bilinear,
|
||||
PassThrough>>>& instances)
|
||||
|
||||
@@ -10,14 +10,14 @@ namespace device {
|
||||
namespace instance {
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv3d_bwd_weight_wmma_scale_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances(
|
||||
void add_device_grouped_conv3d_bwd_weight_wmma_scale_ndhwgc_gkzyxc_ndhwgk_bf16_bf16_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeightMultipleD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
Tuple<>,
|
||||
BF16,
|
||||
F32,
|
||||
BF16,
|
||||
BF16,
|
||||
Tuple<>,
|
||||
PassThrough,
|
||||
|
||||
Reference in New Issue
Block a user