poc convert fnuz fp8 to non-native dtype similar to ocp (#2871)

This commit is contained in:
Max Podkorytov
2025-09-18 22:51:01 -07:00
committed by GitHub
parent 47cd0d5cff
commit e469fee046
7 changed files with 80 additions and 49 deletions

View File

@@ -33,8 +33,34 @@
namespace ck {
using f8_fnuz_t = _BitInt(8);
using bf8_fnuz_t = unsigned _BitInt(8);
struct f8_fnuz_t
{
using data_type = unsigned char;
data_type m_data;
__host__ __device__ explicit constexpr f8_fnuz_t(data_type in_data) : m_data(in_data) {}
__host__ __device__ explicit constexpr f8_fnuz_t() = default;
__host__ __device__ bool constexpr operator==(f8_fnuz_t other) const
{
return m_data == other.m_data;
}
__host__ __device__ explicit constexpr operator data_type() const { return m_data; }
};
struct bf8_fnuz_t
{
using data_type = unsigned char;
data_type m_data;
__host__ __device__ explicit constexpr bf8_fnuz_t(data_type in_data) : m_data(in_data) {}
__host__ __device__ explicit constexpr bf8_fnuz_t() = default;
__host__ __device__ bool constexpr operator==(bf8_fnuz_t other) const
{
return m_data == other.m_data;
}
__host__ __device__ explicit constexpr operator data_type() const { return m_data; }
};
static_assert(1 == sizeof(f8_fnuz_t));
static_assert(1 == sizeof(bf8_fnuz_t));
typedef unsigned char fp8_storage_t;

View File

@@ -205,7 +205,7 @@ inline constexpr bool is_native_type()
return is_same<T, double>::value || is_same<T, float>::value || is_same<T, half_t>::value ||
is_same<T, bhalf_t>::value || is_same<T, int32_t>::value ||
is_same<T, uint32_t>::value || is_same<T, int8_t>::value || is_same<T, uint8_t>::value ||
is_same<T, f8_fnuz_t>::value || is_same<T, bf8_fnuz_t>::value || is_same<T, bool>::value;
is_same_v<T, _BitInt(8)> || is_same_v<T, unsigned _BitInt(8)> || is_same<T, bool>::value;
}
// scalar_type
@@ -300,14 +300,14 @@ struct scalar_type<pk_i4_t>
template <>
struct scalar_type<f8_fnuz_t>
{
using type = f8_fnuz_t;
using type = f8_fnuz_t::data_type;
static constexpr index_t vector_size = 1;
};
template <>
struct scalar_type<bf8_fnuz_t>
{
using type = bf8_fnuz_t;
using type = bf8_fnuz_t::data_type;
static constexpr index_t vector_size = 1;
};

View File

@@ -1294,6 +1294,18 @@ struct nnvb_data_t_selector<bf8_ocp_t>
using type = bf8_ocp_t::data_type;
};
template <>
struct nnvb_data_t_selector<f8_fnuz_t>
{
using type = f8_fnuz_t::data_type;
};
template <>
struct nnvb_data_t_selector<bf8_fnuz_t>
{
using type = bf8_fnuz_t::data_type;
};
template <>
struct nnvb_data_t_selector<e8m0_bexp_t>
{

View File

@@ -39,7 +39,7 @@ __host__ __device__ Y run_cast_to_f8(X x, uint32_t rng)
int exponent, bias;
uint32_t head, mantissa, sign;
// nan code is same for float and half
constexpr Y nan_code = 0x80;
constexpr uint8_t nan_code = 0x80;
constexpr uint32_t nan_mask = NumericUtils<X>::nan_mask;
// convert to bitwise
@@ -60,17 +60,17 @@ __host__ __device__ Y run_cast_to_f8(X x, uint32_t rng)
if constexpr(negative_zero_nan)
{
if((x_bitwise & nan_mask) == nan_mask)
return nan_code;
return Y{nan_code};
}
else
{
if((x_bitwise & nan_mask) == nan_mask)
return signed_inf + (mantissa != 0 ? 1 : 0);
return Y{static_cast<uint8_t>(signed_inf + (mantissa != 0 ? 1 : 0))};
}
// check if x is 0.0
if(x_bitwise == 0)
return 0;
return Y{0};
// First need to check if it is normal or denorm as there is a difference of implict 1
// Then need to adjust the exponent to align with the F8 exponent, in the meanwhile, shift
@@ -178,9 +178,10 @@ In this case, the fp16 mantissa should be shift left by 1 */
// check if x is 0.0 or -0.0
if(out_exponent == 0 && mantissa == 0)
return negative_zero_nan ? 0 : (sign << (out_exp + out_mant));
return Y{negative_zero_nan ? 0 : static_cast<uint8_t>(sign << (out_exp + out_mant))};
mantissa &= (1 << out_mant) - 1;
return (sign << (out_exp + out_mant)) | (out_exponent << out_mant) | mantissa;
return Y{static_cast<uint8_t>((sign << (out_exp + out_mant)) | (out_exponent << out_mant) |
mantissa)};
}
template <typename X, typename Y, bool negative_zero_nan>
@@ -195,8 +196,8 @@ __host__ __device__ Y run_cast_from_f8(X x)
constexpr int out_mant = NumericUtils<Y>::mant;
// prepare the codes
constexpr X nan_code = 0x80;
using T_bitwise = typename NumericUtils<Y>::bitwise_type;
constexpr uint8_t nan_code = 0x80;
using T_bitwise = typename NumericUtils<Y>::bitwise_type;
constexpr T_bitwise Inf_bitwise = NumericUtils<Y>::Inf;
constexpr T_bitwise NegInf_bitwise = NumericUtils<Y>::NegInf;
@@ -209,13 +210,13 @@ __host__ __device__ Y run_cast_from_f8(X x)
constexpr Y Neg0 = bit_cast<Y>(Neg0_bitwise);
// check if x is 0.0
if(x == 0)
if(!static_cast<uint8_t>(x))
return static_cast<Y>(0);
// unpack the input
uint32_t sign = x >> (in_exp + in_mant);
uint32_t mantissa = x & ((1 << in_mant) - 1);
int exponent = (x & 0x7F) >> in_mant;
uint32_t sign = static_cast<uint8_t>(x) >> (in_exp + in_mant);
uint32_t mantissa = static_cast<uint8_t>(x) & ((1 << in_mant) - 1);
int exponent = (static_cast<uint8_t>(x) & 0x7F) >> in_mant;
constexpr int exp_low_cutoff =
(1 << (out_exp - 1)) - (1 << (in_exp - 1)) + 1 - (negative_zero_nan ? 1 : 0);
@@ -223,12 +224,12 @@ __host__ __device__ Y run_cast_from_f8(X x)
if constexpr(negative_zero_nan)
{
if(x == nan_code)
if(static_cast<uint8_t>(x) == nan_code)
return NaN;
}
else
{
if(x == nan_code)
if(static_cast<uint8_t>(x) == nan_code)
return Neg0;
if(exponent == ((1 << in_exp) - 1))
return (mantissa == 0) ? (sign ? NegInf : Inf) : NaN;

View File

@@ -351,7 +351,7 @@ inline __host__ __device__ f8_fnuz_t f8_convert_sr<f8_fnuz_t, float>(float x)
val.fval = __builtin_amdgcn_fmed3f(val.fval, max_fp8, -max_fp8);
ival = __builtin_amdgcn_cvt_sr_fp8_f32(val.fval, rng, ival, 0); // 0 pos
val.i32val = ival;
return val.i8val[0]; // little endian
return f8_t{val.i8val[0]}; // little endian
#else
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
@@ -419,7 +419,7 @@ inline __host__ __device__ bf8_fnuz_t f8_convert_sr<bf8_fnuz_t, float>(float x)
val.fval = __builtin_amdgcn_fmed3f(val.fval, max_bf8, -max_bf8);
ival = __builtin_amdgcn_cvt_sr_bf8_f32(val.fval, rng, ival, 0); // 0 pos
val.i32val = ival;
return val.i8val[0]; // little endian
return bf8_t{val.i8val[0]}; // little endian
#else
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
@@ -655,7 +655,7 @@ inline __host__ __device__ f8_fnuz_t f8_convert_rne<f8_fnuz_t, float>(float x)
val.fval = __builtin_amdgcn_fmed3f(val.fval, max_fp8, -max_fp8);
ival = __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival, false); // false -> WORD0
val.i32val = ival;
return val.i8val[0];
return f8_t{val.i8val[0]};
#else
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
@@ -707,7 +707,7 @@ inline __host__ __device__ bf8_fnuz_t f8_convert_rne<bf8_fnuz_t, float>(float x)
val.fval = __builtin_amdgcn_fmed3f(val.fval, max_bf8, -max_bf8);
ival = __builtin_amdgcn_cvt_pk_bf8_f32(val.fval, val.fval, ival, false); // false -> WORD0
val.i32val = ival;
return val.i8val[0];
return bf8_t{val.i8val[0]};
#else
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
@@ -924,7 +924,7 @@ inline __host__ __device__ float type_convert<float, f8_fnuz_t>(f8_fnuz_t x)
{
#if defined(__gfx94__)
float fval;
uint32_t i32val = static_cast<uint32_t>(x);
uint32_t i32val = static_cast<uint32_t>(static_cast<uint8_t>(x));
fval = __builtin_amdgcn_cvt_f32_fp8(i32val, 0);
// asm volatile("v_cvt_f32_fp8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val));
return fval;
@@ -1430,7 +1430,7 @@ inline __host__ __device__ float type_convert<float, bf8_fnuz_t>(bf8_fnuz_t x)
{
#if defined(__gfx94__)
float fval;
uint32_t i32val = static_cast<uint32_t>(x);
uint32_t i32val = static_cast<uint32_t>(static_cast<uint8_t>(x));
fval = __builtin_amdgcn_cvt_f32_bf8(i32val, 0);
// asm volatile("v_cvt_f32_bf8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val));
return fval;