[CK_TILE] Fix UB and corner cases in f32/f16 to/from f8 conversion (#2571)

* Add tests for host convesion f32/f16 to f8

* Add tests for host convesion from f8 to f32/f16

* Fix UB and corner cases in f32/f16 to/from f8 conversion

* There are UBs when very small values are converted to f8: bitshifts
  can be larger that type width. Using unsigned long long does not help
  because exponent_diff >= 64 in such cases. This causes that values
  like 2.117582368e-22 are converted to non-zero f8 in host validation
  of FMHA tests, test_f8 crashes with segfault in completely irrelevant
  code like GTest internals or produces non-deterministic results etc.
* Fix FNUZ conversion to return NaN for NaN inputs.
* Fix compilation error (due to uint8_t << 8) in OCP e5m2 to f16
  conversion.

* Replace some magic numbers with values from numeric_traits

* Build tests only on devices supporting the type

[ROCm/composable_kernel commit: 7b074249f4]
This commit is contained in:
Anton Gorenko
2025-07-31 10:54:17 +06:00
committed by GitHub
parent 64f8c28b42
commit 2ef590ab43
3 changed files with 663 additions and 50 deletions

View File

@@ -43,19 +43,19 @@ enum class fp8_interpretation
};
/*
* ______________FNUZ_________________ | ______________OCP________________
* ______________FNUZ_________________ | ______________OCP________________
* e4m3 e5m2 | e4m3 e5m2
* bias : 8 16 | 7 15
* inf : 1.0000.000 1.00000.00 | N/A s.11111.00
* inf : N/A N/A | N/A s.11111.00
* Nan : 1.0000.000 1.00000.00 | s.1111.111 s.11111.{01, 10, 11}
* zero : 0.0000.000 0.00000.00 | s.0000.000 s.00000.00
* Max(norm) : s.1111.111 (240) s.11111.11(57344) | s.1111.110(448) s.11110.11(57344)
* Max(snorm): s.0000.111 s.00000.11 | s.0000.111 s.00000.11
* 0.0068359375 2.288818e-05 | 0.013671875 4.57763671875e-05
* 0.0068359375 2.288818e-05 | 0.013671875 4.57763671875e-05
* Min(norm) : s.0001.000 s.00001.00 | s.0001.000 s.00001.00
* 2^-7(0.00078125) 2^-15(3.05176e-05) | 2^-6(0.015625) 2^-14(6.10352e-05)
* 2^-7(0.0078125) 2^-15(3.05176e-05) | 2^-6(0.015625) 2^-14(6.10352e-05)
* Min(snorm): s.0000.001 s.00000.01 | s.0000.001 s.00000.01
* 2^-10(0.00097656) 2^-17(7.629395e-06)| 2^-9(0.001953125) 2^-16(1.52588e-05)
* 2^-10(0.0009765625) 2^-17(7.62939e-06) | 2^-9(0.001953125) 2^-16(1.52588e-05)
*/
template <fp8_rounding_mode rounding = static_cast<fp8_rounding_mode>(CK_TILE_FLOAT_TO_FP8_DEFAULT)>
@@ -259,50 +259,50 @@ CK_TILE_HOST_DEVICE DstT run_cast_to_f8(SrcT src, unsigned int rng = 0)
// fp8/bf8 type exponent/mantissa layout
constexpr int DstT_exp = numeric_traits<DstT>::exp; // exponent width of the destination type
constexpr int DstT_mant = numeric_traits<DstT>::mant; // mantissa width of the destination type
constexpr int DstT_bias = numeric_traits<DstT>::bias;
constexpr bool is_fnuz =
(numeric_traits<DstT>::f8_interpret == fp8_interpretation::E4M3_FNUZ) ||
(numeric_traits<DstT>::f8_interpret == fp8_interpretation::E5M2_FNUZ);
constexpr int SrcT_exp = numeric_traits<SrcT>::exp;
constexpr int SrcT_mant = numeric_traits<SrcT>::mant;
constexpr int SrcT_exp = numeric_traits<SrcT>::exp;
constexpr int SrcT_mant = numeric_traits<SrcT>::mant;
constexpr int bias = numeric_traits<SrcT>::bias;
constexpr unsigned int fInf = numeric_traits<SrcT>::Inf;
constexpr unsigned int abs_mask = numeric_traits<SrcT>::abs_mask;
using SrcT_bitwise = typename numeric_traits<SrcT>::bitwise_type;
SrcT_bitwise src_bitwise = bit_cast<SrcT_bitwise>(src);
unsigned long long head, mantissa;
int exponent, bias;
unsigned int head, mantissa;
int exponent;
unsigned int sign;
unsigned long long fInf, abs_mask;
head = src_bitwise & numeric_traits<SrcT>::head_mask;
mantissa = src_bitwise & numeric_traits<SrcT>::mant_mask;
exponent = (head >> SrcT_mant) & numeric_traits<SrcT>::exp_mask;
sign = head >> (SrcT_exp + SrcT_mant);
bias = numeric_traits<SrcT>::bias;
fInf = numeric_traits<SrcT>::Inf;
abs_mask = numeric_traits<SrcT>::abs_mask;
unsigned int signed_inf = 0;
unsigned int nan = 0;
if constexpr(is_fnuz)
{
signed_inf = clip ? ((sign << 7) + 0x7f) : 0x80;
signed_inf = clip ? ((sign << (DstT_exp + DstT_mant)) + 0x7f) : 0x80;
nan = 0x80;
}
else
{
if constexpr(DstT_exp == 4)
{ // e4m3
signed_inf = (sign << 7) + (clip ? 0x7e : 0x7f);
signed_inf = (sign << (DstT_exp + DstT_mant)) + (clip ? 0x7e : 0x7f);
}
else
{ // e5m2
signed_inf = (sign << 7) + (clip ? 0x7b : 0x7c);
signed_inf = (sign << (DstT_exp + DstT_mant)) + (clip ? 0x7b : 0x7c);
}
nan = (sign << 7) + 0x7f;
nan = (sign << (DstT_exp + DstT_mant)) + 0x7f;
}
// Max values
unsigned long long ifmax = 0;
unsigned int ifmax = 0;
if constexpr(is_float)
{
if constexpr(DstT_exp == 5)
@@ -343,9 +343,6 @@ CK_TILE_HOST_DEVICE DstT run_cast_to_f8(SrcT src, unsigned int rng = 0)
// Deal with inf and NaNs
if((src_bitwise & fInf) == fInf)
{
if constexpr(is_fnuz)
return signed_inf;
return mantissa != 0 ? nan : signed_inf;
}
@@ -354,11 +351,6 @@ CK_TILE_HOST_DEVICE DstT run_cast_to_f8(SrcT src, unsigned int rng = 0)
return signed_inf;
}
if(src_bitwise == 0)
{
return 0;
}
// First need to check if it is normal or denorm as there is a difference of
// implicit 1 Then need to adjust the exponent to align with the F8 exponent,
// in the meanwhile, shift The mantissa. Then for stochastic rounding, add rng
@@ -367,8 +359,7 @@ CK_TILE_HOST_DEVICE DstT run_cast_to_f8(SrcT src, unsigned int rng = 0)
// For IEEE bias mode, the bias is 2^(k-1) -1 where k is the width of exponent
// bits
const int f8_bias = (1 << (DstT_exp - 1)) - 1 + (is_fnuz ? 1 : 0);
const int f8_denormal_act_exponent = 1 - f8_bias; // actual exponent of f8 denormal
constexpr int f8_denormal_act_exponent = 1 - DstT_bias; // actual exponent of f8 denormal
// act_exponent is the actual exponent of fp32/fp16 (after subtracting bias)
// f8_exponent is the converted f8 exponent with bias encoding
// exponent_diff is the diff between fp32/fp16 exponent and f8 exponent,
@@ -406,11 +397,16 @@ CK_TILE_HOST_DEVICE DstT run_cast_to_f8(SrcT src, unsigned int rng = 0)
// for this case, act_exponent could be larger. Just
// that it does not need shift mantissa
}
mantissa += (1ull << SrcT_mant); // Add the implicit 1 into mantissa
mantissa += (1u << SrcT_mant); // Add the implicit 1 into mantissa
}
bool midpoint = (mantissa & ((1ull << (SrcT_mant - DstT_mant + exponent_diff)) - 1)) ==
(1ull << (SrcT_mant - DstT_mant + exponent_diff - 1));
// The value is smaller than min f8 denormal and results in zero (the early exit also prevents
// an undefined behavior of bit shifts >= type width).
if(exponent_diff > DstT_mant)
{
return is_fnuz ? 0 : (sign << (DstT_exp + DstT_mant));
}
bool midpoint = (mantissa & ((1u << (SrcT_mant - DstT_mant + exponent_diff)) - 1)) ==
(1u << (SrcT_mant - DstT_mant + exponent_diff - 1));
/* This part is a bit tricky. The judgment of whether it is a tie needs to be
done before we shift right as shift right could rip off some residual part and
make something not midpoint look like midpoint. For example, the fp16 number
@@ -422,31 +418,31 @@ CK_TILE_HOST_DEVICE DstT run_cast_to_f8(SrcT src, unsigned int rng = 0)
mantissa >>= exponent_diff;
else if(exponent_diff == -1)
mantissa <<= -exponent_diff;
bool implicit_one = mantissa & (1ull << SrcT_mant);
bool implicit_one = mantissa & (1u << SrcT_mant);
// if there is no implicit 1, it means the f8 is denormal and need to adjust
// to denorm exponent
f8_exponent =
(act_exponent + exponent_diff) /*actual f8 exponent*/ + f8_bias - (implicit_one ? 0 : 1);
(act_exponent + exponent_diff) /*actual f8 exponent*/ + DstT_bias - (implicit_one ? 0 : 1);
// Now we have the exponent and mantissa adjusted
unsigned long long drop_mask = (1ull << (SrcT_mant - DstT_mant)) - 1;
unsigned int drop_mask = (1u << (SrcT_mant - DstT_mant)) - 1;
bool odd =
mantissa & (1ull << (SrcT_mant -
DstT_mant)); // if the least significant bit that is not truncated is 1
mantissa &
(1u << (SrcT_mant - DstT_mant)); // if the least significant bit that is not truncated is 1
mantissa +=
(stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1ull) : mantissa)) & drop_mask;
(stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1u) : mantissa)) & drop_mask;
// Now we deal with overflow
if(f8_exponent == 0)
{
if((1ull << SrcT_mant) & mantissa)
if((1u << SrcT_mant) & mantissa)
{
f8_exponent = 1; // denormal overflow to become normal, promote exponent
}
}
else
{
if((1ull << (SrcT_mant + 1)) & mantissa)
if((1u << (SrcT_mant + 1)) & mantissa)
{
mantissa >>= 1;
f8_exponent++;
@@ -471,9 +467,9 @@ CK_TILE_HOST_DEVICE DstT run_cast_to_f8(SrcT src, unsigned int rng = 0)
}
if(f8_exponent == 0 && mantissa == 0)
return is_fnuz ? 0 : (sign << 7);
return is_fnuz ? 0 : (sign << (DstT_exp + DstT_mant));
mantissa &= (1 << DstT_mant) - 1;
return (sign << 7) | (f8_exponent << DstT_mant) | mantissa;
return (sign << (DstT_exp + DstT_mant)) | (f8_exponent << DstT_mant) | mantissa;
}
template <typename SrcT, typename DstT, bool clip = true>
@@ -481,8 +477,9 @@ CK_TILE_HOST_DEVICE DstT run_cast_from_f8(SrcT x)
{
static_assert(std::is_same<SrcT, fp8_t>::value || std::is_same<SrcT, bf8_t>::value,
"SrcT type must be fp8 or bf8.");
constexpr int SrcT_exp = numeric_traits<SrcT>::exp;
constexpr int SrcT_mant = numeric_traits<SrcT>::mant;
constexpr int SrcT_exp = numeric_traits<SrcT>::exp;
constexpr int SrcT_mant = numeric_traits<SrcT>::mant;
constexpr uint8_t SrcT_abs_mask = numeric_traits<SrcT>::abs_mask;
constexpr bool is_fnuz =
(numeric_traits<SrcT>::f8_interpret == fp8_interpretation::E4M3_FNUZ) ||
(numeric_traits<SrcT>::f8_interpret == fp8_interpretation::E5M2_FNUZ);
@@ -518,9 +515,9 @@ CK_TILE_HOST_DEVICE DstT run_cast_from_f8(SrcT x)
return 0;
}
unsigned long long sign = x >> 7;
unsigned long long mantissa = x & ((1 << SrcT_mant) - 1);
int exponent = (x & 0x7F) >> SrcT_mant;
unsigned int sign = x >> (SrcT_exp + SrcT_mant);
unsigned int mantissa = x & ((1 << SrcT_mant) - 1);
int exponent = (x & SrcT_abs_mask) >> SrcT_mant;
if constexpr(is_fnuz)
{
if((x & 0xff) == 0x80)
@@ -559,7 +556,7 @@ CK_TILE_HOST_DEVICE DstT run_cast_from_f8(SrcT x)
if constexpr(SrcT_exp == 5 && is_half && !is_fnuz)
{
retval = x << 8;
retval = static_cast<typename numeric_traits<DstT>::bitwise_type>(x) << 8;
return bit_cast<DstT>(retval);
}

View File

@@ -1,5 +1,15 @@
# Currently ck_tile is only built on gfx9
if(GPU_TARGETS MATCHES "gfx9")
add_gtest_executable(test_ck_tile_pk_int4 test_pk_int4.cpp)
add_gtest_executable(test_ck_tile_pk_fp4 test_pk_fp4.cpp)
endif()
if(GPU_TARGETS MATCHES "gfx95")
add_gtest_executable(test_ck_tile_pk_fp4 test_pk_fp4.cpp)
endif()
if(CK_USE_OCP_FP8 OR CK_USE_FNUZ_FP8)
add_gtest_executable(test_ck_tile_fp8 test_fp8.cpp)
target_compile_options(test_ck_tile_fp8 PRIVATE -Wno-float-equal)
# conditionally specify the use of OCP_FP8
if(CK_USE_OCP_FP8)
target_compile_options(test_ck_tile_fp8 PRIVATE -DCK_TILE_USE_OCP_FP8)
endif()
endif()

View File

@@ -0,0 +1,606 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include "ck_tile/core.hpp"
template <typename T>
class ConvertTest : public ::testing::Test
{
};
using TestTypes = ::testing::Types<float, ck_tile::fp16_t>;
TYPED_TEST_SUITE(ConvertTest, TestTypes);
TYPED_TEST(ConvertTest, ToFp8)
{
using SrcT = TypeParam;
using DstT = ck_tile::fp8_t;
auto c = [](SrcT f) {
return static_cast<unsigned int>(
ck_tile::bit_cast<uint8_t>(ck_tile::impl::run_cast_to_f8<SrcT, DstT, true>(f)));
};
auto c_nosat = [](SrcT f) {
return static_cast<unsigned int>(
ck_tile::bit_cast<uint8_t>(ck_tile::impl::run_cast_to_f8<SrcT, DstT, false>(f)));
};
#if CK_TILE_USE_OCP_FP8
EXPECT_EQ(c(+1.0f), 0b0'0111'000);
EXPECT_EQ(c(-1.0f), 0b1'0111'000);
// max f8 normal
EXPECT_EQ(c(+448.0f), 0b0'1111'110);
EXPECT_EQ(c(-448.0f), 0b1'1111'110);
// min f8 normal
EXPECT_EQ(c(+0.015625f), 0b0'0001'000);
EXPECT_EQ(c(-0.015625f), 0b1'0001'000);
// max f8 subnormal
EXPECT_EQ(c(+0.013671875f), 0b0'0000'111);
EXPECT_EQ(c(-0.013671875f), 0b1'0000'111);
// min f8 subnormal
EXPECT_EQ(c(+0.001953125f), 0b0'0000'001);
EXPECT_EQ(c(-0.001953125f), 0b1'0000'001);
// arbitrary values (exact)
EXPECT_EQ(c(+0.203125f), 0b0'0100'101);
EXPECT_EQ(c(-88.0f), 0b1'1101'011);
// arbitrary values (rounded)
EXPECT_EQ(c(+432.919f), 0b0'1111'110);
EXPECT_EQ(c(-431.111f), 0b1'1111'101);
EXPECT_EQ(c(-0.76123f), 0b1'0110'100);
EXPECT_EQ(c(+0.81234f), 0b0'0110'101);
// midpoint values (rounded to nearest even)
EXPECT_EQ(c(+58.0f), 0b0'1100'110);
EXPECT_EQ(c(+62.0f), 0b0'1101'000);
// saturating mode -> max f8 normal
// max f32/f16 normal -> max f8 normal
EXPECT_EQ(c(+ck_tile::numeric<SrcT>::max()), 0b0'1111'110);
EXPECT_EQ(c(-ck_tile::numeric<SrcT>::max()), 0b1'1111'110);
// f32/f16 infinity -> max f8 normal
EXPECT_EQ(c(+ck_tile::numeric<SrcT>::infinity()), 0b0'1111'110);
EXPECT_EQ(c(-ck_tile::numeric<SrcT>::infinity()), 0b1'1111'110);
// large f32/f16 -> max f8 normal
EXPECT_EQ(c(+1.23e9f), 0b0'1111'110);
EXPECT_EQ(c(-1.23e9f), 0b1'1111'110);
constexpr unsigned int nan_mask = 0b0'1111'111;
// non-saturating mode -> f8 NaN (because OCP e4m3 has no infinity)
// max f32/f16 normal -> f8 NaN
EXPECT_EQ(c_nosat(+ck_tile::numeric<SrcT>::max()) & nan_mask, nan_mask);
EXPECT_EQ(c_nosat(-ck_tile::numeric<SrcT>::max()) & nan_mask, nan_mask);
// f32/f16 infinity -> f8 NaN
EXPECT_EQ(c_nosat(+ck_tile::numeric<SrcT>::infinity()) & nan_mask, nan_mask);
EXPECT_EQ(c_nosat(-ck_tile::numeric<SrcT>::infinity()) & nan_mask, nan_mask);
// large f32/f16 -> f8 NaN
EXPECT_EQ(c_nosat(+1.23e9f) & nan_mask, nan_mask);
EXPECT_EQ(c_nosat(-1.23e9f) & nan_mask, nan_mask);
// f32/f16 NaN -> f8 NaN
EXPECT_EQ(c(ck_tile::numeric<SrcT>::quiet_NaN()) & nan_mask, nan_mask);
EXPECT_EQ(c(ck_tile::numeric<SrcT>::signaling_NaN()) & nan_mask, nan_mask);
// f32/f16 zero -> f8 zero
EXPECT_EQ(c(+0.0f), 0b0'0000'000);
EXPECT_EQ(c(-0.0f), 0b1'0000'000);
// min f32/f16 normal -> f8 zero
EXPECT_EQ(c(+ck_tile::numeric<SrcT>::min()), 0b0'0000'000);
EXPECT_EQ(c(-ck_tile::numeric<SrcT>::min()), 0b1'0000'000);
// min f32/f16 subnormal -> f8 zero
EXPECT_EQ(c(+ck_tile::numeric<SrcT>::denorm_min()), 0b0'0000'000);
EXPECT_EQ(c(-ck_tile::numeric<SrcT>::denorm_min()), 0b1'0000'000);
// All values smaller than min f8 subnormal must be converted to f8 zero
constexpr int src_min_subnorm_exp =
-(ck_tile::numeric_traits<SrcT>::bias + ck_tile::numeric_traits<SrcT>::mant - 1);
constexpr int dst_min_subnorm_exp =
-(ck_tile::numeric_traits<DstT>::bias + ck_tile::numeric_traits<DstT>::mant - 1);
for(int exp = src_min_subnorm_exp; exp <= 0; ++exp)
{
const float f = std::ldexp(1.0, exp);
if(exp < dst_min_subnorm_exp)
{
EXPECT_EQ(c(+f), 0b0'0000'000) << "+f = 2^" << exp << " = " << +f;
EXPECT_EQ(c(-f), 0b1'0000'000) << "-f = 2^" << exp << " = " << -f;
}
else
{
EXPECT_GT(c(+f), 0b0'0000'000) << "+f = 2^" << exp << " = " << +f;
EXPECT_GT(c(-f), 0b1'0000'000) << "-f = 2^" << exp << " = " << -f;
}
}
#else // FNUZ
EXPECT_EQ(c(+1.0f), 0b0'1000'000);
EXPECT_EQ(c(-1.0f), 0b1'1000'000);
// max f8 normal
EXPECT_EQ(c(+240.0f), 0b0'1111'111);
EXPECT_EQ(c(-240.0f), 0b1'1111'111);
// min f8 normal
EXPECT_EQ(c(+0.0078125f), 0b0'0001'000);
EXPECT_EQ(c(-0.0078125f), 0b1'0001'000);
// max f8 subnormal
EXPECT_EQ(c(+0.0068359375f), 0b0'0000'111);
EXPECT_EQ(c(-0.0068359375f), 0b1'0000'111);
// min f8 subnormal
EXPECT_EQ(c(+0.0009765625f), 0b0'0000'001);
EXPECT_EQ(c(-0.0009765625f), 0b1'0000'001);
// arbitrary values (exact)
EXPECT_EQ(c(+0.1015625f), 0b0'0100'101);
EXPECT_EQ(c(-44.0f), 0b1'1101'011);
// arbitrary values (rounded)
EXPECT_EQ(c(+219.91f), 0b0'1111'110);
EXPECT_EQ(c(-203.11f), 0b1'1111'101);
EXPECT_EQ(c(-0.3639f), 0b1'0110'100);
EXPECT_EQ(c(+0.4139f), 0b0'0110'101);
// saturating mode -> max f8 normal
// max f32/f16 normal -> max f8 normal
EXPECT_EQ(c(+ck_tile::numeric<SrcT>::max()), 0b0'1111'111);
EXPECT_EQ(c(-ck_tile::numeric<SrcT>::max()), 0b1'1111'111);
// f32/f16 infinity -> max f8 normal
EXPECT_EQ(c(+ck_tile::numeric<SrcT>::infinity()), 0b0'1111'111);
EXPECT_EQ(c(-ck_tile::numeric<SrcT>::infinity()), 0b1'1111'111);
// large f32/f16 -> max f8 normal
EXPECT_EQ(c(+1.23e9f), 0b0'1111'111);
EXPECT_EQ(c(-1.23e9f), 0b1'1111'111);
constexpr unsigned int nan_value = 0b1'0000'000;
// non-saturating mode -> f8 NaN (FN means "finite", so no infinity)
// max f32/f16 normal -> f8 NaN
EXPECT_EQ(c_nosat(+ck_tile::numeric<SrcT>::max()), nan_value);
EXPECT_EQ(c_nosat(-ck_tile::numeric<SrcT>::max()), nan_value);
// f32/f16 infinity -> f8 NaN
EXPECT_EQ(c_nosat(+ck_tile::numeric<SrcT>::infinity()), nan_value);
EXPECT_EQ(c_nosat(-ck_tile::numeric<SrcT>::infinity()), nan_value);
// large f32/f16 -> f8 NaN
EXPECT_EQ(c_nosat(+1.23e9f), nan_value);
EXPECT_EQ(c_nosat(-1.23e9f), nan_value);
// f32/f16 NaN -> f8 NaN
EXPECT_EQ(c(ck_tile::numeric<SrcT>::quiet_NaN()), nan_value);
EXPECT_EQ(c(ck_tile::numeric<SrcT>::signaling_NaN()), nan_value);
// UZ means "unsigned zero" (0b1'0000'000 is NaN)
// f32/f16 +-zero -> f8 +zero
EXPECT_EQ(c(+0.0f), 0b0'0000'000);
EXPECT_EQ(c(-0.0f), 0b0'0000'000);
// min f32/f16 normal -> f8 +zero
EXPECT_EQ(c(+ck_tile::numeric<SrcT>::min()), 0b0'0000'000);
EXPECT_EQ(c(-ck_tile::numeric<SrcT>::min()), 0b0'0000'000);
// min f32/f16 subnormal -> f8 +zero
EXPECT_EQ(c(+ck_tile::numeric<SrcT>::denorm_min()), 0b0'0000'000);
EXPECT_EQ(c(-ck_tile::numeric<SrcT>::denorm_min()), 0b0'0000'000);
// All values smaller than min f8 subnormal must be converted to f8 zero
constexpr int src_min_subnorm_exp =
-(ck_tile::numeric_traits<SrcT>::bias + ck_tile::numeric_traits<SrcT>::mant - 1);
constexpr int dst_min_subnorm_exp =
-(ck_tile::numeric_traits<DstT>::bias + ck_tile::numeric_traits<DstT>::mant - 1);
for(int exp = src_min_subnorm_exp; exp <= 0; ++exp)
{
const float f = std::ldexp(1.0, exp);
if(exp < dst_min_subnorm_exp)
{
EXPECT_EQ(c(+f), 0b0'0000'000) << "+f = 2^" << exp << " = " << +f;
EXPECT_EQ(c(-f), 0b0'0000'000) << "-f = 2^" << exp << " = " << -f;
}
else
{
EXPECT_GT(c(+f), 0b0'0000'000) << "+f = 2^" << exp << " = " << +f;
EXPECT_GT(c(-f), 0b0'0000'000) << "-f = 2^" << exp << " = " << -f;
}
}
#endif
}
TYPED_TEST(ConvertTest, ToBf8)
{
using SrcT = TypeParam;
using DstT = ck_tile::bf8_t;
auto c = [](SrcT f) {
return static_cast<unsigned int>(
ck_tile::bit_cast<uint8_t>(ck_tile::impl::run_cast_to_f8<SrcT, DstT, true>(f)));
};
auto c_nosat = [](SrcT f) {
return static_cast<unsigned int>(
ck_tile::bit_cast<uint8_t>(ck_tile::impl::run_cast_to_f8<SrcT, DstT, false>(f)));
};
#if CK_TILE_USE_OCP_FP8
EXPECT_EQ(c(+1.0f), 0b0'01111'00);
EXPECT_EQ(c(-1.0f), 0b1'01111'00);
// max f8 normal
EXPECT_EQ(c(+57344.0f), 0b0'11110'11);
EXPECT_EQ(c(-57344.0f), 0b1'11110'11);
// min f8 normal
EXPECT_EQ(c(+6.103515625e-05f), 0b0'00001'00);
EXPECT_EQ(c(-6.103515625e-05f), 0b1'00001'00);
// max f8 subnormal
EXPECT_EQ(c(+4.57763671875e-05f), 0b0'00000'11);
EXPECT_EQ(c(-4.57763671875e-05f), 0b1'00000'11);
// min f8 subnormal
EXPECT_EQ(c(+1.52587890625e-05f), 0b0'00000'01);
EXPECT_EQ(c(-1.52587890625e-05f), 0b1'00000'01);
// arbitrary values (exact)
EXPECT_EQ(c(+0.01953125f), 0b0'01001'01);
EXPECT_EQ(c(-3584.0f), 0b1'11010'11);
// arbitrary values (rounded)
EXPECT_EQ(c(+2030.56f), 0b0'11010'00);
EXPECT_EQ(c(-1801.33f), 0b1'11001'11);
EXPECT_EQ(c(-0.27891f), 0b1'0110'100);
EXPECT_EQ(c(+0.33333f), 0b0'0110'101);
// saturating mode -> max f8 normal
// max f32/f16 normal -> max f8 normal
EXPECT_EQ(c(+ck_tile::numeric<SrcT>::max()), 0b0'11110'11);
EXPECT_EQ(c(-ck_tile::numeric<SrcT>::max()), 0b1'11110'11);
// f32/f16 infinity -> max f8 normal
EXPECT_EQ(c(+ck_tile::numeric<SrcT>::infinity()), 0b0'11110'11);
EXPECT_EQ(c(-ck_tile::numeric<SrcT>::infinity()), 0b1'11110'11);
// large f32/f16 -> max f8 normal
EXPECT_EQ(c(+1.23e9f), 0b0'11110'11);
EXPECT_EQ(c(-1.23e9f), 0b1'11110'11);
// non-saturating mode -> f8 infinity
// max f32/f16 normal -> f8 infinity
EXPECT_EQ(c_nosat(+ck_tile::numeric<SrcT>::max()), 0b0'11111'00);
EXPECT_EQ(c_nosat(-ck_tile::numeric<SrcT>::max()), 0b1'11111'00);
// f32/f16 infinity -> f8 infinity
EXPECT_EQ(c_nosat(+ck_tile::numeric<SrcT>::infinity()), 0b0'11111'00);
EXPECT_EQ(c_nosat(-ck_tile::numeric<SrcT>::infinity()), 0b1'11111'00);
// large f32/f16 -> f8 infinity
EXPECT_EQ(c_nosat(+1.23e9f), 0b0'11111'00);
EXPECT_EQ(c_nosat(-1.23e9f), 0b1'11111'00);
// f32/f16 NaN -> f8 NaN
EXPECT_TRUE((c(ck_tile::numeric<SrcT>::quiet_NaN()) & 0b0'11111'11) != 0b0'11111'00);
EXPECT_TRUE((c(ck_tile::numeric<SrcT>::signaling_NaN()) & 0b0'11111'11) != 0b0'11111'00);
// f32/f16 zero -> f8 zero
EXPECT_EQ(c(+0.0f), 0b0'00000'00);
EXPECT_EQ(c(-0.0f), 0b1'00000'00);
if constexpr(std::is_same_v<SrcT, float>)
{
// min f32 normal -> f8 zero
EXPECT_EQ(c(+ck_tile::numeric<SrcT>::min()), 0b0'00000'00);
EXPECT_EQ(c(-ck_tile::numeric<SrcT>::min()), 0b1'00000'00);
}
else
{
// min f16 normal -> min f8 normal (they are equal)
EXPECT_EQ(c(+ck_tile::numeric<SrcT>::min()), 0b0'00001'00);
EXPECT_EQ(c(-ck_tile::numeric<SrcT>::min()), 0b1'00001'00);
}
// min f32/f16 subnormal -> f8 zero
EXPECT_EQ(c(+ck_tile::numeric<SrcT>::denorm_min()), 0b0'00000'00);
EXPECT_EQ(c(-ck_tile::numeric<SrcT>::denorm_min()), 0b1'00000'00);
// All values smaller than min f8 subnormal must be converted to f8 zero
constexpr int src_min_subnorm_exp =
-(ck_tile::numeric_traits<SrcT>::bias + ck_tile::numeric_traits<SrcT>::mant - 1);
constexpr int dst_min_subnorm_exp =
-(ck_tile::numeric_traits<DstT>::bias + ck_tile::numeric_traits<DstT>::mant - 1);
for(int exp = src_min_subnorm_exp; exp <= 0; ++exp)
{
const float f = std::ldexp(1.0, exp);
if(exp < dst_min_subnorm_exp)
{
EXPECT_EQ(c(+f), 0b0'00000'00) << "+f = 2^" << exp << " = " << +f;
EXPECT_EQ(c(-f), 0b1'00000'00) << "-f = 2^" << exp << " = " << -f;
}
else
{
EXPECT_GT(c(+f), 0b0'00000'00) << "+f = 2^" << exp << " = " << +f;
EXPECT_GT(c(-f), 0b1'00000'00) << "-f = 2^" << exp << " = " << -f;
}
}
#else // FNUZ
EXPECT_EQ(c(+1.0f), 0b0'10000'00);
EXPECT_EQ(c(-1.0f), 0b1'10000'00);
// max f8 normal
EXPECT_EQ(c(+57344.0f), 0b0'11111'11);
EXPECT_EQ(c(-57344.0f), 0b1'11111'11);
// min f8 normal
EXPECT_EQ(c(+3.0517578125e-05f), 0b0'00001'00);
EXPECT_EQ(c(-3.0517578125e-05f), 0b1'00001'00);
// max f8 subnormal
EXPECT_EQ(c(+2.288818359375e-05f), 0b0'00000'11);
EXPECT_EQ(c(-2.288818359375e-05f), 0b1'00000'11);
// min f8 subnormal
EXPECT_EQ(c(+7.62939453125e-06f), 0b0'00000'01);
EXPECT_EQ(c(-7.62939453125e-06f), 0b1'00000'01);
// arbitrary values (exact)
EXPECT_EQ(c(+0.009765625f), 0b0'01001'01);
EXPECT_EQ(c(-1792.0f), 0b1'11010'11);
// arbitrary values (rounded)
EXPECT_EQ(c(+840.100f), 0b0'11001'11);
EXPECT_EQ(c(-999.999f), 0b1'11010'00);
EXPECT_EQ(c(-0.12789f), 0b1'0110'100);
EXPECT_EQ(c(+0.14444f), 0b0'0110'101);
// saturating mode -> max f8 normal
// max f32/f16 normal -> max f8 normal
EXPECT_EQ(c(+ck_tile::numeric<SrcT>::max()), 0b0'11111'11);
EXPECT_EQ(c(-ck_tile::numeric<SrcT>::max()), 0b1'1111'111);
// f32/f16 infinity -> max f8 normal
EXPECT_EQ(c(+ck_tile::numeric<SrcT>::infinity()), 0b0'11111'11);
EXPECT_EQ(c(-ck_tile::numeric<SrcT>::infinity()), 0b1'1111'111);
// large f32/f16 -> max f8 normal
EXPECT_EQ(c(+1.23e9f), 0b0'11111'11);
EXPECT_EQ(c(-1.23e9f), 0b1'1111'111);
constexpr unsigned int nan_value = 0b1'00000'00;
// non-saturating mode -> f8 NaN (FN means "finite", so no infinity)
// max f32/f16 normal -> f8 NaN
EXPECT_EQ(c_nosat(+ck_tile::numeric<SrcT>::max()), nan_value);
EXPECT_EQ(c_nosat(-ck_tile::numeric<SrcT>::max()), nan_value);
// f32/f16 infinity -> f8 NaN
EXPECT_EQ(c_nosat(+ck_tile::numeric<SrcT>::infinity()), nan_value);
EXPECT_EQ(c_nosat(-ck_tile::numeric<SrcT>::infinity()), nan_value);
// large f32/f16 -> f8 NaN
EXPECT_EQ(c_nosat(+1.23e9f), nan_value);
EXPECT_EQ(c_nosat(-1.23e9f), nan_value);
// f32/f16 NaN -> f8 NaN
EXPECT_EQ(c(ck_tile::numeric<SrcT>::quiet_NaN()), nan_value);
EXPECT_EQ(c(ck_tile::numeric<SrcT>::signaling_NaN()), nan_value);
// UZ means "unsigned zero" (0b1'00000'00 is NaN)
// f32/f16 +-zero -> f8 +zero
EXPECT_EQ(c(+0.0f), 0b0'00000'00);
EXPECT_EQ(c(-0.0f), 0b0'00000'00);
if constexpr(std::is_same_v<SrcT, float>)
{
// min f32 normal -> f8 +zero
EXPECT_EQ(c(+ck_tile::numeric<SrcT>::min()), 0b0'00000'00);
EXPECT_EQ(c(-ck_tile::numeric<SrcT>::min()), 0b0'00000'00);
}
else
{
// min f16 normal -> f8 normal
EXPECT_EQ(c(+ck_tile::numeric<SrcT>::min()), 0b0'00010'00);
EXPECT_EQ(c(-ck_tile::numeric<SrcT>::min()), 0b1'00010'00);
}
// min f32/f16 subnormal -> f8 +zero
EXPECT_EQ(c(+ck_tile::numeric<SrcT>::denorm_min()), 0b0'00000'00);
EXPECT_EQ(c(-ck_tile::numeric<SrcT>::denorm_min()), 0b0'00000'00);
// All values smaller than min f8 subnormal must be converted to f8 zero
constexpr int src_min_subnorm_exp =
-(ck_tile::numeric_traits<SrcT>::bias + ck_tile::numeric_traits<SrcT>::mant - 1);
constexpr int dst_min_subnorm_exp =
-(ck_tile::numeric_traits<DstT>::bias + ck_tile::numeric_traits<DstT>::mant - 1);
for(int exp = src_min_subnorm_exp; exp <= 0; ++exp)
{
const float f = std::ldexp(1.0, exp);
if(exp < dst_min_subnorm_exp)
{
EXPECT_EQ(c(+f), 0b0'00000'00) << "+f = 2^" << exp << " = " << +f;
EXPECT_EQ(c(-f), 0b0'00000'00) << "-f = 2^" << exp << " = " << -f;
}
else
{
EXPECT_GT(c(+f), 0b0'00000'00) << "+f = 2^" << exp << " = " << +f;
EXPECT_GT(c(-f), 0b0'00000'00) << "-f = 2^" << exp << " = " << -f;
}
}
#endif
}
TYPED_TEST(ConvertTest, FromFp8)
{
using SrcT = ck_tile::fp8_t;
using DstT = TypeParam;
auto c = [](uint8_t u) {
return ck_tile::type_convert<float>(
ck_tile::impl::run_cast_from_f8<SrcT, DstT, true>(ck_tile::bit_cast<SrcT>(u)));
};
#if CK_TILE_USE_OCP_FP8
EXPECT_EQ(c(0b0'0111'000), +1.0f);
EXPECT_EQ(c(0b1'0111'000), -1.0f);
// max f8 normal
EXPECT_EQ(c(0b0'1111'110), +448.0f);
EXPECT_EQ(c(0b1'1111'110), -448.0f);
// min f8 normal
EXPECT_EQ(c(0b0'0001'000), +0.015625f);
EXPECT_EQ(c(0b1'0001'000), -0.015625f);
// max f8 subnormal
EXPECT_EQ(c(0b0'0000'111), +0.013671875f);
EXPECT_EQ(c(0b1'0000'111), -0.013671875f);
// min f8 subnormal
EXPECT_EQ(c(0b0'0000'001), +0.001953125f);
EXPECT_EQ(c(0b1'0000'001), -0.001953125f);
// arbitrary values
EXPECT_EQ(c(0b0'0100'101), +0.203125f);
EXPECT_EQ(c(0b1'1101'011), -88.0f);
// f8 NaN -> f32/f16 NaN
EXPECT_TRUE(ck_tile::isnan(c(0b0'1111'111)));
EXPECT_TRUE(ck_tile::isnan(c(0b1'1111'111)));
// f8 zero -> f32/f16 zero (sign is preserved)
EXPECT_EQ(c(0b0'0000'000),
ck_tile::bit_cast<DstT>(typename ck_tile::numeric_traits<DstT>::bitwise_type{0}));
EXPECT_EQ(c(0b1'0000'000), ck_tile::bit_cast<DstT>(ck_tile::numeric_traits<DstT>::Neg0));
#else // FNUZ
EXPECT_EQ(c(0b0'1000'000), +1.0f);
EXPECT_EQ(c(0b1'1000'000), -1.0f);
// max f8 normal
EXPECT_EQ(c(0b0'1111'111), +240.0f);
EXPECT_EQ(c(0b1'1111'111), -240.0f);
// min f8 normal
EXPECT_EQ(c(0b0'0001'000), +0.0078125f);
EXPECT_EQ(c(0b1'0001'000), -0.0078125f);
// max f8 subnormal
EXPECT_EQ(c(0b0'0000'111), +0.0068359375f);
EXPECT_EQ(c(0b1'0000'111), -0.0068359375f);
// min f8 subnormal
EXPECT_EQ(c(0b0'0000'001), +0.0009765625f);
EXPECT_EQ(c(0b1'0000'001), -0.0009765625f);
// arbitrary values
EXPECT_EQ(c(0b0'0100'101), +0.1015625f);
EXPECT_EQ(c(0b1'1101'011), -44.0f);
// f8 NaN -> f32/f16 NaN
EXPECT_TRUE(ck_tile::isnan(c(0b1'0000'000)));
// UZ means "unsigned zero" (0b1'0000'000 is NaN)
// f8 +zero -> f32/f16 +zero
EXPECT_EQ(c(0b0'0000'000),
ck_tile::bit_cast<DstT>(typename ck_tile::numeric_traits<DstT>::bitwise_type{0}));
#endif
}
TYPED_TEST(ConvertTest, FromBf8)
{
using SrcT = ck_tile::bf8_t;
using DstT = TypeParam;
using DstT = TypeParam;
auto c = [](uint8_t u) {
return ck_tile::type_convert<float>(
ck_tile::impl::run_cast_from_f8<SrcT, DstT, true>(ck_tile::bit_cast<SrcT>(u)));
};
#if CK_TILE_USE_OCP_FP8
auto c_nosat = [](uint8_t u) {
return ck_tile::type_convert<float>(
ck_tile::impl::run_cast_from_f8<SrcT, DstT, false>(ck_tile::bit_cast<SrcT>(u)));
};
EXPECT_EQ(c(0b0'01111'00), +1.0f);
EXPECT_EQ(c(0b1'01111'00), -1.0f);
// max f8 normal
EXPECT_EQ(c(0b0'11110'11), +57344.0f);
EXPECT_EQ(c(0b1'11110'11), -57344.0f);
// min f8 normal
EXPECT_EQ(c(0b0'00001'00), +6.103515625e-05f);
EXPECT_EQ(c(0b1'00001'00), -6.103515625e-05f);
// max f8 subnormal
EXPECT_EQ(c(0b0'00000'11), +4.57763671875e-05f);
EXPECT_EQ(c(0b1'00000'11), -4.57763671875e-05f);
// min f8 subnormal
EXPECT_EQ(c(0b0'00000'01), +1.52587890625e-05f);
EXPECT_EQ(c(0b1'00000'01), -1.52587890625e-05f);
// arbitrary values
EXPECT_EQ(c(0b0'01001'01), +0.01953125f);
EXPECT_EQ(c(0b1'11010'11), -3584.0f);
// saturating mode
// f8 infinity -> max f8 normal as f32/f16
EXPECT_EQ(c(0b0'11111'00), +57344.0f);
EXPECT_EQ(c(0b1'11111'00), -57344.0f);
// non-saturating mode
// f8 infinity -> f32/f16 infinity
EXPECT_EQ(c_nosat(0b0'11111'00), +ck_tile::numeric<DstT>::infinity());
EXPECT_EQ(c_nosat(0b1'11111'00), -ck_tile::numeric<DstT>::infinity());
// f8 NaN -> f32/f16 NaN
EXPECT_TRUE(ck_tile::isnan(c(0b0'11111'01)));
EXPECT_TRUE(ck_tile::isnan(c(0b0'11111'10)));
EXPECT_TRUE(ck_tile::isnan(c(0b0'11111'11)));
EXPECT_TRUE(ck_tile::isnan(c(0b1'11111'01)));
EXPECT_TRUE(ck_tile::isnan(c(0b1'11111'10)));
EXPECT_TRUE(ck_tile::isnan(c(0b1'11111'11)));
// f8 zero -> f32/f16 zero (sign is preserved)
EXPECT_EQ(c(0b0'00000'00),
ck_tile::bit_cast<DstT>(typename ck_tile::numeric_traits<DstT>::bitwise_type{0}));
EXPECT_EQ(c(0b1'00000'00), ck_tile::bit_cast<DstT>(ck_tile::numeric_traits<DstT>::Neg0));
if constexpr(std::is_same_v<DstT, ck_tile::fp16_t>)
{
// min f8 normal -> min f16 normal (they are equal)
EXPECT_EQ(c(0b0'00001'00), +ck_tile::numeric<DstT>::min());
EXPECT_EQ(c(0b1'00001'00), -ck_tile::numeric<DstT>::min());
}
#else // FNUZ
EXPECT_EQ(c(0b0'10000'00), +1.0f);
EXPECT_EQ(c(0b1'10000'00), -1.0f);
// max f8 normal
EXPECT_EQ(c(0b0'11111'11), +57344.0f);
EXPECT_EQ(c(0b1'11111'11), -57344.0f);
// min f8 normal
EXPECT_EQ(c(0b0'00001'00), +3.0517578125e-05f);
EXPECT_EQ(c(0b1'00001'00), -3.0517578125e-05f);
// max f8 subnormal
EXPECT_EQ(c(0b0'00000'11), +2.288818359375e-05f);
EXPECT_EQ(c(0b1'00000'11), -2.288818359375e-05f);
// min f8 subnormal
EXPECT_EQ(c(0b0'00000'01), +7.62939453125e-06f);
EXPECT_EQ(c(0b1'00000'01), -7.62939453125e-06f);
// arbitrary values
EXPECT_EQ(c(0b0'01001'01), +0.009765625f);
EXPECT_EQ(c(0b1'11010'11), -1792.0f);
// f8 NaN -> f32/f16 NaN
EXPECT_TRUE(ck_tile::isnan(c(0b1'00000'00)));
// UZ means "unsigned zero" (0b1'00000'00 is NaN)
// f8 +zero -> f32/f16 +zero
EXPECT_EQ(c(0b0'00000'00),
ck_tile::bit_cast<DstT>(typename ck_tile::numeric_traits<DstT>::bitwise_type{0}));
if constexpr(std::is_same_v<DstT, ck_tile::fp16_t>)
{
// one of f8 normals -> min f16 normal
EXPECT_EQ(c(0b0'00010'00), +ck_tile::numeric<DstT>::min());
EXPECT_EQ(c(0b1'00010'00), -ck_tile::numeric<DstT>::min());
}
#endif
}
// Convert f8 -> f32/f16 -> f8 to check if all values are covered
// OCP types multiple NaN representations (e4m3 - 2, e5m2 - 6), they are ignored for simplicity.
TYPED_TEST(ConvertTest, FromFp8AndToFp8)
{
using SrcT = ck_tile::fp8_t;
using DstT = TypeParam;
for(int i = 0; i < 256; ++i)
{
#if CK_TILE_USE_OCP_FP8
if((i & 0b0'1111'111) == 0b0'1111'111)
{
continue;
}
#endif
const uint8_t u = static_cast<uint8_t>(i);
const SrcT from = ck_tile::bit_cast<SrcT>(u);
const DstT f = ck_tile::impl::run_cast_from_f8<SrcT, DstT, false>(from);
const SrcT to = ck_tile::impl::run_cast_to_f8<DstT, SrcT, false>(f);
EXPECT_EQ(from, to) << "u8: " << i << " f32/f16: " << ck_tile::type_convert<float>(f);
}
}
TYPED_TEST(ConvertTest, FromBf8AndToBf8)
{
using SrcT = ck_tile::bf8_t;
using DstT = TypeParam;
for(int i = 0; i < 256; ++i)
{
#if CK_TILE_USE_OCP_FP8
if((i & 0b0'11111'11) > 0b0'11111'00)
{
continue;
}
#endif
const uint8_t u = static_cast<uint8_t>(i);
const SrcT from = ck_tile::bit_cast<SrcT>(u);
const DstT f = ck_tile::impl::run_cast_from_f8<SrcT, DstT, false>(from);
const SrcT to = ck_tile::impl::run_cast_to_f8<DstT, SrcT, false>(f);
EXPECT_EQ(from, to) << "u8: " << i << " f32/f16: " << ck_tile::type_convert<float>(f);
}
}