Grouped Conv Bwd Data out index calculation optimizations (#2917)

* Grouped Conv Bwd Data index calculation optimizations

* fixes

* refactor instances

* gfx12 fixes

* temporary disable splitK for gfx12

[ROCm/composable_kernel commit: 5477811670]
This commit is contained in:
Bartłomiej Kocot
2025-09-29 15:59:11 +02:00
committed by GitHub
parent 4bc708f401
commit ef933ee241
17 changed files with 895 additions and 75 deletions

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -1553,6 +1553,198 @@ struct UnMerge
}
};
/**
* @brief Transformation struct for convolution backward data output indices to GEMM indices.
*
* This struct is responsible for mapping the output tensor indices (N, Ho, Wo, K) from the
* convolution backward data operation to the corresponding indices (K0, M, K1) used in the
* implicit GEMM computation. It encapsulates the necessary parameters and transformation logic
* required to efficiently perform the index conversion.
*/
struct ConvBwdDataImplicitGemmOutTransform
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
using LowerIndex = MultiIndex<4>; // N, Ho, Wo, K
using UpperIndex = MultiIndex<3>; // K0, M, K1
index_t N_, Ho_, Wo_, K_;
index_t XDot_;
index_t HTilde_, WTilde_;
index_t WTildeSlice_, TildeSlice_;
index_t IHTildeSliceBegin_, IWTildeSliceBegin_;
index_t HRatio_, WRatio_;
index_t XDotSlice_K_;
index_t MPad_, KPad_;
Tuple<index_t, index_t, index_t> up_lengths_; // K0_, MPadded, K1_;
Tuple<index_t, index_t, index_t, index_t>
low_lengths_magic_divisor_multiplier_; // XDotSlice_K_, K_, TildeSlice_, WTildeSlice_
Tuple<index_t, index_t, index_t, index_t>
low_lengths_magic_divisor_shift_; // XDotSlice_K_, K_, TildeSlice_, WTildeSlice_
__host__ __device__ constexpr ConvBwdDataImplicitGemmOutTransform() = default;
__host__ __device__ constexpr ConvBwdDataImplicitGemmOutTransform(index_t N,
index_t Ho,
index_t Wo,
index_t K,
index_t XDot,
index_t HTilde,
index_t WTilde,
index_t WTildeSlice,
index_t HWTildeSlice,
index_t IHTildeSliceBegin,
index_t IWTildeSliceBegin,
index_t HRatio,
index_t WRatio,
index_t XDotSlice_K,
index_t K0,
index_t MPadded,
index_t K1,
index_t MPad,
index_t KPad)
: N_{N},
Ho_{Ho},
Wo_{Wo},
K_{K},
XDot_{XDot},
HTilde_{HTilde},
WTilde_{WTilde},
WTildeSlice_{WTildeSlice},
TildeSlice_{HWTildeSlice},
IHTildeSliceBegin_{IHTildeSliceBegin},
IWTildeSliceBegin_{IWTildeSliceBegin},
HRatio_{HRatio},
WRatio_{WRatio},
XDotSlice_K_{XDotSlice_K},
MPad_{MPad},
KPad_{KPad},
up_lengths_{make_tuple(K0, MPadded, K1)},
low_lengths_magic_divisor_multiplier_{
MagicDivision::CalculateMagicMultiplier(XDotSlice_K_),
MagicDivision::CalculateMagicMultiplier(K_),
MagicDivision::CalculateMagicMultiplier(TildeSlice_),
MagicDivision::CalculateMagicMultiplier(WTildeSlice_)},
low_lengths_magic_divisor_shift_{MagicDivision::CalculateMagicShift(XDotSlice_K_),
MagicDivision::CalculateMagicShift(K_),
MagicDivision::CalculateMagicShift(TildeSlice_),
MagicDivision::CalculateMagicShift(WTildeSlice_)}
{
}
__host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 4; }
__host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 3; }
__host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; }
template <typename UpIdx>
__host__ __device__ constexpr auto CalculateLowerIndexN(const UpIdx& idx_up) const
{
index_t NStep, HStep, WStep;
// Merge
// NStep = M_id / TildeSlice_
NStep = MagicDivision::DoMagicDivision(idx_up[I1],
this->low_lengths_magic_divisor_multiplier_[I2],
this->low_lengths_magic_divisor_shift_[I2]);
HStep = idx_up[I1] - NStep * TildeSlice_;
// HStep = HStep / WTildeSlice_
HStep = MagicDivision::DoMagicDivision(HStep,
this->low_lengths_magic_divisor_multiplier_[I3],
this->low_lengths_magic_divisor_shift_[I3]);
WStep = idx_up[I1] - NStep * TildeSlice_ - HStep * WTildeSlice_;
// Slice
HStep += IHTildeSliceBegin_;
WStep += IWTildeSliceBegin_;
return make_tuple(NStep, HStep, WStep, 0);
}
template <typename UpIdx>
__host__ __device__ constexpr auto CalculateLowerIndexK(const UpIdx& idx_up) const
{
// UnMerge
// K_idx <- K0_idx * K1 + K1_idx
index_t K_idx = idx_up[I0] * up_lengths_[I2] + idx_up[I2];
// Merge
// YStep = K_idx / XDotSlice_K_
index_t YStep =
MagicDivision::DoMagicDivision(K_idx,
this->low_lengths_magic_divisor_multiplier_[I0],
this->low_lengths_magic_divisor_shift_[I0]);
index_t KStep = K_idx - YStep * XDotSlice_K_;
// Xstep = KStep / K_
index_t XStep =
MagicDivision::DoMagicDivision(KStep,
this->low_lengths_magic_divisor_multiplier_[I1],
this->low_lengths_magic_divisor_shift_[I1]);
KStep -= XStep * K_;
// Embed
YStep *= HRatio_;
XStep *= WRatio_;
return make_tuple(0, YStep, XStep, KStep);
}
template <typename LowIdx, typename UpIdx>
__host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low,
const UpIdx& idx_up) const
{
idx_low = CalculateLowerIndexN(idx_up) + CalculateLowerIndexK(idx_up);
}
template <typename LowIdxDiff,
typename UpIdxDiff,
typename LowIdx,
typename UpIdx,
index_t Hack>
__host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
const UpIdxDiff& /* idx_diff_up */,
LowIdx& idx_low,
const UpIdx& idx_up,
Number<Hack>) const
{
LowIdx low_old = idx_low;
idx_low = CalculateLowerIndexN(idx_up) + CalculateLowerIndexK(idx_up);
idx_diff_low = idx_low - low_old;
}
__host__ __device__ static constexpr bool IsLinearTransform() { return false; }
__host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
{
return true;
}
template <typename UpIdx>
__host__ __device__ constexpr bool
IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& idx_up) const
{
// Padding
index_t K_idx = idx_up[Number<0>{}] * up_lengths_[Number<2>{}] + idx_up[Number<2>{}];
index_t& M_idx = idx_up[Number<1>{}];
bool pad_valid = M_idx < up_lengths_[Number<1>{}] - MPad_ &&
K_idx < up_lengths_[Number<0>{}] * up_lengths_[Number<2>{}] - KPad_;
return pad_valid;
}
__host__ __device__ static constexpr bool IsKnownAtCompileTime() { return false; }
__host__ __device__ void Print() const
{
printf("{");
printf("ConvBwdDataImplicitGemmOutTransform, ");
printf("up_lengths_");
print_multi_index(up_lengths_);
printf("}");
}
};
template <typename LowerIndex>
struct Freeze
{

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -94,6 +94,59 @@ __host__ __device__ constexpr auto make_unmerge_transform(
return UnMerge<UpLengths, Use24BitIntegerCalculation>{up_lengths};
}
__host__ __device__ constexpr auto make_conv_bwd_data_out_transform(index_t N,
index_t Ho,
index_t Wo,
index_t K,
[[maybe_unused]] index_t YDot,
index_t XDot,
index_t HTilde,
index_t WTilde,
index_t ConvDilationH,
index_t ConvDilationW,
index_t HTildeSlice,
index_t WTildeSlice,
index_t YDotSlice,
index_t XDotSlice,
index_t IHTildeSliceBegin,
index_t IWTildeSliceBegin,
index_t GcdStrideDilationH,
index_t GcdStrideDilationW,
index_t K0,
index_t K1,
index_t MPerBlock,
index_t GemmKPerBlock)
{
// Calculate padding
const auto MRaw = N * HTildeSlice * WTildeSlice;
const auto MPadded = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
const auto MPad = MPadded - MRaw;
const auto KRaw = YDotSlice * XDotSlice * K;
const auto KPadded = math::integer_divide_ceil(KRaw, GemmKPerBlock) * GemmKPerBlock;
const auto KPad = KPadded - KRaw;
return ConvBwdDataImplicitGemmOutTransform{N,
Ho,
Wo,
K,
XDot,
HTilde,
WTilde,
WTildeSlice,
HTildeSlice * WTildeSlice,
IHTildeSliceBegin,
IWTildeSliceBegin,
-ConvDilationH / GcdStrideDilationH,
-ConvDilationW / GcdStrideDilationW,
XDotSlice * K,
K0,
MPadded,
K1,
MPad,
KPad};
}
template <typename LowerIndex>
__host__ __device__ constexpr auto make_freeze_transform(const LowerIndex& low_idx)
{