From 03c97c9524b5d8a1362dae0d0625929f55167f94 Mon Sep 17 00:00:00 2001 From: John Shumway Date: Mon, 27 Oct 2025 22:14:08 -0700 Subject: [PATCH] [CK_BUILDER] Test and fix instance traits utils. (#3096) * Refactor instance_traits_util and add unit tests tests * Address reviewer comments. Just adds some TODOs to indicate deprecated layouts in our reflection. Our strategy is to leave the reflection code broad (covering deprecated features), but keep the builder concepts narrow. Once we've removed deprecated features from all instances, we can remove them from reflection. Also add a comment to the cmake to explain the unit test target test_conv_builder. * Addressed more reviewer comments. * Remove duplicate PassThrough::name Accidentally added this field to the end of the struct, too. The `name` field should be a the start of the struct for consistency. [ROCm/composable_kernel commit: 54746e9329e9d365512e8d81dd4df6bd2403f794] --- .../builder/reflect/instance_traits_util.hpp | 255 +++++++++++++---- experimental/builder/test/CMakeLists.txt | 13 +- .../test/test_instance_traits_util.cpp | 263 ++++++++++++++++++ .../element/binary_element_wise_operation.hpp | 8 + .../gpu/element/element_wise_operation.hpp | 2 + .../element/unary_element_wise_operation.hpp | 10 +- 6 files changed, 479 insertions(+), 72 deletions(-) create mode 100644 experimental/builder/test/test_instance_traits_util.cpp diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_util.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_util.hpp index a7ec568b03..902c3b3579 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_util.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_util.hpp @@ -23,19 +23,12 @@ namespace ck_tile::reflect::detail { -// Metaprogramming helper to convert ck::Sequence to constexpr std::array -template -struct SequenceToArray; - -template -struct SequenceToArray> -{ - static constexpr std::array value = {static_cast(Is)...}; -}; - -// Convert data types to string names +// Implementation detail for type name mapping +// This is the single source of truth for supported data types that +// returns an empty string to indicate an unsupported type. +namespace impl { template -consteval std::string_view type_name() +consteval std::string_view type_name_impl() { if constexpr(std::is_same_v) return "fp16"; @@ -56,22 +49,38 @@ consteval std::string_view type_name() else if constexpr(std::is_same_v) return "bf8"; else - static_assert(false, "unknown_type"); + return std::string_view{}; // Return empty for supported types +} +} // namespace impl + +// Convert data types to string names +// Fails at compile time for unsupported types +template +consteval std::string_view type_name() +{ + constexpr auto name = impl::type_name_impl(); + static_assert(!name.empty(), "Unsupported data type"); + return name; } -// Convert layout types to string names +// Concept that checks if a type is a valid data type +// Uses the impl directly to avoid triggering static_assert during concept evaluation template +concept IsDataType = !impl::type_name_impl().empty(); + +// Concept that checks valid layout types +template +concept IsLayoutType = (std::is_base_of_v || + std::is_base_of_v) && + requires { + { T::name } -> std::convertible_to; + }; + +// Convert layout types to string names +template constexpr std::string_view layout_name() { - if constexpr((std::is_base_of_v || - std::is_base_of_v) && - requires { - { T::name } -> std::convertible_to; - }) - return T::name; - else - static_assert(false, - "Layout type must derive from BaseTensorLayout and have name attribute"); + return T::name; } // Convert element-wise operation types to string names @@ -90,64 +99,64 @@ constexpr std::string_view elementwise_op_name() constexpr std::string_view conv_fwd_spec_name(ck::tensor_operation::device::ConvolutionForwardSpecialization spec) { - using ck::tensor_operation::device::ConvolutionForwardSpecialization; + using enum ck::tensor_operation::device::ConvolutionForwardSpecialization; switch(spec) { - case ConvolutionForwardSpecialization::Default: return "Default"; - case ConvolutionForwardSpecialization::Filter1x1Stride1Pad0: return "Filter1x1Stride1Pad0"; - case ConvolutionForwardSpecialization::Filter1x1Pad0: return "Filter1x1Pad0"; - case ConvolutionForwardSpecialization::Filter3x3: return "Filter3x3"; - case ConvolutionForwardSpecialization::OddC: return "OddC"; + case Default: return "Default"; + case Filter1x1Stride1Pad0: return "Filter1x1Stride1Pad0"; + case Filter1x1Pad0: return "Filter1x1Pad0"; + case Filter3x3: return "Filter3x3"; + case OddC: return "OddC"; } } // Convert GemmSpecialization enum to string constexpr std::string_view gemm_spec_name(ck::tensor_operation::device::GemmSpecialization spec) { - using ck::tensor_operation::device::GemmSpecialization; + using enum ck::tensor_operation::device::GemmSpecialization; switch(spec) { - case GemmSpecialization::Default: return "Default"; - case GemmSpecialization::MPadding: return "MPadding"; - case GemmSpecialization::NPadding: return "NPadding"; - case GemmSpecialization::KPadding: return "KPadding"; - case GemmSpecialization::MNPadding: return "MNPadding"; - case GemmSpecialization::MKPadding: return "MKPadding"; - case GemmSpecialization::NKPadding: return "NKPadding"; - case GemmSpecialization::MNKPadding: return "MNKPadding"; - case GemmSpecialization::OPadding: return "OPadding"; - case GemmSpecialization::MOPadding: return "MOPadding"; - case GemmSpecialization::NOPadding: return "NOPadding"; - case GemmSpecialization::KOPadding: return "KOPadding"; - case GemmSpecialization::MNOPadding: return "MNOPadding"; - case GemmSpecialization::MKOPadding: return "MKOPadding"; - case GemmSpecialization::NKOPadding: return "NKOPadding"; - case GemmSpecialization::MNKOPadding: return "MNKOPadding"; + case Default: return "Default"; + case MPadding: return "MPadding"; + case NPadding: return "NPadding"; + case KPadding: return "KPadding"; + case MNPadding: return "MNPadding"; + case MKPadding: return "MKPadding"; + case NKPadding: return "NKPadding"; + case MNKPadding: return "MNKPadding"; + case OPadding: return "OPadding"; + case MOPadding: return "MOPadding"; + case NOPadding: return "NOPadding"; + case KOPadding: return "KOPadding"; + case MNOPadding: return "MNOPadding"; + case MKOPadding: return "MKOPadding"; + case NKOPadding: return "NKOPadding"; + case MNKOPadding: return "MNKOPadding"; } } // Convert BlockGemmPipelineScheduler enum to string constexpr std::string_view pipeline_scheduler_name(ck::BlockGemmPipelineScheduler sched) { - using ck::BlockGemmPipelineScheduler; + using enum ck::BlockGemmPipelineScheduler; switch(sched) { - case BlockGemmPipelineScheduler::Intrawave: return "Intrawave"; - case BlockGemmPipelineScheduler::Interwave: return "Interwave"; + case Intrawave: return "Intrawave"; + case Interwave: return "Interwave"; } } // Convert BlockGemmPipelineVersion enum to string constexpr std::string_view pipeline_version_name(ck::BlockGemmPipelineVersion ver) { - using ck::BlockGemmPipelineVersion; + using enum ck::BlockGemmPipelineVersion; switch(ver) { - case BlockGemmPipelineVersion::v1: return "v1"; - case BlockGemmPipelineVersion::v2: return "v2"; - case BlockGemmPipelineVersion::v3: return "v3"; - case BlockGemmPipelineVersion::v4: return "v4"; - case BlockGemmPipelineVersion::v5: return "v5"; + case v1: return "v1"; + case v2: return "v2"; + case v3: return "v3"; + case v4: return "v4"; + case v5: return "v5"; } } @@ -167,12 +176,138 @@ inline std::string array_to_string(const std::array& arr) return oss.str(); } -// Handle ck::Tuple (empty tuple for DsLayout/DsDataType) -template -constexpr std::string_view tuple_name() +// Metaprogramming helper to convert ck::Sequence to constexpr std::array +template +struct SequenceToArray; + +template +struct SequenceToArray> { - // For now, just check if it's an empty tuple - return "EmptyTuple"; + static constexpr std::array value = {static_cast(Is)...}; +}; + +namespace detail { +// Generic helper to build list-like strings (Tuple, Seq, etc.) +// +// Example output: "Seq(1,2,3)" +// +// prefix: The list-like container name (e.g. "Tuple" or "Seq") +// converter_fn: A callable that converts each element to a string representation +// For types: converter_fn should be a template lambda like []() { return +// type_name(); } For values: converter_fn should be a regular lambda like [](auto value) { +// return std::to_string(value); } +template +constexpr std::string build_list_string(std::string_view prefix, const ConverterFn& converter_fn) +{ + if constexpr(sizeof...(Elements) == 0) + { + return std::string(prefix) + "()"; + } + else + { + std::string result = std::string(prefix) + "("; + std::size_t index = 0; + ((result += + (index++ > 0 ? "," : "") + std::string(converter_fn.template operator()())), + ...); + result += ")"; + return result; + } +} + +// Overload for value-based lists (sequences) +template +constexpr std::string build_list_string_values(std::string_view prefix, + const ConverterFn& converter_fn) +{ + if constexpr(sizeof...(Values) == 0) + { + return std::string(prefix) + "()"; + } + else + { + std::string result = std::string(prefix) + "("; + std::size_t index = 0; + ((result += (index++ > 0 ? "," : "") + converter_fn(Values)), ...); + result += ")"; + return result; + } +} +} // namespace detail + +// Convert ck::Sequence to string representation +// Converts a ck::Sequence type to a string in the format "Seq(v1,v2,...,vn)" +// where each value is converted using std::to_string. +// +// Template parameter: +// T: Must be a ck::Sequence<...> type +// +// Constraints: +// - Sequence elements must support std::to_string (typically ck::index_t) +// +// Examples: +// sequence_name>() returns "Seq()" +// sequence_name>() returns "Seq(42)" +// sequence_name>() returns "Seq(1,2,3)" +// sequence_name>() returns "Seq(256,128,64)" +template + requires requires { [](ck::Sequence*) {}(static_cast(nullptr)); } +constexpr std::string sequence_name() +{ + return [](ck::Sequence*) constexpr { + auto to_string_fn = [](auto value) { return std::to_string(value); }; + return detail::build_list_string_values("Seq", to_string_fn); + }(static_cast(nullptr)); +} + +// Convert ck::Tuple to string representation +// Converts a ck::Tuple type to a string in the format "Tuple(e1,e2,...,en)" +// where each element is converted based on its type (layout names or data type names). +// +// Template parameter: +// T: Must be a ck::Tuple<...> type +// +// Constraints: +// - Empty tuples are supported and return "EmptyTuple" +// - All tuple elements must be homogeneous: either all layouts (IsLayoutType) or all data types +// (IsDataType) +// - Mixed layouts and data types in the same tuple will cause a compile-time error +// +// Examples: +// tuple_name>() returns "EmptyTuple" +// tuple_name>() returns "Tuple(RowMajor)" +// tuple_name>() returns "Tuple(NCHW,NHWC)" +// tuple_name>() returns "Tuple(fp16)" +// tuple_name>() returns "Tuple(fp16,fp32,fp64)" +template + requires requires { [](ck::Tuple*) {}(static_cast(nullptr)); } +constexpr std::string tuple_name() +{ + return [](ck::Tuple*) constexpr { + if constexpr(sizeof...(Ts) == 0) + { + return std::string("EmptyTuple"); + } + else if constexpr((IsLayoutType && ...)) + { + // Lambda wrapper for layout_name + auto layout_name_fn = []() { return layout_name(); }; + return detail::build_list_string("Tuple", + layout_name_fn); + } + else if constexpr((IsDataType && ...)) + { + // Lambda wrapper for type_name + auto type_name_fn = []() { return type_name(); }; + return detail::build_list_string("Tuple", type_name_fn); + } + else + { + static_assert((IsLayoutType && ...) || (IsDataType && ...), + "Tuple elements must be all layouts or all data types, not mixed"); + return std::string{}; // unreachable + } + }(static_cast(nullptr)); } } // namespace ck_tile::reflect::detail diff --git a/experimental/builder/test/CMakeLists.txt b/experimental/builder/test/CMakeLists.txt index d44864938f..1dc508a0c3 100644 --- a/experimental/builder/test/CMakeLists.txt +++ b/experimental/builder/test/CMakeLists.txt @@ -16,20 +16,13 @@ function(add_ck_builder_test test_name) target_link_libraries(${test_name} PRIVATE GTest::gtest_main GTest::gmock) endfunction() +# The test_conv_builder target has all the unit tests (each test should run < 10 ms) add_ck_builder_test(test_conv_builder test_conv_builder.cpp test_instance_traits.cpp + test_instance_traits_util.cpp testing_utils.cpp) +# Testing the virtual GetInstanceString methods requires kernel compilation. add_ck_builder_test(test_get_instance_string test_get_instance_string.cpp) - -add_ck_builder_test(test_inline_diff test_inline_diff.cpp testing_utils.cpp) - -add_ck_builder_test(test_ckb_build_fwd_instances - conv/test_ckb_conv_fwd_2d_bf16.cpp - conv/test_ckb_conv_fwd_2d_fp16.cpp - conv/test_ckb_conv_fwd_2d_fp32.cpp - conv/test_ckb_conv_fwd_3d_bf16.cpp - conv/test_ckb_conv_fwd_3d_fp16.cpp - conv/test_ckb_conv_fwd_3d_fp32.cpp) \ No newline at end of file diff --git a/experimental/builder/test/test_instance_traits_util.cpp b/experimental/builder/test/test_instance_traits_util.cpp new file mode 100644 index 0000000000..4aa5ebf25e --- /dev/null +++ b/experimental/builder/test/test_instance_traits_util.cpp @@ -0,0 +1,263 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace ck_tile::reflect::detail { +namespace { + +using ::testing::ElementsAre; +using ::testing::IsEmpty; + +TEST(InstanceTraitsUtil, SequenceToArrayReturnsEmptyArrayForEmptySequence) +{ + EXPECT_THAT(SequenceToArray>::value, IsEmpty()); +} + +TEST(InstanceTraitsUtil, SequenceToArrayReturnsArrayWithSingleElement) +{ + EXPECT_THAT(SequenceToArray>::value, ElementsAre(42)); +} + +TEST(InstanceTraitsUtil, SequenceToArrayReturnsArrayWithMultipleElements) +{ + EXPECT_THAT((SequenceToArray>::value), ElementsAre(1, 2, 3, 4, 5)); +} + +TEST(InstanceTraitsUtil, TypeNameReturnsCorrectStrings) +{ + EXPECT_THAT((std::vector{type_name(), + type_name(), + type_name(), + type_name(), + type_name(), + type_name(), + type_name(), + type_name()}), + ElementsAre("fp16", "fp32", "fp64", "s8", "s32", "bf16", "fp8", "bf8")); +} + +TEST(InstanceTraitsUtil, LayoutNameReturnsCorrectStringsForGemmLayouts) +{ + namespace gemm = ck::tensor_layout::gemm; + EXPECT_THAT((std::vector{layout_name(), + layout_name(), + layout_name()}), + ElementsAre("RowMajor", "ColumnMajor", "MFMA")); +} + +TEST(InstanceTraitsUtil, LayoutNameReturnsCorrectStringsForConvLayouts) +{ + namespace conv = ck::tensor_layout::convolution; + EXPECT_THAT((std::vector{ + // Input tensor layouts + // TODO(deprecated): Remove non-grouped layouts once instances are removed. + layout_name(), + layout_name(), + layout_name(), + layout_name(), + // Grouped input layouts + layout_name(), + layout_name(), + // Weight tensor layouts + layout_name(), + layout_name(), + layout_name(), + layout_name(), + // Output tensor layouts + layout_name(), + layout_name(), + layout_name(), + layout_name(), + // Strided layouts + // TODO(deprecated): Remove strided layouts once instances are removed. + layout_name(), + layout_name(), + layout_name(), + // Bias layouts + layout_name(), + layout_name()}), + ElementsAre("NCHW", + "NHWC", + "NCDHW", + "NDHWC", + "GNCHW", + "GNHWC", + "KCYX", + "KYXC", + "GKCYX", + "GKYXC", + "NKHW", + "NHWK", + "GNKHW", + "GNHWK", + "G_NHW_C", + "G_K_YX_C", + "G_NHW_K", + "G_C", + "G_K")); +} + +TEST(InstanceTraitsUtil, ElementwiseOpNameReturnsCorrectStrings) +{ + namespace element_wise = ck::tensor_operation::element_wise; + EXPECT_THAT((std::vector{ + elementwise_op_name(), + elementwise_op_name(), + elementwise_op_name(), + elementwise_op_name(), + elementwise_op_name(), + elementwise_op_name(), + elementwise_op_name(), + elementwise_op_name(), + elementwise_op_name()}), + ElementsAre("PassThrough", + "Scale", + "Bilinear", + "Add", + "AddRelu", + "Relu", + "BiasNormalizeInInferClamp", + "Clamp", + "AddClamp")); +} + +TEST(InstanceTraitsUtil, ConvFwdSpecNameReturnsCorrectStrings) +{ + using enum ck::tensor_operation::device::ConvolutionForwardSpecialization; + EXPECT_THAT( + (std::vector{conv_fwd_spec_name(Default), + conv_fwd_spec_name(Filter1x1Stride1Pad0), + conv_fwd_spec_name(Filter1x1Pad0), + conv_fwd_spec_name(Filter3x3), + conv_fwd_spec_name(OddC)}), + ElementsAre("Default", "Filter1x1Stride1Pad0", "Filter1x1Pad0", "Filter3x3", "OddC")); +} + +TEST(InstanceTraitsUtil, GemmSpecNameReturnsCorrectStrings) +{ + using enum ck::tensor_operation::device::GemmSpecialization; + EXPECT_THAT((std::vector{gemm_spec_name(Default), + gemm_spec_name(MPadding), + gemm_spec_name(NPadding), + gemm_spec_name(KPadding), + gemm_spec_name(MNPadding), + gemm_spec_name(MKPadding), + gemm_spec_name(NKPadding), + gemm_spec_name(MNKPadding), + gemm_spec_name(OPadding), + gemm_spec_name(MOPadding), + gemm_spec_name(NOPadding), + gemm_spec_name(KOPadding), + gemm_spec_name(MNOPadding), + gemm_spec_name(MKOPadding), + gemm_spec_name(NKOPadding), + gemm_spec_name(MNKOPadding)}), + ElementsAre("Default", + "MPadding", + "NPadding", + "KPadding", + "MNPadding", + "MKPadding", + "NKPadding", + "MNKPadding", + "OPadding", + "MOPadding", + "NOPadding", + "KOPadding", + "MNOPadding", + "MKOPadding", + "NKOPadding", + "MNKOPadding")); +} + +TEST(InstanceTraitsUtil, PipelineSchedulerNameReturnsCorrectStrings) +{ + using enum ck::BlockGemmPipelineScheduler; + EXPECT_THAT((std::vector{pipeline_scheduler_name(Intrawave), + pipeline_scheduler_name(Interwave)}), + ElementsAre("Intrawave", "Interwave")); +} + +TEST(InstanceTraitsUtil, PipelineVersionNameReturnsCorrectStrings) +{ + using enum ck::BlockGemmPipelineVersion; + EXPECT_THAT((std::vector{pipeline_version_name(v1), + pipeline_version_name(v2), + pipeline_version_name(v3), + pipeline_version_name(v4), + pipeline_version_name(v5)}), + ElementsAre("v1", "v2", "v3", "v4", "v5")); +} + +TEST(InstanceTraitsUtil, TupleNameReturnsEmptyTupleForEmptyTuple) +{ + EXPECT_EQ(tuple_name>(), "EmptyTuple"); +} + +TEST(InstanceTraitsUtil, TupleNameReturnsTupleStringForSingleLayout) +{ + EXPECT_EQ(tuple_name>(), "Tuple(NCHW)"); +} + +TEST(InstanceTraitsUtil, TupleNameReturnsTupleStringForTwoLayouts) +{ + EXPECT_EQ((tuple_name>()), + "Tuple(NCHW,NHWC)"); +} + +TEST(InstanceTraitsUtil, TupleNameReturnsTupleStringForThreeLayouts) +{ + EXPECT_EQ((tuple_name>()), + "Tuple(NCHW,NHWC,NKHW)"); +} + +TEST(InstanceTraitsUtil, TupleNameReturnsTupleStringForSingleDataType) +{ + EXPECT_EQ(tuple_name>(), "Tuple(fp16)"); +} + +TEST(InstanceTraitsUtil, TupleNameReturnsTupleStringForTwoDataTypes) +{ + EXPECT_EQ((tuple_name>()), "Tuple(fp16,fp32)"); +} + +TEST(InstanceTraitsUtil, TupleNameReturnsTupleStringForThreeDataTypes) +{ + EXPECT_EQ((tuple_name>()), "Tuple(fp16,fp32,fp64)"); +} + +TEST(InstanceTraitsUtil, SequenceNameReturnsSeqStringForEmptySequence) +{ + EXPECT_EQ(sequence_name>(), "Seq()"); +} + +TEST(InstanceTraitsUtil, SequenceNameReturnsSeqStringForSingleValueSequence) +{ + EXPECT_EQ(sequence_name>(), "Seq(42)"); +} + +TEST(InstanceTraitsUtil, SequenceNameReturnsSeqStringForTwoValueSequence) +{ + EXPECT_EQ((sequence_name>()), "Seq(1,2)"); +} + +TEST(InstanceTraitsUtil, SequenceNameReturnsSeqStringForMultipleValueSequence) +{ + EXPECT_EQ((sequence_name>()), "Seq(256,128,64,32,16)"); +} + +} // namespace +} // namespace ck_tile::reflect::detail diff --git a/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp index 61d249fc93..4954144aca 100644 --- a/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp @@ -12,6 +12,8 @@ namespace element_wise { struct Add { + static constexpr const char* name = "Add"; + template __host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const; @@ -279,6 +281,8 @@ struct Subtract struct Bilinear { + static constexpr const char* name = "Bilinear"; + Bilinear(float alpha = 1.f, float beta = 1.f) : alpha_(alpha), beta_(beta){}; template @@ -353,6 +357,8 @@ struct Bilinear struct AddClamp { + static constexpr const char* name = "AddClamp"; + AddClamp(float floor = 0.f, float ceil = NumericLimits::Max()) : floor_(floor), ceil_(ceil){}; @@ -442,6 +448,8 @@ struct AddClamp struct AddRelu { + static constexpr const char* name = "AddRelu"; + template __host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const; diff --git a/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp index 089d4c2a9d..5edcdd257b 100644 --- a/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp @@ -565,6 +565,8 @@ struct NormalizeInInfer // used by Conv+Bias+BatchNorm+Clamp inference struct BiasNormalizeInInferClamp { + static constexpr const char* name = "BiasNormalizeInInferClamp"; + BiasNormalizeInInferClamp(float floor = 0.f, float ceil = NumericLimits::Max(), float epsilon = 1e-4) diff --git a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp index 4643c0bcb3..59292b30e2 100644 --- a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp @@ -332,6 +332,8 @@ struct PassThroughPack2 struct PassThrough { + static constexpr const char* name = "PassThrough"; + template __host__ __device__ void operator()(Y& y, const X& x) const; @@ -552,8 +554,6 @@ struct PassThrough { y = type_convert(x); } - - static constexpr const char* name = "PassThrough"; }; struct UnaryConvert @@ -620,6 +620,8 @@ struct ConvertF8RNE struct Scale { + static constexpr const char* name = "Scale"; + __host__ __device__ Scale(float scale = 1.f) : scale_(scale) {} template @@ -783,6 +785,8 @@ struct UnarySqrt struct Clamp { + static constexpr const char* name = "Clamp"; + Clamp(float floor = 0.f, float ceil = NumericLimits::Max()) : floor_(floor), ceil_(ceil){}; @@ -856,6 +860,8 @@ struct Clamp struct Relu { + static constexpr const char* name = "Relu"; + template __host__ __device__ void operator()(T& y, const T& x) const {