mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
Fix bf8 conversion issues (#1003)
* Fix the conversion
* Add bf8 functionality
* Enable example on MI200 as well
[ROCm/composable_kernel commit: 1fd27d520f]
This commit is contained in:
@@ -10,10 +10,8 @@ foreach(gpu IN LISTS GPU_TARGETS)
|
||||
add_example_executable(example_grouped_conv_bwd_weight_xdl_bf16 grouped_conv_bwd_weight_xdl_bf16.cpp)
|
||||
add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_bf16)
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx940" OR GPU_TARGETS MATCHES "gfx941" OR GPU_TARGETS MATCHES "gfx942")
|
||||
add_example_executable(example_grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8 grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8.cpp)
|
||||
add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8)
|
||||
endif()
|
||||
add_example_executable(example_grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8 grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8.cpp)
|
||||
add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8)
|
||||
set(target 1)
|
||||
endif()
|
||||
|
||||
|
||||
@@ -207,7 +207,8 @@ struct ConvertF8SR
|
||||
__host__ __device__ void operator()(Y& y, const X& x) const
|
||||
{
|
||||
// check Y datatype
|
||||
static_assert(is_same<Y, f8_t>::value, "Data type is not supported by this operation!");
|
||||
static_assert(is_same<Y, f8_t>::value || is_same<Y, bf8_t>::value,
|
||||
"Data type is not supported by this operation!");
|
||||
|
||||
// check X datatype
|
||||
static_assert(is_same<X, float>::value || is_same<X, half_t>::value,
|
||||
|
||||
@@ -1075,6 +1075,7 @@ struct NumericUtils<float>
|
||||
{
|
||||
static constexpr int exp = 8;
|
||||
static constexpr int mant = 23;
|
||||
static constexpr int bias = 127;
|
||||
static constexpr uint32_t nan_mask = 0x7F800000;
|
||||
static constexpr uint32_t head_mask = 0xFF800000;
|
||||
static constexpr uint32_t mant_mask = 0x7FFFFF;
|
||||
@@ -1091,6 +1092,7 @@ struct NumericUtils<half_t>
|
||||
{
|
||||
static constexpr int exp = 5;
|
||||
static constexpr int mant = 10;
|
||||
static constexpr int bias = 15;
|
||||
static constexpr uint16_t nan_mask = 0x7C00;
|
||||
static constexpr uint16_t head_mask = 0xFC00;
|
||||
static constexpr uint16_t mant_mask = 0x3FF;
|
||||
@@ -1107,6 +1109,8 @@ struct NumericUtils<f8_t>
|
||||
{
|
||||
static constexpr int exp = 4;
|
||||
static constexpr int mant = 3;
|
||||
static constexpr int bias = 8; // negative zero nan mode
|
||||
// static constexpr int bias = 7; // ieee mode
|
||||
};
|
||||
|
||||
template <>
|
||||
@@ -1114,6 +1118,7 @@ struct NumericUtils<bf8_t>
|
||||
{
|
||||
static constexpr int exp = 5;
|
||||
static constexpr int mant = 2;
|
||||
static constexpr int bias = 16; // negative zero nan mode
|
||||
// static constexpr int bias = 15; // ieee mode
|
||||
};
|
||||
//
|
||||
} // namespace ck
|
||||
|
||||
@@ -5,7 +5,6 @@
|
||||
|
||||
#include "ck/utility/data_type.hpp"
|
||||
|
||||
// these conversions are disabled if native conversions available
|
||||
namespace ck {
|
||||
|
||||
// fp8 rounding modes
|
||||
@@ -17,6 +16,9 @@ enum class f8_rounding_mode
|
||||
stochastic
|
||||
};
|
||||
|
||||
__host__ inline int clz(uint32_t x) { return __builtin_clz(x); }
|
||||
__device__ inline int clz(uint32_t x) { return __clz(x); }
|
||||
|
||||
} // namespace ck
|
||||
|
||||
namespace ck::utils {
|
||||
@@ -34,7 +36,7 @@ __host__ __device__ Y run_cast_to_f8(X x, uint32_t rng)
|
||||
constexpr int in_exp = NumericUtils<X>::exp;
|
||||
constexpr int in_mant = NumericUtils<X>::mant;
|
||||
|
||||
int exponent;
|
||||
int exponent, bias;
|
||||
uint32_t head, mantissa, sign;
|
||||
// nan code is same for float and half
|
||||
constexpr Y nan_code = 0x80;
|
||||
@@ -49,12 +51,11 @@ __host__ __device__ Y run_cast_to_f8(X x, uint32_t rng)
|
||||
mantissa = x_bitwise & NumericUtils<X>::mant_mask;
|
||||
exponent = (head >> in_mant) & NumericUtils<X>::exp_mask;
|
||||
sign = head >> (in_exp + in_mant);
|
||||
bias = NumericUtils<X>::bias;
|
||||
|
||||
uint32_t signed_inf = (sign << (in_exp + in_mant)) + (((1 << in_exp) - 1) << in_mant);
|
||||
uint32_t drop_mask = (1 << (in_mant - out_mant)) - 1;
|
||||
constexpr int max_exp = (1 << out_exp) - (negative_zero_nan ? 1 : 2);
|
||||
constexpr int exp_low_cutoff =
|
||||
(1 << (in_exp - 1)) - (1 << (out_exp - 1)) + 1 - (negative_zero_nan ? 1 : 0);
|
||||
|
||||
if constexpr(negative_zero_nan)
|
||||
{
|
||||
@@ -67,56 +68,107 @@ __host__ __device__ Y run_cast_to_f8(X x, uint32_t rng)
|
||||
return signed_inf + (mantissa != 0 ? 1 : 0);
|
||||
}
|
||||
|
||||
// if input is half and output is bf8
|
||||
if((NumericUtils<X>::mant == 10) && (NumericUtils<Y>::mant == 2) && negative_zero_nan &&
|
||||
exponent == 0)
|
||||
{
|
||||
exponent += 1;
|
||||
while(mantissa < (1 << in_mant))
|
||||
{
|
||||
mantissa <<= 1;
|
||||
exponent -= 1;
|
||||
}
|
||||
mantissa &= ~(1 << in_mant);
|
||||
}
|
||||
|
||||
// check if x is 0.0
|
||||
if(x_bitwise == 0)
|
||||
return 0;
|
||||
|
||||
exponent -= exp_low_cutoff - 1;
|
||||
if(exponent <= 0)
|
||||
drop_mask = (1 << (in_mant - out_mant + 1 - exponent)) - 1;
|
||||
mantissa += 1 << in_mant;
|
||||
// apply random number if needed
|
||||
mantissa += (stoch ? rng : mantissa) & drop_mask;
|
||||
if(mantissa >= (2 << in_mant))
|
||||
{
|
||||
mantissa >>= 1;
|
||||
exponent++;
|
||||
}
|
||||
mantissa >>= (in_mant - out_mant);
|
||||
// First need to check if it is normal or denorm as there is a difference of implict 1
|
||||
// Then need to adjust the exponent to align with the F8 exponent, in the meanwhile, shift
|
||||
// The mantissa. Then for stochastic rounding, add rng to mantissa and truncate. And for
|
||||
// RNE, no need to add rng. Then probably need to check whether there is carry and adjust
|
||||
// exponent and mantissa again3
|
||||
|
||||
// check negative exponent
|
||||
if(exponent <= 0)
|
||||
{
|
||||
if(x_bitwise == 0)
|
||||
return 0;
|
||||
else
|
||||
// For IEEE bias mode, the bias is 2^(k-1)-1 where k is the width of exponent bits
|
||||
const int out_bias = (1 << (out_exp - 1)) - 1 + (negative_zero_nan ? 1 : 0);
|
||||
const int out_denormal_act_exponent = 1 - out_bias; // actual exponent of f8 denormal
|
||||
// act_exponent is the actual exponent of fp32/fp16 (after subtracting bias)
|
||||
// out_exponent is the converted f8 exponent with bias encoding
|
||||
// exponent_diff is the diff between fp32/fp16 exponent and f8 exponent,
|
||||
// the difference needs to be adjusted and mantissa shifted
|
||||
int act_exponent, out_exponent, exponent_diff;
|
||||
|
||||
if(exponent == 0)
|
||||
{ // fp32/fp16 is in denormal.
|
||||
/* fp32 denormal is below 2^-127 so it is usually not a concern here, we mostly concern fp16
|
||||
here. In this case, f8 is usually in denormal. But there could be exceptions. fp16 denormal has
|
||||
exponent bias 15 while bf8 with NANOO has exponent bias 16. It means that there are some numbers in
|
||||
fp16 denormal but they are bf8 (NANOO) normals - smallest bf8 (NANOO) normal is 2^-15. fp16 numbers
|
||||
where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8 (NANOO) normal.
|
||||
In this case, the fp16 mantissa should be shift left by 1 */
|
||||
act_exponent = exponent - bias + 1;
|
||||
exponent_diff = out_denormal_act_exponent -
|
||||
act_exponent; // actual exponent is exponent-bias+1 as it is denormal
|
||||
}
|
||||
else
|
||||
{ // fp32/fp16 is normal with implicit 1
|
||||
act_exponent = exponent - bias;
|
||||
if(act_exponent <= out_denormal_act_exponent)
|
||||
{
|
||||
// subnormal range; represented by a subnormal float8 (exponent 0)
|
||||
// and involves loss of accuracy
|
||||
mantissa >>= 1 - exponent;
|
||||
exponent = 0;
|
||||
/* This is the case where fp32/fp16 is normal but it is in f8 denormal range.
|
||||
For example fp8 nanoo mode, denormal exponent is -7, but if the fp32/fp16
|
||||
actual exponent is -7, it is actually larger due to the implict 1,
|
||||
Therefore it needs to be adjust to -6 and mantissa shift right by 1.
|
||||
So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */
|
||||
exponent_diff = out_denormal_act_exponent - act_exponent;
|
||||
}
|
||||
else
|
||||
{ // both fp32/fp16 and f8 are in normal range
|
||||
exponent_diff =
|
||||
0; // exponent_diff=0 does not mean there is no difference for this case,
|
||||
// act_exponent could be larger. Just that it does not need shift mantissa
|
||||
}
|
||||
mantissa += (1 << in_mant); // Add the implicit 1 into mantissa
|
||||
}
|
||||
|
||||
bool midpoint = (mantissa & ((1 << (in_mant - out_mant + exponent_diff)) - 1)) ==
|
||||
(1 << (in_mant - out_mant + exponent_diff - 1));
|
||||
/* This part is a bit tricky. The judgment of whether it is a tie needs to be done before we
|
||||
shift right as shift right could rip off some residual part and make something not midpoint look
|
||||
like midpoint. For example, the fp16 number 0x1002 (0 00100 0000000010), it is larger than
|
||||
midpoint, but after shift right by 4 bits, it would look like midpoint. */
|
||||
|
||||
if(exponent_diff > 0)
|
||||
mantissa >>= exponent_diff;
|
||||
else if(exponent_diff == -1)
|
||||
mantissa <<= -exponent_diff;
|
||||
bool implicit_one = mantissa & (1 << in_mant);
|
||||
// if there is no implict 1, it means the f8 is denormal and need to adjust to denorm exponent
|
||||
out_exponent =
|
||||
(act_exponent + exponent_diff) /*actual f8 exponent*/ + out_bias - (implicit_one ? 0 : 1);
|
||||
|
||||
// Now we have the exponent and mantissa adjusted
|
||||
bool odd =
|
||||
mantissa &
|
||||
(1 << (in_mant - out_mant)); // if the least significant bit that is not truncated is 1
|
||||
mantissa += (stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1) : mantissa)) & drop_mask;
|
||||
|
||||
// Now we deal with overflow
|
||||
if(out_exponent == 0)
|
||||
{
|
||||
if((1 << in_mant) & mantissa)
|
||||
{
|
||||
out_exponent = 1; // denormal overflow to become normal, promote exponent
|
||||
// No need to make 1 implicit now as it will be addressed later
|
||||
}
|
||||
}
|
||||
// above range: quantize to maximum possible float of the same sign
|
||||
else if(exponent > max_exp)
|
||||
else
|
||||
{
|
||||
if((1 << (in_mant + 1)) & mantissa)
|
||||
{
|
||||
mantissa >>= 1;
|
||||
out_exponent++;
|
||||
// No need to make 1 implicit now as it will be addressed later
|
||||
}
|
||||
}
|
||||
|
||||
mantissa >>= (in_mant - out_mant);
|
||||
|
||||
if(out_exponent > max_exp)
|
||||
{
|
||||
if(clip)
|
||||
{
|
||||
mantissa = (1 << out_mant) - 1;
|
||||
exponent = max_exp;
|
||||
mantissa = (1 << out_mant) - 1;
|
||||
out_exponent = max_exp;
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -125,10 +177,10 @@ __host__ __device__ Y run_cast_to_f8(X x, uint32_t rng)
|
||||
}
|
||||
|
||||
// check if x is 0.0 or -0.0
|
||||
if(exponent == 0 && mantissa == 0)
|
||||
if(out_exponent == 0 && mantissa == 0)
|
||||
return negative_zero_nan ? 0 : (sign << (out_exp + out_mant));
|
||||
mantissa &= (1 << out_mant) - 1;
|
||||
return (sign << (out_exp + out_mant)) | (exponent << out_mant) | mantissa;
|
||||
return (sign << (out_exp + out_mant)) | (out_exponent << out_mant) | mantissa;
|
||||
}
|
||||
|
||||
template <typename X, typename Y, bool negative_zero_nan>
|
||||
@@ -194,12 +246,9 @@ __host__ __device__ Y run_cast_from_f8(X x)
|
||||
if(exponent == 0)
|
||||
{
|
||||
// guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above
|
||||
exponent++;
|
||||
while(mantissa < (1 << in_mant))
|
||||
{
|
||||
mantissa <<= 1;
|
||||
exponent--;
|
||||
}
|
||||
int sh = 1 + clz(mantissa) - (32 - in_mant);
|
||||
mantissa <<= sh;
|
||||
exponent += 1 - sh;
|
||||
mantissa &= ((1 << in_mant) - 1);
|
||||
}
|
||||
exponent += exp_low_cutoff - 1;
|
||||
|
||||
@@ -145,7 +145,7 @@ inline __host__ __device__ f8_t type_convert<f8_t, half_t>(half_t x)
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
// convert to float and use native converion
|
||||
return type_convert<f8_t>(type_convert<float>(x));
|
||||
#elif 0
|
||||
#else
|
||||
constexpr bool negative_zero_nan = true;
|
||||
constexpr bool clip = true;
|
||||
constexpr f8_rounding_mode rm = f8_rounding_mode::standard;
|
||||
@@ -153,8 +153,6 @@ inline __host__ __device__ f8_t type_convert<f8_t, half_t>(half_t x)
|
||||
return utils::
|
||||
cast_to_f8<half_t, f8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
|
||||
x, rng);
|
||||
#else
|
||||
return type_convert<f8_t>(type_convert<float>(x));
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -165,11 +163,9 @@ inline __host__ __device__ half_t type_convert<half_t, f8_t>(f8_t x)
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
// use native conversion to float and convert to fp16
|
||||
return type_convert<half_t>(type_convert<float>(x));
|
||||
#elif 0
|
||||
#else
|
||||
constexpr bool negative_zero_nan = true;
|
||||
return utils::cast_from_f8<f8_t, half_t, negative_zero_nan>(x);
|
||||
#else
|
||||
return type_convert<half_t>(type_convert<float>(x));
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -223,7 +219,7 @@ inline __host__ __device__ bf8_t type_convert<bf8_t, half_t>(half_t x)
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
// convert to float and use native converion
|
||||
return type_convert<bf8_t>(type_convert<float>(x));
|
||||
#elif 0
|
||||
#else
|
||||
constexpr bool negative_zero_nan = true;
|
||||
constexpr bool clip = true;
|
||||
constexpr f8_rounding_mode rm = f8_rounding_mode::standard;
|
||||
@@ -231,8 +227,6 @@ inline __host__ __device__ bf8_t type_convert<bf8_t, half_t>(half_t x)
|
||||
return utils::
|
||||
cast_to_f8<half_t, bf8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
|
||||
x, rng);
|
||||
#else
|
||||
return type_convert<bf8_t>(type_convert<float>(x));
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -243,11 +237,9 @@ inline __host__ __device__ half_t type_convert<half_t, bf8_t>(bf8_t x)
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
// use native conversion to float and convert to fp16
|
||||
return type_convert<half_t>(type_convert<float>(x));
|
||||
#elif 0
|
||||
#else
|
||||
constexpr bool negative_zero_nan = true;
|
||||
return utils::cast_from_f8<bf8_t, half_t, negative_zero_nan>(x);
|
||||
#else
|
||||
return type_convert<half_t>(type_convert<float>(x));
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -347,7 +339,7 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, half_t>(half_t x)
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
// convert to float and use native converion
|
||||
return f8_convert_sr<f8_t>(type_convert<float>(x));
|
||||
#elif 0
|
||||
#else
|
||||
constexpr bool negative_zero_nan = true;
|
||||
constexpr bool clip = true;
|
||||
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
|
||||
@@ -356,8 +348,6 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, half_t>(half_t x)
|
||||
return utils::
|
||||
cast_to_f8<half_t, f8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
|
||||
x, rng);
|
||||
#else
|
||||
return f8_convert_sr<f8_t>(type_convert<float>(x));
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -396,7 +386,7 @@ inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, half_t>(half_t x)
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
// convert to float and use native converion
|
||||
return f8_convert_sr<f8_t>(type_convert<float>(x));
|
||||
#elif 0
|
||||
#else
|
||||
constexpr bool negative_zero_nan = true;
|
||||
constexpr bool clip = true;
|
||||
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
|
||||
@@ -406,8 +396,6 @@ inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, half_t>(half_t x)
|
||||
return utils::
|
||||
cast_to_f8<half_t, bf8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
|
||||
x, rng);
|
||||
#else
|
||||
return f8_convert_sr<bf8_t>(type_convert<float>(x));
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user