mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-18 20:09:25 +00:00
FP8 enablement - add a pseudorandom number generator, add conversion methods (#708)
* Add basic fp8 definitions and prn-generator
* Format
* Add fp8<->fp32 type_convert
* Format
* Split type_convert and cast_to/from_f8
* Format
* Minor fix
* Minor fix
* Move fp8 utils to a separate header
* Add elementwise ops
* Add fp8_convert_sr
* Format
* Add element op
* Eliminate magic numbers
* Split f8_convert_sr in host and device
* Format
* Add some constexpr
* Add a datatype test
* Format
* Another format
* Add fp8<->fp16 tests
* Update type_converts
* Format
* Add fp16 casting functions
* Format
* Use seed as a runtime arg
* Use element location for PRNG
* Format
* Add fp8<->fp16 to PassThrough element op
* Clean up
* Merge host and device implementations
* Add comments on rounding modes
* Remove leftover code
* Put type_converts into a separate header
* Put random number gen to a separate header
* Rearrange f8_utils' namespaces
* Refactor type_convert.hpp
* Move f8_t definition
[ROCm/composable_kernel commit: f0c620c42e]
This commit is contained in:
@@ -2,3 +2,6 @@ if (USE_BITINT_EXTENSION_INT4)
|
||||
add_gtest_executable(test_int4 int4.cpp)
|
||||
target_link_libraries(test_int4 PRIVATE utility)
|
||||
endif()
|
||||
|
||||
add_gtest_executable(test_fp8 fp8.cpp)
|
||||
target_link_libraries(test_fp8 PRIVATE utility)
|
||||
|
||||
123
test/data_type/fp8.cpp
Normal file
123
test/data_type/fp8.cpp
Normal file
@@ -0,0 +1,123 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ck/utility/data_type.hpp"
|
||||
#include "ck/utility/type_convert.hpp"
|
||||
|
||||
using ck::f8_convert_sr;
|
||||
using ck::f8_t;
|
||||
using ck::half_t;
|
||||
using ck::type_convert;
|
||||
|
||||
TEST(FP8, NumericLimits)
|
||||
{
|
||||
EXPECT_EQ(ck::NumericLimits<f8_t>::Min(), 0x08);
|
||||
EXPECT_EQ(ck::NumericLimits<f8_t>::Max(), 0x77);
|
||||
EXPECT_EQ(ck::NumericLimits<f8_t>::Lowest(), 0xF7);
|
||||
EXPECT_EQ(ck::NumericLimits<f8_t>::QuietNaN(), 0x80);
|
||||
}
|
||||
|
||||
TEST(FP8, ConvertFP32Nearest)
|
||||
{
|
||||
// fix the tolerance value
|
||||
float abs_tol = 1e-6;
|
||||
// convert 0 float to fp8 and back, check if holds
|
||||
ASSERT_NEAR(0.0f, type_convert<float>(type_convert<f8_t>(0.0f)), abs_tol);
|
||||
// convert minimal float to fp8 and back, check if holds
|
||||
ASSERT_NEAR(std::numeric_limits<float>::min(),
|
||||
type_convert<float>(type_convert<f8_t>(std::numeric_limits<float>::min())),
|
||||
abs_tol);
|
||||
// convert maximal f8_t to float and check if equal to 240.0
|
||||
ASSERT_NEAR(240.0f, type_convert<float>(type_convert<f8_t>(240.0f)), abs_tol);
|
||||
// convert maximal float to fp8 and back, check if clipped to 240.0
|
||||
ASSERT_NEAR(240.0f,
|
||||
type_convert<float>(type_convert<f8_t>(std::numeric_limits<float>::max())),
|
||||
abs_tol);
|
||||
// convert inf float to f8_t and check if it is qNan
|
||||
ASSERT_NEAR(0x80, type_convert<f8_t>(std::numeric_limits<float>::infinity()), abs_tol);
|
||||
// positive float value to fp8 and back, check if holds
|
||||
float pos_float = 0.0078125f;
|
||||
ASSERT_NEAR(pos_float, type_convert<float>(type_convert<f8_t>(pos_float)), abs_tol);
|
||||
// negative float value to fp8 and back, check if holds
|
||||
float neg_float = -0.0156250f;
|
||||
ASSERT_NEAR(neg_float, type_convert<float>(type_convert<f8_t>(neg_float)), abs_tol);
|
||||
}
|
||||
|
||||
TEST(FP8, ConvertFP32Stochastic)
|
||||
{
|
||||
// fix the tolerance value
|
||||
float abs_tol = 1e-6;
|
||||
// convert 0 float to fp8 and back, check if holds
|
||||
ASSERT_NEAR(0.0f, type_convert<float>(f8_convert_sr<f8_t>(0.0f)), abs_tol);
|
||||
// convert minimal float to fp8 and back, check if holds
|
||||
ASSERT_NEAR(std::numeric_limits<float>::min(),
|
||||
type_convert<float>(f8_convert_sr<f8_t>(std::numeric_limits<float>::min())),
|
||||
abs_tol);
|
||||
// convert maximal f8_t to float and check if equal to 240.0
|
||||
ASSERT_NEAR(240.0f, type_convert<float>(f8_convert_sr<f8_t>(240.0f)), abs_tol);
|
||||
// convert maximal float to fp8 and back, check if clipped to 240.0
|
||||
ASSERT_NEAR(240.0f,
|
||||
type_convert<float>(f8_convert_sr<f8_t>(std::numeric_limits<float>::max())),
|
||||
abs_tol);
|
||||
// convert inf float to f8_t and check if it is qNan
|
||||
ASSERT_NEAR(0x80, f8_convert_sr<f8_t>(std::numeric_limits<float>::infinity()), abs_tol);
|
||||
// positive float value to fp8 and back, check if holds
|
||||
float pos_float = 0.0078125f;
|
||||
ASSERT_NEAR(pos_float, type_convert<float>(f8_convert_sr<f8_t>(pos_float)), abs_tol);
|
||||
// negative float value to fp8 and back, check if holds
|
||||
float neg_float = -0.0156250f;
|
||||
ASSERT_NEAR(neg_float, type_convert<float>(f8_convert_sr<f8_t>(neg_float)), abs_tol);
|
||||
}
|
||||
|
||||
TEST(FP8, ConvertFP16Nearest)
|
||||
{
|
||||
// fix the tolerance value
|
||||
float abs_tol = 1e-3;
|
||||
// convert 0 fp16 to fp8 and back, check if holds
|
||||
ASSERT_NEAR(half_t{0.0}, type_convert<half_t>(type_convert<f8_t>(half_t{0.0})), abs_tol);
|
||||
// convert minimal fp16 to fp8 and back, check if holds
|
||||
ASSERT_NEAR(ck::NumericLimits<half_t>::Min(),
|
||||
type_convert<half_t>(type_convert<f8_t>(ck::NumericLimits<half_t>::Min())),
|
||||
abs_tol);
|
||||
// convert maximal f8_t to fp16 and check if equal to 240.0
|
||||
ASSERT_NEAR(half_t{240.0}, type_convert<half_t>(type_convert<f8_t>(half_t{240.0})), abs_tol);
|
||||
// convert maximal fp16 to fp8 and back, check if clipped to 240.0
|
||||
ASSERT_NEAR(half_t{240.0},
|
||||
type_convert<half_t>(type_convert<f8_t>(ck::NumericLimits<half_t>::Max())),
|
||||
abs_tol);
|
||||
// convert QuietNaN fp16 to f8_t and check if it is QuietNaN
|
||||
ASSERT_NEAR(0x80, type_convert<f8_t>(ck::NumericLimits<half_t>::QuietNaN()), abs_tol);
|
||||
// positive fp16 value to fp8 and back, check if holds
|
||||
half_t pos_half = half_t{0.0078125};
|
||||
ASSERT_NEAR(pos_half, type_convert<half_t>(type_convert<f8_t>(pos_half)), abs_tol);
|
||||
// negative fp16 value to fp8 and back, check if holds
|
||||
half_t neg_half = half_t{-0.0156250};
|
||||
ASSERT_NEAR(neg_half, type_convert<half_t>(type_convert<f8_t>(neg_half)), abs_tol);
|
||||
}
|
||||
|
||||
TEST(FP8, ConvertFP16Stochastic)
|
||||
{
|
||||
// fix the tolerance value
|
||||
float abs_tol = 1e-3;
|
||||
// convert 0 fp16 to fp8 and back, check if holds
|
||||
ASSERT_NEAR(half_t{0.0}, type_convert<half_t>(f8_convert_sr<f8_t>(half_t{0.0})), abs_tol);
|
||||
// convert minimal fp16 to fp8 and back, check if holds
|
||||
ASSERT_NEAR(ck::NumericLimits<half_t>::Min(),
|
||||
type_convert<half_t>(f8_convert_sr<f8_t>(ck::NumericLimits<half_t>::Min())),
|
||||
abs_tol);
|
||||
// convert maximal f8_t to fp16 and check if equal to 240.0
|
||||
ASSERT_NEAR(half_t{240.0}, type_convert<half_t>(f8_convert_sr<f8_t>(half_t{240.0})), abs_tol);
|
||||
// convert maximal fp16 to fp8 and back, check if clipped to 240.0
|
||||
ASSERT_NEAR(half_t{240.0},
|
||||
type_convert<half_t>(f8_convert_sr<f8_t>(ck::NumericLimits<half_t>::Max())),
|
||||
abs_tol);
|
||||
// convert QuietNaN fp16 to f8_t and check if it is QuietNaN
|
||||
ASSERT_NEAR(0x80, f8_convert_sr<f8_t>(ck::NumericLimits<half_t>::QuietNaN()), abs_tol);
|
||||
// positive fp16 value to fp8 and back, check if holds
|
||||
half_t pos_half = half_t{0.0078125};
|
||||
ASSERT_NEAR(pos_half, type_convert<half_t>(f8_convert_sr<f8_t>(pos_half)), abs_tol);
|
||||
// negative fp16 value to fp8 and back, check if holds
|
||||
half_t neg_half = half_t{-0.0156250};
|
||||
ASSERT_NEAR(neg_half, type_convert<half_t>(f8_convert_sr<f8_t>(neg_half)), abs_tol);
|
||||
}
|
||||
Reference in New Issue
Block a user