Conv builder and solution descriptor

This commit is contained in:
Bartlomiej Kocot
2025-08-06 13:45:02 -04:00
parent 594858dd6e
commit 4d2ecedab2
5 changed files with 164 additions and 1 deletions

View File

@@ -706,4 +706,5 @@ rocm_create_package(
)
# Add the experimental subdirectory
add_subdirectory(experimental/gemm_builder)
add_subdirectory(experimental/gemm_builder)
add_subdirectory(experimental/convolution_builder)

View File

@@ -0,0 +1,3 @@
set(CMAKE_CXX_STANDARD 20)
add_executable(convolution_example convolution_example.cpp)

View File

@@ -0,0 +1,31 @@
#pragma once
#include "convolution_solution_descriptor.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp"
template<SolutionDescriptorV1 Solution>
struct ConvolutionBuilder {
static constexpr auto GetSolution() {
if constexpr(Solution::GemmImplementationType_ == GemmImplementationType::XDL) {
if constexpr(Solution::ConvolutionDirection_ == ConvolutionDirection::Forward) {
return ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< 2, ck::tensor_layout::convolution::NHWGC, ck::tensor_layout::convolution::GKYXC, ck::Tuple<>, ck::tensor_layout::convolution::NHWGK, float, float, float, float, ck::Tuple<>, float, ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::device::ConvolutionForwardSpecialization::Default, ck::tensor_operation::device::GemmSpecialization::MNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, ck::Sequence<4, 64, 1>, ck::Sequence<1, 0, 2>, ck::Sequence<1, 0, 2>, 2, 8, 8, 1, ck::Sequence<4, 64, 1>, ck::Sequence<1, 0, 2>, ck::Sequence<1, 0, 2>, 2, 8, 8, 1, 1, 1, ck::Sequence<1, 32, 1, 8>, 8>{};
} else if constexpr(Solution::ConvolutionDirection_ == ConvolutionDirection::BackwardData) {
return std::tuple<>{};
} else if constexpr(Solution::ConvolutionDirection_ == ConvolutionDirection::BackwardWeight) {
return ck::tensor_operation::device::DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle<2, ck::tensor_layout::convolution::NHWGC, ck::tensor_layout::convolution::GKYXC, ck::tensor_layout::convolution::NHWGK, float, float, float, float, ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Default, 64, 16, 16, 32, 8, 16, 16, 1, 1, ck::Sequence<4, 8, 1>, ck::Sequence<2, 0, 1>, ck::Sequence<1, 0, 2>, 1, 1, 4, false, ck::Sequence<4, 8, 1>, ck::Sequence<2, 0, 1>, ck::Sequence<1, 0, 2>, 1, 1, 4, false, 1, 1, ck::Sequence<1, 8, 1, 8>, 1>{};
} else {
return std::tuple<>{};
}
} else {
return std::tuple<>{};
}
}
std::string GetKernelName() const
{
return GetSolution().GetTypeString();
}
};

View File

@@ -0,0 +1,22 @@
#include <iostream>
#include <hip/hip_runtime.h>
#include "convolution_builder.hpp"
// Example of solution description for Forward Conv with default settings
struct GroupedConvFwdXdlImplicitGemm : public GroupedConvBaseXdl {
static constexpr ConvolutionDirection ConvolutionDirection_ = ConvolutionDirection::Forward;
};
// Example of solution description for Backward Weight Conv with default settings and Split K Two Stage
struct GroupedConvBwdWeightXdlImplicitGemm : public GroupedConvBaseXdl {
static constexpr ConvolutionDirection ConvolutionDirection_ = ConvolutionDirection::BackwardWeight;
static constexpr SplitKSupport SplitKSupport_ = SplitKSupport::SupportedTwoStage;
};
int main () {
ConvolutionBuilder<GroupedConvFwdXdlImplicitGemm> builder;
std::cout << builder.GetKernelName() << std::endl;
return 0;
}

View File

@@ -0,0 +1,106 @@
#pragma once
#include <iostream>
#include <concepts>
enum class GemmImplementationType
{
XDL,
WMMA,
DL
};
enum class ConvolutionDirection
{
Forward,
BackwardData,
BackwardWeight
};
enum class GemmPipelineVersion
{
V1,
V2,
V3,
V4,
V5
};
enum class GemmPipelineScheduler
{
Intrawave,
Interwave
};
enum class SplitKSupport
{
Supported,
SupportedTwoStage,
NotSupported
};
enum class MergedGroups
{
X16,
X8,
X4,
X2,
NotSupported
};
enum class LargeTensorSupport
{
Supported,
SplitBatch,
NotSupported
};
enum class ImplementationType
{
ExplicitDefault,
ExplicitMPadding,
ExplicitNPadding,
ExplicitKPadding,
ExplicitMNPadding,
ExplicitMKPadding,
ExplicitNKPadding,
ExplicitMNKPadding,
Implicit
};
enum class ElementwiseOperation {
Bias,
BiasClamp,
Bilinear,
Clamp,
Scale,
PassThrough
};
template <typename T>
concept SolutionDescriptorV1 = requires {
{T::GemmImplementationType_} -> std::convertible_to<GemmImplementationType>;
{T::ConvolutionDirection_} -> std::convertible_to<ConvolutionDirection>;
{T::GemmPipelineVersion_} -> std::convertible_to<const GemmPipelineVersion>;
{T::GemmPipelineScheduler_} -> std::convertible_to<const GemmPipelineScheduler>;
{T::SplitKSupport_} -> std::convertible_to<const SplitKSupport>;
{T::MergedGroups_} -> std::convertible_to<const MergedGroups>;
{T::LargeTensorSupport_} -> std::convertible_to<const LargeTensorSupport>;
{T::ImplementationType_} -> std::convertible_to<const ImplementationType>;
{T::ElementwiseOperation_} -> std::convertible_to<const ElementwiseOperation>;
};
struct GroupedConvBase {
static constexpr GemmPipelineVersion GemmPipelineVersion_ = GemmPipelineVersion::V1;
static constexpr GemmPipelineScheduler GemmPipelineScheduler_ = GemmPipelineScheduler::Intrawave;
static constexpr SplitKSupport SplitKSupport_ = SplitKSupport::NotSupported;
static constexpr MergedGroups MergedGroups_ = MergedGroups::NotSupported;
static constexpr LargeTensorSupport LargeTensorSupport_ = LargeTensorSupport::NotSupported;
static constexpr ImplementationType ImplementationType_ = ImplementationType::Implicit;
static constexpr ElementwiseOperation ElementwiseOperation_ = ElementwiseOperation::PassThrough;
};
struct GroupedConvBaseXdl : public GroupedConvBase {
static constexpr GemmImplementationType GemmImplementationType_ = GemmImplementationType::XDL;
};