mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
[CK_Builder] Add name member to unary elementwise ops & update builder traits. (#3093)
* Add name member to unary elementwise ops. * Update elementwise_op_name to check for name attribute. * Require that the layout is derived from BaseTensorLayout struct.
This commit is contained in:
@@ -60,40 +60,25 @@ consteval std::string_view type_name()
|
||||
template <typename T>
|
||||
constexpr std::string_view layout_name()
|
||||
{
|
||||
if constexpr(requires {
|
||||
if constexpr(std::is_base_of_v<ck_tile::tensor_layout::BaseTensorLayout, T> && requires {
|
||||
{ T::name } -> std::convertible_to<std::string_view>;
|
||||
})
|
||||
return T::name;
|
||||
else
|
||||
static_assert(false, "layout type is missing name attribute");
|
||||
static_assert(false,
|
||||
"Layout type must derive from BaseTensorLayout and have name attribute");
|
||||
}
|
||||
|
||||
// Convert element-wise operation types to string names
|
||||
template <typename T>
|
||||
constexpr std::string_view elementwise_op_name()
|
||||
{
|
||||
namespace element_wise = ck::tensor_operation::element_wise;
|
||||
|
||||
if constexpr(std::is_same_v<T, element_wise::PassThrough>)
|
||||
return "PassThrough";
|
||||
else if constexpr(std::is_same_v<T, element_wise::Scale>)
|
||||
return "Scale";
|
||||
else if constexpr(std::is_same_v<T, element_wise::Bilinear>)
|
||||
return "Bilinear";
|
||||
else if constexpr(std::is_same_v<T, element_wise::Add>)
|
||||
return "Add";
|
||||
else if constexpr(std::is_same_v<T, element_wise::AddRelu>)
|
||||
return "AddRelu";
|
||||
else if constexpr(std::is_same_v<T, element_wise::Relu>)
|
||||
return "Relu";
|
||||
else if constexpr(std::is_same_v<T, element_wise::BiasNormalizeInInferClamp>)
|
||||
return "BiasNormalizeInInferClamp";
|
||||
else if constexpr(std::is_same_v<T, element_wise::Clamp>)
|
||||
return "Clamp";
|
||||
else if constexpr(std::is_same_v<T, element_wise::AddClamp>)
|
||||
return "AddClamp";
|
||||
if constexpr(requires {
|
||||
{ T::name } -> std::convertible_to<std::string_view>;
|
||||
})
|
||||
return T::name;
|
||||
else
|
||||
static_assert(false, "unknown_op");
|
||||
static_assert(false, "Elementwise operation is missing name attribute");
|
||||
}
|
||||
|
||||
// Convert ConvolutionForwardSpecialization enum to string
|
||||
|
||||
Reference in New Issue
Block a user