mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-13 17:55:48 +00:00
Switch default f8 conversion to stochastic rounding (#1048)
* Switch default f8 conversion to stochastic rounding
* Refactor f8-related type_converts
* Add an element-wise op
[ROCm/composable_kernel commit: 6ef034f6ca]
This commit is contained in:
@@ -134,6 +134,9 @@
|
||||
// inner product using V_DOT with DPP8 modifiers
|
||||
#define CK_USE_AMD_V_DOT_DPP8_INLINE_ASM 1
|
||||
|
||||
// set stochastic rounding as default for f8 conversions
|
||||
#define CK_USE_SR_F8_CONVERSION 1
|
||||
|
||||
// block synchronization only s_wait lgkmcnt(0), not vmcnt(0)
|
||||
#define CK_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM 1
|
||||
|
||||
|
||||
@@ -281,6 +281,24 @@ struct ConvertF8SR
|
||||
}
|
||||
};
|
||||
|
||||
struct ConvertF8RNE
|
||||
{
|
||||
// convert to fp8 using rounding to nearest even
|
||||
template <typename Y, typename X>
|
||||
__host__ __device__ void operator()(Y& y, const X& x) const
|
||||
{
|
||||
// check Y datatype
|
||||
static_assert(is_same<Y, f8_t>::value || is_same<Y, bf8_t>::value,
|
||||
"Data type is not supported by this operation!");
|
||||
|
||||
// check X datatype
|
||||
static_assert(is_same<X, float>::value || is_same<X, half_t>::value,
|
||||
"Data type is not supported by this operation!");
|
||||
|
||||
y = f8_convert_rne<Y>(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct Scale
|
||||
{
|
||||
__host__ __device__ Scale(float scale) : scale_(scale) {}
|
||||
|
||||
@@ -95,243 +95,6 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, int8_t>(int8_
|
||||
return type_convert<bhalf_t>(x_fp32);
|
||||
}
|
||||
|
||||
// convert fp32 to fp8
|
||||
template <>
|
||||
inline __host__ __device__ f8_t type_convert<f8_t, float>(float x)
|
||||
{
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
float max_fp8 = 240.0f;
|
||||
x = x > max_fp8 ? max_fp8 : (x < -max_fp8 ? -max_fp8 : x);
|
||||
union
|
||||
{
|
||||
float fval;
|
||||
uint32_t i32val;
|
||||
uint8_t i8val[4]; // not endian independent
|
||||
} val;
|
||||
val.fval = x;
|
||||
uint32_t ival = 0;
|
||||
ival = __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival, false); // false -> WORD0
|
||||
val.i32val = ival;
|
||||
return val.i8val[0];
|
||||
#else
|
||||
constexpr bool negative_zero_nan = true;
|
||||
constexpr bool clip = true;
|
||||
constexpr f8_rounding_mode rm = f8_rounding_mode::standard;
|
||||
constexpr uint32_t rng = 0;
|
||||
return utils::
|
||||
cast_to_f8<float, f8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(x,
|
||||
rng);
|
||||
#endif
|
||||
}
|
||||
|
||||
// convert fp8 to fp32
|
||||
template <>
|
||||
inline __host__ __device__ float type_convert<float, f8_t>(f8_t x)
|
||||
{
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
float fval;
|
||||
uint32_t i32val = static_cast<uint32_t>(x);
|
||||
fval = __builtin_amdgcn_cvt_f32_fp8(i32val, 0);
|
||||
// asm volatile("v_cvt_f32_fp8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val));
|
||||
return fval;
|
||||
#else
|
||||
constexpr bool negative_zero_nan = true;
|
||||
return utils::cast_from_f8<f8_t, float, negative_zero_nan>(x);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
inline __host__ __device__ float2_t type_convert<float2_t, f8x2_t>(f8x2_t x)
|
||||
{
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
const auto i16val = bit_cast<uint16_t>(x);
|
||||
return __builtin_amdgcn_cvt_pk_f32_fp8(i16val, 0);
|
||||
#else
|
||||
constexpr bool negative_zero_nan = true;
|
||||
const auto f8x2_v = vector_type<f8_t, 2>(x);
|
||||
vector_type<float, 2> f32x2_v;
|
||||
f32x2_v.template AsType<float>()(Number<0>{}) =
|
||||
utils::cast_from_f8<f8_t, float, negative_zero_nan>(
|
||||
f8x2_v.template AsType<f8_t>()[Number<0>{}]);
|
||||
f32x2_v.template AsType<float>()(Number<1>{}) =
|
||||
utils::cast_from_f8<f8_t, float, negative_zero_nan>(
|
||||
f8x2_v.template AsType<f8_t>()[Number<1>{}]);
|
||||
return f32x2_v.template AsType<float2_t>()[Number<0>{}];
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
inline __host__ __device__ half2_t type_convert<half2_t, float2_t>(float2_t x)
|
||||
{
|
||||
|
||||
const vector_type<float, 2> f32x2_v(x);
|
||||
const auto y = __builtin_amdgcn_cvt_pkrtz(f32x2_v.template AsType<float>()[Number<0>{}],
|
||||
f32x2_v.template AsType<float>()[Number<1>{}]);
|
||||
return bit_cast<half2_t>(y);
|
||||
}
|
||||
|
||||
// convert fp16 to fp8
|
||||
template <>
|
||||
inline __host__ __device__ f8_t type_convert<f8_t, half_t>(half_t x)
|
||||
{
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
// convert to float and use native converion
|
||||
return type_convert<f8_t>(type_convert<float>(x));
|
||||
#else
|
||||
constexpr bool negative_zero_nan = true;
|
||||
constexpr bool clip = true;
|
||||
constexpr f8_rounding_mode rm = f8_rounding_mode::standard;
|
||||
constexpr uint32_t rng = 0;
|
||||
return utils::
|
||||
cast_to_f8<half_t, f8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
|
||||
x, rng);
|
||||
#endif
|
||||
}
|
||||
|
||||
// convert fp8 to fp16
|
||||
template <>
|
||||
inline __host__ __device__ half_t type_convert<half_t, f8_t>(f8_t x)
|
||||
{
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
// use native conversion to float and convert to fp16
|
||||
return type_convert<half_t>(type_convert<float>(x));
|
||||
#else
|
||||
constexpr bool negative_zero_nan = true;
|
||||
return utils::cast_from_f8<f8_t, half_t, negative_zero_nan>(x);
|
||||
#endif
|
||||
}
|
||||
|
||||
// convert fp32 to bf8
|
||||
template <>
|
||||
inline __host__ __device__ bf8_t type_convert<bf8_t, float>(float x)
|
||||
{
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
union
|
||||
{
|
||||
float fval;
|
||||
uint32_t i32val;
|
||||
uint8_t i8val[4]; // not endian independent
|
||||
} val;
|
||||
val.fval = x;
|
||||
uint32_t ival = 0;
|
||||
ival = __builtin_amdgcn_cvt_pk_bf8_f32(val.fval, val.fval, ival, false); // false -> WORD0
|
||||
val.i32val = ival;
|
||||
return val.i8val[0];
|
||||
#else
|
||||
constexpr bool negative_zero_nan = true;
|
||||
constexpr bool clip = true;
|
||||
constexpr f8_rounding_mode rm = f8_rounding_mode::standard;
|
||||
constexpr uint32_t rng = 0;
|
||||
return utils::
|
||||
cast_to_f8<float, bf8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
|
||||
x, rng);
|
||||
#endif
|
||||
}
|
||||
|
||||
// convert bf8 to fp32
|
||||
template <>
|
||||
inline __host__ __device__ float type_convert<float, bf8_t>(bf8_t x)
|
||||
{
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
float fval;
|
||||
uint32_t i32val = static_cast<uint32_t>(x);
|
||||
fval = __builtin_amdgcn_cvt_f32_bf8(i32val, 0);
|
||||
// asm volatile("v_cvt_f32_bf8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val));
|
||||
return fval;
|
||||
#else
|
||||
constexpr bool negative_zero_nan = true;
|
||||
return utils::cast_from_f8<bf8_t, float, negative_zero_nan>(x);
|
||||
#endif
|
||||
}
|
||||
|
||||
// convert fp16 to bf8
|
||||
template <>
|
||||
inline __host__ __device__ bf8_t type_convert<bf8_t, half_t>(half_t x)
|
||||
{
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
// convert to float and use native converion
|
||||
return type_convert<bf8_t>(type_convert<float>(x));
|
||||
#else
|
||||
constexpr bool negative_zero_nan = true;
|
||||
constexpr bool clip = true;
|
||||
constexpr f8_rounding_mode rm = f8_rounding_mode::standard;
|
||||
constexpr uint32_t rng = 0;
|
||||
return utils::
|
||||
cast_to_f8<half_t, bf8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
|
||||
x, rng);
|
||||
#endif
|
||||
}
|
||||
|
||||
// convert bf8 to fp16
|
||||
template <>
|
||||
inline __host__ __device__ half_t type_convert<half_t, bf8_t>(bf8_t x)
|
||||
{
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
// use native conversion to float and convert to fp16
|
||||
return type_convert<half_t>(type_convert<float>(x));
|
||||
#else
|
||||
constexpr bool negative_zero_nan = true;
|
||||
return utils::cast_from_f8<bf8_t, half_t, negative_zero_nan>(x);
|
||||
#endif
|
||||
}
|
||||
|
||||
// Declare a template function for bf16 conversion using RTN
|
||||
template <typename Y, typename X>
|
||||
__host__ __device__ constexpr Y bf16_convert_rtn(X x);
|
||||
|
||||
// Convert fp32 to bf16 with RTN if higher precision is needed
|
||||
template <>
|
||||
inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn<bhalf_t, float>(float x)
|
||||
{
|
||||
union
|
||||
{
|
||||
float fp32;
|
||||
uint32_t int32;
|
||||
} u = {x};
|
||||
|
||||
// When the exponent bits are not all 1s, then the value is zero, normal,
|
||||
// or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus
|
||||
// 1 if the least significant bit of the bfloat16 mantissa is 1 (odd).
|
||||
// This causes the bfloat16's mantissa to be incremented by 1 if the 16
|
||||
// least significant bits of the float mantissa are greater than 0x8000,
|
||||
// or if they are equal to 0x8000 and the least significant bit of the
|
||||
// bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when
|
||||
// the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already
|
||||
// has the value 0x7f, then incrementing it causes it to become 0x00 and
|
||||
// the exponent is incremented by one, which is the next higher FP value
|
||||
// to the unrounded bfloat16 value. When the bfloat16 value is subnormal
|
||||
// with an exponent of 0x00 and a mantissa of 0x7f, it may be rounded up
|
||||
// to a normal value with an exponent of 0x01 and a mantissa of 0x00.
|
||||
// When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F,
|
||||
// incrementing it causes it to become an exponent of 0xFF and a mantissa
|
||||
// of 0x00, which is Inf, the next higher value to the unrounded value.
|
||||
bool flag0 = ~u.int32 & 0x7f800000;
|
||||
|
||||
// When all of the exponent bits are 1, the value is Inf or NaN.
|
||||
// Inf is indicated by a zero mantissa. NaN is indicated by any nonzero
|
||||
// mantissa bit. Quiet NaN is indicated by the most significant mantissa
|
||||
// bit being 1. Signaling NaN is indicated by the most significant
|
||||
// mantissa bit being 0 but some other bit(s) being 1. If any of the
|
||||
// lower 16 bits of the mantissa are 1, we set the least significant bit
|
||||
// of the bfloat16 mantissa, in order to preserve signaling NaN in case
|
||||
// the bfloat16's mantissa bits are all 0.
|
||||
bool flag1 = !flag0 && (u.int32 & 0xffff);
|
||||
|
||||
u.int32 += flag0 ? 0x7fff + ((u.int32 >> 16) & 1) : 0; // Round to nearest, round to even
|
||||
u.int32 |= flag1 ? 0x10000 : 0x0; // Preserve signaling NaN
|
||||
|
||||
return uint16_t(u.int32 >> 16);
|
||||
}
|
||||
|
||||
// convert fp16 to bfp16 via fp32 with RTN if higher precision is needed
|
||||
template <>
|
||||
inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn<bhalf_t, half_t>(half_t x)
|
||||
{
|
||||
float x_fp32 = static_cast<float>(x);
|
||||
|
||||
return bf16_convert_rtn<bhalf_t>(x_fp32);
|
||||
}
|
||||
|
||||
// Declare a template function for fp8 conversion using SR
|
||||
template <typename Y, typename X>
|
||||
__host__ __device__ constexpr Y f8_convert_sr(X x);
|
||||
@@ -343,6 +106,8 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, float>(float x)
|
||||
constexpr int seed = 42;
|
||||
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
float max_fp8 = 240.0f;
|
||||
x = x > max_fp8 ? max_fp8 : (x < -max_fp8 ? -max_fp8 : x);
|
||||
union
|
||||
{
|
||||
float fval;
|
||||
@@ -423,7 +188,6 @@ inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, half_t>(half_t x)
|
||||
constexpr bool clip = true;
|
||||
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
|
||||
constexpr int seed = 42;
|
||||
// as thread id is not available on host, use 0 for prn generation
|
||||
uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<uintptr_t>(&x), x);
|
||||
return utils::
|
||||
cast_to_f8<half_t, bf8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
|
||||
@@ -431,4 +195,288 @@ inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, half_t>(half_t x)
|
||||
#endif
|
||||
}
|
||||
|
||||
// Declare a template function for fp8 conversion using RNE
|
||||
template <typename Y, typename X>
|
||||
__host__ __device__ constexpr Y f8_convert_rne(X x);
|
||||
|
||||
// convert fp32 to fp8 with rounding to nearest even
|
||||
template <>
|
||||
inline __host__ __device__ f8_t f8_convert_rne<f8_t, float>(float x)
|
||||
{
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
float max_fp8 = 240.0f;
|
||||
x = x > max_fp8 ? max_fp8 : (x < -max_fp8 ? -max_fp8 : x);
|
||||
union
|
||||
{
|
||||
float fval;
|
||||
uint32_t i32val;
|
||||
uint8_t i8val[4]; // not endian independent
|
||||
} val;
|
||||
val.fval = x;
|
||||
uint32_t ival = 0;
|
||||
ival = __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival, false); // false -> WORD0
|
||||
val.i32val = ival;
|
||||
return val.i8val[0];
|
||||
#else
|
||||
constexpr bool negative_zero_nan = true;
|
||||
constexpr bool clip = true;
|
||||
constexpr f8_rounding_mode rm = f8_rounding_mode::standard;
|
||||
constexpr uint32_t rng = 0;
|
||||
return utils::
|
||||
cast_to_f8<float, f8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(x,
|
||||
rng);
|
||||
#endif
|
||||
}
|
||||
|
||||
// convert fp16 to fp8 with rounding to nearest even
|
||||
template <>
|
||||
inline __host__ __device__ f8_t f8_convert_rne<f8_t, half_t>(half_t x)
|
||||
{
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
// convert to float and use native converion
|
||||
return f8_convert_rne<f8_t>(type_convert<float>(x));
|
||||
#else
|
||||
constexpr bool negative_zero_nan = true;
|
||||
constexpr bool clip = true;
|
||||
constexpr f8_rounding_mode rm = f8_rounding_mode::standard;
|
||||
constexpr uint32_t rng = 0;
|
||||
return utils::
|
||||
cast_to_f8<half_t, f8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
|
||||
x, rng);
|
||||
#endif
|
||||
}
|
||||
|
||||
// convert fp32 to bf8 with rounding to nearest even
|
||||
template <>
|
||||
inline __host__ __device__ bf8_t f8_convert_rne<bf8_t, float>(float x)
|
||||
{
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
union
|
||||
{
|
||||
float fval;
|
||||
uint32_t i32val;
|
||||
uint8_t i8val[4]; // not endian independent
|
||||
} val;
|
||||
val.fval = x;
|
||||
uint32_t ival = 0;
|
||||
ival = __builtin_amdgcn_cvt_pk_bf8_f32(val.fval, val.fval, ival, false); // false -> WORD0
|
||||
val.i32val = ival;
|
||||
return val.i8val[0];
|
||||
#else
|
||||
constexpr bool negative_zero_nan = true;
|
||||
constexpr bool clip = true;
|
||||
constexpr f8_rounding_mode rm = f8_rounding_mode::standard;
|
||||
constexpr uint32_t rng = 0;
|
||||
return utils::
|
||||
cast_to_f8<float, bf8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
|
||||
x, rng);
|
||||
#endif
|
||||
}
|
||||
|
||||
// convert fp16 to bf8 with rounding to nearest even
|
||||
template <>
|
||||
inline __host__ __device__ bf8_t f8_convert_rne<bf8_t, half_t>(half_t x)
|
||||
{
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
// convert to float and use native converion
|
||||
return f8_convert_rne<bf8_t>(type_convert<float>(x));
|
||||
#else
|
||||
constexpr bool negative_zero_nan = true;
|
||||
constexpr bool clip = true;
|
||||
constexpr f8_rounding_mode rm = f8_rounding_mode::standard;
|
||||
constexpr uint32_t rng = 0;
|
||||
return utils::
|
||||
cast_to_f8<half_t, bf8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
|
||||
x, rng);
|
||||
#endif
|
||||
}
|
||||
|
||||
// convert fp32 to fp8
|
||||
template <>
|
||||
inline __host__ __device__ f8_t type_convert<f8_t, float>(float x)
|
||||
{
|
||||
#if defined CK_USE_SR_F8_CONVERSION
|
||||
return f8_convert_sr<f8_t>(x);
|
||||
#else
|
||||
return f8_convert_rne<f8_t>(x);
|
||||
#endif
|
||||
}
|
||||
|
||||
// convert fp8 to fp32
|
||||
template <>
|
||||
inline __host__ __device__ float type_convert<float, f8_t>(f8_t x)
|
||||
{
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
float fval;
|
||||
uint32_t i32val = static_cast<uint32_t>(x);
|
||||
fval = __builtin_amdgcn_cvt_f32_fp8(i32val, 0);
|
||||
// asm volatile("v_cvt_f32_fp8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val));
|
||||
return fval;
|
||||
#else
|
||||
constexpr bool negative_zero_nan = true;
|
||||
return utils::cast_from_f8<f8_t, float, negative_zero_nan>(x);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
inline __host__ __device__ float2_t type_convert<float2_t, f8x2_t>(f8x2_t x)
|
||||
{
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
const auto i16val = bit_cast<uint16_t>(x);
|
||||
return __builtin_amdgcn_cvt_pk_f32_fp8(i16val, 0);
|
||||
#else
|
||||
constexpr bool negative_zero_nan = true;
|
||||
const auto f8x2_v = vector_type<f8_t, 2>(x);
|
||||
vector_type<float, 2> f32x2_v;
|
||||
f32x2_v.template AsType<float>()(Number<0>{}) =
|
||||
utils::cast_from_f8<f8_t, float, negative_zero_nan>(
|
||||
f8x2_v.template AsType<f8_t>()[Number<0>{}]);
|
||||
f32x2_v.template AsType<float>()(Number<1>{}) =
|
||||
utils::cast_from_f8<f8_t, float, negative_zero_nan>(
|
||||
f8x2_v.template AsType<f8_t>()[Number<1>{}]);
|
||||
return f32x2_v.template AsType<float2_t>()[Number<0>{}];
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
inline __host__ __device__ half2_t type_convert<half2_t, float2_t>(float2_t x)
|
||||
{
|
||||
|
||||
const vector_type<float, 2> f32x2_v(x);
|
||||
const auto y = __builtin_amdgcn_cvt_pkrtz(f32x2_v.template AsType<float>()[Number<0>{}],
|
||||
f32x2_v.template AsType<float>()[Number<1>{}]);
|
||||
return bit_cast<half2_t>(y);
|
||||
}
|
||||
|
||||
// convert fp16 to fp8
|
||||
template <>
|
||||
inline __host__ __device__ f8_t type_convert<f8_t, half_t>(half_t x)
|
||||
{
|
||||
#if defined CK_USE_SR_F8_CONVERSION
|
||||
return f8_convert_sr<f8_t>(x);
|
||||
#else
|
||||
return f8_convert_nre<f8_t>(x);
|
||||
#endif
|
||||
}
|
||||
|
||||
// convert fp8 to fp16
|
||||
template <>
|
||||
inline __host__ __device__ half_t type_convert<half_t, f8_t>(f8_t x)
|
||||
{
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
// use native conversion to float and convert to fp16
|
||||
return type_convert<half_t>(type_convert<float>(x));
|
||||
#else
|
||||
constexpr bool negative_zero_nan = true;
|
||||
return utils::cast_from_f8<f8_t, half_t, negative_zero_nan>(x);
|
||||
#endif
|
||||
}
|
||||
|
||||
// convert fp32 to bf8
|
||||
template <>
|
||||
inline __host__ __device__ bf8_t type_convert<bf8_t, float>(float x)
|
||||
{
|
||||
#if defined CK_USE_SR_F8_CONVERSION
|
||||
return f8_convert_sr<bf8_t>(x);
|
||||
#else
|
||||
return f8_convert_rne<bf8_t>(x);
|
||||
#endif
|
||||
}
|
||||
|
||||
// convert bf8 to fp32
|
||||
template <>
|
||||
inline __host__ __device__ float type_convert<float, bf8_t>(bf8_t x)
|
||||
{
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
float fval;
|
||||
uint32_t i32val = static_cast<uint32_t>(x);
|
||||
fval = __builtin_amdgcn_cvt_f32_bf8(i32val, 0);
|
||||
// asm volatile("v_cvt_f32_bf8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val));
|
||||
return fval;
|
||||
#else
|
||||
constexpr bool negative_zero_nan = true;
|
||||
return utils::cast_from_f8<bf8_t, float, negative_zero_nan>(x);
|
||||
#endif
|
||||
}
|
||||
|
||||
// convert fp16 to bf8
|
||||
template <>
|
||||
inline __host__ __device__ bf8_t type_convert<bf8_t, half_t>(half_t x)
|
||||
{
|
||||
#if defined CK_USE_SR_F8_CONVERSION
|
||||
return f8_convert_sr<bf8_t>(x);
|
||||
#else
|
||||
return f8_convert_rne<bf8_t>(x);
|
||||
#endif
|
||||
}
|
||||
|
||||
// convert bf8 to fp16
|
||||
template <>
|
||||
inline __host__ __device__ half_t type_convert<half_t, bf8_t>(bf8_t x)
|
||||
{
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
// use native conversion to float and convert to fp16
|
||||
return type_convert<half_t>(type_convert<float>(x));
|
||||
#else
|
||||
constexpr bool negative_zero_nan = true;
|
||||
return utils::cast_from_f8<bf8_t, half_t, negative_zero_nan>(x);
|
||||
#endif
|
||||
}
|
||||
|
||||
// Declare a template function for bf16 conversion using RTN
|
||||
template <typename Y, typename X>
|
||||
__host__ __device__ constexpr Y bf16_convert_rtn(X x);
|
||||
|
||||
// Convert fp32 to bf16 with RTN if higher precision is needed
|
||||
template <>
|
||||
inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn<bhalf_t, float>(float x)
|
||||
{
|
||||
union
|
||||
{
|
||||
float fp32;
|
||||
uint32_t int32;
|
||||
} u = {x};
|
||||
|
||||
// When the exponent bits are not all 1s, then the value is zero, normal,
|
||||
// or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus
|
||||
// 1 if the least significant bit of the bfloat16 mantissa is 1 (odd).
|
||||
// This causes the bfloat16's mantissa to be incremented by 1 if the 16
|
||||
// least significant bits of the float mantissa are greater than 0x8000,
|
||||
// or if they are equal to 0x8000 and the least significant bit of the
|
||||
// bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when
|
||||
// the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already
|
||||
// has the value 0x7f, then incrementing it causes it to become 0x00 and
|
||||
// the exponent is incremented by one, which is the next higher FP value
|
||||
// to the unrounded bfloat16 value. When the bfloat16 value is subnormal
|
||||
// with an exponent of 0x00 and a mantissa of 0x7f, it may be rounded up
|
||||
// to a normal value with an exponent of 0x01 and a mantissa of 0x00.
|
||||
// When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F,
|
||||
// incrementing it causes it to become an exponent of 0xFF and a mantissa
|
||||
// of 0x00, which is Inf, the next higher value to the unrounded value.
|
||||
bool flag0 = ~u.int32 & 0x7f800000;
|
||||
|
||||
// When all of the exponent bits are 1, the value is Inf or NaN.
|
||||
// Inf is indicated by a zero mantissa. NaN is indicated by any nonzero
|
||||
// mantissa bit. Quiet NaN is indicated by the most significant mantissa
|
||||
// bit being 1. Signaling NaN is indicated by the most significant
|
||||
// mantissa bit being 0 but some other bit(s) being 1. If any of the
|
||||
// lower 16 bits of the mantissa are 1, we set the least significant bit
|
||||
// of the bfloat16 mantissa, in order to preserve signaling NaN in case
|
||||
// the bfloat16's mantissa bits are all 0.
|
||||
bool flag1 = !flag0 && (u.int32 & 0xffff);
|
||||
|
||||
u.int32 += flag0 ? 0x7fff + ((u.int32 >> 16) & 1) : 0; // Round to nearest, round to even
|
||||
u.int32 |= flag1 ? 0x10000 : 0x0; // Preserve signaling NaN
|
||||
|
||||
return uint16_t(u.int32 >> 16);
|
||||
}
|
||||
|
||||
// convert fp16 to bfp16 via fp32 with RTN if higher precision is needed
|
||||
template <>
|
||||
inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn<bhalf_t, half_t>(half_t x)
|
||||
{
|
||||
float x_fp32 = static_cast<float>(x);
|
||||
|
||||
return bf16_convert_rtn<bhalf_t>(x_fp32);
|
||||
}
|
||||
} // namespace ck
|
||||
|
||||
Reference in New Issue
Block a user