[CK tests] Extend conv GPU reference (#3539)

* test_convnd_fwd

* test_convnd_bwd_data

* test_conv_bwd_data_scale

* test_grouped_convnd_fwd_clamp

* test_grouped_convnd_fwd_scale

* multiple A/B tensors and D tensor for fwd GPU ref

* test_grouped_convnd_fwd_scaleadd_ab

* test_grouped_convnd_fwd_bias_clamp

* test_grouped_convnd_fwd_bilinear

* test_grouped_convnd_fwd_gk_bias_clamp

* Extend GPU reference to enable batchnorm epilogue

* test_grouped_convnd_fwd{,_gk}_bias_bnorm_clamp

* test_grouped_conv_bwd_data_bilinear

* test_grouped_convnd_bwd_weight_bilinear

* Add missing template instantiation

* Perform operations in float in reference

* Slightly increase tolerance for batchnorm profiler

* Revert "Slightly increase tolerance for batchnorm profiler"

This reverts commit a3b2475229.

* Revert "test_grouped_convnd_fwd{,_gk}_bias_bnorm_clamp"

This reverts commit 6da4576060.

* Revert "Extend GPU reference to enable batchnorm epilogue"

This reverts commit e2f75fa10e.

* Clarify variable names

* Refactor elementwise ops into helper functions

* Make helpers C++17-compatible
This commit is contained in:
Johannes Graner
2026-01-27 09:49:42 +01:00
committed by GitHub
parent cc75948d1c
commit c190d8d61f
24 changed files with 2217 additions and 473 deletions

View File

@@ -1631,6 +1631,13 @@ struct ConvInvscale
e = type_convert<f8_t>(c / scale_in_ / scale_wei_ / scale_out_);
};
template <>
__host__ __device__ void operator()<f8_t, f8_t>(f8_t& e, const f8_t& c) const
{
const float c_float = type_convert<float>(c);
e = type_convert<f8_t>(c_float / scale_in_ / scale_wei_ / scale_out_);
};
float scale_in_;
float scale_wei_;
float scale_out_;
@@ -1656,6 +1663,13 @@ struct ConvScale
e = type_convert<f8_t>(c * scale_in_ * scale_wei_ * scale_out_);
};
template <>
__host__ __device__ void operator()<f8_t, f8_t>(f8_t& e, const f8_t& c) const
{
const float c_float = type_convert<float>(c);
e = type_convert<f8_t>(c_float * scale_in_ * scale_wei_ * scale_out_);
};
float scale_in_;
float scale_wei_;
float scale_out_;
@@ -1683,6 +1697,15 @@ struct ConvScaleRelu
e = type_convert<f8_t>(x * scale_out_);
};
template <>
__host__ __device__ void operator()<f8_t, f8_t>(f8_t& e, const f8_t& c) const
{
const float c_float = type_convert<float>(c);
float x;
Relu{}.template operator()<float>(x, c_float * scale_in_ * scale_wei_);
e = type_convert<f8_t>(x * scale_out_);
};
float scale_in_;
float scale_wei_;
float scale_out_;