diff --git a/CMakeLists.txt b/CMakeLists.txt index 4ca54d847d..a79af9cc32 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -15,6 +15,12 @@ if (DTYPES) if (DTYPES MATCHES "fp8") add_definitions(-DCK_ENABLE_FP8) set(CK_ENABLE_FP8 "ON") + add_compile_options(-Wno-bit-int-extension) + endif() + if (DTYPES MATCHES "bf8") + add_definitions(-DCK_ENABLE_BF8) + set(CK_ENABLE_BF8 "ON") + add_compile_options(-Wno-bit-int-extension) endif() if (DTYPES MATCHES "fp16") add_definitions(-DCK_ENABLE_FP16) @@ -34,8 +40,9 @@ if (DTYPES) endif() message("DTYPES macro set to ${DTYPES}") else() - add_definitions(-DCK_ENABLE_INT8 -DCK_ENABLE_FP8 -DCK_ENABLE_FP16 -DCK_ENABLE_FP32 -DCK_ENABLE_FP64 -DCK_ENABLE_BF16) + add_definitions(-DCK_ENABLE_INT8 -DCK_ENABLE_FP8 -DCK_ENABLE_BF8 -DCK_ENABLE_FP16 -DCK_ENABLE_FP32 -DCK_ENABLE_FP64 -DCK_ENABLE_BF16) set(CK_ENABLE_ALL_DTYPES "ON") + add_compile_options(-Wno-bit-int-extension) # enable fp8 and bf8 endif() if(DL_KERNELS) @@ -365,6 +372,10 @@ IF(IS_DIRECTORY "${PROJECT_SOURCE_DIR}/library/src/tensor_operation_instance/gpu #message("fp8 instance found!") set(add_inst 1) endif() + if("${cmake_instance}" MATCHES "DTYPES MATCHES \"bf8\" " AND DTYPES MATCHES "bf8") + #message("bf8 instance found!") + set(add_inst 1) + endif() if("${cmake_instance}" MATCHES "DTYPES MATCHES \"fp16\"" AND DTYPES MATCHES "fp16") #message("fp16 instance found!") set(add_inst 1) diff --git a/client_example/20_splitk_gemm/CMakeLists.txt b/client_example/20_splitk_gemm/CMakeLists.txt index a60bada473..5571ed1d70 100644 --- a/client_example/20_splitk_gemm/CMakeLists.txt +++ b/client_example/20_splitk_gemm/CMakeLists.txt @@ -1,2 +1,4 @@ -add_executable(client_splitK_gemm splitK_gemm_fp16_f8.cpp) -target_link_libraries(client_splitK_gemm PRIVATE composable_kernel::device_operations) +if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "fp16") OR NOT DEFINED DTYPES) + add_executable(client_splitK_gemm splitK_gemm_fp16_f8.cpp) + target_link_libraries(client_splitK_gemm PRIVATE composable_kernel::device_operations) +endif() diff --git a/example/01_gemm/CMakeLists.txt b/example/01_gemm/CMakeLists.txt index 3dc2a0966e..5574d09001 100644 --- a/example/01_gemm/CMakeLists.txt +++ b/example/01_gemm/CMakeLists.txt @@ -69,5 +69,7 @@ if(DTYPES MATCHES "fp8" OR NOT DEFINED DTYPES) endif() endif() -add_example_executable(example_gemm_xdl_fp16_f8 gemm_xdl_fp16_f8.cpp) -add_dependencies(example_gemm_xdl example_gemm_xdl_fp16_f8) +if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "fp16") OR NOT DEFINED DTYPES) + add_example_executable(example_gemm_xdl_fp16_f8 gemm_xdl_fp16_f8.cpp) + add_dependencies(example_gemm_xdl example_gemm_xdl_fp16_f8) +endif() diff --git a/include/ck/config.h.in b/include/ck/config.h.in index 13dc5da5d1..1748344756 100644 --- a/include/ck/config.h.in +++ b/include/ck/config.h.in @@ -43,6 +43,9 @@ #ifndef CK_ENABLE_FP8 #define CK_ENABLE_FP8 "ON" #endif +#ifndef CK_ENABLE_BF8 +#define CK_ENABLE_BF8 "ON" +#endif #ifndef CK_ENABLE_FP16 #define CK_ENABLE_FP16 "ON" #endif @@ -66,6 +69,10 @@ #cmakedefine CK_ENABLE_FP8 @CK_ENABLE_FP8@ #endif +#ifndef CK_ENABLE_BF8 +#cmakedefine CK_ENABLE_BF8 @CK_ENABLE_BF8@ +#endif + #ifndef CK_ENABLE_FP16 #cmakedefine CK_ENABLE_FP16 @CK_ENABLE_FP16@ #endif diff --git a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp index 905908a1c3..34ac08b665 100644 --- a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp @@ -89,6 +89,7 @@ struct PassThrough } #endif +#if defined CK_ENABLE_FP8 template <> __host__ __device__ void operator()(f8_t& y, const f8_t& x) const { @@ -118,6 +119,7 @@ struct PassThrough { y = type_convert(x); } +#endif }; struct UnaryConvert @@ -146,6 +148,7 @@ struct ConvertBF16RTN } }; +#if defined CK_ENABLE_FP8 struct ConvertF8SR { // convert to fp8 using stochastic rounding (SR) @@ -162,6 +165,7 @@ struct ConvertF8SR y = f8_convert_sr(x); } }; +#endif struct Scale { diff --git a/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp b/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp index 814969ef42..9ee07b84ae 100644 --- a/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp +++ b/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp @@ -456,6 +456,7 @@ struct mfma_type } }; +#if defined CK_ENABLE_FP8 template <> struct mfma_type { @@ -499,6 +500,7 @@ struct mfma_type intrin_mfma_f32_16x16x32f8f8::Run(a, b, reg_c); } }; +#endif template struct MfmaSelector @@ -640,6 +642,7 @@ struct MfmaSelector } #endif +#if defined CK_ENABLE_FP8 template <> static constexpr auto GetMfma() { @@ -651,6 +654,7 @@ struct MfmaSelector { return MfmaInstr::mfma_f32_16x16x32f8f8; } +#endif static constexpr auto selected_mfma = mfma_type()>{}; @@ -852,7 +856,11 @@ struct XdlopsGemm { static_assert(is_same::value || is_same::value || is_same::value || is_same::value || - is_same::value || is_same::value, + is_same::value +#if defined CK_ENABLE_FP8 + || is_same::value +#endif + , "base base_type must be double, float, half, bfloat16, and int8_t!"); static_for<0, KPack / mfma_instr.k_per_blk, 1>{}([&](auto k) { diff --git a/include/ck/utility/amd_buffer_addressing.hpp b/include/ck/utility/amd_buffer_addressing.hpp index 897cb4f249..694027100f 100644 --- a/include/ck/utility/amd_buffer_addressing.hpp +++ b/include/ck/utility/amd_buffer_addressing.hpp @@ -1127,7 +1127,7 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave, #if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK uint32_t src_addr_shift = src_thread_element_valid ? 0 : 0x80000000; - +#if defined CK_ENABLE_FP8 if constexpr(is_same::value) { auto tmp = amd_buffer_load_impl( @@ -1136,10 +1136,14 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave, } else { +#endif return amd_buffer_load_impl( src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0); +#if defined CK_ENABLE_FP8 } +#endif #else +#if defined CK_ENABLE_FP8 if constexpr(is_same::value) { auto tmp = amd_buffer_load_impl( @@ -1148,11 +1152,14 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave, } else { +#endif vector_t tmp = amd_buffer_load_impl( src_wave_buffer_resource, src_thread_addr_offset, 0); return src_thread_element_valid ? tmp : vector_t(0); +#if defined CK_ENABLE_FP8 } #endif +#endif } // buffer_load requires: @@ -1209,7 +1216,7 @@ __device__ void amd_buffer_store(const typename vector_type_maker::type::t #if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000; - +#if defined CK_ENABLE_FP8 if constexpr(is_same::value) { auto tmp = @@ -1219,12 +1226,16 @@ __device__ void amd_buffer_store(const typename vector_type_maker::type::t } else { +#endif amd_buffer_store_impl( src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0); +#if defined CK_ENABLE_FP8 } +#endif #else if(dst_thread_element_valid) { +#if defined CK_ENABLE_FP8 if constexpr(is_same::value) { auto tmp = bit_cast::type::type>( @@ -1234,9 +1245,12 @@ __device__ void amd_buffer_store(const typename vector_type_maker::type::t } else { +#endif amd_buffer_store_impl( src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0); +#if defined CK_ENABLE_FP8 } +#endif } #endif } diff --git a/include/ck/utility/amd_xdlops.hpp b/include/ck/utility/amd_xdlops.hpp index ea7755036f..a80540515a 100644 --- a/include/ck/utility/amd_xdlops.hpp +++ b/include/ck/utility/amd_xdlops.hpp @@ -355,6 +355,7 @@ struct intrin_mfma_f64_16x16x4f64<16, 16> } }; +#if defined CK_ENABLE_FP8 template struct intrin_mfma_f32_32x32x16f8f8; @@ -417,5 +418,6 @@ struct intrin_mfma_f32_16x16x32f8f8<16, 16> #endif } }; +#endif } // namespace ck #endif diff --git a/include/ck/utility/data_type.hpp b/include/ck/utility/data_type.hpp index c240afa2b8..89100577aa 100644 --- a/include/ck/utility/data_type.hpp +++ b/include/ck/utility/data_type.hpp @@ -12,7 +12,12 @@ using half_t = _Float16; #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 using int4_t = _BitInt(4); #endif -using f8_t = uint8_t; +#if defined CK_ENABLE_FP8 +using f8_t = _BitInt(8); +#endif +#if defined CK_ENABLE_BF8 +using bf8_t = unsigned _BitInt(8); +#endif // vector_type template @@ -143,14 +148,24 @@ struct scalar_type }; #endif +#if defined CK_ENABLE_FP8 template <> struct scalar_type { using type = f8_t; static constexpr index_t vector_size = 1; }; +#endif + +#if defined CK_ENABLE_BF8 +template <> +struct scalar_type +{ + using type = bf8_t; + static constexpr index_t vector_size = 1; +}; +#endif -// template struct vector_type { @@ -953,12 +968,24 @@ using int8x32_t = typename vector_type::type; using int8x64_t = typename vector_type::type; // f8 +#if defined CK_ENABLE_FP8 using f8x2_t = typename vector_type::type; using f8x4_t = typename vector_type::type; using f8x8_t = typename vector_type::type; using f8x16_t = typename vector_type::type; using f8x32_t = typename vector_type::type; using f8x64_t = typename vector_type::type; +#endif + +// bf8 +#if defined CK_ENABLE_BF8 +using bf8x2_t = typename vector_type::type; +using bf8x4_t = typename vector_type::type; +using bf8x8_t = typename vector_type::type; +using bf8x16_t = typename vector_type::type; +using bf8x32_t = typename vector_type::type; +using bf8x64_t = typename vector_type::type; +#endif template struct NumericLimits @@ -1006,21 +1033,109 @@ struct NumericLimits }; #endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 +#if defined CK_ENABLE_FP8 template <> struct NumericLimits { + // negative zero nan mode with exp bias = 8 static constexpr uint8_t binary_min = 0x08; // 0b00001000 - static constexpr uint8_t binary_max = 0x77; // 0b01110111 - static constexpr uint8_t binary_lowest = 0xF7; // 0b11110111 + static constexpr uint8_t binary_max = 0x7F; // 0b01111111 + static constexpr uint8_t binary_lowest = 0xFF; // 0b11111111 static constexpr uint8_t binary_qnan = 0x80; // 0b10000000 + // ieee mode with exp bias = 7 + // static constexpr uint8_t binary_min = 0x08; // 0b00001000 + // static constexpr uint8_t binary_max = 0x77; // 0b01110111 + // static constexpr uint8_t binary_lowest = 0xF7; // 0b11110111 + // static constexpr uint8_t binary_qnan = 0x79; // any sign, exp=1111, mant!=0 - __host__ __device__ static constexpr f8_t Min() { return bit_cast(binary_min); } + __host__ __device__ static constexpr f8_t Min() { return f8_t(binary_min); } - __host__ __device__ static constexpr f8_t Max() { return bit_cast(binary_max); } + __host__ __device__ static constexpr f8_t Max() { return f8_t(binary_max); } - __host__ __device__ static constexpr f8_t Lowest() { return bit_cast(binary_lowest); } + __host__ __device__ static constexpr f8_t Lowest() { return f8_t(binary_lowest); } - __host__ __device__ static constexpr f8_t QuietNaN() { return bit_cast(binary_qnan); } + __host__ __device__ static constexpr f8_t QuietNaN() { return f8_t(binary_qnan); } +}; +#endif + +#if defined CK_ENABLE_BF8 +template <> +struct NumericLimits +{ + // negative zero nan mode with exp bias = 16 + static constexpr uint8_t binary_min = 0x04; // 0b00000100 + static constexpr uint8_t binary_max = 0x7F; // 0b01111111 + static constexpr uint8_t binary_lowest = 0xFF; // 0b11111111 + static constexpr uint8_t binary_qnan = 0x80; // 0b10000000 + // ieee mode with exp bias = 15 + // static constexpr uint8_t binary_min = 0x04; // 0b00000100 + // static constexpr uint8_t binary_max = 0x7B; // 0b01111011 + // static constexpr uint8_t binary_lowest = 0xFB; // 0b11111011 + // static constexpr uint8_t binary_qnan = 0x79; // any sign, exp=1111, mant!= + + __host__ __device__ static constexpr bf8_t Min() { return bf8_t(binary_min); } + + __host__ __device__ static constexpr bf8_t Max() { return bf8_t(binary_max); } + + __host__ __device__ static constexpr bf8_t Lowest() { return bf8_t(binary_lowest); } + + __host__ __device__ static constexpr bf8_t QuietNaN() { return bf8_t(binary_qnan); } +}; +#endif + +template +struct NumericUtils +{ }; +template <> +struct NumericUtils +{ + static constexpr int exp = 8; + static constexpr int mant = 23; + static constexpr uint32_t nan_mask = 0x7F800000; + static constexpr uint32_t head_mask = 0xFF800000; + static constexpr uint32_t mant_mask = 0x7FFFFF; + static constexpr uint32_t exp_mask = 0xFF; + static constexpr uint32_t Inf = 0x7F800000; + static constexpr uint32_t NegInf = 0xFF800000; + static constexpr uint32_t NaN = 0x7F800001; + static constexpr uint32_t Neg0 = 0x80000000; + using bitwise_type = uint32_t; +}; + +template <> +struct NumericUtils +{ + static constexpr int exp = 5; + static constexpr int mant = 10; + static constexpr uint16_t nan_mask = 0x7C00; + static constexpr uint16_t head_mask = 0xFC00; + static constexpr uint16_t mant_mask = 0x3FF; + static constexpr uint16_t exp_mask = 0x1F; + static constexpr uint32_t Inf = 0x7C00; + static constexpr uint32_t NegInf = 0xFC00; + static constexpr uint32_t NaN = 0x7C01; + static constexpr uint32_t Neg0 = 0x8000; + using bitwise_type = uint16_t; +}; + +#if defined CK_ENABLE_FP8 +template <> +struct NumericUtils +{ + static constexpr int exp = 4; + static constexpr int mant = 3; +}; +#endif + +#if defined CK_ENABLE_BF8 +template <> +struct NumericUtils +{ + static constexpr int exp = 5; + static constexpr int mant = 2; +}; +#endif + } // namespace ck diff --git a/include/ck/utility/f8_utils.hpp b/include/ck/utility/f8_utils.hpp index bb13f98154..5fbebb708d 100644 --- a/include/ck/utility/f8_utils.hpp +++ b/include/ck/utility/f8_utils.hpp @@ -5,6 +5,7 @@ #include "ck/utility/data_type.hpp" +#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8 namespace ck { // fp8 rounding modes @@ -22,53 +23,38 @@ namespace ck::utils { namespace { -template -__host__ __device__ f8_t run_cast_to_f8(T x, uint32_t rng) +template +__host__ __device__ Y run_cast_to_f8(X x, uint32_t rng) { - // check data type - constexpr bool is_half = std::is_same::value; - constexpr bool is_float = std::is_same::value; + // fp8/bf8 exponent/mantissa layout + constexpr int out_exp = NumericUtils::exp; + constexpr int out_mant = NumericUtils::mant; - // fp8 exponent/mantissa layout - constexpr int f8_exp = 4; - constexpr int f8_mant = 3; - - // resulting type exponent/mantissa layout - constexpr int type_exp = is_half ? 5 : 8; - constexpr int type_mant = is_half ? 10 : 23; + // original type exponent/mantissa layout + constexpr int in_exp = NumericUtils::exp; + constexpr int in_mant = NumericUtils::mant; int exponent; uint32_t head, mantissa, sign; // nan code is same for float and half - constexpr uint8_t nan_code = 0x80; - constexpr uint32_t nan_mask = is_half ? 0x7C00 : 0x7F800000; + constexpr Y nan_code = 0x80; + constexpr uint32_t nan_mask = NumericUtils::nan_mask; // convert to bitwise - typedef typename std::conditional::value, uint16_t, uint32_t>::type - T_bitwise; + using T_bitwise = typename NumericUtils::bitwise_type; T_bitwise x_bitwise = *(reinterpret_cast(&x)); // unpack the input, depends on datatype - if constexpr(is_float) - { - head = x_bitwise & 0xFF800000; - mantissa = x_bitwise & 0x7FFFFF; - exponent = (head >> type_mant) & 0xFF; - sign = head >> (type_exp + type_mant); - } - else if constexpr(is_half) - { - head = x_bitwise & 0xFC00; - mantissa = x_bitwise & 0x3FF; - exponent = (head >> type_mant) & 0x1F; - sign = head >> (type_exp + type_mant); - } + head = x_bitwise & NumericUtils::head_mask; + mantissa = x_bitwise & NumericUtils::mant_mask; + exponent = (head >> in_mant) & NumericUtils::exp_mask; + sign = head >> (in_exp + in_mant); - uint32_t signed_inf = (sign << (type_exp + type_mant)) + (((1 << type_exp) - 1) << type_mant); - uint32_t drop_mask = (1 << (type_mant - f8_mant)) - 1; - constexpr int max_exp = (1 << f8_exp) - (negative_zero_nan ? 1 : 2); + uint32_t signed_inf = (sign << (in_exp + in_mant)) + (((1 << in_exp) - 1) << in_mant); + uint32_t drop_mask = (1 << (in_mant - out_mant)) - 1; + constexpr int max_exp = (1 << out_exp) - (negative_zero_nan ? 1 : 2); constexpr int exp_low_cutoff = - (1 << (type_exp - 1)) - (1 << (f8_exp - 1)) + 1 - (negative_zero_nan ? 1 : 0); + (1 << (in_exp - 1)) - (1 << (out_exp - 1)) + 1 - (negative_zero_nan ? 1 : 0); if constexpr(negative_zero_nan) { @@ -81,22 +67,35 @@ __host__ __device__ f8_t run_cast_to_f8(T x, uint32_t rng) return signed_inf + (mantissa != 0 ? 1 : 0); } + // if input is half and output is bf8 + if((NumericUtils::mant == 10) && (NumericUtils::mant == 2) && negative_zero_nan && + exponent == 0) + { + exponent += 1; + while(mantissa < (1 << in_mant)) + { + mantissa <<= 1; + exponent -= 1; + } + mantissa &= ~(1 << in_mant); + } + // check if x is 0.0 if(x_bitwise == 0) return 0; exponent -= exp_low_cutoff - 1; if(exponent <= 0) - drop_mask = (1 << (type_mant - f8_mant + 1 - exponent)) - 1; - mantissa += 1 << type_mant; + drop_mask = (1 << (in_mant - out_mant + 1 - exponent)) - 1; + mantissa += 1 << in_mant; // apply random number if needed mantissa += (stoch ? rng : mantissa) & drop_mask; - if(mantissa >= (2 << type_mant)) + if(mantissa >= (2 << in_mant)) { mantissa >>= 1; exponent++; } - mantissa >>= (type_mant - f8_mant); + mantissa >>= (in_mant - out_mant); // check negative exponent if(exponent <= 0) @@ -116,7 +115,7 @@ __host__ __device__ f8_t run_cast_to_f8(T x, uint32_t rng) { if(clip) { - mantissa = (1 << f8_mant) - 1; + mantissa = (1 << out_mant) - 1; exponent = max_exp; } else @@ -127,124 +126,120 @@ __host__ __device__ f8_t run_cast_to_f8(T x, uint32_t rng) // check if x is 0.0 or -0.0 if(exponent == 0 && mantissa == 0) - return negative_zero_nan ? 0 : (sign << (f8_exp + f8_mant)); - mantissa &= (1 << f8_mant) - 1; - return (sign << (f8_exp + f8_mant)) | (exponent << f8_mant) | mantissa; + return negative_zero_nan ? 0 : (sign << (out_exp + out_mant)); + mantissa &= (1 << out_mant) - 1; + return (sign << (out_exp + out_mant)) | (exponent << out_mant) | mantissa; } -template -__host__ __device__ T run_cast_from_f8(f8_t x) +template +__host__ __device__ Y run_cast_from_f8(X x) { - // check data type - constexpr bool is_half = std::is_same::value; - constexpr bool is_float = std::is_same::value; - - // fp8 exponent/mantissa layout - constexpr int f8_exp = 4; - constexpr int f8_mant = 3; + // fp8/bf8 exponent/mantissa layout + constexpr int in_exp = NumericUtils::exp; + constexpr int in_mant = NumericUtils::mant; // resulting type exponent/mantissa layout - constexpr int type_exp = is_half ? 5 : 8; - constexpr int type_mant = is_half ? 10 : 23; + constexpr int out_exp = NumericUtils::exp; + constexpr int out_mant = NumericUtils::mant; // prepare the codes - constexpr uint8_t nan_code = 0x80; - T fInf, fNegInf, fNaN, fNeg0; - if constexpr(is_half) - { - constexpr uint16_t ihInf = 0x7C00; - constexpr uint16_t ihNegInf = 0xFC00; - constexpr uint16_t ihNaN = 0x7C01; - constexpr uint16_t ihNeg0 = 0x8000; - fInf = *(reinterpret_cast(&ihInf)); - fNegInf = *(reinterpret_cast(&ihNegInf)); - fNaN = *(reinterpret_cast(&ihNaN)); - fNeg0 = *(reinterpret_cast(&ihNeg0)); - } - else if constexpr(is_float) - { - constexpr uint32_t ifInf = 0x7F800000; - constexpr uint32_t ifNegInf = 0xFF800000; - constexpr uint32_t ifNaN = 0x7F800001; - constexpr uint32_t ifNeg0 = 0x80000000; - fInf = *(reinterpret_cast(&ifInf)); - fNegInf = *(reinterpret_cast(&ifNegInf)); - fNaN = *(reinterpret_cast(&ifNaN)); - fNeg0 = *(reinterpret_cast(&ifNeg0)); - } + constexpr X nan_code = 0x80; + Y Inf, NegInf, NaN, Neg0; + using T_bitwise = typename NumericUtils::bitwise_type; + + constexpr T_bitwise Inf_bitwise = NumericUtils::Inf; + constexpr T_bitwise NegInf_bitwise = NumericUtils::NegInf; + constexpr T_bitwise NaN_bitwise = NumericUtils::NaN; + constexpr T_bitwise Neg0_bitwise = NumericUtils::Neg0; + + Inf = *(reinterpret_cast(&Inf_bitwise)); + NegInf = *(reinterpret_cast(&NegInf_bitwise)); + NaN = *(reinterpret_cast(&NaN_bitwise)); + Neg0 = *(reinterpret_cast(&Neg0_bitwise)); + + // check if x is 0.0 + if(x == 0) + return static_cast(0); // unpack the input - uint32_t sign = x >> (f8_exp + f8_mant); - uint32_t mantissa = x & ((1 << f8_mant) - 1); - int exponent = (x & 0x7F) >> f8_mant; + uint32_t sign = x >> (in_exp + in_mant); + uint32_t mantissa = x & ((1 << in_mant) - 1); + int exponent = (x & 0x7F) >> in_mant; constexpr int exp_low_cutoff = - (1 << (type_exp - 1)) - (1 << (f8_exp - 1)) + 1 - (negative_zero_nan ? 1 : 0); - typename std::conditional::value, uint16_t, uint32_t>::type retval; + (1 << (out_exp - 1)) - (1 << (in_exp - 1)) + 1 - (negative_zero_nan ? 1 : 0); + T_bitwise retval; if constexpr(negative_zero_nan) { if(x == nan_code) - return fNaN; + return NaN; } else { if(x == nan_code) - return fNeg0; - if(exponent == ((1 << f8_exp) - 1)) - return (mantissa == 0) ? (sign ? fNegInf : fInf) : fNaN; + return Neg0; + if(exponent == ((1 << in_exp) - 1)) + return (mantissa == 0) ? (sign ? NegInf : Inf) : NaN; + } + + if((NumericUtils::mant == 10) && (NumericUtils::mant == 2) && !negative_zero_nan) + { + retval = x; + retval <<= 8; + return *(reinterpret_cast(&retval)); } // subnormal input if(exponent == 0) { // guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above - int sh = 1 + __builtin_clz(mantissa) - ((1 + type_exp + type_mant) - f8_mant); - mantissa <<= sh; - mantissa &= ((1 << f8_mant) - 1); - exponent += 1 - sh; + exponent++; + while(mantissa < (1 << in_mant)) + { + mantissa <<= 1; + exponent--; + } + mantissa &= ((1 << in_mant) - 1); } exponent += exp_low_cutoff - 1; - mantissa <<= type_mant - f8_mant; + mantissa <<= out_mant - in_mant; // subnormal output (occurs when T=half, we=5, negative_zero_nan=true) if(exponent <= 0) { - mantissa |= 1 << type_mant; + mantissa |= 1 << out_mant; mantissa >>= 1 - exponent; exponent = 0; } - retval = (sign << (type_exp + type_mant)) | (exponent << type_mant) | mantissa; - return *(reinterpret_cast(&retval)); + retval = (sign << (out_exp + out_mant)) | (exponent << out_mant) | mantissa; + return *(reinterpret_cast(&retval)); } } // namespace -template -__host__ __device__ f8_t cast_to_f8(T x, uint32_t rng) +template +__host__ __device__ Y cast_to_f8(X x, uint32_t rng) { - // check datatype - constexpr bool is_half = std::is_same::value; - constexpr bool is_float = std::is_same::value; - static_assert(is_half || is_float, "Only half and float can be casted to f8."); + // check datatypes + constexpr bool is_half = std::is_same::value; + constexpr bool is_float = std::is_same::value; + static_assert(is_half || is_float, "Only half and float can be casted."); - return run_cast_to_f8(x, rng); + return run_cast_to_f8(x, rng); } -template -__host__ __device__ T cast_from_f8(f8_t x) +template +__host__ __device__ Y cast_from_f8(X x) { // check datatype - constexpr bool is_half = std::is_same::value; - constexpr bool is_float = std::is_same::value; + constexpr bool is_half = std::is_same::value; + constexpr bool is_float = std::is_same::value; static_assert(is_half || is_float, "only half and float are supported."); - // check if x is 0.0 - if(x == 0) - return static_cast(0); - - return run_cast_from_f8(x); + return run_cast_from_f8(x); } } // namespace ck::utils +#endif diff --git a/include/ck/utility/type_convert.hpp b/include/ck/utility/type_convert.hpp index 65d8940377..5c5447f94e 100644 --- a/include/ck/utility/type_convert.hpp +++ b/include/ck/utility/type_convert.hpp @@ -80,6 +80,7 @@ inline __host__ __device__ constexpr bhalf_t type_convert(int8_ return type_convert(x_fp32); } +#if defined CK_ENABLE_FP8 // convert fp32 to fp8 template <> inline __host__ __device__ f8_t type_convert(float x) @@ -88,8 +89,9 @@ inline __host__ __device__ f8_t type_convert(float x) constexpr bool clip = true; constexpr f8_rounding_mode rm = f8_rounding_mode::standard; constexpr uint32_t rng = 0; - return utils::cast_to_f8( - x, rng); + return utils:: + cast_to_f8(x, + rng); } // convert fp8 to fp32 @@ -97,7 +99,7 @@ template <> inline __host__ __device__ float type_convert(f8_t x) { constexpr bool negative_zero_nan = true; - return utils::cast_from_f8(x); + return utils::cast_from_f8(x); } // convert fp16 to fp8 @@ -108,8 +110,9 @@ inline __host__ __device__ f8_t type_convert(half_t x) constexpr bool clip = true; constexpr f8_rounding_mode rm = f8_rounding_mode::standard; constexpr uint32_t rng = 0; - return utils::cast_to_f8( - x, rng); + return utils:: + cast_to_f8( + x, rng); } // convert fp8 to fp16 @@ -117,8 +120,53 @@ template <> inline __host__ __device__ half_t type_convert(f8_t x) { constexpr bool negative_zero_nan = true; - return utils::cast_from_f8(x); + return utils::cast_from_f8(x); } +#endif + +#if defined CK_ENABLE_BF8 +// convert fp32 to bf8 +template <> +inline __host__ __device__ bf8_t type_convert(float x) +{ + constexpr bool negative_zero_nan = true; + constexpr bool clip = true; + constexpr f8_rounding_mode rm = f8_rounding_mode::standard; + constexpr uint32_t rng = 0; + return utils:: + cast_to_f8( + x, rng); +} + +// convert bf8 to fp32 +template <> +inline __host__ __device__ float type_convert(bf8_t x) +{ + constexpr bool negative_zero_nan = true; + return utils::cast_from_f8(x); +} + +// convert fp16 to bf8 +template <> +inline __host__ __device__ bf8_t type_convert(half_t x) +{ + constexpr bool negative_zero_nan = true; + constexpr bool clip = true; + constexpr f8_rounding_mode rm = f8_rounding_mode::standard; + constexpr uint32_t rng = 0; + return utils:: + cast_to_f8( + x, rng); +} + +// convert bf8 to fp16 +template <> +inline __host__ __device__ half_t type_convert(bf8_t x) +{ + constexpr bool negative_zero_nan = true; + return utils::cast_from_f8(x); +} +#endif // Declare a template function for bf16 conversion using RTN template @@ -181,6 +229,7 @@ inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn(h template __host__ __device__ constexpr Y f8_convert_sr(X x); +#if defined CK_ENABLE_FP8 // convert fp32 to fp8 with stochastic rounding template <> inline __host__ __device__ f8_t f8_convert_sr(float x) @@ -191,8 +240,9 @@ inline __host__ __device__ f8_t f8_convert_sr(float x) constexpr int seed = 42; // as thread id is not available on host, use 0 for prn generation uint32_t rng = prand_generator(reinterpret_cast(&x), x); - return utils::cast_to_f8( - x, rng); + return utils:: + cast_to_f8(x, + rng); } // convert fp16 to fp8 with stochastic rounding @@ -205,8 +255,42 @@ inline __host__ __device__ f8_t f8_convert_sr(half_t x) constexpr int seed = 42; // as thread id is not available on host, use 0 for prn generation uint32_t rng = prand_generator(reinterpret_cast(&x), x); - return utils::cast_to_f8( - x, rng); + return utils:: + cast_to_f8( + x, rng); +} +#endif + +#if defined CK_ENABLE_BF8 +// convert fp32 to bf8 with stochastic rounding +template <> +inline __host__ __device__ bf8_t f8_convert_sr(float x) +{ + constexpr bool negative_zero_nan = true; + constexpr bool clip = true; + constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic; + constexpr int seed = 42; + // as thread id is not available on host, use 0 for prn generation + uint32_t rng = prand_generator(reinterpret_cast(&x), x); + return utils:: + cast_to_f8( + x, rng); } +// convert fp16 to bf8 with stochastic rounding +template <> +inline __host__ __device__ bf8_t f8_convert_sr(half_t x) +{ + constexpr bool negative_zero_nan = true; + constexpr bool clip = true; + constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic; + constexpr int seed = 42; + // as thread id is not available on host, use 0 for prn generation + uint32_t rng = prand_generator(reinterpret_cast(&x), x); + return utils:: + cast_to_f8( + x, rng); +} +#endif + } // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp b/library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp index 84d31ce267..ea11fd2e1a 100644 --- a/library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp +++ b/library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp @@ -17,10 +17,15 @@ namespace instance { using F64 = double; using F32 = float; using F16 = ck::half_t; -using F8 = ck::f8_t; using BF16 = ck::bhalf_t; using I8 = int8_t; using I32 = int32_t; +#if defined CK_ENABLE_FP8 +using F8 = ck::f8_t; +#endif +#if defined CK_ENABLE_BF8 +using BF8 = ck::bf8_t; +#endif using Empty_Tuple = ck::Tuple<>; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_multiply_add.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm_multiply_add.hpp index badc06dd6f..64c74d4795 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/gemm_multiply_add.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_multiply_add.hpp @@ -45,6 +45,7 @@ void add_device_gemm_multiply_add_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_ PassThrough, MultiplyAdd>>>&); +#if defined CK_ENABLE_FP8 void add_device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_kn_mn_mn_mn_instances( std::vector>>&); +#endif // GEMM + Multiply + Add template && is_same_v && is_same_v && is_same_v && is_same_v) @@ -150,6 +153,7 @@ struct DeviceOperationInstanceFactory>>& instances); +#if defined CK_ENABLE_FP8 void add_device_gemm_xdl_splitk_f8_f16_f16_km_kn_mn_instances( std::vector>>& @@ -96,6 +97,7 @@ void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances( std::vector>>& instances); +#endif template && is_same_v && is_same_v) { @@ -224,6 +227,7 @@ struct DeviceOperationInstanceFactory< add_device_gemm_xdl_splitk_f16_f8_f16_km_nk_mn_instances(op_ptrs); } } +#endif return op_ptrs; } diff --git a/library/include/ck/library/utility/check_err.hpp b/library/include/ck/library/utility/check_err.hpp index 8a72631376..c0f9ba2edc 100644 --- a/library/include/ck/library/utility/check_err.hpp +++ b/library/include/ck/library/utility/check_err.hpp @@ -230,5 +230,99 @@ check_err(const Range& out, return res; } +#if defined CK_ENABLE_FP8 +template +std::enable_if_t<(std::is_same_v, ranges::range_value_t> && + std::is_same_v, f8_t>), + bool> +check_err(const Range& out, + const RefRange& ref, + const std::string& msg = "Error: Incorrect results!", + double rtol = 1e-3, + double atol = 1e-3) +{ + if(out.size() != ref.size()) + { + std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size() + << std::endl; + return false; + } + + bool res{true}; + int err_count = 0; + double err = 0; + double max_err = std::numeric_limits::min(); + for(std::size_t i = 0; i < ref.size(); ++i) + { + const double o = type_convert(*std::next(std::begin(out), i)); + const double r = type_convert(*std::next(std::begin(ref), i)); + err = std::abs(o - r); + if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r)) + { + max_err = err > max_err ? err : max_err; + err_count++; + if(err_count < 5) + { + std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i + << "] != ref[" << i << "]: " << o << " != " << r << std::endl; + } + res = false; + } + } + if(!res) + { + std::cerr << std::setw(12) << std::setprecision(7) << "max err: " << max_err << std::endl; + } + return res; +} +#endif + +#if defined CK_ENABLE_BF8 +template +std::enable_if_t<(std::is_same_v, ranges::range_value_t> && + std::is_same_v, bf8_t>), + bool> +check_err(const Range& out, + const RefRange& ref, + const std::string& msg = "Error: Incorrect results!", + double rtol = 1e-3, + double atol = 1e-3) +{ + if(out.size() != ref.size()) + { + std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size() + << std::endl; + return false; + } + + bool res{true}; + int err_count = 0; + double err = 0; + double max_err = std::numeric_limits::min(); + for(std::size_t i = 0; i < ref.size(); ++i) + { + const double o = type_convert(*std::next(std::begin(out), i)); + const double r = type_convert(*std::next(std::begin(ref), i)); + err = std::abs(o - r); + if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r)) + { + max_err = err > max_err ? err : max_err; + err_count++; + if(err_count < 5) + { + std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i + << "] != ref[" << i << "]: " << o << " != " << r << std::endl; + } + res = false; + } + } + if(!res) + { + std::cerr << std::setw(12) << std::setprecision(7) << "max err: " << max_err << std::endl; + } + return res; +} +#endif + } // namespace utils } // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_multiply_add/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_multiply_add/CMakeLists.txt index a4f7443455..36bd6a4aa3 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_multiply_add/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_multiply_add/CMakeLists.txt @@ -1,7 +1,13 @@ -add_instance_library(device_gemm_multiply_add_instance - device_gemm_multiply_add_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instance.cpp - device_gemm_multiply_add_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instance.cpp +set(GEMM_MULTIPLY_ADD_INSTANCES) - device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_kn_mn_mn_mn_instance.cpp - device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_nk_mn_mn_mn_instance.cpp -) +if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) + list(APPEND GEMM_MULTIPLY_ADD_INSTANCES device_gemm_multiply_add_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instance.cpp) + list(APPEND GEMM_MULTIPLY_ADD_INSTANCES device_gemm_multiply_add_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instance.cpp) +endif() + +if((DTYPES MATCHES "fp16" AND DTYPES MATCHES "fp8") OR NOT DEFINED DTYPES) + list(APPEND GEMM_MULTIPLY_ADD_INSTANCES device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_kn_mn_mn_mn_instance.cpp) + list(APPEND GEMM_MULTIPLY_ADD_INSTANCES device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_nk_mn_mn_mn_instance.cpp) +endif() + +add_instance_library(device_gemm_multiply_add_instance ${GEMM_MULTIPLY_ADD_INSTANCES}) diff --git a/library/src/tensor_operation_instance/gpu/gemm_splitk/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_splitk/CMakeLists.txt index 89dfa8f2ed..043b28a1be 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_splitk/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_splitk/CMakeLists.txt @@ -14,7 +14,7 @@ if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instance.cpp) endif() -if(DTYPES MATCHES "fp16" OR DTYPES MATCHES "fp8" OR NOT DEFINED DTYPES) +if((DTYPES MATCHES "fp16" AND DTYPES MATCHES "fp8") OR NOT DEFINED DTYPES) list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instance.cpp) list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f8_f16_f16_mk_nk_mn_instance.cpp) list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f8_f16_f16_km_kn_mn_instance.cpp) diff --git a/profiler/include/profiler/profile_gemm_splitk_impl.hpp b/profiler/include/profiler/profile_gemm_splitk_impl.hpp index fb68bb8811..495513f665 100644 --- a/profiler/include/profiler/profile_gemm_splitk_impl.hpp +++ b/profiler/include/profiler/profile_gemm_splitk_impl.hpp @@ -214,6 +214,7 @@ bool profile_gemm_splitk_impl(int do_verification, << " TFlops, " << gb_per_sec << " GB/s, " << op_name << ", KBatch " << kbatch_curr << std::endl; +#if defined CK_ENABLE_FP8 // set softer tolerances for fp8 if constexpr(is_same_v || is_same_v || is_same_v) @@ -226,8 +227,11 @@ bool profile_gemm_splitk_impl(int do_verification, } else { +#endif pass = pass & ck::utils::check_err(c_m_n_device_result, c_m_n_host_result); +#if defined CK_ENABLE_FP8 } +#endif if(tflops > best_tflops) { diff --git a/profiler/src/profile_gemm_multiply_add.cpp b/profiler/src/profile_gemm_multiply_add.cpp index fd1f5c65c1..98973b2f01 100644 --- a/profiler/src/profile_gemm_multiply_add.cpp +++ b/profiler/src/profile_gemm_multiply_add.cpp @@ -59,9 +59,11 @@ int profile_gemm_multiply_add(int argc, char* argv[]) const int StrideD1 = std::stoi(argv[14]); const int StrideE = std::stoi(argv[15]); - using F8 = ck::f8_t; using F16 = ck::half_t; using F32 = float; +#if defined CK_ENABLE_FP8 + using F8 = ck::f8_t; +#endif using Row = ck::tensor_layout::gemm::RowMajor; using Col = ck::tensor_layout::gemm::ColumnMajor; @@ -132,6 +134,7 @@ int profile_gemm_multiply_add(int argc, char* argv[]) { return profile(F16{}, F16{}, F32{}, F16{}, F16{}, F16{}, Row{}, Col{}, Row{}, Row{}, Row{}); } +#if defined CK_ENABLE_FP8 else if(data_type == MatrixDataType::F16_F8_F32_F32_F16 && layout == MatrixLayout::MK_KN_MN_MN_MN) { @@ -142,6 +145,7 @@ int profile_gemm_multiply_add(int argc, char* argv[]) { return profile(F16{}, F8{}, F32{}, F32{}, F32{}, F16{}, Row{}, Col{}, Row{}, Row{}, Row{}); } +#endif else { std::cout << "this data_type & layout is not implemented" << std::endl; diff --git a/profiler/src/profile_gemm_splitk.cpp b/profiler/src/profile_gemm_splitk.cpp index 617e0b9cd4..9c805fc1d1 100644 --- a/profiler/src/profile_gemm_splitk.cpp +++ b/profiler/src/profile_gemm_splitk.cpp @@ -67,7 +67,9 @@ int profile_gemm_splitk(int argc, char* argv[]) using F32 = float; using F16 = ck::half_t; - using F8 = ck::f8_t; +#if defined CK_ENABLE_FP8 + using F8 = ck::f8_t; +#endif using Row = ck::tensor_layout::gemm::RowMajor; using Col = ck::tensor_layout::gemm::ColumnMajor; @@ -146,6 +148,7 @@ int profile_gemm_splitk(int argc, char* argv[]) { return profile(F16{}, F16{}, F32{}, F16{}, Col{}, Col{}, Row{}); } +#if defined CK_ENABLE_FP8 else if(data_type == GemmDataType::F8_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN) { return profile(F8{}, F16{}, F32{}, F16{}, Row{}, Row{}, Row{}); @@ -178,6 +181,7 @@ int profile_gemm_splitk(int argc, char* argv[]) { return profile(F16{}, F8{}, F32{}, F16{}, Col{}, Col{}, Row{}); } +#endif else { std::cout << "this data_type & layout is not implemented" << std::endl; diff --git a/test/data_type/CMakeLists.txt b/test/data_type/CMakeLists.txt index 2b63727f19..baf0174556 100644 --- a/test/data_type/CMakeLists.txt +++ b/test/data_type/CMakeLists.txt @@ -3,5 +3,12 @@ if (USE_BITINT_EXTENSION_INT4) target_link_libraries(test_int4 PRIVATE utility) endif() -add_gtest_executable(test_fp8 fp8.cpp) -target_link_libraries(test_fp8 PRIVATE utility) +if(DTYPES MATCHES "fp8" OR NOT DEFINED DTYPES) + add_gtest_executable(test_f8 f8.cpp) + target_link_libraries(test_f8 PRIVATE utility) +endif() + +if(DTYPES MATCHES "bf8" OR NOT DEFINED DTYPES) + add_gtest_executable(test_bf8 bf8.cpp) + target_link_libraries(test_bf8 PRIVATE utility) +endif() diff --git a/test/data_type/bf8.cpp b/test/data_type/bf8.cpp new file mode 100644 index 0000000000..6a5fa281e8 --- /dev/null +++ b/test/data_type/bf8.cpp @@ -0,0 +1,158 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "gtest/gtest.h" +#include "ck/utility/data_type.hpp" +#include "ck/utility/type_convert.hpp" + +using ck::bf8_t; +using ck::f8_convert_sr; +using ck::half_t; +using ck::type_convert; + +TEST(BF8, NumericLimits) +{ + // constants given for negative zero nan mode + EXPECT_EQ(ck::NumericLimits::Min(), type_convert(0x04)); + EXPECT_EQ(ck::NumericLimits::Max(), type_convert(0x7F)); + EXPECT_EQ(ck::NumericLimits::Lowest(), type_convert(0xFF)); + EXPECT_EQ(ck::NumericLimits::QuietNaN(), type_convert(0x80)); +} + +TEST(BF8, ConvertFP32Nearest) +{ + // fix the tolerance value + float abs_tol = 1e-6; + // convert 0 float to bf8 and back, check if holds + ASSERT_NEAR(0.0f, type_convert(type_convert(0.0f)), abs_tol); + // convert minimal float to bf8 and back, check if holds + ASSERT_NEAR(std::numeric_limits::min(), + type_convert(type_convert(std::numeric_limits::min())), + abs_tol); + // convert maximal bf8_t to float and check if equal to 57344.0 + ASSERT_NEAR(57344.0f, type_convert(type_convert(57344.0f)), abs_tol); + // convert maximal float to bf8 and back, check if clipped to 57344.0 + ASSERT_NEAR(57344.0f, + type_convert(type_convert(std::numeric_limits::max())), + abs_tol); + // convert inf float to bf8_t and check if it is qNan + ASSERT_NEAR(type_convert(0x80), + type_convert(std::numeric_limits::infinity()), + abs_tol); + // positive norm float value to bf8 and back, check if holds + float pos_float = 0.0000762939f; + ASSERT_NEAR(pos_float, type_convert(type_convert(pos_float)), abs_tol); + // negative norm float value to bf8 and back, check if holds + float neg_float = -0.0000610351f; + ASSERT_NEAR(neg_float, type_convert(type_convert(neg_float)), abs_tol); + // positive subnorm float value to bf8 and back, check if holds + pos_float = 0.0000305175f; + ASSERT_NEAR(pos_float, type_convert(type_convert(pos_float)), abs_tol); + // negative subnorm float value to bf8 and back, check if holds + neg_float = -0.0000152587f; + ASSERT_NEAR(neg_float, type_convert(type_convert(neg_float)), abs_tol); +} + +TEST(BF8, ConvertFP32Stochastic) +{ + // fix the tolerance value + float abs_tol = 1e-6; + // convert 0 float to bf8 and back, check if holds + ASSERT_NEAR(0.0f, type_convert(f8_convert_sr(0.0f)), abs_tol); + // convert minimal float to bf8 and back, check if holds + ASSERT_NEAR(std::numeric_limits::min(), + type_convert(f8_convert_sr(std::numeric_limits::min())), + abs_tol); + // convert maximal bf8_t to float and check if equal to 57344.0 + ASSERT_NEAR(57344.0f, type_convert(f8_convert_sr(57344.0f)), abs_tol); + // convert maximal float to bf8 and back, check if clipped to 57344.0 + ASSERT_NEAR(57344.0f, + type_convert(f8_convert_sr(std::numeric_limits::max())), + abs_tol); + // convert inf float to bf8_t and check if it is qNan + ASSERT_NEAR(type_convert(0x80), + f8_convert_sr(std::numeric_limits::infinity()), + abs_tol); + // positive norm float value to bf8 and back, check if holds + float pos_float = 0.0000762939f; + ASSERT_NEAR(pos_float, type_convert(f8_convert_sr(pos_float)), abs_tol); + // negative norm float value to bf8 and back, check if holds + float neg_float = -0.0000610351f; + ASSERT_NEAR(neg_float, type_convert(f8_convert_sr(neg_float)), abs_tol); + // positive subnorm float value to bf8 and back, check if holds + pos_float = 0.0000305175f; + ASSERT_NEAR(pos_float, type_convert(f8_convert_sr(pos_float)), abs_tol); + // negative subnorm float value to bf8 and back, check if holds + neg_float = -0.0000152587f; + ASSERT_NEAR(neg_float, type_convert(f8_convert_sr(neg_float)), abs_tol); +} + +TEST(BF8, ConvertFP16Nearest) +{ + // fix the tolerance value + float abs_tol = 1e-3; + // convert 0 fp16 to bf8 and back, check if holds + ASSERT_NEAR(half_t{0.0}, type_convert(type_convert(half_t{0.0})), abs_tol); + // convert minimal fp16 to bf8 and back, check if holds + ASSERT_NEAR(ck::NumericLimits::Min(), + type_convert(type_convert(ck::NumericLimits::Min())), + abs_tol); + // convert maximal bf8_t to fp16 and check if equal to 57344.0 + ASSERT_NEAR( + half_t{57344.0}, type_convert(type_convert(half_t{57344.0})), abs_tol); + // convert maximal fp16 to bf8 and back, check if clipped to 57344.0 + ASSERT_NEAR(half_t{57344.0}, + type_convert(type_convert(ck::NumericLimits::Max())), + abs_tol); + // convert QuietNaN fp16 to bf8_t and check if it is QuietNaN + ASSERT_NEAR(type_convert(0x80), + type_convert(ck::NumericLimits::QuietNaN()), + abs_tol); + // positive norm fp16 value to bf8 and back, check if holds + half_t pos_half = half_t{0.0000762939}; + ASSERT_NEAR(pos_half, type_convert(type_convert(pos_half)), abs_tol); + // negative norm fp16 value to bf8 and back, check if holds + half_t neg_half = half_t{-0.0000610351}; + ASSERT_NEAR(neg_half, type_convert(type_convert(neg_half)), abs_tol); + // positive subnorm fp16 value to bf8 and back, check if holds + pos_half = half_t{0.0000305175}; + ASSERT_NEAR(pos_half, type_convert(type_convert(pos_half)), abs_tol); + // negative subnorm fp16 value to bf8 and back, check if holds + neg_half = half_t{-0.0000152587}; + ASSERT_NEAR(neg_half, type_convert(type_convert(neg_half)), abs_tol); +} + +TEST(BF8, ConvertFP16Stochastic) +{ + // fix the tolerance value + float abs_tol = 1e-3; + // convert 0 fp16 to bf8 and back, check if holds + ASSERT_NEAR(half_t{0.0}, type_convert(f8_convert_sr(half_t{0.0})), abs_tol); + // convert minimal fp16 to bf8 and back, check if holds + ASSERT_NEAR(ck::NumericLimits::Min(), + type_convert(f8_convert_sr(ck::NumericLimits::Min())), + abs_tol); + // convert maximal bf8_t to fp16 and check if equal to 57344.0 + ASSERT_NEAR( + half_t{57344.0}, type_convert(f8_convert_sr(half_t{57344.0})), abs_tol); + // convert maximal fp16 to bf8 and back, check if clipped to 57344.0 + ASSERT_NEAR(half_t{57344.0}, + type_convert(f8_convert_sr(ck::NumericLimits::Max())), + abs_tol); + // convert QuietNaN fp16 to bf8_t and check if it is QuietNaN + ASSERT_NEAR(type_convert(0x80), + f8_convert_sr(ck::NumericLimits::QuietNaN()), + abs_tol); + // positive norm fp16 value to bf8 and back, check if holds + half_t pos_half = half_t{0.0000762939}; + ASSERT_NEAR(pos_half, type_convert(f8_convert_sr(pos_half)), abs_tol); + // negative norm fp16 value to bf8 and back, check if holds + half_t neg_half = half_t{-0.0000610351}; + ASSERT_NEAR(neg_half, type_convert(f8_convert_sr(neg_half)), abs_tol); + // positive subnorm fp16 value to bf8 and back, check if holds + pos_half = half_t{0.0000305175}; + ASSERT_NEAR(pos_half, type_convert(f8_convert_sr(pos_half)), abs_tol); + // negative subnorm fp16 value to bf8 and back, check if holds + neg_half = half_t{-0.0000152587}; + ASSERT_NEAR(neg_half, type_convert(f8_convert_sr(neg_half)), abs_tol); +} diff --git a/test/data_type/fp8.cpp b/test/data_type/f8.cpp similarity index 58% rename from test/data_type/fp8.cpp rename to test/data_type/f8.cpp index 5004fe9527..0612a1cf44 100644 --- a/test/data_type/fp8.cpp +++ b/test/data_type/f8.cpp @@ -12,10 +12,11 @@ using ck::type_convert; TEST(FP8, NumericLimits) { - EXPECT_EQ(ck::NumericLimits::Min(), 0x08); - EXPECT_EQ(ck::NumericLimits::Max(), 0x77); - EXPECT_EQ(ck::NumericLimits::Lowest(), 0xF7); - EXPECT_EQ(ck::NumericLimits::QuietNaN(), 0x80); + // constants given for negative zero nan mode + EXPECT_EQ(ck::NumericLimits::Min(), type_convert(0x08)); + EXPECT_EQ(ck::NumericLimits::Max(), type_convert(0x7F)); + EXPECT_EQ(ck::NumericLimits::Lowest(), type_convert(0xFF)); + EXPECT_EQ(ck::NumericLimits::QuietNaN(), type_convert(0x80)); } TEST(FP8, ConvertFP32Nearest) @@ -35,12 +36,20 @@ TEST(FP8, ConvertFP32Nearest) type_convert(type_convert(std::numeric_limits::max())), abs_tol); // convert inf float to f8_t and check if it is qNan - ASSERT_NEAR(0x80, type_convert(std::numeric_limits::infinity()), abs_tol); - // positive float value to fp8 and back, check if holds - float pos_float = 0.0078125f; + ASSERT_NEAR(type_convert(0x80), + type_convert(std::numeric_limits::infinity()), + abs_tol); + // positive norm float value to fp8 and back, check if holds + float pos_float = 0.017578125f; ASSERT_NEAR(pos_float, type_convert(type_convert(pos_float)), abs_tol); - // negative float value to fp8 and back, check if holds - float neg_float = -0.0156250f; + // negative norm float value to fp8 and back, check if holds + float neg_float = -0.015625f; + ASSERT_NEAR(neg_float, type_convert(type_convert(neg_float)), abs_tol); + // positive subnorm float value to fp8 and back, check if holds + pos_float = 0.00390625f; + ASSERT_NEAR(pos_float, type_convert(type_convert(pos_float)), abs_tol); + // negative subnorm float value to fp8 and back, check if holds + neg_float = -0.001953125f; ASSERT_NEAR(neg_float, type_convert(type_convert(neg_float)), abs_tol); } @@ -61,12 +70,20 @@ TEST(FP8, ConvertFP32Stochastic) type_convert(f8_convert_sr(std::numeric_limits::max())), abs_tol); // convert inf float to f8_t and check if it is qNan - ASSERT_NEAR(0x80, f8_convert_sr(std::numeric_limits::infinity()), abs_tol); - // positive float value to fp8 and back, check if holds - float pos_float = 0.0078125f; + ASSERT_NEAR(type_convert(0x80), + f8_convert_sr(std::numeric_limits::infinity()), + abs_tol); + // positive norm float value to fp8 and back, check if holds + float pos_float = 0.017578125f; ASSERT_NEAR(pos_float, type_convert(f8_convert_sr(pos_float)), abs_tol); - // negative float value to fp8 and back, check if holds - float neg_float = -0.0156250f; + // negative norm float value to fp8 and back, check if holds + float neg_float = -0.015625f; + ASSERT_NEAR(neg_float, type_convert(f8_convert_sr(neg_float)), abs_tol); + // positive subnorm float value to fp8 and back, check if holds + pos_float = 0.00390625f; + ASSERT_NEAR(pos_float, type_convert(f8_convert_sr(pos_float)), abs_tol); + // negative subnorm float value to fp8 and back, check if holds + neg_float = -0.001953125f; ASSERT_NEAR(neg_float, type_convert(f8_convert_sr(neg_float)), abs_tol); } @@ -87,12 +104,20 @@ TEST(FP8, ConvertFP16Nearest) type_convert(type_convert(ck::NumericLimits::Max())), abs_tol); // convert QuietNaN fp16 to f8_t and check if it is QuietNaN - ASSERT_NEAR(0x80, type_convert(ck::NumericLimits::QuietNaN()), abs_tol); - // positive fp16 value to fp8 and back, check if holds - half_t pos_half = half_t{0.0078125}; + ASSERT_NEAR(type_convert(0x80), + type_convert(ck::NumericLimits::QuietNaN()), + abs_tol); + // positive norm fp16 value to fp8 and back, check if holds + half_t pos_half = half_t{0.017578125}; ASSERT_NEAR(pos_half, type_convert(type_convert(pos_half)), abs_tol); - // negative fp16 value to fp8 and back, check if holds - half_t neg_half = half_t{-0.0156250}; + // negative norm fp16 value to fp8 and back, check if holds + half_t neg_half = half_t{-0.015625}; + ASSERT_NEAR(neg_half, type_convert(type_convert(neg_half)), abs_tol); + // positive subnorm fp16 value to fp8 and back, check if holds + pos_half = half_t{0.00390625}; + ASSERT_NEAR(pos_half, type_convert(type_convert(pos_half)), abs_tol); + // negative subnorm fp16 value to fp8 and back, check if holds + neg_half = half_t{-0.001953125}; ASSERT_NEAR(neg_half, type_convert(type_convert(neg_half)), abs_tol); } @@ -113,11 +138,19 @@ TEST(FP8, ConvertFP16Stochastic) type_convert(f8_convert_sr(ck::NumericLimits::Max())), abs_tol); // convert QuietNaN fp16 to f8_t and check if it is QuietNaN - ASSERT_NEAR(0x80, f8_convert_sr(ck::NumericLimits::QuietNaN()), abs_tol); - // positive fp16 value to fp8 and back, check if holds - half_t pos_half = half_t{0.0078125}; + ASSERT_NEAR(type_convert(0x80), + f8_convert_sr(ck::NumericLimits::QuietNaN()), + abs_tol); + // positive norm fp16 value to fp8 and back, check if holds + half_t pos_half = half_t{0.017578125}; ASSERT_NEAR(pos_half, type_convert(f8_convert_sr(pos_half)), abs_tol); - // negative fp16 value to fp8 and back, check if holds - half_t neg_half = half_t{-0.0156250}; + // negative norm fp16 value to fp8 and back, check if holds + half_t neg_half = half_t{-0.015625}; + ASSERT_NEAR(neg_half, type_convert(f8_convert_sr(neg_half)), abs_tol); + // positive subnorm fp16 value to fp8 and back, check if holds + pos_half = half_t{0.00390625}; + ASSERT_NEAR(pos_half, type_convert(f8_convert_sr(pos_half)), abs_tol); + // negative subnorm fp16 value to fp8 and back, check if holds + neg_half = half_t{-0.001953125}; ASSERT_NEAR(neg_half, type_convert(f8_convert_sr(neg_half)), abs_tol); }