From 09bc04e7a4892d6be55f0dc822151d5f3c15dbbd Mon Sep 17 00:00:00 2001 From: Rostyslav Geyyer <46627076+geyyer@users.noreply.github.com> Date: Mon, 19 Jun 2023 11:20:35 -0500 Subject: [PATCH] FP8 enablement - add a pseudorandom number generator, add conversion methods (#708) * Add basic fp8 definitions and prn-generator * Format * Add fp8<->fp32 type_convert * Format * Split type_convert and cast_to/from_f8 * Format * Minor fix * Minor fix * Move fp8 utils to a separate header * Add elementwise ops * Add fp8_convert_sr * Format * Add element op * Eliminate magic numbers * Split f8_convert_sr in host and device * Format * Add some constexpr * Add a datatype test * Format * Another format * Add fp8<->fp16 tests * Update type_converts * Format * Add fp16 casting functions * Format * Use seed as a runtime arg * Use element location for PRNG * Format * Add fp8<->fp16 to PassThrough element op * Clean up * Merge host and device implementations * Add comments on rounding modes * Remove leftover code * Put type_converts into a separate header * Put random number gen to a separate header * Rearrange f8_utils' namespaces * Refactor type_convert.hpp * Move f8_t definition [ROCm/composable_kernel commit: f0c620c42e753432ded96040abdac1bbae89aab5] --- .../element/unary_element_wise_operation.hpp | 48 ++++ include/ck/utility/common_header.hpp | 1 + include/ck/utility/data_type.hpp | 177 +++---------- include/ck/utility/f8_utils.hpp | 250 ++++++++++++++++++ include/ck/utility/inner_product.hpp | 1 + include/ck/utility/random_gen.hpp | 53 ++++ include/ck/utility/reduction_operator.hpp | 1 + include/ck/utility/type_convert.hpp | 230 ++++++++++++++++ .../ck/library/utility/host_tensor.hpp | 1 + test/data_type/CMakeLists.txt | 3 + test/data_type/fp8.cpp | 123 +++++++++ 11 files changed, 743 insertions(+), 145 deletions(-) create mode 100644 include/ck/utility/f8_utils.hpp create mode 100644 include/ck/utility/random_gen.hpp create mode 100644 include/ck/utility/type_convert.hpp create mode 100644 test/data_type/fp8.cpp diff --git a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp index c3e7706ef3..4fb061fadb 100644 --- a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp @@ -6,6 +6,7 @@ #include "ck/utility/data_type.hpp" #include "ck/utility/math.hpp" #include "ck/utility/math_v2.hpp" +#include "ck/utility/type_convert.hpp" namespace ck { namespace tensor_operation { @@ -81,6 +82,36 @@ struct PassThrough y = x; } #endif + + template <> + __host__ __device__ void operator()(f8_t& y, const f8_t& x) const + { + y = x; + } + + template <> + __host__ __device__ void operator()(float& y, const f8_t& x) const + { + y = type_convert(x); + } + + template <> + __host__ __device__ void operator()(f8_t& y, const float& x) const + { + y = type_convert(x); + } + + template <> + __host__ __device__ void operator()(half_t& y, const f8_t& x) const + { + y = type_convert(x); + } + + template <> + __host__ __device__ void operator()(f8_t& y, const half_t& x) const + { + y = type_convert(x); + } }; struct UnaryConvert @@ -109,6 +140,23 @@ struct ConvertBF16RTN } }; +struct ConvertF8SR +{ + // convert to fp8 using stochastic rounding (SR) + template + __host__ __device__ void operator()(Y& y, const X& x) const + { + // check Y datatype + static_assert(is_same::value, "Data type is not supported by this operation!"); + + // check X datatype + static_assert(is_same::value || is_same::value, + "Data type is not supported by this operation!"); + + y = f8_convert_sr(x); + } +}; + struct Scale { __host__ __device__ Scale(float scale) : scale_(scale) {} diff --git a/include/ck/utility/common_header.hpp b/include/ck/utility/common_header.hpp index 41a9d0b585..f95660a8a4 100644 --- a/include/ck/utility/common_header.hpp +++ b/include/ck/utility/common_header.hpp @@ -24,6 +24,7 @@ #include "ck/utility/tuple.hpp" #include "ck/utility/tuple_helper.hpp" #include "ck/utility/type.hpp" +#include "ck/utility/type_convert.hpp" #include "ck/utility/magic_division.hpp" #include "ck/utility/c_style_pointer_cast.hpp" #include "ck/utility/is_known_at_compile_time.hpp" diff --git a/include/ck/utility/data_type.hpp b/include/ck/utility/data_type.hpp index d43af8a2e3..c240afa2b8 100644 --- a/include/ck/utility/data_type.hpp +++ b/include/ck/utility/data_type.hpp @@ -12,6 +12,7 @@ using half_t = _Float16; #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 using int4_t = _BitInt(4); #endif +using f8_t = uint8_t; // vector_type template @@ -142,6 +143,13 @@ struct scalar_type }; #endif +template <> +struct scalar_type +{ + using type = f8_t; + static constexpr index_t vector_size = 1; +}; + // template struct vector_type @@ -944,151 +952,13 @@ using int8x16_t = typename vector_type::type; using int8x32_t = typename vector_type::type; using int8x64_t = typename vector_type::type; -// Convert X to Y -template -__host__ __device__ constexpr Y type_convert(X x) -{ - static_assert(!std::is_reference_v && !std::is_reference_v); - - return static_cast(x); -} - -// convert bfp16 to fp32 -template <> -inline __host__ __device__ constexpr float type_convert(bhalf_t x) -{ - union - { - uint32_t int32; - float fp32; - } u = {uint32_t(x) << 16}; - - return u.fp32; -} - -// convert fp32 to bfp16 -template <> -inline __host__ __device__ constexpr bhalf_t type_convert(float x) -{ - union - { - float fp32; - uint32_t int32; - } u = {x}; - - return uint16_t(u.int32 >> 16); -} - -// convert bfp16 to fp16 via fp32 -template <> -inline __host__ __device__ constexpr half_t type_convert(bhalf_t x) -{ - float x_fp32 = type_convert(x); - - return static_cast(x_fp32); -} - -// convert fp16 to bfp16 via fp32 -template <> -inline __host__ __device__ constexpr bhalf_t type_convert(half_t x) -{ - float x_fp32 = static_cast(x); - - return type_convert(x_fp32); -} - -// convert bfp16 to int32 via fp32 -template <> -inline __host__ __device__ constexpr int32_t type_convert(bhalf_t x) -{ - float x_fp32 = type_convert(x); - - return static_cast(x_fp32); -} - -// convert int32 to bfp16 via fp32 -template <> -inline __host__ __device__ constexpr bhalf_t type_convert(int32_t x) -{ - float x_fp32 = static_cast(x); - - return type_convert(x_fp32); -} - -// convert bfp16 to int8 via fp32 -template <> -inline __host__ __device__ constexpr int8_t type_convert(bhalf_t x) -{ - float x_fp32 = type_convert(x); - - return static_cast(x_fp32); -} - -// convert int8 to bfp16 via fp32 -template <> -inline __host__ __device__ constexpr bhalf_t type_convert(int8_t x) -{ - float x_fp32 = static_cast(x); - - return type_convert(x_fp32); -} - -// Declare a template function for bf16 conversion using RTN -template -__host__ __device__ constexpr Y bf16_convert_rtn(X x); - -// Convert fp32 to bf16 with RTN if higher precision is needed -template <> -inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn(float x) -{ - union - { - float fp32; - uint32_t int32; - } u = {x}; - - // When the exponent bits are not all 1s, then the value is zero, normal, - // or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus - // 1 if the least significant bit of the bfloat16 mantissa is 1 (odd). - // This causes the bfloat16's mantissa to be incremented by 1 if the 16 - // least significant bits of the float mantissa are greater than 0x8000, - // or if they are equal to 0x8000 and the least significant bit of the - // bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when - // the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already - // has the value 0x7f, then incrementing it causes it to become 0x00 and - // the exponent is incremented by one, which is the next higher FP value - // to the unrounded bfloat16 value. When the bfloat16 value is subnormal - // with an exponent of 0x00 and a mantissa of 0x7f, it may be rounded up - // to a normal value with an exponent of 0x01 and a mantissa of 0x00. - // When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F, - // incrementing it causes it to become an exponent of 0xFF and a mantissa - // of 0x00, which is Inf, the next higher value to the unrounded value. - bool flag0 = ~u.int32 & 0x7f800000; - - // When all of the exponent bits are 1, the value is Inf or NaN. - // Inf is indicated by a zero mantissa. NaN is indicated by any nonzero - // mantissa bit. Quiet NaN is indicated by the most significant mantissa - // bit being 1. Signaling NaN is indicated by the most significant - // mantissa bit being 0 but some other bit(s) being 1. If any of the - // lower 16 bits of the mantissa are 1, we set the least significant bit - // of the bfloat16 mantissa, in order to preserve signaling NaN in case - // the bfloat16's mantissa bits are all 0. - bool flag1 = !flag0 && (u.int32 & 0xffff); - - u.int32 += flag0 ? 0x7fff + ((u.int32 >> 16) & 1) : 0; // Round to nearest, round to even - u.int32 |= flag1 ? 0x10000 : 0x0; // Preserve signaling NaN - - return uint16_t(u.int32 >> 16); -} - -// convert fp16 to bfp16 via fp32 with RTN if higher precision is needed -template <> -inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn(half_t x) -{ - float x_fp32 = static_cast(x); - - return bf16_convert_rtn(x_fp32); -} +// f8 +using f8x2_t = typename vector_type::type; +using f8x4_t = typename vector_type::type; +using f8x8_t = typename vector_type::type; +using f8x16_t = typename vector_type::type; +using f8x32_t = typename vector_type::type; +using f8x64_t = typename vector_type::type; template struct NumericLimits @@ -1136,4 +1006,21 @@ struct NumericLimits }; #endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 +template <> +struct NumericLimits +{ + static constexpr uint8_t binary_min = 0x08; // 0b00001000 + static constexpr uint8_t binary_max = 0x77; // 0b01110111 + static constexpr uint8_t binary_lowest = 0xF7; // 0b11110111 + static constexpr uint8_t binary_qnan = 0x80; // 0b10000000 + + __host__ __device__ static constexpr f8_t Min() { return bit_cast(binary_min); } + + __host__ __device__ static constexpr f8_t Max() { return bit_cast(binary_max); } + + __host__ __device__ static constexpr f8_t Lowest() { return bit_cast(binary_lowest); } + + __host__ __device__ static constexpr f8_t QuietNaN() { return bit_cast(binary_qnan); } +}; + } // namespace ck diff --git a/include/ck/utility/f8_utils.hpp b/include/ck/utility/f8_utils.hpp new file mode 100644 index 0000000000..bb13f98154 --- /dev/null +++ b/include/ck/utility/f8_utils.hpp @@ -0,0 +1,250 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" + +namespace ck { + +// fp8 rounding modes +// use standard for rounding to nearest, the faster one +// use stochastic for stochastic rounding, helps to avoid error accumulation +enum class f8_rounding_mode +{ + standard, + stochastic +}; + +} // namespace ck + +namespace ck::utils { + +namespace { + +template +__host__ __device__ f8_t run_cast_to_f8(T x, uint32_t rng) +{ + // check data type + constexpr bool is_half = std::is_same::value; + constexpr bool is_float = std::is_same::value; + + // fp8 exponent/mantissa layout + constexpr int f8_exp = 4; + constexpr int f8_mant = 3; + + // resulting type exponent/mantissa layout + constexpr int type_exp = is_half ? 5 : 8; + constexpr int type_mant = is_half ? 10 : 23; + + int exponent; + uint32_t head, mantissa, sign; + // nan code is same for float and half + constexpr uint8_t nan_code = 0x80; + constexpr uint32_t nan_mask = is_half ? 0x7C00 : 0x7F800000; + + // convert to bitwise + typedef typename std::conditional::value, uint16_t, uint32_t>::type + T_bitwise; + T_bitwise x_bitwise = *(reinterpret_cast(&x)); + + // unpack the input, depends on datatype + if constexpr(is_float) + { + head = x_bitwise & 0xFF800000; + mantissa = x_bitwise & 0x7FFFFF; + exponent = (head >> type_mant) & 0xFF; + sign = head >> (type_exp + type_mant); + } + else if constexpr(is_half) + { + head = x_bitwise & 0xFC00; + mantissa = x_bitwise & 0x3FF; + exponent = (head >> type_mant) & 0x1F; + sign = head >> (type_exp + type_mant); + } + + uint32_t signed_inf = (sign << (type_exp + type_mant)) + (((1 << type_exp) - 1) << type_mant); + uint32_t drop_mask = (1 << (type_mant - f8_mant)) - 1; + constexpr int max_exp = (1 << f8_exp) - (negative_zero_nan ? 1 : 2); + constexpr int exp_low_cutoff = + (1 << (type_exp - 1)) - (1 << (f8_exp - 1)) + 1 - (negative_zero_nan ? 1 : 0); + + if constexpr(negative_zero_nan) + { + if((x_bitwise & nan_mask) == nan_mask) + return nan_code; + } + else + { + if((x_bitwise & nan_mask) == nan_mask) + return signed_inf + (mantissa != 0 ? 1 : 0); + } + + // check if x is 0.0 + if(x_bitwise == 0) + return 0; + + exponent -= exp_low_cutoff - 1; + if(exponent <= 0) + drop_mask = (1 << (type_mant - f8_mant + 1 - exponent)) - 1; + mantissa += 1 << type_mant; + // apply random number if needed + mantissa += (stoch ? rng : mantissa) & drop_mask; + if(mantissa >= (2 << type_mant)) + { + mantissa >>= 1; + exponent++; + } + mantissa >>= (type_mant - f8_mant); + + // check negative exponent + if(exponent <= 0) + { + if(x_bitwise == 0) + return 0; + else + { + // subnormal range; represented by a subnormal float8 (exponent 0) + // and involves loss of accuracy + mantissa >>= 1 - exponent; + exponent = 0; + } + } + // above range: quantize to maximum possible float of the same sign + else if(exponent > max_exp) + { + if(clip) + { + mantissa = (1 << f8_mant) - 1; + exponent = max_exp; + } + else + { + return signed_inf; + } + } + + // check if x is 0.0 or -0.0 + if(exponent == 0 && mantissa == 0) + return negative_zero_nan ? 0 : (sign << (f8_exp + f8_mant)); + mantissa &= (1 << f8_mant) - 1; + return (sign << (f8_exp + f8_mant)) | (exponent << f8_mant) | mantissa; +} + +template +__host__ __device__ T run_cast_from_f8(f8_t x) +{ + // check data type + constexpr bool is_half = std::is_same::value; + constexpr bool is_float = std::is_same::value; + + // fp8 exponent/mantissa layout + constexpr int f8_exp = 4; + constexpr int f8_mant = 3; + + // resulting type exponent/mantissa layout + constexpr int type_exp = is_half ? 5 : 8; + constexpr int type_mant = is_half ? 10 : 23; + + // prepare the codes + constexpr uint8_t nan_code = 0x80; + T fInf, fNegInf, fNaN, fNeg0; + if constexpr(is_half) + { + constexpr uint16_t ihInf = 0x7C00; + constexpr uint16_t ihNegInf = 0xFC00; + constexpr uint16_t ihNaN = 0x7C01; + constexpr uint16_t ihNeg0 = 0x8000; + fInf = *(reinterpret_cast(&ihInf)); + fNegInf = *(reinterpret_cast(&ihNegInf)); + fNaN = *(reinterpret_cast(&ihNaN)); + fNeg0 = *(reinterpret_cast(&ihNeg0)); + } + else if constexpr(is_float) + { + constexpr uint32_t ifInf = 0x7F800000; + constexpr uint32_t ifNegInf = 0xFF800000; + constexpr uint32_t ifNaN = 0x7F800001; + constexpr uint32_t ifNeg0 = 0x80000000; + fInf = *(reinterpret_cast(&ifInf)); + fNegInf = *(reinterpret_cast(&ifNegInf)); + fNaN = *(reinterpret_cast(&ifNaN)); + fNeg0 = *(reinterpret_cast(&ifNeg0)); + } + + // unpack the input + uint32_t sign = x >> (f8_exp + f8_mant); + uint32_t mantissa = x & ((1 << f8_mant) - 1); + int exponent = (x & 0x7F) >> f8_mant; + + constexpr int exp_low_cutoff = + (1 << (type_exp - 1)) - (1 << (f8_exp - 1)) + 1 - (negative_zero_nan ? 1 : 0); + typename std::conditional::value, uint16_t, uint32_t>::type retval; + + if constexpr(negative_zero_nan) + { + if(x == nan_code) + return fNaN; + } + else + { + if(x == nan_code) + return fNeg0; + if(exponent == ((1 << f8_exp) - 1)) + return (mantissa == 0) ? (sign ? fNegInf : fInf) : fNaN; + } + + // subnormal input + if(exponent == 0) + { + // guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above + int sh = 1 + __builtin_clz(mantissa) - ((1 + type_exp + type_mant) - f8_mant); + mantissa <<= sh; + mantissa &= ((1 << f8_mant) - 1); + exponent += 1 - sh; + } + exponent += exp_low_cutoff - 1; + mantissa <<= type_mant - f8_mant; + + // subnormal output (occurs when T=half, we=5, negative_zero_nan=true) + if(exponent <= 0) + { + mantissa |= 1 << type_mant; + mantissa >>= 1 - exponent; + exponent = 0; + } + + retval = (sign << (type_exp + type_mant)) | (exponent << type_mant) | mantissa; + return *(reinterpret_cast(&retval)); +} + +} // namespace + +template +__host__ __device__ f8_t cast_to_f8(T x, uint32_t rng) +{ + // check datatype + constexpr bool is_half = std::is_same::value; + constexpr bool is_float = std::is_same::value; + static_assert(is_half || is_float, "Only half and float can be casted to f8."); + + return run_cast_to_f8(x, rng); +} + +template +__host__ __device__ T cast_from_f8(f8_t x) +{ + // check datatype + constexpr bool is_half = std::is_same::value; + constexpr bool is_float = std::is_same::value; + static_assert(is_half || is_float, "only half and float are supported."); + + // check if x is 0.0 + if(x == 0) + return static_cast(0); + + return run_cast_from_f8(x); +} + +} // namespace ck::utils diff --git a/include/ck/utility/inner_product.hpp b/include/ck/utility/inner_product.hpp index 7828d21d7f..b13bccb5ab 100644 --- a/include/ck/utility/inner_product.hpp +++ b/include/ck/utility/inner_product.hpp @@ -3,6 +3,7 @@ #pragma once #include "data_type.hpp" +#include "type_convert.hpp" namespace ck { diff --git a/include/ck/utility/random_gen.hpp b/include/ck/utility/random_gen.hpp new file mode 100644 index 0000000000..b7edf26507 --- /dev/null +++ b/include/ck/utility/random_gen.hpp @@ -0,0 +1,53 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck { + +// Pseudo random number generator +// version for fp32 +template {}, bool> = false> +__host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed = seed_t) +{ + uint32_t x = *(reinterpret_cast(&val)); + uint32_t drop_bits = uint32_t(x) & 0xFFFFu; + drop_bits ^= x >> 16; + drop_bits = ((drop_bits & 31) << 11) | (drop_bits >> 5); + drop_bits *= 0x7000149; + // NOTE: If id is in 64 bit, we are only using lower 32 bit. + // So, it can have an effect of using same id for multiple elements when the id is very + // large! + uint32_t rng = (drop_bits ^ 0x13371337 ^ (id * 229791) ^ seed); + return rng; +} + +// version for fp16 +template {}, bool> = false> +__host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed = seed_t) +{ + uint16_t x = *(reinterpret_cast(&val)); + uint32_t drop_bits = uint32_t(x) & 0xFFFFu; + drop_bits = ((drop_bits & 31) << 11) | (drop_bits >> 5); + drop_bits *= 0x7000149; + // NOTE: If id is in 64 bit, we are only using lower 32 bit. + // So, it can have an effect of using same id for multiple elements when the id is very + // large! + uint32_t rng = (drop_bits ^ 0x13371337 ^ (id * 229791) ^ seed); + return rng; +} + +// return 0 if data is not fp16 or fp32 +template {} || std::is_same{}), bool> = false> +__host__ __device__ uint32_t prand_generator(int id, T val, uint32_t seed = seed_t) +{ + std::ignore = id; + std::ignore = val; + std::ignore = seed; + + return 0; +} + +} // namespace ck diff --git a/include/ck/utility/reduction_operator.hpp b/include/ck/utility/reduction_operator.hpp index 0f5b73cb03..36c25203ea 100644 --- a/include/ck/utility/reduction_operator.hpp +++ b/include/ck/utility/reduction_operator.hpp @@ -6,6 +6,7 @@ #include "ck/ck.hpp" #include "ck/utility/data_type.hpp" #include "ck/utility/type.hpp" +#include "ck/utility/type_convert.hpp" namespace ck { diff --git a/include/ck/utility/type_convert.hpp b/include/ck/utility/type_convert.hpp new file mode 100644 index 0000000000..ed83964936 --- /dev/null +++ b/include/ck/utility/type_convert.hpp @@ -0,0 +1,230 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/f8_utils.hpp" +#include "ck/utility/random_gen.hpp" + +namespace ck { + +// Convert X to Y +template +__host__ __device__ constexpr Y type_convert(X x) +{ + static_assert(!std::is_reference_v && !std::is_reference_v); + + return static_cast(x); +} + +// convert bfp16 to fp32 +template <> +inline __host__ __device__ constexpr float type_convert(bhalf_t x) +{ + union + { + uint32_t int32; + float fp32; + } u = {uint32_t(x) << 16}; + + return u.fp32; +} + +// convert fp32 to bfp16 +template <> +inline __host__ __device__ constexpr bhalf_t type_convert(float x) +{ + union + { + float fp32; + uint32_t int32; + } u = {x}; + + return uint16_t(u.int32 >> 16); +} + +// convert bfp16 to fp16 via fp32 +template <> +inline __host__ __device__ constexpr half_t type_convert(bhalf_t x) +{ + float x_fp32 = type_convert(x); + + return static_cast(x_fp32); +} + +// convert fp16 to bfp16 via fp32 +template <> +inline __host__ __device__ constexpr bhalf_t type_convert(half_t x) +{ + float x_fp32 = static_cast(x); + + return type_convert(x_fp32); +} + +// convert bfp16 to int32 via fp32 +template <> +inline __host__ __device__ constexpr int32_t type_convert(bhalf_t x) +{ + float x_fp32 = type_convert(x); + + return static_cast(x_fp32); +} + +// convert int32 to bfp16 via fp32 +template <> +inline __host__ __device__ constexpr bhalf_t type_convert(int32_t x) +{ + float x_fp32 = static_cast(x); + + return type_convert(x_fp32); +} + +// convert bfp16 to int8 via fp32 +template <> +inline __host__ __device__ constexpr int8_t type_convert(bhalf_t x) +{ + float x_fp32 = type_convert(x); + + return static_cast(x_fp32); +} + +// convert int8 to bfp16 via fp32 +template <> +inline __host__ __device__ constexpr bhalf_t type_convert(int8_t x) +{ + float x_fp32 = static_cast(x); + + return type_convert(x_fp32); +} + +// convert fp32 to fp8 +template <> +inline __host__ __device__ f8_t type_convert(float x) +{ + constexpr bool negative_zero_nan = true; + constexpr bool clip = true; + constexpr f8_rounding_mode rm = f8_rounding_mode::standard; + constexpr uint32_t rng = 0; + return utils::cast_to_f8( + x, rng); +} + +// convert fp8 to fp32 +template <> +inline __host__ __device__ float type_convert(f8_t x) +{ + constexpr bool negative_zero_nan = true; + return utils::cast_from_f8(x); +} + +// convert fp16 to fp8 +template <> +inline __host__ __device__ f8_t type_convert(half_t x) +{ + constexpr bool negative_zero_nan = true; + constexpr bool clip = true; + constexpr f8_rounding_mode rm = f8_rounding_mode::standard; + constexpr uint32_t rng = 0; + return utils::cast_to_f8( + x, rng); +} + +// convert fp8 to fp16 +template <> +inline __host__ __device__ half_t type_convert(f8_t x) +{ + constexpr bool negative_zero_nan = true; + return utils::cast_from_f8(x); +} + +// Declare a template function for bf16 conversion using RTN +template +__host__ __device__ constexpr Y bf16_convert_rtn(X x); + +// Convert fp32 to bf16 with RTN if higher precision is needed +template <> +inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn(float x) +{ + union + { + float fp32; + uint32_t int32; + } u = {x}; + + // When the exponent bits are not all 1s, then the value is zero, normal, + // or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus + // 1 if the least significant bit of the bfloat16 mantissa is 1 (odd). + // This causes the bfloat16's mantissa to be incremented by 1 if the 16 + // least significant bits of the float mantissa are greater than 0x8000, + // or if they are equal to 0x8000 and the least significant bit of the + // bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when + // the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already + // has the value 0x7f, then incrementing it causes it to become 0x00 and + // the exponent is incremented by one, which is the next higher FP value + // to the unrounded bfloat16 value. When the bfloat16 value is subnormal + // with an exponent of 0x00 and a mantissa of 0x7f, it may be rounded up + // to a normal value with an exponent of 0x01 and a mantissa of 0x00. + // When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F, + // incrementing it causes it to become an exponent of 0xFF and a mantissa + // of 0x00, which is Inf, the next higher value to the unrounded value. + bool flag0 = ~u.int32 & 0x7f800000; + + // When all of the exponent bits are 1, the value is Inf or NaN. + // Inf is indicated by a zero mantissa. NaN is indicated by any nonzero + // mantissa bit. Quiet NaN is indicated by the most significant mantissa + // bit being 1. Signaling NaN is indicated by the most significant + // mantissa bit being 0 but some other bit(s) being 1. If any of the + // lower 16 bits of the mantissa are 1, we set the least significant bit + // of the bfloat16 mantissa, in order to preserve signaling NaN in case + // the bfloat16's mantissa bits are all 0. + bool flag1 = !flag0 && (u.int32 & 0xffff); + + u.int32 += flag0 ? 0x7fff + ((u.int32 >> 16) & 1) : 0; // Round to nearest, round to even + u.int32 |= flag1 ? 0x10000 : 0x0; // Preserve signaling NaN + + return uint16_t(u.int32 >> 16); +} + +// convert fp16 to bfp16 via fp32 with RTN if higher precision is needed +template <> +inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn(half_t x) +{ + float x_fp32 = static_cast(x); + + return bf16_convert_rtn(x_fp32); +} + +// Declare a template function for fp8 conversion using SR +template +__host__ __device__ constexpr Y f8_convert_sr(X x); + +// convert fp32 to fp8 with stochastic rounding +template <> +inline __host__ __device__ f8_t f8_convert_sr(float x) +{ + constexpr bool negative_zero_nan = true; + constexpr bool clip = true; + constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic; + constexpr int seed = 42; + // as thread id is not available on host, use 0 for prn generation + uint32_t rng = prand_generator(reinterpret_cast(&x), x); + return utils::cast_to_f8( + x, rng); +} + +// convert fp16 to fp8 with stochastic rounding +template <> +inline __host__ __device__ f8_t f8_convert_sr(half_t x) +{ + constexpr bool negative_zero_nan = true; + constexpr bool clip = true; + constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic; + constexpr int seed = 42; + // as thread id is not available on host, use 0 for prn generation + uint32_t rng = prand_generator(reinterpret_cast(&x), x); + return utils::cast_to_f8( + x, rng); +} + +} // namespace ck diff --git a/library/include/ck/library/utility/host_tensor.hpp b/library/include/ck/library/utility/host_tensor.hpp index 91293d29f7..816d834130 100644 --- a/library/include/ck/library/utility/host_tensor.hpp +++ b/library/include/ck/library/utility/host_tensor.hpp @@ -13,6 +13,7 @@ #include "ck/utility/data_type.hpp" #include "ck/utility/span.hpp" +#include "ck/utility/type_convert.hpp" #include "ck/library/utility/algorithm.hpp" #include "ck/library/utility/ranges.hpp" diff --git a/test/data_type/CMakeLists.txt b/test/data_type/CMakeLists.txt index 088fbfec71..2b63727f19 100644 --- a/test/data_type/CMakeLists.txt +++ b/test/data_type/CMakeLists.txt @@ -2,3 +2,6 @@ if (USE_BITINT_EXTENSION_INT4) add_gtest_executable(test_int4 int4.cpp) target_link_libraries(test_int4 PRIVATE utility) endif() + +add_gtest_executable(test_fp8 fp8.cpp) +target_link_libraries(test_fp8 PRIVATE utility) diff --git a/test/data_type/fp8.cpp b/test/data_type/fp8.cpp new file mode 100644 index 0000000000..5004fe9527 --- /dev/null +++ b/test/data_type/fp8.cpp @@ -0,0 +1,123 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "gtest/gtest.h" +#include "ck/utility/data_type.hpp" +#include "ck/utility/type_convert.hpp" + +using ck::f8_convert_sr; +using ck::f8_t; +using ck::half_t; +using ck::type_convert; + +TEST(FP8, NumericLimits) +{ + EXPECT_EQ(ck::NumericLimits::Min(), 0x08); + EXPECT_EQ(ck::NumericLimits::Max(), 0x77); + EXPECT_EQ(ck::NumericLimits::Lowest(), 0xF7); + EXPECT_EQ(ck::NumericLimits::QuietNaN(), 0x80); +} + +TEST(FP8, ConvertFP32Nearest) +{ + // fix the tolerance value + float abs_tol = 1e-6; + // convert 0 float to fp8 and back, check if holds + ASSERT_NEAR(0.0f, type_convert(type_convert(0.0f)), abs_tol); + // convert minimal float to fp8 and back, check if holds + ASSERT_NEAR(std::numeric_limits::min(), + type_convert(type_convert(std::numeric_limits::min())), + abs_tol); + // convert maximal f8_t to float and check if equal to 240.0 + ASSERT_NEAR(240.0f, type_convert(type_convert(240.0f)), abs_tol); + // convert maximal float to fp8 and back, check if clipped to 240.0 + ASSERT_NEAR(240.0f, + type_convert(type_convert(std::numeric_limits::max())), + abs_tol); + // convert inf float to f8_t and check if it is qNan + ASSERT_NEAR(0x80, type_convert(std::numeric_limits::infinity()), abs_tol); + // positive float value to fp8 and back, check if holds + float pos_float = 0.0078125f; + ASSERT_NEAR(pos_float, type_convert(type_convert(pos_float)), abs_tol); + // negative float value to fp8 and back, check if holds + float neg_float = -0.0156250f; + ASSERT_NEAR(neg_float, type_convert(type_convert(neg_float)), abs_tol); +} + +TEST(FP8, ConvertFP32Stochastic) +{ + // fix the tolerance value + float abs_tol = 1e-6; + // convert 0 float to fp8 and back, check if holds + ASSERT_NEAR(0.0f, type_convert(f8_convert_sr(0.0f)), abs_tol); + // convert minimal float to fp8 and back, check if holds + ASSERT_NEAR(std::numeric_limits::min(), + type_convert(f8_convert_sr(std::numeric_limits::min())), + abs_tol); + // convert maximal f8_t to float and check if equal to 240.0 + ASSERT_NEAR(240.0f, type_convert(f8_convert_sr(240.0f)), abs_tol); + // convert maximal float to fp8 and back, check if clipped to 240.0 + ASSERT_NEAR(240.0f, + type_convert(f8_convert_sr(std::numeric_limits::max())), + abs_tol); + // convert inf float to f8_t and check if it is qNan + ASSERT_NEAR(0x80, f8_convert_sr(std::numeric_limits::infinity()), abs_tol); + // positive float value to fp8 and back, check if holds + float pos_float = 0.0078125f; + ASSERT_NEAR(pos_float, type_convert(f8_convert_sr(pos_float)), abs_tol); + // negative float value to fp8 and back, check if holds + float neg_float = -0.0156250f; + ASSERT_NEAR(neg_float, type_convert(f8_convert_sr(neg_float)), abs_tol); +} + +TEST(FP8, ConvertFP16Nearest) +{ + // fix the tolerance value + float abs_tol = 1e-3; + // convert 0 fp16 to fp8 and back, check if holds + ASSERT_NEAR(half_t{0.0}, type_convert(type_convert(half_t{0.0})), abs_tol); + // convert minimal fp16 to fp8 and back, check if holds + ASSERT_NEAR(ck::NumericLimits::Min(), + type_convert(type_convert(ck::NumericLimits::Min())), + abs_tol); + // convert maximal f8_t to fp16 and check if equal to 240.0 + ASSERT_NEAR(half_t{240.0}, type_convert(type_convert(half_t{240.0})), abs_tol); + // convert maximal fp16 to fp8 and back, check if clipped to 240.0 + ASSERT_NEAR(half_t{240.0}, + type_convert(type_convert(ck::NumericLimits::Max())), + abs_tol); + // convert QuietNaN fp16 to f8_t and check if it is QuietNaN + ASSERT_NEAR(0x80, type_convert(ck::NumericLimits::QuietNaN()), abs_tol); + // positive fp16 value to fp8 and back, check if holds + half_t pos_half = half_t{0.0078125}; + ASSERT_NEAR(pos_half, type_convert(type_convert(pos_half)), abs_tol); + // negative fp16 value to fp8 and back, check if holds + half_t neg_half = half_t{-0.0156250}; + ASSERT_NEAR(neg_half, type_convert(type_convert(neg_half)), abs_tol); +} + +TEST(FP8, ConvertFP16Stochastic) +{ + // fix the tolerance value + float abs_tol = 1e-3; + // convert 0 fp16 to fp8 and back, check if holds + ASSERT_NEAR(half_t{0.0}, type_convert(f8_convert_sr(half_t{0.0})), abs_tol); + // convert minimal fp16 to fp8 and back, check if holds + ASSERT_NEAR(ck::NumericLimits::Min(), + type_convert(f8_convert_sr(ck::NumericLimits::Min())), + abs_tol); + // convert maximal f8_t to fp16 and check if equal to 240.0 + ASSERT_NEAR(half_t{240.0}, type_convert(f8_convert_sr(half_t{240.0})), abs_tol); + // convert maximal fp16 to fp8 and back, check if clipped to 240.0 + ASSERT_NEAR(half_t{240.0}, + type_convert(f8_convert_sr(ck::NumericLimits::Max())), + abs_tol); + // convert QuietNaN fp16 to f8_t and check if it is QuietNaN + ASSERT_NEAR(0x80, f8_convert_sr(ck::NumericLimits::QuietNaN()), abs_tol); + // positive fp16 value to fp8 and back, check if holds + half_t pos_half = half_t{0.0078125}; + ASSERT_NEAR(pos_half, type_convert(f8_convert_sr(pos_half)), abs_tol); + // negative fp16 value to fp8 and back, check if holds + half_t neg_half = half_t{-0.0156250}; + ASSERT_NEAR(neg_half, type_convert(f8_convert_sr(neg_half)), abs_tol); +}