Add Clamp/Relu bf16/fp16 cast fixes (#2279)

* Add Clamp/Relu bf16/fp16 fixes

* fix
This commit is contained in:
Bartłomiej Kocot
2025-06-03 18:31:46 +02:00
committed by GitHub
parent 7f9eef40b0
commit 6e5acee0f9
2 changed files with 25 additions and 18 deletions

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -383,22 +383,29 @@ struct ReferenceConvFwd : public device::BaseOperator
const T& x,
Args... dims)
{
float y_f32;
if constexpr(NumTensor::value == 0)
{
elementwise_op(y, x);
elementwise_op(y_f32, ck::type_convert<float>(x));
}
else if constexpr(NumTensor::value == 1)
{
elementwise_op(y, x, elementwise_tensors[0](dims...));
elementwise_op(y_f32,
ck::type_convert<float>(x),
ck::type_convert<float>(elementwise_tensors[0](dims...)));
}
else if constexpr(NumTensor::value == 2)
{
elementwise_op(y, x, elementwise_tensors[0](dims...), elementwise_tensors[1](dims...));
elementwise_op(y_f32,
ck::type_convert<float>(x),
ck::type_convert<float>(elementwise_tensors[0](dims...)),
ck::type_convert<float>(elementwise_tensors[1](dims...)));
}
else
{
throw std::runtime_error("ElementOp not supported in reference.");
}
y = ck::type_convert<T>(y_f32);
}
static constexpr bool IsValidCompilationParameter()