mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
Clean up conv_traits.hpp (#3354)
When I asked for a description of operators that didn't have ConvTraits, I was getting very long confusing errors about ConvTraits not being defined. Now we get specific errors explaining which concepts are violated, making it easier to know which code to generalize or update. * Add concepts to conv_traits.hpp to get better error message. * Put the correct requires clauses in the right places to get descriptive error messages. * General cleanup of functions in conv_traits.hpp to make functions easier to read.
This commit is contained in:
@@ -251,14 +251,10 @@ class ConvDescription : public Description
|
||||
};
|
||||
} // namespace conv
|
||||
|
||||
/// @brief Helper concept to detect if a type has ConvTraits specialization
|
||||
template <typename T>
|
||||
concept HasConvTraits = requires { typename conv::ConvTraits<T>; };
|
||||
|
||||
/// @brief Factory function to create ConvDescription from a convolution instance type
|
||||
/// @tparam Instance The convolution instance type (must have InstanceTraits specialization)
|
||||
/// @tparam Instance The convolution instance type (must have ConvTraits specialization)
|
||||
/// @return A ConvDescription object populated with the instance's configuration details
|
||||
template <HasConvTraits Instance>
|
||||
template <conv::HasConvTraits Instance>
|
||||
conv::ConvDescription describe()
|
||||
{
|
||||
using Traits = conv::ConvTraits<Instance>;
|
||||
|
||||
@@ -21,6 +21,57 @@
|
||||
|
||||
namespace ck_tile::reflect::conv {
|
||||
|
||||
// Forward convolution layout concept - checks for A/B/E layout types
|
||||
template <typename T>
|
||||
concept HasFwdConvLayouts = requires {
|
||||
typename T::ALayout;
|
||||
typename T::BLayout;
|
||||
typename T::ELayout;
|
||||
};
|
||||
|
||||
// GEMM specialization concept - checks for kGemmSpecialization member
|
||||
template <typename T>
|
||||
concept HasGemmSpec = requires {
|
||||
{
|
||||
T::kGemmSpecialization
|
||||
} -> std::convertible_to<ck::tensor_operation::device::GemmSpecialization>;
|
||||
};
|
||||
|
||||
// Data types concept - checks for ADataType member
|
||||
template <typename T>
|
||||
concept HasDataTypes = requires { typename T::ADataType; };
|
||||
|
||||
// Elementwise operations concept - checks for A/B/CDE elementwise operation types
|
||||
template <typename T>
|
||||
concept HasElementwiseOps = requires {
|
||||
typename T::AElementwiseOperation;
|
||||
typename T::BElementwiseOperation;
|
||||
typename T::CDEElementwiseOperation;
|
||||
};
|
||||
|
||||
// Tile parameters concept - checks for tile dimension and transfer members
|
||||
template <typename T>
|
||||
concept HasTileParams = requires {
|
||||
{ T::kKPerBlock } -> std::convertible_to<int>;
|
||||
{ T::kMPerBlock } -> std::convertible_to<int>;
|
||||
{ T::kNPerBlock } -> std::convertible_to<int>;
|
||||
{ T::kAK1 } -> std::convertible_to<int>;
|
||||
{ T::kBK1 } -> std::convertible_to<int>;
|
||||
T::kCThreadClusterLengths;
|
||||
};
|
||||
|
||||
// Comprehensive concept that checks if an instance has all XDL forward convolution traits
|
||||
// This concept is used to constrain ConvTraits specialization that expect XDL forward convolutions
|
||||
template <typename T>
|
||||
concept IsXdlFwdConv = HasFwdConvLayouts<T> && HasGemmSpec<T> && HasDataTypes<T> &&
|
||||
HasElementwiseOps<T> && HasTileParams<T>;
|
||||
|
||||
// Primary concept for checking if a type can be described
|
||||
// Currently only forward convolutions are supported, but this can be extended
|
||||
// in the future to include backward data and backward weight convolutions
|
||||
template <typename T>
|
||||
concept HasConvTraits = IsXdlFwdConv<InstanceTraits<T>>;
|
||||
|
||||
// Helper metafunctions to convert from ck enums to builder enums
|
||||
|
||||
/// @brief Converts a CK BlockGemmPipelineVersion enum to a builder PipelineVersion enum.
|
||||
@@ -35,16 +86,15 @@ constexpr auto convert_pipeline_version()
|
||||
{
|
||||
using enum ck::BlockGemmPipelineVersion;
|
||||
using enum builder::PipelineVersion;
|
||||
if constexpr(ck_ver == v1)
|
||||
return V1;
|
||||
else if constexpr(ck_ver == v2)
|
||||
return V2;
|
||||
else if constexpr(ck_ver == v3)
|
||||
return V3;
|
||||
else if constexpr(ck_ver == v4)
|
||||
return V4;
|
||||
else if constexpr(ck_ver == v5)
|
||||
return V5;
|
||||
|
||||
switch(ck_ver)
|
||||
{
|
||||
case v1: return V1;
|
||||
case v2: return V2;
|
||||
case v3: return V3;
|
||||
case v4: return V4;
|
||||
case v5: return V5;
|
||||
}
|
||||
}
|
||||
|
||||
/// @brief Converts a CK PipelineVersion enum to a builder PipelineVersion enum.
|
||||
@@ -59,14 +109,14 @@ constexpr auto convert_pipeline_version()
|
||||
{
|
||||
using enum ck::PipelineVersion;
|
||||
using enum builder::PipelineVersion;
|
||||
if constexpr(ck_ver == v1)
|
||||
return V1;
|
||||
else if constexpr(ck_ver == v2)
|
||||
return V2;
|
||||
else if constexpr(ck_ver == v4)
|
||||
return V4;
|
||||
else if constexpr(ck_ver == weight_only)
|
||||
return WEIGHT_ONLY;
|
||||
|
||||
switch(ck_ver)
|
||||
{
|
||||
case v1: return V1;
|
||||
case v2: return V2;
|
||||
case v4: return V4;
|
||||
case weight_only: return WEIGHT_ONLY;
|
||||
}
|
||||
}
|
||||
|
||||
/// @brief Converts a CK BlockGemmPipelineScheduler enum to a builder PipelineScheduler enum.
|
||||
@@ -82,10 +132,12 @@ constexpr auto convert_pipeline_scheduler()
|
||||
{
|
||||
using enum ck::BlockGemmPipelineScheduler;
|
||||
using enum builder::PipelineScheduler;
|
||||
if constexpr(ck_sched == Intrawave)
|
||||
return INTRAWAVE;
|
||||
else if constexpr(ck_sched == Interwave)
|
||||
return INTERWAVE;
|
||||
|
||||
switch(ck_sched)
|
||||
{
|
||||
case Intrawave: return INTRAWAVE;
|
||||
case Interwave: return INTERWAVE;
|
||||
}
|
||||
}
|
||||
|
||||
/// @brief Converts a CK LoopScheduler enum to a builder PipelineScheduler enum.
|
||||
@@ -101,10 +153,12 @@ constexpr auto convert_pipeline_scheduler()
|
||||
{
|
||||
using enum ck::LoopScheduler;
|
||||
using enum builder::PipelineScheduler;
|
||||
if constexpr(ck_sched == Default)
|
||||
return DEFAULT;
|
||||
else if constexpr(ck_sched == Interwave)
|
||||
return INTERWAVE;
|
||||
|
||||
switch(ck_sched)
|
||||
{
|
||||
case Default: return DEFAULT;
|
||||
case Interwave: return INTERWAVE;
|
||||
}
|
||||
}
|
||||
|
||||
/// @brief Helper structures for organizing trait data with domain-specific naming
|
||||
@@ -213,21 +267,13 @@ constexpr builder::ConvDirection conv_direction()
|
||||
using InstTraits = InstanceTraits<Instance>;
|
||||
|
||||
if constexpr(requires { &InstTraits::kConvForwardSpecialization; })
|
||||
{
|
||||
return builder::ConvDirection::FORWARD;
|
||||
}
|
||||
else if constexpr(requires { &InstTraits::kConvBwdDataSpecialization; })
|
||||
{
|
||||
return builder::ConvDirection::BACKWARD_DATA;
|
||||
}
|
||||
else if constexpr(requires { &InstTraits::kConvBwdWeightSpecialization; })
|
||||
{
|
||||
return builder::ConvDirection::BACKWARD_WEIGHT;
|
||||
}
|
||||
else
|
||||
{
|
||||
return builder::ConvDirection::FORWARD; // Default fallback
|
||||
}
|
||||
}
|
||||
|
||||
/// @brief Derives the convolution-specific specialization from a device kernel `Instance` type.
|
||||
@@ -242,60 +288,52 @@ constexpr auto conv_spec()
|
||||
if constexpr(requires { InstTraits::kConvForwardSpecialization; })
|
||||
{
|
||||
using enum ck::tensor_operation::device::ConvolutionForwardSpecialization;
|
||||
using enum builder::ConvFwdSpecialization;
|
||||
|
||||
if constexpr(InstTraits::kConvForwardSpecialization == Default)
|
||||
switch(InstTraits::kConvForwardSpecialization)
|
||||
{
|
||||
return builder::ConvFwdSpecialization::DEFAULT;
|
||||
}
|
||||
else if constexpr(InstTraits::kConvForwardSpecialization == Filter1x1Pad0)
|
||||
{
|
||||
return builder::ConvFwdSpecialization::FILTER_1X1_PAD0;
|
||||
}
|
||||
else if constexpr(InstTraits::kConvForwardSpecialization == Filter1x1Stride1Pad0)
|
||||
{
|
||||
return builder::ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0;
|
||||
}
|
||||
else if constexpr(InstTraits::kConvForwardSpecialization == Filter3x3)
|
||||
{
|
||||
return builder::ConvFwdSpecialization::FILTER_3x3;
|
||||
case Default: return DEFAULT;
|
||||
case Filter1x1Pad0: return FILTER_1X1_PAD0;
|
||||
case Filter1x1Stride1Pad0: return FILTER_1X1_STRIDE1_PAD0;
|
||||
case Filter3x3: return FILTER_3x3;
|
||||
}
|
||||
}
|
||||
else if constexpr(requires { InstTraits::kConvBwdDataSpecialization; })
|
||||
{
|
||||
using enum ck::tensor_operation::device::ConvolutionBackwardDataSpecialization;
|
||||
using enum builder::ConvBwdDataSpecialization;
|
||||
|
||||
if constexpr(InstTraits::kConvBwdDataSpecialization == Default)
|
||||
switch(InstTraits::kConvBwdDataSpecialization)
|
||||
{
|
||||
return builder::ConvBwdDataSpecialization::DEFAULT;
|
||||
}
|
||||
else if constexpr(InstTraits::kConvBwdDataSpecialization == Filter1x1Stride1Pad0)
|
||||
{
|
||||
return builder::ConvBwdDataSpecialization::FILTER_1X1_STRIDE1_PAD0;
|
||||
case Default: return DEFAULT;
|
||||
case Filter1x1Stride1Pad0: return FILTER_1X1_STRIDE1_PAD0;
|
||||
}
|
||||
}
|
||||
else if constexpr(requires { InstTraits::kConvBwdWeightSpecialization; })
|
||||
{
|
||||
using enum ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization;
|
||||
using enum builder::ConvBwdWeightSpecialization;
|
||||
|
||||
if constexpr(InstTraits::kConvBwdWeightSpecialization == Default)
|
||||
switch(InstTraits::kConvBwdWeightSpecialization)
|
||||
{
|
||||
return builder::ConvBwdWeightSpecialization::DEFAULT;
|
||||
}
|
||||
else if constexpr(InstTraits::kConvBwdWeightSpecialization == Filter1x1Stride1Pad0)
|
||||
{
|
||||
return builder::ConvBwdWeightSpecialization::FILTER_1X1_STRIDE1_PAD0;
|
||||
}
|
||||
else if constexpr(InstTraits::kConvBwdWeightSpecialization == Filter1x1Pad0)
|
||||
{
|
||||
return builder::ConvBwdWeightSpecialization::FILTER_1X1_PAD0;
|
||||
}
|
||||
else if constexpr(InstTraits::kConvBwdWeightSpecialization == OddC)
|
||||
{
|
||||
return builder::ConvBwdWeightSpecialization::ODD_C;
|
||||
case Default: return DEFAULT;
|
||||
case Filter1x1Stride1Pad0: return FILTER_1X1_STRIDE1_PAD0;
|
||||
case Filter1x1Pad0: return FILTER_1X1_PAD0;
|
||||
case OddC: return ODD_C;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Helper variable template to check if CK layout enums match
|
||||
template <typename A,
|
||||
typename B,
|
||||
typename E,
|
||||
typename ExpectedA,
|
||||
typename ExpectedB,
|
||||
typename ExpectedE>
|
||||
inline constexpr bool layouts_are =
|
||||
std::is_same_v<A, ExpectedA> && std::is_same_v<B, ExpectedB> && std::is_same_v<E, ExpectedE>;
|
||||
|
||||
/// @brief Derives the grouped convolution layout from a device kernel `Instance` type.
|
||||
/// @tparam Instance The device kernel instance type.
|
||||
/// @return An std::array corresponding to the tensor layouts:
|
||||
@@ -304,112 +342,49 @@ constexpr auto conv_spec()
|
||||
/// index 2 -> Output layout
|
||||
template <typename Instance>
|
||||
constexpr auto conv_layout()
|
||||
requires HasFwdConvLayouts<InstanceTraits<Instance>>
|
||||
{
|
||||
using InstTraits = InstanceTraits<Instance>;
|
||||
using ALayout = typename InstTraits::ALayout;
|
||||
using BLayout = typename InstTraits::BLayout;
|
||||
using ELayout = typename InstTraits::ELayout;
|
||||
// Helper lambda to construct layout array
|
||||
auto layouts = [](auto... Ls) { return std::array<builder::TensorLayout, 3>{Ls...}; };
|
||||
|
||||
namespace ctc = ck::tensor_layout::convolution;
|
||||
using A = typename InstanceTraits<Instance>::ALayout;
|
||||
using B = typename InstanceTraits<Instance>::BLayout;
|
||||
using E = typename InstanceTraits<Instance>::ELayout;
|
||||
namespace ctl = ck::tensor_layout::convolution;
|
||||
using enum builder::TensorLayout;
|
||||
|
||||
if constexpr(InstTraits::kSpatialDim == 1)
|
||||
switch(InstanceTraits<Instance>::kSpatialDim)
|
||||
{
|
||||
if constexpr(std::is_same_v<ALayout, ctc::GNWC> && std::is_same_v<BLayout, ctc::GKXC> &&
|
||||
std::is_same_v<ELayout, ctc::GNWK>)
|
||||
{
|
||||
return std::array<builder::TensorLayout, 3>{builder::TensorLayout::GNWC,
|
||||
builder::TensorLayout::GKXC,
|
||||
builder::TensorLayout::GNWK};
|
||||
}
|
||||
else if constexpr(std::is_same_v<ALayout, ctc::NWGC> &&
|
||||
std::is_same_v<BLayout, ctc::GKXC> && std::is_same_v<ELayout, ctc::NWGK>)
|
||||
{
|
||||
return std::array<builder::TensorLayout, 3>{builder::TensorLayout::NWGC,
|
||||
builder::TensorLayout::GKXC,
|
||||
builder::TensorLayout::NWGK};
|
||||
}
|
||||
else if constexpr(std::is_same_v<ALayout, ctc::NGCW> &&
|
||||
std::is_same_v<BLayout, ctc::GKXC> && std::is_same_v<ELayout, ctc::NGKW>)
|
||||
{
|
||||
return std::array<builder::TensorLayout, 3>{builder::TensorLayout::NGCW,
|
||||
builder::TensorLayout::GKXC,
|
||||
builder::TensorLayout::NGKW};
|
||||
}
|
||||
else if constexpr(std::is_same_v<ALayout, ctc::NGCW> &&
|
||||
std::is_same_v<BLayout, ctc::GKCX> && std::is_same_v<ELayout, ctc::NGKW>)
|
||||
{
|
||||
return std::array<builder::TensorLayout, 3>{builder::TensorLayout::NGCW,
|
||||
builder::TensorLayout::GKCX,
|
||||
builder::TensorLayout::NGKW};
|
||||
}
|
||||
}
|
||||
else if constexpr(InstTraits::kSpatialDim == 2)
|
||||
{
|
||||
if constexpr(std::is_same_v<ALayout, ctc::GNHWC> && std::is_same_v<BLayout, ctc::GKYXC> &&
|
||||
std::is_same_v<ELayout, ctc::GNHWK>)
|
||||
{
|
||||
return std::array<builder::TensorLayout, 3>{builder::TensorLayout::GNHWC,
|
||||
builder::TensorLayout::GKYXC,
|
||||
builder::TensorLayout::GNHWK};
|
||||
}
|
||||
else if constexpr(std::is_same_v<ALayout, ctc::NHWGC> &&
|
||||
std::is_same_v<BLayout, ctc::GKYXC> &&
|
||||
std::is_same_v<ELayout, ctc::NHWGK>)
|
||||
{
|
||||
return std::array<builder::TensorLayout, 3>{builder::TensorLayout::NHWGC,
|
||||
builder::TensorLayout::GKYXC,
|
||||
builder::TensorLayout::NHWGK};
|
||||
}
|
||||
else if constexpr(std::is_same_v<ALayout, ctc::NGCHW> &&
|
||||
std::is_same_v<BLayout, ctc::GKYXC> &&
|
||||
std::is_same_v<ELayout, ctc::NGKHW>)
|
||||
{
|
||||
return std::array<builder::TensorLayout, 3>{builder::TensorLayout::NGCHW,
|
||||
builder::TensorLayout::GKYXC,
|
||||
builder::TensorLayout::NGKHW};
|
||||
}
|
||||
else if constexpr(std::is_same_v<ALayout, ctc::NGCHW> &&
|
||||
std::is_same_v<BLayout, ctc::GKCYX> &&
|
||||
std::is_same_v<ELayout, ctc::NGKHW>)
|
||||
{
|
||||
return std::array<builder::TensorLayout, 3>{builder::TensorLayout::NGCHW,
|
||||
builder::TensorLayout::GKCYX,
|
||||
builder::TensorLayout::NGKHW};
|
||||
}
|
||||
}
|
||||
else if constexpr(InstTraits::kSpatialDim == 3)
|
||||
{
|
||||
if constexpr(std::is_same_v<ALayout, ctc::GNDHWC> && std::is_same_v<BLayout, ctc::GKZYXC> &&
|
||||
std::is_same_v<ELayout, ctc::GNDHWK>)
|
||||
{
|
||||
return std::array<builder::TensorLayout, 3>{builder::TensorLayout::GNDHWC,
|
||||
builder::TensorLayout::GKZYXC,
|
||||
builder::TensorLayout::GNDHWK};
|
||||
}
|
||||
else if constexpr(std::is_same_v<ALayout, ctc::NDHWGC> &&
|
||||
std::is_same_v<BLayout, ctc::GKZYXC> &&
|
||||
std::is_same_v<ELayout, ctc::NDHWGK>)
|
||||
{
|
||||
return std::array<builder::TensorLayout, 3>{builder::TensorLayout::NDHWGC,
|
||||
builder::TensorLayout::GKZYXC,
|
||||
builder::TensorLayout::NDHWGK};
|
||||
}
|
||||
else if constexpr(std::is_same_v<ALayout, ctc::NGCDHW> &&
|
||||
std::is_same_v<BLayout, ctc::GKZYXC> &&
|
||||
std::is_same_v<ELayout, ctc::NGKDHW>)
|
||||
{
|
||||
return std::array<builder::TensorLayout, 3>{builder::TensorLayout::NGCDHW,
|
||||
builder::TensorLayout::GKZYXC,
|
||||
builder::TensorLayout::NGKDHW};
|
||||
}
|
||||
else if constexpr(std::is_same_v<ALayout, ctc::NGCDHW> &&
|
||||
std::is_same_v<BLayout, ctc::GKCZYX> &&
|
||||
std::is_same_v<ELayout, ctc::NGKDHW>)
|
||||
{
|
||||
return std::array<builder::TensorLayout, 3>{builder::TensorLayout::NGCDHW,
|
||||
builder::TensorLayout::GKCZYX,
|
||||
builder::TensorLayout::NGKDHW};
|
||||
}
|
||||
case 1:
|
||||
if constexpr(layouts_are<A, B, E, ctl::GNWC, ctl::GKXC, ctl::GNWK>)
|
||||
return layouts(GNWC, GKXC, GNWK);
|
||||
if constexpr(layouts_are<A, B, E, ctl::NWGC, ctl::GKXC, ctl::NWGK>)
|
||||
return layouts(NWGC, GKXC, NWGK);
|
||||
if constexpr(layouts_are<A, B, E, ctl::NGCW, ctl::GKXC, ctl::NGKW>)
|
||||
return layouts(NGCW, GKXC, NGKW);
|
||||
if constexpr(layouts_are<A, B, E, ctl::NGCW, ctl::GKCX, ctl::NGKW>)
|
||||
return layouts(NGCW, GKCX, NGKW);
|
||||
break;
|
||||
case 2:
|
||||
if constexpr(layouts_are<A, B, E, ctl::GNHWC, ctl::GKYXC, ctl::GNHWK>)
|
||||
return layouts(GNHWC, GKYXC, GNHWK);
|
||||
if constexpr(layouts_are<A, B, E, ctl::NHWGC, ctl::GKYXC, ctl::NHWGK>)
|
||||
return layouts(NHWGC, GKYXC, NHWGK);
|
||||
if constexpr(layouts_are<A, B, E, ctl::NGCHW, ctl::GKYXC, ctl::NGKHW>)
|
||||
return layouts(NGCHW, GKYXC, NGKHW);
|
||||
if constexpr(layouts_are<A, B, E, ctl::NGCHW, ctl::GKCYX, ctl::NGKHW>)
|
||||
return layouts(NGCHW, GKCYX, NGKHW);
|
||||
break;
|
||||
case 3:
|
||||
if constexpr(layouts_are<A, B, E, ctl::GNDHWC, ctl::GKZYXC, ctl::GNDHWK>)
|
||||
return layouts(GNDHWC, GKZYXC, GNDHWK);
|
||||
if constexpr(layouts_are<A, B, E, ctl::NDHWGC, ctl::GKZYXC, ctl::NDHWGK>)
|
||||
return layouts(NDHWGC, GKZYXC, NDHWGK);
|
||||
if constexpr(layouts_are<A, B, E, ctl::NGCDHW, ctl::GKZYXC, ctl::NGKDHW>)
|
||||
return layouts(NGCDHW, GKZYXC, NGKDHW);
|
||||
if constexpr(layouts_are<A, B, E, ctl::NGCDHW, ctl::GKCZYX, ctl::NGKDHW>)
|
||||
return layouts(NGCDHW, GKCZYX, NGKDHW);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -418,39 +393,26 @@ constexpr auto conv_layout()
|
||||
/// @return A `builder::DataType` enum value (e.g., FP16, BF16, FP32).
|
||||
template <typename Instance>
|
||||
constexpr builder::DataType conv_data_type()
|
||||
requires HasDataTypes<InstanceTraits<Instance>>
|
||||
{
|
||||
using InstTraits = InstanceTraits<Instance>;
|
||||
using ADataType = typename InstTraits::ADataType;
|
||||
using enum builder::DataType;
|
||||
|
||||
if constexpr(std::is_same_v<ADataType, ck::half_t>)
|
||||
{
|
||||
return builder::DataType::FP16;
|
||||
}
|
||||
return FP16;
|
||||
else if constexpr(std::is_same_v<ADataType, ck::bhalf_t>)
|
||||
{
|
||||
return builder::DataType::BF16;
|
||||
}
|
||||
return BF16;
|
||||
else if constexpr(std::is_same_v<ADataType, float>)
|
||||
{
|
||||
return builder::DataType::FP32;
|
||||
}
|
||||
return FP32;
|
||||
else if constexpr(std::is_same_v<ADataType, ck::f8_t>)
|
||||
{
|
||||
return builder::DataType::FP8;
|
||||
}
|
||||
return FP8;
|
||||
else if constexpr(std::is_same_v<ADataType, int8_t>)
|
||||
{
|
||||
return builder::DataType::I8;
|
||||
}
|
||||
return I8;
|
||||
else if constexpr(std::is_same_v<ADataType, uint8_t>)
|
||||
{
|
||||
return builder::DataType::U8;
|
||||
}
|
||||
return U8;
|
||||
else
|
||||
{
|
||||
// Default fallback
|
||||
return builder::DataType::FP32;
|
||||
}
|
||||
return FP32; // Default fallback
|
||||
}
|
||||
|
||||
/// @brief Derives the elementwise operation from op type.
|
||||
@@ -459,27 +421,19 @@ constexpr builder::DataType conv_data_type()
|
||||
template <typename ElementwiseOp>
|
||||
constexpr builder::ElementwiseOperation elementwise_op()
|
||||
{
|
||||
using enum builder::ElementwiseOperation;
|
||||
constexpr std::string_view name = detail::elementwise_op_name<ElementwiseOp>();
|
||||
|
||||
if constexpr(detail::case_insensitive_equal(name, "BiasBnormClamp"))
|
||||
{
|
||||
return builder::ElementwiseOperation::BIAS_BNORM_CLAMP;
|
||||
}
|
||||
else if constexpr(detail::case_insensitive_equal(name, "Clamp"))
|
||||
{
|
||||
return builder::ElementwiseOperation::CLAMP;
|
||||
}
|
||||
else if constexpr(detail::case_insensitive_equal(name, "Scale"))
|
||||
{
|
||||
return builder::ElementwiseOperation::SCALE;
|
||||
}
|
||||
else if constexpr(detail::case_insensitive_equal(name, "PassThrough"))
|
||||
{
|
||||
return builder::ElementwiseOperation::PASS_THROUGH;
|
||||
}
|
||||
else if constexpr(detail::case_insensitive_equal(name, "ScaleAddScaleAddRelu"))
|
||||
{
|
||||
return builder::ElementwiseOperation::SCALEADD_SCALEADD_RELU;
|
||||
}
|
||||
return BIAS_BNORM_CLAMP;
|
||||
if constexpr(detail::case_insensitive_equal(name, "Clamp"))
|
||||
return CLAMP;
|
||||
if constexpr(detail::case_insensitive_equal(name, "Scale"))
|
||||
return SCALE;
|
||||
if constexpr(detail::case_insensitive_equal(name, "PassThrough"))
|
||||
return PASS_THROUGH;
|
||||
if constexpr(detail::case_insensitive_equal(name, "ScaleAddScaleAddRelu"))
|
||||
return SCALEADD_SCALEADD_RELU;
|
||||
}
|
||||
|
||||
/// @brief Derives a gemm padding from a kernel instance type.
|
||||
@@ -487,6 +441,7 @@ constexpr builder::ElementwiseOperation elementwise_op()
|
||||
/// @return A `builder::GemmPadding` enum value corresponding to kernel padding.
|
||||
template <typename Instance>
|
||||
constexpr builder::GemmPadding gemm_spec()
|
||||
requires HasGemmSpec<InstanceTraits<Instance>>
|
||||
{
|
||||
using InstTraits = InstanceTraits<Instance>;
|
||||
using enum builder::GemmPadding;
|
||||
@@ -494,69 +449,24 @@ constexpr builder::GemmPadding gemm_spec()
|
||||
|
||||
constexpr auto gemm_spec = InstTraits::kGemmSpecialization;
|
||||
|
||||
if constexpr(gemm_spec == Default)
|
||||
switch(gemm_spec)
|
||||
{
|
||||
return DEFAULT;
|
||||
}
|
||||
else if constexpr(gemm_spec == MPadding)
|
||||
{
|
||||
return M_PADDING;
|
||||
}
|
||||
else if constexpr(gemm_spec == NPadding)
|
||||
{
|
||||
return N_PADDING;
|
||||
}
|
||||
else if constexpr(gemm_spec == KPadding)
|
||||
{
|
||||
return K_PADDING;
|
||||
}
|
||||
else if constexpr(gemm_spec == MNPadding)
|
||||
{
|
||||
return MN_PADDING;
|
||||
}
|
||||
else if constexpr(gemm_spec == MKPadding)
|
||||
{
|
||||
return MK_PADDING;
|
||||
}
|
||||
else if constexpr(gemm_spec == NKPadding)
|
||||
{
|
||||
return NK_PADDING;
|
||||
}
|
||||
else if constexpr(gemm_spec == MNKPadding)
|
||||
{
|
||||
return MNK_PADDING;
|
||||
}
|
||||
else if constexpr(gemm_spec == OPadding)
|
||||
{
|
||||
return O_PADDING;
|
||||
}
|
||||
else if constexpr(gemm_spec == MOPadding)
|
||||
{
|
||||
return MO_PADDING;
|
||||
}
|
||||
else if constexpr(gemm_spec == NOPadding)
|
||||
{
|
||||
return NO_PADDING;
|
||||
}
|
||||
else if constexpr(gemm_spec == KOPadding)
|
||||
{
|
||||
return KO_PADDING;
|
||||
}
|
||||
else if constexpr(gemm_spec == MNOPadding)
|
||||
{
|
||||
return MNO_PADDING;
|
||||
}
|
||||
else if constexpr(gemm_spec == MKOPadding)
|
||||
{
|
||||
return MKO_PADDING;
|
||||
}
|
||||
else if constexpr(gemm_spec == NKOPadding)
|
||||
{
|
||||
return NKO_PADDING;
|
||||
}
|
||||
else if constexpr(gemm_spec == MNKOPadding)
|
||||
{
|
||||
return MNKO_PADDING;
|
||||
case Default: return DEFAULT;
|
||||
case MPadding: return M_PADDING;
|
||||
case NPadding: return N_PADDING;
|
||||
case KPadding: return K_PADDING;
|
||||
case MNPadding: return MN_PADDING;
|
||||
case MKPadding: return MK_PADDING;
|
||||
case NKPadding: return NK_PADDING;
|
||||
case MNKPadding: return MNK_PADDING;
|
||||
case OPadding: return O_PADDING;
|
||||
case MOPadding: return MO_PADDING;
|
||||
case NOPadding: return NO_PADDING;
|
||||
case KOPadding: return KO_PADDING;
|
||||
case MNOPadding: return MNO_PADDING;
|
||||
case MKOPadding: return MKO_PADDING;
|
||||
case NKOPadding: return NKO_PADDING;
|
||||
case MNKOPadding: return MNKO_PADDING;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -571,6 +481,7 @@ struct ConvTraits;
|
||||
/// set of traits directly from a fully-formed device kernel `Instance` type.
|
||||
/// It uses `InstanceTraits` to access the kernel's template parameters.
|
||||
template <HasInstanceTraits Instance>
|
||||
requires IsXdlFwdConv<InstanceTraits<Instance>>
|
||||
struct ConvTraits<Instance>
|
||||
{
|
||||
using InstTraits = InstanceTraits<Instance>;
|
||||
|
||||
Reference in New Issue
Block a user