mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 18:17:44 +00:00
Merge commit '13f6d635653bd5ffbfcac8577f1ef09590c23d78' into develop
This commit is contained in:
@@ -251,14 +251,10 @@ class ConvDescription : public Description
|
||||
};
|
||||
} // namespace conv
|
||||
|
||||
/// @brief Helper concept to detect if a type has ConvTraits specialization
|
||||
template <typename T>
|
||||
concept HasConvTraits = requires { typename conv::ConvTraits<T>; };
|
||||
|
||||
/// @brief Factory function to create ConvDescription from a convolution instance type
|
||||
/// @tparam Instance The convolution instance type (must have InstanceTraits specialization)
|
||||
/// @tparam Instance The convolution instance type (must have ConvTraits specialization)
|
||||
/// @return A ConvDescription object populated with the instance's configuration details
|
||||
template <HasConvTraits Instance>
|
||||
template <conv::HasConvTraits Instance>
|
||||
conv::ConvDescription describe()
|
||||
{
|
||||
using Traits = conv::ConvTraits<Instance>;
|
||||
|
||||
@@ -21,6 +21,57 @@
|
||||
|
||||
namespace ck_tile::reflect::conv {
|
||||
|
||||
// Forward convolution layout concept - checks for A/B/E layout types
|
||||
template <typename T>
|
||||
concept HasFwdConvLayouts = requires {
|
||||
typename T::ALayout;
|
||||
typename T::BLayout;
|
||||
typename T::ELayout;
|
||||
};
|
||||
|
||||
// GEMM specialization concept - checks for kGemmSpecialization member
|
||||
template <typename T>
|
||||
concept HasGemmSpec = requires {
|
||||
{
|
||||
T::kGemmSpecialization
|
||||
} -> std::convertible_to<ck::tensor_operation::device::GemmSpecialization>;
|
||||
};
|
||||
|
||||
// Data types concept - checks for ADataType member
|
||||
template <typename T>
|
||||
concept HasDataTypes = requires { typename T::ADataType; };
|
||||
|
||||
// Elementwise operations concept - checks for A/B/CDE elementwise operation types
|
||||
template <typename T>
|
||||
concept HasElementwiseOps = requires {
|
||||
typename T::AElementwiseOperation;
|
||||
typename T::BElementwiseOperation;
|
||||
typename T::CDEElementwiseOperation;
|
||||
};
|
||||
|
||||
// Tile parameters concept - checks for tile dimension and transfer members
|
||||
template <typename T>
|
||||
concept HasTileParams = requires {
|
||||
{ T::kKPerBlock } -> std::convertible_to<int>;
|
||||
{ T::kMPerBlock } -> std::convertible_to<int>;
|
||||
{ T::kNPerBlock } -> std::convertible_to<int>;
|
||||
{ T::kAK1 } -> std::convertible_to<int>;
|
||||
{ T::kBK1 } -> std::convertible_to<int>;
|
||||
T::kCThreadClusterLengths;
|
||||
};
|
||||
|
||||
// Comprehensive concept that checks if an instance has all XDL forward convolution traits
|
||||
// This concept is used to constrain ConvTraits specialization that expect XDL forward convolutions
|
||||
template <typename T>
|
||||
concept IsXdlFwdConv = HasFwdConvLayouts<T> && HasGemmSpec<T> && HasDataTypes<T> &&
|
||||
HasElementwiseOps<T> && HasTileParams<T>;
|
||||
|
||||
// Primary concept for checking if a type can be described
|
||||
// Currently only forward convolutions are supported, but this can be extended
|
||||
// in the future to include backward data and backward weight convolutions
|
||||
template <typename T>
|
||||
concept HasConvTraits = IsXdlFwdConv<InstanceTraits<T>>;
|
||||
|
||||
// Helper metafunctions to convert from ck enums to builder enums
|
||||
|
||||
/// @brief Converts a CK BlockGemmPipelineVersion enum to a builder PipelineVersion enum.
|
||||
@@ -35,16 +86,15 @@ constexpr auto convert_pipeline_version()
|
||||
{
|
||||
using enum ck::BlockGemmPipelineVersion;
|
||||
using enum builder::PipelineVersion;
|
||||
if constexpr(ck_ver == v1)
|
||||
return V1;
|
||||
else if constexpr(ck_ver == v2)
|
||||
return V2;
|
||||
else if constexpr(ck_ver == v3)
|
||||
return V3;
|
||||
else if constexpr(ck_ver == v4)
|
||||
return V4;
|
||||
else if constexpr(ck_ver == v5)
|
||||
return V5;
|
||||
|
||||
switch(ck_ver)
|
||||
{
|
||||
case v1: return V1;
|
||||
case v2: return V2;
|
||||
case v3: return V3;
|
||||
case v4: return V4;
|
||||
case v5: return V5;
|
||||
}
|
||||
}
|
||||
|
||||
/// @brief Converts a CK PipelineVersion enum to a builder PipelineVersion enum.
|
||||
@@ -59,14 +109,14 @@ constexpr auto convert_pipeline_version()
|
||||
{
|
||||
using enum ck::PipelineVersion;
|
||||
using enum builder::PipelineVersion;
|
||||
if constexpr(ck_ver == v1)
|
||||
return V1;
|
||||
else if constexpr(ck_ver == v2)
|
||||
return V2;
|
||||
else if constexpr(ck_ver == v4)
|
||||
return V4;
|
||||
else if constexpr(ck_ver == weight_only)
|
||||
return WEIGHT_ONLY;
|
||||
|
||||
switch(ck_ver)
|
||||
{
|
||||
case v1: return V1;
|
||||
case v2: return V2;
|
||||
case v4: return V4;
|
||||
case weight_only: return WEIGHT_ONLY;
|
||||
}
|
||||
}
|
||||
|
||||
/// @brief Converts a CK BlockGemmPipelineScheduler enum to a builder PipelineScheduler enum.
|
||||
@@ -82,10 +132,12 @@ constexpr auto convert_pipeline_scheduler()
|
||||
{
|
||||
using enum ck::BlockGemmPipelineScheduler;
|
||||
using enum builder::PipelineScheduler;
|
||||
if constexpr(ck_sched == Intrawave)
|
||||
return INTRAWAVE;
|
||||
else if constexpr(ck_sched == Interwave)
|
||||
return INTERWAVE;
|
||||
|
||||
switch(ck_sched)
|
||||
{
|
||||
case Intrawave: return INTRAWAVE;
|
||||
case Interwave: return INTERWAVE;
|
||||
}
|
||||
}
|
||||
|
||||
/// @brief Converts a CK LoopScheduler enum to a builder PipelineScheduler enum.
|
||||
@@ -101,10 +153,12 @@ constexpr auto convert_pipeline_scheduler()
|
||||
{
|
||||
using enum ck::LoopScheduler;
|
||||
using enum builder::PipelineScheduler;
|
||||
if constexpr(ck_sched == Default)
|
||||
return DEFAULT;
|
||||
else if constexpr(ck_sched == Interwave)
|
||||
return INTERWAVE;
|
||||
|
||||
switch(ck_sched)
|
||||
{
|
||||
case Default: return DEFAULT;
|
||||
case Interwave: return INTERWAVE;
|
||||
}
|
||||
}
|
||||
|
||||
/// @brief Helper structures for organizing trait data with domain-specific naming
|
||||
@@ -213,21 +267,13 @@ constexpr builder::ConvDirection conv_direction()
|
||||
using InstTraits = InstanceTraits<Instance>;
|
||||
|
||||
if constexpr(requires { &InstTraits::kConvForwardSpecialization; })
|
||||
{
|
||||
return builder::ConvDirection::FORWARD;
|
||||
}
|
||||
else if constexpr(requires { &InstTraits::kConvBwdDataSpecialization; })
|
||||
{
|
||||
return builder::ConvDirection::BACKWARD_DATA;
|
||||
}
|
||||
else if constexpr(requires { &InstTraits::kConvBwdWeightSpecialization; })
|
||||
{
|
||||
return builder::ConvDirection::BACKWARD_WEIGHT;
|
||||
}
|
||||
else
|
||||
{
|
||||
return builder::ConvDirection::FORWARD; // Default fallback
|
||||
}
|
||||
}
|
||||
|
||||
/// @brief Derives the convolution-specific specialization from a device kernel `Instance` type.
|
||||
@@ -242,60 +288,52 @@ constexpr auto conv_spec()
|
||||
if constexpr(requires { InstTraits::kConvForwardSpecialization; })
|
||||
{
|
||||
using enum ck::tensor_operation::device::ConvolutionForwardSpecialization;
|
||||
using enum builder::ConvFwdSpecialization;
|
||||
|
||||
if constexpr(InstTraits::kConvForwardSpecialization == Default)
|
||||
switch(InstTraits::kConvForwardSpecialization)
|
||||
{
|
||||
return builder::ConvFwdSpecialization::DEFAULT;
|
||||
}
|
||||
else if constexpr(InstTraits::kConvForwardSpecialization == Filter1x1Pad0)
|
||||
{
|
||||
return builder::ConvFwdSpecialization::FILTER_1X1_PAD0;
|
||||
}
|
||||
else if constexpr(InstTraits::kConvForwardSpecialization == Filter1x1Stride1Pad0)
|
||||
{
|
||||
return builder::ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0;
|
||||
}
|
||||
else if constexpr(InstTraits::kConvForwardSpecialization == Filter3x3)
|
||||
{
|
||||
return builder::ConvFwdSpecialization::FILTER_3x3;
|
||||
case Default: return DEFAULT;
|
||||
case Filter1x1Pad0: return FILTER_1X1_PAD0;
|
||||
case Filter1x1Stride1Pad0: return FILTER_1X1_STRIDE1_PAD0;
|
||||
case Filter3x3: return FILTER_3x3;
|
||||
}
|
||||
}
|
||||
else if constexpr(requires { InstTraits::kConvBwdDataSpecialization; })
|
||||
{
|
||||
using enum ck::tensor_operation::device::ConvolutionBackwardDataSpecialization;
|
||||
using enum builder::ConvBwdDataSpecialization;
|
||||
|
||||
if constexpr(InstTraits::kConvBwdDataSpecialization == Default)
|
||||
switch(InstTraits::kConvBwdDataSpecialization)
|
||||
{
|
||||
return builder::ConvBwdDataSpecialization::DEFAULT;
|
||||
}
|
||||
else if constexpr(InstTraits::kConvBwdDataSpecialization == Filter1x1Stride1Pad0)
|
||||
{
|
||||
return builder::ConvBwdDataSpecialization::FILTER_1X1_STRIDE1_PAD0;
|
||||
case Default: return DEFAULT;
|
||||
case Filter1x1Stride1Pad0: return FILTER_1X1_STRIDE1_PAD0;
|
||||
}
|
||||
}
|
||||
else if constexpr(requires { InstTraits::kConvBwdWeightSpecialization; })
|
||||
{
|
||||
using enum ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization;
|
||||
using enum builder::ConvBwdWeightSpecialization;
|
||||
|
||||
if constexpr(InstTraits::kConvBwdWeightSpecialization == Default)
|
||||
switch(InstTraits::kConvBwdWeightSpecialization)
|
||||
{
|
||||
return builder::ConvBwdWeightSpecialization::DEFAULT;
|
||||
}
|
||||
else if constexpr(InstTraits::kConvBwdWeightSpecialization == Filter1x1Stride1Pad0)
|
||||
{
|
||||
return builder::ConvBwdWeightSpecialization::FILTER_1X1_STRIDE1_PAD0;
|
||||
}
|
||||
else if constexpr(InstTraits::kConvBwdWeightSpecialization == Filter1x1Pad0)
|
||||
{
|
||||
return builder::ConvBwdWeightSpecialization::FILTER_1X1_PAD0;
|
||||
}
|
||||
else if constexpr(InstTraits::kConvBwdWeightSpecialization == OddC)
|
||||
{
|
||||
return builder::ConvBwdWeightSpecialization::ODD_C;
|
||||
case Default: return DEFAULT;
|
||||
case Filter1x1Stride1Pad0: return FILTER_1X1_STRIDE1_PAD0;
|
||||
case Filter1x1Pad0: return FILTER_1X1_PAD0;
|
||||
case OddC: return ODD_C;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Helper variable template to check if CK layout enums match
|
||||
template <typename A,
|
||||
typename B,
|
||||
typename E,
|
||||
typename ExpectedA,
|
||||
typename ExpectedB,
|
||||
typename ExpectedE>
|
||||
inline constexpr bool layouts_are =
|
||||
std::is_same_v<A, ExpectedA> && std::is_same_v<B, ExpectedB> && std::is_same_v<E, ExpectedE>;
|
||||
|
||||
/// @brief Derives the grouped convolution layout from a device kernel `Instance` type.
|
||||
/// @tparam Instance The device kernel instance type.
|
||||
/// @return An std::array corresponding to the tensor layouts:
|
||||
@@ -304,112 +342,49 @@ constexpr auto conv_spec()
|
||||
/// index 2 -> Output layout
|
||||
template <typename Instance>
|
||||
constexpr auto conv_layout()
|
||||
requires HasFwdConvLayouts<InstanceTraits<Instance>>
|
||||
{
|
||||
using InstTraits = InstanceTraits<Instance>;
|
||||
using ALayout = typename InstTraits::ALayout;
|
||||
using BLayout = typename InstTraits::BLayout;
|
||||
using ELayout = typename InstTraits::ELayout;
|
||||
// Helper lambda to construct layout array
|
||||
auto layouts = [](auto... Ls) { return std::array<builder::TensorLayout, 3>{Ls...}; };
|
||||
|
||||
namespace ctc = ck::tensor_layout::convolution;
|
||||
using A = typename InstanceTraits<Instance>::ALayout;
|
||||
using B = typename InstanceTraits<Instance>::BLayout;
|
||||
using E = typename InstanceTraits<Instance>::ELayout;
|
||||
namespace ctl = ck::tensor_layout::convolution;
|
||||
using enum builder::TensorLayout;
|
||||
|
||||
if constexpr(InstTraits::kSpatialDim == 1)
|
||||
switch(InstanceTraits<Instance>::kSpatialDim)
|
||||
{
|
||||
if constexpr(std::is_same_v<ALayout, ctc::GNWC> && std::is_same_v<BLayout, ctc::GKXC> &&
|
||||
std::is_same_v<ELayout, ctc::GNWK>)
|
||||
{
|
||||
return std::array<builder::TensorLayout, 3>{builder::TensorLayout::GNWC,
|
||||
builder::TensorLayout::GKXC,
|
||||
builder::TensorLayout::GNWK};
|
||||
}
|
||||
else if constexpr(std::is_same_v<ALayout, ctc::NWGC> &&
|
||||
std::is_same_v<BLayout, ctc::GKXC> && std::is_same_v<ELayout, ctc::NWGK>)
|
||||
{
|
||||
return std::array<builder::TensorLayout, 3>{builder::TensorLayout::NWGC,
|
||||
builder::TensorLayout::GKXC,
|
||||
builder::TensorLayout::NWGK};
|
||||
}
|
||||
else if constexpr(std::is_same_v<ALayout, ctc::NGCW> &&
|
||||
std::is_same_v<BLayout, ctc::GKXC> && std::is_same_v<ELayout, ctc::NGKW>)
|
||||
{
|
||||
return std::array<builder::TensorLayout, 3>{builder::TensorLayout::NGCW,
|
||||
builder::TensorLayout::GKXC,
|
||||
builder::TensorLayout::NGKW};
|
||||
}
|
||||
else if constexpr(std::is_same_v<ALayout, ctc::NGCW> &&
|
||||
std::is_same_v<BLayout, ctc::GKCX> && std::is_same_v<ELayout, ctc::NGKW>)
|
||||
{
|
||||
return std::array<builder::TensorLayout, 3>{builder::TensorLayout::NGCW,
|
||||
builder::TensorLayout::GKCX,
|
||||
builder::TensorLayout::NGKW};
|
||||
}
|
||||
}
|
||||
else if constexpr(InstTraits::kSpatialDim == 2)
|
||||
{
|
||||
if constexpr(std::is_same_v<ALayout, ctc::GNHWC> && std::is_same_v<BLayout, ctc::GKYXC> &&
|
||||
std::is_same_v<ELayout, ctc::GNHWK>)
|
||||
{
|
||||
return std::array<builder::TensorLayout, 3>{builder::TensorLayout::GNHWC,
|
||||
builder::TensorLayout::GKYXC,
|
||||
builder::TensorLayout::GNHWK};
|
||||
}
|
||||
else if constexpr(std::is_same_v<ALayout, ctc::NHWGC> &&
|
||||
std::is_same_v<BLayout, ctc::GKYXC> &&
|
||||
std::is_same_v<ELayout, ctc::NHWGK>)
|
||||
{
|
||||
return std::array<builder::TensorLayout, 3>{builder::TensorLayout::NHWGC,
|
||||
builder::TensorLayout::GKYXC,
|
||||
builder::TensorLayout::NHWGK};
|
||||
}
|
||||
else if constexpr(std::is_same_v<ALayout, ctc::NGCHW> &&
|
||||
std::is_same_v<BLayout, ctc::GKYXC> &&
|
||||
std::is_same_v<ELayout, ctc::NGKHW>)
|
||||
{
|
||||
return std::array<builder::TensorLayout, 3>{builder::TensorLayout::NGCHW,
|
||||
builder::TensorLayout::GKYXC,
|
||||
builder::TensorLayout::NGKHW};
|
||||
}
|
||||
else if constexpr(std::is_same_v<ALayout, ctc::NGCHW> &&
|
||||
std::is_same_v<BLayout, ctc::GKCYX> &&
|
||||
std::is_same_v<ELayout, ctc::NGKHW>)
|
||||
{
|
||||
return std::array<builder::TensorLayout, 3>{builder::TensorLayout::NGCHW,
|
||||
builder::TensorLayout::GKCYX,
|
||||
builder::TensorLayout::NGKHW};
|
||||
}
|
||||
}
|
||||
else if constexpr(InstTraits::kSpatialDim == 3)
|
||||
{
|
||||
if constexpr(std::is_same_v<ALayout, ctc::GNDHWC> && std::is_same_v<BLayout, ctc::GKZYXC> &&
|
||||
std::is_same_v<ELayout, ctc::GNDHWK>)
|
||||
{
|
||||
return std::array<builder::TensorLayout, 3>{builder::TensorLayout::GNDHWC,
|
||||
builder::TensorLayout::GKZYXC,
|
||||
builder::TensorLayout::GNDHWK};
|
||||
}
|
||||
else if constexpr(std::is_same_v<ALayout, ctc::NDHWGC> &&
|
||||
std::is_same_v<BLayout, ctc::GKZYXC> &&
|
||||
std::is_same_v<ELayout, ctc::NDHWGK>)
|
||||
{
|
||||
return std::array<builder::TensorLayout, 3>{builder::TensorLayout::NDHWGC,
|
||||
builder::TensorLayout::GKZYXC,
|
||||
builder::TensorLayout::NDHWGK};
|
||||
}
|
||||
else if constexpr(std::is_same_v<ALayout, ctc::NGCDHW> &&
|
||||
std::is_same_v<BLayout, ctc::GKZYXC> &&
|
||||
std::is_same_v<ELayout, ctc::NGKDHW>)
|
||||
{
|
||||
return std::array<builder::TensorLayout, 3>{builder::TensorLayout::NGCDHW,
|
||||
builder::TensorLayout::GKZYXC,
|
||||
builder::TensorLayout::NGKDHW};
|
||||
}
|
||||
else if constexpr(std::is_same_v<ALayout, ctc::NGCDHW> &&
|
||||
std::is_same_v<BLayout, ctc::GKCZYX> &&
|
||||
std::is_same_v<ELayout, ctc::NGKDHW>)
|
||||
{
|
||||
return std::array<builder::TensorLayout, 3>{builder::TensorLayout::NGCDHW,
|
||||
builder::TensorLayout::GKCZYX,
|
||||
builder::TensorLayout::NGKDHW};
|
||||
}
|
||||
case 1:
|
||||
if constexpr(layouts_are<A, B, E, ctl::GNWC, ctl::GKXC, ctl::GNWK>)
|
||||
return layouts(GNWC, GKXC, GNWK);
|
||||
if constexpr(layouts_are<A, B, E, ctl::NWGC, ctl::GKXC, ctl::NWGK>)
|
||||
return layouts(NWGC, GKXC, NWGK);
|
||||
if constexpr(layouts_are<A, B, E, ctl::NGCW, ctl::GKXC, ctl::NGKW>)
|
||||
return layouts(NGCW, GKXC, NGKW);
|
||||
if constexpr(layouts_are<A, B, E, ctl::NGCW, ctl::GKCX, ctl::NGKW>)
|
||||
return layouts(NGCW, GKCX, NGKW);
|
||||
break;
|
||||
case 2:
|
||||
if constexpr(layouts_are<A, B, E, ctl::GNHWC, ctl::GKYXC, ctl::GNHWK>)
|
||||
return layouts(GNHWC, GKYXC, GNHWK);
|
||||
if constexpr(layouts_are<A, B, E, ctl::NHWGC, ctl::GKYXC, ctl::NHWGK>)
|
||||
return layouts(NHWGC, GKYXC, NHWGK);
|
||||
if constexpr(layouts_are<A, B, E, ctl::NGCHW, ctl::GKYXC, ctl::NGKHW>)
|
||||
return layouts(NGCHW, GKYXC, NGKHW);
|
||||
if constexpr(layouts_are<A, B, E, ctl::NGCHW, ctl::GKCYX, ctl::NGKHW>)
|
||||
return layouts(NGCHW, GKCYX, NGKHW);
|
||||
break;
|
||||
case 3:
|
||||
if constexpr(layouts_are<A, B, E, ctl::GNDHWC, ctl::GKZYXC, ctl::GNDHWK>)
|
||||
return layouts(GNDHWC, GKZYXC, GNDHWK);
|
||||
if constexpr(layouts_are<A, B, E, ctl::NDHWGC, ctl::GKZYXC, ctl::NDHWGK>)
|
||||
return layouts(NDHWGC, GKZYXC, NDHWGK);
|
||||
if constexpr(layouts_are<A, B, E, ctl::NGCDHW, ctl::GKZYXC, ctl::NGKDHW>)
|
||||
return layouts(NGCDHW, GKZYXC, NGKDHW);
|
||||
if constexpr(layouts_are<A, B, E, ctl::NGCDHW, ctl::GKCZYX, ctl::NGKDHW>)
|
||||
return layouts(NGCDHW, GKCZYX, NGKDHW);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -418,39 +393,26 @@ constexpr auto conv_layout()
|
||||
/// @return A `builder::DataType` enum value (e.g., FP16, BF16, FP32).
|
||||
template <typename Instance>
|
||||
constexpr builder::DataType conv_data_type()
|
||||
requires HasDataTypes<InstanceTraits<Instance>>
|
||||
{
|
||||
using InstTraits = InstanceTraits<Instance>;
|
||||
using ADataType = typename InstTraits::ADataType;
|
||||
using enum builder::DataType;
|
||||
|
||||
if constexpr(std::is_same_v<ADataType, ck::half_t>)
|
||||
{
|
||||
return builder::DataType::FP16;
|
||||
}
|
||||
return FP16;
|
||||
else if constexpr(std::is_same_v<ADataType, ck::bhalf_t>)
|
||||
{
|
||||
return builder::DataType::BF16;
|
||||
}
|
||||
return BF16;
|
||||
else if constexpr(std::is_same_v<ADataType, float>)
|
||||
{
|
||||
return builder::DataType::FP32;
|
||||
}
|
||||
return FP32;
|
||||
else if constexpr(std::is_same_v<ADataType, ck::f8_t>)
|
||||
{
|
||||
return builder::DataType::FP8;
|
||||
}
|
||||
return FP8;
|
||||
else if constexpr(std::is_same_v<ADataType, int8_t>)
|
||||
{
|
||||
return builder::DataType::I8;
|
||||
}
|
||||
return I8;
|
||||
else if constexpr(std::is_same_v<ADataType, uint8_t>)
|
||||
{
|
||||
return builder::DataType::U8;
|
||||
}
|
||||
return U8;
|
||||
else
|
||||
{
|
||||
// Default fallback
|
||||
return builder::DataType::FP32;
|
||||
}
|
||||
return FP32; // Default fallback
|
||||
}
|
||||
|
||||
/// @brief Derives the elementwise operation from op type.
|
||||
@@ -459,27 +421,19 @@ constexpr builder::DataType conv_data_type()
|
||||
template <typename ElementwiseOp>
|
||||
constexpr builder::ElementwiseOperation elementwise_op()
|
||||
{
|
||||
using enum builder::ElementwiseOperation;
|
||||
constexpr std::string_view name = detail::elementwise_op_name<ElementwiseOp>();
|
||||
|
||||
if constexpr(detail::case_insensitive_equal(name, "BiasBnormClamp"))
|
||||
{
|
||||
return builder::ElementwiseOperation::BIAS_BNORM_CLAMP;
|
||||
}
|
||||
else if constexpr(detail::case_insensitive_equal(name, "Clamp"))
|
||||
{
|
||||
return builder::ElementwiseOperation::CLAMP;
|
||||
}
|
||||
else if constexpr(detail::case_insensitive_equal(name, "Scale"))
|
||||
{
|
||||
return builder::ElementwiseOperation::SCALE;
|
||||
}
|
||||
else if constexpr(detail::case_insensitive_equal(name, "PassThrough"))
|
||||
{
|
||||
return builder::ElementwiseOperation::PASS_THROUGH;
|
||||
}
|
||||
else if constexpr(detail::case_insensitive_equal(name, "ScaleAddScaleAddRelu"))
|
||||
{
|
||||
return builder::ElementwiseOperation::SCALEADD_SCALEADD_RELU;
|
||||
}
|
||||
return BIAS_BNORM_CLAMP;
|
||||
if constexpr(detail::case_insensitive_equal(name, "Clamp"))
|
||||
return CLAMP;
|
||||
if constexpr(detail::case_insensitive_equal(name, "Scale"))
|
||||
return SCALE;
|
||||
if constexpr(detail::case_insensitive_equal(name, "PassThrough"))
|
||||
return PASS_THROUGH;
|
||||
if constexpr(detail::case_insensitive_equal(name, "ScaleAddScaleAddRelu"))
|
||||
return SCALEADD_SCALEADD_RELU;
|
||||
}
|
||||
|
||||
/// @brief Derives a gemm padding from a kernel instance type.
|
||||
@@ -487,6 +441,7 @@ constexpr builder::ElementwiseOperation elementwise_op()
|
||||
/// @return A `builder::GemmPadding` enum value corresponding to kernel padding.
|
||||
template <typename Instance>
|
||||
constexpr builder::GemmPadding gemm_spec()
|
||||
requires HasGemmSpec<InstanceTraits<Instance>>
|
||||
{
|
||||
using InstTraits = InstanceTraits<Instance>;
|
||||
using enum builder::GemmPadding;
|
||||
@@ -494,69 +449,24 @@ constexpr builder::GemmPadding gemm_spec()
|
||||
|
||||
constexpr auto gemm_spec = InstTraits::kGemmSpecialization;
|
||||
|
||||
if constexpr(gemm_spec == Default)
|
||||
switch(gemm_spec)
|
||||
{
|
||||
return DEFAULT;
|
||||
}
|
||||
else if constexpr(gemm_spec == MPadding)
|
||||
{
|
||||
return M_PADDING;
|
||||
}
|
||||
else if constexpr(gemm_spec == NPadding)
|
||||
{
|
||||
return N_PADDING;
|
||||
}
|
||||
else if constexpr(gemm_spec == KPadding)
|
||||
{
|
||||
return K_PADDING;
|
||||
}
|
||||
else if constexpr(gemm_spec == MNPadding)
|
||||
{
|
||||
return MN_PADDING;
|
||||
}
|
||||
else if constexpr(gemm_spec == MKPadding)
|
||||
{
|
||||
return MK_PADDING;
|
||||
}
|
||||
else if constexpr(gemm_spec == NKPadding)
|
||||
{
|
||||
return NK_PADDING;
|
||||
}
|
||||
else if constexpr(gemm_spec == MNKPadding)
|
||||
{
|
||||
return MNK_PADDING;
|
||||
}
|
||||
else if constexpr(gemm_spec == OPadding)
|
||||
{
|
||||
return O_PADDING;
|
||||
}
|
||||
else if constexpr(gemm_spec == MOPadding)
|
||||
{
|
||||
return MO_PADDING;
|
||||
}
|
||||
else if constexpr(gemm_spec == NOPadding)
|
||||
{
|
||||
return NO_PADDING;
|
||||
}
|
||||
else if constexpr(gemm_spec == KOPadding)
|
||||
{
|
||||
return KO_PADDING;
|
||||
}
|
||||
else if constexpr(gemm_spec == MNOPadding)
|
||||
{
|
||||
return MNO_PADDING;
|
||||
}
|
||||
else if constexpr(gemm_spec == MKOPadding)
|
||||
{
|
||||
return MKO_PADDING;
|
||||
}
|
||||
else if constexpr(gemm_spec == NKOPadding)
|
||||
{
|
||||
return NKO_PADDING;
|
||||
}
|
||||
else if constexpr(gemm_spec == MNKOPadding)
|
||||
{
|
||||
return MNKO_PADDING;
|
||||
case Default: return DEFAULT;
|
||||
case MPadding: return M_PADDING;
|
||||
case NPadding: return N_PADDING;
|
||||
case KPadding: return K_PADDING;
|
||||
case MNPadding: return MN_PADDING;
|
||||
case MKPadding: return MK_PADDING;
|
||||
case NKPadding: return NK_PADDING;
|
||||
case MNKPadding: return MNK_PADDING;
|
||||
case OPadding: return O_PADDING;
|
||||
case MOPadding: return MO_PADDING;
|
||||
case NOPadding: return NO_PADDING;
|
||||
case KOPadding: return KO_PADDING;
|
||||
case MNOPadding: return MNO_PADDING;
|
||||
case MKOPadding: return MKO_PADDING;
|
||||
case NKOPadding: return NKO_PADDING;
|
||||
case MNKOPadding: return MNKO_PADDING;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -571,6 +481,7 @@ struct ConvTraits;
|
||||
/// set of traits directly from a fully-formed device kernel `Instance` type.
|
||||
/// It uses `InstanceTraits` to access the kernel's template parameters.
|
||||
template <HasInstanceTraits Instance>
|
||||
requires IsXdlFwdConv<InstanceTraits<Instance>>
|
||||
struct ConvTraits<Instance>
|
||||
{
|
||||
using InstTraits = InstanceTraits<Instance>;
|
||||
|
||||
Reference in New Issue
Block a user