[CK_BUILDER] Test and fix instance traits utils. (#3096)

* Refactor instance_traits_util and add unit tests tests

* Address reviewer comments.

Just adds some TODOs to indicate deprecated layouts in our reflection. Our strategy is to leave the reflection code broad (covering deprecated features), but keep the builder concepts narrow. Once we've removed deprecated features from all instances, we can remove them from reflection.

Also add a comment to the cmake to explain the unit test target test_conv_builder.

* Addressed more reviewer comments.

* Remove duplicate PassThrough::name

Accidentally added this field to the end of the struct, too. The `name` field should be a the start of the struct for consistency.
This commit is contained in:
John Shumway
2025-10-27 22:14:08 -07:00
committed by GitHub
parent e02b1e7caf
commit 54746e9329
6 changed files with 479 additions and 72 deletions

View File

@@ -12,6 +12,8 @@ namespace element_wise {
struct Add
{
static constexpr const char* name = "Add";
template <typename Y, typename X0, typename X1>
__host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const;
@@ -279,6 +281,8 @@ struct Subtract
struct Bilinear
{
static constexpr const char* name = "Bilinear";
Bilinear(float alpha = 1.f, float beta = 1.f) : alpha_(alpha), beta_(beta){};
template <typename Y, typename X0, typename X1>
@@ -353,6 +357,8 @@ struct Bilinear
struct AddClamp
{
static constexpr const char* name = "AddClamp";
AddClamp(float floor = 0.f, float ceil = NumericLimits<float>::Max())
: floor_(floor), ceil_(ceil){};
@@ -442,6 +448,8 @@ struct AddClamp
struct AddRelu
{
static constexpr const char* name = "AddRelu";
template <typename Y, typename X0, typename X1>
__host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const;

View File

@@ -565,6 +565,8 @@ struct NormalizeInInfer
// used by Conv+Bias+BatchNorm+Clamp inference
struct BiasNormalizeInInferClamp
{
static constexpr const char* name = "BiasNormalizeInInferClamp";
BiasNormalizeInInferClamp(float floor = 0.f,
float ceil = NumericLimits<float>::Max(),
float epsilon = 1e-4)

View File

@@ -332,6 +332,8 @@ struct PassThroughPack2
struct PassThrough
{
static constexpr const char* name = "PassThrough";
template <typename Y, typename X>
__host__ __device__ void operator()(Y& y, const X& x) const;
@@ -552,8 +554,6 @@ struct PassThrough
{
y = type_convert<bf8_t>(x);
}
static constexpr const char* name = "PassThrough";
};
struct UnaryConvert
@@ -620,6 +620,8 @@ struct ConvertF8RNE
struct Scale
{
static constexpr const char* name = "Scale";
__host__ __device__ Scale(float scale = 1.f) : scale_(scale) {}
template <typename Y, typename X>
@@ -783,6 +785,8 @@ struct UnarySqrt
struct Clamp
{
static constexpr const char* name = "Clamp";
Clamp(float floor = 0.f, float ceil = NumericLimits<float>::Max())
: floor_(floor), ceil_(ceil){};
@@ -856,6 +860,8 @@ struct Clamp
struct Relu
{
static constexpr const char* name = "Relu";
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{