mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 18:17:44 +00:00
Add native conversions fp8<->fp32 (#908)
* Add native conversions
* Add bf8 conversions
[ROCm/composable_kernel commit: f17af2e9ed]
This commit is contained in:
@@ -5,6 +5,8 @@
|
||||
|
||||
#include "ck/utility/data_type.hpp"
|
||||
|
||||
// these conversions are disabled if native conversions available
|
||||
#if !defined(__gfx940__) && !defined(__gfx941__) && !defined(__gfx942__)
|
||||
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
|
||||
namespace ck {
|
||||
|
||||
@@ -242,4 +244,5 @@ __host__ __device__ Y cast_from_f8(X x)
|
||||
}
|
||||
|
||||
} // namespace ck::utils
|
||||
#endif
|
||||
#endif // #if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
|
||||
#endif // #if !defined(__gfx940__) && !defined(__gfx941__) && !defined(__gfx942__)
|
||||
|
||||
@@ -85,6 +85,19 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, int8_t>(int8_
|
||||
template <>
|
||||
inline __host__ __device__ f8_t type_convert<f8_t, float>(float x)
|
||||
{
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
union
|
||||
{
|
||||
float fval;
|
||||
uint32_t i32val;
|
||||
uint8_t i8val[4]; // not endian independent
|
||||
} val;
|
||||
val.fval = x;
|
||||
uint32_t ival = 0;
|
||||
ival = __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival, false); // false -> WORD0
|
||||
val.i32val = ival;
|
||||
return val.i8val[0];
|
||||
#else
|
||||
constexpr bool negative_zero_nan = true;
|
||||
constexpr bool clip = true;
|
||||
constexpr f8_rounding_mode rm = f8_rounding_mode::standard;
|
||||
@@ -92,20 +105,33 @@ inline __host__ __device__ f8_t type_convert<f8_t, float>(float x)
|
||||
return utils::
|
||||
cast_to_f8<float, f8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(x,
|
||||
rng);
|
||||
#endif
|
||||
}
|
||||
|
||||
// convert fp8 to fp32
|
||||
template <>
|
||||
inline __host__ __device__ float type_convert<float, f8_t>(f8_t x)
|
||||
{
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
float fval;
|
||||
uint32_t i32val = static_cast<uint32_t>(x);
|
||||
fval = __builtin_amdgcn_cvt_f32_fp8(i32val, 0);
|
||||
// asm volatile("v_cvt_f32_fp8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val));
|
||||
return fval;
|
||||
#else
|
||||
constexpr bool negative_zero_nan = true;
|
||||
return utils::cast_from_f8<f8_t, float, negative_zero_nan>(x);
|
||||
#endif
|
||||
}
|
||||
|
||||
// convert fp16 to fp8
|
||||
template <>
|
||||
inline __host__ __device__ f8_t type_convert<f8_t, half_t>(half_t x)
|
||||
{
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
// convert to float and use native converion
|
||||
return type_convert<f8_t>(type_convert<float>(x));
|
||||
#else
|
||||
constexpr bool negative_zero_nan = true;
|
||||
constexpr bool clip = true;
|
||||
constexpr f8_rounding_mode rm = f8_rounding_mode::standard;
|
||||
@@ -113,14 +139,20 @@ inline __host__ __device__ f8_t type_convert<f8_t, half_t>(half_t x)
|
||||
return utils::
|
||||
cast_to_f8<half_t, f8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
|
||||
x, rng);
|
||||
#endif
|
||||
}
|
||||
|
||||
// convert fp8 to fp16
|
||||
template <>
|
||||
inline __host__ __device__ half_t type_convert<half_t, f8_t>(f8_t x)
|
||||
{
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
// use native conversion to float and convert to fp16
|
||||
return type_convert<half_t>(type_convert<float>(x));
|
||||
#else
|
||||
constexpr bool negative_zero_nan = true;
|
||||
return utils::cast_from_f8<f8_t, half_t, negative_zero_nan>(x);
|
||||
#endif
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -129,6 +161,19 @@ inline __host__ __device__ half_t type_convert<half_t, f8_t>(f8_t x)
|
||||
template <>
|
||||
inline __host__ __device__ bf8_t type_convert<bf8_t, float>(float x)
|
||||
{
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
union
|
||||
{
|
||||
float fval;
|
||||
uint32_t i32val;
|
||||
uint8_t i8val[4]; // not endian independent
|
||||
} val;
|
||||
val.fval = x;
|
||||
uint32_t ival = 0;
|
||||
ival = __builtin_amdgcn_cvt_pk_bf8_f32(val.fval, val.fval, ival, false); // false -> WORD0
|
||||
val.i32val = ival;
|
||||
return val.i8val[0];
|
||||
#else
|
||||
constexpr bool negative_zero_nan = true;
|
||||
constexpr bool clip = true;
|
||||
constexpr f8_rounding_mode rm = f8_rounding_mode::standard;
|
||||
@@ -136,20 +181,33 @@ inline __host__ __device__ bf8_t type_convert<bf8_t, float>(float x)
|
||||
return utils::
|
||||
cast_to_f8<float, bf8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
|
||||
x, rng);
|
||||
#endif
|
||||
}
|
||||
|
||||
// convert bf8 to fp32
|
||||
template <>
|
||||
inline __host__ __device__ float type_convert<float, bf8_t>(bf8_t x)
|
||||
{
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
float fval;
|
||||
uint32_t i32val = static_cast<uint32_t>(x);
|
||||
fval = __builtin_amdgcn_cvt_f32_bf8(i32val, 0);
|
||||
// asm volatile("v_cvt_f32_bf8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val));
|
||||
return fval;
|
||||
#else
|
||||
constexpr bool negative_zero_nan = true;
|
||||
return utils::cast_from_f8<bf8_t, float, negative_zero_nan>(x);
|
||||
#endif
|
||||
}
|
||||
|
||||
// convert fp16 to bf8
|
||||
template <>
|
||||
inline __host__ __device__ bf8_t type_convert<bf8_t, half_t>(half_t x)
|
||||
{
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
// convert to float and use native converion
|
||||
return type_convert<f8_t>(type_convert<float>(x));
|
||||
#else
|
||||
constexpr bool negative_zero_nan = true;
|
||||
constexpr bool clip = true;
|
||||
constexpr f8_rounding_mode rm = f8_rounding_mode::standard;
|
||||
@@ -157,14 +215,20 @@ inline __host__ __device__ bf8_t type_convert<bf8_t, half_t>(half_t x)
|
||||
return utils::
|
||||
cast_to_f8<half_t, bf8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
|
||||
x, rng);
|
||||
#endif
|
||||
}
|
||||
|
||||
// convert bf8 to fp16
|
||||
template <>
|
||||
inline __host__ __device__ half_t type_convert<half_t, bf8_t>(bf8_t x)
|
||||
{
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
// use native conversion to float and convert to fp16
|
||||
return type_convert<half_t>(type_convert<float>(x));
|
||||
#else
|
||||
constexpr bool negative_zero_nan = true;
|
||||
return utils::cast_from_f8<bf8_t, half_t, negative_zero_nan>(x);
|
||||
#endif
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -234,30 +298,47 @@ __host__ __device__ constexpr Y f8_convert_sr(X x);
|
||||
template <>
|
||||
inline __host__ __device__ f8_t f8_convert_sr<f8_t, float>(float x)
|
||||
{
|
||||
constexpr int seed = 42;
|
||||
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
union
|
||||
{
|
||||
float fval;
|
||||
uint32_t i32val;
|
||||
uint8_t i8val[4]; // not endian independent
|
||||
} val;
|
||||
val.fval = x;
|
||||
uint32_t ival = 0;
|
||||
ival = __builtin_amdgcn_cvt_sr_fp8_f32(val.fval, rng, ival, 0); // 0 pos
|
||||
val.i32val = ival;
|
||||
return val.i8val[0]; // little endian
|
||||
#else
|
||||
constexpr bool negative_zero_nan = true;
|
||||
constexpr bool clip = true;
|
||||
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
|
||||
constexpr int seed = 42;
|
||||
// as thread id is not available on host, use 0 for prn generation
|
||||
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
|
||||
return utils::
|
||||
cast_to_f8<float, f8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(x,
|
||||
rng);
|
||||
#endif
|
||||
}
|
||||
|
||||
// convert fp16 to fp8 with stochastic rounding
|
||||
template <>
|
||||
inline __host__ __device__ f8_t f8_convert_sr<f8_t, half_t>(half_t x)
|
||||
{
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
// convert to float and use native converion
|
||||
return f8_convert_sr<f8_t>(type_convert<float>(x));
|
||||
#else
|
||||
constexpr bool negative_zero_nan = true;
|
||||
constexpr bool clip = true;
|
||||
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
|
||||
constexpr int seed = 42;
|
||||
// as thread id is not available on host, use 0 for prn generation
|
||||
uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<uintptr_t>(&x), x);
|
||||
return utils::
|
||||
cast_to_f8<half_t, f8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
|
||||
x, rng);
|
||||
#endif
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -266,21 +347,38 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, half_t>(half_t x)
|
||||
template <>
|
||||
inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, float>(float x)
|
||||
{
|
||||
constexpr int seed = 42;
|
||||
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
union
|
||||
{
|
||||
float fval;
|
||||
uint32_t i32val;
|
||||
uint8_t i8val[4]; // not endian independent
|
||||
} val;
|
||||
val.fval = x;
|
||||
uint32_t ival = 0;
|
||||
ival = __builtin_amdgcn_cvt_sr_bf8_f32(val.fval, rng, ival, 0); // 0 pos
|
||||
val.i32val = ival;
|
||||
return val.i8val[0]; // little endian
|
||||
#else
|
||||
constexpr bool negative_zero_nan = true;
|
||||
constexpr bool clip = true;
|
||||
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
|
||||
constexpr int seed = 42;
|
||||
// as thread id is not available on host, use 0 for prn generation
|
||||
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
|
||||
return utils::
|
||||
cast_to_f8<float, bf8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
|
||||
x, rng);
|
||||
#endif
|
||||
}
|
||||
|
||||
// convert fp16 to bf8 with stochastic rounding
|
||||
template <>
|
||||
inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, half_t>(half_t x)
|
||||
{
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
// convert to float and use native converion
|
||||
return f8_convert_sr<f8_t>(type_convert<float>(x));
|
||||
#else
|
||||
constexpr bool negative_zero_nan = true;
|
||||
constexpr bool clip = true;
|
||||
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
|
||||
@@ -290,6 +388,7 @@ inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, half_t>(half_t x)
|
||||
return utils::
|
||||
cast_to_f8<half_t, bf8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
|
||||
x, rng);
|
||||
#endif
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
Reference in New Issue
Block a user