From f9938dfdc60d38ebd170656987f8d810777d2c3c Mon Sep 17 00:00:00 2001 From: Andriy Roshchenko <107577548+andriy-ca@users.noreply.github.com> Date: Mon, 27 Jan 2025 11:59:49 -0700 Subject: [PATCH] Add OCP FP8 support in CK_TILE (#1829) * Add OCP FP8 to CK_TILE * Validate OCP FP8 in FMHA FWD under VALID=1 [ROCm/composable_kernel commit: 35aebe5936431a950022cce65b40c78cf3a06cf3] --- example/ck_tile/01_fmha/CMakeLists.txt | 5 + include/ck_tile/core/config.hpp | 18 +- include/ck_tile/core/numeric/float8.hpp | 1225 +++++++++++++--------- include/ck_tile/core/numeric/half.hpp | 11 +- include/ck_tile/core/numeric/numeric.hpp | 3 +- 5 files changed, 767 insertions(+), 495 deletions(-) diff --git a/example/ck_tile/01_fmha/CMakeLists.txt b/example/ck_tile/01_fmha/CMakeLists.txt index 1ba76a523e..9ba3a453fc 100644 --- a/example/ck_tile/01_fmha/CMakeLists.txt +++ b/example/ck_tile/01_fmha/CMakeLists.txt @@ -102,6 +102,11 @@ else() list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_APPENDKV_API=0) endif() +# conditionally specify the use of OCP_FP8 +if(CK_USE_OCP_FP8) + list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) +endif() + # Allow comparing floating points directly in order to check sentinel values list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-float-equal) list(APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-float-equal) diff --git a/include/ck_tile/core/config.hpp b/include/ck_tile/core/config.hpp index a15d2c0402..4c495ba781 100644 --- a/include/ck_tile/core/config.hpp +++ b/include/ck_tile/core/config.hpp @@ -1,13 +1,13 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once #if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \ - defined(__gfx942__) + defined(__gfx942__) || defined(__gfx950__) #define __gfx9__ #endif -#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) || defined(__gfx950__) #define __gfx94__ #endif #if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || \ @@ -230,3 +230,15 @@ #ifndef CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID #define CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID 1 #endif + +#ifndef __HIP_DEVICE_COMPILE__ // for host code +#ifdef CK_TILE_USE_OCP_FP8 +#define CK_TILE_USE_OCP_FP8 1 +#else +#define CK_TILE_USE_OCP_FP8 0 +#endif +#elif defined(__gfx950__) || defined(__gfx12__) // for GPU code +#define CK_TILE_USE_OCP_FP8 1 +#else // for GPU code +#define CK_TILE_USE_OCP_FP8 0 +#endif diff --git a/include/ck_tile/core/numeric/float8.hpp b/include/ck_tile/core/numeric/float8.hpp index b3b1a1f3fb..c4fc6890c6 100644 --- a/include/ck_tile/core/numeric/float8.hpp +++ b/include/ck_tile/core/numeric/float8.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck_tile/core/config.hpp" #include "ck_tile/core/utility/bit_cast.hpp" @@ -14,6 +14,12 @@ #pragma once +#if(defined(__gfx94__) || defined(__gfx12__)) && __HIP_DEVICE_COMPILE__ +#define CK_TILE_FP8_CVT_DEVICE 1 +#else +#define CK_TILE_FP8_CVT_DEVICE 0 +#endif + namespace ck_tile { // fp8 rounding modes @@ -25,15 +31,26 @@ enum class fp8_rounding_mode stochastic }; +/** + * \brief FP8 interpretation used in conversion algorithms + */ +enum class fp8_interpretation +{ + E4M3_OCP = 0, // OCP FP8 E4M3 + E5M2_OCP = 1, // OCP BF8 E5M2 + E4M3_FNUZ = 2, // FNUZ FP8 E4M3 + E5M2_FNUZ = 3, // FNUZ BF8 E5M2 +}; + /* - * ______________NANOO_________________ | ______________IEEE________________ + * ______________FNUZ_________________ | ______________OCP________________ * e4m3 e5m2 | e4m3 e5m2 * bias : 8 16 | 7 15 * inf : 1.0000.000 1.00000.00 | N/A s.11111.00 * Nan : 1.0000.000 1.00000.00 | s.1111.111 s.11111.{01, 10, 11} * zero : 0.0000.000 0.00000.00 | s.0000.000 s.00000.00 * Max(norm) : s.1111.111 (240) s.11111.11(57344) | s.1111.110(448) s.11110.11(57344) - * Max(snorm): s.0000.111 s.00000.11 | s.0000.111(448) s.00000.11(57344) + * Max(snorm): s.0000.111 s.00000.11 | s.0000.111 s.00000.11 * 0.0068359375 2.288818e-05 | 0.013671875 4.57763671875e-05 * Min(norm) : s.0001.000 s.00001.00 | s.0001.000 s.00001.00 * 2^-7(0.00078125) 2^-15(3.05176e-05) | 2^-6(0.015625) 2^-14(6.10352e-05) @@ -55,10 +72,10 @@ struct alignas(1) float8_e4m3_t { static constexpr int exponent = 4; static constexpr int mantissa = 3; -#if defined(__gfx94__) - static constexpr int bias = 1 << (exponent - 1); // NANOO +#if CK_TILE_USE_OCP_FP8 + static constexpr int bias = 7; // OCP #else - static constexpr int bias = (1 << (exponent - 1)) - 1; // IEEE + static constexpr int bias = 8; // FNUZ #endif using raw_type = uint8_t; raw_type data; @@ -113,10 +130,10 @@ struct alignas(1) float8_e5m2_t { static constexpr int exponent = 5; static constexpr int mantissa = 2; -#if defined(__gfx94__) - static constexpr int bias = 1 << (exponent - 1); // NANOO +#if CK_TILE_USE_OCP_FP8 + static constexpr int bias = 15; // OCP #else - static constexpr int bias = (1 << (exponent - 1)) - 1; // IEEE + static constexpr int bias = 16; // FNUZ #endif using raw_type = uint8_t; raw_type data; @@ -183,501 +200,727 @@ struct native_t }; #else + using fp8_t = _BitInt(8); using fp8_raw_t = uint8_t; using bf8_t = unsigned _BitInt(8); using bf8_raw_t = uint8_t; #endif -// below is sw fp8 conversion, not utilizing hw instruction -namespace impl { - -template -CK_TILE_HOST_DEVICE Y run_cast_to_f8(X x, uint32_t rng) -{ - // fp8/bf8 exponent/mantissa layout - constexpr int out_exp = numeric_traits::exp; - constexpr int out_mant = numeric_traits::mant; - - // original type exponent/mantissa layout - constexpr int in_exp = numeric_traits::exp; - constexpr int in_mant = numeric_traits::mant; - - int exponent, bias; - uint32_t head, mantissa, sign; - // nan code is same for float and half -#if CK_TILE_USE_CUSTOM_DATA_TYPE - constexpr Y nan_code = - numeric::quiet_NaN(); // __builtin_bit_cast(Y, static_cast(0x80)); -#else - constexpr Y nan_code = 0x80; -#endif - - constexpr uint32_t nan_mask = numeric_traits::nan_mask; - - // convert to bitwise - using T_bitwise = typename numeric_traits::bitwise_type; - T_bitwise x_bitwise = *(reinterpret_cast(&x)); - - // unpack the input, depends on datatype - head = x_bitwise & numeric_traits::head_mask; - mantissa = x_bitwise & numeric_traits::mant_mask; - exponent = (head >> in_mant) & numeric_traits::exp_mask; - sign = head >> (in_exp + in_mant); - bias = numeric_traits::bias; - - uint32_t signed_inf = (sign << (in_exp + in_mant)) + (((1 << in_exp) - 1) << in_mant); - uint32_t drop_mask = (1 << (in_mant - out_mant)) - 1; - constexpr int max_exp = (1 << out_exp) - (negative_zero_nan ? 1 : 2); - - if constexpr(negative_zero_nan) - { - if((x_bitwise & nan_mask) == nan_mask) - return nan_code; - } - else - { - if((x_bitwise & nan_mask) == nan_mask) - return signed_inf + (mantissa != 0 ? 1 : 0); - } - - // check if x is 0.0 - if(x_bitwise == 0) - return __builtin_bit_cast(Y, static_cast(0)); - - // First need to check if it is normal or denorm as there is a difference of implict 1 - // Then need to adjust the exponent to align with the F8 exponent, in the meanwhile, shift - // The mantissa. Then for stochastic rounding, add rng to mantissa and truncate. And for - // RNE, no need to add rng. Then probably need to check whether there is carry and adjust - // exponent and mantissa again3 - - // For IEEE bias mode, the bias is 2^(k-1)-1 where k is the width of exponent bits - const int out_bias = (1 << (out_exp - 1)) - 1 + (negative_zero_nan ? 1 : 0); - const int out_denormal_act_exponent = 1 - out_bias; // actual exponent of f8 denormal - // act_exponent is the actual exponent of fp32/fp16 (after subtracting bias) - // out_exponent is the converted f8 exponent with bias encoding - // exponent_diff is the diff between fp32/fp16 exponent and f8 exponent, - // the difference needs to be adjusted and mantissa shifted - int act_exponent, out_exponent, exponent_diff; - - if(exponent == 0) - { // fp32/fp16 is in denormal. - /* fp32 denormal is below 2^-127 so it is usually not a concern here, we mostly concern fp16 -here. In this case, f8 is usually in denormal. But there could be exceptions. fp16 denormal has -exponent bias 15 while bf8 with NANOO has exponent bias 16. It means that there are some numbers in -fp16 denormal but they are bf8 (NANOO) normals - smallest bf8 (NANOO) normal is 2^-15. fp16 numbers -where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8 (NANOO) normal. -In this case, the fp16 mantissa should be shift left by 1 */ - act_exponent = exponent - bias + 1; - exponent_diff = out_denormal_act_exponent - - act_exponent; // actual exponent is exponent-bias+1 as it is denormal - } - else - { // fp32/fp16 is normal with implicit 1 - act_exponent = exponent - bias; - if(act_exponent <= out_denormal_act_exponent) - { - /* This is the case where fp32/fp16 is normal but it is in f8 denormal range. - For example fp8 nanoo mode, denormal exponent is -7, but if the fp32/fp16 - actual exponent is -7, it is actually larger due to the implict 1, - Therefore it needs to be adjust to -6 and mantissa shift right by 1. - So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */ - exponent_diff = out_denormal_act_exponent - act_exponent; - } - else - { // both fp32/fp16 and f8 are in normal range - exponent_diff = - 0; // exponent_diff=0 does not mean there is no difference for this case, - // act_exponent could be larger. Just that it does not need shift mantissa - } - mantissa += (1 << in_mant); // Add the implicit 1 into mantissa - } - - bool midpoint = (mantissa & ((1 << (in_mant - out_mant + exponent_diff)) - 1)) == - (1 << (in_mant - out_mant + exponent_diff - 1)); - /* This part is a bit tricky. The judgment of whether it is a tie needs to be done before we - shift right as shift right could rip off some residual part and make something not midpoint look - like midpoint. For example, the fp16 number 0x1002 (0 00100 0000000010), it is larger than - midpoint, but after shift right by 4 bits, it would look like midpoint. */ - - if(exponent_diff > 0) - mantissa >>= exponent_diff; - else if(exponent_diff == -1) - mantissa <<= -exponent_diff; - bool implicit_one = mantissa & (1 << in_mant); - // if there is no implict 1, it means the f8 is denormal and need to adjust to denorm exponent - out_exponent = - (act_exponent + exponent_diff) /*actual f8 exponent*/ + out_bias - (implicit_one ? 0 : 1); - - // Now we have the exponent and mantissa adjusted - bool odd = - mantissa & - (1 << (in_mant - out_mant)); // if the least significant bit that is not truncated is 1 - mantissa += (stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1) : mantissa)) & drop_mask; - - // Now we deal with overflow - if(out_exponent == 0) - { - if((1 << in_mant) & mantissa) - { - out_exponent = 1; // denormal overflow to become normal, promote exponent - // No need to make 1 implicit now as it will be addressed later - } - } - else - { - if((1 << (in_mant + 1)) & mantissa) - { - mantissa >>= 1; - out_exponent++; - // No need to make 1 implicit now as it will be addressed later - } - } - - mantissa >>= (in_mant - out_mant); - - if(out_exponent > max_exp) - { - if(clip) - { - mantissa = (1 << out_mant) - 1; - out_exponent = max_exp; - } - else - { - return __builtin_bit_cast(Y, static_cast(signed_inf)); - } - } - - // check if x is 0.0 or -0.0 - if(out_exponent == 0 && mantissa == 0) - return __builtin_bit_cast( - Y, static_cast(negative_zero_nan ? 0 : (sign << (out_exp + out_mant)))); - mantissa &= (1 << out_mant) - 1; - return __builtin_bit_cast(Y, - static_cast((sign << (out_exp + out_mant)) | - (out_exponent << out_mant) | mantissa)); -} - -template -CK_TILE_HOST_DEVICE Y run_cast_from_f8(X x) -{ - // fp8/bf8 exponent/mantissa layout - constexpr int in_exp = numeric_traits::exp; - constexpr int in_mant = numeric_traits::mant; - - // resulting type exponent/mantissa layout - constexpr int out_exp = numeric_traits::exp; - constexpr int out_mant = numeric_traits::mant; - uint8_t x_raw = __builtin_bit_cast(uint8_t, x); - - // prepare the codes - constexpr uint8_t nan_code = 0x80; - Y Inf, NegInf, NaN, Neg0; - using T_bitwise = typename numeric_traits::bitwise_type; - - constexpr T_bitwise Inf_bitwise = numeric_traits::Inf; - constexpr T_bitwise NegInf_bitwise = numeric_traits::NegInf; - constexpr T_bitwise NaN_bitwise = numeric_traits::NaN; - constexpr T_bitwise Neg0_bitwise = numeric_traits::Neg0; - - Inf = *(reinterpret_cast(&Inf_bitwise)); - NegInf = *(reinterpret_cast(&NegInf_bitwise)); - NaN = *(reinterpret_cast(&NaN_bitwise)); - Neg0 = *(reinterpret_cast(&Neg0_bitwise)); - - // check if x is 0.0 - if(x_raw == 0) - return static_cast(0); - - // unpack the input - uint32_t sign = x_raw >> (in_exp + in_mant); - uint32_t mantissa = x_raw & ((1 << in_mant) - 1); - int exponent = (x_raw & 0x7F) >> in_mant; - - constexpr int exp_low_cutoff = - (1 << (out_exp - 1)) - (1 << (in_exp - 1)) + 1 - (negative_zero_nan ? 1 : 0); - T_bitwise retval; - - if constexpr(negative_zero_nan) - { - if(x_raw == nan_code) - return NaN; - } - else - { - if(x_raw == nan_code) - return Neg0; - if(exponent == ((1 << in_exp) - 1)) - return (mantissa == 0) ? (sign ? NegInf : Inf) : NaN; - } - - if((numeric_traits::mant == 10) && (numeric_traits::mant == 2) && !negative_zero_nan) - { - retval = x_raw; - retval <<= 8; - return *(reinterpret_cast(&retval)); - } - - // subnormal input - if(exponent == 0) - { - // guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above - int sh = 1 + clz(mantissa) - (32 - in_mant); - mantissa <<= sh; - exponent += 1 - sh; - mantissa &= ((1 << in_mant) - 1); - } - exponent += exp_low_cutoff - 1; - mantissa <<= out_mant - in_mant; - - // subnormal output (occurs when T=half, we=5, negative_zero_nan=true) - if(exponent <= 0) - { - mantissa |= 1 << out_mant; - mantissa >>= 1 - exponent; - exponent = 0; - } - - retval = (sign << (out_exp + out_mant)) | (exponent << out_mant) | mantissa; - return *(reinterpret_cast(&retval)); -} - -template -CK_TILE_HOST_DEVICE Y cast_to_f8(X x, uint32_t rng) -{ - // check datatypes - constexpr bool is_half = std::is_same::value; - constexpr bool is_float = std::is_same::value; - static_assert(is_half || is_float, "Only half and float can be casted."); - - return run_cast_to_f8(x, rng); -} - -template -CK_TILE_HOST_DEVICE Y cast_from_f8(X x) -{ - // check datatype - constexpr bool is_half = std::is_same::value; - constexpr bool is_float = std::is_same::value; - static_assert(is_half || is_float, "only half and float are supported."); - - return run_cast_from_f8(x); -} -} // namespace impl - -CK_TILE_HOST_DEVICE fp8_raw_t float_to_fp8_sr_raw(float x) -{ - constexpr int seed = 42; - uint32_t rng = prand_generator_t{}(reinterpret_cast(&x), x); -#if defined(__gfx94__) - float max_fp8 = 240.0f; - x = x > max_fp8 ? max_fp8 : (x < -max_fp8 ? -max_fp8 : x); - union - { - float fval; - uint32_t i32val; - uint8_t i8val[4]; // not endian independent - } val; - val.fval = x; - uint32_t ival = 0; - ival = __builtin_amdgcn_cvt_sr_fp8_f32(val.fval, rng, ival, 0); // 0 pos - val.i32val = ival; - return val.i8val[0]; // little endian -#else - constexpr bool negative_zero_nan = true; - constexpr bool clip = true; - constexpr fp8_rounding_mode rm = fp8_rounding_mode::stochastic; - return bit_cast(impl::cast_to_f8(x, rng)); -#endif -} - -CK_TILE_HOST_DEVICE bf8_raw_t float_to_bf8_sr_raw(float x) -{ - constexpr int seed = 42; - uint32_t rng = prand_generator_t{}(reinterpret_cast(&x), x); -#if defined(__gfx94__) - union - { - float fval; - uint32_t i32val; - uint8_t i8val[4]; // not endian independent - } val; - val.fval = x; - uint32_t ival = 0; - ival = __builtin_amdgcn_cvt_sr_bf8_f32(val.fval, rng, ival, 0); // 0 pos - val.i32val = ival; - return val.i8val[0]; // little endian -#else - constexpr bool negative_zero_nan = true; - constexpr bool clip = true; - constexpr fp8_rounding_mode rm = fp8_rounding_mode::stochastic; - return bit_cast(impl::cast_to_f8(x, rng)); -#endif -} - -CK_TILE_HOST_DEVICE fp8_raw_t float_to_fp8_rtn_raw(float x) -{ -#if defined(__gfx94__) - float max_fp8 = 240.0f; - x = x > max_fp8 ? max_fp8 : (x < -max_fp8 ? -max_fp8 : x); - union - { - float fval; - uint32_t i32val; - uint8_t i8val[4]; // not endian independent - } val; - val.fval = x; - uint32_t ival = 0; - ival = __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival, false); // false -> WORD0 - val.i32val = ival; - return val.i8val[0]; -#else - constexpr bool negative_zero_nan = true; - constexpr bool clip = true; - constexpr fp8_rounding_mode rm = fp8_rounding_mode::standard; - constexpr uint32_t rng = 0; - return bit_cast(impl::cast_to_f8(x, rng)); -#endif -} -CK_TILE_HOST_DEVICE bf8_raw_t float_to_bf8_rtn_raw(float x) -{ -#if defined(__gfx94__) - union - { - float fval; - uint32_t i32val; - uint8_t i8val[4]; // not endian independent - } val; - val.fval = x; - uint32_t ival = 0; - ival = __builtin_amdgcn_cvt_pk_bf8_f32(val.fval, val.fval, ival, false); // false -> WORD0 - val.i32val = ival; - return val.i8val[0]; -#else - constexpr bool negative_zero_nan = true; - constexpr bool clip = true; - constexpr fp8_rounding_mode rm = fp8_rounding_mode::standard; - constexpr uint32_t rng = 0; - return bit_cast(impl::cast_to_f8(x, rng)); -#endif -} - -// clang-format off -template -CK_TILE_HOST_DEVICE fp8_raw_t float_to_fp8_raw(float x, constant) -{ - if constexpr (rounding == fp8_rounding_mode::standard) return float_to_fp8_rtn_raw(x); - else if constexpr (rounding == fp8_rounding_mode::stochastic) return float_to_fp8_sr_raw(x); - else return fp8_raw_t{0}; -} - -template -CK_TILE_HOST_DEVICE bf8_raw_t float_to_bf8_raw(float x, constant) -{ - if constexpr (rounding == fp8_rounding_mode::standard) return float_to_bf8_rtn_raw(x); - else if constexpr (rounding == fp8_rounding_mode::stochastic) return float_to_bf8_sr_raw(x); - else return bf8_raw_t{0}; -} - -CK_TILE_HOST_DEVICE float fp8_to_float_raw(fp8_raw_t x) -{ -#if defined(__gfx94__) - float fval; - uint32_t i32val = static_cast(x); - fval = __builtin_amdgcn_cvt_f32_fp8(i32val, 0); - // asm volatile("v_cvt_f32_fp8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val)); - return fval; -#else - constexpr bool negative_zero_nan = true; - return impl::cast_from_f8(bit_cast(x)); -#endif -} - -CK_TILE_HOST_DEVICE float bf8_to_float_raw(bf8_raw_t x) -{ -#if defined(__gfx94__) - float fval; - uint32_t i32val = static_cast(x); - fval = __builtin_amdgcn_cvt_f32_bf8(i32val, 0); - // asm volatile("v_cvt_f32_bf8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val)); - return fval; -#else - constexpr bool negative_zero_nan = true; - return impl::cast_from_f8(bit_cast(x)); -#endif -} - -template(CK_TILE_FLOAT_TO_FP8_DEFAULT)> -CK_TILE_HOST_DEVICE fp8_t float_to_fp8(float x, constant = {}) -{ - return bit_cast(float_to_fp8_raw(x, constant{})); -} - -template(CK_TILE_FLOAT_TO_FP8_DEFAULT)> -CK_TILE_HOST_DEVICE bf8_t float_to_bf8(float x, constant = {}) -{ - return bit_cast(float_to_bf8_raw(x, constant{})); -} - -CK_TILE_HOST_DEVICE float fp8_to_float(fp8_t x) -{ - return fp8_to_float_raw(bit_cast(x)); -} - -CK_TILE_HOST_DEVICE float bf8_to_float(bf8_t x) -{ - return bf8_to_float_raw(bit_cast(x)); -} - -// clang-format on - template struct numeric_traits; template <> struct numeric_traits { + using bitwise_type = fp8_raw_t; + static constexpr int exp = 4; static constexpr int mant = 3; -#if defined(__gfx94__) - static constexpr int bias = 8; +#if CK_TILE_USE_OCP_FP8 + static constexpr int bias = 7; + static constexpr fp8_interpretation f8_interpret = fp8_interpretation::E4M3_OCP; #else - static constexpr int bias = 7; + static constexpr int bias = 8; + static constexpr fp8_interpretation f8_interpret = fp8_interpretation::E4M3_FNUZ; #endif + static constexpr uint8_t abs_mask = 0x7F; }; template <> struct numeric_traits { + using bitwise_type = bf8_raw_t; + static constexpr int exp = 5; static constexpr int mant = 2; -#if defined(__gfx94__) - static constexpr int bias = 16; +#if CK_TILE_USE_OCP_FP8 + static constexpr int bias = 15; + static constexpr fp8_interpretation f8_interpret = fp8_interpretation::E5M2_OCP; #else - static constexpr int bias = 15; // IEEE + static constexpr int bias = 16; + static constexpr fp8_interpretation f8_interpret = fp8_interpretation::E5M2_FNUZ; #endif + static constexpr uint8_t abs_mask = 0x7F; }; +// below is sw fp8 conversion, not utilizing hw instruction +namespace impl { + +template +CK_TILE_HOST_DEVICE DstT run_cast_to_f8(SrcT src, unsigned int rng = 0) +{ + static_assert(std::is_same::value || std::is_same::value, + "DstT type must be fp8 or bf8."); + + constexpr bool is_half = std::is_same::value; + constexpr bool is_float = std::is_same::value; + static_assert(is_half || is_float, "Only half and float can be cast to f8"); + + // fp8/bf8 type exponent/mantissa layout + constexpr int DstT_exp = numeric_traits::exp; // exponent width of the destination type + constexpr int DstT_mant = numeric_traits::mant; // mantissa width of the destination type + constexpr bool is_fnuz = + (numeric_traits::f8_interpret == fp8_interpretation::E4M3_FNUZ) || + (numeric_traits::f8_interpret == fp8_interpretation::E5M2_FNUZ); + + constexpr int SrcT_exp = numeric_traits::exp; + constexpr int SrcT_mant = numeric_traits::mant; + + using SrcT_bitwise = typename numeric_traits::bitwise_type; + SrcT_bitwise src_bitwise = bit_cast(src); + + unsigned long long head, mantissa; + int exponent, bias; + unsigned int sign; + unsigned long long fInf, abs_mask; + + head = src_bitwise & numeric_traits::head_mask; + mantissa = src_bitwise & numeric_traits::mant_mask; + exponent = (head >> SrcT_mant) & numeric_traits::exp_mask; + sign = head >> (SrcT_exp + SrcT_mant); + bias = numeric_traits::bias; + fInf = numeric_traits::Inf; + abs_mask = numeric_traits::abs_mask; + + unsigned int signed_inf = 0; + unsigned int nan = 0; + if constexpr(is_fnuz) + { + signed_inf = clip ? ((sign << 7) + 0x7f) : 0x80; + nan = 0x80; + } + else + { + if constexpr(DstT_exp == 4) + { // e4m3 + signed_inf = (sign << 7) + (clip ? 0x7e : 0x7f); + } + else + { // e5m2 + signed_inf = (sign << 7) + (clip ? 0x7b : 0x7c); + } + nan = (sign << 7) + 0x7f; + } + // Max values + unsigned long long ifmax = 0; + if constexpr(is_float) + { + if constexpr(DstT_exp == 5) + { + ifmax = 0x47600000; + } + else + { + if constexpr(is_fnuz) + { + ifmax = 0x43700000; + } + else + { + ifmax = 0x43E00000; + } + } + } + else if constexpr(is_half) + { + if constexpr(DstT_exp == 5) + { + ifmax = 0x7B00; + } + else + { + if constexpr(is_fnuz) + { + ifmax = 0x5B80; + } + else + { + ifmax = 0x5F00; + } + } + } + + // Deal with inf and NaNs + if((src_bitwise & fInf) == fInf) + { + if constexpr(is_fnuz) + return signed_inf; + + return mantissa != 0 ? nan : signed_inf; + } + + if((src_bitwise & abs_mask) > ifmax) + { + return signed_inf; + } + + if(src_bitwise == 0) + { + return 0; + } + + // First need to check if it is normal or denorm as there is a difference of + // implicit 1 Then need to adjust the exponent to align with the F8 exponent, + // in the meanwhile, shift The mantissa. Then for stochastic rounding, add rng + // to mantissa and truncate. And for RNE, no need to add rng. Then probably + // need to check whether there is carry and adjust exponent and mantissa again + + // For IEEE bias mode, the bias is 2^(k-1) -1 where k is the width of exponent + // bits + const int f8_bias = (1 << (DstT_exp - 1)) - 1 + (is_fnuz ? 1 : 0); + const int f8_denormal_act_exponent = 1 - f8_bias; // actual exponent of f8 denormal + // act_exponent is the actual exponent of fp32/fp16 (after subtracting bias) + // f8_exponent is the converted f8 exponent with bias encoding + // exponent_diff is the diff between fp32/fp16 exponent and f8 exponent, + // the difference needs to be adjusted and mantissa shifted + int act_exponent, f8_exponent, exponent_diff; + + if(exponent == 0) + { // fp32/fp16 is in denormal. + /* fp32 denormal is below 2^-127 so it is usually not a concern here, we + mostly concern fp16 here. In this case, f8 is usually in denormal. But there + could be exceptions. fp16 denormal has exponent bias 15 while bf8 with NANOO has + exponent bias 16. It means that there are some numbers in fp16 denormal but they + are bf8 (NANOO) normals - smallest bf8 (NANOO) normal is 2^-15. fp16 numbers + where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8 + (NANOO) normal. In this case, the fp16 mantissa should be shift left by 1 */ + act_exponent = exponent - bias + 1; + exponent_diff = f8_denormal_act_exponent - + act_exponent; // actual exponent is exponent-bias+1 as it is denormal + } + else + { // fp32/fp16 is normal with implicit 1 + act_exponent = exponent - bias; + if(act_exponent <= f8_denormal_act_exponent) + { + /* This is the case where fp32/fp16 is normal but it is in f8 denormal + range. For example fp8 nanoo mode, denormal exponent is -7, but if the fp32/fp16 + actual exponent is -7, it is actually larger due to the implicit 1, + Therefore it needs to be adjust to -6 and mantissa shift right by 1. + So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */ + exponent_diff = f8_denormal_act_exponent - act_exponent; + } + else + { // both fp32/fp16 and f8 are in normal range + exponent_diff = 0; // exponent_diff=0 does not mean there is no difference + // for this case, act_exponent could be larger. Just + // that it does not need shift mantissa + } + mantissa += (1ull << SrcT_mant); // Add the implicit 1 into mantissa + } + + bool midpoint = (mantissa & ((1ull << (SrcT_mant - DstT_mant + exponent_diff)) - 1)) == + (1ull << (SrcT_mant - DstT_mant + exponent_diff - 1)); + /* This part is a bit tricky. The judgment of whether it is a tie needs to be + done before we shift right as shift right could rip off some residual part and + make something not midpoint look like midpoint. For example, the fp16 number + 0x1002 (0 00100 0000000010), it is larger than midpoint, but after shift right + by 4 bits, it would look like midpoint. + */ + + if(exponent_diff > 0) + mantissa >>= exponent_diff; + else if(exponent_diff == -1) + mantissa <<= -exponent_diff; + bool implicit_one = mantissa & (1ull << SrcT_mant); + // if there is no implicit 1, it means the f8 is denormal and need to adjust + // to denorm exponent + f8_exponent = + (act_exponent + exponent_diff) /*actual f8 exponent*/ + f8_bias - (implicit_one ? 0 : 1); + + // Now we have the exponent and mantissa adjusted + unsigned long long drop_mask = (1ull << (SrcT_mant - DstT_mant)) - 1; + bool odd = + mantissa & (1ull << (SrcT_mant - + DstT_mant)); // if the least significant bit that is not truncated is 1 + mantissa += + (stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1ull) : mantissa)) & drop_mask; + + // Now we deal with overflow + if(f8_exponent == 0) + { + if((1ull << SrcT_mant) & mantissa) + { + f8_exponent = 1; // denormal overflow to become normal, promote exponent + } + } + else + { + if((1ull << (SrcT_mant + 1)) & mantissa) + { + mantissa >>= 1; + f8_exponent++; + } + } + + mantissa >>= (SrcT_mant - DstT_mant); + + // above range: quantize to maximum possible float of the same sign + const int max_exp = (1 << DstT_exp) - 1; + if(f8_exponent > max_exp) + { + if constexpr(clip) + { + mantissa = (1 << DstT_mant) - 1; + f8_exponent = max_exp; + } + else + { + return signed_inf; + } + } + + if(f8_exponent == 0 && mantissa == 0) + return is_fnuz ? 0 : (sign << 7); + mantissa &= (1 << DstT_mant) - 1; + return (sign << 7) | (f8_exponent << DstT_mant) | mantissa; +} + +template +CK_TILE_HOST_DEVICE DstT run_cast_from_f8(SrcT x) +{ + static_assert(std::is_same::value || std::is_same::value, + "SrcT type must be fp8 or bf8."); + constexpr int SrcT_exp = numeric_traits::exp; + constexpr int SrcT_mant = numeric_traits::mant; + constexpr bool is_fnuz = + (numeric_traits::f8_interpret == fp8_interpretation::E4M3_FNUZ) || + (numeric_traits::f8_interpret == fp8_interpretation::E5M2_FNUZ); + + constexpr bool is_half = std::is_same::value; + constexpr bool is_float = std::is_same::value; + static_assert(is_half || is_float, "DstT type must be half_t or float."); + + // destination type exponent/mantissa layout + constexpr int DstT_exp = numeric_traits::exp; // exponent width of the destination type + constexpr int DstT_mant = numeric_traits::mant; // mantissa width of the destination type + + constexpr DstT fInf = bit_cast(numeric_traits::Inf); + constexpr DstT fNegInf = bit_cast(numeric_traits::NegInf); + constexpr DstT fNaN = bit_cast(numeric_traits::NaN); + constexpr DstT fNeg0 = bit_cast(numeric_traits::Neg0); + + DstT fmax{0}, fmin{0}; + // Max number in e5m2 57344 + if constexpr(is_half) + { + fmax = bit_cast(static_cast::bitwise_type>(0x7B00)); + fmin = bit_cast(static_cast::bitwise_type>(0xFB00)); + } + else if constexpr(is_float) + { + fmax = bit_cast(static_cast::bitwise_type>(0x47600000)); + fmin = bit_cast(static_cast::bitwise_type>(0xC7600000)); + } + + if(x == 0) + { + return 0; + } + + unsigned long long sign = x >> 7; + unsigned long long mantissa = x & ((1 << SrcT_mant) - 1); + int exponent = (x & 0x7F) >> SrcT_mant; + if constexpr(is_fnuz) + { + if(x == 0x80) + { + return fNaN; + } + } + else + { + if(x == 0x80) + { + return fNeg0; + } + if constexpr(SrcT_exp == 4) + { // e4m3 + if((x & 0x7F) == 0x7F) + { + return fNaN; + } + } + else if((x & 0x7C) == 0x7C) + { // e5m2 + if((x & 0x3) == 0) + { + if constexpr(clip) + { + return sign ? fmin : fmax; + } + return sign ? fNegInf : fInf; + } + return fNaN; + } + } + + typename numeric_traits::bitwise_type retval; + + if constexpr(SrcT_exp == 5 && is_half && !is_fnuz) + { + retval = x << 8; + return bit_cast(retval); + } + + const int exp_low_cutoff = + (1 << (DstT_exp - 1)) - (1 << (SrcT_exp - 1)) + 1 - (is_fnuz ? 1 : 0); + + // subnormal input + if(exponent == 0) + { + int sh = 1 + clz(mantissa) - (32 - SrcT_mant); + mantissa <<= sh; + exponent += 1 - sh; + mantissa &= ((1ull << SrcT_mant) - 1); + } + exponent += exp_low_cutoff - 1; + mantissa <<= DstT_mant - SrcT_mant; + + // subnormal output (occurs when DstT is half_t, we=5, is_fnuz=true) + if(exponent <= 0) + { + mantissa |= 1 << DstT_mant; + mantissa >>= 1 - exponent; + exponent = 0; + } + + retval = (sign << (DstT_exp + DstT_mant)) | (exponent << DstT_mant) | mantissa; + + return bit_cast(retval); +} + +template +CK_TILE_HOST_DEVICE Y cast_to_f8(X x, uint32_t rng) +{ + return bit_cast(run_cast_to_f8(x, rng)); +} + +#if CK_TILE_FP8_CVT_DEVICE +/** + * @brief Cast float to fp8/bf8 using device conversion instructions + */ +template +CK_TILE_DEVICE uint8_t cast_to_f8_from_f32(float v, unsigned int rng = 0) +{ + uint8_t i8data; + union + { + float fval; + unsigned int i32val; + unsigned char i8val[4]; // NOTE: not endian independent + } val; + + unsigned int ival = 0; + val.fval = v; + + if constexpr(saturate) + { + if constexpr(interpret == fp8_interpretation::E4M3_FNUZ) + { + if((val.i32val & 0x7F800000) != 0x7F800000) + { /// propagate NAN/INF, no clipping + val.fval = __builtin_amdgcn_fmed3f(val.fval, 240.0, -240.0); + } + } + else if constexpr(interpret == fp8_interpretation::E4M3_OCP) + { // OCP type + if((val.i32val & 0x7F800000) != 0x7F800000) + { /// propagate NAN/INF, no clipping + val.fval = __builtin_amdgcn_fmed3f(val.fval, 448.0, -448.0); + } + } + else + { + if((val.i32val & 0x7F800000) != 0x7F800000) + { /// propagate NAN/INF, no clipping + val.fval = __builtin_amdgcn_fmed3f(val.fval, 57344.0, -57344.0); + } + } + } + + if constexpr(stochastic_rounding) + { + ival = (interpret == fp8_interpretation::E4M3_FNUZ) || + (interpret == fp8_interpretation::E4M3_OCP) + ? __builtin_amdgcn_cvt_sr_fp8_f32(val.fval, rng, ival, 0) + : __builtin_amdgcn_cvt_sr_bf8_f32(val.fval, rng, ival, 0); // 0 pos + val.i32val = ival; + i8data = val.i8val[0]; // little endian + } + else + { // RNE CVT + ival = (interpret == fp8_interpretation::E4M3_FNUZ) || + (interpret == fp8_interpretation::E4M3_OCP) + ? __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival, false) + : __builtin_amdgcn_cvt_pk_bf8_f32(val.fval, + val.fval, + ival, + false); // false -> WORD0 + val.i32val = ival; + i8data = val.i8val[0]; + } + return i8data; +} +#endif // CK_TILE_FP8_CVT_DEVICE + +} // namespace impl + +/** + * @brief Converts a floating-point value to an 8-bit floating-point representation with stochastic + * rounding. + * + * This function converts a floating-point value (float or half_t) to an 8-bit floating-point + * representation of type fp8_t or bf8_t. The conversion process may + * involve clipping and uses a pseudo-random number generator for the stochastic rounding. + * + * @tparam DstT The destination type (fp8_t or bf8_t). + * @tparam SrcT The source type (float or half_t) to be converted. + * @param x The floating-point value to be converted. + * @return The 8-bit floating-point representation of the input value. + */ +template +CK_TILE_HOST_DEVICE typename numeric_traits::bitwise_type float_to_fp8_sr_raw(SrcT x) +{ + constexpr bool clip = true; + constexpr int seed = 42; + uint32_t rng = prand_generator_t{}(reinterpret_cast(&x), x); +#if CK_TILE_FP8_CVT_DEVICE + return impl::cast_to_f8_from_f32::f8_interpret, clip, true>(x, rng); +#else + return bit_cast::bitwise_type>( + impl::cast_to_f8(x, rng)); +#endif +} + +/** + * @brief Converts a floating-point value to an 8-bit floating-point representation with rounding to + * nearest even. + * + * This function converts a floating-point value (float or half_t) to an 8-bit floating-point + * representation of type fp8_t or bf8_t. The conversion process may involve clipping. + * + * @tparam DstT The destination type (fp8_t or bf8_t). + * @tparam SrcT The source type (float or half_t) to be converted. + * @param x The floating-point value to be converted. + * @return The 8-bit floating-point representation of the input value. + */ +template +CK_TILE_HOST_DEVICE typename numeric_traits::bitwise_type float_to_fp8_rtn_raw(SrcT x) +{ + constexpr bool clip = true; +#if CK_TILE_FP8_CVT_DEVICE + return impl::cast_to_f8_from_f32::f8_interpret, clip, false>(x, 0); +#else + return bit_cast::bitwise_type>( + impl::cast_to_f8(x, 0)); +#endif +} + +template +CK_TILE_HOST_DEVICE fp8_raw_t float_to_fp8_raw(float x, constant) +{ + if constexpr(rounding == fp8_rounding_mode::standard) + { + return float_to_fp8_rtn_raw(x); + } + else if constexpr(rounding == fp8_rounding_mode::stochastic) + { + return float_to_fp8_sr_raw(x); + } + else + { + return fp8_raw_t{0}; + } +} + +template +CK_TILE_HOST_DEVICE bf8_raw_t float_to_bf8_raw(float x, constant) +{ + if constexpr(rounding == fp8_rounding_mode::standard) + { + return float_to_fp8_rtn_raw(x); + } + else if constexpr(rounding == fp8_rounding_mode::stochastic) + { + return float_to_fp8_sr_raw(x); + } + else + { + return bf8_raw_t{0}; + } +} + +CK_TILE_HOST_DEVICE float fp8_to_float_raw(fp8_raw_t x) +{ +#if CK_TILE_FP8_CVT_DEVICE + float fval; + uint32_t i32val = static_cast(x); + fval = __builtin_amdgcn_cvt_f32_fp8(i32val, 0); + // asm volatile("v_cvt_f32_fp8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val)); + return fval; +#else + return impl::run_cast_from_f8(bit_cast(x)); +#endif +} + +CK_TILE_HOST_DEVICE float bf8_to_float_raw(bf8_raw_t x) +{ +#if CK_TILE_FP8_CVT_DEVICE + float fval; + uint32_t i32val = static_cast(x); + fval = __builtin_amdgcn_cvt_f32_bf8(i32val, 0); + // asm volatile("v_cvt_f32_bf8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val)); + return fval; +#else + return impl::run_cast_from_f8(bit_cast(x)); +#endif +} + +template (CK_TILE_FLOAT_TO_FP8_DEFAULT)> +CK_TILE_HOST_DEVICE fp8_t float_to_fp8(float x, constant = {}) +{ + return bit_cast(float_to_fp8_raw(x, constant{})); +} + +template (CK_TILE_FLOAT_TO_FP8_DEFAULT)> +CK_TILE_HOST_DEVICE bf8_t float_to_bf8(float x, constant = {}) +{ + return bit_cast(float_to_bf8_raw(x, constant{})); +} + +CK_TILE_HOST_DEVICE float fp8_to_float(fp8_t x) { return fp8_to_float_raw(bit_cast(x)); } + +CK_TILE_HOST_DEVICE float bf8_to_float(bf8_t x) { return bf8_to_float_raw(bit_cast(x)); } + template struct numeric; +#if CK_TILE_USE_OCP_FP8 +template <> +struct numeric +{ + // minimum finite value, or minimum positive normal value + CK_TILE_HOST_DEVICE static constexpr fp8_t min() + { + return bit_cast(static_cast(0x08)); // 0b00001000 = 2^-6 + } + + // minumum finite value + CK_TILE_HOST_DEVICE static constexpr fp8_t lowest() + { + return bit_cast(static_cast(0xfe)); // 0b11111110 = -448 + } + + // maximum finite value + CK_TILE_HOST_DEVICE static constexpr fp8_t max() + { + return bit_cast(static_cast(0x7e)); // 0b01111110 = 448 + } + + // difference between 1.0 and next representable f8 value (1.125) + // returns fp8_t(0.125) + CK_TILE_HOST_DEVICE static constexpr fp8_t epsilon() + { + return bit_cast(static_cast(0x20)); // 0.125 + } + + // rounding error (0.0625) + // half of epsilon + CK_TILE_HOST_DEVICE static constexpr fp8_t round_error() + { + return bit_cast(static_cast(0x18)); // 0.0625 + } + + // quiet NaN + CK_TILE_HOST_DEVICE static constexpr fp8_t quiet_NaN() + { + return bit_cast(static_cast(0x7F)); // 0b01111111 + } + + // signaling NaN + CK_TILE_HOST_DEVICE static constexpr fp8_t signaling_NaN() + { + return bit_cast(static_cast(0xFF)); // 0b11111111 + } + + // smallest positive subnormal value + CK_TILE_HOST_DEVICE static constexpr fp8_t denorm_min() + { + return bit_cast(static_cast(0x01)); + } + + CK_TILE_HOST_DEVICE static constexpr fp8_t zero() + { + return bit_cast(static_cast(0)); + } +}; + +template <> +struct numeric +{ + // minimum finite value, or minimum positive normalized value for float + CK_TILE_HOST_DEVICE static constexpr bf8_t min() + { + return bit_cast(static_cast(0x04)); // 0b00000100 = 2^-14 + } + + // minumum finite value + CK_TILE_HOST_DEVICE static constexpr bf8_t lowest() + { + return bit_cast(static_cast(0xfb)); // 0b11111011 = -57344 + } + + // maximum finite value + CK_TILE_HOST_DEVICE static constexpr bf8_t max() + { + return bit_cast(static_cast(0x7b)); // 0b01111011 = 57344 + } + + // difference between 1.0 and next representable bf8 value (1.25) + CK_TILE_HOST_DEVICE static constexpr bf8_t epsilon() + { + return bit_cast(static_cast(0x34)); // 0.25 + } + + // rounding error (0.125) + // half of epsilon + CK_TILE_HOST_DEVICE static constexpr bf8_t round_error() + { + return bit_cast(static_cast(0x30)); // 0.125 + } + + // positive infinity value + CK_TILE_HOST_DEVICE static constexpr bf8_t infinity() + { + return bit_cast(static_cast(0x7c)); // 0b01111100 + } + + // quiet NaN + CK_TILE_HOST_DEVICE static constexpr bf8_t quiet_NaN() + { + return bit_cast(static_cast(0x7F)); // 0b01111111 + } + + // signaling NaN + CK_TILE_HOST_DEVICE static constexpr bf8_t signaling_NaN() + { + return bit_cast(static_cast(0xFF)); + } + + // smallest positive subnormal value + CK_TILE_HOST_DEVICE static constexpr bf8_t denorm_min() + { + return bit_cast(static_cast(0x01)); + } + + CK_TILE_HOST_DEVICE static constexpr bf8_t zero() + { + return bit_cast(static_cast(0)); + } +}; +#else template <> struct numeric { @@ -811,6 +1054,7 @@ struct numeric return bit_cast(static_cast(0)); } }; +#endif #if CK_TILE_USE_CUSTOM_DATA_TYPE CK_TILE_ARITHMETIC_USING_FLOAT(CK_TILE_HOST_DEVICE, fp8_t) @@ -818,19 +1062,26 @@ CK_TILE_ARITHMETIC_USING_FLOAT(CK_TILE_HOST_DEVICE, bf8_t) #endif // math -CK_TILE_HOST_DEVICE -fp8_t abs(const fp8_t& x) +template +CK_TILE_HOST_DEVICE T abs(const T& x) { - return bit_cast(static_cast(bit_cast(x) & 0x7f)); + static_assert(std::is_same_v || std::is_same_v, + "Only fp8_t and bf8_t are supported"); + return bit_cast(static_cast(bit_cast(x) & numeric_traits::abs_mask)); } CK_TILE_HOST_DEVICE bool isnan(const fp8_t& x) { uint8_t xx = bit_cast(x); - return xx == 0x80; // TODO: NANOO -} +#if CK_TILE_USE_OCP_FP8 + return (xx & 0x7f) == 0x7f; +#else + return xx == 0x80; +#endif +} +#if CK_TILE_USE_CUSTOM_DATA_TYPE CK_TILE_DEVICE fp8_t sqrt(fp8_t x) { return static_cast(__builtin_amdgcn_sqrtf(static_cast(x))); }; @@ -842,20 +1093,21 @@ fp8_t exp2(fp8_t x) { return static_cast(exp2f(static_cast(x))); } CK_TILE_DEVICE fp8_t log(fp8_t x) { return static_cast(__logf(static_cast(x))); }; - -CK_TILE_HOST_DEVICE -bf8_t abs(const bf8_t& x) -{ - return bit_cast(static_cast(bit_cast(x) & 0x7f)); -} +#endif CK_TILE_HOST_DEVICE bool isnan(const bf8_t& x) { uint8_t xx = bit_cast(x); - return xx == 0x80; // TODO: NANOO + +#if CK_TILE_USE_OCP_FP8 + return (xx & 0x7f) > 0x7c; +#else + return xx == 0x80; +#endif } +#if CK_TILE_USE_CUSTOM_DATA_TYPE CK_TILE_DEVICE bf8_t sqrt(bf8_t x) { return static_cast(__builtin_amdgcn_sqrtf(static_cast(x))); }; @@ -867,5 +1119,6 @@ bf8_t exp2(bf8_t x) { return static_cast(exp2f(static_cast(x))); } CK_TILE_DEVICE bf8_t log(bf8_t x) { return static_cast(__logf(static_cast(x))); }; +#endif } // namespace ck_tile diff --git a/include/ck_tile/core/numeric/half.hpp b/include/ck_tile/core/numeric/half.hpp index acb6eb6c3e..5779b170b7 100644 --- a/include/ck_tile/core/numeric/half.hpp +++ b/include/ck_tile/core/numeric/half.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck_tile/core/config.hpp" #include "ck_tile/core/utility/bit_cast.hpp" @@ -236,10 +236,11 @@ struct numeric_traits static constexpr uint16_t head_mask = 0xFC00; static constexpr uint16_t mant_mask = 0x3FF; static constexpr uint16_t exp_mask = 0x1F; - static constexpr uint32_t Inf = 0x7C00; - static constexpr uint32_t NegInf = 0xFC00; - static constexpr uint32_t NaN = 0x7C01; - static constexpr uint32_t Neg0 = 0x8000; + static constexpr uint16_t abs_mask = 0x7FFF; + static constexpr uint16_t Inf = 0x7C00; + static constexpr uint16_t NegInf = 0xFC00; + static constexpr uint16_t NaN = 0x7C01; + static constexpr uint16_t Neg0 = 0x8000; using bitwise_type = uint16_t; }; diff --git a/include/ck_tile/core/numeric/numeric.hpp b/include/ck_tile/core/numeric/numeric.hpp index 35745b12d2..6b16485b48 100644 --- a/include/ck_tile/core/numeric/numeric.hpp +++ b/include/ck_tile/core/numeric/numeric.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -89,6 +89,7 @@ struct numeric_traits static constexpr uint32_t head_mask = 0xFF800000; static constexpr uint32_t mant_mask = 0x7FFFFF; static constexpr uint32_t exp_mask = 0xFF; + static constexpr uint32_t abs_mask = 0x7FFFFFFF; static constexpr uint32_t Inf = 0x7F800000; static constexpr uint32_t NegInf = 0xFF800000; static constexpr uint32_t NaN = 0x7F800001;