[CK] Add BF16^3 support to grouped conv bwd weight: bilinear and scale (#4591)

## Motivation

Until now, XDL grouped conv bwd weight for bilinear and scale only
supported bf16f32bf16.
Therefore, bf16bf16bf16 support should be added.

## Technical Details

Instances were added to the relevant files in
`library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/`
folder. In addition, `add()` functions were included in new files in
`library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight_bilinear/xdl/`
and
`library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight_scale/xdl/`
folders. The new .cpp files were also included in the `CMakeFiles.txt`
files of both folders.

## Test Plan

Execute `grouped_convnd_bwd_weight` tests to check execution on
different architectures.
The tests for bilinear and scale already include the tuple
`std::tuple<ck::half_t, ck::half_t, ck::half_t, ck::Number<3>>`, so in
principle, there is nothing to modify in the tests themselves.

## Test Result

`gfx1201`: Tests passed.
`gfx1100`: Tests passed.
`gfx90a`: Tests passed.

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.

---------

Co-authored-by: Fernando Jiménez <fernando.jimenez@streamhpc.com>
This commit is contained in:
JP-Fernando
2026-03-11 13:05:44 +01:00
committed by GitHub
parent 25d9fdfc16
commit 5d4107862b
10 changed files with 195 additions and 30 deletions

View File

@@ -6,7 +6,9 @@ set(GROUPED_CONV3D_BWD_WEIGHT_BILINEAR
xdl/device_grouped_conv3d_bwd_weight_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp
xdl/device_grouped_conv3d_bwd_weight_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp
xdl/device_grouped_conv3d_bwd_weight_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp
xdl/device_grouped_conv3d_bwd_weight_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp)
xdl/device_grouped_conv3d_bwd_weight_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
xdl/device_grouped_conv3d_bwd_weight_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instance.cpp
)
if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "bf8" AND DTYPES MATCHES "fp16") OR NOT DEFINED DTYPES)
list(APPEND GROUPED_CONV3D_BWD_WEIGHT_BILINEAR

View File

@@ -0,0 +1,50 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_bilinear_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv3d_bwd_weight_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeightMultipleD<3,
NDHWGC,
GKZYXC,
NDHWGK,
Tuple<GKZYXC>,
BF16,
F32,
BF16,
Tuple<F32>,
PassThrough,
Bilinear,
PassThrough>>>& instances)
{
// 1. Default
add_device_operation_instances(
instances,
device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_f32_bf16_bilinear_instances<
3,
NDHWGC,
GKZYXC,
NDHWGK,
ConvBwdWeightDefault>{});
// 2. Filter1x1Stride1Pad0
add_device_operation_instances(
instances,
device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_f32_bf16_bilinear_instances<
3,
NDHWGC,
GKZYXC,
NDHWGK,
ConvBwdWeightFilter1x1Stride1Pad0>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -9,22 +9,21 @@ namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv3d_bwd_weight_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances(
void add_device_grouped_conv3d_bwd_weight_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeightMultipleD<3,
NDHWGC,
GKZYXC,
NDHWGK,
Tuple<GKZYXC>,
BF16,
F32,
BF16,
Tuple<F32>,
BF16,
Tuple<BF16>,
PassThrough,
Bilinear,
PassThrough>>>& instances)
{
// 1. Default
// Default bwd weight bilinear
add_device_operation_instances(
instances,
device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_bilinear_instances<
@@ -33,15 +32,6 @@ void add_device_grouped_conv3d_bwd_weight_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16
GKZYXC,
NDHWGK,
ConvBwdWeightDefault>{});
// 2. Filter1x1Stride1Pad0
add_device_operation_instances(
instances,
device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_bilinear_instances<
3,
NDHWGC,
GKZYXC,
NDHWGK,
ConvBwdWeightFilter1x1Stride1Pad0>{});
}
} // namespace instance

View File

@@ -6,7 +6,9 @@ set(GROUPED_CONV3D_BWD_WEIGHT_SCALE
xdl/device_grouped_conv3d_bwd_weight_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp
xdl/device_grouped_conv3d_bwd_weight_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp
xdl/device_grouped_conv3d_bwd_weight_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp
xdl/device_grouped_conv3d_bwd_weight_xdl_scale_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp)
xdl/device_grouped_conv3d_bwd_weight_xdl_scale_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
xdl/device_grouped_conv3d_bwd_weight_xdl_scale_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instance.cpp
)
if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "bf8" AND DTYPES MATCHES "fp16") OR NOT DEFINED DTYPES)
list(APPEND GROUPED_CONV3D_BWD_WEIGHT_SCALE

View File

@@ -0,0 +1,51 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_scale_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv3d_bwd_weight_xdl_scale_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeightMultipleD<3,
NDHWGC,
GKZYXC,
NDHWGK,
Tuple<>,
BF16,
F32,
BF16,
Tuple<>,
PassThrough,
Scale,
PassThrough>>>& instances)
{
// 1. Default
add_device_operation_instances(
instances,
device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_f32_bf16_scale_instances<
3,
NDHWGC,
GKZYXC,
NDHWGK,
ConvBwdWeightDefault>{});
// 2. Filter1x1Stride1Pad0
add_device_operation_instances(
instances,
device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_f32_bf16_scale_instances<
3,
NDHWGC,
GKZYXC,
NDHWGK,
ConvBwdWeightFilter1x1Stride1Pad0>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -10,21 +10,21 @@ namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv3d_bwd_weight_xdl_scale_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances(
void add_device_grouped_conv3d_bwd_weight_xdl_scale_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeightMultipleD<3,
NDHWGC,
GKZYXC,
NDHWGK,
Tuple<>,
BF16,
F32,
BF16,
BF16,
Tuple<>,
PassThrough,
Scale,
PassThrough>>>& instances)
{
// 1. Default
// Default conv bwd weight
add_device_operation_instances(
instances,
device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_scale_instances<3,
@@ -32,15 +32,6 @@ void add_device_grouped_conv3d_bwd_weight_xdl_scale_ndhwgc_gkzyxc_ndhwgk_bf16_f3
GKZYXC,
NDHWGK,
ConvBwdWeightDefault>{});
// 2. Filter1x1Stride1Pad0
add_device_operation_instances(
instances,
device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_scale_instances<
3,
NDHWGC,
GKZYXC,
NDHWGK,
ConvBwdWeightFilter1x1Stride1Pad0>{});
}
} // namespace instance