[CK_TILE] Fix UB and corner cases in f32/f16 to/from f8 conversion (#2571)

* Add tests for host convesion f32/f16 to f8

* Add tests for host convesion from f8 to f32/f16

* Fix UB and corner cases in f32/f16 to/from f8 conversion

* There are UBs when very small values are converted to f8: bitshifts
  can be larger that type width. Using unsigned long long does not help
  because exponent_diff >= 64 in such cases. This causes that values
  like 2.117582368e-22 are converted to non-zero f8 in host validation
  of FMHA tests, test_f8 crashes with segfault in completely irrelevant
  code like GTest internals or produces non-deterministic results etc.
* Fix FNUZ conversion to return NaN for NaN inputs.
* Fix compilation error (due to uint8_t << 8) in OCP e5m2 to f16
  conversion.

* Replace some magic numbers with values from numeric_traits

* Build tests only on devices supporting the type
This commit is contained in:
Anton Gorenko
2025-07-31 10:54:17 +06:00
committed by GitHub
parent e8709c24f4
commit 7b074249f4
3 changed files with 663 additions and 50 deletions

View File

@@ -43,19 +43,19 @@ enum class fp8_interpretation
};
/*
* ______________FNUZ_________________ | ______________OCP________________
* ______________FNUZ_________________ | ______________OCP________________
* e4m3 e5m2 | e4m3 e5m2
* bias : 8 16 | 7 15
* inf : 1.0000.000 1.00000.00 | N/A s.11111.00
* inf : N/A N/A | N/A s.11111.00
* Nan : 1.0000.000 1.00000.00 | s.1111.111 s.11111.{01, 10, 11}
* zero : 0.0000.000 0.00000.00 | s.0000.000 s.00000.00
* Max(norm) : s.1111.111 (240) s.11111.11(57344) | s.1111.110(448) s.11110.11(57344)
* Max(snorm): s.0000.111 s.00000.11 | s.0000.111 s.00000.11
* 0.0068359375 2.288818e-05 | 0.013671875 4.57763671875e-05
* 0.0068359375 2.288818e-05 | 0.013671875 4.57763671875e-05
* Min(norm) : s.0001.000 s.00001.00 | s.0001.000 s.00001.00
* 2^-7(0.00078125) 2^-15(3.05176e-05) | 2^-6(0.015625) 2^-14(6.10352e-05)
* 2^-7(0.0078125) 2^-15(3.05176e-05) | 2^-6(0.015625) 2^-14(6.10352e-05)
* Min(snorm): s.0000.001 s.00000.01 | s.0000.001 s.00000.01
* 2^-10(0.00097656) 2^-17(7.629395e-06)| 2^-9(0.001953125) 2^-16(1.52588e-05)
* 2^-10(0.0009765625) 2^-17(7.62939e-06) | 2^-9(0.001953125) 2^-16(1.52588e-05)
*/
template <fp8_rounding_mode rounding = static_cast<fp8_rounding_mode>(CK_TILE_FLOAT_TO_FP8_DEFAULT)>
@@ -259,50 +259,50 @@ CK_TILE_HOST_DEVICE DstT run_cast_to_f8(SrcT src, unsigned int rng = 0)
// fp8/bf8 type exponent/mantissa layout
constexpr int DstT_exp = numeric_traits<DstT>::exp; // exponent width of the destination type
constexpr int DstT_mant = numeric_traits<DstT>::mant; // mantissa width of the destination type
constexpr int DstT_bias = numeric_traits<DstT>::bias;
constexpr bool is_fnuz =
(numeric_traits<DstT>::f8_interpret == fp8_interpretation::E4M3_FNUZ) ||
(numeric_traits<DstT>::f8_interpret == fp8_interpretation::E5M2_FNUZ);
constexpr int SrcT_exp = numeric_traits<SrcT>::exp;
constexpr int SrcT_mant = numeric_traits<SrcT>::mant;
constexpr int SrcT_exp = numeric_traits<SrcT>::exp;
constexpr int SrcT_mant = numeric_traits<SrcT>::mant;
constexpr int bias = numeric_traits<SrcT>::bias;
constexpr unsigned int fInf = numeric_traits<SrcT>::Inf;
constexpr unsigned int abs_mask = numeric_traits<SrcT>::abs_mask;
using SrcT_bitwise = typename numeric_traits<SrcT>::bitwise_type;
SrcT_bitwise src_bitwise = bit_cast<SrcT_bitwise>(src);
unsigned long long head, mantissa;
int exponent, bias;
unsigned int head, mantissa;
int exponent;
unsigned int sign;
unsigned long long fInf, abs_mask;
head = src_bitwise & numeric_traits<SrcT>::head_mask;
mantissa = src_bitwise & numeric_traits<SrcT>::mant_mask;
exponent = (head >> SrcT_mant) & numeric_traits<SrcT>::exp_mask;
sign = head >> (SrcT_exp + SrcT_mant);
bias = numeric_traits<SrcT>::bias;
fInf = numeric_traits<SrcT>::Inf;
abs_mask = numeric_traits<SrcT>::abs_mask;
unsigned int signed_inf = 0;
unsigned int nan = 0;
if constexpr(is_fnuz)
{
signed_inf = clip ? ((sign << 7) + 0x7f) : 0x80;
signed_inf = clip ? ((sign << (DstT_exp + DstT_mant)) + 0x7f) : 0x80;
nan = 0x80;
}
else
{
if constexpr(DstT_exp == 4)
{ // e4m3
signed_inf = (sign << 7) + (clip ? 0x7e : 0x7f);
signed_inf = (sign << (DstT_exp + DstT_mant)) + (clip ? 0x7e : 0x7f);
}
else
{ // e5m2
signed_inf = (sign << 7) + (clip ? 0x7b : 0x7c);
signed_inf = (sign << (DstT_exp + DstT_mant)) + (clip ? 0x7b : 0x7c);
}
nan = (sign << 7) + 0x7f;
nan = (sign << (DstT_exp + DstT_mant)) + 0x7f;
}
// Max values
unsigned long long ifmax = 0;
unsigned int ifmax = 0;
if constexpr(is_float)
{
if constexpr(DstT_exp == 5)
@@ -343,9 +343,6 @@ CK_TILE_HOST_DEVICE DstT run_cast_to_f8(SrcT src, unsigned int rng = 0)
// Deal with inf and NaNs
if((src_bitwise & fInf) == fInf)
{
if constexpr(is_fnuz)
return signed_inf;
return mantissa != 0 ? nan : signed_inf;
}
@@ -354,11 +351,6 @@ CK_TILE_HOST_DEVICE DstT run_cast_to_f8(SrcT src, unsigned int rng = 0)
return signed_inf;
}
if(src_bitwise == 0)
{
return 0;
}
// First need to check if it is normal or denorm as there is a difference of
// implicit 1 Then need to adjust the exponent to align with the F8 exponent,
// in the meanwhile, shift The mantissa. Then for stochastic rounding, add rng
@@ -367,8 +359,7 @@ CK_TILE_HOST_DEVICE DstT run_cast_to_f8(SrcT src, unsigned int rng = 0)
// For IEEE bias mode, the bias is 2^(k-1) -1 where k is the width of exponent
// bits
const int f8_bias = (1 << (DstT_exp - 1)) - 1 + (is_fnuz ? 1 : 0);
const int f8_denormal_act_exponent = 1 - f8_bias; // actual exponent of f8 denormal
constexpr int f8_denormal_act_exponent = 1 - DstT_bias; // actual exponent of f8 denormal
// act_exponent is the actual exponent of fp32/fp16 (after subtracting bias)
// f8_exponent is the converted f8 exponent with bias encoding
// exponent_diff is the diff between fp32/fp16 exponent and f8 exponent,
@@ -406,11 +397,16 @@ CK_TILE_HOST_DEVICE DstT run_cast_to_f8(SrcT src, unsigned int rng = 0)
// for this case, act_exponent could be larger. Just
// that it does not need shift mantissa
}
mantissa += (1ull << SrcT_mant); // Add the implicit 1 into mantissa
mantissa += (1u << SrcT_mant); // Add the implicit 1 into mantissa
}
bool midpoint = (mantissa & ((1ull << (SrcT_mant - DstT_mant + exponent_diff)) - 1)) ==
(1ull << (SrcT_mant - DstT_mant + exponent_diff - 1));
// The value is smaller than min f8 denormal and results in zero (the early exit also prevents
// an undefined behavior of bit shifts >= type width).
if(exponent_diff > DstT_mant)
{
return is_fnuz ? 0 : (sign << (DstT_exp + DstT_mant));
}
bool midpoint = (mantissa & ((1u << (SrcT_mant - DstT_mant + exponent_diff)) - 1)) ==
(1u << (SrcT_mant - DstT_mant + exponent_diff - 1));
/* This part is a bit tricky. The judgment of whether it is a tie needs to be
done before we shift right as shift right could rip off some residual part and
make something not midpoint look like midpoint. For example, the fp16 number
@@ -422,31 +418,31 @@ CK_TILE_HOST_DEVICE DstT run_cast_to_f8(SrcT src, unsigned int rng = 0)
mantissa >>= exponent_diff;
else if(exponent_diff == -1)
mantissa <<= -exponent_diff;
bool implicit_one = mantissa & (1ull << SrcT_mant);
bool implicit_one = mantissa & (1u << SrcT_mant);
// if there is no implicit 1, it means the f8 is denormal and need to adjust
// to denorm exponent
f8_exponent =
(act_exponent + exponent_diff) /*actual f8 exponent*/ + f8_bias - (implicit_one ? 0 : 1);
(act_exponent + exponent_diff) /*actual f8 exponent*/ + DstT_bias - (implicit_one ? 0 : 1);
// Now we have the exponent and mantissa adjusted
unsigned long long drop_mask = (1ull << (SrcT_mant - DstT_mant)) - 1;
unsigned int drop_mask = (1u << (SrcT_mant - DstT_mant)) - 1;
bool odd =
mantissa & (1ull << (SrcT_mant -
DstT_mant)); // if the least significant bit that is not truncated is 1
mantissa &
(1u << (SrcT_mant - DstT_mant)); // if the least significant bit that is not truncated is 1
mantissa +=
(stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1ull) : mantissa)) & drop_mask;
(stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1u) : mantissa)) & drop_mask;
// Now we deal with overflow
if(f8_exponent == 0)
{
if((1ull << SrcT_mant) & mantissa)
if((1u << SrcT_mant) & mantissa)
{
f8_exponent = 1; // denormal overflow to become normal, promote exponent
}
}
else
{
if((1ull << (SrcT_mant + 1)) & mantissa)
if((1u << (SrcT_mant + 1)) & mantissa)
{
mantissa >>= 1;
f8_exponent++;
@@ -471,9 +467,9 @@ CK_TILE_HOST_DEVICE DstT run_cast_to_f8(SrcT src, unsigned int rng = 0)
}
if(f8_exponent == 0 && mantissa == 0)
return is_fnuz ? 0 : (sign << 7);
return is_fnuz ? 0 : (sign << (DstT_exp + DstT_mant));
mantissa &= (1 << DstT_mant) - 1;
return (sign << 7) | (f8_exponent << DstT_mant) | mantissa;
return (sign << (DstT_exp + DstT_mant)) | (f8_exponent << DstT_mant) | mantissa;
}
template <typename SrcT, typename DstT, bool clip = true>
@@ -481,8 +477,9 @@ CK_TILE_HOST_DEVICE DstT run_cast_from_f8(SrcT x)
{
static_assert(std::is_same<SrcT, fp8_t>::value || std::is_same<SrcT, bf8_t>::value,
"SrcT type must be fp8 or bf8.");
constexpr int SrcT_exp = numeric_traits<SrcT>::exp;
constexpr int SrcT_mant = numeric_traits<SrcT>::mant;
constexpr int SrcT_exp = numeric_traits<SrcT>::exp;
constexpr int SrcT_mant = numeric_traits<SrcT>::mant;
constexpr uint8_t SrcT_abs_mask = numeric_traits<SrcT>::abs_mask;
constexpr bool is_fnuz =
(numeric_traits<SrcT>::f8_interpret == fp8_interpretation::E4M3_FNUZ) ||
(numeric_traits<SrcT>::f8_interpret == fp8_interpretation::E5M2_FNUZ);
@@ -518,9 +515,9 @@ CK_TILE_HOST_DEVICE DstT run_cast_from_f8(SrcT x)
return 0;
}
unsigned long long sign = x >> 7;
unsigned long long mantissa = x & ((1 << SrcT_mant) - 1);
int exponent = (x & 0x7F) >> SrcT_mant;
unsigned int sign = x >> (SrcT_exp + SrcT_mant);
unsigned int mantissa = x & ((1 << SrcT_mant) - 1);
int exponent = (x & SrcT_abs_mask) >> SrcT_mant;
if constexpr(is_fnuz)
{
if((x & 0xff) == 0x80)
@@ -559,7 +556,7 @@ CK_TILE_HOST_DEVICE DstT run_cast_from_f8(SrcT x)
if constexpr(SrcT_exp == 5 && is_half && !is_fnuz)
{
retval = x << 8;
retval = static_cast<typename numeric_traits<DstT>::bitwise_type>(x) << 8;
return bit_cast<DstT>(retval);
}