mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
Refactor f8_t, add bf8_t (#792)
* Refactor f8_t to add bf8_t * Add check_err impl for f8_t * Update fp8 test * Format * Revert the fix * Update vector_type implementation * Add bf8 test * Add bf8, use BitInt types * Add bf8 conversion methods * Update type_convert for fp8/bf8 * Add check_err fp8/bf8 support * Add subnorm fp8 tests * Add subnorm bf8 tests * Fix conversion * Add bf8 cmake bindings * Add macros to enable build with disabled fp8/bf8 * Remove is_native method * Update flag combination for mixed precision instances * Add more flag checks * Add another flag to a client example * Add type traits, decouple f8/bf8 casting * Clean up * Decouple fp8 and bf8 flags * Remove more redundant flags * Remove leftover comments
This commit is contained in:
@@ -15,6 +15,12 @@ if (DTYPES)
|
||||
if (DTYPES MATCHES "fp8")
|
||||
add_definitions(-DCK_ENABLE_FP8)
|
||||
set(CK_ENABLE_FP8 "ON")
|
||||
add_compile_options(-Wno-bit-int-extension)
|
||||
endif()
|
||||
if (DTYPES MATCHES "bf8")
|
||||
add_definitions(-DCK_ENABLE_BF8)
|
||||
set(CK_ENABLE_BF8 "ON")
|
||||
add_compile_options(-Wno-bit-int-extension)
|
||||
endif()
|
||||
if (DTYPES MATCHES "fp16")
|
||||
add_definitions(-DCK_ENABLE_FP16)
|
||||
@@ -34,8 +40,9 @@ if (DTYPES)
|
||||
endif()
|
||||
message("DTYPES macro set to ${DTYPES}")
|
||||
else()
|
||||
add_definitions(-DCK_ENABLE_INT8 -DCK_ENABLE_FP8 -DCK_ENABLE_FP16 -DCK_ENABLE_FP32 -DCK_ENABLE_FP64 -DCK_ENABLE_BF16)
|
||||
add_definitions(-DCK_ENABLE_INT8 -DCK_ENABLE_FP8 -DCK_ENABLE_BF8 -DCK_ENABLE_FP16 -DCK_ENABLE_FP32 -DCK_ENABLE_FP64 -DCK_ENABLE_BF16)
|
||||
set(CK_ENABLE_ALL_DTYPES "ON")
|
||||
add_compile_options(-Wno-bit-int-extension) # enable fp8 and bf8
|
||||
endif()
|
||||
|
||||
if(DL_KERNELS)
|
||||
@@ -365,6 +372,10 @@ IF(IS_DIRECTORY "${PROJECT_SOURCE_DIR}/library/src/tensor_operation_instance/gpu
|
||||
#message("fp8 instance found!")
|
||||
set(add_inst 1)
|
||||
endif()
|
||||
if("${cmake_instance}" MATCHES "DTYPES MATCHES \"bf8\" " AND DTYPES MATCHES "bf8")
|
||||
#message("bf8 instance found!")
|
||||
set(add_inst 1)
|
||||
endif()
|
||||
if("${cmake_instance}" MATCHES "DTYPES MATCHES \"fp16\"" AND DTYPES MATCHES "fp16")
|
||||
#message("fp16 instance found!")
|
||||
set(add_inst 1)
|
||||
|
||||
@@ -1,2 +1,4 @@
|
||||
add_executable(client_splitK_gemm splitK_gemm_fp16_f8.cpp)
|
||||
target_link_libraries(client_splitK_gemm PRIVATE composable_kernel::device_operations)
|
||||
if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "fp16") OR NOT DEFINED DTYPES)
|
||||
add_executable(client_splitK_gemm splitK_gemm_fp16_f8.cpp)
|
||||
target_link_libraries(client_splitK_gemm PRIVATE composable_kernel::device_operations)
|
||||
endif()
|
||||
|
||||
@@ -69,5 +69,7 @@ if(DTYPES MATCHES "fp8" OR NOT DEFINED DTYPES)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
add_example_executable(example_gemm_xdl_fp16_f8 gemm_xdl_fp16_f8.cpp)
|
||||
add_dependencies(example_gemm_xdl example_gemm_xdl_fp16_f8)
|
||||
if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "fp16") OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_gemm_xdl_fp16_f8 gemm_xdl_fp16_f8.cpp)
|
||||
add_dependencies(example_gemm_xdl example_gemm_xdl_fp16_f8)
|
||||
endif()
|
||||
|
||||
@@ -43,6 +43,9 @@
|
||||
#ifndef CK_ENABLE_FP8
|
||||
#define CK_ENABLE_FP8 "ON"
|
||||
#endif
|
||||
#ifndef CK_ENABLE_BF8
|
||||
#define CK_ENABLE_BF8 "ON"
|
||||
#endif
|
||||
#ifndef CK_ENABLE_FP16
|
||||
#define CK_ENABLE_FP16 "ON"
|
||||
#endif
|
||||
@@ -66,6 +69,10 @@
|
||||
#cmakedefine CK_ENABLE_FP8 @CK_ENABLE_FP8@
|
||||
#endif
|
||||
|
||||
#ifndef CK_ENABLE_BF8
|
||||
#cmakedefine CK_ENABLE_BF8 @CK_ENABLE_BF8@
|
||||
#endif
|
||||
|
||||
#ifndef CK_ENABLE_FP16
|
||||
#cmakedefine CK_ENABLE_FP16 @CK_ENABLE_FP16@
|
||||
#endif
|
||||
|
||||
@@ -89,6 +89,7 @@ struct PassThrough
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined CK_ENABLE_FP8
|
||||
template <>
|
||||
__host__ __device__ void operator()<f8_t, f8_t>(f8_t& y, const f8_t& x) const
|
||||
{
|
||||
@@ -118,6 +119,7 @@ struct PassThrough
|
||||
{
|
||||
y = type_convert<f8_t>(x);
|
||||
}
|
||||
#endif
|
||||
};
|
||||
|
||||
struct UnaryConvert
|
||||
@@ -146,6 +148,7 @@ struct ConvertBF16RTN
|
||||
}
|
||||
};
|
||||
|
||||
#if defined CK_ENABLE_FP8
|
||||
struct ConvertF8SR
|
||||
{
|
||||
// convert to fp8 using stochastic rounding (SR)
|
||||
@@ -162,6 +165,7 @@ struct ConvertF8SR
|
||||
y = f8_convert_sr<Y>(x);
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
struct Scale
|
||||
{
|
||||
|
||||
@@ -456,6 +456,7 @@ struct mfma_type<MfmaInstr::mfma_f64_16x16x4f64>
|
||||
}
|
||||
};
|
||||
|
||||
#if defined CK_ENABLE_FP8
|
||||
template <>
|
||||
struct mfma_type<MfmaInstr::mfma_f32_32x32x16f8f8>
|
||||
{
|
||||
@@ -499,6 +500,7 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x32f8f8>
|
||||
intrin_mfma_f32_16x16x32f8f8<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
template <typename base_type, index_t MPerXdlops, index_t NPerXdlops>
|
||||
struct MfmaSelector
|
||||
@@ -640,6 +642,7 @@ struct MfmaSelector
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined CK_ENABLE_FP8
|
||||
template <>
|
||||
static constexpr auto GetMfma<f8_t, 32, 32>()
|
||||
{
|
||||
@@ -651,6 +654,7 @@ struct MfmaSelector
|
||||
{
|
||||
return MfmaInstr::mfma_f32_16x16x32f8f8;
|
||||
}
|
||||
#endif
|
||||
|
||||
static constexpr auto selected_mfma = mfma_type<GetMfma<base_type, MPerXdlops, NPerXdlops>()>{};
|
||||
|
||||
@@ -852,7 +856,11 @@ struct XdlopsGemm
|
||||
{
|
||||
static_assert(is_same<base_type, double>::value || is_same<base_type, float>::value ||
|
||||
is_same<base_type, half_t>::value || is_same<base_type, bhalf_t>::value ||
|
||||
is_same<base_type, int8_t>::value || is_same<base_type, f8_t>::value,
|
||||
is_same<base_type, int8_t>::value
|
||||
#if defined CK_ENABLE_FP8
|
||||
|| is_same<base_type, f8_t>::value
|
||||
#endif
|
||||
,
|
||||
"base base_type must be double, float, half, bfloat16, and int8_t!");
|
||||
|
||||
static_for<0, KPack / mfma_instr.k_per_blk, 1>{}([&](auto k) {
|
||||
|
||||
@@ -1127,7 +1127,7 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave,
|
||||
|
||||
#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
|
||||
uint32_t src_addr_shift = src_thread_element_valid ? 0 : 0x80000000;
|
||||
|
||||
#if defined CK_ENABLE_FP8
|
||||
if constexpr(is_same<scalar_t, f8_t>::value)
|
||||
{
|
||||
auto tmp = amd_buffer_load_impl<int8_t, vector_size, coherence>(
|
||||
@@ -1136,10 +1136,14 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave,
|
||||
}
|
||||
else
|
||||
{
|
||||
#endif
|
||||
return amd_buffer_load_impl<scalar_t, vector_size, coherence>(
|
||||
src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0);
|
||||
#if defined CK_ENABLE_FP8
|
||||
}
|
||||
#endif
|
||||
#else
|
||||
#if defined CK_ENABLE_FP8
|
||||
if constexpr(is_same<scalar_t, f8_t>::value)
|
||||
{
|
||||
auto tmp = amd_buffer_load_impl<int8_t, vector_size, coherence>(
|
||||
@@ -1148,11 +1152,14 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave,
|
||||
}
|
||||
else
|
||||
{
|
||||
#endif
|
||||
vector_t tmp = amd_buffer_load_impl<scalar_t, vector_size, coherence>(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, 0);
|
||||
return src_thread_element_valid ? tmp : vector_t(0);
|
||||
#if defined CK_ENABLE_FP8
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
}
|
||||
|
||||
// buffer_load requires:
|
||||
@@ -1209,7 +1216,7 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t
|
||||
|
||||
#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
|
||||
uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000;
|
||||
|
||||
#if defined CK_ENABLE_FP8
|
||||
if constexpr(is_same<scalar_t, f8_t>::value)
|
||||
{
|
||||
auto tmp =
|
||||
@@ -1219,12 +1226,16 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t
|
||||
}
|
||||
else
|
||||
{
|
||||
#endif
|
||||
amd_buffer_store_impl<scalar_t, vector_size, coherence>(
|
||||
src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0);
|
||||
#if defined CK_ENABLE_FP8
|
||||
}
|
||||
#endif
|
||||
#else
|
||||
if(dst_thread_element_valid)
|
||||
{
|
||||
#if defined CK_ENABLE_FP8
|
||||
if constexpr(is_same<scalar_t, f8_t>::value)
|
||||
{
|
||||
auto tmp = bit_cast<typename vector_type_maker<int8_t, vector_size>::type::type>(
|
||||
@@ -1234,9 +1245,12 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t
|
||||
}
|
||||
else
|
||||
{
|
||||
#endif
|
||||
amd_buffer_store_impl<scalar_t, vector_size, coherence>(
|
||||
src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0);
|
||||
#if defined CK_ENABLE_FP8
|
||||
}
|
||||
#endif
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -355,6 +355,7 @@ struct intrin_mfma_f64_16x16x4f64<16, 16>
|
||||
}
|
||||
};
|
||||
|
||||
#if defined CK_ENABLE_FP8
|
||||
template <index_t MPerWave, index_t NPerWave>
|
||||
struct intrin_mfma_f32_32x32x16f8f8;
|
||||
|
||||
@@ -417,5 +418,6 @@ struct intrin_mfma_f32_16x16x32f8f8<16, 16>
|
||||
#endif
|
||||
}
|
||||
};
|
||||
#endif
|
||||
} // namespace ck
|
||||
#endif
|
||||
|
||||
@@ -12,7 +12,12 @@ using half_t = _Float16;
|
||||
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
|
||||
using int4_t = _BitInt(4);
|
||||
#endif
|
||||
using f8_t = uint8_t;
|
||||
#if defined CK_ENABLE_FP8
|
||||
using f8_t = _BitInt(8);
|
||||
#endif
|
||||
#if defined CK_ENABLE_BF8
|
||||
using bf8_t = unsigned _BitInt(8);
|
||||
#endif
|
||||
|
||||
// vector_type
|
||||
template <typename T, index_t N>
|
||||
@@ -143,14 +148,24 @@ struct scalar_type<int4_t>
|
||||
};
|
||||
#endif
|
||||
|
||||
#if defined CK_ENABLE_FP8
|
||||
template <>
|
||||
struct scalar_type<f8_t>
|
||||
{
|
||||
using type = f8_t;
|
||||
static constexpr index_t vector_size = 1;
|
||||
};
|
||||
#endif
|
||||
|
||||
#if defined CK_ENABLE_BF8
|
||||
template <>
|
||||
struct scalar_type<bf8_t>
|
||||
{
|
||||
using type = bf8_t;
|
||||
static constexpr index_t vector_size = 1;
|
||||
};
|
||||
#endif
|
||||
|
||||
//
|
||||
template <typename T>
|
||||
struct vector_type<T, 1>
|
||||
{
|
||||
@@ -953,12 +968,24 @@ using int8x32_t = typename vector_type<int8_t, 32>::type;
|
||||
using int8x64_t = typename vector_type<int8_t, 64>::type;
|
||||
|
||||
// f8
|
||||
#if defined CK_ENABLE_FP8
|
||||
using f8x2_t = typename vector_type<f8_t, 2>::type;
|
||||
using f8x4_t = typename vector_type<f8_t, 4>::type;
|
||||
using f8x8_t = typename vector_type<f8_t, 8>::type;
|
||||
using f8x16_t = typename vector_type<f8_t, 16>::type;
|
||||
using f8x32_t = typename vector_type<f8_t, 32>::type;
|
||||
using f8x64_t = typename vector_type<f8_t, 64>::type;
|
||||
#endif
|
||||
|
||||
// bf8
|
||||
#if defined CK_ENABLE_BF8
|
||||
using bf8x2_t = typename vector_type<bf8_t, 2>::type;
|
||||
using bf8x4_t = typename vector_type<bf8_t, 4>::type;
|
||||
using bf8x8_t = typename vector_type<bf8_t, 8>::type;
|
||||
using bf8x16_t = typename vector_type<bf8_t, 16>::type;
|
||||
using bf8x32_t = typename vector_type<bf8_t, 32>::type;
|
||||
using bf8x64_t = typename vector_type<bf8_t, 64>::type;
|
||||
#endif
|
||||
|
||||
template <typename T>
|
||||
struct NumericLimits
|
||||
@@ -1006,21 +1033,109 @@ struct NumericLimits<int4_t>
|
||||
};
|
||||
#endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
|
||||
|
||||
#if defined CK_ENABLE_FP8
|
||||
template <>
|
||||
struct NumericLimits<f8_t>
|
||||
{
|
||||
// negative zero nan mode with exp bias = 8
|
||||
static constexpr uint8_t binary_min = 0x08; // 0b00001000
|
||||
static constexpr uint8_t binary_max = 0x77; // 0b01110111
|
||||
static constexpr uint8_t binary_lowest = 0xF7; // 0b11110111
|
||||
static constexpr uint8_t binary_max = 0x7F; // 0b01111111
|
||||
static constexpr uint8_t binary_lowest = 0xFF; // 0b11111111
|
||||
static constexpr uint8_t binary_qnan = 0x80; // 0b10000000
|
||||
// ieee mode with exp bias = 7
|
||||
// static constexpr uint8_t binary_min = 0x08; // 0b00001000
|
||||
// static constexpr uint8_t binary_max = 0x77; // 0b01110111
|
||||
// static constexpr uint8_t binary_lowest = 0xF7; // 0b11110111
|
||||
// static constexpr uint8_t binary_qnan = 0x79; // any sign, exp=1111, mant!=0
|
||||
|
||||
__host__ __device__ static constexpr f8_t Min() { return bit_cast<f8_t>(binary_min); }
|
||||
__host__ __device__ static constexpr f8_t Min() { return f8_t(binary_min); }
|
||||
|
||||
__host__ __device__ static constexpr f8_t Max() { return bit_cast<f8_t>(binary_max); }
|
||||
__host__ __device__ static constexpr f8_t Max() { return f8_t(binary_max); }
|
||||
|
||||
__host__ __device__ static constexpr f8_t Lowest() { return bit_cast<f8_t>(binary_lowest); }
|
||||
__host__ __device__ static constexpr f8_t Lowest() { return f8_t(binary_lowest); }
|
||||
|
||||
__host__ __device__ static constexpr f8_t QuietNaN() { return bit_cast<f8_t>(binary_qnan); }
|
||||
__host__ __device__ static constexpr f8_t QuietNaN() { return f8_t(binary_qnan); }
|
||||
};
|
||||
#endif
|
||||
|
||||
#if defined CK_ENABLE_BF8
|
||||
template <>
|
||||
struct NumericLimits<bf8_t>
|
||||
{
|
||||
// negative zero nan mode with exp bias = 16
|
||||
static constexpr uint8_t binary_min = 0x04; // 0b00000100
|
||||
static constexpr uint8_t binary_max = 0x7F; // 0b01111111
|
||||
static constexpr uint8_t binary_lowest = 0xFF; // 0b11111111
|
||||
static constexpr uint8_t binary_qnan = 0x80; // 0b10000000
|
||||
// ieee mode with exp bias = 15
|
||||
// static constexpr uint8_t binary_min = 0x04; // 0b00000100
|
||||
// static constexpr uint8_t binary_max = 0x7B; // 0b01111011
|
||||
// static constexpr uint8_t binary_lowest = 0xFB; // 0b11111011
|
||||
// static constexpr uint8_t binary_qnan = 0x79; // any sign, exp=1111, mant!=
|
||||
|
||||
__host__ __device__ static constexpr bf8_t Min() { return bf8_t(binary_min); }
|
||||
|
||||
__host__ __device__ static constexpr bf8_t Max() { return bf8_t(binary_max); }
|
||||
|
||||
__host__ __device__ static constexpr bf8_t Lowest() { return bf8_t(binary_lowest); }
|
||||
|
||||
__host__ __device__ static constexpr bf8_t QuietNaN() { return bf8_t(binary_qnan); }
|
||||
};
|
||||
#endif
|
||||
|
||||
template <typename T>
|
||||
struct NumericUtils
|
||||
{
|
||||
};
|
||||
|
||||
template <>
|
||||
struct NumericUtils<float>
|
||||
{
|
||||
static constexpr int exp = 8;
|
||||
static constexpr int mant = 23;
|
||||
static constexpr uint32_t nan_mask = 0x7F800000;
|
||||
static constexpr uint32_t head_mask = 0xFF800000;
|
||||
static constexpr uint32_t mant_mask = 0x7FFFFF;
|
||||
static constexpr uint32_t exp_mask = 0xFF;
|
||||
static constexpr uint32_t Inf = 0x7F800000;
|
||||
static constexpr uint32_t NegInf = 0xFF800000;
|
||||
static constexpr uint32_t NaN = 0x7F800001;
|
||||
static constexpr uint32_t Neg0 = 0x80000000;
|
||||
using bitwise_type = uint32_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct NumericUtils<half_t>
|
||||
{
|
||||
static constexpr int exp = 5;
|
||||
static constexpr int mant = 10;
|
||||
static constexpr uint16_t nan_mask = 0x7C00;
|
||||
static constexpr uint16_t head_mask = 0xFC00;
|
||||
static constexpr uint16_t mant_mask = 0x3FF;
|
||||
static constexpr uint16_t exp_mask = 0x1F;
|
||||
static constexpr uint32_t Inf = 0x7C00;
|
||||
static constexpr uint32_t NegInf = 0xFC00;
|
||||
static constexpr uint32_t NaN = 0x7C01;
|
||||
static constexpr uint32_t Neg0 = 0x8000;
|
||||
using bitwise_type = uint16_t;
|
||||
};
|
||||
|
||||
#if defined CK_ENABLE_FP8
|
||||
template <>
|
||||
struct NumericUtils<f8_t>
|
||||
{
|
||||
static constexpr int exp = 4;
|
||||
static constexpr int mant = 3;
|
||||
};
|
||||
#endif
|
||||
|
||||
#if defined CK_ENABLE_BF8
|
||||
template <>
|
||||
struct NumericUtils<bf8_t>
|
||||
{
|
||||
static constexpr int exp = 5;
|
||||
static constexpr int mant = 2;
|
||||
};
|
||||
#endif
|
||||
|
||||
} // namespace ck
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
|
||||
#include "ck/utility/data_type.hpp"
|
||||
|
||||
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
|
||||
namespace ck {
|
||||
|
||||
// fp8 rounding modes
|
||||
@@ -22,53 +23,38 @@ namespace ck::utils {
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename T, bool negative_zero_nan, bool clip, bool stoch>
|
||||
__host__ __device__ f8_t run_cast_to_f8(T x, uint32_t rng)
|
||||
template <typename X, typename Y, bool negative_zero_nan, bool clip, bool stoch>
|
||||
__host__ __device__ Y run_cast_to_f8(X x, uint32_t rng)
|
||||
{
|
||||
// check data type
|
||||
constexpr bool is_half = std::is_same<T, half_t>::value;
|
||||
constexpr bool is_float = std::is_same<T, float>::value;
|
||||
// fp8/bf8 exponent/mantissa layout
|
||||
constexpr int out_exp = NumericUtils<Y>::exp;
|
||||
constexpr int out_mant = NumericUtils<Y>::mant;
|
||||
|
||||
// fp8 exponent/mantissa layout
|
||||
constexpr int f8_exp = 4;
|
||||
constexpr int f8_mant = 3;
|
||||
|
||||
// resulting type exponent/mantissa layout
|
||||
constexpr int type_exp = is_half ? 5 : 8;
|
||||
constexpr int type_mant = is_half ? 10 : 23;
|
||||
// original type exponent/mantissa layout
|
||||
constexpr int in_exp = NumericUtils<X>::exp;
|
||||
constexpr int in_mant = NumericUtils<X>::mant;
|
||||
|
||||
int exponent;
|
||||
uint32_t head, mantissa, sign;
|
||||
// nan code is same for float and half
|
||||
constexpr uint8_t nan_code = 0x80;
|
||||
constexpr uint32_t nan_mask = is_half ? 0x7C00 : 0x7F800000;
|
||||
constexpr Y nan_code = 0x80;
|
||||
constexpr uint32_t nan_mask = NumericUtils<X>::nan_mask;
|
||||
|
||||
// convert to bitwise
|
||||
typedef typename std::conditional<std::is_same<T, half_t>::value, uint16_t, uint32_t>::type
|
||||
T_bitwise;
|
||||
using T_bitwise = typename NumericUtils<X>::bitwise_type;
|
||||
T_bitwise x_bitwise = *(reinterpret_cast<T_bitwise*>(&x));
|
||||
|
||||
// unpack the input, depends on datatype
|
||||
if constexpr(is_float)
|
||||
{
|
||||
head = x_bitwise & 0xFF800000;
|
||||
mantissa = x_bitwise & 0x7FFFFF;
|
||||
exponent = (head >> type_mant) & 0xFF;
|
||||
sign = head >> (type_exp + type_mant);
|
||||
}
|
||||
else if constexpr(is_half)
|
||||
{
|
||||
head = x_bitwise & 0xFC00;
|
||||
mantissa = x_bitwise & 0x3FF;
|
||||
exponent = (head >> type_mant) & 0x1F;
|
||||
sign = head >> (type_exp + type_mant);
|
||||
}
|
||||
head = x_bitwise & NumericUtils<X>::head_mask;
|
||||
mantissa = x_bitwise & NumericUtils<X>::mant_mask;
|
||||
exponent = (head >> in_mant) & NumericUtils<X>::exp_mask;
|
||||
sign = head >> (in_exp + in_mant);
|
||||
|
||||
uint32_t signed_inf = (sign << (type_exp + type_mant)) + (((1 << type_exp) - 1) << type_mant);
|
||||
uint32_t drop_mask = (1 << (type_mant - f8_mant)) - 1;
|
||||
constexpr int max_exp = (1 << f8_exp) - (negative_zero_nan ? 1 : 2);
|
||||
uint32_t signed_inf = (sign << (in_exp + in_mant)) + (((1 << in_exp) - 1) << in_mant);
|
||||
uint32_t drop_mask = (1 << (in_mant - out_mant)) - 1;
|
||||
constexpr int max_exp = (1 << out_exp) - (negative_zero_nan ? 1 : 2);
|
||||
constexpr int exp_low_cutoff =
|
||||
(1 << (type_exp - 1)) - (1 << (f8_exp - 1)) + 1 - (negative_zero_nan ? 1 : 0);
|
||||
(1 << (in_exp - 1)) - (1 << (out_exp - 1)) + 1 - (negative_zero_nan ? 1 : 0);
|
||||
|
||||
if constexpr(negative_zero_nan)
|
||||
{
|
||||
@@ -81,22 +67,35 @@ __host__ __device__ f8_t run_cast_to_f8(T x, uint32_t rng)
|
||||
return signed_inf + (mantissa != 0 ? 1 : 0);
|
||||
}
|
||||
|
||||
// if input is half and output is bf8
|
||||
if((NumericUtils<X>::mant == 10) && (NumericUtils<Y>::mant == 2) && negative_zero_nan &&
|
||||
exponent == 0)
|
||||
{
|
||||
exponent += 1;
|
||||
while(mantissa < (1 << in_mant))
|
||||
{
|
||||
mantissa <<= 1;
|
||||
exponent -= 1;
|
||||
}
|
||||
mantissa &= ~(1 << in_mant);
|
||||
}
|
||||
|
||||
// check if x is 0.0
|
||||
if(x_bitwise == 0)
|
||||
return 0;
|
||||
|
||||
exponent -= exp_low_cutoff - 1;
|
||||
if(exponent <= 0)
|
||||
drop_mask = (1 << (type_mant - f8_mant + 1 - exponent)) - 1;
|
||||
mantissa += 1 << type_mant;
|
||||
drop_mask = (1 << (in_mant - out_mant + 1 - exponent)) - 1;
|
||||
mantissa += 1 << in_mant;
|
||||
// apply random number if needed
|
||||
mantissa += (stoch ? rng : mantissa) & drop_mask;
|
||||
if(mantissa >= (2 << type_mant))
|
||||
if(mantissa >= (2 << in_mant))
|
||||
{
|
||||
mantissa >>= 1;
|
||||
exponent++;
|
||||
}
|
||||
mantissa >>= (type_mant - f8_mant);
|
||||
mantissa >>= (in_mant - out_mant);
|
||||
|
||||
// check negative exponent
|
||||
if(exponent <= 0)
|
||||
@@ -116,7 +115,7 @@ __host__ __device__ f8_t run_cast_to_f8(T x, uint32_t rng)
|
||||
{
|
||||
if(clip)
|
||||
{
|
||||
mantissa = (1 << f8_mant) - 1;
|
||||
mantissa = (1 << out_mant) - 1;
|
||||
exponent = max_exp;
|
||||
}
|
||||
else
|
||||
@@ -127,124 +126,120 @@ __host__ __device__ f8_t run_cast_to_f8(T x, uint32_t rng)
|
||||
|
||||
// check if x is 0.0 or -0.0
|
||||
if(exponent == 0 && mantissa == 0)
|
||||
return negative_zero_nan ? 0 : (sign << (f8_exp + f8_mant));
|
||||
mantissa &= (1 << f8_mant) - 1;
|
||||
return (sign << (f8_exp + f8_mant)) | (exponent << f8_mant) | mantissa;
|
||||
return negative_zero_nan ? 0 : (sign << (out_exp + out_mant));
|
||||
mantissa &= (1 << out_mant) - 1;
|
||||
return (sign << (out_exp + out_mant)) | (exponent << out_mant) | mantissa;
|
||||
}
|
||||
|
||||
template <typename T, bool negative_zero_nan>
|
||||
__host__ __device__ T run_cast_from_f8(f8_t x)
|
||||
template <typename X, typename Y, bool negative_zero_nan>
|
||||
__host__ __device__ Y run_cast_from_f8(X x)
|
||||
{
|
||||
// check data type
|
||||
constexpr bool is_half = std::is_same<T, half_t>::value;
|
||||
constexpr bool is_float = std::is_same<T, float>::value;
|
||||
|
||||
// fp8 exponent/mantissa layout
|
||||
constexpr int f8_exp = 4;
|
||||
constexpr int f8_mant = 3;
|
||||
// fp8/bf8 exponent/mantissa layout
|
||||
constexpr int in_exp = NumericUtils<X>::exp;
|
||||
constexpr int in_mant = NumericUtils<X>::mant;
|
||||
|
||||
// resulting type exponent/mantissa layout
|
||||
constexpr int type_exp = is_half ? 5 : 8;
|
||||
constexpr int type_mant = is_half ? 10 : 23;
|
||||
constexpr int out_exp = NumericUtils<Y>::exp;
|
||||
constexpr int out_mant = NumericUtils<Y>::mant;
|
||||
|
||||
// prepare the codes
|
||||
constexpr uint8_t nan_code = 0x80;
|
||||
T fInf, fNegInf, fNaN, fNeg0;
|
||||
if constexpr(is_half)
|
||||
{
|
||||
constexpr uint16_t ihInf = 0x7C00;
|
||||
constexpr uint16_t ihNegInf = 0xFC00;
|
||||
constexpr uint16_t ihNaN = 0x7C01;
|
||||
constexpr uint16_t ihNeg0 = 0x8000;
|
||||
fInf = *(reinterpret_cast<const half_t*>(&ihInf));
|
||||
fNegInf = *(reinterpret_cast<const half_t*>(&ihNegInf));
|
||||
fNaN = *(reinterpret_cast<const half_t*>(&ihNaN));
|
||||
fNeg0 = *(reinterpret_cast<const half_t*>(&ihNeg0));
|
||||
}
|
||||
else if constexpr(is_float)
|
||||
{
|
||||
constexpr uint32_t ifInf = 0x7F800000;
|
||||
constexpr uint32_t ifNegInf = 0xFF800000;
|
||||
constexpr uint32_t ifNaN = 0x7F800001;
|
||||
constexpr uint32_t ifNeg0 = 0x80000000;
|
||||
fInf = *(reinterpret_cast<const float*>(&ifInf));
|
||||
fNegInf = *(reinterpret_cast<const float*>(&ifNegInf));
|
||||
fNaN = *(reinterpret_cast<const float*>(&ifNaN));
|
||||
fNeg0 = *(reinterpret_cast<const float*>(&ifNeg0));
|
||||
}
|
||||
constexpr X nan_code = 0x80;
|
||||
Y Inf, NegInf, NaN, Neg0;
|
||||
using T_bitwise = typename NumericUtils<Y>::bitwise_type;
|
||||
|
||||
constexpr T_bitwise Inf_bitwise = NumericUtils<Y>::Inf;
|
||||
constexpr T_bitwise NegInf_bitwise = NumericUtils<Y>::NegInf;
|
||||
constexpr T_bitwise NaN_bitwise = NumericUtils<Y>::NaN;
|
||||
constexpr T_bitwise Neg0_bitwise = NumericUtils<Y>::Neg0;
|
||||
|
||||
Inf = *(reinterpret_cast<const Y*>(&Inf_bitwise));
|
||||
NegInf = *(reinterpret_cast<const Y*>(&NegInf_bitwise));
|
||||
NaN = *(reinterpret_cast<const Y*>(&NaN_bitwise));
|
||||
Neg0 = *(reinterpret_cast<const Y*>(&Neg0_bitwise));
|
||||
|
||||
// check if x is 0.0
|
||||
if(x == 0)
|
||||
return static_cast<Y>(0);
|
||||
|
||||
// unpack the input
|
||||
uint32_t sign = x >> (f8_exp + f8_mant);
|
||||
uint32_t mantissa = x & ((1 << f8_mant) - 1);
|
||||
int exponent = (x & 0x7F) >> f8_mant;
|
||||
uint32_t sign = x >> (in_exp + in_mant);
|
||||
uint32_t mantissa = x & ((1 << in_mant) - 1);
|
||||
int exponent = (x & 0x7F) >> in_mant;
|
||||
|
||||
constexpr int exp_low_cutoff =
|
||||
(1 << (type_exp - 1)) - (1 << (f8_exp - 1)) + 1 - (negative_zero_nan ? 1 : 0);
|
||||
typename std::conditional<std::is_same<T, half_t>::value, uint16_t, uint32_t>::type retval;
|
||||
(1 << (out_exp - 1)) - (1 << (in_exp - 1)) + 1 - (negative_zero_nan ? 1 : 0);
|
||||
T_bitwise retval;
|
||||
|
||||
if constexpr(negative_zero_nan)
|
||||
{
|
||||
if(x == nan_code)
|
||||
return fNaN;
|
||||
return NaN;
|
||||
}
|
||||
else
|
||||
{
|
||||
if(x == nan_code)
|
||||
return fNeg0;
|
||||
if(exponent == ((1 << f8_exp) - 1))
|
||||
return (mantissa == 0) ? (sign ? fNegInf : fInf) : fNaN;
|
||||
return Neg0;
|
||||
if(exponent == ((1 << in_exp) - 1))
|
||||
return (mantissa == 0) ? (sign ? NegInf : Inf) : NaN;
|
||||
}
|
||||
|
||||
if((NumericUtils<Y>::mant == 10) && (NumericUtils<X>::mant == 2) && !negative_zero_nan)
|
||||
{
|
||||
retval = x;
|
||||
retval <<= 8;
|
||||
return *(reinterpret_cast<const Y*>(&retval));
|
||||
}
|
||||
|
||||
// subnormal input
|
||||
if(exponent == 0)
|
||||
{
|
||||
// guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above
|
||||
int sh = 1 + __builtin_clz(mantissa) - ((1 + type_exp + type_mant) - f8_mant);
|
||||
mantissa <<= sh;
|
||||
mantissa &= ((1 << f8_mant) - 1);
|
||||
exponent += 1 - sh;
|
||||
exponent++;
|
||||
while(mantissa < (1 << in_mant))
|
||||
{
|
||||
mantissa <<= 1;
|
||||
exponent--;
|
||||
}
|
||||
mantissa &= ((1 << in_mant) - 1);
|
||||
}
|
||||
exponent += exp_low_cutoff - 1;
|
||||
mantissa <<= type_mant - f8_mant;
|
||||
mantissa <<= out_mant - in_mant;
|
||||
|
||||
// subnormal output (occurs when T=half, we=5, negative_zero_nan=true)
|
||||
if(exponent <= 0)
|
||||
{
|
||||
mantissa |= 1 << type_mant;
|
||||
mantissa |= 1 << out_mant;
|
||||
mantissa >>= 1 - exponent;
|
||||
exponent = 0;
|
||||
}
|
||||
|
||||
retval = (sign << (type_exp + type_mant)) | (exponent << type_mant) | mantissa;
|
||||
return *(reinterpret_cast<const T*>(&retval));
|
||||
retval = (sign << (out_exp + out_mant)) | (exponent << out_mant) | mantissa;
|
||||
return *(reinterpret_cast<const Y*>(&retval));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
template <typename T, bool negative_zero_nan, bool clip, bool stoch>
|
||||
__host__ __device__ f8_t cast_to_f8(T x, uint32_t rng)
|
||||
template <typename X, typename Y, bool negative_zero_nan, bool clip, bool stoch>
|
||||
__host__ __device__ Y cast_to_f8(X x, uint32_t rng)
|
||||
{
|
||||
// check datatype
|
||||
constexpr bool is_half = std::is_same<T, half_t>::value;
|
||||
constexpr bool is_float = std::is_same<T, float>::value;
|
||||
static_assert(is_half || is_float, "Only half and float can be casted to f8.");
|
||||
// check datatypes
|
||||
constexpr bool is_half = std::is_same<X, half_t>::value;
|
||||
constexpr bool is_float = std::is_same<X, float>::value;
|
||||
static_assert(is_half || is_float, "Only half and float can be casted.");
|
||||
|
||||
return run_cast_to_f8<T, negative_zero_nan, clip, stoch>(x, rng);
|
||||
return run_cast_to_f8<X, Y, negative_zero_nan, clip, stoch>(x, rng);
|
||||
}
|
||||
|
||||
template <typename T, bool negative_zero_nan>
|
||||
__host__ __device__ T cast_from_f8(f8_t x)
|
||||
template <typename X, typename Y, bool negative_zero_nan>
|
||||
__host__ __device__ Y cast_from_f8(X x)
|
||||
{
|
||||
// check datatype
|
||||
constexpr bool is_half = std::is_same<T, half_t>::value;
|
||||
constexpr bool is_float = std::is_same<T, float>::value;
|
||||
constexpr bool is_half = std::is_same<Y, half_t>::value;
|
||||
constexpr bool is_float = std::is_same<Y, float>::value;
|
||||
static_assert(is_half || is_float, "only half and float are supported.");
|
||||
|
||||
// check if x is 0.0
|
||||
if(x == 0)
|
||||
return static_cast<T>(0);
|
||||
|
||||
return run_cast_from_f8<T, negative_zero_nan>(x);
|
||||
return run_cast_from_f8<X, Y, negative_zero_nan>(x);
|
||||
}
|
||||
|
||||
} // namespace ck::utils
|
||||
#endif
|
||||
|
||||
@@ -80,6 +80,7 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, int8_t>(int8_
|
||||
return type_convert<bhalf_t>(x_fp32);
|
||||
}
|
||||
|
||||
#if defined CK_ENABLE_FP8
|
||||
// convert fp32 to fp8
|
||||
template <>
|
||||
inline __host__ __device__ f8_t type_convert<f8_t, float>(float x)
|
||||
@@ -88,8 +89,9 @@ inline __host__ __device__ f8_t type_convert<f8_t, float>(float x)
|
||||
constexpr bool clip = true;
|
||||
constexpr f8_rounding_mode rm = f8_rounding_mode::standard;
|
||||
constexpr uint32_t rng = 0;
|
||||
return utils::cast_to_f8<float, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
|
||||
x, rng);
|
||||
return utils::
|
||||
cast_to_f8<float, f8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(x,
|
||||
rng);
|
||||
}
|
||||
|
||||
// convert fp8 to fp32
|
||||
@@ -97,7 +99,7 @@ template <>
|
||||
inline __host__ __device__ float type_convert<float, f8_t>(f8_t x)
|
||||
{
|
||||
constexpr bool negative_zero_nan = true;
|
||||
return utils::cast_from_f8<float, negative_zero_nan>(x);
|
||||
return utils::cast_from_f8<f8_t, float, negative_zero_nan>(x);
|
||||
}
|
||||
|
||||
// convert fp16 to fp8
|
||||
@@ -108,8 +110,9 @@ inline __host__ __device__ f8_t type_convert<f8_t, half_t>(half_t x)
|
||||
constexpr bool clip = true;
|
||||
constexpr f8_rounding_mode rm = f8_rounding_mode::standard;
|
||||
constexpr uint32_t rng = 0;
|
||||
return utils::cast_to_f8<half_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
|
||||
x, rng);
|
||||
return utils::
|
||||
cast_to_f8<half_t, f8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
|
||||
x, rng);
|
||||
}
|
||||
|
||||
// convert fp8 to fp16
|
||||
@@ -117,8 +120,53 @@ template <>
|
||||
inline __host__ __device__ half_t type_convert<half_t, f8_t>(f8_t x)
|
||||
{
|
||||
constexpr bool negative_zero_nan = true;
|
||||
return utils::cast_from_f8<half_t, negative_zero_nan>(x);
|
||||
return utils::cast_from_f8<f8_t, half_t, negative_zero_nan>(x);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined CK_ENABLE_BF8
|
||||
// convert fp32 to bf8
|
||||
template <>
|
||||
inline __host__ __device__ bf8_t type_convert<bf8_t, float>(float x)
|
||||
{
|
||||
constexpr bool negative_zero_nan = true;
|
||||
constexpr bool clip = true;
|
||||
constexpr f8_rounding_mode rm = f8_rounding_mode::standard;
|
||||
constexpr uint32_t rng = 0;
|
||||
return utils::
|
||||
cast_to_f8<float, bf8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
|
||||
x, rng);
|
||||
}
|
||||
|
||||
// convert bf8 to fp32
|
||||
template <>
|
||||
inline __host__ __device__ float type_convert<float, bf8_t>(bf8_t x)
|
||||
{
|
||||
constexpr bool negative_zero_nan = true;
|
||||
return utils::cast_from_f8<bf8_t, float, negative_zero_nan>(x);
|
||||
}
|
||||
|
||||
// convert fp16 to bf8
|
||||
template <>
|
||||
inline __host__ __device__ bf8_t type_convert<bf8_t, half_t>(half_t x)
|
||||
{
|
||||
constexpr bool negative_zero_nan = true;
|
||||
constexpr bool clip = true;
|
||||
constexpr f8_rounding_mode rm = f8_rounding_mode::standard;
|
||||
constexpr uint32_t rng = 0;
|
||||
return utils::
|
||||
cast_to_f8<half_t, bf8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
|
||||
x, rng);
|
||||
}
|
||||
|
||||
// convert bf8 to fp16
|
||||
template <>
|
||||
inline __host__ __device__ half_t type_convert<half_t, bf8_t>(bf8_t x)
|
||||
{
|
||||
constexpr bool negative_zero_nan = true;
|
||||
return utils::cast_from_f8<bf8_t, half_t, negative_zero_nan>(x);
|
||||
}
|
||||
#endif
|
||||
|
||||
// Declare a template function for bf16 conversion using RTN
|
||||
template <typename Y, typename X>
|
||||
@@ -181,6 +229,7 @@ inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn<bhalf_t, half_t>(h
|
||||
template <typename Y, typename X>
|
||||
__host__ __device__ constexpr Y f8_convert_sr(X x);
|
||||
|
||||
#if defined CK_ENABLE_FP8
|
||||
// convert fp32 to fp8 with stochastic rounding
|
||||
template <>
|
||||
inline __host__ __device__ f8_t f8_convert_sr<f8_t, float>(float x)
|
||||
@@ -191,8 +240,9 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, float>(float x)
|
||||
constexpr int seed = 42;
|
||||
// as thread id is not available on host, use 0 for prn generation
|
||||
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
|
||||
return utils::cast_to_f8<float, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
|
||||
x, rng);
|
||||
return utils::
|
||||
cast_to_f8<float, f8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(x,
|
||||
rng);
|
||||
}
|
||||
|
||||
// convert fp16 to fp8 with stochastic rounding
|
||||
@@ -205,8 +255,42 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, half_t>(half_t x)
|
||||
constexpr int seed = 42;
|
||||
// as thread id is not available on host, use 0 for prn generation
|
||||
uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<uintptr_t>(&x), x);
|
||||
return utils::cast_to_f8<half_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
|
||||
x, rng);
|
||||
return utils::
|
||||
cast_to_f8<half_t, f8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
|
||||
x, rng);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined CK_ENABLE_BF8
|
||||
// convert fp32 to bf8 with stochastic rounding
|
||||
template <>
|
||||
inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, float>(float x)
|
||||
{
|
||||
constexpr bool negative_zero_nan = true;
|
||||
constexpr bool clip = true;
|
||||
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
|
||||
constexpr int seed = 42;
|
||||
// as thread id is not available on host, use 0 for prn generation
|
||||
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
|
||||
return utils::
|
||||
cast_to_f8<float, bf8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
|
||||
x, rng);
|
||||
}
|
||||
|
||||
// convert fp16 to bf8 with stochastic rounding
|
||||
template <>
|
||||
inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, half_t>(half_t x)
|
||||
{
|
||||
constexpr bool negative_zero_nan = true;
|
||||
constexpr bool clip = true;
|
||||
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
|
||||
constexpr int seed = 42;
|
||||
// as thread id is not available on host, use 0 for prn generation
|
||||
uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<uintptr_t>(&x), x);
|
||||
return utils::
|
||||
cast_to_f8<half_t, bf8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
|
||||
x, rng);
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace ck
|
||||
|
||||
@@ -17,10 +17,15 @@ namespace instance {
|
||||
using F64 = double;
|
||||
using F32 = float;
|
||||
using F16 = ck::half_t;
|
||||
using F8 = ck::f8_t;
|
||||
using BF16 = ck::bhalf_t;
|
||||
using I8 = int8_t;
|
||||
using I32 = int32_t;
|
||||
#if defined CK_ENABLE_FP8
|
||||
using F8 = ck::f8_t;
|
||||
#endif
|
||||
#if defined CK_ENABLE_BF8
|
||||
using BF8 = ck::bf8_t;
|
||||
#endif
|
||||
|
||||
using Empty_Tuple = ck::Tuple<>;
|
||||
|
||||
|
||||
@@ -45,6 +45,7 @@ void add_device_gemm_multiply_add_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_
|
||||
PassThrough,
|
||||
MultiplyAdd>>>&);
|
||||
|
||||
#if defined CK_ENABLE_FP8
|
||||
void add_device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_kn_mn_mn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row,
|
||||
Row,
|
||||
@@ -70,6 +71,7 @@ void add_device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_nk_mn_mn_m
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyAdd>>>&);
|
||||
#endif
|
||||
|
||||
// GEMM + Multiply + Add
|
||||
template <typename ALayout,
|
||||
@@ -131,6 +133,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
|
||||
}
|
||||
}
|
||||
|
||||
#if defined CK_ENABLE_FP8
|
||||
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, f8_t> &&
|
||||
is_same_v<D0DataType, float> && is_same_v<D1DataType, float> &&
|
||||
is_same_v<EDataType, half_t>)
|
||||
@@ -150,6 +153,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
return op_ptrs;
|
||||
}
|
||||
|
||||
@@ -57,6 +57,7 @@ void add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances(
|
||||
DeviceGemmSplitK<Row, Col, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
#if defined CK_ENABLE_FP8
|
||||
void add_device_gemm_xdl_splitk_f8_f16_f16_km_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmSplitK<Col, Row, Row, F8, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
@@ -96,6 +97,7 @@ void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmSplitK<Row, Col, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
#endif
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
@@ -176,6 +178,7 @@ struct DeviceOperationInstanceFactory<
|
||||
add_device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
#if defined CK_ENABLE_FP8
|
||||
else if constexpr(is_same_v<ADataType, f8_t> && is_same_v<BDataType, half_t> &&
|
||||
is_same_v<CDataType, half_t>)
|
||||
{
|
||||
@@ -224,6 +227,7 @@ struct DeviceOperationInstanceFactory<
|
||||
add_device_gemm_xdl_splitk_f16_f8_f16_km_nk_mn_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
return op_ptrs;
|
||||
}
|
||||
|
||||
@@ -230,5 +230,99 @@ check_err(const Range& out,
|
||||
return res;
|
||||
}
|
||||
|
||||
#if defined CK_ENABLE_FP8
|
||||
template <typename Range, typename RefRange>
|
||||
std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
|
||||
std::is_same_v<ranges::range_value_t<Range>, f8_t>),
|
||||
bool>
|
||||
check_err(const Range& out,
|
||||
const RefRange& ref,
|
||||
const std::string& msg = "Error: Incorrect results!",
|
||||
double rtol = 1e-3,
|
||||
double atol = 1e-3)
|
||||
{
|
||||
if(out.size() != ref.size())
|
||||
{
|
||||
std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size()
|
||||
<< std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
bool res{true};
|
||||
int err_count = 0;
|
||||
double err = 0;
|
||||
double max_err = std::numeric_limits<float>::min();
|
||||
for(std::size_t i = 0; i < ref.size(); ++i)
|
||||
{
|
||||
const double o = type_convert<float>(*std::next(std::begin(out), i));
|
||||
const double r = type_convert<float>(*std::next(std::begin(ref), i));
|
||||
err = std::abs(o - r);
|
||||
if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r))
|
||||
{
|
||||
max_err = err > max_err ? err : max_err;
|
||||
err_count++;
|
||||
if(err_count < 5)
|
||||
{
|
||||
std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
|
||||
<< "] != ref[" << i << "]: " << o << " != " << r << std::endl;
|
||||
}
|
||||
res = false;
|
||||
}
|
||||
}
|
||||
if(!res)
|
||||
{
|
||||
std::cerr << std::setw(12) << std::setprecision(7) << "max err: " << max_err << std::endl;
|
||||
}
|
||||
return res;
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined CK_ENABLE_BF8
|
||||
template <typename Range, typename RefRange>
|
||||
std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
|
||||
std::is_same_v<ranges::range_value_t<Range>, bf8_t>),
|
||||
bool>
|
||||
check_err(const Range& out,
|
||||
const RefRange& ref,
|
||||
const std::string& msg = "Error: Incorrect results!",
|
||||
double rtol = 1e-3,
|
||||
double atol = 1e-3)
|
||||
{
|
||||
if(out.size() != ref.size())
|
||||
{
|
||||
std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size()
|
||||
<< std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
bool res{true};
|
||||
int err_count = 0;
|
||||
double err = 0;
|
||||
double max_err = std::numeric_limits<float>::min();
|
||||
for(std::size_t i = 0; i < ref.size(); ++i)
|
||||
{
|
||||
const double o = type_convert<float>(*std::next(std::begin(out), i));
|
||||
const double r = type_convert<float>(*std::next(std::begin(ref), i));
|
||||
err = std::abs(o - r);
|
||||
if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r))
|
||||
{
|
||||
max_err = err > max_err ? err : max_err;
|
||||
err_count++;
|
||||
if(err_count < 5)
|
||||
{
|
||||
std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
|
||||
<< "] != ref[" << i << "]: " << o << " != " << r << std::endl;
|
||||
}
|
||||
res = false;
|
||||
}
|
||||
}
|
||||
if(!res)
|
||||
{
|
||||
std::cerr << std::setw(12) << std::setprecision(7) << "max err: " << max_err << std::endl;
|
||||
}
|
||||
return res;
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace utils
|
||||
} // namespace ck
|
||||
|
||||
@@ -1,7 +1,13 @@
|
||||
add_instance_library(device_gemm_multiply_add_instance
|
||||
device_gemm_multiply_add_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instance.cpp
|
||||
device_gemm_multiply_add_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instance.cpp
|
||||
set(GEMM_MULTIPLY_ADD_INSTANCES)
|
||||
|
||||
device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_kn_mn_mn_mn_instance.cpp
|
||||
device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_nk_mn_mn_mn_instance.cpp
|
||||
)
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
list(APPEND GEMM_MULTIPLY_ADD_INSTANCES device_gemm_multiply_add_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instance.cpp)
|
||||
list(APPEND GEMM_MULTIPLY_ADD_INSTANCES device_gemm_multiply_add_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instance.cpp)
|
||||
endif()
|
||||
|
||||
if((DTYPES MATCHES "fp16" AND DTYPES MATCHES "fp8") OR NOT DEFINED DTYPES)
|
||||
list(APPEND GEMM_MULTIPLY_ADD_INSTANCES device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_kn_mn_mn_mn_instance.cpp)
|
||||
list(APPEND GEMM_MULTIPLY_ADD_INSTANCES device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_nk_mn_mn_mn_instance.cpp)
|
||||
endif()
|
||||
|
||||
add_instance_library(device_gemm_multiply_add_instance ${GEMM_MULTIPLY_ADD_INSTANCES})
|
||||
|
||||
@@ -14,7 +14,7 @@ if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instance.cpp)
|
||||
endif()
|
||||
|
||||
if(DTYPES MATCHES "fp16" OR DTYPES MATCHES "fp8" OR NOT DEFINED DTYPES)
|
||||
if((DTYPES MATCHES "fp16" AND DTYPES MATCHES "fp8") OR NOT DEFINED DTYPES)
|
||||
list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instance.cpp)
|
||||
list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f8_f16_f16_mk_nk_mn_instance.cpp)
|
||||
list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f8_f16_f16_km_kn_mn_instance.cpp)
|
||||
|
||||
@@ -214,6 +214,7 @@ bool profile_gemm_splitk_impl(int do_verification,
|
||||
<< " TFlops, " << gb_per_sec << " GB/s, " << op_name << ", KBatch "
|
||||
<< kbatch_curr << std::endl;
|
||||
|
||||
#if defined CK_ENABLE_FP8
|
||||
// set softer tolerances for fp8
|
||||
if constexpr(is_same_v<ADataType, f8_t> || is_same_v<BDataType, f8_t> ||
|
||||
is_same_v<CDataType, f8_t>)
|
||||
@@ -226,8 +227,11 @@ bool profile_gemm_splitk_impl(int do_verification,
|
||||
}
|
||||
else
|
||||
{
|
||||
#endif
|
||||
pass = pass & ck::utils::check_err(c_m_n_device_result, c_m_n_host_result);
|
||||
#if defined CK_ENABLE_FP8
|
||||
}
|
||||
#endif
|
||||
|
||||
if(tflops > best_tflops)
|
||||
{
|
||||
|
||||
@@ -59,9 +59,11 @@ int profile_gemm_multiply_add(int argc, char* argv[])
|
||||
const int StrideD1 = std::stoi(argv[14]);
|
||||
const int StrideE = std::stoi(argv[15]);
|
||||
|
||||
using F8 = ck::f8_t;
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
#if defined CK_ENABLE_FP8
|
||||
using F8 = ck::f8_t;
|
||||
#endif
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
@@ -132,6 +134,7 @@ int profile_gemm_multiply_add(int argc, char* argv[])
|
||||
{
|
||||
return profile(F16{}, F16{}, F32{}, F16{}, F16{}, F16{}, Row{}, Col{}, Row{}, Row{}, Row{});
|
||||
}
|
||||
#if defined CK_ENABLE_FP8
|
||||
else if(data_type == MatrixDataType::F16_F8_F32_F32_F16 &&
|
||||
layout == MatrixLayout::MK_KN_MN_MN_MN)
|
||||
{
|
||||
@@ -142,6 +145,7 @@ int profile_gemm_multiply_add(int argc, char* argv[])
|
||||
{
|
||||
return profile(F16{}, F8{}, F32{}, F32{}, F32{}, F16{}, Row{}, Col{}, Row{}, Row{}, Row{});
|
||||
}
|
||||
#endif
|
||||
else
|
||||
{
|
||||
std::cout << "this data_type & layout is not implemented" << std::endl;
|
||||
|
||||
@@ -67,7 +67,9 @@ int profile_gemm_splitk(int argc, char* argv[])
|
||||
|
||||
using F32 = float;
|
||||
using F16 = ck::half_t;
|
||||
using F8 = ck::f8_t;
|
||||
#if defined CK_ENABLE_FP8
|
||||
using F8 = ck::f8_t;
|
||||
#endif
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
@@ -146,6 +148,7 @@ int profile_gemm_splitk(int argc, char* argv[])
|
||||
{
|
||||
return profile(F16{}, F16{}, F32{}, F16{}, Col{}, Col{}, Row{});
|
||||
}
|
||||
#if defined CK_ENABLE_FP8
|
||||
else if(data_type == GemmDataType::F8_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN)
|
||||
{
|
||||
return profile(F8{}, F16{}, F32{}, F16{}, Row{}, Row{}, Row{});
|
||||
@@ -178,6 +181,7 @@ int profile_gemm_splitk(int argc, char* argv[])
|
||||
{
|
||||
return profile(F16{}, F8{}, F32{}, F16{}, Col{}, Col{}, Row{});
|
||||
}
|
||||
#endif
|
||||
else
|
||||
{
|
||||
std::cout << "this data_type & layout is not implemented" << std::endl;
|
||||
|
||||
@@ -3,5 +3,12 @@ if (USE_BITINT_EXTENSION_INT4)
|
||||
target_link_libraries(test_int4 PRIVATE utility)
|
||||
endif()
|
||||
|
||||
add_gtest_executable(test_fp8 fp8.cpp)
|
||||
target_link_libraries(test_fp8 PRIVATE utility)
|
||||
if(DTYPES MATCHES "fp8" OR NOT DEFINED DTYPES)
|
||||
add_gtest_executable(test_f8 f8.cpp)
|
||||
target_link_libraries(test_f8 PRIVATE utility)
|
||||
endif()
|
||||
|
||||
if(DTYPES MATCHES "bf8" OR NOT DEFINED DTYPES)
|
||||
add_gtest_executable(test_bf8 bf8.cpp)
|
||||
target_link_libraries(test_bf8 PRIVATE utility)
|
||||
endif()
|
||||
|
||||
158
test/data_type/bf8.cpp
Normal file
158
test/data_type/bf8.cpp
Normal file
@@ -0,0 +1,158 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ck/utility/data_type.hpp"
|
||||
#include "ck/utility/type_convert.hpp"
|
||||
|
||||
using ck::bf8_t;
|
||||
using ck::f8_convert_sr;
|
||||
using ck::half_t;
|
||||
using ck::type_convert;
|
||||
|
||||
TEST(BF8, NumericLimits)
|
||||
{
|
||||
// constants given for negative zero nan mode
|
||||
EXPECT_EQ(ck::NumericLimits<bf8_t>::Min(), type_convert<bf8_t>(0x04));
|
||||
EXPECT_EQ(ck::NumericLimits<bf8_t>::Max(), type_convert<bf8_t>(0x7F));
|
||||
EXPECT_EQ(ck::NumericLimits<bf8_t>::Lowest(), type_convert<bf8_t>(0xFF));
|
||||
EXPECT_EQ(ck::NumericLimits<bf8_t>::QuietNaN(), type_convert<bf8_t>(0x80));
|
||||
}
|
||||
|
||||
TEST(BF8, ConvertFP32Nearest)
|
||||
{
|
||||
// fix the tolerance value
|
||||
float abs_tol = 1e-6;
|
||||
// convert 0 float to bf8 and back, check if holds
|
||||
ASSERT_NEAR(0.0f, type_convert<float>(type_convert<bf8_t>(0.0f)), abs_tol);
|
||||
// convert minimal float to bf8 and back, check if holds
|
||||
ASSERT_NEAR(std::numeric_limits<float>::min(),
|
||||
type_convert<float>(type_convert<bf8_t>(std::numeric_limits<float>::min())),
|
||||
abs_tol);
|
||||
// convert maximal bf8_t to float and check if equal to 57344.0
|
||||
ASSERT_NEAR(57344.0f, type_convert<float>(type_convert<bf8_t>(57344.0f)), abs_tol);
|
||||
// convert maximal float to bf8 and back, check if clipped to 57344.0
|
||||
ASSERT_NEAR(57344.0f,
|
||||
type_convert<float>(type_convert<bf8_t>(std::numeric_limits<float>::max())),
|
||||
abs_tol);
|
||||
// convert inf float to bf8_t and check if it is qNan
|
||||
ASSERT_NEAR(type_convert<bf8_t>(0x80),
|
||||
type_convert<bf8_t>(std::numeric_limits<float>::infinity()),
|
||||
abs_tol);
|
||||
// positive norm float value to bf8 and back, check if holds
|
||||
float pos_float = 0.0000762939f;
|
||||
ASSERT_NEAR(pos_float, type_convert<float>(type_convert<bf8_t>(pos_float)), abs_tol);
|
||||
// negative norm float value to bf8 and back, check if holds
|
||||
float neg_float = -0.0000610351f;
|
||||
ASSERT_NEAR(neg_float, type_convert<float>(type_convert<bf8_t>(neg_float)), abs_tol);
|
||||
// positive subnorm float value to bf8 and back, check if holds
|
||||
pos_float = 0.0000305175f;
|
||||
ASSERT_NEAR(pos_float, type_convert<float>(type_convert<bf8_t>(pos_float)), abs_tol);
|
||||
// negative subnorm float value to bf8 and back, check if holds
|
||||
neg_float = -0.0000152587f;
|
||||
ASSERT_NEAR(neg_float, type_convert<float>(type_convert<bf8_t>(neg_float)), abs_tol);
|
||||
}
|
||||
|
||||
TEST(BF8, ConvertFP32Stochastic)
|
||||
{
|
||||
// fix the tolerance value
|
||||
float abs_tol = 1e-6;
|
||||
// convert 0 float to bf8 and back, check if holds
|
||||
ASSERT_NEAR(0.0f, type_convert<float>(f8_convert_sr<bf8_t>(0.0f)), abs_tol);
|
||||
// convert minimal float to bf8 and back, check if holds
|
||||
ASSERT_NEAR(std::numeric_limits<float>::min(),
|
||||
type_convert<float>(f8_convert_sr<bf8_t>(std::numeric_limits<float>::min())),
|
||||
abs_tol);
|
||||
// convert maximal bf8_t to float and check if equal to 57344.0
|
||||
ASSERT_NEAR(57344.0f, type_convert<float>(f8_convert_sr<bf8_t>(57344.0f)), abs_tol);
|
||||
// convert maximal float to bf8 and back, check if clipped to 57344.0
|
||||
ASSERT_NEAR(57344.0f,
|
||||
type_convert<float>(f8_convert_sr<bf8_t>(std::numeric_limits<float>::max())),
|
||||
abs_tol);
|
||||
// convert inf float to bf8_t and check if it is qNan
|
||||
ASSERT_NEAR(type_convert<bf8_t>(0x80),
|
||||
f8_convert_sr<bf8_t>(std::numeric_limits<float>::infinity()),
|
||||
abs_tol);
|
||||
// positive norm float value to bf8 and back, check if holds
|
||||
float pos_float = 0.0000762939f;
|
||||
ASSERT_NEAR(pos_float, type_convert<float>(f8_convert_sr<bf8_t>(pos_float)), abs_tol);
|
||||
// negative norm float value to bf8 and back, check if holds
|
||||
float neg_float = -0.0000610351f;
|
||||
ASSERT_NEAR(neg_float, type_convert<float>(f8_convert_sr<bf8_t>(neg_float)), abs_tol);
|
||||
// positive subnorm float value to bf8 and back, check if holds
|
||||
pos_float = 0.0000305175f;
|
||||
ASSERT_NEAR(pos_float, type_convert<float>(f8_convert_sr<bf8_t>(pos_float)), abs_tol);
|
||||
// negative subnorm float value to bf8 and back, check if holds
|
||||
neg_float = -0.0000152587f;
|
||||
ASSERT_NEAR(neg_float, type_convert<float>(f8_convert_sr<bf8_t>(neg_float)), abs_tol);
|
||||
}
|
||||
|
||||
TEST(BF8, ConvertFP16Nearest)
|
||||
{
|
||||
// fix the tolerance value
|
||||
float abs_tol = 1e-3;
|
||||
// convert 0 fp16 to bf8 and back, check if holds
|
||||
ASSERT_NEAR(half_t{0.0}, type_convert<half_t>(type_convert<bf8_t>(half_t{0.0})), abs_tol);
|
||||
// convert minimal fp16 to bf8 and back, check if holds
|
||||
ASSERT_NEAR(ck::NumericLimits<half_t>::Min(),
|
||||
type_convert<half_t>(type_convert<bf8_t>(ck::NumericLimits<half_t>::Min())),
|
||||
abs_tol);
|
||||
// convert maximal bf8_t to fp16 and check if equal to 57344.0
|
||||
ASSERT_NEAR(
|
||||
half_t{57344.0}, type_convert<half_t>(type_convert<bf8_t>(half_t{57344.0})), abs_tol);
|
||||
// convert maximal fp16 to bf8 and back, check if clipped to 57344.0
|
||||
ASSERT_NEAR(half_t{57344.0},
|
||||
type_convert<half_t>(type_convert<bf8_t>(ck::NumericLimits<half_t>::Max())),
|
||||
abs_tol);
|
||||
// convert QuietNaN fp16 to bf8_t and check if it is QuietNaN
|
||||
ASSERT_NEAR(type_convert<bf8_t>(0x80),
|
||||
type_convert<bf8_t>(ck::NumericLimits<half_t>::QuietNaN()),
|
||||
abs_tol);
|
||||
// positive norm fp16 value to bf8 and back, check if holds
|
||||
half_t pos_half = half_t{0.0000762939};
|
||||
ASSERT_NEAR(pos_half, type_convert<half_t>(type_convert<bf8_t>(pos_half)), abs_tol);
|
||||
// negative norm fp16 value to bf8 and back, check if holds
|
||||
half_t neg_half = half_t{-0.0000610351};
|
||||
ASSERT_NEAR(neg_half, type_convert<half_t>(type_convert<bf8_t>(neg_half)), abs_tol);
|
||||
// positive subnorm fp16 value to bf8 and back, check if holds
|
||||
pos_half = half_t{0.0000305175};
|
||||
ASSERT_NEAR(pos_half, type_convert<half_t>(type_convert<bf8_t>(pos_half)), abs_tol);
|
||||
// negative subnorm fp16 value to bf8 and back, check if holds
|
||||
neg_half = half_t{-0.0000152587};
|
||||
ASSERT_NEAR(neg_half, type_convert<half_t>(type_convert<bf8_t>(neg_half)), abs_tol);
|
||||
}
|
||||
|
||||
TEST(BF8, ConvertFP16Stochastic)
|
||||
{
|
||||
// fix the tolerance value
|
||||
float abs_tol = 1e-3;
|
||||
// convert 0 fp16 to bf8 and back, check if holds
|
||||
ASSERT_NEAR(half_t{0.0}, type_convert<half_t>(f8_convert_sr<bf8_t>(half_t{0.0})), abs_tol);
|
||||
// convert minimal fp16 to bf8 and back, check if holds
|
||||
ASSERT_NEAR(ck::NumericLimits<half_t>::Min(),
|
||||
type_convert<half_t>(f8_convert_sr<bf8_t>(ck::NumericLimits<half_t>::Min())),
|
||||
abs_tol);
|
||||
// convert maximal bf8_t to fp16 and check if equal to 57344.0
|
||||
ASSERT_NEAR(
|
||||
half_t{57344.0}, type_convert<half_t>(f8_convert_sr<bf8_t>(half_t{57344.0})), abs_tol);
|
||||
// convert maximal fp16 to bf8 and back, check if clipped to 57344.0
|
||||
ASSERT_NEAR(half_t{57344.0},
|
||||
type_convert<half_t>(f8_convert_sr<bf8_t>(ck::NumericLimits<half_t>::Max())),
|
||||
abs_tol);
|
||||
// convert QuietNaN fp16 to bf8_t and check if it is QuietNaN
|
||||
ASSERT_NEAR(type_convert<bf8_t>(0x80),
|
||||
f8_convert_sr<bf8_t>(ck::NumericLimits<half_t>::QuietNaN()),
|
||||
abs_tol);
|
||||
// positive norm fp16 value to bf8 and back, check if holds
|
||||
half_t pos_half = half_t{0.0000762939};
|
||||
ASSERT_NEAR(pos_half, type_convert<half_t>(f8_convert_sr<bf8_t>(pos_half)), abs_tol);
|
||||
// negative norm fp16 value to bf8 and back, check if holds
|
||||
half_t neg_half = half_t{-0.0000610351};
|
||||
ASSERT_NEAR(neg_half, type_convert<half_t>(f8_convert_sr<bf8_t>(neg_half)), abs_tol);
|
||||
// positive subnorm fp16 value to bf8 and back, check if holds
|
||||
pos_half = half_t{0.0000305175};
|
||||
ASSERT_NEAR(pos_half, type_convert<half_t>(f8_convert_sr<bf8_t>(pos_half)), abs_tol);
|
||||
// negative subnorm fp16 value to bf8 and back, check if holds
|
||||
neg_half = half_t{-0.0000152587};
|
||||
ASSERT_NEAR(neg_half, type_convert<half_t>(f8_convert_sr<bf8_t>(neg_half)), abs_tol);
|
||||
}
|
||||
@@ -12,10 +12,11 @@ using ck::type_convert;
|
||||
|
||||
TEST(FP8, NumericLimits)
|
||||
{
|
||||
EXPECT_EQ(ck::NumericLimits<f8_t>::Min(), 0x08);
|
||||
EXPECT_EQ(ck::NumericLimits<f8_t>::Max(), 0x77);
|
||||
EXPECT_EQ(ck::NumericLimits<f8_t>::Lowest(), 0xF7);
|
||||
EXPECT_EQ(ck::NumericLimits<f8_t>::QuietNaN(), 0x80);
|
||||
// constants given for negative zero nan mode
|
||||
EXPECT_EQ(ck::NumericLimits<f8_t>::Min(), type_convert<f8_t>(0x08));
|
||||
EXPECT_EQ(ck::NumericLimits<f8_t>::Max(), type_convert<f8_t>(0x7F));
|
||||
EXPECT_EQ(ck::NumericLimits<f8_t>::Lowest(), type_convert<f8_t>(0xFF));
|
||||
EXPECT_EQ(ck::NumericLimits<f8_t>::QuietNaN(), type_convert<f8_t>(0x80));
|
||||
}
|
||||
|
||||
TEST(FP8, ConvertFP32Nearest)
|
||||
@@ -35,12 +36,20 @@ TEST(FP8, ConvertFP32Nearest)
|
||||
type_convert<float>(type_convert<f8_t>(std::numeric_limits<float>::max())),
|
||||
abs_tol);
|
||||
// convert inf float to f8_t and check if it is qNan
|
||||
ASSERT_NEAR(0x80, type_convert<f8_t>(std::numeric_limits<float>::infinity()), abs_tol);
|
||||
// positive float value to fp8 and back, check if holds
|
||||
float pos_float = 0.0078125f;
|
||||
ASSERT_NEAR(type_convert<f8_t>(0x80),
|
||||
type_convert<f8_t>(std::numeric_limits<float>::infinity()),
|
||||
abs_tol);
|
||||
// positive norm float value to fp8 and back, check if holds
|
||||
float pos_float = 0.017578125f;
|
||||
ASSERT_NEAR(pos_float, type_convert<float>(type_convert<f8_t>(pos_float)), abs_tol);
|
||||
// negative float value to fp8 and back, check if holds
|
||||
float neg_float = -0.0156250f;
|
||||
// negative norm float value to fp8 and back, check if holds
|
||||
float neg_float = -0.015625f;
|
||||
ASSERT_NEAR(neg_float, type_convert<float>(type_convert<f8_t>(neg_float)), abs_tol);
|
||||
// positive subnorm float value to fp8 and back, check if holds
|
||||
pos_float = 0.00390625f;
|
||||
ASSERT_NEAR(pos_float, type_convert<float>(type_convert<f8_t>(pos_float)), abs_tol);
|
||||
// negative subnorm float value to fp8 and back, check if holds
|
||||
neg_float = -0.001953125f;
|
||||
ASSERT_NEAR(neg_float, type_convert<float>(type_convert<f8_t>(neg_float)), abs_tol);
|
||||
}
|
||||
|
||||
@@ -61,12 +70,20 @@ TEST(FP8, ConvertFP32Stochastic)
|
||||
type_convert<float>(f8_convert_sr<f8_t>(std::numeric_limits<float>::max())),
|
||||
abs_tol);
|
||||
// convert inf float to f8_t and check if it is qNan
|
||||
ASSERT_NEAR(0x80, f8_convert_sr<f8_t>(std::numeric_limits<float>::infinity()), abs_tol);
|
||||
// positive float value to fp8 and back, check if holds
|
||||
float pos_float = 0.0078125f;
|
||||
ASSERT_NEAR(type_convert<f8_t>(0x80),
|
||||
f8_convert_sr<f8_t>(std::numeric_limits<float>::infinity()),
|
||||
abs_tol);
|
||||
// positive norm float value to fp8 and back, check if holds
|
||||
float pos_float = 0.017578125f;
|
||||
ASSERT_NEAR(pos_float, type_convert<float>(f8_convert_sr<f8_t>(pos_float)), abs_tol);
|
||||
// negative float value to fp8 and back, check if holds
|
||||
float neg_float = -0.0156250f;
|
||||
// negative norm float value to fp8 and back, check if holds
|
||||
float neg_float = -0.015625f;
|
||||
ASSERT_NEAR(neg_float, type_convert<float>(f8_convert_sr<f8_t>(neg_float)), abs_tol);
|
||||
// positive subnorm float value to fp8 and back, check if holds
|
||||
pos_float = 0.00390625f;
|
||||
ASSERT_NEAR(pos_float, type_convert<float>(f8_convert_sr<f8_t>(pos_float)), abs_tol);
|
||||
// negative subnorm float value to fp8 and back, check if holds
|
||||
neg_float = -0.001953125f;
|
||||
ASSERT_NEAR(neg_float, type_convert<float>(f8_convert_sr<f8_t>(neg_float)), abs_tol);
|
||||
}
|
||||
|
||||
@@ -87,12 +104,20 @@ TEST(FP8, ConvertFP16Nearest)
|
||||
type_convert<half_t>(type_convert<f8_t>(ck::NumericLimits<half_t>::Max())),
|
||||
abs_tol);
|
||||
// convert QuietNaN fp16 to f8_t and check if it is QuietNaN
|
||||
ASSERT_NEAR(0x80, type_convert<f8_t>(ck::NumericLimits<half_t>::QuietNaN()), abs_tol);
|
||||
// positive fp16 value to fp8 and back, check if holds
|
||||
half_t pos_half = half_t{0.0078125};
|
||||
ASSERT_NEAR(type_convert<f8_t>(0x80),
|
||||
type_convert<f8_t>(ck::NumericLimits<half_t>::QuietNaN()),
|
||||
abs_tol);
|
||||
// positive norm fp16 value to fp8 and back, check if holds
|
||||
half_t pos_half = half_t{0.017578125};
|
||||
ASSERT_NEAR(pos_half, type_convert<half_t>(type_convert<f8_t>(pos_half)), abs_tol);
|
||||
// negative fp16 value to fp8 and back, check if holds
|
||||
half_t neg_half = half_t{-0.0156250};
|
||||
// negative norm fp16 value to fp8 and back, check if holds
|
||||
half_t neg_half = half_t{-0.015625};
|
||||
ASSERT_NEAR(neg_half, type_convert<half_t>(type_convert<f8_t>(neg_half)), abs_tol);
|
||||
// positive subnorm fp16 value to fp8 and back, check if holds
|
||||
pos_half = half_t{0.00390625};
|
||||
ASSERT_NEAR(pos_half, type_convert<half_t>(type_convert<f8_t>(pos_half)), abs_tol);
|
||||
// negative subnorm fp16 value to fp8 and back, check if holds
|
||||
neg_half = half_t{-0.001953125};
|
||||
ASSERT_NEAR(neg_half, type_convert<half_t>(type_convert<f8_t>(neg_half)), abs_tol);
|
||||
}
|
||||
|
||||
@@ -113,11 +138,19 @@ TEST(FP8, ConvertFP16Stochastic)
|
||||
type_convert<half_t>(f8_convert_sr<f8_t>(ck::NumericLimits<half_t>::Max())),
|
||||
abs_tol);
|
||||
// convert QuietNaN fp16 to f8_t and check if it is QuietNaN
|
||||
ASSERT_NEAR(0x80, f8_convert_sr<f8_t>(ck::NumericLimits<half_t>::QuietNaN()), abs_tol);
|
||||
// positive fp16 value to fp8 and back, check if holds
|
||||
half_t pos_half = half_t{0.0078125};
|
||||
ASSERT_NEAR(type_convert<f8_t>(0x80),
|
||||
f8_convert_sr<f8_t>(ck::NumericLimits<half_t>::QuietNaN()),
|
||||
abs_tol);
|
||||
// positive norm fp16 value to fp8 and back, check if holds
|
||||
half_t pos_half = half_t{0.017578125};
|
||||
ASSERT_NEAR(pos_half, type_convert<half_t>(f8_convert_sr<f8_t>(pos_half)), abs_tol);
|
||||
// negative fp16 value to fp8 and back, check if holds
|
||||
half_t neg_half = half_t{-0.0156250};
|
||||
// negative norm fp16 value to fp8 and back, check if holds
|
||||
half_t neg_half = half_t{-0.015625};
|
||||
ASSERT_NEAR(neg_half, type_convert<half_t>(f8_convert_sr<f8_t>(neg_half)), abs_tol);
|
||||
// positive subnorm fp16 value to fp8 and back, check if holds
|
||||
pos_half = half_t{0.00390625};
|
||||
ASSERT_NEAR(pos_half, type_convert<half_t>(f8_convert_sr<f8_t>(pos_half)), abs_tol);
|
||||
// negative subnorm fp16 value to fp8 and back, check if holds
|
||||
neg_half = half_t{-0.001953125};
|
||||
ASSERT_NEAR(neg_half, type_convert<half_t>(f8_convert_sr<f8_t>(neg_half)), abs_tol);
|
||||
}
|
||||
Reference in New Issue
Block a user