mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
poc convert fnuz fp8 to non-native dtype similar to ocp (#2871)
This commit is contained in:
@@ -33,8 +33,34 @@
|
||||
|
||||
namespace ck {
|
||||
|
||||
using f8_fnuz_t = _BitInt(8);
|
||||
using bf8_fnuz_t = unsigned _BitInt(8);
|
||||
struct f8_fnuz_t
|
||||
{
|
||||
using data_type = unsigned char;
|
||||
data_type m_data;
|
||||
__host__ __device__ explicit constexpr f8_fnuz_t(data_type in_data) : m_data(in_data) {}
|
||||
__host__ __device__ explicit constexpr f8_fnuz_t() = default;
|
||||
__host__ __device__ bool constexpr operator==(f8_fnuz_t other) const
|
||||
{
|
||||
return m_data == other.m_data;
|
||||
}
|
||||
__host__ __device__ explicit constexpr operator data_type() const { return m_data; }
|
||||
};
|
||||
|
||||
struct bf8_fnuz_t
|
||||
{
|
||||
using data_type = unsigned char;
|
||||
data_type m_data;
|
||||
__host__ __device__ explicit constexpr bf8_fnuz_t(data_type in_data) : m_data(in_data) {}
|
||||
__host__ __device__ explicit constexpr bf8_fnuz_t() = default;
|
||||
__host__ __device__ bool constexpr operator==(bf8_fnuz_t other) const
|
||||
{
|
||||
return m_data == other.m_data;
|
||||
}
|
||||
__host__ __device__ explicit constexpr operator data_type() const { return m_data; }
|
||||
};
|
||||
|
||||
static_assert(1 == sizeof(f8_fnuz_t));
|
||||
static_assert(1 == sizeof(bf8_fnuz_t));
|
||||
|
||||
typedef unsigned char fp8_storage_t;
|
||||
|
||||
|
||||
@@ -205,7 +205,7 @@ inline constexpr bool is_native_type()
|
||||
return is_same<T, double>::value || is_same<T, float>::value || is_same<T, half_t>::value ||
|
||||
is_same<T, bhalf_t>::value || is_same<T, int32_t>::value ||
|
||||
is_same<T, uint32_t>::value || is_same<T, int8_t>::value || is_same<T, uint8_t>::value ||
|
||||
is_same<T, f8_fnuz_t>::value || is_same<T, bf8_fnuz_t>::value || is_same<T, bool>::value;
|
||||
is_same_v<T, _BitInt(8)> || is_same_v<T, unsigned _BitInt(8)> || is_same<T, bool>::value;
|
||||
}
|
||||
|
||||
// scalar_type
|
||||
@@ -300,14 +300,14 @@ struct scalar_type<pk_i4_t>
|
||||
template <>
|
||||
struct scalar_type<f8_fnuz_t>
|
||||
{
|
||||
using type = f8_fnuz_t;
|
||||
using type = f8_fnuz_t::data_type;
|
||||
static constexpr index_t vector_size = 1;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct scalar_type<bf8_fnuz_t>
|
||||
{
|
||||
using type = bf8_fnuz_t;
|
||||
using type = bf8_fnuz_t::data_type;
|
||||
static constexpr index_t vector_size = 1;
|
||||
};
|
||||
|
||||
|
||||
@@ -1294,6 +1294,18 @@ struct nnvb_data_t_selector<bf8_ocp_t>
|
||||
using type = bf8_ocp_t::data_type;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct nnvb_data_t_selector<f8_fnuz_t>
|
||||
{
|
||||
using type = f8_fnuz_t::data_type;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct nnvb_data_t_selector<bf8_fnuz_t>
|
||||
{
|
||||
using type = bf8_fnuz_t::data_type;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct nnvb_data_t_selector<e8m0_bexp_t>
|
||||
{
|
||||
|
||||
@@ -39,7 +39,7 @@ __host__ __device__ Y run_cast_to_f8(X x, uint32_t rng)
|
||||
int exponent, bias;
|
||||
uint32_t head, mantissa, sign;
|
||||
// nan code is same for float and half
|
||||
constexpr Y nan_code = 0x80;
|
||||
constexpr uint8_t nan_code = 0x80;
|
||||
constexpr uint32_t nan_mask = NumericUtils<X>::nan_mask;
|
||||
|
||||
// convert to bitwise
|
||||
@@ -60,17 +60,17 @@ __host__ __device__ Y run_cast_to_f8(X x, uint32_t rng)
|
||||
if constexpr(negative_zero_nan)
|
||||
{
|
||||
if((x_bitwise & nan_mask) == nan_mask)
|
||||
return nan_code;
|
||||
return Y{nan_code};
|
||||
}
|
||||
else
|
||||
{
|
||||
if((x_bitwise & nan_mask) == nan_mask)
|
||||
return signed_inf + (mantissa != 0 ? 1 : 0);
|
||||
return Y{static_cast<uint8_t>(signed_inf + (mantissa != 0 ? 1 : 0))};
|
||||
}
|
||||
|
||||
// check if x is 0.0
|
||||
if(x_bitwise == 0)
|
||||
return 0;
|
||||
return Y{0};
|
||||
|
||||
// First need to check if it is normal or denorm as there is a difference of implict 1
|
||||
// Then need to adjust the exponent to align with the F8 exponent, in the meanwhile, shift
|
||||
@@ -178,9 +178,10 @@ In this case, the fp16 mantissa should be shift left by 1 */
|
||||
|
||||
// check if x is 0.0 or -0.0
|
||||
if(out_exponent == 0 && mantissa == 0)
|
||||
return negative_zero_nan ? 0 : (sign << (out_exp + out_mant));
|
||||
return Y{negative_zero_nan ? 0 : static_cast<uint8_t>(sign << (out_exp + out_mant))};
|
||||
mantissa &= (1 << out_mant) - 1;
|
||||
return (sign << (out_exp + out_mant)) | (out_exponent << out_mant) | mantissa;
|
||||
return Y{static_cast<uint8_t>((sign << (out_exp + out_mant)) | (out_exponent << out_mant) |
|
||||
mantissa)};
|
||||
}
|
||||
|
||||
template <typename X, typename Y, bool negative_zero_nan>
|
||||
@@ -195,8 +196,8 @@ __host__ __device__ Y run_cast_from_f8(X x)
|
||||
constexpr int out_mant = NumericUtils<Y>::mant;
|
||||
|
||||
// prepare the codes
|
||||
constexpr X nan_code = 0x80;
|
||||
using T_bitwise = typename NumericUtils<Y>::bitwise_type;
|
||||
constexpr uint8_t nan_code = 0x80;
|
||||
using T_bitwise = typename NumericUtils<Y>::bitwise_type;
|
||||
|
||||
constexpr T_bitwise Inf_bitwise = NumericUtils<Y>::Inf;
|
||||
constexpr T_bitwise NegInf_bitwise = NumericUtils<Y>::NegInf;
|
||||
@@ -209,13 +210,13 @@ __host__ __device__ Y run_cast_from_f8(X x)
|
||||
constexpr Y Neg0 = bit_cast<Y>(Neg0_bitwise);
|
||||
|
||||
// check if x is 0.0
|
||||
if(x == 0)
|
||||
if(!static_cast<uint8_t>(x))
|
||||
return static_cast<Y>(0);
|
||||
|
||||
// unpack the input
|
||||
uint32_t sign = x >> (in_exp + in_mant);
|
||||
uint32_t mantissa = x & ((1 << in_mant) - 1);
|
||||
int exponent = (x & 0x7F) >> in_mant;
|
||||
uint32_t sign = static_cast<uint8_t>(x) >> (in_exp + in_mant);
|
||||
uint32_t mantissa = static_cast<uint8_t>(x) & ((1 << in_mant) - 1);
|
||||
int exponent = (static_cast<uint8_t>(x) & 0x7F) >> in_mant;
|
||||
|
||||
constexpr int exp_low_cutoff =
|
||||
(1 << (out_exp - 1)) - (1 << (in_exp - 1)) + 1 - (negative_zero_nan ? 1 : 0);
|
||||
@@ -223,12 +224,12 @@ __host__ __device__ Y run_cast_from_f8(X x)
|
||||
|
||||
if constexpr(negative_zero_nan)
|
||||
{
|
||||
if(x == nan_code)
|
||||
if(static_cast<uint8_t>(x) == nan_code)
|
||||
return NaN;
|
||||
}
|
||||
else
|
||||
{
|
||||
if(x == nan_code)
|
||||
if(static_cast<uint8_t>(x) == nan_code)
|
||||
return Neg0;
|
||||
if(exponent == ((1 << in_exp) - 1))
|
||||
return (mantissa == 0) ? (sign ? NegInf : Inf) : NaN;
|
||||
|
||||
@@ -351,7 +351,7 @@ inline __host__ __device__ f8_fnuz_t f8_convert_sr<f8_fnuz_t, float>(float x)
|
||||
val.fval = __builtin_amdgcn_fmed3f(val.fval, max_fp8, -max_fp8);
|
||||
ival = __builtin_amdgcn_cvt_sr_fp8_f32(val.fval, rng, ival, 0); // 0 pos
|
||||
val.i32val = ival;
|
||||
return val.i8val[0]; // little endian
|
||||
return f8_t{val.i8val[0]}; // little endian
|
||||
#else
|
||||
constexpr bool negative_zero_nan = true;
|
||||
constexpr bool clip = true;
|
||||
@@ -419,7 +419,7 @@ inline __host__ __device__ bf8_fnuz_t f8_convert_sr<bf8_fnuz_t, float>(float x)
|
||||
val.fval = __builtin_amdgcn_fmed3f(val.fval, max_bf8, -max_bf8);
|
||||
ival = __builtin_amdgcn_cvt_sr_bf8_f32(val.fval, rng, ival, 0); // 0 pos
|
||||
val.i32val = ival;
|
||||
return val.i8val[0]; // little endian
|
||||
return bf8_t{val.i8val[0]}; // little endian
|
||||
#else
|
||||
constexpr bool negative_zero_nan = true;
|
||||
constexpr bool clip = true;
|
||||
@@ -655,7 +655,7 @@ inline __host__ __device__ f8_fnuz_t f8_convert_rne<f8_fnuz_t, float>(float x)
|
||||
val.fval = __builtin_amdgcn_fmed3f(val.fval, max_fp8, -max_fp8);
|
||||
ival = __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival, false); // false -> WORD0
|
||||
val.i32val = ival;
|
||||
return val.i8val[0];
|
||||
return f8_t{val.i8val[0]};
|
||||
#else
|
||||
constexpr bool negative_zero_nan = true;
|
||||
constexpr bool clip = true;
|
||||
@@ -707,7 +707,7 @@ inline __host__ __device__ bf8_fnuz_t f8_convert_rne<bf8_fnuz_t, float>(float x)
|
||||
val.fval = __builtin_amdgcn_fmed3f(val.fval, max_bf8, -max_bf8);
|
||||
ival = __builtin_amdgcn_cvt_pk_bf8_f32(val.fval, val.fval, ival, false); // false -> WORD0
|
||||
val.i32val = ival;
|
||||
return val.i8val[0];
|
||||
return bf8_t{val.i8val[0]};
|
||||
#else
|
||||
constexpr bool negative_zero_nan = true;
|
||||
constexpr bool clip = true;
|
||||
@@ -924,7 +924,7 @@ inline __host__ __device__ float type_convert<float, f8_fnuz_t>(f8_fnuz_t x)
|
||||
{
|
||||
#if defined(__gfx94__)
|
||||
float fval;
|
||||
uint32_t i32val = static_cast<uint32_t>(x);
|
||||
uint32_t i32val = static_cast<uint32_t>(static_cast<uint8_t>(x));
|
||||
fval = __builtin_amdgcn_cvt_f32_fp8(i32val, 0);
|
||||
// asm volatile("v_cvt_f32_fp8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val));
|
||||
return fval;
|
||||
@@ -1430,7 +1430,7 @@ inline __host__ __device__ float type_convert<float, bf8_fnuz_t>(bf8_fnuz_t x)
|
||||
{
|
||||
#if defined(__gfx94__)
|
||||
float fval;
|
||||
uint32_t i32val = static_cast<uint32_t>(x);
|
||||
uint32_t i32val = static_cast<uint32_t>(static_cast<uint8_t>(x));
|
||||
fval = __builtin_amdgcn_cvt_f32_bf8(i32val, 0);
|
||||
// asm volatile("v_cvt_f32_bf8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val));
|
||||
return fval;
|
||||
|
||||
Reference in New Issue
Block a user