mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 10:09:41 +00:00
FP8 enablement - add a pseudorandom number generator, add conversion methods (#708)
* Add basic fp8 definitions and prn-generator
* Format
* Add fp8<->fp32 type_convert
* Format
* Split type_convert and cast_to/from_f8
* Format
* Minor fix
* Minor fix
* Move fp8 utils to a separate header
* Add elementwise ops
* Add fp8_convert_sr
* Format
* Add element op
* Eliminate magic numbers
* Split f8_convert_sr in host and device
* Format
* Add some constexpr
* Add a datatype test
* Format
* Another format
* Add fp8<->fp16 tests
* Update type_converts
* Format
* Add fp16 casting functions
* Format
* Use seed as a runtime arg
* Use element location for PRNG
* Format
* Add fp8<->fp16 to PassThrough element op
* Clean up
* Merge host and device implementations
* Add comments on rounding modes
* Remove leftover code
* Put type_converts into a separate header
* Put random number gen to a separate header
* Rearrange f8_utils' namespaces
* Refactor type_convert.hpp
* Move f8_t definition
[ROCm/composable_kernel commit: f0c620c42e]
This commit is contained in:
@@ -6,6 +6,7 @@
|
||||
#include "ck/utility/data_type.hpp"
|
||||
#include "ck/utility/math.hpp"
|
||||
#include "ck/utility/math_v2.hpp"
|
||||
#include "ck/utility/type_convert.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -81,6 +82,36 @@ struct PassThrough
|
||||
y = x;
|
||||
}
|
||||
#endif
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()<f8_t, f8_t>(f8_t& y, const f8_t& x) const
|
||||
{
|
||||
y = x;
|
||||
}
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()<float, f8_t>(float& y, const f8_t& x) const
|
||||
{
|
||||
y = type_convert<float>(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()<f8_t, float>(f8_t& y, const float& x) const
|
||||
{
|
||||
y = type_convert<f8_t>(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()<half_t, f8_t>(half_t& y, const f8_t& x) const
|
||||
{
|
||||
y = type_convert<half_t>(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()<f8_t, half_t>(f8_t& y, const half_t& x) const
|
||||
{
|
||||
y = type_convert<f8_t>(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct UnaryConvert
|
||||
@@ -109,6 +140,23 @@ struct ConvertBF16RTN
|
||||
}
|
||||
};
|
||||
|
||||
struct ConvertF8SR
|
||||
{
|
||||
// convert to fp8 using stochastic rounding (SR)
|
||||
template <typename Y, typename X>
|
||||
__host__ __device__ void operator()(Y& y, const X& x) const
|
||||
{
|
||||
// check Y datatype
|
||||
static_assert(is_same<Y, f8_t>::value, "Data type is not supported by this operation!");
|
||||
|
||||
// check X datatype
|
||||
static_assert(is_same<X, float>::value || is_same<X, half_t>::value,
|
||||
"Data type is not supported by this operation!");
|
||||
|
||||
y = f8_convert_sr<Y>(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct Scale
|
||||
{
|
||||
__host__ __device__ Scale(float scale) : scale_(scale) {}
|
||||
|
||||
@@ -24,6 +24,7 @@
|
||||
#include "ck/utility/tuple.hpp"
|
||||
#include "ck/utility/tuple_helper.hpp"
|
||||
#include "ck/utility/type.hpp"
|
||||
#include "ck/utility/type_convert.hpp"
|
||||
#include "ck/utility/magic_division.hpp"
|
||||
#include "ck/utility/c_style_pointer_cast.hpp"
|
||||
#include "ck/utility/is_known_at_compile_time.hpp"
|
||||
|
||||
@@ -12,6 +12,7 @@ using half_t = _Float16;
|
||||
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
|
||||
using int4_t = _BitInt(4);
|
||||
#endif
|
||||
using f8_t = uint8_t;
|
||||
|
||||
// vector_type
|
||||
template <typename T, index_t N>
|
||||
@@ -142,6 +143,13 @@ struct scalar_type<int4_t>
|
||||
};
|
||||
#endif
|
||||
|
||||
template <>
|
||||
struct scalar_type<f8_t>
|
||||
{
|
||||
using type = f8_t;
|
||||
static constexpr index_t vector_size = 1;
|
||||
};
|
||||
|
||||
//
|
||||
template <typename T>
|
||||
struct vector_type<T, 1>
|
||||
@@ -944,151 +952,13 @@ using int8x16_t = typename vector_type<int8_t, 16>::type;
|
||||
using int8x32_t = typename vector_type<int8_t, 32>::type;
|
||||
using int8x64_t = typename vector_type<int8_t, 64>::type;
|
||||
|
||||
// Convert X to Y
|
||||
template <typename Y, typename X>
|
||||
__host__ __device__ constexpr Y type_convert(X x)
|
||||
{
|
||||
static_assert(!std::is_reference_v<Y> && !std::is_reference_v<X>);
|
||||
|
||||
return static_cast<Y>(x);
|
||||
}
|
||||
|
||||
// convert bfp16 to fp32
|
||||
template <>
|
||||
inline __host__ __device__ constexpr float type_convert<float, bhalf_t>(bhalf_t x)
|
||||
{
|
||||
union
|
||||
{
|
||||
uint32_t int32;
|
||||
float fp32;
|
||||
} u = {uint32_t(x) << 16};
|
||||
|
||||
return u.fp32;
|
||||
}
|
||||
|
||||
// convert fp32 to bfp16
|
||||
template <>
|
||||
inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, float>(float x)
|
||||
{
|
||||
union
|
||||
{
|
||||
float fp32;
|
||||
uint32_t int32;
|
||||
} u = {x};
|
||||
|
||||
return uint16_t(u.int32 >> 16);
|
||||
}
|
||||
|
||||
// convert bfp16 to fp16 via fp32
|
||||
template <>
|
||||
inline __host__ __device__ constexpr half_t type_convert<half_t, bhalf_t>(bhalf_t x)
|
||||
{
|
||||
float x_fp32 = type_convert<float>(x);
|
||||
|
||||
return static_cast<half_t>(x_fp32);
|
||||
}
|
||||
|
||||
// convert fp16 to bfp16 via fp32
|
||||
template <>
|
||||
inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, half_t>(half_t x)
|
||||
{
|
||||
float x_fp32 = static_cast<float>(x);
|
||||
|
||||
return type_convert<bhalf_t>(x_fp32);
|
||||
}
|
||||
|
||||
// convert bfp16 to int32 via fp32
|
||||
template <>
|
||||
inline __host__ __device__ constexpr int32_t type_convert<int32_t, bhalf_t>(bhalf_t x)
|
||||
{
|
||||
float x_fp32 = type_convert<float>(x);
|
||||
|
||||
return static_cast<int32_t>(x_fp32);
|
||||
}
|
||||
|
||||
// convert int32 to bfp16 via fp32
|
||||
template <>
|
||||
inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, int32_t>(int32_t x)
|
||||
{
|
||||
float x_fp32 = static_cast<float>(x);
|
||||
|
||||
return type_convert<bhalf_t>(x_fp32);
|
||||
}
|
||||
|
||||
// convert bfp16 to int8 via fp32
|
||||
template <>
|
||||
inline __host__ __device__ constexpr int8_t type_convert<int8_t, bhalf_t>(bhalf_t x)
|
||||
{
|
||||
float x_fp32 = type_convert<float>(x);
|
||||
|
||||
return static_cast<int8_t>(x_fp32);
|
||||
}
|
||||
|
||||
// convert int8 to bfp16 via fp32
|
||||
template <>
|
||||
inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, int8_t>(int8_t x)
|
||||
{
|
||||
float x_fp32 = static_cast<float>(x);
|
||||
|
||||
return type_convert<bhalf_t>(x_fp32);
|
||||
}
|
||||
|
||||
// Declare a template function for bf16 conversion using RTN
|
||||
template <typename Y, typename X>
|
||||
__host__ __device__ constexpr Y bf16_convert_rtn(X x);
|
||||
|
||||
// Convert fp32 to bf16 with RTN if higher precision is needed
|
||||
template <>
|
||||
inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn<bhalf_t, float>(float x)
|
||||
{
|
||||
union
|
||||
{
|
||||
float fp32;
|
||||
uint32_t int32;
|
||||
} u = {x};
|
||||
|
||||
// When the exponent bits are not all 1s, then the value is zero, normal,
|
||||
// or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus
|
||||
// 1 if the least significant bit of the bfloat16 mantissa is 1 (odd).
|
||||
// This causes the bfloat16's mantissa to be incremented by 1 if the 16
|
||||
// least significant bits of the float mantissa are greater than 0x8000,
|
||||
// or if they are equal to 0x8000 and the least significant bit of the
|
||||
// bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when
|
||||
// the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already
|
||||
// has the value 0x7f, then incrementing it causes it to become 0x00 and
|
||||
// the exponent is incremented by one, which is the next higher FP value
|
||||
// to the unrounded bfloat16 value. When the bfloat16 value is subnormal
|
||||
// with an exponent of 0x00 and a mantissa of 0x7f, it may be rounded up
|
||||
// to a normal value with an exponent of 0x01 and a mantissa of 0x00.
|
||||
// When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F,
|
||||
// incrementing it causes it to become an exponent of 0xFF and a mantissa
|
||||
// of 0x00, which is Inf, the next higher value to the unrounded value.
|
||||
bool flag0 = ~u.int32 & 0x7f800000;
|
||||
|
||||
// When all of the exponent bits are 1, the value is Inf or NaN.
|
||||
// Inf is indicated by a zero mantissa. NaN is indicated by any nonzero
|
||||
// mantissa bit. Quiet NaN is indicated by the most significant mantissa
|
||||
// bit being 1. Signaling NaN is indicated by the most significant
|
||||
// mantissa bit being 0 but some other bit(s) being 1. If any of the
|
||||
// lower 16 bits of the mantissa are 1, we set the least significant bit
|
||||
// of the bfloat16 mantissa, in order to preserve signaling NaN in case
|
||||
// the bfloat16's mantissa bits are all 0.
|
||||
bool flag1 = !flag0 && (u.int32 & 0xffff);
|
||||
|
||||
u.int32 += flag0 ? 0x7fff + ((u.int32 >> 16) & 1) : 0; // Round to nearest, round to even
|
||||
u.int32 |= flag1 ? 0x10000 : 0x0; // Preserve signaling NaN
|
||||
|
||||
return uint16_t(u.int32 >> 16);
|
||||
}
|
||||
|
||||
// convert fp16 to bfp16 via fp32 with RTN if higher precision is needed
|
||||
template <>
|
||||
inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn<bhalf_t, half_t>(half_t x)
|
||||
{
|
||||
float x_fp32 = static_cast<float>(x);
|
||||
|
||||
return bf16_convert_rtn<bhalf_t>(x_fp32);
|
||||
}
|
||||
// f8
|
||||
using f8x2_t = typename vector_type<f8_t, 2>::type;
|
||||
using f8x4_t = typename vector_type<f8_t, 4>::type;
|
||||
using f8x8_t = typename vector_type<f8_t, 8>::type;
|
||||
using f8x16_t = typename vector_type<f8_t, 16>::type;
|
||||
using f8x32_t = typename vector_type<f8_t, 32>::type;
|
||||
using f8x64_t = typename vector_type<f8_t, 64>::type;
|
||||
|
||||
template <typename T>
|
||||
struct NumericLimits
|
||||
@@ -1136,4 +1006,21 @@ struct NumericLimits<int4_t>
|
||||
};
|
||||
#endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
|
||||
|
||||
template <>
|
||||
struct NumericLimits<f8_t>
|
||||
{
|
||||
static constexpr uint8_t binary_min = 0x08; // 0b00001000
|
||||
static constexpr uint8_t binary_max = 0x77; // 0b01110111
|
||||
static constexpr uint8_t binary_lowest = 0xF7; // 0b11110111
|
||||
static constexpr uint8_t binary_qnan = 0x80; // 0b10000000
|
||||
|
||||
__host__ __device__ static constexpr f8_t Min() { return bit_cast<f8_t>(binary_min); }
|
||||
|
||||
__host__ __device__ static constexpr f8_t Max() { return bit_cast<f8_t>(binary_max); }
|
||||
|
||||
__host__ __device__ static constexpr f8_t Lowest() { return bit_cast<f8_t>(binary_lowest); }
|
||||
|
||||
__host__ __device__ static constexpr f8_t QuietNaN() { return bit_cast<f8_t>(binary_qnan); }
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
|
||||
250
include/ck/utility/f8_utils.hpp
Normal file
250
include/ck/utility/f8_utils.hpp
Normal file
@@ -0,0 +1,250 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/utility/data_type.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// fp8 rounding modes
|
||||
// use standard for rounding to nearest, the faster one
|
||||
// use stochastic for stochastic rounding, helps to avoid error accumulation
|
||||
enum class f8_rounding_mode
|
||||
{
|
||||
standard,
|
||||
stochastic
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
|
||||
namespace ck::utils {
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename T, bool negative_zero_nan, bool clip, bool stoch>
|
||||
__host__ __device__ f8_t run_cast_to_f8(T x, uint32_t rng)
|
||||
{
|
||||
// check data type
|
||||
constexpr bool is_half = std::is_same<T, half_t>::value;
|
||||
constexpr bool is_float = std::is_same<T, float>::value;
|
||||
|
||||
// fp8 exponent/mantissa layout
|
||||
constexpr int f8_exp = 4;
|
||||
constexpr int f8_mant = 3;
|
||||
|
||||
// resulting type exponent/mantissa layout
|
||||
constexpr int type_exp = is_half ? 5 : 8;
|
||||
constexpr int type_mant = is_half ? 10 : 23;
|
||||
|
||||
int exponent;
|
||||
uint32_t head, mantissa, sign;
|
||||
// nan code is same for float and half
|
||||
constexpr uint8_t nan_code = 0x80;
|
||||
constexpr uint32_t nan_mask = is_half ? 0x7C00 : 0x7F800000;
|
||||
|
||||
// convert to bitwise
|
||||
typedef typename std::conditional<std::is_same<T, half_t>::value, uint16_t, uint32_t>::type
|
||||
T_bitwise;
|
||||
T_bitwise x_bitwise = *(reinterpret_cast<T_bitwise*>(&x));
|
||||
|
||||
// unpack the input, depends on datatype
|
||||
if constexpr(is_float)
|
||||
{
|
||||
head = x_bitwise & 0xFF800000;
|
||||
mantissa = x_bitwise & 0x7FFFFF;
|
||||
exponent = (head >> type_mant) & 0xFF;
|
||||
sign = head >> (type_exp + type_mant);
|
||||
}
|
||||
else if constexpr(is_half)
|
||||
{
|
||||
head = x_bitwise & 0xFC00;
|
||||
mantissa = x_bitwise & 0x3FF;
|
||||
exponent = (head >> type_mant) & 0x1F;
|
||||
sign = head >> (type_exp + type_mant);
|
||||
}
|
||||
|
||||
uint32_t signed_inf = (sign << (type_exp + type_mant)) + (((1 << type_exp) - 1) << type_mant);
|
||||
uint32_t drop_mask = (1 << (type_mant - f8_mant)) - 1;
|
||||
constexpr int max_exp = (1 << f8_exp) - (negative_zero_nan ? 1 : 2);
|
||||
constexpr int exp_low_cutoff =
|
||||
(1 << (type_exp - 1)) - (1 << (f8_exp - 1)) + 1 - (negative_zero_nan ? 1 : 0);
|
||||
|
||||
if constexpr(negative_zero_nan)
|
||||
{
|
||||
if((x_bitwise & nan_mask) == nan_mask)
|
||||
return nan_code;
|
||||
}
|
||||
else
|
||||
{
|
||||
if((x_bitwise & nan_mask) == nan_mask)
|
||||
return signed_inf + (mantissa != 0 ? 1 : 0);
|
||||
}
|
||||
|
||||
// check if x is 0.0
|
||||
if(x_bitwise == 0)
|
||||
return 0;
|
||||
|
||||
exponent -= exp_low_cutoff - 1;
|
||||
if(exponent <= 0)
|
||||
drop_mask = (1 << (type_mant - f8_mant + 1 - exponent)) - 1;
|
||||
mantissa += 1 << type_mant;
|
||||
// apply random number if needed
|
||||
mantissa += (stoch ? rng : mantissa) & drop_mask;
|
||||
if(mantissa >= (2 << type_mant))
|
||||
{
|
||||
mantissa >>= 1;
|
||||
exponent++;
|
||||
}
|
||||
mantissa >>= (type_mant - f8_mant);
|
||||
|
||||
// check negative exponent
|
||||
if(exponent <= 0)
|
||||
{
|
||||
if(x_bitwise == 0)
|
||||
return 0;
|
||||
else
|
||||
{
|
||||
// subnormal range; represented by a subnormal float8 (exponent 0)
|
||||
// and involves loss of accuracy
|
||||
mantissa >>= 1 - exponent;
|
||||
exponent = 0;
|
||||
}
|
||||
}
|
||||
// above range: quantize to maximum possible float of the same sign
|
||||
else if(exponent > max_exp)
|
||||
{
|
||||
if(clip)
|
||||
{
|
||||
mantissa = (1 << f8_mant) - 1;
|
||||
exponent = max_exp;
|
||||
}
|
||||
else
|
||||
{
|
||||
return signed_inf;
|
||||
}
|
||||
}
|
||||
|
||||
// check if x is 0.0 or -0.0
|
||||
if(exponent == 0 && mantissa == 0)
|
||||
return negative_zero_nan ? 0 : (sign << (f8_exp + f8_mant));
|
||||
mantissa &= (1 << f8_mant) - 1;
|
||||
return (sign << (f8_exp + f8_mant)) | (exponent << f8_mant) | mantissa;
|
||||
}
|
||||
|
||||
template <typename T, bool negative_zero_nan>
|
||||
__host__ __device__ T run_cast_from_f8(f8_t x)
|
||||
{
|
||||
// check data type
|
||||
constexpr bool is_half = std::is_same<T, half_t>::value;
|
||||
constexpr bool is_float = std::is_same<T, float>::value;
|
||||
|
||||
// fp8 exponent/mantissa layout
|
||||
constexpr int f8_exp = 4;
|
||||
constexpr int f8_mant = 3;
|
||||
|
||||
// resulting type exponent/mantissa layout
|
||||
constexpr int type_exp = is_half ? 5 : 8;
|
||||
constexpr int type_mant = is_half ? 10 : 23;
|
||||
|
||||
// prepare the codes
|
||||
constexpr uint8_t nan_code = 0x80;
|
||||
T fInf, fNegInf, fNaN, fNeg0;
|
||||
if constexpr(is_half)
|
||||
{
|
||||
constexpr uint16_t ihInf = 0x7C00;
|
||||
constexpr uint16_t ihNegInf = 0xFC00;
|
||||
constexpr uint16_t ihNaN = 0x7C01;
|
||||
constexpr uint16_t ihNeg0 = 0x8000;
|
||||
fInf = *(reinterpret_cast<const half_t*>(&ihInf));
|
||||
fNegInf = *(reinterpret_cast<const half_t*>(&ihNegInf));
|
||||
fNaN = *(reinterpret_cast<const half_t*>(&ihNaN));
|
||||
fNeg0 = *(reinterpret_cast<const half_t*>(&ihNeg0));
|
||||
}
|
||||
else if constexpr(is_float)
|
||||
{
|
||||
constexpr uint32_t ifInf = 0x7F800000;
|
||||
constexpr uint32_t ifNegInf = 0xFF800000;
|
||||
constexpr uint32_t ifNaN = 0x7F800001;
|
||||
constexpr uint32_t ifNeg0 = 0x80000000;
|
||||
fInf = *(reinterpret_cast<const float*>(&ifInf));
|
||||
fNegInf = *(reinterpret_cast<const float*>(&ifNegInf));
|
||||
fNaN = *(reinterpret_cast<const float*>(&ifNaN));
|
||||
fNeg0 = *(reinterpret_cast<const float*>(&ifNeg0));
|
||||
}
|
||||
|
||||
// unpack the input
|
||||
uint32_t sign = x >> (f8_exp + f8_mant);
|
||||
uint32_t mantissa = x & ((1 << f8_mant) - 1);
|
||||
int exponent = (x & 0x7F) >> f8_mant;
|
||||
|
||||
constexpr int exp_low_cutoff =
|
||||
(1 << (type_exp - 1)) - (1 << (f8_exp - 1)) + 1 - (negative_zero_nan ? 1 : 0);
|
||||
typename std::conditional<std::is_same<T, half_t>::value, uint16_t, uint32_t>::type retval;
|
||||
|
||||
if constexpr(negative_zero_nan)
|
||||
{
|
||||
if(x == nan_code)
|
||||
return fNaN;
|
||||
}
|
||||
else
|
||||
{
|
||||
if(x == nan_code)
|
||||
return fNeg0;
|
||||
if(exponent == ((1 << f8_exp) - 1))
|
||||
return (mantissa == 0) ? (sign ? fNegInf : fInf) : fNaN;
|
||||
}
|
||||
|
||||
// subnormal input
|
||||
if(exponent == 0)
|
||||
{
|
||||
// guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above
|
||||
int sh = 1 + __builtin_clz(mantissa) - ((1 + type_exp + type_mant) - f8_mant);
|
||||
mantissa <<= sh;
|
||||
mantissa &= ((1 << f8_mant) - 1);
|
||||
exponent += 1 - sh;
|
||||
}
|
||||
exponent += exp_low_cutoff - 1;
|
||||
mantissa <<= type_mant - f8_mant;
|
||||
|
||||
// subnormal output (occurs when T=half, we=5, negative_zero_nan=true)
|
||||
if(exponent <= 0)
|
||||
{
|
||||
mantissa |= 1 << type_mant;
|
||||
mantissa >>= 1 - exponent;
|
||||
exponent = 0;
|
||||
}
|
||||
|
||||
retval = (sign << (type_exp + type_mant)) | (exponent << type_mant) | mantissa;
|
||||
return *(reinterpret_cast<const T*>(&retval));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
template <typename T, bool negative_zero_nan, bool clip, bool stoch>
|
||||
__host__ __device__ f8_t cast_to_f8(T x, uint32_t rng)
|
||||
{
|
||||
// check datatype
|
||||
constexpr bool is_half = std::is_same<T, half_t>::value;
|
||||
constexpr bool is_float = std::is_same<T, float>::value;
|
||||
static_assert(is_half || is_float, "Only half and float can be casted to f8.");
|
||||
|
||||
return run_cast_to_f8<T, negative_zero_nan, clip, stoch>(x, rng);
|
||||
}
|
||||
|
||||
template <typename T, bool negative_zero_nan>
|
||||
__host__ __device__ T cast_from_f8(f8_t x)
|
||||
{
|
||||
// check datatype
|
||||
constexpr bool is_half = std::is_same<T, half_t>::value;
|
||||
constexpr bool is_float = std::is_same<T, float>::value;
|
||||
static_assert(is_half || is_float, "only half and float are supported.");
|
||||
|
||||
// check if x is 0.0
|
||||
if(x == 0)
|
||||
return static_cast<T>(0);
|
||||
|
||||
return run_cast_from_f8<T, negative_zero_nan>(x);
|
||||
}
|
||||
|
||||
} // namespace ck::utils
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
#pragma once
|
||||
#include "data_type.hpp"
|
||||
#include "type_convert.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
|
||||
53
include/ck/utility/random_gen.hpp
Normal file
53
include/ck/utility/random_gen.hpp
Normal file
@@ -0,0 +1,53 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck {
|
||||
|
||||
// Pseudo random number generator
|
||||
// version for fp32
|
||||
template <typename T, uint32_t seed_t, std::enable_if_t<std::is_same<float, T>{}, bool> = false>
|
||||
__host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed = seed_t)
|
||||
{
|
||||
uint32_t x = *(reinterpret_cast<uint32_t*>(&val));
|
||||
uint32_t drop_bits = uint32_t(x) & 0xFFFFu;
|
||||
drop_bits ^= x >> 16;
|
||||
drop_bits = ((drop_bits & 31) << 11) | (drop_bits >> 5);
|
||||
drop_bits *= 0x7000149;
|
||||
// NOTE: If id is in 64 bit, we are only using lower 32 bit.
|
||||
// So, it can have an effect of using same id for multiple elements when the id is very
|
||||
// large!
|
||||
uint32_t rng = (drop_bits ^ 0x13371337 ^ (id * 229791) ^ seed);
|
||||
return rng;
|
||||
}
|
||||
|
||||
// version for fp16
|
||||
template <typename T, uint32_t seed_t, std::enable_if_t<std::is_same<half_t, T>{}, bool> = false>
|
||||
__host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed = seed_t)
|
||||
{
|
||||
uint16_t x = *(reinterpret_cast<uint16_t*>(&val));
|
||||
uint32_t drop_bits = uint32_t(x) & 0xFFFFu;
|
||||
drop_bits = ((drop_bits & 31) << 11) | (drop_bits >> 5);
|
||||
drop_bits *= 0x7000149;
|
||||
// NOTE: If id is in 64 bit, we are only using lower 32 bit.
|
||||
// So, it can have an effect of using same id for multiple elements when the id is very
|
||||
// large!
|
||||
uint32_t rng = (drop_bits ^ 0x13371337 ^ (id * 229791) ^ seed);
|
||||
return rng;
|
||||
}
|
||||
|
||||
// return 0 if data is not fp16 or fp32
|
||||
template <typename T,
|
||||
uint32_t seed_t,
|
||||
std::enable_if_t<!(std::is_same<float, T>{} || std::is_same<half_t, T>{}), bool> = false>
|
||||
__host__ __device__ uint32_t prand_generator(int id, T val, uint32_t seed = seed_t)
|
||||
{
|
||||
std::ignore = id;
|
||||
std::ignore = val;
|
||||
std::ignore = seed;
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
@@ -6,6 +6,7 @@
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/utility/data_type.hpp"
|
||||
#include "ck/utility/type.hpp"
|
||||
#include "ck/utility/type_convert.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
|
||||
230
include/ck/utility/type_convert.hpp
Normal file
230
include/ck/utility/type_convert.hpp
Normal file
@@ -0,0 +1,230 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/utility/data_type.hpp"
|
||||
#include "ck/utility/f8_utils.hpp"
|
||||
#include "ck/utility/random_gen.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// Convert X to Y
|
||||
template <typename Y, typename X>
|
||||
__host__ __device__ constexpr Y type_convert(X x)
|
||||
{
|
||||
static_assert(!std::is_reference_v<Y> && !std::is_reference_v<X>);
|
||||
|
||||
return static_cast<Y>(x);
|
||||
}
|
||||
|
||||
// convert bfp16 to fp32
|
||||
template <>
|
||||
inline __host__ __device__ constexpr float type_convert<float, bhalf_t>(bhalf_t x)
|
||||
{
|
||||
union
|
||||
{
|
||||
uint32_t int32;
|
||||
float fp32;
|
||||
} u = {uint32_t(x) << 16};
|
||||
|
||||
return u.fp32;
|
||||
}
|
||||
|
||||
// convert fp32 to bfp16
|
||||
template <>
|
||||
inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, float>(float x)
|
||||
{
|
||||
union
|
||||
{
|
||||
float fp32;
|
||||
uint32_t int32;
|
||||
} u = {x};
|
||||
|
||||
return uint16_t(u.int32 >> 16);
|
||||
}
|
||||
|
||||
// convert bfp16 to fp16 via fp32
|
||||
template <>
|
||||
inline __host__ __device__ constexpr half_t type_convert<half_t, bhalf_t>(bhalf_t x)
|
||||
{
|
||||
float x_fp32 = type_convert<float>(x);
|
||||
|
||||
return static_cast<half_t>(x_fp32);
|
||||
}
|
||||
|
||||
// convert fp16 to bfp16 via fp32
|
||||
template <>
|
||||
inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, half_t>(half_t x)
|
||||
{
|
||||
float x_fp32 = static_cast<float>(x);
|
||||
|
||||
return type_convert<bhalf_t>(x_fp32);
|
||||
}
|
||||
|
||||
// convert bfp16 to int32 via fp32
|
||||
template <>
|
||||
inline __host__ __device__ constexpr int32_t type_convert<int32_t, bhalf_t>(bhalf_t x)
|
||||
{
|
||||
float x_fp32 = type_convert<float>(x);
|
||||
|
||||
return static_cast<int32_t>(x_fp32);
|
||||
}
|
||||
|
||||
// convert int32 to bfp16 via fp32
|
||||
template <>
|
||||
inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, int32_t>(int32_t x)
|
||||
{
|
||||
float x_fp32 = static_cast<float>(x);
|
||||
|
||||
return type_convert<bhalf_t>(x_fp32);
|
||||
}
|
||||
|
||||
// convert bfp16 to int8 via fp32
|
||||
template <>
|
||||
inline __host__ __device__ constexpr int8_t type_convert<int8_t, bhalf_t>(bhalf_t x)
|
||||
{
|
||||
float x_fp32 = type_convert<float>(x);
|
||||
|
||||
return static_cast<int8_t>(x_fp32);
|
||||
}
|
||||
|
||||
// convert int8 to bfp16 via fp32
|
||||
template <>
|
||||
inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, int8_t>(int8_t x)
|
||||
{
|
||||
float x_fp32 = static_cast<float>(x);
|
||||
|
||||
return type_convert<bhalf_t>(x_fp32);
|
||||
}
|
||||
|
||||
// convert fp32 to fp8
|
||||
template <>
|
||||
inline __host__ __device__ f8_t type_convert<f8_t, float>(float x)
|
||||
{
|
||||
constexpr bool negative_zero_nan = true;
|
||||
constexpr bool clip = true;
|
||||
constexpr f8_rounding_mode rm = f8_rounding_mode::standard;
|
||||
constexpr uint32_t rng = 0;
|
||||
return utils::cast_to_f8<float, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
|
||||
x, rng);
|
||||
}
|
||||
|
||||
// convert fp8 to fp32
|
||||
template <>
|
||||
inline __host__ __device__ float type_convert<float, f8_t>(f8_t x)
|
||||
{
|
||||
constexpr bool negative_zero_nan = true;
|
||||
return utils::cast_from_f8<float, negative_zero_nan>(x);
|
||||
}
|
||||
|
||||
// convert fp16 to fp8
|
||||
template <>
|
||||
inline __host__ __device__ f8_t type_convert<f8_t, half_t>(half_t x)
|
||||
{
|
||||
constexpr bool negative_zero_nan = true;
|
||||
constexpr bool clip = true;
|
||||
constexpr f8_rounding_mode rm = f8_rounding_mode::standard;
|
||||
constexpr uint32_t rng = 0;
|
||||
return utils::cast_to_f8<half_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
|
||||
x, rng);
|
||||
}
|
||||
|
||||
// convert fp8 to fp16
|
||||
template <>
|
||||
inline __host__ __device__ half_t type_convert<half_t, f8_t>(f8_t x)
|
||||
{
|
||||
constexpr bool negative_zero_nan = true;
|
||||
return utils::cast_from_f8<half_t, negative_zero_nan>(x);
|
||||
}
|
||||
|
||||
// Declare a template function for bf16 conversion using RTN
|
||||
template <typename Y, typename X>
|
||||
__host__ __device__ constexpr Y bf16_convert_rtn(X x);
|
||||
|
||||
// Convert fp32 to bf16 with RTN if higher precision is needed
|
||||
template <>
|
||||
inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn<bhalf_t, float>(float x)
|
||||
{
|
||||
union
|
||||
{
|
||||
float fp32;
|
||||
uint32_t int32;
|
||||
} u = {x};
|
||||
|
||||
// When the exponent bits are not all 1s, then the value is zero, normal,
|
||||
// or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus
|
||||
// 1 if the least significant bit of the bfloat16 mantissa is 1 (odd).
|
||||
// This causes the bfloat16's mantissa to be incremented by 1 if the 16
|
||||
// least significant bits of the float mantissa are greater than 0x8000,
|
||||
// or if they are equal to 0x8000 and the least significant bit of the
|
||||
// bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when
|
||||
// the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already
|
||||
// has the value 0x7f, then incrementing it causes it to become 0x00 and
|
||||
// the exponent is incremented by one, which is the next higher FP value
|
||||
// to the unrounded bfloat16 value. When the bfloat16 value is subnormal
|
||||
// with an exponent of 0x00 and a mantissa of 0x7f, it may be rounded up
|
||||
// to a normal value with an exponent of 0x01 and a mantissa of 0x00.
|
||||
// When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F,
|
||||
// incrementing it causes it to become an exponent of 0xFF and a mantissa
|
||||
// of 0x00, which is Inf, the next higher value to the unrounded value.
|
||||
bool flag0 = ~u.int32 & 0x7f800000;
|
||||
|
||||
// When all of the exponent bits are 1, the value is Inf or NaN.
|
||||
// Inf is indicated by a zero mantissa. NaN is indicated by any nonzero
|
||||
// mantissa bit. Quiet NaN is indicated by the most significant mantissa
|
||||
// bit being 1. Signaling NaN is indicated by the most significant
|
||||
// mantissa bit being 0 but some other bit(s) being 1. If any of the
|
||||
// lower 16 bits of the mantissa are 1, we set the least significant bit
|
||||
// of the bfloat16 mantissa, in order to preserve signaling NaN in case
|
||||
// the bfloat16's mantissa bits are all 0.
|
||||
bool flag1 = !flag0 && (u.int32 & 0xffff);
|
||||
|
||||
u.int32 += flag0 ? 0x7fff + ((u.int32 >> 16) & 1) : 0; // Round to nearest, round to even
|
||||
u.int32 |= flag1 ? 0x10000 : 0x0; // Preserve signaling NaN
|
||||
|
||||
return uint16_t(u.int32 >> 16);
|
||||
}
|
||||
|
||||
// convert fp16 to bfp16 via fp32 with RTN if higher precision is needed
|
||||
template <>
|
||||
inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn<bhalf_t, half_t>(half_t x)
|
||||
{
|
||||
float x_fp32 = static_cast<float>(x);
|
||||
|
||||
return bf16_convert_rtn<bhalf_t>(x_fp32);
|
||||
}
|
||||
|
||||
// Declare a template function for fp8 conversion using SR
|
||||
template <typename Y, typename X>
|
||||
__host__ __device__ constexpr Y f8_convert_sr(X x);
|
||||
|
||||
// convert fp32 to fp8 with stochastic rounding
|
||||
template <>
|
||||
inline __host__ __device__ f8_t f8_convert_sr<f8_t, float>(float x)
|
||||
{
|
||||
constexpr bool negative_zero_nan = true;
|
||||
constexpr bool clip = true;
|
||||
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
|
||||
constexpr int seed = 42;
|
||||
// as thread id is not available on host, use 0 for prn generation
|
||||
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
|
||||
return utils::cast_to_f8<float, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
|
||||
x, rng);
|
||||
}
|
||||
|
||||
// convert fp16 to fp8 with stochastic rounding
|
||||
template <>
|
||||
inline __host__ __device__ f8_t f8_convert_sr<f8_t, half_t>(half_t x)
|
||||
{
|
||||
constexpr bool negative_zero_nan = true;
|
||||
constexpr bool clip = true;
|
||||
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
|
||||
constexpr int seed = 42;
|
||||
// as thread id is not available on host, use 0 for prn generation
|
||||
uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<uintptr_t>(&x), x);
|
||||
return utils::cast_to_f8<half_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
|
||||
x, rng);
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
@@ -13,6 +13,7 @@
|
||||
|
||||
#include "ck/utility/data_type.hpp"
|
||||
#include "ck/utility/span.hpp"
|
||||
#include "ck/utility/type_convert.hpp"
|
||||
|
||||
#include "ck/library/utility/algorithm.hpp"
|
||||
#include "ck/library/utility/ranges.hpp"
|
||||
|
||||
@@ -2,3 +2,6 @@ if (USE_BITINT_EXTENSION_INT4)
|
||||
add_gtest_executable(test_int4 int4.cpp)
|
||||
target_link_libraries(test_int4 PRIVATE utility)
|
||||
endif()
|
||||
|
||||
add_gtest_executable(test_fp8 fp8.cpp)
|
||||
target_link_libraries(test_fp8 PRIVATE utility)
|
||||
|
||||
123
test/data_type/fp8.cpp
Normal file
123
test/data_type/fp8.cpp
Normal file
@@ -0,0 +1,123 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ck/utility/data_type.hpp"
|
||||
#include "ck/utility/type_convert.hpp"
|
||||
|
||||
using ck::f8_convert_sr;
|
||||
using ck::f8_t;
|
||||
using ck::half_t;
|
||||
using ck::type_convert;
|
||||
|
||||
TEST(FP8, NumericLimits)
|
||||
{
|
||||
EXPECT_EQ(ck::NumericLimits<f8_t>::Min(), 0x08);
|
||||
EXPECT_EQ(ck::NumericLimits<f8_t>::Max(), 0x77);
|
||||
EXPECT_EQ(ck::NumericLimits<f8_t>::Lowest(), 0xF7);
|
||||
EXPECT_EQ(ck::NumericLimits<f8_t>::QuietNaN(), 0x80);
|
||||
}
|
||||
|
||||
TEST(FP8, ConvertFP32Nearest)
|
||||
{
|
||||
// fix the tolerance value
|
||||
float abs_tol = 1e-6;
|
||||
// convert 0 float to fp8 and back, check if holds
|
||||
ASSERT_NEAR(0.0f, type_convert<float>(type_convert<f8_t>(0.0f)), abs_tol);
|
||||
// convert minimal float to fp8 and back, check if holds
|
||||
ASSERT_NEAR(std::numeric_limits<float>::min(),
|
||||
type_convert<float>(type_convert<f8_t>(std::numeric_limits<float>::min())),
|
||||
abs_tol);
|
||||
// convert maximal f8_t to float and check if equal to 240.0
|
||||
ASSERT_NEAR(240.0f, type_convert<float>(type_convert<f8_t>(240.0f)), abs_tol);
|
||||
// convert maximal float to fp8 and back, check if clipped to 240.0
|
||||
ASSERT_NEAR(240.0f,
|
||||
type_convert<float>(type_convert<f8_t>(std::numeric_limits<float>::max())),
|
||||
abs_tol);
|
||||
// convert inf float to f8_t and check if it is qNan
|
||||
ASSERT_NEAR(0x80, type_convert<f8_t>(std::numeric_limits<float>::infinity()), abs_tol);
|
||||
// positive float value to fp8 and back, check if holds
|
||||
float pos_float = 0.0078125f;
|
||||
ASSERT_NEAR(pos_float, type_convert<float>(type_convert<f8_t>(pos_float)), abs_tol);
|
||||
// negative float value to fp8 and back, check if holds
|
||||
float neg_float = -0.0156250f;
|
||||
ASSERT_NEAR(neg_float, type_convert<float>(type_convert<f8_t>(neg_float)), abs_tol);
|
||||
}
|
||||
|
||||
TEST(FP8, ConvertFP32Stochastic)
|
||||
{
|
||||
// fix the tolerance value
|
||||
float abs_tol = 1e-6;
|
||||
// convert 0 float to fp8 and back, check if holds
|
||||
ASSERT_NEAR(0.0f, type_convert<float>(f8_convert_sr<f8_t>(0.0f)), abs_tol);
|
||||
// convert minimal float to fp8 and back, check if holds
|
||||
ASSERT_NEAR(std::numeric_limits<float>::min(),
|
||||
type_convert<float>(f8_convert_sr<f8_t>(std::numeric_limits<float>::min())),
|
||||
abs_tol);
|
||||
// convert maximal f8_t to float and check if equal to 240.0
|
||||
ASSERT_NEAR(240.0f, type_convert<float>(f8_convert_sr<f8_t>(240.0f)), abs_tol);
|
||||
// convert maximal float to fp8 and back, check if clipped to 240.0
|
||||
ASSERT_NEAR(240.0f,
|
||||
type_convert<float>(f8_convert_sr<f8_t>(std::numeric_limits<float>::max())),
|
||||
abs_tol);
|
||||
// convert inf float to f8_t and check if it is qNan
|
||||
ASSERT_NEAR(0x80, f8_convert_sr<f8_t>(std::numeric_limits<float>::infinity()), abs_tol);
|
||||
// positive float value to fp8 and back, check if holds
|
||||
float pos_float = 0.0078125f;
|
||||
ASSERT_NEAR(pos_float, type_convert<float>(f8_convert_sr<f8_t>(pos_float)), abs_tol);
|
||||
// negative float value to fp8 and back, check if holds
|
||||
float neg_float = -0.0156250f;
|
||||
ASSERT_NEAR(neg_float, type_convert<float>(f8_convert_sr<f8_t>(neg_float)), abs_tol);
|
||||
}
|
||||
|
||||
TEST(FP8, ConvertFP16Nearest)
|
||||
{
|
||||
// fix the tolerance value
|
||||
float abs_tol = 1e-3;
|
||||
// convert 0 fp16 to fp8 and back, check if holds
|
||||
ASSERT_NEAR(half_t{0.0}, type_convert<half_t>(type_convert<f8_t>(half_t{0.0})), abs_tol);
|
||||
// convert minimal fp16 to fp8 and back, check if holds
|
||||
ASSERT_NEAR(ck::NumericLimits<half_t>::Min(),
|
||||
type_convert<half_t>(type_convert<f8_t>(ck::NumericLimits<half_t>::Min())),
|
||||
abs_tol);
|
||||
// convert maximal f8_t to fp16 and check if equal to 240.0
|
||||
ASSERT_NEAR(half_t{240.0}, type_convert<half_t>(type_convert<f8_t>(half_t{240.0})), abs_tol);
|
||||
// convert maximal fp16 to fp8 and back, check if clipped to 240.0
|
||||
ASSERT_NEAR(half_t{240.0},
|
||||
type_convert<half_t>(type_convert<f8_t>(ck::NumericLimits<half_t>::Max())),
|
||||
abs_tol);
|
||||
// convert QuietNaN fp16 to f8_t and check if it is QuietNaN
|
||||
ASSERT_NEAR(0x80, type_convert<f8_t>(ck::NumericLimits<half_t>::QuietNaN()), abs_tol);
|
||||
// positive fp16 value to fp8 and back, check if holds
|
||||
half_t pos_half = half_t{0.0078125};
|
||||
ASSERT_NEAR(pos_half, type_convert<half_t>(type_convert<f8_t>(pos_half)), abs_tol);
|
||||
// negative fp16 value to fp8 and back, check if holds
|
||||
half_t neg_half = half_t{-0.0156250};
|
||||
ASSERT_NEAR(neg_half, type_convert<half_t>(type_convert<f8_t>(neg_half)), abs_tol);
|
||||
}
|
||||
|
||||
TEST(FP8, ConvertFP16Stochastic)
|
||||
{
|
||||
// fix the tolerance value
|
||||
float abs_tol = 1e-3;
|
||||
// convert 0 fp16 to fp8 and back, check if holds
|
||||
ASSERT_NEAR(half_t{0.0}, type_convert<half_t>(f8_convert_sr<f8_t>(half_t{0.0})), abs_tol);
|
||||
// convert minimal fp16 to fp8 and back, check if holds
|
||||
ASSERT_NEAR(ck::NumericLimits<half_t>::Min(),
|
||||
type_convert<half_t>(f8_convert_sr<f8_t>(ck::NumericLimits<half_t>::Min())),
|
||||
abs_tol);
|
||||
// convert maximal f8_t to fp16 and check if equal to 240.0
|
||||
ASSERT_NEAR(half_t{240.0}, type_convert<half_t>(f8_convert_sr<f8_t>(half_t{240.0})), abs_tol);
|
||||
// convert maximal fp16 to fp8 and back, check if clipped to 240.0
|
||||
ASSERT_NEAR(half_t{240.0},
|
||||
type_convert<half_t>(f8_convert_sr<f8_t>(ck::NumericLimits<half_t>::Max())),
|
||||
abs_tol);
|
||||
// convert QuietNaN fp16 to f8_t and check if it is QuietNaN
|
||||
ASSERT_NEAR(0x80, f8_convert_sr<f8_t>(ck::NumericLimits<half_t>::QuietNaN()), abs_tol);
|
||||
// positive fp16 value to fp8 and back, check if holds
|
||||
half_t pos_half = half_t{0.0078125};
|
||||
ASSERT_NEAR(pos_half, type_convert<half_t>(f8_convert_sr<f8_t>(pos_half)), abs_tol);
|
||||
// negative fp16 value to fp8 and back, check if holds
|
||||
half_t neg_half = half_t{-0.0156250};
|
||||
ASSERT_NEAR(neg_half, type_convert<half_t>(f8_convert_sr<f8_t>(neg_half)), abs_tol);
|
||||
}
|
||||
Reference in New Issue
Block a user