mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 19:28:33 +00:00
Conv builder and solution descriptor
This commit is contained in:
@@ -706,4 +706,5 @@ rocm_create_package(
|
||||
)
|
||||
|
||||
# Add the experimental subdirectory
|
||||
add_subdirectory(experimental/gemm_builder)
|
||||
add_subdirectory(experimental/gemm_builder)
|
||||
add_subdirectory(experimental/convolution_builder)
|
||||
3
experimental/convolution_builder/CMakeLists.txt
Normal file
3
experimental/convolution_builder/CMakeLists.txt
Normal file
@@ -0,0 +1,3 @@
|
||||
set(CMAKE_CXX_STANDARD 20)
|
||||
|
||||
add_executable(convolution_example convolution_example.cpp)
|
||||
31
experimental/convolution_builder/convolution_builder.hpp
Normal file
31
experimental/convolution_builder/convolution_builder.hpp
Normal file
@@ -0,0 +1,31 @@
|
||||
#pragma once
|
||||
#include "convolution_solution_descriptor.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp"
|
||||
|
||||
template<SolutionDescriptorV1 Solution>
|
||||
struct ConvolutionBuilder {
|
||||
|
||||
static constexpr auto GetSolution() {
|
||||
if constexpr(Solution::GemmImplementationType_ == GemmImplementationType::XDL) {
|
||||
if constexpr(Solution::ConvolutionDirection_ == ConvolutionDirection::Forward) {
|
||||
return ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< 2, ck::tensor_layout::convolution::NHWGC, ck::tensor_layout::convolution::GKYXC, ck::Tuple<>, ck::tensor_layout::convolution::NHWGK, float, float, float, float, ck::Tuple<>, float, ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::device::ConvolutionForwardSpecialization::Default, ck::tensor_operation::device::GemmSpecialization::MNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, ck::Sequence<4, 64, 1>, ck::Sequence<1, 0, 2>, ck::Sequence<1, 0, 2>, 2, 8, 8, 1, ck::Sequence<4, 64, 1>, ck::Sequence<1, 0, 2>, ck::Sequence<1, 0, 2>, 2, 8, 8, 1, 1, 1, ck::Sequence<1, 32, 1, 8>, 8>{};
|
||||
} else if constexpr(Solution::ConvolutionDirection_ == ConvolutionDirection::BackwardData) {
|
||||
return std::tuple<>{};
|
||||
} else if constexpr(Solution::ConvolutionDirection_ == ConvolutionDirection::BackwardWeight) {
|
||||
return ck::tensor_operation::device::DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle<2, ck::tensor_layout::convolution::NHWGC, ck::tensor_layout::convolution::GKYXC, ck::tensor_layout::convolution::NHWGK, float, float, float, float, ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Default, 64, 16, 16, 32, 8, 16, 16, 1, 1, ck::Sequence<4, 8, 1>, ck::Sequence<2, 0, 1>, ck::Sequence<1, 0, 2>, 1, 1, 4, false, ck::Sequence<4, 8, 1>, ck::Sequence<2, 0, 1>, ck::Sequence<1, 0, 2>, 1, 1, 4, false, 1, 1, ck::Sequence<1, 8, 1, 8>, 1>{};
|
||||
} else {
|
||||
return std::tuple<>{};
|
||||
}
|
||||
} else {
|
||||
return std::tuple<>{};
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
std::string GetKernelName() const
|
||||
{
|
||||
return GetSolution().GetTypeString();
|
||||
}
|
||||
};
|
||||
22
experimental/convolution_builder/convolution_example.cpp
Normal file
22
experimental/convolution_builder/convolution_example.cpp
Normal file
@@ -0,0 +1,22 @@
|
||||
#include <iostream>
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
#include "convolution_builder.hpp"
|
||||
|
||||
// Example of solution description for Forward Conv with default settings
|
||||
struct GroupedConvFwdXdlImplicitGemm : public GroupedConvBaseXdl {
|
||||
static constexpr ConvolutionDirection ConvolutionDirection_ = ConvolutionDirection::Forward;
|
||||
};
|
||||
|
||||
// Example of solution description for Backward Weight Conv with default settings and Split K Two Stage
|
||||
struct GroupedConvBwdWeightXdlImplicitGemm : public GroupedConvBaseXdl {
|
||||
static constexpr ConvolutionDirection ConvolutionDirection_ = ConvolutionDirection::BackwardWeight;
|
||||
static constexpr SplitKSupport SplitKSupport_ = SplitKSupport::SupportedTwoStage;
|
||||
};
|
||||
|
||||
int main () {
|
||||
ConvolutionBuilder<GroupedConvFwdXdlImplicitGemm> builder;
|
||||
std::cout << builder.GetKernelName() << std::endl;
|
||||
return 0;
|
||||
}
|
||||
@@ -0,0 +1,106 @@
|
||||
#pragma once
|
||||
#include <iostream>
|
||||
#include <concepts>
|
||||
|
||||
enum class GemmImplementationType
|
||||
{
|
||||
XDL,
|
||||
WMMA,
|
||||
DL
|
||||
};
|
||||
|
||||
enum class ConvolutionDirection
|
||||
{
|
||||
Forward,
|
||||
BackwardData,
|
||||
BackwardWeight
|
||||
};
|
||||
|
||||
|
||||
enum class GemmPipelineVersion
|
||||
{
|
||||
V1,
|
||||
V2,
|
||||
V3,
|
||||
V4,
|
||||
V5
|
||||
};
|
||||
|
||||
enum class GemmPipelineScheduler
|
||||
{
|
||||
Intrawave,
|
||||
Interwave
|
||||
};
|
||||
|
||||
enum class SplitKSupport
|
||||
{
|
||||
Supported,
|
||||
SupportedTwoStage,
|
||||
NotSupported
|
||||
};
|
||||
|
||||
enum class MergedGroups
|
||||
{
|
||||
X16,
|
||||
X8,
|
||||
X4,
|
||||
X2,
|
||||
NotSupported
|
||||
};
|
||||
|
||||
enum class LargeTensorSupport
|
||||
{
|
||||
Supported,
|
||||
SplitBatch,
|
||||
NotSupported
|
||||
};
|
||||
|
||||
enum class ImplementationType
|
||||
{
|
||||
ExplicitDefault,
|
||||
ExplicitMPadding,
|
||||
ExplicitNPadding,
|
||||
ExplicitKPadding,
|
||||
ExplicitMNPadding,
|
||||
ExplicitMKPadding,
|
||||
ExplicitNKPadding,
|
||||
ExplicitMNKPadding,
|
||||
Implicit
|
||||
};
|
||||
|
||||
enum class ElementwiseOperation {
|
||||
Bias,
|
||||
BiasClamp,
|
||||
Bilinear,
|
||||
Clamp,
|
||||
Scale,
|
||||
PassThrough
|
||||
};
|
||||
|
||||
|
||||
template <typename T>
|
||||
concept SolutionDescriptorV1 = requires {
|
||||
{T::GemmImplementationType_} -> std::convertible_to<GemmImplementationType>;
|
||||
{T::ConvolutionDirection_} -> std::convertible_to<ConvolutionDirection>;
|
||||
{T::GemmPipelineVersion_} -> std::convertible_to<const GemmPipelineVersion>;
|
||||
{T::GemmPipelineScheduler_} -> std::convertible_to<const GemmPipelineScheduler>;
|
||||
{T::SplitKSupport_} -> std::convertible_to<const SplitKSupport>;
|
||||
{T::MergedGroups_} -> std::convertible_to<const MergedGroups>;
|
||||
{T::LargeTensorSupport_} -> std::convertible_to<const LargeTensorSupport>;
|
||||
{T::ImplementationType_} -> std::convertible_to<const ImplementationType>;
|
||||
{T::ElementwiseOperation_} -> std::convertible_to<const ElementwiseOperation>;
|
||||
};
|
||||
|
||||
struct GroupedConvBase {
|
||||
static constexpr GemmPipelineVersion GemmPipelineVersion_ = GemmPipelineVersion::V1;
|
||||
static constexpr GemmPipelineScheduler GemmPipelineScheduler_ = GemmPipelineScheduler::Intrawave;
|
||||
static constexpr SplitKSupport SplitKSupport_ = SplitKSupport::NotSupported;
|
||||
static constexpr MergedGroups MergedGroups_ = MergedGroups::NotSupported;
|
||||
static constexpr LargeTensorSupport LargeTensorSupport_ = LargeTensorSupport::NotSupported;
|
||||
static constexpr ImplementationType ImplementationType_ = ImplementationType::Implicit;
|
||||
static constexpr ElementwiseOperation ElementwiseOperation_ = ElementwiseOperation::PassThrough;
|
||||
};
|
||||
|
||||
struct GroupedConvBaseXdl : public GroupedConvBase {
|
||||
static constexpr GemmImplementationType GemmImplementationType_ = GemmImplementationType::XDL;
|
||||
};
|
||||
Reference in New Issue
Block a user