From 5bb30d4077e6ed2151035bed790675f2aac86442 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Wed, 24 Jul 2024 19:12:37 +0200 Subject: [PATCH] Add support for half_t and bfloat to reduction operations (#1395) * Add support for half_t and bfloat to reduction operations * Fix bhalf convert * Next fix bf16 [ROCm/composable_kernel commit: ffabd70a15ca907f271b61a8301decc0c05ffee0] --- include/ck/utility/reduction_operator.hpp | 22 +++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/include/ck/utility/reduction_operator.hpp b/include/ck/utility/reduction_operator.hpp index 5480a98409..fffd0ac49e 100644 --- a/include/ck/utility/reduction_operator.hpp +++ b/include/ck/utility/reduction_operator.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -52,11 +52,19 @@ struct Add __host__ __device__ inline constexpr void operator()(T& a, T b) const { static_assert(is_same::value || is_same::value || - is_same::value, + is_same::value || is_same::value, "The data type is not supported by the Add accumulator!"); a = a + b; } + + __host__ __device__ inline constexpr void operator()(bhalf_t& a, bhalf_t b) const + { + float a_ = type_convert(a); + float b_ = type_convert(b); + + a = type_convert(a_ + b_); + } }; struct SquaredAdd @@ -104,11 +112,19 @@ struct Mul __host__ __device__ inline constexpr void operator()(T& a, T b) const { static_assert(is_same::value || is_same::value || - is_same::value, + is_same::value || is_same::value, "The data type is not supported by the Mul accumulator!"); a = a * b; } + + __host__ __device__ inline constexpr void operator()(bhalf_t& a, bhalf_t b) const + { + float a_ = type_convert(a); + float b_ = type_convert(b); + + a = type_convert(a_ * b_); + } }; struct Max