mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 10:09:41 +00:00
Grouped Conv Bwd Data out index calculation optimizations (#2917)
* Grouped Conv Bwd Data index calculation optimizations
* fixes
* refactor instances
* gfx12 fixes
* temporary disable splitK for gfx12
[ROCm/composable_kernel commit: 5477811670]
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -1553,6 +1553,198 @@ struct UnMerge
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Transformation struct for convolution backward data output indices to GEMM indices.
|
||||
*
|
||||
* This struct is responsible for mapping the output tensor indices (N, Ho, Wo, K) from the
|
||||
* convolution backward data operation to the corresponding indices (K0, M, K1) used in the
|
||||
* implicit GEMM computation. It encapsulates the necessary parameters and transformation logic
|
||||
* required to efficiently perform the index conversion.
|
||||
*/
|
||||
struct ConvBwdDataImplicitGemmOutTransform
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
|
||||
using LowerIndex = MultiIndex<4>; // N, Ho, Wo, K
|
||||
using UpperIndex = MultiIndex<3>; // K0, M, K1
|
||||
|
||||
index_t N_, Ho_, Wo_, K_;
|
||||
index_t XDot_;
|
||||
index_t HTilde_, WTilde_;
|
||||
index_t WTildeSlice_, TildeSlice_;
|
||||
index_t IHTildeSliceBegin_, IWTildeSliceBegin_;
|
||||
index_t HRatio_, WRatio_;
|
||||
index_t XDotSlice_K_;
|
||||
index_t MPad_, KPad_;
|
||||
Tuple<index_t, index_t, index_t> up_lengths_; // K0_, MPadded, K1_;
|
||||
|
||||
Tuple<index_t, index_t, index_t, index_t>
|
||||
low_lengths_magic_divisor_multiplier_; // XDotSlice_K_, K_, TildeSlice_, WTildeSlice_
|
||||
Tuple<index_t, index_t, index_t, index_t>
|
||||
low_lengths_magic_divisor_shift_; // XDotSlice_K_, K_, TildeSlice_, WTildeSlice_
|
||||
|
||||
__host__ __device__ constexpr ConvBwdDataImplicitGemmOutTransform() = default;
|
||||
|
||||
__host__ __device__ constexpr ConvBwdDataImplicitGemmOutTransform(index_t N,
|
||||
index_t Ho,
|
||||
index_t Wo,
|
||||
index_t K,
|
||||
index_t XDot,
|
||||
index_t HTilde,
|
||||
index_t WTilde,
|
||||
index_t WTildeSlice,
|
||||
index_t HWTildeSlice,
|
||||
index_t IHTildeSliceBegin,
|
||||
index_t IWTildeSliceBegin,
|
||||
index_t HRatio,
|
||||
index_t WRatio,
|
||||
index_t XDotSlice_K,
|
||||
index_t K0,
|
||||
index_t MPadded,
|
||||
index_t K1,
|
||||
index_t MPad,
|
||||
index_t KPad)
|
||||
: N_{N},
|
||||
Ho_{Ho},
|
||||
Wo_{Wo},
|
||||
K_{K},
|
||||
XDot_{XDot},
|
||||
HTilde_{HTilde},
|
||||
WTilde_{WTilde},
|
||||
WTildeSlice_{WTildeSlice},
|
||||
TildeSlice_{HWTildeSlice},
|
||||
IHTildeSliceBegin_{IHTildeSliceBegin},
|
||||
IWTildeSliceBegin_{IWTildeSliceBegin},
|
||||
HRatio_{HRatio},
|
||||
WRatio_{WRatio},
|
||||
XDotSlice_K_{XDotSlice_K},
|
||||
MPad_{MPad},
|
||||
KPad_{KPad},
|
||||
up_lengths_{make_tuple(K0, MPadded, K1)},
|
||||
low_lengths_magic_divisor_multiplier_{
|
||||
MagicDivision::CalculateMagicMultiplier(XDotSlice_K_),
|
||||
MagicDivision::CalculateMagicMultiplier(K_),
|
||||
MagicDivision::CalculateMagicMultiplier(TildeSlice_),
|
||||
MagicDivision::CalculateMagicMultiplier(WTildeSlice_)},
|
||||
low_lengths_magic_divisor_shift_{MagicDivision::CalculateMagicShift(XDotSlice_K_),
|
||||
MagicDivision::CalculateMagicShift(K_),
|
||||
MagicDivision::CalculateMagicShift(TildeSlice_),
|
||||
MagicDivision::CalculateMagicShift(WTildeSlice_)}
|
||||
{
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 4; }
|
||||
|
||||
__host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 3; }
|
||||
|
||||
__host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; }
|
||||
|
||||
template <typename UpIdx>
|
||||
__host__ __device__ constexpr auto CalculateLowerIndexN(const UpIdx& idx_up) const
|
||||
{
|
||||
index_t NStep, HStep, WStep;
|
||||
// Merge
|
||||
// NStep = M_id / TildeSlice_
|
||||
NStep = MagicDivision::DoMagicDivision(idx_up[I1],
|
||||
this->low_lengths_magic_divisor_multiplier_[I2],
|
||||
this->low_lengths_magic_divisor_shift_[I2]);
|
||||
HStep = idx_up[I1] - NStep * TildeSlice_;
|
||||
// HStep = HStep / WTildeSlice_
|
||||
HStep = MagicDivision::DoMagicDivision(HStep,
|
||||
this->low_lengths_magic_divisor_multiplier_[I3],
|
||||
this->low_lengths_magic_divisor_shift_[I3]);
|
||||
WStep = idx_up[I1] - NStep * TildeSlice_ - HStep * WTildeSlice_;
|
||||
// Slice
|
||||
HStep += IHTildeSliceBegin_;
|
||||
WStep += IWTildeSliceBegin_;
|
||||
|
||||
return make_tuple(NStep, HStep, WStep, 0);
|
||||
}
|
||||
|
||||
template <typename UpIdx>
|
||||
__host__ __device__ constexpr auto CalculateLowerIndexK(const UpIdx& idx_up) const
|
||||
{
|
||||
// UnMerge
|
||||
// K_idx <- K0_idx * K1 + K1_idx
|
||||
index_t K_idx = idx_up[I0] * up_lengths_[I2] + idx_up[I2];
|
||||
// Merge
|
||||
// YStep = K_idx / XDotSlice_K_
|
||||
index_t YStep =
|
||||
MagicDivision::DoMagicDivision(K_idx,
|
||||
this->low_lengths_magic_divisor_multiplier_[I0],
|
||||
this->low_lengths_magic_divisor_shift_[I0]);
|
||||
index_t KStep = K_idx - YStep * XDotSlice_K_;
|
||||
// Xstep = KStep / K_
|
||||
index_t XStep =
|
||||
MagicDivision::DoMagicDivision(KStep,
|
||||
this->low_lengths_magic_divisor_multiplier_[I1],
|
||||
this->low_lengths_magic_divisor_shift_[I1]);
|
||||
KStep -= XStep * K_;
|
||||
// Embed
|
||||
YStep *= HRatio_;
|
||||
XStep *= WRatio_;
|
||||
|
||||
return make_tuple(0, YStep, XStep, KStep);
|
||||
}
|
||||
|
||||
template <typename LowIdx, typename UpIdx>
|
||||
__host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low,
|
||||
const UpIdx& idx_up) const
|
||||
{
|
||||
idx_low = CalculateLowerIndexN(idx_up) + CalculateLowerIndexK(idx_up);
|
||||
}
|
||||
|
||||
template <typename LowIdxDiff,
|
||||
typename UpIdxDiff,
|
||||
typename LowIdx,
|
||||
typename UpIdx,
|
||||
index_t Hack>
|
||||
__host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
|
||||
const UpIdxDiff& /* idx_diff_up */,
|
||||
LowIdx& idx_low,
|
||||
const UpIdx& idx_up,
|
||||
Number<Hack>) const
|
||||
{
|
||||
LowIdx low_old = idx_low;
|
||||
idx_low = CalculateLowerIndexN(idx_up) + CalculateLowerIndexK(idx_up);
|
||||
idx_diff_low = idx_low - low_old;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool IsLinearTransform() { return false; }
|
||||
|
||||
__host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename UpIdx>
|
||||
__host__ __device__ constexpr bool
|
||||
IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& idx_up) const
|
||||
{
|
||||
// Padding
|
||||
index_t K_idx = idx_up[Number<0>{}] * up_lengths_[Number<2>{}] + idx_up[Number<2>{}];
|
||||
index_t& M_idx = idx_up[Number<1>{}];
|
||||
|
||||
bool pad_valid = M_idx < up_lengths_[Number<1>{}] - MPad_ &&
|
||||
K_idx < up_lengths_[Number<0>{}] * up_lengths_[Number<2>{}] - KPad_;
|
||||
return pad_valid;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool IsKnownAtCompileTime() { return false; }
|
||||
|
||||
__host__ __device__ void Print() const
|
||||
{
|
||||
printf("{");
|
||||
printf("ConvBwdDataImplicitGemmOutTransform, ");
|
||||
printf("up_lengths_");
|
||||
print_multi_index(up_lengths_);
|
||||
printf("}");
|
||||
}
|
||||
};
|
||||
|
||||
template <typename LowerIndex>
|
||||
struct Freeze
|
||||
{
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -94,6 +94,59 @@ __host__ __device__ constexpr auto make_unmerge_transform(
|
||||
return UnMerge<UpLengths, Use24BitIntegerCalculation>{up_lengths};
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr auto make_conv_bwd_data_out_transform(index_t N,
|
||||
index_t Ho,
|
||||
index_t Wo,
|
||||
index_t K,
|
||||
[[maybe_unused]] index_t YDot,
|
||||
index_t XDot,
|
||||
index_t HTilde,
|
||||
index_t WTilde,
|
||||
index_t ConvDilationH,
|
||||
index_t ConvDilationW,
|
||||
index_t HTildeSlice,
|
||||
index_t WTildeSlice,
|
||||
index_t YDotSlice,
|
||||
index_t XDotSlice,
|
||||
index_t IHTildeSliceBegin,
|
||||
index_t IWTildeSliceBegin,
|
||||
index_t GcdStrideDilationH,
|
||||
index_t GcdStrideDilationW,
|
||||
index_t K0,
|
||||
index_t K1,
|
||||
index_t MPerBlock,
|
||||
index_t GemmKPerBlock)
|
||||
{
|
||||
// Calculate padding
|
||||
const auto MRaw = N * HTildeSlice * WTildeSlice;
|
||||
const auto MPadded = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
|
||||
const auto MPad = MPadded - MRaw;
|
||||
|
||||
const auto KRaw = YDotSlice * XDotSlice * K;
|
||||
const auto KPadded = math::integer_divide_ceil(KRaw, GemmKPerBlock) * GemmKPerBlock;
|
||||
const auto KPad = KPadded - KRaw;
|
||||
|
||||
return ConvBwdDataImplicitGemmOutTransform{N,
|
||||
Ho,
|
||||
Wo,
|
||||
K,
|
||||
XDot,
|
||||
HTilde,
|
||||
WTilde,
|
||||
WTildeSlice,
|
||||
HTildeSlice * WTildeSlice,
|
||||
IHTildeSliceBegin,
|
||||
IWTildeSliceBegin,
|
||||
-ConvDilationH / GcdStrideDilationH,
|
||||
-ConvDilationW / GcdStrideDilationW,
|
||||
XDotSlice * K,
|
||||
K0,
|
||||
MPadded,
|
||||
K1,
|
||||
MPad,
|
||||
KPad};
|
||||
}
|
||||
|
||||
template <typename LowerIndex>
|
||||
__host__ __device__ constexpr auto make_freeze_transform(const LowerIndex& low_idx)
|
||||
{
|
||||
|
||||
@@ -1485,7 +1485,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
// gfx11 doesn't support float atomic
|
||||
if(ck::is_gfx11_supported() && arg.k_batch_ > 1)
|
||||
// Todo: Enable splitK for gfx12
|
||||
if((ck::is_gfx12_supported() || ck::is_gfx11_supported()) && arg.k_batch_ > 1)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -13,6 +13,14 @@
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
|
||||
/**
|
||||
* @brief Enable custom tensor transform for convolution backward data output.
|
||||
*
|
||||
* When set to 1, this macro enables a custom transformation of the output tensor
|
||||
* in convolution backward data operations.
|
||||
*/
|
||||
#define CK_USE_CUSTOM_TENSOR_TRANSFORM_FOR_BWD_DATA_OUT 1
|
||||
|
||||
template <
|
||||
index_t NDimSpatial,
|
||||
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization ConvBwdDataSpecialization,
|
||||
@@ -705,6 +713,12 @@ struct TransformConvBwdDataToGemm_v1
|
||||
|
||||
if constexpr(NDimSpatial == 2)
|
||||
{
|
||||
const index_t K0PerBlock = GemmKPerBlock / AK1;
|
||||
const index_t AK0 = math::integer_divide_ceil(YDotSlice * XDotSlice * K_,
|
||||
AK1 * K0PerBlock * batch_k_) *
|
||||
K0PerBlock;
|
||||
|
||||
#if CK_USE_CUSTOM_TENSOR_TRANSFORM_FOR_BWD_DATA_OUT == 0
|
||||
// A: output tensor
|
||||
const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor(
|
||||
out_grid_desc,
|
||||
@@ -762,12 +776,6 @@ struct TransformConvBwdDataToGemm_v1
|
||||
make_tuple(GemmKPerBlock, GemmMPerBlock),
|
||||
Sequence<true, DoPadGemmM>{});
|
||||
|
||||
const index_t K0PerBlock = GemmKPerBlock / AK1;
|
||||
const index_t AK0 =
|
||||
math::integer_divide_ceil(out_gemmk_gemmm_padded_grid_desc.GetLength(I0),
|
||||
AK1 * K0PerBlock * batch_k_) *
|
||||
K0PerBlock;
|
||||
|
||||
const auto out_gemmak0_gemmm_gemmak1_grid_desc = transform_tensor_descriptor(
|
||||
out_gemmk_gemmm_padded_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(AK0 * batch_k_, AK1)),
|
||||
@@ -775,8 +783,46 @@ struct TransformConvBwdDataToGemm_v1
|
||||
out_gemmk_gemmm_padded_grid_desc.GetLength(I1))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
return out_gemmak0_gemmm_gemmak1_grid_desc;
|
||||
#else
|
||||
const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor(
|
||||
out_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N_),
|
||||
make_pad_transform(Ho_, I0, I0),
|
||||
make_pad_transform(Wo_, I0, I0),
|
||||
make_pass_through_transform(K_)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto out_n_hop_wop_k_grid_desc_final = transform_tensor_descriptor(
|
||||
out_n_hop_wop_k_grid_desc,
|
||||
make_tuple(make_conv_bwd_data_out_transform(N_,
|
||||
Ho_,
|
||||
Wo_,
|
||||
K_,
|
||||
YDot_,
|
||||
XDot_,
|
||||
HTilde_,
|
||||
WTilde_,
|
||||
ConvDilationH_,
|
||||
ConvDilationW_,
|
||||
HTildeSlice,
|
||||
WTildeSlice,
|
||||
YDotSlice,
|
||||
XDotSlice,
|
||||
IHTildeSliceBegin,
|
||||
IWTildeSliceBegin,
|
||||
GcdStrideDilationH_,
|
||||
GcdStrideDilationW_,
|
||||
AK0,
|
||||
AK1,
|
||||
GemmMPerBlock,
|
||||
GemmKPerBlock)),
|
||||
make_tuple(Sequence<0, 1, 2, 3>{}),
|
||||
make_tuple(Sequence<0, 1, 2>{}));
|
||||
|
||||
return out_n_hop_wop_k_grid_desc_final;
|
||||
#endif
|
||||
}
|
||||
else if constexpr(NDimSpatial == 3)
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user