Add FP16/BF16<->FP8/BF8 conversions (#2035)

* Move conversion functions and add missing conversions

* Add tests

* Add missing conversions

* Add missing conversions

* Add bf8 tests

* Update clipping for vectors

* Add missing conversions

* Add bf16 fp8 tests

* Add bf16 bf8 tests

* Fix device conversion

* Fix conversions

* Fix vector use

* Minor fix

* Add a workaround flag

* Add a workaround flag for bf16 conversion

* Add another workaround

* Add a workaround for fp16 to bf8 conversion

* Update type alias

* Add docstrings and missing wrappers

* Fix if defined macros

* Fix more if defined macros

* Add comments

* Remove __host__ specifier

* Add a gfx950 guard

* Update function naming

[ROCm/composable_kernel commit: 265af71a71]
This commit is contained in:
Rostyslav Geyyer
2025-04-03 12:42:03 -05:00
committed by GitHub
parent b7359bcfac
commit 7fbc128e83
7 changed files with 2628 additions and 110 deletions

View File

@@ -248,6 +248,12 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)
// workaround: compiler issue on gfx950
#define CK_WORKAROUND_FP32_TO_FP4_SR_CONVERSION 1
// workaround: compiler issue on gfx950
#define CK_WORKAROUND_FP16_TO_FP8_CONVERSION 1
// workaround: compiler issue on gfx950
#define CK_WORKAROUND_BF16_TO_FP8_CONVERSION 1
// denorm test fix, necessary for gfx90a
#ifndef CK_GFX90A_DENORM_WORKAROUND
#define CK_GFX90A_DENORM_WORKAROUND 0

View File

@@ -64,6 +64,9 @@ enum class ck_saturation_t
namespace fp8_impl {
typedef fp8_storage_t fp8x2_storage_t __attribute__((ext_vector_type(2)));
typedef _Float16 half2_t __attribute__((ext_vector_type(2)));
typedef ushort ushortx2_t __attribute__((ext_vector_type(2)));
typedef short shortx2_t __attribute__((ext_vector_type(2)));
typedef float float2_t __attribute__((ext_vector_type(2)));
__host__ __device__ static inline constexpr bool fnuz_f8_is_nan(f8_fnuz_t a)
@@ -270,7 +273,7 @@ static __host__ __device__ float cast_to_f32_from_f8(fp8_storage_t v)
}
template <ck_fp8_interpretation_t interpret>
static __device__ float2_t cast_to_f32x2_from_f8x2(fp8x2_storage_t v)
static __device__ float2_t cast_to_f32_from_f8(fp8x2_storage_t v)
{
const auto i16val = bit_cast<uint16_t>(v);
@@ -458,6 +461,510 @@ __is_interpret_supported([[maybe_unused]] ck_fp8_interpretation_t interp)
#endif
}
#if defined(__gfx950__)
template <ck_fp8_interpretation_t interpret,
bool saturate,
bool stochastic_rounding = false,
ck::enable_if_t<interpret == ck_fp8_interpretation_t::CK_E4M3_OCP, bool> = false,
ck::enable_if_t<stochastic_rounding == true, bool> = false>
static __device__ fp8_storage_t cast_to_f8_from_f16(_Float16 v, unsigned int rng = 0)
{
union
{
unsigned int i32val;
half2_t half_vec;
fp8_storage_t i8val[4];
} val;
constexpr unsigned int i32val = 0;
val.half_vec[0] = v;
if constexpr(saturate)
{
if((val.i32val & 0x7FFF) != 0x7FFF)
{
val.half_vec[0] = __builtin_amdgcn_fmed3h(val.half_vec[0], 448.0, -448.0);
}
}
val.i32val =
__builtin_amdgcn_cvt_scalef32_sr_fp8_f16(i32val, val.half_vec[0], rng, /* scale */ 1.f, 0);
return val.i8val[0];
}
template <ck_fp8_interpretation_t interpret,
bool saturate,
bool stochastic_rounding = false,
ck::enable_if_t<interpret == ck_fp8_interpretation_t::CK_E4M3_OCP, bool> = false,
ck::enable_if_t<stochastic_rounding == true, bool> = false>
static __device__ fp8x2_storage_t cast_to_f8_from_f16(half2_t v, unsigned int rng = 0)
{
// there is no packed conversion with SR, so convert one element at a time
return fp8x2_storage_t{
cast_to_f8_from_f16<interpret, saturate, stochastic_rounding>(v[0], rng),
cast_to_f8_from_f16<interpret, saturate, stochastic_rounding>(v[1], rng)};
}
template <ck_fp8_interpretation_t interpret,
bool saturate,
bool stochastic_rounding = false,
ck::enable_if_t<interpret == ck_fp8_interpretation_t::CK_E5M2_OCP, bool> = false,
ck::enable_if_t<stochastic_rounding == true, bool> = false>
static __device__ fp8_storage_t cast_to_f8_from_f16(_Float16 v, unsigned int rng = 0)
{
union
{
unsigned int i32val;
half2_t half_vec;
fp8_storage_t i8val[4];
} val;
constexpr unsigned int i32val = 0;
val.half_vec[0] = v;
if constexpr(saturate)
{
if((val.i32val & 0x7FFF) != 0x7FFF)
{
val.half_vec[0] = __builtin_amdgcn_fmed3h(val.half_vec[0], 57344.0, -57344.0);
}
}
val.i32val =
__builtin_amdgcn_cvt_scalef32_sr_bf8_f16(i32val, val.half_vec[0], rng, /* scale */ 1.f, 0);
return val.i8val[0];
}
template <ck_fp8_interpretation_t interpret,
bool saturate,
bool stochastic_rounding = false,
ck::enable_if_t<interpret == ck_fp8_interpretation_t::CK_E5M2_OCP, bool> = false,
ck::enable_if_t<stochastic_rounding == true, bool> = false>
static __device__ fp8x2_storage_t cast_to_f8_from_f16(half2_t v, unsigned int rng = 0)
{
// there is no packed conversion with SR, so convert one element at a time
return fp8x2_storage_t{
cast_to_f8_from_f16<interpret, saturate, stochastic_rounding>(v[0], rng),
cast_to_f8_from_f16<interpret, saturate, stochastic_rounding>(v[1], rng)};
}
template <ck_fp8_interpretation_t interpret,
bool saturate,
bool stochastic_rounding = false,
ck::enable_if_t<interpret == ck_fp8_interpretation_t::CK_E4M3_OCP, bool> = false,
ck::enable_if_t<stochastic_rounding == false, bool> = false>
static __device__ fp8_storage_t cast_to_f8_from_f16(_Float16 v, unsigned int rng = 0)
{
std::ignore = rng;
union
{
unsigned int i32val;
half2_t half_vec;
shortx2_t i16_vec;
fp8_storage_t i8val[4];
} val;
constexpr shortx2_t i16x2val = {0, 0};
val.half_vec[0] = v;
if constexpr(saturate)
{
if((val.i32val & 0x7FFF) != 0x7FFF)
{
val.half_vec[0] = __builtin_amdgcn_fmed3h(val.half_vec[0], 448.0, -448.0);
}
}
val.i16_vec =
__builtin_amdgcn_cvt_scalef32_pk_fp8_f16(i16x2val, val.half_vec, /* scale */ 1.f, 0);
return val.i8val[0];
}
template <ck_fp8_interpretation_t interpret,
bool saturate,
bool stochastic_rounding = false,
ck::enable_if_t<interpret == ck_fp8_interpretation_t::CK_E4M3_OCP, bool> = false,
ck::enable_if_t<stochastic_rounding == false, bool> = false>
static __device__ fp8x2_storage_t cast_to_f8_from_f16(half2_t v, unsigned int rng = 0)
{
#if CK_WORKAROUND_FP16_TO_FP8_CONVERSION
return fp8x2_storage_t{
cast_to_f8_from_f16<interpret, saturate, stochastic_rounding>(v[0], rng),
cast_to_f8_from_f16<interpret, saturate, stochastic_rounding>(v[1], rng)};
#else
std::ignore = rng;
union
{
half2_t half_vec;
shortx2_t i16_vec;
fp8_storage_t i8val[4];
} val;
constexpr shortx2_t i16x2val = {0, 0};
val.half_vec = v;
if constexpr(saturate)
{
if((val.i16_vec[0] & 0x7FFF) != 0x7FFF)
{
val.half_vec[0] = __builtin_amdgcn_fmed3h(val.half_vec[0], 448.0, -448.0);
}
if((val.i16_vec[1] & 0x7FFF) != 0x7FFF)
{
val.half_vec[1] = __builtin_amdgcn_fmed3h(val.half_vec[1], 448.0, -448.0);
}
}
val.i16_vec =
__builtin_amdgcn_cvt_scalef32_pk_fp8_f16(i16x2val, val.half_vec, /* scale */ 1.f, 0);
return fp8x2_storage_t{val.i8val[0], val.i8val[1]};
#endif
}
template <ck_fp8_interpretation_t interpret,
bool saturate,
bool stochastic_rounding = false,
ck::enable_if_t<interpret == ck_fp8_interpretation_t::CK_E5M2_OCP, bool> = false,
ck::enable_if_t<stochastic_rounding == false, bool> = false>
static __device__ fp8_storage_t cast_to_f8_from_f16(_Float16 v, unsigned int rng = 0)
{
std::ignore = rng;
union
{
unsigned int i32val;
half2_t half_vec;
shortx2_t i16_vec;
fp8_storage_t i8val[4];
} val;
constexpr shortx2_t i16x2val = {0, 0};
val.half_vec[0] = v;
if constexpr(saturate)
{
if((val.i32val & 0x7FFF) != 0x7FFF)
{
val.half_vec[0] = __builtin_amdgcn_fmed3h(val.half_vec[0], 57344.0, -57344.0);
}
}
val.half_vec =
__builtin_amdgcn_cvt_scalef32_pk_bf8_f16(i16x2val, val.half_vec, /* scale */ 1.f, 0);
return val.i8val[0];
}
template <ck_fp8_interpretation_t interpret,
bool saturate,
bool stochastic_rounding = false,
ck::enable_if_t<interpret == ck_fp8_interpretation_t::CK_E5M2_OCP, bool> = false,
ck::enable_if_t<stochastic_rounding == false, bool> = false>
static __device__ fp8x2_storage_t cast_to_f8_from_f16(half2_t v, unsigned int rng = 0)
{
#if CK_WORKAROUND_FP16_TO_FP8_CONVERSION
return fp8x2_storage_t{
cast_to_f8_from_f16<interpret, saturate, stochastic_rounding>(v[0], rng),
cast_to_f8_from_f16<interpret, saturate, stochastic_rounding>(v[1], rng)};
#else
std::ignore = rng;
union
{
half2_t half_vec;
shortx2_t i16_vec;
fp8_storage_t i8val[4];
} val;
constexpr shortx2_t i16x2val = {0, 0};
val.half_vec = v;
if constexpr(saturate)
{
if((val.i16_vec[0] & 0x7FFF) != 0x7FFF)
{
val.half_vec[0] = __builtin_amdgcn_fmed3h(val.half_vec[0], 57344.0, -57344.0);
}
if((val.i16_vec[1] & 0x7FFF) != 0x7FFF)
{
val.half_vec[1] = __builtin_amdgcn_fmed3h(val.half_vec[1], 57344.0, -57344.0);
}
}
val.i16_vec =
__builtin_amdgcn_cvt_scalef32_pk_bf8_f16(i16x2val, val.half_vec, /* scale */ 1.f, 0);
return fp8x2_storage_t{val.i8val[0], val.i8val[1]};
#endif
}
template <ck_fp8_interpretation_t interpret,
bool saturate,
bool stochastic_rounding = false,
ck::enable_if_t<interpret == ck_fp8_interpretation_t::CK_E4M3_OCP, bool> = false,
ck::enable_if_t<stochastic_rounding == true, bool> = false>
static __device__ fp8_storage_t cast_to_f8_from_bf16(ushort v, unsigned int rng = 0)
{
union
{
unsigned int i32val;
ushortx2_t bhalf_vec;
fp8_storage_t i8val[4];
} val;
constexpr unsigned int i32val = 0;
val.bhalf_vec[0] = v;
if constexpr(saturate)
{
if((val.i32val & 0x7FFF) != 0x7FFF)
{
val.bhalf_vec[0] =
ushort((bit_cast<uint32_t>(__builtin_amdgcn_fmed3f(
bit_cast<float>(uint32_t{val.bhalf_vec[0]} << 16), 448.0, -448.0)) >>
16)); // convert to float and back
}
}
val.i32val = __builtin_amdgcn_cvt_scalef32_sr_fp8_bf16(
i32val, val.bhalf_vec[0], rng, /* scale */ 1.f, 0);
return val.i8val[0];
}
template <ck_fp8_interpretation_t interpret,
bool saturate,
bool stochastic_rounding = false,
ck::enable_if_t<interpret == ck_fp8_interpretation_t::CK_E4M3_OCP, bool> = false,
ck::enable_if_t<stochastic_rounding == true, bool> = false>
static __device__ fp8x2_storage_t cast_to_f8_from_bf16(ushortx2_t v, unsigned int rng = 0)
{
// there is no packed conversion with SR, so convert one element at a time
return fp8x2_storage_t{
cast_to_f8_from_bf16<interpret, saturate, stochastic_rounding>(v[0], rng),
cast_to_f8_from_bf16<interpret, saturate, stochastic_rounding>(v[1], rng)};
}
template <ck_fp8_interpretation_t interpret,
bool saturate,
bool stochastic_rounding = false,
ck::enable_if_t<interpret == ck_fp8_interpretation_t::CK_E5M2_OCP, bool> = false,
ck::enable_if_t<stochastic_rounding == true, bool> = false>
static __device__ fp8_storage_t cast_to_f8_from_bf16(ushort v, unsigned int rng = 0)
{
union
{
unsigned int i32val;
ushortx2_t bhalf_vec;
fp8_storage_t i8val[4];
} val;
constexpr unsigned int i32val = 0;
val.bhalf_vec[0] = v;
if constexpr(saturate)
{
if((val.i32val & 0x7FFF) != 0x7FFF)
{
val.bhalf_vec[0] = ushort(
(bit_cast<uint32_t>(__builtin_amdgcn_fmed3f(
bit_cast<float>(uint32_t{val.bhalf_vec[0]} << 16), 57344.0, -57344.0)) >>
16)); // convert to float and back
}
}
val.i32val = __builtin_amdgcn_cvt_scalef32_sr_bf8_bf16(
i32val, val.bhalf_vec[0], rng, /* scale */ 1.f, 0);
return val.i8val[0];
}
template <ck_fp8_interpretation_t interpret,
bool saturate,
bool stochastic_rounding = false,
ck::enable_if_t<interpret == ck_fp8_interpretation_t::CK_E5M2_OCP, bool> = false,
ck::enable_if_t<stochastic_rounding == true, bool> = false>
static __device__ fp8x2_storage_t cast_to_f8_from_bf16(ushortx2_t v, unsigned int rng = 0)
{
// there is no packed conversion with SR, so convert one element at a time
return fp8x2_storage_t{
cast_to_f8_from_bf16<interpret, saturate, stochastic_rounding>(v[0], rng),
cast_to_f8_from_bf16<interpret, saturate, stochastic_rounding>(v[1], rng)};
}
template <ck_fp8_interpretation_t interpret,
bool saturate,
bool stochastic_rounding = false,
ck::enable_if_t<interpret == ck_fp8_interpretation_t::CK_E4M3_OCP, bool> = false,
ck::enable_if_t<stochastic_rounding == false, bool> = false>
static __device__ fp8_storage_t cast_to_f8_from_bf16(ushort v, unsigned int rng = 0)
{
std::ignore = rng;
union
{
unsigned int i32val;
ushortx2_t bhalf_vec;
shortx2_t i16_vec;
fp8_storage_t i8val[4];
} val;
constexpr shortx2_t i16x2val = {0, 0};
val.bhalf_vec[0] = v;
if constexpr(saturate)
{
if((val.i32val & 0x7FFF) != 0x7FFF)
{
val.bhalf_vec[0] =
ushort((bit_cast<uint32_t>(__builtin_amdgcn_fmed3f(
bit_cast<float>(uint32_t{val.bhalf_vec[0]} << 16), 448.0, -448.0)) >>
16)); // convert to float and back
}
}
val.i16_vec =
__builtin_amdgcn_cvt_scalef32_pk_fp8_bf16(i16x2val, val.bhalf_vec, /* scale */ 1.f, 0);
return val.i8val[0];
}
template <ck_fp8_interpretation_t interpret,
bool saturate,
bool stochastic_rounding = false,
ck::enable_if_t<interpret == ck_fp8_interpretation_t::CK_E4M3_OCP, bool> = false,
ck::enable_if_t<stochastic_rounding == false, bool> = false>
static __device__ fp8x2_storage_t cast_to_f8_from_bf16(ushortx2_t v, unsigned int rng = 0)
{
#if CK_WORKAROUND_BF16_TO_FP8_CONVERSION
return fp8x2_storage_t{
cast_to_f8_from_bf16<interpret, saturate, stochastic_rounding>(v[0], rng),
cast_to_f8_from_bf16<interpret, saturate, stochastic_rounding>(v[1], rng)};
#else
std::ignore = rng;
union
{
ushortx2_t bhalf_vec;
shortx2_t i16_vec;
fp8_storage_t i8val[4];
} val;
constexpr shortx2_t i16x2val = {0, 0};
val.bhalf_vec = v;
if constexpr(saturate)
{
if((val.i16_vec[0] & 0x7FFF) != 0x7FFF)
{
val.bhalf_vec[0] =
ushort((bit_cast<uint32_t>(__builtin_amdgcn_fmed3f(
bit_cast<float>(uint32_t{val.bhalf_vec[0]} << 16), 448.0, -448.0)) >>
16)); // convert to float and back
}
if((val.i16_vec[1] & 0x7FFF) != 0x7FFF)
{
val.bhalf_vec[1] =
ushort((bit_cast<uint32_t>(__builtin_amdgcn_fmed3f(
bit_cast<float>(uint32_t{val.bhalf_vec[1]} << 16), 448.0, -448.0)) >>
16)); // convert to float and back
}
}
val.i16_vec =
__builtin_amdgcn_cvt_scalef32_pk_fp8_bf16(i16x2val, val.bhalf_vec, /* scale */ 1.f, 0);
return fp8x2_storage_t{val.i8val[0], val.i8val[1]};
#endif
}
template <ck_fp8_interpretation_t interpret,
bool saturate,
bool stochastic_rounding = false,
ck::enable_if_t<interpret == ck_fp8_interpretation_t::CK_E5M2_OCP, bool> = false,
ck::enable_if_t<stochastic_rounding == false, bool> = false>
static __device__ fp8_storage_t cast_to_f8_from_bf16(ushort v, unsigned int rng = 0)
{
std::ignore = rng;
union
{
unsigned int i32val;
ushortx2_t bhalf_vec;
shortx2_t i16_vec;
fp8_storage_t i8val[4];
} val;
constexpr shortx2_t i16x2val = {0, 0};
val.bhalf_vec[0] = v;
if constexpr(saturate)
{
if((val.i32val & 0x7FFF) != 0x7FFF)
{
val.bhalf_vec[0] = ushort(
(bit_cast<uint32_t>(__builtin_amdgcn_fmed3f(
bit_cast<float>(uint32_t{val.bhalf_vec[0]} << 16), 57344.0, -57344.0)) >>
16)); // convert to float and back
}
}
val.i16_vec =
__builtin_amdgcn_cvt_scalef32_pk_bf8_bf16(i16x2val, val.bhalf_vec, /* scale */ 1.f, 0);
return val.i8val[0];
}
template <ck_fp8_interpretation_t interpret,
bool saturate,
bool stochastic_rounding = false,
ck::enable_if_t<interpret == ck_fp8_interpretation_t::CK_E5M2_OCP, bool> = false,
ck::enable_if_t<stochastic_rounding == false, bool> = false>
static __device__ fp8x2_storage_t cast_to_f8_from_bf16(ushortx2_t v, unsigned int rng = 0)
{
std::ignore = rng;
union
{
ushortx2_t bhalf_vec;
shortx2_t i16_vec;
fp8_storage_t i8val[4];
} val;
constexpr shortx2_t i16x2val = {0, 0};
val.bhalf_vec = v;
if constexpr(saturate)
{
if((val.i16_vec[0] & 0x7FFF) != 0x7FFF)
{
val.bhalf_vec[0] = ushort(
(bit_cast<uint32_t>(__builtin_amdgcn_fmed3f(
bit_cast<float>(uint32_t{val.bhalf_vec[0]} << 16), 57344.0, -57344.0)) >>
16)); // convert to float and back
}
if((val.i16_vec[1] & 0x7FFF) != 0x7FFF)
{
val.bhalf_vec[1] = ushort(
(bit_cast<uint32_t>(__builtin_amdgcn_fmed3f(
bit_cast<float>(uint32_t{val.bhalf_vec[1]} << 16), 57344.0, -57344.0)) >>
16)); // convert to float and back
}
}
val.i16_vec =
__builtin_amdgcn_cvt_scalef32_pk_bf8_bf16(i16x2val, val.bhalf_vec, /* scale */ 1.f, 0);
return fp8x2_storage_t{val.i8val[0], val.i8val[1]};
}
#endif // defined(__gfx950__)
#if CK_FP8_CVT_FAST_PATH
// The conversion function is from rocblas
// https://github.com/ROCm/rocBLAS/blob/9b7f692abe3c54b88d1e77e045a7db7f1f188b69/library/include/internal/rocblas_float8.h#L79
@@ -523,6 +1030,84 @@ static __device__ fp8_storage_t cast_to_f8_from_f32(float v, unsigned int rng =
}
return i8data;
}
template <ck_fp8_interpretation_t interpret, bool saturate, bool stochastic_rounding = false>
static __device__ fp8x2_storage_t cast_to_f8_from_f32(float2_t v, unsigned int rng = 0)
{
if constexpr(stochastic_rounding)
{
// there is no packed conversion with SR, so convert one element at a time
return fp8x2_storage_t{
cast_to_f8_from_f32<interpret, saturate, stochastic_rounding>(v[0], rng),
cast_to_f8_from_f32<interpret, saturate, stochastic_rounding>(v[1], rng)};
}
else
{
union
{
float fval;
unsigned int i32val;
unsigned char i8val[4];
} val0, val1;
val0.fval = v[0];
val1.fval = v[1];
unsigned int ival = 0;
if constexpr(saturate)
{
if constexpr(interpret == ck_fp8_interpretation_t::CK_E4M3_FNUZ)
{
if((val0.i32val & 0x7F800000) != 0x7F800000)
{ /// propagate NAN/INF, no clipping
val0.fval = __builtin_amdgcn_fmed3f(val0.fval, 240.0, -240.0);
}
if((val1.i32val & 0x7F800000) != 0x7F800000)
{ /// propagate NAN/INF, no clipping
val1.fval = __builtin_amdgcn_fmed3f(val1.fval, 240.0, -240.0);
}
}
else if constexpr(interpret == ck_fp8_interpretation_t::CK_E4M3_OCP)
{ // OCP type
if((val0.i32val & 0x7F800000) != 0x7F800000)
{ /// propagate NAN/INF, no clipping
val0.fval = __builtin_amdgcn_fmed3f(val0.fval, 448.0, -448.0);
}
if((val1.i32val & 0x7F800000) != 0x7F800000)
{ /// propagate NAN/INF, no clipping
val1.fval = __builtin_amdgcn_fmed3f(val1.fval, 448.0, -448.0);
}
}
else
{
if((val0.i32val & 0x7F800000) != 0x7F800000)
{ /// propagate NAN/INF, no clipping
val0.fval = __builtin_amdgcn_fmed3f(val0.fval, 57344.0, -57344.0);
}
if((val1.i32val & 0x7F800000) != 0x7F800000)
{ /// propagate NAN/INF, no clipping
val1.fval = __builtin_amdgcn_fmed3f(val1.fval, 57344.0, -57344.0);
}
}
}
// RNE CVT
if constexpr((interpret == ck_fp8_interpretation_t::CK_E4M3_FNUZ) ||
(interpret == ck_fp8_interpretation_t::CK_E4M3_OCP))
{
ival = __builtin_amdgcn_cvt_pk_fp8_f32(val0.fval, val1.fval, ival, false);
}
else
{
ival = __builtin_amdgcn_cvt_pk_bf8_f32(val0.fval, val1.fval, ival, false);
}
val0.i32val = ival;
return fp8x2_storage_t{val0.i8val[0], val0.i8val[1]};
}
}
#endif // CK_FP8_CVT_FAST_PATH
// The conversion function is from rocblas
@@ -797,6 +1382,7 @@ __host__ __device__ static inline fp8_storage_t cast_to_f8(T _x, unsigned int rn
*
* \tparam interp interpretation of fp8
* \tparam sat saturation of fp8
* \tparam stochastic_rounding switch between RNE and SR
* \param f float number
* \return fp8_storage_t
*/
@@ -882,6 +1468,47 @@ __host__ static inline fp8_storage_t cvt_float_to_fp8(const float f)
#endif // CK_FP8_CVT_FAST_PATH
}
/**
* \brief convert vector of 2 floats to vector of 2 @p fp8_storage_t
*
* \tparam interp interpretation of fp8
* \tparam sat saturation of fp8
* \tparam stochastic_rounding switch between RNE and SR
* \param f vector of 2 floats
* \return fp8x2_storage_t
*/
template <ck_fp8_interpretation_t interp,
ck_saturation_t sat = ck_saturation_t::CK_SATFINITE,
bool stochastic_rounding = false>
#if CK_FP8_CVT_FAST_PATH
__device__ static inline fp8x2_storage_t cvt_float_to_fp8(const float2_t f)
{
__is_interpret_supported(interp);
uint32_t rng = 0;
if constexpr(stochastic_rounding)
{
constexpr int seed = 1254739;
#ifndef CK_CODE_GEN_RTC
rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&f), f[0]);
#else
rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&f), f[0]);
#endif
}
return cast_to_f8_from_f32<interp, sat == ck_saturation_t::CK_SATFINITE, stochastic_rounding>(
f, rng);
#else
#if CK_USE_OCP_FP8
__host__ __device__ static inline fp8x2_storage_t cvt_float_to_fp8(const float2_t f)
{
#else
__host__ static inline fp8x2_storage_t cvt_float_to_fp8(const float2_t f)
{
#endif // CK_USE_OCP_FP8
return fp8x2_storage_t{cvt_float_to_fp8<interp, sat, stochastic_rounding>(f[0]),
cvt_float_to_fp8<interp, sat, stochastic_rounding>(f[1])};
#endif // CK_FP8_CVT_FAST_PATH
}
/**
* \brief convert _Float16 to @p fp8_storage_t
*
@@ -900,87 +1527,168 @@ __host__ __device__ static inline fp8_storage_t cvt_half_t_to_fp8(const _Float16
__host__ static inline fp8_storage_t cvt_half_t_to_fp8(const _Float16 x)
#endif
{
return cvt_float_to_fp8<interp, sat, stochastic_rounding>(static_cast<float>(x));
{
__is_interpret_supported(interp);
uint32_t rng = 0;
if constexpr(stochastic_rounding)
{
constexpr int seed = 1254739;
#ifndef CK_CODE_GEN_RTC
rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
#else
rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&x), x);
#endif
}
#if defined(__gfx950__)
return cast_to_f8_from_f16<interp,
sat == ck_saturation_t::CK_SATFINITE,
stochastic_rounding>(x, rng);
#else
std::ignore = rng;
return cvt_float_to_fp8<interp, ck_saturation_t::CK_SATFINITE, stochastic_rounding>(
static_cast<float>(x));
#endif // defined(__gfx950__)
}
}
/**
* \brief convert vector of 2 _Float16 to vector of 2 @p fp8_storage_t
*
* \tparam sat saturation of fp8
* \tparam interp interpretation of fp8
* \tparam stochastic_rounding switch between RNE and SR
* \param x vector of 2 _Float16
* \return fp8x2_storage_t
*/
template <ck_fp8_interpretation_t interp,
ck_saturation_t sat = ck_saturation_t::CK_SATFINITE,
bool stochastic_rounding = false>
#if CK_FP8_CVT_FAST_PATH || CK_USE_OCP_FP8
__host__ __device__ static inline fp8x2_storage_t cvt_half_t_to_fp8(const half2_t x)
#else
__host__ static inline fp8x2_storage_t cvt_half_t_to_fp8(const half2_t x)
#endif
{
{
__is_interpret_supported(interp);
uint32_t rng = 0;
if constexpr(stochastic_rounding)
{
constexpr int seed = 1254739;
#ifndef CK_CODE_GEN_RTC
rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x[0]);
#else
rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&x), x[0]);
#endif
}
#if defined(__gfx950__)
return cast_to_f8_from_f16<interp,
sat == ck_saturation_t::CK_SATFINITE,
stochastic_rounding>(x, rng);
#else
std::ignore = rng;
return cvt_float_to_fp8<interp, ck_saturation_t::CK_SATFINITE, stochastic_rounding>(
float2_t{static_cast<float>(x[0]), static_cast<float>(x[1])});
#endif // defined(__gfx950__)
}
}
/**
* \brief convert bhalf_t to @p fp8_storage_t
*
* \tparam sat saturation of fp8
* \tparam interp interpretation of fp8
* \tparam stochastic_rounding switch between RNE and SR
* \param x bhalf_t value
* \return fp8_storage_t
*/
template <ck_fp8_interpretation_t interp,
ck_saturation_t sat = ck_saturation_t::CK_SATFINITE,
bool stochastic_rounding = false>
#if CK_FP8_CVT_FAST_PATH || CK_USE_OCP_FP8
__host__ __device__ static inline fp8_storage_t cvt_bhalf_t_to_fp8(const ushort x)
#else
__host__ static inline fp8_storage_t cvt_bhalf_t_to_fp8(const ushort x)
#endif
{
{
__is_interpret_supported(interp);
uint32_t rng = 0;
if constexpr(stochastic_rounding)
{
constexpr int seed = 1254739;
#ifndef CK_CODE_GEN_RTC
rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x),
static_cast<float>(x));
#else
rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&x), static_cast<float>(x));
#endif
}
#if defined(__gfx950__)
return cast_to_f8_from_bf16<interp,
sat == ck_saturation_t::CK_SATFINITE,
stochastic_rounding>(x, rng);
#else
std::ignore = rng;
return cvt_float_to_fp8<interp, ck_saturation_t::CK_SATFINITE, stochastic_rounding>(
bit_cast<float>(uint32_t{x} << 16)); // convert value to float
#endif // defined(__gfx950__)
}
}
/**
* \brief convert vector of 2 bhalf_t to vector of 2 @p fp8_storage_t
*
* \tparam sat saturation of fp8
* \tparam interp interpretation of fp8
* \tparam stochastic_rounding switch between RNE and SR
* \param x vector of 2 bhalf_t
* \return fp8x2_storage_t
*/
template <ck_fp8_interpretation_t interp,
ck_saturation_t sat = ck_saturation_t::CK_SATFINITE,
bool stochastic_rounding = false>
#if CK_FP8_CVT_FAST_PATH || CK_USE_OCP_FP8
__host__ __device__ static inline fp8x2_storage_t cvt_bhalf_t_to_fp8(const ushortx2_t x)
#else
__host__ static inline fp8x2_storage_t cvt_bhalf_t_to_fp8(const ushortx2_t x)
#endif
{
#if CK_WORKAROUND_BF16_TO_FP8_CONVERSION
return cvt_float_to_fp8<interp, ck_saturation_t::CK_SATFINITE, stochastic_rounding>(
float2_t{bit_cast<float>(uint32_t{x[0]} << 16),
bit_cast<float>(uint32_t{x[1]} << 16)}); // convert values to float
#else // CK_WORKAROUND_BF16_TO_FP8_CONVERSION
{
__is_interpret_supported(interp);
uint32_t rng = 0;
if constexpr(stochastic_rounding)
{
constexpr int seed = 1254739;
#ifndef CK_CODE_GEN_RTC
rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x),
static_cast<float>(x[0]));
#else
rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&x),
static_cast<float>(x[0]));
#endif
}
#if defined(__gfx950__)
return cast_to_f8_from_bf16<interp,
sat == ck_saturation_t::CK_SATFINITE,
stochastic_rounding>(x, rng);
#else
std::ignore = rng;
return cvt_float_to_fp8<interp, ck_saturation_t::CK_SATFINITE, stochastic_rounding>(
float2_t{bit_cast<float>(uint32_t{x[0]} << 16),
bit_cast<float>(uint32_t{x[1]} << 16)}); // convert values to float
#endif // defined(__gfx950__)
}
#endif // CK_WORKAROUND_BF16_TO_FP8_CONVERSION
}
} // namespace fp8_impl
// Declare a template function for fp8 conversion using RNE
template <typename Y, typename X>
__host__ __device__ constexpr Y f8_convert_rne(X x);
// convert fp32 to fp8 with rounding to nearest even
template <>
inline __host__ __device__ f8_ocp_t f8_convert_rne<f8_ocp_t, float>(float x)
{
return f8_ocp_t{
fp8_impl::cvt_float_to_fp8<f8_ocp_t::default_interpret, f8_ocp_t::default_saturation>(x)};
}
// convert fp32 to bf8 with rounding to nearest even
template <>
inline __host__ __device__ bf8_ocp_t f8_convert_rne<bf8_ocp_t, float>(float x)
{
return bf8_ocp_t{
fp8_impl::cvt_float_to_fp8<bf8_ocp_t::default_interpret, bf8_ocp_t::default_saturation>(x)};
}
// convert _Float16 to fp8 with rounding to nearest even
template <>
inline __host__ __device__ f8_ocp_t f8_convert_rne<f8_ocp_t, _Float16>(_Float16 x)
{
return f8_ocp_t{
fp8_impl::cvt_half_t_to_fp8<f8_ocp_t::default_interpret, f8_ocp_t::default_saturation>(x)};
}
template <>
inline __host__ __device__ bf8_ocp_t f8_convert_rne<bf8_ocp_t, _Float16>(_Float16 x)
{
return bf8_ocp_t{
fp8_impl::cvt_half_t_to_fp8<bf8_ocp_t::default_interpret, bf8_ocp_t::default_saturation>(
x)};
}
// Declare a template function for fp8 conversion using RNE
template <typename Y, typename X>
__host__ __device__ constexpr Y f8_convert_sr(X x);
// convert fp32 to fp8 with stochastic rounding
template <>
inline __host__ __device__ f8_ocp_t f8_convert_sr<f8_ocp_t, float>(float x)
{
return f8_ocp_t{
fp8_impl::cvt_float_to_fp8<f8_ocp_t::default_interpret, f8_ocp_t::default_saturation, true>(
x)};
}
// convert fp32 to bf8 with stochastic rounding
template <>
inline __host__ __device__ bf8_ocp_t f8_convert_sr<bf8_ocp_t, float>(float x)
{
return bf8_ocp_t{fp8_impl::cvt_float_to_fp8<bf8_ocp_t::default_interpret,
bf8_ocp_t::default_saturation,
true>(x)};
}
// convert _Float16 to fp8 with stochastic rounding
template <>
inline __host__ __device__ f8_ocp_t f8_convert_sr<f8_ocp_t, _Float16>(_Float16 x)
{
return f8_ocp_t{fp8_impl::cvt_half_t_to_fp8<f8_ocp_t::default_interpret,
f8_ocp_t::default_saturation,
true>(x)};
}
// convert _Float16 to bf8 with stochastic rounding
template <>
inline __host__ __device__ bf8_ocp_t f8_convert_sr<bf8_ocp_t, _Float16>(_Float16 x)
{
return bf8_ocp_t{fp8_impl::cvt_half_t_to_fp8<bf8_ocp_t::default_interpret,
bf8_ocp_t::default_saturation,
true>(x)};
}
#if CK_USE_OCP_FP8
using f8_t = f8_ocp_t;
using bf8_t = bf8_ocp_t;

View File

@@ -39,7 +39,7 @@ static __device__ float cast_to_f32_from_f8_scaled(float scale, fp8_storage_t v)
}
template <ck_fp8_interpretation_t interpret>
static __device__ float2_t cast_to_f32x2_from_f8x2_scaled(float scale, fp8x2_storage_t v)
static __device__ float2_t cast_to_f32_from_f8_scaled(float scale, fp8x2_storage_t v)
{
const auto i16val = bit_cast<uint16_t>(v);

View File

@@ -67,7 +67,7 @@ inline __host__ float2_t scaled_type_convert<float2_t, f8x2_ocp_t>(e8m0_bexp_t s
#endif
{
#if CK_MX_FP8_CVT_FAST_PATH
return fp8_impl::cast_to_f32x2_from_f8x2_scaled<f8_ocp_t::default_interpret>(
return fp8_impl::cast_to_f32_from_f8_scaled<f8_ocp_t::default_interpret>(
type_convert<float>(scale), x.AsType<fp8_impl::fp8x2_storage_t>()[Number<0>{}]);
#else
return float2_t{scaled_type_convert<float>(scale, x.AsType<f8_ocp_t>()[Number<0>{}]),
@@ -86,7 +86,7 @@ inline __host__ float2_t scaled_type_convert<float2_t, bf8x2_ocp_t>(e8m0_bexp_t
#endif
{
#if CK_MX_FP8_CVT_FAST_PATH
return fp8_impl::cast_to_f32x2_from_f8x2_scaled<bf8_ocp_t::default_interpret>(
return fp8_impl::cast_to_f32_from_f8_scaled<bf8_ocp_t::default_interpret>(
type_convert<float>(scale), x.AsType<fp8_impl::fp8x2_storage_t>()[Number<0>{}]);
#else
return float2_t{scaled_type_convert<float>(scale, x.AsType<bf8_ocp_t>()[Number<0>{}]),

View File

@@ -117,7 +117,7 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, float>(float
#if CK_USE_RNE_BF16_CONVERSION
return bf16_convert_rtn<bhalf_t>(x);
#else
return uint16_t(u.int32 >> 16);
return uint16_t(uint32_t{x} >> 16);
#endif
}
@@ -356,6 +356,180 @@ inline __host__ __device__ bf8_fnuz_t f8_convert_sr<bf8_fnuz_t, half_t>(half_t x
#endif
}
/**
* @brief Converts a float to a 8-bit float type (f8_ocp_t) using stochastic rounding.
*
* @param x The input float value.
* @return The converted f8_ocp_t value.
*/
template <>
inline __host__ __device__ f8_ocp_t f8_convert_sr<f8_ocp_t, float>(float x)
{
return f8_ocp_t{
fp8_impl::cvt_float_to_fp8<f8_ocp_t::default_interpret, f8_ocp_t::default_saturation, true>(
x)};
}
/**
* @brief Converts a vector of 2 floats to a vector of 2 8-bit float types (f8_ocp_t) using
* stochastic rounding.
*
* @param x The input vector of 2 floats.
* @return The converted vector of 2 f8_ocp_t.
*/
template <>
inline __host__ __device__ f8x2_ocp_t f8_convert_sr<f8x2_ocp_t, float2_t>(float2_t x)
{
return f8x2_ocp_t{
fp8_impl::cvt_float_to_fp8<f8_ocp_t::default_interpret, f8_ocp_t::default_saturation, true>(
x)};
}
/**
* @brief Converts a float to a 8-bit float type (bf8_ocp_t) using stochastic rounding.
*
* @param x The input float value.
* @return The converted bf8_ocp_t value.
*/
template <>
inline __host__ __device__ bf8_ocp_t f8_convert_sr<bf8_ocp_t, float>(float x)
{
return bf8_ocp_t{fp8_impl::cvt_float_to_fp8<bf8_ocp_t::default_interpret,
bf8_ocp_t::default_saturation,
true>(x)};
}
/**
* @brief Converts a vector of 2 floats to a vector of 2 8-bit float types (bf8_ocp_t) using
* stochastic rounding.
*
* @param x The input vector of 2 floats.
* @return The converted vector of 2 bf8_ocp_t.
*/
template <>
inline __host__ __device__ bf8x2_ocp_t f8_convert_sr<bf8x2_ocp_t, float2_t>(float2_t x)
{
return bf8x2_ocp_t{fp8_impl::cvt_float_to_fp8<bf8_ocp_t::default_interpret,
bf8_ocp_t::default_saturation,
true>(x)};
}
/**
* @brief Converts a half_t to a 8-bit float type (f8_ocp_t) using stochastic rounding.
*
* @param x The input half_t value.
* @return The converted f8_ocp_t value.
*/
template <>
inline __host__ __device__ f8_ocp_t f8_convert_sr<f8_ocp_t, half_t>(half_t x)
{
return f8_ocp_t{fp8_impl::cvt_half_t_to_fp8<f8_ocp_t::default_interpret,
f8_ocp_t::default_saturation,
true>(x)};
}
/**
* @brief Converts a vector of 2 half_t to a vector of 2 8-bit float types (f8_ocp_t) using
* stochastic rounding.
*
* @param x The input vector of 2 half_t.
* @return The converted vector of 2 f8_ocp_t.
*/
template <>
inline __host__ __device__ f8x2_ocp_t f8_convert_sr<f8x2_ocp_t, half2_t>(half2_t x)
{
return f8x2_ocp_t{fp8_impl::cvt_half_t_to_fp8<f8_ocp_t::default_interpret,
f8_ocp_t::default_saturation,
true>(x)};
}
/**
* @brief Converts a half_t to a 8-bit half_t type (bf8_ocp_t) using stochastic rounding.
*
* @param x The input half_t value.
* @return The converted bf8_ocp_t value.
*/
template <>
inline __host__ __device__ bf8_ocp_t f8_convert_sr<bf8_ocp_t, half_t>(half_t x)
{
return bf8_ocp_t{fp8_impl::cvt_half_t_to_fp8<bf8_ocp_t::default_interpret,
bf8_ocp_t::default_saturation,
true>(x)};
}
/**
* @brief Converts a vector of 2 half_t to a vector of 2 8-bit float types (bf8_ocp_t) using
* stochastic rounding.
*
* @param x The input vector of 2 half_t.
* @return The converted vector of 2 bf8_ocp_t.
*/
template <>
inline __host__ __device__ bf8x2_ocp_t f8_convert_sr<bf8x2_ocp_t, half2_t>(half2_t x)
{
return bf8x2_ocp_t{fp8_impl::cvt_half_t_to_fp8<bf8_ocp_t::default_interpret,
bf8_ocp_t::default_saturation,
true>(x)};
}
/**
* @brief Converts a bhalf_t to a 8-bit float type (f8_ocp_t) using stochastic rounding.
*
* @param x The input bhalf_t value.
* @return The converted f8_ocp_t value.
*/
template <>
inline __host__ __device__ f8_ocp_t f8_convert_sr<f8_ocp_t, bhalf_t>(bhalf_t x)
{
return f8_ocp_t{fp8_impl::cvt_bhalf_t_to_fp8<f8_ocp_t::default_interpret,
f8_ocp_t::default_saturation,
true>(x)};
}
/**
* @brief Converts a vector of 2 bhalf_t to a vector of 2 8-bit float types (f8_ocp_t) using
* stochastic rounding.
*
* @param x The input vector of 2 bhalf_t.
* @return The converted vector of 2 f8_ocp_t.
*/
template <>
inline __host__ __device__ f8x2_ocp_t f8_convert_sr<f8x2_ocp_t, bhalf2_t>(bhalf2_t x)
{
return f8x2_ocp_t{fp8_impl::cvt_bhalf_t_to_fp8<f8_ocp_t::default_interpret,
f8_ocp_t::default_saturation,
true>(x)};
}
/**
* @brief Converts a bhalf_t to a 8-bit half_t type (bf8_ocp_t) using stochastic rounding.
*
* @param x The input bhalf_t value.
* @return The converted bf8_ocp_t value.
*/
template <>
inline __host__ __device__ bf8_ocp_t f8_convert_sr<bf8_ocp_t, bhalf_t>(bhalf_t x)
{
return bf8_ocp_t{fp8_impl::cvt_bhalf_t_to_fp8<bf8_ocp_t::default_interpret,
bf8_ocp_t::default_saturation,
true>(x)};
}
/**
* @brief Converts a vector of 2 bhalf_t to a vector of 2 8-bit float types (bf8_ocp_t) using
* stochastic rounding.
*
* @param x The input vector of 2 bhalf_t.
* @return The converted vector of 2 bf8_ocp_t.
*/
template <>
inline __host__ __device__ bf8x2_ocp_t f8_convert_sr<bf8x2_ocp_t, bhalf2_t>(bhalf2_t x)
{
return bf8x2_ocp_t{fp8_impl::cvt_bhalf_t_to_fp8<bf8_ocp_t::default_interpret,
bf8_ocp_t::default_saturation,
true>(x)};
}
// Declare a template function for fp8 conversion using RNE
template <typename Y, typename X>
__host__ __device__ constexpr Y f8_convert_rne(X x);
@@ -466,6 +640,172 @@ inline __host__ __device__ bf8_fnuz_t f8_convert_rne<bf8_fnuz_t, half_t>(half_t
#endif
}
/**
* @brief Converts a float to a 8-bit float type (f8_ocp_t) using rounding to nearest/even.
*
* @param x The input float value.
* @return The converted f8_ocp_t value.
*/
template <>
inline __host__ __device__ f8_ocp_t f8_convert_rne<f8_ocp_t, float>(float x)
{
return f8_ocp_t{
fp8_impl::cvt_float_to_fp8<f8_ocp_t::default_interpret, f8_ocp_t::default_saturation>(x)};
}
/**
* @brief Converts a vector of 2 floats to a vector of 2 8-bit float types (f8_ocp_t) using rounding
* to nearest/even.
*
* @param x The input vector of 2 floats.
* @return The converted vector of 2 f8_ocp_t.
*/
template <>
inline __host__ __device__ f8x2_ocp_t f8_convert_rne<f8x2_ocp_t, float2_t>(float2_t x)
{
return f8x2_ocp_t{
fp8_impl::cvt_float_to_fp8<f8_ocp_t::default_interpret, f8_ocp_t::default_saturation>(x)};
}
/**
* @brief Converts a float to a 8-bit float type (bf8_ocp_t) using rounding to nearest/even.
*
* @param x The input float value.
* @return The converted bf8_ocp_t value.
*/
template <>
inline __host__ __device__ bf8_ocp_t f8_convert_rne<bf8_ocp_t, float>(float x)
{
return bf8_ocp_t{
fp8_impl::cvt_float_to_fp8<bf8_ocp_t::default_interpret, bf8_ocp_t::default_saturation>(x)};
}
/**
* @brief Converts a vector of 2 floats to a vector of 2 8-bit float types (bf8_ocp_t) using
* rounding to nearest/even.
*
* @param x The input vector of 2 floats.
* @return The converted vector of 2 bf8_ocp_t.
*/
template <>
inline __host__ __device__ bf8x2_ocp_t f8_convert_rne<bf8x2_ocp_t, float2_t>(float2_t x)
{
return bf8x2_ocp_t{
fp8_impl::cvt_float_to_fp8<bf8_ocp_t::default_interpret, bf8_ocp_t::default_saturation>(x)};
}
/**
* @brief Converts a half_t to a 8-bit float type (f8_ocp_t) using rounding to nearest/even.
*
* @param x The input half_t value.
* @return The converted f8_ocp_t value.
*/
template <>
inline __host__ __device__ f8_ocp_t f8_convert_rne<f8_ocp_t, half_t>(half_t x)
{
return f8_ocp_t{
fp8_impl::cvt_half_t_to_fp8<f8_ocp_t::default_interpret, f8_ocp_t::default_saturation>(x)};
}
/**
* @brief Converts a vector of 2 half_t to a vector of 2 8-bit float types (f8_ocp_t) using rounding
* to nearest/even.
*
* @param x The input vector of 2 half_t.
* @return The converted vector of 2 f8_ocp_t.
*/
template <>
inline __host__ __device__ f8x2_ocp_t f8_convert_rne<f8x2_ocp_t, half2_t>(half2_t x)
{
return f8x2_ocp_t{
fp8_impl::cvt_half_t_to_fp8<f8_ocp_t::default_interpret, f8_ocp_t::default_saturation>(x)};
}
/**
* @brief Converts a half_t to a 8-bit half_t type (bf8_ocp_t) using rounding to nearest/even.
*
* @param x The input half_t value.
* @return The converted bf8_ocp_t value.
*/
template <>
inline __host__ __device__ bf8_ocp_t f8_convert_rne<bf8_ocp_t, half_t>(half_t x)
{
return bf8_ocp_t{
fp8_impl::cvt_half_t_to_fp8<bf8_ocp_t::default_interpret, bf8_ocp_t::default_saturation>(
x)};
}
/**
* @brief Converts a vector of 2 half_t to a vector of 2 8-bit float types (bf8_ocp_t) using
* rounding to nearest/even.
*
* @param x The input vector of 2 half_t.
* @return The converted vector of 2 bf8_ocp_t.
*/
template <>
inline __host__ __device__ bf8x2_ocp_t f8_convert_rne<bf8x2_ocp_t, half2_t>(half2_t x)
{
return bf8x2_ocp_t{
fp8_impl::cvt_half_t_to_fp8<bf8_ocp_t::default_interpret, bf8_ocp_t::default_saturation>(
x)};
}
/**
* @brief Converts a bhalf_t to a 8-bit float type (f8_ocp_t) using rounding to nearest/even.
*
* @param x The input bhalf_t value.
* @return The converted f8_ocp_t value.
*/
template <>
inline __host__ __device__ f8_ocp_t f8_convert_rne<f8_ocp_t, bhalf_t>(bhalf_t x)
{
return f8_ocp_t{
fp8_impl::cvt_bhalf_t_to_fp8<f8_ocp_t::default_interpret, f8_ocp_t::default_saturation>(x)};
}
/**
* @brief Converts a vector of 2 bhalf_t to a vector of 2 8-bit float types (f8_ocp_t) using
* rounding to nearest/even.
*
* @param x The input vector of 2 bhalf_t.
* @return The converted vector of 2 f8_ocp_t.
*/
template <>
inline __host__ __device__ f8x2_ocp_t f8_convert_rne<f8x2_ocp_t, bhalf2_t>(bhalf2_t x)
{
return f8x2_ocp_t{
fp8_impl::cvt_bhalf_t_to_fp8<f8_ocp_t::default_interpret, f8_ocp_t::default_saturation>(x)};
}
/**
* @brief Converts a bhalf_t to a 8-bit half_t type (bf8_ocp_t) using rounding to nearest/even.
*
* @param x The input bhalf_t value.
* @return The converted bf8_ocp_t value.
*/
template <>
inline __host__ __device__ bf8_ocp_t f8_convert_rne<bf8_ocp_t, bhalf_t>(bhalf_t x)
{
return bf8_ocp_t{
fp8_impl::cvt_bhalf_t_to_fp8<bf8_ocp_t::default_interpret, bf8_ocp_t::default_saturation>(
x)};
}
/**
* @brief Converts a vector of 2 bhalf_t to a vector of 2 8-bit float types (bf8_ocp_t) using
* rounding to nearest/even.
*
* @param x The input vector of 2 bhalf_t.
* @return The converted vector of 2 bf8_ocp_t.
*/
template <>
inline __host__ __device__ bf8x2_ocp_t f8_convert_rne<bf8x2_ocp_t, bhalf2_t>(bhalf2_t x)
{
return bf8x2_ocp_t{
fp8_impl::cvt_bhalf_t_to_fp8<bf8_ocp_t::default_interpret, bf8_ocp_t::default_saturation>(
x)};
}
// convert fp32 to fp8
template <>
inline __host__ __device__ f8_fnuz_t type_convert<f8_fnuz_t, float>(float x)
@@ -477,17 +817,6 @@ inline __host__ __device__ f8_fnuz_t type_convert<f8_fnuz_t, float>(float x)
#endif
}
// convert fp32 to fp8
template <>
inline __host__ __device__ f8_ocp_t type_convert<f8_ocp_t, float>(float x)
{
#if CK_USE_SR_F8_CONVERSION
return f8_convert_sr<f8_ocp_t>(x);
#else
return f8_convert_rne<f8_ocp_t>(x);
#endif
}
// convert fp8 to fp32
template <>
inline __host__ __device__ float type_convert<float, f8_fnuz_t>(f8_fnuz_t x)
@@ -524,12 +853,39 @@ inline __host__ __device__ float2_t type_convert<float2_t, f8x2_fnuz_t>(f8x2_fnu
#endif
}
/**
* @brief Converts a f8_ocp_t value to a float value.
*
* @param x The input f8_ocp_t value.
* @return The converted float value.
*/
template <>
inline __host__ __device__ float type_convert<float, f8_ocp_t>(f8_ocp_t x)
{
#if CK_OCP_FP8_CVT_FAST_PATH
union
{
unsigned int i32val;
fp8_storage_t i8val[4];
} val;
val.i8val[0] = x.data;
return __builtin_amdgcn_cvt_f32_fp8(val.i32val, 0);
#else
return fp8_impl::cast_from_f8<float, f8_ocp_t::wm, f8_ocp_t::we, false>(x.data);
#endif
}
/**
* @brief Converts a vector of 2 f8_ocp_t values to a vector of 2 float values.
*
* @param x The input vector of 2 f8_ocp_t values.
* @return The converted vector of 2 float values.
*/
template <>
inline __host__ __device__ float2_t type_convert<float2_t, f8x2_ocp_t>(f8x2_ocp_t x)
{
#if CK_OCP_FP8_CVT_FAST_PATH
return fp8_impl::cast_to_f32x2_from_f8x2<f8_ocp_t::default_interpret>(
x.AsType<fp8_impl::fp8x2_storage_t>()[Number<0>{}]);
return __builtin_amdgcn_cvt_pk_f32_fp8(bit_cast<uint16_t>(x), false);
#else
return float2_t{fp8_impl::cast_from_f8<float, f8_ocp_t::wm, f8_ocp_t::we, false>(
x.AsType<fp8_storage_t>()[Number<0>{}]),
@@ -538,6 +894,229 @@ inline __host__ __device__ float2_t type_convert<float2_t, f8x2_ocp_t>(f8x2_ocp_
#endif
}
/**
* @brief Converts a f8_ocp_t value to a half_t value.
*
* @param x The input f8_ocp_t value.
* @return The converted half_t value.
*/
template <>
inline __host__ __device__ half_t type_convert<half_t, f8_ocp_t>(f8_ocp_t x)
{
#if defined(__gfx950__)
union
{
uint16_t i16val;
fp8_storage_t i8val[2];
} input;
input.i8val[0] = x.data;
union
{
half2_t half_vec;
half_t half_arr[2];
} output;
output.half_vec = __builtin_amdgcn_cvt_scalef32_pk_f16_fp8(input.i16val, /*scale*/ 1.f, 0);
return output.half_arr[0];
#else
return fp8_impl::cast_from_f8<half_t, f8_ocp_t::wm, f8_ocp_t::we, false>(x.data);
#endif
}
/**
* @brief Converts a vector of 2 f8_ocp_t values to a vector of 2 half_t values.
*
* @param x The input vector of 2 f8_ocp_t values.
* @return The converted vector of 2 half_t values.
*/
template <>
inline __host__ __device__ half2_t type_convert<half2_t, f8x2_ocp_t>(f8x2_ocp_t x)
{
#if defined(__gfx950__)
return __builtin_amdgcn_cvt_scalef32_pk_f16_fp8(bit_cast<uint16_t>(x), /*scale*/ 1.f, 0);
#else
return half2_t{type_convert<half_t>(float(x.AsType<f8_ocp_t>()[Number<0>{}])),
type_convert<half_t>(float(x.AsType<f8_ocp_t>()[Number<1>{}]))};
#endif
}
/**
* @brief Converts a f8_ocp_t value to a bhalf_t value.
*
* @param x The input f8_ocp_t value.
* @return The converted bhalf_t value.
*/
template <>
inline __host__ __device__ bhalf_t type_convert<bhalf_t, f8_ocp_t>(f8_ocp_t x)
{
#if defined(__gfx950__)
union
{
uint16_t i16val;
fp8_storage_t i8val[2];
} input;
input.i8val[0] = x.data;
union
{
bhalf2_t bhalf_vec;
bhalf_t bhalf_arr[2];
} output;
output.bhalf_vec = __builtin_amdgcn_cvt_scalef32_pk_bf16_fp8(input.i16val, /*scale*/ 1.f, 0);
return output.bhalf_arr[0];
#else
return type_convert<bhalf_t>(
fp8_impl::cast_from_f8<float, f8_ocp_t::wm, f8_ocp_t::we, false>(x.data));
#endif
}
/**
* @brief Converts a vector of 2 f8_ocp_t values to a vector of 2 bhalf_t values.
*
* @param x The input vector of 2 f8_ocp_t values.
* @return The converted vector of 2 bhalf_t values.
*/
template <>
inline __host__ __device__ bhalf2_t type_convert<bhalf2_t, f8x2_ocp_t>(f8x2_ocp_t x)
{
#if defined(__gfx950__)
return __builtin_amdgcn_cvt_scalef32_pk_bf16_fp8(bit_cast<uint16_t>(x), /*scale*/ 1.f, 0);
#else
return bhalf2_t{type_convert<bhalf_t>(float(x.AsType<f8_ocp_t>()[Number<0>{}])),
type_convert<bhalf_t>(float(x.AsType<f8_ocp_t>()[Number<1>{}]))};
#endif
}
/**
* @brief Converts a bf8_ocp_t value to a float value.
*
* @param x The input bf8_ocp_t value.
* @return The converted float value.
*/
template <>
inline __host__ __device__ float type_convert<float, bf8_ocp_t>(bf8_ocp_t x)
{
#if CK_OCP_FP8_CVT_FAST_PATH
union
{
unsigned int i32val;
fp8_storage_t i8val[4];
} val;
val.i8val[0] = x.data;
return __builtin_amdgcn_cvt_f32_bf8(val.i32val, 0);
#else
return fp8_impl::cast_from_f8<float, bf8_ocp_t::wm, bf8_ocp_t::we, false>(x.data);
#endif
}
/**
* @brief Converts a vector of 2 bf8_ocp_t values to a vector of 2 float values.
*
* @param x The input vector of 2 bf8_ocp_t values.
* @return The converted vector of 2 float values.
*/
template <>
inline __host__ __device__ float2_t type_convert<float2_t, bf8x2_ocp_t>(bf8x2_ocp_t x)
{
#if CK_OCP_FP8_CVT_FAST_PATH
return __builtin_amdgcn_cvt_pk_f32_bf8(bit_cast<uint16_t>(x), false);
#else
return float2_t{fp8_impl::cast_from_f8<float, bf8_ocp_t::wm, bf8_ocp_t::we, false>(
x.AsType<fp8_storage_t>()[Number<0>{}]),
fp8_impl::cast_from_f8<float, bf8_ocp_t::wm, bf8_ocp_t::we, false>(
x.AsType<fp8_storage_t>()[Number<1>{}])};
#endif
}
/**
* @brief Converts a bf8_ocp_t value to a half_t value.
*
* @param x The input bf8_ocp_t value.
* @return The converted half_t value.
*/
template <>
inline __host__ __device__ half_t type_convert<half_t, bf8_ocp_t>(bf8_ocp_t x)
{
#if defined(__gfx950__)
union
{
uint16_t i16val;
fp8_storage_t i8val[2];
} val;
val.i8val[0] = x.data;
return __builtin_amdgcn_cvt_scalef32_pk_f16_bf8(val.i16val, /*scale*/ 1.f, 0)[0];
#else
return fp8_impl::cast_from_f8<half_t, bf8_ocp_t::wm, bf8_ocp_t::we, false>(x.data);
#endif
}
/**
* @brief Converts a vector of 2 bf8_ocp_t values to a vector of 2 half_t values.
*
* @param x The input vector of 2 bf8_ocp_t values.
* @return The converted vector of 2 half_t values.
*/
template <>
inline __host__ __device__ half2_t type_convert<half2_t, bf8x2_ocp_t>(bf8x2_ocp_t x)
{
#if defined(__gfx950__)
return __builtin_amdgcn_cvt_scalef32_pk_f16_bf8(bit_cast<uint16_t>(x), /*scale*/ 1.f, 0);
#else
return half2_t{type_convert<half_t>(float(x.AsType<bf8_ocp_t>()[Number<0>{}])),
type_convert<half_t>(float(x.AsType<bf8_ocp_t>()[Number<1>{}]))};
#endif
}
/**
* @brief Converts a bf8_ocp_t value to a bhalf_t value.
*
* @param x The input bf8_ocp_t value.
* @return The converted bhalf_t value.
*/
template <>
inline __host__ __device__ bhalf_t type_convert<bhalf_t, bf8_ocp_t>(bf8_ocp_t x)
{
#if defined(__gfx950__)
union
{
uint16_t i16val;
fp8_storage_t i8val[2];
} input;
input.i8val[0] = x.data;
union
{
bhalf2_t bhalf_vec;
bhalf_t bhalf_arr[2];
} output;
output.bhalf_vec = __builtin_amdgcn_cvt_scalef32_pk_bf16_bf8(input.i16val, /*scale*/ 1.f, 0);
return output.bhalf_arr[0];
#else
return type_convert<bhalf_t>(
fp8_impl::cast_from_f8<float, bf8_ocp_t::wm, bf8_ocp_t::we, false>(x.data));
#endif
}
/**
* @brief Converts a vector of 2 bf8_ocp_t values to a vector of 2 bhalf_t values.
*
* @param x The input vector of 2 bf8_ocp_t values.
* @return The converted vector of 2 bhalf_t values.
*/
template <>
inline __host__ __device__ bhalf2_t type_convert<bhalf2_t, bf8x2_ocp_t>(bf8x2_ocp_t x)
{
#if defined(__gfx950__)
return __builtin_amdgcn_cvt_scalef32_pk_bf16_bf8(bit_cast<uint16_t>(x), /*scale*/ 1.f, 0);
#else
return bhalf2_t{type_convert<bhalf_t>(float(x.AsType<bf8_ocp_t>()[Number<0>{}])),
type_convert<bhalf_t>(float(x.AsType<bf8_ocp_t>()[Number<1>{}]))};
#endif
}
template <>
inline __host__ __device__ float2_t type_convert<float2_t, pk_i4_t>(pk_i4_t x)
{
@@ -610,7 +1189,12 @@ inline __host__ __device__ f8_fnuz_t type_convert<f8_fnuz_t, half_t>(half_t x)
#endif
}
// convert fp16 to fp8
/**
* @brief Converts a half_t value to a f8_ocp_t value with rounding determined by a flag.
*
* @param x The input half_t value.
* @return The converted f8_ocp_t value.
*/
template <>
inline __host__ __device__ f8_ocp_t type_convert<f8_ocp_t, half_t>(half_t x)
{
@@ -621,6 +1205,22 @@ inline __host__ __device__ f8_ocp_t type_convert<f8_ocp_t, half_t>(half_t x)
#endif
}
/**
* @brief Converts a half_t value to a bf8_ocp_t value with rounding determined by a flag.
*
* @param x The input half_t value.
* @return The converted bf8_ocp_t value.
*/
template <>
inline __host__ __device__ bf8_ocp_t type_convert<bf8_ocp_t, half_t>(half_t x)
{
#if CK_USE_SR_F8_CONVERSION
return f8_convert_sr<bf8_ocp_t>(x);
#else
return f8_convert_rne<bf8_ocp_t>(x);
#endif
}
// convert fp8 to fp16
template <>
inline __host__ __device__ half_t type_convert<half_t, f8_fnuz_t>(f8_fnuz_t x)
@@ -645,7 +1245,28 @@ inline __host__ __device__ bf8_fnuz_t type_convert<bf8_fnuz_t, float>(float x)
#endif
}
// convert fp32 to bf8
/**
* @brief Converts a float value to a f8_ocp_t value with rounding determined by a flag.
*
* @param x The input float value.
* @return The converted f8_ocp_t value.
*/
template <>
inline __host__ __device__ f8_ocp_t type_convert<f8_ocp_t, float>(float x)
{
#if CK_USE_SR_F8_CONVERSION
return f8_convert_sr<f8_ocp_t>(x);
#else
return f8_convert_rne<f8_ocp_t>(x);
#endif
}
/**
* @brief Converts a float value to a bf8_ocp_t value with rounding determined by a flag.
*
* @param x The input float value.
* @return The converted bf8_ocp_t value.
*/
template <>
inline __host__ __device__ bf8_ocp_t type_convert<bf8_ocp_t, float>(float x)
{
@@ -656,6 +1277,38 @@ inline __host__ __device__ bf8_ocp_t type_convert<bf8_ocp_t, float>(float x)
#endif
}
/**
* @brief Converts a bhalf_t value to a f8_ocp_t value with rounding determined by a flag.
*
* @param x The input bhalf_t value.
* @return The converted f8_ocp_t value.
*/
template <>
inline __host__ __device__ f8_ocp_t type_convert<f8_ocp_t, bhalf_t>(bhalf_t x)
{
#if CK_USE_SR_F8_CONVERSION
return f8_convert_sr<f8_ocp_t>(x);
#else
return f8_convert_rne<f8_ocp_t>(x);
#endif
}
/**
* @brief Converts a bhalf_t value to a bf8_ocp_t value with rounding determined by a flag.
*
* @param x The input bhalf_t value.
* @return The converted bf8_ocp_t value.
*/
template <>
inline __host__ __device__ bf8_ocp_t type_convert<bf8_ocp_t, bhalf_t>(bhalf_t x)
{
#if CK_USE_SR_F8_CONVERSION
return f8_convert_sr<bf8_ocp_t>(x);
#else
return f8_convert_rne<bf8_ocp_t>(x);
#endif
}
// convert bf8 to fp32
template <>
inline __host__ __device__ float type_convert<float, bf8_fnuz_t>(bf8_fnuz_t x)
@@ -683,17 +1336,6 @@ inline __host__ __device__ bf8_fnuz_t type_convert<bf8_fnuz_t, half_t>(half_t x)
#endif
}
// convert fp16 to bf8
template <>
inline __host__ __device__ bf8_ocp_t type_convert<bf8_ocp_t, half_t>(half_t x)
{
#if CK_USE_SR_F8_CONVERSION
return f8_convert_sr<bf8_ocp_t>(x);
#else
return f8_convert_rne<bf8_ocp_t>(x);
#endif
}
// convert bf8 to fp16
template <>
inline __host__ __device__ half_t type_convert<half_t, bf8_fnuz_t>(bf8_fnuz_t x)

View File

@@ -1,13 +1,19 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include "ck/library/utility/device_memory.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/utility/type_convert.hpp"
using ck::bf8_ocp_t;
using ck::bf8x2_ocp_t;
using ck::bhalf2_t;
using ck::bhalf_t;
using ck::f8_convert_rne;
using ck::f8_convert_sr;
using ck::float2_t;
using ck::half2_t;
using ck::half_t;
using ck::type_convert;
@@ -266,3 +272,590 @@ TEST(BF8OCP, ConvertFP16Stochastic)
const auto bf8_nan = f8_convert_sr<bf8_ocp_t>(ck::NumericLimits<half_t>::QuietNaN());
ASSERT_TRUE(ck::fp8_impl::ocp_bf8_is_nan(bf8_nan.data));
}
constexpr uint64_t test_size = 256 + 6;
__host__ __device__ void
test_fp32_bf8_type_convert(uint64_t N, float* p_test, uint64_t* p_completed)
{
if(p_completed == nullptr)
{
return;
}
uint64_t& i = *p_completed;
i = 0;
if(p_test == nullptr)
{
return;
}
for(ck::index_t bf8_id = 0; bf8_id < 256; bf8_id++)
{
uint8_t bf8_uid = static_cast<uint8_t>(bf8_id);
auto v = type_convert<float>(bf8_ocp_t{bf8_uid});
p_test[i] = v;
i++;
if(i >= N)
{
return;
}
}
/// Test vector conversion
// bf8x2 -> fp32x2
bf8x2_ocp_t bf8x2{bf8x2_ocp_t::data_v{0b10000100, 0b00000001}}; //-2^-14, 2^-16
float2_t f32x2 = type_convert<float2_t>(bf8x2);
p_test[i++] = f32x2[0];
if(i >= N)
{
return;
}
p_test[i++] = f32x2[1];
if(i >= N)
{
return;
}
// fp32x2 -> bf8x2
f32x2 = {-4.0f, 2.0f};
bf8x2 = f8_convert_rne<bf8x2_ocp_t>(f32x2); // expect {-4, 2}
p_test[i++] = type_convert<float>(bf8x2.AsType<bf8_ocp_t>()(ck::Number<0>{})); //-4f
if(i >= N)
{
return;
}
p_test[i++] = type_convert<float>(bf8x2.AsType<bf8_ocp_t>()(ck::Number<1>{})); // 2f
if(i >= N)
{
return;
}
bf8x2 = f8_convert_sr<bf8x2_ocp_t>(f32x2); // expect {-4, 2}
p_test[i++] = type_convert<float>(bf8x2.AsType<bf8_ocp_t>()(ck::Number<0>{})); //-4f
if(i >= N)
{
return;
}
p_test[i++] = type_convert<float>(bf8x2.AsType<bf8_ocp_t>()(ck::Number<1>{})); // 2f
if(i >= N)
{
return;
}
}
TEST(BF8OCP, HostFP32BF8Convert)
{
std::vector<float> out(test_size, -1.0f);
uint64_t completed = 0;
test_fp32_bf8_type_convert(test_size, out.data(), &completed);
std::set<uint8_t> bf8_nan_ids;
bf8_nan_ids.insert(0b11111111);
bf8_nan_ids.insert(0b01111111);
bf8_nan_ids.insert(0b11111101);
bf8_nan_ids.insert(0b01111101);
bf8_nan_ids.insert(0b11111110);
bf8_nan_ids.insert(0b01111110);
for(auto bf8_nan_id : bf8_nan_ids)
{
auto idx = bf8_nan_id;
ASSERT_TRUE(std::isnan(out[idx]));
}
for(ck::index_t bf8_id = 0; bf8_id < 256; bf8_id++)
{
if(bf8_nan_ids.find(bf8_id) != bf8_nan_ids.end())
continue;
uint8_t bf8_uid = static_cast<uint8_t>(bf8_id);
auto idx = bf8_uid;
ASSERT_FLOAT_EQ(out[idx], type_convert<float>(bf8_ocp_t{bf8_uid}))
<< " bf8_id: " << bf8_id << std::endl
<< type_convert<float>(bf8_ocp_t{bf8_uid});
}
// /// Test vector conversions
auto i = 256;
// bf8x2 -> fp32x2
EXPECT_EQ(out[i++], -powf(2.0f, -14.0f));
EXPECT_EQ(out[i++], powf(2.0f, -16.0f));
// fp32x2 -> bf8x2
// RNE
EXPECT_EQ(out[i++], -4.0f);
EXPECT_EQ(out[i++], 2.0f);
// SR
EXPECT_EQ(out[i++], -4.0f);
EXPECT_EQ(out[i++], 2.0f);
EXPECT_EQ(test_size, completed);
EXPECT_EQ(test_size, i);
}
__global__ void device_test_fp32_bf8_type_convert(uint64_t N, float* p_test, uint64_t* p_completed)
{
test_fp32_bf8_type_convert(N, p_test, p_completed);
}
TEST(BF8OCP, DeviceFP32BF8Convert)
{
std::vector<float> out(test_size, -1.0f);
DeviceMem device_out(test_size * sizeof(float));
DeviceMem device_completed(sizeof(uint64_t));
device_out.SetValue(-21.0f);
device_completed.SetValue(-21.0f);
device_test_fp32_bf8_type_convert<<<1, 1>>>(
test_size,
static_cast<float*>(device_out.GetDeviceBuffer()),
static_cast<uint64_t*>(device_completed.GetDeviceBuffer()));
uint64_t completed = 0;
device_completed.FromDevice(&completed);
device_out.FromDevice(out.data());
std::set<uint8_t> bf8_nan_ids;
bf8_nan_ids.insert(0b11111111);
bf8_nan_ids.insert(0b01111111);
bf8_nan_ids.insert(0b11111101);
bf8_nan_ids.insert(0b01111101);
bf8_nan_ids.insert(0b11111110);
bf8_nan_ids.insert(0b01111110);
for(auto bf8_nan_id : bf8_nan_ids)
{
auto idx = bf8_nan_id;
ASSERT_TRUE(std::isnan(out[idx])) << "idx: " << idx << " out[idx]: " << out[idx];
}
for(ck::index_t bf8_id = 0; bf8_id < 256; bf8_id++)
{
if(bf8_nan_ids.find(bf8_id) != bf8_nan_ids.end())
continue;
uint8_t bf8_uid = static_cast<uint8_t>(bf8_id);
auto idx = bf8_uid;
ASSERT_FLOAT_EQ(out[idx], type_convert<float>(bf8_ocp_t{bf8_uid}))
<< " bf8_id: " << bf8_id << std::endl
<< type_convert<float>(bf8_ocp_t{bf8_uid});
}
/// Test vector conversions
auto i = 256;
// bf8x2 -> fp32x2
EXPECT_EQ(out[i++], -powf(2.0f, -14.0f));
EXPECT_EQ(out[i++], powf(2.0f, -16.0f));
// fp32x2 -> bf8x2
// RNE
EXPECT_EQ(out[i++], -4.0f);
EXPECT_EQ(out[i++], 2.0f);
// SR
EXPECT_EQ(out[i++], -4.0f);
EXPECT_EQ(out[i++], 2.0f);
EXPECT_EQ(test_size, completed);
EXPECT_EQ(test_size, i);
}
__host__ __device__ void
test_fp16_bf8_type_convert(uint64_t N, half_t* p_test, uint64_t* p_completed)
{
if(p_completed == nullptr)
{
return;
}
uint64_t& i = *p_completed;
i = 0;
if(p_test == nullptr)
{
return;
}
for(ck::index_t bf8_id = 0; bf8_id < 256; bf8_id++)
{
uint8_t bf8_uid = static_cast<uint8_t>(bf8_id);
auto v = type_convert<half_t>(bf8_ocp_t{bf8_uid});
p_test[i] = v;
i++;
if(i >= N)
{
return;
}
}
/// Test vector conversion
// bf8x2 -> fp16x2
bf8x2_ocp_t bf8x2{bf8x2_ocp_t::data_v{0b10000100, 0b00000001}}; //-2^-14, 2^-16
half2_t f16x2 = type_convert<half2_t>(bf8x2);
p_test[i++] = f16x2[0];
if(i >= N)
{
return;
}
p_test[i++] = f16x2[1];
if(i >= N)
{
return;
}
// fp16x2 -> bf8x2
f16x2 = {-4.0f, 2.0f};
bf8x2 = f8_convert_rne<bf8x2_ocp_t>(f16x2); // expect {-4, 2}
p_test[i++] = type_convert<half_t>(bf8x2.AsType<bf8_ocp_t>()(ck::Number<0>{})); //-4f
if(i >= N)
{
return;
}
p_test[i++] = type_convert<half_t>(bf8x2.AsType<bf8_ocp_t>()(ck::Number<1>{})); // 2f
if(i >= N)
{
return;
}
bf8x2 = f8_convert_sr<bf8x2_ocp_t>(f16x2); // expect {-4, 2}
p_test[i++] = type_convert<half_t>(bf8x2.AsType<bf8_ocp_t>()(ck::Number<0>{})); //-4f
if(i >= N)
{
return;
}
p_test[i++] = type_convert<half_t>(bf8x2.AsType<bf8_ocp_t>()(ck::Number<1>{})); // 2f
if(i >= N)
{
return;
}
}
TEST(BF8OCP, HostFP16BF8Convert)
{
std::vector<half_t> out(test_size, -1.0f);
uint64_t completed = 0;
test_fp16_bf8_type_convert(test_size, out.data(), &completed);
std::set<uint8_t> bf8_nan_ids;
bf8_nan_ids.insert(0b11111111);
bf8_nan_ids.insert(0b01111111);
bf8_nan_ids.insert(0b11111101);
bf8_nan_ids.insert(0b01111101);
bf8_nan_ids.insert(0b11111110);
bf8_nan_ids.insert(0b01111110);
for(auto bf8_nan_id : bf8_nan_ids)
{
auto idx = bf8_nan_id;
ASSERT_TRUE(std::isnan(type_convert<float>(out[idx])));
}
for(ck::index_t bf8_id = 0; bf8_id < 256; bf8_id++)
{
if(bf8_nan_ids.find(bf8_id) != bf8_nan_ids.end())
continue;
uint8_t bf8_uid = static_cast<uint8_t>(bf8_id);
auto idx = bf8_uid;
ASSERT_FLOAT_EQ(out[idx], type_convert<half_t>(bf8_ocp_t{bf8_uid}))
<< " bf8_id: " << bf8_id << std::endl
<< type_convert<float>(type_convert<half_t>(bf8_ocp_t{bf8_uid}));
}
// /// Test vector conversions
auto i = 256;
// bf8x2 -> fp16x2
EXPECT_EQ(out[i++], type_convert<half_t>(-powf(2.0f, -14.0f)));
EXPECT_EQ(out[i++], type_convert<half_t>(powf(2.0f, -16.0f)));
// fp16x2 -> bf8x2
// RNE
EXPECT_EQ(out[i++], type_convert<half_t>(-4.0f));
EXPECT_EQ(out[i++], type_convert<half_t>(2.0f));
// SR
EXPECT_EQ(out[i++], type_convert<half_t>(-4.0f));
EXPECT_EQ(out[i++], type_convert<half_t>(2.0f));
EXPECT_EQ(test_size, completed);
EXPECT_EQ(test_size, i);
}
__global__ void device_test_fp16_bf8_type_convert(uint64_t N, half_t* p_test, uint64_t* p_completed)
{
test_fp16_bf8_type_convert(N, p_test, p_completed);
}
TEST(BF8OCP, DeviceFP16BF8Convert)
{
std::vector<half_t> out(test_size, -1.0f);
DeviceMem device_out(test_size * sizeof(half_t));
DeviceMem device_completed(sizeof(uint64_t));
device_out.SetValue(-21.0f);
device_completed.SetValue(-21.0f);
device_test_fp16_bf8_type_convert<<<1, 1>>>(
test_size,
static_cast<half_t*>(device_out.GetDeviceBuffer()),
static_cast<uint64_t*>(device_completed.GetDeviceBuffer()));
uint64_t completed = 0;
device_completed.FromDevice(&completed);
device_out.FromDevice(out.data());
std::set<uint8_t> bf8_nan_ids;
bf8_nan_ids.insert(0b11111111);
bf8_nan_ids.insert(0b01111111);
bf8_nan_ids.insert(0b11111101);
bf8_nan_ids.insert(0b01111101);
bf8_nan_ids.insert(0b11111110);
bf8_nan_ids.insert(0b01111110);
for(auto bf8_nan_id : bf8_nan_ids)
{
auto idx = bf8_nan_id;
ASSERT_TRUE(std::isnan(type_convert<float>(out[idx])))
<< "idx: " << idx << " out[idx]: " << type_convert<float>(out[idx]);
}
for(ck::index_t bf8_id = 0; bf8_id < 256; bf8_id++)
{
if(bf8_nan_ids.find(bf8_id) != bf8_nan_ids.end())
continue;
uint8_t bf8_uid = static_cast<uint8_t>(bf8_id);
auto idx = bf8_uid;
ASSERT_FLOAT_EQ(out[idx], type_convert<half_t>(bf8_ocp_t{bf8_uid}))
<< " bf8_id: " << bf8_id << std::endl
<< type_convert<float>(type_convert<half_t>(bf8_ocp_t{bf8_uid}));
}
/// Test vector conversions
auto i = 256;
// bf8x2 -> fp16x2
EXPECT_EQ(out[i++], type_convert<half_t>(-powf(2.0f, -14.0f)));
EXPECT_EQ(out[i++], type_convert<half_t>(powf(2.0f, -16.0f)));
// fp16x2 -> bf8x2
// RNE
EXPECT_EQ(out[i++], type_convert<half_t>(-4.0f));
EXPECT_EQ(out[i++], type_convert<half_t>(2.0f));
// SR
EXPECT_EQ(out[i++], type_convert<half_t>(-4.0f));
EXPECT_EQ(out[i++], type_convert<half_t>(2.0f));
EXPECT_EQ(test_size, completed);
EXPECT_EQ(test_size, i);
}
__host__ __device__ void
test_bf16_bf8_type_convert(uint64_t N, bhalf_t* p_test, uint64_t* p_completed)
{
if(p_completed == nullptr)
{
return;
}
uint64_t& i = *p_completed;
i = 0;
if(p_test == nullptr)
{
return;
}
for(ck::index_t bf8_id = 0; bf8_id < 256; bf8_id++)
{
uint8_t bf8_uid = static_cast<uint8_t>(bf8_id);
auto v = type_convert<bhalf_t>(bf8_ocp_t{bf8_uid});
p_test[i] = v;
i++;
if(i >= N)
{
return;
}
}
/// Test vector conversion
// bf8x2 -> bf16x2
bf8x2_ocp_t bf8x2{bf8x2_ocp_t::data_v{0b10000100, 0b00000001}}; //-2^-14, 2^-16
bhalf2_t bf16x2 = type_convert<bhalf2_t>(bf8x2);
p_test[i++] = bf16x2[0];
if(i >= N)
{
return;
}
p_test[i++] = bf16x2[1];
if(i >= N)
{
return;
}
// bf16x2 -> bf8x2
bf16x2 = {type_convert<bhalf_t>(-4.0f), type_convert<bhalf_t>(2.0f)};
bf8x2 = f8_convert_rne<bf8x2_ocp_t>(bf16x2); // expect {-4, 2}
p_test[i++] = type_convert<bhalf_t>(bf8x2.AsType<bf8_ocp_t>()(ck::Number<0>{})); //-4f
if(i >= N)
{
return;
}
p_test[i++] = type_convert<bhalf_t>(bf8x2.AsType<bf8_ocp_t>()(ck::Number<1>{})); // 2f
if(i >= N)
{
return;
}
bf8x2 = f8_convert_sr<bf8x2_ocp_t>(bf16x2); // expect {-4, 2}
p_test[i++] = type_convert<bhalf_t>(bf8x2.AsType<bf8_ocp_t>()(ck::Number<0>{})); //-4f
if(i >= N)
{
return;
}
p_test[i++] = type_convert<bhalf_t>(bf8x2.AsType<bf8_ocp_t>()(ck::Number<1>{})); // 2f
if(i >= N)
{
return;
}
}
TEST(BF8OCP, HostBF16BF8Convert)
{
std::vector<bhalf_t> out(test_size, -1.0f);
uint64_t completed = 0;
test_bf16_bf8_type_convert(test_size, out.data(), &completed);
std::set<uint8_t> bf8_nan_ids;
bf8_nan_ids.insert(0b11111111);
bf8_nan_ids.insert(0b01111111);
bf8_nan_ids.insert(0b11111101);
bf8_nan_ids.insert(0b01111101);
bf8_nan_ids.insert(0b11111110);
bf8_nan_ids.insert(0b01111110);
for(auto bf8_nan_id : bf8_nan_ids)
{
auto idx = bf8_nan_id;
ASSERT_TRUE(std::isnan(type_convert<float>(out[idx])));
}
for(ck::index_t bf8_id = 0; bf8_id < 256; bf8_id++)
{
if(bf8_nan_ids.find(bf8_id) != bf8_nan_ids.end())
continue;
uint8_t bf8_uid = static_cast<uint8_t>(bf8_id);
auto idx = bf8_uid;
ASSERT_FLOAT_EQ(out[idx], type_convert<bhalf_t>(bf8_ocp_t{bf8_uid}))
<< " bf8_id: " << bf8_id << std::endl
<< type_convert<float>(type_convert<bhalf_t>(bf8_ocp_t{bf8_uid}));
}
// /// Test vector conversions
auto i = 256;
// bf8x2 -> bf16x2
EXPECT_EQ(out[i++], type_convert<bhalf_t>(-powf(2.0f, -14.0f)));
EXPECT_EQ(out[i++], type_convert<bhalf_t>(powf(2.0f, -16.0f)));
// bf16x2 -> bf8x2
// RNE
EXPECT_EQ(out[i++], type_convert<bhalf_t>(-4.0f));
EXPECT_EQ(out[i++], type_convert<bhalf_t>(2.0f));
// SR
EXPECT_EQ(out[i++], type_convert<bhalf_t>(-4.0f));
EXPECT_EQ(out[i++], type_convert<bhalf_t>(2.0f));
EXPECT_EQ(test_size, completed);
EXPECT_EQ(test_size, i);
}
__global__ void
device_test_bf16_bf8_type_convert(uint64_t N, bhalf_t* p_test, uint64_t* p_completed)
{
test_bf16_bf8_type_convert(N, p_test, p_completed);
}
TEST(BF8OCP, DeviceBF16BF8Convert)
{
std::vector<bhalf_t> out(test_size, -1.0f);
DeviceMem device_out(test_size * sizeof(bhalf_t));
DeviceMem device_completed(sizeof(uint64_t));
device_out.SetValue(-21.0f);
device_completed.SetValue(-21.0f);
device_test_bf16_bf8_type_convert<<<1, 1>>>(
test_size,
static_cast<bhalf_t*>(device_out.GetDeviceBuffer()),
static_cast<uint64_t*>(device_completed.GetDeviceBuffer()));
uint64_t completed = 0;
device_completed.FromDevice(&completed);
device_out.FromDevice(out.data());
std::set<uint8_t> bf8_nan_ids;
bf8_nan_ids.insert(0b11111111);
bf8_nan_ids.insert(0b01111111);
bf8_nan_ids.insert(0b11111101);
bf8_nan_ids.insert(0b01111101);
bf8_nan_ids.insert(0b11111110);
bf8_nan_ids.insert(0b01111110);
for(auto bf8_nan_id : bf8_nan_ids)
{
auto idx = bf8_nan_id;
ASSERT_TRUE(std::isnan(type_convert<float>(out[idx])))
<< "idx: " << idx << " out[idx]: " << type_convert<float>(out[idx]);
}
for(ck::index_t bf8_id = 0; bf8_id < 256; bf8_id++)
{
if(bf8_nan_ids.find(bf8_id) != bf8_nan_ids.end())
continue;
uint8_t bf8_uid = static_cast<uint8_t>(bf8_id);
auto idx = bf8_uid;
ASSERT_FLOAT_EQ(out[idx], type_convert<bhalf_t>(bf8_ocp_t{bf8_uid}))
<< " bf8_id: " << bf8_id << std::endl
<< type_convert<float>(type_convert<bhalf_t>(bf8_ocp_t{bf8_uid}));
}
/// Test vector conversions
auto i = 256;
// bf8x2 -> bf16x2
EXPECT_EQ(out[i++], type_convert<bhalf_t>(-powf(2.0f, -14.0f)));
EXPECT_EQ(out[i++], type_convert<bhalf_t>(powf(2.0f, -16.0f)));
// bf16x2 -> bf8x2
// RNE
EXPECT_EQ(out[i++], type_convert<bhalf_t>(-4.0f));
EXPECT_EQ(out[i++], type_convert<bhalf_t>(2.0f));
// SR
EXPECT_EQ(out[i++], type_convert<bhalf_t>(-4.0f));
EXPECT_EQ(out[i++], type_convert<bhalf_t>(2.0f));
EXPECT_EQ(test_size, completed);
EXPECT_EQ(test_size, i);
}

View File

@@ -1,13 +1,19 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include "ck/library/utility/device_memory.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/utility/type_convert.hpp"
using ck::bhalf2_t;
using ck::bhalf_t;
using ck::f8_convert_rne;
using ck::f8_convert_sr;
using ck::f8_ocp_t;
using ck::f8x2_ocp_t;
using ck::float2_t;
using ck::half2_t;
using ck::half_t;
using ck::type_convert;
@@ -248,3 +254,566 @@ TEST(FP8OCP, ConvertFP16Stochastic)
auto f8_nan = f8_convert_sr<f8_ocp_t>(ck::NumericLimits<half_t>::QuietNaN());
ASSERT_TRUE(ck::fp8_impl::ocp_f8_is_nan(f8_nan.data));
}
constexpr uint64_t test_size = 256 + 6;
__host__ __device__ void
test_fp32_fp8_type_convert(uint64_t N, float* p_test, uint64_t* p_completed)
{
if(p_completed == nullptr)
{
return;
}
uint64_t& i = *p_completed;
i = 0;
if(p_test == nullptr)
{
return;
}
for(ck::index_t fp8_id = 0; fp8_id < 256; fp8_id++)
{
uint8_t fp8_uid = static_cast<uint8_t>(fp8_id);
auto v = type_convert<float>(f8_ocp_t{fp8_uid});
p_test[i] = v;
i++;
if(i >= N)
{
return;
}
}
/// Test vector conversion
// fp8x2 -> fp32x2
f8x2_ocp_t fp8x2{f8x2_ocp_t::data_v{0b10001000, 0b00000001}}; //-2^-6, 2^-9
float2_t f32x2 = type_convert<float2_t>(fp8x2);
p_test[i++] = f32x2[0];
if(i >= N)
{
return;
}
p_test[i++] = f32x2[1];
if(i >= N)
{
return;
}
// fp32x2 -> fp8x2
f32x2 = {-4.0f, 2.0f};
fp8x2 = f8_convert_rne<f8x2_ocp_t>(f32x2); // expect {-4, 2}
p_test[i++] = type_convert<float>(fp8x2.AsType<f8_ocp_t>()(ck::Number<0>{})); //-4f
if(i >= N)
{
return;
}
p_test[i++] = type_convert<float>(fp8x2.AsType<f8_ocp_t>()(ck::Number<1>{})); // 2f
if(i >= N)
{
return;
}
fp8x2 = f8_convert_sr<f8x2_ocp_t>(f32x2); // expect {-4, 2}
p_test[i++] = type_convert<float>(fp8x2.AsType<f8_ocp_t>()(ck::Number<0>{})); //-4f
if(i >= N)
{
return;
}
p_test[i++] = type_convert<float>(fp8x2.AsType<f8_ocp_t>()(ck::Number<1>{})); // 2f
if(i >= N)
{
return;
}
}
TEST(FP8OCP, HostFP32FP8Convert)
{
std::vector<float> out(test_size, -1.0f);
uint64_t completed = 0;
test_fp32_fp8_type_convert(test_size, out.data(), &completed);
std::set<uint8_t> fp8_nan_ids;
fp8_nan_ids.insert(0b11111111); //-NaN
fp8_nan_ids.insert(0b01111111); // +NaN
for(auto fp8_nan_id : fp8_nan_ids)
{
auto idx = fp8_nan_id;
ASSERT_TRUE(std::isnan(out[idx]));
}
for(ck::index_t fp8_id = 0; fp8_id < 256; fp8_id++)
{
if(fp8_nan_ids.find(fp8_id) != fp8_nan_ids.end())
continue;
uint8_t fp8_uid = static_cast<uint8_t>(fp8_id);
auto idx = fp8_uid;
ASSERT_FLOAT_EQ(out[idx], type_convert<float>(f8_ocp_t{fp8_uid}))
<< " fp8_id: " << fp8_id << std::endl
<< type_convert<float>(f8_ocp_t{fp8_uid});
}
// /// Test vector conversions
auto i = 256;
// fp8x2 -> fp32x2
EXPECT_EQ(out[i++], -powf(2.0f, -6.0f));
EXPECT_EQ(out[i++], powf(2.0f, -9.0f));
// fp32x2 -> fp8x2
// RNE
EXPECT_EQ(out[i++], -4.0f);
EXPECT_EQ(out[i++], 2.0f);
// SR
EXPECT_EQ(out[i++], -4.0f);
EXPECT_EQ(out[i++], 2.0f);
EXPECT_EQ(test_size, completed);
EXPECT_EQ(test_size, i);
}
__global__ void device_test_fp32_fp8_type_convert(uint64_t N, float* p_test, uint64_t* p_completed)
{
test_fp32_fp8_type_convert(N, p_test, p_completed);
}
TEST(FP8OCP, DeviceFP32FP8Convert)
{
std::vector<float> out(test_size, -1.0f);
DeviceMem device_out(test_size * sizeof(float));
DeviceMem device_completed(sizeof(uint64_t));
device_out.SetValue(-21.0f);
device_completed.SetValue(-21.0f);
device_test_fp32_fp8_type_convert<<<1, 1>>>(
test_size,
static_cast<float*>(device_out.GetDeviceBuffer()),
static_cast<uint64_t*>(device_completed.GetDeviceBuffer()));
uint64_t completed = 0;
device_completed.FromDevice(&completed);
device_out.FromDevice(out.data());
std::set<uint8_t> fp8_nan_ids;
fp8_nan_ids.insert(0b11111111); //-NaN
fp8_nan_ids.insert(0b01111111); // +NaN
for(auto fp8_nan_id : fp8_nan_ids)
{
auto idx = fp8_nan_id;
ASSERT_TRUE(std::isnan(out[idx])) << "idx: " << idx << " out[idx]: " << out[idx];
}
for(ck::index_t fp8_id = 0; fp8_id < 256; fp8_id++)
{
if(fp8_nan_ids.find(fp8_id) != fp8_nan_ids.end())
continue;
uint8_t fp8_uid = static_cast<uint8_t>(fp8_id);
auto idx = fp8_uid;
ASSERT_FLOAT_EQ(out[idx], type_convert<float>(f8_ocp_t{fp8_uid}))
<< " fp8_id: " << fp8_id << std::endl
<< type_convert<float>(f8_ocp_t{fp8_uid});
}
/// Test vector conversions
auto i = 256;
// fp8x2 -> fp32x2
EXPECT_EQ(out[i++], -powf(2.0f, -6.0f));
EXPECT_EQ(out[i++], powf(2.0f, -9.0f));
// fp32x2 -> fp8x2
// RNE
EXPECT_EQ(out[i++], -4.0f);
EXPECT_EQ(out[i++], 2.0f);
// SR
EXPECT_EQ(out[i++], -4.0f);
EXPECT_EQ(out[i++], 2.0f);
EXPECT_EQ(test_size, completed);
EXPECT_EQ(test_size, i);
}
__host__ __device__ void
test_fp16_fp8_type_convert(uint64_t N, half_t* p_test, uint64_t* p_completed)
{
if(p_completed == nullptr)
{
return;
}
uint64_t& i = *p_completed;
i = 0;
if(p_test == nullptr)
{
return;
}
for(ck::index_t fp8_id = 0; fp8_id < 256; fp8_id++)
{
uint8_t fp8_uid = static_cast<uint8_t>(fp8_id);
auto v = type_convert<half_t>(f8_ocp_t{fp8_uid});
p_test[i] = v;
i++;
if(i >= N)
{
return;
}
}
/// Test vector conversion
// fp8x2 -> fp16x2
f8x2_ocp_t fp8x2{f8x2_ocp_t::data_v{0b10001000, 0b00000001}}; //-2^-6, 2^-9
half2_t f16x2 = type_convert<half2_t>(fp8x2);
p_test[i++] = f16x2[0];
if(i >= N)
{
return;
}
p_test[i++] = f16x2[1];
if(i >= N)
{
return;
}
// fp16x2 -> fp8x2
f16x2 = {-4.0f, 2.0f};
fp8x2 = f8_convert_rne<f8x2_ocp_t>(f16x2); // expect {-4, 2}
p_test[i++] = type_convert<half_t>(fp8x2.AsType<f8_ocp_t>()(ck::Number<0>{})); //-4f
if(i >= N)
{
return;
}
p_test[i++] = type_convert<half_t>(fp8x2.AsType<f8_ocp_t>()(ck::Number<1>{})); // 2f
if(i >= N)
{
return;
}
fp8x2 = f8_convert_sr<f8x2_ocp_t>(f16x2); // expect {-4, 2}
p_test[i++] = type_convert<half_t>(fp8x2.AsType<f8_ocp_t>()(ck::Number<0>{})); //-4f
if(i >= N)
{
return;
}
p_test[i++] = type_convert<half_t>(fp8x2.AsType<f8_ocp_t>()(ck::Number<1>{})); // 2f
if(i >= N)
{
return;
}
}
TEST(FP8OCP, HostFP16FP8Convert)
{
std::vector<half_t> out(test_size, -1.0f);
uint64_t completed = 0;
test_fp16_fp8_type_convert(test_size, out.data(), &completed);
std::set<uint8_t> fp8_nan_ids;
fp8_nan_ids.insert(0b11111111); //-NaN
fp8_nan_ids.insert(0b01111111); // +NaN
for(auto fp8_nan_id : fp8_nan_ids)
{
auto idx = fp8_nan_id;
ASSERT_TRUE(std::isnan(type_convert<float>(out[idx])));
}
for(ck::index_t fp8_id = 0; fp8_id < 256; fp8_id++)
{
if(fp8_nan_ids.find(fp8_id) != fp8_nan_ids.end())
continue;
uint8_t fp8_uid = static_cast<uint8_t>(fp8_id);
auto idx = fp8_uid;
ASSERT_FLOAT_EQ(out[idx], type_convert<half_t>(f8_ocp_t{fp8_uid}))
<< " fp8_id: " << fp8_id << std::endl
<< type_convert<float>(type_convert<half_t>(f8_ocp_t{fp8_uid}));
}
// /// Test vector conversions
auto i = 256;
// fp8x2 -> fp16x2
EXPECT_EQ(out[i++], type_convert<half_t>(-powf(2.0f, -6.0f)));
EXPECT_EQ(out[i++], type_convert<half_t>(powf(2.0f, -9.0f)));
// fp16x2 -> fp8x2
// RNE
EXPECT_EQ(out[i++], type_convert<half_t>(-4.0f));
EXPECT_EQ(out[i++], type_convert<half_t>(2.0f));
// SR
EXPECT_EQ(out[i++], type_convert<half_t>(-4.0f));
EXPECT_EQ(out[i++], type_convert<half_t>(2.0f));
EXPECT_EQ(test_size, completed);
EXPECT_EQ(test_size, i);
}
__global__ void device_test_fp16_fp8_type_convert(uint64_t N, half_t* p_test, uint64_t* p_completed)
{
test_fp16_fp8_type_convert(N, p_test, p_completed);
}
TEST(FP8OCP, DeviceFP16FP8Convert)
{
std::vector<half_t> out(test_size, -1.0f);
DeviceMem device_out(test_size * sizeof(half_t));
DeviceMem device_completed(sizeof(uint64_t));
device_out.SetValue(-21.0f);
device_completed.SetValue(-21.0f);
device_test_fp16_fp8_type_convert<<<1, 1>>>(
test_size,
static_cast<half_t*>(device_out.GetDeviceBuffer()),
static_cast<uint64_t*>(device_completed.GetDeviceBuffer()));
uint64_t completed = 0;
device_completed.FromDevice(&completed);
device_out.FromDevice(out.data());
std::set<uint8_t> fp8_nan_ids;
fp8_nan_ids.insert(0b11111111); //-NaN
fp8_nan_ids.insert(0b01111111); // +NaN
for(auto fp8_nan_id : fp8_nan_ids)
{
auto idx = fp8_nan_id;
ASSERT_TRUE(std::isnan(type_convert<float>(out[idx])))
<< "idx: " << idx << " out[idx]: " << type_convert<float>(out[idx]);
}
for(ck::index_t fp8_id = 0; fp8_id < 256; fp8_id++)
{
if(fp8_nan_ids.find(fp8_id) != fp8_nan_ids.end())
continue;
uint8_t fp8_uid = static_cast<uint8_t>(fp8_id);
auto idx = fp8_uid;
ASSERT_FLOAT_EQ(out[idx], type_convert<half_t>(f8_ocp_t{fp8_uid}))
<< " fp8_id: " << fp8_id << std::endl
<< type_convert<float>(type_convert<half_t>(f8_ocp_t{fp8_uid}));
}
/// Test vector conversions
auto i = 256;
// fp8x2 -> fp16x2
EXPECT_EQ(out[i++], type_convert<half_t>(-powf(2.0f, -6.0f)));
EXPECT_EQ(out[i++], type_convert<half_t>(powf(2.0f, -9.0f)));
// fp16x2 -> fp8x2
// RNE
EXPECT_EQ(out[i++], type_convert<half_t>(-4.0f));
EXPECT_EQ(out[i++], type_convert<half_t>(2.0f));
// SR
EXPECT_EQ(out[i++], type_convert<half_t>(-4.0f));
EXPECT_EQ(out[i++], type_convert<half_t>(2.0f));
EXPECT_EQ(test_size, completed);
EXPECT_EQ(test_size, i);
}
__host__ __device__ void
test_bf16_fp8_type_convert(uint64_t N, bhalf_t* p_test, uint64_t* p_completed)
{
if(p_completed == nullptr)
{
return;
}
uint64_t& i = *p_completed;
i = 0;
if(p_test == nullptr)
{
return;
}
for(ck::index_t fp8_id = 0; fp8_id < 256; fp8_id++)
{
uint8_t fp8_uid = static_cast<uint8_t>(fp8_id);
auto v = type_convert<bhalf_t>(f8_ocp_t{fp8_uid});
p_test[i] = v;
i++;
if(i >= N)
{
return;
}
}
/// Test vector conversion
// fp8x2 -> bf16x2
f8x2_ocp_t fp8x2{f8x2_ocp_t::data_v{0b10001000, 0b00000001}}; //-2^-6, 2^-9
bhalf2_t bf16x2 = type_convert<bhalf2_t>(fp8x2);
p_test[i++] = bf16x2[0];
if(i >= N)
{
return;
}
p_test[i++] = bf16x2[1];
if(i >= N)
{
return;
}
// bf16x2 -> fp8x2
bf16x2 = {type_convert<bhalf_t>(-4.0f), type_convert<bhalf_t>(2.0f)};
fp8x2 = f8_convert_rne<f8x2_ocp_t>(bf16x2); // expect {-4, 2}
p_test[i++] = type_convert<bhalf_t>(fp8x2.AsType<f8_ocp_t>()(ck::Number<0>{})); //-4f
if(i >= N)
{
return;
}
p_test[i++] = type_convert<bhalf_t>(fp8x2.AsType<f8_ocp_t>()(ck::Number<1>{})); // 2f
if(i >= N)
{
return;
}
fp8x2 = f8_convert_sr<f8x2_ocp_t>(bf16x2); // expect {-4, 2}
p_test[i++] = type_convert<bhalf_t>(fp8x2.AsType<f8_ocp_t>()(ck::Number<0>{})); //-4f
if(i >= N)
{
return;
}
p_test[i++] = type_convert<bhalf_t>(fp8x2.AsType<f8_ocp_t>()(ck::Number<1>{})); // 2f
if(i >= N)
{
return;
}
}
TEST(FP8OCP, HostBF16FP8Convert)
{
std::vector<bhalf_t> out(test_size, -1.0f);
uint64_t completed = 0;
test_bf16_fp8_type_convert(test_size, out.data(), &completed);
std::set<uint8_t> fp8_nan_ids;
fp8_nan_ids.insert(0b11111111); //-NaN
fp8_nan_ids.insert(0b01111111); // +NaN
for(auto fp8_nan_id : fp8_nan_ids)
{
auto idx = fp8_nan_id;
ASSERT_TRUE(std::isnan(type_convert<float>(out[idx])));
}
for(ck::index_t fp8_id = 0; fp8_id < 256; fp8_id++)
{
if(fp8_nan_ids.find(fp8_id) != fp8_nan_ids.end())
continue;
uint8_t fp8_uid = static_cast<uint8_t>(fp8_id);
auto idx = fp8_uid;
ASSERT_FLOAT_EQ(out[idx], type_convert<bhalf_t>(f8_ocp_t{fp8_uid}))
<< " fp8_id: " << fp8_id << std::endl
<< type_convert<float>(type_convert<bhalf_t>(f8_ocp_t{fp8_uid}));
}
// /// Test vector conversions
auto i = 256;
// fp8x2 -> bf16x2
EXPECT_EQ(out[i++], type_convert<bhalf_t>(-powf(2.0f, -6.0f)));
EXPECT_EQ(out[i++], type_convert<bhalf_t>(powf(2.0f, -9.0f)));
// bf16x2 -> fp8x2
// RNE
EXPECT_EQ(out[i++], type_convert<bhalf_t>(-4.0f));
EXPECT_EQ(out[i++], type_convert<bhalf_t>(2.0f));
// SR
EXPECT_EQ(out[i++], type_convert<bhalf_t>(-4.0f));
EXPECT_EQ(out[i++], type_convert<bhalf_t>(2.0f));
EXPECT_EQ(test_size, completed);
EXPECT_EQ(test_size, i);
}
__global__ void
device_test_bf16_fp8_type_convert(uint64_t N, bhalf_t* p_test, uint64_t* p_completed)
{
test_bf16_fp8_type_convert(N, p_test, p_completed);
}
TEST(FP8OCP, DeviceBF16FP8Convert)
{
std::vector<bhalf_t> out(test_size, -1.0f);
DeviceMem device_out(test_size * sizeof(bhalf_t));
DeviceMem device_completed(sizeof(uint64_t));
device_out.SetValue(-21.0f);
device_completed.SetValue(-21.0f);
device_test_bf16_fp8_type_convert<<<1, 1>>>(
test_size,
static_cast<bhalf_t*>(device_out.GetDeviceBuffer()),
static_cast<uint64_t*>(device_completed.GetDeviceBuffer()));
uint64_t completed = 0;
device_completed.FromDevice(&completed);
device_out.FromDevice(out.data());
std::set<uint8_t> fp8_nan_ids;
fp8_nan_ids.insert(0b11111111); //-NaN
fp8_nan_ids.insert(0b01111111); // +NaN
for(auto fp8_nan_id : fp8_nan_ids)
{
auto idx = fp8_nan_id;
ASSERT_TRUE(std::isnan(type_convert<float>(out[idx])))
<< "idx: " << idx << " out[idx]: " << type_convert<float>(out[idx]);
}
for(ck::index_t fp8_id = 0; fp8_id < 256; fp8_id++)
{
if(fp8_nan_ids.find(fp8_id) != fp8_nan_ids.end())
continue;
uint8_t fp8_uid = static_cast<uint8_t>(fp8_id);
auto idx = fp8_uid;
ASSERT_FLOAT_EQ(out[idx], type_convert<bhalf_t>(f8_ocp_t{fp8_uid}))
<< " fp8_id: " << fp8_id << std::endl
<< type_convert<float>(type_convert<bhalf_t>(f8_ocp_t{fp8_uid}));
}
/// Test vector conversions
auto i = 256;
// fp8x2 -> bf16x2
EXPECT_EQ(out[i++], type_convert<bhalf_t>(-powf(2.0f, -6.0f)));
EXPECT_EQ(out[i++], type_convert<bhalf_t>(powf(2.0f, -9.0f)));
// bf16x2 -> fp8x2
// RNE
EXPECT_EQ(out[i++], type_convert<bhalf_t>(-4.0f));
EXPECT_EQ(out[i++], type_convert<bhalf_t>(2.0f));
// SR
EXPECT_EQ(out[i++], type_convert<bhalf_t>(-4.0f));
EXPECT_EQ(out[i++], type_convert<bhalf_t>(2.0f));
EXPECT_EQ(test_size, completed);
EXPECT_EQ(test_size, i);
}