From dbdf79d541df32e2fc0b2499ea935d47b84f0875 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Tue, 3 Jun 2025 18:31:46 +0200 Subject: [PATCH] Add Clamp/Relu bf16/fp16 cast fixes (#2279) * Add Clamp/Relu bf16/fp16 fixes * fix [ROCm/composable_kernel commit: 6e5acee0f951e4d174ac9afb4afce83fc801305d] --- .../element/binary_element_wise_operation.hpp | 28 +++++++++---------- .../cpu/reference_conv_fwd.hpp | 15 +++++++--- 2 files changed, 25 insertions(+), 18 deletions(-) diff --git a/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp index badd64508d..34c76b89e4 100644 --- a/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp @@ -389,10 +389,9 @@ struct AddClamp __host__ __device__ constexpr void operator()(half_t& y, const float& x0, const half_t& x1) const { - const float a = x0 + x1; - y = a > type_convert(floor_) - ? (a < type_convert(ceil_) ? a : type_convert(ceil_)) - : type_convert(floor_); + const float a = x0 + type_convert(x1); + const float b = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_; + y = type_convert(b); }; template <> @@ -408,9 +407,8 @@ struct AddClamp operator()(bhalf_t& y, const float& x0, const bhalf_t& x1) const { const float a = x0 + type_convert(x1); - y = a > type_convert(floor_) - ? (a < type_convert(ceil_) ? a : type_convert(ceil_)) - : type_convert(floor_); + const float b = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_; + y = type_convert(b); }; template <> @@ -418,9 +416,8 @@ struct AddClamp operator()(bhalf_t& y, const bhalf_t& x0, const bhalf_t& x1) const { const float a = type_convert(x0) + type_convert(x1); - y = a > type_convert(floor_) - ? (a < type_convert(ceil_) ? a : type_convert(ceil_)) - : type_convert(floor_); + const float b = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_; + y = type_convert(b); }; template <> @@ -476,8 +473,9 @@ struct AddRelu __host__ __device__ constexpr void operator()(half_t& y, const float& x0, const half_t& x1) const { - const float a = x0 + x1; - y = a > type_convert(0.0f) ? a : type_convert(0.0f); + const float a = x0 + type_convert(x1); + const float b = a > 0.0f ? a : 0.0f; + y = type_convert(b); }; template <> @@ -493,7 +491,8 @@ struct AddRelu operator()(bhalf_t& y, const float& x0, const bhalf_t& x1) const { const float a = x0 + type_convert(x1); - y = a > type_convert(0.0f) ? a : type_convert(0.0f); + const float b = a > 0.0f ? a : 0.0f; + y = type_convert(b); }; template <> @@ -501,7 +500,8 @@ struct AddRelu operator()(bhalf_t& y, const bhalf_t& x0, const bhalf_t& x1) const { const float a = type_convert(x0) + type_convert(x1); - y = a > type_convert(0.0f) ? a : type_convert(0.0f); + const float b = a > 0.0f ? a : 0.0f; + y = type_convert(b); }; template <> diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp index 9c1349f56c..3884902bbf 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -383,22 +383,29 @@ struct ReferenceConvFwd : public device::BaseOperator const T& x, Args... dims) { + float y_f32; if constexpr(NumTensor::value == 0) { - elementwise_op(y, x); + elementwise_op(y_f32, ck::type_convert(x)); } else if constexpr(NumTensor::value == 1) { - elementwise_op(y, x, elementwise_tensors[0](dims...)); + elementwise_op(y_f32, + ck::type_convert(x), + ck::type_convert(elementwise_tensors[0](dims...))); } else if constexpr(NumTensor::value == 2) { - elementwise_op(y, x, elementwise_tensors[0](dims...), elementwise_tensors[1](dims...)); + elementwise_op(y_f32, + ck::type_convert(x), + ck::type_convert(elementwise_tensors[0](dims...)), + ck::type_convert(elementwise_tensors[1](dims...))); } else { throw std::runtime_error("ElementOp not supported in reference."); } + y = ck::type_convert(y_f32); } static constexpr bool IsValidCompilationParameter()