#pragma once #include "convolution_kernel_descriptor.hpp" #include "convolution_problem_descriptor.hpp" #include "convolution_implementation_descriptor.hpp" #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp" template struct ConvolutionBuilder; template struct ConvolutionBuilder { public: static constexpr auto GetInstance() { using DataType = typename ProblemDesc::DataType; using AccDataType = std::conditional_t, int32_t, float>; using InLayout = std::tuple_element<0, decltype(GetLayout())>::type; using WeiLayout = std::tuple_element<1,decltype( GetLayout())>::type; using OutLayout = std::tuple_element<2,decltype( GetLayout())>::type; using GroupedConvFwdMultipleABD_Xdl_CShuffleInstance = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< ProblemDesc::NDimSpatial_, InLayout, WeiLayout, decltype(GetMultiDLayout()), OutLayout, DataType, DataType, DataType, AccDataType, typename ProblemDesc::ElementwiseOpDataTypes, DataType, ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough, decltype(GetOutElementwiseOp()), GetConvSpecialization(), ck::tensor_operation::device::GemmSpecialization::MNKPadding, 1, ImplementationDesc::BlockSize_, ImplementationDesc::TileSizes_::At(0), ImplementationDesc::TileSizes_::At(1), ImplementationDesc::TileSizes_::At(2), ImplementationDesc::K1_, ImplementationDesc::K1_, 16, 16, 1, 1, ck::Sequence<4, 8, 1>, ck::Sequence<1, 0, 2>, ck::Sequence<1, 0, 2>, 2, 1, 4, 1, ck::Sequence<4, 8, 1>, ck::Sequence<1, 0, 2>, ck::Sequence<1, 0, 2>, 2, 1, 4, 1, 1, 1, ck::Sequence<1, 32, 1, 8>, 1>; using DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffleInstance = ck::tensor_operation::device::DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle, ck::Sequence<2, 0, 1>, ck::Sequence<1, 0, 2>, 1, 1, 4, false, ck::Sequence<4, 8, 1>, ck::Sequence<2, 0, 1>, ck::Sequence<1, 0, 2>, 1, 1, 4, false, 1, 1, ck::Sequence<1, 8, 1, 8>, 1>; using SelectedInstance = std::conditional_t; return SelectedInstance{}; } std::string GetInstanceName() const { auto str = std::stringstream(); std::map KernelToString{ {Kernel::GroupedConvFwdMultipleABD_Xdl_CShuffle, "GroupedConvFwdMultipleABD_Xdl_CShuffle"}, {Kernel::GroupedConvBwdWeightTwoStage_Xdl_CShuffle, "GroupedConvBwdWeightTwoStage_Xdl_CShuffle"}}; std::map GemmPipelineSchedulerToString{ {GemmPipelineScheduler::Intrawave, "Intrawave"}, {GemmPipelineScheduler::Interwave, "Interwave"}}; std::map GemmPipelineVersionToString{ {GemmPipelineVersion::Naive, "v1"}, {GemmPipelineVersion::ComputeFriendly, "v2"}, {GemmPipelineVersion::MemFriendly, "v3"}, {GemmPipelineVersion::ComputeFriendlyDoubleLDS, "v4"}, {GemmPipelineVersion::ComputeFriendlyDoubleGlobalPrefetch, "v5"}}; std::map MFMAInstructionSizeToString{ {MFMAInstructionSize::M16N16, "16x16"}, {MFMAInstructionSize::M32N32, "32x32"}}; std::map ConvolutionSpecializationToString{ {ConvolutionSpecialization::Default, "Default"}, {ConvolutionSpecialization::Filter1x1Pad0, "Filter1x1Pad0"}, {ConvolutionSpecialization::Filter1x1Stride1Pad0, "Filter1x1Stride1Pad0"}, {ConvolutionSpecialization::Filter3x3, "Filter3x3"}}; std::map DepthwiseOptimizationToString{ {DepthwiseOptimization::X16, "16"}, {DepthwiseOptimization::X8, "8"}, {DepthwiseOptimization::X4, "4"}, {DepthwiseOptimization::X2, "2"}, {DepthwiseOptimization::NotSupported, "1"}}; // clang-format off str << KernelToString[GetKernel()] << "<" << ImplementationDesc::BlockSize_ << ", " << std::get<0>(ImplementationDesc::TileSizes_) << ", " << std::get<1>(ImplementationDesc::TileSizes_) << ", " << std::get<2>(ImplementationDesc::TileSizes_) << ", " << ConvolutionSpecializationToString[ImplementationDesc::ConvolutionSpecialization_] << ", " << ImplementationDesc::K1_ << ", " << MFMAInstructionSizeToString[ImplementationDesc::MFMAInstructionSize_] << ", " << std::get<0>(ImplementationDesc::XdlPerWave_) << ", " << std::get<1>(ImplementationDesc::XdlPerWave_) << ", " << std::get<0>(ImplementationDesc::GlobalTransferVectorSize_) << ", " << std::get<0>(ImplementationDesc::LDSStoreVectorSize_) << ", " << std::get<1>(ImplementationDesc::GlobalTransferVectorSize_) << ", " << std::get<1>(ImplementationDesc::LDSStoreVectorSize_) << ", " << std::get<2>(ImplementationDesc::GlobalTransferVectorSize_) << ", " << GemmPipelineSchedulerToString[ImplementationDesc::GemmPipelineScheduler_] << ", " << GemmPipelineVersionToString[ImplementationDesc::GemmPipelineVersion_] << ", " << DepthwiseOptimizationToString[KernelDesc::DepthwiseOptimization_] << ">"; // clang-format on return str.str(); } private: enum class Kernel { GroupedConvFwdMultipleABD_Xdl_CShuffle, GroupedConvBwdWeightTwoStage_Xdl_CShuffle }; static constexpr Kernel GetKernel() { if constexpr(KernelDesc::GemmImplementationType_ == GemmImplementationType::XDL) { if constexpr(KernelDesc::ConvolutionDirection_ == ConvolutionDirection::Forward) { return Kernel::GroupedConvFwdMultipleABD_Xdl_CShuffle; } else if constexpr(KernelDesc::ConvolutionDirection_ == ConvolutionDirection::BackwardData) { static_assert("Instance not found!"); } else if constexpr(KernelDesc::ConvolutionDirection_ == ConvolutionDirection::BackwardWeight) { return Kernel::GroupedConvBwdWeightTwoStage_Xdl_CShuffle; } else { static_assert("Instance not found!"); } } else { static_assert("Instance not found!"); } } static constexpr auto GetLayout() { if constexpr(ProblemDesc::NDimSpatial_ == 2) { if constexpr(ProblemDesc::ConvolutionLayout_ == ConvolutionLayout::NHWGC_GKYXC_NHWGK) { return std::tuple{}; } else { static_assert("Layout not supported!"); } } else { static_assert("Not supported spatial dim!"); } } static constexpr auto GetMultiDLayout() { return ck::Tuple<>{}; } static constexpr auto GetOutElementwiseOp() { return ck::tensor_operation::element_wise::PassThrough{}; } static constexpr auto GetConvSpecialization() { if constexpr(KernelDesc::ConvolutionDirection_ == ConvolutionDirection::Forward) { return ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; } else if constexpr(KernelDesc::ConvolutionDirection_ == ConvolutionDirection::BackwardWeight) { return ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Default; } else { static_assert("Specialization not found!"); } } };