mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 10:09:41 +00:00
Add FP16/BF16<->FP8/BF8 conversions (#2035)
* Move conversion functions and add missing conversions
* Add tests
* Add missing conversions
* Add missing conversions
* Add bf8 tests
* Update clipping for vectors
* Add missing conversions
* Add bf16 fp8 tests
* Add bf16 bf8 tests
* Fix device conversion
* Fix conversions
* Fix vector use
* Minor fix
* Add a workaround flag
* Add a workaround flag for bf16 conversion
* Add another workaround
* Add a workaround for fp16 to bf8 conversion
* Update type alias
* Add docstrings and missing wrappers
* Fix if defined macros
* Fix more if defined macros
* Add comments
* Remove __host__ specifier
* Add a gfx950 guard
* Update function naming
[ROCm/composable_kernel commit: 265af71a71]
This commit is contained in:
@@ -248,6 +248,12 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)
|
||||
// workaround: compiler issue on gfx950
|
||||
#define CK_WORKAROUND_FP32_TO_FP4_SR_CONVERSION 1
|
||||
|
||||
// workaround: compiler issue on gfx950
|
||||
#define CK_WORKAROUND_FP16_TO_FP8_CONVERSION 1
|
||||
|
||||
// workaround: compiler issue on gfx950
|
||||
#define CK_WORKAROUND_BF16_TO_FP8_CONVERSION 1
|
||||
|
||||
// denorm test fix, necessary for gfx90a
|
||||
#ifndef CK_GFX90A_DENORM_WORKAROUND
|
||||
#define CK_GFX90A_DENORM_WORKAROUND 0
|
||||
|
||||
@@ -64,6 +64,9 @@ enum class ck_saturation_t
|
||||
namespace fp8_impl {
|
||||
|
||||
typedef fp8_storage_t fp8x2_storage_t __attribute__((ext_vector_type(2)));
|
||||
typedef _Float16 half2_t __attribute__((ext_vector_type(2)));
|
||||
typedef ushort ushortx2_t __attribute__((ext_vector_type(2)));
|
||||
typedef short shortx2_t __attribute__((ext_vector_type(2)));
|
||||
typedef float float2_t __attribute__((ext_vector_type(2)));
|
||||
|
||||
__host__ __device__ static inline constexpr bool fnuz_f8_is_nan(f8_fnuz_t a)
|
||||
@@ -270,7 +273,7 @@ static __host__ __device__ float cast_to_f32_from_f8(fp8_storage_t v)
|
||||
}
|
||||
|
||||
template <ck_fp8_interpretation_t interpret>
|
||||
static __device__ float2_t cast_to_f32x2_from_f8x2(fp8x2_storage_t v)
|
||||
static __device__ float2_t cast_to_f32_from_f8(fp8x2_storage_t v)
|
||||
{
|
||||
const auto i16val = bit_cast<uint16_t>(v);
|
||||
|
||||
@@ -458,6 +461,510 @@ __is_interpret_supported([[maybe_unused]] ck_fp8_interpretation_t interp)
|
||||
#endif
|
||||
}
|
||||
|
||||
#if defined(__gfx950__)
|
||||
template <ck_fp8_interpretation_t interpret,
|
||||
bool saturate,
|
||||
bool stochastic_rounding = false,
|
||||
ck::enable_if_t<interpret == ck_fp8_interpretation_t::CK_E4M3_OCP, bool> = false,
|
||||
ck::enable_if_t<stochastic_rounding == true, bool> = false>
|
||||
static __device__ fp8_storage_t cast_to_f8_from_f16(_Float16 v, unsigned int rng = 0)
|
||||
{
|
||||
union
|
||||
{
|
||||
unsigned int i32val;
|
||||
half2_t half_vec;
|
||||
fp8_storage_t i8val[4];
|
||||
} val;
|
||||
|
||||
constexpr unsigned int i32val = 0;
|
||||
val.half_vec[0] = v;
|
||||
|
||||
if constexpr(saturate)
|
||||
{
|
||||
if((val.i32val & 0x7FFF) != 0x7FFF)
|
||||
{
|
||||
val.half_vec[0] = __builtin_amdgcn_fmed3h(val.half_vec[0], 448.0, -448.0);
|
||||
}
|
||||
}
|
||||
|
||||
val.i32val =
|
||||
__builtin_amdgcn_cvt_scalef32_sr_fp8_f16(i32val, val.half_vec[0], rng, /* scale */ 1.f, 0);
|
||||
|
||||
return val.i8val[0];
|
||||
}
|
||||
|
||||
template <ck_fp8_interpretation_t interpret,
|
||||
bool saturate,
|
||||
bool stochastic_rounding = false,
|
||||
ck::enable_if_t<interpret == ck_fp8_interpretation_t::CK_E4M3_OCP, bool> = false,
|
||||
ck::enable_if_t<stochastic_rounding == true, bool> = false>
|
||||
static __device__ fp8x2_storage_t cast_to_f8_from_f16(half2_t v, unsigned int rng = 0)
|
||||
{
|
||||
// there is no packed conversion with SR, so convert one element at a time
|
||||
return fp8x2_storage_t{
|
||||
cast_to_f8_from_f16<interpret, saturate, stochastic_rounding>(v[0], rng),
|
||||
cast_to_f8_from_f16<interpret, saturate, stochastic_rounding>(v[1], rng)};
|
||||
}
|
||||
|
||||
template <ck_fp8_interpretation_t interpret,
|
||||
bool saturate,
|
||||
bool stochastic_rounding = false,
|
||||
ck::enable_if_t<interpret == ck_fp8_interpretation_t::CK_E5M2_OCP, bool> = false,
|
||||
ck::enable_if_t<stochastic_rounding == true, bool> = false>
|
||||
static __device__ fp8_storage_t cast_to_f8_from_f16(_Float16 v, unsigned int rng = 0)
|
||||
{
|
||||
union
|
||||
{
|
||||
unsigned int i32val;
|
||||
half2_t half_vec;
|
||||
fp8_storage_t i8val[4];
|
||||
} val;
|
||||
|
||||
constexpr unsigned int i32val = 0;
|
||||
val.half_vec[0] = v;
|
||||
|
||||
if constexpr(saturate)
|
||||
{
|
||||
if((val.i32val & 0x7FFF) != 0x7FFF)
|
||||
{
|
||||
val.half_vec[0] = __builtin_amdgcn_fmed3h(val.half_vec[0], 57344.0, -57344.0);
|
||||
}
|
||||
}
|
||||
|
||||
val.i32val =
|
||||
__builtin_amdgcn_cvt_scalef32_sr_bf8_f16(i32val, val.half_vec[0], rng, /* scale */ 1.f, 0);
|
||||
|
||||
return val.i8val[0];
|
||||
}
|
||||
|
||||
template <ck_fp8_interpretation_t interpret,
|
||||
bool saturate,
|
||||
bool stochastic_rounding = false,
|
||||
ck::enable_if_t<interpret == ck_fp8_interpretation_t::CK_E5M2_OCP, bool> = false,
|
||||
ck::enable_if_t<stochastic_rounding == true, bool> = false>
|
||||
static __device__ fp8x2_storage_t cast_to_f8_from_f16(half2_t v, unsigned int rng = 0)
|
||||
{
|
||||
// there is no packed conversion with SR, so convert one element at a time
|
||||
return fp8x2_storage_t{
|
||||
cast_to_f8_from_f16<interpret, saturate, stochastic_rounding>(v[0], rng),
|
||||
cast_to_f8_from_f16<interpret, saturate, stochastic_rounding>(v[1], rng)};
|
||||
}
|
||||
|
||||
template <ck_fp8_interpretation_t interpret,
|
||||
bool saturate,
|
||||
bool stochastic_rounding = false,
|
||||
ck::enable_if_t<interpret == ck_fp8_interpretation_t::CK_E4M3_OCP, bool> = false,
|
||||
ck::enable_if_t<stochastic_rounding == false, bool> = false>
|
||||
static __device__ fp8_storage_t cast_to_f8_from_f16(_Float16 v, unsigned int rng = 0)
|
||||
{
|
||||
std::ignore = rng;
|
||||
|
||||
union
|
||||
{
|
||||
unsigned int i32val;
|
||||
half2_t half_vec;
|
||||
shortx2_t i16_vec;
|
||||
fp8_storage_t i8val[4];
|
||||
} val;
|
||||
|
||||
constexpr shortx2_t i16x2val = {0, 0};
|
||||
val.half_vec[0] = v;
|
||||
|
||||
if constexpr(saturate)
|
||||
{
|
||||
if((val.i32val & 0x7FFF) != 0x7FFF)
|
||||
{
|
||||
val.half_vec[0] = __builtin_amdgcn_fmed3h(val.half_vec[0], 448.0, -448.0);
|
||||
}
|
||||
}
|
||||
|
||||
val.i16_vec =
|
||||
__builtin_amdgcn_cvt_scalef32_pk_fp8_f16(i16x2val, val.half_vec, /* scale */ 1.f, 0);
|
||||
|
||||
return val.i8val[0];
|
||||
}
|
||||
|
||||
template <ck_fp8_interpretation_t interpret,
|
||||
bool saturate,
|
||||
bool stochastic_rounding = false,
|
||||
ck::enable_if_t<interpret == ck_fp8_interpretation_t::CK_E4M3_OCP, bool> = false,
|
||||
ck::enable_if_t<stochastic_rounding == false, bool> = false>
|
||||
static __device__ fp8x2_storage_t cast_to_f8_from_f16(half2_t v, unsigned int rng = 0)
|
||||
{
|
||||
#if CK_WORKAROUND_FP16_TO_FP8_CONVERSION
|
||||
return fp8x2_storage_t{
|
||||
cast_to_f8_from_f16<interpret, saturate, stochastic_rounding>(v[0], rng),
|
||||
cast_to_f8_from_f16<interpret, saturate, stochastic_rounding>(v[1], rng)};
|
||||
#else
|
||||
std::ignore = rng;
|
||||
|
||||
union
|
||||
{
|
||||
half2_t half_vec;
|
||||
shortx2_t i16_vec;
|
||||
fp8_storage_t i8val[4];
|
||||
} val;
|
||||
|
||||
constexpr shortx2_t i16x2val = {0, 0};
|
||||
val.half_vec = v;
|
||||
|
||||
if constexpr(saturate)
|
||||
{
|
||||
if((val.i16_vec[0] & 0x7FFF) != 0x7FFF)
|
||||
{
|
||||
val.half_vec[0] = __builtin_amdgcn_fmed3h(val.half_vec[0], 448.0, -448.0);
|
||||
}
|
||||
if((val.i16_vec[1] & 0x7FFF) != 0x7FFF)
|
||||
{
|
||||
val.half_vec[1] = __builtin_amdgcn_fmed3h(val.half_vec[1], 448.0, -448.0);
|
||||
}
|
||||
}
|
||||
|
||||
val.i16_vec =
|
||||
__builtin_amdgcn_cvt_scalef32_pk_fp8_f16(i16x2val, val.half_vec, /* scale */ 1.f, 0);
|
||||
|
||||
return fp8x2_storage_t{val.i8val[0], val.i8val[1]};
|
||||
#endif
|
||||
}
|
||||
|
||||
template <ck_fp8_interpretation_t interpret,
|
||||
bool saturate,
|
||||
bool stochastic_rounding = false,
|
||||
ck::enable_if_t<interpret == ck_fp8_interpretation_t::CK_E5M2_OCP, bool> = false,
|
||||
ck::enable_if_t<stochastic_rounding == false, bool> = false>
|
||||
static __device__ fp8_storage_t cast_to_f8_from_f16(_Float16 v, unsigned int rng = 0)
|
||||
{
|
||||
std::ignore = rng;
|
||||
|
||||
union
|
||||
{
|
||||
unsigned int i32val;
|
||||
half2_t half_vec;
|
||||
shortx2_t i16_vec;
|
||||
fp8_storage_t i8val[4];
|
||||
} val;
|
||||
|
||||
constexpr shortx2_t i16x2val = {0, 0};
|
||||
val.half_vec[0] = v;
|
||||
|
||||
if constexpr(saturate)
|
||||
{
|
||||
if((val.i32val & 0x7FFF) != 0x7FFF)
|
||||
{
|
||||
val.half_vec[0] = __builtin_amdgcn_fmed3h(val.half_vec[0], 57344.0, -57344.0);
|
||||
}
|
||||
}
|
||||
|
||||
val.half_vec =
|
||||
__builtin_amdgcn_cvt_scalef32_pk_bf8_f16(i16x2val, val.half_vec, /* scale */ 1.f, 0);
|
||||
|
||||
return val.i8val[0];
|
||||
}
|
||||
|
||||
template <ck_fp8_interpretation_t interpret,
|
||||
bool saturate,
|
||||
bool stochastic_rounding = false,
|
||||
ck::enable_if_t<interpret == ck_fp8_interpretation_t::CK_E5M2_OCP, bool> = false,
|
||||
ck::enable_if_t<stochastic_rounding == false, bool> = false>
|
||||
static __device__ fp8x2_storage_t cast_to_f8_from_f16(half2_t v, unsigned int rng = 0)
|
||||
{
|
||||
#if CK_WORKAROUND_FP16_TO_FP8_CONVERSION
|
||||
return fp8x2_storage_t{
|
||||
cast_to_f8_from_f16<interpret, saturate, stochastic_rounding>(v[0], rng),
|
||||
cast_to_f8_from_f16<interpret, saturate, stochastic_rounding>(v[1], rng)};
|
||||
#else
|
||||
std::ignore = rng;
|
||||
|
||||
union
|
||||
{
|
||||
half2_t half_vec;
|
||||
shortx2_t i16_vec;
|
||||
fp8_storage_t i8val[4];
|
||||
} val;
|
||||
|
||||
constexpr shortx2_t i16x2val = {0, 0};
|
||||
val.half_vec = v;
|
||||
|
||||
if constexpr(saturate)
|
||||
{
|
||||
if((val.i16_vec[0] & 0x7FFF) != 0x7FFF)
|
||||
{
|
||||
val.half_vec[0] = __builtin_amdgcn_fmed3h(val.half_vec[0], 57344.0, -57344.0);
|
||||
}
|
||||
if((val.i16_vec[1] & 0x7FFF) != 0x7FFF)
|
||||
{
|
||||
val.half_vec[1] = __builtin_amdgcn_fmed3h(val.half_vec[1], 57344.0, -57344.0);
|
||||
}
|
||||
}
|
||||
|
||||
val.i16_vec =
|
||||
__builtin_amdgcn_cvt_scalef32_pk_bf8_f16(i16x2val, val.half_vec, /* scale */ 1.f, 0);
|
||||
|
||||
return fp8x2_storage_t{val.i8val[0], val.i8val[1]};
|
||||
#endif
|
||||
}
|
||||
|
||||
template <ck_fp8_interpretation_t interpret,
|
||||
bool saturate,
|
||||
bool stochastic_rounding = false,
|
||||
ck::enable_if_t<interpret == ck_fp8_interpretation_t::CK_E4M3_OCP, bool> = false,
|
||||
ck::enable_if_t<stochastic_rounding == true, bool> = false>
|
||||
static __device__ fp8_storage_t cast_to_f8_from_bf16(ushort v, unsigned int rng = 0)
|
||||
{
|
||||
union
|
||||
{
|
||||
unsigned int i32val;
|
||||
ushortx2_t bhalf_vec;
|
||||
fp8_storage_t i8val[4];
|
||||
} val;
|
||||
|
||||
constexpr unsigned int i32val = 0;
|
||||
val.bhalf_vec[0] = v;
|
||||
|
||||
if constexpr(saturate)
|
||||
{
|
||||
if((val.i32val & 0x7FFF) != 0x7FFF)
|
||||
{
|
||||
val.bhalf_vec[0] =
|
||||
ushort((bit_cast<uint32_t>(__builtin_amdgcn_fmed3f(
|
||||
bit_cast<float>(uint32_t{val.bhalf_vec[0]} << 16), 448.0, -448.0)) >>
|
||||
16)); // convert to float and back
|
||||
}
|
||||
}
|
||||
|
||||
val.i32val = __builtin_amdgcn_cvt_scalef32_sr_fp8_bf16(
|
||||
i32val, val.bhalf_vec[0], rng, /* scale */ 1.f, 0);
|
||||
|
||||
return val.i8val[0];
|
||||
}
|
||||
|
||||
template <ck_fp8_interpretation_t interpret,
|
||||
bool saturate,
|
||||
bool stochastic_rounding = false,
|
||||
ck::enable_if_t<interpret == ck_fp8_interpretation_t::CK_E4M3_OCP, bool> = false,
|
||||
ck::enable_if_t<stochastic_rounding == true, bool> = false>
|
||||
static __device__ fp8x2_storage_t cast_to_f8_from_bf16(ushortx2_t v, unsigned int rng = 0)
|
||||
{
|
||||
// there is no packed conversion with SR, so convert one element at a time
|
||||
return fp8x2_storage_t{
|
||||
cast_to_f8_from_bf16<interpret, saturate, stochastic_rounding>(v[0], rng),
|
||||
cast_to_f8_from_bf16<interpret, saturate, stochastic_rounding>(v[1], rng)};
|
||||
}
|
||||
|
||||
template <ck_fp8_interpretation_t interpret,
|
||||
bool saturate,
|
||||
bool stochastic_rounding = false,
|
||||
ck::enable_if_t<interpret == ck_fp8_interpretation_t::CK_E5M2_OCP, bool> = false,
|
||||
ck::enable_if_t<stochastic_rounding == true, bool> = false>
|
||||
static __device__ fp8_storage_t cast_to_f8_from_bf16(ushort v, unsigned int rng = 0)
|
||||
{
|
||||
union
|
||||
{
|
||||
unsigned int i32val;
|
||||
ushortx2_t bhalf_vec;
|
||||
fp8_storage_t i8val[4];
|
||||
} val;
|
||||
|
||||
constexpr unsigned int i32val = 0;
|
||||
val.bhalf_vec[0] = v;
|
||||
|
||||
if constexpr(saturate)
|
||||
{
|
||||
if((val.i32val & 0x7FFF) != 0x7FFF)
|
||||
{
|
||||
val.bhalf_vec[0] = ushort(
|
||||
(bit_cast<uint32_t>(__builtin_amdgcn_fmed3f(
|
||||
bit_cast<float>(uint32_t{val.bhalf_vec[0]} << 16), 57344.0, -57344.0)) >>
|
||||
16)); // convert to float and back
|
||||
}
|
||||
}
|
||||
|
||||
val.i32val = __builtin_amdgcn_cvt_scalef32_sr_bf8_bf16(
|
||||
i32val, val.bhalf_vec[0], rng, /* scale */ 1.f, 0);
|
||||
|
||||
return val.i8val[0];
|
||||
}
|
||||
|
||||
template <ck_fp8_interpretation_t interpret,
|
||||
bool saturate,
|
||||
bool stochastic_rounding = false,
|
||||
ck::enable_if_t<interpret == ck_fp8_interpretation_t::CK_E5M2_OCP, bool> = false,
|
||||
ck::enable_if_t<stochastic_rounding == true, bool> = false>
|
||||
static __device__ fp8x2_storage_t cast_to_f8_from_bf16(ushortx2_t v, unsigned int rng = 0)
|
||||
{
|
||||
// there is no packed conversion with SR, so convert one element at a time
|
||||
return fp8x2_storage_t{
|
||||
cast_to_f8_from_bf16<interpret, saturate, stochastic_rounding>(v[0], rng),
|
||||
cast_to_f8_from_bf16<interpret, saturate, stochastic_rounding>(v[1], rng)};
|
||||
}
|
||||
|
||||
template <ck_fp8_interpretation_t interpret,
|
||||
bool saturate,
|
||||
bool stochastic_rounding = false,
|
||||
ck::enable_if_t<interpret == ck_fp8_interpretation_t::CK_E4M3_OCP, bool> = false,
|
||||
ck::enable_if_t<stochastic_rounding == false, bool> = false>
|
||||
static __device__ fp8_storage_t cast_to_f8_from_bf16(ushort v, unsigned int rng = 0)
|
||||
{
|
||||
std::ignore = rng;
|
||||
|
||||
union
|
||||
{
|
||||
unsigned int i32val;
|
||||
ushortx2_t bhalf_vec;
|
||||
shortx2_t i16_vec;
|
||||
fp8_storage_t i8val[4];
|
||||
} val;
|
||||
|
||||
constexpr shortx2_t i16x2val = {0, 0};
|
||||
val.bhalf_vec[0] = v;
|
||||
|
||||
if constexpr(saturate)
|
||||
{
|
||||
if((val.i32val & 0x7FFF) != 0x7FFF)
|
||||
{
|
||||
val.bhalf_vec[0] =
|
||||
ushort((bit_cast<uint32_t>(__builtin_amdgcn_fmed3f(
|
||||
bit_cast<float>(uint32_t{val.bhalf_vec[0]} << 16), 448.0, -448.0)) >>
|
||||
16)); // convert to float and back
|
||||
}
|
||||
}
|
||||
|
||||
val.i16_vec =
|
||||
__builtin_amdgcn_cvt_scalef32_pk_fp8_bf16(i16x2val, val.bhalf_vec, /* scale */ 1.f, 0);
|
||||
|
||||
return val.i8val[0];
|
||||
}
|
||||
|
||||
template <ck_fp8_interpretation_t interpret,
|
||||
bool saturate,
|
||||
bool stochastic_rounding = false,
|
||||
ck::enable_if_t<interpret == ck_fp8_interpretation_t::CK_E4M3_OCP, bool> = false,
|
||||
ck::enable_if_t<stochastic_rounding == false, bool> = false>
|
||||
static __device__ fp8x2_storage_t cast_to_f8_from_bf16(ushortx2_t v, unsigned int rng = 0)
|
||||
{
|
||||
#if CK_WORKAROUND_BF16_TO_FP8_CONVERSION
|
||||
return fp8x2_storage_t{
|
||||
cast_to_f8_from_bf16<interpret, saturate, stochastic_rounding>(v[0], rng),
|
||||
cast_to_f8_from_bf16<interpret, saturate, stochastic_rounding>(v[1], rng)};
|
||||
#else
|
||||
std::ignore = rng;
|
||||
|
||||
union
|
||||
{
|
||||
ushortx2_t bhalf_vec;
|
||||
shortx2_t i16_vec;
|
||||
fp8_storage_t i8val[4];
|
||||
} val;
|
||||
|
||||
constexpr shortx2_t i16x2val = {0, 0};
|
||||
val.bhalf_vec = v;
|
||||
|
||||
if constexpr(saturate)
|
||||
{
|
||||
if((val.i16_vec[0] & 0x7FFF) != 0x7FFF)
|
||||
{
|
||||
val.bhalf_vec[0] =
|
||||
ushort((bit_cast<uint32_t>(__builtin_amdgcn_fmed3f(
|
||||
bit_cast<float>(uint32_t{val.bhalf_vec[0]} << 16), 448.0, -448.0)) >>
|
||||
16)); // convert to float and back
|
||||
}
|
||||
if((val.i16_vec[1] & 0x7FFF) != 0x7FFF)
|
||||
{
|
||||
val.bhalf_vec[1] =
|
||||
ushort((bit_cast<uint32_t>(__builtin_amdgcn_fmed3f(
|
||||
bit_cast<float>(uint32_t{val.bhalf_vec[1]} << 16), 448.0, -448.0)) >>
|
||||
16)); // convert to float and back
|
||||
}
|
||||
}
|
||||
|
||||
val.i16_vec =
|
||||
__builtin_amdgcn_cvt_scalef32_pk_fp8_bf16(i16x2val, val.bhalf_vec, /* scale */ 1.f, 0);
|
||||
|
||||
return fp8x2_storage_t{val.i8val[0], val.i8val[1]};
|
||||
#endif
|
||||
}
|
||||
|
||||
template <ck_fp8_interpretation_t interpret,
|
||||
bool saturate,
|
||||
bool stochastic_rounding = false,
|
||||
ck::enable_if_t<interpret == ck_fp8_interpretation_t::CK_E5M2_OCP, bool> = false,
|
||||
ck::enable_if_t<stochastic_rounding == false, bool> = false>
|
||||
static __device__ fp8_storage_t cast_to_f8_from_bf16(ushort v, unsigned int rng = 0)
|
||||
{
|
||||
std::ignore = rng;
|
||||
|
||||
union
|
||||
{
|
||||
unsigned int i32val;
|
||||
ushortx2_t bhalf_vec;
|
||||
shortx2_t i16_vec;
|
||||
fp8_storage_t i8val[4];
|
||||
} val;
|
||||
|
||||
constexpr shortx2_t i16x2val = {0, 0};
|
||||
val.bhalf_vec[0] = v;
|
||||
|
||||
if constexpr(saturate)
|
||||
{
|
||||
if((val.i32val & 0x7FFF) != 0x7FFF)
|
||||
{
|
||||
val.bhalf_vec[0] = ushort(
|
||||
(bit_cast<uint32_t>(__builtin_amdgcn_fmed3f(
|
||||
bit_cast<float>(uint32_t{val.bhalf_vec[0]} << 16), 57344.0, -57344.0)) >>
|
||||
16)); // convert to float and back
|
||||
}
|
||||
}
|
||||
|
||||
val.i16_vec =
|
||||
__builtin_amdgcn_cvt_scalef32_pk_bf8_bf16(i16x2val, val.bhalf_vec, /* scale */ 1.f, 0);
|
||||
|
||||
return val.i8val[0];
|
||||
}
|
||||
|
||||
template <ck_fp8_interpretation_t interpret,
|
||||
bool saturate,
|
||||
bool stochastic_rounding = false,
|
||||
ck::enable_if_t<interpret == ck_fp8_interpretation_t::CK_E5M2_OCP, bool> = false,
|
||||
ck::enable_if_t<stochastic_rounding == false, bool> = false>
|
||||
static __device__ fp8x2_storage_t cast_to_f8_from_bf16(ushortx2_t v, unsigned int rng = 0)
|
||||
{
|
||||
std::ignore = rng;
|
||||
|
||||
union
|
||||
{
|
||||
ushortx2_t bhalf_vec;
|
||||
shortx2_t i16_vec;
|
||||
fp8_storage_t i8val[4];
|
||||
} val;
|
||||
|
||||
constexpr shortx2_t i16x2val = {0, 0};
|
||||
val.bhalf_vec = v;
|
||||
|
||||
if constexpr(saturate)
|
||||
{
|
||||
if((val.i16_vec[0] & 0x7FFF) != 0x7FFF)
|
||||
{
|
||||
val.bhalf_vec[0] = ushort(
|
||||
(bit_cast<uint32_t>(__builtin_amdgcn_fmed3f(
|
||||
bit_cast<float>(uint32_t{val.bhalf_vec[0]} << 16), 57344.0, -57344.0)) >>
|
||||
16)); // convert to float and back
|
||||
}
|
||||
if((val.i16_vec[1] & 0x7FFF) != 0x7FFF)
|
||||
{
|
||||
val.bhalf_vec[1] = ushort(
|
||||
(bit_cast<uint32_t>(__builtin_amdgcn_fmed3f(
|
||||
bit_cast<float>(uint32_t{val.bhalf_vec[1]} << 16), 57344.0, -57344.0)) >>
|
||||
16)); // convert to float and back
|
||||
}
|
||||
}
|
||||
|
||||
val.i16_vec =
|
||||
__builtin_amdgcn_cvt_scalef32_pk_bf8_bf16(i16x2val, val.bhalf_vec, /* scale */ 1.f, 0);
|
||||
|
||||
return fp8x2_storage_t{val.i8val[0], val.i8val[1]};
|
||||
}
|
||||
#endif // defined(__gfx950__)
|
||||
|
||||
#if CK_FP8_CVT_FAST_PATH
|
||||
// The conversion function is from rocblas
|
||||
// https://github.com/ROCm/rocBLAS/blob/9b7f692abe3c54b88d1e77e045a7db7f1f188b69/library/include/internal/rocblas_float8.h#L79
|
||||
@@ -523,6 +1030,84 @@ static __device__ fp8_storage_t cast_to_f8_from_f32(float v, unsigned int rng =
|
||||
}
|
||||
return i8data;
|
||||
}
|
||||
|
||||
template <ck_fp8_interpretation_t interpret, bool saturate, bool stochastic_rounding = false>
|
||||
static __device__ fp8x2_storage_t cast_to_f8_from_f32(float2_t v, unsigned int rng = 0)
|
||||
{
|
||||
if constexpr(stochastic_rounding)
|
||||
{
|
||||
// there is no packed conversion with SR, so convert one element at a time
|
||||
return fp8x2_storage_t{
|
||||
cast_to_f8_from_f32<interpret, saturate, stochastic_rounding>(v[0], rng),
|
||||
cast_to_f8_from_f32<interpret, saturate, stochastic_rounding>(v[1], rng)};
|
||||
}
|
||||
else
|
||||
{
|
||||
union
|
||||
{
|
||||
float fval;
|
||||
unsigned int i32val;
|
||||
unsigned char i8val[4];
|
||||
} val0, val1;
|
||||
|
||||
val0.fval = v[0];
|
||||
val1.fval = v[1];
|
||||
|
||||
unsigned int ival = 0;
|
||||
|
||||
if constexpr(saturate)
|
||||
{
|
||||
if constexpr(interpret == ck_fp8_interpretation_t::CK_E4M3_FNUZ)
|
||||
{
|
||||
if((val0.i32val & 0x7F800000) != 0x7F800000)
|
||||
{ /// propagate NAN/INF, no clipping
|
||||
val0.fval = __builtin_amdgcn_fmed3f(val0.fval, 240.0, -240.0);
|
||||
}
|
||||
if((val1.i32val & 0x7F800000) != 0x7F800000)
|
||||
{ /// propagate NAN/INF, no clipping
|
||||
val1.fval = __builtin_amdgcn_fmed3f(val1.fval, 240.0, -240.0);
|
||||
}
|
||||
}
|
||||
else if constexpr(interpret == ck_fp8_interpretation_t::CK_E4M3_OCP)
|
||||
{ // OCP type
|
||||
if((val0.i32val & 0x7F800000) != 0x7F800000)
|
||||
{ /// propagate NAN/INF, no clipping
|
||||
val0.fval = __builtin_amdgcn_fmed3f(val0.fval, 448.0, -448.0);
|
||||
}
|
||||
if((val1.i32val & 0x7F800000) != 0x7F800000)
|
||||
{ /// propagate NAN/INF, no clipping
|
||||
val1.fval = __builtin_amdgcn_fmed3f(val1.fval, 448.0, -448.0);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if((val0.i32val & 0x7F800000) != 0x7F800000)
|
||||
{ /// propagate NAN/INF, no clipping
|
||||
val0.fval = __builtin_amdgcn_fmed3f(val0.fval, 57344.0, -57344.0);
|
||||
}
|
||||
if((val1.i32val & 0x7F800000) != 0x7F800000)
|
||||
{ /// propagate NAN/INF, no clipping
|
||||
val1.fval = __builtin_amdgcn_fmed3f(val1.fval, 57344.0, -57344.0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// RNE CVT
|
||||
if constexpr((interpret == ck_fp8_interpretation_t::CK_E4M3_FNUZ) ||
|
||||
(interpret == ck_fp8_interpretation_t::CK_E4M3_OCP))
|
||||
{
|
||||
ival = __builtin_amdgcn_cvt_pk_fp8_f32(val0.fval, val1.fval, ival, false);
|
||||
}
|
||||
else
|
||||
{
|
||||
ival = __builtin_amdgcn_cvt_pk_bf8_f32(val0.fval, val1.fval, ival, false);
|
||||
}
|
||||
|
||||
val0.i32val = ival;
|
||||
|
||||
return fp8x2_storage_t{val0.i8val[0], val0.i8val[1]};
|
||||
}
|
||||
}
|
||||
#endif // CK_FP8_CVT_FAST_PATH
|
||||
|
||||
// The conversion function is from rocblas
|
||||
@@ -797,6 +1382,7 @@ __host__ __device__ static inline fp8_storage_t cast_to_f8(T _x, unsigned int rn
|
||||
*
|
||||
* \tparam interp interpretation of fp8
|
||||
* \tparam sat saturation of fp8
|
||||
* \tparam stochastic_rounding switch between RNE and SR
|
||||
* \param f float number
|
||||
* \return fp8_storage_t
|
||||
*/
|
||||
@@ -882,6 +1468,47 @@ __host__ static inline fp8_storage_t cvt_float_to_fp8(const float f)
|
||||
#endif // CK_FP8_CVT_FAST_PATH
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief convert vector of 2 floats to vector of 2 @p fp8_storage_t
|
||||
*
|
||||
* \tparam interp interpretation of fp8
|
||||
* \tparam sat saturation of fp8
|
||||
* \tparam stochastic_rounding switch between RNE and SR
|
||||
* \param f vector of 2 floats
|
||||
* \return fp8x2_storage_t
|
||||
*/
|
||||
template <ck_fp8_interpretation_t interp,
|
||||
ck_saturation_t sat = ck_saturation_t::CK_SATFINITE,
|
||||
bool stochastic_rounding = false>
|
||||
#if CK_FP8_CVT_FAST_PATH
|
||||
__device__ static inline fp8x2_storage_t cvt_float_to_fp8(const float2_t f)
|
||||
{
|
||||
__is_interpret_supported(interp);
|
||||
uint32_t rng = 0;
|
||||
if constexpr(stochastic_rounding)
|
||||
{
|
||||
constexpr int seed = 1254739;
|
||||
#ifndef CK_CODE_GEN_RTC
|
||||
rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&f), f[0]);
|
||||
#else
|
||||
rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&f), f[0]);
|
||||
#endif
|
||||
}
|
||||
return cast_to_f8_from_f32<interp, sat == ck_saturation_t::CK_SATFINITE, stochastic_rounding>(
|
||||
f, rng);
|
||||
#else
|
||||
#if CK_USE_OCP_FP8
|
||||
__host__ __device__ static inline fp8x2_storage_t cvt_float_to_fp8(const float2_t f)
|
||||
{
|
||||
#else
|
||||
__host__ static inline fp8x2_storage_t cvt_float_to_fp8(const float2_t f)
|
||||
{
|
||||
#endif // CK_USE_OCP_FP8
|
||||
return fp8x2_storage_t{cvt_float_to_fp8<interp, sat, stochastic_rounding>(f[0]),
|
||||
cvt_float_to_fp8<interp, sat, stochastic_rounding>(f[1])};
|
||||
#endif // CK_FP8_CVT_FAST_PATH
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief convert _Float16 to @p fp8_storage_t
|
||||
*
|
||||
@@ -900,87 +1527,168 @@ __host__ __device__ static inline fp8_storage_t cvt_half_t_to_fp8(const _Float16
|
||||
__host__ static inline fp8_storage_t cvt_half_t_to_fp8(const _Float16 x)
|
||||
#endif
|
||||
{
|
||||
return cvt_float_to_fp8<interp, sat, stochastic_rounding>(static_cast<float>(x));
|
||||
{
|
||||
__is_interpret_supported(interp);
|
||||
uint32_t rng = 0;
|
||||
if constexpr(stochastic_rounding)
|
||||
{
|
||||
constexpr int seed = 1254739;
|
||||
#ifndef CK_CODE_GEN_RTC
|
||||
rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
|
||||
#else
|
||||
rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&x), x);
|
||||
#endif
|
||||
}
|
||||
#if defined(__gfx950__)
|
||||
return cast_to_f8_from_f16<interp,
|
||||
sat == ck_saturation_t::CK_SATFINITE,
|
||||
stochastic_rounding>(x, rng);
|
||||
#else
|
||||
std::ignore = rng;
|
||||
return cvt_float_to_fp8<interp, ck_saturation_t::CK_SATFINITE, stochastic_rounding>(
|
||||
static_cast<float>(x));
|
||||
#endif // defined(__gfx950__)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief convert vector of 2 _Float16 to vector of 2 @p fp8_storage_t
|
||||
*
|
||||
* \tparam sat saturation of fp8
|
||||
* \tparam interp interpretation of fp8
|
||||
* \tparam stochastic_rounding switch between RNE and SR
|
||||
* \param x vector of 2 _Float16
|
||||
* \return fp8x2_storage_t
|
||||
*/
|
||||
template <ck_fp8_interpretation_t interp,
|
||||
ck_saturation_t sat = ck_saturation_t::CK_SATFINITE,
|
||||
bool stochastic_rounding = false>
|
||||
#if CK_FP8_CVT_FAST_PATH || CK_USE_OCP_FP8
|
||||
__host__ __device__ static inline fp8x2_storage_t cvt_half_t_to_fp8(const half2_t x)
|
||||
#else
|
||||
__host__ static inline fp8x2_storage_t cvt_half_t_to_fp8(const half2_t x)
|
||||
#endif
|
||||
{
|
||||
{
|
||||
__is_interpret_supported(interp);
|
||||
uint32_t rng = 0;
|
||||
if constexpr(stochastic_rounding)
|
||||
{
|
||||
constexpr int seed = 1254739;
|
||||
#ifndef CK_CODE_GEN_RTC
|
||||
rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x[0]);
|
||||
#else
|
||||
rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&x), x[0]);
|
||||
#endif
|
||||
}
|
||||
#if defined(__gfx950__)
|
||||
return cast_to_f8_from_f16<interp,
|
||||
sat == ck_saturation_t::CK_SATFINITE,
|
||||
stochastic_rounding>(x, rng);
|
||||
#else
|
||||
std::ignore = rng;
|
||||
return cvt_float_to_fp8<interp, ck_saturation_t::CK_SATFINITE, stochastic_rounding>(
|
||||
float2_t{static_cast<float>(x[0]), static_cast<float>(x[1])});
|
||||
#endif // defined(__gfx950__)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief convert bhalf_t to @p fp8_storage_t
|
||||
*
|
||||
* \tparam sat saturation of fp8
|
||||
* \tparam interp interpretation of fp8
|
||||
* \tparam stochastic_rounding switch between RNE and SR
|
||||
* \param x bhalf_t value
|
||||
* \return fp8_storage_t
|
||||
*/
|
||||
template <ck_fp8_interpretation_t interp,
|
||||
ck_saturation_t sat = ck_saturation_t::CK_SATFINITE,
|
||||
bool stochastic_rounding = false>
|
||||
#if CK_FP8_CVT_FAST_PATH || CK_USE_OCP_FP8
|
||||
__host__ __device__ static inline fp8_storage_t cvt_bhalf_t_to_fp8(const ushort x)
|
||||
#else
|
||||
__host__ static inline fp8_storage_t cvt_bhalf_t_to_fp8(const ushort x)
|
||||
#endif
|
||||
{
|
||||
{
|
||||
__is_interpret_supported(interp);
|
||||
uint32_t rng = 0;
|
||||
if constexpr(stochastic_rounding)
|
||||
{
|
||||
constexpr int seed = 1254739;
|
||||
#ifndef CK_CODE_GEN_RTC
|
||||
rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x),
|
||||
static_cast<float>(x));
|
||||
#else
|
||||
rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&x), static_cast<float>(x));
|
||||
#endif
|
||||
}
|
||||
#if defined(__gfx950__)
|
||||
return cast_to_f8_from_bf16<interp,
|
||||
sat == ck_saturation_t::CK_SATFINITE,
|
||||
stochastic_rounding>(x, rng);
|
||||
#else
|
||||
std::ignore = rng;
|
||||
return cvt_float_to_fp8<interp, ck_saturation_t::CK_SATFINITE, stochastic_rounding>(
|
||||
bit_cast<float>(uint32_t{x} << 16)); // convert value to float
|
||||
#endif // defined(__gfx950__)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief convert vector of 2 bhalf_t to vector of 2 @p fp8_storage_t
|
||||
*
|
||||
* \tparam sat saturation of fp8
|
||||
* \tparam interp interpretation of fp8
|
||||
* \tparam stochastic_rounding switch between RNE and SR
|
||||
* \param x vector of 2 bhalf_t
|
||||
* \return fp8x2_storage_t
|
||||
*/
|
||||
template <ck_fp8_interpretation_t interp,
|
||||
ck_saturation_t sat = ck_saturation_t::CK_SATFINITE,
|
||||
bool stochastic_rounding = false>
|
||||
#if CK_FP8_CVT_FAST_PATH || CK_USE_OCP_FP8
|
||||
__host__ __device__ static inline fp8x2_storage_t cvt_bhalf_t_to_fp8(const ushortx2_t x)
|
||||
#else
|
||||
__host__ static inline fp8x2_storage_t cvt_bhalf_t_to_fp8(const ushortx2_t x)
|
||||
#endif
|
||||
{
|
||||
#if CK_WORKAROUND_BF16_TO_FP8_CONVERSION
|
||||
return cvt_float_to_fp8<interp, ck_saturation_t::CK_SATFINITE, stochastic_rounding>(
|
||||
float2_t{bit_cast<float>(uint32_t{x[0]} << 16),
|
||||
bit_cast<float>(uint32_t{x[1]} << 16)}); // convert values to float
|
||||
#else // CK_WORKAROUND_BF16_TO_FP8_CONVERSION
|
||||
{
|
||||
__is_interpret_supported(interp);
|
||||
uint32_t rng = 0;
|
||||
if constexpr(stochastic_rounding)
|
||||
{
|
||||
constexpr int seed = 1254739;
|
||||
#ifndef CK_CODE_GEN_RTC
|
||||
rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x),
|
||||
static_cast<float>(x[0]));
|
||||
#else
|
||||
rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&x),
|
||||
static_cast<float>(x[0]));
|
||||
#endif
|
||||
}
|
||||
#if defined(__gfx950__)
|
||||
return cast_to_f8_from_bf16<interp,
|
||||
sat == ck_saturation_t::CK_SATFINITE,
|
||||
stochastic_rounding>(x, rng);
|
||||
#else
|
||||
std::ignore = rng;
|
||||
return cvt_float_to_fp8<interp, ck_saturation_t::CK_SATFINITE, stochastic_rounding>(
|
||||
float2_t{bit_cast<float>(uint32_t{x[0]} << 16),
|
||||
bit_cast<float>(uint32_t{x[1]} << 16)}); // convert values to float
|
||||
#endif // defined(__gfx950__)
|
||||
}
|
||||
#endif // CK_WORKAROUND_BF16_TO_FP8_CONVERSION
|
||||
}
|
||||
|
||||
} // namespace fp8_impl
|
||||
|
||||
// Declare a template function for fp8 conversion using RNE
|
||||
template <typename Y, typename X>
|
||||
__host__ __device__ constexpr Y f8_convert_rne(X x);
|
||||
|
||||
// convert fp32 to fp8 with rounding to nearest even
|
||||
template <>
|
||||
inline __host__ __device__ f8_ocp_t f8_convert_rne<f8_ocp_t, float>(float x)
|
||||
{
|
||||
return f8_ocp_t{
|
||||
fp8_impl::cvt_float_to_fp8<f8_ocp_t::default_interpret, f8_ocp_t::default_saturation>(x)};
|
||||
}
|
||||
|
||||
// convert fp32 to bf8 with rounding to nearest even
|
||||
template <>
|
||||
inline __host__ __device__ bf8_ocp_t f8_convert_rne<bf8_ocp_t, float>(float x)
|
||||
{
|
||||
return bf8_ocp_t{
|
||||
fp8_impl::cvt_float_to_fp8<bf8_ocp_t::default_interpret, bf8_ocp_t::default_saturation>(x)};
|
||||
}
|
||||
|
||||
// convert _Float16 to fp8 with rounding to nearest even
|
||||
template <>
|
||||
inline __host__ __device__ f8_ocp_t f8_convert_rne<f8_ocp_t, _Float16>(_Float16 x)
|
||||
{
|
||||
return f8_ocp_t{
|
||||
fp8_impl::cvt_half_t_to_fp8<f8_ocp_t::default_interpret, f8_ocp_t::default_saturation>(x)};
|
||||
}
|
||||
|
||||
template <>
|
||||
inline __host__ __device__ bf8_ocp_t f8_convert_rne<bf8_ocp_t, _Float16>(_Float16 x)
|
||||
{
|
||||
return bf8_ocp_t{
|
||||
fp8_impl::cvt_half_t_to_fp8<bf8_ocp_t::default_interpret, bf8_ocp_t::default_saturation>(
|
||||
x)};
|
||||
}
|
||||
|
||||
// Declare a template function for fp8 conversion using RNE
|
||||
template <typename Y, typename X>
|
||||
__host__ __device__ constexpr Y f8_convert_sr(X x);
|
||||
|
||||
// convert fp32 to fp8 with stochastic rounding
|
||||
template <>
|
||||
inline __host__ __device__ f8_ocp_t f8_convert_sr<f8_ocp_t, float>(float x)
|
||||
{
|
||||
return f8_ocp_t{
|
||||
fp8_impl::cvt_float_to_fp8<f8_ocp_t::default_interpret, f8_ocp_t::default_saturation, true>(
|
||||
x)};
|
||||
}
|
||||
|
||||
// convert fp32 to bf8 with stochastic rounding
|
||||
template <>
|
||||
inline __host__ __device__ bf8_ocp_t f8_convert_sr<bf8_ocp_t, float>(float x)
|
||||
{
|
||||
return bf8_ocp_t{fp8_impl::cvt_float_to_fp8<bf8_ocp_t::default_interpret,
|
||||
bf8_ocp_t::default_saturation,
|
||||
true>(x)};
|
||||
}
|
||||
|
||||
// convert _Float16 to fp8 with stochastic rounding
|
||||
template <>
|
||||
inline __host__ __device__ f8_ocp_t f8_convert_sr<f8_ocp_t, _Float16>(_Float16 x)
|
||||
{
|
||||
return f8_ocp_t{fp8_impl::cvt_half_t_to_fp8<f8_ocp_t::default_interpret,
|
||||
f8_ocp_t::default_saturation,
|
||||
true>(x)};
|
||||
}
|
||||
|
||||
// convert _Float16 to bf8 with stochastic rounding
|
||||
template <>
|
||||
inline __host__ __device__ bf8_ocp_t f8_convert_sr<bf8_ocp_t, _Float16>(_Float16 x)
|
||||
{
|
||||
return bf8_ocp_t{fp8_impl::cvt_half_t_to_fp8<bf8_ocp_t::default_interpret,
|
||||
bf8_ocp_t::default_saturation,
|
||||
true>(x)};
|
||||
}
|
||||
|
||||
#if CK_USE_OCP_FP8
|
||||
using f8_t = f8_ocp_t;
|
||||
using bf8_t = bf8_ocp_t;
|
||||
|
||||
@@ -39,7 +39,7 @@ static __device__ float cast_to_f32_from_f8_scaled(float scale, fp8_storage_t v)
|
||||
}
|
||||
|
||||
template <ck_fp8_interpretation_t interpret>
|
||||
static __device__ float2_t cast_to_f32x2_from_f8x2_scaled(float scale, fp8x2_storage_t v)
|
||||
static __device__ float2_t cast_to_f32_from_f8_scaled(float scale, fp8x2_storage_t v)
|
||||
{
|
||||
const auto i16val = bit_cast<uint16_t>(v);
|
||||
|
||||
|
||||
@@ -67,7 +67,7 @@ inline __host__ float2_t scaled_type_convert<float2_t, f8x2_ocp_t>(e8m0_bexp_t s
|
||||
#endif
|
||||
{
|
||||
#if CK_MX_FP8_CVT_FAST_PATH
|
||||
return fp8_impl::cast_to_f32x2_from_f8x2_scaled<f8_ocp_t::default_interpret>(
|
||||
return fp8_impl::cast_to_f32_from_f8_scaled<f8_ocp_t::default_interpret>(
|
||||
type_convert<float>(scale), x.AsType<fp8_impl::fp8x2_storage_t>()[Number<0>{}]);
|
||||
#else
|
||||
return float2_t{scaled_type_convert<float>(scale, x.AsType<f8_ocp_t>()[Number<0>{}]),
|
||||
@@ -86,7 +86,7 @@ inline __host__ float2_t scaled_type_convert<float2_t, bf8x2_ocp_t>(e8m0_bexp_t
|
||||
#endif
|
||||
{
|
||||
#if CK_MX_FP8_CVT_FAST_PATH
|
||||
return fp8_impl::cast_to_f32x2_from_f8x2_scaled<bf8_ocp_t::default_interpret>(
|
||||
return fp8_impl::cast_to_f32_from_f8_scaled<bf8_ocp_t::default_interpret>(
|
||||
type_convert<float>(scale), x.AsType<fp8_impl::fp8x2_storage_t>()[Number<0>{}]);
|
||||
#else
|
||||
return float2_t{scaled_type_convert<float>(scale, x.AsType<bf8_ocp_t>()[Number<0>{}]),
|
||||
|
||||
@@ -117,7 +117,7 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, float>(float
|
||||
#if CK_USE_RNE_BF16_CONVERSION
|
||||
return bf16_convert_rtn<bhalf_t>(x);
|
||||
#else
|
||||
return uint16_t(u.int32 >> 16);
|
||||
return uint16_t(uint32_t{x} >> 16);
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -356,6 +356,180 @@ inline __host__ __device__ bf8_fnuz_t f8_convert_sr<bf8_fnuz_t, half_t>(half_t x
|
||||
#endif
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Converts a float to a 8-bit float type (f8_ocp_t) using stochastic rounding.
|
||||
*
|
||||
* @param x The input float value.
|
||||
* @return The converted f8_ocp_t value.
|
||||
*/
|
||||
template <>
|
||||
inline __host__ __device__ f8_ocp_t f8_convert_sr<f8_ocp_t, float>(float x)
|
||||
{
|
||||
return f8_ocp_t{
|
||||
fp8_impl::cvt_float_to_fp8<f8_ocp_t::default_interpret, f8_ocp_t::default_saturation, true>(
|
||||
x)};
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Converts a vector of 2 floats to a vector of 2 8-bit float types (f8_ocp_t) using
|
||||
* stochastic rounding.
|
||||
*
|
||||
* @param x The input vector of 2 floats.
|
||||
* @return The converted vector of 2 f8_ocp_t.
|
||||
*/
|
||||
template <>
|
||||
inline __host__ __device__ f8x2_ocp_t f8_convert_sr<f8x2_ocp_t, float2_t>(float2_t x)
|
||||
{
|
||||
return f8x2_ocp_t{
|
||||
fp8_impl::cvt_float_to_fp8<f8_ocp_t::default_interpret, f8_ocp_t::default_saturation, true>(
|
||||
x)};
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Converts a float to a 8-bit float type (bf8_ocp_t) using stochastic rounding.
|
||||
*
|
||||
* @param x The input float value.
|
||||
* @return The converted bf8_ocp_t value.
|
||||
*/
|
||||
template <>
|
||||
inline __host__ __device__ bf8_ocp_t f8_convert_sr<bf8_ocp_t, float>(float x)
|
||||
{
|
||||
return bf8_ocp_t{fp8_impl::cvt_float_to_fp8<bf8_ocp_t::default_interpret,
|
||||
bf8_ocp_t::default_saturation,
|
||||
true>(x)};
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Converts a vector of 2 floats to a vector of 2 8-bit float types (bf8_ocp_t) using
|
||||
* stochastic rounding.
|
||||
*
|
||||
* @param x The input vector of 2 floats.
|
||||
* @return The converted vector of 2 bf8_ocp_t.
|
||||
*/
|
||||
template <>
|
||||
inline __host__ __device__ bf8x2_ocp_t f8_convert_sr<bf8x2_ocp_t, float2_t>(float2_t x)
|
||||
{
|
||||
return bf8x2_ocp_t{fp8_impl::cvt_float_to_fp8<bf8_ocp_t::default_interpret,
|
||||
bf8_ocp_t::default_saturation,
|
||||
true>(x)};
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Converts a half_t to a 8-bit float type (f8_ocp_t) using stochastic rounding.
|
||||
*
|
||||
* @param x The input half_t value.
|
||||
* @return The converted f8_ocp_t value.
|
||||
*/
|
||||
template <>
|
||||
inline __host__ __device__ f8_ocp_t f8_convert_sr<f8_ocp_t, half_t>(half_t x)
|
||||
{
|
||||
return f8_ocp_t{fp8_impl::cvt_half_t_to_fp8<f8_ocp_t::default_interpret,
|
||||
f8_ocp_t::default_saturation,
|
||||
true>(x)};
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Converts a vector of 2 half_t to a vector of 2 8-bit float types (f8_ocp_t) using
|
||||
* stochastic rounding.
|
||||
*
|
||||
* @param x The input vector of 2 half_t.
|
||||
* @return The converted vector of 2 f8_ocp_t.
|
||||
*/
|
||||
template <>
|
||||
inline __host__ __device__ f8x2_ocp_t f8_convert_sr<f8x2_ocp_t, half2_t>(half2_t x)
|
||||
{
|
||||
return f8x2_ocp_t{fp8_impl::cvt_half_t_to_fp8<f8_ocp_t::default_interpret,
|
||||
f8_ocp_t::default_saturation,
|
||||
true>(x)};
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Converts a half_t to a 8-bit half_t type (bf8_ocp_t) using stochastic rounding.
|
||||
*
|
||||
* @param x The input half_t value.
|
||||
* @return The converted bf8_ocp_t value.
|
||||
*/
|
||||
template <>
|
||||
inline __host__ __device__ bf8_ocp_t f8_convert_sr<bf8_ocp_t, half_t>(half_t x)
|
||||
{
|
||||
return bf8_ocp_t{fp8_impl::cvt_half_t_to_fp8<bf8_ocp_t::default_interpret,
|
||||
bf8_ocp_t::default_saturation,
|
||||
true>(x)};
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Converts a vector of 2 half_t to a vector of 2 8-bit float types (bf8_ocp_t) using
|
||||
* stochastic rounding.
|
||||
*
|
||||
* @param x The input vector of 2 half_t.
|
||||
* @return The converted vector of 2 bf8_ocp_t.
|
||||
*/
|
||||
template <>
|
||||
inline __host__ __device__ bf8x2_ocp_t f8_convert_sr<bf8x2_ocp_t, half2_t>(half2_t x)
|
||||
{
|
||||
return bf8x2_ocp_t{fp8_impl::cvt_half_t_to_fp8<bf8_ocp_t::default_interpret,
|
||||
bf8_ocp_t::default_saturation,
|
||||
true>(x)};
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Converts a bhalf_t to a 8-bit float type (f8_ocp_t) using stochastic rounding.
|
||||
*
|
||||
* @param x The input bhalf_t value.
|
||||
* @return The converted f8_ocp_t value.
|
||||
*/
|
||||
template <>
|
||||
inline __host__ __device__ f8_ocp_t f8_convert_sr<f8_ocp_t, bhalf_t>(bhalf_t x)
|
||||
{
|
||||
return f8_ocp_t{fp8_impl::cvt_bhalf_t_to_fp8<f8_ocp_t::default_interpret,
|
||||
f8_ocp_t::default_saturation,
|
||||
true>(x)};
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Converts a vector of 2 bhalf_t to a vector of 2 8-bit float types (f8_ocp_t) using
|
||||
* stochastic rounding.
|
||||
*
|
||||
* @param x The input vector of 2 bhalf_t.
|
||||
* @return The converted vector of 2 f8_ocp_t.
|
||||
*/
|
||||
template <>
|
||||
inline __host__ __device__ f8x2_ocp_t f8_convert_sr<f8x2_ocp_t, bhalf2_t>(bhalf2_t x)
|
||||
{
|
||||
return f8x2_ocp_t{fp8_impl::cvt_bhalf_t_to_fp8<f8_ocp_t::default_interpret,
|
||||
f8_ocp_t::default_saturation,
|
||||
true>(x)};
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Converts a bhalf_t to a 8-bit half_t type (bf8_ocp_t) using stochastic rounding.
|
||||
*
|
||||
* @param x The input bhalf_t value.
|
||||
* @return The converted bf8_ocp_t value.
|
||||
*/
|
||||
template <>
|
||||
inline __host__ __device__ bf8_ocp_t f8_convert_sr<bf8_ocp_t, bhalf_t>(bhalf_t x)
|
||||
{
|
||||
return bf8_ocp_t{fp8_impl::cvt_bhalf_t_to_fp8<bf8_ocp_t::default_interpret,
|
||||
bf8_ocp_t::default_saturation,
|
||||
true>(x)};
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Converts a vector of 2 bhalf_t to a vector of 2 8-bit float types (bf8_ocp_t) using
|
||||
* stochastic rounding.
|
||||
*
|
||||
* @param x The input vector of 2 bhalf_t.
|
||||
* @return The converted vector of 2 bf8_ocp_t.
|
||||
*/
|
||||
template <>
|
||||
inline __host__ __device__ bf8x2_ocp_t f8_convert_sr<bf8x2_ocp_t, bhalf2_t>(bhalf2_t x)
|
||||
{
|
||||
return bf8x2_ocp_t{fp8_impl::cvt_bhalf_t_to_fp8<bf8_ocp_t::default_interpret,
|
||||
bf8_ocp_t::default_saturation,
|
||||
true>(x)};
|
||||
}
|
||||
|
||||
// Declare a template function for fp8 conversion using RNE
|
||||
template <typename Y, typename X>
|
||||
__host__ __device__ constexpr Y f8_convert_rne(X x);
|
||||
@@ -466,6 +640,172 @@ inline __host__ __device__ bf8_fnuz_t f8_convert_rne<bf8_fnuz_t, half_t>(half_t
|
||||
#endif
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Converts a float to a 8-bit float type (f8_ocp_t) using rounding to nearest/even.
|
||||
*
|
||||
* @param x The input float value.
|
||||
* @return The converted f8_ocp_t value.
|
||||
*/
|
||||
template <>
|
||||
inline __host__ __device__ f8_ocp_t f8_convert_rne<f8_ocp_t, float>(float x)
|
||||
{
|
||||
return f8_ocp_t{
|
||||
fp8_impl::cvt_float_to_fp8<f8_ocp_t::default_interpret, f8_ocp_t::default_saturation>(x)};
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Converts a vector of 2 floats to a vector of 2 8-bit float types (f8_ocp_t) using rounding
|
||||
* to nearest/even.
|
||||
*
|
||||
* @param x The input vector of 2 floats.
|
||||
* @return The converted vector of 2 f8_ocp_t.
|
||||
*/
|
||||
template <>
|
||||
inline __host__ __device__ f8x2_ocp_t f8_convert_rne<f8x2_ocp_t, float2_t>(float2_t x)
|
||||
{
|
||||
return f8x2_ocp_t{
|
||||
fp8_impl::cvt_float_to_fp8<f8_ocp_t::default_interpret, f8_ocp_t::default_saturation>(x)};
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Converts a float to a 8-bit float type (bf8_ocp_t) using rounding to nearest/even.
|
||||
*
|
||||
* @param x The input float value.
|
||||
* @return The converted bf8_ocp_t value.
|
||||
*/
|
||||
template <>
|
||||
inline __host__ __device__ bf8_ocp_t f8_convert_rne<bf8_ocp_t, float>(float x)
|
||||
{
|
||||
return bf8_ocp_t{
|
||||
fp8_impl::cvt_float_to_fp8<bf8_ocp_t::default_interpret, bf8_ocp_t::default_saturation>(x)};
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Converts a vector of 2 floats to a vector of 2 8-bit float types (bf8_ocp_t) using
|
||||
* rounding to nearest/even.
|
||||
*
|
||||
* @param x The input vector of 2 floats.
|
||||
* @return The converted vector of 2 bf8_ocp_t.
|
||||
*/
|
||||
template <>
|
||||
inline __host__ __device__ bf8x2_ocp_t f8_convert_rne<bf8x2_ocp_t, float2_t>(float2_t x)
|
||||
{
|
||||
return bf8x2_ocp_t{
|
||||
fp8_impl::cvt_float_to_fp8<bf8_ocp_t::default_interpret, bf8_ocp_t::default_saturation>(x)};
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Converts a half_t to a 8-bit float type (f8_ocp_t) using rounding to nearest/even.
|
||||
*
|
||||
* @param x The input half_t value.
|
||||
* @return The converted f8_ocp_t value.
|
||||
*/
|
||||
template <>
|
||||
inline __host__ __device__ f8_ocp_t f8_convert_rne<f8_ocp_t, half_t>(half_t x)
|
||||
{
|
||||
return f8_ocp_t{
|
||||
fp8_impl::cvt_half_t_to_fp8<f8_ocp_t::default_interpret, f8_ocp_t::default_saturation>(x)};
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Converts a vector of 2 half_t to a vector of 2 8-bit float types (f8_ocp_t) using rounding
|
||||
* to nearest/even.
|
||||
*
|
||||
* @param x The input vector of 2 half_t.
|
||||
* @return The converted vector of 2 f8_ocp_t.
|
||||
*/
|
||||
template <>
|
||||
inline __host__ __device__ f8x2_ocp_t f8_convert_rne<f8x2_ocp_t, half2_t>(half2_t x)
|
||||
{
|
||||
return f8x2_ocp_t{
|
||||
fp8_impl::cvt_half_t_to_fp8<f8_ocp_t::default_interpret, f8_ocp_t::default_saturation>(x)};
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Converts a half_t to a 8-bit half_t type (bf8_ocp_t) using rounding to nearest/even.
|
||||
*
|
||||
* @param x The input half_t value.
|
||||
* @return The converted bf8_ocp_t value.
|
||||
*/
|
||||
template <>
|
||||
inline __host__ __device__ bf8_ocp_t f8_convert_rne<bf8_ocp_t, half_t>(half_t x)
|
||||
{
|
||||
return bf8_ocp_t{
|
||||
fp8_impl::cvt_half_t_to_fp8<bf8_ocp_t::default_interpret, bf8_ocp_t::default_saturation>(
|
||||
x)};
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Converts a vector of 2 half_t to a vector of 2 8-bit float types (bf8_ocp_t) using
|
||||
* rounding to nearest/even.
|
||||
*
|
||||
* @param x The input vector of 2 half_t.
|
||||
* @return The converted vector of 2 bf8_ocp_t.
|
||||
*/
|
||||
template <>
|
||||
inline __host__ __device__ bf8x2_ocp_t f8_convert_rne<bf8x2_ocp_t, half2_t>(half2_t x)
|
||||
{
|
||||
return bf8x2_ocp_t{
|
||||
fp8_impl::cvt_half_t_to_fp8<bf8_ocp_t::default_interpret, bf8_ocp_t::default_saturation>(
|
||||
x)};
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Converts a bhalf_t to a 8-bit float type (f8_ocp_t) using rounding to nearest/even.
|
||||
*
|
||||
* @param x The input bhalf_t value.
|
||||
* @return The converted f8_ocp_t value.
|
||||
*/
|
||||
template <>
|
||||
inline __host__ __device__ f8_ocp_t f8_convert_rne<f8_ocp_t, bhalf_t>(bhalf_t x)
|
||||
{
|
||||
return f8_ocp_t{
|
||||
fp8_impl::cvt_bhalf_t_to_fp8<f8_ocp_t::default_interpret, f8_ocp_t::default_saturation>(x)};
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Converts a vector of 2 bhalf_t to a vector of 2 8-bit float types (f8_ocp_t) using
|
||||
* rounding to nearest/even.
|
||||
*
|
||||
* @param x The input vector of 2 bhalf_t.
|
||||
* @return The converted vector of 2 f8_ocp_t.
|
||||
*/
|
||||
template <>
|
||||
inline __host__ __device__ f8x2_ocp_t f8_convert_rne<f8x2_ocp_t, bhalf2_t>(bhalf2_t x)
|
||||
{
|
||||
return f8x2_ocp_t{
|
||||
fp8_impl::cvt_bhalf_t_to_fp8<f8_ocp_t::default_interpret, f8_ocp_t::default_saturation>(x)};
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Converts a bhalf_t to a 8-bit half_t type (bf8_ocp_t) using rounding to nearest/even.
|
||||
*
|
||||
* @param x The input bhalf_t value.
|
||||
* @return The converted bf8_ocp_t value.
|
||||
*/
|
||||
template <>
|
||||
inline __host__ __device__ bf8_ocp_t f8_convert_rne<bf8_ocp_t, bhalf_t>(bhalf_t x)
|
||||
{
|
||||
return bf8_ocp_t{
|
||||
fp8_impl::cvt_bhalf_t_to_fp8<bf8_ocp_t::default_interpret, bf8_ocp_t::default_saturation>(
|
||||
x)};
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Converts a vector of 2 bhalf_t to a vector of 2 8-bit float types (bf8_ocp_t) using
|
||||
* rounding to nearest/even.
|
||||
*
|
||||
* @param x The input vector of 2 bhalf_t.
|
||||
* @return The converted vector of 2 bf8_ocp_t.
|
||||
*/
|
||||
template <>
|
||||
inline __host__ __device__ bf8x2_ocp_t f8_convert_rne<bf8x2_ocp_t, bhalf2_t>(bhalf2_t x)
|
||||
{
|
||||
return bf8x2_ocp_t{
|
||||
fp8_impl::cvt_bhalf_t_to_fp8<bf8_ocp_t::default_interpret, bf8_ocp_t::default_saturation>(
|
||||
x)};
|
||||
}
|
||||
|
||||
// convert fp32 to fp8
|
||||
template <>
|
||||
inline __host__ __device__ f8_fnuz_t type_convert<f8_fnuz_t, float>(float x)
|
||||
@@ -477,17 +817,6 @@ inline __host__ __device__ f8_fnuz_t type_convert<f8_fnuz_t, float>(float x)
|
||||
#endif
|
||||
}
|
||||
|
||||
// convert fp32 to fp8
|
||||
template <>
|
||||
inline __host__ __device__ f8_ocp_t type_convert<f8_ocp_t, float>(float x)
|
||||
{
|
||||
#if CK_USE_SR_F8_CONVERSION
|
||||
return f8_convert_sr<f8_ocp_t>(x);
|
||||
#else
|
||||
return f8_convert_rne<f8_ocp_t>(x);
|
||||
#endif
|
||||
}
|
||||
|
||||
// convert fp8 to fp32
|
||||
template <>
|
||||
inline __host__ __device__ float type_convert<float, f8_fnuz_t>(f8_fnuz_t x)
|
||||
@@ -524,12 +853,39 @@ inline __host__ __device__ float2_t type_convert<float2_t, f8x2_fnuz_t>(f8x2_fnu
|
||||
#endif
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Converts a f8_ocp_t value to a float value.
|
||||
*
|
||||
* @param x The input f8_ocp_t value.
|
||||
* @return The converted float value.
|
||||
*/
|
||||
template <>
|
||||
inline __host__ __device__ float type_convert<float, f8_ocp_t>(f8_ocp_t x)
|
||||
{
|
||||
#if CK_OCP_FP8_CVT_FAST_PATH
|
||||
union
|
||||
{
|
||||
unsigned int i32val;
|
||||
fp8_storage_t i8val[4];
|
||||
} val;
|
||||
val.i8val[0] = x.data;
|
||||
return __builtin_amdgcn_cvt_f32_fp8(val.i32val, 0);
|
||||
#else
|
||||
return fp8_impl::cast_from_f8<float, f8_ocp_t::wm, f8_ocp_t::we, false>(x.data);
|
||||
#endif
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Converts a vector of 2 f8_ocp_t values to a vector of 2 float values.
|
||||
*
|
||||
* @param x The input vector of 2 f8_ocp_t values.
|
||||
* @return The converted vector of 2 float values.
|
||||
*/
|
||||
template <>
|
||||
inline __host__ __device__ float2_t type_convert<float2_t, f8x2_ocp_t>(f8x2_ocp_t x)
|
||||
{
|
||||
#if CK_OCP_FP8_CVT_FAST_PATH
|
||||
return fp8_impl::cast_to_f32x2_from_f8x2<f8_ocp_t::default_interpret>(
|
||||
x.AsType<fp8_impl::fp8x2_storage_t>()[Number<0>{}]);
|
||||
return __builtin_amdgcn_cvt_pk_f32_fp8(bit_cast<uint16_t>(x), false);
|
||||
#else
|
||||
return float2_t{fp8_impl::cast_from_f8<float, f8_ocp_t::wm, f8_ocp_t::we, false>(
|
||||
x.AsType<fp8_storage_t>()[Number<0>{}]),
|
||||
@@ -538,6 +894,229 @@ inline __host__ __device__ float2_t type_convert<float2_t, f8x2_ocp_t>(f8x2_ocp_
|
||||
#endif
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Converts a f8_ocp_t value to a half_t value.
|
||||
*
|
||||
* @param x The input f8_ocp_t value.
|
||||
* @return The converted half_t value.
|
||||
*/
|
||||
template <>
|
||||
inline __host__ __device__ half_t type_convert<half_t, f8_ocp_t>(f8_ocp_t x)
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
union
|
||||
{
|
||||
uint16_t i16val;
|
||||
fp8_storage_t i8val[2];
|
||||
} input;
|
||||
input.i8val[0] = x.data;
|
||||
|
||||
union
|
||||
{
|
||||
half2_t half_vec;
|
||||
half_t half_arr[2];
|
||||
} output;
|
||||
output.half_vec = __builtin_amdgcn_cvt_scalef32_pk_f16_fp8(input.i16val, /*scale*/ 1.f, 0);
|
||||
|
||||
return output.half_arr[0];
|
||||
#else
|
||||
return fp8_impl::cast_from_f8<half_t, f8_ocp_t::wm, f8_ocp_t::we, false>(x.data);
|
||||
#endif
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Converts a vector of 2 f8_ocp_t values to a vector of 2 half_t values.
|
||||
*
|
||||
* @param x The input vector of 2 f8_ocp_t values.
|
||||
* @return The converted vector of 2 half_t values.
|
||||
*/
|
||||
template <>
|
||||
inline __host__ __device__ half2_t type_convert<half2_t, f8x2_ocp_t>(f8x2_ocp_t x)
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
return __builtin_amdgcn_cvt_scalef32_pk_f16_fp8(bit_cast<uint16_t>(x), /*scale*/ 1.f, 0);
|
||||
#else
|
||||
return half2_t{type_convert<half_t>(float(x.AsType<f8_ocp_t>()[Number<0>{}])),
|
||||
type_convert<half_t>(float(x.AsType<f8_ocp_t>()[Number<1>{}]))};
|
||||
#endif
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Converts a f8_ocp_t value to a bhalf_t value.
|
||||
*
|
||||
* @param x The input f8_ocp_t value.
|
||||
* @return The converted bhalf_t value.
|
||||
*/
|
||||
template <>
|
||||
inline __host__ __device__ bhalf_t type_convert<bhalf_t, f8_ocp_t>(f8_ocp_t x)
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
union
|
||||
{
|
||||
uint16_t i16val;
|
||||
fp8_storage_t i8val[2];
|
||||
} input;
|
||||
input.i8val[0] = x.data;
|
||||
|
||||
union
|
||||
{
|
||||
bhalf2_t bhalf_vec;
|
||||
bhalf_t bhalf_arr[2];
|
||||
} output;
|
||||
output.bhalf_vec = __builtin_amdgcn_cvt_scalef32_pk_bf16_fp8(input.i16val, /*scale*/ 1.f, 0);
|
||||
|
||||
return output.bhalf_arr[0];
|
||||
#else
|
||||
return type_convert<bhalf_t>(
|
||||
fp8_impl::cast_from_f8<float, f8_ocp_t::wm, f8_ocp_t::we, false>(x.data));
|
||||
#endif
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Converts a vector of 2 f8_ocp_t values to a vector of 2 bhalf_t values.
|
||||
*
|
||||
* @param x The input vector of 2 f8_ocp_t values.
|
||||
* @return The converted vector of 2 bhalf_t values.
|
||||
*/
|
||||
template <>
|
||||
inline __host__ __device__ bhalf2_t type_convert<bhalf2_t, f8x2_ocp_t>(f8x2_ocp_t x)
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
return __builtin_amdgcn_cvt_scalef32_pk_bf16_fp8(bit_cast<uint16_t>(x), /*scale*/ 1.f, 0);
|
||||
#else
|
||||
return bhalf2_t{type_convert<bhalf_t>(float(x.AsType<f8_ocp_t>()[Number<0>{}])),
|
||||
type_convert<bhalf_t>(float(x.AsType<f8_ocp_t>()[Number<1>{}]))};
|
||||
#endif
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Converts a bf8_ocp_t value to a float value.
|
||||
*
|
||||
* @param x The input bf8_ocp_t value.
|
||||
* @return The converted float value.
|
||||
*/
|
||||
template <>
|
||||
inline __host__ __device__ float type_convert<float, bf8_ocp_t>(bf8_ocp_t x)
|
||||
{
|
||||
#if CK_OCP_FP8_CVT_FAST_PATH
|
||||
union
|
||||
{
|
||||
unsigned int i32val;
|
||||
fp8_storage_t i8val[4];
|
||||
} val;
|
||||
val.i8val[0] = x.data;
|
||||
return __builtin_amdgcn_cvt_f32_bf8(val.i32val, 0);
|
||||
#else
|
||||
return fp8_impl::cast_from_f8<float, bf8_ocp_t::wm, bf8_ocp_t::we, false>(x.data);
|
||||
#endif
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Converts a vector of 2 bf8_ocp_t values to a vector of 2 float values.
|
||||
*
|
||||
* @param x The input vector of 2 bf8_ocp_t values.
|
||||
* @return The converted vector of 2 float values.
|
||||
*/
|
||||
template <>
|
||||
inline __host__ __device__ float2_t type_convert<float2_t, bf8x2_ocp_t>(bf8x2_ocp_t x)
|
||||
{
|
||||
#if CK_OCP_FP8_CVT_FAST_PATH
|
||||
return __builtin_amdgcn_cvt_pk_f32_bf8(bit_cast<uint16_t>(x), false);
|
||||
#else
|
||||
return float2_t{fp8_impl::cast_from_f8<float, bf8_ocp_t::wm, bf8_ocp_t::we, false>(
|
||||
x.AsType<fp8_storage_t>()[Number<0>{}]),
|
||||
fp8_impl::cast_from_f8<float, bf8_ocp_t::wm, bf8_ocp_t::we, false>(
|
||||
x.AsType<fp8_storage_t>()[Number<1>{}])};
|
||||
#endif
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Converts a bf8_ocp_t value to a half_t value.
|
||||
*
|
||||
* @param x The input bf8_ocp_t value.
|
||||
* @return The converted half_t value.
|
||||
*/
|
||||
template <>
|
||||
inline __host__ __device__ half_t type_convert<half_t, bf8_ocp_t>(bf8_ocp_t x)
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
union
|
||||
{
|
||||
uint16_t i16val;
|
||||
fp8_storage_t i8val[2];
|
||||
} val;
|
||||
val.i8val[0] = x.data;
|
||||
return __builtin_amdgcn_cvt_scalef32_pk_f16_bf8(val.i16val, /*scale*/ 1.f, 0)[0];
|
||||
#else
|
||||
return fp8_impl::cast_from_f8<half_t, bf8_ocp_t::wm, bf8_ocp_t::we, false>(x.data);
|
||||
#endif
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Converts a vector of 2 bf8_ocp_t values to a vector of 2 half_t values.
|
||||
*
|
||||
* @param x The input vector of 2 bf8_ocp_t values.
|
||||
* @return The converted vector of 2 half_t values.
|
||||
*/
|
||||
template <>
|
||||
inline __host__ __device__ half2_t type_convert<half2_t, bf8x2_ocp_t>(bf8x2_ocp_t x)
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
return __builtin_amdgcn_cvt_scalef32_pk_f16_bf8(bit_cast<uint16_t>(x), /*scale*/ 1.f, 0);
|
||||
#else
|
||||
return half2_t{type_convert<half_t>(float(x.AsType<bf8_ocp_t>()[Number<0>{}])),
|
||||
type_convert<half_t>(float(x.AsType<bf8_ocp_t>()[Number<1>{}]))};
|
||||
#endif
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Converts a bf8_ocp_t value to a bhalf_t value.
|
||||
*
|
||||
* @param x The input bf8_ocp_t value.
|
||||
* @return The converted bhalf_t value.
|
||||
*/
|
||||
template <>
|
||||
inline __host__ __device__ bhalf_t type_convert<bhalf_t, bf8_ocp_t>(bf8_ocp_t x)
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
union
|
||||
{
|
||||
uint16_t i16val;
|
||||
fp8_storage_t i8val[2];
|
||||
} input;
|
||||
input.i8val[0] = x.data;
|
||||
|
||||
union
|
||||
{
|
||||
bhalf2_t bhalf_vec;
|
||||
bhalf_t bhalf_arr[2];
|
||||
} output;
|
||||
output.bhalf_vec = __builtin_amdgcn_cvt_scalef32_pk_bf16_bf8(input.i16val, /*scale*/ 1.f, 0);
|
||||
|
||||
return output.bhalf_arr[0];
|
||||
#else
|
||||
return type_convert<bhalf_t>(
|
||||
fp8_impl::cast_from_f8<float, bf8_ocp_t::wm, bf8_ocp_t::we, false>(x.data));
|
||||
#endif
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Converts a vector of 2 bf8_ocp_t values to a vector of 2 bhalf_t values.
|
||||
*
|
||||
* @param x The input vector of 2 bf8_ocp_t values.
|
||||
* @return The converted vector of 2 bhalf_t values.
|
||||
*/
|
||||
template <>
|
||||
inline __host__ __device__ bhalf2_t type_convert<bhalf2_t, bf8x2_ocp_t>(bf8x2_ocp_t x)
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
return __builtin_amdgcn_cvt_scalef32_pk_bf16_bf8(bit_cast<uint16_t>(x), /*scale*/ 1.f, 0);
|
||||
#else
|
||||
return bhalf2_t{type_convert<bhalf_t>(float(x.AsType<bf8_ocp_t>()[Number<0>{}])),
|
||||
type_convert<bhalf_t>(float(x.AsType<bf8_ocp_t>()[Number<1>{}]))};
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
inline __host__ __device__ float2_t type_convert<float2_t, pk_i4_t>(pk_i4_t x)
|
||||
{
|
||||
@@ -610,7 +1189,12 @@ inline __host__ __device__ f8_fnuz_t type_convert<f8_fnuz_t, half_t>(half_t x)
|
||||
#endif
|
||||
}
|
||||
|
||||
// convert fp16 to fp8
|
||||
/**
|
||||
* @brief Converts a half_t value to a f8_ocp_t value with rounding determined by a flag.
|
||||
*
|
||||
* @param x The input half_t value.
|
||||
* @return The converted f8_ocp_t value.
|
||||
*/
|
||||
template <>
|
||||
inline __host__ __device__ f8_ocp_t type_convert<f8_ocp_t, half_t>(half_t x)
|
||||
{
|
||||
@@ -621,6 +1205,22 @@ inline __host__ __device__ f8_ocp_t type_convert<f8_ocp_t, half_t>(half_t x)
|
||||
#endif
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Converts a half_t value to a bf8_ocp_t value with rounding determined by a flag.
|
||||
*
|
||||
* @param x The input half_t value.
|
||||
* @return The converted bf8_ocp_t value.
|
||||
*/
|
||||
template <>
|
||||
inline __host__ __device__ bf8_ocp_t type_convert<bf8_ocp_t, half_t>(half_t x)
|
||||
{
|
||||
#if CK_USE_SR_F8_CONVERSION
|
||||
return f8_convert_sr<bf8_ocp_t>(x);
|
||||
#else
|
||||
return f8_convert_rne<bf8_ocp_t>(x);
|
||||
#endif
|
||||
}
|
||||
|
||||
// convert fp8 to fp16
|
||||
template <>
|
||||
inline __host__ __device__ half_t type_convert<half_t, f8_fnuz_t>(f8_fnuz_t x)
|
||||
@@ -645,7 +1245,28 @@ inline __host__ __device__ bf8_fnuz_t type_convert<bf8_fnuz_t, float>(float x)
|
||||
#endif
|
||||
}
|
||||
|
||||
// convert fp32 to bf8
|
||||
/**
|
||||
* @brief Converts a float value to a f8_ocp_t value with rounding determined by a flag.
|
||||
*
|
||||
* @param x The input float value.
|
||||
* @return The converted f8_ocp_t value.
|
||||
*/
|
||||
template <>
|
||||
inline __host__ __device__ f8_ocp_t type_convert<f8_ocp_t, float>(float x)
|
||||
{
|
||||
#if CK_USE_SR_F8_CONVERSION
|
||||
return f8_convert_sr<f8_ocp_t>(x);
|
||||
#else
|
||||
return f8_convert_rne<f8_ocp_t>(x);
|
||||
#endif
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Converts a float value to a bf8_ocp_t value with rounding determined by a flag.
|
||||
*
|
||||
* @param x The input float value.
|
||||
* @return The converted bf8_ocp_t value.
|
||||
*/
|
||||
template <>
|
||||
inline __host__ __device__ bf8_ocp_t type_convert<bf8_ocp_t, float>(float x)
|
||||
{
|
||||
@@ -656,6 +1277,38 @@ inline __host__ __device__ bf8_ocp_t type_convert<bf8_ocp_t, float>(float x)
|
||||
#endif
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Converts a bhalf_t value to a f8_ocp_t value with rounding determined by a flag.
|
||||
*
|
||||
* @param x The input bhalf_t value.
|
||||
* @return The converted f8_ocp_t value.
|
||||
*/
|
||||
template <>
|
||||
inline __host__ __device__ f8_ocp_t type_convert<f8_ocp_t, bhalf_t>(bhalf_t x)
|
||||
{
|
||||
#if CK_USE_SR_F8_CONVERSION
|
||||
return f8_convert_sr<f8_ocp_t>(x);
|
||||
#else
|
||||
return f8_convert_rne<f8_ocp_t>(x);
|
||||
#endif
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Converts a bhalf_t value to a bf8_ocp_t value with rounding determined by a flag.
|
||||
*
|
||||
* @param x The input bhalf_t value.
|
||||
* @return The converted bf8_ocp_t value.
|
||||
*/
|
||||
template <>
|
||||
inline __host__ __device__ bf8_ocp_t type_convert<bf8_ocp_t, bhalf_t>(bhalf_t x)
|
||||
{
|
||||
#if CK_USE_SR_F8_CONVERSION
|
||||
return f8_convert_sr<bf8_ocp_t>(x);
|
||||
#else
|
||||
return f8_convert_rne<bf8_ocp_t>(x);
|
||||
#endif
|
||||
}
|
||||
|
||||
// convert bf8 to fp32
|
||||
template <>
|
||||
inline __host__ __device__ float type_convert<float, bf8_fnuz_t>(bf8_fnuz_t x)
|
||||
@@ -683,17 +1336,6 @@ inline __host__ __device__ bf8_fnuz_t type_convert<bf8_fnuz_t, half_t>(half_t x)
|
||||
#endif
|
||||
}
|
||||
|
||||
// convert fp16 to bf8
|
||||
template <>
|
||||
inline __host__ __device__ bf8_ocp_t type_convert<bf8_ocp_t, half_t>(half_t x)
|
||||
{
|
||||
#if CK_USE_SR_F8_CONVERSION
|
||||
return f8_convert_sr<bf8_ocp_t>(x);
|
||||
#else
|
||||
return f8_convert_rne<bf8_ocp_t>(x);
|
||||
#endif
|
||||
}
|
||||
|
||||
// convert bf8 to fp16
|
||||
template <>
|
||||
inline __host__ __device__ half_t type_convert<half_t, bf8_fnuz_t>(bf8_fnuz_t x)
|
||||
|
||||
@@ -1,13 +1,19 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/utility/data_type.hpp"
|
||||
#include "ck/utility/type_convert.hpp"
|
||||
|
||||
using ck::bf8_ocp_t;
|
||||
using ck::bf8x2_ocp_t;
|
||||
using ck::bhalf2_t;
|
||||
using ck::bhalf_t;
|
||||
using ck::f8_convert_rne;
|
||||
using ck::f8_convert_sr;
|
||||
using ck::float2_t;
|
||||
using ck::half2_t;
|
||||
using ck::half_t;
|
||||
using ck::type_convert;
|
||||
|
||||
@@ -266,3 +272,590 @@ TEST(BF8OCP, ConvertFP16Stochastic)
|
||||
const auto bf8_nan = f8_convert_sr<bf8_ocp_t>(ck::NumericLimits<half_t>::QuietNaN());
|
||||
ASSERT_TRUE(ck::fp8_impl::ocp_bf8_is_nan(bf8_nan.data));
|
||||
}
|
||||
|
||||
constexpr uint64_t test_size = 256 + 6;
|
||||
|
||||
__host__ __device__ void
|
||||
test_fp32_bf8_type_convert(uint64_t N, float* p_test, uint64_t* p_completed)
|
||||
{
|
||||
if(p_completed == nullptr)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
uint64_t& i = *p_completed;
|
||||
i = 0;
|
||||
|
||||
if(p_test == nullptr)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
for(ck::index_t bf8_id = 0; bf8_id < 256; bf8_id++)
|
||||
{
|
||||
uint8_t bf8_uid = static_cast<uint8_t>(bf8_id);
|
||||
auto v = type_convert<float>(bf8_ocp_t{bf8_uid});
|
||||
p_test[i] = v;
|
||||
i++;
|
||||
if(i >= N)
|
||||
{
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
/// Test vector conversion
|
||||
// bf8x2 -> fp32x2
|
||||
bf8x2_ocp_t bf8x2{bf8x2_ocp_t::data_v{0b10000100, 0b00000001}}; //-2^-14, 2^-16
|
||||
|
||||
float2_t f32x2 = type_convert<float2_t>(bf8x2);
|
||||
p_test[i++] = f32x2[0];
|
||||
if(i >= N)
|
||||
{
|
||||
return;
|
||||
}
|
||||
p_test[i++] = f32x2[1];
|
||||
if(i >= N)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
// fp32x2 -> bf8x2
|
||||
f32x2 = {-4.0f, 2.0f};
|
||||
bf8x2 = f8_convert_rne<bf8x2_ocp_t>(f32x2); // expect {-4, 2}
|
||||
|
||||
p_test[i++] = type_convert<float>(bf8x2.AsType<bf8_ocp_t>()(ck::Number<0>{})); //-4f
|
||||
if(i >= N)
|
||||
{
|
||||
return;
|
||||
}
|
||||
p_test[i++] = type_convert<float>(bf8x2.AsType<bf8_ocp_t>()(ck::Number<1>{})); // 2f
|
||||
if(i >= N)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
bf8x2 = f8_convert_sr<bf8x2_ocp_t>(f32x2); // expect {-4, 2}
|
||||
|
||||
p_test[i++] = type_convert<float>(bf8x2.AsType<bf8_ocp_t>()(ck::Number<0>{})); //-4f
|
||||
if(i >= N)
|
||||
{
|
||||
return;
|
||||
}
|
||||
p_test[i++] = type_convert<float>(bf8x2.AsType<bf8_ocp_t>()(ck::Number<1>{})); // 2f
|
||||
if(i >= N)
|
||||
{
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
TEST(BF8OCP, HostFP32BF8Convert)
|
||||
{
|
||||
std::vector<float> out(test_size, -1.0f);
|
||||
uint64_t completed = 0;
|
||||
|
||||
test_fp32_bf8_type_convert(test_size, out.data(), &completed);
|
||||
|
||||
std::set<uint8_t> bf8_nan_ids;
|
||||
bf8_nan_ids.insert(0b11111111);
|
||||
bf8_nan_ids.insert(0b01111111);
|
||||
bf8_nan_ids.insert(0b11111101);
|
||||
bf8_nan_ids.insert(0b01111101);
|
||||
bf8_nan_ids.insert(0b11111110);
|
||||
bf8_nan_ids.insert(0b01111110);
|
||||
for(auto bf8_nan_id : bf8_nan_ids)
|
||||
{
|
||||
auto idx = bf8_nan_id;
|
||||
ASSERT_TRUE(std::isnan(out[idx]));
|
||||
}
|
||||
|
||||
for(ck::index_t bf8_id = 0; bf8_id < 256; bf8_id++)
|
||||
{
|
||||
if(bf8_nan_ids.find(bf8_id) != bf8_nan_ids.end())
|
||||
continue;
|
||||
|
||||
uint8_t bf8_uid = static_cast<uint8_t>(bf8_id);
|
||||
auto idx = bf8_uid;
|
||||
ASSERT_FLOAT_EQ(out[idx], type_convert<float>(bf8_ocp_t{bf8_uid}))
|
||||
<< " bf8_id: " << bf8_id << std::endl
|
||||
<< type_convert<float>(bf8_ocp_t{bf8_uid});
|
||||
}
|
||||
|
||||
// /// Test vector conversions
|
||||
|
||||
auto i = 256;
|
||||
|
||||
// bf8x2 -> fp32x2
|
||||
EXPECT_EQ(out[i++], -powf(2.0f, -14.0f));
|
||||
EXPECT_EQ(out[i++], powf(2.0f, -16.0f));
|
||||
|
||||
// fp32x2 -> bf8x2
|
||||
// RNE
|
||||
EXPECT_EQ(out[i++], -4.0f);
|
||||
EXPECT_EQ(out[i++], 2.0f);
|
||||
// SR
|
||||
EXPECT_EQ(out[i++], -4.0f);
|
||||
EXPECT_EQ(out[i++], 2.0f);
|
||||
|
||||
EXPECT_EQ(test_size, completed);
|
||||
EXPECT_EQ(test_size, i);
|
||||
}
|
||||
|
||||
__global__ void device_test_fp32_bf8_type_convert(uint64_t N, float* p_test, uint64_t* p_completed)
|
||||
{
|
||||
test_fp32_bf8_type_convert(N, p_test, p_completed);
|
||||
}
|
||||
|
||||
TEST(BF8OCP, DeviceFP32BF8Convert)
|
||||
{
|
||||
std::vector<float> out(test_size, -1.0f);
|
||||
|
||||
DeviceMem device_out(test_size * sizeof(float));
|
||||
DeviceMem device_completed(sizeof(uint64_t));
|
||||
|
||||
device_out.SetValue(-21.0f);
|
||||
device_completed.SetValue(-21.0f);
|
||||
|
||||
device_test_fp32_bf8_type_convert<<<1, 1>>>(
|
||||
test_size,
|
||||
static_cast<float*>(device_out.GetDeviceBuffer()),
|
||||
static_cast<uint64_t*>(device_completed.GetDeviceBuffer()));
|
||||
|
||||
uint64_t completed = 0;
|
||||
device_completed.FromDevice(&completed);
|
||||
device_out.FromDevice(out.data());
|
||||
|
||||
std::set<uint8_t> bf8_nan_ids;
|
||||
bf8_nan_ids.insert(0b11111111);
|
||||
bf8_nan_ids.insert(0b01111111);
|
||||
bf8_nan_ids.insert(0b11111101);
|
||||
bf8_nan_ids.insert(0b01111101);
|
||||
bf8_nan_ids.insert(0b11111110);
|
||||
bf8_nan_ids.insert(0b01111110);
|
||||
for(auto bf8_nan_id : bf8_nan_ids)
|
||||
{
|
||||
auto idx = bf8_nan_id;
|
||||
ASSERT_TRUE(std::isnan(out[idx])) << "idx: " << idx << " out[idx]: " << out[idx];
|
||||
}
|
||||
|
||||
for(ck::index_t bf8_id = 0; bf8_id < 256; bf8_id++)
|
||||
{
|
||||
if(bf8_nan_ids.find(bf8_id) != bf8_nan_ids.end())
|
||||
continue;
|
||||
|
||||
uint8_t bf8_uid = static_cast<uint8_t>(bf8_id);
|
||||
auto idx = bf8_uid;
|
||||
ASSERT_FLOAT_EQ(out[idx], type_convert<float>(bf8_ocp_t{bf8_uid}))
|
||||
<< " bf8_id: " << bf8_id << std::endl
|
||||
<< type_convert<float>(bf8_ocp_t{bf8_uid});
|
||||
}
|
||||
|
||||
/// Test vector conversions
|
||||
|
||||
auto i = 256;
|
||||
|
||||
// bf8x2 -> fp32x2
|
||||
EXPECT_EQ(out[i++], -powf(2.0f, -14.0f));
|
||||
EXPECT_EQ(out[i++], powf(2.0f, -16.0f));
|
||||
|
||||
// fp32x2 -> bf8x2
|
||||
// RNE
|
||||
EXPECT_EQ(out[i++], -4.0f);
|
||||
EXPECT_EQ(out[i++], 2.0f);
|
||||
// SR
|
||||
EXPECT_EQ(out[i++], -4.0f);
|
||||
EXPECT_EQ(out[i++], 2.0f);
|
||||
|
||||
EXPECT_EQ(test_size, completed);
|
||||
EXPECT_EQ(test_size, i);
|
||||
}
|
||||
|
||||
__host__ __device__ void
|
||||
test_fp16_bf8_type_convert(uint64_t N, half_t* p_test, uint64_t* p_completed)
|
||||
{
|
||||
if(p_completed == nullptr)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
uint64_t& i = *p_completed;
|
||||
i = 0;
|
||||
|
||||
if(p_test == nullptr)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
for(ck::index_t bf8_id = 0; bf8_id < 256; bf8_id++)
|
||||
{
|
||||
uint8_t bf8_uid = static_cast<uint8_t>(bf8_id);
|
||||
auto v = type_convert<half_t>(bf8_ocp_t{bf8_uid});
|
||||
p_test[i] = v;
|
||||
i++;
|
||||
if(i >= N)
|
||||
{
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
/// Test vector conversion
|
||||
// bf8x2 -> fp16x2
|
||||
bf8x2_ocp_t bf8x2{bf8x2_ocp_t::data_v{0b10000100, 0b00000001}}; //-2^-14, 2^-16
|
||||
|
||||
half2_t f16x2 = type_convert<half2_t>(bf8x2);
|
||||
p_test[i++] = f16x2[0];
|
||||
if(i >= N)
|
||||
{
|
||||
return;
|
||||
}
|
||||
p_test[i++] = f16x2[1];
|
||||
if(i >= N)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
// fp16x2 -> bf8x2
|
||||
f16x2 = {-4.0f, 2.0f};
|
||||
bf8x2 = f8_convert_rne<bf8x2_ocp_t>(f16x2); // expect {-4, 2}
|
||||
|
||||
p_test[i++] = type_convert<half_t>(bf8x2.AsType<bf8_ocp_t>()(ck::Number<0>{})); //-4f
|
||||
if(i >= N)
|
||||
{
|
||||
return;
|
||||
}
|
||||
p_test[i++] = type_convert<half_t>(bf8x2.AsType<bf8_ocp_t>()(ck::Number<1>{})); // 2f
|
||||
if(i >= N)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
bf8x2 = f8_convert_sr<bf8x2_ocp_t>(f16x2); // expect {-4, 2}
|
||||
|
||||
p_test[i++] = type_convert<half_t>(bf8x2.AsType<bf8_ocp_t>()(ck::Number<0>{})); //-4f
|
||||
if(i >= N)
|
||||
{
|
||||
return;
|
||||
}
|
||||
p_test[i++] = type_convert<half_t>(bf8x2.AsType<bf8_ocp_t>()(ck::Number<1>{})); // 2f
|
||||
if(i >= N)
|
||||
{
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
TEST(BF8OCP, HostFP16BF8Convert)
|
||||
{
|
||||
std::vector<half_t> out(test_size, -1.0f);
|
||||
uint64_t completed = 0;
|
||||
|
||||
test_fp16_bf8_type_convert(test_size, out.data(), &completed);
|
||||
|
||||
std::set<uint8_t> bf8_nan_ids;
|
||||
bf8_nan_ids.insert(0b11111111);
|
||||
bf8_nan_ids.insert(0b01111111);
|
||||
bf8_nan_ids.insert(0b11111101);
|
||||
bf8_nan_ids.insert(0b01111101);
|
||||
bf8_nan_ids.insert(0b11111110);
|
||||
bf8_nan_ids.insert(0b01111110);
|
||||
for(auto bf8_nan_id : bf8_nan_ids)
|
||||
{
|
||||
auto idx = bf8_nan_id;
|
||||
ASSERT_TRUE(std::isnan(type_convert<float>(out[idx])));
|
||||
}
|
||||
|
||||
for(ck::index_t bf8_id = 0; bf8_id < 256; bf8_id++)
|
||||
{
|
||||
if(bf8_nan_ids.find(bf8_id) != bf8_nan_ids.end())
|
||||
continue;
|
||||
|
||||
uint8_t bf8_uid = static_cast<uint8_t>(bf8_id);
|
||||
auto idx = bf8_uid;
|
||||
ASSERT_FLOAT_EQ(out[idx], type_convert<half_t>(bf8_ocp_t{bf8_uid}))
|
||||
<< " bf8_id: " << bf8_id << std::endl
|
||||
<< type_convert<float>(type_convert<half_t>(bf8_ocp_t{bf8_uid}));
|
||||
}
|
||||
|
||||
// /// Test vector conversions
|
||||
|
||||
auto i = 256;
|
||||
|
||||
// bf8x2 -> fp16x2
|
||||
EXPECT_EQ(out[i++], type_convert<half_t>(-powf(2.0f, -14.0f)));
|
||||
EXPECT_EQ(out[i++], type_convert<half_t>(powf(2.0f, -16.0f)));
|
||||
|
||||
// fp16x2 -> bf8x2
|
||||
// RNE
|
||||
EXPECT_EQ(out[i++], type_convert<half_t>(-4.0f));
|
||||
EXPECT_EQ(out[i++], type_convert<half_t>(2.0f));
|
||||
// SR
|
||||
EXPECT_EQ(out[i++], type_convert<half_t>(-4.0f));
|
||||
EXPECT_EQ(out[i++], type_convert<half_t>(2.0f));
|
||||
|
||||
EXPECT_EQ(test_size, completed);
|
||||
EXPECT_EQ(test_size, i);
|
||||
}
|
||||
|
||||
__global__ void device_test_fp16_bf8_type_convert(uint64_t N, half_t* p_test, uint64_t* p_completed)
|
||||
{
|
||||
test_fp16_bf8_type_convert(N, p_test, p_completed);
|
||||
}
|
||||
|
||||
TEST(BF8OCP, DeviceFP16BF8Convert)
|
||||
{
|
||||
std::vector<half_t> out(test_size, -1.0f);
|
||||
|
||||
DeviceMem device_out(test_size * sizeof(half_t));
|
||||
DeviceMem device_completed(sizeof(uint64_t));
|
||||
|
||||
device_out.SetValue(-21.0f);
|
||||
device_completed.SetValue(-21.0f);
|
||||
|
||||
device_test_fp16_bf8_type_convert<<<1, 1>>>(
|
||||
test_size,
|
||||
static_cast<half_t*>(device_out.GetDeviceBuffer()),
|
||||
static_cast<uint64_t*>(device_completed.GetDeviceBuffer()));
|
||||
|
||||
uint64_t completed = 0;
|
||||
device_completed.FromDevice(&completed);
|
||||
device_out.FromDevice(out.data());
|
||||
|
||||
std::set<uint8_t> bf8_nan_ids;
|
||||
bf8_nan_ids.insert(0b11111111);
|
||||
bf8_nan_ids.insert(0b01111111);
|
||||
bf8_nan_ids.insert(0b11111101);
|
||||
bf8_nan_ids.insert(0b01111101);
|
||||
bf8_nan_ids.insert(0b11111110);
|
||||
bf8_nan_ids.insert(0b01111110);
|
||||
for(auto bf8_nan_id : bf8_nan_ids)
|
||||
{
|
||||
auto idx = bf8_nan_id;
|
||||
ASSERT_TRUE(std::isnan(type_convert<float>(out[idx])))
|
||||
<< "idx: " << idx << " out[idx]: " << type_convert<float>(out[idx]);
|
||||
}
|
||||
|
||||
for(ck::index_t bf8_id = 0; bf8_id < 256; bf8_id++)
|
||||
{
|
||||
if(bf8_nan_ids.find(bf8_id) != bf8_nan_ids.end())
|
||||
continue;
|
||||
|
||||
uint8_t bf8_uid = static_cast<uint8_t>(bf8_id);
|
||||
auto idx = bf8_uid;
|
||||
ASSERT_FLOAT_EQ(out[idx], type_convert<half_t>(bf8_ocp_t{bf8_uid}))
|
||||
<< " bf8_id: " << bf8_id << std::endl
|
||||
<< type_convert<float>(type_convert<half_t>(bf8_ocp_t{bf8_uid}));
|
||||
}
|
||||
|
||||
/// Test vector conversions
|
||||
|
||||
auto i = 256;
|
||||
|
||||
// bf8x2 -> fp16x2
|
||||
EXPECT_EQ(out[i++], type_convert<half_t>(-powf(2.0f, -14.0f)));
|
||||
EXPECT_EQ(out[i++], type_convert<half_t>(powf(2.0f, -16.0f)));
|
||||
|
||||
// fp16x2 -> bf8x2
|
||||
// RNE
|
||||
EXPECT_EQ(out[i++], type_convert<half_t>(-4.0f));
|
||||
EXPECT_EQ(out[i++], type_convert<half_t>(2.0f));
|
||||
// SR
|
||||
EXPECT_EQ(out[i++], type_convert<half_t>(-4.0f));
|
||||
EXPECT_EQ(out[i++], type_convert<half_t>(2.0f));
|
||||
|
||||
EXPECT_EQ(test_size, completed);
|
||||
EXPECT_EQ(test_size, i);
|
||||
}
|
||||
|
||||
__host__ __device__ void
|
||||
test_bf16_bf8_type_convert(uint64_t N, bhalf_t* p_test, uint64_t* p_completed)
|
||||
{
|
||||
if(p_completed == nullptr)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
uint64_t& i = *p_completed;
|
||||
i = 0;
|
||||
|
||||
if(p_test == nullptr)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
for(ck::index_t bf8_id = 0; bf8_id < 256; bf8_id++)
|
||||
{
|
||||
uint8_t bf8_uid = static_cast<uint8_t>(bf8_id);
|
||||
auto v = type_convert<bhalf_t>(bf8_ocp_t{bf8_uid});
|
||||
p_test[i] = v;
|
||||
i++;
|
||||
if(i >= N)
|
||||
{
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
/// Test vector conversion
|
||||
// bf8x2 -> bf16x2
|
||||
bf8x2_ocp_t bf8x2{bf8x2_ocp_t::data_v{0b10000100, 0b00000001}}; //-2^-14, 2^-16
|
||||
|
||||
bhalf2_t bf16x2 = type_convert<bhalf2_t>(bf8x2);
|
||||
p_test[i++] = bf16x2[0];
|
||||
if(i >= N)
|
||||
{
|
||||
return;
|
||||
}
|
||||
p_test[i++] = bf16x2[1];
|
||||
if(i >= N)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
// bf16x2 -> bf8x2
|
||||
bf16x2 = {type_convert<bhalf_t>(-4.0f), type_convert<bhalf_t>(2.0f)};
|
||||
bf8x2 = f8_convert_rne<bf8x2_ocp_t>(bf16x2); // expect {-4, 2}
|
||||
|
||||
p_test[i++] = type_convert<bhalf_t>(bf8x2.AsType<bf8_ocp_t>()(ck::Number<0>{})); //-4f
|
||||
if(i >= N)
|
||||
{
|
||||
return;
|
||||
}
|
||||
p_test[i++] = type_convert<bhalf_t>(bf8x2.AsType<bf8_ocp_t>()(ck::Number<1>{})); // 2f
|
||||
if(i >= N)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
bf8x2 = f8_convert_sr<bf8x2_ocp_t>(bf16x2); // expect {-4, 2}
|
||||
|
||||
p_test[i++] = type_convert<bhalf_t>(bf8x2.AsType<bf8_ocp_t>()(ck::Number<0>{})); //-4f
|
||||
if(i >= N)
|
||||
{
|
||||
return;
|
||||
}
|
||||
p_test[i++] = type_convert<bhalf_t>(bf8x2.AsType<bf8_ocp_t>()(ck::Number<1>{})); // 2f
|
||||
if(i >= N)
|
||||
{
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
TEST(BF8OCP, HostBF16BF8Convert)
|
||||
{
|
||||
std::vector<bhalf_t> out(test_size, -1.0f);
|
||||
uint64_t completed = 0;
|
||||
|
||||
test_bf16_bf8_type_convert(test_size, out.data(), &completed);
|
||||
|
||||
std::set<uint8_t> bf8_nan_ids;
|
||||
bf8_nan_ids.insert(0b11111111);
|
||||
bf8_nan_ids.insert(0b01111111);
|
||||
bf8_nan_ids.insert(0b11111101);
|
||||
bf8_nan_ids.insert(0b01111101);
|
||||
bf8_nan_ids.insert(0b11111110);
|
||||
bf8_nan_ids.insert(0b01111110);
|
||||
for(auto bf8_nan_id : bf8_nan_ids)
|
||||
{
|
||||
auto idx = bf8_nan_id;
|
||||
ASSERT_TRUE(std::isnan(type_convert<float>(out[idx])));
|
||||
}
|
||||
|
||||
for(ck::index_t bf8_id = 0; bf8_id < 256; bf8_id++)
|
||||
{
|
||||
if(bf8_nan_ids.find(bf8_id) != bf8_nan_ids.end())
|
||||
continue;
|
||||
|
||||
uint8_t bf8_uid = static_cast<uint8_t>(bf8_id);
|
||||
auto idx = bf8_uid;
|
||||
ASSERT_FLOAT_EQ(out[idx], type_convert<bhalf_t>(bf8_ocp_t{bf8_uid}))
|
||||
<< " bf8_id: " << bf8_id << std::endl
|
||||
<< type_convert<float>(type_convert<bhalf_t>(bf8_ocp_t{bf8_uid}));
|
||||
}
|
||||
|
||||
// /// Test vector conversions
|
||||
|
||||
auto i = 256;
|
||||
|
||||
// bf8x2 -> bf16x2
|
||||
EXPECT_EQ(out[i++], type_convert<bhalf_t>(-powf(2.0f, -14.0f)));
|
||||
EXPECT_EQ(out[i++], type_convert<bhalf_t>(powf(2.0f, -16.0f)));
|
||||
|
||||
// bf16x2 -> bf8x2
|
||||
// RNE
|
||||
EXPECT_EQ(out[i++], type_convert<bhalf_t>(-4.0f));
|
||||
EXPECT_EQ(out[i++], type_convert<bhalf_t>(2.0f));
|
||||
// SR
|
||||
EXPECT_EQ(out[i++], type_convert<bhalf_t>(-4.0f));
|
||||
EXPECT_EQ(out[i++], type_convert<bhalf_t>(2.0f));
|
||||
|
||||
EXPECT_EQ(test_size, completed);
|
||||
EXPECT_EQ(test_size, i);
|
||||
}
|
||||
|
||||
__global__ void
|
||||
device_test_bf16_bf8_type_convert(uint64_t N, bhalf_t* p_test, uint64_t* p_completed)
|
||||
{
|
||||
test_bf16_bf8_type_convert(N, p_test, p_completed);
|
||||
}
|
||||
|
||||
TEST(BF8OCP, DeviceBF16BF8Convert)
|
||||
{
|
||||
std::vector<bhalf_t> out(test_size, -1.0f);
|
||||
|
||||
DeviceMem device_out(test_size * sizeof(bhalf_t));
|
||||
DeviceMem device_completed(sizeof(uint64_t));
|
||||
|
||||
device_out.SetValue(-21.0f);
|
||||
device_completed.SetValue(-21.0f);
|
||||
|
||||
device_test_bf16_bf8_type_convert<<<1, 1>>>(
|
||||
test_size,
|
||||
static_cast<bhalf_t*>(device_out.GetDeviceBuffer()),
|
||||
static_cast<uint64_t*>(device_completed.GetDeviceBuffer()));
|
||||
|
||||
uint64_t completed = 0;
|
||||
device_completed.FromDevice(&completed);
|
||||
device_out.FromDevice(out.data());
|
||||
|
||||
std::set<uint8_t> bf8_nan_ids;
|
||||
bf8_nan_ids.insert(0b11111111);
|
||||
bf8_nan_ids.insert(0b01111111);
|
||||
bf8_nan_ids.insert(0b11111101);
|
||||
bf8_nan_ids.insert(0b01111101);
|
||||
bf8_nan_ids.insert(0b11111110);
|
||||
bf8_nan_ids.insert(0b01111110);
|
||||
for(auto bf8_nan_id : bf8_nan_ids)
|
||||
{
|
||||
auto idx = bf8_nan_id;
|
||||
ASSERT_TRUE(std::isnan(type_convert<float>(out[idx])))
|
||||
<< "idx: " << idx << " out[idx]: " << type_convert<float>(out[idx]);
|
||||
}
|
||||
|
||||
for(ck::index_t bf8_id = 0; bf8_id < 256; bf8_id++)
|
||||
{
|
||||
if(bf8_nan_ids.find(bf8_id) != bf8_nan_ids.end())
|
||||
continue;
|
||||
|
||||
uint8_t bf8_uid = static_cast<uint8_t>(bf8_id);
|
||||
auto idx = bf8_uid;
|
||||
ASSERT_FLOAT_EQ(out[idx], type_convert<bhalf_t>(bf8_ocp_t{bf8_uid}))
|
||||
<< " bf8_id: " << bf8_id << std::endl
|
||||
<< type_convert<float>(type_convert<bhalf_t>(bf8_ocp_t{bf8_uid}));
|
||||
}
|
||||
|
||||
/// Test vector conversions
|
||||
|
||||
auto i = 256;
|
||||
|
||||
// bf8x2 -> bf16x2
|
||||
EXPECT_EQ(out[i++], type_convert<bhalf_t>(-powf(2.0f, -14.0f)));
|
||||
EXPECT_EQ(out[i++], type_convert<bhalf_t>(powf(2.0f, -16.0f)));
|
||||
|
||||
// bf16x2 -> bf8x2
|
||||
// RNE
|
||||
EXPECT_EQ(out[i++], type_convert<bhalf_t>(-4.0f));
|
||||
EXPECT_EQ(out[i++], type_convert<bhalf_t>(2.0f));
|
||||
// SR
|
||||
EXPECT_EQ(out[i++], type_convert<bhalf_t>(-4.0f));
|
||||
EXPECT_EQ(out[i++], type_convert<bhalf_t>(2.0f));
|
||||
|
||||
EXPECT_EQ(test_size, completed);
|
||||
EXPECT_EQ(test_size, i);
|
||||
}
|
||||
|
||||
@@ -1,13 +1,19 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/utility/data_type.hpp"
|
||||
#include "ck/utility/type_convert.hpp"
|
||||
|
||||
using ck::bhalf2_t;
|
||||
using ck::bhalf_t;
|
||||
using ck::f8_convert_rne;
|
||||
using ck::f8_convert_sr;
|
||||
using ck::f8_ocp_t;
|
||||
using ck::f8x2_ocp_t;
|
||||
using ck::float2_t;
|
||||
using ck::half2_t;
|
||||
using ck::half_t;
|
||||
using ck::type_convert;
|
||||
|
||||
@@ -248,3 +254,566 @@ TEST(FP8OCP, ConvertFP16Stochastic)
|
||||
auto f8_nan = f8_convert_sr<f8_ocp_t>(ck::NumericLimits<half_t>::QuietNaN());
|
||||
ASSERT_TRUE(ck::fp8_impl::ocp_f8_is_nan(f8_nan.data));
|
||||
}
|
||||
|
||||
constexpr uint64_t test_size = 256 + 6;
|
||||
|
||||
__host__ __device__ void
|
||||
test_fp32_fp8_type_convert(uint64_t N, float* p_test, uint64_t* p_completed)
|
||||
{
|
||||
if(p_completed == nullptr)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
uint64_t& i = *p_completed;
|
||||
i = 0;
|
||||
|
||||
if(p_test == nullptr)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
for(ck::index_t fp8_id = 0; fp8_id < 256; fp8_id++)
|
||||
{
|
||||
uint8_t fp8_uid = static_cast<uint8_t>(fp8_id);
|
||||
auto v = type_convert<float>(f8_ocp_t{fp8_uid});
|
||||
p_test[i] = v;
|
||||
i++;
|
||||
if(i >= N)
|
||||
{
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
/// Test vector conversion
|
||||
// fp8x2 -> fp32x2
|
||||
f8x2_ocp_t fp8x2{f8x2_ocp_t::data_v{0b10001000, 0b00000001}}; //-2^-6, 2^-9
|
||||
|
||||
float2_t f32x2 = type_convert<float2_t>(fp8x2);
|
||||
p_test[i++] = f32x2[0];
|
||||
if(i >= N)
|
||||
{
|
||||
return;
|
||||
}
|
||||
p_test[i++] = f32x2[1];
|
||||
if(i >= N)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
// fp32x2 -> fp8x2
|
||||
f32x2 = {-4.0f, 2.0f};
|
||||
fp8x2 = f8_convert_rne<f8x2_ocp_t>(f32x2); // expect {-4, 2}
|
||||
|
||||
p_test[i++] = type_convert<float>(fp8x2.AsType<f8_ocp_t>()(ck::Number<0>{})); //-4f
|
||||
if(i >= N)
|
||||
{
|
||||
return;
|
||||
}
|
||||
p_test[i++] = type_convert<float>(fp8x2.AsType<f8_ocp_t>()(ck::Number<1>{})); // 2f
|
||||
if(i >= N)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
fp8x2 = f8_convert_sr<f8x2_ocp_t>(f32x2); // expect {-4, 2}
|
||||
|
||||
p_test[i++] = type_convert<float>(fp8x2.AsType<f8_ocp_t>()(ck::Number<0>{})); //-4f
|
||||
if(i >= N)
|
||||
{
|
||||
return;
|
||||
}
|
||||
p_test[i++] = type_convert<float>(fp8x2.AsType<f8_ocp_t>()(ck::Number<1>{})); // 2f
|
||||
if(i >= N)
|
||||
{
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
TEST(FP8OCP, HostFP32FP8Convert)
|
||||
{
|
||||
std::vector<float> out(test_size, -1.0f);
|
||||
uint64_t completed = 0;
|
||||
|
||||
test_fp32_fp8_type_convert(test_size, out.data(), &completed);
|
||||
|
||||
std::set<uint8_t> fp8_nan_ids;
|
||||
fp8_nan_ids.insert(0b11111111); //-NaN
|
||||
fp8_nan_ids.insert(0b01111111); // +NaN
|
||||
for(auto fp8_nan_id : fp8_nan_ids)
|
||||
{
|
||||
auto idx = fp8_nan_id;
|
||||
ASSERT_TRUE(std::isnan(out[idx]));
|
||||
}
|
||||
|
||||
for(ck::index_t fp8_id = 0; fp8_id < 256; fp8_id++)
|
||||
{
|
||||
if(fp8_nan_ids.find(fp8_id) != fp8_nan_ids.end())
|
||||
continue;
|
||||
|
||||
uint8_t fp8_uid = static_cast<uint8_t>(fp8_id);
|
||||
auto idx = fp8_uid;
|
||||
ASSERT_FLOAT_EQ(out[idx], type_convert<float>(f8_ocp_t{fp8_uid}))
|
||||
<< " fp8_id: " << fp8_id << std::endl
|
||||
<< type_convert<float>(f8_ocp_t{fp8_uid});
|
||||
}
|
||||
|
||||
// /// Test vector conversions
|
||||
|
||||
auto i = 256;
|
||||
|
||||
// fp8x2 -> fp32x2
|
||||
EXPECT_EQ(out[i++], -powf(2.0f, -6.0f));
|
||||
EXPECT_EQ(out[i++], powf(2.0f, -9.0f));
|
||||
|
||||
// fp32x2 -> fp8x2
|
||||
// RNE
|
||||
EXPECT_EQ(out[i++], -4.0f);
|
||||
EXPECT_EQ(out[i++], 2.0f);
|
||||
// SR
|
||||
EXPECT_EQ(out[i++], -4.0f);
|
||||
EXPECT_EQ(out[i++], 2.0f);
|
||||
|
||||
EXPECT_EQ(test_size, completed);
|
||||
EXPECT_EQ(test_size, i);
|
||||
}
|
||||
|
||||
__global__ void device_test_fp32_fp8_type_convert(uint64_t N, float* p_test, uint64_t* p_completed)
|
||||
{
|
||||
test_fp32_fp8_type_convert(N, p_test, p_completed);
|
||||
}
|
||||
|
||||
TEST(FP8OCP, DeviceFP32FP8Convert)
|
||||
{
|
||||
std::vector<float> out(test_size, -1.0f);
|
||||
|
||||
DeviceMem device_out(test_size * sizeof(float));
|
||||
DeviceMem device_completed(sizeof(uint64_t));
|
||||
|
||||
device_out.SetValue(-21.0f);
|
||||
device_completed.SetValue(-21.0f);
|
||||
|
||||
device_test_fp32_fp8_type_convert<<<1, 1>>>(
|
||||
test_size,
|
||||
static_cast<float*>(device_out.GetDeviceBuffer()),
|
||||
static_cast<uint64_t*>(device_completed.GetDeviceBuffer()));
|
||||
|
||||
uint64_t completed = 0;
|
||||
device_completed.FromDevice(&completed);
|
||||
device_out.FromDevice(out.data());
|
||||
|
||||
std::set<uint8_t> fp8_nan_ids;
|
||||
fp8_nan_ids.insert(0b11111111); //-NaN
|
||||
fp8_nan_ids.insert(0b01111111); // +NaN
|
||||
for(auto fp8_nan_id : fp8_nan_ids)
|
||||
{
|
||||
auto idx = fp8_nan_id;
|
||||
ASSERT_TRUE(std::isnan(out[idx])) << "idx: " << idx << " out[idx]: " << out[idx];
|
||||
}
|
||||
|
||||
for(ck::index_t fp8_id = 0; fp8_id < 256; fp8_id++)
|
||||
{
|
||||
if(fp8_nan_ids.find(fp8_id) != fp8_nan_ids.end())
|
||||
continue;
|
||||
|
||||
uint8_t fp8_uid = static_cast<uint8_t>(fp8_id);
|
||||
auto idx = fp8_uid;
|
||||
ASSERT_FLOAT_EQ(out[idx], type_convert<float>(f8_ocp_t{fp8_uid}))
|
||||
<< " fp8_id: " << fp8_id << std::endl
|
||||
<< type_convert<float>(f8_ocp_t{fp8_uid});
|
||||
}
|
||||
|
||||
/// Test vector conversions
|
||||
|
||||
auto i = 256;
|
||||
|
||||
// fp8x2 -> fp32x2
|
||||
EXPECT_EQ(out[i++], -powf(2.0f, -6.0f));
|
||||
EXPECT_EQ(out[i++], powf(2.0f, -9.0f));
|
||||
|
||||
// fp32x2 -> fp8x2
|
||||
// RNE
|
||||
EXPECT_EQ(out[i++], -4.0f);
|
||||
EXPECT_EQ(out[i++], 2.0f);
|
||||
// SR
|
||||
EXPECT_EQ(out[i++], -4.0f);
|
||||
EXPECT_EQ(out[i++], 2.0f);
|
||||
|
||||
EXPECT_EQ(test_size, completed);
|
||||
EXPECT_EQ(test_size, i);
|
||||
}
|
||||
|
||||
__host__ __device__ void
|
||||
test_fp16_fp8_type_convert(uint64_t N, half_t* p_test, uint64_t* p_completed)
|
||||
{
|
||||
if(p_completed == nullptr)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
uint64_t& i = *p_completed;
|
||||
i = 0;
|
||||
|
||||
if(p_test == nullptr)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
for(ck::index_t fp8_id = 0; fp8_id < 256; fp8_id++)
|
||||
{
|
||||
uint8_t fp8_uid = static_cast<uint8_t>(fp8_id);
|
||||
auto v = type_convert<half_t>(f8_ocp_t{fp8_uid});
|
||||
p_test[i] = v;
|
||||
i++;
|
||||
if(i >= N)
|
||||
{
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
/// Test vector conversion
|
||||
// fp8x2 -> fp16x2
|
||||
f8x2_ocp_t fp8x2{f8x2_ocp_t::data_v{0b10001000, 0b00000001}}; //-2^-6, 2^-9
|
||||
|
||||
half2_t f16x2 = type_convert<half2_t>(fp8x2);
|
||||
p_test[i++] = f16x2[0];
|
||||
if(i >= N)
|
||||
{
|
||||
return;
|
||||
}
|
||||
p_test[i++] = f16x2[1];
|
||||
if(i >= N)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
// fp16x2 -> fp8x2
|
||||
f16x2 = {-4.0f, 2.0f};
|
||||
fp8x2 = f8_convert_rne<f8x2_ocp_t>(f16x2); // expect {-4, 2}
|
||||
|
||||
p_test[i++] = type_convert<half_t>(fp8x2.AsType<f8_ocp_t>()(ck::Number<0>{})); //-4f
|
||||
if(i >= N)
|
||||
{
|
||||
return;
|
||||
}
|
||||
p_test[i++] = type_convert<half_t>(fp8x2.AsType<f8_ocp_t>()(ck::Number<1>{})); // 2f
|
||||
if(i >= N)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
fp8x2 = f8_convert_sr<f8x2_ocp_t>(f16x2); // expect {-4, 2}
|
||||
|
||||
p_test[i++] = type_convert<half_t>(fp8x2.AsType<f8_ocp_t>()(ck::Number<0>{})); //-4f
|
||||
if(i >= N)
|
||||
{
|
||||
return;
|
||||
}
|
||||
p_test[i++] = type_convert<half_t>(fp8x2.AsType<f8_ocp_t>()(ck::Number<1>{})); // 2f
|
||||
if(i >= N)
|
||||
{
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
TEST(FP8OCP, HostFP16FP8Convert)
|
||||
{
|
||||
std::vector<half_t> out(test_size, -1.0f);
|
||||
uint64_t completed = 0;
|
||||
|
||||
test_fp16_fp8_type_convert(test_size, out.data(), &completed);
|
||||
|
||||
std::set<uint8_t> fp8_nan_ids;
|
||||
fp8_nan_ids.insert(0b11111111); //-NaN
|
||||
fp8_nan_ids.insert(0b01111111); // +NaN
|
||||
for(auto fp8_nan_id : fp8_nan_ids)
|
||||
{
|
||||
auto idx = fp8_nan_id;
|
||||
ASSERT_TRUE(std::isnan(type_convert<float>(out[idx])));
|
||||
}
|
||||
|
||||
for(ck::index_t fp8_id = 0; fp8_id < 256; fp8_id++)
|
||||
{
|
||||
if(fp8_nan_ids.find(fp8_id) != fp8_nan_ids.end())
|
||||
continue;
|
||||
|
||||
uint8_t fp8_uid = static_cast<uint8_t>(fp8_id);
|
||||
auto idx = fp8_uid;
|
||||
ASSERT_FLOAT_EQ(out[idx], type_convert<half_t>(f8_ocp_t{fp8_uid}))
|
||||
<< " fp8_id: " << fp8_id << std::endl
|
||||
<< type_convert<float>(type_convert<half_t>(f8_ocp_t{fp8_uid}));
|
||||
}
|
||||
|
||||
// /// Test vector conversions
|
||||
|
||||
auto i = 256;
|
||||
|
||||
// fp8x2 -> fp16x2
|
||||
EXPECT_EQ(out[i++], type_convert<half_t>(-powf(2.0f, -6.0f)));
|
||||
EXPECT_EQ(out[i++], type_convert<half_t>(powf(2.0f, -9.0f)));
|
||||
|
||||
// fp16x2 -> fp8x2
|
||||
// RNE
|
||||
EXPECT_EQ(out[i++], type_convert<half_t>(-4.0f));
|
||||
EXPECT_EQ(out[i++], type_convert<half_t>(2.0f));
|
||||
// SR
|
||||
EXPECT_EQ(out[i++], type_convert<half_t>(-4.0f));
|
||||
EXPECT_EQ(out[i++], type_convert<half_t>(2.0f));
|
||||
|
||||
EXPECT_EQ(test_size, completed);
|
||||
EXPECT_EQ(test_size, i);
|
||||
}
|
||||
|
||||
__global__ void device_test_fp16_fp8_type_convert(uint64_t N, half_t* p_test, uint64_t* p_completed)
|
||||
{
|
||||
test_fp16_fp8_type_convert(N, p_test, p_completed);
|
||||
}
|
||||
|
||||
TEST(FP8OCP, DeviceFP16FP8Convert)
|
||||
{
|
||||
std::vector<half_t> out(test_size, -1.0f);
|
||||
|
||||
DeviceMem device_out(test_size * sizeof(half_t));
|
||||
DeviceMem device_completed(sizeof(uint64_t));
|
||||
|
||||
device_out.SetValue(-21.0f);
|
||||
device_completed.SetValue(-21.0f);
|
||||
|
||||
device_test_fp16_fp8_type_convert<<<1, 1>>>(
|
||||
test_size,
|
||||
static_cast<half_t*>(device_out.GetDeviceBuffer()),
|
||||
static_cast<uint64_t*>(device_completed.GetDeviceBuffer()));
|
||||
|
||||
uint64_t completed = 0;
|
||||
device_completed.FromDevice(&completed);
|
||||
device_out.FromDevice(out.data());
|
||||
|
||||
std::set<uint8_t> fp8_nan_ids;
|
||||
fp8_nan_ids.insert(0b11111111); //-NaN
|
||||
fp8_nan_ids.insert(0b01111111); // +NaN
|
||||
for(auto fp8_nan_id : fp8_nan_ids)
|
||||
{
|
||||
auto idx = fp8_nan_id;
|
||||
ASSERT_TRUE(std::isnan(type_convert<float>(out[idx])))
|
||||
<< "idx: " << idx << " out[idx]: " << type_convert<float>(out[idx]);
|
||||
}
|
||||
|
||||
for(ck::index_t fp8_id = 0; fp8_id < 256; fp8_id++)
|
||||
{
|
||||
if(fp8_nan_ids.find(fp8_id) != fp8_nan_ids.end())
|
||||
continue;
|
||||
|
||||
uint8_t fp8_uid = static_cast<uint8_t>(fp8_id);
|
||||
auto idx = fp8_uid;
|
||||
ASSERT_FLOAT_EQ(out[idx], type_convert<half_t>(f8_ocp_t{fp8_uid}))
|
||||
<< " fp8_id: " << fp8_id << std::endl
|
||||
<< type_convert<float>(type_convert<half_t>(f8_ocp_t{fp8_uid}));
|
||||
}
|
||||
|
||||
/// Test vector conversions
|
||||
|
||||
auto i = 256;
|
||||
|
||||
// fp8x2 -> fp16x2
|
||||
EXPECT_EQ(out[i++], type_convert<half_t>(-powf(2.0f, -6.0f)));
|
||||
EXPECT_EQ(out[i++], type_convert<half_t>(powf(2.0f, -9.0f)));
|
||||
|
||||
// fp16x2 -> fp8x2
|
||||
// RNE
|
||||
EXPECT_EQ(out[i++], type_convert<half_t>(-4.0f));
|
||||
EXPECT_EQ(out[i++], type_convert<half_t>(2.0f));
|
||||
// SR
|
||||
EXPECT_EQ(out[i++], type_convert<half_t>(-4.0f));
|
||||
EXPECT_EQ(out[i++], type_convert<half_t>(2.0f));
|
||||
|
||||
EXPECT_EQ(test_size, completed);
|
||||
EXPECT_EQ(test_size, i);
|
||||
}
|
||||
|
||||
__host__ __device__ void
|
||||
test_bf16_fp8_type_convert(uint64_t N, bhalf_t* p_test, uint64_t* p_completed)
|
||||
{
|
||||
if(p_completed == nullptr)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
uint64_t& i = *p_completed;
|
||||
i = 0;
|
||||
|
||||
if(p_test == nullptr)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
for(ck::index_t fp8_id = 0; fp8_id < 256; fp8_id++)
|
||||
{
|
||||
uint8_t fp8_uid = static_cast<uint8_t>(fp8_id);
|
||||
auto v = type_convert<bhalf_t>(f8_ocp_t{fp8_uid});
|
||||
p_test[i] = v;
|
||||
i++;
|
||||
if(i >= N)
|
||||
{
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
/// Test vector conversion
|
||||
// fp8x2 -> bf16x2
|
||||
f8x2_ocp_t fp8x2{f8x2_ocp_t::data_v{0b10001000, 0b00000001}}; //-2^-6, 2^-9
|
||||
|
||||
bhalf2_t bf16x2 = type_convert<bhalf2_t>(fp8x2);
|
||||
p_test[i++] = bf16x2[0];
|
||||
if(i >= N)
|
||||
{
|
||||
return;
|
||||
}
|
||||
p_test[i++] = bf16x2[1];
|
||||
if(i >= N)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
// bf16x2 -> fp8x2
|
||||
bf16x2 = {type_convert<bhalf_t>(-4.0f), type_convert<bhalf_t>(2.0f)};
|
||||
fp8x2 = f8_convert_rne<f8x2_ocp_t>(bf16x2); // expect {-4, 2}
|
||||
|
||||
p_test[i++] = type_convert<bhalf_t>(fp8x2.AsType<f8_ocp_t>()(ck::Number<0>{})); //-4f
|
||||
if(i >= N)
|
||||
{
|
||||
return;
|
||||
}
|
||||
p_test[i++] = type_convert<bhalf_t>(fp8x2.AsType<f8_ocp_t>()(ck::Number<1>{})); // 2f
|
||||
if(i >= N)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
fp8x2 = f8_convert_sr<f8x2_ocp_t>(bf16x2); // expect {-4, 2}
|
||||
|
||||
p_test[i++] = type_convert<bhalf_t>(fp8x2.AsType<f8_ocp_t>()(ck::Number<0>{})); //-4f
|
||||
if(i >= N)
|
||||
{
|
||||
return;
|
||||
}
|
||||
p_test[i++] = type_convert<bhalf_t>(fp8x2.AsType<f8_ocp_t>()(ck::Number<1>{})); // 2f
|
||||
if(i >= N)
|
||||
{
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
TEST(FP8OCP, HostBF16FP8Convert)
|
||||
{
|
||||
std::vector<bhalf_t> out(test_size, -1.0f);
|
||||
uint64_t completed = 0;
|
||||
|
||||
test_bf16_fp8_type_convert(test_size, out.data(), &completed);
|
||||
|
||||
std::set<uint8_t> fp8_nan_ids;
|
||||
fp8_nan_ids.insert(0b11111111); //-NaN
|
||||
fp8_nan_ids.insert(0b01111111); // +NaN
|
||||
for(auto fp8_nan_id : fp8_nan_ids)
|
||||
{
|
||||
auto idx = fp8_nan_id;
|
||||
ASSERT_TRUE(std::isnan(type_convert<float>(out[idx])));
|
||||
}
|
||||
|
||||
for(ck::index_t fp8_id = 0; fp8_id < 256; fp8_id++)
|
||||
{
|
||||
if(fp8_nan_ids.find(fp8_id) != fp8_nan_ids.end())
|
||||
continue;
|
||||
|
||||
uint8_t fp8_uid = static_cast<uint8_t>(fp8_id);
|
||||
auto idx = fp8_uid;
|
||||
ASSERT_FLOAT_EQ(out[idx], type_convert<bhalf_t>(f8_ocp_t{fp8_uid}))
|
||||
<< " fp8_id: " << fp8_id << std::endl
|
||||
<< type_convert<float>(type_convert<bhalf_t>(f8_ocp_t{fp8_uid}));
|
||||
}
|
||||
|
||||
// /// Test vector conversions
|
||||
|
||||
auto i = 256;
|
||||
|
||||
// fp8x2 -> bf16x2
|
||||
EXPECT_EQ(out[i++], type_convert<bhalf_t>(-powf(2.0f, -6.0f)));
|
||||
EXPECT_EQ(out[i++], type_convert<bhalf_t>(powf(2.0f, -9.0f)));
|
||||
|
||||
// bf16x2 -> fp8x2
|
||||
// RNE
|
||||
EXPECT_EQ(out[i++], type_convert<bhalf_t>(-4.0f));
|
||||
EXPECT_EQ(out[i++], type_convert<bhalf_t>(2.0f));
|
||||
// SR
|
||||
EXPECT_EQ(out[i++], type_convert<bhalf_t>(-4.0f));
|
||||
EXPECT_EQ(out[i++], type_convert<bhalf_t>(2.0f));
|
||||
|
||||
EXPECT_EQ(test_size, completed);
|
||||
EXPECT_EQ(test_size, i);
|
||||
}
|
||||
|
||||
__global__ void
|
||||
device_test_bf16_fp8_type_convert(uint64_t N, bhalf_t* p_test, uint64_t* p_completed)
|
||||
{
|
||||
test_bf16_fp8_type_convert(N, p_test, p_completed);
|
||||
}
|
||||
|
||||
TEST(FP8OCP, DeviceBF16FP8Convert)
|
||||
{
|
||||
std::vector<bhalf_t> out(test_size, -1.0f);
|
||||
|
||||
DeviceMem device_out(test_size * sizeof(bhalf_t));
|
||||
DeviceMem device_completed(sizeof(uint64_t));
|
||||
|
||||
device_out.SetValue(-21.0f);
|
||||
device_completed.SetValue(-21.0f);
|
||||
|
||||
device_test_bf16_fp8_type_convert<<<1, 1>>>(
|
||||
test_size,
|
||||
static_cast<bhalf_t*>(device_out.GetDeviceBuffer()),
|
||||
static_cast<uint64_t*>(device_completed.GetDeviceBuffer()));
|
||||
|
||||
uint64_t completed = 0;
|
||||
device_completed.FromDevice(&completed);
|
||||
device_out.FromDevice(out.data());
|
||||
|
||||
std::set<uint8_t> fp8_nan_ids;
|
||||
fp8_nan_ids.insert(0b11111111); //-NaN
|
||||
fp8_nan_ids.insert(0b01111111); // +NaN
|
||||
for(auto fp8_nan_id : fp8_nan_ids)
|
||||
{
|
||||
auto idx = fp8_nan_id;
|
||||
ASSERT_TRUE(std::isnan(type_convert<float>(out[idx])))
|
||||
<< "idx: " << idx << " out[idx]: " << type_convert<float>(out[idx]);
|
||||
}
|
||||
|
||||
for(ck::index_t fp8_id = 0; fp8_id < 256; fp8_id++)
|
||||
{
|
||||
if(fp8_nan_ids.find(fp8_id) != fp8_nan_ids.end())
|
||||
continue;
|
||||
|
||||
uint8_t fp8_uid = static_cast<uint8_t>(fp8_id);
|
||||
auto idx = fp8_uid;
|
||||
ASSERT_FLOAT_EQ(out[idx], type_convert<bhalf_t>(f8_ocp_t{fp8_uid}))
|
||||
<< " fp8_id: " << fp8_id << std::endl
|
||||
<< type_convert<float>(type_convert<bhalf_t>(f8_ocp_t{fp8_uid}));
|
||||
}
|
||||
|
||||
/// Test vector conversions
|
||||
|
||||
auto i = 256;
|
||||
|
||||
// fp8x2 -> bf16x2
|
||||
EXPECT_EQ(out[i++], type_convert<bhalf_t>(-powf(2.0f, -6.0f)));
|
||||
EXPECT_EQ(out[i++], type_convert<bhalf_t>(powf(2.0f, -9.0f)));
|
||||
|
||||
// bf16x2 -> fp8x2
|
||||
// RNE
|
||||
EXPECT_EQ(out[i++], type_convert<bhalf_t>(-4.0f));
|
||||
EXPECT_EQ(out[i++], type_convert<bhalf_t>(2.0f));
|
||||
// SR
|
||||
EXPECT_EQ(out[i++], type_convert<bhalf_t>(-4.0f));
|
||||
EXPECT_EQ(out[i++], type_convert<bhalf_t>(2.0f));
|
||||
|
||||
EXPECT_EQ(test_size, completed);
|
||||
EXPECT_EQ(test_size, i);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user