Grouped Conv Bwd Data out index calculation optimizations (#2917)

* Grouped Conv Bwd Data index calculation optimizations

* fixes

* refactor instances

* gfx12 fixes

* temporary disable splitK for gfx12
This commit is contained in:
Bartłomiej Kocot
2025-09-29 15:59:11 +02:00
committed by GitHub
parent 0f10e6d921
commit 5477811670
17 changed files with 895 additions and 75 deletions

View File

@@ -1485,7 +1485,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
static bool IsSupportedArgument(const Argument& arg)
{
// gfx11 doesn't support float atomic
if(ck::is_gfx11_supported() && arg.k_batch_ > 1)
// Todo: Enable splitK for gfx12
if((ck::is_gfx12_supported() || ck::is_gfx11_supported()) && arg.k_batch_ > 1)
{
return false;
}

View File

@@ -13,6 +13,14 @@
namespace ck {
namespace tensor_operation {
/**
* @brief Enable custom tensor transform for convolution backward data output.
*
* When set to 1, this macro enables a custom transformation of the output tensor
* in convolution backward data operations.
*/
#define CK_USE_CUSTOM_TENSOR_TRANSFORM_FOR_BWD_DATA_OUT 1
template <
index_t NDimSpatial,
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization ConvBwdDataSpecialization,
@@ -705,6 +713,12 @@ struct TransformConvBwdDataToGemm_v1
if constexpr(NDimSpatial == 2)
{
const index_t K0PerBlock = GemmKPerBlock / AK1;
const index_t AK0 = math::integer_divide_ceil(YDotSlice * XDotSlice * K_,
AK1 * K0PerBlock * batch_k_) *
K0PerBlock;
#if CK_USE_CUSTOM_TENSOR_TRANSFORM_FOR_BWD_DATA_OUT == 0
// A: output tensor
const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor(
out_grid_desc,
@@ -762,12 +776,6 @@ struct TransformConvBwdDataToGemm_v1
make_tuple(GemmKPerBlock, GemmMPerBlock),
Sequence<true, DoPadGemmM>{});
const index_t K0PerBlock = GemmKPerBlock / AK1;
const index_t AK0 =
math::integer_divide_ceil(out_gemmk_gemmm_padded_grid_desc.GetLength(I0),
AK1 * K0PerBlock * batch_k_) *
K0PerBlock;
const auto out_gemmak0_gemmm_gemmak1_grid_desc = transform_tensor_descriptor(
out_gemmk_gemmm_padded_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(AK0 * batch_k_, AK1)),
@@ -775,8 +783,46 @@ struct TransformConvBwdDataToGemm_v1
out_gemmk_gemmm_padded_grid_desc.GetLength(I1))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return out_gemmak0_gemmm_gemmak1_grid_desc;
#else
const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor(
out_grid_desc,
make_tuple(make_pass_through_transform(N_),
make_pad_transform(Ho_, I0, I0),
make_pad_transform(Wo_, I0, I0),
make_pass_through_transform(K_)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto out_n_hop_wop_k_grid_desc_final = transform_tensor_descriptor(
out_n_hop_wop_k_grid_desc,
make_tuple(make_conv_bwd_data_out_transform(N_,
Ho_,
Wo_,
K_,
YDot_,
XDot_,
HTilde_,
WTilde_,
ConvDilationH_,
ConvDilationW_,
HTildeSlice,
WTildeSlice,
YDotSlice,
XDotSlice,
IHTildeSliceBegin,
IWTildeSliceBegin,
GcdStrideDilationH_,
GcdStrideDilationW_,
AK0,
AK1,
GemmMPerBlock,
GemmKPerBlock)),
make_tuple(Sequence<0, 1, 2, 3>{}),
make_tuple(Sequence<0, 1, 2>{}));
return out_n_hop_wop_k_grid_desc_final;
#endif
}
else if constexpr(NDimSpatial == 3)
{