[CK_Builder] Add name member to unary elementwise ops & update builder traits. (#3093)

* Add name member to unary elementwise ops.

* Update elementwise_op_name to check for name attribute.

* Require that the layout is derived from BaseTensorLayout struct.
This commit is contained in:
Adam Osewski
2025-10-25 16:27:03 +02:00
committed by GitHub
parent e576992dca
commit f53d857b25
2 changed files with 106 additions and 23 deletions

View File

@@ -60,40 +60,25 @@ consteval std::string_view type_name()
template <typename T>
constexpr std::string_view layout_name()
{
if constexpr(requires {
if constexpr(std::is_base_of_v<ck_tile::tensor_layout::BaseTensorLayout, T> && requires {
{ T::name } -> std::convertible_to<std::string_view>;
})
return T::name;
else
static_assert(false, "layout type is missing name attribute");
static_assert(false,
"Layout type must derive from BaseTensorLayout and have name attribute");
}
// Convert element-wise operation types to string names
template <typename T>
constexpr std::string_view elementwise_op_name()
{
namespace element_wise = ck::tensor_operation::element_wise;
if constexpr(std::is_same_v<T, element_wise::PassThrough>)
return "PassThrough";
else if constexpr(std::is_same_v<T, element_wise::Scale>)
return "Scale";
else if constexpr(std::is_same_v<T, element_wise::Bilinear>)
return "Bilinear";
else if constexpr(std::is_same_v<T, element_wise::Add>)
return "Add";
else if constexpr(std::is_same_v<T, element_wise::AddRelu>)
return "AddRelu";
else if constexpr(std::is_same_v<T, element_wise::Relu>)
return "Relu";
else if constexpr(std::is_same_v<T, element_wise::BiasNormalizeInInferClamp>)
return "BiasNormalizeInInferClamp";
else if constexpr(std::is_same_v<T, element_wise::Clamp>)
return "Clamp";
else if constexpr(std::is_same_v<T, element_wise::AddClamp>)
return "AddClamp";
if constexpr(requires {
{ T::name } -> std::convertible_to<std::string_view>;
})
return T::name;
else
static_assert(false, "unknown_op");
static_assert(false, "Elementwise operation is missing name attribute");
}
// Convert ConvolutionForwardSpecialization enum to string