Grouped Conv Bwd Data out index calculation optimizations (#2917)

* Grouped Conv Bwd Data index calculation optimizations

* fixes

* refactor instances

* gfx12 fixes

* temporary disable splitK for gfx12

[ROCm/composable_kernel commit: 5477811670]
This commit is contained in:
Bartłomiej Kocot
2025-09-29 15:59:11 +02:00
committed by GitHub
parent 4bc708f401
commit ef933ee241
17 changed files with 895 additions and 75 deletions

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -1553,6 +1553,198 @@ struct UnMerge
}
};
/**
* @brief Transformation struct for convolution backward data output indices to GEMM indices.
*
* This struct is responsible for mapping the output tensor indices (N, Ho, Wo, K) from the
* convolution backward data operation to the corresponding indices (K0, M, K1) used in the
* implicit GEMM computation. It encapsulates the necessary parameters and transformation logic
* required to efficiently perform the index conversion.
*/
struct ConvBwdDataImplicitGemmOutTransform
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
using LowerIndex = MultiIndex<4>; // N, Ho, Wo, K
using UpperIndex = MultiIndex<3>; // K0, M, K1
index_t N_, Ho_, Wo_, K_;
index_t XDot_;
index_t HTilde_, WTilde_;
index_t WTildeSlice_, TildeSlice_;
index_t IHTildeSliceBegin_, IWTildeSliceBegin_;
index_t HRatio_, WRatio_;
index_t XDotSlice_K_;
index_t MPad_, KPad_;
Tuple<index_t, index_t, index_t> up_lengths_; // K0_, MPadded, K1_;
Tuple<index_t, index_t, index_t, index_t>
low_lengths_magic_divisor_multiplier_; // XDotSlice_K_, K_, TildeSlice_, WTildeSlice_
Tuple<index_t, index_t, index_t, index_t>
low_lengths_magic_divisor_shift_; // XDotSlice_K_, K_, TildeSlice_, WTildeSlice_
__host__ __device__ constexpr ConvBwdDataImplicitGemmOutTransform() = default;
__host__ __device__ constexpr ConvBwdDataImplicitGemmOutTransform(index_t N,
index_t Ho,
index_t Wo,
index_t K,
index_t XDot,
index_t HTilde,
index_t WTilde,
index_t WTildeSlice,
index_t HWTildeSlice,
index_t IHTildeSliceBegin,
index_t IWTildeSliceBegin,
index_t HRatio,
index_t WRatio,
index_t XDotSlice_K,
index_t K0,
index_t MPadded,
index_t K1,
index_t MPad,
index_t KPad)
: N_{N},
Ho_{Ho},
Wo_{Wo},
K_{K},
XDot_{XDot},
HTilde_{HTilde},
WTilde_{WTilde},
WTildeSlice_{WTildeSlice},
TildeSlice_{HWTildeSlice},
IHTildeSliceBegin_{IHTildeSliceBegin},
IWTildeSliceBegin_{IWTildeSliceBegin},
HRatio_{HRatio},
WRatio_{WRatio},
XDotSlice_K_{XDotSlice_K},
MPad_{MPad},
KPad_{KPad},
up_lengths_{make_tuple(K0, MPadded, K1)},
low_lengths_magic_divisor_multiplier_{
MagicDivision::CalculateMagicMultiplier(XDotSlice_K_),
MagicDivision::CalculateMagicMultiplier(K_),
MagicDivision::CalculateMagicMultiplier(TildeSlice_),
MagicDivision::CalculateMagicMultiplier(WTildeSlice_)},
low_lengths_magic_divisor_shift_{MagicDivision::CalculateMagicShift(XDotSlice_K_),
MagicDivision::CalculateMagicShift(K_),
MagicDivision::CalculateMagicShift(TildeSlice_),
MagicDivision::CalculateMagicShift(WTildeSlice_)}
{
}
__host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 4; }
__host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 3; }
__host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; }
template <typename UpIdx>
__host__ __device__ constexpr auto CalculateLowerIndexN(const UpIdx& idx_up) const
{
index_t NStep, HStep, WStep;
// Merge
// NStep = M_id / TildeSlice_
NStep = MagicDivision::DoMagicDivision(idx_up[I1],
this->low_lengths_magic_divisor_multiplier_[I2],
this->low_lengths_magic_divisor_shift_[I2]);
HStep = idx_up[I1] - NStep * TildeSlice_;
// HStep = HStep / WTildeSlice_
HStep = MagicDivision::DoMagicDivision(HStep,
this->low_lengths_magic_divisor_multiplier_[I3],
this->low_lengths_magic_divisor_shift_[I3]);
WStep = idx_up[I1] - NStep * TildeSlice_ - HStep * WTildeSlice_;
// Slice
HStep += IHTildeSliceBegin_;
WStep += IWTildeSliceBegin_;
return make_tuple(NStep, HStep, WStep, 0);
}
template <typename UpIdx>
__host__ __device__ constexpr auto CalculateLowerIndexK(const UpIdx& idx_up) const
{
// UnMerge
// K_idx <- K0_idx * K1 + K1_idx
index_t K_idx = idx_up[I0] * up_lengths_[I2] + idx_up[I2];
// Merge
// YStep = K_idx / XDotSlice_K_
index_t YStep =
MagicDivision::DoMagicDivision(K_idx,
this->low_lengths_magic_divisor_multiplier_[I0],
this->low_lengths_magic_divisor_shift_[I0]);
index_t KStep = K_idx - YStep * XDotSlice_K_;
// Xstep = KStep / K_
index_t XStep =
MagicDivision::DoMagicDivision(KStep,
this->low_lengths_magic_divisor_multiplier_[I1],
this->low_lengths_magic_divisor_shift_[I1]);
KStep -= XStep * K_;
// Embed
YStep *= HRatio_;
XStep *= WRatio_;
return make_tuple(0, YStep, XStep, KStep);
}
template <typename LowIdx, typename UpIdx>
__host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low,
const UpIdx& idx_up) const
{
idx_low = CalculateLowerIndexN(idx_up) + CalculateLowerIndexK(idx_up);
}
template <typename LowIdxDiff,
typename UpIdxDiff,
typename LowIdx,
typename UpIdx,
index_t Hack>
__host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
const UpIdxDiff& /* idx_diff_up */,
LowIdx& idx_low,
const UpIdx& idx_up,
Number<Hack>) const
{
LowIdx low_old = idx_low;
idx_low = CalculateLowerIndexN(idx_up) + CalculateLowerIndexK(idx_up);
idx_diff_low = idx_low - low_old;
}
__host__ __device__ static constexpr bool IsLinearTransform() { return false; }
__host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
{
return true;
}
template <typename UpIdx>
__host__ __device__ constexpr bool
IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& idx_up) const
{
// Padding
index_t K_idx = idx_up[Number<0>{}] * up_lengths_[Number<2>{}] + idx_up[Number<2>{}];
index_t& M_idx = idx_up[Number<1>{}];
bool pad_valid = M_idx < up_lengths_[Number<1>{}] - MPad_ &&
K_idx < up_lengths_[Number<0>{}] * up_lengths_[Number<2>{}] - KPad_;
return pad_valid;
}
__host__ __device__ static constexpr bool IsKnownAtCompileTime() { return false; }
__host__ __device__ void Print() const
{
printf("{");
printf("ConvBwdDataImplicitGemmOutTransform, ");
printf("up_lengths_");
print_multi_index(up_lengths_);
printf("}");
}
};
template <typename LowerIndex>
struct Freeze
{

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -94,6 +94,59 @@ __host__ __device__ constexpr auto make_unmerge_transform(
return UnMerge<UpLengths, Use24BitIntegerCalculation>{up_lengths};
}
__host__ __device__ constexpr auto make_conv_bwd_data_out_transform(index_t N,
index_t Ho,
index_t Wo,
index_t K,
[[maybe_unused]] index_t YDot,
index_t XDot,
index_t HTilde,
index_t WTilde,
index_t ConvDilationH,
index_t ConvDilationW,
index_t HTildeSlice,
index_t WTildeSlice,
index_t YDotSlice,
index_t XDotSlice,
index_t IHTildeSliceBegin,
index_t IWTildeSliceBegin,
index_t GcdStrideDilationH,
index_t GcdStrideDilationW,
index_t K0,
index_t K1,
index_t MPerBlock,
index_t GemmKPerBlock)
{
// Calculate padding
const auto MRaw = N * HTildeSlice * WTildeSlice;
const auto MPadded = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
const auto MPad = MPadded - MRaw;
const auto KRaw = YDotSlice * XDotSlice * K;
const auto KPadded = math::integer_divide_ceil(KRaw, GemmKPerBlock) * GemmKPerBlock;
const auto KPad = KPadded - KRaw;
return ConvBwdDataImplicitGemmOutTransform{N,
Ho,
Wo,
K,
XDot,
HTilde,
WTilde,
WTildeSlice,
HTildeSlice * WTildeSlice,
IHTildeSliceBegin,
IWTildeSliceBegin,
-ConvDilationH / GcdStrideDilationH,
-ConvDilationW / GcdStrideDilationW,
XDotSlice * K,
K0,
MPadded,
K1,
MPad,
KPad};
}
template <typename LowerIndex>
__host__ __device__ constexpr auto make_freeze_transform(const LowerIndex& low_idx)
{

View File

@@ -1485,7 +1485,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
static bool IsSupportedArgument(const Argument& arg)
{
// gfx11 doesn't support float atomic
if(ck::is_gfx11_supported() && arg.k_batch_ > 1)
// Todo: Enable splitK for gfx12
if((ck::is_gfx12_supported() || ck::is_gfx11_supported()) && arg.k_batch_ > 1)
{
return false;
}

View File

@@ -13,6 +13,14 @@
namespace ck {
namespace tensor_operation {
/**
* @brief Enable custom tensor transform for convolution backward data output.
*
* When set to 1, this macro enables a custom transformation of the output tensor
* in convolution backward data operations.
*/
#define CK_USE_CUSTOM_TENSOR_TRANSFORM_FOR_BWD_DATA_OUT 1
template <
index_t NDimSpatial,
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization ConvBwdDataSpecialization,
@@ -705,6 +713,12 @@ struct TransformConvBwdDataToGemm_v1
if constexpr(NDimSpatial == 2)
{
const index_t K0PerBlock = GemmKPerBlock / AK1;
const index_t AK0 = math::integer_divide_ceil(YDotSlice * XDotSlice * K_,
AK1 * K0PerBlock * batch_k_) *
K0PerBlock;
#if CK_USE_CUSTOM_TENSOR_TRANSFORM_FOR_BWD_DATA_OUT == 0
// A: output tensor
const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor(
out_grid_desc,
@@ -762,12 +776,6 @@ struct TransformConvBwdDataToGemm_v1
make_tuple(GemmKPerBlock, GemmMPerBlock),
Sequence<true, DoPadGemmM>{});
const index_t K0PerBlock = GemmKPerBlock / AK1;
const index_t AK0 =
math::integer_divide_ceil(out_gemmk_gemmm_padded_grid_desc.GetLength(I0),
AK1 * K0PerBlock * batch_k_) *
K0PerBlock;
const auto out_gemmak0_gemmm_gemmak1_grid_desc = transform_tensor_descriptor(
out_gemmk_gemmm_padded_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(AK0 * batch_k_, AK1)),
@@ -775,8 +783,46 @@ struct TransformConvBwdDataToGemm_v1
out_gemmk_gemmm_padded_grid_desc.GetLength(I1))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return out_gemmak0_gemmm_gemmak1_grid_desc;
#else
const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor(
out_grid_desc,
make_tuple(make_pass_through_transform(N_),
make_pad_transform(Ho_, I0, I0),
make_pad_transform(Wo_, I0, I0),
make_pass_through_transform(K_)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto out_n_hop_wop_k_grid_desc_final = transform_tensor_descriptor(
out_n_hop_wop_k_grid_desc,
make_tuple(make_conv_bwd_data_out_transform(N_,
Ho_,
Wo_,
K_,
YDot_,
XDot_,
HTilde_,
WTilde_,
ConvDilationH_,
ConvDilationW_,
HTildeSlice,
WTildeSlice,
YDotSlice,
XDotSlice,
IHTildeSliceBegin,
IWTildeSliceBegin,
GcdStrideDilationH_,
GcdStrideDilationW_,
AK0,
AK1,
GemmMPerBlock,
GemmKPerBlock)),
make_tuple(Sequence<0, 1, 2, 3>{}),
make_tuple(Sequence<0, 1, 2>{}));
return out_n_hop_wop_k_grid_desc_final;
#endif
}
else if constexpr(NDimSpatial == 3)
{