mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 13:11:25 +00:00
FP8 enablement - add a pseudorandom number generator, add conversion methods (#708)
* Add basic fp8 definitions and prn-generator * Format * Add fp8<->fp32 type_convert * Format * Split type_convert and cast_to/from_f8 * Format * Minor fix * Minor fix * Move fp8 utils to a separate header * Add elementwise ops * Add fp8_convert_sr * Format * Add element op * Eliminate magic numbers * Split f8_convert_sr in host and device * Format * Add some constexpr * Add a datatype test * Format * Another format * Add fp8<->fp16 tests * Update type_converts * Format * Add fp16 casting functions * Format * Use seed as a runtime arg * Use element location for PRNG * Format * Add fp8<->fp16 to PassThrough element op * Clean up * Merge host and device implementations * Add comments on rounding modes * Remove leftover code * Put type_converts into a separate header * Put random number gen to a separate header * Rearrange f8_utils' namespaces * Refactor type_convert.hpp * Move f8_t definition
This commit is contained in:
@@ -6,6 +6,7 @@
|
||||
#include "ck/utility/data_type.hpp"
|
||||
#include "ck/utility/math.hpp"
|
||||
#include "ck/utility/math_v2.hpp"
|
||||
#include "ck/utility/type_convert.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -81,6 +82,36 @@ struct PassThrough
|
||||
y = x;
|
||||
}
|
||||
#endif
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()<f8_t, f8_t>(f8_t& y, const f8_t& x) const
|
||||
{
|
||||
y = x;
|
||||
}
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()<float, f8_t>(float& y, const f8_t& x) const
|
||||
{
|
||||
y = type_convert<float>(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()<f8_t, float>(f8_t& y, const float& x) const
|
||||
{
|
||||
y = type_convert<f8_t>(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()<half_t, f8_t>(half_t& y, const f8_t& x) const
|
||||
{
|
||||
y = type_convert<half_t>(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()<f8_t, half_t>(f8_t& y, const half_t& x) const
|
||||
{
|
||||
y = type_convert<f8_t>(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct UnaryConvert
|
||||
@@ -109,6 +140,23 @@ struct ConvertBF16RTN
|
||||
}
|
||||
};
|
||||
|
||||
struct ConvertF8SR
|
||||
{
|
||||
// convert to fp8 using stochastic rounding (SR)
|
||||
template <typename Y, typename X>
|
||||
__host__ __device__ void operator()(Y& y, const X& x) const
|
||||
{
|
||||
// check Y datatype
|
||||
static_assert(is_same<Y, f8_t>::value, "Data type is not supported by this operation!");
|
||||
|
||||
// check X datatype
|
||||
static_assert(is_same<X, float>::value || is_same<X, half_t>::value,
|
||||
"Data type is not supported by this operation!");
|
||||
|
||||
y = f8_convert_sr<Y>(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct Scale
|
||||
{
|
||||
__host__ __device__ Scale(float scale) : scale_(scale) {}
|
||||
|
||||
Reference in New Issue
Block a user