diff --git a/experimental/convolution_builder/convolution_builder.hpp b/experimental/convolution_builder/convolution_builder.hpp index dc671bfd01..4d711af715 100644 --- a/experimental/convolution_builder/convolution_builder.hpp +++ b/experimental/convolution_builder/convolution_builder.hpp @@ -1,27 +1,28 @@ #pragma once +#include "convolution_kernel_descriptor.hpp" #include "convolution_problem_descriptor.hpp" #include "convolution_implementation_descriptor.hpp" #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp" -template +template struct ConvolutionBuilder; -template -struct ConvolutionBuilder { +template +struct ConvolutionBuilder { public: static constexpr auto GetInstance() { - using DataType = typename Implementation::DataType; + using DataType = typename ProblemDesc::DataType; using AccDataType = std::conditional_t, int32_t, float>; using InLayout = std::tuple_element<0, decltype(GetLayout())>::type; using WeiLayout = std::tuple_element<1,decltype( GetLayout())>::type; using OutLayout = std::tuple_element<2,decltype( GetLayout())>::type; - using GroupedConvFwdMultipleABD_Xdl_CShuffleInstance = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< Implementation::NDimSpatial_, InLayout, WeiLayout, decltype(GetMultiDLayout()), OutLayout, DataType, DataType, DataType, AccDataType, typename Implementation::ElementwiseOpDataTypes, DataType, ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough, decltype(GetOutElementwiseOp()), GetConvSpecialization(), ck::tensor_operation::device::GemmSpecialization::MNKPadding, 1, Implementation::BlockSize_, Implementation::TileSizes_::At(0), Implementation::TileSizes_::At(1), Implementation::TileSizes_::At(2), Implementation::K1_, Implementation::K1_, 16, 16, 1, 1, ck::Sequence<4, 8, 1>, ck::Sequence<1, 0, 2>, ck::Sequence<1, 0, 2>, 2, 1, 4, 1, ck::Sequence<4, 8, 1>, ck::Sequence<1, 0, 2>, ck::Sequence<1, 0, 2>, 2, 1, 4, 1, 1, 1, ck::Sequence<1, 32, 1, 8>, 1>; - using DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffleInstance = ck::tensor_operation::device::DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle, ck::Sequence<2, 0, 1>, ck::Sequence<1, 0, 2>, 1, 1, 4, false, ck::Sequence<4, 8, 1>, ck::Sequence<2, 0, 1>, ck::Sequence<1, 0, 2>, 1, 1, 4, false, 1, 1, ck::Sequence<1, 8, 1, 8>, 1>; + using GroupedConvFwdMultipleABD_Xdl_CShuffleInstance = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< ProblemDesc::NDimSpatial_, InLayout, WeiLayout, decltype(GetMultiDLayout()), OutLayout, DataType, DataType, DataType, AccDataType, typename ProblemDesc::ElementwiseOpDataTypes, DataType, ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough, decltype(GetOutElementwiseOp()), GetConvSpecialization(), ck::tensor_operation::device::GemmSpecialization::MNKPadding, 1, ImplementationDesc::BlockSize_, ImplementationDesc::TileSizes_::At(0), ImplementationDesc::TileSizes_::At(1), ImplementationDesc::TileSizes_::At(2), ImplementationDesc::K1_, ImplementationDesc::K1_, 16, 16, 1, 1, ck::Sequence<4, 8, 1>, ck::Sequence<1, 0, 2>, ck::Sequence<1, 0, 2>, 2, 1, 4, 1, ck::Sequence<4, 8, 1>, ck::Sequence<1, 0, 2>, ck::Sequence<1, 0, 2>, 2, 1, 4, 1, 1, 1, ck::Sequence<1, 32, 1, 8>, 1>; + using DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffleInstance = ck::tensor_operation::device::DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle, ck::Sequence<2, 0, 1>, ck::Sequence<1, 0, 2>, 1, 1, 4, false, ck::Sequence<4, 8, 1>, ck::Sequence<2, 0, 1>, ck::Sequence<1, 0, 2>, 1, 1, 4, false, 1, 1, ck::Sequence<1, 8, 1, 8>, 1>; - using SelectedInstance = std::conditional_t; + using SelectedInstance = std::conditional_t; return SelectedInstance{}; } @@ -64,23 +65,23 @@ public: // clang-format off str << KernelToString[GetKernel()] << "<" - << Implementation::BlockSize_ << ", " - << std::get<0>(Implementation::TileSizes_) << ", " - << std::get<1>(Implementation::TileSizes_) << ", " - << std::get<2>(Implementation::TileSizes_) << ", " - << ConvolutionSpecializationToString[Implementation::ConvolutionSpecialization_] << ", " - << Implementation::K1_ << ", " - << MFMAInstructionSizeToString[Implementation::MFMAInstructionSize_] << ", " - << std::get<0>(Implementation::XdlPerWave_) << ", " - << std::get<1>(Implementation::XdlPerWave_) << ", " - << std::get<0>(Implementation::GlobalTransferVectorSize_) << ", " - << std::get<0>(Implementation::LDSStoreVectorSize_) << ", " - << std::get<1>(Implementation::GlobalTransferVectorSize_) << ", " - << std::get<1>(Implementation::LDSStoreVectorSize_) << ", " - << std::get<2>(Implementation::GlobalTransferVectorSize_) << ", " - << GemmPipelineSchedulerToString[Problem::GemmPipelineScheduler_] << ", " - << GemmPipelineVersionToString[Problem::GemmPipelineVersion_] << ", " - << MergedGroupsToString[Problem::MergedGroups_] << ">"; + << ImplementationDesc::BlockSize_ << ", " + << std::get<0>(ImplementationDesc::TileSizes_) << ", " + << std::get<1>(ImplementationDesc::TileSizes_) << ", " + << std::get<2>(ImplementationDesc::TileSizes_) << ", " + << ConvolutionSpecializationToString[ImplementationDesc::ConvolutionSpecialization_] << ", " + << ImplementationDesc::K1_ << ", " + << MFMAInstructionSizeToString[ImplementationDesc::MFMAInstructionSize_] << ", " + << std::get<0>(ImplementationDesc::XdlPerWave_) << ", " + << std::get<1>(ImplementationDesc::XdlPerWave_) << ", " + << std::get<0>(ImplementationDesc::GlobalTransferVectorSize_) << ", " + << std::get<0>(ImplementationDesc::LDSStoreVectorSize_) << ", " + << std::get<1>(ImplementationDesc::GlobalTransferVectorSize_) << ", " + << std::get<1>(ImplementationDesc::LDSStoreVectorSize_) << ", " + << std::get<2>(ImplementationDesc::GlobalTransferVectorSize_) << ", " + << GemmPipelineSchedulerToString[KernelDesc::GemmPipelineScheduler_] << ", " + << GemmPipelineVersionToString[KernelDesc::GemmPipelineVersion_] << ", " + << MergedGroupsToString[KernelDesc::MergedGroups_] << ">"; // clang-format on return str.str(); @@ -93,12 +94,12 @@ private: }; static constexpr Kernel GetKernel() { - if constexpr(Problem::GemmImplementationType_ == GemmImplementationType::XDL) { - if constexpr(Problem::ConvolutionDirection_ == ConvolutionDirection::Forward) { + if constexpr(KernelDesc::GemmImplementationType_ == GemmImplementationType::XDL) { + if constexpr(KernelDesc::ConvolutionDirection_ == ConvolutionDirection::Forward) { return Kernel::GroupedConvFwdMultipleABD_Xdl_CShuffle; - } else if constexpr(Problem::ConvolutionDirection_ == ConvolutionDirection::BackwardData) { + } else if constexpr(KernelDesc::ConvolutionDirection_ == ConvolutionDirection::BackwardData) { static_assert("Instance not found!"); - } else if constexpr(Problem::ConvolutionDirection_ == ConvolutionDirection::BackwardWeight) { + } else if constexpr(KernelDesc::ConvolutionDirection_ == ConvolutionDirection::BackwardWeight) { return Kernel::GroupedConvBwdWeightTwoStage_Xdl_CShuffle; } else { static_assert("Instance not found!"); @@ -109,8 +110,8 @@ private: } static constexpr auto GetLayout() { - if constexpr(Implementation::NDimSpatial_ == 2) { - if constexpr(Implementation::ConvolutionLayout_ == ConvolutionLayout::NHWGC_GKYXC_NHWGK) { + if constexpr(ProblemDesc::NDimSpatial_ == 2) { + if constexpr(ProblemDesc::ConvolutionLayout_ == ConvolutionLayout::NHWGC_GKYXC_NHWGK) { return std::tuple{}; } else { static_assert("Layout not supported!"); @@ -129,9 +130,9 @@ private: } static constexpr auto GetConvSpecialization() { - if constexpr(Problem::ConvolutionDirection_ == ConvolutionDirection::Forward) { + if constexpr(KernelDesc::ConvolutionDirection_ == ConvolutionDirection::Forward) { return ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; - } else if constexpr(Problem::ConvolutionDirection_ == ConvolutionDirection::BackwardWeight) { + } else if constexpr(KernelDesc::ConvolutionDirection_ == ConvolutionDirection::BackwardWeight) { return ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Default; } else { static_assert("Specialization not found!"); diff --git a/experimental/convolution_builder/convolution_example.cpp b/experimental/convolution_builder/convolution_example.cpp index 96a2955507..0149c23be3 100644 --- a/experimental/convolution_builder/convolution_example.cpp +++ b/experimental/convolution_builder/convolution_example.cpp @@ -4,19 +4,19 @@ #include "convolution_builder.hpp" -// Example of problem description for Forward Conv with default settings +// Example of kernel description for Forward Conv with default settings struct GroupedConvFwdXdlImplicitGemm : public GroupedConvBaseXdlV1 { static constexpr ConvolutionDirection ConvolutionDirection_ = ConvolutionDirection::Forward; static constexpr ElementwiseOperation ElementwiseOperation_ = ElementwiseOperation::Bias; }; -// Example of problem description for Backward Weight Conv with default settings and Split K Two Stage +// Example of kernel description for Backward Weight Conv with default settings and Split K Two Stage struct GroupedConvBwdWeightXdlImplicitGemmTwoStage : public GroupedConvBaseXdlV1 { static constexpr ConvolutionDirection ConvolutionDirection_ = ConvolutionDirection::BackwardWeight; static constexpr SplitKSupport SplitKSupport_ = SplitKSupport::SupportedTwoStage; }; -struct ImplementationDescriptor : public NHWCImplementationBaseV1, public BF16ImplementationBaseV1 { +struct Implementation16x16 : ImplementationDefaultV1 { static constexpr ck::index_t BlockSize_ = 64; static constexpr auto TileSizes_ = std::make_tuple(16, 16, 32); static constexpr ck::index_t K1_ = 8; @@ -26,10 +26,12 @@ struct ImplementationDescriptor : public NHWCImplementationBaseV1, public BF16Im static constexpr auto LDSStoreVectorSize_ = std::make_tuple(4, 4); }; +struct ProblemBF16NHWGC : public BF16ProblemBaseV1, public NHWGCProblemBaseV1 {}; + int main () { - ConvolutionBuilder builder_fwd; + ConvolutionBuilder builder_fwd; std::cout << builder_fwd.GetInstanceName() << std::endl; - ConvolutionBuilder builder_bwd_weight_two_stage; + ConvolutionBuilder builder_bwd_weight_two_stage; std::cout << builder_bwd_weight_two_stage.GetInstanceName() << std::endl; return 0; } diff --git a/experimental/convolution_builder/convolution_implementation_descriptor.hpp b/experimental/convolution_builder/convolution_implementation_descriptor.hpp index 27eead016d..c136147076 100644 --- a/experimental/convolution_builder/convolution_implementation_descriptor.hpp +++ b/experimental/convolution_builder/convolution_implementation_descriptor.hpp @@ -19,11 +19,6 @@ enum class ConvolutionSpecialization { Filter3x3 }; -enum class ConvolutionLayout { - NHWGC_GKYXC_NHWGK, - NGCHW_GKCYX_NGKHW -}; - enum class MFMAInstructionSize { M16N16, M32N32 @@ -33,11 +28,7 @@ enum class MFMAInstructionSize { template concept ImplementationDescriptorV1 = requires { {T::ImplementationDescriptorVersion_} -> std::convertible_to; - {T::NDimSpatial_} -> std::convertible_to; - typename T::DataType; - typename T::ElementwiseOpDataTypes; {T::ConvolutionSpecialization_} -> std::convertible_to; - {T::ConvolutionLayout_} -> std::convertible_to; {T::BlockSize_} -> std::convertible_to; {T::TileSizes_} -> std::convertible_to>; {T::K1_} -> std::convertible_to; @@ -47,36 +38,7 @@ concept ImplementationDescriptorV1 = requires { {T::LDSStoreVectorSize_} -> std::convertible_to>; } && (T::ImplementationDescriptorVersion_ == ImplementationDescriptorVersion::V1); -struct ImplementationBaseV1 { +struct ImplementationDefaultV1 { static constexpr ImplementationDescriptorVersion ImplementationDescriptorVersion_ = ImplementationDescriptorVersion::V1; - using DataType = ck::bhalf_t; - using ElementwiseOpDataTypes = ck::Tuple<>; static constexpr ConvolutionSpecialization ConvolutionSpecialization_ = ConvolutionSpecialization::Default; }; - -struct BF16ImplementationBaseV1 : public ImplementationBaseV1 { - using DataType = ck::bhalf_t; -}; - -struct F32ImplementationBaseV1 : public ImplementationBaseV1 { - using DataType = float; -}; - -struct F16ImplementationBaseV1 : public ImplementationBaseV1 { - using DataType = ck::half_t; -}; - -struct NWCImplementationBaseV1 : public ImplementationBaseV1 { - static constexpr int NDimSpatial_ = 1; - static constexpr ConvolutionLayout ConvolutionLayout_ = ConvolutionLayout::NHWGC_GKYXC_NHWGK; -}; - -struct NHWCImplementationBaseV1 : public ImplementationBaseV1 { - static constexpr int NDimSpatial_ = 2; - static constexpr ConvolutionLayout ConvolutionLayout_ = ConvolutionLayout::NHWGC_GKYXC_NHWGK; -}; - -struct NDHWCImplementationBaseV1 : public ImplementationBaseV1 { - static constexpr int NDimSpatial_ = 3; - static constexpr ConvolutionLayout ConvolutionLayout_ = ConvolutionLayout::NHWGC_GKYXC_NHWGK; -}; diff --git a/experimental/convolution_builder/convolution_kernel_descriptor.hpp b/experimental/convolution_builder/convolution_kernel_descriptor.hpp new file mode 100644 index 0000000000..42b7752848 --- /dev/null +++ b/experimental/convolution_builder/convolution_kernel_descriptor.hpp @@ -0,0 +1,116 @@ +#pragma once +#include + +enum class KernelDescriptorVersion +{ + V1 +}; + +enum class GemmImplementationType +{ + XDL, + WMMA, + DL +}; + +enum class ConvolutionDirection +{ + Forward, + BackwardData, + BackwardWeight +}; + + +enum class GemmPipelineVersion +{ + V1, + V2, + V3, + V4, + V5 +}; + +enum class GemmPipelineScheduler +{ + Intrawave, + Interwave +}; + +enum class SplitKSupport +{ + Supported, + SupportedTwoStage, + NotSupported +}; + +enum class MergedGroups +{ + X16, + X8, + X4, + X2, + NotSupported +}; + +enum class LargeTensorSupport +{ + Supported, + SplitBatch, + NotSupported +}; + +enum class ImplementationType +{ + ExplicitDefault, + ExplicitMPadding, + ExplicitNPadding, + ExplicitKPadding, + ExplicitMNPadding, + ExplicitMKPadding, + ExplicitNKPadding, + ExplicitMNKPadding, + Implicit +}; + +enum class ElementwiseOperation { + Bias, + BiasClamp, + Bilinear, + Clamp, + Scale, + PassThrough +}; + + +template +concept KernelDescriptorV1 = requires { + {T::KernelDescriptorVersion_} -> std::convertible_to; + {T::GemmImplementationType_} -> std::convertible_to; + {T::ConvolutionDirection_} -> std::convertible_to; + {T::GemmPipelineVersion_} -> std::convertible_to; + {T::GemmPipelineScheduler_} -> std::convertible_to; + {T::SplitKSupport_} -> std::convertible_to; + {T::MergedGroups_} -> std::convertible_to; + {T::LargeTensorSupport_} -> std::convertible_to; + {T::ImplementationType_} -> std::convertible_to; + {T::ElementwiseOperation_} -> std::convertible_to; +} && (T::KernelDescriptorVersion_ == KernelDescriptorVersion::V1); + +struct GroupedConvBase { + static constexpr GemmPipelineVersion GemmPipelineVersion_ = GemmPipelineVersion::V1; + static constexpr GemmPipelineScheduler GemmPipelineScheduler_ = GemmPipelineScheduler::Intrawave; + static constexpr SplitKSupport SplitKSupport_ = SplitKSupport::NotSupported; + static constexpr MergedGroups MergedGroups_ = MergedGroups::NotSupported; + static constexpr LargeTensorSupport LargeTensorSupport_ = LargeTensorSupport::NotSupported; + static constexpr ImplementationType ImplementationType_ = ImplementationType::Implicit; + static constexpr ElementwiseOperation ElementwiseOperation_ = ElementwiseOperation::PassThrough; +}; + +struct GroupedConvBaseXdl : public GroupedConvBase { + static constexpr GemmImplementationType GemmImplementationType_ = GemmImplementationType::XDL; +}; + +struct GroupedConvBaseXdlV1 : public GroupedConvBaseXdl { + static constexpr KernelDescriptorVersion KernelDescriptorVersion_ = KernelDescriptorVersion::V1; +}; + diff --git a/experimental/convolution_builder/convolution_problem_descriptor.hpp b/experimental/convolution_builder/convolution_problem_descriptor.hpp index 234baf8eb2..db29c33b8c 100644 --- a/experimental/convolution_builder/convolution_problem_descriptor.hpp +++ b/experimental/convolution_builder/convolution_problem_descriptor.hpp @@ -1,116 +1,58 @@ #pragma once #include +#include "ck/utility/data_type.hpp" +#include "ck/utility/sequence.hpp" +#include "ck/utility/tuple.hpp" +#include "ck/ck.hpp" + enum class ProblemDescriptorVersion { V1 }; -enum class GemmImplementationType -{ - XDL, - WMMA, - DL +enum class ConvolutionLayout { + NHWGC_GKYXC_NHWGK, + NGCHW_GKCYX_NGKHW }; -enum class ConvolutionDirection -{ - Forward, - BackwardData, - BackwardWeight -}; - - -enum class GemmPipelineVersion -{ - V1, - V2, - V3, - V4, - V5 -}; - -enum class GemmPipelineScheduler -{ - Intrawave, - Interwave -}; - -enum class SplitKSupport -{ - Supported, - SupportedTwoStage, - NotSupported -}; - -enum class MergedGroups -{ - X16, - X8, - X4, - X2, - NotSupported -}; - -enum class LargeTensorSupport -{ - Supported, - SplitBatch, - NotSupported -}; - -enum class ImplementationType -{ - ExplicitDefault, - ExplicitMPadding, - ExplicitNPadding, - ExplicitKPadding, - ExplicitMNPadding, - ExplicitMKPadding, - ExplicitNKPadding, - ExplicitMNKPadding, - Implicit -}; - -enum class ElementwiseOperation { - Bias, - BiasClamp, - Bilinear, - Clamp, - Scale, - PassThrough -}; - - template concept ProblemDescriptorV1 = requires { {T::ProblemDescriptorVersion_} -> std::convertible_to; - {T::GemmImplementationType_} -> std::convertible_to; - {T::ConvolutionDirection_} -> std::convertible_to; - {T::GemmPipelineVersion_} -> std::convertible_to; - {T::GemmPipelineScheduler_} -> std::convertible_to; - {T::SplitKSupport_} -> std::convertible_to; - {T::MergedGroups_} -> std::convertible_to; - {T::LargeTensorSupport_} -> std::convertible_to; - {T::ImplementationType_} -> std::convertible_to; - {T::ElementwiseOperation_} -> std::convertible_to; + {T::NDimSpatial_} -> std::convertible_to; + typename T::DataType; + typename T::ElementwiseOpDataTypes; + {T::ConvolutionLayout_} -> std::convertible_to; } && (T::ProblemDescriptorVersion_ == ProblemDescriptorVersion::V1); -struct GroupedConvBase { - static constexpr GemmPipelineVersion GemmPipelineVersion_ = GemmPipelineVersion::V1; - static constexpr GemmPipelineScheduler GemmPipelineScheduler_ = GemmPipelineScheduler::Intrawave; - static constexpr SplitKSupport SplitKSupport_ = SplitKSupport::NotSupported; - static constexpr MergedGroups MergedGroups_ = MergedGroups::NotSupported; - static constexpr LargeTensorSupport LargeTensorSupport_ = LargeTensorSupport::NotSupported; - static constexpr ImplementationType ImplementationType_ = ImplementationType::Implicit; - static constexpr ElementwiseOperation ElementwiseOperation_ = ElementwiseOperation::PassThrough; -}; - -struct GroupedConvBaseXdl : public GroupedConvBase { - static constexpr GemmImplementationType GemmImplementationType_ = GemmImplementationType::XDL; -}; - -struct GroupedConvBaseXdlV1 : public GroupedConvBaseXdl { +struct ProblemBaseV1 { static constexpr ProblemDescriptorVersion ProblemDescriptorVersion_ = ProblemDescriptorVersion::V1; + using ElementwiseOpDataTypes = ck::Tuple<>; }; +struct BF16ProblemBaseV1 : public ProblemBaseV1 { + using DataType = ck::bhalf_t; +}; + +struct F32ProblemBaseV1 : public ProblemBaseV1 { + using DataType = float; +}; + +struct F16ProblemBaseV1 : public ProblemBaseV1 { + using DataType = ck::half_t; +}; + +struct NWGCProblemBaseV1 : public ProblemBaseV1 { + static constexpr int NDimSpatial_ = 1; + static constexpr ConvolutionLayout ConvolutionLayout_ = ConvolutionLayout::NHWGC_GKYXC_NHWGK; +}; + +struct NHWGCProblemBaseV1 : public ProblemBaseV1 { + static constexpr int NDimSpatial_ = 2; + static constexpr ConvolutionLayout ConvolutionLayout_ = ConvolutionLayout::NHWGC_GKYXC_NHWGK; +}; + +struct NDHWGCProblemBaseV1 : public ProblemBaseV1 { + static constexpr int NDimSpatial_ = 3; + static constexpr ConvolutionLayout ConvolutionLayout_ = ConvolutionLayout::NHWGC_GKYXC_NHWGK; +};