mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
[CK tests] Extend conv GPU reference (#3539)
* test_convnd_fwd
* test_convnd_bwd_data
* test_conv_bwd_data_scale
* test_grouped_convnd_fwd_clamp
* test_grouped_convnd_fwd_scale
* multiple A/B tensors and D tensor for fwd GPU ref
* test_grouped_convnd_fwd_scaleadd_ab
* test_grouped_convnd_fwd_bias_clamp
* test_grouped_convnd_fwd_bilinear
* test_grouped_convnd_fwd_gk_bias_clamp
* Extend GPU reference to enable batchnorm epilogue
* test_grouped_convnd_fwd{,_gk}_bias_bnorm_clamp
* test_grouped_conv_bwd_data_bilinear
* test_grouped_convnd_bwd_weight_bilinear
* Add missing template instantiation
* Perform operations in float in reference
* Slightly increase tolerance for batchnorm profiler
* Revert "Slightly increase tolerance for batchnorm profiler"
This reverts commit a3b2475229.
* Revert "test_grouped_convnd_fwd{,_gk}_bias_bnorm_clamp"
This reverts commit 6da4576060.
* Revert "Extend GPU reference to enable batchnorm epilogue"
This reverts commit e2f75fa10e.
* Clarify variable names
* Refactor elementwise ops into helper functions
* Make helpers C++17-compatible
This commit is contained in:
@@ -1631,6 +1631,13 @@ struct ConvInvscale
|
||||
e = type_convert<f8_t>(c / scale_in_ / scale_wei_ / scale_out_);
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()<f8_t, f8_t>(f8_t& e, const f8_t& c) const
|
||||
{
|
||||
const float c_float = type_convert<float>(c);
|
||||
e = type_convert<f8_t>(c_float / scale_in_ / scale_wei_ / scale_out_);
|
||||
};
|
||||
|
||||
float scale_in_;
|
||||
float scale_wei_;
|
||||
float scale_out_;
|
||||
@@ -1656,6 +1663,13 @@ struct ConvScale
|
||||
e = type_convert<f8_t>(c * scale_in_ * scale_wei_ * scale_out_);
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()<f8_t, f8_t>(f8_t& e, const f8_t& c) const
|
||||
{
|
||||
const float c_float = type_convert<float>(c);
|
||||
e = type_convert<f8_t>(c_float * scale_in_ * scale_wei_ * scale_out_);
|
||||
};
|
||||
|
||||
float scale_in_;
|
||||
float scale_wei_;
|
||||
float scale_out_;
|
||||
@@ -1683,6 +1697,15 @@ struct ConvScaleRelu
|
||||
e = type_convert<f8_t>(x * scale_out_);
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()<f8_t, f8_t>(f8_t& e, const f8_t& c) const
|
||||
{
|
||||
const float c_float = type_convert<float>(c);
|
||||
float x;
|
||||
Relu{}.template operator()<float>(x, c_float * scale_in_ * scale_wei_);
|
||||
e = type_convert<f8_t>(x * scale_out_);
|
||||
};
|
||||
|
||||
float scale_in_;
|
||||
float scale_wei_;
|
||||
float scale_out_;
|
||||
|
||||
Reference in New Issue
Block a user