mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 09:16:52 +00:00
updated bfloat16_to_float
This commit is contained in:
@@ -5,7 +5,6 @@
|
||||
#include "hip/hip_runtime.h"
|
||||
#include "hip/hip_fp16.h"
|
||||
#endif
|
||||
#include "bfloat16_dev.hpp"
|
||||
|
||||
// "Constant" address space for kernel parameter
|
||||
#define CONSTANT __attribute__((address_space(4)))
|
||||
|
||||
@@ -927,6 +927,58 @@ using int8x16_t = typename vector_type<int8_t, 16>::type;
|
||||
using int8x32_t = typename vector_type<int8_t, 32>::type;
|
||||
using int8x64_t = typename vector_type<int8_t, 64>::type;
|
||||
|
||||
__host__ __device__ float bf16_to_f32(ushort src_val)
|
||||
{
|
||||
union
|
||||
{
|
||||
uint32_t int32;
|
||||
float fp32;
|
||||
} u = {uint32_t(src_val) << 16};
|
||||
return u.fp32;
|
||||
}
|
||||
|
||||
__host__ __device__ ushort f32_to_bf16(float src_val)
|
||||
{
|
||||
union
|
||||
{
|
||||
float fp32;
|
||||
uint32_t int32;
|
||||
} u = {src_val};
|
||||
if(~u.int32 & 0x7f800000)
|
||||
{
|
||||
// When the exponent bits are not all 1s, then the value is zero, normal,
|
||||
// or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus
|
||||
// 1 if the least significant bit of the bfloat16 mantissa is 1 (odd).
|
||||
// This causes the bfloat16's mantissa to be incremented by 1 if the 16
|
||||
// least significant bits of the float mantissa are greater than 0x8000,
|
||||
// or if they are equal to 0x8000 and the least significant bit of the
|
||||
// bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when
|
||||
// the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already
|
||||
// has the value 0x7f, then incrementing it causes it to become 0x00 and
|
||||
// the exponent is incremented by one, which is the next higher FP value
|
||||
// to the unrounded bfloat16 value. When the bfloat16 value is subnormal
|
||||
// with an exponent of 0x00 and a mantissa of 0x7F, it may be rounded up
|
||||
// to a normal value with an exponent of 0x01 and a mantissa of 0x00.
|
||||
// When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F,
|
||||
// incrementing it causes it to become an exponent of 0xFF and a mantissa
|
||||
// of 0x00, which is Inf, the next higher value to the unrounded value.
|
||||
u.int32 += 0x7fff + ((u.int32 >> 16) & 1); // Round to nearest, round to even
|
||||
}
|
||||
else if(u.int32 & 0xffff)
|
||||
{
|
||||
// When all of the exponent bits are 1, the value is Inf or NaN.
|
||||
// Inf is indicated by a zero mantissa. NaN is indicated by any nonzero
|
||||
// mantissa bit. Quiet NaN is indicated by the most significant mantissa
|
||||
// bit being 1. Signaling NaN is indicated by the most significant
|
||||
// mantissa bit being 0 but some other bit(s) being 1. If any of the
|
||||
// lower 16 bits of the mantissa are 1, we set the least significant bit
|
||||
// of the bfloat16 mantissa, in order to preserve signaling NaN in case
|
||||
// the bloat16's mantissa bits are all 0.
|
||||
u.int32 |= 0x10000; // Preserve signaling NaN
|
||||
}
|
||||
return uint16_t(u.int32 >> 16);
|
||||
}
|
||||
|
||||
// data type conversion
|
||||
template <typename T>
|
||||
struct type_convert
|
||||
@@ -942,14 +994,14 @@ template <>
|
||||
template <>
|
||||
__device__ float type_convert<float>::operator()<ushort>(ushort x) const
|
||||
{
|
||||
return bfloat16_to_float(x);
|
||||
return bf16_to_f32(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
template <>
|
||||
__device__ ushort type_convert<ushort>::operator()<float>(float x) const
|
||||
{
|
||||
return float_to_bfloat16(x);
|
||||
return f32_to_bf16(x);
|
||||
}
|
||||
|
||||
// TODO: deprecate this
|
||||
|
||||
@@ -28,6 +28,12 @@ __device__ void inner_product<float, float, float>(const float& a, const float&
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ void inner_product<ushort, ushort, float>(const ushort& a, const ushort& b, float& c)
|
||||
{
|
||||
c += bf16_to_f32(a) * bf16_to_f32(b);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ void
|
||||
inner_product<float2_t, float2_t, float>(const float2_t& a, const float2_t& b, float& c)
|
||||
|
||||
Reference in New Issue
Block a user