diff --git a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp index 29a04d9b6c..9cff75f049 100644 --- a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp @@ -47,11 +47,17 @@ concept BlockGemmPipelineDescriptor = requires(T t) { // Concept for parameters that describe a gridwise WMMA GEMM problem. template concept GridwiseWmmaGemmDescriptor = requires(T t) { - { t.k1 } -> SizeType; - { t.m_per_wmma } -> SizeType; - { t.n_per_wmma } -> SizeType; - { t.m_wmma_per_wave } -> SizeType; - { t.n_wmma_per_wave } -> SizeType; + ( + requires { { T::k1 } -> SizeType; } || + (requires { { T::ak1 } -> SizeType; } && + requires { { T::bk1 } -> SizeType; }) +) && +requires { + { T::m_per_wmma } -> SizeType; + { T::n_per_wmma } -> SizeType; + { T::m_wmma_per_wave } -> SizeType; + { T::n_wmma_per_wave } -> SizeType; +}; }; // Concept for vectorized data transfer for convolution input tensors. @@ -187,6 +193,14 @@ concept GridwiseBwdXdlGemmDescriptor = requires(T t) { { t.xdl_params } -> GridwiseXdlGemmDescriptor; }; +// Concept to check if a struct specifies gridwise XDL GEMM info. +template +concept GridwiseBwdDataXdlGemmDescriptor = requires(T t) { + { t.ak1 } -> SizeType; + { t.bk1 } -> SizeType; + { t.xdl_params } -> GridwiseXdlGemmDescriptor; +}; + // Concept to check if a struct specifies gridwise XDL GEMM info. template concept SpecifiesGridwiseFwdXdlGemm = requires(T t) { @@ -199,6 +213,12 @@ concept SpecifiesGridwiseBwdXdlGemm = requires(T t) { { t.gridwise_gemm } -> GridwiseBwdXdlGemmDescriptor; }; +// Concept to check if a struct specifies gridwise XDL GEMM info. +template +concept SpecifiesGridwiseBwdDataXdlGemm = requires(T t) { + { t.gridwise_gemm } -> GridwiseBwdDataXdlGemmDescriptor; +}; + // Concept to check if a struct specifies gridwise WMMA GEMM info. template concept SpecifiesGridwiseWmmaGemm = requires(T t) { @@ -292,6 +312,11 @@ concept SpecifiesBwdWeightConvSpecialization = requires { { T::bwd_weight_specialization } -> std::convertible_to; }; +template +concept SpecifiesBwdDataConvSpecialization = requires { + { T::bwd_data_specialization } -> std::convertible_to; +}; + template concept SpecifiesGemmSpecialization = requires { { T::gemm_specialization } -> std::convertible_to; diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp index 2b09ba0b1f..84f8b688ad 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp @@ -28,24 +28,29 @@ concept FwdXdlAlgorithmBase = template concept BwdXdlAlgorithmBase = - ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesTileTransferParameters4D && - SpecifiesGridwiseBwdXdlGemm && SpecifiesBwdWeightConvSpecialization; + ConvAlgorithmDescriptor && SpecifiesThreadBlock && + (SpecifiesTileTransferParameters4D || SpecifiesTileTransferParameters3D) && + (SpecifiesGridwiseBwdXdlGemm || SpecifiesGridwiseBwdDataXdlGemm) && + (SpecifiesBwdWeightConvSpecialization || SpecifiesBwdDataConvSpecialization); template concept BwdXdlV3AlgorithmBase = ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesTileTransferParameters3D && - SpecifiesGridwiseBwdXdlGemm && SpecifiesBwdWeightConvSpecialization && + (SpecifiesGridwiseBwdXdlGemm || SpecifiesGridwiseBwdDataXdlGemm) && + (SpecifiesBwdWeightConvSpecialization || SpecifiesBwdDataConvSpecialization) && SpecifiesBlockGemm && SpecifiesNumGroupsToMerge; template concept BwdWmmaAlgorithmBase = ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesTileTransferParameters3D && - SpecifiesGridwiseWmmaGemm && SpecifiesBwdWeightConvSpecialization; + SpecifiesGridwiseWmmaGemm && + (SpecifiesBwdWeightConvSpecialization || SpecifiesBwdDataConvSpecialization); template concept BwdWmmaV3AlgorithmBase = ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesTileTransferParameters3D && - SpecifiesGridwiseWmmaGemm && SpecifiesBwdWeightConvSpecialization && + SpecifiesGridwiseWmmaGemm && + (SpecifiesBwdWeightConvSpecialization || SpecifiesBwdDataConvSpecialization) && SpecifiesBlockGemm; // Reference algorithm concept @@ -107,6 +112,9 @@ concept BwdWmmaAlgorithm = BwdWmmaAlgorithmBase && SpecifiesNumPrefetchStages && SpecifiesLoopScheduler && SpecifiesGridwiseGemmPipeline && SpecifiesGenericInstance; +template +concept BwdMultiDWmmaAlgorithm = BwdWmmaAlgorithmBase && SpecifiesMultipleDSupport; + template concept BwdMultiDWmmaV3Algorithm = BwdWmmaV3AlgorithmBase && SpecifiesMultipleDSupport; diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_data_multi_d_wmma_cshuffle_v3_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_data_multi_d_wmma_cshuffle_v3_factory.hpp new file mode 100644 index 0000000000..afad6c12df --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_data_multi_d_wmma_cshuffle_v3_factory.hpp @@ -0,0 +1,115 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp" +#include "ck_tile/builder/conv_signature_concepts.hpp" +#include "ck_tile/builder/conv_algorithm_concepts.hpp" +#include "ck_tile/builder/conv_algorithm_limits.hpp" +#include "ck_tile/builder/builder_utils.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp" + +namespace ck_tile::builder::factory { + +// Factory for DeviceGroupedConvBwdDataMultipleD_wmma_CShuffle_v3 instance +// of a grouped bwd Data convolution kernel. +template + requires ConvDirectionIsBackwardData +struct ConvBwdDataMultiDWmmaV3Factory +{ + static constexpr int SPATIAL_DIM = SIGNATURE.spatial_dim; + using Layouts = internal::ConvTensorLayouts; + using Types = internal::ConvTensorDataTypes; + using Ops = internal::ConvElementwiseOps; + using AlgorithmType = decltype(ALGORITHM); + + static constexpr auto BWD_CONV_SPECIALIZATION = + internal::SetBwdDataConvSpecialization(); + + static constexpr auto LOOP_SCHEDULER = internal::SetLoopScheduler(); + static constexpr auto BLOCK = internal::SetThreadBlockInfo(); + static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm; + static constexpr auto A_BLOCK_TRANSFER = + internal::SetBwdConvBlockTransfer(); + static constexpr auto B_BLOCK_TRANSFER = + internal::SetBwdConvBlockTransfer(); + static constexpr auto C_BLOCK_TRANSFER = internal::SetCBlockTransfer(); + static constexpr auto BLOCK_GEMM = internal::SetBlockGemm(); + + // Check limits for the algorithm parameters. + // TODO: Add more limits checks as needed. + static_assert(InputVectorTransferLimits); + static_assert(InputVectorTransferLimits); + static_assert(OutputVectorTransferLimits); + static_assert(AccessOrderLimits3D); + static_assert(AccessOrderLimits3D); + static_assert(AccessOrderLimits3D); + static_assert(AccessOrderLimits3D); + + // The backward convolution kernel class instance. + using Instance = + ck::tensor_operation::device::DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3< + SPATIAL_DIM, + typename Layouts::OutLayout, + typename Layouts::WeiLayout, + typename Layouts::DsLayout, + typename Layouts::InLayout, + typename Types::OutDataType, + typename Types::WeiDataType, + typename Types::AccDataType, + typename Types::OutComputeType, + typename Types::DsDataType, + typename Types::InDataType, + typename Ops::OutElementwiseOp, + typename Ops::WeiElementwiseOp, + typename Ops::InElementwiseOp, + BWD_CONV_SPECIALIZATION, + ALGORITHM.DoPadGemmM, + ALGORITHM.DoPadGemmN, + BLOCK.block_size, + BLOCK.per_block.m, + BLOCK.per_block.n, + BLOCK.per_block.k, + GRIDWISE_GEMM.ak1, + GRIDWISE_GEMM.bk1, + GRIDWISE_GEMM.m_per_wmma, + GRIDWISE_GEMM.n_per_wmma, + GRIDWISE_GEMM.m_wmma_per_wave, + GRIDWISE_GEMM.n_wmma_per_wave, + to_sequence_v, + to_sequence_v, + to_sequence_v, + A_BLOCK_TRANSFER.src_vector_dim, + A_BLOCK_TRANSFER.src_scalar_per_vector, + A_BLOCK_TRANSFER.lds_dst_scalar_per_vector, + A_BLOCK_TRANSFER.lds_padding, + to_sequence_v, + to_sequence_v, + to_sequence_v, + B_BLOCK_TRANSFER.src_vector_dim, + B_BLOCK_TRANSFER.src_scalar_per_vector, + B_BLOCK_TRANSFER.lds_dst_scalar_per_vector, + B_BLOCK_TRANSFER.lds_padding, + C_BLOCK_TRANSFER.m_xdl_per_wave_per_shuffle, + C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle, + to_sequence_v, + ck::Sequence, + BLOCK_GEMM.scheduler, + BLOCK_GEMM.pipeline_version, + typename Types::OutComputeType, + typename Types::InComputeType, + ALGORITHM.max_transpose_transfer_src_scalar_per_vector, + ALGORITHM.max_transpose_transfer_dst_scalar_per_vector>; +}; + +} // namespace ck_tile::builder::factory diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_data_multi_d_wmma_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_data_multi_d_wmma_factory.hpp new file mode 100644 index 0000000000..8a2147d12d --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_data_multi_d_wmma_factory.hpp @@ -0,0 +1,107 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp" +#include "ck_tile/builder/conv_signature_concepts.hpp" +#include "ck_tile/builder/conv_algorithm_concepts.hpp" +#include "ck_tile/builder/conv_algorithm_limits.hpp" +#include "ck_tile/builder/builder_utils.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp" + +namespace ck_tile::builder::factory { + +// Factory for DeviceGroupedConvBwdDataMultipleD_wmma_CShuffle instance +// of a grouped bwd Data convolution kernel. +template + requires ConvDirectionIsBackwardData +struct ConvBwdDataMultiDWmmaFactory +{ + static constexpr int SPATIAL_DIM = SIGNATURE.spatial_dim; + using Layouts = internal::ConvTensorLayouts; + using Types = internal::ConvTensorDataTypes; + using Ops = internal::ConvElementwiseOps; + using AlgorithmType = decltype(ALGORITHM); + + static constexpr auto BWD_CONV_SPECIALIZATION = + internal::SetBwdDataConvSpecialization(); + + static constexpr auto LOOP_SCHEDULER = internal::SetLoopScheduler(); + static constexpr auto BLOCK = internal::SetThreadBlockInfo(); + static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm; + static constexpr auto GRIDWISE_GEMM_PIPELINE_VERSION = + internal::SetGridwiseGemmPipelineVersion(); + static constexpr auto A_BLOCK_TRANSFER = + internal::SetBwdConvBlockTransfer(); + static constexpr auto B_BLOCK_TRANSFER = + internal::SetBwdConvBlockTransfer(); + static constexpr auto C_BLOCK_TRANSFER = internal::SetCBlockTransfer(); + + // Check limits for the algorithm parameters. + // TODO: Add more limits checks as needed. + static_assert(InputVectorTransferLimits); + static_assert(InputVectorTransferLimits); + static_assert(OutputVectorTransferLimits); + static_assert(AccessOrderLimits3D); + static_assert(AccessOrderLimits3D); + static_assert(AccessOrderLimits3D); + static_assert(AccessOrderLimits3D); + + // The backward convolution kernel class instance. + using Instance = ck::tensor_operation::device::DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle< + SPATIAL_DIM, + typename Layouts::OutLayout, + typename Layouts::WeiLayout, + typename Layouts::DsLayout, + typename Layouts::InLayout, + typename Types::OutDataType, + typename Types::WeiDataType, + typename Types::AccDataType, + typename Types::OutComputeType, + typename Types::DsDataType, + typename Types::InDataType, + typename Ops::OutElementwiseOp, + typename Ops::WeiElementwiseOp, + typename Ops::InElementwiseOp, + BWD_CONV_SPECIALIZATION, + BLOCK.block_size, + BLOCK.per_block.m, + BLOCK.per_block.n, + BLOCK.per_block.k, + GRIDWISE_GEMM.k1, + GRIDWISE_GEMM.m_per_wmma, + GRIDWISE_GEMM.n_per_wmma, + GRIDWISE_GEMM.m_wmma_per_wave, + GRIDWISE_GEMM.n_wmma_per_wave, + to_sequence_v, + to_sequence_v, + to_sequence_v, + A_BLOCK_TRANSFER.src_vector_dim, + A_BLOCK_TRANSFER.src_scalar_per_vector, + A_BLOCK_TRANSFER.lds_dst_scalar_per_vector, + A_BLOCK_TRANSFER.lds_padding, + to_sequence_v, + to_sequence_v, + to_sequence_v, + B_BLOCK_TRANSFER.src_vector_dim, + B_BLOCK_TRANSFER.src_scalar_per_vector, + B_BLOCK_TRANSFER.lds_dst_scalar_per_vector, + B_BLOCK_TRANSFER.lds_padding, + C_BLOCK_TRANSFER.m_xdl_per_wave_per_shuffle, + C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle, + to_sequence_v, + C_BLOCK_TRANSFER.scalar_per_vector, + ALGORITHM.num_gemm_k_prefetch_stages, + LOOP_SCHEDULER, + GRIDWISE_GEMM_PIPELINE_VERSION>; +}; + +} // namespace ck_tile::builder::factory diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_data_multi_d_xdl_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_data_multi_d_xdl_factory.hpp new file mode 100644 index 0000000000..68ef5cbdfd --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_data_multi_d_xdl_factory.hpp @@ -0,0 +1,113 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp" +#include "ck_tile/builder/conv_signature_concepts.hpp" +#include "ck_tile/builder/conv_algorithm_concepts.hpp" +#include "ck_tile/builder/conv_algorithm_limits.hpp" +#include "ck_tile/builder/builder_utils.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp" + +namespace ck_tile::builder::factory { + +// Factory for DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_V1 instance +// of a grouped bwd Data convolution kernel. +template + requires ConvDirectionIsBackwardData +struct ConvBwdDataMultiDXdlFactory +{ + static constexpr int SPATIAL_DIM = SIGNATURE.spatial_dim; + using Layouts = internal::ConvTensorLayouts; + using Types = internal::ConvTensorDataTypes; + using Ops = internal::ConvElementwiseOps; + using AlgorithmType = decltype(ALGORITHM); + + static constexpr auto BWD_CONV_SPECIALIZATION = + internal::SetBwdDataConvSpecialization(); + + static constexpr auto LOOP_SCHEDULER = internal::SetLoopScheduler(); + static constexpr auto BLOCK = internal::SetThreadBlockInfo(); + static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm; + static constexpr auto XDL_PARAMS = GRIDWISE_GEMM.xdl_params; + static constexpr auto A_BLOCK_TRANSFER = + internal::SetBwdConvBlockTransfer(); + static constexpr auto B_BLOCK_TRANSFER = + internal::SetBwdConvBlockTransfer(); + static constexpr auto C_BLOCK_TRANSFER = internal::SetCBlockTransfer(); + + // Check limits for the algorithm parameters. + // TODO: Add more limits checks as needed. + static_assert(InputVectorTransferLimits); + static_assert(InputVectorTransferLimits); + static_assert(OutputVectorTransferLimits); + static_assert(AccessOrderLimits3D); + static_assert(AccessOrderLimits3D); + static_assert(AccessOrderLimits3D); + static_assert(AccessOrderLimits3D); + + // The backward convolution kernel class instance. + using Instance = + ck::tensor_operation::device::DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< + SPATIAL_DIM, + typename Layouts::OutLayout, + typename Layouts::WeiLayout, + typename Layouts::DsLayout, + typename Layouts::InLayout, + typename Types::OutDataType, + typename Types::WeiDataType, + typename Types::AccDataType, + typename Types::OutComputeType, + typename Types::DsDataType, + typename Types::InDataType, + typename Ops::OutElementwiseOp, + typename Ops::WeiElementwiseOp, + typename Ops::InElementwiseOp, + BWD_CONV_SPECIALIZATION, + ALGORITHM.DoPadGemmM, + ALGORITHM.DoPadGemmN, + ALGORITHM.num_gemm_k_prefetch_stages, + BLOCK.block_size, + BLOCK.per_block.m, + BLOCK.per_block.n, + BLOCK.per_block.k, + GRIDWISE_GEMM.ak1, + GRIDWISE_GEMM.bk1, + XDL_PARAMS.m_per_xdl, + XDL_PARAMS.n_per_xdl, + XDL_PARAMS.m_xdl_per_wave, + XDL_PARAMS.n_xdl_per_wave, + to_sequence_v, + to_sequence_v, + to_sequence_v, + A_BLOCK_TRANSFER.src_vector_dim, + A_BLOCK_TRANSFER.src_scalar_per_vector, + A_BLOCK_TRANSFER.lds_dst_scalar_per_vector, + A_BLOCK_TRANSFER.lds_padding, + to_sequence_v, + to_sequence_v, + to_sequence_v, + B_BLOCK_TRANSFER.src_vector_dim, + B_BLOCK_TRANSFER.src_scalar_per_vector, + B_BLOCK_TRANSFER.lds_dst_scalar_per_vector, + B_BLOCK_TRANSFER.lds_padding, + C_BLOCK_TRANSFER.m_xdl_per_wave_per_shuffle, + C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle, + to_sequence_v, + C_BLOCK_TRANSFER.scalar_per_vector, + LOOP_SCHEDULER, + typename Types::OutComputeType, + typename Types::InComputeType, + ALGORITHM.max_transpose_transfer_src_scalar_per_vector, + ALGORITHM.max_transpose_transfer_dst_scalar_per_vector>; +}; + +} // namespace ck_tile::builder::factory diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp index e235db4bb0..857bc4b7c2 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp @@ -77,6 +77,9 @@ #include "ck_tile/builder/factory/conv_bwd_weight_two_stage_wmma_v3_factory.hpp" #include "ck_tile/builder/factory/conv_bwd_weight_wmma_factory.hpp" #include "ck_tile/builder/factory/conv_bwd_weight_multi_d_wmma_v3_factory.hpp" +#include "ck_tile/builder/factory/conv_bwd_data_multi_d_xdl_factory.hpp" +#include "ck_tile/builder/factory/conv_bwd_data_multi_d_wmma_factory.hpp" +#include "ck_tile/builder/factory/conv_bwd_data_multi_d_wmma_cshuffle_v3_factory.hpp" namespace ck_tile::builder::factory { @@ -148,13 +151,32 @@ constexpr auto make_conv_instance() "WMMA, DL (NHWC layout), or Large Tensor variant."); } } - // Backward data direction (will expand with more algorithms in the future) + // Backward data direction else if constexpr(ConvDirectionIsBackwardData) { - static_assert(false, - "Backward data convolution: Only reference and tile algorithms supported " - "currently. " - "Optimized kernels (XDL, WMMA, etc.) not yet implemented."); + if constexpr(BwdMultiDXdlAlgorithm) + { + return typename ConvBwdDataMultiDXdlFactory::Instance{}; + } + else if constexpr(BwdMultiDWmmaV3Algorithm) + { + return + typename ConvBwdDataMultiDWmmaV3Factory::Instance{}; + } + else if constexpr(BwdMultiDWmmaAlgorithm) + { + return typename ConvBwdDataMultiDWmmaFactory::Instance{}; + } + else + { + static_assert( + false, + "No suitable backward data convolution kernel factory found for the provided " + "ALGORITHM. " + "The ALGORITHM must satisfy requirements for one of: Reference, XDL multiple d, " + "Wmma multiple d, " + "or WMMA multiple d v3."); + } } // Backward weight direction (will expand with more algorithms in the future) else if constexpr(ConvDirectionIsBackwardWeight) diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp index 129d6e9c83..411f5ad346 100644 --- a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp @@ -5,6 +5,7 @@ #include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" #include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp" +#include "ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp" #include "ck/tensor_operation/gpu/device/device_base.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" @@ -186,4 +187,24 @@ SetBwdWeightConvSpecialization() } } +template +consteval ck::tensor_operation::device::ConvolutionBackwardDataSpecialization +SetBwdDataConvSpecialization() +{ + constexpr auto specialization = ALGORITHM.bwd_data_specialization; + using ck_conv_spec = ck::tensor_operation::device::ConvolutionBackwardDataSpecialization; + switch(specialization) + { + case ConvSpecialization::DEFAULT: return ck_conv_spec::Default; + case ConvSpecialization::FILTER_1X1_PAD0: + throw "FILTER_1x1_PAD0 is not supported for backward data convolution."; + case ConvSpecialization::FILTER_1X1_STRIDE1_PAD0: return ck_conv_spec::Filter1x1Stride1Pad0; + case ConvSpecialization::ODD_C: + throw "FILTER ODD_C is not supported for backward data convolution."; + case ConvSpecialization::FILTER_3x3: + throw "FILTER_3x3 is not supported for backward data convolution."; + default: throw "Unsupported ConvSpecialization"; + } +} + } // namespace ck_tile::builder::factory::internal diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_description.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_description.hpp index 5c09e4b735..c14cfce63c 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/conv_description.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_description.hpp @@ -90,6 +90,10 @@ class ConvDescription : public Description 2, "Gemm padding: ", traits_.gemm_padding.value_or(builder::GemmPadding::DEFAULT)); else f.writeLine(2, "Struct does not contain optional gemm_padding argument"); + if(traits_.do_pad_gemm_m) + f.writeLine(2, "Do Padd Gemm M: ", traits_.do_pad_gemm_m.value_or(false)); + if(traits_.do_pad_gemm_n) + f.writeLine(2, "Do Padd Gemm N: ", traits_.do_pad_gemm_n.value_or(false)); f.writeLine(2, "Convolution specialization: ", traits_.conv_specialization); // Pipeline section f.writeLine(2, "Pipeline version: ", traits_.pipeline_version); @@ -103,7 +107,7 @@ class ConvDescription : public Description traits_.warp_gemm.n_iter); // Memory Access section - f.writeLast(2, "Memory access:"); + f.writeLine(2, "Memory access:"); f.writeLine(3, "A Tile transfer: "); f.writeLine(4, @@ -196,7 +200,7 @@ class ConvDescription : public Description traits_.c_tile_transfer.thread_cluster_dims[2], "×", traits_.c_tile_transfer.thread_cluster_dims[3]); - f.writeLine(4, + f.writeLast(4, "Vector access (GMEM write) instruction size: ", traits_.c_tile_transfer.scalar_per_vector); if(traits_.num_gemm_k_prefetch_stage) @@ -215,14 +219,14 @@ class ConvDescription : public Description f.writeLine(2, "Struct does not contain optional " "max_transpose_transfer_src_scalar_per_vector parameter"); - if(traits_.max_transpose_dst_scalar_per_vector) + if(traits_.max_transpose_transfer_dst_scalar_per_vector) f.writeLine(2, "Max Transpose dst scalar per vector: ", - traits_.max_transpose_dst_scalar_per_vector.value_or(0)); + traits_.max_transpose_transfer_dst_scalar_per_vector.value_or(0)); else - f.writeLine( - 2, - "Struct does not contain optional max_transpose_dst_scalar_per_vector parameter"); + f.writeLine(2, + "Struct does not contain optional " + "max_transpose_transfer_dst_scalar_per_vector parameter"); if(traits_.num_groups_to_merge) f.writeLast(2, "Num groups to merge: ", traits_.num_groups_to_merge.value_or(0)); else diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp index 16a9c47f7e..21f6525534 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp @@ -108,8 +108,10 @@ struct ConvTraits builder::PipelineScheduler pipeline_scheduler; std::optional max_transpose_transfer_src_scalar_per_vector = std::nullopt; - std::optional max_transpose_dst_scalar_per_vector = std::nullopt; + std::optional max_transpose_transfer_dst_scalar_per_vector = std::nullopt; std::optional num_groups_to_merge = std::nullopt; + std::optional do_pad_gemm_m = std::nullopt; + std::optional do_pad_gemm_n = std::nullopt; }; } // namespace ck_tile::reflect::conv diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp new file mode 100644 index 0000000000..81ca13e2aa --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp @@ -0,0 +1,48 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include + +#include "ck_tile/builder/reflect/conv_traits.hpp" +#include "ck_tile/builder/reflect/conv_traits_helpers.hpp" +#include "ck_tile/builder/reflect/instance_traits.hpp" +#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp" + +namespace ck_tile::reflect::conv { + +/// @brief Tag dispatch implementation for DeviceGroupedConvBwdData_Wmma_CShuffle_Tag +template + requires HasInstanceTraits && + std::same_as::device_kernel_tag, + DeviceGroupedConvBwdData_multiple_d_Wmma_CShuffle_Tag> +constexpr ConvTraits instance_to_conv_traits() +{ + using InstTraits = InstanceTraits; + + return ConvTraits{ + .spatial_dim = InstTraits::kSpatialDim, + .direction = conv_direction(), + .layout = bwd_wei_conv_layout(), + .data_type = conv_data_type(), + .input_element_op = elementwise_op(), + .weight_element_op = elementwise_op(), + .output_element_op = elementwise_op(), + .conv_specialization = conv_spec(), + .thread_block_size = InstTraits::kBlockSize, + .tile_dims = conv_traits_data_tile(InstTraits::kK0PerBlock), + .a_tile_transfer = + conv_traits_a_transfer_params(InstTraits::kK1, InstTraits::kK0PerBlock), + .b_tile_transfer = + conv_traits_b_transfer_params(InstTraits::kK1, InstTraits::kK0PerBlock), + .warp_gemm = conv_traits_wmma_warp_gemm_params(), + .c_tile_transfer = conv_traits_wmma_c_tile_transfer( + InstTraits::kCDEShuffleBlockTransferScalarPerVector_NPerBlock), + .num_gemm_k_prefetch_stage = InstTraits::kNumGemmKPrefetchStage, + .pipeline_version = get_pipeline_version(), + .pipeline_scheduler = get_pipeline_scheduler(), + }; +} + +} // namespace ck_tile::reflect::conv diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp new file mode 100644 index 0000000000..45757c0432 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp @@ -0,0 +1,53 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include + +#include "ck_tile/builder/reflect/conv_traits.hpp" +#include "ck_tile/builder/reflect/conv_traits_helpers.hpp" +#include "ck_tile/builder/reflect/instance_traits.hpp" +#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp" + +namespace ck_tile::reflect::conv { + +/// @brief Tag dispatch implementation for DeviceGroupedConvBwdData_Wmma_CShuffle_V3_Tag +template + requires HasInstanceTraits && + std::same_as::device_kernel_tag, + DeviceGroupedConvBwdData_multiple_d_Wmma_CShuffle_V3_Tag> +constexpr ConvTraits instance_to_conv_traits() +{ + using InstTraits = InstanceTraits; + + return ConvTraits{ + .spatial_dim = InstTraits::kSpatialDim, + .direction = conv_direction(), + .layout = bwd_wei_conv_layout(), + .data_type = conv_data_type(), + .input_element_op = elementwise_op(), + .weight_element_op = elementwise_op(), + .output_element_op = elementwise_op(), + .conv_specialization = conv_spec(), + .thread_block_size = InstTraits::kBlockSize, + .tile_dims = conv_traits_data_tile(InstTraits::kK0PerBlock), + .a_tile_transfer = + conv_traits_a_transfer_params(InstTraits::kAK1, InstTraits::kK0PerBlock), + .b_tile_transfer = + conv_traits_b_transfer_params(InstTraits::kBK1, InstTraits::kK0PerBlock), + .warp_gemm = conv_traits_wmma_warp_gemm_params(), + .c_tile_transfer = conv_traits_wmma_c_tile_transfer( + InstTraits::kCDEShuffleBlockTransferScalarPerVector_NPerBlock[0]), + .pipeline_version = get_pipeline_version(), + .pipeline_scheduler = get_pipeline_scheduler(), + .max_transpose_transfer_src_scalar_per_vector = + InstTraits::kMaxTransposeTransferSrcScalarPerVector, + .max_transpose_transfer_dst_scalar_per_vector = + InstTraits::kMaxTransposeTransferDstScalarPerVector, + .do_pad_gemm_m = InstTraits::kDoPadGemmM, + .do_pad_gemm_n = InstTraits::kDoPadGemmN, + }; +} + +} // namespace ck_tile::reflect::conv diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle.hpp new file mode 100644 index 0000000000..50fcc9b192 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle.hpp @@ -0,0 +1,60 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include + +#include "ck_tile/builder/reflect/conv_traits.hpp" +#include "ck_tile/builder/reflect/conv_traits_helpers.hpp" +#include "ck_tile/builder/reflect/instance_traits.hpp" +#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle.hpp" + +namespace ck_tile::reflect::conv { + +/// @brief Tag dispatch implementation for DeviceGroupedConvBwdData_Xdl_CShuffle_Tag +template + requires HasInstanceTraits && + std::same_as::device_kernel_tag, + DeviceGroupedConvBwdData_multiple_d_Xdl_CShuffle_Tag> +constexpr ConvTraits instance_to_conv_traits() +{ + using InstTraits = InstanceTraits; + + return ConvTraits{ + .spatial_dim = InstTraits::kSpatialDim, + .direction = conv_direction(), + .layout = bwd_wei_conv_layout(), + .data_type = conv_data_type(), + .input_element_op = elementwise_op(), + .weight_element_op = elementwise_op(), + .output_element_op = elementwise_op(), + .conv_specialization = conv_spec(), + .thread_block_size = InstTraits::kBlockSize, + .tile_dims = conv_traits_data_tile(InstTraits::kK0PerBlock), + .a_tile_transfer = + conv_traits_a_transfer_params(InstTraits::kAK1, InstTraits::kK0PerBlock), + .b_tile_transfer = + conv_traits_b_transfer_params(InstTraits::kBK1, InstTraits::kK0PerBlock), + .warp_gemm = conv_traits_xdl_warp_gemm_params(), + .c_tile_transfer = + {.shuffle_params = {.m_gemms_per_shuffle = InstTraits::kCShuffleMXdlPerWavePerShuffle, + .n_gemms_per_shuffle = InstTraits::kCShuffleNXdlPerWavePerShuffle}, + .thread_cluster_dims = {InstTraits::kCThreadClusterLengths[0], + InstTraits::kCThreadClusterLengths[1], + InstTraits::kCThreadClusterLengths[2], + InstTraits::kCThreadClusterLengths[3]}, + .scalar_per_vector = InstTraits::kCBlockTransferScalarPerVector_NWaveNPerXdl}, + .num_gemm_k_prefetch_stage = InstTraits::kNumGemmKPrefetchStage, + .pipeline_version = get_pipeline_version(), + .pipeline_scheduler = get_pipeline_scheduler(), + .max_transpose_transfer_src_scalar_per_vector = + InstTraits::kMaxTransposeTransferSrcScalarPerVector, + .max_transpose_transfer_dst_scalar_per_vector = + InstTraits::kMaxTransposeTransferDstScalarPerVector, + .do_pad_gemm_m = InstTraits::kDoPadGemmM, + .do_pad_gemm_n = InstTraits::kDoPadGemmN, + }; +} + +} // namespace ck_tile::reflect::conv diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp index 4f39b00b5c..0ce714adcc 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp @@ -42,8 +42,9 @@ constexpr ConvTraits instance_to_conv_traits() .pipeline_scheduler = get_pipeline_scheduler(), .max_transpose_transfer_src_scalar_per_vector = InstTraits::kTransposeTransferSrcScalarPerVector, - .max_transpose_dst_scalar_per_vector = InstTraits::kTransposeTransferDstScalarPerVector, - .num_groups_to_merge = InstTraits::kNumGroupsToMerge, + .max_transpose_transfer_dst_scalar_per_vector = + InstTraits::kTransposeTransferDstScalarPerVector, + .num_groups_to_merge = InstTraits::kNumGroupsToMerge, }; } diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp index 5666233091..ba663c12bb 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp @@ -49,8 +49,9 @@ constexpr ConvTraits instance_to_conv_traits() .pipeline_scheduler = get_pipeline_scheduler(), .max_transpose_transfer_src_scalar_per_vector = InstTraits::kTransposeTransferSrcScalarPerVector, - .max_transpose_dst_scalar_per_vector = InstTraits::kTransposeTransferDstScalarPerVector, - .num_groups_to_merge = InstTraits::kNumGroupsToMerge, + .max_transpose_transfer_dst_scalar_per_vector = + InstTraits::kTransposeTransferDstScalarPerVector, + .num_groups_to_merge = InstTraits::kNumGroupsToMerge, }; } diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp index 13625aa182..81a7bf76fd 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp @@ -42,7 +42,8 @@ constexpr ConvTraits instance_to_conv_traits() .pipeline_scheduler = get_pipeline_scheduler(), .max_transpose_transfer_src_scalar_per_vector = InstTraits::kMaxTransposeTransferSrcScalarPerVector, - .max_transpose_dst_scalar_per_vector = InstTraits::kMaxTransposeTransferDstScalarPerVector, + .max_transpose_transfer_dst_scalar_per_vector = + InstTraits::kMaxTransposeTransferDstScalarPerVector, }; } diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_xdl_cshuffle.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_xdl_cshuffle.hpp index 39fde33217..d47b2ee4d3 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_xdl_cshuffle.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_xdl_cshuffle.hpp @@ -49,7 +49,8 @@ constexpr ConvTraits instance_to_conv_traits() .pipeline_scheduler = get_pipeline_scheduler(), .max_transpose_transfer_src_scalar_per_vector = InstTraits::kMaxTransposeTransferSrcScalarPerVector, - .max_transpose_dst_scalar_per_vector = InstTraits::kMaxTransposeTransferDstScalarPerVector, + .max_transpose_transfer_dst_scalar_per_vector = + InstTraits::kMaxTransposeTransferDstScalarPerVector, }; } diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_helpers.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_helpers.hpp index 4baf2423ee..3b6c006e6b 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_helpers.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_helpers.hpp @@ -796,7 +796,8 @@ constexpr WarpGemmParams conv_traits_xdl_warp_gemm_params() } template -constexpr OutputTileTransferInfo conv_traits_wmma_c_tile_transfer() +constexpr OutputTileTransferInfo conv_traits_wmma_c_tile_transfer( + ck::index_t CDEBlockTansferScalarPerVector = InstTraits::kCDEBlockTransferScalarPerVector) { return OutputTileTransferInfo{ .shuffle_params = {.m_gemms_per_shuffle = InstTraits::kCShuffleMRepeatPerShuffle, @@ -805,7 +806,7 @@ constexpr OutputTileTransferInfo conv_traits_wmma_c_tile_transfer() InstTraits::kCDEThreadClusterLengths[1], InstTraits::kCDEThreadClusterLengths[2], InstTraits::kCDEThreadClusterLengths[3]}, - .scalar_per_vector = InstTraits::kCDEBlockTransferScalarPerVector}; + .scalar_per_vector = CDEBlockTansferScalarPerVector}; } template diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_to_conv_traits.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_to_conv_traits.hpp index e10baaf712..cb4b3b2175 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/instance_to_conv_traits.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_to_conv_traits.hpp @@ -18,3 +18,8 @@ #include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp" #include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp" #include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_wmma_cshuffle.hpp" + +// Bwd data instances +#include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle.hpp" +#include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp" +#include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp" diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp new file mode 100644 index 0000000000..f63a17a977 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp @@ -0,0 +1,315 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "instance_traits.hpp" +#include "instance_traits_util.hpp" +#include "ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp" + +namespace ck::tensor_operation::device { + +template +struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle; + +} // namespace ck::tensor_operation::device + +namespace ck_tile::reflect { +/// @brief Tag type for DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle device kernel +struct DeviceGroupedConvBwdData_multiple_d_Wmma_CShuffle_Tag; + +template +struct InstanceTraits> +{ + static constexpr auto kTensorOpName = "DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle"; + + /// @brief Tag type identifying this device kernel variant + using device_kernel_tag = DeviceGroupedConvBwdData_multiple_d_Wmma_CShuffle_Tag; + + static constexpr ck::index_t kSpatialDim = NDimSpatial; + + using InLayout = InLayout_; + using WeiLayout = WeiLayout_; + using OutLayout = OutLayout_; + using DsLayout = DsLayout_; + + using InDataType = InDataType_; + using WeiDataType = WeiDataType_; + using OutDataType = OutDataType_; + using AccDataType = AccDataType_; + using DsDataType = DsDataType_; + + using InElementwiseOperation = InElementwiseOp_; + using WeiElementwiseOperation = WeiElementwiseOp_; + using OutElementwiseOperation = OutElementwiseOp_; + + static constexpr auto kConvBwdDataSpecialization = ConvBackwardDataSpecialization; + + static constexpr ck::index_t kBlockSize = BlockSize; + static constexpr ck::index_t kMPerBlock = MPerBlock; + static constexpr ck::index_t kNPerBlock = NPerBlock; + static constexpr ck::index_t kK0PerBlock = K0PerBlock; + static constexpr ck::index_t kK1 = K1; + static constexpr ck::index_t kMPerWmma = MPerWMMA; + static constexpr ck::index_t kNPerWmma = NPerWMMA; + static constexpr ck::index_t kMRepeat = MRepeat; + static constexpr ck::index_t kNRepeat = NRepeat; + static constexpr ck::index_t kCShuffleMRepeatPerShuffle = CShuffleMRepeatPerShuffle; + static constexpr ck::index_t kCShuffleNRepeatPerShuffle = CShuffleNRepeatPerShuffle; + static constexpr ck::index_t kCDEShuffleBlockTransferScalarPerVector_NPerBlock = + CDEShuffleBlockTransferScalarPerVector_NPerBlock; + + static constexpr ck::PipelineVersion kPipelineVer = PipelineVer; + static constexpr int kNumGemmKPrefetchStage = NumGemmKPrefetchStage; + + using ABlockTransferThreadClusterLengths_AK0_M_AK1 = + ABlockTransferThreadClusterLengths_AK0_M_AK1_; + using ABlockTransferThreadClusterArrangeOrder = ABlockTransferThreadClusterArrangeOrder_; + using ABlockTransferSrcAccessOrder = ABlockTransferSrcAccessOrder_; + // A block transfer thread cluster dimensions (converted to std::array) + static constexpr auto kAThreadClusterLengths = + detail::SequenceToArray::value; + static constexpr auto kAThreadClusterArrangeOrder = + detail::SequenceToArray::value; + static constexpr auto kABlockTransferSrcAccessOrder = + detail::SequenceToArray::value; + + static constexpr ck::index_t kABlockTransferSrcVectorDim = ABlockTransferSrcVectorDim; + static constexpr ck::index_t kABlockTransferSrcScalarPerVector = + ABlockTransferSrcScalarPerVector; + static constexpr ck::index_t kABlockTransferDstScalarPerVectorK1 = + ABlockTransferDstScalarPerVector_AK1; + static constexpr bool kABlockLdsExtraM = ABlockLdsExtraM; + + using BBlockTransferThreadClusterLengths_BK0_N_BK1 = + BBlockTransferThreadClusterLengths_BK0_N_BK1_; + using BBlockTransferThreadClusterArrangeOrder = BBlockTransferThreadClusterArrangeOrder_; + using BBlockTransferSrcAccessOrder = BBlockTransferSrcAccessOrder_; + // B block transfer thread cluster dimensions (converted to std::array) + // B block transfer thread cluster dimensions (converted to std::array) + static constexpr auto kBThreadClusterLengths = + detail::SequenceToArray::value; + static constexpr auto kBThreadClusterArrangeOrder = + detail::SequenceToArray::value; + static constexpr auto kBBlockTransferSrcAccessOrder = + detail::SequenceToArray::value; + static constexpr ck::index_t kBBlockTransferSrcVectorDim = BBlockTransferSrcVectorDim; + static constexpr ck::index_t kBBlockTransferSrcScalarPerVector = + BBlockTransferSrcScalarPerVector; + static constexpr ck::index_t kBBlockTransferDstScalarPerVectorK1 = + BBlockTransferDstScalarPerVector_BK1; + static constexpr bool kBBlockLdsExtraN = BBlockLdsExtraN; + + using CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock = + CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_; + + static constexpr ck::LoopScheduler kLoopScheduler = LoopSched; + + static constexpr auto kCDEThreadClusterLengths = detail::SequenceToArray< + CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>::value; + + // Static member function to generate instance string + static std::string instance_string() + { + std::ostringstream oss; + + // Kernel type name + oss << kTensorOpName; + + // Template parameters in exact order + oss << "<" << kSpatialDim; // 1. NDimSpatial + oss << "," << detail::layout_name(); // 2. OutLayout + oss << "," << detail::layout_name(); // 3. WeiLayout + oss << "," << detail::tuple_name(); // 4. DsLayout + oss << "," << detail::layout_name(); // 5. InLayout + oss << "," << detail::type_name(); // 6. OutDataType + oss << "," << detail::type_name(); // 7. WeiDataType + oss << "," << detail::type_name(); // 8. AccDataType + oss << "," << detail::tuple_name(); // 9. DsDataType + oss << "," << detail::type_name(); // 10. InDataType + oss << "," + << detail::elementwise_op_name(); // 11. + // OutElementwiseOperation + oss << "," + << detail::elementwise_op_name(); // 12. + // WeiElementwiseOperation + oss << "," + << detail::elementwise_op_name(); // 13. InElementwiseOperation + oss << "," + << detail::conv_bwd_data_spec_name( + kConvBwdDataSpecialization); // 14. ConvBackwardDataSpecialization + oss << "," << kBlockSize; // 15. BlockSize + oss << "," << kMPerBlock; // 16. MPerBlock + oss << "," << kNPerBlock; // 17. NPerBlock + oss << "," << kK0PerBlock; // 18. K0PerBlock + oss << "," << kK1; // 19. ABK1 + oss << "," << kMPerWmma; // 20. MPerWmma + oss << "," << kNPerWmma; // 21. NPerWmma + oss << "," << kMRepeat; // 22. MRepeat + oss << "," << kNRepeat; // 23. NRepeat + oss << "," << detail::sequence_name(); // 24. + oss << "," << detail::sequence_name(); // 25. + oss << "," << detail::sequence_name(); // 26. + oss << "," << kABlockTransferSrcVectorDim; // 27. + oss << "," << kABlockTransferSrcScalarPerVector; // 28. + oss << "," << kABlockTransferDstScalarPerVectorK1; // 29. + oss << "," << (kABlockLdsExtraM ? "true" : "false"); // 30. + oss << "," << detail::sequence_name(); // 31. + oss << "," << detail::sequence_name(); // 32. + oss << "," << detail::sequence_name(); // 33. + oss << "," << kBBlockTransferSrcVectorDim; // 34. + oss << "," << kBBlockTransferSrcScalarPerVector; // 35. + oss << "," << kBBlockTransferDstScalarPerVectorK1; // 36. + oss << "," << (kBBlockLdsExtraN ? "true" : "false"); // 37. + oss << "," << kCShuffleMRepeatPerShuffle; // 38. + oss << "," << kCShuffleNRepeatPerShuffle; // 39. + oss << "," + << detail::sequence_name< + CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>(); // 40. + oss << "," << kCDEShuffleBlockTransferScalarPerVector_NPerBlock; // 41. + oss << "," << detail::loop_scheduler_name(kLoopScheduler); // 42. LoopSched + oss << "," << kNumGemmKPrefetchStage; // 43. + oss << "," << detail::pipeline_version_name(kPipelineVer); // 44. + + oss << ">"; + + return oss.str(); + } +}; + +} // namespace ck_tile::reflect diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp new file mode 100644 index 0000000000..13b892857c --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp @@ -0,0 +1,350 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "instance_traits.hpp" +#include "instance_traits_util.hpp" +#include "ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp" + +namespace ck::tensor_operation::device { + +template +struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3; + +} // namespace ck::tensor_operation::device + +namespace ck_tile::reflect { +/// @brief Tag type for DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle_V3 device kernel +struct DeviceGroupedConvBwdData_multiple_d_Wmma_CShuffle_V3_Tag; + +template +struct InstanceTraits< + ck::tensor_operation::device::DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3< + NDimSpatial, + OutLayout_, // output image + WeiLayout_, // weight + DsLayout_, // bias + InLayout_, // input image + OutDataType_, // output image + WeiDataType_, // weight + AccDataType_, + CShuffleDataType_, + DsDataType_, // bias + InDataType_, // input image + OutElementwiseOp_, // output image + WeiElementwiseOp_, // weight + InElementwiseOp_, // C, bias, and input image + ConvBackwardDataSpecialization, + DoPadGemmM, + DoPadGemmN, + BlockSize, + MPerBlock, + NPerBlock, + K0PerBlock, + AK1, + BK1, + MPerWMMA, + NPerWMMA, + MRepeat, + NRepeat, + ABlockTransferThreadClusterLengths_AK0_M_AK1_, + ABlockTransferThreadClusterArrangeOrder_, + ABlockTransferSrcAccessOrder_, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1_, + BBlockTransferThreadClusterArrangeOrder_, + BBlockTransferSrcAccessOrder_, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + BBlockLdsExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_, + CDEShuffleBlockTransferScalarPerVector_NPerBlock_, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA_, + ComputeTypeB_, + max_transpose_transfer_src_scalar_per_vector, + max_transpose_transfer_dst_scalar_per_vector>> +{ + static constexpr auto kTensorOpName = "DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3"; + + /// @brief Tag type identifying this device kernel variant + using device_kernel_tag = DeviceGroupedConvBwdData_multiple_d_Wmma_CShuffle_V3_Tag; + + static constexpr ck::index_t kSpatialDim = NDimSpatial; + + using InLayout = InLayout_; + using WeiLayout = WeiLayout_; + using OutLayout = OutLayout_; + using DsLayout = DsLayout_; + + using InDataType = InDataType_; + using WeiDataType = WeiDataType_; + using OutDataType = OutDataType_; + using AccDataType = AccDataType_; + using DsDataType = DsDataType_; + + using InElementwiseOperation = InElementwiseOp_; + using WeiElementwiseOperation = WeiElementwiseOp_; + using OutElementwiseOperation = OutElementwiseOp_; + + static constexpr auto kConvBwdDataSpecialization = ConvBackwardDataSpecialization; + + static constexpr ck::index_t kBlockSize = BlockSize; + static constexpr ck::index_t kMPerBlock = MPerBlock; + static constexpr ck::index_t kNPerBlock = NPerBlock; + static constexpr ck::index_t kK0PerBlock = K0PerBlock; + static constexpr ck::index_t kAK1 = AK1; + static constexpr ck::index_t kBK1 = BK1; + static constexpr ck::index_t kMPerWmma = MPerWMMA; + static constexpr ck::index_t kNPerWmma = NPerWMMA; + static constexpr ck::index_t kMRepeat = MRepeat; + static constexpr ck::index_t kNRepeat = NRepeat; + static constexpr ck::index_t kCShuffleMRepeatPerShuffle = CShuffleMRepeatPerShuffle; + static constexpr ck::index_t kCShuffleNRepeatPerShuffle = CShuffleNRepeatPerShuffle; + static constexpr ck::index_t kMaxTransposeTransferSrcScalarPerVector = + max_transpose_transfer_src_scalar_per_vector; + static constexpr ck::index_t kMaxTransposeTransferDstScalarPerVector = + max_transpose_transfer_dst_scalar_per_vector; + static constexpr bool kDoPadGemmM = DoPadGemmM; + static constexpr bool kDoPadGemmN = DoPadGemmN; + using CDEShuffleBlockTransferScalarPerVector_NPerBlock = + CDEShuffleBlockTransferScalarPerVector_NPerBlock_; + + static constexpr auto kCDEShuffleBlockTransferScalarPerVector_NPerBlock = + detail::SequenceToArray::value; + + static constexpr ck::BlockGemmPipelineScheduler kBlkGemmPipeSched = BlkGemmPipeSched; + static constexpr ck::BlockGemmPipelineVersion kBlkGemmPipelineVer = BlkGemmPipelineVer; + + using ABlockTransferThreadClusterLengths_AK0_M_AK1 = + ABlockTransferThreadClusterLengths_AK0_M_AK1_; + using ABlockTransferThreadClusterArrangeOrder = ABlockTransferThreadClusterArrangeOrder_; + using ABlockTransferSrcAccessOrder = ABlockTransferSrcAccessOrder_; + // A block transfer thread cluster dimensions (converted to std::array) + static constexpr auto kAThreadClusterLengths = + detail::SequenceToArray::value; + static constexpr auto kAThreadClusterArrangeOrder = + detail::SequenceToArray::value; + static constexpr auto kABlockTransferSrcAccessOrder = + detail::SequenceToArray::value; + + static constexpr ck::index_t kABlockTransferSrcVectorDim = ABlockTransferSrcVectorDim; + static constexpr ck::index_t kABlockTransferSrcScalarPerVector = + ABlockTransferSrcScalarPerVector; + static constexpr ck::index_t kABlockTransferDstScalarPerVectorK1 = + ABlockTransferDstScalarPerVector_AK1; + static constexpr bool kABlockLdsExtraM = ABlockLdsExtraM; + + using BBlockTransferThreadClusterLengths_BK0_N_BK1 = + BBlockTransferThreadClusterLengths_BK0_N_BK1_; + using BBlockTransferThreadClusterArrangeOrder = BBlockTransferThreadClusterArrangeOrder_; + using BBlockTransferSrcAccessOrder = BBlockTransferSrcAccessOrder_; + // B block transfer thread cluster dimensions (converted to std::array) + // B block transfer thread cluster dimensions (converted to std::array) + static constexpr auto kBThreadClusterLengths = + detail::SequenceToArray::value; + static constexpr auto kBThreadClusterArrangeOrder = + detail::SequenceToArray::value; + static constexpr auto kBBlockTransferSrcAccessOrder = + detail::SequenceToArray::value; + static constexpr ck::index_t kBBlockTransferSrcVectorDim = BBlockTransferSrcVectorDim; + static constexpr ck::index_t kBBlockTransferSrcScalarPerVector = + BBlockTransferSrcScalarPerVector; + static constexpr ck::index_t kBBlockTransferDstScalarPerVectorK1 = + BBlockTransferDstScalarPerVector_BK1; + static constexpr bool kBBlockLdsExtraN = BBlockLdsExtraN; + + using CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock = + CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_; + + static constexpr auto kCDEThreadClusterLengths = detail::SequenceToArray< + CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>::value; + + using ComputeTypeA = ComputeTypeA_; + using ComputeTypeB = ComputeTypeB_; + // Static member function to generate instance string + static std::string instance_string() + { + std::ostringstream oss; + + // Kernel type name + oss << kTensorOpName; + + // Template parameters in exact order + oss << "<" << kSpatialDim; // 1. NDimSpatial + oss << "," << detail::layout_name(); // 2. OutLayout + oss << "," << detail::layout_name(); // 3. WeiLayout + oss << "," << detail::tuple_name(); // 4. DsLayout + oss << "," << detail::layout_name(); // 5. InLayout + oss << "," << detail::type_name(); // 6. OutDataType + oss << "," << detail::type_name(); // 7. WeiDataType + oss << "," << detail::type_name(); // 8. AccDataType + oss << "," << detail::tuple_name(); // 9. DsDataType + oss << "," << detail::type_name(); // 10. InDataType + oss << "," + << detail::elementwise_op_name(); // 11. + // OutElementwiseOperation + oss << "," + << detail::elementwise_op_name(); // 12. + // WeiElementwiseOperation + oss << "," + << detail::elementwise_op_name(); // 13. InElementwiseOperation + oss << "," + << detail::conv_bwd_data_spec_name( + kConvBwdDataSpecialization); // 14. ConvBackwardDataSpecialization + oss << "," << kDoPadGemmM; + oss << "," << kDoPadGemmN; + oss << "," << kBlockSize; // 15. BlockSize + oss << "," << kMPerBlock; // 16. MPerBlock + oss << "," << kNPerBlock; // 17. NPerBlock + oss << "," << kK0PerBlock; // 18. K0PerBlock + oss << "," << kAK1; // 19. ABK1 + oss << "," << kBK1; // 19. ABK1 + oss << "," << kMPerWmma; // 20. MPerWmma + oss << "," << kNPerWmma; // 21. NPerWmma + oss << "," << kMRepeat; // 22. MRepeat + oss << "," << kNRepeat; // 23. NRepeat + oss << "," << detail::sequence_name(); // 24. + oss << "," << detail::sequence_name(); // 25. + oss << "," << detail::sequence_name(); // 26. + oss << "," << kABlockTransferSrcVectorDim; // 27. + oss << "," << kABlockTransferSrcScalarPerVector; // 28. + oss << "," << kABlockTransferDstScalarPerVectorK1; // 29. + oss << "," << (kABlockLdsExtraM ? "true" : "false"); // 30. + oss << "," << detail::sequence_name(); // 31. + oss << "," << detail::sequence_name(); // 32. + oss << "," << detail::sequence_name(); // 33. + oss << "," << kBBlockTransferSrcVectorDim; // 34. + oss << "," << kBBlockTransferSrcScalarPerVector; // 35. + oss << "," << kBBlockTransferDstScalarPerVectorK1; // 36. + oss << "," << (kBBlockLdsExtraN ? "true" : "false"); // 37. + oss << "," << kCShuffleMRepeatPerShuffle; // 38. + oss << "," << kCShuffleNRepeatPerShuffle; // 39. + oss << "," + << detail::sequence_name< + CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>(); // 40. + oss << "," << kCDEShuffleBlockTransferScalarPerVector_NPerBlock[0]; // 41. + oss << "," << detail::pipeline_scheduler_name(kBlkGemmPipeSched); // 43. + oss << "," << detail::pipeline_version_name(kBlkGemmPipelineVer); // 44. + oss << "," << detail::type_name(); // 45. + oss << "," << detail::type_name(); // 46. + oss << "," << kMaxTransposeTransferSrcScalarPerVector; // 47. + oss << "," << kMaxTransposeTransferDstScalarPerVector; // 48. + + oss << ">"; + + return oss.str(); + } +}; + +} // namespace ck_tile::reflect diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle.hpp new file mode 100644 index 0000000000..df2a3532c9 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle.hpp @@ -0,0 +1,345 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "instance_traits.hpp" +#include "instance_traits_util.hpp" +#include "ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp" + +namespace ck::tensor_operation::device { + +template +struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1; + +} // namespace ck::tensor_operation::device + +namespace ck_tile::reflect { +/// @brief Tag type for DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle device kernel +struct DeviceGroupedConvBwdData_multiple_d_Xdl_CShuffle_Tag; + +template +struct InstanceTraits< + ck::tensor_operation::device::DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< + NDimSpatial, + OutLayout_, + WeiLayout_, + DsLayout_, + InLayout_, + OutDataType_, + WeiDataType_, + AccDataType_, + OutComputeType_, + DsDataType_, + InDataType_, + OutElementwiseOperation_, + WeiElementwiseOperation_, + InElementwiseOperation_, + ConvBackwardDataSpecialization, + do_pad_gemm_m, + do_pad_gemm_n, + num_gemm_k_prefetch_stages, + BlockSize, + MPerBlock, + NPerBlock, + K0PerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1_, + ABlockTransferThreadClusterArrangeOrder_, + ABlockTransferSrcAccessOrder_, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + ABlockLdsAddExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1_, + BBlockTransferThreadClusterArrangeOrder_, + BBlockTransferSrcAccessOrder_, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + BBlockLdsAddExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_, + CBlockTransferScalarPerVector_NWaveNPerXdl, + LoopSched, + ComputeTypeA_, + ComputeTypeB_, + max_transpose_transfer_src_scalar_per_vector, + max_transpose_transfer_dst_scalar_per_vector>> +{ + static constexpr auto kTensorOpName = "DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle"; + + /// @brief Tag type identifying this device kernel variant + using device_kernel_tag = DeviceGroupedConvBwdData_multiple_d_Xdl_CShuffle_Tag; + + static constexpr ck::index_t kSpatialDim = NDimSpatial; + + using InLayout = InLayout_; + using WeiLayout = WeiLayout_; + using OutLayout = OutLayout_; + using DsLayout = DsLayout_; + + using InDataType = InDataType_; + using WeiDataType = WeiDataType_; + using OutDataType = OutDataType_; + using AccDataType = AccDataType_; + using DsDataType = DsDataType_; + + using InElementwiseOperation = InElementwiseOperation_; + using WeiElementwiseOperation = WeiElementwiseOperation_; + using OutElementwiseOperation = OutElementwiseOperation_; + + static constexpr auto kConvBwdDataSpecialization = ConvBackwardDataSpecialization; + + static constexpr ck::index_t kBlockSize = BlockSize; + static constexpr ck::index_t kMPerBlock = MPerBlock; + static constexpr ck::index_t kNPerBlock = NPerBlock; + static constexpr ck::index_t kK0PerBlock = K0PerBlock; + static constexpr ck::index_t kAK1 = AK1; + static constexpr ck::index_t kBK1 = BK1; + static constexpr ck::index_t kMPerXDL = MPerXDL; + static constexpr ck::index_t kNPerXDL = NPerXDL; + static constexpr ck::index_t kMXdlPerWave = MXdlPerWave; + static constexpr ck::index_t kNXdlPerWave = NXdlPerWave; + static constexpr ck::index_t kCShuffleMXdlPerWavePerShuffle = CShuffleMXdlPerWavePerShuffle; + static constexpr ck::index_t kCShuffleNXdlPerWavePerShuffle = CShuffleNXdlPerWavePerShuffle; + static constexpr ck::index_t kCBlockTransferScalarPerVector_NWaveNPerXdl = + CBlockTransferScalarPerVector_NWaveNPerXdl; + static constexpr ck::index_t kMaxTransposeTransferSrcScalarPerVector = + max_transpose_transfer_src_scalar_per_vector; + static constexpr ck::index_t kMaxTransposeTransferDstScalarPerVector = + max_transpose_transfer_dst_scalar_per_vector; + static constexpr bool kDoPadGemmM = do_pad_gemm_m; + static constexpr bool kDoPadGemmN = do_pad_gemm_n; + static constexpr int kNumGemmKPrefetchStage = num_gemm_k_prefetch_stages; + + using ABlockTransferThreadClusterLengths_K0_M_K1 = ABlockTransferThreadClusterLengths_K0_M_K1_; + using ABlockTransferThreadClusterArrangeOrder = ABlockTransferThreadClusterArrangeOrder_; + using ABlockTransferSrcAccessOrder = ABlockTransferSrcAccessOrder_; + // A block transfer thread cluster dimensions (converted to std::array) + static constexpr auto kAThreadClusterLengths = + detail::SequenceToArray::value; + static constexpr auto kAThreadClusterArrangeOrder = + detail::SequenceToArray::value; + static constexpr auto kABlockTransferSrcAccessOrder = + detail::SequenceToArray::value; + static constexpr ck::index_t kABlockTransferSrcVectorDim = ABlockTransferSrcVectorDim; + static constexpr ck::index_t kABlockTransferSrcScalarPerVector = + ABlockTransferSrcScalarPerVector; + static constexpr ck::index_t kABlockTransferDstScalarPerVectorK1 = + ABlockTransferDstScalarPerVector_AK1; + static constexpr bool kABlockLdsExtraM = ABlockLdsAddExtraM; + + using BBlockTransferThreadClusterLengths_K0_N_K1 = BBlockTransferThreadClusterLengths_K0_N_K1_; + using BBlockTransferThreadClusterArrangeOrder = BBlockTransferThreadClusterArrangeOrder_; + using BBlockTransferSrcAccessOrder = BBlockTransferSrcAccessOrder_; + // B block transfer thread cluster dimensions (converted to std::array) + static constexpr auto kBThreadClusterLengths = + detail::SequenceToArray::value; + static constexpr auto kBThreadClusterArrangeOrder = + detail::SequenceToArray::value; + static constexpr auto kBBlockTransferSrcAccessOrder = + detail::SequenceToArray::value; + static constexpr ck::index_t kBBlockTransferSrcVectorDim = BBlockTransferSrcVectorDim; + static constexpr ck::index_t kBBlockTransferSrcScalarPerVector = + BBlockTransferSrcScalarPerVector; + static constexpr ck::index_t kBBlockTransferDstScalarPerVectorK1 = + BBlockTransferDstScalarPerVector_BK1; + static constexpr bool kBBlockLdsExtraN = BBlockLdsAddExtraN; + + using CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock = + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_; + + using ComputeTypeA = ComputeTypeA_; + using ComputeTypeB = ComputeTypeB_; + + static constexpr ck::LoopScheduler kLoopScheduler = LoopSched; + + static constexpr auto kCThreadClusterLengths = detail::SequenceToArray< + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>::value; + + // Static member function to generate instance string + static std::string instance_string() + { + std::ostringstream oss; + + // Kernel type name + oss << kTensorOpName; + + // Template parameters in exact order + oss << "<" << kSpatialDim; // 1. NDimSpatial + oss << "," << detail::layout_name(); // 2. OutLayout + oss << "," << detail::layout_name(); // 3. WeiLayout + oss << "," << detail::tuple_name(); // 4. DsLayout + oss << "," << detail::layout_name(); // 5. InLayout + oss << "," << detail::type_name(); // 6. OutDataType + oss << "," << detail::type_name(); // 7. WeiDataType + oss << "," << detail::type_name(); // 8. AccDataType + oss << "," << detail::tuple_name(); // 9. DsDataType + oss << "," << detail::type_name(); // 10. InDataType + oss << "," + << detail::elementwise_op_name(); // 11. + // OutElementwiseOperation + oss << "," + << detail::elementwise_op_name(); // 12. + // WeiElementwiseOperation + oss << "," + << detail::elementwise_op_name(); // 13. InElementwiseOperation + oss << "," + << detail::conv_bwd_data_spec_name( + kConvBwdDataSpecialization); // 14. ConvBackwardDataSpecialization + oss << "," << kDoPadGemmM; + oss << "," << kDoPadGemmN; + oss << "," << kNumGemmKPrefetchStage; + oss << "," << kBlockSize; // 15. BlockSize + oss << "," << kMPerBlock; // 16. MPerBlock + oss << "," << kNPerBlock; // 17. NPerBlock + oss << "," << kK0PerBlock; // 18. K0PerBlock + oss << "," << kAK1; // 19. AK1 + oss << "," << kBK1; // 19. BK1 + oss << "," << kMPerXDL; // 20. MPerXDL + oss << "," << kNPerXDL; // 21. NPerXDL + oss << "," << kMXdlPerWave; // 22. MXdlPerWave + oss << "," << kNXdlPerWave; // 23. NXdlPerWave + oss << "," << detail::sequence_name(); // 24. + oss << "," << detail::sequence_name(); // 25. + oss << "," << detail::sequence_name(); // 26. + oss << "," << kABlockTransferSrcVectorDim; // 27. + oss << "," << kABlockTransferSrcScalarPerVector; // 28. + oss << "," << kABlockTransferDstScalarPerVectorK1; // 29. + oss << "," << (kABlockLdsExtraM ? "true" : "false"); // 30. + oss << "," << detail::sequence_name(); // 31. + oss << "," << detail::sequence_name(); // 32. + oss << "," << detail::sequence_name(); // 33. + oss << "," << kBBlockTransferSrcVectorDim; // 34. + oss << "," << kBBlockTransferSrcScalarPerVector; // 35. + oss << "," << kBBlockTransferDstScalarPerVectorK1; // 36. + oss << "," << (kBBlockLdsExtraN ? "true" : "false"); // 37. + oss << "," << kCShuffleMXdlPerWavePerShuffle; // 38. + oss << "," << kCShuffleNXdlPerWavePerShuffle; // 39. + oss << "," + << detail::sequence_name< + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>(); // 40. + oss << "," << kCBlockTransferScalarPerVector_NWaveNPerXdl; // 42. + oss << "," << kNumGemmKPrefetchStage; // 41. + oss << "," << detail::loop_scheduler_name(kLoopScheduler); // 43. LoopSched + oss << "," << detail::type_name(); // 44. + oss << "," << detail::type_name(); // 45. + oss << "," << kMaxTransposeTransferSrcScalarPerVector; // 46. + oss << "," << kMaxTransposeTransferDstScalarPerVector; // 47. + + oss << ">"; + + return oss.str(); + } +}; + +} // namespace ck_tile::reflect diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_util.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_util.hpp index 1055cbc038..b9aa8e37e0 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_util.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_util.hpp @@ -18,6 +18,7 @@ #include #include #include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp" +#include "ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp" #include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" @@ -138,6 +139,18 @@ constexpr std::string_view conv_bwd_weight_spec_name( } } +// Convert ConvolutionBackwardDataSpecialization enum to string +constexpr std::string_view +conv_bwd_data_spec_name(ck::tensor_operation::device::ConvolutionBackwardDataSpecialization spec) +{ + using enum ck::tensor_operation::device::ConvolutionBackwardDataSpecialization; + switch(spec) + { + case Default: return "Default"; + case Filter1x1Stride1Pad0: return "Filter1x1Stride1Pad0"; + } +} + // Convert GemmSpecialization enum to string constexpr std::string_view gemm_spec_name(ck::tensor_operation::device::GemmSpecialization spec) { diff --git a/experimental/builder/test/CMakeLists.txt b/experimental/builder/test/CMakeLists.txt index 9397de5546..57fc4cc779 100644 --- a/experimental/builder/test/CMakeLists.txt +++ b/experimental/builder/test/CMakeLists.txt @@ -197,6 +197,9 @@ target_link_libraries(test_ckb_build_bwd_weight_instances PRIVATE utility) add_ck_builder_test(test_ckb_build_bwd_data_instances conv/ck_tile/test_ckb_conv_bwd_data_2d_fp16_v3.cpp + conv/ck/test_ckb_conv_bwd_data_multi_d_xdl_cshuffle.cpp + conv/ck/test_ckb_conv_bwd_data_multi_d_wmma_cshuffle.cpp + conv/ck/test_ckb_conv_bwd_data_multi_d_wmma_cshuffle_v3.cpp ) target_link_libraries(test_ckb_build_bwd_data_instances PRIVATE utility) diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_data_multi_d_wmma_cshuffle.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_data_multi_d_wmma_cshuffle.cpp new file mode 100644 index 0000000000..4f47d60879 --- /dev/null +++ b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_data_multi_d_wmma_cshuffle.cpp @@ -0,0 +1,43 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +#include "gmock/gmock.h" +#include "utils/ckb_conv_test_configs.hpp" +#include "utils/ckb_conv_test_utils.hpp" +#include "utils/conv_algorithm_type_utils.hpp" +#include "ck_tile/host/device_prop.hpp" + +namespace ckb = ck_tile::builder; +namespace ckt = ck_tile::builder::test; +namespace cku = ck_tile::builder::test_utils; + +constexpr auto SIGNATURE = + ckt::ConvSignature{.spatial_dim = 2, + .direction = ckb::ConvDirection::BACKWARD_DATA, + .data_type = ckb::DataType::FP16, + .accumulation_data_type = ckb::DataType::FP32, + .input = {.config = {.layout = ckb::TensorLayout::GNHWC}}, + .weight = {.config = {.layout = ckb::TensorLayout::GKYXC}}, + .output = {.config = {.layout = ckb::TensorLayout::GNHWK}}}; + +constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle{} + .with_thread_block(cku::ThreadBlock_64_32x32x32) + .with_gemm_config(cku::GemmParams_Wmma_16x16_2x1_per_wave) + .with_transfer(cku::BwdTransfer_4x8x1_4x16x1_v3) + .with_bwd_data_specialization(ckb::ConvSpecialization::DEFAULT) + .with_prefetch_config(1, ckb::PipelineScheduler::DEFAULT) + .with_gridwise_gemm_pipeline(ckb::PipelineVersion::V1); + +using Builder = ckb::ConvBuilder; +using Instance = Builder::Instance; + +TEST(BwdData_2DFp16_MultiD_Wmma_CShuffle_GNHWC, Create) +{ + const auto expected_transfer_parameters = to_string(ALGORITHM); + std::cout << "Expected Transfer Parameters: " << expected_transfer_parameters << std::endl; + cku::run_test({"DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle", + expected_transfer_parameters, + "Default", + "GNHWK,GKYXC,EmptyTuple,GNHWC", + "PassThrough,PassThrough,PassThrough", + "fp16,fp16"}); // check compute types +} diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_data_multi_d_wmma_cshuffle_v3.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_data_multi_d_wmma_cshuffle_v3.cpp new file mode 100644 index 0000000000..eee9d86947 --- /dev/null +++ b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_data_multi_d_wmma_cshuffle_v3.cpp @@ -0,0 +1,45 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +#include "gmock/gmock.h" +#include "utils/ckb_conv_test_configs.hpp" +#include "utils/ckb_conv_test_utils.hpp" +#include "utils/conv_algorithm_type_utils.hpp" +#include "ck_tile/host/device_prop.hpp" + +namespace ckb = ck_tile::builder; +namespace ckt = ck_tile::builder::test; +namespace cku = ck_tile::builder::test_utils; + +constexpr auto SIGNATURE = + ckt::ConvSignature{.spatial_dim = 2, + .direction = ckb::ConvDirection::BACKWARD_DATA, + .data_type = ckb::DataType::FP16, + .accumulation_data_type = ckb::DataType::FP32, + .input = {.config = {.layout = ckb::TensorLayout::GNHWC}}, + .weight = {.config = {.layout = ckb::TensorLayout::GKYXC}}, + .output = {.config = {.layout = ckb::TensorLayout::GNHWK}}}; + +constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle_V3{} + .with_thread_block(cku::ThreadBlock_64_32x32x32) + .with_gemm_config(cku::GemmParamsABK1_Wmma_16x16_2x1_per_wave) + .with_transfer(cku::BwdTransfer_4x8x1_4x16x1_v3) + .with_bwd_data_specialization(ckb::ConvSpecialization::DEFAULT) + .with_prefetch_config(1, ckb::PipelineScheduler::DEFAULT) + .with_gemm_pad_params(0, 0) + .with_block_gemm(cku::BlockGemmDesc_v1_intrawave) + .with_transpose_params(2, 2); + +using Builder = ckb::ConvBuilder; +using Instance = Builder::Instance; + +TEST(BwdData_2DFp16_MultiD_Wmma_CShuffle_V3_GNHWC, Create) +{ + const auto expected_transfer_parameters = to_string(ALGORITHM); + std::cout << "Expected Transfer Parameters: " << expected_transfer_parameters << std::endl; + cku::run_test({"DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3", + expected_transfer_parameters, + "Default", + "GNHWK,GKYXC,EmptyTuple,GNHWC", + "PassThrough,PassThrough,PassThrough", + "fp16,fp16"}); // check compute types +} diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_data_multi_d_xdl_cshuffle.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_data_multi_d_xdl_cshuffle.cpp new file mode 100644 index 0000000000..1adba672da --- /dev/null +++ b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_data_multi_d_xdl_cshuffle.cpp @@ -0,0 +1,44 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +#include "gmock/gmock.h" +#include "utils/ckb_conv_test_configs.hpp" +#include "utils/ckb_conv_test_utils.hpp" +#include "utils/conv_algorithm_type_utils.hpp" +#include "ck_tile/host/device_prop.hpp" + +namespace ckb = ck_tile::builder; +namespace ckt = ck_tile::builder::test; +namespace cku = ck_tile::builder::test_utils; + +constexpr auto SIGNATURE = + ckt::ConvSignature{.spatial_dim = 2, + .direction = ckb::ConvDirection::BACKWARD_DATA, + .data_type = ckb::DataType::FP16, + .accumulation_data_type = ckb::DataType::FP32, + .input = {.config = {.layout = ckb::TensorLayout::GNHWC}}, + .weight = {.config = {.layout = ckb::TensorLayout::GKYXC}}, + .output = {.config = {.layout = ckb::TensorLayout::GNHWK}}}; + +constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle{} + .with_thread_block(cku::ThreadBlock_256_256x128x32) + .with_gemm_config(cku::BwdDataGemmParams_Xdl_4x4_per_wave) + .with_transfer(cku::Transfer_4x64x1) + .with_prefetch_config(1, ckb::PipelineScheduler::DEFAULT) + .with_bwd_data_specialization(ckb::ConvSpecialization::DEFAULT) + .with_gemm_pad_params(0, 0) + .with_transpose_params(2, 2); + +using Builder = ckb::ConvBuilder; +using Instance = Builder::Instance; + +TEST(BwdData_2DFp16_MultiD_Xdl_CShuffle_GNHWC, Create) +{ + const auto expected_transfer_parameters = to_string(ALGORITHM); + std::cout << "Expected Transfer Parameters: " << expected_transfer_parameters << std::endl; + cku::run_test({"DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle", + expected_transfer_parameters, + "Default", + "GNHWK,GKYXC,EmptyTuple,GNHWC", + "PassThrough,PassThrough,PassThrough", + "fp16,fp16"}); // check compute types +} diff --git a/experimental/builder/test/conv/ck/test_conv_traits.cpp b/experimental/builder/test/conv/ck/test_conv_traits.cpp index ba9cb0a030..a171627753 100644 --- a/experimental/builder/test/conv/ck/test_conv_traits.cpp +++ b/experimental/builder/test/conv/ck/test_conv_traits.cpp @@ -19,6 +19,9 @@ #include #include #include +#include +#include +#include namespace { @@ -35,7 +38,390 @@ class ConvTraitsTest : public ::testing::Test { }; -// Test ConvTraits with DeviceGroupedConvBwdWeight_Xdl_CShuffle_V3 +// Test ConvTraits with DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle +TEST_F(ConvTraitsTest, ConvBwdDataMultipleDCshuffleWmmaTraitsExtraction) +{ + // Define a concrete instance type with specific template parameters + using DeviceInstance = + ck::tensor_operation::device::DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle< + 2, // NDimSpatial + ck::tensor_layout::convolution::GNHWK, // OutLayout + ck::tensor_layout::convolution::GKYXC, // WeiLayout + ck::Tuple<>, // DsLayout + ck::tensor_layout::convolution::GNHWC, // InLayout + ck::half_t, // OutDataType + ck::half_t, // WeiDataType + ck::half_t, // OutDataType + float, // AccDataType + ck::Tuple<>, // DsDataType + float, // OutComputeType + ck::tensor_operation::element_wise::PassThrough, // InElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // WeiElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // OutElementwiseOperation + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization:: + Default, // ConvBackwardDataSpecialization + 256, // BlockSize + 128, // MPerBlock + 128, // NPerBlock + 16, // K0PerBlock + 8, // K1 + 32, // MPerWMMA + 32, // NPerWMMA + 4, // MRepeat + 4, // NRepeat + ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 + ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder_ + ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_K1 + 1, // ABlockLdsAddExtraM + ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 + ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder_ + ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder_ + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_K1 + 1, // BBlockLdsAddExtraN + 1, // CShuffleMRepeatPerWavePerShuffle + 1, // CShuffleNRepeatPerWavePerShuffle + ck::Sequence<1, + 32, + 1, + 8>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_ + 8, // CDEBlockTransferScalarPerVector_NPerBlock_ + 2, // NumGemmKPrefetchStage + ck::LoopScheduler::Default, // BlkGemmPipeSched + ck::PipelineVersion::v1>; // PipelineVerison + + // Use ConvTraitsTmpl to extract compile-time information + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + + // Verify signature information + EXPECT_EQ(traits.spatial_dim, 2); + EXPECT_EQ(traits.direction, ConvDirection::BACKWARD_DATA); + EXPECT_THAT(traits.layout, + ElementsAre(TensorLayout::GNHWC, TensorLayout::GKYXC, TensorLayout::GNHWK)); + EXPECT_EQ(traits.data_type, DataType::FP32); + EXPECT_EQ(traits.input_element_op, ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(traits.weight_element_op, ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(traits.output_element_op, ElementwiseOperation::PASS_THROUGH); + + // Verify specializations + EXPECT_EQ(traits.pipeline_scheduler, PipelineScheduler::DEFAULT); + + // Verify algorithm information + EXPECT_EQ(traits.thread_block_size, 256); + + // Verify tile dimensions + EXPECT_EQ(traits.tile_dims.m, 128); + EXPECT_EQ(traits.tile_dims.n, 128); + EXPECT_EQ(traits.tile_dims.k, 16); + + // Verify A tile transfer info + EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.k0, 2); + EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.m_or_n, 128); + EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.k1, 8); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.k1, 8); + EXPECT_THAT(traits.a_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1)); + EXPECT_THAT(traits.a_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2)); + EXPECT_THAT(traits.a_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2)); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.src_vector_dim, 2); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.src_scalar_per_vector, 8); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8); + EXPECT_TRUE(traits.a_tile_transfer.transfer_params.lds_padding); + + // Verify B tile transfer info + EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.k0, 2); + EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.m_or_n, 128); + EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.k1, 8); + EXPECT_THAT(traits.b_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1)); + EXPECT_THAT(traits.b_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2)); + EXPECT_THAT(traits.b_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2)); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.src_vector_dim, 2); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.src_scalar_per_vector, 8); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8); + EXPECT_TRUE(traits.b_tile_transfer.transfer_params.lds_padding); + + // Verify warp GEMM params + EXPECT_EQ(traits.warp_gemm.gemm_m, 32); + EXPECT_EQ(traits.warp_gemm.gemm_n, 32); + EXPECT_EQ(traits.warp_gemm.m_iter, 4); + EXPECT_EQ(traits.warp_gemm.n_iter, 4); + + // Verify output tile transfer info + EXPECT_EQ(traits.c_tile_transfer.shuffle_params.m_gemms_per_shuffle, 1); + EXPECT_EQ(traits.c_tile_transfer.shuffle_params.n_gemms_per_shuffle, 1); + EXPECT_THAT(traits.c_tile_transfer.thread_cluster_dims, ElementsAre(1, 32, 1, 8)); + EXPECT_EQ(traits.c_tile_transfer.scalar_per_vector, 8); + EXPECT_EQ(traits.num_gemm_k_prefetch_stage, 2); + + // Verify pipeline configuration +} + +// Test ConvTraits with DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle_V3 +TEST_F(ConvTraitsTest, ConvBwdDataMultipleDCshuffleWmmaV3TraitsExtraction) +{ + // Define a concrete instance type with specific template parameters + using DeviceInstance = + ck::tensor_operation::device::DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3< + 2, // NDimSpatial + ck::tensor_layout::convolution::GNHWK, // OutLayout + ck::tensor_layout::convolution::GKYXC, // WeiLayout + ck::Tuple<>, // DsLayout + ck::tensor_layout::convolution::GNHWC, // InLayout + ck::half_t, // OutDataType + ck::half_t, // WeiDataType + ck::half_t, // OutDataType + float, // AccDataType + ck::Tuple<>, // DsDataType + float, // OutComputeType + ck::tensor_operation::element_wise::PassThrough, // InElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // WeiElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // OutElementwiseOperation + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization:: + Default, // ConvBackwardDataSpecialization + false, // DoPadGemmM + false, // DoPadGemmN + 256, // BlockSize + 128, // MPerBlock + 128, // NPerBlock + 16, // K0PerBlock + 8, // AK1 + 8, // BK1 + 32, // MPerWMMA + 32, // NPerWMMA + 4, // MRepeat + 4, // NRepeat + ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 + ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder_ + ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_K1 + 1, // ABlockLdsAddExtraM + ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 + ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder_ + ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder_ + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_K1 + 1, // BBlockLdsAddExtraN + 1, // CShuffleMRepeatPerWavePerShuffle + 1, // CShuffleNRepeatPerWavePerShuffle + ck::Sequence<1, + 32, + 1, + 8>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_ + ck::Sequence<8, 8, 8>, // CDEBlockTransferScalarPerVector_NPerBlock_ + ck::BlockGemmPipelineScheduler::Intrawave, // BlkGemmPipeSched + ck::BlockGemmPipelineVersion::v1, // BlkGemmPipelineVer + ck::half_t, // AComputeDataType + ck::half_t, // BComputeDataType + 1, // MaxTransposeTransferSrcScalarPerVector + 1>; // MaxTransposeTransferDstScalarPerVector + + // Use ConvTraitsTmpl to extract compile-time information + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + + // Verify signature information + EXPECT_EQ(traits.spatial_dim, 2); + EXPECT_EQ(traits.direction, ConvDirection::BACKWARD_DATA); + EXPECT_THAT(traits.layout, + ElementsAre(TensorLayout::GNHWC, TensorLayout::GKYXC, TensorLayout::GNHWK)); + EXPECT_EQ(traits.data_type, DataType::FP32); + EXPECT_EQ(traits.input_element_op, ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(traits.weight_element_op, ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(traits.output_element_op, ElementwiseOperation::PASS_THROUGH); + + // Verify specializations + EXPECT_EQ(traits.pipeline_scheduler, PipelineScheduler::DEFAULT); + + // Verify algorithm information + EXPECT_EQ(traits.thread_block_size, 256); + + // Verify tile dimensions + EXPECT_EQ(traits.tile_dims.m, 128); + EXPECT_EQ(traits.tile_dims.n, 128); + EXPECT_EQ(traits.tile_dims.k, 16); + + // Verify A tile transfer info + EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.k0, 2); + EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.m_or_n, 128); + EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.k1, 8); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.k1, 8); + EXPECT_THAT(traits.a_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1)); + EXPECT_THAT(traits.a_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2)); + EXPECT_THAT(traits.a_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2)); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.src_vector_dim, 2); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.src_scalar_per_vector, 8); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8); + EXPECT_TRUE(traits.a_tile_transfer.transfer_params.lds_padding); + EXPECT_FALSE(traits.do_pad_gemm_n.value()); + EXPECT_FALSE(traits.do_pad_gemm_m.value()); + + // Verify B tile transfer info + EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.k0, 2); + EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.m_or_n, 128); + EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.k1, 8); + EXPECT_THAT(traits.b_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1)); + EXPECT_THAT(traits.b_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2)); + EXPECT_THAT(traits.b_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2)); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.src_vector_dim, 2); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.src_scalar_per_vector, 8); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8); + EXPECT_TRUE(traits.b_tile_transfer.transfer_params.lds_padding); + + // Verify warp GEMM params + EXPECT_EQ(traits.warp_gemm.gemm_m, 32); + EXPECT_EQ(traits.warp_gemm.gemm_n, 32); + EXPECT_EQ(traits.warp_gemm.m_iter, 4); + EXPECT_EQ(traits.warp_gemm.n_iter, 4); + + // Verify output tile transfer info + EXPECT_EQ(traits.c_tile_transfer.shuffle_params.m_gemms_per_shuffle, 1); + EXPECT_EQ(traits.c_tile_transfer.shuffle_params.n_gemms_per_shuffle, 1); + EXPECT_THAT(traits.c_tile_transfer.thread_cluster_dims, ElementsAre(1, 32, 1, 8)); + EXPECT_EQ(traits.c_tile_transfer.scalar_per_vector, 8); + + EXPECT_EQ(traits.max_transpose_transfer_src_scalar_per_vector, 1); + EXPECT_EQ(traits.max_transpose_transfer_dst_scalar_per_vector, 1); + + // Verify pipeline configuration +} + +// Test ConvTraits with DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle +TEST_F(ConvTraitsTest, ConvBwdDataMultipleDCshuffleXDLTraitsExtraction) +{ + // Define a concrete instance type with specific template parameters + using DeviceInstance = + ck::tensor_operation::device::DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< + 2, // NDimSpatial + ck::tensor_layout::convolution::GNHWK, // OutLayout + ck::tensor_layout::convolution::GKYXC, // WeiLayout + ck::Tuple<>, // DsLayout + ck::tensor_layout::convolution::GNHWC, // InLayout + ck::half_t, // OutDataType + ck::half_t, // WeiDataType + ck::half_t, // OutDataType + float, // AccDataType + ck::Tuple<>, // DsDataType + float, // OutComputeType + ck::tensor_operation::element_wise::PassThrough, // InElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // WeiElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // OutElementwiseOperation + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization:: + Default, // ConvBackwardDataSpecialization + false, // DoPadGemmM + false, // DoPadGemmN + 1, // num_gemm_k_prefetch_stage + 256, // BlockSize + 128, // MPerBlock + 128, // NPerBlock + 16, // K0PerBlock + 8, // AK1 + 8, // BK1 + 32, // MPerXDL + 32, // NPerXDL + 4, // MXdlPerWave + 4, // NXdlPerWave + ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 + ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder_ + ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_K1 + 1, // ABlockLdsAddExtraM + ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 + ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder_ + ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder_ + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_K1 + 1, // BBlockLdsAddExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + ck::Sequence<1, + 32, + 1, + 8>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_ + 8, // CDEBlockTransferScalarPerVector_NPerBlock_ + ck::LoopScheduler::Default, // BlkGemmPipeSched + ck::half_t, // AComputeDataType + ck::half_t, // BComputeDataType + 1, // MaxTransposeTransferSrcScalarPerVector + 1>; // MaxTransposeTransferDstScalarPerVector + + // Use ConvTraitsTmpl to extract compile-time information + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + + // Verify signature information + EXPECT_EQ(traits.spatial_dim, 2); + EXPECT_EQ(traits.direction, ConvDirection::BACKWARD_DATA); + EXPECT_THAT(traits.layout, + ElementsAre(TensorLayout::GNHWC, TensorLayout::GKYXC, TensorLayout::GNHWK)); + EXPECT_EQ(traits.data_type, DataType::FP32); + EXPECT_EQ(traits.input_element_op, ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(traits.weight_element_op, ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(traits.output_element_op, ElementwiseOperation::PASS_THROUGH); + + // Verify specializations + EXPECT_EQ(traits.pipeline_scheduler, PipelineScheduler::DEFAULT); + EXPECT_EQ(traits.num_gemm_k_prefetch_stage, 1); + + // Verify algorithm information + EXPECT_EQ(traits.thread_block_size, 256); + + // Verify tile dimensions + EXPECT_EQ(traits.tile_dims.m, 128); + EXPECT_EQ(traits.tile_dims.n, 128); + EXPECT_EQ(traits.tile_dims.k, 16); + + // Verify A tile transfer info + EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.k0, 2); + EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.m_or_n, 128); + EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.k1, 8); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.k1, 8); + EXPECT_THAT(traits.a_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1)); + EXPECT_THAT(traits.a_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2)); + EXPECT_THAT(traits.a_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2)); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.src_vector_dim, 2); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.src_scalar_per_vector, 8); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8); + EXPECT_TRUE(traits.a_tile_transfer.transfer_params.lds_padding); + EXPECT_FALSE(traits.do_pad_gemm_n.value()); + EXPECT_FALSE(traits.do_pad_gemm_m.value()); + + // Verify B tile transfer info + EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.k0, 2); + EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.m_or_n, 128); + EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.k1, 8); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.k1, 8); + EXPECT_THAT(traits.b_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1)); + EXPECT_THAT(traits.b_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2)); + EXPECT_THAT(traits.b_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2)); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.src_vector_dim, 2); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.src_scalar_per_vector, 8); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8); + EXPECT_TRUE(traits.b_tile_transfer.transfer_params.lds_padding); + + // Verify warp GEMM params + EXPECT_EQ(traits.warp_gemm.gemm_m, 32); + EXPECT_EQ(traits.warp_gemm.gemm_n, 32); + EXPECT_EQ(traits.warp_gemm.m_iter, 4); + EXPECT_EQ(traits.warp_gemm.n_iter, 4); + + // Verify output tile transfer info + EXPECT_EQ(traits.c_tile_transfer.shuffle_params.m_gemms_per_shuffle, 1); + EXPECT_EQ(traits.c_tile_transfer.shuffle_params.n_gemms_per_shuffle, 1); + EXPECT_THAT(traits.c_tile_transfer.thread_cluster_dims, ElementsAre(1, 32, 1, 8)); + EXPECT_EQ(traits.c_tile_transfer.scalar_per_vector, 8); + + EXPECT_EQ(traits.max_transpose_transfer_src_scalar_per_vector, 1); + EXPECT_EQ(traits.max_transpose_transfer_dst_scalar_per_vector, 1); +} + +// Test ConvTraits with DeviceGroupedConvBwdWeight_Wmma_CShuffle TEST_F(ConvTraitsTest, ConvBwdWeightCshuffleWmmaTraitsExtraction) { // Define a concrete instance type with specific template parameters @@ -270,6 +656,9 @@ TEST_F(ConvTraitsTest, ConvBwdWeightCshuffleWmmaV3TraitsExtraction) EXPECT_THAT(traits.c_tile_transfer.thread_cluster_dims, ElementsAre(1, 32, 1, 8)); EXPECT_EQ(traits.c_tile_transfer.scalar_per_vector, 8); + EXPECT_EQ(traits.max_transpose_transfer_src_scalar_per_vector, 1); + EXPECT_EQ(traits.max_transpose_transfer_dst_scalar_per_vector, 1); + // Verify pipeline configuration } @@ -516,6 +905,9 @@ TEST_F(ConvTraitsTest, ConvBwdWeightTwoStageWmmaCshuffleTraitsExtraction) // Verify pipeline configuration EXPECT_EQ(traits.pipeline_scheduler, PipelineScheduler::DEFAULT); EXPECT_EQ(traits.pipeline_version, PipelineVersion::V1); + + EXPECT_EQ(traits.max_transpose_transfer_src_scalar_per_vector, 1); + EXPECT_EQ(traits.max_transpose_transfer_dst_scalar_per_vector, 1); } // Test ConvTraits with DeviceGroupedConvBwdWeight_two_stage_Xdl_CShuffleV3 @@ -640,6 +1032,9 @@ TEST_F(ConvTraitsTest, ConvBwdWeightTwoStageXdlCshuffleTraitsExtraction) // Verify pipeline configuration EXPECT_EQ(traits.pipeline_scheduler, PipelineScheduler::DEFAULT); EXPECT_EQ(traits.pipeline_version, PipelineVersion::V1); + + EXPECT_EQ(traits.max_transpose_transfer_src_scalar_per_vector, 1); + EXPECT_EQ(traits.max_transpose_transfer_dst_scalar_per_vector, 1); } // Test ConvTraits with DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle @@ -1001,6 +1396,9 @@ TEST_F(ConvTraitsTest, ConvBwdWeightXdlCshuffleTraitsExtraction) // Verify pipeline configuration EXPECT_EQ(traits.pipeline_scheduler, PipelineScheduler::DEFAULT); EXPECT_EQ(traits.pipeline_version, PipelineVersion::V1); + + EXPECT_EQ(traits.max_transpose_transfer_src_scalar_per_vector, 1); + EXPECT_EQ(traits.max_transpose_transfer_dst_scalar_per_vector, 1); } // test conv traits device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp diff --git a/experimental/builder/test/impl/conv_algorithm_types.hpp b/experimental/builder/test/impl/conv_algorithm_types.hpp index c4e83293ef..bcf17fd087 100644 --- a/experimental/builder/test/impl/conv_algorithm_types.hpp +++ b/experimental/builder/test/impl/conv_algorithm_types.hpp @@ -54,6 +54,13 @@ struct GridwiseBwdXdlGemm }; static_assert(ckb::GridwiseBwdXdlGemmDescriptor); +struct GridwiseBwdDataXdlGemm +{ + size_t ak1 = 0; + size_t bk1 = 0; + XdlParams xdl_params; +}; + // Describe gridwise WMMA GEMM parameters. struct GridwiseWmmaGemm { @@ -64,6 +71,16 @@ struct GridwiseWmmaGemm size_t n_wmma_per_wave = 0; }; static_assert(ckb::GridwiseWmmaGemmDescriptor); +struct GridwiseWmmaGemmABK1 +{ + size_t ak1 = 0; + size_t bk1 = 0; + size_t m_per_wmma = 0; + size_t n_per_wmma = 0; + size_t m_wmma_per_wave = 0; + size_t n_wmma_per_wave = 0; +}; +static_assert(ckb::GridwiseWmmaGemmDescriptor); struct BlockGemmPipeline { @@ -209,11 +226,21 @@ struct BwdXdlGemm_ GridwiseBwdXdlGemm gridwise_gemm; }; +struct BwdDataXdlGemm_ +{ + GridwiseBwdDataXdlGemm gridwise_gemm; +}; + struct WmmaGemm_ { GridwiseWmmaGemm gridwise_gemm; }; +struct WmmaGemmABK1_ +{ + GridwiseWmmaGemmABK1 gridwise_gemm; +}; + template struct Transfer_ { @@ -231,12 +258,23 @@ struct ConvSpecializationBwdWeight_ ConvSpecialization bwd_weight_specialization; }; +struct ConvSpecializationBwdData_ +{ + ConvSpecialization bwd_data_specialization; +}; + struct Prefetch_ { size_t num_gemm_k_prefetch_stages; PipelineScheduler loop_scheduler; }; +struct GemmPad_ +{ + size_t DoPadGemmM; + size_t DoPadGemmN; +}; + struct TransposeParams_ { size_t max_transpose_transfer_src_scalar_per_vector{1}; @@ -394,10 +432,18 @@ struct ConvAlgorithmTemplate : Components... { result.gridwise_gemm = gemm; } + else if constexpr(std::is_base_of_v) + { + result.gridwise_gemm = gemm; + } else if constexpr(std::is_base_of_v) { result.gridwise_gemm = gemm; } + else if constexpr(std::is_base_of_v) + { + result.gridwise_gemm = gemm; + } else { static_assert(false, "Unrecognized GemmConfig type"); @@ -433,6 +479,14 @@ struct ConvAlgorithmTemplate : Components... return result; } + constexpr auto with_bwd_data_specialization(ConvSpecialization bwd_spec) const + { + static_assert(std::is_base_of_v); + auto result = *this; + result.bwd_data_specialization = bwd_spec; + return result; + } + constexpr auto with_prefetch_config(size_t k_prefetch_stages, PipelineScheduler scheduler) const { static_assert(std::is_base_of_v); @@ -452,6 +506,15 @@ struct ConvAlgorithmTemplate : Components... return result; } + constexpr auto with_gemm_pad_params(size_t doPadGemmN_, size_t doPadGemmM_) const + { + static_assert(std::is_base_of_v); + auto result = *this; + result.DoPadGemmN = doPadGemmN_; + result.DoPadGemmM = doPadGemmM_; + return result; + } + constexpr auto with_num_conv_groups_to_merge(size_t num_groups_to_merge) const { static_assert(std::is_base_of_v); @@ -684,4 +747,35 @@ using ConvAlgorithm_DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffle_V3 = BlockGemm_, MultipleDSpecialization_>; +// Bwd Data algorithm types +using ConvAlgorithm_DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle = + ConvAlgorithmTemplate, + ConvSpecializationBwdData_, + MultipleDSpecialization_, + Prefetch_, + TransposeParams_, + GemmPad_>; + +using ConvAlgorithm_DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle = + ConvAlgorithmTemplate, + ConvSpecializationBwdData_, + GridGemm_, + MultipleDSpecialization_, + Prefetch_>; + +using ConvAlgorithm_DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle_V3 = + ConvAlgorithmTemplate, + ConvSpecializationBwdData_, + BlockGemm_, + MultipleDSpecialization_, + Prefetch_, + TransposeParams_, + GemmPad_>; + } // namespace ck_tile::builder::test diff --git a/experimental/builder/test/test_conv_description.cpp b/experimental/builder/test/test_conv_description.cpp index 4fd8ce867a..40ea364ba9 100644 --- a/experimental/builder/test/test_conv_description.cpp +++ b/experimental/builder/test/test_conv_description.cpp @@ -282,38 +282,39 @@ TEST(ConvDescriptionTest, DefaultInstanceHasDetailedDescription) " ├─ Warp Gemm parameters: \n" " │ ├─ subtile size: 16×16\n" " │ └─ Number of warp gemm iterations: 8×8\n" - " └─ Memory access:\n" - " ├─ A Tile transfer: \n" - " │ ├─ Tile dimensions: 4×256×8×\n" - " │ ├─ The innermost K subdimension size: 8\n" - " │ ├─ Spatial thread distribution over the data tile: 0×1×2\n" - " │ ├─ The order of accessing data tile axes: 0×1×2\n" - " │ ├─ Vectorized memory access axis index (with contiguous memory): 2\n" - " │ ├─ Vector access (GMEM read) instruction size: 2\n" - " │ ├─ Vector access (LDS write) instruction size: 2\n" - " │ └─ LDS data layout padding (to prevent bank conflicts): 2\n" - " ├─ B Tile transfer: \n" - " │ ├─ Tile dimensions: 4×256×8×\n" - " │ ├─ The innermost K subdimension size: 8\n" - " │ ├─ Spatial thread distribution over the data tile: 0×1×2\n" - " │ ├─ The order of accessing data tile axes: 0×1×2\n" - " │ ├─ Vectorized memory access axis index (with contiguous memory): 2\n" - " │ ├─ Vector access (GMEM read) instruction size: 2\n" - " │ ├─ Vector access (LDS write) instruction size: 2\n" - " │ └─ LDS data layout padding (to prevent bank conflicts): 2\n" - " └─ C Tile transfer: \n" - " ├─ Data shuffle (number of gemm instructions per iteration): 1×1\n" - " ├─ Spatial thread distribution used to store data: 1×32×1×8\n" - " ├─ Vector access (GMEM write) instruction size: 2\n" + " ├─ Memory access:\n" + " │ ├─ A Tile transfer: \n" + " │ │ ├─ Tile dimensions: 4×256×8×\n" + " │ │ ├─ The innermost K subdimension size: 8\n" + " │ │ ├─ Spatial thread distribution over the data tile: 0×1×2\n" + " │ │ ├─ The order of accessing data tile axes: 0×1×2\n" + " │ │ ├─ Vectorized memory access axis index (with contiguous memory): 2\n" + " │ │ ├─ Vector access (GMEM read) instruction size: 2\n" + " │ │ ├─ Vector access (LDS write) instruction size: 2\n" + " │ │ └─ LDS data layout padding (to prevent bank conflicts): 2\n" + " │ ├─ B Tile transfer: \n" + " │ │ ├─ Tile dimensions: 4×256×8×\n" + " │ │ ├─ The innermost K subdimension size: 8\n" + " │ │ ├─ Spatial thread distribution over the data tile: 0×1×2\n" + " │ │ ├─ The order of accessing data tile axes: 0×1×2\n" + " │ │ ├─ Vectorized memory access axis index (with contiguous memory): 2\n" + " │ │ ├─ Vector access (GMEM read) instruction size: 2\n" + " │ │ ├─ Vector access (LDS write) instruction size: 2\n" + " │ │ └─ LDS data layout padding (to prevent bank conflicts): 2\n" + " │ └─ C Tile transfer: \n" + " │ ├─ Data shuffle (number of gemm instructions per iteration): 1×1\n" + " │ ├─ Spatial thread distribution used to store data: 1×32×1×8\n" + " │ └─ Vector access (GMEM write) instruction size: 2\n" " ├─ Struct does not contain optional num_gemm_k_prefetch_stage parameter\n" " ├─ Struct does not contain optional max_transpose_transfer_src_scalar_per_vector " "parameter\n" - " ├─ Struct does not contain optional max_transpose_dst_scalar_per_vector parameter\n" + " ├─ Struct does not contain optional max_transpose_transfer_dst_scalar_per_vector " + "parameter\n" " └─ Struct does not contain optional num_groups_to_merge parameter")); } // Test printing of optional parameters num_groups_to_merge, -// nax_transose_transfer_src_scalar_per_vector and max_transpose_dst_scalar_per_vector +// max_transpose_transfer_src_scalar_per_vector and max_transpose_transfer_dst_scalar_per_vector TEST(ConvDescriptionTest, BwdWeightTwoStageWmmaV3DescriptionTest) { using Instance = @@ -390,29 +391,29 @@ TEST(ConvDescriptionTest, BwdWeightTwoStageWmmaV3DescriptionTest) " ├─ Warp Gemm parameters: \n" " │ ├─ subtile size: 32×32\n" " │ └─ Number of warp gemm iterations: 4×4\n" - " └─ Memory access:\n" - " ├─ A Tile transfer: \n" - " │ ├─ Tile dimensions: 2×128×8×\n" - " │ ├─ The innermost K subdimension size: 8\n" - " │ ├─ Spatial thread distribution over the data tile: 1×0×2\n" - " │ ├─ The order of accessing data tile axes: 1×0×2\n" - " │ ├─ Vectorized memory access axis index (with contiguous memory): 2\n" - " │ ├─ Vector access (GMEM read) instruction size: 8\n" - " │ ├─ Vector access (LDS write) instruction size: 8\n" - " │ └─ LDS data layout padding (to prevent bank conflicts): 8\n" - " ├─ B Tile transfer: \n" - " │ ├─ Tile dimensions: 2×128×8×\n" - " │ ├─ The innermost K subdimension size: 8\n" - " │ ├─ Spatial thread distribution over the data tile: 1×0×2\n" - " │ ├─ The order of accessing data tile axes: 1×0×2\n" - " │ ├─ Vectorized memory access axis index (with contiguous memory): 2\n" - " │ ├─ Vector access (GMEM read) instruction size: 8\n" - " │ ├─ Vector access (LDS write) instruction size: 8\n" - " │ └─ LDS data layout padding (to prevent bank conflicts): 8\n" - " └─ C Tile transfer: \n" - " ├─ Data shuffle (number of gemm instructions per iteration): 1×1\n" - " ├─ Spatial thread distribution used to store data: 1×32×1×8\n" - " ├─ Vector access (GMEM write) instruction size: 8\n" + " ├─ Memory access:\n" + " │ ├─ A Tile transfer: \n" + " │ │ ├─ Tile dimensions: 2×128×8×\n" + " │ │ ├─ The innermost K subdimension size: 8\n" + " │ │ ├─ Spatial thread distribution over the data tile: 1×0×2\n" + " │ │ ├─ The order of accessing data tile axes: 1×0×2\n" + " │ │ ├─ Vectorized memory access axis index (with contiguous memory): 2\n" + " │ │ ├─ Vector access (GMEM read) instruction size: 8\n" + " │ │ ├─ Vector access (LDS write) instruction size: 8\n" + " │ │ └─ LDS data layout padding (to prevent bank conflicts): 8\n" + " │ ├─ B Tile transfer: \n" + " │ │ ├─ Tile dimensions: 2×128×8×\n" + " │ │ ├─ The innermost K subdimension size: 8\n" + " │ │ ├─ Spatial thread distribution over the data tile: 1×0×2\n" + " │ │ ├─ The order of accessing data tile axes: 1×0×2\n" + " │ │ ├─ Vectorized memory access axis index (with contiguous memory): 2\n" + " │ │ ├─ Vector access (GMEM read) instruction size: 8\n" + " │ │ ├─ Vector access (LDS write) instruction size: 8\n" + " │ │ └─ LDS data layout padding (to prevent bank conflicts): 8\n" + " │ └─ C Tile transfer: \n" + " │ ├─ Data shuffle (number of gemm instructions per iteration): 1×1\n" + " │ ├─ Spatial thread distribution used to store data: 1×32×1×8\n" + " │ └─ Vector access (GMEM write) instruction size: 8\n" " ├─ Struct does not contain optional num_gemm_k_prefetch_stage parameter\n" " ├─ Max Transpose transfer scr scalar per vector: 1\n" " ├─ Max Transpose dst scalar per vector: 1\n" @@ -494,33 +495,34 @@ TEST(ConvDescriptionTest, BwdWeightWmmaCshuffleV3DescriptionTest) " ├─ Warp Gemm parameters: \n" " │ ├─ subtile size: 32×32\n" " │ └─ Number of warp gemm iterations: 4×4\n" - " └─ Memory access:\n" - " ├─ A Tile transfer: \n" - " │ ├─ Tile dimensions: 2×128×8×\n" - " │ ├─ The innermost K subdimension size: 8\n" - " │ ├─ Spatial thread distribution over the data tile: 1×0×2\n" - " │ ├─ The order of accessing data tile axes: 1×0×2\n" - " │ ├─ Vectorized memory access axis index (with contiguous memory): 2\n" - " │ ├─ Vector access (GMEM read) instruction size: 8\n" - " │ ├─ Vector access (LDS write) instruction size: 8\n" - " │ └─ LDS data layout padding (to prevent bank conflicts): 8\n" - " ├─ B Tile transfer: \n" - " │ ├─ Tile dimensions: 2×128×8×\n" - " │ ├─ The innermost K subdimension size: 8\n" - " │ ├─ Spatial thread distribution over the data tile: 1×0×2\n" - " │ ├─ The order of accessing data tile axes: 1×0×2\n" - " │ ├─ Vectorized memory access axis index (with contiguous memory): 2\n" - " │ ├─ Vector access (GMEM read) instruction size: 8\n" - " │ ├─ Vector access (LDS write) instruction size: 8\n" - " │ └─ LDS data layout padding (to prevent bank conflicts): 8\n" - " └─ C Tile transfer: \n" - " ├─ Data shuffle (number of gemm instructions per iteration): 1×1\n" - " ├─ Spatial thread distribution used to store data: 1×32×1×8\n" - " ├─ Vector access (GMEM write) instruction size: 8\n" + " ├─ Memory access:\n" + " │ ├─ A Tile transfer: \n" + " │ │ ├─ Tile dimensions: 2×128×8×\n" + " │ │ ├─ The innermost K subdimension size: 8\n" + " │ │ ├─ Spatial thread distribution over the data tile: 1×0×2\n" + " │ │ ├─ The order of accessing data tile axes: 1×0×2\n" + " │ │ ├─ Vectorized memory access axis index (with contiguous memory): 2\n" + " │ │ ├─ Vector access (GMEM read) instruction size: 8\n" + " │ │ ├─ Vector access (LDS write) instruction size: 8\n" + " │ │ └─ LDS data layout padding (to prevent bank conflicts): 8\n" + " │ ├─ B Tile transfer: \n" + " │ │ ├─ Tile dimensions: 2×128×8×\n" + " │ │ ├─ The innermost K subdimension size: 8\n" + " │ │ ├─ Spatial thread distribution over the data tile: 1×0×2\n" + " │ │ ├─ The order of accessing data tile axes: 1×0×2\n" + " │ │ ├─ Vectorized memory access axis index (with contiguous memory): 2\n" + " │ │ ├─ Vector access (GMEM read) instruction size: 8\n" + " │ │ ├─ Vector access (LDS write) instruction size: 8\n" + " │ │ └─ LDS data layout padding (to prevent bank conflicts): 8\n" + " │ └─ C Tile transfer: \n" + " │ ├─ Data shuffle (number of gemm instructions per iteration): 1×1\n" + " │ ├─ Spatial thread distribution used to store data: 1×32×1×8\n" + " │ └─ Vector access (GMEM write) instruction size: 8\n" " ├─ Num gemm k prefetch stage: 1\n" " ├─ Struct does not contain optional max_transpose_transfer_src_scalar_per_vector " "parameter\n" - " ├─ Struct does not contain optional max_transpose_dst_scalar_per_vector parameter\n" + " ├─ Struct does not contain optional max_transpose_transfer_dst_scalar_per_vector " + "parameter\n" " └─ Struct does not contain optional num_groups_to_merge parameter")); } diff --git a/experimental/builder/test/utils/ckb_conv_test_configs.hpp b/experimental/builder/test/utils/ckb_conv_test_configs.hpp index e48f1dd6ba..8b7d68f8db 100644 --- a/experimental/builder/test/utils/ckb_conv_test_configs.hpp +++ b/experimental/builder/test/utils/ckb_conv_test_configs.hpp @@ -249,6 +249,26 @@ constexpr Transfer<> Transfer_4x32x1{ }, }; +constexpr GridwiseBwdDataXdlGemm BwdDataGemmParams_Xdl_4x4_per_wave{ + .ak1 = 8, + .bk1 = 8, + .xdl_params = {.m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 4, .n_xdl_per_wave = 4}}; + +constexpr GridwiseBwdDataXdlGemm BwdDataGemmParams_Xdl_4x2_per_wave{ + .ak1 = 8, + .bk1 = 8, + .xdl_params = {.m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 4, .n_xdl_per_wave = 2}}; + +constexpr GridwiseBwdDataXdlGemm BwdDataGemmParams_Xdl_2x2_per_wave{ + .ak1 = 8, + .bk1 = 8, + .xdl_params = {.m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 2, .n_xdl_per_wave = 2}}; + +constexpr GridwiseBwdDataXdlGemm BwdDataGemmParams_Xdl_2x1_per_wave{ + .ak1 = 8, + .bk1 = 8, + .xdl_params = {.m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 2, .n_xdl_per_wave = 1}}; + constexpr GridwiseBwdXdlGemm BwdGemmParams_Xdl_4x4_per_wave{ .k1 = 8, .xdl_params = {.m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 4, .n_xdl_per_wave = 4}}; @@ -283,6 +303,13 @@ constexpr GridwiseWmmaGemm GemmParams_Wmma_2x1_per_wave{ constexpr GridwiseWmmaGemm GemmParams_Wmma_16x16_2x1_per_wave{ .k1 = 8, .m_per_wmma = 16, .n_per_wmma = 16, .m_wmma_per_wave = 2, .n_wmma_per_wave = 1}; +constexpr GridwiseWmmaGemmABK1 GemmParamsABK1_Wmma_16x16_2x1_per_wave{.ak1 = 8, + .bk1 = 8, + .m_per_wmma = 16, + .n_per_wmma = 16, + .m_wmma_per_wave = 2, + .n_wmma_per_wave = 1}; + constexpr ThreadBlock ThreadBlock_256_256x256x32{.block_size = 256, .tile_size = {.m = 256, .n = 256, .k = 32}}; diff --git a/experimental/builder/test/utils/conv_algorithm_type_utils.hpp b/experimental/builder/test/utils/conv_algorithm_type_utils.hpp index 178029e338..cc7dde885a 100644 --- a/experimental/builder/test/utils/conv_algorithm_type_utils.hpp +++ b/experimental/builder/test/utils/conv_algorithm_type_utils.hpp @@ -85,6 +85,15 @@ inline std::string to_string(ThreadBlock t) return oss.str(); } +template <> +inline std::string to_string(GridwiseBwdDataXdlGemm t) +{ + std::ostringstream oss; + oss << t.ak1 << "," << t.bk1 << "," << t.xdl_params.m_per_xdl << "," << t.xdl_params.n_per_xdl + << "," << t.xdl_params.m_xdl_per_wave << "," << t.xdl_params.n_xdl_per_wave; + return oss.str(); +} + template <> inline std::string to_string(GridwiseBwdXdlGemm t) { @@ -112,6 +121,15 @@ inline std::string to_string(GridwiseWmmaGemm t) return oss.str(); } +template <> +inline std::string to_string(GridwiseWmmaGemmABK1 t) +{ + std::ostringstream oss; + oss << t.ak1 << "," << t.bk1 << "," << t.m_per_wmma << "," << t.n_per_wmma << "," + << t.m_wmma_per_wave << "," << t.n_wmma_per_wave; + return oss.str(); +} + template <> inline std::string to_string(BlockGemmPipeline t) { @@ -283,12 +301,24 @@ inline std::string to_string(BwdXdlGemm_ t) return to_string(t.gridwise_gemm); } +template <> +inline std::string to_string(BwdDataXdlGemm_ t) +{ + return to_string(t.gridwise_gemm); +} + template <> inline std::string to_string(WmmaGemm_ t) { return to_string(t.gridwise_gemm); } +template <> +inline std::string to_string(WmmaGemmABK1_ t) +{ + return to_string(t.gridwise_gemm); +} + template inline std::string to_string(Transfer_ t) { @@ -311,6 +341,14 @@ inline std::string to_string(ConvSpecializationBwd return oss.str(); } +template <> +inline std::string to_string(ConvSpecializationBwdData_ t) +{ + std::ostringstream oss; + oss << to_string(t.bwd_data_specialization); + return oss.str(); +} + template <> inline std::string to_string(Prefetch_ t) { @@ -495,4 +533,36 @@ inline std::string to_string +inline std::string to_string( + ConvAlgorithm_DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle t) +{ + std::ostringstream oss; + oss << to_string(static_cast(t)) << "," + << to_string(static_cast(t)) << "," + << to_string(static_cast>(t)); + return oss.str(); +} + +template <> +inline std::string to_string( + ConvAlgorithm_DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle t) +{ + std::ostringstream oss; + oss << to_string(static_cast(t)) << "," << to_string(static_cast(t)) + << "," << to_string(static_cast>(t)); + return oss.str(); +} + +template <> +inline std::string to_string( + ConvAlgorithm_DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle_V3 t) +{ + std::ostringstream oss; + oss << to_string(static_cast(t)) << "," + << to_string(static_cast(t)) << "," + << to_string(static_cast>(t)); + return oss.str(); +} + } // namespace ck_tile::builder::test diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp index 04a4e7f6c8..8c4016c8ab 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp @@ -20,6 +20,11 @@ #include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/io.hpp" +#ifdef CK_EXPERIMENTAL_BUILDER +#include "ck_tile/builder/reflect/description.hpp" +#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp" +#endif + namespace ck { namespace tensor_operation { namespace device { @@ -826,6 +831,24 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle return str.str(); } + +#ifdef CK_EXPERIMENTAL_BUILDER + std::string GetInstanceString() const override + { + static_assert(ck_tile::reflect::HasInstanceTraits, + "Specialization of instance_traits not found. Please check that a " + "specialization exists in file " + "ck_tile/builder/reflect/" + "instance_traits_device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp " + "for the given template parameters."); + return ck_tile::reflect::instance_string(); + } + + std::unique_ptr describe() const override + { + return std::make_unique(GetInstanceString()); + } +#endif }; } // namespace device diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp index dfdfd53725..bcce4ef9ca 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp @@ -28,6 +28,11 @@ #include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#ifdef CK_EXPERIMENTAL_BUILDER +#include "ck_tile/builder/reflect/description.hpp" +#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp" +#endif + namespace ck { namespace tensor_operation { namespace device { @@ -1985,8 +1990,27 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3 "The argument pointer is not an object of " "DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3::Argument structure!"); } + +#ifdef CK_EXPERIMENTAL_BUILDER + std::string GetInstanceString() const override + { + static_assert(ck_tile::reflect::HasInstanceTraits, + "Specialization of instance_traits not found. Please check that a " + "specialization exists in file " + "ck_tile/builder/reflect/" + "instance_traits_device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp " + "for the given template parameters."); + return ck_tile::reflect::instance_string(); + } + + std::unique_ptr describe() const override + { + return std::make_unique(GetInstanceString()); + } +#endif }; } // namespace device } // namespace tensor_operation } // namespace ck + diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp index 16b12cf386..8b71a4fa40 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp @@ -27,6 +27,11 @@ #include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#ifdef CK_EXPERIMENTAL_BUILDER +#include "ck_tile/builder/reflect/description.hpp" +#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle.hpp" +#endif + namespace ck { namespace tensor_operation { namespace device { @@ -1942,6 +1947,24 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 "The argument pointer is not an object of " "DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1::Argument structure!"); } + +#ifdef CK_EXPERIMENTAL_BUILDER + std::string GetInstanceString() const override + { + static_assert(ck_tile::reflect::HasInstanceTraits, + "Specialization of instance_traits not found. Please check that a " + "specialization exists in file " + "ck_tile/builder/reflect/" + "instance_traits_device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle.hpp " + "for the given template parameters."); + return ck_tile::reflect::instance_string(); + } + + std::unique_ptr describe() const override + { + return std::make_unique(GetInstanceString()); + } +#endif }; } // namespace device