Grouped Conv Bwd Data out index calculation optimizations (#2917)

* Grouped Conv Bwd Data index calculation optimizations

* fixes

* refactor instances

* gfx12 fixes

* temporary disable splitK for gfx12

[ROCm/composable_kernel commit: 5477811670]
This commit is contained in:
Bartłomiej Kocot
2025-09-29 15:59:11 +02:00
committed by GitHub
parent 0da205766b
commit 1d9ec09cf2
17 changed files with 895 additions and 75 deletions

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -1553,6 +1553,198 @@ struct UnMerge
}
};
/**
* @brief Transformation struct for convolution backward data output indices to GEMM indices.
*
* This struct is responsible for mapping the output tensor indices (N, Ho, Wo, K) from the
* convolution backward data operation to the corresponding indices (K0, M, K1) used in the
* implicit GEMM computation. It encapsulates the necessary parameters and transformation logic
* required to efficiently perform the index conversion.
*/
struct ConvBwdDataImplicitGemmOutTransform
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
using LowerIndex = MultiIndex<4>; // N, Ho, Wo, K
using UpperIndex = MultiIndex<3>; // K0, M, K1
index_t N_, Ho_, Wo_, K_;
index_t XDot_;
index_t HTilde_, WTilde_;
index_t WTildeSlice_, TildeSlice_;
index_t IHTildeSliceBegin_, IWTildeSliceBegin_;
index_t HRatio_, WRatio_;
index_t XDotSlice_K_;
index_t MPad_, KPad_;
Tuple<index_t, index_t, index_t> up_lengths_; // K0_, MPadded, K1_;
Tuple<index_t, index_t, index_t, index_t>
low_lengths_magic_divisor_multiplier_; // XDotSlice_K_, K_, TildeSlice_, WTildeSlice_
Tuple<index_t, index_t, index_t, index_t>
low_lengths_magic_divisor_shift_; // XDotSlice_K_, K_, TildeSlice_, WTildeSlice_
__host__ __device__ constexpr ConvBwdDataImplicitGemmOutTransform() = default;
__host__ __device__ constexpr ConvBwdDataImplicitGemmOutTransform(index_t N,
index_t Ho,
index_t Wo,
index_t K,
index_t XDot,
index_t HTilde,
index_t WTilde,
index_t WTildeSlice,
index_t HWTildeSlice,
index_t IHTildeSliceBegin,
index_t IWTildeSliceBegin,
index_t HRatio,
index_t WRatio,
index_t XDotSlice_K,
index_t K0,
index_t MPadded,
index_t K1,
index_t MPad,
index_t KPad)
: N_{N},
Ho_{Ho},
Wo_{Wo},
K_{K},
XDot_{XDot},
HTilde_{HTilde},
WTilde_{WTilde},
WTildeSlice_{WTildeSlice},
TildeSlice_{HWTildeSlice},
IHTildeSliceBegin_{IHTildeSliceBegin},
IWTildeSliceBegin_{IWTildeSliceBegin},
HRatio_{HRatio},
WRatio_{WRatio},
XDotSlice_K_{XDotSlice_K},
MPad_{MPad},
KPad_{KPad},
up_lengths_{make_tuple(K0, MPadded, K1)},
low_lengths_magic_divisor_multiplier_{
MagicDivision::CalculateMagicMultiplier(XDotSlice_K_),
MagicDivision::CalculateMagicMultiplier(K_),
MagicDivision::CalculateMagicMultiplier(TildeSlice_),
MagicDivision::CalculateMagicMultiplier(WTildeSlice_)},
low_lengths_magic_divisor_shift_{MagicDivision::CalculateMagicShift(XDotSlice_K_),
MagicDivision::CalculateMagicShift(K_),
MagicDivision::CalculateMagicShift(TildeSlice_),
MagicDivision::CalculateMagicShift(WTildeSlice_)}
{
}
__host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 4; }
__host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 3; }
__host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; }
template <typename UpIdx>
__host__ __device__ constexpr auto CalculateLowerIndexN(const UpIdx& idx_up) const
{
index_t NStep, HStep, WStep;
// Merge
// NStep = M_id / TildeSlice_
NStep = MagicDivision::DoMagicDivision(idx_up[I1],
this->low_lengths_magic_divisor_multiplier_[I2],
this->low_lengths_magic_divisor_shift_[I2]);
HStep = idx_up[I1] - NStep * TildeSlice_;
// HStep = HStep / WTildeSlice_
HStep = MagicDivision::DoMagicDivision(HStep,
this->low_lengths_magic_divisor_multiplier_[I3],
this->low_lengths_magic_divisor_shift_[I3]);
WStep = idx_up[I1] - NStep * TildeSlice_ - HStep * WTildeSlice_;
// Slice
HStep += IHTildeSliceBegin_;
WStep += IWTildeSliceBegin_;
return make_tuple(NStep, HStep, WStep, 0);
}
template <typename UpIdx>
__host__ __device__ constexpr auto CalculateLowerIndexK(const UpIdx& idx_up) const
{
// UnMerge
// K_idx <- K0_idx * K1 + K1_idx
index_t K_idx = idx_up[I0] * up_lengths_[I2] + idx_up[I2];
// Merge
// YStep = K_idx / XDotSlice_K_
index_t YStep =
MagicDivision::DoMagicDivision(K_idx,
this->low_lengths_magic_divisor_multiplier_[I0],
this->low_lengths_magic_divisor_shift_[I0]);
index_t KStep = K_idx - YStep * XDotSlice_K_;
// Xstep = KStep / K_
index_t XStep =
MagicDivision::DoMagicDivision(KStep,
this->low_lengths_magic_divisor_multiplier_[I1],
this->low_lengths_magic_divisor_shift_[I1]);
KStep -= XStep * K_;
// Embed
YStep *= HRatio_;
XStep *= WRatio_;
return make_tuple(0, YStep, XStep, KStep);
}
template <typename LowIdx, typename UpIdx>
__host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low,
const UpIdx& idx_up) const
{
idx_low = CalculateLowerIndexN(idx_up) + CalculateLowerIndexK(idx_up);
}
template <typename LowIdxDiff,
typename UpIdxDiff,
typename LowIdx,
typename UpIdx,
index_t Hack>
__host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
const UpIdxDiff& /* idx_diff_up */,
LowIdx& idx_low,
const UpIdx& idx_up,
Number<Hack>) const
{
LowIdx low_old = idx_low;
idx_low = CalculateLowerIndexN(idx_up) + CalculateLowerIndexK(idx_up);
idx_diff_low = idx_low - low_old;
}
__host__ __device__ static constexpr bool IsLinearTransform() { return false; }
__host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
{
return true;
}
template <typename UpIdx>
__host__ __device__ constexpr bool
IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& idx_up) const
{
// Padding
index_t K_idx = idx_up[Number<0>{}] * up_lengths_[Number<2>{}] + idx_up[Number<2>{}];
index_t& M_idx = idx_up[Number<1>{}];
bool pad_valid = M_idx < up_lengths_[Number<1>{}] - MPad_ &&
K_idx < up_lengths_[Number<0>{}] * up_lengths_[Number<2>{}] - KPad_;
return pad_valid;
}
__host__ __device__ static constexpr bool IsKnownAtCompileTime() { return false; }
__host__ __device__ void Print() const
{
printf("{");
printf("ConvBwdDataImplicitGemmOutTransform, ");
printf("up_lengths_");
print_multi_index(up_lengths_);
printf("}");
}
};
template <typename LowerIndex>
struct Freeze
{

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -94,6 +94,59 @@ __host__ __device__ constexpr auto make_unmerge_transform(
return UnMerge<UpLengths, Use24BitIntegerCalculation>{up_lengths};
}
__host__ __device__ constexpr auto make_conv_bwd_data_out_transform(index_t N,
index_t Ho,
index_t Wo,
index_t K,
[[maybe_unused]] index_t YDot,
index_t XDot,
index_t HTilde,
index_t WTilde,
index_t ConvDilationH,
index_t ConvDilationW,
index_t HTildeSlice,
index_t WTildeSlice,
index_t YDotSlice,
index_t XDotSlice,
index_t IHTildeSliceBegin,
index_t IWTildeSliceBegin,
index_t GcdStrideDilationH,
index_t GcdStrideDilationW,
index_t K0,
index_t K1,
index_t MPerBlock,
index_t GemmKPerBlock)
{
// Calculate padding
const auto MRaw = N * HTildeSlice * WTildeSlice;
const auto MPadded = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
const auto MPad = MPadded - MRaw;
const auto KRaw = YDotSlice * XDotSlice * K;
const auto KPadded = math::integer_divide_ceil(KRaw, GemmKPerBlock) * GemmKPerBlock;
const auto KPad = KPadded - KRaw;
return ConvBwdDataImplicitGemmOutTransform{N,
Ho,
Wo,
K,
XDot,
HTilde,
WTilde,
WTildeSlice,
HTildeSlice * WTildeSlice,
IHTildeSliceBegin,
IWTildeSliceBegin,
-ConvDilationH / GcdStrideDilationH,
-ConvDilationW / GcdStrideDilationW,
XDotSlice * K,
K0,
MPadded,
K1,
MPad,
KPad};
}
template <typename LowerIndex>
__host__ __device__ constexpr auto make_freeze_transform(const LowerIndex& low_idx)
{

View File

@@ -1485,7 +1485,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
static bool IsSupportedArgument(const Argument& arg)
{
// gfx11 doesn't support float atomic
if(ck::is_gfx11_supported() && arg.k_batch_ > 1)
// Todo: Enable splitK for gfx12
if((ck::is_gfx12_supported() || ck::is_gfx11_supported()) && arg.k_batch_ > 1)
{
return false;
}

View File

@@ -13,6 +13,14 @@
namespace ck {
namespace tensor_operation {
/**
* @brief Enable custom tensor transform for convolution backward data output.
*
* When set to 1, this macro enables a custom transformation of the output tensor
* in convolution backward data operations.
*/
#define CK_USE_CUSTOM_TENSOR_TRANSFORM_FOR_BWD_DATA_OUT 1
template <
index_t NDimSpatial,
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization ConvBwdDataSpecialization,
@@ -705,6 +713,12 @@ struct TransformConvBwdDataToGemm_v1
if constexpr(NDimSpatial == 2)
{
const index_t K0PerBlock = GemmKPerBlock / AK1;
const index_t AK0 = math::integer_divide_ceil(YDotSlice * XDotSlice * K_,
AK1 * K0PerBlock * batch_k_) *
K0PerBlock;
#if CK_USE_CUSTOM_TENSOR_TRANSFORM_FOR_BWD_DATA_OUT == 0
// A: output tensor
const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor(
out_grid_desc,
@@ -762,12 +776,6 @@ struct TransformConvBwdDataToGemm_v1
make_tuple(GemmKPerBlock, GemmMPerBlock),
Sequence<true, DoPadGemmM>{});
const index_t K0PerBlock = GemmKPerBlock / AK1;
const index_t AK0 =
math::integer_divide_ceil(out_gemmk_gemmm_padded_grid_desc.GetLength(I0),
AK1 * K0PerBlock * batch_k_) *
K0PerBlock;
const auto out_gemmak0_gemmm_gemmak1_grid_desc = transform_tensor_descriptor(
out_gemmk_gemmm_padded_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(AK0 * batch_k_, AK1)),
@@ -775,8 +783,46 @@ struct TransformConvBwdDataToGemm_v1
out_gemmk_gemmm_padded_grid_desc.GetLength(I1))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return out_gemmak0_gemmm_gemmak1_grid_desc;
#else
const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor(
out_grid_desc,
make_tuple(make_pass_through_transform(N_),
make_pad_transform(Ho_, I0, I0),
make_pad_transform(Wo_, I0, I0),
make_pass_through_transform(K_)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto out_n_hop_wop_k_grid_desc_final = transform_tensor_descriptor(
out_n_hop_wop_k_grid_desc,
make_tuple(make_conv_bwd_data_out_transform(N_,
Ho_,
Wo_,
K_,
YDot_,
XDot_,
HTilde_,
WTilde_,
ConvDilationH_,
ConvDilationW_,
HTildeSlice,
WTildeSlice,
YDotSlice,
XDotSlice,
IHTildeSliceBegin,
IWTildeSliceBegin,
GcdStrideDilationH_,
GcdStrideDilationW_,
AK0,
AK1,
GemmMPerBlock,
GemmKPerBlock)),
make_tuple(Sequence<0, 1, 2, 3>{}),
make_tuple(Sequence<0, 1, 2>{}));
return out_n_hop_wop_k_grid_desc_final;
#endif
}
else if constexpr(NDimSpatial == 3)
{

View File

@@ -76,6 +76,47 @@ using device_grouped_conv_bwd_data_xdl_f16_16_16_instances =
// clang-format on
>;
template <index_t NDimSpatial,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
ConvolutionBackwardDataSpecialization ConvSpec>
using device_grouped_conv_bwd_data_xdl_f16_optimized_loads_instances =
std::tuple<
// clang-format off
// ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer|
// ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector|
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock|
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// A K1 one access for each thread per load
// 32x32
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 32, 64, 16, 16, 32, 32, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 4, 1, S<4, 4, 16>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, 1, 1, 1, S<1, 64, 1, 4>, 8>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 32, 32, 8, 8, 32, 32, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 4, 1, S<4, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, 1, 1, 1, S<1, 64, 1, 4>, 8>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 32, 16, 4, 4, 32, 32, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 4, 4>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, 1, 1, 1, S<1, 64, 1, 4>, 8>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 32, 64, 16, 16, 32, 32, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 4, 1, S<4, 4, 16>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 1, 1, 1, S<1, 16, 1, 16>, 2>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 32, 32, 8, 8, 32, 32, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 4, 1, S<4, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 1, 1, 1, S<1, 16, 1, 16>, 2>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 32, 16, 4, 4, 32, 32, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 8, 4>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 1, 1, 1, S<1, 16, 1, 16>, 2>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 32, 64, 16, 16, 32, 32, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 4, 1, S<4, 4, 16>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 8, 1, 32>, 1>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 32, 32, 8, 8, 32, 32, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 4, 1, S<4, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 8, 1, 32>, 1>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 32, 16, 4, 4, 32, 32, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 4>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 8, 1, 32>, 1>,
// 16x16
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 16, 64, 16, 16, 16, 16, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 4, 1, S<4, 2, 16>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, 1, 1, 1, S<1, 64, 1, 4>, 4>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 4, 1, S<4, 2, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, 1, 1, 1, S<1, 64, 1, 4>, 4>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 16, 16, 4, 4, 16, 16, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 2, 4>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, 1, 1, 1, S<1, 64, 1, 4>, 4>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 16, 64, 16, 16, 16, 16, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 4, 1, S<4, 4, 16>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 1, 1, 1, S<1, 32, 1, 8>, 2>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 4, 1, S<4, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 1, 1, 1, S<1, 32, 1, 8>, 2>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 16, 16, 4, 4, 16, 16, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 8, 4>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 1, 1, 1, S<1, 32, 1, 8>, 2>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 16, 64, 16, 16, 16, 16, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 4, 1, S<4, 4, 16>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 4, 1, S<4, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 16, 16, 4, 4, 16, 16, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 4>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>
// clang-format on
>;
template <index_t NDimSpatial,
typename ALayout,
typename BLayout,
@@ -178,6 +219,48 @@ using device_grouped_conv_bwd_data_xdl_bf16_16_16_instances = std::tuple<
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 16, 32, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 4>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 16, 32, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 16, 32, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 4>
// clang-format on
>;
template <index_t NDimSpatial,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
ConvolutionBackwardDataSpecialization ConvSpec>
using device_grouped_conv_bwd_data_xdl_bf16_optimized_loads_instances = std::tuple<
// clang-format off
// ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer|
// ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector|
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock|
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// A K1 one access for each thread per load
// 32x32
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 32, 64, 16, 16, 32, 32, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 4, 1, S<4, 4, 16>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, 1, 1, 1, S<1, 64, 1, 4>, 8>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 32, 32, 8, 8, 32, 32, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 4, 1, S<4, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, 1, 1, 1, S<1, 64, 1, 4>, 8>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 32, 16, 4, 4, 32, 32, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 4, 4>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, 1, 1, 1, S<1, 64, 1, 4>, 8>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 32, 64, 16, 16, 32, 32, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 4, 1, S<4, 4, 16>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 1, 1, 1, S<1, 16, 1, 16>, 2>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 32, 32, 8, 8, 32, 32, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 4, 1, S<4, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 1, 1, 1, S<1, 16, 1, 16>, 2>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 32, 16, 4, 4, 32, 32, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 8, 4>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 1, 1, 1, S<1, 16, 1, 16>, 2>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 32, 64, 16, 16, 32, 32, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 4, 1, S<4, 4, 16>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 8, 1, 32>, 1>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 32, 32, 8, 8, 32, 32, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 4, 1, S<4, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 8, 1, 32>, 1>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 32, 16, 4, 4, 32, 32, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 4>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 8, 1, 32>, 1>,
// 16x16
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 16, 64, 16, 16, 16, 16, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 4, 1, S<4, 2, 16>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, 1, 1, 1, S<1, 64, 1, 4>, 4>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 4, 1, S<4, 2, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, 1, 1, 1, S<1, 64, 1, 4>, 4>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 16, 16, 4, 4, 16, 16, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 2, 4>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, 1, 1, 1, S<1, 64, 1, 4>, 4>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 16, 64, 16, 16, 16, 16, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 4, 1, S<4, 4, 16>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 1, 1, 1, S<1, 32, 1, 8>, 2>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 4, 1, S<4, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 1, 1, 1, S<1, 32, 1, 8>, 2>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 16, 16, 4, 4, 16, 16, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 8, 4>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 1, 1, 1, S<1, 32, 1, 8>, 2>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 16, 64, 16, 16, 16, 16, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 4, 1, S<4, 4, 16>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 4, 1, S<4, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 16, 16, 4, 4, 16, 16, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 4>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>
// clang-format on
>;
@@ -257,6 +340,41 @@ using device_grouped_conv_bwd_data_xdl_f32_16_16_instances =
// clang-format on
>;
template <index_t NDimSpatial,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
ConvolutionBackwardDataSpecialization ConvSpec>
using device_grouped_conv_bwd_data_xdl_f32_optimized_loads_instances =
std::tuple<
// clang-format off
// ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer|
// ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector|
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock|
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// A K1 one access for each thread per load
// 32x32
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 32, 32, 8, 8, 32, 32, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 4, 1, S<4, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, 1, 1, 1, S<1, 32, 1, 8>, 4>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 32, 16, 4, 4, 32, 32, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 4, 4>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, 1, 1, 1, S<1, 32, 1, 8>, 4>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 32, 32, 8, 8, 32, 32, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 4, 1, S<4, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 1, 1, 1, S<1, 16, 1, 16>, 2>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 32, 16, 4, 4, 32, 32, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 8, 4>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 1, 1, 1, S<1, 16, 1, 16>, 2>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 32, 32, 8, 8, 32, 32, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 4, 1, S<4, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 8, 1, 32>, 1>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 32, 16, 4, 4, 32, 32, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 4>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 8, 1, 32>, 1>,
// 16x16
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 4, 1, S<4, 2, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, 1, 1, 1, S<1, 64, 1, 4>, 4>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 16, 16, 4, 4, 16, 16, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 2, 4>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, 1, 1, 1, S<1, 64, 1, 4>, 4>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 4, 1, S<4, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 1, 1, 1, S<1, 32, 1, 8>, 2>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 16, 16, 4, 4, 16, 16, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 8, 4>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 1, 1, 1, S<1, 32, 1, 8>, 2>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 4, 1, S<4, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 16, 16, 4, 4, 16, 16, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 4>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>
// clang-format on
>;
template <index_t NDimSpatial,
typename ALayout,
typename BLayout,

View File

@@ -111,6 +111,8 @@ struct DeviceOperationInstanceFactory<
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_instances(op_ptrs);
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_16_16_instances(
op_ptrs);
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_optimized_loads_instances(
op_ptrs);
}
#endif
#ifdef CK_ENABLE_FP32
@@ -121,6 +123,8 @@ struct DeviceOperationInstanceFactory<
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_instances(op_ptrs);
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_16_16_instances(
op_ptrs);
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_optimized_loads_instances(
op_ptrs);
}
#endif
#ifdef CK_ENABLE_BF16
@@ -132,6 +136,8 @@ struct DeviceOperationInstanceFactory<
op_ptrs);
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_16_16_instances(
op_ptrs);
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_optimized_loads_instances(
op_ptrs);
}
#endif
}
@@ -251,6 +257,8 @@ struct DeviceOperationInstanceFactory<
op_ptrs);
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_16_16_instances(
op_ptrs);
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_optimized_loads_instances(
op_ptrs);
}
#endif
#if defined CK_ENABLE_FP16 && defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
@@ -271,6 +279,8 @@ struct DeviceOperationInstanceFactory<
op_ptrs);
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_16_16_instances(
op_ptrs);
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_optimized_loads_instances(
op_ptrs);
}
#endif
#ifdef CK_ENABLE_BF16
@@ -282,6 +292,8 @@ struct DeviceOperationInstanceFactory<
op_ptrs);
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_16_16_instances(
op_ptrs);
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_optimized_loads_instances(
op_ptrs);
}
#endif
}

View File

@@ -83,6 +83,20 @@ void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_16_16_instance
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_optimized_loads_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
NHWGK,
GKYXC,
Empty_Tuple,
NHWGC,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP32
void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_instances(
@@ -112,6 +126,20 @@ void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_16_16_instance
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_optimized_loads_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
NHWGK,
GKYXC,
Empty_Tuple,
NHWGC,
F32,
F32,
Empty_Tuple,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_BF16
void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_instances(
@@ -141,6 +169,20 @@ void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_16_16_instanc
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_optimized_loads_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
NHWGK,
GKYXC,
Empty_Tuple,
NHWGC,
BF16,
BF16,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP16
@@ -393,6 +435,20 @@ void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_16_16_insta
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_optimized_loads_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
NDHWGK,
GKZYXC,
Empty_Tuple,
NDHWGC,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP32
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_instances(
@@ -422,6 +478,20 @@ void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_16_16_insta
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_optimized_loads_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
NDHWGK,
GKZYXC,
Empty_Tuple,
NDHWGC,
F32,
F32,
Empty_Tuple,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_BF16
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_instances(
@@ -451,6 +521,20 @@ void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_16_16_inst
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_optimized_loads_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
NDHWGK,
GKZYXC,
Empty_Tuple,
NDHWGC,
BF16,
BF16,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#if defined CK_ENABLE_FP16 && defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_input_f16_comp_bf8f8_instances(

View File

@@ -10,6 +10,9 @@ add_instance_library(
xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f16_16_16_instance.cpp
xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_bf16_16_16_instance.cpp
xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f32_16_16_instance.cpp
xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f16_optimized_loads_instance.cpp
xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_bf16_optimized_loads_instance.cpp
xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f32_optimized_loads_instance.cpp
xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkyxc_ngkhw_f16_instance.cpp
xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkyxc_ngkhw_bf16_instance.cpp
xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkyxc_ngkhw_f32_instance.cpp

View File

@@ -0,0 +1,49 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_optimized_loads_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
NHWGK,
GKYXC,
Empty_Tuple,
NHWGC,
BF16,
BF16,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
// 1. Default
add_device_operation_instances(
instances,
device_grouped_conv_bwd_data_xdl_bf16_optimized_loads_instances<2,
NHWGK,
GKYXC,
Empty_Tuple,
NHWGC,
ConvBwdDataDefault>{});
// 2. Filter1x1Stride1Pad0
add_device_operation_instances(instances,
device_grouped_conv_bwd_data_xdl_bf16_optimized_loads_instances<
2,
NHWGK,
GKYXC,
Empty_Tuple,
NHWGC,
ConvBwdDataFilter1x1Stride1Pad0>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,49 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_optimized_loads_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
NHWGK,
GKYXC,
Empty_Tuple,
NHWGC,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
// 1. Default
add_device_operation_instances(
instances,
device_grouped_conv_bwd_data_xdl_f16_optimized_loads_instances<2,
NHWGK,
GKYXC,
Empty_Tuple,
NHWGC,
ConvBwdDataDefault>{});
// 2. Filter1x1Stride1Pad0
add_device_operation_instances(instances,
device_grouped_conv_bwd_data_xdl_f16_optimized_loads_instances<
2,
NHWGK,
GKYXC,
Empty_Tuple,
NHWGC,
ConvBwdDataFilter1x1Stride1Pad0>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,49 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_optimized_loads_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
NHWGK,
GKYXC,
Empty_Tuple,
NHWGC,
F32,
F32,
Empty_Tuple,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
// 1. Default
add_device_operation_instances(
instances,
device_grouped_conv_bwd_data_xdl_f32_optimized_loads_instances<2,
NHWGK,
GKYXC,
Empty_Tuple,
NHWGC,
ConvBwdDataDefault>{});
// 2. Filter1x1Stride1Pad0
add_device_operation_instances(instances,
device_grouped_conv_bwd_data_xdl_f32_optimized_loads_instances<
2,
NHWGK,
GKYXC,
Empty_Tuple,
NHWGC,
ConvBwdDataFilter1x1Stride1Pad0>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -9,6 +9,9 @@ set(GROUPED_CONV3D_BWD_DATA
xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f16_16_16_instance.cpp
xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16_16_instance.cpp
xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f32_16_16_instance.cpp
xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f16_optimized_loads_instance.cpp
xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_optimized_loads_instance.cpp
xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f32_optimized_loads_instance.cpp
xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkzyxc_ngkdhw_f16_instance.cpp
xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_instance.cpp
xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkzyxc_ngkdhw_f32_instance.cpp

View File

@@ -0,0 +1,49 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_optimized_loads_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
NDHWGK,
GKZYXC,
Empty_Tuple,
NDHWGC,
BF16,
BF16,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
// 1. Default
add_device_operation_instances(
instances,
device_grouped_conv_bwd_data_xdl_bf16_optimized_loads_instances<3,
NDHWGK,
GKZYXC,
Empty_Tuple,
NDHWGC,
ConvBwdDataDefault>{});
// 2. Filter1x1Stride1Pad0
add_device_operation_instances(instances,
device_grouped_conv_bwd_data_xdl_bf16_optimized_loads_instances<
3,
NDHWGK,
GKZYXC,
Empty_Tuple,
NDHWGC,
ConvBwdDataFilter1x1Stride1Pad0>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,49 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_optimized_loads_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
NDHWGK,
GKZYXC,
Empty_Tuple,
NDHWGC,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
// 1. Default
add_device_operation_instances(
instances,
device_grouped_conv_bwd_data_xdl_f16_optimized_loads_instances<3,
NDHWGK,
GKZYXC,
Empty_Tuple,
NDHWGC,
ConvBwdDataDefault>{});
// 2. Filter1x1Stride1Pad0
add_device_operation_instances(instances,
device_grouped_conv_bwd_data_xdl_f16_optimized_loads_instances<
3,
NDHWGK,
GKZYXC,
Empty_Tuple,
NDHWGC,
ConvBwdDataFilter1x1Stride1Pad0>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,49 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_optimized_loads_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
NDHWGK,
GKZYXC,
Empty_Tuple,
NDHWGC,
F32,
F32,
Empty_Tuple,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
// 1. Default
add_device_operation_instances(
instances,
device_grouped_conv_bwd_data_xdl_f32_optimized_loads_instances<3,
NDHWGK,
GKZYXC,
Empty_Tuple,
NDHWGC,
ConvBwdDataDefault>{});
// 2. Filter1x1Stride1Pad0
add_device_operation_instances(instances,
device_grouped_conv_bwd_data_xdl_f32_optimized_loads_instances<
3,
NDHWGK,
GKZYXC,
Empty_Tuple,
NDHWGC,
ConvBwdDataFilter1x1Stride1Pad0>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -185,11 +185,17 @@ bool profile_grouped_conv_bwd_data_impl(int do_verification,
// Use higher threshold
rtol = std::max(rtol, rtol_split_k);
atol = std::max(atol, atol_split_k);
pass &= ck::utils::check_err(
in_device, in_host, "Error: Incorrect results!", rtol, atol);
std::cout << "Relative error threshold: " << rtol
<< " Absolute error threshold: " << atol << std::endl;
if(split_k_for_run > 1)
{
pass &= ck::utils::check_err(
in_device, in_host, "Error: Incorrect results!", rtol, atol);
std::cout << "Relative error threshold: " << rtol
<< " Absolute error threshold: " << atol << std::endl;
}
else
{
pass &= ck::utils::check_err(in_device, in_host, "Error: Incorrect results!");
}
if(do_log)
{

View File

@@ -10,7 +10,7 @@ import subprocess
def init_const_args(args):
args.ck_profiler_cmd = '../build/bin/ckProfiler'
args.ck_profiler_cmd = "../build/bin/ckProfiler"
# use decimal values
args.init_method = 2
# don't print tensor values
@@ -27,52 +27,62 @@ def run_ck_profiler_cmd(cmd):
def parse_layouts(args):
if args.in_layout == "NCW" or args.in_layout == "NCHW" or \
args.in_layout == "NCDHW":
if args.in_layout == "NCW" or args.in_layout == "NCHW" or args.in_layout == "NCDHW":
if args.ck_profier_op == "grouped_conv_bwd_weight":
args.layout = 4
elif args.ck_profier_op == "grouped_conv_fwd" or \
args.ck_profier_op == "grouped_conv_bwd_data":
elif (
args.ck_profier_op == "grouped_conv_fwd"
or args.ck_profier_op == "grouped_conv_bwd_data"
):
args.layout = 3
else:
print('Not supported layout for this op')
print("Not supported layout for this op")
exit(1)
elif args.in_layout == "NWC" or args.in_layout == "NHWC" or \
args.in_layout == "NDHWC":
elif (
args.in_layout == "NWC" or args.in_layout == "NHWC" or args.in_layout == "NDHWC"
):
if args.ck_profier_op == "grouped_conv_bwd_weight":
args.layout = 2
elif args.ck_profier_op == "grouped_conv_bwd_data" or \
args.ck_profier_op == "grouped_conv_fwd":
elif (
args.ck_profier_op == "grouped_conv_bwd_data"
or args.ck_profier_op == "grouped_conv_fwd"
):
args.layout = 1
else:
print('Not supported layout for this op')
print("Not supported layout for this op")
exit(1)
def parse_data_type(args):
if args.data_type == "fp32":
if args.ck_profier_op == "grouped_conv_bwd_weight" or \
args.ck_profier_op == "grouped_conv_bwd_data" or \
args.ck_profier_op == "grouped_conv_fwd":
if (
args.ck_profier_op == "grouped_conv_bwd_weight"
or args.ck_profier_op == "grouped_conv_bwd_data"
or args.ck_profier_op == "grouped_conv_fwd"
):
args.data_type = 0
if args.data_type == "fp16":
if args.ck_profier_op == "grouped_conv_bwd_weight" or \
args.ck_profier_op == "grouped_conv_bwd_data" or \
args.ck_profier_op == "grouped_conv_fwd":
if (
args.ck_profier_op == "grouped_conv_bwd_weight"
or args.ck_profier_op == "grouped_conv_bwd_data"
or args.ck_profier_op == "grouped_conv_fwd"
):
args.data_type = 1
if args.data_type == "int8":
if args.ck_profier_op == "grouped_conv_bwd_weight":
args.data_type = 4
if args.ck_profier_op == "grouped_conv_bwd_data":
print('Not supported data type for grouped_conv_bwd_data')
print("Not supported data type for grouped_conv_bwd_data")
exit(1)
if args.ck_profier_op == "grouped_conv_fwd":
args.data_type = 3
if args.data_type == "bfp16":
if args.ck_profier_op == "grouped_conv_bwd_weight":
args.data_type = 5
if args.ck_profier_op == "grouped_conv_bwd_data" or \
args.ck_profier_op == "grouped_conv_fwd":
if (
args.ck_profier_op == "grouped_conv_bwd_data"
or args.ck_profier_op == "grouped_conv_fwd"
):
args.data_type = 2
@@ -93,13 +103,11 @@ def add_conv_params_to_cmd(args, cmd):
cmd += [str(args.in_d), str(args.in_h), str(args.in_w)]
cmd += [str(args.conv_stride_d), str(args.conv_stride_h)]
cmd += [str(args.conv_stride_w)]
cmd += [str(args.dilation_d),
str(args.dilation_h),
str(args.dilation_w)]
cmd += [str(args.dilation_d), str(args.dilation_h), str(args.dilation_w)]
cmd += [str(args.pad_d), str(args.pad_h), str(args.pad_w)]
cmd += [str(args.pad_d), str(args.pad_h), str(args.pad_w)]
else:
print('Not supported spatial dim (supported: 1, 2, 3)')
print("Not supported spatial dim (supported: 1, 2, 3)")
exit(1)
@@ -147,7 +155,7 @@ def run_ck_grouped_conv_bwd_weight(args):
parse_data_type(args)
parse_layouts(args)
# Test all split K value from the list {1, 2, 4, 8, 32, 64, 128}
args.split_k_value = -1
args.split_k_value = "all"
cmd = [str(args.ck_profiler_cmd), str(args.ck_profier_op)]
cmd += [str(args.data_type), str(args.layout)]
@@ -161,23 +169,23 @@ def run_ck_grouped_conv_bwd_weight(args):
cmd += [str(args.split_k_value)]
run_ck_profiler_cmd(cmd)
# Get name of miopen driver, remove it from unknown
def process_miopen_driver_name(args, unknown):
if "convint8" in unknown:
args.data_type = 'int8'
args.data_type = "int8"
unknown.remove("convint8")
elif "convbfp16" in unknown:
args.data_type = 'bfp16'
args.data_type = "bfp16"
unknown.remove("convbfp16")
elif "convfp16" in unknown:
args.data_type = 'fp16'
args.data_type = "fp16"
unknown.remove("convfp16")
elif "conv" in unknown:
args.data_type = 'fp32'
args.data_type = "fp32"
unknown.remove("conv")
else:
print('Not supported driver (supported: conv, convfp16, convint8,'
' convbfp16).')
print("Not supported driver (supported: conv, convfp16, convint8, convbfp16).")
exit(1)
@@ -199,11 +207,11 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser(
prog="converter",
description="Convert miopen driver command to ck Profiler"
"\nExample: python3 "
"../script/convert_miopen_driver_to_profiler.py "
"/opt/rocm/bin/MIOpenDriver conv -n 32 -c 64 -H 28 -W 28 "
"-k 64 -y 3 -x 3 -p 1 -q 1 -u 1 -v 1 -l 1 -j 1 -m conv -g "
"32 -F 1 -t 1",
"\nExample: python3 "
"../script/convert_miopen_driver_to_profiler.py "
"/opt/rocm/bin/MIOpenDriver conv -n 32 -c 64 -H 28 -W 28 "
"-k 64 -y 3 -x 3 -p 1 -q 1 -u 1 -v 1 -l 1 -j 1 -m conv -g "
"32 -F 1 -t 1",
)
parser.add_argument(
"-in_layout",
@@ -213,7 +221,7 @@ if __name__ == "__main__":
default="NCHW",
type=str,
required=False,
help="Input Layout (Default=NCHW for 2d conv, NCDHW for 3d conv)"
help="Input Layout (Default=NCHW for 2d conv, NCDHW for 3d conv)",
)
parser.add_argument(
"-forw",
@@ -230,7 +238,7 @@ if __name__ == "__main__":
"\n4 wrw only"
"\n3 fwd+bwd"
"\n5 fwd+wrw"
"\n6 bwd+wrw"
"\n6 bwd+wrw",
)
parser.add_argument(
"-spatial_dim",
@@ -240,7 +248,7 @@ if __name__ == "__main__":
default=2,
type=int,
required=False,
help="convolution spatial dimension (Default-2)"
help="convolution spatial dimension (Default-2)",
)
parser.add_argument(
"-batchsize",
@@ -250,7 +258,7 @@ if __name__ == "__main__":
default=100,
type=int,
required=False,
help="Mini-batch size (Default=100)"
help="Mini-batch size (Default=100)",
)
parser.add_argument(
"-in_channels",
@@ -260,7 +268,7 @@ if __name__ == "__main__":
default=3,
type=int,
required=False,
help="Number of Input Channels (Default=3)"
help="Number of Input Channels (Default=3)",
)
parser.add_argument(
"-in_d",
@@ -270,7 +278,7 @@ if __name__ == "__main__":
default=32,
type=int,
required=False,
help="Input Depth (Default=32)"
help="Input Depth (Default=32)",
)
parser.add_argument(
"-in_h",
@@ -280,7 +288,7 @@ if __name__ == "__main__":
default=32,
type=int,
required=False,
help="Input Height (Default=32)"
help="Input Height (Default=32)",
)
parser.add_argument(
"-in_w",
@@ -290,7 +298,7 @@ if __name__ == "__main__":
default=32,
type=int,
required=False,
help="Input Width (Default=32)"
help="Input Width (Default=32)",
)
parser.add_argument(
"-out_channels",
@@ -300,7 +308,7 @@ if __name__ == "__main__":
default=32,
type=int,
required=False,
help="Number of Output Channels (Default=32)"
help="Number of Output Channels (Default=32)",
)
parser.add_argument(
"-fil_d",
@@ -310,7 +318,7 @@ if __name__ == "__main__":
default=3,
type=int,
required=False,
help="Filter Depth (Default=3)"
help="Filter Depth (Default=3)",
)
parser.add_argument(
"-fil_h",
@@ -320,7 +328,7 @@ if __name__ == "__main__":
default=3,
type=int,
required=False,
help="Filter Height (Default=3)"
help="Filter Height (Default=3)",
)
parser.add_argument(
"-fil_w",
@@ -330,7 +338,7 @@ if __name__ == "__main__":
default=3,
type=int,
required=False,
help="Filter Width (Default=3)"
help="Filter Width (Default=3)",
)
parser.add_argument(
"-conv_stride_d",
@@ -340,7 +348,7 @@ if __name__ == "__main__":
default=1,
type=int,
required=False,
help="Convolution Stride for Depth (Default=1)"
help="Convolution Stride for Depth (Default=1)",
)
parser.add_argument(
"-conv_stride_h",
@@ -350,7 +358,7 @@ if __name__ == "__main__":
default=1,
type=int,
required=False,
help="Convolution Stride for Height (Default=1)"
help="Convolution Stride for Height (Default=1)",
)
parser.add_argument(
"-conv_stride_w",
@@ -360,7 +368,7 @@ if __name__ == "__main__":
default=1,
type=int,
required=False,
help="Convolution Stride for Width (Default=1)"
help="Convolution Stride for Width (Default=1)",
)
parser.add_argument(
"-pad_d",
@@ -370,7 +378,7 @@ if __name__ == "__main__":
default=1,
type=int,
required=False,
help="Zero Padding for Depth (Default=0)"
help="Zero Padding for Depth (Default=0)",
)
parser.add_argument(
"-pad_h",
@@ -380,7 +388,7 @@ if __name__ == "__main__":
default=1,
type=int,
required=False,
help="Zero Padding for Height (Default=0)"
help="Zero Padding for Height (Default=0)",
)
parser.add_argument(
"-pad_w",
@@ -390,7 +398,7 @@ if __name__ == "__main__":
default=1,
type=int,
required=False,
help="Zero Padding for Width (Default=0)"
help="Zero Padding for Width (Default=0)",
)
parser.add_argument(
"-verify",
@@ -400,7 +408,7 @@ if __name__ == "__main__":
default=1,
type=int,
required=False,
help="Verify Each Layer (Default=1)"
help="Verify Each Layer (Default=1)",
)
parser.add_argument(
"-time",
@@ -410,7 +418,7 @@ if __name__ == "__main__":
default=0,
type=int,
required=False,
help="Time Each Layer (Default=0)"
help="Time Each Layer (Default=0)",
)
parser.add_argument(
"-dilation_d",
@@ -420,7 +428,7 @@ if __name__ == "__main__":
default=1,
type=int,
required=False,
help="Dilation of Filter Depth (Default=1)"
help="Dilation of Filter Depth (Default=1)",
)
parser.add_argument(
"-dilation_h",
@@ -430,7 +438,7 @@ if __name__ == "__main__":
default=1,
type=int,
required=False,
help="Dilation of Filter Height (Default=1)"
help="Dilation of Filter Height (Default=1)",
)
parser.add_argument(
"-dilation_w",
@@ -440,7 +448,7 @@ if __name__ == "__main__":
default=1,
type=int,
required=False,
help="Dilation of Filter Width (Default=1)"
help="Dilation of Filter Width (Default=1)",
)
parser.add_argument(
"-group_count",
@@ -450,7 +458,7 @@ if __name__ == "__main__":
type=int,
default=1,
required=False,
help="Number of Groups (Default=1)"
help="Number of Groups (Default=1)",
)
args, unknown = parser.parse_known_args()