mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 10:09:41 +00:00
Add support for half_t and bfloat to reduction operations (#1395)
* Add support for half_t and bfloat to reduction operations
* Fix bhalf convert
* Next fix bf16
[ROCm/composable_kernel commit: ffabd70a15]
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -52,11 +52,19 @@ struct Add
|
||||
__host__ __device__ inline constexpr void operator()(T& a, T b) const
|
||||
{
|
||||
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
|
||||
is_same<T, int32_t>::value,
|
||||
is_same<T, int32_t>::value || is_same<T, half_t>::value,
|
||||
"The data type is not supported by the Add accumulator!");
|
||||
|
||||
a = a + b;
|
||||
}
|
||||
|
||||
__host__ __device__ inline constexpr void operator()(bhalf_t& a, bhalf_t b) const
|
||||
{
|
||||
float a_ = type_convert<float>(a);
|
||||
float b_ = type_convert<float>(b);
|
||||
|
||||
a = type_convert<bhalf_t>(a_ + b_);
|
||||
}
|
||||
};
|
||||
|
||||
struct SquaredAdd
|
||||
@@ -104,11 +112,19 @@ struct Mul
|
||||
__host__ __device__ inline constexpr void operator()(T& a, T b) const
|
||||
{
|
||||
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
|
||||
is_same<T, int32_t>::value,
|
||||
is_same<T, int32_t>::value || is_same<T, half_t>::value,
|
||||
"The data type is not supported by the Mul accumulator!");
|
||||
|
||||
a = a * b;
|
||||
}
|
||||
|
||||
__host__ __device__ inline constexpr void operator()(bhalf_t& a, bhalf_t b) const
|
||||
{
|
||||
float a_ = type_convert<float>(a);
|
||||
float b_ = type_convert<float>(b);
|
||||
|
||||
a = type_convert<bhalf_t>(a_ * b_);
|
||||
}
|
||||
};
|
||||
|
||||
struct Max
|
||||
|
||||
Reference in New Issue
Block a user