From 2ef590ab438163f33023f931fcf22854ec7cb38f Mon Sep 17 00:00:00 2001 From: Anton Gorenko Date: Thu, 31 Jul 2025 10:54:17 +0600 Subject: [PATCH] [CK_TILE] Fix UB and corner cases in f32/f16 to/from f8 conversion (#2571) * Add tests for host convesion f32/f16 to f8 * Add tests for host convesion from f8 to f32/f16 * Fix UB and corner cases in f32/f16 to/from f8 conversion * There are UBs when very small values are converted to f8: bitshifts can be larger that type width. Using unsigned long long does not help because exponent_diff >= 64 in such cases. This causes that values like 2.117582368e-22 are converted to non-zero f8 in host validation of FMHA tests, test_f8 crashes with segfault in completely irrelevant code like GTest internals or produces non-deterministic results etc. * Fix FNUZ conversion to return NaN for NaN inputs. * Fix compilation error (due to uint8_t << 8) in OCP e5m2 to f16 conversion. * Replace some magic numbers with values from numeric_traits * Build tests only on devices supporting the type [ROCm/composable_kernel commit: 7b074249f44c4fda2ed71e2f4059f80806476424] --- include/ck_tile/core/numeric/float8.hpp | 93 ++-- test/ck_tile/data_type/CMakeLists.txt | 14 +- test/ck_tile/data_type/test_fp8.cpp | 606 ++++++++++++++++++++++++ 3 files changed, 663 insertions(+), 50 deletions(-) create mode 100644 test/ck_tile/data_type/test_fp8.cpp diff --git a/include/ck_tile/core/numeric/float8.hpp b/include/ck_tile/core/numeric/float8.hpp index a3ce614f84..04ca950641 100644 --- a/include/ck_tile/core/numeric/float8.hpp +++ b/include/ck_tile/core/numeric/float8.hpp @@ -43,19 +43,19 @@ enum class fp8_interpretation }; /* - * ______________FNUZ_________________ | ______________OCP________________ + * ______________FNUZ_________________ | ______________OCP________________ * e4m3 e5m2 | e4m3 e5m2 * bias : 8 16 | 7 15 - * inf : 1.0000.000 1.00000.00 | N/A s.11111.00 + * inf : N/A N/A | N/A s.11111.00 * Nan : 1.0000.000 1.00000.00 | s.1111.111 s.11111.{01, 10, 11} * zero : 0.0000.000 0.00000.00 | s.0000.000 s.00000.00 * Max(norm) : s.1111.111 (240) s.11111.11(57344) | s.1111.110(448) s.11110.11(57344) * Max(snorm): s.0000.111 s.00000.11 | s.0000.111 s.00000.11 - * 0.0068359375 2.288818e-05 | 0.013671875 4.57763671875e-05 + * 0.0068359375 2.288818e-05 | 0.013671875 4.57763671875e-05 * Min(norm) : s.0001.000 s.00001.00 | s.0001.000 s.00001.00 - * 2^-7(0.00078125) 2^-15(3.05176e-05) | 2^-6(0.015625) 2^-14(6.10352e-05) + * 2^-7(0.0078125) 2^-15(3.05176e-05) | 2^-6(0.015625) 2^-14(6.10352e-05) * Min(snorm): s.0000.001 s.00000.01 | s.0000.001 s.00000.01 - * 2^-10(0.00097656) 2^-17(7.629395e-06)| 2^-9(0.001953125) 2^-16(1.52588e-05) + * 2^-10(0.0009765625) 2^-17(7.62939e-06) | 2^-9(0.001953125) 2^-16(1.52588e-05) */ template (CK_TILE_FLOAT_TO_FP8_DEFAULT)> @@ -259,50 +259,50 @@ CK_TILE_HOST_DEVICE DstT run_cast_to_f8(SrcT src, unsigned int rng = 0) // fp8/bf8 type exponent/mantissa layout constexpr int DstT_exp = numeric_traits::exp; // exponent width of the destination type constexpr int DstT_mant = numeric_traits::mant; // mantissa width of the destination type + constexpr int DstT_bias = numeric_traits::bias; constexpr bool is_fnuz = (numeric_traits::f8_interpret == fp8_interpretation::E4M3_FNUZ) || (numeric_traits::f8_interpret == fp8_interpretation::E5M2_FNUZ); - constexpr int SrcT_exp = numeric_traits::exp; - constexpr int SrcT_mant = numeric_traits::mant; + constexpr int SrcT_exp = numeric_traits::exp; + constexpr int SrcT_mant = numeric_traits::mant; + constexpr int bias = numeric_traits::bias; + constexpr unsigned int fInf = numeric_traits::Inf; + constexpr unsigned int abs_mask = numeric_traits::abs_mask; using SrcT_bitwise = typename numeric_traits::bitwise_type; SrcT_bitwise src_bitwise = bit_cast(src); - unsigned long long head, mantissa; - int exponent, bias; + unsigned int head, mantissa; + int exponent; unsigned int sign; - unsigned long long fInf, abs_mask; head = src_bitwise & numeric_traits::head_mask; mantissa = src_bitwise & numeric_traits::mant_mask; exponent = (head >> SrcT_mant) & numeric_traits::exp_mask; sign = head >> (SrcT_exp + SrcT_mant); - bias = numeric_traits::bias; - fInf = numeric_traits::Inf; - abs_mask = numeric_traits::abs_mask; unsigned int signed_inf = 0; unsigned int nan = 0; if constexpr(is_fnuz) { - signed_inf = clip ? ((sign << 7) + 0x7f) : 0x80; + signed_inf = clip ? ((sign << (DstT_exp + DstT_mant)) + 0x7f) : 0x80; nan = 0x80; } else { if constexpr(DstT_exp == 4) { // e4m3 - signed_inf = (sign << 7) + (clip ? 0x7e : 0x7f); + signed_inf = (sign << (DstT_exp + DstT_mant)) + (clip ? 0x7e : 0x7f); } else { // e5m2 - signed_inf = (sign << 7) + (clip ? 0x7b : 0x7c); + signed_inf = (sign << (DstT_exp + DstT_mant)) + (clip ? 0x7b : 0x7c); } - nan = (sign << 7) + 0x7f; + nan = (sign << (DstT_exp + DstT_mant)) + 0x7f; } // Max values - unsigned long long ifmax = 0; + unsigned int ifmax = 0; if constexpr(is_float) { if constexpr(DstT_exp == 5) @@ -343,9 +343,6 @@ CK_TILE_HOST_DEVICE DstT run_cast_to_f8(SrcT src, unsigned int rng = 0) // Deal with inf and NaNs if((src_bitwise & fInf) == fInf) { - if constexpr(is_fnuz) - return signed_inf; - return mantissa != 0 ? nan : signed_inf; } @@ -354,11 +351,6 @@ CK_TILE_HOST_DEVICE DstT run_cast_to_f8(SrcT src, unsigned int rng = 0) return signed_inf; } - if(src_bitwise == 0) - { - return 0; - } - // First need to check if it is normal or denorm as there is a difference of // implicit 1 Then need to adjust the exponent to align with the F8 exponent, // in the meanwhile, shift The mantissa. Then for stochastic rounding, add rng @@ -367,8 +359,7 @@ CK_TILE_HOST_DEVICE DstT run_cast_to_f8(SrcT src, unsigned int rng = 0) // For IEEE bias mode, the bias is 2^(k-1) -1 where k is the width of exponent // bits - const int f8_bias = (1 << (DstT_exp - 1)) - 1 + (is_fnuz ? 1 : 0); - const int f8_denormal_act_exponent = 1 - f8_bias; // actual exponent of f8 denormal + constexpr int f8_denormal_act_exponent = 1 - DstT_bias; // actual exponent of f8 denormal // act_exponent is the actual exponent of fp32/fp16 (after subtracting bias) // f8_exponent is the converted f8 exponent with bias encoding // exponent_diff is the diff between fp32/fp16 exponent and f8 exponent, @@ -406,11 +397,16 @@ CK_TILE_HOST_DEVICE DstT run_cast_to_f8(SrcT src, unsigned int rng = 0) // for this case, act_exponent could be larger. Just // that it does not need shift mantissa } - mantissa += (1ull << SrcT_mant); // Add the implicit 1 into mantissa + mantissa += (1u << SrcT_mant); // Add the implicit 1 into mantissa } - - bool midpoint = (mantissa & ((1ull << (SrcT_mant - DstT_mant + exponent_diff)) - 1)) == - (1ull << (SrcT_mant - DstT_mant + exponent_diff - 1)); + // The value is smaller than min f8 denormal and results in zero (the early exit also prevents + // an undefined behavior of bit shifts >= type width). + if(exponent_diff > DstT_mant) + { + return is_fnuz ? 0 : (sign << (DstT_exp + DstT_mant)); + } + bool midpoint = (mantissa & ((1u << (SrcT_mant - DstT_mant + exponent_diff)) - 1)) == + (1u << (SrcT_mant - DstT_mant + exponent_diff - 1)); /* This part is a bit tricky. The judgment of whether it is a tie needs to be done before we shift right as shift right could rip off some residual part and make something not midpoint look like midpoint. For example, the fp16 number @@ -422,31 +418,31 @@ CK_TILE_HOST_DEVICE DstT run_cast_to_f8(SrcT src, unsigned int rng = 0) mantissa >>= exponent_diff; else if(exponent_diff == -1) mantissa <<= -exponent_diff; - bool implicit_one = mantissa & (1ull << SrcT_mant); + bool implicit_one = mantissa & (1u << SrcT_mant); // if there is no implicit 1, it means the f8 is denormal and need to adjust // to denorm exponent f8_exponent = - (act_exponent + exponent_diff) /*actual f8 exponent*/ + f8_bias - (implicit_one ? 0 : 1); + (act_exponent + exponent_diff) /*actual f8 exponent*/ + DstT_bias - (implicit_one ? 0 : 1); // Now we have the exponent and mantissa adjusted - unsigned long long drop_mask = (1ull << (SrcT_mant - DstT_mant)) - 1; + unsigned int drop_mask = (1u << (SrcT_mant - DstT_mant)) - 1; bool odd = - mantissa & (1ull << (SrcT_mant - - DstT_mant)); // if the least significant bit that is not truncated is 1 + mantissa & + (1u << (SrcT_mant - DstT_mant)); // if the least significant bit that is not truncated is 1 mantissa += - (stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1ull) : mantissa)) & drop_mask; + (stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1u) : mantissa)) & drop_mask; // Now we deal with overflow if(f8_exponent == 0) { - if((1ull << SrcT_mant) & mantissa) + if((1u << SrcT_mant) & mantissa) { f8_exponent = 1; // denormal overflow to become normal, promote exponent } } else { - if((1ull << (SrcT_mant + 1)) & mantissa) + if((1u << (SrcT_mant + 1)) & mantissa) { mantissa >>= 1; f8_exponent++; @@ -471,9 +467,9 @@ CK_TILE_HOST_DEVICE DstT run_cast_to_f8(SrcT src, unsigned int rng = 0) } if(f8_exponent == 0 && mantissa == 0) - return is_fnuz ? 0 : (sign << 7); + return is_fnuz ? 0 : (sign << (DstT_exp + DstT_mant)); mantissa &= (1 << DstT_mant) - 1; - return (sign << 7) | (f8_exponent << DstT_mant) | mantissa; + return (sign << (DstT_exp + DstT_mant)) | (f8_exponent << DstT_mant) | mantissa; } template @@ -481,8 +477,9 @@ CK_TILE_HOST_DEVICE DstT run_cast_from_f8(SrcT x) { static_assert(std::is_same::value || std::is_same::value, "SrcT type must be fp8 or bf8."); - constexpr int SrcT_exp = numeric_traits::exp; - constexpr int SrcT_mant = numeric_traits::mant; + constexpr int SrcT_exp = numeric_traits::exp; + constexpr int SrcT_mant = numeric_traits::mant; + constexpr uint8_t SrcT_abs_mask = numeric_traits::abs_mask; constexpr bool is_fnuz = (numeric_traits::f8_interpret == fp8_interpretation::E4M3_FNUZ) || (numeric_traits::f8_interpret == fp8_interpretation::E5M2_FNUZ); @@ -518,9 +515,9 @@ CK_TILE_HOST_DEVICE DstT run_cast_from_f8(SrcT x) return 0; } - unsigned long long sign = x >> 7; - unsigned long long mantissa = x & ((1 << SrcT_mant) - 1); - int exponent = (x & 0x7F) >> SrcT_mant; + unsigned int sign = x >> (SrcT_exp + SrcT_mant); + unsigned int mantissa = x & ((1 << SrcT_mant) - 1); + int exponent = (x & SrcT_abs_mask) >> SrcT_mant; if constexpr(is_fnuz) { if((x & 0xff) == 0x80) @@ -559,7 +556,7 @@ CK_TILE_HOST_DEVICE DstT run_cast_from_f8(SrcT x) if constexpr(SrcT_exp == 5 && is_half && !is_fnuz) { - retval = x << 8; + retval = static_cast::bitwise_type>(x) << 8; return bit_cast(retval); } diff --git a/test/ck_tile/data_type/CMakeLists.txt b/test/ck_tile/data_type/CMakeLists.txt index 655a0cef9c..a9ce48d1de 100644 --- a/test/ck_tile/data_type/CMakeLists.txt +++ b/test/ck_tile/data_type/CMakeLists.txt @@ -1,5 +1,15 @@ -# Currently ck_tile is only built on gfx9 if(GPU_TARGETS MATCHES "gfx9") add_gtest_executable(test_ck_tile_pk_int4 test_pk_int4.cpp) - add_gtest_executable(test_ck_tile_pk_fp4 test_pk_fp4.cpp) +endif() +if(GPU_TARGETS MATCHES "gfx95") + add_gtest_executable(test_ck_tile_pk_fp4 test_pk_fp4.cpp) +endif() + +if(CK_USE_OCP_FP8 OR CK_USE_FNUZ_FP8) + add_gtest_executable(test_ck_tile_fp8 test_fp8.cpp) + target_compile_options(test_ck_tile_fp8 PRIVATE -Wno-float-equal) + # conditionally specify the use of OCP_FP8 + if(CK_USE_OCP_FP8) + target_compile_options(test_ck_tile_fp8 PRIVATE -DCK_TILE_USE_OCP_FP8) + endif() endif() diff --git a/test/ck_tile/data_type/test_fp8.cpp b/test/ck_tile/data_type/test_fp8.cpp new file mode 100644 index 0000000000..49fd68591f --- /dev/null +++ b/test/ck_tile/data_type/test_fp8.cpp @@ -0,0 +1,606 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gtest/gtest.h" + +#include "ck_tile/core.hpp" + +template +class ConvertTest : public ::testing::Test +{ +}; + +using TestTypes = ::testing::Types; + +TYPED_TEST_SUITE(ConvertTest, TestTypes); + +TYPED_TEST(ConvertTest, ToFp8) +{ + using SrcT = TypeParam; + using DstT = ck_tile::fp8_t; + + auto c = [](SrcT f) { + return static_cast( + ck_tile::bit_cast(ck_tile::impl::run_cast_to_f8(f))); + }; + + auto c_nosat = [](SrcT f) { + return static_cast( + ck_tile::bit_cast(ck_tile::impl::run_cast_to_f8(f))); + }; + +#if CK_TILE_USE_OCP_FP8 + EXPECT_EQ(c(+1.0f), 0b0'0111'000); + EXPECT_EQ(c(-1.0f), 0b1'0111'000); + // max f8 normal + EXPECT_EQ(c(+448.0f), 0b0'1111'110); + EXPECT_EQ(c(-448.0f), 0b1'1111'110); + // min f8 normal + EXPECT_EQ(c(+0.015625f), 0b0'0001'000); + EXPECT_EQ(c(-0.015625f), 0b1'0001'000); + // max f8 subnormal + EXPECT_EQ(c(+0.013671875f), 0b0'0000'111); + EXPECT_EQ(c(-0.013671875f), 0b1'0000'111); + // min f8 subnormal + EXPECT_EQ(c(+0.001953125f), 0b0'0000'001); + EXPECT_EQ(c(-0.001953125f), 0b1'0000'001); + // arbitrary values (exact) + EXPECT_EQ(c(+0.203125f), 0b0'0100'101); + EXPECT_EQ(c(-88.0f), 0b1'1101'011); + // arbitrary values (rounded) + EXPECT_EQ(c(+432.919f), 0b0'1111'110); + EXPECT_EQ(c(-431.111f), 0b1'1111'101); + EXPECT_EQ(c(-0.76123f), 0b1'0110'100); + EXPECT_EQ(c(+0.81234f), 0b0'0110'101); + // midpoint values (rounded to nearest even) + EXPECT_EQ(c(+58.0f), 0b0'1100'110); + EXPECT_EQ(c(+62.0f), 0b0'1101'000); + + // saturating mode -> max f8 normal + // max f32/f16 normal -> max f8 normal + EXPECT_EQ(c(+ck_tile::numeric::max()), 0b0'1111'110); + EXPECT_EQ(c(-ck_tile::numeric::max()), 0b1'1111'110); + // f32/f16 infinity -> max f8 normal + EXPECT_EQ(c(+ck_tile::numeric::infinity()), 0b0'1111'110); + EXPECT_EQ(c(-ck_tile::numeric::infinity()), 0b1'1111'110); + // large f32/f16 -> max f8 normal + EXPECT_EQ(c(+1.23e9f), 0b0'1111'110); + EXPECT_EQ(c(-1.23e9f), 0b1'1111'110); + + constexpr unsigned int nan_mask = 0b0'1111'111; + + // non-saturating mode -> f8 NaN (because OCP e4m3 has no infinity) + // max f32/f16 normal -> f8 NaN + EXPECT_EQ(c_nosat(+ck_tile::numeric::max()) & nan_mask, nan_mask); + EXPECT_EQ(c_nosat(-ck_tile::numeric::max()) & nan_mask, nan_mask); + // f32/f16 infinity -> f8 NaN + EXPECT_EQ(c_nosat(+ck_tile::numeric::infinity()) & nan_mask, nan_mask); + EXPECT_EQ(c_nosat(-ck_tile::numeric::infinity()) & nan_mask, nan_mask); + // large f32/f16 -> f8 NaN + EXPECT_EQ(c_nosat(+1.23e9f) & nan_mask, nan_mask); + EXPECT_EQ(c_nosat(-1.23e9f) & nan_mask, nan_mask); + + // f32/f16 NaN -> f8 NaN + EXPECT_EQ(c(ck_tile::numeric::quiet_NaN()) & nan_mask, nan_mask); + EXPECT_EQ(c(ck_tile::numeric::signaling_NaN()) & nan_mask, nan_mask); + + // f32/f16 zero -> f8 zero + EXPECT_EQ(c(+0.0f), 0b0'0000'000); + EXPECT_EQ(c(-0.0f), 0b1'0000'000); + // min f32/f16 normal -> f8 zero + EXPECT_EQ(c(+ck_tile::numeric::min()), 0b0'0000'000); + EXPECT_EQ(c(-ck_tile::numeric::min()), 0b1'0000'000); + // min f32/f16 subnormal -> f8 zero + EXPECT_EQ(c(+ck_tile::numeric::denorm_min()), 0b0'0000'000); + EXPECT_EQ(c(-ck_tile::numeric::denorm_min()), 0b1'0000'000); + + // All values smaller than min f8 subnormal must be converted to f8 zero + constexpr int src_min_subnorm_exp = + -(ck_tile::numeric_traits::bias + ck_tile::numeric_traits::mant - 1); + constexpr int dst_min_subnorm_exp = + -(ck_tile::numeric_traits::bias + ck_tile::numeric_traits::mant - 1); + for(int exp = src_min_subnorm_exp; exp <= 0; ++exp) + { + const float f = std::ldexp(1.0, exp); + if(exp < dst_min_subnorm_exp) + { + EXPECT_EQ(c(+f), 0b0'0000'000) << "+f = 2^" << exp << " = " << +f; + EXPECT_EQ(c(-f), 0b1'0000'000) << "-f = 2^" << exp << " = " << -f; + } + else + { + EXPECT_GT(c(+f), 0b0'0000'000) << "+f = 2^" << exp << " = " << +f; + EXPECT_GT(c(-f), 0b1'0000'000) << "-f = 2^" << exp << " = " << -f; + } + } +#else // FNUZ + EXPECT_EQ(c(+1.0f), 0b0'1000'000); + EXPECT_EQ(c(-1.0f), 0b1'1000'000); + // max f8 normal + EXPECT_EQ(c(+240.0f), 0b0'1111'111); + EXPECT_EQ(c(-240.0f), 0b1'1111'111); + // min f8 normal + EXPECT_EQ(c(+0.0078125f), 0b0'0001'000); + EXPECT_EQ(c(-0.0078125f), 0b1'0001'000); + // max f8 subnormal + EXPECT_EQ(c(+0.0068359375f), 0b0'0000'111); + EXPECT_EQ(c(-0.0068359375f), 0b1'0000'111); + // min f8 subnormal + EXPECT_EQ(c(+0.0009765625f), 0b0'0000'001); + EXPECT_EQ(c(-0.0009765625f), 0b1'0000'001); + // arbitrary values (exact) + EXPECT_EQ(c(+0.1015625f), 0b0'0100'101); + EXPECT_EQ(c(-44.0f), 0b1'1101'011); + // arbitrary values (rounded) + EXPECT_EQ(c(+219.91f), 0b0'1111'110); + EXPECT_EQ(c(-203.11f), 0b1'1111'101); + EXPECT_EQ(c(-0.3639f), 0b1'0110'100); + EXPECT_EQ(c(+0.4139f), 0b0'0110'101); + + // saturating mode -> max f8 normal + // max f32/f16 normal -> max f8 normal + EXPECT_EQ(c(+ck_tile::numeric::max()), 0b0'1111'111); + EXPECT_EQ(c(-ck_tile::numeric::max()), 0b1'1111'111); + // f32/f16 infinity -> max f8 normal + EXPECT_EQ(c(+ck_tile::numeric::infinity()), 0b0'1111'111); + EXPECT_EQ(c(-ck_tile::numeric::infinity()), 0b1'1111'111); + // large f32/f16 -> max f8 normal + EXPECT_EQ(c(+1.23e9f), 0b0'1111'111); + EXPECT_EQ(c(-1.23e9f), 0b1'1111'111); + + constexpr unsigned int nan_value = 0b1'0000'000; + + // non-saturating mode -> f8 NaN (FN means "finite", so no infinity) + // max f32/f16 normal -> f8 NaN + EXPECT_EQ(c_nosat(+ck_tile::numeric::max()), nan_value); + EXPECT_EQ(c_nosat(-ck_tile::numeric::max()), nan_value); + // f32/f16 infinity -> f8 NaN + EXPECT_EQ(c_nosat(+ck_tile::numeric::infinity()), nan_value); + EXPECT_EQ(c_nosat(-ck_tile::numeric::infinity()), nan_value); + // large f32/f16 -> f8 NaN + EXPECT_EQ(c_nosat(+1.23e9f), nan_value); + EXPECT_EQ(c_nosat(-1.23e9f), nan_value); + + // f32/f16 NaN -> f8 NaN + EXPECT_EQ(c(ck_tile::numeric::quiet_NaN()), nan_value); + EXPECT_EQ(c(ck_tile::numeric::signaling_NaN()), nan_value); + + // UZ means "unsigned zero" (0b1'0000'000 is NaN) + // f32/f16 +-zero -> f8 +zero + EXPECT_EQ(c(+0.0f), 0b0'0000'000); + EXPECT_EQ(c(-0.0f), 0b0'0000'000); + // min f32/f16 normal -> f8 +zero + EXPECT_EQ(c(+ck_tile::numeric::min()), 0b0'0000'000); + EXPECT_EQ(c(-ck_tile::numeric::min()), 0b0'0000'000); + // min f32/f16 subnormal -> f8 +zero + EXPECT_EQ(c(+ck_tile::numeric::denorm_min()), 0b0'0000'000); + EXPECT_EQ(c(-ck_tile::numeric::denorm_min()), 0b0'0000'000); + + // All values smaller than min f8 subnormal must be converted to f8 zero + constexpr int src_min_subnorm_exp = + -(ck_tile::numeric_traits::bias + ck_tile::numeric_traits::mant - 1); + constexpr int dst_min_subnorm_exp = + -(ck_tile::numeric_traits::bias + ck_tile::numeric_traits::mant - 1); + for(int exp = src_min_subnorm_exp; exp <= 0; ++exp) + { + const float f = std::ldexp(1.0, exp); + if(exp < dst_min_subnorm_exp) + { + EXPECT_EQ(c(+f), 0b0'0000'000) << "+f = 2^" << exp << " = " << +f; + EXPECT_EQ(c(-f), 0b0'0000'000) << "-f = 2^" << exp << " = " << -f; + } + else + { + EXPECT_GT(c(+f), 0b0'0000'000) << "+f = 2^" << exp << " = " << +f; + EXPECT_GT(c(-f), 0b0'0000'000) << "-f = 2^" << exp << " = " << -f; + } + } +#endif +} + +TYPED_TEST(ConvertTest, ToBf8) +{ + using SrcT = TypeParam; + using DstT = ck_tile::bf8_t; + + auto c = [](SrcT f) { + return static_cast( + ck_tile::bit_cast(ck_tile::impl::run_cast_to_f8(f))); + }; + + auto c_nosat = [](SrcT f) { + return static_cast( + ck_tile::bit_cast(ck_tile::impl::run_cast_to_f8(f))); + }; + +#if CK_TILE_USE_OCP_FP8 + EXPECT_EQ(c(+1.0f), 0b0'01111'00); + EXPECT_EQ(c(-1.0f), 0b1'01111'00); + // max f8 normal + EXPECT_EQ(c(+57344.0f), 0b0'11110'11); + EXPECT_EQ(c(-57344.0f), 0b1'11110'11); + // min f8 normal + EXPECT_EQ(c(+6.103515625e-05f), 0b0'00001'00); + EXPECT_EQ(c(-6.103515625e-05f), 0b1'00001'00); + // max f8 subnormal + EXPECT_EQ(c(+4.57763671875e-05f), 0b0'00000'11); + EXPECT_EQ(c(-4.57763671875e-05f), 0b1'00000'11); + // min f8 subnormal + EXPECT_EQ(c(+1.52587890625e-05f), 0b0'00000'01); + EXPECT_EQ(c(-1.52587890625e-05f), 0b1'00000'01); + // arbitrary values (exact) + EXPECT_EQ(c(+0.01953125f), 0b0'01001'01); + EXPECT_EQ(c(-3584.0f), 0b1'11010'11); + // arbitrary values (rounded) + EXPECT_EQ(c(+2030.56f), 0b0'11010'00); + EXPECT_EQ(c(-1801.33f), 0b1'11001'11); + EXPECT_EQ(c(-0.27891f), 0b1'0110'100); + EXPECT_EQ(c(+0.33333f), 0b0'0110'101); + + // saturating mode -> max f8 normal + // max f32/f16 normal -> max f8 normal + EXPECT_EQ(c(+ck_tile::numeric::max()), 0b0'11110'11); + EXPECT_EQ(c(-ck_tile::numeric::max()), 0b1'11110'11); + // f32/f16 infinity -> max f8 normal + EXPECT_EQ(c(+ck_tile::numeric::infinity()), 0b0'11110'11); + EXPECT_EQ(c(-ck_tile::numeric::infinity()), 0b1'11110'11); + // large f32/f16 -> max f8 normal + EXPECT_EQ(c(+1.23e9f), 0b0'11110'11); + EXPECT_EQ(c(-1.23e9f), 0b1'11110'11); + + // non-saturating mode -> f8 infinity + // max f32/f16 normal -> f8 infinity + EXPECT_EQ(c_nosat(+ck_tile::numeric::max()), 0b0'11111'00); + EXPECT_EQ(c_nosat(-ck_tile::numeric::max()), 0b1'11111'00); + // f32/f16 infinity -> f8 infinity + EXPECT_EQ(c_nosat(+ck_tile::numeric::infinity()), 0b0'11111'00); + EXPECT_EQ(c_nosat(-ck_tile::numeric::infinity()), 0b1'11111'00); + // large f32/f16 -> f8 infinity + EXPECT_EQ(c_nosat(+1.23e9f), 0b0'11111'00); + EXPECT_EQ(c_nosat(-1.23e9f), 0b1'11111'00); + + // f32/f16 NaN -> f8 NaN + EXPECT_TRUE((c(ck_tile::numeric::quiet_NaN()) & 0b0'11111'11) != 0b0'11111'00); + EXPECT_TRUE((c(ck_tile::numeric::signaling_NaN()) & 0b0'11111'11) != 0b0'11111'00); + + // f32/f16 zero -> f8 zero + EXPECT_EQ(c(+0.0f), 0b0'00000'00); + EXPECT_EQ(c(-0.0f), 0b1'00000'00); + if constexpr(std::is_same_v) + { + // min f32 normal -> f8 zero + EXPECT_EQ(c(+ck_tile::numeric::min()), 0b0'00000'00); + EXPECT_EQ(c(-ck_tile::numeric::min()), 0b1'00000'00); + } + else + { + // min f16 normal -> min f8 normal (they are equal) + EXPECT_EQ(c(+ck_tile::numeric::min()), 0b0'00001'00); + EXPECT_EQ(c(-ck_tile::numeric::min()), 0b1'00001'00); + } + // min f32/f16 subnormal -> f8 zero + EXPECT_EQ(c(+ck_tile::numeric::denorm_min()), 0b0'00000'00); + EXPECT_EQ(c(-ck_tile::numeric::denorm_min()), 0b1'00000'00); + + // All values smaller than min f8 subnormal must be converted to f8 zero + constexpr int src_min_subnorm_exp = + -(ck_tile::numeric_traits::bias + ck_tile::numeric_traits::mant - 1); + constexpr int dst_min_subnorm_exp = + -(ck_tile::numeric_traits::bias + ck_tile::numeric_traits::mant - 1); + for(int exp = src_min_subnorm_exp; exp <= 0; ++exp) + { + const float f = std::ldexp(1.0, exp); + if(exp < dst_min_subnorm_exp) + { + EXPECT_EQ(c(+f), 0b0'00000'00) << "+f = 2^" << exp << " = " << +f; + EXPECT_EQ(c(-f), 0b1'00000'00) << "-f = 2^" << exp << " = " << -f; + } + else + { + EXPECT_GT(c(+f), 0b0'00000'00) << "+f = 2^" << exp << " = " << +f; + EXPECT_GT(c(-f), 0b1'00000'00) << "-f = 2^" << exp << " = " << -f; + } + } +#else // FNUZ + EXPECT_EQ(c(+1.0f), 0b0'10000'00); + EXPECT_EQ(c(-1.0f), 0b1'10000'00); + // max f8 normal + EXPECT_EQ(c(+57344.0f), 0b0'11111'11); + EXPECT_EQ(c(-57344.0f), 0b1'11111'11); + // min f8 normal + EXPECT_EQ(c(+3.0517578125e-05f), 0b0'00001'00); + EXPECT_EQ(c(-3.0517578125e-05f), 0b1'00001'00); + // max f8 subnormal + EXPECT_EQ(c(+2.288818359375e-05f), 0b0'00000'11); + EXPECT_EQ(c(-2.288818359375e-05f), 0b1'00000'11); + // min f8 subnormal + EXPECT_EQ(c(+7.62939453125e-06f), 0b0'00000'01); + EXPECT_EQ(c(-7.62939453125e-06f), 0b1'00000'01); + // arbitrary values (exact) + EXPECT_EQ(c(+0.009765625f), 0b0'01001'01); + EXPECT_EQ(c(-1792.0f), 0b1'11010'11); + // arbitrary values (rounded) + EXPECT_EQ(c(+840.100f), 0b0'11001'11); + EXPECT_EQ(c(-999.999f), 0b1'11010'00); + EXPECT_EQ(c(-0.12789f), 0b1'0110'100); + EXPECT_EQ(c(+0.14444f), 0b0'0110'101); + + // saturating mode -> max f8 normal + // max f32/f16 normal -> max f8 normal + EXPECT_EQ(c(+ck_tile::numeric::max()), 0b0'11111'11); + EXPECT_EQ(c(-ck_tile::numeric::max()), 0b1'1111'111); + // f32/f16 infinity -> max f8 normal + EXPECT_EQ(c(+ck_tile::numeric::infinity()), 0b0'11111'11); + EXPECT_EQ(c(-ck_tile::numeric::infinity()), 0b1'1111'111); + // large f32/f16 -> max f8 normal + EXPECT_EQ(c(+1.23e9f), 0b0'11111'11); + EXPECT_EQ(c(-1.23e9f), 0b1'1111'111); + + constexpr unsigned int nan_value = 0b1'00000'00; + + // non-saturating mode -> f8 NaN (FN means "finite", so no infinity) + // max f32/f16 normal -> f8 NaN + EXPECT_EQ(c_nosat(+ck_tile::numeric::max()), nan_value); + EXPECT_EQ(c_nosat(-ck_tile::numeric::max()), nan_value); + // f32/f16 infinity -> f8 NaN + EXPECT_EQ(c_nosat(+ck_tile::numeric::infinity()), nan_value); + EXPECT_EQ(c_nosat(-ck_tile::numeric::infinity()), nan_value); + // large f32/f16 -> f8 NaN + EXPECT_EQ(c_nosat(+1.23e9f), nan_value); + EXPECT_EQ(c_nosat(-1.23e9f), nan_value); + + // f32/f16 NaN -> f8 NaN + EXPECT_EQ(c(ck_tile::numeric::quiet_NaN()), nan_value); + EXPECT_EQ(c(ck_tile::numeric::signaling_NaN()), nan_value); + + // UZ means "unsigned zero" (0b1'00000'00 is NaN) + // f32/f16 +-zero -> f8 +zero + EXPECT_EQ(c(+0.0f), 0b0'00000'00); + EXPECT_EQ(c(-0.0f), 0b0'00000'00); + if constexpr(std::is_same_v) + { + // min f32 normal -> f8 +zero + EXPECT_EQ(c(+ck_tile::numeric::min()), 0b0'00000'00); + EXPECT_EQ(c(-ck_tile::numeric::min()), 0b0'00000'00); + } + else + { + // min f16 normal -> f8 normal + EXPECT_EQ(c(+ck_tile::numeric::min()), 0b0'00010'00); + EXPECT_EQ(c(-ck_tile::numeric::min()), 0b1'00010'00); + } + // min f32/f16 subnormal -> f8 +zero + EXPECT_EQ(c(+ck_tile::numeric::denorm_min()), 0b0'00000'00); + EXPECT_EQ(c(-ck_tile::numeric::denorm_min()), 0b0'00000'00); + + // All values smaller than min f8 subnormal must be converted to f8 zero + constexpr int src_min_subnorm_exp = + -(ck_tile::numeric_traits::bias + ck_tile::numeric_traits::mant - 1); + constexpr int dst_min_subnorm_exp = + -(ck_tile::numeric_traits::bias + ck_tile::numeric_traits::mant - 1); + for(int exp = src_min_subnorm_exp; exp <= 0; ++exp) + { + const float f = std::ldexp(1.0, exp); + if(exp < dst_min_subnorm_exp) + { + EXPECT_EQ(c(+f), 0b0'00000'00) << "+f = 2^" << exp << " = " << +f; + EXPECT_EQ(c(-f), 0b0'00000'00) << "-f = 2^" << exp << " = " << -f; + } + else + { + EXPECT_GT(c(+f), 0b0'00000'00) << "+f = 2^" << exp << " = " << +f; + EXPECT_GT(c(-f), 0b0'00000'00) << "-f = 2^" << exp << " = " << -f; + } + } +#endif +} + +TYPED_TEST(ConvertTest, FromFp8) +{ + using SrcT = ck_tile::fp8_t; + using DstT = TypeParam; + + auto c = [](uint8_t u) { + return ck_tile::type_convert( + ck_tile::impl::run_cast_from_f8(ck_tile::bit_cast(u))); + }; + +#if CK_TILE_USE_OCP_FP8 + EXPECT_EQ(c(0b0'0111'000), +1.0f); + EXPECT_EQ(c(0b1'0111'000), -1.0f); + // max f8 normal + EXPECT_EQ(c(0b0'1111'110), +448.0f); + EXPECT_EQ(c(0b1'1111'110), -448.0f); + // min f8 normal + EXPECT_EQ(c(0b0'0001'000), +0.015625f); + EXPECT_EQ(c(0b1'0001'000), -0.015625f); + // max f8 subnormal + EXPECT_EQ(c(0b0'0000'111), +0.013671875f); + EXPECT_EQ(c(0b1'0000'111), -0.013671875f); + // min f8 subnormal + EXPECT_EQ(c(0b0'0000'001), +0.001953125f); + EXPECT_EQ(c(0b1'0000'001), -0.001953125f); + // arbitrary values + EXPECT_EQ(c(0b0'0100'101), +0.203125f); + EXPECT_EQ(c(0b1'1101'011), -88.0f); + + // f8 NaN -> f32/f16 NaN + EXPECT_TRUE(ck_tile::isnan(c(0b0'1111'111))); + EXPECT_TRUE(ck_tile::isnan(c(0b1'1111'111))); + + // f8 zero -> f32/f16 zero (sign is preserved) + EXPECT_EQ(c(0b0'0000'000), + ck_tile::bit_cast(typename ck_tile::numeric_traits::bitwise_type{0})); + EXPECT_EQ(c(0b1'0000'000), ck_tile::bit_cast(ck_tile::numeric_traits::Neg0)); +#else // FNUZ + EXPECT_EQ(c(0b0'1000'000), +1.0f); + EXPECT_EQ(c(0b1'1000'000), -1.0f); + // max f8 normal + EXPECT_EQ(c(0b0'1111'111), +240.0f); + EXPECT_EQ(c(0b1'1111'111), -240.0f); + // min f8 normal + EXPECT_EQ(c(0b0'0001'000), +0.0078125f); + EXPECT_EQ(c(0b1'0001'000), -0.0078125f); + // max f8 subnormal + EXPECT_EQ(c(0b0'0000'111), +0.0068359375f); + EXPECT_EQ(c(0b1'0000'111), -0.0068359375f); + // min f8 subnormal + EXPECT_EQ(c(0b0'0000'001), +0.0009765625f); + EXPECT_EQ(c(0b1'0000'001), -0.0009765625f); + // arbitrary values + EXPECT_EQ(c(0b0'0100'101), +0.1015625f); + EXPECT_EQ(c(0b1'1101'011), -44.0f); + + // f8 NaN -> f32/f16 NaN + EXPECT_TRUE(ck_tile::isnan(c(0b1'0000'000))); + + // UZ means "unsigned zero" (0b1'0000'000 is NaN) + // f8 +zero -> f32/f16 +zero + EXPECT_EQ(c(0b0'0000'000), + ck_tile::bit_cast(typename ck_tile::numeric_traits::bitwise_type{0})); +#endif +} + +TYPED_TEST(ConvertTest, FromBf8) +{ + using SrcT = ck_tile::bf8_t; + using DstT = TypeParam; + + using DstT = TypeParam; + + auto c = [](uint8_t u) { + return ck_tile::type_convert( + ck_tile::impl::run_cast_from_f8(ck_tile::bit_cast(u))); + }; + +#if CK_TILE_USE_OCP_FP8 + auto c_nosat = [](uint8_t u) { + return ck_tile::type_convert( + ck_tile::impl::run_cast_from_f8(ck_tile::bit_cast(u))); + }; + + EXPECT_EQ(c(0b0'01111'00), +1.0f); + EXPECT_EQ(c(0b1'01111'00), -1.0f); + // max f8 normal + EXPECT_EQ(c(0b0'11110'11), +57344.0f); + EXPECT_EQ(c(0b1'11110'11), -57344.0f); + // min f8 normal + EXPECT_EQ(c(0b0'00001'00), +6.103515625e-05f); + EXPECT_EQ(c(0b1'00001'00), -6.103515625e-05f); + // max f8 subnormal + EXPECT_EQ(c(0b0'00000'11), +4.57763671875e-05f); + EXPECT_EQ(c(0b1'00000'11), -4.57763671875e-05f); + // min f8 subnormal + EXPECT_EQ(c(0b0'00000'01), +1.52587890625e-05f); + EXPECT_EQ(c(0b1'00000'01), -1.52587890625e-05f); + // arbitrary values + EXPECT_EQ(c(0b0'01001'01), +0.01953125f); + EXPECT_EQ(c(0b1'11010'11), -3584.0f); + + // saturating mode + // f8 infinity -> max f8 normal as f32/f16 + EXPECT_EQ(c(0b0'11111'00), +57344.0f); + EXPECT_EQ(c(0b1'11111'00), -57344.0f); + + // non-saturating mode + // f8 infinity -> f32/f16 infinity + EXPECT_EQ(c_nosat(0b0'11111'00), +ck_tile::numeric::infinity()); + EXPECT_EQ(c_nosat(0b1'11111'00), -ck_tile::numeric::infinity()); + + // f8 NaN -> f32/f16 NaN + EXPECT_TRUE(ck_tile::isnan(c(0b0'11111'01))); + EXPECT_TRUE(ck_tile::isnan(c(0b0'11111'10))); + EXPECT_TRUE(ck_tile::isnan(c(0b0'11111'11))); + EXPECT_TRUE(ck_tile::isnan(c(0b1'11111'01))); + EXPECT_TRUE(ck_tile::isnan(c(0b1'11111'10))); + EXPECT_TRUE(ck_tile::isnan(c(0b1'11111'11))); + + // f8 zero -> f32/f16 zero (sign is preserved) + EXPECT_EQ(c(0b0'00000'00), + ck_tile::bit_cast(typename ck_tile::numeric_traits::bitwise_type{0})); + EXPECT_EQ(c(0b1'00000'00), ck_tile::bit_cast(ck_tile::numeric_traits::Neg0)); + if constexpr(std::is_same_v) + { + // min f8 normal -> min f16 normal (they are equal) + EXPECT_EQ(c(0b0'00001'00), +ck_tile::numeric::min()); + EXPECT_EQ(c(0b1'00001'00), -ck_tile::numeric::min()); + } +#else // FNUZ + EXPECT_EQ(c(0b0'10000'00), +1.0f); + EXPECT_EQ(c(0b1'10000'00), -1.0f); + // max f8 normal + EXPECT_EQ(c(0b0'11111'11), +57344.0f); + EXPECT_EQ(c(0b1'11111'11), -57344.0f); + // min f8 normal + EXPECT_EQ(c(0b0'00001'00), +3.0517578125e-05f); + EXPECT_EQ(c(0b1'00001'00), -3.0517578125e-05f); + // max f8 subnormal + EXPECT_EQ(c(0b0'00000'11), +2.288818359375e-05f); + EXPECT_EQ(c(0b1'00000'11), -2.288818359375e-05f); + // min f8 subnormal + EXPECT_EQ(c(0b0'00000'01), +7.62939453125e-06f); + EXPECT_EQ(c(0b1'00000'01), -7.62939453125e-06f); + // arbitrary values + EXPECT_EQ(c(0b0'01001'01), +0.009765625f); + EXPECT_EQ(c(0b1'11010'11), -1792.0f); + + // f8 NaN -> f32/f16 NaN + EXPECT_TRUE(ck_tile::isnan(c(0b1'00000'00))); + + // UZ means "unsigned zero" (0b1'00000'00 is NaN) + // f8 +zero -> f32/f16 +zero + EXPECT_EQ(c(0b0'00000'00), + ck_tile::bit_cast(typename ck_tile::numeric_traits::bitwise_type{0})); + if constexpr(std::is_same_v) + { + // one of f8 normals -> min f16 normal + EXPECT_EQ(c(0b0'00010'00), +ck_tile::numeric::min()); + EXPECT_EQ(c(0b1'00010'00), -ck_tile::numeric::min()); + } +#endif +} + +// Convert f8 -> f32/f16 -> f8 to check if all values are covered +// OCP types multiple NaN representations (e4m3 - 2, e5m2 - 6), they are ignored for simplicity. + +TYPED_TEST(ConvertTest, FromFp8AndToFp8) +{ + using SrcT = ck_tile::fp8_t; + using DstT = TypeParam; + + for(int i = 0; i < 256; ++i) + { +#if CK_TILE_USE_OCP_FP8 + if((i & 0b0'1111'111) == 0b0'1111'111) + { + continue; + } +#endif + const uint8_t u = static_cast(i); + const SrcT from = ck_tile::bit_cast(u); + const DstT f = ck_tile::impl::run_cast_from_f8(from); + const SrcT to = ck_tile::impl::run_cast_to_f8(f); + EXPECT_EQ(from, to) << "u8: " << i << " f32/f16: " << ck_tile::type_convert(f); + } +} + +TYPED_TEST(ConvertTest, FromBf8AndToBf8) +{ + using SrcT = ck_tile::bf8_t; + using DstT = TypeParam; + + for(int i = 0; i < 256; ++i) + { +#if CK_TILE_USE_OCP_FP8 + if((i & 0b0'11111'11) > 0b0'11111'00) + { + continue; + } +#endif + const uint8_t u = static_cast(i); + const SrcT from = ck_tile::bit_cast(u); + const DstT f = ck_tile::impl::run_cast_from_f8(from); + const SrcT to = ck_tile::impl::run_cast_to_f8(f); + EXPECT_EQ(from, to) << "u8: " << i << " f32/f16: " << ck_tile::type_convert(f); + } +}