diff --git a/include/ck/utility/amd_ck_fp8.hpp b/include/ck/utility/amd_ck_fp8.hpp index 2edbb7c789..0b73f76155 100644 --- a/include/ck/utility/amd_ck_fp8.hpp +++ b/include/ck/utility/amd_ck_fp8.hpp @@ -33,8 +33,34 @@ namespace ck { -using f8_fnuz_t = _BitInt(8); -using bf8_fnuz_t = unsigned _BitInt(8); +struct f8_fnuz_t +{ + using data_type = unsigned char; + data_type m_data; + __host__ __device__ explicit constexpr f8_fnuz_t(data_type in_data) : m_data(in_data) {} + __host__ __device__ explicit constexpr f8_fnuz_t() = default; + __host__ __device__ bool constexpr operator==(f8_fnuz_t other) const + { + return m_data == other.m_data; + } + __host__ __device__ explicit constexpr operator data_type() const { return m_data; } +}; + +struct bf8_fnuz_t +{ + using data_type = unsigned char; + data_type m_data; + __host__ __device__ explicit constexpr bf8_fnuz_t(data_type in_data) : m_data(in_data) {} + __host__ __device__ explicit constexpr bf8_fnuz_t() = default; + __host__ __device__ bool constexpr operator==(bf8_fnuz_t other) const + { + return m_data == other.m_data; + } + __host__ __device__ explicit constexpr operator data_type() const { return m_data; } +}; + +static_assert(1 == sizeof(f8_fnuz_t)); +static_assert(1 == sizeof(bf8_fnuz_t)); typedef unsigned char fp8_storage_t; diff --git a/include/ck/utility/data_type.hpp b/include/ck/utility/data_type.hpp index 48b352986e..984bb4d862 100644 --- a/include/ck/utility/data_type.hpp +++ b/include/ck/utility/data_type.hpp @@ -205,7 +205,7 @@ inline constexpr bool is_native_type() return is_same::value || is_same::value || is_same::value || is_same::value || is_same::value || is_same::value || is_same::value || is_same::value || - is_same::value || is_same::value || is_same::value; + is_same_v || is_same_v || is_same::value; } // scalar_type @@ -300,14 +300,14 @@ struct scalar_type template <> struct scalar_type { - using type = f8_fnuz_t; + using type = f8_fnuz_t::data_type; static constexpr index_t vector_size = 1; }; template <> struct scalar_type { - using type = bf8_fnuz_t; + using type = bf8_fnuz_t::data_type; static constexpr index_t vector_size = 1; }; diff --git a/include/ck/utility/dtype_vector.hpp b/include/ck/utility/dtype_vector.hpp index ae0edb35ee..27a7545a0e 100644 --- a/include/ck/utility/dtype_vector.hpp +++ b/include/ck/utility/dtype_vector.hpp @@ -1294,6 +1294,18 @@ struct nnvb_data_t_selector using type = bf8_ocp_t::data_type; }; +template <> +struct nnvb_data_t_selector +{ + using type = f8_fnuz_t::data_type; +}; + +template <> +struct nnvb_data_t_selector +{ + using type = bf8_fnuz_t::data_type; +}; + template <> struct nnvb_data_t_selector { diff --git a/include/ck/utility/f8_utils.hpp b/include/ck/utility/f8_utils.hpp index 799683ae65..748aa07f9e 100644 --- a/include/ck/utility/f8_utils.hpp +++ b/include/ck/utility/f8_utils.hpp @@ -39,7 +39,7 @@ __host__ __device__ Y run_cast_to_f8(X x, uint32_t rng) int exponent, bias; uint32_t head, mantissa, sign; // nan code is same for float and half - constexpr Y nan_code = 0x80; + constexpr uint8_t nan_code = 0x80; constexpr uint32_t nan_mask = NumericUtils::nan_mask; // convert to bitwise @@ -60,17 +60,17 @@ __host__ __device__ Y run_cast_to_f8(X x, uint32_t rng) if constexpr(negative_zero_nan) { if((x_bitwise & nan_mask) == nan_mask) - return nan_code; + return Y{nan_code}; } else { if((x_bitwise & nan_mask) == nan_mask) - return signed_inf + (mantissa != 0 ? 1 : 0); + return Y{static_cast(signed_inf + (mantissa != 0 ? 1 : 0))}; } // check if x is 0.0 if(x_bitwise == 0) - return 0; + return Y{0}; // First need to check if it is normal or denorm as there is a difference of implict 1 // Then need to adjust the exponent to align with the F8 exponent, in the meanwhile, shift @@ -178,9 +178,10 @@ In this case, the fp16 mantissa should be shift left by 1 */ // check if x is 0.0 or -0.0 if(out_exponent == 0 && mantissa == 0) - return negative_zero_nan ? 0 : (sign << (out_exp + out_mant)); + return Y{negative_zero_nan ? 0 : static_cast(sign << (out_exp + out_mant))}; mantissa &= (1 << out_mant) - 1; - return (sign << (out_exp + out_mant)) | (out_exponent << out_mant) | mantissa; + return Y{static_cast((sign << (out_exp + out_mant)) | (out_exponent << out_mant) | + mantissa)}; } template @@ -195,8 +196,8 @@ __host__ __device__ Y run_cast_from_f8(X x) constexpr int out_mant = NumericUtils::mant; // prepare the codes - constexpr X nan_code = 0x80; - using T_bitwise = typename NumericUtils::bitwise_type; + constexpr uint8_t nan_code = 0x80; + using T_bitwise = typename NumericUtils::bitwise_type; constexpr T_bitwise Inf_bitwise = NumericUtils::Inf; constexpr T_bitwise NegInf_bitwise = NumericUtils::NegInf; @@ -209,13 +210,13 @@ __host__ __device__ Y run_cast_from_f8(X x) constexpr Y Neg0 = bit_cast(Neg0_bitwise); // check if x is 0.0 - if(x == 0) + if(!static_cast(x)) return static_cast(0); // unpack the input - uint32_t sign = x >> (in_exp + in_mant); - uint32_t mantissa = x & ((1 << in_mant) - 1); - int exponent = (x & 0x7F) >> in_mant; + uint32_t sign = static_cast(x) >> (in_exp + in_mant); + uint32_t mantissa = static_cast(x) & ((1 << in_mant) - 1); + int exponent = (static_cast(x) & 0x7F) >> in_mant; constexpr int exp_low_cutoff = (1 << (out_exp - 1)) - (1 << (in_exp - 1)) + 1 - (negative_zero_nan ? 1 : 0); @@ -223,12 +224,12 @@ __host__ __device__ Y run_cast_from_f8(X x) if constexpr(negative_zero_nan) { - if(x == nan_code) + if(static_cast(x) == nan_code) return NaN; } else { - if(x == nan_code) + if(static_cast(x) == nan_code) return Neg0; if(exponent == ((1 << in_exp) - 1)) return (mantissa == 0) ? (sign ? NegInf : Inf) : NaN; diff --git a/include/ck/utility/type_convert.hpp b/include/ck/utility/type_convert.hpp index 290a6c8dd6..913557fc7a 100644 --- a/include/ck/utility/type_convert.hpp +++ b/include/ck/utility/type_convert.hpp @@ -351,7 +351,7 @@ inline __host__ __device__ f8_fnuz_t f8_convert_sr(float x) val.fval = __builtin_amdgcn_fmed3f(val.fval, max_fp8, -max_fp8); ival = __builtin_amdgcn_cvt_sr_fp8_f32(val.fval, rng, ival, 0); // 0 pos val.i32val = ival; - return val.i8val[0]; // little endian + return f8_t{val.i8val[0]}; // little endian #else constexpr bool negative_zero_nan = true; constexpr bool clip = true; @@ -419,7 +419,7 @@ inline __host__ __device__ bf8_fnuz_t f8_convert_sr(float x) val.fval = __builtin_amdgcn_fmed3f(val.fval, max_bf8, -max_bf8); ival = __builtin_amdgcn_cvt_sr_bf8_f32(val.fval, rng, ival, 0); // 0 pos val.i32val = ival; - return val.i8val[0]; // little endian + return bf8_t{val.i8val[0]}; // little endian #else constexpr bool negative_zero_nan = true; constexpr bool clip = true; @@ -655,7 +655,7 @@ inline __host__ __device__ f8_fnuz_t f8_convert_rne(float x) val.fval = __builtin_amdgcn_fmed3f(val.fval, max_fp8, -max_fp8); ival = __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival, false); // false -> WORD0 val.i32val = ival; - return val.i8val[0]; + return f8_t{val.i8val[0]}; #else constexpr bool negative_zero_nan = true; constexpr bool clip = true; @@ -707,7 +707,7 @@ inline __host__ __device__ bf8_fnuz_t f8_convert_rne(float x) val.fval = __builtin_amdgcn_fmed3f(val.fval, max_bf8, -max_bf8); ival = __builtin_amdgcn_cvt_pk_bf8_f32(val.fval, val.fval, ival, false); // false -> WORD0 val.i32val = ival; - return val.i8val[0]; + return bf8_t{val.i8val[0]}; #else constexpr bool negative_zero_nan = true; constexpr bool clip = true; @@ -924,7 +924,7 @@ inline __host__ __device__ float type_convert(f8_fnuz_t x) { #if defined(__gfx94__) float fval; - uint32_t i32val = static_cast(x); + uint32_t i32val = static_cast(static_cast(x)); fval = __builtin_amdgcn_cvt_f32_fp8(i32val, 0); // asm volatile("v_cvt_f32_fp8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val)); return fval; @@ -1430,7 +1430,7 @@ inline __host__ __device__ float type_convert(bf8_fnuz_t x) { #if defined(__gfx94__) float fval; - uint32_t i32val = static_cast(x); + uint32_t i32val = static_cast(static_cast(x)); fval = __builtin_amdgcn_cvt_f32_bf8(i32val, 0); // asm volatile("v_cvt_f32_bf8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val)); return fval; diff --git a/test/data_type/test_bf8_fnuz.cpp b/test/data_type/test_bf8_fnuz.cpp index 4ff796a614..f028c0da73 100644 --- a/test/data_type/test_bf8_fnuz.cpp +++ b/test/data_type/test_bf8_fnuz.cpp @@ -43,9 +43,8 @@ TEST(BF8FNUZ, ConvertFP32Nearest) type_convert(f8_convert_rne(std::numeric_limits::max())), abs_tol); // convert inf float to bf8_fnuz_t and check if it is qNan - ASSERT_NEAR(ck::NumericLimits::QuietNaN(), - f8_convert_rne(std::numeric_limits::infinity()), - abs_tol); + ASSERT_EQ(ck::NumericLimits::QuietNaN(), + f8_convert_rne(std::numeric_limits::infinity())); // positive norm float value to bf8 and back, check if holds float pos_float = 0.0000762939f; ASSERT_NEAR(pos_float, type_convert(f8_convert_rne(pos_float)), abs_tol); @@ -80,9 +79,8 @@ TEST(BF8FNUZ, ConvertFP32Stochastic) type_convert(f8_convert_sr(std::numeric_limits::max())), abs_tol); // convert inf float to bf8_fnuz_t and check if it is qNan - ASSERT_NEAR(ck::NumericLimits::QuietNaN(), - f8_convert_sr(std::numeric_limits::infinity()), - abs_tol); + ASSERT_EQ(ck::NumericLimits::QuietNaN(), + f8_convert_sr(std::numeric_limits::infinity())); // positive norm float value to bf8 and back, check if holds float pos_float = 0.0000762939f; ASSERT_NEAR(pos_float, type_convert(f8_convert_sr(pos_float)), abs_tol); @@ -118,9 +116,8 @@ TEST(BF8FNUZ, ConvertFP16Nearest) type_convert(f8_convert_rne(ck::NumericLimits::Max())), abs_tol); // convert QuietNaN fp16 to bf8_fnuz_t and check if it is QuietNaN - ASSERT_NEAR(ck::NumericLimits::QuietNaN(), - f8_convert_rne(ck::NumericLimits::QuietNaN()), - abs_tol); + ASSERT_EQ(ck::NumericLimits::QuietNaN(), + f8_convert_rne(ck::NumericLimits::QuietNaN())); // positive norm fp16 value to bf8 and back, check if holds half_t pos_half = half_t{0.0000762939}; ASSERT_NEAR(pos_half, type_convert(f8_convert_rne(pos_half)), abs_tol); @@ -155,9 +152,8 @@ TEST(BF8FNUZ, ConvertFP16Stochastic) type_convert(f8_convert_sr(ck::NumericLimits::Max())), abs_tol); // convert QuietNaN fp16 to bf8_fnuz_t and check if it is QuietNaN - ASSERT_NEAR(ck::NumericLimits::QuietNaN(), - f8_convert_sr(ck::NumericLimits::QuietNaN()), - abs_tol); + ASSERT_EQ(ck::NumericLimits::QuietNaN(), + f8_convert_sr(ck::NumericLimits::QuietNaN())); // positive norm fp16 value to bf8 and back, check if holds half_t pos_half = half_t{0.0000762939}; ASSERT_NEAR(pos_half, type_convert(f8_convert_sr(pos_half)), abs_tol); diff --git a/test/data_type/test_fp8_fnuz.cpp b/test/data_type/test_fp8_fnuz.cpp index c2ec6dad94..0cf775f947 100644 --- a/test/data_type/test_fp8_fnuz.cpp +++ b/test/data_type/test_fp8_fnuz.cpp @@ -48,9 +48,8 @@ TEST(FP8FNUZ, ConvertFP32Nearest) type_convert(f8_convert_rne(std::numeric_limits::max())), abs_tol); // convert inf float to f8_fnuz_t and check if it is qNan - ASSERT_NEAR(ck::NumericLimits::QuietNaN(), - f8_convert_rne(std::numeric_limits::infinity()), - abs_tol); + ASSERT_EQ(ck::NumericLimits::QuietNaN(), + f8_convert_rne(std::numeric_limits::infinity())); // positive norm float value to fp8 and back, check if holds float pos_float = 0.017578125f; ASSERT_NEAR(pos_float, type_convert(f8_convert_rne(pos_float)), abs_tol); @@ -85,9 +84,8 @@ TEST(FP8FNUZ, ConvertFP32Stochastic) type_convert(f8_convert_sr(std::numeric_limits::max())), abs_tol); // convert inf float to f8_fnuz_t and check if it is qNan - ASSERT_NEAR(ck::NumericLimits::QuietNaN(), - f8_convert_sr(std::numeric_limits::infinity()), - abs_tol); + ASSERT_EQ(ck::NumericLimits::QuietNaN(), + f8_convert_sr(std::numeric_limits::infinity())); // positive norm float value to fp8 and back, check if holds float pos_float = 0.017578125f; ASSERT_NEAR(pos_float, type_convert(f8_convert_sr(pos_float)), abs_tol); @@ -122,9 +120,8 @@ TEST(FP8FNUZ, ConvertFP16Nearest) type_convert(f8_convert_rne(ck::NumericLimits::Max())), abs_tol); // convert QuietNaN fp16 to f8_fnuz_t and check if it is QuietNaN - ASSERT_NEAR(ck::NumericLimits::QuietNaN(), - f8_convert_rne(ck::NumericLimits::QuietNaN()), - abs_tol); + ASSERT_EQ(ck::NumericLimits::QuietNaN(), + f8_convert_rne(ck::NumericLimits::QuietNaN())); // positive norm fp16 value to fp8 and back, check if holds half_t pos_half = half_t{0.017578125}; ASSERT_NEAR(pos_half, type_convert(f8_convert_rne(pos_half)), abs_tol); @@ -159,9 +156,8 @@ TEST(FP8FNUZ, ConvertFP16Stochastic) type_convert(f8_convert_sr(ck::NumericLimits::Max())), abs_tol); // convert QuietNaN fp16 to f8_fnuz_t and check if it is QuietNaN - ASSERT_NEAR(ck::NumericLimits::QuietNaN(), - f8_convert_sr(ck::NumericLimits::QuietNaN()), - abs_tol); + ASSERT_EQ(ck::NumericLimits::QuietNaN(), + f8_convert_sr(ck::NumericLimits::QuietNaN())); // positive norm fp16 value to fp8 and back, check if holds half_t pos_half = half_t{0.017578125}; ASSERT_NEAR(pos_half, type_convert(f8_convert_sr(pos_half)), abs_tol);