mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 17:00:18 +00:00
Add Clamp/Relu bf16/fp16 cast fixes (#2279)
* Add Clamp/Relu bf16/fp16 fixes * fix
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -383,22 +383,29 @@ struct ReferenceConvFwd : public device::BaseOperator
|
||||
const T& x,
|
||||
Args... dims)
|
||||
{
|
||||
float y_f32;
|
||||
if constexpr(NumTensor::value == 0)
|
||||
{
|
||||
elementwise_op(y, x);
|
||||
elementwise_op(y_f32, ck::type_convert<float>(x));
|
||||
}
|
||||
else if constexpr(NumTensor::value == 1)
|
||||
{
|
||||
elementwise_op(y, x, elementwise_tensors[0](dims...));
|
||||
elementwise_op(y_f32,
|
||||
ck::type_convert<float>(x),
|
||||
ck::type_convert<float>(elementwise_tensors[0](dims...)));
|
||||
}
|
||||
else if constexpr(NumTensor::value == 2)
|
||||
{
|
||||
elementwise_op(y, x, elementwise_tensors[0](dims...), elementwise_tensors[1](dims...));
|
||||
elementwise_op(y_f32,
|
||||
ck::type_convert<float>(x),
|
||||
ck::type_convert<float>(elementwise_tensors[0](dims...)),
|
||||
ck::type_convert<float>(elementwise_tensors[1](dims...)));
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("ElementOp not supported in reference.");
|
||||
}
|
||||
y = ck::type_convert<T>(y_f32);
|
||||
}
|
||||
|
||||
static constexpr bool IsValidCompilationParameter()
|
||||
|
||||
Reference in New Issue
Block a user