Add support for half_t and bfloat to reduction operations (#1395)

* Add support for half_t and bfloat to reduction operations

* Fix bhalf convert

* Next fix bf16

[ROCm/composable_kernel commit: ffabd70a15]
This commit is contained in:
Bartłomiej Kocot
2024-07-24 19:12:37 +02:00
committed by GitHub
parent fed6f87e5c
commit 5bb30d4077

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -52,11 +52,19 @@ struct Add
__host__ __device__ inline constexpr void operator()(T& a, T b) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, int32_t>::value,
is_same<T, int32_t>::value || is_same<T, half_t>::value,
"The data type is not supported by the Add accumulator!");
a = a + b;
}
__host__ __device__ inline constexpr void operator()(bhalf_t& a, bhalf_t b) const
{
float a_ = type_convert<float>(a);
float b_ = type_convert<float>(b);
a = type_convert<bhalf_t>(a_ + b_);
}
};
struct SquaredAdd
@@ -104,11 +112,19 @@ struct Mul
__host__ __device__ inline constexpr void operator()(T& a, T b) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, int32_t>::value,
is_same<T, int32_t>::value || is_same<T, half_t>::value,
"The data type is not supported by the Mul accumulator!");
a = a * b;
}
__host__ __device__ inline constexpr void operator()(bhalf_t& a, bhalf_t b) const
{
float a_ = type_convert<float>(a);
float b_ = type_convert<float>(b);
a = type_convert<bhalf_t>(a_ * b_);
}
};
struct Max