mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-21 13:29:20 +00:00
FP8 enablement - add a pseudorandom number generator, add conversion methods (#708)
* Add basic fp8 definitions and prn-generator
* Format
* Add fp8<->fp32 type_convert
* Format
* Split type_convert and cast_to/from_f8
* Format
* Minor fix
* Minor fix
* Move fp8 utils to a separate header
* Add elementwise ops
* Add fp8_convert_sr
* Format
* Add element op
* Eliminate magic numbers
* Split f8_convert_sr in host and device
* Format
* Add some constexpr
* Add a datatype test
* Format
* Another format
* Add fp8<->fp16 tests
* Update type_converts
* Format
* Add fp16 casting functions
* Format
* Use seed as a runtime arg
* Use element location for PRNG
* Format
* Add fp8<->fp16 to PassThrough element op
* Clean up
* Merge host and device implementations
* Add comments on rounding modes
* Remove leftover code
* Put type_converts into a separate header
* Put random number gen to a separate header
* Rearrange f8_utils' namespaces
* Refactor type_convert.hpp
* Move f8_t definition
[ROCm/composable_kernel commit: f0c620c42e]
This commit is contained in:
@@ -6,6 +6,7 @@
|
||||
#include "ck/utility/data_type.hpp"
|
||||
#include "ck/utility/math.hpp"
|
||||
#include "ck/utility/math_v2.hpp"
|
||||
#include "ck/utility/type_convert.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -81,6 +82,36 @@ struct PassThrough
|
||||
y = x;
|
||||
}
|
||||
#endif
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()<f8_t, f8_t>(f8_t& y, const f8_t& x) const
|
||||
{
|
||||
y = x;
|
||||
}
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()<float, f8_t>(float& y, const f8_t& x) const
|
||||
{
|
||||
y = type_convert<float>(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()<f8_t, float>(f8_t& y, const float& x) const
|
||||
{
|
||||
y = type_convert<f8_t>(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()<half_t, f8_t>(half_t& y, const f8_t& x) const
|
||||
{
|
||||
y = type_convert<half_t>(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()<f8_t, half_t>(f8_t& y, const half_t& x) const
|
||||
{
|
||||
y = type_convert<f8_t>(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct UnaryConvert
|
||||
@@ -109,6 +140,23 @@ struct ConvertBF16RTN
|
||||
}
|
||||
};
|
||||
|
||||
struct ConvertF8SR
|
||||
{
|
||||
// convert to fp8 using stochastic rounding (SR)
|
||||
template <typename Y, typename X>
|
||||
__host__ __device__ void operator()(Y& y, const X& x) const
|
||||
{
|
||||
// check Y datatype
|
||||
static_assert(is_same<Y, f8_t>::value, "Data type is not supported by this operation!");
|
||||
|
||||
// check X datatype
|
||||
static_assert(is_same<X, float>::value || is_same<X, half_t>::value,
|
||||
"Data type is not supported by this operation!");
|
||||
|
||||
y = f8_convert_sr<Y>(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct Scale
|
||||
{
|
||||
__host__ __device__ Scale(float scale) : scale_(scale) {}
|
||||
|
||||
Reference in New Issue
Block a user