mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[CK_TILE] Fix UB and corner cases in f32/f16 to/from f8 conversion (#2571)
* Add tests for host convesion f32/f16 to f8 * Add tests for host convesion from f8 to f32/f16 * Fix UB and corner cases in f32/f16 to/from f8 conversion * There are UBs when very small values are converted to f8: bitshifts can be larger that type width. Using unsigned long long does not help because exponent_diff >= 64 in such cases. This causes that values like 2.117582368e-22 are converted to non-zero f8 in host validation of FMHA tests, test_f8 crashes with segfault in completely irrelevant code like GTest internals or produces non-deterministic results etc. * Fix FNUZ conversion to return NaN for NaN inputs. * Fix compilation error (due to uint8_t << 8) in OCP e5m2 to f16 conversion. * Replace some magic numbers with values from numeric_traits * Build tests only on devices supporting the type
This commit is contained in:
@@ -43,19 +43,19 @@ enum class fp8_interpretation
|
||||
};
|
||||
|
||||
/*
|
||||
* ______________FNUZ_________________ | ______________OCP________________
|
||||
* ______________FNUZ_________________ | ______________OCP________________
|
||||
* e4m3 e5m2 | e4m3 e5m2
|
||||
* bias : 8 16 | 7 15
|
||||
* inf : 1.0000.000 1.00000.00 | N/A s.11111.00
|
||||
* inf : N/A N/A | N/A s.11111.00
|
||||
* Nan : 1.0000.000 1.00000.00 | s.1111.111 s.11111.{01, 10, 11}
|
||||
* zero : 0.0000.000 0.00000.00 | s.0000.000 s.00000.00
|
||||
* Max(norm) : s.1111.111 (240) s.11111.11(57344) | s.1111.110(448) s.11110.11(57344)
|
||||
* Max(snorm): s.0000.111 s.00000.11 | s.0000.111 s.00000.11
|
||||
* 0.0068359375 2.288818e-05 | 0.013671875 4.57763671875e-05
|
||||
* 0.0068359375 2.288818e-05 | 0.013671875 4.57763671875e-05
|
||||
* Min(norm) : s.0001.000 s.00001.00 | s.0001.000 s.00001.00
|
||||
* 2^-7(0.00078125) 2^-15(3.05176e-05) | 2^-6(0.015625) 2^-14(6.10352e-05)
|
||||
* 2^-7(0.0078125) 2^-15(3.05176e-05) | 2^-6(0.015625) 2^-14(6.10352e-05)
|
||||
* Min(snorm): s.0000.001 s.00000.01 | s.0000.001 s.00000.01
|
||||
* 2^-10(0.00097656) 2^-17(7.629395e-06)| 2^-9(0.001953125) 2^-16(1.52588e-05)
|
||||
* 2^-10(0.0009765625) 2^-17(7.62939e-06) | 2^-9(0.001953125) 2^-16(1.52588e-05)
|
||||
*/
|
||||
|
||||
template <fp8_rounding_mode rounding = static_cast<fp8_rounding_mode>(CK_TILE_FLOAT_TO_FP8_DEFAULT)>
|
||||
@@ -259,50 +259,50 @@ CK_TILE_HOST_DEVICE DstT run_cast_to_f8(SrcT src, unsigned int rng = 0)
|
||||
// fp8/bf8 type exponent/mantissa layout
|
||||
constexpr int DstT_exp = numeric_traits<DstT>::exp; // exponent width of the destination type
|
||||
constexpr int DstT_mant = numeric_traits<DstT>::mant; // mantissa width of the destination type
|
||||
constexpr int DstT_bias = numeric_traits<DstT>::bias;
|
||||
constexpr bool is_fnuz =
|
||||
(numeric_traits<DstT>::f8_interpret == fp8_interpretation::E4M3_FNUZ) ||
|
||||
(numeric_traits<DstT>::f8_interpret == fp8_interpretation::E5M2_FNUZ);
|
||||
|
||||
constexpr int SrcT_exp = numeric_traits<SrcT>::exp;
|
||||
constexpr int SrcT_mant = numeric_traits<SrcT>::mant;
|
||||
constexpr int SrcT_exp = numeric_traits<SrcT>::exp;
|
||||
constexpr int SrcT_mant = numeric_traits<SrcT>::mant;
|
||||
constexpr int bias = numeric_traits<SrcT>::bias;
|
||||
constexpr unsigned int fInf = numeric_traits<SrcT>::Inf;
|
||||
constexpr unsigned int abs_mask = numeric_traits<SrcT>::abs_mask;
|
||||
|
||||
using SrcT_bitwise = typename numeric_traits<SrcT>::bitwise_type;
|
||||
SrcT_bitwise src_bitwise = bit_cast<SrcT_bitwise>(src);
|
||||
|
||||
unsigned long long head, mantissa;
|
||||
int exponent, bias;
|
||||
unsigned int head, mantissa;
|
||||
int exponent;
|
||||
unsigned int sign;
|
||||
unsigned long long fInf, abs_mask;
|
||||
|
||||
head = src_bitwise & numeric_traits<SrcT>::head_mask;
|
||||
mantissa = src_bitwise & numeric_traits<SrcT>::mant_mask;
|
||||
exponent = (head >> SrcT_mant) & numeric_traits<SrcT>::exp_mask;
|
||||
sign = head >> (SrcT_exp + SrcT_mant);
|
||||
bias = numeric_traits<SrcT>::bias;
|
||||
fInf = numeric_traits<SrcT>::Inf;
|
||||
abs_mask = numeric_traits<SrcT>::abs_mask;
|
||||
|
||||
unsigned int signed_inf = 0;
|
||||
unsigned int nan = 0;
|
||||
if constexpr(is_fnuz)
|
||||
{
|
||||
signed_inf = clip ? ((sign << 7) + 0x7f) : 0x80;
|
||||
signed_inf = clip ? ((sign << (DstT_exp + DstT_mant)) + 0x7f) : 0x80;
|
||||
nan = 0x80;
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(DstT_exp == 4)
|
||||
{ // e4m3
|
||||
signed_inf = (sign << 7) + (clip ? 0x7e : 0x7f);
|
||||
signed_inf = (sign << (DstT_exp + DstT_mant)) + (clip ? 0x7e : 0x7f);
|
||||
}
|
||||
else
|
||||
{ // e5m2
|
||||
signed_inf = (sign << 7) + (clip ? 0x7b : 0x7c);
|
||||
signed_inf = (sign << (DstT_exp + DstT_mant)) + (clip ? 0x7b : 0x7c);
|
||||
}
|
||||
nan = (sign << 7) + 0x7f;
|
||||
nan = (sign << (DstT_exp + DstT_mant)) + 0x7f;
|
||||
}
|
||||
// Max values
|
||||
unsigned long long ifmax = 0;
|
||||
unsigned int ifmax = 0;
|
||||
if constexpr(is_float)
|
||||
{
|
||||
if constexpr(DstT_exp == 5)
|
||||
@@ -343,9 +343,6 @@ CK_TILE_HOST_DEVICE DstT run_cast_to_f8(SrcT src, unsigned int rng = 0)
|
||||
// Deal with inf and NaNs
|
||||
if((src_bitwise & fInf) == fInf)
|
||||
{
|
||||
if constexpr(is_fnuz)
|
||||
return signed_inf;
|
||||
|
||||
return mantissa != 0 ? nan : signed_inf;
|
||||
}
|
||||
|
||||
@@ -354,11 +351,6 @@ CK_TILE_HOST_DEVICE DstT run_cast_to_f8(SrcT src, unsigned int rng = 0)
|
||||
return signed_inf;
|
||||
}
|
||||
|
||||
if(src_bitwise == 0)
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
|
||||
// First need to check if it is normal or denorm as there is a difference of
|
||||
// implicit 1 Then need to adjust the exponent to align with the F8 exponent,
|
||||
// in the meanwhile, shift The mantissa. Then for stochastic rounding, add rng
|
||||
@@ -367,8 +359,7 @@ CK_TILE_HOST_DEVICE DstT run_cast_to_f8(SrcT src, unsigned int rng = 0)
|
||||
|
||||
// For IEEE bias mode, the bias is 2^(k-1) -1 where k is the width of exponent
|
||||
// bits
|
||||
const int f8_bias = (1 << (DstT_exp - 1)) - 1 + (is_fnuz ? 1 : 0);
|
||||
const int f8_denormal_act_exponent = 1 - f8_bias; // actual exponent of f8 denormal
|
||||
constexpr int f8_denormal_act_exponent = 1 - DstT_bias; // actual exponent of f8 denormal
|
||||
// act_exponent is the actual exponent of fp32/fp16 (after subtracting bias)
|
||||
// f8_exponent is the converted f8 exponent with bias encoding
|
||||
// exponent_diff is the diff between fp32/fp16 exponent and f8 exponent,
|
||||
@@ -406,11 +397,16 @@ CK_TILE_HOST_DEVICE DstT run_cast_to_f8(SrcT src, unsigned int rng = 0)
|
||||
// for this case, act_exponent could be larger. Just
|
||||
// that it does not need shift mantissa
|
||||
}
|
||||
mantissa += (1ull << SrcT_mant); // Add the implicit 1 into mantissa
|
||||
mantissa += (1u << SrcT_mant); // Add the implicit 1 into mantissa
|
||||
}
|
||||
|
||||
bool midpoint = (mantissa & ((1ull << (SrcT_mant - DstT_mant + exponent_diff)) - 1)) ==
|
||||
(1ull << (SrcT_mant - DstT_mant + exponent_diff - 1));
|
||||
// The value is smaller than min f8 denormal and results in zero (the early exit also prevents
|
||||
// an undefined behavior of bit shifts >= type width).
|
||||
if(exponent_diff > DstT_mant)
|
||||
{
|
||||
return is_fnuz ? 0 : (sign << (DstT_exp + DstT_mant));
|
||||
}
|
||||
bool midpoint = (mantissa & ((1u << (SrcT_mant - DstT_mant + exponent_diff)) - 1)) ==
|
||||
(1u << (SrcT_mant - DstT_mant + exponent_diff - 1));
|
||||
/* This part is a bit tricky. The judgment of whether it is a tie needs to be
|
||||
done before we shift right as shift right could rip off some residual part and
|
||||
make something not midpoint look like midpoint. For example, the fp16 number
|
||||
@@ -422,31 +418,31 @@ CK_TILE_HOST_DEVICE DstT run_cast_to_f8(SrcT src, unsigned int rng = 0)
|
||||
mantissa >>= exponent_diff;
|
||||
else if(exponent_diff == -1)
|
||||
mantissa <<= -exponent_diff;
|
||||
bool implicit_one = mantissa & (1ull << SrcT_mant);
|
||||
bool implicit_one = mantissa & (1u << SrcT_mant);
|
||||
// if there is no implicit 1, it means the f8 is denormal and need to adjust
|
||||
// to denorm exponent
|
||||
f8_exponent =
|
||||
(act_exponent + exponent_diff) /*actual f8 exponent*/ + f8_bias - (implicit_one ? 0 : 1);
|
||||
(act_exponent + exponent_diff) /*actual f8 exponent*/ + DstT_bias - (implicit_one ? 0 : 1);
|
||||
|
||||
// Now we have the exponent and mantissa adjusted
|
||||
unsigned long long drop_mask = (1ull << (SrcT_mant - DstT_mant)) - 1;
|
||||
unsigned int drop_mask = (1u << (SrcT_mant - DstT_mant)) - 1;
|
||||
bool odd =
|
||||
mantissa & (1ull << (SrcT_mant -
|
||||
DstT_mant)); // if the least significant bit that is not truncated is 1
|
||||
mantissa &
|
||||
(1u << (SrcT_mant - DstT_mant)); // if the least significant bit that is not truncated is 1
|
||||
mantissa +=
|
||||
(stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1ull) : mantissa)) & drop_mask;
|
||||
(stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1u) : mantissa)) & drop_mask;
|
||||
|
||||
// Now we deal with overflow
|
||||
if(f8_exponent == 0)
|
||||
{
|
||||
if((1ull << SrcT_mant) & mantissa)
|
||||
if((1u << SrcT_mant) & mantissa)
|
||||
{
|
||||
f8_exponent = 1; // denormal overflow to become normal, promote exponent
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if((1ull << (SrcT_mant + 1)) & mantissa)
|
||||
if((1u << (SrcT_mant + 1)) & mantissa)
|
||||
{
|
||||
mantissa >>= 1;
|
||||
f8_exponent++;
|
||||
@@ -471,9 +467,9 @@ CK_TILE_HOST_DEVICE DstT run_cast_to_f8(SrcT src, unsigned int rng = 0)
|
||||
}
|
||||
|
||||
if(f8_exponent == 0 && mantissa == 0)
|
||||
return is_fnuz ? 0 : (sign << 7);
|
||||
return is_fnuz ? 0 : (sign << (DstT_exp + DstT_mant));
|
||||
mantissa &= (1 << DstT_mant) - 1;
|
||||
return (sign << 7) | (f8_exponent << DstT_mant) | mantissa;
|
||||
return (sign << (DstT_exp + DstT_mant)) | (f8_exponent << DstT_mant) | mantissa;
|
||||
}
|
||||
|
||||
template <typename SrcT, typename DstT, bool clip = true>
|
||||
@@ -481,8 +477,9 @@ CK_TILE_HOST_DEVICE DstT run_cast_from_f8(SrcT x)
|
||||
{
|
||||
static_assert(std::is_same<SrcT, fp8_t>::value || std::is_same<SrcT, bf8_t>::value,
|
||||
"SrcT type must be fp8 or bf8.");
|
||||
constexpr int SrcT_exp = numeric_traits<SrcT>::exp;
|
||||
constexpr int SrcT_mant = numeric_traits<SrcT>::mant;
|
||||
constexpr int SrcT_exp = numeric_traits<SrcT>::exp;
|
||||
constexpr int SrcT_mant = numeric_traits<SrcT>::mant;
|
||||
constexpr uint8_t SrcT_abs_mask = numeric_traits<SrcT>::abs_mask;
|
||||
constexpr bool is_fnuz =
|
||||
(numeric_traits<SrcT>::f8_interpret == fp8_interpretation::E4M3_FNUZ) ||
|
||||
(numeric_traits<SrcT>::f8_interpret == fp8_interpretation::E5M2_FNUZ);
|
||||
@@ -518,9 +515,9 @@ CK_TILE_HOST_DEVICE DstT run_cast_from_f8(SrcT x)
|
||||
return 0;
|
||||
}
|
||||
|
||||
unsigned long long sign = x >> 7;
|
||||
unsigned long long mantissa = x & ((1 << SrcT_mant) - 1);
|
||||
int exponent = (x & 0x7F) >> SrcT_mant;
|
||||
unsigned int sign = x >> (SrcT_exp + SrcT_mant);
|
||||
unsigned int mantissa = x & ((1 << SrcT_mant) - 1);
|
||||
int exponent = (x & SrcT_abs_mask) >> SrcT_mant;
|
||||
if constexpr(is_fnuz)
|
||||
{
|
||||
if((x & 0xff) == 0x80)
|
||||
@@ -559,7 +556,7 @@ CK_TILE_HOST_DEVICE DstT run_cast_from_f8(SrcT x)
|
||||
|
||||
if constexpr(SrcT_exp == 5 && is_half && !is_fnuz)
|
||||
{
|
||||
retval = x << 8;
|
||||
retval = static_cast<typename numeric_traits<DstT>::bitwise_type>(x) << 8;
|
||||
return bit_cast<DstT>(retval);
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user