mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[CK_Builder] Add name member to unary elementwise ops & update builder traits. (#3093)
* Add name member to unary elementwise ops. * Update elementwise_op_name to check for name attribute. * Require that the layout is derived from BaseTensorLayout struct.
This commit is contained in:
@@ -60,40 +60,25 @@ consteval std::string_view type_name()
|
||||
template <typename T>
|
||||
constexpr std::string_view layout_name()
|
||||
{
|
||||
if constexpr(requires {
|
||||
if constexpr(std::is_base_of_v<ck_tile::tensor_layout::BaseTensorLayout, T> && requires {
|
||||
{ T::name } -> std::convertible_to<std::string_view>;
|
||||
})
|
||||
return T::name;
|
||||
else
|
||||
static_assert(false, "layout type is missing name attribute");
|
||||
static_assert(false,
|
||||
"Layout type must derive from BaseTensorLayout and have name attribute");
|
||||
}
|
||||
|
||||
// Convert element-wise operation types to string names
|
||||
template <typename T>
|
||||
constexpr std::string_view elementwise_op_name()
|
||||
{
|
||||
namespace element_wise = ck::tensor_operation::element_wise;
|
||||
|
||||
if constexpr(std::is_same_v<T, element_wise::PassThrough>)
|
||||
return "PassThrough";
|
||||
else if constexpr(std::is_same_v<T, element_wise::Scale>)
|
||||
return "Scale";
|
||||
else if constexpr(std::is_same_v<T, element_wise::Bilinear>)
|
||||
return "Bilinear";
|
||||
else if constexpr(std::is_same_v<T, element_wise::Add>)
|
||||
return "Add";
|
||||
else if constexpr(std::is_same_v<T, element_wise::AddRelu>)
|
||||
return "AddRelu";
|
||||
else if constexpr(std::is_same_v<T, element_wise::Relu>)
|
||||
return "Relu";
|
||||
else if constexpr(std::is_same_v<T, element_wise::BiasNormalizeInInferClamp>)
|
||||
return "BiasNormalizeInInferClamp";
|
||||
else if constexpr(std::is_same_v<T, element_wise::Clamp>)
|
||||
return "Clamp";
|
||||
else if constexpr(std::is_same_v<T, element_wise::AddClamp>)
|
||||
return "AddClamp";
|
||||
if constexpr(requires {
|
||||
{ T::name } -> std::convertible_to<std::string_view>;
|
||||
})
|
||||
return T::name;
|
||||
else
|
||||
static_assert(false, "unknown_op");
|
||||
static_assert(false, "Elementwise operation is missing name attribute");
|
||||
}
|
||||
|
||||
// Convert ConvolutionForwardSpecialization enum to string
|
||||
|
||||
@@ -349,6 +349,8 @@ CK_TILE_DEVICE bf8x4_t i4_to_bf8x4(int q)
|
||||
|
||||
struct PassThroughPack8
|
||||
{
|
||||
static constexpr const char* name = "PassThroughPack8";
|
||||
|
||||
template <typename Y, typename X>
|
||||
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const;
|
||||
|
||||
@@ -388,6 +390,8 @@ struct PassThroughPack8
|
||||
|
||||
struct DequantPack8
|
||||
{
|
||||
static constexpr const char* name = "DequantPack8";
|
||||
|
||||
template <typename Y, typename X, typename Z>
|
||||
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x, const Z& z) const;
|
||||
|
||||
@@ -403,6 +407,8 @@ struct DequantPack8
|
||||
|
||||
struct PassThroughPack2
|
||||
{
|
||||
static constexpr const char* name = "PassThroughPack2";
|
||||
|
||||
template <typename Y, typename X>
|
||||
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const;
|
||||
|
||||
@@ -429,6 +435,8 @@ struct PassThroughPack2
|
||||
|
||||
struct PassThrough
|
||||
{
|
||||
static constexpr const char* name = "PassThrough";
|
||||
|
||||
template <class T>
|
||||
using raw_t = std::remove_cv_t<std::remove_reference_t<T>>;
|
||||
|
||||
@@ -465,6 +473,8 @@ struct PassThrough
|
||||
|
||||
struct AddScale
|
||||
{
|
||||
static constexpr const char* name = "AddScale";
|
||||
|
||||
template <typename E, typename... As>
|
||||
CK_TILE_HOST_DEVICE constexpr void operator()(E& a, const As&... as) const
|
||||
{
|
||||
@@ -482,6 +492,8 @@ struct AddScale
|
||||
|
||||
struct MultiDMultiply
|
||||
{
|
||||
static constexpr const char* name = "MultiDMultiply";
|
||||
|
||||
template <typename E, typename C, typename... Ds>
|
||||
CK_TILE_HOST_DEVICE auto operator()(E& e, const C& c, const Ds&... ds) const -> void
|
||||
{
|
||||
@@ -497,6 +509,8 @@ struct MultiDMultiply
|
||||
|
||||
struct MultiDAdd
|
||||
{
|
||||
static constexpr const char* name = "MultiDAdd";
|
||||
|
||||
template <typename E, typename C, typename... Ds>
|
||||
CK_TILE_HOST_DEVICE auto operator()(E& e, const C& c, const Ds&... ds) const -> void
|
||||
{
|
||||
@@ -512,6 +526,8 @@ struct MultiDAdd
|
||||
|
||||
struct UnaryConvert
|
||||
{
|
||||
static constexpr const char* name = "UnaryConvert";
|
||||
|
||||
template <typename Y, typename X>
|
||||
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const
|
||||
{
|
||||
@@ -576,6 +592,8 @@ struct ConvertF8RNE
|
||||
|
||||
struct Scale
|
||||
{
|
||||
static constexpr const char* name = "Scale";
|
||||
|
||||
CK_TILE_HOST_DEVICE Scale(float scale = 1.f) : scale_(scale) {}
|
||||
|
||||
template <typename Y, typename X>
|
||||
@@ -623,6 +641,8 @@ struct Scale
|
||||
|
||||
struct ScaleAndResetNaNToMinusInfinity
|
||||
{
|
||||
static constexpr const char* name = "ScaleAndResetNaNToMinusInfinity";
|
||||
|
||||
CK_TILE_HOST_DEVICE ScaleAndResetNaNToMinusInfinity(float scale) : scale_(scale) {}
|
||||
|
||||
template <typename Y, typename X>
|
||||
@@ -639,6 +659,8 @@ struct ScaleAndResetNaNToMinusInfinity
|
||||
|
||||
struct UnaryDivide
|
||||
{
|
||||
static constexpr const char* name = "UnaryDivide";
|
||||
|
||||
CK_TILE_HOST_DEVICE UnaryDivide(const int32_t divider = 1) : divider_(divider) {}
|
||||
|
||||
template <typename T>
|
||||
@@ -656,6 +678,8 @@ struct UnaryDivide
|
||||
|
||||
struct UnarySquare
|
||||
{
|
||||
static constexpr const char* name = "UnarySquare";
|
||||
|
||||
template <typename Y, typename X>
|
||||
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const
|
||||
{
|
||||
@@ -673,6 +697,8 @@ struct UnarySquare
|
||||
|
||||
struct UnaryAbs
|
||||
{
|
||||
static constexpr const char* name = "UnaryAbs";
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
|
||||
{
|
||||
@@ -687,6 +713,8 @@ struct UnaryAbs
|
||||
|
||||
struct UnarySqrt
|
||||
{
|
||||
static constexpr const char* name = "UnarySqrt";
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
|
||||
{
|
||||
@@ -699,6 +727,8 @@ struct UnarySqrt
|
||||
|
||||
struct Relu
|
||||
{
|
||||
static constexpr const char* name = "Relu";
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
|
||||
{
|
||||
@@ -725,6 +755,8 @@ struct Relu
|
||||
// gpu code use lower accuracy "_ocml_exp_f32" and "rcp" function
|
||||
struct FastGelu
|
||||
{
|
||||
static constexpr const char* name = "FastGelu";
|
||||
|
||||
template <typename Y, typename X>
|
||||
CK_TILE_HOST void operator()(Y& y, const X& x) const;
|
||||
|
||||
@@ -842,6 +874,8 @@ struct FastGelu
|
||||
|
||||
struct FastGeluAsm
|
||||
{
|
||||
static constexpr const char* name = "FastGeluAsm";
|
||||
|
||||
template <typename Y, typename X>
|
||||
CK_TILE_HOST void operator()(Y& y, const X& x) const;
|
||||
|
||||
@@ -943,6 +977,8 @@ struct FastGeluAsm
|
||||
// y = 0.5*x*(1+erf(x/sqrt(2)))
|
||||
struct Gelu
|
||||
{
|
||||
static constexpr const char* name = "Gelu";
|
||||
|
||||
template <typename Y, typename X>
|
||||
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const;
|
||||
|
||||
@@ -963,6 +999,8 @@ struct Gelu
|
||||
|
||||
struct Sigmoid
|
||||
{
|
||||
static constexpr const char* name = "Sigmoid";
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
|
||||
{
|
||||
@@ -977,6 +1015,8 @@ struct Sigmoid
|
||||
|
||||
struct Silu
|
||||
{
|
||||
static constexpr const char* name = "Silu";
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
|
||||
{
|
||||
@@ -1066,6 +1106,8 @@ struct SiluAsm
|
||||
|
||||
struct TanH
|
||||
{
|
||||
static constexpr const char* name = "TanH";
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
|
||||
{
|
||||
@@ -1080,6 +1122,8 @@ struct TanH
|
||||
|
||||
struct ACos
|
||||
{
|
||||
static constexpr const char* name = "ACos";
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
|
||||
{
|
||||
@@ -1094,6 +1138,8 @@ struct ACos
|
||||
|
||||
struct Neg
|
||||
{
|
||||
static constexpr const char* name = "Neg";
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
|
||||
{
|
||||
@@ -1108,6 +1154,8 @@ struct Neg
|
||||
|
||||
struct ATan
|
||||
{
|
||||
static constexpr const char* name = "ATan";
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
|
||||
{
|
||||
@@ -1122,6 +1170,8 @@ struct ATan
|
||||
|
||||
struct Sin
|
||||
{
|
||||
static constexpr const char* name = "Sin";
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
|
||||
{
|
||||
@@ -1136,6 +1186,8 @@ struct Sin
|
||||
|
||||
struct ASinH
|
||||
{
|
||||
static constexpr const char* name = "ASinH";
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
|
||||
{
|
||||
@@ -1150,6 +1202,8 @@ struct ASinH
|
||||
|
||||
struct Cos
|
||||
{
|
||||
static constexpr const char* name = "Cos";
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
|
||||
{
|
||||
@@ -1164,6 +1218,8 @@ struct Cos
|
||||
|
||||
struct ACosH
|
||||
{
|
||||
static constexpr const char* name = "ACosH";
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
|
||||
{
|
||||
@@ -1178,6 +1234,8 @@ struct ACosH
|
||||
|
||||
struct Tan
|
||||
{
|
||||
static constexpr const char* name = "Tan";
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
|
||||
{
|
||||
@@ -1192,6 +1250,8 @@ struct Tan
|
||||
|
||||
struct ATanH
|
||||
{
|
||||
static constexpr const char* name = "ATanH";
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
|
||||
{
|
||||
@@ -1206,6 +1266,8 @@ struct ATanH
|
||||
|
||||
struct SinH
|
||||
{
|
||||
static constexpr const char* name = "SinH";
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
|
||||
{
|
||||
@@ -1220,6 +1282,8 @@ struct SinH
|
||||
|
||||
struct Ceil
|
||||
{
|
||||
static constexpr const char* name = "Ceil";
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
|
||||
{
|
||||
@@ -1234,6 +1298,8 @@ struct Ceil
|
||||
|
||||
struct Exp
|
||||
{
|
||||
static constexpr const char* name = "Exp";
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
|
||||
{
|
||||
@@ -1248,6 +1314,8 @@ struct Exp
|
||||
|
||||
struct CosH
|
||||
{
|
||||
static constexpr const char* name = "CosH";
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
|
||||
{
|
||||
@@ -1262,6 +1330,8 @@ struct CosH
|
||||
|
||||
struct Floor
|
||||
{
|
||||
static constexpr const char* name = "Floor";
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
|
||||
{
|
||||
@@ -1276,6 +1346,8 @@ struct Floor
|
||||
|
||||
struct Log
|
||||
{
|
||||
static constexpr const char* name = "Log";
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
|
||||
{
|
||||
@@ -1290,6 +1362,8 @@ struct Log
|
||||
|
||||
struct ASin
|
||||
{
|
||||
static constexpr const char* name = "ASin";
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
|
||||
{
|
||||
@@ -1304,6 +1378,8 @@ struct ASin
|
||||
|
||||
struct Rcp
|
||||
{
|
||||
static constexpr const char* name = "Rcp";
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
|
||||
{
|
||||
@@ -1318,6 +1394,8 @@ struct Rcp
|
||||
|
||||
struct Swish
|
||||
{
|
||||
static constexpr const char* name = "Swish";
|
||||
|
||||
Swish(float beta = 1.0f) : beta_(beta) {}
|
||||
|
||||
template <typename Y, typename X>
|
||||
@@ -1340,6 +1418,8 @@ struct Swish
|
||||
|
||||
struct SoftRelu
|
||||
{
|
||||
static constexpr const char* name = "SoftRelu";
|
||||
|
||||
SoftRelu(float alpha = 1.f) : alpha_(alpha){};
|
||||
|
||||
template <typename T>
|
||||
@@ -1358,6 +1438,8 @@ struct SoftRelu
|
||||
|
||||
struct Power
|
||||
{
|
||||
static constexpr const char* name = "Power";
|
||||
|
||||
Power(float alpha = 0.f, float beta = 1.f, float gamma = 2.f)
|
||||
: alpha_(alpha), beta_(beta), gamma_(gamma){};
|
||||
|
||||
@@ -1381,6 +1463,8 @@ struct Power
|
||||
|
||||
struct ClippedRelu
|
||||
{
|
||||
static constexpr const char* name = "ClippedRelu";
|
||||
|
||||
ClippedRelu(float alpha = 0.f, float beta = 1.f) : alpha_(alpha), beta_(beta){};
|
||||
|
||||
template <typename T>
|
||||
@@ -1400,6 +1484,8 @@ struct ClippedRelu
|
||||
|
||||
struct LeakyRelu
|
||||
{
|
||||
static constexpr const char* name = "LeakyRelu";
|
||||
|
||||
LeakyRelu(float alpha = 0.01f) : alpha_(alpha){};
|
||||
|
||||
template <typename T>
|
||||
@@ -1417,6 +1503,8 @@ struct LeakyRelu
|
||||
|
||||
struct Elu
|
||||
{
|
||||
static constexpr const char* name = "Elu";
|
||||
|
||||
Elu(float alpha = 1.f) : alpha_(alpha){};
|
||||
|
||||
template <typename T>
|
||||
@@ -1434,6 +1522,8 @@ struct Elu
|
||||
|
||||
struct Logistic
|
||||
{
|
||||
static constexpr const char* name = "Logistic";
|
||||
|
||||
Logistic(float alpha = 1.f) : alpha_(alpha){};
|
||||
|
||||
template <typename T>
|
||||
@@ -1452,6 +1542,8 @@ struct Logistic
|
||||
|
||||
struct ConvInvscale
|
||||
{
|
||||
static constexpr const char* name = "ConvInvscale";
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
ConvInvscale(float scale_in = 1.f, float scale_wei = 1.f, float scale_out = 1.f)
|
||||
: scale_in_(scale_in), scale_wei_(scale_wei), scale_out_(scale_out)
|
||||
@@ -1475,6 +1567,8 @@ struct ConvInvscale
|
||||
|
||||
struct ConvScale
|
||||
{
|
||||
static constexpr const char* name = "ConvScale";
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
ConvScale(float scale_in = 1.f, float scale_wei = 1.f, float scale_out = 1.f)
|
||||
: scale_in_(scale_in), scale_wei_(scale_wei), scale_out_(scale_out)
|
||||
@@ -1498,6 +1592,8 @@ struct ConvScale
|
||||
|
||||
struct ConvScaleRelu
|
||||
{
|
||||
static constexpr const char* name = "ConvScaleRelu";
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
ConvScaleRelu(float scale_in = 1.f, float scale_wei = 1.f, float scale_out = 1.f)
|
||||
: scale_in_(scale_in), scale_wei_(scale_wei), scale_out_(scale_out)
|
||||
@@ -1524,6 +1620,8 @@ struct ConvScaleRelu
|
||||
template <typename DstType, typename SrcType>
|
||||
struct Cast
|
||||
{
|
||||
static constexpr const char* name = "Cast";
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE void operator()(DstType& y, const SrcType& x) const
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user