diff --git a/composable_kernel/include/utility/config.hpp b/composable_kernel/include/utility/config.hpp index 2f540e1083..f4181b29d4 100644 --- a/composable_kernel/include/utility/config.hpp +++ b/composable_kernel/include/utility/config.hpp @@ -5,7 +5,6 @@ #include "hip/hip_runtime.h" #include "hip/hip_fp16.h" #endif -#include "bfloat16_dev.hpp" // "Constant" address space for kernel parameter #define CONSTANT __attribute__((address_space(4))) diff --git a/composable_kernel/include/utility/data_type.hpp b/composable_kernel/include/utility/data_type.hpp index cc5ee0de0e..96157bd19d 100644 --- a/composable_kernel/include/utility/data_type.hpp +++ b/composable_kernel/include/utility/data_type.hpp @@ -927,6 +927,58 @@ using int8x16_t = typename vector_type::type; using int8x32_t = typename vector_type::type; using int8x64_t = typename vector_type::type; +__host__ __device__ float bf16_to_f32(ushort src_val) +{ + union + { + uint32_t int32; + float fp32; + } u = {uint32_t(src_val) << 16}; + return u.fp32; +} + +__host__ __device__ ushort f32_to_bf16(float src_val) +{ + union + { + float fp32; + uint32_t int32; + } u = {src_val}; + if(~u.int32 & 0x7f800000) + { + // When the exponent bits are not all 1s, then the value is zero, normal, + // or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus + // 1 if the least significant bit of the bfloat16 mantissa is 1 (odd). + // This causes the bfloat16's mantissa to be incremented by 1 if the 16 + // least significant bits of the float mantissa are greater than 0x8000, + // or if they are equal to 0x8000 and the least significant bit of the + // bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when + // the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already + // has the value 0x7f, then incrementing it causes it to become 0x00 and + // the exponent is incremented by one, which is the next higher FP value + // to the unrounded bfloat16 value. When the bfloat16 value is subnormal + // with an exponent of 0x00 and a mantissa of 0x7F, it may be rounded up + // to a normal value with an exponent of 0x01 and a mantissa of 0x00. + // When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F, + // incrementing it causes it to become an exponent of 0xFF and a mantissa + // of 0x00, which is Inf, the next higher value to the unrounded value. + u.int32 += 0x7fff + ((u.int32 >> 16) & 1); // Round to nearest, round to even + } + else if(u.int32 & 0xffff) + { + // When all of the exponent bits are 1, the value is Inf or NaN. + // Inf is indicated by a zero mantissa. NaN is indicated by any nonzero + // mantissa bit. Quiet NaN is indicated by the most significant mantissa + // bit being 1. Signaling NaN is indicated by the most significant + // mantissa bit being 0 but some other bit(s) being 1. If any of the + // lower 16 bits of the mantissa are 1, we set the least significant bit + // of the bfloat16 mantissa, in order to preserve signaling NaN in case + // the bloat16's mantissa bits are all 0. + u.int32 |= 0x10000; // Preserve signaling NaN + } + return uint16_t(u.int32 >> 16); +} + // data type conversion template struct type_convert @@ -942,14 +994,14 @@ template <> template <> __device__ float type_convert::operator()(ushort x) const { - return bfloat16_to_float(x); + return bf16_to_f32(x); } template <> template <> __device__ ushort type_convert::operator()(float x) const { - return float_to_bfloat16(x); + return f32_to_bf16(x); } // TODO: deprecate this diff --git a/composable_kernel/include/utility/inner_product.hpp b/composable_kernel/include/utility/inner_product.hpp index 51753accf3..813b559474 100644 --- a/composable_kernel/include/utility/inner_product.hpp +++ b/composable_kernel/include/utility/inner_product.hpp @@ -28,6 +28,12 @@ __device__ void inner_product(const float& a, const float& #endif } +template <> +__device__ void inner_product(const ushort& a, const ushort& b, float& c) +{ + c += bf16_to_f32(a) * bf16_to_f32(b); +} + template <> __device__ void inner_product(const float2_t& a, const float2_t& b, float& c) diff --git a/external/rocm/include/bfloat16_dev.hpp b/external/rocm/include/bfloat16_dev.hpp deleted file mode 100644 index 304d8406a8..0000000000 --- a/external/rocm/include/bfloat16_dev.hpp +++ /dev/null @@ -1,125 +0,0 @@ -/******************************************************************************* - * - * MIT License - * - * Copyright (c) 2019 Advanced Micro Devices, Inc. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * - *******************************************************************************/ -#ifndef BFLOAT16_DEVICE_HPP -#define BFLOAT16_DEVICE_HPP - -#ifdef __cplusplus -extern "C" { -#endif - -#ifdef __HIP_PLATFORM_HCC__ -#define EXECUTION_SPECIFIER __device__ __host__ -#else -#define EXECUTION_SPECIFIER -#endif // MIOPEN_BACKEND_HIP - -typedef union -{ - uint u32; - ushort2 ushortx2; - -// Composable kernels are written in HIP language. The language doesnt support -// ushort2.hi or ushort2.low. -#ifdef __HIP_PLATFORM_HCC__ - ushort ushortvec[2]; -#endif // MIOPEN_BACKEND_HIP - float f32; -} cvt_bf16_fp32_t; - -EXECUTION_SPECIFIER float bfloat16_to_float(ushort src_val) -{ - cvt_bf16_fp32_t target_val; - -#ifdef __HIP_PLATFORM_HCC__ - target_val.ushortx2 = make_ushort2(0, src_val); -#else - target_val.ushortx2 = (ushort2)(0, src_val); -#endif - - return target_val.f32; -} - -EXECUTION_SPECIFIER ushort float_to_bfloat16(float src_val) -{ - cvt_bf16_fp32_t target_val; - target_val.f32 = src_val; - // BF16 round and NaN preservation code matches - // https://github.com/ROCmSoftwarePlatform/rocBLAS/blob/develop/library/include/rocblas_bfloat16.h - if((~target_val.u32 & 0x7f800000) == 0) // Inf or NaN - { - // When all of the exponent bits are 1, the value is Inf or NaN. - // Inf is indicated by a zero mantissa. NaN is indicated by any nonzero - // mantissa bit. Quiet NaN is indicated by the most significant mantissa - // bit being 1. Signaling NaN is indicated by the most significant - // mantissa bit being 0 but some other bit(s) being 1. If any of the - // lower 16 bits of the mantissa are 1, we set the least significant bit - // of the bfloat16 mantissa, in order to preserve signaling NaN in case - // the bloat16's mantissa bits are all 0. - if((target_val.u32 & 0xffff) != 0) - { - target_val.u32 |= 0x10000; // Preserve signaling NaN - } - } - else - { -#ifdef MIOPEN_USE_RNE_BFLOAT16 -// When the exponent bits are not all 1s, then the value is zero, normal, -// or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus -// 1 if the least significant bit of the bfloat16 mantissa is 1 (odd). -// This causes the bfloat16's mantissa to be incremented by 1 if the 16 -// least significant bits of the float mantissa are greater than 0x8000, -// or if they are equal to 0x8000 and the least significant bit of the -// bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when -// the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already -// has the value 0x7f, then incrementing it causes it to become 0x00 and -// the exponent is incremented by one, which is the next higher FP value -// to the unrounded bfloat16 value. When the bfloat16 value is subnormal -// with an exponent of 0x00 and a mantissa of 0x7F, it may be rounded up -// to a normal value with an exponent of 0x01 and a mantissa of 0x00. -// When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F, -// incrementing it causes it to become an exponent of 0xFF and a mantissa -// of 0x00, which is Inf, the next higher value to the unrounded value. -#ifdef __HIP_PLATFORM_HCC__ - target_val.u32 += (0x7fff + (target_val.ushortvec[1] & 1)); -#else - target_val.u32 += - (0x7fff + (target_val.ushortx2.hi & 1)); // Round to nearest, round to even -#endif // MIOPEN_BACKEND_HIP -#endif // MIOPEN_USE_RNE_BFLOAT16 - } - -#ifdef __HIP_PLATFORM_HCC__ - return target_val.ushortvec[1]; -#else - return target_val.ushortx2.hi; -#endif // MIOPEN_BACKEND_HIP -} - -#ifdef __cplusplus -} -#endif - -#endif // BFLOAT16_DEVICE_HPP diff --git a/host/driver_offline/src/conv_fwd_driver_offline.cpp b/host/driver_offline/src/conv_fwd_driver_offline.cpp index e63f176d4b..d87195e366 100644 --- a/host/driver_offline/src/conv_fwd_driver_offline.cpp +++ b/host/driver_offline/src/conv_fwd_driver_offline.cpp @@ -82,8 +82,8 @@ void host_convolution_forward(const Tensor& in, { if constexpr(is_same::value) { - v += bfloat16_to_float(in(n, c, hi, wi)) * - bfloat16_to_float(wei(k, c, y, x)); + v += ck::bf16_to_f32(in(n, c, hi, wi)) * + ck::bf16_to_f32(wei(k, c, y, x)); } else { @@ -97,7 +97,7 @@ void host_convolution_forward(const Tensor& in, if constexpr(is_same::value) { - out(n, k, ho, wo) = float_to_bfloat16(v); + out(n, k, ho, wo) = f32_to_bf16(v); } else { @@ -120,8 +120,8 @@ void host_convolution_forward(const Tensor& in, { if constexpr(is_same::value) { - v += bfloat16_to_float(in(n, hi, wi, c)) * - bfloat16_to_float(wei(k, y, x, c)); + v += ck::bf16_to_f32(in(n, hi, wi, c)) * + ck::bf16_to_f32(wei(k, y, x, c)); } else { @@ -134,7 +134,7 @@ void host_convolution_forward(const Tensor& in, } if constexpr(is_same::value) { - out(n, ho, wo, k) = float_to_bfloat16(v); + out(n, ho, wo, k) = f32_to_bf16(v); } else { diff --git a/host/host_tensor/include/host_gemm.hpp b/host/host_tensor/include/host_gemm.hpp index 70f1c4dfa3..b5dbedd1d0 100644 --- a/host/host_tensor/include/host_gemm.hpp +++ b/host/host_tensor/include/host_gemm.hpp @@ -16,10 +16,10 @@ void host_gemm(const Tensor& a, for(int k = 0; k < K; ++k) { - v += bfloat16_to_float(a(m, k)) * bfloat16_to_float(b(k, n)); + v += ck::bf16_to_f32(a(m, k)) * ck::bf16_to_f32(b(k, n)); } - c(m, n) = float_to_bfloat16(v); + c(m, n) = ck::f32_to_bf16(v); }; make_ParallelTensorFunctor(f_mk_kn_mn, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])( @@ -34,10 +34,10 @@ void host_gemm(const Tensor& a, for(int k = 0; k < K; ++k) { - v += bfloat16_to_float(a(m, k)) * bfloat16_to_float(b(n, k)); + v += ck::bf16_to_f32(a(m, k)) * ck::bf16_to_f32(b(n, k)); } - c(m, n) = float_to_bfloat16(v); + c(m, n) = ck::f32_to_bf16(v); }; make_ParallelTensorFunctor(f_mk_nk_mn, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])( @@ -52,10 +52,10 @@ void host_gemm(const Tensor& a, for(int k = 0; k < K; ++k) { - v += bfloat16_to_float(a(k, m)) * bfloat16_to_float(b(k, n)); + v += ck::bf16_to_f32(a(k, m)) * ck::bf16_to_f32(b(k, n)); } - c(m, n) = float_to_bfloat16(v); + c(m, n) = ck::f32_to_bf16(v); }; make_ParallelTensorFunctor(f_km_kn_mn, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])( @@ -70,10 +70,10 @@ void host_gemm(const Tensor& a, for(int k = 0; k < K; ++k) { - v += bfloat16_to_float(a(k, m)) * bfloat16_to_float(b(n, k)); + v += ck::bf16_to_f32(a(k, m)) * ck::bf16_to_f32(b(n, k)); } - c(m, n) = float_to_bfloat16(v); + c(m, n) = ck::f32_to_bf16(v); }; make_ParallelTensorFunctor(f_km_nk_mn, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])( @@ -88,10 +88,10 @@ void host_gemm(const Tensor& a, for(int k = 0; k < K; ++k) { - v += bfloat16_to_float(a(m, k)) * bfloat16_to_float(b(k, n)); + v += ck::bf16_to_f32(a(m, k)) * ck::bf16_to_f32(b(k, n)); } - c(n, m) = float_to_bfloat16(v); + c(n, m) = ck::f32_to_bf16(v); }; make_ParallelTensorFunctor(f_mk_kn_nm, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])( @@ -106,10 +106,10 @@ void host_gemm(const Tensor& a, for(int k = 0; k < K; ++k) { - v += bfloat16_to_float(a(m, k)) * bfloat16_to_float(b(n, k)); + v += ck::bf16_to_f32(a(m, k)) * ck::bf16_to_f32(b(n, k)); } - c(n, m) = float_to_bfloat16(v); + c(n, m) = ck::f32_to_bf16(v); }; make_ParallelTensorFunctor(f_mk_nk_nm, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])( @@ -124,10 +124,10 @@ void host_gemm(const Tensor& a, for(int k = 0; k < K; ++k) { - v += bfloat16_to_float(a(k, m)) * bfloat16_to_float(b(k, n)); + v += ck::bf16_to_f32(a(k, m)) * ck::bf16_to_f32(b(k, n)); } - c(n, m) = float_to_bfloat16(v); + c(n, m) = ck::f32_to_bf16(v); }; make_ParallelTensorFunctor(f_km_kn_nm, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])( @@ -142,10 +142,10 @@ void host_gemm(const Tensor& a, for(int k = 0; k < K; ++k) { - v += bfloat16_to_float(a(k, m)) * bfloat16_to_float(b(n, k)); + v += ck::bf16_to_f32(a(k, m)) * ck::bf16_to_f32(b(n, k)); } - c(n, m) = float_to_bfloat16(v); + c(n, m) = ck::f32_to_bf16(v); }; make_ParallelTensorFunctor(f_km_nk_nm, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])( diff --git a/host/host_tensor/include/host_tensor.hpp b/host/host_tensor/include/host_tensor.hpp index 853261103c..352ccccde0 100644 --- a/host/host_tensor/include/host_tensor.hpp +++ b/host/host_tensor/include/host_tensor.hpp @@ -321,18 +321,14 @@ void check_error(const Tensor& ref, const Tensor& result) std::cout << "max_diff: " << max_diff << ", " << ref_value << ", " << result_value << std::endl; } -float bf16_to_f32(ushort src_val) +__host__ __device__ float bf16_to_f32(ushort src_val) { - typedef union + union { - ushort x, y; - float f32; - } bf16_f32_t; - - bf16_f32_t v; - v.x = 0; - v.y = src_val; - return v.f32; + uint32_t int32; + float fp32; + } u = {uint32_t(src_val) << 16}; + return u.fp32; } template <> @@ -354,8 +350,7 @@ void check_error(const Tensor& ref, const Tensor& result } std::cout << "error: " << error << std::endl; - std::cout << "max_diff: " << max_diff << ", ref: " << ref_value << ", res: " << result_value - << std::endl; + std::cout << "max_diff: " << max_diff << ", " << ref_value << ", " << result_value << std::endl; } #endif diff --git a/host/host_tensor/include/host_tensor_generator.hpp b/host/host_tensor/include/host_tensor_generator.hpp index c7b3fb0fb7..7734b7134b 100644 --- a/host/host_tensor/include/host_tensor_generator.hpp +++ b/host/host_tensor/include/host_tensor_generator.hpp @@ -3,6 +3,7 @@ #include #include "config.hpp" +#include "data_type.hpp" template struct GeneratorTensor_1 @@ -24,7 +25,7 @@ struct GeneratorTensor_1 template ushort operator()(Is...) { - return float_to_bfloat16(value); + return ck::f32_to_bf16(value); } }; @@ -74,7 +75,7 @@ struct GeneratorTensor_2 ushort operator()(Is...) { float tmp = (std::rand() % (max_value - min_value)) + min_value; - return float_to_bfloat16(tmp); + return ck::f32_to_bf16(tmp); } }; @@ -119,7 +120,7 @@ struct GeneratorTensor_3 float fp32_tmp = min_value + tmp * (max_value - min_value); - return float_to_bfloat16(fp32_tmp); + return ck::f32_to_bf16(fp32_tmp); } };