mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
[CK_TILE] Fix UB and corner cases in f32/f16 to/from f8 conversion (#2571)
* Add tests for host convesion f32/f16 to f8
* Add tests for host convesion from f8 to f32/f16
* Fix UB and corner cases in f32/f16 to/from f8 conversion
* There are UBs when very small values are converted to f8: bitshifts
can be larger that type width. Using unsigned long long does not help
because exponent_diff >= 64 in such cases. This causes that values
like 2.117582368e-22 are converted to non-zero f8 in host validation
of FMHA tests, test_f8 crashes with segfault in completely irrelevant
code like GTest internals or produces non-deterministic results etc.
* Fix FNUZ conversion to return NaN for NaN inputs.
* Fix compilation error (due to uint8_t << 8) in OCP e5m2 to f16
conversion.
* Replace some magic numbers with values from numeric_traits
* Build tests only on devices supporting the type
[ROCm/composable_kernel commit: 7b074249f4]
This commit is contained in:
@@ -43,19 +43,19 @@ enum class fp8_interpretation
|
||||
};
|
||||
|
||||
/*
|
||||
* ______________FNUZ_________________ | ______________OCP________________
|
||||
* ______________FNUZ_________________ | ______________OCP________________
|
||||
* e4m3 e5m2 | e4m3 e5m2
|
||||
* bias : 8 16 | 7 15
|
||||
* inf : 1.0000.000 1.00000.00 | N/A s.11111.00
|
||||
* inf : N/A N/A | N/A s.11111.00
|
||||
* Nan : 1.0000.000 1.00000.00 | s.1111.111 s.11111.{01, 10, 11}
|
||||
* zero : 0.0000.000 0.00000.00 | s.0000.000 s.00000.00
|
||||
* Max(norm) : s.1111.111 (240) s.11111.11(57344) | s.1111.110(448) s.11110.11(57344)
|
||||
* Max(snorm): s.0000.111 s.00000.11 | s.0000.111 s.00000.11
|
||||
* 0.0068359375 2.288818e-05 | 0.013671875 4.57763671875e-05
|
||||
* 0.0068359375 2.288818e-05 | 0.013671875 4.57763671875e-05
|
||||
* Min(norm) : s.0001.000 s.00001.00 | s.0001.000 s.00001.00
|
||||
* 2^-7(0.00078125) 2^-15(3.05176e-05) | 2^-6(0.015625) 2^-14(6.10352e-05)
|
||||
* 2^-7(0.0078125) 2^-15(3.05176e-05) | 2^-6(0.015625) 2^-14(6.10352e-05)
|
||||
* Min(snorm): s.0000.001 s.00000.01 | s.0000.001 s.00000.01
|
||||
* 2^-10(0.00097656) 2^-17(7.629395e-06)| 2^-9(0.001953125) 2^-16(1.52588e-05)
|
||||
* 2^-10(0.0009765625) 2^-17(7.62939e-06) | 2^-9(0.001953125) 2^-16(1.52588e-05)
|
||||
*/
|
||||
|
||||
template <fp8_rounding_mode rounding = static_cast<fp8_rounding_mode>(CK_TILE_FLOAT_TO_FP8_DEFAULT)>
|
||||
@@ -259,50 +259,50 @@ CK_TILE_HOST_DEVICE DstT run_cast_to_f8(SrcT src, unsigned int rng = 0)
|
||||
// fp8/bf8 type exponent/mantissa layout
|
||||
constexpr int DstT_exp = numeric_traits<DstT>::exp; // exponent width of the destination type
|
||||
constexpr int DstT_mant = numeric_traits<DstT>::mant; // mantissa width of the destination type
|
||||
constexpr int DstT_bias = numeric_traits<DstT>::bias;
|
||||
constexpr bool is_fnuz =
|
||||
(numeric_traits<DstT>::f8_interpret == fp8_interpretation::E4M3_FNUZ) ||
|
||||
(numeric_traits<DstT>::f8_interpret == fp8_interpretation::E5M2_FNUZ);
|
||||
|
||||
constexpr int SrcT_exp = numeric_traits<SrcT>::exp;
|
||||
constexpr int SrcT_mant = numeric_traits<SrcT>::mant;
|
||||
constexpr int SrcT_exp = numeric_traits<SrcT>::exp;
|
||||
constexpr int SrcT_mant = numeric_traits<SrcT>::mant;
|
||||
constexpr int bias = numeric_traits<SrcT>::bias;
|
||||
constexpr unsigned int fInf = numeric_traits<SrcT>::Inf;
|
||||
constexpr unsigned int abs_mask = numeric_traits<SrcT>::abs_mask;
|
||||
|
||||
using SrcT_bitwise = typename numeric_traits<SrcT>::bitwise_type;
|
||||
SrcT_bitwise src_bitwise = bit_cast<SrcT_bitwise>(src);
|
||||
|
||||
unsigned long long head, mantissa;
|
||||
int exponent, bias;
|
||||
unsigned int head, mantissa;
|
||||
int exponent;
|
||||
unsigned int sign;
|
||||
unsigned long long fInf, abs_mask;
|
||||
|
||||
head = src_bitwise & numeric_traits<SrcT>::head_mask;
|
||||
mantissa = src_bitwise & numeric_traits<SrcT>::mant_mask;
|
||||
exponent = (head >> SrcT_mant) & numeric_traits<SrcT>::exp_mask;
|
||||
sign = head >> (SrcT_exp + SrcT_mant);
|
||||
bias = numeric_traits<SrcT>::bias;
|
||||
fInf = numeric_traits<SrcT>::Inf;
|
||||
abs_mask = numeric_traits<SrcT>::abs_mask;
|
||||
|
||||
unsigned int signed_inf = 0;
|
||||
unsigned int nan = 0;
|
||||
if constexpr(is_fnuz)
|
||||
{
|
||||
signed_inf = clip ? ((sign << 7) + 0x7f) : 0x80;
|
||||
signed_inf = clip ? ((sign << (DstT_exp + DstT_mant)) + 0x7f) : 0x80;
|
||||
nan = 0x80;
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(DstT_exp == 4)
|
||||
{ // e4m3
|
||||
signed_inf = (sign << 7) + (clip ? 0x7e : 0x7f);
|
||||
signed_inf = (sign << (DstT_exp + DstT_mant)) + (clip ? 0x7e : 0x7f);
|
||||
}
|
||||
else
|
||||
{ // e5m2
|
||||
signed_inf = (sign << 7) + (clip ? 0x7b : 0x7c);
|
||||
signed_inf = (sign << (DstT_exp + DstT_mant)) + (clip ? 0x7b : 0x7c);
|
||||
}
|
||||
nan = (sign << 7) + 0x7f;
|
||||
nan = (sign << (DstT_exp + DstT_mant)) + 0x7f;
|
||||
}
|
||||
// Max values
|
||||
unsigned long long ifmax = 0;
|
||||
unsigned int ifmax = 0;
|
||||
if constexpr(is_float)
|
||||
{
|
||||
if constexpr(DstT_exp == 5)
|
||||
@@ -343,9 +343,6 @@ CK_TILE_HOST_DEVICE DstT run_cast_to_f8(SrcT src, unsigned int rng = 0)
|
||||
// Deal with inf and NaNs
|
||||
if((src_bitwise & fInf) == fInf)
|
||||
{
|
||||
if constexpr(is_fnuz)
|
||||
return signed_inf;
|
||||
|
||||
return mantissa != 0 ? nan : signed_inf;
|
||||
}
|
||||
|
||||
@@ -354,11 +351,6 @@ CK_TILE_HOST_DEVICE DstT run_cast_to_f8(SrcT src, unsigned int rng = 0)
|
||||
return signed_inf;
|
||||
}
|
||||
|
||||
if(src_bitwise == 0)
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
|
||||
// First need to check if it is normal or denorm as there is a difference of
|
||||
// implicit 1 Then need to adjust the exponent to align with the F8 exponent,
|
||||
// in the meanwhile, shift The mantissa. Then for stochastic rounding, add rng
|
||||
@@ -367,8 +359,7 @@ CK_TILE_HOST_DEVICE DstT run_cast_to_f8(SrcT src, unsigned int rng = 0)
|
||||
|
||||
// For IEEE bias mode, the bias is 2^(k-1) -1 where k is the width of exponent
|
||||
// bits
|
||||
const int f8_bias = (1 << (DstT_exp - 1)) - 1 + (is_fnuz ? 1 : 0);
|
||||
const int f8_denormal_act_exponent = 1 - f8_bias; // actual exponent of f8 denormal
|
||||
constexpr int f8_denormal_act_exponent = 1 - DstT_bias; // actual exponent of f8 denormal
|
||||
// act_exponent is the actual exponent of fp32/fp16 (after subtracting bias)
|
||||
// f8_exponent is the converted f8 exponent with bias encoding
|
||||
// exponent_diff is the diff between fp32/fp16 exponent and f8 exponent,
|
||||
@@ -406,11 +397,16 @@ CK_TILE_HOST_DEVICE DstT run_cast_to_f8(SrcT src, unsigned int rng = 0)
|
||||
// for this case, act_exponent could be larger. Just
|
||||
// that it does not need shift mantissa
|
||||
}
|
||||
mantissa += (1ull << SrcT_mant); // Add the implicit 1 into mantissa
|
||||
mantissa += (1u << SrcT_mant); // Add the implicit 1 into mantissa
|
||||
}
|
||||
|
||||
bool midpoint = (mantissa & ((1ull << (SrcT_mant - DstT_mant + exponent_diff)) - 1)) ==
|
||||
(1ull << (SrcT_mant - DstT_mant + exponent_diff - 1));
|
||||
// The value is smaller than min f8 denormal and results in zero (the early exit also prevents
|
||||
// an undefined behavior of bit shifts >= type width).
|
||||
if(exponent_diff > DstT_mant)
|
||||
{
|
||||
return is_fnuz ? 0 : (sign << (DstT_exp + DstT_mant));
|
||||
}
|
||||
bool midpoint = (mantissa & ((1u << (SrcT_mant - DstT_mant + exponent_diff)) - 1)) ==
|
||||
(1u << (SrcT_mant - DstT_mant + exponent_diff - 1));
|
||||
/* This part is a bit tricky. The judgment of whether it is a tie needs to be
|
||||
done before we shift right as shift right could rip off some residual part and
|
||||
make something not midpoint look like midpoint. For example, the fp16 number
|
||||
@@ -422,31 +418,31 @@ CK_TILE_HOST_DEVICE DstT run_cast_to_f8(SrcT src, unsigned int rng = 0)
|
||||
mantissa >>= exponent_diff;
|
||||
else if(exponent_diff == -1)
|
||||
mantissa <<= -exponent_diff;
|
||||
bool implicit_one = mantissa & (1ull << SrcT_mant);
|
||||
bool implicit_one = mantissa & (1u << SrcT_mant);
|
||||
// if there is no implicit 1, it means the f8 is denormal and need to adjust
|
||||
// to denorm exponent
|
||||
f8_exponent =
|
||||
(act_exponent + exponent_diff) /*actual f8 exponent*/ + f8_bias - (implicit_one ? 0 : 1);
|
||||
(act_exponent + exponent_diff) /*actual f8 exponent*/ + DstT_bias - (implicit_one ? 0 : 1);
|
||||
|
||||
// Now we have the exponent and mantissa adjusted
|
||||
unsigned long long drop_mask = (1ull << (SrcT_mant - DstT_mant)) - 1;
|
||||
unsigned int drop_mask = (1u << (SrcT_mant - DstT_mant)) - 1;
|
||||
bool odd =
|
||||
mantissa & (1ull << (SrcT_mant -
|
||||
DstT_mant)); // if the least significant bit that is not truncated is 1
|
||||
mantissa &
|
||||
(1u << (SrcT_mant - DstT_mant)); // if the least significant bit that is not truncated is 1
|
||||
mantissa +=
|
||||
(stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1ull) : mantissa)) & drop_mask;
|
||||
(stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1u) : mantissa)) & drop_mask;
|
||||
|
||||
// Now we deal with overflow
|
||||
if(f8_exponent == 0)
|
||||
{
|
||||
if((1ull << SrcT_mant) & mantissa)
|
||||
if((1u << SrcT_mant) & mantissa)
|
||||
{
|
||||
f8_exponent = 1; // denormal overflow to become normal, promote exponent
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if((1ull << (SrcT_mant + 1)) & mantissa)
|
||||
if((1u << (SrcT_mant + 1)) & mantissa)
|
||||
{
|
||||
mantissa >>= 1;
|
||||
f8_exponent++;
|
||||
@@ -471,9 +467,9 @@ CK_TILE_HOST_DEVICE DstT run_cast_to_f8(SrcT src, unsigned int rng = 0)
|
||||
}
|
||||
|
||||
if(f8_exponent == 0 && mantissa == 0)
|
||||
return is_fnuz ? 0 : (sign << 7);
|
||||
return is_fnuz ? 0 : (sign << (DstT_exp + DstT_mant));
|
||||
mantissa &= (1 << DstT_mant) - 1;
|
||||
return (sign << 7) | (f8_exponent << DstT_mant) | mantissa;
|
||||
return (sign << (DstT_exp + DstT_mant)) | (f8_exponent << DstT_mant) | mantissa;
|
||||
}
|
||||
|
||||
template <typename SrcT, typename DstT, bool clip = true>
|
||||
@@ -481,8 +477,9 @@ CK_TILE_HOST_DEVICE DstT run_cast_from_f8(SrcT x)
|
||||
{
|
||||
static_assert(std::is_same<SrcT, fp8_t>::value || std::is_same<SrcT, bf8_t>::value,
|
||||
"SrcT type must be fp8 or bf8.");
|
||||
constexpr int SrcT_exp = numeric_traits<SrcT>::exp;
|
||||
constexpr int SrcT_mant = numeric_traits<SrcT>::mant;
|
||||
constexpr int SrcT_exp = numeric_traits<SrcT>::exp;
|
||||
constexpr int SrcT_mant = numeric_traits<SrcT>::mant;
|
||||
constexpr uint8_t SrcT_abs_mask = numeric_traits<SrcT>::abs_mask;
|
||||
constexpr bool is_fnuz =
|
||||
(numeric_traits<SrcT>::f8_interpret == fp8_interpretation::E4M3_FNUZ) ||
|
||||
(numeric_traits<SrcT>::f8_interpret == fp8_interpretation::E5M2_FNUZ);
|
||||
@@ -518,9 +515,9 @@ CK_TILE_HOST_DEVICE DstT run_cast_from_f8(SrcT x)
|
||||
return 0;
|
||||
}
|
||||
|
||||
unsigned long long sign = x >> 7;
|
||||
unsigned long long mantissa = x & ((1 << SrcT_mant) - 1);
|
||||
int exponent = (x & 0x7F) >> SrcT_mant;
|
||||
unsigned int sign = x >> (SrcT_exp + SrcT_mant);
|
||||
unsigned int mantissa = x & ((1 << SrcT_mant) - 1);
|
||||
int exponent = (x & SrcT_abs_mask) >> SrcT_mant;
|
||||
if constexpr(is_fnuz)
|
||||
{
|
||||
if((x & 0xff) == 0x80)
|
||||
@@ -559,7 +556,7 @@ CK_TILE_HOST_DEVICE DstT run_cast_from_f8(SrcT x)
|
||||
|
||||
if constexpr(SrcT_exp == 5 && is_half && !is_fnuz)
|
||||
{
|
||||
retval = x << 8;
|
||||
retval = static_cast<typename numeric_traits<DstT>::bitwise_type>(x) << 8;
|
||||
return bit_cast<DstT>(retval);
|
||||
}
|
||||
|
||||
|
||||
@@ -1,5 +1,15 @@
|
||||
# Currently ck_tile is only built on gfx9
|
||||
if(GPU_TARGETS MATCHES "gfx9")
|
||||
add_gtest_executable(test_ck_tile_pk_int4 test_pk_int4.cpp)
|
||||
add_gtest_executable(test_ck_tile_pk_fp4 test_pk_fp4.cpp)
|
||||
endif()
|
||||
if(GPU_TARGETS MATCHES "gfx95")
|
||||
add_gtest_executable(test_ck_tile_pk_fp4 test_pk_fp4.cpp)
|
||||
endif()
|
||||
|
||||
if(CK_USE_OCP_FP8 OR CK_USE_FNUZ_FP8)
|
||||
add_gtest_executable(test_ck_tile_fp8 test_fp8.cpp)
|
||||
target_compile_options(test_ck_tile_fp8 PRIVATE -Wno-float-equal)
|
||||
# conditionally specify the use of OCP_FP8
|
||||
if(CK_USE_OCP_FP8)
|
||||
target_compile_options(test_ck_tile_fp8 PRIVATE -DCK_TILE_USE_OCP_FP8)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
606
test/ck_tile/data_type/test_fp8.cpp
Normal file
606
test/ck_tile/data_type/test_fp8.cpp
Normal file
@@ -0,0 +1,606 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
template <typename T>
|
||||
class ConvertTest : public ::testing::Test
|
||||
{
|
||||
};
|
||||
|
||||
using TestTypes = ::testing::Types<float, ck_tile::fp16_t>;
|
||||
|
||||
TYPED_TEST_SUITE(ConvertTest, TestTypes);
|
||||
|
||||
TYPED_TEST(ConvertTest, ToFp8)
|
||||
{
|
||||
using SrcT = TypeParam;
|
||||
using DstT = ck_tile::fp8_t;
|
||||
|
||||
auto c = [](SrcT f) {
|
||||
return static_cast<unsigned int>(
|
||||
ck_tile::bit_cast<uint8_t>(ck_tile::impl::run_cast_to_f8<SrcT, DstT, true>(f)));
|
||||
};
|
||||
|
||||
auto c_nosat = [](SrcT f) {
|
||||
return static_cast<unsigned int>(
|
||||
ck_tile::bit_cast<uint8_t>(ck_tile::impl::run_cast_to_f8<SrcT, DstT, false>(f)));
|
||||
};
|
||||
|
||||
#if CK_TILE_USE_OCP_FP8
|
||||
EXPECT_EQ(c(+1.0f), 0b0'0111'000);
|
||||
EXPECT_EQ(c(-1.0f), 0b1'0111'000);
|
||||
// max f8 normal
|
||||
EXPECT_EQ(c(+448.0f), 0b0'1111'110);
|
||||
EXPECT_EQ(c(-448.0f), 0b1'1111'110);
|
||||
// min f8 normal
|
||||
EXPECT_EQ(c(+0.015625f), 0b0'0001'000);
|
||||
EXPECT_EQ(c(-0.015625f), 0b1'0001'000);
|
||||
// max f8 subnormal
|
||||
EXPECT_EQ(c(+0.013671875f), 0b0'0000'111);
|
||||
EXPECT_EQ(c(-0.013671875f), 0b1'0000'111);
|
||||
// min f8 subnormal
|
||||
EXPECT_EQ(c(+0.001953125f), 0b0'0000'001);
|
||||
EXPECT_EQ(c(-0.001953125f), 0b1'0000'001);
|
||||
// arbitrary values (exact)
|
||||
EXPECT_EQ(c(+0.203125f), 0b0'0100'101);
|
||||
EXPECT_EQ(c(-88.0f), 0b1'1101'011);
|
||||
// arbitrary values (rounded)
|
||||
EXPECT_EQ(c(+432.919f), 0b0'1111'110);
|
||||
EXPECT_EQ(c(-431.111f), 0b1'1111'101);
|
||||
EXPECT_EQ(c(-0.76123f), 0b1'0110'100);
|
||||
EXPECT_EQ(c(+0.81234f), 0b0'0110'101);
|
||||
// midpoint values (rounded to nearest even)
|
||||
EXPECT_EQ(c(+58.0f), 0b0'1100'110);
|
||||
EXPECT_EQ(c(+62.0f), 0b0'1101'000);
|
||||
|
||||
// saturating mode -> max f8 normal
|
||||
// max f32/f16 normal -> max f8 normal
|
||||
EXPECT_EQ(c(+ck_tile::numeric<SrcT>::max()), 0b0'1111'110);
|
||||
EXPECT_EQ(c(-ck_tile::numeric<SrcT>::max()), 0b1'1111'110);
|
||||
// f32/f16 infinity -> max f8 normal
|
||||
EXPECT_EQ(c(+ck_tile::numeric<SrcT>::infinity()), 0b0'1111'110);
|
||||
EXPECT_EQ(c(-ck_tile::numeric<SrcT>::infinity()), 0b1'1111'110);
|
||||
// large f32/f16 -> max f8 normal
|
||||
EXPECT_EQ(c(+1.23e9f), 0b0'1111'110);
|
||||
EXPECT_EQ(c(-1.23e9f), 0b1'1111'110);
|
||||
|
||||
constexpr unsigned int nan_mask = 0b0'1111'111;
|
||||
|
||||
// non-saturating mode -> f8 NaN (because OCP e4m3 has no infinity)
|
||||
// max f32/f16 normal -> f8 NaN
|
||||
EXPECT_EQ(c_nosat(+ck_tile::numeric<SrcT>::max()) & nan_mask, nan_mask);
|
||||
EXPECT_EQ(c_nosat(-ck_tile::numeric<SrcT>::max()) & nan_mask, nan_mask);
|
||||
// f32/f16 infinity -> f8 NaN
|
||||
EXPECT_EQ(c_nosat(+ck_tile::numeric<SrcT>::infinity()) & nan_mask, nan_mask);
|
||||
EXPECT_EQ(c_nosat(-ck_tile::numeric<SrcT>::infinity()) & nan_mask, nan_mask);
|
||||
// large f32/f16 -> f8 NaN
|
||||
EXPECT_EQ(c_nosat(+1.23e9f) & nan_mask, nan_mask);
|
||||
EXPECT_EQ(c_nosat(-1.23e9f) & nan_mask, nan_mask);
|
||||
|
||||
// f32/f16 NaN -> f8 NaN
|
||||
EXPECT_EQ(c(ck_tile::numeric<SrcT>::quiet_NaN()) & nan_mask, nan_mask);
|
||||
EXPECT_EQ(c(ck_tile::numeric<SrcT>::signaling_NaN()) & nan_mask, nan_mask);
|
||||
|
||||
// f32/f16 zero -> f8 zero
|
||||
EXPECT_EQ(c(+0.0f), 0b0'0000'000);
|
||||
EXPECT_EQ(c(-0.0f), 0b1'0000'000);
|
||||
// min f32/f16 normal -> f8 zero
|
||||
EXPECT_EQ(c(+ck_tile::numeric<SrcT>::min()), 0b0'0000'000);
|
||||
EXPECT_EQ(c(-ck_tile::numeric<SrcT>::min()), 0b1'0000'000);
|
||||
// min f32/f16 subnormal -> f8 zero
|
||||
EXPECT_EQ(c(+ck_tile::numeric<SrcT>::denorm_min()), 0b0'0000'000);
|
||||
EXPECT_EQ(c(-ck_tile::numeric<SrcT>::denorm_min()), 0b1'0000'000);
|
||||
|
||||
// All values smaller than min f8 subnormal must be converted to f8 zero
|
||||
constexpr int src_min_subnorm_exp =
|
||||
-(ck_tile::numeric_traits<SrcT>::bias + ck_tile::numeric_traits<SrcT>::mant - 1);
|
||||
constexpr int dst_min_subnorm_exp =
|
||||
-(ck_tile::numeric_traits<DstT>::bias + ck_tile::numeric_traits<DstT>::mant - 1);
|
||||
for(int exp = src_min_subnorm_exp; exp <= 0; ++exp)
|
||||
{
|
||||
const float f = std::ldexp(1.0, exp);
|
||||
if(exp < dst_min_subnorm_exp)
|
||||
{
|
||||
EXPECT_EQ(c(+f), 0b0'0000'000) << "+f = 2^" << exp << " = " << +f;
|
||||
EXPECT_EQ(c(-f), 0b1'0000'000) << "-f = 2^" << exp << " = " << -f;
|
||||
}
|
||||
else
|
||||
{
|
||||
EXPECT_GT(c(+f), 0b0'0000'000) << "+f = 2^" << exp << " = " << +f;
|
||||
EXPECT_GT(c(-f), 0b1'0000'000) << "-f = 2^" << exp << " = " << -f;
|
||||
}
|
||||
}
|
||||
#else // FNUZ
|
||||
EXPECT_EQ(c(+1.0f), 0b0'1000'000);
|
||||
EXPECT_EQ(c(-1.0f), 0b1'1000'000);
|
||||
// max f8 normal
|
||||
EXPECT_EQ(c(+240.0f), 0b0'1111'111);
|
||||
EXPECT_EQ(c(-240.0f), 0b1'1111'111);
|
||||
// min f8 normal
|
||||
EXPECT_EQ(c(+0.0078125f), 0b0'0001'000);
|
||||
EXPECT_EQ(c(-0.0078125f), 0b1'0001'000);
|
||||
// max f8 subnormal
|
||||
EXPECT_EQ(c(+0.0068359375f), 0b0'0000'111);
|
||||
EXPECT_EQ(c(-0.0068359375f), 0b1'0000'111);
|
||||
// min f8 subnormal
|
||||
EXPECT_EQ(c(+0.0009765625f), 0b0'0000'001);
|
||||
EXPECT_EQ(c(-0.0009765625f), 0b1'0000'001);
|
||||
// arbitrary values (exact)
|
||||
EXPECT_EQ(c(+0.1015625f), 0b0'0100'101);
|
||||
EXPECT_EQ(c(-44.0f), 0b1'1101'011);
|
||||
// arbitrary values (rounded)
|
||||
EXPECT_EQ(c(+219.91f), 0b0'1111'110);
|
||||
EXPECT_EQ(c(-203.11f), 0b1'1111'101);
|
||||
EXPECT_EQ(c(-0.3639f), 0b1'0110'100);
|
||||
EXPECT_EQ(c(+0.4139f), 0b0'0110'101);
|
||||
|
||||
// saturating mode -> max f8 normal
|
||||
// max f32/f16 normal -> max f8 normal
|
||||
EXPECT_EQ(c(+ck_tile::numeric<SrcT>::max()), 0b0'1111'111);
|
||||
EXPECT_EQ(c(-ck_tile::numeric<SrcT>::max()), 0b1'1111'111);
|
||||
// f32/f16 infinity -> max f8 normal
|
||||
EXPECT_EQ(c(+ck_tile::numeric<SrcT>::infinity()), 0b0'1111'111);
|
||||
EXPECT_EQ(c(-ck_tile::numeric<SrcT>::infinity()), 0b1'1111'111);
|
||||
// large f32/f16 -> max f8 normal
|
||||
EXPECT_EQ(c(+1.23e9f), 0b0'1111'111);
|
||||
EXPECT_EQ(c(-1.23e9f), 0b1'1111'111);
|
||||
|
||||
constexpr unsigned int nan_value = 0b1'0000'000;
|
||||
|
||||
// non-saturating mode -> f8 NaN (FN means "finite", so no infinity)
|
||||
// max f32/f16 normal -> f8 NaN
|
||||
EXPECT_EQ(c_nosat(+ck_tile::numeric<SrcT>::max()), nan_value);
|
||||
EXPECT_EQ(c_nosat(-ck_tile::numeric<SrcT>::max()), nan_value);
|
||||
// f32/f16 infinity -> f8 NaN
|
||||
EXPECT_EQ(c_nosat(+ck_tile::numeric<SrcT>::infinity()), nan_value);
|
||||
EXPECT_EQ(c_nosat(-ck_tile::numeric<SrcT>::infinity()), nan_value);
|
||||
// large f32/f16 -> f8 NaN
|
||||
EXPECT_EQ(c_nosat(+1.23e9f), nan_value);
|
||||
EXPECT_EQ(c_nosat(-1.23e9f), nan_value);
|
||||
|
||||
// f32/f16 NaN -> f8 NaN
|
||||
EXPECT_EQ(c(ck_tile::numeric<SrcT>::quiet_NaN()), nan_value);
|
||||
EXPECT_EQ(c(ck_tile::numeric<SrcT>::signaling_NaN()), nan_value);
|
||||
|
||||
// UZ means "unsigned zero" (0b1'0000'000 is NaN)
|
||||
// f32/f16 +-zero -> f8 +zero
|
||||
EXPECT_EQ(c(+0.0f), 0b0'0000'000);
|
||||
EXPECT_EQ(c(-0.0f), 0b0'0000'000);
|
||||
// min f32/f16 normal -> f8 +zero
|
||||
EXPECT_EQ(c(+ck_tile::numeric<SrcT>::min()), 0b0'0000'000);
|
||||
EXPECT_EQ(c(-ck_tile::numeric<SrcT>::min()), 0b0'0000'000);
|
||||
// min f32/f16 subnormal -> f8 +zero
|
||||
EXPECT_EQ(c(+ck_tile::numeric<SrcT>::denorm_min()), 0b0'0000'000);
|
||||
EXPECT_EQ(c(-ck_tile::numeric<SrcT>::denorm_min()), 0b0'0000'000);
|
||||
|
||||
// All values smaller than min f8 subnormal must be converted to f8 zero
|
||||
constexpr int src_min_subnorm_exp =
|
||||
-(ck_tile::numeric_traits<SrcT>::bias + ck_tile::numeric_traits<SrcT>::mant - 1);
|
||||
constexpr int dst_min_subnorm_exp =
|
||||
-(ck_tile::numeric_traits<DstT>::bias + ck_tile::numeric_traits<DstT>::mant - 1);
|
||||
for(int exp = src_min_subnorm_exp; exp <= 0; ++exp)
|
||||
{
|
||||
const float f = std::ldexp(1.0, exp);
|
||||
if(exp < dst_min_subnorm_exp)
|
||||
{
|
||||
EXPECT_EQ(c(+f), 0b0'0000'000) << "+f = 2^" << exp << " = " << +f;
|
||||
EXPECT_EQ(c(-f), 0b0'0000'000) << "-f = 2^" << exp << " = " << -f;
|
||||
}
|
||||
else
|
||||
{
|
||||
EXPECT_GT(c(+f), 0b0'0000'000) << "+f = 2^" << exp << " = " << +f;
|
||||
EXPECT_GT(c(-f), 0b0'0000'000) << "-f = 2^" << exp << " = " << -f;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
TYPED_TEST(ConvertTest, ToBf8)
|
||||
{
|
||||
using SrcT = TypeParam;
|
||||
using DstT = ck_tile::bf8_t;
|
||||
|
||||
auto c = [](SrcT f) {
|
||||
return static_cast<unsigned int>(
|
||||
ck_tile::bit_cast<uint8_t>(ck_tile::impl::run_cast_to_f8<SrcT, DstT, true>(f)));
|
||||
};
|
||||
|
||||
auto c_nosat = [](SrcT f) {
|
||||
return static_cast<unsigned int>(
|
||||
ck_tile::bit_cast<uint8_t>(ck_tile::impl::run_cast_to_f8<SrcT, DstT, false>(f)));
|
||||
};
|
||||
|
||||
#if CK_TILE_USE_OCP_FP8
|
||||
EXPECT_EQ(c(+1.0f), 0b0'01111'00);
|
||||
EXPECT_EQ(c(-1.0f), 0b1'01111'00);
|
||||
// max f8 normal
|
||||
EXPECT_EQ(c(+57344.0f), 0b0'11110'11);
|
||||
EXPECT_EQ(c(-57344.0f), 0b1'11110'11);
|
||||
// min f8 normal
|
||||
EXPECT_EQ(c(+6.103515625e-05f), 0b0'00001'00);
|
||||
EXPECT_EQ(c(-6.103515625e-05f), 0b1'00001'00);
|
||||
// max f8 subnormal
|
||||
EXPECT_EQ(c(+4.57763671875e-05f), 0b0'00000'11);
|
||||
EXPECT_EQ(c(-4.57763671875e-05f), 0b1'00000'11);
|
||||
// min f8 subnormal
|
||||
EXPECT_EQ(c(+1.52587890625e-05f), 0b0'00000'01);
|
||||
EXPECT_EQ(c(-1.52587890625e-05f), 0b1'00000'01);
|
||||
// arbitrary values (exact)
|
||||
EXPECT_EQ(c(+0.01953125f), 0b0'01001'01);
|
||||
EXPECT_EQ(c(-3584.0f), 0b1'11010'11);
|
||||
// arbitrary values (rounded)
|
||||
EXPECT_EQ(c(+2030.56f), 0b0'11010'00);
|
||||
EXPECT_EQ(c(-1801.33f), 0b1'11001'11);
|
||||
EXPECT_EQ(c(-0.27891f), 0b1'0110'100);
|
||||
EXPECT_EQ(c(+0.33333f), 0b0'0110'101);
|
||||
|
||||
// saturating mode -> max f8 normal
|
||||
// max f32/f16 normal -> max f8 normal
|
||||
EXPECT_EQ(c(+ck_tile::numeric<SrcT>::max()), 0b0'11110'11);
|
||||
EXPECT_EQ(c(-ck_tile::numeric<SrcT>::max()), 0b1'11110'11);
|
||||
// f32/f16 infinity -> max f8 normal
|
||||
EXPECT_EQ(c(+ck_tile::numeric<SrcT>::infinity()), 0b0'11110'11);
|
||||
EXPECT_EQ(c(-ck_tile::numeric<SrcT>::infinity()), 0b1'11110'11);
|
||||
// large f32/f16 -> max f8 normal
|
||||
EXPECT_EQ(c(+1.23e9f), 0b0'11110'11);
|
||||
EXPECT_EQ(c(-1.23e9f), 0b1'11110'11);
|
||||
|
||||
// non-saturating mode -> f8 infinity
|
||||
// max f32/f16 normal -> f8 infinity
|
||||
EXPECT_EQ(c_nosat(+ck_tile::numeric<SrcT>::max()), 0b0'11111'00);
|
||||
EXPECT_EQ(c_nosat(-ck_tile::numeric<SrcT>::max()), 0b1'11111'00);
|
||||
// f32/f16 infinity -> f8 infinity
|
||||
EXPECT_EQ(c_nosat(+ck_tile::numeric<SrcT>::infinity()), 0b0'11111'00);
|
||||
EXPECT_EQ(c_nosat(-ck_tile::numeric<SrcT>::infinity()), 0b1'11111'00);
|
||||
// large f32/f16 -> f8 infinity
|
||||
EXPECT_EQ(c_nosat(+1.23e9f), 0b0'11111'00);
|
||||
EXPECT_EQ(c_nosat(-1.23e9f), 0b1'11111'00);
|
||||
|
||||
// f32/f16 NaN -> f8 NaN
|
||||
EXPECT_TRUE((c(ck_tile::numeric<SrcT>::quiet_NaN()) & 0b0'11111'11) != 0b0'11111'00);
|
||||
EXPECT_TRUE((c(ck_tile::numeric<SrcT>::signaling_NaN()) & 0b0'11111'11) != 0b0'11111'00);
|
||||
|
||||
// f32/f16 zero -> f8 zero
|
||||
EXPECT_EQ(c(+0.0f), 0b0'00000'00);
|
||||
EXPECT_EQ(c(-0.0f), 0b1'00000'00);
|
||||
if constexpr(std::is_same_v<SrcT, float>)
|
||||
{
|
||||
// min f32 normal -> f8 zero
|
||||
EXPECT_EQ(c(+ck_tile::numeric<SrcT>::min()), 0b0'00000'00);
|
||||
EXPECT_EQ(c(-ck_tile::numeric<SrcT>::min()), 0b1'00000'00);
|
||||
}
|
||||
else
|
||||
{
|
||||
// min f16 normal -> min f8 normal (they are equal)
|
||||
EXPECT_EQ(c(+ck_tile::numeric<SrcT>::min()), 0b0'00001'00);
|
||||
EXPECT_EQ(c(-ck_tile::numeric<SrcT>::min()), 0b1'00001'00);
|
||||
}
|
||||
// min f32/f16 subnormal -> f8 zero
|
||||
EXPECT_EQ(c(+ck_tile::numeric<SrcT>::denorm_min()), 0b0'00000'00);
|
||||
EXPECT_EQ(c(-ck_tile::numeric<SrcT>::denorm_min()), 0b1'00000'00);
|
||||
|
||||
// All values smaller than min f8 subnormal must be converted to f8 zero
|
||||
constexpr int src_min_subnorm_exp =
|
||||
-(ck_tile::numeric_traits<SrcT>::bias + ck_tile::numeric_traits<SrcT>::mant - 1);
|
||||
constexpr int dst_min_subnorm_exp =
|
||||
-(ck_tile::numeric_traits<DstT>::bias + ck_tile::numeric_traits<DstT>::mant - 1);
|
||||
for(int exp = src_min_subnorm_exp; exp <= 0; ++exp)
|
||||
{
|
||||
const float f = std::ldexp(1.0, exp);
|
||||
if(exp < dst_min_subnorm_exp)
|
||||
{
|
||||
EXPECT_EQ(c(+f), 0b0'00000'00) << "+f = 2^" << exp << " = " << +f;
|
||||
EXPECT_EQ(c(-f), 0b1'00000'00) << "-f = 2^" << exp << " = " << -f;
|
||||
}
|
||||
else
|
||||
{
|
||||
EXPECT_GT(c(+f), 0b0'00000'00) << "+f = 2^" << exp << " = " << +f;
|
||||
EXPECT_GT(c(-f), 0b1'00000'00) << "-f = 2^" << exp << " = " << -f;
|
||||
}
|
||||
}
|
||||
#else // FNUZ
|
||||
EXPECT_EQ(c(+1.0f), 0b0'10000'00);
|
||||
EXPECT_EQ(c(-1.0f), 0b1'10000'00);
|
||||
// max f8 normal
|
||||
EXPECT_EQ(c(+57344.0f), 0b0'11111'11);
|
||||
EXPECT_EQ(c(-57344.0f), 0b1'11111'11);
|
||||
// min f8 normal
|
||||
EXPECT_EQ(c(+3.0517578125e-05f), 0b0'00001'00);
|
||||
EXPECT_EQ(c(-3.0517578125e-05f), 0b1'00001'00);
|
||||
// max f8 subnormal
|
||||
EXPECT_EQ(c(+2.288818359375e-05f), 0b0'00000'11);
|
||||
EXPECT_EQ(c(-2.288818359375e-05f), 0b1'00000'11);
|
||||
// min f8 subnormal
|
||||
EXPECT_EQ(c(+7.62939453125e-06f), 0b0'00000'01);
|
||||
EXPECT_EQ(c(-7.62939453125e-06f), 0b1'00000'01);
|
||||
// arbitrary values (exact)
|
||||
EXPECT_EQ(c(+0.009765625f), 0b0'01001'01);
|
||||
EXPECT_EQ(c(-1792.0f), 0b1'11010'11);
|
||||
// arbitrary values (rounded)
|
||||
EXPECT_EQ(c(+840.100f), 0b0'11001'11);
|
||||
EXPECT_EQ(c(-999.999f), 0b1'11010'00);
|
||||
EXPECT_EQ(c(-0.12789f), 0b1'0110'100);
|
||||
EXPECT_EQ(c(+0.14444f), 0b0'0110'101);
|
||||
|
||||
// saturating mode -> max f8 normal
|
||||
// max f32/f16 normal -> max f8 normal
|
||||
EXPECT_EQ(c(+ck_tile::numeric<SrcT>::max()), 0b0'11111'11);
|
||||
EXPECT_EQ(c(-ck_tile::numeric<SrcT>::max()), 0b1'1111'111);
|
||||
// f32/f16 infinity -> max f8 normal
|
||||
EXPECT_EQ(c(+ck_tile::numeric<SrcT>::infinity()), 0b0'11111'11);
|
||||
EXPECT_EQ(c(-ck_tile::numeric<SrcT>::infinity()), 0b1'1111'111);
|
||||
// large f32/f16 -> max f8 normal
|
||||
EXPECT_EQ(c(+1.23e9f), 0b0'11111'11);
|
||||
EXPECT_EQ(c(-1.23e9f), 0b1'1111'111);
|
||||
|
||||
constexpr unsigned int nan_value = 0b1'00000'00;
|
||||
|
||||
// non-saturating mode -> f8 NaN (FN means "finite", so no infinity)
|
||||
// max f32/f16 normal -> f8 NaN
|
||||
EXPECT_EQ(c_nosat(+ck_tile::numeric<SrcT>::max()), nan_value);
|
||||
EXPECT_EQ(c_nosat(-ck_tile::numeric<SrcT>::max()), nan_value);
|
||||
// f32/f16 infinity -> f8 NaN
|
||||
EXPECT_EQ(c_nosat(+ck_tile::numeric<SrcT>::infinity()), nan_value);
|
||||
EXPECT_EQ(c_nosat(-ck_tile::numeric<SrcT>::infinity()), nan_value);
|
||||
// large f32/f16 -> f8 NaN
|
||||
EXPECT_EQ(c_nosat(+1.23e9f), nan_value);
|
||||
EXPECT_EQ(c_nosat(-1.23e9f), nan_value);
|
||||
|
||||
// f32/f16 NaN -> f8 NaN
|
||||
EXPECT_EQ(c(ck_tile::numeric<SrcT>::quiet_NaN()), nan_value);
|
||||
EXPECT_EQ(c(ck_tile::numeric<SrcT>::signaling_NaN()), nan_value);
|
||||
|
||||
// UZ means "unsigned zero" (0b1'00000'00 is NaN)
|
||||
// f32/f16 +-zero -> f8 +zero
|
||||
EXPECT_EQ(c(+0.0f), 0b0'00000'00);
|
||||
EXPECT_EQ(c(-0.0f), 0b0'00000'00);
|
||||
if constexpr(std::is_same_v<SrcT, float>)
|
||||
{
|
||||
// min f32 normal -> f8 +zero
|
||||
EXPECT_EQ(c(+ck_tile::numeric<SrcT>::min()), 0b0'00000'00);
|
||||
EXPECT_EQ(c(-ck_tile::numeric<SrcT>::min()), 0b0'00000'00);
|
||||
}
|
||||
else
|
||||
{
|
||||
// min f16 normal -> f8 normal
|
||||
EXPECT_EQ(c(+ck_tile::numeric<SrcT>::min()), 0b0'00010'00);
|
||||
EXPECT_EQ(c(-ck_tile::numeric<SrcT>::min()), 0b1'00010'00);
|
||||
}
|
||||
// min f32/f16 subnormal -> f8 +zero
|
||||
EXPECT_EQ(c(+ck_tile::numeric<SrcT>::denorm_min()), 0b0'00000'00);
|
||||
EXPECT_EQ(c(-ck_tile::numeric<SrcT>::denorm_min()), 0b0'00000'00);
|
||||
|
||||
// All values smaller than min f8 subnormal must be converted to f8 zero
|
||||
constexpr int src_min_subnorm_exp =
|
||||
-(ck_tile::numeric_traits<SrcT>::bias + ck_tile::numeric_traits<SrcT>::mant - 1);
|
||||
constexpr int dst_min_subnorm_exp =
|
||||
-(ck_tile::numeric_traits<DstT>::bias + ck_tile::numeric_traits<DstT>::mant - 1);
|
||||
for(int exp = src_min_subnorm_exp; exp <= 0; ++exp)
|
||||
{
|
||||
const float f = std::ldexp(1.0, exp);
|
||||
if(exp < dst_min_subnorm_exp)
|
||||
{
|
||||
EXPECT_EQ(c(+f), 0b0'00000'00) << "+f = 2^" << exp << " = " << +f;
|
||||
EXPECT_EQ(c(-f), 0b0'00000'00) << "-f = 2^" << exp << " = " << -f;
|
||||
}
|
||||
else
|
||||
{
|
||||
EXPECT_GT(c(+f), 0b0'00000'00) << "+f = 2^" << exp << " = " << +f;
|
||||
EXPECT_GT(c(-f), 0b0'00000'00) << "-f = 2^" << exp << " = " << -f;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
TYPED_TEST(ConvertTest, FromFp8)
|
||||
{
|
||||
using SrcT = ck_tile::fp8_t;
|
||||
using DstT = TypeParam;
|
||||
|
||||
auto c = [](uint8_t u) {
|
||||
return ck_tile::type_convert<float>(
|
||||
ck_tile::impl::run_cast_from_f8<SrcT, DstT, true>(ck_tile::bit_cast<SrcT>(u)));
|
||||
};
|
||||
|
||||
#if CK_TILE_USE_OCP_FP8
|
||||
EXPECT_EQ(c(0b0'0111'000), +1.0f);
|
||||
EXPECT_EQ(c(0b1'0111'000), -1.0f);
|
||||
// max f8 normal
|
||||
EXPECT_EQ(c(0b0'1111'110), +448.0f);
|
||||
EXPECT_EQ(c(0b1'1111'110), -448.0f);
|
||||
// min f8 normal
|
||||
EXPECT_EQ(c(0b0'0001'000), +0.015625f);
|
||||
EXPECT_EQ(c(0b1'0001'000), -0.015625f);
|
||||
// max f8 subnormal
|
||||
EXPECT_EQ(c(0b0'0000'111), +0.013671875f);
|
||||
EXPECT_EQ(c(0b1'0000'111), -0.013671875f);
|
||||
// min f8 subnormal
|
||||
EXPECT_EQ(c(0b0'0000'001), +0.001953125f);
|
||||
EXPECT_EQ(c(0b1'0000'001), -0.001953125f);
|
||||
// arbitrary values
|
||||
EXPECT_EQ(c(0b0'0100'101), +0.203125f);
|
||||
EXPECT_EQ(c(0b1'1101'011), -88.0f);
|
||||
|
||||
// f8 NaN -> f32/f16 NaN
|
||||
EXPECT_TRUE(ck_tile::isnan(c(0b0'1111'111)));
|
||||
EXPECT_TRUE(ck_tile::isnan(c(0b1'1111'111)));
|
||||
|
||||
// f8 zero -> f32/f16 zero (sign is preserved)
|
||||
EXPECT_EQ(c(0b0'0000'000),
|
||||
ck_tile::bit_cast<DstT>(typename ck_tile::numeric_traits<DstT>::bitwise_type{0}));
|
||||
EXPECT_EQ(c(0b1'0000'000), ck_tile::bit_cast<DstT>(ck_tile::numeric_traits<DstT>::Neg0));
|
||||
#else // FNUZ
|
||||
EXPECT_EQ(c(0b0'1000'000), +1.0f);
|
||||
EXPECT_EQ(c(0b1'1000'000), -1.0f);
|
||||
// max f8 normal
|
||||
EXPECT_EQ(c(0b0'1111'111), +240.0f);
|
||||
EXPECT_EQ(c(0b1'1111'111), -240.0f);
|
||||
// min f8 normal
|
||||
EXPECT_EQ(c(0b0'0001'000), +0.0078125f);
|
||||
EXPECT_EQ(c(0b1'0001'000), -0.0078125f);
|
||||
// max f8 subnormal
|
||||
EXPECT_EQ(c(0b0'0000'111), +0.0068359375f);
|
||||
EXPECT_EQ(c(0b1'0000'111), -0.0068359375f);
|
||||
// min f8 subnormal
|
||||
EXPECT_EQ(c(0b0'0000'001), +0.0009765625f);
|
||||
EXPECT_EQ(c(0b1'0000'001), -0.0009765625f);
|
||||
// arbitrary values
|
||||
EXPECT_EQ(c(0b0'0100'101), +0.1015625f);
|
||||
EXPECT_EQ(c(0b1'1101'011), -44.0f);
|
||||
|
||||
// f8 NaN -> f32/f16 NaN
|
||||
EXPECT_TRUE(ck_tile::isnan(c(0b1'0000'000)));
|
||||
|
||||
// UZ means "unsigned zero" (0b1'0000'000 is NaN)
|
||||
// f8 +zero -> f32/f16 +zero
|
||||
EXPECT_EQ(c(0b0'0000'000),
|
||||
ck_tile::bit_cast<DstT>(typename ck_tile::numeric_traits<DstT>::bitwise_type{0}));
|
||||
#endif
|
||||
}
|
||||
|
||||
TYPED_TEST(ConvertTest, FromBf8)
|
||||
{
|
||||
using SrcT = ck_tile::bf8_t;
|
||||
using DstT = TypeParam;
|
||||
|
||||
using DstT = TypeParam;
|
||||
|
||||
auto c = [](uint8_t u) {
|
||||
return ck_tile::type_convert<float>(
|
||||
ck_tile::impl::run_cast_from_f8<SrcT, DstT, true>(ck_tile::bit_cast<SrcT>(u)));
|
||||
};
|
||||
|
||||
#if CK_TILE_USE_OCP_FP8
|
||||
auto c_nosat = [](uint8_t u) {
|
||||
return ck_tile::type_convert<float>(
|
||||
ck_tile::impl::run_cast_from_f8<SrcT, DstT, false>(ck_tile::bit_cast<SrcT>(u)));
|
||||
};
|
||||
|
||||
EXPECT_EQ(c(0b0'01111'00), +1.0f);
|
||||
EXPECT_EQ(c(0b1'01111'00), -1.0f);
|
||||
// max f8 normal
|
||||
EXPECT_EQ(c(0b0'11110'11), +57344.0f);
|
||||
EXPECT_EQ(c(0b1'11110'11), -57344.0f);
|
||||
// min f8 normal
|
||||
EXPECT_EQ(c(0b0'00001'00), +6.103515625e-05f);
|
||||
EXPECT_EQ(c(0b1'00001'00), -6.103515625e-05f);
|
||||
// max f8 subnormal
|
||||
EXPECT_EQ(c(0b0'00000'11), +4.57763671875e-05f);
|
||||
EXPECT_EQ(c(0b1'00000'11), -4.57763671875e-05f);
|
||||
// min f8 subnormal
|
||||
EXPECT_EQ(c(0b0'00000'01), +1.52587890625e-05f);
|
||||
EXPECT_EQ(c(0b1'00000'01), -1.52587890625e-05f);
|
||||
// arbitrary values
|
||||
EXPECT_EQ(c(0b0'01001'01), +0.01953125f);
|
||||
EXPECT_EQ(c(0b1'11010'11), -3584.0f);
|
||||
|
||||
// saturating mode
|
||||
// f8 infinity -> max f8 normal as f32/f16
|
||||
EXPECT_EQ(c(0b0'11111'00), +57344.0f);
|
||||
EXPECT_EQ(c(0b1'11111'00), -57344.0f);
|
||||
|
||||
// non-saturating mode
|
||||
// f8 infinity -> f32/f16 infinity
|
||||
EXPECT_EQ(c_nosat(0b0'11111'00), +ck_tile::numeric<DstT>::infinity());
|
||||
EXPECT_EQ(c_nosat(0b1'11111'00), -ck_tile::numeric<DstT>::infinity());
|
||||
|
||||
// f8 NaN -> f32/f16 NaN
|
||||
EXPECT_TRUE(ck_tile::isnan(c(0b0'11111'01)));
|
||||
EXPECT_TRUE(ck_tile::isnan(c(0b0'11111'10)));
|
||||
EXPECT_TRUE(ck_tile::isnan(c(0b0'11111'11)));
|
||||
EXPECT_TRUE(ck_tile::isnan(c(0b1'11111'01)));
|
||||
EXPECT_TRUE(ck_tile::isnan(c(0b1'11111'10)));
|
||||
EXPECT_TRUE(ck_tile::isnan(c(0b1'11111'11)));
|
||||
|
||||
// f8 zero -> f32/f16 zero (sign is preserved)
|
||||
EXPECT_EQ(c(0b0'00000'00),
|
||||
ck_tile::bit_cast<DstT>(typename ck_tile::numeric_traits<DstT>::bitwise_type{0}));
|
||||
EXPECT_EQ(c(0b1'00000'00), ck_tile::bit_cast<DstT>(ck_tile::numeric_traits<DstT>::Neg0));
|
||||
if constexpr(std::is_same_v<DstT, ck_tile::fp16_t>)
|
||||
{
|
||||
// min f8 normal -> min f16 normal (they are equal)
|
||||
EXPECT_EQ(c(0b0'00001'00), +ck_tile::numeric<DstT>::min());
|
||||
EXPECT_EQ(c(0b1'00001'00), -ck_tile::numeric<DstT>::min());
|
||||
}
|
||||
#else // FNUZ
|
||||
EXPECT_EQ(c(0b0'10000'00), +1.0f);
|
||||
EXPECT_EQ(c(0b1'10000'00), -1.0f);
|
||||
// max f8 normal
|
||||
EXPECT_EQ(c(0b0'11111'11), +57344.0f);
|
||||
EXPECT_EQ(c(0b1'11111'11), -57344.0f);
|
||||
// min f8 normal
|
||||
EXPECT_EQ(c(0b0'00001'00), +3.0517578125e-05f);
|
||||
EXPECT_EQ(c(0b1'00001'00), -3.0517578125e-05f);
|
||||
// max f8 subnormal
|
||||
EXPECT_EQ(c(0b0'00000'11), +2.288818359375e-05f);
|
||||
EXPECT_EQ(c(0b1'00000'11), -2.288818359375e-05f);
|
||||
// min f8 subnormal
|
||||
EXPECT_EQ(c(0b0'00000'01), +7.62939453125e-06f);
|
||||
EXPECT_EQ(c(0b1'00000'01), -7.62939453125e-06f);
|
||||
// arbitrary values
|
||||
EXPECT_EQ(c(0b0'01001'01), +0.009765625f);
|
||||
EXPECT_EQ(c(0b1'11010'11), -1792.0f);
|
||||
|
||||
// f8 NaN -> f32/f16 NaN
|
||||
EXPECT_TRUE(ck_tile::isnan(c(0b1'00000'00)));
|
||||
|
||||
// UZ means "unsigned zero" (0b1'00000'00 is NaN)
|
||||
// f8 +zero -> f32/f16 +zero
|
||||
EXPECT_EQ(c(0b0'00000'00),
|
||||
ck_tile::bit_cast<DstT>(typename ck_tile::numeric_traits<DstT>::bitwise_type{0}));
|
||||
if constexpr(std::is_same_v<DstT, ck_tile::fp16_t>)
|
||||
{
|
||||
// one of f8 normals -> min f16 normal
|
||||
EXPECT_EQ(c(0b0'00010'00), +ck_tile::numeric<DstT>::min());
|
||||
EXPECT_EQ(c(0b1'00010'00), -ck_tile::numeric<DstT>::min());
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
// Convert f8 -> f32/f16 -> f8 to check if all values are covered
|
||||
// OCP types multiple NaN representations (e4m3 - 2, e5m2 - 6), they are ignored for simplicity.
|
||||
|
||||
TYPED_TEST(ConvertTest, FromFp8AndToFp8)
|
||||
{
|
||||
using SrcT = ck_tile::fp8_t;
|
||||
using DstT = TypeParam;
|
||||
|
||||
for(int i = 0; i < 256; ++i)
|
||||
{
|
||||
#if CK_TILE_USE_OCP_FP8
|
||||
if((i & 0b0'1111'111) == 0b0'1111'111)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
#endif
|
||||
const uint8_t u = static_cast<uint8_t>(i);
|
||||
const SrcT from = ck_tile::bit_cast<SrcT>(u);
|
||||
const DstT f = ck_tile::impl::run_cast_from_f8<SrcT, DstT, false>(from);
|
||||
const SrcT to = ck_tile::impl::run_cast_to_f8<DstT, SrcT, false>(f);
|
||||
EXPECT_EQ(from, to) << "u8: " << i << " f32/f16: " << ck_tile::type_convert<float>(f);
|
||||
}
|
||||
}
|
||||
|
||||
TYPED_TEST(ConvertTest, FromBf8AndToBf8)
|
||||
{
|
||||
using SrcT = ck_tile::bf8_t;
|
||||
using DstT = TypeParam;
|
||||
|
||||
for(int i = 0; i < 256; ++i)
|
||||
{
|
||||
#if CK_TILE_USE_OCP_FP8
|
||||
if((i & 0b0'11111'11) > 0b0'11111'00)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
#endif
|
||||
const uint8_t u = static_cast<uint8_t>(i);
|
||||
const SrcT from = ck_tile::bit_cast<SrcT>(u);
|
||||
const DstT f = ck_tile::impl::run_cast_from_f8<SrcT, DstT, false>(from);
|
||||
const SrcT to = ck_tile::impl::run_cast_to_f8<DstT, SrcT, false>(f);
|
||||
EXPECT_EQ(from, to) << "u8: " << i << " f32/f16: " << ck_tile::type_convert<float>(f);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user