mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 03:37:38 +00:00
Problem descriptor
This commit is contained in:
@@ -1,27 +1,28 @@
|
||||
#pragma once
|
||||
#include "convolution_kernel_descriptor.hpp"
|
||||
#include "convolution_problem_descriptor.hpp"
|
||||
#include "convolution_implementation_descriptor.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp"
|
||||
|
||||
template<typename Problem, typename Implementation>
|
||||
template<typename KernelDesc, typename ProblemDesc, typename Implementation>
|
||||
struct ConvolutionBuilder;
|
||||
|
||||
template<ProblemDescriptorV1 Problem, ImplementationDescriptorV1 Implementation>
|
||||
struct ConvolutionBuilder<Problem, Implementation> {
|
||||
template<KernelDescriptorV1 KernelDesc, ProblemDescriptorV1 ProblemDesc, ImplementationDescriptorV1 ImplementationDesc>
|
||||
struct ConvolutionBuilder<KernelDesc, ProblemDesc, ImplementationDesc> {
|
||||
public:
|
||||
static constexpr auto GetInstance() {
|
||||
using DataType = typename Implementation::DataType;
|
||||
using DataType = typename ProblemDesc::DataType;
|
||||
using AccDataType = std::conditional_t<std::is_same_v<DataType, int8_t>, int32_t, float>;
|
||||
using InLayout = std::tuple_element<0, decltype(GetLayout())>::type;
|
||||
using WeiLayout = std::tuple_element<1,decltype( GetLayout())>::type;
|
||||
using OutLayout = std::tuple_element<2,decltype( GetLayout())>::type;
|
||||
|
||||
using GroupedConvFwdMultipleABD_Xdl_CShuffleInstance = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< Implementation::NDimSpatial_, InLayout, WeiLayout, decltype(GetMultiDLayout()), OutLayout, DataType, DataType, DataType, AccDataType, typename Implementation::ElementwiseOpDataTypes, DataType, ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough, decltype(GetOutElementwiseOp()), GetConvSpecialization(), ck::tensor_operation::device::GemmSpecialization::MNKPadding, 1, Implementation::BlockSize_, Implementation::TileSizes_::At(0), Implementation::TileSizes_::At(1), Implementation::TileSizes_::At(2), Implementation::K1_, Implementation::K1_, 16, 16, 1, 1, ck::Sequence<4, 8, 1>, ck::Sequence<1, 0, 2>, ck::Sequence<1, 0, 2>, 2, 1, 4, 1, ck::Sequence<4, 8, 1>, ck::Sequence<1, 0, 2>, ck::Sequence<1, 0, 2>, 2, 1, 4, 1, 1, 1, ck::Sequence<1, 32, 1, 8>, 1>;
|
||||
using DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffleInstance = ck::tensor_operation::device::DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle<Implementation::NDimSpatial_, InLayout, WeiLayout, OutLayout, DataType, DataType, DataType, AccDataType , ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough, decltype(GetOutElementwiseOp()), GetConvSpecialization(), Implementation::BlockSize_, Implementation::TileSizes_::At(0), Implementation::TileSizes_::At(1), Implementation::TileSizes_::At(2), Implementation::K1_, 16, 16, 1, 1, ck::Sequence<4, 8, 1>, ck::Sequence<2, 0, 1>, ck::Sequence<1, 0, 2>, 1, 1, 4, false, ck::Sequence<4, 8, 1>, ck::Sequence<2, 0, 1>, ck::Sequence<1, 0, 2>, 1, 1, 4, false, 1, 1, ck::Sequence<1, 8, 1, 8>, 1>;
|
||||
using GroupedConvFwdMultipleABD_Xdl_CShuffleInstance = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< ProblemDesc::NDimSpatial_, InLayout, WeiLayout, decltype(GetMultiDLayout()), OutLayout, DataType, DataType, DataType, AccDataType, typename ProblemDesc::ElementwiseOpDataTypes, DataType, ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough, decltype(GetOutElementwiseOp()), GetConvSpecialization(), ck::tensor_operation::device::GemmSpecialization::MNKPadding, 1, ImplementationDesc::BlockSize_, ImplementationDesc::TileSizes_::At(0), ImplementationDesc::TileSizes_::At(1), ImplementationDesc::TileSizes_::At(2), ImplementationDesc::K1_, ImplementationDesc::K1_, 16, 16, 1, 1, ck::Sequence<4, 8, 1>, ck::Sequence<1, 0, 2>, ck::Sequence<1, 0, 2>, 2, 1, 4, 1, ck::Sequence<4, 8, 1>, ck::Sequence<1, 0, 2>, ck::Sequence<1, 0, 2>, 2, 1, 4, 1, 1, 1, ck::Sequence<1, 32, 1, 8>, 1>;
|
||||
using DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffleInstance = ck::tensor_operation::device::DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle<ProblemDesc::NDimSpatial_, InLayout, WeiLayout, OutLayout, DataType, DataType, DataType, AccDataType , ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough, decltype(GetOutElementwiseOp()), GetConvSpecialization(), ImplementationDesc::BlockSize_, ImplementationDesc::TileSizes_::At(0), ImplementationDesc::TileSizes_::At(1), ImplementationDesc::TileSizes_::At(2), ImplementationDesc::K1_, 16, 16, 1, 1, ck::Sequence<4, 8, 1>, ck::Sequence<2, 0, 1>, ck::Sequence<1, 0, 2>, 1, 1, 4, false, ck::Sequence<4, 8, 1>, ck::Sequence<2, 0, 1>, ck::Sequence<1, 0, 2>, 1, 1, 4, false, 1, 1, ck::Sequence<1, 8, 1, 8>, 1>;
|
||||
|
||||
using SelectedInstance = std::conditional_t<GetKernel() == Kernel::GroupedConvFwdMultipleABD_Xdl_CShuffle, GroupedConvFwdMultipleABD_Xdl_CShuffleInstance, DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffleInstance>;
|
||||
using SelectedInstance = std::conditional_t<GetKernel() == KernelDesc::GroupedConvFwdMultipleABD_Xdl_CShuffle, GroupedConvFwdMultipleABD_Xdl_CShuffleInstance, DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffleInstance>;
|
||||
return SelectedInstance{};
|
||||
}
|
||||
|
||||
@@ -64,23 +65,23 @@ public:
|
||||
// clang-format off
|
||||
str << KernelToString[GetKernel()]
|
||||
<< "<"
|
||||
<< Implementation::BlockSize_ << ", "
|
||||
<< std::get<0>(Implementation::TileSizes_) << ", "
|
||||
<< std::get<1>(Implementation::TileSizes_) << ", "
|
||||
<< std::get<2>(Implementation::TileSizes_) << ", "
|
||||
<< ConvolutionSpecializationToString[Implementation::ConvolutionSpecialization_] << ", "
|
||||
<< Implementation::K1_ << ", "
|
||||
<< MFMAInstructionSizeToString[Implementation::MFMAInstructionSize_] << ", "
|
||||
<< std::get<0>(Implementation::XdlPerWave_) << ", "
|
||||
<< std::get<1>(Implementation::XdlPerWave_) << ", "
|
||||
<< std::get<0>(Implementation::GlobalTransferVectorSize_) << ", "
|
||||
<< std::get<0>(Implementation::LDSStoreVectorSize_) << ", "
|
||||
<< std::get<1>(Implementation::GlobalTransferVectorSize_) << ", "
|
||||
<< std::get<1>(Implementation::LDSStoreVectorSize_) << ", "
|
||||
<< std::get<2>(Implementation::GlobalTransferVectorSize_) << ", "
|
||||
<< GemmPipelineSchedulerToString[Problem::GemmPipelineScheduler_] << ", "
|
||||
<< GemmPipelineVersionToString[Problem::GemmPipelineVersion_] << ", "
|
||||
<< MergedGroupsToString[Problem::MergedGroups_] << ">";
|
||||
<< ImplementationDesc::BlockSize_ << ", "
|
||||
<< std::get<0>(ImplementationDesc::TileSizes_) << ", "
|
||||
<< std::get<1>(ImplementationDesc::TileSizes_) << ", "
|
||||
<< std::get<2>(ImplementationDesc::TileSizes_) << ", "
|
||||
<< ConvolutionSpecializationToString[ImplementationDesc::ConvolutionSpecialization_] << ", "
|
||||
<< ImplementationDesc::K1_ << ", "
|
||||
<< MFMAInstructionSizeToString[ImplementationDesc::MFMAInstructionSize_] << ", "
|
||||
<< std::get<0>(ImplementationDesc::XdlPerWave_) << ", "
|
||||
<< std::get<1>(ImplementationDesc::XdlPerWave_) << ", "
|
||||
<< std::get<0>(ImplementationDesc::GlobalTransferVectorSize_) << ", "
|
||||
<< std::get<0>(ImplementationDesc::LDSStoreVectorSize_) << ", "
|
||||
<< std::get<1>(ImplementationDesc::GlobalTransferVectorSize_) << ", "
|
||||
<< std::get<1>(ImplementationDesc::LDSStoreVectorSize_) << ", "
|
||||
<< std::get<2>(ImplementationDesc::GlobalTransferVectorSize_) << ", "
|
||||
<< GemmPipelineSchedulerToString[KernelDesc::GemmPipelineScheduler_] << ", "
|
||||
<< GemmPipelineVersionToString[KernelDesc::GemmPipelineVersion_] << ", "
|
||||
<< MergedGroupsToString[KernelDesc::MergedGroups_] << ">";
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
@@ -93,12 +94,12 @@ private:
|
||||
};
|
||||
|
||||
static constexpr Kernel GetKernel() {
|
||||
if constexpr(Problem::GemmImplementationType_ == GemmImplementationType::XDL) {
|
||||
if constexpr(Problem::ConvolutionDirection_ == ConvolutionDirection::Forward) {
|
||||
if constexpr(KernelDesc::GemmImplementationType_ == GemmImplementationType::XDL) {
|
||||
if constexpr(KernelDesc::ConvolutionDirection_ == ConvolutionDirection::Forward) {
|
||||
return Kernel::GroupedConvFwdMultipleABD_Xdl_CShuffle;
|
||||
} else if constexpr(Problem::ConvolutionDirection_ == ConvolutionDirection::BackwardData) {
|
||||
} else if constexpr(KernelDesc::ConvolutionDirection_ == ConvolutionDirection::BackwardData) {
|
||||
static_assert("Instance not found!");
|
||||
} else if constexpr(Problem::ConvolutionDirection_ == ConvolutionDirection::BackwardWeight) {
|
||||
} else if constexpr(KernelDesc::ConvolutionDirection_ == ConvolutionDirection::BackwardWeight) {
|
||||
return Kernel::GroupedConvBwdWeightTwoStage_Xdl_CShuffle;
|
||||
} else {
|
||||
static_assert("Instance not found!");
|
||||
@@ -109,8 +110,8 @@ private:
|
||||
}
|
||||
|
||||
static constexpr auto GetLayout() {
|
||||
if constexpr(Implementation::NDimSpatial_ == 2) {
|
||||
if constexpr(Implementation::ConvolutionLayout_ == ConvolutionLayout::NHWGC_GKYXC_NHWGK) {
|
||||
if constexpr(ProblemDesc::NDimSpatial_ == 2) {
|
||||
if constexpr(ProblemDesc::ConvolutionLayout_ == ConvolutionLayout::NHWGC_GKYXC_NHWGK) {
|
||||
return std::tuple<ck::tensor_layout::convolution::NHWGC, ck::tensor_layout::convolution::GKYXC, ck::tensor_layout::convolution::NHWGK>{};
|
||||
} else {
|
||||
static_assert("Layout not supported!");
|
||||
@@ -129,9 +130,9 @@ private:
|
||||
}
|
||||
|
||||
static constexpr auto GetConvSpecialization() {
|
||||
if constexpr(Problem::ConvolutionDirection_ == ConvolutionDirection::Forward) {
|
||||
if constexpr(KernelDesc::ConvolutionDirection_ == ConvolutionDirection::Forward) {
|
||||
return ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
|
||||
} else if constexpr(Problem::ConvolutionDirection_ == ConvolutionDirection::BackwardWeight) {
|
||||
} else if constexpr(KernelDesc::ConvolutionDirection_ == ConvolutionDirection::BackwardWeight) {
|
||||
return ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Default;
|
||||
} else {
|
||||
static_assert("Specialization not found!");
|
||||
|
||||
@@ -4,19 +4,19 @@
|
||||
|
||||
#include "convolution_builder.hpp"
|
||||
|
||||
// Example of problem description for Forward Conv with default settings
|
||||
// Example of kernel description for Forward Conv with default settings
|
||||
struct GroupedConvFwdXdlImplicitGemm : public GroupedConvBaseXdlV1 {
|
||||
static constexpr ConvolutionDirection ConvolutionDirection_ = ConvolutionDirection::Forward;
|
||||
static constexpr ElementwiseOperation ElementwiseOperation_ = ElementwiseOperation::Bias;
|
||||
};
|
||||
|
||||
// Example of problem description for Backward Weight Conv with default settings and Split K Two Stage
|
||||
// Example of kernel description for Backward Weight Conv with default settings and Split K Two Stage
|
||||
struct GroupedConvBwdWeightXdlImplicitGemmTwoStage : public GroupedConvBaseXdlV1 {
|
||||
static constexpr ConvolutionDirection ConvolutionDirection_ = ConvolutionDirection::BackwardWeight;
|
||||
static constexpr SplitKSupport SplitKSupport_ = SplitKSupport::SupportedTwoStage;
|
||||
};
|
||||
|
||||
struct ImplementationDescriptor : public NHWCImplementationBaseV1, public BF16ImplementationBaseV1 {
|
||||
struct Implementation16x16 : ImplementationDefaultV1 {
|
||||
static constexpr ck::index_t BlockSize_ = 64;
|
||||
static constexpr auto TileSizes_ = std::make_tuple(16, 16, 32);
|
||||
static constexpr ck::index_t K1_ = 8;
|
||||
@@ -26,10 +26,12 @@ struct ImplementationDescriptor : public NHWCImplementationBaseV1, public BF16Im
|
||||
static constexpr auto LDSStoreVectorSize_ = std::make_tuple(4, 4);
|
||||
};
|
||||
|
||||
struct ProblemBF16NHWGC : public BF16ProblemBaseV1, public NHWGCProblemBaseV1 {};
|
||||
|
||||
int main () {
|
||||
ConvolutionBuilder<GroupedConvFwdXdlImplicitGemm, ImplementationDescriptor> builder_fwd;
|
||||
ConvolutionBuilder<GroupedConvFwdXdlImplicitGemm, ProblemBF16NHWGC, Implementation16x16> builder_fwd;
|
||||
std::cout << builder_fwd.GetInstanceName() << std::endl;
|
||||
ConvolutionBuilder<GroupedConvBwdWeightXdlImplicitGemmTwoStage, ImplementationDescriptor> builder_bwd_weight_two_stage;
|
||||
ConvolutionBuilder<GroupedConvBwdWeightXdlImplicitGemmTwoStage, ProblemBF16NHWGC, Implementation16x16> builder_bwd_weight_two_stage;
|
||||
std::cout << builder_bwd_weight_two_stage.GetInstanceName() << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
@@ -19,11 +19,6 @@ enum class ConvolutionSpecialization {
|
||||
Filter3x3
|
||||
};
|
||||
|
||||
enum class ConvolutionLayout {
|
||||
NHWGC_GKYXC_NHWGK,
|
||||
NGCHW_GKCYX_NGKHW
|
||||
};
|
||||
|
||||
enum class MFMAInstructionSize {
|
||||
M16N16,
|
||||
M32N32
|
||||
@@ -33,11 +28,7 @@ enum class MFMAInstructionSize {
|
||||
template <typename T>
|
||||
concept ImplementationDescriptorV1 = requires {
|
||||
{T::ImplementationDescriptorVersion_} -> std::convertible_to<ImplementationDescriptorVersion>;
|
||||
{T::NDimSpatial_} -> std::convertible_to<int>;
|
||||
typename T::DataType;
|
||||
typename T::ElementwiseOpDataTypes;
|
||||
{T::ConvolutionSpecialization_} -> std::convertible_to<ConvolutionSpecialization>;
|
||||
{T::ConvolutionLayout_} -> std::convertible_to<ConvolutionLayout>;
|
||||
{T::BlockSize_} -> std::convertible_to<int>;
|
||||
{T::TileSizes_} -> std::convertible_to<std::tuple<int, int, int>>;
|
||||
{T::K1_} -> std::convertible_to<int>;
|
||||
@@ -47,36 +38,7 @@ concept ImplementationDescriptorV1 = requires {
|
||||
{T::LDSStoreVectorSize_} -> std::convertible_to<std::tuple<int, int>>;
|
||||
} && (T::ImplementationDescriptorVersion_ == ImplementationDescriptorVersion::V1);
|
||||
|
||||
struct ImplementationBaseV1 {
|
||||
struct ImplementationDefaultV1 {
|
||||
static constexpr ImplementationDescriptorVersion ImplementationDescriptorVersion_ = ImplementationDescriptorVersion::V1;
|
||||
using DataType = ck::bhalf_t;
|
||||
using ElementwiseOpDataTypes = ck::Tuple<>;
|
||||
static constexpr ConvolutionSpecialization ConvolutionSpecialization_ = ConvolutionSpecialization::Default;
|
||||
};
|
||||
|
||||
struct BF16ImplementationBaseV1 : public ImplementationBaseV1 {
|
||||
using DataType = ck::bhalf_t;
|
||||
};
|
||||
|
||||
struct F32ImplementationBaseV1 : public ImplementationBaseV1 {
|
||||
using DataType = float;
|
||||
};
|
||||
|
||||
struct F16ImplementationBaseV1 : public ImplementationBaseV1 {
|
||||
using DataType = ck::half_t;
|
||||
};
|
||||
|
||||
struct NWCImplementationBaseV1 : public ImplementationBaseV1 {
|
||||
static constexpr int NDimSpatial_ = 1;
|
||||
static constexpr ConvolutionLayout ConvolutionLayout_ = ConvolutionLayout::NHWGC_GKYXC_NHWGK;
|
||||
};
|
||||
|
||||
struct NHWCImplementationBaseV1 : public ImplementationBaseV1 {
|
||||
static constexpr int NDimSpatial_ = 2;
|
||||
static constexpr ConvolutionLayout ConvolutionLayout_ = ConvolutionLayout::NHWGC_GKYXC_NHWGK;
|
||||
};
|
||||
|
||||
struct NDHWCImplementationBaseV1 : public ImplementationBaseV1 {
|
||||
static constexpr int NDimSpatial_ = 3;
|
||||
static constexpr ConvolutionLayout ConvolutionLayout_ = ConvolutionLayout::NHWGC_GKYXC_NHWGK;
|
||||
};
|
||||
|
||||
@@ -0,0 +1,116 @@
|
||||
#pragma once
|
||||
#include <concepts>
|
||||
|
||||
enum class KernelDescriptorVersion
|
||||
{
|
||||
V1
|
||||
};
|
||||
|
||||
enum class GemmImplementationType
|
||||
{
|
||||
XDL,
|
||||
WMMA,
|
||||
DL
|
||||
};
|
||||
|
||||
enum class ConvolutionDirection
|
||||
{
|
||||
Forward,
|
||||
BackwardData,
|
||||
BackwardWeight
|
||||
};
|
||||
|
||||
|
||||
enum class GemmPipelineVersion
|
||||
{
|
||||
V1,
|
||||
V2,
|
||||
V3,
|
||||
V4,
|
||||
V5
|
||||
};
|
||||
|
||||
enum class GemmPipelineScheduler
|
||||
{
|
||||
Intrawave,
|
||||
Interwave
|
||||
};
|
||||
|
||||
enum class SplitKSupport
|
||||
{
|
||||
Supported,
|
||||
SupportedTwoStage,
|
||||
NotSupported
|
||||
};
|
||||
|
||||
enum class MergedGroups
|
||||
{
|
||||
X16,
|
||||
X8,
|
||||
X4,
|
||||
X2,
|
||||
NotSupported
|
||||
};
|
||||
|
||||
enum class LargeTensorSupport
|
||||
{
|
||||
Supported,
|
||||
SplitBatch,
|
||||
NotSupported
|
||||
};
|
||||
|
||||
enum class ImplementationType
|
||||
{
|
||||
ExplicitDefault,
|
||||
ExplicitMPadding,
|
||||
ExplicitNPadding,
|
||||
ExplicitKPadding,
|
||||
ExplicitMNPadding,
|
||||
ExplicitMKPadding,
|
||||
ExplicitNKPadding,
|
||||
ExplicitMNKPadding,
|
||||
Implicit
|
||||
};
|
||||
|
||||
enum class ElementwiseOperation {
|
||||
Bias,
|
||||
BiasClamp,
|
||||
Bilinear,
|
||||
Clamp,
|
||||
Scale,
|
||||
PassThrough
|
||||
};
|
||||
|
||||
|
||||
template <typename T>
|
||||
concept KernelDescriptorV1 = requires {
|
||||
{T::KernelDescriptorVersion_} -> std::convertible_to<KernelDescriptorVersion>;
|
||||
{T::GemmImplementationType_} -> std::convertible_to<GemmImplementationType>;
|
||||
{T::ConvolutionDirection_} -> std::convertible_to<ConvolutionDirection>;
|
||||
{T::GemmPipelineVersion_} -> std::convertible_to<const GemmPipelineVersion>;
|
||||
{T::GemmPipelineScheduler_} -> std::convertible_to<const GemmPipelineScheduler>;
|
||||
{T::SplitKSupport_} -> std::convertible_to<const SplitKSupport>;
|
||||
{T::MergedGroups_} -> std::convertible_to<const MergedGroups>;
|
||||
{T::LargeTensorSupport_} -> std::convertible_to<const LargeTensorSupport>;
|
||||
{T::ImplementationType_} -> std::convertible_to<const ImplementationType>;
|
||||
{T::ElementwiseOperation_} -> std::convertible_to<const ElementwiseOperation>;
|
||||
} && (T::KernelDescriptorVersion_ == KernelDescriptorVersion::V1);
|
||||
|
||||
struct GroupedConvBase {
|
||||
static constexpr GemmPipelineVersion GemmPipelineVersion_ = GemmPipelineVersion::V1;
|
||||
static constexpr GemmPipelineScheduler GemmPipelineScheduler_ = GemmPipelineScheduler::Intrawave;
|
||||
static constexpr SplitKSupport SplitKSupport_ = SplitKSupport::NotSupported;
|
||||
static constexpr MergedGroups MergedGroups_ = MergedGroups::NotSupported;
|
||||
static constexpr LargeTensorSupport LargeTensorSupport_ = LargeTensorSupport::NotSupported;
|
||||
static constexpr ImplementationType ImplementationType_ = ImplementationType::Implicit;
|
||||
static constexpr ElementwiseOperation ElementwiseOperation_ = ElementwiseOperation::PassThrough;
|
||||
};
|
||||
|
||||
struct GroupedConvBaseXdl : public GroupedConvBase {
|
||||
static constexpr GemmImplementationType GemmImplementationType_ = GemmImplementationType::XDL;
|
||||
};
|
||||
|
||||
struct GroupedConvBaseXdlV1 : public GroupedConvBaseXdl {
|
||||
static constexpr KernelDescriptorVersion KernelDescriptorVersion_ = KernelDescriptorVersion::V1;
|
||||
};
|
||||
|
||||
@@ -1,116 +1,58 @@
|
||||
#pragma once
|
||||
#include <concepts>
|
||||
|
||||
#include "ck/utility/data_type.hpp"
|
||||
#include "ck/utility/sequence.hpp"
|
||||
#include "ck/utility/tuple.hpp"
|
||||
#include "ck/ck.hpp"
|
||||
|
||||
enum class ProblemDescriptorVersion
|
||||
{
|
||||
V1
|
||||
};
|
||||
|
||||
enum class GemmImplementationType
|
||||
{
|
||||
XDL,
|
||||
WMMA,
|
||||
DL
|
||||
enum class ConvolutionLayout {
|
||||
NHWGC_GKYXC_NHWGK,
|
||||
NGCHW_GKCYX_NGKHW
|
||||
};
|
||||
|
||||
enum class ConvolutionDirection
|
||||
{
|
||||
Forward,
|
||||
BackwardData,
|
||||
BackwardWeight
|
||||
};
|
||||
|
||||
|
||||
enum class GemmPipelineVersion
|
||||
{
|
||||
V1,
|
||||
V2,
|
||||
V3,
|
||||
V4,
|
||||
V5
|
||||
};
|
||||
|
||||
enum class GemmPipelineScheduler
|
||||
{
|
||||
Intrawave,
|
||||
Interwave
|
||||
};
|
||||
|
||||
enum class SplitKSupport
|
||||
{
|
||||
Supported,
|
||||
SupportedTwoStage,
|
||||
NotSupported
|
||||
};
|
||||
|
||||
enum class MergedGroups
|
||||
{
|
||||
X16,
|
||||
X8,
|
||||
X4,
|
||||
X2,
|
||||
NotSupported
|
||||
};
|
||||
|
||||
enum class LargeTensorSupport
|
||||
{
|
||||
Supported,
|
||||
SplitBatch,
|
||||
NotSupported
|
||||
};
|
||||
|
||||
enum class ImplementationType
|
||||
{
|
||||
ExplicitDefault,
|
||||
ExplicitMPadding,
|
||||
ExplicitNPadding,
|
||||
ExplicitKPadding,
|
||||
ExplicitMNPadding,
|
||||
ExplicitMKPadding,
|
||||
ExplicitNKPadding,
|
||||
ExplicitMNKPadding,
|
||||
Implicit
|
||||
};
|
||||
|
||||
enum class ElementwiseOperation {
|
||||
Bias,
|
||||
BiasClamp,
|
||||
Bilinear,
|
||||
Clamp,
|
||||
Scale,
|
||||
PassThrough
|
||||
};
|
||||
|
||||
|
||||
template <typename T>
|
||||
concept ProblemDescriptorV1 = requires {
|
||||
{T::ProblemDescriptorVersion_} -> std::convertible_to<ProblemDescriptorVersion>;
|
||||
{T::GemmImplementationType_} -> std::convertible_to<GemmImplementationType>;
|
||||
{T::ConvolutionDirection_} -> std::convertible_to<ConvolutionDirection>;
|
||||
{T::GemmPipelineVersion_} -> std::convertible_to<const GemmPipelineVersion>;
|
||||
{T::GemmPipelineScheduler_} -> std::convertible_to<const GemmPipelineScheduler>;
|
||||
{T::SplitKSupport_} -> std::convertible_to<const SplitKSupport>;
|
||||
{T::MergedGroups_} -> std::convertible_to<const MergedGroups>;
|
||||
{T::LargeTensorSupport_} -> std::convertible_to<const LargeTensorSupport>;
|
||||
{T::ImplementationType_} -> std::convertible_to<const ImplementationType>;
|
||||
{T::ElementwiseOperation_} -> std::convertible_to<const ElementwiseOperation>;
|
||||
{T::NDimSpatial_} -> std::convertible_to<int>;
|
||||
typename T::DataType;
|
||||
typename T::ElementwiseOpDataTypes;
|
||||
{T::ConvolutionLayout_} -> std::convertible_to<ConvolutionLayout>;
|
||||
} && (T::ProblemDescriptorVersion_ == ProblemDescriptorVersion::V1);
|
||||
|
||||
struct GroupedConvBase {
|
||||
static constexpr GemmPipelineVersion GemmPipelineVersion_ = GemmPipelineVersion::V1;
|
||||
static constexpr GemmPipelineScheduler GemmPipelineScheduler_ = GemmPipelineScheduler::Intrawave;
|
||||
static constexpr SplitKSupport SplitKSupport_ = SplitKSupport::NotSupported;
|
||||
static constexpr MergedGroups MergedGroups_ = MergedGroups::NotSupported;
|
||||
static constexpr LargeTensorSupport LargeTensorSupport_ = LargeTensorSupport::NotSupported;
|
||||
static constexpr ImplementationType ImplementationType_ = ImplementationType::Implicit;
|
||||
static constexpr ElementwiseOperation ElementwiseOperation_ = ElementwiseOperation::PassThrough;
|
||||
};
|
||||
|
||||
struct GroupedConvBaseXdl : public GroupedConvBase {
|
||||
static constexpr GemmImplementationType GemmImplementationType_ = GemmImplementationType::XDL;
|
||||
};
|
||||
|
||||
struct GroupedConvBaseXdlV1 : public GroupedConvBaseXdl {
|
||||
struct ProblemBaseV1 {
|
||||
static constexpr ProblemDescriptorVersion ProblemDescriptorVersion_ = ProblemDescriptorVersion::V1;
|
||||
using ElementwiseOpDataTypes = ck::Tuple<>;
|
||||
};
|
||||
|
||||
struct BF16ProblemBaseV1 : public ProblemBaseV1 {
|
||||
using DataType = ck::bhalf_t;
|
||||
};
|
||||
|
||||
struct F32ProblemBaseV1 : public ProblemBaseV1 {
|
||||
using DataType = float;
|
||||
};
|
||||
|
||||
struct F16ProblemBaseV1 : public ProblemBaseV1 {
|
||||
using DataType = ck::half_t;
|
||||
};
|
||||
|
||||
struct NWGCProblemBaseV1 : public ProblemBaseV1 {
|
||||
static constexpr int NDimSpatial_ = 1;
|
||||
static constexpr ConvolutionLayout ConvolutionLayout_ = ConvolutionLayout::NHWGC_GKYXC_NHWGK;
|
||||
};
|
||||
|
||||
struct NHWGCProblemBaseV1 : public ProblemBaseV1 {
|
||||
static constexpr int NDimSpatial_ = 2;
|
||||
static constexpr ConvolutionLayout ConvolutionLayout_ = ConvolutionLayout::NHWGC_GKYXC_NHWGK;
|
||||
};
|
||||
|
||||
struct NDHWGCProblemBaseV1 : public ProblemBaseV1 {
|
||||
static constexpr int NDimSpatial_ = 3;
|
||||
static constexpr ConvolutionLayout ConvolutionLayout_ = ConvolutionLayout::NHWGC_GKYXC_NHWGK;
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user