Grouped Conv Bwd Data out index calculation optimizations (#2917)

* Grouped Conv Bwd Data index calculation optimizations

* fixes

* refactor instances

* gfx12 fixes

* temporary disable splitK for gfx12

[ROCm/composable_kernel commit: 5477811670]
This commit is contained in:
Bartłomiej Kocot
2025-09-29 15:59:11 +02:00
committed by GitHub
parent 4bc708f401
commit ef933ee241
17 changed files with 895 additions and 75 deletions

View File

@@ -10,6 +10,9 @@ add_instance_library(
xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f16_16_16_instance.cpp
xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_bf16_16_16_instance.cpp
xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f32_16_16_instance.cpp
xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f16_optimized_loads_instance.cpp
xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_bf16_optimized_loads_instance.cpp
xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f32_optimized_loads_instance.cpp
xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkyxc_ngkhw_f16_instance.cpp
xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkyxc_ngkhw_bf16_instance.cpp
xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkyxc_ngkhw_f32_instance.cpp

View File

@@ -0,0 +1,49 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_optimized_loads_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
NHWGK,
GKYXC,
Empty_Tuple,
NHWGC,
BF16,
BF16,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
// 1. Default
add_device_operation_instances(
instances,
device_grouped_conv_bwd_data_xdl_bf16_optimized_loads_instances<2,
NHWGK,
GKYXC,
Empty_Tuple,
NHWGC,
ConvBwdDataDefault>{});
// 2. Filter1x1Stride1Pad0
add_device_operation_instances(instances,
device_grouped_conv_bwd_data_xdl_bf16_optimized_loads_instances<
2,
NHWGK,
GKYXC,
Empty_Tuple,
NHWGC,
ConvBwdDataFilter1x1Stride1Pad0>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,49 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_optimized_loads_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
NHWGK,
GKYXC,
Empty_Tuple,
NHWGC,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
// 1. Default
add_device_operation_instances(
instances,
device_grouped_conv_bwd_data_xdl_f16_optimized_loads_instances<2,
NHWGK,
GKYXC,
Empty_Tuple,
NHWGC,
ConvBwdDataDefault>{});
// 2. Filter1x1Stride1Pad0
add_device_operation_instances(instances,
device_grouped_conv_bwd_data_xdl_f16_optimized_loads_instances<
2,
NHWGK,
GKYXC,
Empty_Tuple,
NHWGC,
ConvBwdDataFilter1x1Stride1Pad0>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,49 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_optimized_loads_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
NHWGK,
GKYXC,
Empty_Tuple,
NHWGC,
F32,
F32,
Empty_Tuple,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
// 1. Default
add_device_operation_instances(
instances,
device_grouped_conv_bwd_data_xdl_f32_optimized_loads_instances<2,
NHWGK,
GKYXC,
Empty_Tuple,
NHWGC,
ConvBwdDataDefault>{});
// 2. Filter1x1Stride1Pad0
add_device_operation_instances(instances,
device_grouped_conv_bwd_data_xdl_f32_optimized_loads_instances<
2,
NHWGK,
GKYXC,
Empty_Tuple,
NHWGC,
ConvBwdDataFilter1x1Stride1Pad0>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -9,6 +9,9 @@ set(GROUPED_CONV3D_BWD_DATA
xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f16_16_16_instance.cpp
xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16_16_instance.cpp
xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f32_16_16_instance.cpp
xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f16_optimized_loads_instance.cpp
xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_optimized_loads_instance.cpp
xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f32_optimized_loads_instance.cpp
xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkzyxc_ngkdhw_f16_instance.cpp
xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_instance.cpp
xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkzyxc_ngkdhw_f32_instance.cpp

View File

@@ -0,0 +1,49 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_optimized_loads_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
NDHWGK,
GKZYXC,
Empty_Tuple,
NDHWGC,
BF16,
BF16,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
// 1. Default
add_device_operation_instances(
instances,
device_grouped_conv_bwd_data_xdl_bf16_optimized_loads_instances<3,
NDHWGK,
GKZYXC,
Empty_Tuple,
NDHWGC,
ConvBwdDataDefault>{});
// 2. Filter1x1Stride1Pad0
add_device_operation_instances(instances,
device_grouped_conv_bwd_data_xdl_bf16_optimized_loads_instances<
3,
NDHWGK,
GKZYXC,
Empty_Tuple,
NDHWGC,
ConvBwdDataFilter1x1Stride1Pad0>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,49 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_optimized_loads_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
NDHWGK,
GKZYXC,
Empty_Tuple,
NDHWGC,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
// 1. Default
add_device_operation_instances(
instances,
device_grouped_conv_bwd_data_xdl_f16_optimized_loads_instances<3,
NDHWGK,
GKZYXC,
Empty_Tuple,
NDHWGC,
ConvBwdDataDefault>{});
// 2. Filter1x1Stride1Pad0
add_device_operation_instances(instances,
device_grouped_conv_bwd_data_xdl_f16_optimized_loads_instances<
3,
NDHWGK,
GKZYXC,
Empty_Tuple,
NDHWGC,
ConvBwdDataFilter1x1Stride1Pad0>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,49 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_optimized_loads_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
NDHWGK,
GKZYXC,
Empty_Tuple,
NDHWGC,
F32,
F32,
Empty_Tuple,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
// 1. Default
add_device_operation_instances(
instances,
device_grouped_conv_bwd_data_xdl_f32_optimized_loads_instances<3,
NDHWGK,
GKZYXC,
Empty_Tuple,
NDHWGC,
ConvBwdDataDefault>{});
// 2. Filter1x1Stride1Pad0
add_device_operation_instances(instances,
device_grouped_conv_bwd_data_xdl_f32_optimized_loads_instances<
3,
NDHWGK,
GKZYXC,
Empty_Tuple,
NDHWGC,
ConvBwdDataFilter1x1Stride1Pad0>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck