From dc6d0327f9de3b076e9475a74358a877069d66f5 Mon Sep 17 00:00:00 2001 From: Adam Osewski <19374865+aosewski@users.noreply.github.com> Date: Sat, 25 Oct 2025 16:27:03 +0200 Subject: [PATCH] [CK_Builder] Add name member to unary elementwise ops & update builder traits. (#3093) * Add name member to unary elementwise ops. * Update elementwise_op_name to check for name attribute. * Require that the layout is derived from BaseTensorLayout struct. [ROCm/composable_kernel commit: f53d857b2552c072b0f8f14fd7609e88168d6e44] --- .../builder/reflect/instance_traits_util.hpp | 31 ++---- .../unary_element_wise_operation.hpp | 98 +++++++++++++++++++ 2 files changed, 106 insertions(+), 23 deletions(-) diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_util.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_util.hpp index 90e42528e1..4bc091f203 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_util.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_util.hpp @@ -60,40 +60,25 @@ consteval std::string_view type_name() template constexpr std::string_view layout_name() { - if constexpr(requires { + if constexpr(std::is_base_of_v && requires { { T::name } -> std::convertible_to; }) return T::name; else - static_assert(false, "layout type is missing name attribute"); + static_assert(false, + "Layout type must derive from BaseTensorLayout and have name attribute"); } // Convert element-wise operation types to string names template constexpr std::string_view elementwise_op_name() { - namespace element_wise = ck::tensor_operation::element_wise; - - if constexpr(std::is_same_v) - return "PassThrough"; - else if constexpr(std::is_same_v) - return "Scale"; - else if constexpr(std::is_same_v) - return "Bilinear"; - else if constexpr(std::is_same_v) - return "Add"; - else if constexpr(std::is_same_v) - return "AddRelu"; - else if constexpr(std::is_same_v) - return "Relu"; - else if constexpr(std::is_same_v) - return "BiasNormalizeInInferClamp"; - else if constexpr(std::is_same_v) - return "Clamp"; - else if constexpr(std::is_same_v) - return "AddClamp"; + if constexpr(requires { + { T::name } -> std::convertible_to; + }) + return T::name; else - static_assert(false, "unknown_op"); + static_assert(false, "Elementwise operation is missing name attribute"); } // Convert ConvolutionForwardSpecialization enum to string diff --git a/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp b/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp index ea8ba4557e..c6f2db639c 100644 --- a/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp +++ b/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp @@ -349,6 +349,8 @@ CK_TILE_DEVICE bf8x4_t i4_to_bf8x4(int q) struct PassThroughPack8 { + static constexpr const char* name = "PassThroughPack8"; + template CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const; @@ -388,6 +390,8 @@ struct PassThroughPack8 struct DequantPack8 { + static constexpr const char* name = "DequantPack8"; + template CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x, const Z& z) const; @@ -403,6 +407,8 @@ struct DequantPack8 struct PassThroughPack2 { + static constexpr const char* name = "PassThroughPack2"; + template CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const; @@ -429,6 +435,8 @@ struct PassThroughPack2 struct PassThrough { + static constexpr const char* name = "PassThrough"; + template using raw_t = std::remove_cv_t>; @@ -465,6 +473,8 @@ struct PassThrough struct AddScale { + static constexpr const char* name = "AddScale"; + template CK_TILE_HOST_DEVICE constexpr void operator()(E& a, const As&... as) const { @@ -482,6 +492,8 @@ struct AddScale struct MultiDMultiply { + static constexpr const char* name = "MultiDMultiply"; + template CK_TILE_HOST_DEVICE auto operator()(E& e, const C& c, const Ds&... ds) const -> void { @@ -497,6 +509,8 @@ struct MultiDMultiply struct MultiDAdd { + static constexpr const char* name = "MultiDAdd"; + template CK_TILE_HOST_DEVICE auto operator()(E& e, const C& c, const Ds&... ds) const -> void { @@ -512,6 +526,8 @@ struct MultiDAdd struct UnaryConvert { + static constexpr const char* name = "UnaryConvert"; + template CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const { @@ -576,6 +592,8 @@ struct ConvertF8RNE struct Scale { + static constexpr const char* name = "Scale"; + CK_TILE_HOST_DEVICE Scale(float scale = 1.f) : scale_(scale) {} template @@ -623,6 +641,8 @@ struct Scale struct ScaleAndResetNaNToMinusInfinity { + static constexpr const char* name = "ScaleAndResetNaNToMinusInfinity"; + CK_TILE_HOST_DEVICE ScaleAndResetNaNToMinusInfinity(float scale) : scale_(scale) {} template @@ -639,6 +659,8 @@ struct ScaleAndResetNaNToMinusInfinity struct UnaryDivide { + static constexpr const char* name = "UnaryDivide"; + CK_TILE_HOST_DEVICE UnaryDivide(const int32_t divider = 1) : divider_(divider) {} template @@ -656,6 +678,8 @@ struct UnaryDivide struct UnarySquare { + static constexpr const char* name = "UnarySquare"; + template CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const { @@ -673,6 +697,8 @@ struct UnarySquare struct UnaryAbs { + static constexpr const char* name = "UnaryAbs"; + template CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const { @@ -687,6 +713,8 @@ struct UnaryAbs struct UnarySqrt { + static constexpr const char* name = "UnarySqrt"; + template CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const { @@ -699,6 +727,8 @@ struct UnarySqrt struct Relu { + static constexpr const char* name = "Relu"; + template CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const { @@ -725,6 +755,8 @@ struct Relu // gpu code use lower accuracy "_ocml_exp_f32" and "rcp" function struct FastGelu { + static constexpr const char* name = "FastGelu"; + template CK_TILE_HOST void operator()(Y& y, const X& x) const; @@ -842,6 +874,8 @@ struct FastGelu struct FastGeluAsm { + static constexpr const char* name = "FastGeluAsm"; + template CK_TILE_HOST void operator()(Y& y, const X& x) const; @@ -943,6 +977,8 @@ struct FastGeluAsm // y = 0.5*x*(1+erf(x/sqrt(2))) struct Gelu { + static constexpr const char* name = "Gelu"; + template CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const; @@ -963,6 +999,8 @@ struct Gelu struct Sigmoid { + static constexpr const char* name = "Sigmoid"; + template CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const { @@ -977,6 +1015,8 @@ struct Sigmoid struct Silu { + static constexpr const char* name = "Silu"; + template CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const { @@ -1066,6 +1106,8 @@ struct SiluAsm struct TanH { + static constexpr const char* name = "TanH"; + template CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const { @@ -1080,6 +1122,8 @@ struct TanH struct ACos { + static constexpr const char* name = "ACos"; + template CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const { @@ -1094,6 +1138,8 @@ struct ACos struct Neg { + static constexpr const char* name = "Neg"; + template CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const { @@ -1108,6 +1154,8 @@ struct Neg struct ATan { + static constexpr const char* name = "ATan"; + template CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const { @@ -1122,6 +1170,8 @@ struct ATan struct Sin { + static constexpr const char* name = "Sin"; + template CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const { @@ -1136,6 +1186,8 @@ struct Sin struct ASinH { + static constexpr const char* name = "ASinH"; + template CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const { @@ -1150,6 +1202,8 @@ struct ASinH struct Cos { + static constexpr const char* name = "Cos"; + template CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const { @@ -1164,6 +1218,8 @@ struct Cos struct ACosH { + static constexpr const char* name = "ACosH"; + template CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const { @@ -1178,6 +1234,8 @@ struct ACosH struct Tan { + static constexpr const char* name = "Tan"; + template CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const { @@ -1192,6 +1250,8 @@ struct Tan struct ATanH { + static constexpr const char* name = "ATanH"; + template CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const { @@ -1206,6 +1266,8 @@ struct ATanH struct SinH { + static constexpr const char* name = "SinH"; + template CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const { @@ -1220,6 +1282,8 @@ struct SinH struct Ceil { + static constexpr const char* name = "Ceil"; + template CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const { @@ -1234,6 +1298,8 @@ struct Ceil struct Exp { + static constexpr const char* name = "Exp"; + template CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const { @@ -1248,6 +1314,8 @@ struct Exp struct CosH { + static constexpr const char* name = "CosH"; + template CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const { @@ -1262,6 +1330,8 @@ struct CosH struct Floor { + static constexpr const char* name = "Floor"; + template CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const { @@ -1276,6 +1346,8 @@ struct Floor struct Log { + static constexpr const char* name = "Log"; + template CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const { @@ -1290,6 +1362,8 @@ struct Log struct ASin { + static constexpr const char* name = "ASin"; + template CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const { @@ -1304,6 +1378,8 @@ struct ASin struct Rcp { + static constexpr const char* name = "Rcp"; + template CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const { @@ -1318,6 +1394,8 @@ struct Rcp struct Swish { + static constexpr const char* name = "Swish"; + Swish(float beta = 1.0f) : beta_(beta) {} template @@ -1340,6 +1418,8 @@ struct Swish struct SoftRelu { + static constexpr const char* name = "SoftRelu"; + SoftRelu(float alpha = 1.f) : alpha_(alpha){}; template @@ -1358,6 +1438,8 @@ struct SoftRelu struct Power { + static constexpr const char* name = "Power"; + Power(float alpha = 0.f, float beta = 1.f, float gamma = 2.f) : alpha_(alpha), beta_(beta), gamma_(gamma){}; @@ -1381,6 +1463,8 @@ struct Power struct ClippedRelu { + static constexpr const char* name = "ClippedRelu"; + ClippedRelu(float alpha = 0.f, float beta = 1.f) : alpha_(alpha), beta_(beta){}; template @@ -1400,6 +1484,8 @@ struct ClippedRelu struct LeakyRelu { + static constexpr const char* name = "LeakyRelu"; + LeakyRelu(float alpha = 0.01f) : alpha_(alpha){}; template @@ -1417,6 +1503,8 @@ struct LeakyRelu struct Elu { + static constexpr const char* name = "Elu"; + Elu(float alpha = 1.f) : alpha_(alpha){}; template @@ -1434,6 +1522,8 @@ struct Elu struct Logistic { + static constexpr const char* name = "Logistic"; + Logistic(float alpha = 1.f) : alpha_(alpha){}; template @@ -1452,6 +1542,8 @@ struct Logistic struct ConvInvscale { + static constexpr const char* name = "ConvInvscale"; + CK_TILE_HOST_DEVICE ConvInvscale(float scale_in = 1.f, float scale_wei = 1.f, float scale_out = 1.f) : scale_in_(scale_in), scale_wei_(scale_wei), scale_out_(scale_out) @@ -1475,6 +1567,8 @@ struct ConvInvscale struct ConvScale { + static constexpr const char* name = "ConvScale"; + CK_TILE_HOST_DEVICE ConvScale(float scale_in = 1.f, float scale_wei = 1.f, float scale_out = 1.f) : scale_in_(scale_in), scale_wei_(scale_wei), scale_out_(scale_out) @@ -1498,6 +1592,8 @@ struct ConvScale struct ConvScaleRelu { + static constexpr const char* name = "ConvScaleRelu"; + CK_TILE_HOST_DEVICE ConvScaleRelu(float scale_in = 1.f, float scale_wei = 1.f, float scale_out = 1.f) : scale_in_(scale_in), scale_wei_(scale_wei), scale_out_(scale_out) @@ -1524,6 +1620,8 @@ struct ConvScaleRelu template struct Cast { + static constexpr const char* name = "Cast"; + template CK_TILE_HOST_DEVICE void operator()(DstType& y, const SrcType& x) const {