Merge commit '0f10e6d9218ce9d00a34a66572c0686dce1e45ea' into develop

This commit is contained in:
assistant-librarian[bot]
2025-09-29 11:12:04 +00:00
parent 2593ecf5b5
commit f9767142cf
16 changed files with 198 additions and 55 deletions

View File

@@ -117,12 +117,8 @@ using bf16_raw_t = uint16_t;
CK_TILE_HOST_DEVICE
constexpr uint16_t float_to_bf16_rtn_raw(float f)
{
union
{
float fp32;
uint32_t int32;
} u = {f};
if(~u.int32 & 0x7f800000)
uint32_t bits = bit_cast<uint32_t>(f);
if(~bits & 0x7f800000)
{
// When the exponent bits are not all 1s, then the value is zero, normal,
// or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus
@@ -140,9 +136,9 @@ constexpr uint16_t float_to_bf16_rtn_raw(float f)
// When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F,
// incrementing it causes it to become an exponent of 0xFF and a mantissa
// of 0x00, which is Inf, the next higher value to the unrounded value.
u.int32 += 0x7fff + ((u.int32 >> 16) & 1); // Round to nearest, round to even
bits += 0x7fff + ((bits >> 16) & 1); // Round to nearest, round to even
}
else if(u.int32 & 0xffff)
else if(bits & 0xffff)
{
// When all of the exponent bits are 1, the value is Inf or NaN.
// Inf is indicated by a zero mantissa. NaN is indicated by any nonzero
@@ -152,9 +148,9 @@ constexpr uint16_t float_to_bf16_rtn_raw(float f)
// lower 16 bits of the mantissa are 1, we set the least significant bit
// of the bfloat16 mantissa, in order to preserve signaling NaN in case
// the bloat16's mantissa bits are all 0.
u.int32 |= 0x10000; // Preserve signaling NaN
bits |= 0x10000; // Preserve signaling NaN
}
return uint16_t(u.int32 >> 16);
return uint16_t(bits >> 16);
}
CK_TILE_HOST
@@ -225,24 +221,16 @@ uint16_t float_to_bf16_rta_asm(float f)
CK_TILE_HOST_DEVICE
constexpr uint16_t float_to_bf16_truc_nan_raw(float f)
{
union
{
float fp32;
uint32_t int32;
} u = {f};
return uint16_t(u.int32 >> 16) | (!(~u.int32 & 0x7f800000) && (u.int32 & 0xffff));
uint32_t bits = bit_cast<uint32_t>(f);
return static_cast<uint16_t>(bits >> 16) | (!(~bits & 0x7f800000) && (bits & 0xffff));
}
// Fast truncate instead of rounding, RTZ
CK_TILE_HOST_DEVICE
constexpr uint16_t float_to_bf16_truc_raw(float f)
{
union
{
float fp32;
uint32_t int32;
} u = {f};
return uint16_t(u.int32 >> 16);
uint32_t bits = bit_cast<uint32_t>(f);
return static_cast<uint16_t>(bits >> 16);
}
template <bf16_rounding_mode rounding>
@@ -287,7 +275,7 @@ template <bf16_rounding_mode rounding =
static_cast<bf16_rounding_mode>(CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT)>
CK_TILE_HOST_DEVICE constexpr bfloat16_t float_to_bf16(float f, constant<rounding> = {})
{
#if defined(__gfx950__)
#if CK_TILE_USE_LLVM_BUILTIN_BF16
return static_cast<bfloat16_t>(f);
#else
return bit_cast<bfloat16_t>(float_to_bf16_raw(f, constant<rounding>{}));