diff --git a/include/ck/utility/amd_ck_fp8.hpp b/include/ck/utility/amd_ck_fp8.hpp index c5525d5ff8..91745376ac 100644 --- a/include/ck/utility/amd_ck_fp8.hpp +++ b/include/ck/utility/amd_ck_fp8.hpp @@ -32,34 +32,8 @@ namespace ck { -struct f8_fnuz_t -{ - using data_type = unsigned char; - data_type m_data = data_type{}; - __host__ __device__ explicit constexpr f8_fnuz_t(data_type in_data) : m_data(in_data) {} - __host__ __device__ explicit constexpr f8_fnuz_t() = default; - __host__ __device__ bool constexpr operator==(f8_fnuz_t other) const - { - return m_data == other.m_data; - } - __host__ __device__ explicit constexpr operator data_type() const { return m_data; } -}; - -struct bf8_fnuz_t -{ - using data_type = unsigned char; - data_type m_data = data_type{}; - __host__ __device__ explicit constexpr bf8_fnuz_t(data_type in_data) : m_data(in_data) {} - __host__ __device__ explicit constexpr bf8_fnuz_t() = default; - __host__ __device__ bool constexpr operator==(bf8_fnuz_t other) const - { - return m_data == other.m_data; - } - __host__ __device__ explicit constexpr operator data_type() const { return m_data; } -}; - -static_assert(1 == sizeof(f8_fnuz_t)); -static_assert(1 == sizeof(bf8_fnuz_t)); +using f8_fnuz_t = _BitInt(8); +using bf8_fnuz_t = unsigned _BitInt(8); typedef unsigned char fp8_storage_t; diff --git a/include/ck/utility/data_type.hpp b/include/ck/utility/data_type.hpp index 574269b94a..a962e27b3d 100644 --- a/include/ck/utility/data_type.hpp +++ b/include/ck/utility/data_type.hpp @@ -205,7 +205,7 @@ inline constexpr bool is_native_type() return is_same::value || is_same::value || is_same::value || is_same::value || is_same::value || is_same::value || is_same::value || is_same::value || - is_same_v || is_same_v || is_same::value; + is_same::value || is_same::value || is_same::value; } // scalar_type @@ -300,14 +300,14 @@ struct scalar_type template <> struct scalar_type { - using type = f8_fnuz_t::data_type; + using type = f8_fnuz_t; static constexpr index_t vector_size = 1; }; template <> struct scalar_type { - using type = bf8_fnuz_t::data_type; + using type = bf8_fnuz_t; static constexpr index_t vector_size = 1; }; diff --git a/include/ck/utility/dtype_vector.hpp b/include/ck/utility/dtype_vector.hpp index 084240f84b..8c5fb3ecf7 100644 --- a/include/ck/utility/dtype_vector.hpp +++ b/include/ck/utility/dtype_vector.hpp @@ -1295,18 +1295,6 @@ struct nnvb_data_t_selector }; #ifndef CK_CODE_GEN_RTC -template <> -struct nnvb_data_t_selector -{ - using type = f8_fnuz_t::data_type; -}; - -template <> -struct nnvb_data_t_selector -{ - using type = bf8_fnuz_t::data_type; -}; - template <> struct nnvb_data_t_selector { diff --git a/include/ck/utility/f8_utils.hpp b/include/ck/utility/f8_utils.hpp index 94c2f84c8c..77e4d44796 100644 --- a/include/ck/utility/f8_utils.hpp +++ b/include/ck/utility/f8_utils.hpp @@ -39,7 +39,7 @@ __host__ __device__ Y run_cast_to_f8(X x, uint32_t rng) int exponent, bias; uint32_t head, mantissa, sign; // nan code is same for float and half - constexpr uint8_t nan_code = 0x80; + constexpr Y nan_code = 0x80; constexpr uint32_t nan_mask = NumericUtils::nan_mask; // convert to bitwise @@ -60,17 +60,17 @@ __host__ __device__ Y run_cast_to_f8(X x, uint32_t rng) if constexpr(negative_zero_nan) { if((x_bitwise & nan_mask) == nan_mask) - return Y{nan_code}; + return nan_code; } else { if((x_bitwise & nan_mask) == nan_mask) - return Y{static_cast(signed_inf + (mantissa != 0 ? 1 : 0))}; + return signed_inf + (mantissa != 0 ? 1 : 0); } // check if x is 0.0 if(x_bitwise == 0) - return Y{0}; + return 0; // First need to check if it is normal or denorm as there is a difference of implict 1 // Then need to adjust the exponent to align with the F8 exponent, in the meanwhile, shift @@ -178,10 +178,9 @@ In this case, the fp16 mantissa should be shift left by 1 */ // check if x is 0.0 or -0.0 if(out_exponent == 0 && mantissa == 0) - return Y{negative_zero_nan ? 0 : static_cast(sign << (out_exp + out_mant))}; + return negative_zero_nan ? 0 : (sign << (out_exp + out_mant)); mantissa &= (1 << out_mant) - 1; - return Y{static_cast((sign << (out_exp + out_mant)) | (out_exponent << out_mant) | - mantissa)}; + return (sign << (out_exp + out_mant)) | (out_exponent << out_mant) | mantissa; } template @@ -196,8 +195,8 @@ __host__ __device__ Y run_cast_from_f8(X x) constexpr int out_mant = NumericUtils::mant; // prepare the codes - constexpr uint8_t nan_code = 0x80; - using T_bitwise = typename NumericUtils::bitwise_type; + constexpr X nan_code = 0x80; + using T_bitwise = typename NumericUtils::bitwise_type; constexpr T_bitwise Inf_bitwise = NumericUtils::Inf; constexpr T_bitwise NegInf_bitwise = NumericUtils::NegInf; @@ -210,13 +209,13 @@ __host__ __device__ Y run_cast_from_f8(X x) constexpr Y Neg0 = bit_cast(Neg0_bitwise); // check if x is 0.0 - if(!static_cast(x)) + if(x == 0) return static_cast(0); // unpack the input - uint32_t sign = static_cast(x) >> (in_exp + in_mant); - uint32_t mantissa = static_cast(x) & ((1 << in_mant) - 1); - int exponent = (static_cast(x) & 0x7F) >> in_mant; + uint32_t sign = x >> (in_exp + in_mant); + uint32_t mantissa = x & ((1 << in_mant) - 1); + int exponent = (x & 0x7F) >> in_mant; constexpr int exp_low_cutoff = (1 << (out_exp - 1)) - (1 << (in_exp - 1)) + 1 - (negative_zero_nan ? 1 : 0); @@ -224,12 +223,12 @@ __host__ __device__ Y run_cast_from_f8(X x) if constexpr(negative_zero_nan) { - if(static_cast(x) == nan_code) + if(x == nan_code) return NaN; } else { - if(static_cast(x) == nan_code) + if(x == nan_code) return Neg0; if(exponent == ((1 << in_exp) - 1)) return (mantissa == 0) ? (sign ? NegInf : Inf) : NaN; diff --git a/include/ck/utility/type_convert.hpp b/include/ck/utility/type_convert.hpp index 701b2686c7..aef2948fb3 100644 --- a/include/ck/utility/type_convert.hpp +++ b/include/ck/utility/type_convert.hpp @@ -351,7 +351,7 @@ inline __host__ __device__ f8_fnuz_t f8_convert_sr(float x) val.fval = __builtin_amdgcn_fmed3f(val.fval, max_fp8, -max_fp8); ival = __builtin_amdgcn_cvt_sr_fp8_f32(val.fval, rng, ival, 0); // 0 pos val.i32val = ival; - return f8_fnuz_t{val.i8val[0]}; // little endian + return val.i8val[0]; // little endian #else constexpr bool negative_zero_nan = true; constexpr bool clip = true; @@ -419,7 +419,7 @@ inline __host__ __device__ bf8_fnuz_t f8_convert_sr(float x) val.fval = __builtin_amdgcn_fmed3f(val.fval, max_bf8, -max_bf8); ival = __builtin_amdgcn_cvt_sr_bf8_f32(val.fval, rng, ival, 0); // 0 pos val.i32val = ival; - return bf8_fnuz_t{val.i8val[0]}; // little endian + return val.i8val[0]; // little endian #else constexpr bool negative_zero_nan = true; constexpr bool clip = true; @@ -655,7 +655,7 @@ inline __host__ __device__ f8_fnuz_t f8_convert_rne(float x) val.fval = __builtin_amdgcn_fmed3f(val.fval, max_fp8, -max_fp8); ival = __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival, false); // false -> WORD0 val.i32val = ival; - return f8_fnuz_t{val.i8val[0]}; + return val.i8val[0]; #else constexpr bool negative_zero_nan = true; constexpr bool clip = true; @@ -707,7 +707,7 @@ inline __host__ __device__ bf8_fnuz_t f8_convert_rne(float x) val.fval = __builtin_amdgcn_fmed3f(val.fval, max_bf8, -max_bf8); ival = __builtin_amdgcn_cvt_pk_bf8_f32(val.fval, val.fval, ival, false); // false -> WORD0 val.i32val = ival; - return bf8_fnuz_t{val.i8val[0]}; + return val.i8val[0]; #else constexpr bool negative_zero_nan = true; constexpr bool clip = true; @@ -924,7 +924,7 @@ inline __host__ __device__ float type_convert(f8_fnuz_t x) { #if defined(__gfx94__) float fval; - uint32_t i32val = static_cast(static_cast(x)); + uint32_t i32val = static_cast(x); fval = __builtin_amdgcn_cvt_f32_fp8(i32val, 0); // asm volatile("v_cvt_f32_fp8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val)); return fval; @@ -1430,7 +1430,7 @@ inline __host__ __device__ float type_convert(bf8_fnuz_t x) { #if defined(__gfx94__) float fval; - uint32_t i32val = static_cast(static_cast(x)); + uint32_t i32val = static_cast(x); fval = __builtin_amdgcn_cvt_f32_bf8(i32val, 0); // asm volatile("v_cvt_f32_bf8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val)); return fval; diff --git a/test/data_type/test_bf8_fnuz.cpp b/test/data_type/test_bf8_fnuz.cpp index f028c0da73..4ff796a614 100644 --- a/test/data_type/test_bf8_fnuz.cpp +++ b/test/data_type/test_bf8_fnuz.cpp @@ -43,8 +43,9 @@ TEST(BF8FNUZ, ConvertFP32Nearest) type_convert(f8_convert_rne(std::numeric_limits::max())), abs_tol); // convert inf float to bf8_fnuz_t and check if it is qNan - ASSERT_EQ(ck::NumericLimits::QuietNaN(), - f8_convert_rne(std::numeric_limits::infinity())); + ASSERT_NEAR(ck::NumericLimits::QuietNaN(), + f8_convert_rne(std::numeric_limits::infinity()), + abs_tol); // positive norm float value to bf8 and back, check if holds float pos_float = 0.0000762939f; ASSERT_NEAR(pos_float, type_convert(f8_convert_rne(pos_float)), abs_tol); @@ -79,8 +80,9 @@ TEST(BF8FNUZ, ConvertFP32Stochastic) type_convert(f8_convert_sr(std::numeric_limits::max())), abs_tol); // convert inf float to bf8_fnuz_t and check if it is qNan - ASSERT_EQ(ck::NumericLimits::QuietNaN(), - f8_convert_sr(std::numeric_limits::infinity())); + ASSERT_NEAR(ck::NumericLimits::QuietNaN(), + f8_convert_sr(std::numeric_limits::infinity()), + abs_tol); // positive norm float value to bf8 and back, check if holds float pos_float = 0.0000762939f; ASSERT_NEAR(pos_float, type_convert(f8_convert_sr(pos_float)), abs_tol); @@ -116,8 +118,9 @@ TEST(BF8FNUZ, ConvertFP16Nearest) type_convert(f8_convert_rne(ck::NumericLimits::Max())), abs_tol); // convert QuietNaN fp16 to bf8_fnuz_t and check if it is QuietNaN - ASSERT_EQ(ck::NumericLimits::QuietNaN(), - f8_convert_rne(ck::NumericLimits::QuietNaN())); + ASSERT_NEAR(ck::NumericLimits::QuietNaN(), + f8_convert_rne(ck::NumericLimits::QuietNaN()), + abs_tol); // positive norm fp16 value to bf8 and back, check if holds half_t pos_half = half_t{0.0000762939}; ASSERT_NEAR(pos_half, type_convert(f8_convert_rne(pos_half)), abs_tol); @@ -152,8 +155,9 @@ TEST(BF8FNUZ, ConvertFP16Stochastic) type_convert(f8_convert_sr(ck::NumericLimits::Max())), abs_tol); // convert QuietNaN fp16 to bf8_fnuz_t and check if it is QuietNaN - ASSERT_EQ(ck::NumericLimits::QuietNaN(), - f8_convert_sr(ck::NumericLimits::QuietNaN())); + ASSERT_NEAR(ck::NumericLimits::QuietNaN(), + f8_convert_sr(ck::NumericLimits::QuietNaN()), + abs_tol); // positive norm fp16 value to bf8 and back, check if holds half_t pos_half = half_t{0.0000762939}; ASSERT_NEAR(pos_half, type_convert(f8_convert_sr(pos_half)), abs_tol); diff --git a/test/data_type/test_fp8_fnuz.cpp b/test/data_type/test_fp8_fnuz.cpp index 0cf775f947..c2ec6dad94 100644 --- a/test/data_type/test_fp8_fnuz.cpp +++ b/test/data_type/test_fp8_fnuz.cpp @@ -48,8 +48,9 @@ TEST(FP8FNUZ, ConvertFP32Nearest) type_convert(f8_convert_rne(std::numeric_limits::max())), abs_tol); // convert inf float to f8_fnuz_t and check if it is qNan - ASSERT_EQ(ck::NumericLimits::QuietNaN(), - f8_convert_rne(std::numeric_limits::infinity())); + ASSERT_NEAR(ck::NumericLimits::QuietNaN(), + f8_convert_rne(std::numeric_limits::infinity()), + abs_tol); // positive norm float value to fp8 and back, check if holds float pos_float = 0.017578125f; ASSERT_NEAR(pos_float, type_convert(f8_convert_rne(pos_float)), abs_tol); @@ -84,8 +85,9 @@ TEST(FP8FNUZ, ConvertFP32Stochastic) type_convert(f8_convert_sr(std::numeric_limits::max())), abs_tol); // convert inf float to f8_fnuz_t and check if it is qNan - ASSERT_EQ(ck::NumericLimits::QuietNaN(), - f8_convert_sr(std::numeric_limits::infinity())); + ASSERT_NEAR(ck::NumericLimits::QuietNaN(), + f8_convert_sr(std::numeric_limits::infinity()), + abs_tol); // positive norm float value to fp8 and back, check if holds float pos_float = 0.017578125f; ASSERT_NEAR(pos_float, type_convert(f8_convert_sr(pos_float)), abs_tol); @@ -120,8 +122,9 @@ TEST(FP8FNUZ, ConvertFP16Nearest) type_convert(f8_convert_rne(ck::NumericLimits::Max())), abs_tol); // convert QuietNaN fp16 to f8_fnuz_t and check if it is QuietNaN - ASSERT_EQ(ck::NumericLimits::QuietNaN(), - f8_convert_rne(ck::NumericLimits::QuietNaN())); + ASSERT_NEAR(ck::NumericLimits::QuietNaN(), + f8_convert_rne(ck::NumericLimits::QuietNaN()), + abs_tol); // positive norm fp16 value to fp8 and back, check if holds half_t pos_half = half_t{0.017578125}; ASSERT_NEAR(pos_half, type_convert(f8_convert_rne(pos_half)), abs_tol); @@ -156,8 +159,9 @@ TEST(FP8FNUZ, ConvertFP16Stochastic) type_convert(f8_convert_sr(ck::NumericLimits::Max())), abs_tol); // convert QuietNaN fp16 to f8_fnuz_t and check if it is QuietNaN - ASSERT_EQ(ck::NumericLimits::QuietNaN(), - f8_convert_sr(ck::NumericLimits::QuietNaN())); + ASSERT_NEAR(ck::NumericLimits::QuietNaN(), + f8_convert_sr(ck::NumericLimits::QuietNaN()), + abs_tol); // positive norm fp16 value to fp8 and back, check if holds half_t pos_half = half_t{0.017578125}; ASSERT_NEAR(pos_half, type_convert(f8_convert_sr(pos_half)), abs_tol);