diff --git a/include/ck/tensor_description/multi_index_transform.hpp b/include/ck/tensor_description/multi_index_transform.hpp index c152cbfb1e..e24227ecc3 100644 --- a/include/ck/tensor_description/multi_index_transform.hpp +++ b/include/ck/tensor_description/multi_index_transform.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -1553,6 +1553,198 @@ struct UnMerge } }; +/** + * @brief Transformation struct for convolution backward data output indices to GEMM indices. + * + * This struct is responsible for mapping the output tensor indices (N, Ho, Wo, K) from the + * convolution backward data operation to the corresponding indices (K0, M, K1) used in the + * implicit GEMM computation. It encapsulates the necessary parameters and transformation logic + * required to efficiently perform the index conversion. + */ +struct ConvBwdDataImplicitGemmOutTransform +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + using LowerIndex = MultiIndex<4>; // N, Ho, Wo, K + using UpperIndex = MultiIndex<3>; // K0, M, K1 + + index_t N_, Ho_, Wo_, K_; + index_t XDot_; + index_t HTilde_, WTilde_; + index_t WTildeSlice_, TildeSlice_; + index_t IHTildeSliceBegin_, IWTildeSliceBegin_; + index_t HRatio_, WRatio_; + index_t XDotSlice_K_; + index_t MPad_, KPad_; + Tuple up_lengths_; // K0_, MPadded, K1_; + + Tuple + low_lengths_magic_divisor_multiplier_; // XDotSlice_K_, K_, TildeSlice_, WTildeSlice_ + Tuple + low_lengths_magic_divisor_shift_; // XDotSlice_K_, K_, TildeSlice_, WTildeSlice_ + + __host__ __device__ constexpr ConvBwdDataImplicitGemmOutTransform() = default; + + __host__ __device__ constexpr ConvBwdDataImplicitGemmOutTransform(index_t N, + index_t Ho, + index_t Wo, + index_t K, + index_t XDot, + index_t HTilde, + index_t WTilde, + index_t WTildeSlice, + index_t HWTildeSlice, + index_t IHTildeSliceBegin, + index_t IWTildeSliceBegin, + index_t HRatio, + index_t WRatio, + index_t XDotSlice_K, + index_t K0, + index_t MPadded, + index_t K1, + index_t MPad, + index_t KPad) + : N_{N}, + Ho_{Ho}, + Wo_{Wo}, + K_{K}, + XDot_{XDot}, + HTilde_{HTilde}, + WTilde_{WTilde}, + WTildeSlice_{WTildeSlice}, + TildeSlice_{HWTildeSlice}, + IHTildeSliceBegin_{IHTildeSliceBegin}, + IWTildeSliceBegin_{IWTildeSliceBegin}, + HRatio_{HRatio}, + WRatio_{WRatio}, + XDotSlice_K_{XDotSlice_K}, + MPad_{MPad}, + KPad_{KPad}, + up_lengths_{make_tuple(K0, MPadded, K1)}, + low_lengths_magic_divisor_multiplier_{ + MagicDivision::CalculateMagicMultiplier(XDotSlice_K_), + MagicDivision::CalculateMagicMultiplier(K_), + MagicDivision::CalculateMagicMultiplier(TildeSlice_), + MagicDivision::CalculateMagicMultiplier(WTildeSlice_)}, + low_lengths_magic_divisor_shift_{MagicDivision::CalculateMagicShift(XDotSlice_K_), + MagicDivision::CalculateMagicShift(K_), + MagicDivision::CalculateMagicShift(TildeSlice_), + MagicDivision::CalculateMagicShift(WTildeSlice_)} + { + } + + __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 4; } + + __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 3; } + + __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; } + + template + __host__ __device__ constexpr auto CalculateLowerIndexN(const UpIdx& idx_up) const + { + index_t NStep, HStep, WStep; + // Merge + // NStep = M_id / TildeSlice_ + NStep = MagicDivision::DoMagicDivision(idx_up[I1], + this->low_lengths_magic_divisor_multiplier_[I2], + this->low_lengths_magic_divisor_shift_[I2]); + HStep = idx_up[I1] - NStep * TildeSlice_; + // HStep = HStep / WTildeSlice_ + HStep = MagicDivision::DoMagicDivision(HStep, + this->low_lengths_magic_divisor_multiplier_[I3], + this->low_lengths_magic_divisor_shift_[I3]); + WStep = idx_up[I1] - NStep * TildeSlice_ - HStep * WTildeSlice_; + // Slice + HStep += IHTildeSliceBegin_; + WStep += IWTildeSliceBegin_; + + return make_tuple(NStep, HStep, WStep, 0); + } + + template + __host__ __device__ constexpr auto CalculateLowerIndexK(const UpIdx& idx_up) const + { + // UnMerge + // K_idx <- K0_idx * K1 + K1_idx + index_t K_idx = idx_up[I0] * up_lengths_[I2] + idx_up[I2]; + // Merge + // YStep = K_idx / XDotSlice_K_ + index_t YStep = + MagicDivision::DoMagicDivision(K_idx, + this->low_lengths_magic_divisor_multiplier_[I0], + this->low_lengths_magic_divisor_shift_[I0]); + index_t KStep = K_idx - YStep * XDotSlice_K_; + // Xstep = KStep / K_ + index_t XStep = + MagicDivision::DoMagicDivision(KStep, + this->low_lengths_magic_divisor_multiplier_[I1], + this->low_lengths_magic_divisor_shift_[I1]); + KStep -= XStep * K_; + // Embed + YStep *= HRatio_; + XStep *= WRatio_; + + return make_tuple(0, YStep, XStep, KStep); + } + + template + __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low, + const UpIdx& idx_up) const + { + idx_low = CalculateLowerIndexN(idx_up) + CalculateLowerIndexK(idx_up); + } + + template + __host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low, + const UpIdxDiff& /* idx_diff_up */, + LowIdx& idx_low, + const UpIdx& idx_up, + Number) const + { + LowIdx low_old = idx_low; + idx_low = CalculateLowerIndexN(idx_up) + CalculateLowerIndexK(idx_up); + idx_diff_low = idx_low - low_old; + } + + __host__ __device__ static constexpr bool IsLinearTransform() { return false; } + + __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex() + { + return true; + } + + template + __host__ __device__ constexpr bool + IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& idx_up) const + { + // Padding + index_t K_idx = idx_up[Number<0>{}] * up_lengths_[Number<2>{}] + idx_up[Number<2>{}]; + index_t& M_idx = idx_up[Number<1>{}]; + + bool pad_valid = M_idx < up_lengths_[Number<1>{}] - MPad_ && + K_idx < up_lengths_[Number<0>{}] * up_lengths_[Number<2>{}] - KPad_; + return pad_valid; + } + + __host__ __device__ static constexpr bool IsKnownAtCompileTime() { return false; } + + __host__ __device__ void Print() const + { + printf("{"); + printf("ConvBwdDataImplicitGemmOutTransform, "); + printf("up_lengths_"); + print_multi_index(up_lengths_); + printf("}"); + } +}; + template struct Freeze { diff --git a/include/ck/tensor_description/multi_index_transform_helper.hpp b/include/ck/tensor_description/multi_index_transform_helper.hpp index 8feadf63c6..a6626ae252 100644 --- a/include/ck/tensor_description/multi_index_transform_helper.hpp +++ b/include/ck/tensor_description/multi_index_transform_helper.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -94,6 +94,59 @@ __host__ __device__ constexpr auto make_unmerge_transform( return UnMerge{up_lengths}; } +__host__ __device__ constexpr auto make_conv_bwd_data_out_transform(index_t N, + index_t Ho, + index_t Wo, + index_t K, + [[maybe_unused]] index_t YDot, + index_t XDot, + index_t HTilde, + index_t WTilde, + index_t ConvDilationH, + index_t ConvDilationW, + index_t HTildeSlice, + index_t WTildeSlice, + index_t YDotSlice, + index_t XDotSlice, + index_t IHTildeSliceBegin, + index_t IWTildeSliceBegin, + index_t GcdStrideDilationH, + index_t GcdStrideDilationW, + index_t K0, + index_t K1, + index_t MPerBlock, + index_t GemmKPerBlock) +{ + // Calculate padding + const auto MRaw = N * HTildeSlice * WTildeSlice; + const auto MPadded = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; + const auto MPad = MPadded - MRaw; + + const auto KRaw = YDotSlice * XDotSlice * K; + const auto KPadded = math::integer_divide_ceil(KRaw, GemmKPerBlock) * GemmKPerBlock; + const auto KPad = KPadded - KRaw; + + return ConvBwdDataImplicitGemmOutTransform{N, + Ho, + Wo, + K, + XDot, + HTilde, + WTilde, + WTildeSlice, + HTildeSlice * WTildeSlice, + IHTildeSliceBegin, + IWTildeSliceBegin, + -ConvDilationH / GcdStrideDilationH, + -ConvDilationW / GcdStrideDilationW, + XDotSlice * K, + K0, + MPadded, + K1, + MPad, + KPad}; +} + template __host__ __device__ constexpr auto make_freeze_transform(const LowerIndex& low_idx) { diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp index 57ea476ced..383b872832 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp @@ -1485,7 +1485,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 static bool IsSupportedArgument(const Argument& arg) { // gfx11 doesn't support float atomic - if(ck::is_gfx11_supported() && arg.k_batch_ > 1) + // Todo: Enable splitK for gfx12 + if((ck::is_gfx12_supported() || ck::is_gfx11_supported()) && arg.k_batch_ > 1) { return false; } diff --git a/include/ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp b/include/ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp index 977c622f06..03c1945d95 100644 --- a/include/ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp +++ b/include/ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp @@ -13,6 +13,14 @@ namespace ck { namespace tensor_operation { +/** + * @brief Enable custom tensor transform for convolution backward data output. + * + * When set to 1, this macro enables a custom transformation of the output tensor + * in convolution backward data operations. + */ +#define CK_USE_CUSTOM_TENSOR_TRANSFORM_FOR_BWD_DATA_OUT 1 + template < index_t NDimSpatial, ck::tensor_operation::device::ConvolutionBackwardDataSpecialization ConvBwdDataSpecialization, @@ -705,6 +713,12 @@ struct TransformConvBwdDataToGemm_v1 if constexpr(NDimSpatial == 2) { + const index_t K0PerBlock = GemmKPerBlock / AK1; + const index_t AK0 = math::integer_divide_ceil(YDotSlice * XDotSlice * K_, + AK1 * K0PerBlock * batch_k_) * + K0PerBlock; + +#if CK_USE_CUSTOM_TENSOR_TRANSFORM_FOR_BWD_DATA_OUT == 0 // A: output tensor const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor( out_grid_desc, @@ -762,12 +776,6 @@ struct TransformConvBwdDataToGemm_v1 make_tuple(GemmKPerBlock, GemmMPerBlock), Sequence{}); - const index_t K0PerBlock = GemmKPerBlock / AK1; - const index_t AK0 = - math::integer_divide_ceil(out_gemmk_gemmm_padded_grid_desc.GetLength(I0), - AK1 * K0PerBlock * batch_k_) * - K0PerBlock; - const auto out_gemmak0_gemmm_gemmak1_grid_desc = transform_tensor_descriptor( out_gemmk_gemmm_padded_grid_desc, make_tuple(make_unmerge_transform(make_tuple(AK0 * batch_k_, AK1)), @@ -775,8 +783,46 @@ struct TransformConvBwdDataToGemm_v1 out_gemmk_gemmm_padded_grid_desc.GetLength(I1))), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - return out_gemmak0_gemmm_gemmak1_grid_desc; +#else + const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor( + out_grid_desc, + make_tuple(make_pass_through_transform(N_), + make_pad_transform(Ho_, I0, I0), + make_pad_transform(Wo_, I0, I0), + make_pass_through_transform(K_)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto out_n_hop_wop_k_grid_desc_final = transform_tensor_descriptor( + out_n_hop_wop_k_grid_desc, + make_tuple(make_conv_bwd_data_out_transform(N_, + Ho_, + Wo_, + K_, + YDot_, + XDot_, + HTilde_, + WTilde_, + ConvDilationH_, + ConvDilationW_, + HTildeSlice, + WTildeSlice, + YDotSlice, + XDotSlice, + IHTildeSliceBegin, + IWTildeSliceBegin, + GcdStrideDilationH_, + GcdStrideDilationW_, + AK0, + AK1, + GemmMPerBlock, + GemmKPerBlock)), + make_tuple(Sequence<0, 1, 2, 3>{}), + make_tuple(Sequence<0, 1, 2>{})); + + return out_n_hop_wop_k_grid_desc_final; +#endif } else if constexpr(NDimSpatial == 3) { diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp index 11a8ff8e91..f16a345e14 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp @@ -76,6 +76,47 @@ using device_grouped_conv_bwd_data_xdl_f16_16_16_instances = // clang-format on >; +template +using device_grouped_conv_bwd_data_xdl_f16_optimized_loads_instances = + std::tuple< + // clang-format off + // ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer| + // ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector| + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // A K1 one access for each thread per load + // 32x32 + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 32, 64, 16, 16, 32, 32, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 4, 1, S<4, 4, 16>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, 1, 1, 1, S<1, 64, 1, 4>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 32, 32, 8, 8, 32, 32, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 4, 1, S<4, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, 1, 1, 1, S<1, 64, 1, 4>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 32, 16, 4, 4, 32, 32, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 4, 4>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, 1, 1, 1, S<1, 64, 1, 4>, 8>, + + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 32, 64, 16, 16, 32, 32, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 4, 1, S<4, 4, 16>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 1, 1, 1, S<1, 16, 1, 16>, 2>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 32, 32, 8, 8, 32, 32, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 4, 1, S<4, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 1, 1, 1, S<1, 16, 1, 16>, 2>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 32, 16, 4, 4, 32, 32, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 8, 4>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 1, 1, 1, S<1, 16, 1, 16>, 2>, + + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 32, 64, 16, 16, 32, 32, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 4, 1, S<4, 4, 16>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 8, 1, 32>, 1>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 32, 32, 8, 8, 32, 32, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 4, 1, S<4, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 8, 1, 32>, 1>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 32, 16, 4, 4, 32, 32, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 4>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 8, 1, 32>, 1>, + // 16x16 + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 16, 64, 16, 16, 16, 16, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 4, 1, S<4, 2, 16>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, 1, 1, 1, S<1, 64, 1, 4>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 4, 1, S<4, 2, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, 1, 1, 1, S<1, 64, 1, 4>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 16, 16, 4, 4, 16, 16, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 2, 4>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, 1, 1, 1, S<1, 64, 1, 4>, 4>, + + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 16, 64, 16, 16, 16, 16, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 4, 1, S<4, 4, 16>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 1, 1, 1, S<1, 32, 1, 8>, 2>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 4, 1, S<4, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 1, 1, 1, S<1, 32, 1, 8>, 2>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 16, 16, 4, 4, 16, 16, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 8, 4>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 1, 1, 1, S<1, 32, 1, 8>, 2>, + + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 16, 64, 16, 16, 16, 16, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 4, 1, S<4, 4, 16>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 4, 1, S<4, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 16, 16, 4, 4, 16, 16, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 4>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1> + // clang-format on + >; + template , S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 4>, DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 16, 32, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 16, 32, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 4> + + // clang-format on + >; + +template +using device_grouped_conv_bwd_data_xdl_bf16_optimized_loads_instances = std::tuple< + // clang-format off + // ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer| + // ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector| + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // A K1 one access for each thread per load + // 32x32 + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 32, 64, 16, 16, 32, 32, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 4, 1, S<4, 4, 16>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, 1, 1, 1, S<1, 64, 1, 4>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 32, 32, 8, 8, 32, 32, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 4, 1, S<4, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, 1, 1, 1, S<1, 64, 1, 4>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 32, 16, 4, 4, 32, 32, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 4, 4>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, 1, 1, 1, S<1, 64, 1, 4>, 8>, + + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 32, 64, 16, 16, 32, 32, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 4, 1, S<4, 4, 16>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 1, 1, 1, S<1, 16, 1, 16>, 2>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 32, 32, 8, 8, 32, 32, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 4, 1, S<4, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 1, 1, 1, S<1, 16, 1, 16>, 2>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 32, 16, 4, 4, 32, 32, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 8, 4>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 1, 1, 1, S<1, 16, 1, 16>, 2>, + + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 32, 64, 16, 16, 32, 32, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 4, 1, S<4, 4, 16>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 8, 1, 32>, 1>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 32, 32, 8, 8, 32, 32, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 4, 1, S<4, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 8, 1, 32>, 1>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 32, 16, 4, 4, 32, 32, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 4>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 8, 1, 32>, 1>, + // 16x16 + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 16, 64, 16, 16, 16, 16, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 4, 1, S<4, 2, 16>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, 1, 1, 1, S<1, 64, 1, 4>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 4, 1, S<4, 2, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, 1, 1, 1, S<1, 64, 1, 4>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 16, 16, 4, 4, 16, 16, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 2, 4>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, 1, 1, 1, S<1, 64, 1, 4>, 4>, + + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 16, 64, 16, 16, 16, 16, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 4, 1, S<4, 4, 16>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 1, 1, 1, S<1, 32, 1, 8>, 2>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 4, 1, S<4, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 1, 1, 1, S<1, 32, 1, 8>, 2>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 16, 16, 4, 4, 16, 16, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 8, 4>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 1, 1, 1, S<1, 32, 1, 8>, 2>, + + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 16, 64, 16, 16, 16, 16, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 4, 1, S<4, 4, 16>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 4, 1, S<4, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 16, 16, 4, 4, 16, 16, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 4>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1> + // clang-format on >; @@ -257,6 +340,41 @@ using device_grouped_conv_bwd_data_xdl_f32_16_16_instances = // clang-format on >; +template +using device_grouped_conv_bwd_data_xdl_f32_optimized_loads_instances = + std::tuple< + // clang-format off + // ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer| + // ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector| + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // A K1 one access for each thread per load + // 32x32 + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 32, 32, 8, 8, 32, 32, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 4, 1, S<4, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, 1, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 32, 16, 4, 4, 32, 32, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 4, 4>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, 1, 1, 1, S<1, 32, 1, 8>, 4>, + + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 32, 32, 8, 8, 32, 32, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 4, 1, S<4, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 1, 1, 1, S<1, 16, 1, 16>, 2>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 32, 16, 4, 4, 32, 32, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 8, 4>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 1, 1, 1, S<1, 16, 1, 16>, 2>, + + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 32, 32, 8, 8, 32, 32, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 4, 1, S<4, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 8, 1, 32>, 1>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 32, 16, 4, 4, 32, 32, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 4>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 8, 1, 32>, 1>, + // 16x16 + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 4, 1, S<4, 2, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, 1, 1, 1, S<1, 64, 1, 4>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 16, 16, 4, 4, 16, 16, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 2, 4>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, 1, 1, 1, S<1, 64, 1, 4>, 4>, + + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 4, 1, S<4, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 1, 1, 1, S<1, 32, 1, 8>, 2>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 16, 16, 4, 4, 16, 16, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 8, 4>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 1, 1, 1, S<1, 32, 1, 8>, 2>, + + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 4, 1, S<4, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 16, 16, 4, 4, 16, 16, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 4>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1> + // clang-format on + >; + template >>& instances); + +void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_optimized_loads_instances( + std::vector>>& instances); #endif #ifdef CK_ENABLE_FP32 void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_instances( @@ -112,6 +126,20 @@ void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_16_16_instance PassThrough, PassThrough, PassThrough>>>& instances); + +void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_optimized_loads_instances( + std::vector>>& instances); #endif #ifdef CK_ENABLE_BF16 void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_instances( @@ -141,6 +169,20 @@ void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_16_16_instanc PassThrough, PassThrough, PassThrough>>>& instances); + +void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_optimized_loads_instances( + std::vector>>& instances); #endif #ifdef CK_ENABLE_FP16 @@ -393,6 +435,20 @@ void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_16_16_insta PassThrough, PassThrough, PassThrough>>>& instances); + +void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_optimized_loads_instances( + std::vector>>& instances); #endif #ifdef CK_ENABLE_FP32 void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_instances( @@ -422,6 +478,20 @@ void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_16_16_insta PassThrough, PassThrough, PassThrough>>>& instances); + +void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_optimized_loads_instances( + std::vector>>& instances); #endif #ifdef CK_ENABLE_BF16 void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_instances( @@ -451,6 +521,20 @@ void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_16_16_inst PassThrough, PassThrough, PassThrough>>>& instances); + +void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_optimized_loads_instances( + std::vector>>& instances); #endif #if defined CK_ENABLE_FP16 && defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8 void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_input_f16_comp_bf8f8_instances( diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/CMakeLists.txt index 0ef09c55ee..2598325d62 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/CMakeLists.txt @@ -10,6 +10,9 @@ add_instance_library( xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f16_16_16_instance.cpp xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_bf16_16_16_instance.cpp xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f32_16_16_instance.cpp + xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f16_optimized_loads_instance.cpp + xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_bf16_optimized_loads_instance.cpp + xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f32_optimized_loads_instance.cpp xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkyxc_ngkhw_f16_instance.cpp xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkyxc_ngkhw_bf16_instance.cpp xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkyxc_ngkhw_f32_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_bf16_optimized_loads_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_bf16_optimized_loads_instance.cpp new file mode 100644 index 0000000000..ff4ce04949 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_bf16_optimized_loads_instance.cpp @@ -0,0 +1,49 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_optimized_loads_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_bf16_optimized_loads_instances<2, + NHWGK, + GKYXC, + Empty_Tuple, + NHWGC, + ConvBwdDataDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances(instances, + device_grouped_conv_bwd_data_xdl_bf16_optimized_loads_instances< + 2, + NHWGK, + GKYXC, + Empty_Tuple, + NHWGC, + ConvBwdDataFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f16_optimized_loads_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f16_optimized_loads_instance.cpp new file mode 100644 index 0000000000..69f70c81a9 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f16_optimized_loads_instance.cpp @@ -0,0 +1,49 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_optimized_loads_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_f16_optimized_loads_instances<2, + NHWGK, + GKYXC, + Empty_Tuple, + NHWGC, + ConvBwdDataDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances(instances, + device_grouped_conv_bwd_data_xdl_f16_optimized_loads_instances< + 2, + NHWGK, + GKYXC, + Empty_Tuple, + NHWGC, + ConvBwdDataFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f32_optimized_loads_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f32_optimized_loads_instance.cpp new file mode 100644 index 0000000000..7d2c2454e1 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f32_optimized_loads_instance.cpp @@ -0,0 +1,49 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_optimized_loads_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_f32_optimized_loads_instances<2, + NHWGK, + GKYXC, + Empty_Tuple, + NHWGC, + ConvBwdDataDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances(instances, + device_grouped_conv_bwd_data_xdl_f32_optimized_loads_instances< + 2, + NHWGK, + GKYXC, + Empty_Tuple, + NHWGC, + ConvBwdDataFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/CMakeLists.txt index 4bb05e5000..8652d9fd9d 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/CMakeLists.txt @@ -9,6 +9,9 @@ set(GROUPED_CONV3D_BWD_DATA xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f16_16_16_instance.cpp xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16_16_instance.cpp xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f32_16_16_instance.cpp + xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f16_optimized_loads_instance.cpp + xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_optimized_loads_instance.cpp + xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f32_optimized_loads_instance.cpp xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkzyxc_ngkdhw_f16_instance.cpp xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_instance.cpp xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkzyxc_ngkdhw_f32_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_optimized_loads_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_optimized_loads_instance.cpp new file mode 100644 index 0000000000..63cdfcdad8 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_optimized_loads_instance.cpp @@ -0,0 +1,49 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_optimized_loads_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_bf16_optimized_loads_instances<3, + NDHWGK, + GKZYXC, + Empty_Tuple, + NDHWGC, + ConvBwdDataDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances(instances, + device_grouped_conv_bwd_data_xdl_bf16_optimized_loads_instances< + 3, + NDHWGK, + GKZYXC, + Empty_Tuple, + NDHWGC, + ConvBwdDataFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f16_optimized_loads_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f16_optimized_loads_instance.cpp new file mode 100644 index 0000000000..7a1ac75a03 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f16_optimized_loads_instance.cpp @@ -0,0 +1,49 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_optimized_loads_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_f16_optimized_loads_instances<3, + NDHWGK, + GKZYXC, + Empty_Tuple, + NDHWGC, + ConvBwdDataDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances(instances, + device_grouped_conv_bwd_data_xdl_f16_optimized_loads_instances< + 3, + NDHWGK, + GKZYXC, + Empty_Tuple, + NDHWGC, + ConvBwdDataFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f32_optimized_loads_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f32_optimized_loads_instance.cpp new file mode 100644 index 0000000000..c76f32479e --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f32_optimized_loads_instance.cpp @@ -0,0 +1,49 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_optimized_loads_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_f32_optimized_loads_instances<3, + NDHWGK, + GKZYXC, + Empty_Tuple, + NDHWGC, + ConvBwdDataDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances(instances, + device_grouped_conv_bwd_data_xdl_f32_optimized_loads_instances< + 3, + NDHWGK, + GKZYXC, + Empty_Tuple, + NDHWGC, + ConvBwdDataFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/profiler/include/profiler/profile_grouped_conv_bwd_data_impl.hpp b/profiler/include/profiler/profile_grouped_conv_bwd_data_impl.hpp index 0aeefaabfb..29b2fece6b 100644 --- a/profiler/include/profiler/profile_grouped_conv_bwd_data_impl.hpp +++ b/profiler/include/profiler/profile_grouped_conv_bwd_data_impl.hpp @@ -185,11 +185,17 @@ bool profile_grouped_conv_bwd_data_impl(int do_verification, // Use higher threshold rtol = std::max(rtol, rtol_split_k); atol = std::max(atol, atol_split_k); - - pass &= ck::utils::check_err( - in_device, in_host, "Error: Incorrect results!", rtol, atol); - std::cout << "Relative error threshold: " << rtol - << " Absolute error threshold: " << atol << std::endl; + if(split_k_for_run > 1) + { + pass &= ck::utils::check_err( + in_device, in_host, "Error: Incorrect results!", rtol, atol); + std::cout << "Relative error threshold: " << rtol + << " Absolute error threshold: " << atol << std::endl; + } + else + { + pass &= ck::utils::check_err(in_device, in_host, "Error: Incorrect results!"); + } if(do_log) { diff --git a/script/convert_miopen_driver_to_profiler.py b/script/convert_miopen_driver_to_profiler.py index 9e2f436e68..d814e0719c 100644 --- a/script/convert_miopen_driver_to_profiler.py +++ b/script/convert_miopen_driver_to_profiler.py @@ -10,7 +10,7 @@ import subprocess def init_const_args(args): - args.ck_profiler_cmd = '../build/bin/ckProfiler' + args.ck_profiler_cmd = "../build/bin/ckProfiler" # use decimal values args.init_method = 2 # don't print tensor values @@ -27,52 +27,62 @@ def run_ck_profiler_cmd(cmd): def parse_layouts(args): - if args.in_layout == "NCW" or args.in_layout == "NCHW" or \ - args.in_layout == "NCDHW": + if args.in_layout == "NCW" or args.in_layout == "NCHW" or args.in_layout == "NCDHW": if args.ck_profier_op == "grouped_conv_bwd_weight": args.layout = 4 - elif args.ck_profier_op == "grouped_conv_fwd" or \ - args.ck_profier_op == "grouped_conv_bwd_data": + elif ( + args.ck_profier_op == "grouped_conv_fwd" + or args.ck_profier_op == "grouped_conv_bwd_data" + ): args.layout = 3 else: - print('Not supported layout for this op') + print("Not supported layout for this op") exit(1) - elif args.in_layout == "NWC" or args.in_layout == "NHWC" or \ - args.in_layout == "NDHWC": + elif ( + args.in_layout == "NWC" or args.in_layout == "NHWC" or args.in_layout == "NDHWC" + ): if args.ck_profier_op == "grouped_conv_bwd_weight": args.layout = 2 - elif args.ck_profier_op == "grouped_conv_bwd_data" or \ - args.ck_profier_op == "grouped_conv_fwd": + elif ( + args.ck_profier_op == "grouped_conv_bwd_data" + or args.ck_profier_op == "grouped_conv_fwd" + ): args.layout = 1 else: - print('Not supported layout for this op') + print("Not supported layout for this op") exit(1) def parse_data_type(args): if args.data_type == "fp32": - if args.ck_profier_op == "grouped_conv_bwd_weight" or \ - args.ck_profier_op == "grouped_conv_bwd_data" or \ - args.ck_profier_op == "grouped_conv_fwd": + if ( + args.ck_profier_op == "grouped_conv_bwd_weight" + or args.ck_profier_op == "grouped_conv_bwd_data" + or args.ck_profier_op == "grouped_conv_fwd" + ): args.data_type = 0 if args.data_type == "fp16": - if args.ck_profier_op == "grouped_conv_bwd_weight" or \ - args.ck_profier_op == "grouped_conv_bwd_data" or \ - args.ck_profier_op == "grouped_conv_fwd": + if ( + args.ck_profier_op == "grouped_conv_bwd_weight" + or args.ck_profier_op == "grouped_conv_bwd_data" + or args.ck_profier_op == "grouped_conv_fwd" + ): args.data_type = 1 if args.data_type == "int8": if args.ck_profier_op == "grouped_conv_bwd_weight": args.data_type = 4 if args.ck_profier_op == "grouped_conv_bwd_data": - print('Not supported data type for grouped_conv_bwd_data') + print("Not supported data type for grouped_conv_bwd_data") exit(1) if args.ck_profier_op == "grouped_conv_fwd": args.data_type = 3 if args.data_type == "bfp16": if args.ck_profier_op == "grouped_conv_bwd_weight": args.data_type = 5 - if args.ck_profier_op == "grouped_conv_bwd_data" or \ - args.ck_profier_op == "grouped_conv_fwd": + if ( + args.ck_profier_op == "grouped_conv_bwd_data" + or args.ck_profier_op == "grouped_conv_fwd" + ): args.data_type = 2 @@ -93,13 +103,11 @@ def add_conv_params_to_cmd(args, cmd): cmd += [str(args.in_d), str(args.in_h), str(args.in_w)] cmd += [str(args.conv_stride_d), str(args.conv_stride_h)] cmd += [str(args.conv_stride_w)] - cmd += [str(args.dilation_d), - str(args.dilation_h), - str(args.dilation_w)] + cmd += [str(args.dilation_d), str(args.dilation_h), str(args.dilation_w)] cmd += [str(args.pad_d), str(args.pad_h), str(args.pad_w)] cmd += [str(args.pad_d), str(args.pad_h), str(args.pad_w)] else: - print('Not supported spatial dim (supported: 1, 2, 3)') + print("Not supported spatial dim (supported: 1, 2, 3)") exit(1) @@ -147,7 +155,7 @@ def run_ck_grouped_conv_bwd_weight(args): parse_data_type(args) parse_layouts(args) # Test all split K value from the list {1, 2, 4, 8, 32, 64, 128} - args.split_k_value = -1 + args.split_k_value = "all" cmd = [str(args.ck_profiler_cmd), str(args.ck_profier_op)] cmd += [str(args.data_type), str(args.layout)] @@ -161,23 +169,23 @@ def run_ck_grouped_conv_bwd_weight(args): cmd += [str(args.split_k_value)] run_ck_profiler_cmd(cmd) + # Get name of miopen driver, remove it from unknown def process_miopen_driver_name(args, unknown): if "convint8" in unknown: - args.data_type = 'int8' + args.data_type = "int8" unknown.remove("convint8") elif "convbfp16" in unknown: - args.data_type = 'bfp16' + args.data_type = "bfp16" unknown.remove("convbfp16") elif "convfp16" in unknown: - args.data_type = 'fp16' + args.data_type = "fp16" unknown.remove("convfp16") elif "conv" in unknown: - args.data_type = 'fp32' + args.data_type = "fp32" unknown.remove("conv") else: - print('Not supported driver (supported: conv, convfp16, convint8,' - ' convbfp16).') + print("Not supported driver (supported: conv, convfp16, convint8, convbfp16).") exit(1) @@ -199,11 +207,11 @@ if __name__ == "__main__": parser = argparse.ArgumentParser( prog="converter", description="Convert miopen driver command to ck Profiler" - "\nExample: python3 " - "../script/convert_miopen_driver_to_profiler.py " - "/opt/rocm/bin/MIOpenDriver conv -n 32 -c 64 -H 28 -W 28 " - "-k 64 -y 3 -x 3 -p 1 -q 1 -u 1 -v 1 -l 1 -j 1 -m conv -g " - "32 -F 1 -t 1", + "\nExample: python3 " + "../script/convert_miopen_driver_to_profiler.py " + "/opt/rocm/bin/MIOpenDriver conv -n 32 -c 64 -H 28 -W 28 " + "-k 64 -y 3 -x 3 -p 1 -q 1 -u 1 -v 1 -l 1 -j 1 -m conv -g " + "32 -F 1 -t 1", ) parser.add_argument( "-in_layout", @@ -213,7 +221,7 @@ if __name__ == "__main__": default="NCHW", type=str, required=False, - help="Input Layout (Default=NCHW for 2d conv, NCDHW for 3d conv)" + help="Input Layout (Default=NCHW for 2d conv, NCDHW for 3d conv)", ) parser.add_argument( "-forw", @@ -230,7 +238,7 @@ if __name__ == "__main__": "\n4 wrw only" "\n3 fwd+bwd" "\n5 fwd+wrw" - "\n6 bwd+wrw" + "\n6 bwd+wrw", ) parser.add_argument( "-spatial_dim", @@ -240,7 +248,7 @@ if __name__ == "__main__": default=2, type=int, required=False, - help="convolution spatial dimension (Default-2)" + help="convolution spatial dimension (Default-2)", ) parser.add_argument( "-batchsize", @@ -250,7 +258,7 @@ if __name__ == "__main__": default=100, type=int, required=False, - help="Mini-batch size (Default=100)" + help="Mini-batch size (Default=100)", ) parser.add_argument( "-in_channels", @@ -260,7 +268,7 @@ if __name__ == "__main__": default=3, type=int, required=False, - help="Number of Input Channels (Default=3)" + help="Number of Input Channels (Default=3)", ) parser.add_argument( "-in_d", @@ -270,7 +278,7 @@ if __name__ == "__main__": default=32, type=int, required=False, - help="Input Depth (Default=32)" + help="Input Depth (Default=32)", ) parser.add_argument( "-in_h", @@ -280,7 +288,7 @@ if __name__ == "__main__": default=32, type=int, required=False, - help="Input Height (Default=32)" + help="Input Height (Default=32)", ) parser.add_argument( "-in_w", @@ -290,7 +298,7 @@ if __name__ == "__main__": default=32, type=int, required=False, - help="Input Width (Default=32)" + help="Input Width (Default=32)", ) parser.add_argument( "-out_channels", @@ -300,7 +308,7 @@ if __name__ == "__main__": default=32, type=int, required=False, - help="Number of Output Channels (Default=32)" + help="Number of Output Channels (Default=32)", ) parser.add_argument( "-fil_d", @@ -310,7 +318,7 @@ if __name__ == "__main__": default=3, type=int, required=False, - help="Filter Depth (Default=3)" + help="Filter Depth (Default=3)", ) parser.add_argument( "-fil_h", @@ -320,7 +328,7 @@ if __name__ == "__main__": default=3, type=int, required=False, - help="Filter Height (Default=3)" + help="Filter Height (Default=3)", ) parser.add_argument( "-fil_w", @@ -330,7 +338,7 @@ if __name__ == "__main__": default=3, type=int, required=False, - help="Filter Width (Default=3)" + help="Filter Width (Default=3)", ) parser.add_argument( "-conv_stride_d", @@ -340,7 +348,7 @@ if __name__ == "__main__": default=1, type=int, required=False, - help="Convolution Stride for Depth (Default=1)" + help="Convolution Stride for Depth (Default=1)", ) parser.add_argument( "-conv_stride_h", @@ -350,7 +358,7 @@ if __name__ == "__main__": default=1, type=int, required=False, - help="Convolution Stride for Height (Default=1)" + help="Convolution Stride for Height (Default=1)", ) parser.add_argument( "-conv_stride_w", @@ -360,7 +368,7 @@ if __name__ == "__main__": default=1, type=int, required=False, - help="Convolution Stride for Width (Default=1)" + help="Convolution Stride for Width (Default=1)", ) parser.add_argument( "-pad_d", @@ -370,7 +378,7 @@ if __name__ == "__main__": default=1, type=int, required=False, - help="Zero Padding for Depth (Default=0)" + help="Zero Padding for Depth (Default=0)", ) parser.add_argument( "-pad_h", @@ -380,7 +388,7 @@ if __name__ == "__main__": default=1, type=int, required=False, - help="Zero Padding for Height (Default=0)" + help="Zero Padding for Height (Default=0)", ) parser.add_argument( "-pad_w", @@ -390,7 +398,7 @@ if __name__ == "__main__": default=1, type=int, required=False, - help="Zero Padding for Width (Default=0)" + help="Zero Padding for Width (Default=0)", ) parser.add_argument( "-verify", @@ -400,7 +408,7 @@ if __name__ == "__main__": default=1, type=int, required=False, - help="Verify Each Layer (Default=1)" + help="Verify Each Layer (Default=1)", ) parser.add_argument( "-time", @@ -410,7 +418,7 @@ if __name__ == "__main__": default=0, type=int, required=False, - help="Time Each Layer (Default=0)" + help="Time Each Layer (Default=0)", ) parser.add_argument( "-dilation_d", @@ -420,7 +428,7 @@ if __name__ == "__main__": default=1, type=int, required=False, - help="Dilation of Filter Depth (Default=1)" + help="Dilation of Filter Depth (Default=1)", ) parser.add_argument( "-dilation_h", @@ -430,7 +438,7 @@ if __name__ == "__main__": default=1, type=int, required=False, - help="Dilation of Filter Height (Default=1)" + help="Dilation of Filter Height (Default=1)", ) parser.add_argument( "-dilation_w", @@ -440,7 +448,7 @@ if __name__ == "__main__": default=1, type=int, required=False, - help="Dilation of Filter Width (Default=1)" + help="Dilation of Filter Width (Default=1)", ) parser.add_argument( "-group_count", @@ -450,7 +458,7 @@ if __name__ == "__main__": type=int, default=1, required=False, - help="Number of Groups (Default=1)" + help="Number of Groups (Default=1)", ) args, unknown = parser.parse_known_args()