mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
Grouped Conv Bwd Data out index calculation optimizations (#2917)
* Grouped Conv Bwd Data index calculation optimizations
* fixes
* refactor instances
* gfx12 fixes
* temporary disable splitK for gfx12
[ROCm/composable_kernel commit: 5477811670]
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -1553,6 +1553,198 @@ struct UnMerge
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Transformation struct for convolution backward data output indices to GEMM indices.
|
||||
*
|
||||
* This struct is responsible for mapping the output tensor indices (N, Ho, Wo, K) from the
|
||||
* convolution backward data operation to the corresponding indices (K0, M, K1) used in the
|
||||
* implicit GEMM computation. It encapsulates the necessary parameters and transformation logic
|
||||
* required to efficiently perform the index conversion.
|
||||
*/
|
||||
struct ConvBwdDataImplicitGemmOutTransform
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
|
||||
using LowerIndex = MultiIndex<4>; // N, Ho, Wo, K
|
||||
using UpperIndex = MultiIndex<3>; // K0, M, K1
|
||||
|
||||
index_t N_, Ho_, Wo_, K_;
|
||||
index_t XDot_;
|
||||
index_t HTilde_, WTilde_;
|
||||
index_t WTildeSlice_, TildeSlice_;
|
||||
index_t IHTildeSliceBegin_, IWTildeSliceBegin_;
|
||||
index_t HRatio_, WRatio_;
|
||||
index_t XDotSlice_K_;
|
||||
index_t MPad_, KPad_;
|
||||
Tuple<index_t, index_t, index_t> up_lengths_; // K0_, MPadded, K1_;
|
||||
|
||||
Tuple<index_t, index_t, index_t, index_t>
|
||||
low_lengths_magic_divisor_multiplier_; // XDotSlice_K_, K_, TildeSlice_, WTildeSlice_
|
||||
Tuple<index_t, index_t, index_t, index_t>
|
||||
low_lengths_magic_divisor_shift_; // XDotSlice_K_, K_, TildeSlice_, WTildeSlice_
|
||||
|
||||
__host__ __device__ constexpr ConvBwdDataImplicitGemmOutTransform() = default;
|
||||
|
||||
__host__ __device__ constexpr ConvBwdDataImplicitGemmOutTransform(index_t N,
|
||||
index_t Ho,
|
||||
index_t Wo,
|
||||
index_t K,
|
||||
index_t XDot,
|
||||
index_t HTilde,
|
||||
index_t WTilde,
|
||||
index_t WTildeSlice,
|
||||
index_t HWTildeSlice,
|
||||
index_t IHTildeSliceBegin,
|
||||
index_t IWTildeSliceBegin,
|
||||
index_t HRatio,
|
||||
index_t WRatio,
|
||||
index_t XDotSlice_K,
|
||||
index_t K0,
|
||||
index_t MPadded,
|
||||
index_t K1,
|
||||
index_t MPad,
|
||||
index_t KPad)
|
||||
: N_{N},
|
||||
Ho_{Ho},
|
||||
Wo_{Wo},
|
||||
K_{K},
|
||||
XDot_{XDot},
|
||||
HTilde_{HTilde},
|
||||
WTilde_{WTilde},
|
||||
WTildeSlice_{WTildeSlice},
|
||||
TildeSlice_{HWTildeSlice},
|
||||
IHTildeSliceBegin_{IHTildeSliceBegin},
|
||||
IWTildeSliceBegin_{IWTildeSliceBegin},
|
||||
HRatio_{HRatio},
|
||||
WRatio_{WRatio},
|
||||
XDotSlice_K_{XDotSlice_K},
|
||||
MPad_{MPad},
|
||||
KPad_{KPad},
|
||||
up_lengths_{make_tuple(K0, MPadded, K1)},
|
||||
low_lengths_magic_divisor_multiplier_{
|
||||
MagicDivision::CalculateMagicMultiplier(XDotSlice_K_),
|
||||
MagicDivision::CalculateMagicMultiplier(K_),
|
||||
MagicDivision::CalculateMagicMultiplier(TildeSlice_),
|
||||
MagicDivision::CalculateMagicMultiplier(WTildeSlice_)},
|
||||
low_lengths_magic_divisor_shift_{MagicDivision::CalculateMagicShift(XDotSlice_K_),
|
||||
MagicDivision::CalculateMagicShift(K_),
|
||||
MagicDivision::CalculateMagicShift(TildeSlice_),
|
||||
MagicDivision::CalculateMagicShift(WTildeSlice_)}
|
||||
{
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 4; }
|
||||
|
||||
__host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 3; }
|
||||
|
||||
__host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; }
|
||||
|
||||
template <typename UpIdx>
|
||||
__host__ __device__ constexpr auto CalculateLowerIndexN(const UpIdx& idx_up) const
|
||||
{
|
||||
index_t NStep, HStep, WStep;
|
||||
// Merge
|
||||
// NStep = M_id / TildeSlice_
|
||||
NStep = MagicDivision::DoMagicDivision(idx_up[I1],
|
||||
this->low_lengths_magic_divisor_multiplier_[I2],
|
||||
this->low_lengths_magic_divisor_shift_[I2]);
|
||||
HStep = idx_up[I1] - NStep * TildeSlice_;
|
||||
// HStep = HStep / WTildeSlice_
|
||||
HStep = MagicDivision::DoMagicDivision(HStep,
|
||||
this->low_lengths_magic_divisor_multiplier_[I3],
|
||||
this->low_lengths_magic_divisor_shift_[I3]);
|
||||
WStep = idx_up[I1] - NStep * TildeSlice_ - HStep * WTildeSlice_;
|
||||
// Slice
|
||||
HStep += IHTildeSliceBegin_;
|
||||
WStep += IWTildeSliceBegin_;
|
||||
|
||||
return make_tuple(NStep, HStep, WStep, 0);
|
||||
}
|
||||
|
||||
template <typename UpIdx>
|
||||
__host__ __device__ constexpr auto CalculateLowerIndexK(const UpIdx& idx_up) const
|
||||
{
|
||||
// UnMerge
|
||||
// K_idx <- K0_idx * K1 + K1_idx
|
||||
index_t K_idx = idx_up[I0] * up_lengths_[I2] + idx_up[I2];
|
||||
// Merge
|
||||
// YStep = K_idx / XDotSlice_K_
|
||||
index_t YStep =
|
||||
MagicDivision::DoMagicDivision(K_idx,
|
||||
this->low_lengths_magic_divisor_multiplier_[I0],
|
||||
this->low_lengths_magic_divisor_shift_[I0]);
|
||||
index_t KStep = K_idx - YStep * XDotSlice_K_;
|
||||
// Xstep = KStep / K_
|
||||
index_t XStep =
|
||||
MagicDivision::DoMagicDivision(KStep,
|
||||
this->low_lengths_magic_divisor_multiplier_[I1],
|
||||
this->low_lengths_magic_divisor_shift_[I1]);
|
||||
KStep -= XStep * K_;
|
||||
// Embed
|
||||
YStep *= HRatio_;
|
||||
XStep *= WRatio_;
|
||||
|
||||
return make_tuple(0, YStep, XStep, KStep);
|
||||
}
|
||||
|
||||
template <typename LowIdx, typename UpIdx>
|
||||
__host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low,
|
||||
const UpIdx& idx_up) const
|
||||
{
|
||||
idx_low = CalculateLowerIndexN(idx_up) + CalculateLowerIndexK(idx_up);
|
||||
}
|
||||
|
||||
template <typename LowIdxDiff,
|
||||
typename UpIdxDiff,
|
||||
typename LowIdx,
|
||||
typename UpIdx,
|
||||
index_t Hack>
|
||||
__host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
|
||||
const UpIdxDiff& /* idx_diff_up */,
|
||||
LowIdx& idx_low,
|
||||
const UpIdx& idx_up,
|
||||
Number<Hack>) const
|
||||
{
|
||||
LowIdx low_old = idx_low;
|
||||
idx_low = CalculateLowerIndexN(idx_up) + CalculateLowerIndexK(idx_up);
|
||||
idx_diff_low = idx_low - low_old;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool IsLinearTransform() { return false; }
|
||||
|
||||
__host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename UpIdx>
|
||||
__host__ __device__ constexpr bool
|
||||
IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& idx_up) const
|
||||
{
|
||||
// Padding
|
||||
index_t K_idx = idx_up[Number<0>{}] * up_lengths_[Number<2>{}] + idx_up[Number<2>{}];
|
||||
index_t& M_idx = idx_up[Number<1>{}];
|
||||
|
||||
bool pad_valid = M_idx < up_lengths_[Number<1>{}] - MPad_ &&
|
||||
K_idx < up_lengths_[Number<0>{}] * up_lengths_[Number<2>{}] - KPad_;
|
||||
return pad_valid;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool IsKnownAtCompileTime() { return false; }
|
||||
|
||||
__host__ __device__ void Print() const
|
||||
{
|
||||
printf("{");
|
||||
printf("ConvBwdDataImplicitGemmOutTransform, ");
|
||||
printf("up_lengths_");
|
||||
print_multi_index(up_lengths_);
|
||||
printf("}");
|
||||
}
|
||||
};
|
||||
|
||||
template <typename LowerIndex>
|
||||
struct Freeze
|
||||
{
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -94,6 +94,59 @@ __host__ __device__ constexpr auto make_unmerge_transform(
|
||||
return UnMerge<UpLengths, Use24BitIntegerCalculation>{up_lengths};
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr auto make_conv_bwd_data_out_transform(index_t N,
|
||||
index_t Ho,
|
||||
index_t Wo,
|
||||
index_t K,
|
||||
[[maybe_unused]] index_t YDot,
|
||||
index_t XDot,
|
||||
index_t HTilde,
|
||||
index_t WTilde,
|
||||
index_t ConvDilationH,
|
||||
index_t ConvDilationW,
|
||||
index_t HTildeSlice,
|
||||
index_t WTildeSlice,
|
||||
index_t YDotSlice,
|
||||
index_t XDotSlice,
|
||||
index_t IHTildeSliceBegin,
|
||||
index_t IWTildeSliceBegin,
|
||||
index_t GcdStrideDilationH,
|
||||
index_t GcdStrideDilationW,
|
||||
index_t K0,
|
||||
index_t K1,
|
||||
index_t MPerBlock,
|
||||
index_t GemmKPerBlock)
|
||||
{
|
||||
// Calculate padding
|
||||
const auto MRaw = N * HTildeSlice * WTildeSlice;
|
||||
const auto MPadded = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
|
||||
const auto MPad = MPadded - MRaw;
|
||||
|
||||
const auto KRaw = YDotSlice * XDotSlice * K;
|
||||
const auto KPadded = math::integer_divide_ceil(KRaw, GemmKPerBlock) * GemmKPerBlock;
|
||||
const auto KPad = KPadded - KRaw;
|
||||
|
||||
return ConvBwdDataImplicitGemmOutTransform{N,
|
||||
Ho,
|
||||
Wo,
|
||||
K,
|
||||
XDot,
|
||||
HTilde,
|
||||
WTilde,
|
||||
WTildeSlice,
|
||||
HTildeSlice * WTildeSlice,
|
||||
IHTildeSliceBegin,
|
||||
IWTildeSliceBegin,
|
||||
-ConvDilationH / GcdStrideDilationH,
|
||||
-ConvDilationW / GcdStrideDilationW,
|
||||
XDotSlice * K,
|
||||
K0,
|
||||
MPadded,
|
||||
K1,
|
||||
MPad,
|
||||
KPad};
|
||||
}
|
||||
|
||||
template <typename LowerIndex>
|
||||
__host__ __device__ constexpr auto make_freeze_transform(const LowerIndex& low_idx)
|
||||
{
|
||||
|
||||
@@ -1485,7 +1485,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
// gfx11 doesn't support float atomic
|
||||
if(ck::is_gfx11_supported() && arg.k_batch_ > 1)
|
||||
// Todo: Enable splitK for gfx12
|
||||
if((ck::is_gfx12_supported() || ck::is_gfx11_supported()) && arg.k_batch_ > 1)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -13,6 +13,14 @@
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
|
||||
/**
|
||||
* @brief Enable custom tensor transform for convolution backward data output.
|
||||
*
|
||||
* When set to 1, this macro enables a custom transformation of the output tensor
|
||||
* in convolution backward data operations.
|
||||
*/
|
||||
#define CK_USE_CUSTOM_TENSOR_TRANSFORM_FOR_BWD_DATA_OUT 1
|
||||
|
||||
template <
|
||||
index_t NDimSpatial,
|
||||
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization ConvBwdDataSpecialization,
|
||||
@@ -705,6 +713,12 @@ struct TransformConvBwdDataToGemm_v1
|
||||
|
||||
if constexpr(NDimSpatial == 2)
|
||||
{
|
||||
const index_t K0PerBlock = GemmKPerBlock / AK1;
|
||||
const index_t AK0 = math::integer_divide_ceil(YDotSlice * XDotSlice * K_,
|
||||
AK1 * K0PerBlock * batch_k_) *
|
||||
K0PerBlock;
|
||||
|
||||
#if CK_USE_CUSTOM_TENSOR_TRANSFORM_FOR_BWD_DATA_OUT == 0
|
||||
// A: output tensor
|
||||
const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor(
|
||||
out_grid_desc,
|
||||
@@ -762,12 +776,6 @@ struct TransformConvBwdDataToGemm_v1
|
||||
make_tuple(GemmKPerBlock, GemmMPerBlock),
|
||||
Sequence<true, DoPadGemmM>{});
|
||||
|
||||
const index_t K0PerBlock = GemmKPerBlock / AK1;
|
||||
const index_t AK0 =
|
||||
math::integer_divide_ceil(out_gemmk_gemmm_padded_grid_desc.GetLength(I0),
|
||||
AK1 * K0PerBlock * batch_k_) *
|
||||
K0PerBlock;
|
||||
|
||||
const auto out_gemmak0_gemmm_gemmak1_grid_desc = transform_tensor_descriptor(
|
||||
out_gemmk_gemmm_padded_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(AK0 * batch_k_, AK1)),
|
||||
@@ -775,8 +783,46 @@ struct TransformConvBwdDataToGemm_v1
|
||||
out_gemmk_gemmm_padded_grid_desc.GetLength(I1))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
return out_gemmak0_gemmm_gemmak1_grid_desc;
|
||||
#else
|
||||
const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor(
|
||||
out_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N_),
|
||||
make_pad_transform(Ho_, I0, I0),
|
||||
make_pad_transform(Wo_, I0, I0),
|
||||
make_pass_through_transform(K_)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto out_n_hop_wop_k_grid_desc_final = transform_tensor_descriptor(
|
||||
out_n_hop_wop_k_grid_desc,
|
||||
make_tuple(make_conv_bwd_data_out_transform(N_,
|
||||
Ho_,
|
||||
Wo_,
|
||||
K_,
|
||||
YDot_,
|
||||
XDot_,
|
||||
HTilde_,
|
||||
WTilde_,
|
||||
ConvDilationH_,
|
||||
ConvDilationW_,
|
||||
HTildeSlice,
|
||||
WTildeSlice,
|
||||
YDotSlice,
|
||||
XDotSlice,
|
||||
IHTildeSliceBegin,
|
||||
IWTildeSliceBegin,
|
||||
GcdStrideDilationH_,
|
||||
GcdStrideDilationW_,
|
||||
AK0,
|
||||
AK1,
|
||||
GemmMPerBlock,
|
||||
GemmKPerBlock)),
|
||||
make_tuple(Sequence<0, 1, 2, 3>{}),
|
||||
make_tuple(Sequence<0, 1, 2>{}));
|
||||
|
||||
return out_n_hop_wop_k_grid_desc_final;
|
||||
#endif
|
||||
}
|
||||
else if constexpr(NDimSpatial == 3)
|
||||
{
|
||||
|
||||
@@ -76,6 +76,47 @@ using device_grouped_conv_bwd_data_xdl_f16_16_16_instances =
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <index_t NDimSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
ConvolutionBackwardDataSpecialization ConvSpec>
|
||||
using device_grouped_conv_bwd_data_xdl_f16_optimized_loads_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
// ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer|
|
||||
// ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector|
|
||||
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock|
|
||||
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
// A K1 one access for each thread per load
|
||||
// 32x32
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 32, 64, 16, 16, 32, 32, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 4, 1, S<4, 4, 16>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, 1, 1, 1, S<1, 64, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 32, 32, 8, 8, 32, 32, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 4, 1, S<4, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, 1, 1, 1, S<1, 64, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 32, 16, 4, 4, 32, 32, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 4, 4>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, 1, 1, 1, S<1, 64, 1, 4>, 8>,
|
||||
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 32, 64, 16, 16, 32, 32, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 4, 1, S<4, 4, 16>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 1, 1, 1, S<1, 16, 1, 16>, 2>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 32, 32, 8, 8, 32, 32, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 4, 1, S<4, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 1, 1, 1, S<1, 16, 1, 16>, 2>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 32, 16, 4, 4, 32, 32, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 8, 4>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 1, 1, 1, S<1, 16, 1, 16>, 2>,
|
||||
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 32, 64, 16, 16, 32, 32, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 4, 1, S<4, 4, 16>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 8, 1, 32>, 1>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 32, 32, 8, 8, 32, 32, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 4, 1, S<4, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 8, 1, 32>, 1>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 32, 16, 4, 4, 32, 32, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 4>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 8, 1, 32>, 1>,
|
||||
// 16x16
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 16, 64, 16, 16, 16, 16, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 4, 1, S<4, 2, 16>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, 1, 1, 1, S<1, 64, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 4, 1, S<4, 2, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, 1, 1, 1, S<1, 64, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 16, 16, 4, 4, 16, 16, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 2, 4>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, 1, 1, 1, S<1, 64, 1, 4>, 4>,
|
||||
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 16, 64, 16, 16, 16, 16, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 4, 1, S<4, 4, 16>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 1, 1, 1, S<1, 32, 1, 8>, 2>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 4, 1, S<4, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 1, 1, 1, S<1, 32, 1, 8>, 2>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 16, 16, 4, 4, 16, 16, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 8, 4>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 1, 1, 1, S<1, 32, 1, 8>, 2>,
|
||||
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 16, 64, 16, 16, 16, 16, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 4, 1, S<4, 4, 16>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 4, 1, S<4, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 16, 16, 4, 4, 16, 16, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 4>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <index_t NDimSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
@@ -178,6 +219,48 @@ using device_grouped_conv_bwd_data_xdl_bf16_16_16_instances = std::tuple<
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 16, 32, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 16, 32, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 16, 32, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 4>
|
||||
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <index_t NDimSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
ConvolutionBackwardDataSpecialization ConvSpec>
|
||||
using device_grouped_conv_bwd_data_xdl_bf16_optimized_loads_instances = std::tuple<
|
||||
// clang-format off
|
||||
// ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer|
|
||||
// ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector|
|
||||
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock|
|
||||
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
// A K1 one access for each thread per load
|
||||
// 32x32
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 32, 64, 16, 16, 32, 32, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 4, 1, S<4, 4, 16>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, 1, 1, 1, S<1, 64, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 32, 32, 8, 8, 32, 32, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 4, 1, S<4, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, 1, 1, 1, S<1, 64, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 32, 16, 4, 4, 32, 32, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 4, 4>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, 1, 1, 1, S<1, 64, 1, 4>, 8>,
|
||||
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 32, 64, 16, 16, 32, 32, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 4, 1, S<4, 4, 16>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 1, 1, 1, S<1, 16, 1, 16>, 2>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 32, 32, 8, 8, 32, 32, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 4, 1, S<4, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 1, 1, 1, S<1, 16, 1, 16>, 2>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 32, 16, 4, 4, 32, 32, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 8, 4>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 1, 1, 1, S<1, 16, 1, 16>, 2>,
|
||||
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 32, 64, 16, 16, 32, 32, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 4, 1, S<4, 4, 16>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 8, 1, 32>, 1>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 32, 32, 8, 8, 32, 32, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 4, 1, S<4, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 8, 1, 32>, 1>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 32, 16, 4, 4, 32, 32, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 4>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 8, 1, 32>, 1>,
|
||||
// 16x16
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 16, 64, 16, 16, 16, 16, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 4, 1, S<4, 2, 16>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, 1, 1, 1, S<1, 64, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 4, 1, S<4, 2, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, 1, 1, 1, S<1, 64, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 16, 16, 4, 4, 16, 16, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 2, 4>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, 1, 1, 1, S<1, 64, 1, 4>, 4>,
|
||||
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 16, 64, 16, 16, 16, 16, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 4, 1, S<4, 4, 16>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 1, 1, 1, S<1, 32, 1, 8>, 2>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 4, 1, S<4, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 1, 1, 1, S<1, 32, 1, 8>, 2>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 16, 16, 4, 4, 16, 16, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 8, 4>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 1, 1, 1, S<1, 32, 1, 8>, 2>,
|
||||
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 16, 64, 16, 16, 16, 16, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 4, 1, S<4, 4, 16>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 4, 1, S<4, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 16, 16, 4, 4, 16, 16, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 4>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>
|
||||
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
@@ -257,6 +340,41 @@ using device_grouped_conv_bwd_data_xdl_f32_16_16_instances =
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <index_t NDimSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
ConvolutionBackwardDataSpecialization ConvSpec>
|
||||
using device_grouped_conv_bwd_data_xdl_f32_optimized_loads_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
// ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer|
|
||||
// ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector|
|
||||
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock|
|
||||
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
// A K1 one access for each thread per load
|
||||
// 32x32
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 32, 32, 8, 8, 32, 32, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 4, 1, S<4, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, 1, 1, 1, S<1, 32, 1, 8>, 4>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 32, 16, 4, 4, 32, 32, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 4, 4>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, 1, 1, 1, S<1, 32, 1, 8>, 4>,
|
||||
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 32, 32, 8, 8, 32, 32, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 4, 1, S<4, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 1, 1, 1, S<1, 16, 1, 16>, 2>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 32, 16, 4, 4, 32, 32, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 8, 4>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 1, 1, 1, S<1, 16, 1, 16>, 2>,
|
||||
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 32, 32, 8, 8, 32, 32, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 4, 1, S<4, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 8, 1, 32>, 1>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 32, 16, 4, 4, 32, 32, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 4>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 8, 1, 32>, 1>,
|
||||
// 16x16
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 4, 1, S<4, 2, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, 1, 1, 1, S<1, 64, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 16, 16, 4, 4, 16, 16, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 2, 4>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, 1, 1, 1, S<1, 64, 1, 4>, 4>,
|
||||
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 4, 1, S<4, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 1, 1, 1, S<1, 32, 1, 8>, 2>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 16, 16, 4, 4, 16, 16, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 8, 4>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 1, 1, 1, S<1, 32, 1, 8>, 2>,
|
||||
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 4, 1, S<4, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 16, 16, 4, 4, 16, 16, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 4>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <index_t NDimSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
|
||||
@@ -111,6 +111,8 @@ struct DeviceOperationInstanceFactory<
|
||||
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_instances(op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_16_16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_optimized_loads_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
@@ -121,6 +123,8 @@ struct DeviceOperationInstanceFactory<
|
||||
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_instances(op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_16_16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_optimized_loads_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
@@ -132,6 +136,8 @@ struct DeviceOperationInstanceFactory<
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_16_16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_optimized_loads_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
@@ -251,6 +257,8 @@ struct DeviceOperationInstanceFactory<
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_16_16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_optimized_loads_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#if defined CK_ENABLE_FP16 && defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
|
||||
@@ -271,6 +279,8 @@ struct DeviceOperationInstanceFactory<
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_16_16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_optimized_loads_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
@@ -282,6 +292,8 @@ struct DeviceOperationInstanceFactory<
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_16_16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_optimized_loads_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -83,6 +83,20 @@ void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_16_16_instance
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_optimized_loads_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
NHWGK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGC,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_instances(
|
||||
@@ -112,6 +126,20 @@ void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_16_16_instance
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_optimized_loads_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
NHWGK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGC,
|
||||
F32,
|
||||
F32,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_instances(
|
||||
@@ -141,6 +169,20 @@ void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_16_16_instanc
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_optimized_loads_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
NHWGK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGC,
|
||||
BF16,
|
||||
BF16,
|
||||
Empty_Tuple,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_FP16
|
||||
@@ -393,6 +435,20 @@ void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_16_16_insta
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_optimized_loads_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
NDHWGK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGC,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_instances(
|
||||
@@ -422,6 +478,20 @@ void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_16_16_insta
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_optimized_loads_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
NDHWGK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGC,
|
||||
F32,
|
||||
F32,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_instances(
|
||||
@@ -451,6 +521,20 @@ void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_16_16_inst
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_optimized_loads_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
NDHWGK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGC,
|
||||
BF16,
|
||||
BF16,
|
||||
Empty_Tuple,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#if defined CK_ENABLE_FP16 && defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_input_f16_comp_bf8f8_instances(
|
||||
|
||||
@@ -10,6 +10,9 @@ add_instance_library(
|
||||
xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f16_16_16_instance.cpp
|
||||
xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_bf16_16_16_instance.cpp
|
||||
xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f32_16_16_instance.cpp
|
||||
xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f16_optimized_loads_instance.cpp
|
||||
xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_bf16_optimized_loads_instance.cpp
|
||||
xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f32_optimized_loads_instance.cpp
|
||||
xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkyxc_ngkhw_f16_instance.cpp
|
||||
xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkyxc_ngkhw_bf16_instance.cpp
|
||||
xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkyxc_ngkhw_f32_instance.cpp
|
||||
|
||||
@@ -0,0 +1,49 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_optimized_loads_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
NHWGK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGC,
|
||||
BF16,
|
||||
BF16,
|
||||
Empty_Tuple,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
// 1. Default
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_data_xdl_bf16_optimized_loads_instances<2,
|
||||
NHWGK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGC,
|
||||
ConvBwdDataDefault>{});
|
||||
// 2. Filter1x1Stride1Pad0
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_bwd_data_xdl_bf16_optimized_loads_instances<
|
||||
2,
|
||||
NHWGK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGC,
|
||||
ConvBwdDataFilter1x1Stride1Pad0>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,49 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_optimized_loads_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
NHWGK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGC,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
// 1. Default
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_data_xdl_f16_optimized_loads_instances<2,
|
||||
NHWGK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGC,
|
||||
ConvBwdDataDefault>{});
|
||||
// 2. Filter1x1Stride1Pad0
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_bwd_data_xdl_f16_optimized_loads_instances<
|
||||
2,
|
||||
NHWGK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGC,
|
||||
ConvBwdDataFilter1x1Stride1Pad0>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,49 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_optimized_loads_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
NHWGK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGC,
|
||||
F32,
|
||||
F32,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
// 1. Default
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_data_xdl_f32_optimized_loads_instances<2,
|
||||
NHWGK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGC,
|
||||
ConvBwdDataDefault>{});
|
||||
// 2. Filter1x1Stride1Pad0
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_bwd_data_xdl_f32_optimized_loads_instances<
|
||||
2,
|
||||
NHWGK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGC,
|
||||
ConvBwdDataFilter1x1Stride1Pad0>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -9,6 +9,9 @@ set(GROUPED_CONV3D_BWD_DATA
|
||||
xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f16_16_16_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16_16_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f32_16_16_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f16_optimized_loads_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_optimized_loads_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f32_optimized_loads_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkzyxc_ngkdhw_f16_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkzyxc_ngkdhw_f32_instance.cpp
|
||||
|
||||
@@ -0,0 +1,49 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_optimized_loads_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
NDHWGK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGC,
|
||||
BF16,
|
||||
BF16,
|
||||
Empty_Tuple,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
// 1. Default
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_data_xdl_bf16_optimized_loads_instances<3,
|
||||
NDHWGK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGC,
|
||||
ConvBwdDataDefault>{});
|
||||
// 2. Filter1x1Stride1Pad0
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_bwd_data_xdl_bf16_optimized_loads_instances<
|
||||
3,
|
||||
NDHWGK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGC,
|
||||
ConvBwdDataFilter1x1Stride1Pad0>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,49 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_optimized_loads_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
NDHWGK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGC,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
// 1. Default
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_data_xdl_f16_optimized_loads_instances<3,
|
||||
NDHWGK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGC,
|
||||
ConvBwdDataDefault>{});
|
||||
// 2. Filter1x1Stride1Pad0
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_bwd_data_xdl_f16_optimized_loads_instances<
|
||||
3,
|
||||
NDHWGK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGC,
|
||||
ConvBwdDataFilter1x1Stride1Pad0>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,49 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_optimized_loads_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
NDHWGK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGC,
|
||||
F32,
|
||||
F32,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
// 1. Default
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_data_xdl_f32_optimized_loads_instances<3,
|
||||
NDHWGK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGC,
|
||||
ConvBwdDataDefault>{});
|
||||
// 2. Filter1x1Stride1Pad0
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_bwd_data_xdl_f32_optimized_loads_instances<
|
||||
3,
|
||||
NDHWGK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGC,
|
||||
ConvBwdDataFilter1x1Stride1Pad0>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -185,11 +185,17 @@ bool profile_grouped_conv_bwd_data_impl(int do_verification,
|
||||
// Use higher threshold
|
||||
rtol = std::max(rtol, rtol_split_k);
|
||||
atol = std::max(atol, atol_split_k);
|
||||
|
||||
pass &= ck::utils::check_err(
|
||||
in_device, in_host, "Error: Incorrect results!", rtol, atol);
|
||||
std::cout << "Relative error threshold: " << rtol
|
||||
<< " Absolute error threshold: " << atol << std::endl;
|
||||
if(split_k_for_run > 1)
|
||||
{
|
||||
pass &= ck::utils::check_err(
|
||||
in_device, in_host, "Error: Incorrect results!", rtol, atol);
|
||||
std::cout << "Relative error threshold: " << rtol
|
||||
<< " Absolute error threshold: " << atol << std::endl;
|
||||
}
|
||||
else
|
||||
{
|
||||
pass &= ck::utils::check_err(in_device, in_host, "Error: Incorrect results!");
|
||||
}
|
||||
|
||||
if(do_log)
|
||||
{
|
||||
|
||||
@@ -10,7 +10,7 @@ import subprocess
|
||||
|
||||
|
||||
def init_const_args(args):
|
||||
args.ck_profiler_cmd = '../build/bin/ckProfiler'
|
||||
args.ck_profiler_cmd = "../build/bin/ckProfiler"
|
||||
# use decimal values
|
||||
args.init_method = 2
|
||||
# don't print tensor values
|
||||
@@ -27,52 +27,62 @@ def run_ck_profiler_cmd(cmd):
|
||||
|
||||
|
||||
def parse_layouts(args):
|
||||
if args.in_layout == "NCW" or args.in_layout == "NCHW" or \
|
||||
args.in_layout == "NCDHW":
|
||||
if args.in_layout == "NCW" or args.in_layout == "NCHW" or args.in_layout == "NCDHW":
|
||||
if args.ck_profier_op == "grouped_conv_bwd_weight":
|
||||
args.layout = 4
|
||||
elif args.ck_profier_op == "grouped_conv_fwd" or \
|
||||
args.ck_profier_op == "grouped_conv_bwd_data":
|
||||
elif (
|
||||
args.ck_profier_op == "grouped_conv_fwd"
|
||||
or args.ck_profier_op == "grouped_conv_bwd_data"
|
||||
):
|
||||
args.layout = 3
|
||||
else:
|
||||
print('Not supported layout for this op')
|
||||
print("Not supported layout for this op")
|
||||
exit(1)
|
||||
elif args.in_layout == "NWC" or args.in_layout == "NHWC" or \
|
||||
args.in_layout == "NDHWC":
|
||||
elif (
|
||||
args.in_layout == "NWC" or args.in_layout == "NHWC" or args.in_layout == "NDHWC"
|
||||
):
|
||||
if args.ck_profier_op == "grouped_conv_bwd_weight":
|
||||
args.layout = 2
|
||||
elif args.ck_profier_op == "grouped_conv_bwd_data" or \
|
||||
args.ck_profier_op == "grouped_conv_fwd":
|
||||
elif (
|
||||
args.ck_profier_op == "grouped_conv_bwd_data"
|
||||
or args.ck_profier_op == "grouped_conv_fwd"
|
||||
):
|
||||
args.layout = 1
|
||||
else:
|
||||
print('Not supported layout for this op')
|
||||
print("Not supported layout for this op")
|
||||
exit(1)
|
||||
|
||||
|
||||
def parse_data_type(args):
|
||||
if args.data_type == "fp32":
|
||||
if args.ck_profier_op == "grouped_conv_bwd_weight" or \
|
||||
args.ck_profier_op == "grouped_conv_bwd_data" or \
|
||||
args.ck_profier_op == "grouped_conv_fwd":
|
||||
if (
|
||||
args.ck_profier_op == "grouped_conv_bwd_weight"
|
||||
or args.ck_profier_op == "grouped_conv_bwd_data"
|
||||
or args.ck_profier_op == "grouped_conv_fwd"
|
||||
):
|
||||
args.data_type = 0
|
||||
if args.data_type == "fp16":
|
||||
if args.ck_profier_op == "grouped_conv_bwd_weight" or \
|
||||
args.ck_profier_op == "grouped_conv_bwd_data" or \
|
||||
args.ck_profier_op == "grouped_conv_fwd":
|
||||
if (
|
||||
args.ck_profier_op == "grouped_conv_bwd_weight"
|
||||
or args.ck_profier_op == "grouped_conv_bwd_data"
|
||||
or args.ck_profier_op == "grouped_conv_fwd"
|
||||
):
|
||||
args.data_type = 1
|
||||
if args.data_type == "int8":
|
||||
if args.ck_profier_op == "grouped_conv_bwd_weight":
|
||||
args.data_type = 4
|
||||
if args.ck_profier_op == "grouped_conv_bwd_data":
|
||||
print('Not supported data type for grouped_conv_bwd_data')
|
||||
print("Not supported data type for grouped_conv_bwd_data")
|
||||
exit(1)
|
||||
if args.ck_profier_op == "grouped_conv_fwd":
|
||||
args.data_type = 3
|
||||
if args.data_type == "bfp16":
|
||||
if args.ck_profier_op == "grouped_conv_bwd_weight":
|
||||
args.data_type = 5
|
||||
if args.ck_profier_op == "grouped_conv_bwd_data" or \
|
||||
args.ck_profier_op == "grouped_conv_fwd":
|
||||
if (
|
||||
args.ck_profier_op == "grouped_conv_bwd_data"
|
||||
or args.ck_profier_op == "grouped_conv_fwd"
|
||||
):
|
||||
args.data_type = 2
|
||||
|
||||
|
||||
@@ -93,13 +103,11 @@ def add_conv_params_to_cmd(args, cmd):
|
||||
cmd += [str(args.in_d), str(args.in_h), str(args.in_w)]
|
||||
cmd += [str(args.conv_stride_d), str(args.conv_stride_h)]
|
||||
cmd += [str(args.conv_stride_w)]
|
||||
cmd += [str(args.dilation_d),
|
||||
str(args.dilation_h),
|
||||
str(args.dilation_w)]
|
||||
cmd += [str(args.dilation_d), str(args.dilation_h), str(args.dilation_w)]
|
||||
cmd += [str(args.pad_d), str(args.pad_h), str(args.pad_w)]
|
||||
cmd += [str(args.pad_d), str(args.pad_h), str(args.pad_w)]
|
||||
else:
|
||||
print('Not supported spatial dim (supported: 1, 2, 3)')
|
||||
print("Not supported spatial dim (supported: 1, 2, 3)")
|
||||
exit(1)
|
||||
|
||||
|
||||
@@ -147,7 +155,7 @@ def run_ck_grouped_conv_bwd_weight(args):
|
||||
parse_data_type(args)
|
||||
parse_layouts(args)
|
||||
# Test all split K value from the list {1, 2, 4, 8, 32, 64, 128}
|
||||
args.split_k_value = -1
|
||||
args.split_k_value = "all"
|
||||
|
||||
cmd = [str(args.ck_profiler_cmd), str(args.ck_profier_op)]
|
||||
cmd += [str(args.data_type), str(args.layout)]
|
||||
@@ -161,23 +169,23 @@ def run_ck_grouped_conv_bwd_weight(args):
|
||||
cmd += [str(args.split_k_value)]
|
||||
run_ck_profiler_cmd(cmd)
|
||||
|
||||
|
||||
# Get name of miopen driver, remove it from unknown
|
||||
def process_miopen_driver_name(args, unknown):
|
||||
if "convint8" in unknown:
|
||||
args.data_type = 'int8'
|
||||
args.data_type = "int8"
|
||||
unknown.remove("convint8")
|
||||
elif "convbfp16" in unknown:
|
||||
args.data_type = 'bfp16'
|
||||
args.data_type = "bfp16"
|
||||
unknown.remove("convbfp16")
|
||||
elif "convfp16" in unknown:
|
||||
args.data_type = 'fp16'
|
||||
args.data_type = "fp16"
|
||||
unknown.remove("convfp16")
|
||||
elif "conv" in unknown:
|
||||
args.data_type = 'fp32'
|
||||
args.data_type = "fp32"
|
||||
unknown.remove("conv")
|
||||
else:
|
||||
print('Not supported driver (supported: conv, convfp16, convint8,'
|
||||
' convbfp16).')
|
||||
print("Not supported driver (supported: conv, convfp16, convint8, convbfp16).")
|
||||
exit(1)
|
||||
|
||||
|
||||
@@ -199,11 +207,11 @@ if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="converter",
|
||||
description="Convert miopen driver command to ck Profiler"
|
||||
"\nExample: python3 "
|
||||
"../script/convert_miopen_driver_to_profiler.py "
|
||||
"/opt/rocm/bin/MIOpenDriver conv -n 32 -c 64 -H 28 -W 28 "
|
||||
"-k 64 -y 3 -x 3 -p 1 -q 1 -u 1 -v 1 -l 1 -j 1 -m conv -g "
|
||||
"32 -F 1 -t 1",
|
||||
"\nExample: python3 "
|
||||
"../script/convert_miopen_driver_to_profiler.py "
|
||||
"/opt/rocm/bin/MIOpenDriver conv -n 32 -c 64 -H 28 -W 28 "
|
||||
"-k 64 -y 3 -x 3 -p 1 -q 1 -u 1 -v 1 -l 1 -j 1 -m conv -g "
|
||||
"32 -F 1 -t 1",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-in_layout",
|
||||
@@ -213,7 +221,7 @@ if __name__ == "__main__":
|
||||
default="NCHW",
|
||||
type=str,
|
||||
required=False,
|
||||
help="Input Layout (Default=NCHW for 2d conv, NCDHW for 3d conv)"
|
||||
help="Input Layout (Default=NCHW for 2d conv, NCDHW for 3d conv)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-forw",
|
||||
@@ -230,7 +238,7 @@ if __name__ == "__main__":
|
||||
"\n4 wrw only"
|
||||
"\n3 fwd+bwd"
|
||||
"\n5 fwd+wrw"
|
||||
"\n6 bwd+wrw"
|
||||
"\n6 bwd+wrw",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-spatial_dim",
|
||||
@@ -240,7 +248,7 @@ if __name__ == "__main__":
|
||||
default=2,
|
||||
type=int,
|
||||
required=False,
|
||||
help="convolution spatial dimension (Default-2)"
|
||||
help="convolution spatial dimension (Default-2)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-batchsize",
|
||||
@@ -250,7 +258,7 @@ if __name__ == "__main__":
|
||||
default=100,
|
||||
type=int,
|
||||
required=False,
|
||||
help="Mini-batch size (Default=100)"
|
||||
help="Mini-batch size (Default=100)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-in_channels",
|
||||
@@ -260,7 +268,7 @@ if __name__ == "__main__":
|
||||
default=3,
|
||||
type=int,
|
||||
required=False,
|
||||
help="Number of Input Channels (Default=3)"
|
||||
help="Number of Input Channels (Default=3)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-in_d",
|
||||
@@ -270,7 +278,7 @@ if __name__ == "__main__":
|
||||
default=32,
|
||||
type=int,
|
||||
required=False,
|
||||
help="Input Depth (Default=32)"
|
||||
help="Input Depth (Default=32)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-in_h",
|
||||
@@ -280,7 +288,7 @@ if __name__ == "__main__":
|
||||
default=32,
|
||||
type=int,
|
||||
required=False,
|
||||
help="Input Height (Default=32)"
|
||||
help="Input Height (Default=32)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-in_w",
|
||||
@@ -290,7 +298,7 @@ if __name__ == "__main__":
|
||||
default=32,
|
||||
type=int,
|
||||
required=False,
|
||||
help="Input Width (Default=32)"
|
||||
help="Input Width (Default=32)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-out_channels",
|
||||
@@ -300,7 +308,7 @@ if __name__ == "__main__":
|
||||
default=32,
|
||||
type=int,
|
||||
required=False,
|
||||
help="Number of Output Channels (Default=32)"
|
||||
help="Number of Output Channels (Default=32)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-fil_d",
|
||||
@@ -310,7 +318,7 @@ if __name__ == "__main__":
|
||||
default=3,
|
||||
type=int,
|
||||
required=False,
|
||||
help="Filter Depth (Default=3)"
|
||||
help="Filter Depth (Default=3)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-fil_h",
|
||||
@@ -320,7 +328,7 @@ if __name__ == "__main__":
|
||||
default=3,
|
||||
type=int,
|
||||
required=False,
|
||||
help="Filter Height (Default=3)"
|
||||
help="Filter Height (Default=3)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-fil_w",
|
||||
@@ -330,7 +338,7 @@ if __name__ == "__main__":
|
||||
default=3,
|
||||
type=int,
|
||||
required=False,
|
||||
help="Filter Width (Default=3)"
|
||||
help="Filter Width (Default=3)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-conv_stride_d",
|
||||
@@ -340,7 +348,7 @@ if __name__ == "__main__":
|
||||
default=1,
|
||||
type=int,
|
||||
required=False,
|
||||
help="Convolution Stride for Depth (Default=1)"
|
||||
help="Convolution Stride for Depth (Default=1)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-conv_stride_h",
|
||||
@@ -350,7 +358,7 @@ if __name__ == "__main__":
|
||||
default=1,
|
||||
type=int,
|
||||
required=False,
|
||||
help="Convolution Stride for Height (Default=1)"
|
||||
help="Convolution Stride for Height (Default=1)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-conv_stride_w",
|
||||
@@ -360,7 +368,7 @@ if __name__ == "__main__":
|
||||
default=1,
|
||||
type=int,
|
||||
required=False,
|
||||
help="Convolution Stride for Width (Default=1)"
|
||||
help="Convolution Stride for Width (Default=1)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-pad_d",
|
||||
@@ -370,7 +378,7 @@ if __name__ == "__main__":
|
||||
default=1,
|
||||
type=int,
|
||||
required=False,
|
||||
help="Zero Padding for Depth (Default=0)"
|
||||
help="Zero Padding for Depth (Default=0)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-pad_h",
|
||||
@@ -380,7 +388,7 @@ if __name__ == "__main__":
|
||||
default=1,
|
||||
type=int,
|
||||
required=False,
|
||||
help="Zero Padding for Height (Default=0)"
|
||||
help="Zero Padding for Height (Default=0)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-pad_w",
|
||||
@@ -390,7 +398,7 @@ if __name__ == "__main__":
|
||||
default=1,
|
||||
type=int,
|
||||
required=False,
|
||||
help="Zero Padding for Width (Default=0)"
|
||||
help="Zero Padding for Width (Default=0)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-verify",
|
||||
@@ -400,7 +408,7 @@ if __name__ == "__main__":
|
||||
default=1,
|
||||
type=int,
|
||||
required=False,
|
||||
help="Verify Each Layer (Default=1)"
|
||||
help="Verify Each Layer (Default=1)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-time",
|
||||
@@ -410,7 +418,7 @@ if __name__ == "__main__":
|
||||
default=0,
|
||||
type=int,
|
||||
required=False,
|
||||
help="Time Each Layer (Default=0)"
|
||||
help="Time Each Layer (Default=0)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-dilation_d",
|
||||
@@ -420,7 +428,7 @@ if __name__ == "__main__":
|
||||
default=1,
|
||||
type=int,
|
||||
required=False,
|
||||
help="Dilation of Filter Depth (Default=1)"
|
||||
help="Dilation of Filter Depth (Default=1)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-dilation_h",
|
||||
@@ -430,7 +438,7 @@ if __name__ == "__main__":
|
||||
default=1,
|
||||
type=int,
|
||||
required=False,
|
||||
help="Dilation of Filter Height (Default=1)"
|
||||
help="Dilation of Filter Height (Default=1)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-dilation_w",
|
||||
@@ -440,7 +448,7 @@ if __name__ == "__main__":
|
||||
default=1,
|
||||
type=int,
|
||||
required=False,
|
||||
help="Dilation of Filter Width (Default=1)"
|
||||
help="Dilation of Filter Width (Default=1)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-group_count",
|
||||
@@ -450,7 +458,7 @@ if __name__ == "__main__":
|
||||
type=int,
|
||||
default=1,
|
||||
required=False,
|
||||
help="Number of Groups (Default=1)"
|
||||
help="Number of Groups (Default=1)",
|
||||
)
|
||||
|
||||
args, unknown = parser.parse_known_args()
|
||||
|
||||
Reference in New Issue
Block a user