Add Clamp/Relu bf16/fp16 cast fixes (#2279)

* Add Clamp/Relu bf16/fp16 fixes

* fix

[ROCm/composable_kernel commit: 6e5acee0f9]
This commit is contained in:
Bartłomiej Kocot
2025-06-03 18:31:46 +02:00
committed by GitHub
parent 1f65826b77
commit 12c18b697e
2 changed files with 25 additions and 18 deletions

View File

@@ -389,10 +389,9 @@ struct AddClamp
__host__ __device__ constexpr void
operator()<half_t, float, half_t>(half_t& y, const float& x0, const half_t& x1) const
{
const float a = x0 + x1;
y = a > type_convert<half_t>(floor_)
? (a < type_convert<half_t>(ceil_) ? a : type_convert<half_t>(ceil_))
: type_convert<half_t>(floor_);
const float a = x0 + type_convert<float>(x1);
const float b = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_;
y = type_convert<half_t>(b);
};
template <>
@@ -408,9 +407,8 @@ struct AddClamp
operator()<bhalf_t, float, bhalf_t>(bhalf_t& y, const float& x0, const bhalf_t& x1) const
{
const float a = x0 + type_convert<float>(x1);
y = a > type_convert<bhalf_t>(floor_)
? (a < type_convert<bhalf_t>(ceil_) ? a : type_convert<bhalf_t>(ceil_))
: type_convert<bhalf_t>(floor_);
const float b = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_;
y = type_convert<bhalf_t>(b);
};
template <>
@@ -418,9 +416,8 @@ struct AddClamp
operator()<bhalf_t, bhalf_t, bhalf_t>(bhalf_t& y, const bhalf_t& x0, const bhalf_t& x1) const
{
const float a = type_convert<float>(x0) + type_convert<float>(x1);
y = a > type_convert<bhalf_t>(floor_)
? (a < type_convert<bhalf_t>(ceil_) ? a : type_convert<bhalf_t>(ceil_))
: type_convert<bhalf_t>(floor_);
const float b = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_;
y = type_convert<bhalf_t>(b);
};
template <>
@@ -476,8 +473,9 @@ struct AddRelu
__host__ __device__ constexpr void
operator()<half_t, float, half_t>(half_t& y, const float& x0, const half_t& x1) const
{
const float a = x0 + x1;
y = a > type_convert<half_t>(0.0f) ? a : type_convert<half_t>(0.0f);
const float a = x0 + type_convert<float>(x1);
const float b = a > 0.0f ? a : 0.0f;
y = type_convert<half_t>(b);
};
template <>
@@ -493,7 +491,8 @@ struct AddRelu
operator()<bhalf_t, float, bhalf_t>(bhalf_t& y, const float& x0, const bhalf_t& x1) const
{
const float a = x0 + type_convert<float>(x1);
y = a > type_convert<bhalf_t>(0.0f) ? a : type_convert<bhalf_t>(0.0f);
const float b = a > 0.0f ? a : 0.0f;
y = type_convert<bhalf_t>(b);
};
template <>
@@ -501,7 +500,8 @@ struct AddRelu
operator()<bhalf_t, bhalf_t, bhalf_t>(bhalf_t& y, const bhalf_t& x0, const bhalf_t& x1) const
{
const float a = type_convert<float>(x0) + type_convert<float>(x1);
y = a > type_convert<bhalf_t>(0.0f) ? a : type_convert<bhalf_t>(0.0f);
const float b = a > 0.0f ? a : 0.0f;
y = type_convert<bhalf_t>(b);
};
template <>

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -383,22 +383,29 @@ struct ReferenceConvFwd : public device::BaseOperator
const T& x,
Args... dims)
{
float y_f32;
if constexpr(NumTensor::value == 0)
{
elementwise_op(y, x);
elementwise_op(y_f32, ck::type_convert<float>(x));
}
else if constexpr(NumTensor::value == 1)
{
elementwise_op(y, x, elementwise_tensors[0](dims...));
elementwise_op(y_f32,
ck::type_convert<float>(x),
ck::type_convert<float>(elementwise_tensors[0](dims...)));
}
else if constexpr(NumTensor::value == 2)
{
elementwise_op(y, x, elementwise_tensors[0](dims...), elementwise_tensors[1](dims...));
elementwise_op(y_f32,
ck::type_convert<float>(x),
ck::type_convert<float>(elementwise_tensors[0](dims...)),
ck::type_convert<float>(elementwise_tensors[1](dims...)));
}
else
{
throw std::runtime_error("ElementOp not supported in reference.");
}
y = ck::type_convert<T>(y_f32);
}
static constexpr bool IsValidCompilationParameter()