From 55a2a4cac72987bd65b518aec83bfbd7e909f016 Mon Sep 17 00:00:00 2001 From: Bartlomiej Kocot Date: Thu, 7 Aug 2025 10:28:54 +0000 Subject: [PATCH] Fixes problem descritpor --- .../convolution_builder/CMakeLists.txt | 6 ++- .../convolution_builder.hpp | 54 +++++++++++++------ .../convolution_example.cpp | 14 ++--- .../convolution_implementation_descriptor.hpp | 5 ++ ...hpp => convolution_problem_descriptor.hpp} | 16 ++++-- 5 files changed, 70 insertions(+), 25 deletions(-) create mode 100644 experimental/convolution_builder/convolution_implementation_descriptor.hpp rename experimental/convolution_builder/{convolution_solution_descriptor.hpp => convolution_problem_descriptor.hpp} (85%) diff --git a/experimental/convolution_builder/CMakeLists.txt b/experimental/convolution_builder/CMakeLists.txt index f7e2ddc5d8..42646313e6 100644 --- a/experimental/convolution_builder/CMakeLists.txt +++ b/experimental/convolution_builder/CMakeLists.txt @@ -1,3 +1,7 @@ set(CMAKE_CXX_STANDARD 20) -add_executable(convolution_example convolution_example.cpp) \ No newline at end of file +add_executable(convolution_example convolution_example.cpp) + +set(EXAMPLE_COMPILE_OPTIONS) +list(APPEND EXAMPLE_COMPILE_OPTIONS -g -fverbose-asm --save-temps -Wno-gnu-line-marker) +target_compile_options(convolution_example PRIVATE ${EXAMPLE_COMPILE_OPTIONS}) \ No newline at end of file diff --git a/experimental/convolution_builder/convolution_builder.hpp b/experimental/convolution_builder/convolution_builder.hpp index 074a39c0c2..0e6496a020 100644 --- a/experimental/convolution_builder/convolution_builder.hpp +++ b/experimental/convolution_builder/convolution_builder.hpp @@ -1,31 +1,55 @@ #pragma once -#include "convolution_solution_descriptor.hpp" +#include "convolution_problem_descriptor.hpp" #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp" -template -struct ConvolutionBuilder { +template +struct ConvolutionBuilder; - static constexpr auto GetSolution() { - if constexpr(Solution::GemmImplementationType_ == GemmImplementationType::XDL) { - if constexpr(Solution::ConvolutionDirection_ == ConvolutionDirection::Forward) { - return ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< 2, ck::tensor_layout::convolution::NHWGC, ck::tensor_layout::convolution::GKYXC, ck::Tuple<>, ck::tensor_layout::convolution::NHWGK, float, float, float, float, ck::Tuple<>, float, ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::device::ConvolutionForwardSpecialization::Default, ck::tensor_operation::device::GemmSpecialization::MNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, ck::Sequence<4, 64, 1>, ck::Sequence<1, 0, 2>, ck::Sequence<1, 0, 2>, 2, 8, 8, 1, ck::Sequence<4, 64, 1>, ck::Sequence<1, 0, 2>, ck::Sequence<1, 0, 2>, 2, 8, 8, 1, 1, 1, ck::Sequence<1, 32, 1, 8>, 8>{}; - } else if constexpr(Solution::ConvolutionDirection_ == ConvolutionDirection::BackwardData) { - return std::tuple<>{}; - } else if constexpr(Solution::ConvolutionDirection_ == ConvolutionDirection::BackwardWeight) { - return ck::tensor_operation::device::DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle<2, ck::tensor_layout::convolution::NHWGC, ck::tensor_layout::convolution::GKYXC, ck::tensor_layout::convolution::NHWGK, float, float, float, float, ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Default, 64, 16, 16, 32, 8, 16, 16, 1, 1, ck::Sequence<4, 8, 1>, ck::Sequence<2, 0, 1>, ck::Sequence<1, 0, 2>, 1, 1, 4, false, ck::Sequence<4, 8, 1>, ck::Sequence<2, 0, 1>, ck::Sequence<1, 0, 2>, 1, 1, 4, false, 1, 1, ck::Sequence<1, 8, 1, 8>, 1>{}; +template +struct ConvolutionBuilder { + + enum class Kernel { + GroupedConvFwdMultipleABD_Xdl_CShuffle, + GroupedConvBwdWeightTwoStage_Xdl_CShuffle + }; + + static constexpr Kernel GetKernel() { + if constexpr(Problem::GemmImplementationType_ == GemmImplementationType::XDL) { + if constexpr(Problem::ConvolutionDirection_ == ConvolutionDirection::Forward) { + return Kernel::GroupedConvFwdMultipleABD_Xdl_CShuffle; + } else if constexpr(Problem::ConvolutionDirection_ == ConvolutionDirection::BackwardData) { + static_assert("Instance not found!"); + } else if constexpr(Problem::ConvolutionDirection_ == ConvolutionDirection::BackwardWeight) { + return Kernel::GroupedConvBwdWeightTwoStage_Xdl_CShuffle; } else { - return std::tuple<>{}; + static_assert("Instance not found!"); } } else { - return std::tuple<>{}; + static_assert("Instance not found!"); } } + static constexpr auto GetInstance() { + using FwdInstance = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< 2, ck::tensor_layout::convolution::NHWGC, ck::tensor_layout::convolution::GKYXC, ck::Tuple<>, ck::tensor_layout::convolution::NHWGK, float, float, float, float, ck::Tuple<>, float, ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::device::ConvolutionForwardSpecialization::Default, ck::tensor_operation::device::GemmSpecialization::MNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, ck::Sequence<4, 64, 1>, ck::Sequence<1, 0, 2>, ck::Sequence<1, 0, 2>, 2, 8, 8, 1, ck::Sequence<4, 64, 1>, ck::Sequence<1, 0, 2>, ck::Sequence<1, 0, 2>, 2, 8, 8, 1, 1, 1, ck::Sequence<1, 32, 1, 8>, 8>; + using BwdWeightInstance = ck::tensor_operation::device::DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle<2, ck::tensor_layout::convolution::NHWGC, ck::tensor_layout::convolution::GKYXC, ck::tensor_layout::convolution::NHWGK, float, float, float, float, ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Default, 64, 16, 16, 32, 8, 16, 16, 1, 1, ck::Sequence<4, 8, 1>, ck::Sequence<2, 0, 1>, ck::Sequence<1, 0, 2>, 1, 1, 4, false, ck::Sequence<4, 8, 1>, ck::Sequence<2, 0, 1>, ck::Sequence<1, 0, 2>, 1, 1, 4, false, 1, 1, ck::Sequence<1, 8, 1, 8>, 1>; + + using SelectedInstance = std::conditional_t; + return SelectedInstance{}; + } - std::string GetKernelName() const + std::string GetInstanceName() const { - return GetSolution().GetTypeString(); + if constexpr(GetKernel() == Kernel::GroupedConvFwdMultipleABD_Xdl_CShuffle) { + return "GroupedConvFwdMultipleABD_Xdl_CShuffle"; + } else if constexpr (GetKernel() == Kernel::GroupedConvBwdWeightTwoStage_Xdl_CShuffle) { + return "GroupedConvBwdWeightTwoStage_Xdl_CShuffle"; + } else { + return "Not found."; + } + // const auto instance = GetInstance(); + // return instance.GetTypeString(); } }; + diff --git a/experimental/convolution_builder/convolution_example.cpp b/experimental/convolution_builder/convolution_example.cpp index b60522c49b..c092b8a1f7 100644 --- a/experimental/convolution_builder/convolution_example.cpp +++ b/experimental/convolution_builder/convolution_example.cpp @@ -4,19 +4,21 @@ #include "convolution_builder.hpp" -// Example of solution description for Forward Conv with default settings -struct GroupedConvFwdXdlImplicitGemm : public GroupedConvBaseXdl { +// Example of problem description for Forward Conv with default settings +struct GroupedConvFwdXdlImplicitGemm : public GroupedConvBaseXdlV1 { static constexpr ConvolutionDirection ConvolutionDirection_ = ConvolutionDirection::Forward; }; -// Example of solution description for Backward Weight Conv with default settings and Split K Two Stage -struct GroupedConvBwdWeightXdlImplicitGemm : public GroupedConvBaseXdl { +// Example of problem description for Backward Weight Conv with default settings and Split K Two Stage +struct GroupedConvBwdWeightXdlImplicitGemmTwoStage : public GroupedConvBaseXdlV1 { static constexpr ConvolutionDirection ConvolutionDirection_ = ConvolutionDirection::BackwardWeight; static constexpr SplitKSupport SplitKSupport_ = SplitKSupport::SupportedTwoStage; }; int main () { - ConvolutionBuilder builder; - std::cout << builder.GetKernelName() << std::endl; + ConvolutionBuilder builder_fwd; + std::cout << builder_fwd.GetInstanceName() << std::endl; + ConvolutionBuilder builder_bwd_weight_two_stage; + std::cout << builder_bwd_weight_two_stage.GetInstanceName() << std::endl; return 0; } diff --git a/experimental/convolution_builder/convolution_implementation_descriptor.hpp b/experimental/convolution_builder/convolution_implementation_descriptor.hpp new file mode 100644 index 0000000000..5003e9ea84 --- /dev/null +++ b/experimental/convolution_builder/convolution_implementation_descriptor.hpp @@ -0,0 +1,5 @@ +#pragma once +#include + +template +concept \ No newline at end of file diff --git a/experimental/convolution_builder/convolution_solution_descriptor.hpp b/experimental/convolution_builder/convolution_problem_descriptor.hpp similarity index 85% rename from experimental/convolution_builder/convolution_solution_descriptor.hpp rename to experimental/convolution_builder/convolution_problem_descriptor.hpp index 1bfb174600..234baf8eb2 100644 --- a/experimental/convolution_builder/convolution_solution_descriptor.hpp +++ b/experimental/convolution_builder/convolution_problem_descriptor.hpp @@ -1,7 +1,11 @@ #pragma once -#include #include +enum class ProblemDescriptorVersion +{ + V1 +}; + enum class GemmImplementationType { XDL, @@ -79,7 +83,8 @@ enum class ElementwiseOperation { template -concept SolutionDescriptorV1 = requires { +concept ProblemDescriptorV1 = requires { + {T::ProblemDescriptorVersion_} -> std::convertible_to; {T::GemmImplementationType_} -> std::convertible_to; {T::ConvolutionDirection_} -> std::convertible_to; {T::GemmPipelineVersion_} -> std::convertible_to; @@ -89,7 +94,7 @@ concept SolutionDescriptorV1 = requires { {T::LargeTensorSupport_} -> std::convertible_to; {T::ImplementationType_} -> std::convertible_to; {T::ElementwiseOperation_} -> std::convertible_to; -}; +} && (T::ProblemDescriptorVersion_ == ProblemDescriptorVersion::V1); struct GroupedConvBase { static constexpr GemmPipelineVersion GemmPipelineVersion_ = GemmPipelineVersion::V1; @@ -104,3 +109,8 @@ struct GroupedConvBase { struct GroupedConvBaseXdl : public GroupedConvBase { static constexpr GemmImplementationType GemmImplementationType_ = GemmImplementationType::XDL; }; + +struct GroupedConvBaseXdlV1 : public GroupedConvBaseXdl { + static constexpr ProblemDescriptorVersion ProblemDescriptorVersion_ = ProblemDescriptorVersion::V1; +}; +