This commit is contained in:
carlushuang
2024-02-28 22:57:19 +00:00
parent e60c5aea4e
commit f69356b1d7
130 changed files with 28268 additions and 0 deletions

View File

@@ -0,0 +1,116 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <stdint.h>
#pragma once
#define CK_TILE_ARITHMETIC_USING_FLOAT(type_) \
CK_TILE_HOST_DEVICE \
bool operator==(const type_& x, const type_& y) \
{ \
return static_cast<float>(x) == static_cast<float>(y); \
} \
CK_TILE_HOST_DEVICE \
bool operator!=(const type_& x, const type_& y) \
{ \
return static_cast<float>(x) != static_cast<float>(y); \
} \
CK_TILE_HOST_DEVICE \
bool operator<(const type_& x, const type_& y) \
{ \
return static_cast<float>(x) < static_cast<float>(y); \
} \
CK_TILE_HOST_DEVICE \
bool operator<=(const type_& x, const type_& y) \
{ \
return static_cast<float>(x) <= static_cast<float>(y); \
} \
CK_TILE_HOST_DEVICE \
bool operator>(const type_& x, const type_& y) \
{ \
return static_cast<float>(x) > static_cast<float>(y); \
} \
CK_TILE_HOST_DEVICE \
bool operator>=(const type_& x, const type_& y) \
{ \
return static_cast<float>(x) >= static_cast<float>(y); \
} \
CK_TILE_HOST_DEVICE \
type_ operator+(const type_& x, const type_& y) \
{ \
return type_(static_cast<float>(x) + static_cast<float>(y)); \
} \
CK_TILE_HOST_DEVICE \
type_ operator-(const type_& x) \
{ \
constexpr uint32_t bits = sizeof(type_) * 8; \
constexpr uint32_t mask = 1 << (bits - 1); \
type_ y = x; \
y.data ^= static_cast<typename type_::raw_type>(mask); \
return y; \
} \
CK_TILE_HOST_DEVICE \
type_ operator-(const type_& x, const type_& y) \
{ \
return type_(static_cast<float>(x) - static_cast<float>(y)); \
} \
CK_TILE_HOST_DEVICE \
type_ operator*(const type_& x, const type_& y) \
{ \
return type_(static_cast<float>(x) * static_cast<float>(y)); \
} \
CK_TILE_HOST_DEVICE \
type_ operator/(const type_& x, const type_& y) \
{ \
return type_(static_cast<float>(x) / static_cast<float>(y)); \
} \
CK_TILE_HOST_DEVICE \
type_& operator+=(type_& x, const type_& y) \
{ \
x = type_(static_cast<float>(x) + static_cast<float>(y)); \
return x; \
} \
CK_TILE_HOST_DEVICE \
type_& operator-=(type_& x, const type_& y) \
{ \
x = type_(static_cast<float>(x) - static_cast<float>(y)); \
return x; \
} \
CK_TILE_HOST_DEVICE \
type_& operator*=(type_& x, const type_& y) \
{ \
x = type_(static_cast<float>(x) * static_cast<float>(y)); \
return x; \
} \
CK_TILE_HOST_DEVICE \
type_& operator/=(type_& x, const type_& y) \
{ \
x = type_(static_cast<float>(x) / static_cast<float>(y)); \
return x; \
} \
CK_TILE_HOST_DEVICE \
type_& operator++(type_& x) \
{ \
x = type_(static_cast<float>(x) + 1.f); \
return x; \
} \
CK_TILE_HOST_DEVICE \
type_& operator--(type_& x) \
{ \
x = type_(static_cast<float>(x) - 1.f); \
return x; \
} \
CK_TILE_HOST_DEVICE \
type_ operator++(type_& x, int) \
{ \
type_ y(x); \
x = type_(static_cast<float>(x) + 1.f); \
return y; \
} \
CK_TILE_HOST_DEVICE \
type_ operator--(type_& x, int) \
{ \
type_ y(x); \
x = type_(static_cast<float>(x) - 1.f); \
return y; \
}

View File

@@ -0,0 +1,263 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/numeric/arithmetic.hpp"
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include <stdint.h>
#pragma once
namespace ck_tile {
enum class bf16_rounding_mode
{
standard = 0, // rtn
truncate_with_nan,
truncate,
};
template <bf16_rounding_mode rounding = CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT>
CK_TILE_HOST_DEVICE uint16_t float_to_bf16_raw(float f, constant<rounding> = {});
CK_TILE_HOST_DEVICE
float bf16_to_float_raw(uint16_t x);
// HIP use __hip_bfloat16 as struct
struct alignas(2) bfloat16_t
{
using raw_type = uint16_t;
raw_type data;
CK_TILE_HOST_DEVICE
static bfloat16_t bit_cast(raw_type x)
{
bfloat16_t y;
y.data = x;
return y;
}
// constructor
bfloat16_t() = default;
// construct from float
CK_TILE_HOST_DEVICE
explicit bfloat16_t(const float& x) { data = float_to_bf16_raw(x); }
// construct from int
CK_TILE_HOST_DEVICE
explicit bfloat16_t(const int& x) { data = float_to_bf16_raw(static_cast<float>(x)); }
// construct from unsigned int
CK_TILE_HOST_DEVICE
explicit bfloat16_t(const unsigned int& x) { data = float_to_bf16_raw(static_cast<float>(x)); }
// cast to float
CK_TILE_HOST_DEVICE
explicit operator float() const { return bf16_to_float_raw(data); }
// cast to int
CK_TILE_HOST_DEVICE
explicit operator int() const { return static_cast<int>(bf16_to_float_raw(data)); }
// internal access
CK_TILE_HOST_DEVICE
raw_type& get() { return data; }
CK_TILE_HOST_DEVICE
raw_type get() const { return data; }
};
// round to nearest
CK_TILE_HOST_DEVICE
uint16_t float_to_bf16_rtn_raw(float f)
{
union
{
float fp32;
uint32_t int32;
} u = {f};
if(~u.int32 & 0x7f800000)
{
// When the exponent bits are not all 1s, then the value is zero, normal,
// or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus
// 1 if the least significant bit of the bfloat16 mantissa is 1 (odd).
// This causes the bfloat16's mantissa to be incremented by 1 if the 16
// least significant bits of the float mantissa are greater than 0x8000,
// or if they are equal to 0x8000 and the least significant bit of the
// bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when
// the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already
// has the value 0x7f, then incrementing it causes it to become 0x00 and
// the exponent is incremented by one, which is the next higher FP value
// to the unrounded bfloat16 value. When the bfloat16 value is subnormal
// with an exponent of 0x00 and a mantissa of 0x7F, it may be rounded up
// to a normal value with an exponent of 0x01 and a mantissa of 0x00.
// When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F,
// incrementing it causes it to become an exponent of 0xFF and a mantissa
// of 0x00, which is Inf, the next higher value to the unrounded value.
u.int32 += 0x7fff + ((u.int32 >> 16) & 1); // Round to nearest, round to even
}
else if(u.int32 & 0xffff)
{
// When all of the exponent bits are 1, the value is Inf or NaN.
// Inf is indicated by a zero mantissa. NaN is indicated by any nonzero
// mantissa bit. Quiet NaN is indicated by the most significant mantissa
// bit being 1. Signaling NaN is indicated by the most significant
// mantissa bit being 0 but some other bit(s) being 1. If any of the
// lower 16 bits of the mantissa are 1, we set the least significant bit
// of the bfloat16 mantissa, in order to preserve signaling NaN in case
// the bloat16's mantissa bits are all 0.
u.int32 |= 0x10000; // Preserve signaling NaN
}
return uint16_t(u.int32 >> 16);
}
// Truncate instead of rounding, preserving SNaN
CK_TILE_HOST_DEVICE
uint16_t float_to_bf16_truc_nan_raw(float f)
{
union
{
float fp32;
uint32_t int32;
} u = {f};
return uint16_t(u.int32 >> 16) | (!(~u.int32 & 0x7f800000) && (u.int32 & 0xffff));
}
// Fast truncate instead of rounding, RTZ
CK_TILE_HOST_DEVICE
uint16_t float_to_bf16_truc_raw(float f)
{
union
{
float fp32;
uint32_t int32;
} u = {f};
return uint16_t(u.int32 >> 16);
}
template <bf16_rounding_mode rounding = CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT>
CK_TILE_HOST_DEVICE uint16_t float_to_bf16_raw(float f, constant<rounding> = {})
{
if constexpr(rounding == bf16_rounding_mode::standard)
return float_to_bf16_rtn_raw(f);
else if constexpr(rounding == bf16_rounding_mode::truncate_with_nan)
return float_to_bf16_truc_nan_raw(f);
else
return float_to_bf16_truc_raw(f);
}
CK_TILE_HOST_DEVICE
float bf16_to_float_raw(uint16_t x)
{
union
{
uint32_t int32;
float fp32;
} u = {uint32_t(x) << 16};
return u.fp32;
}
template <bf16_rounding_mode rounding = CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT>
CK_TILE_HOST_DEVICE bfloat16_t float_to_bf16(float f, constant<rounding> = {})
{
return bfloat16_t::bit_cast(float_to_bf16_raw(f, constant<rounding>{}));
}
CK_TILE_HOST_DEVICE
float bf16_to_float(bfloat16_t x) { return static_cast<float>(x); }
template <bf16_rounding_mode rounding = CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT>
CK_TILE_HOST_DEVICE bfloat16_t fp16_to_bf16(half_t f, constant<rounding> = {})
{
return bfloat16_t::bit_cast(float_to_bf16_raw(static_cast<float>(f), constant<rounding>{}));
}
CK_TILE_HOST_DEVICE
float bf16_to_fp16(bfloat16_t x) { return float_to_fp16(static_cast<float>(x)); }
template <class T>
struct numeric_limits;
template <>
struct numeric_limits<bfloat16_t>
{
// minimum finite value, or minimum positive normalized value for float
CK_TILE_HOST_DEVICE static constexpr bfloat16_t min() { return bfloat16_t::bit_cast(0x0080); }
// minumum finite value
CK_TILE_HOST_DEVICE static constexpr bfloat16_t lowest()
{
return bfloat16_t::bit_cast(0xff7f);
}
// maximum finite value
CK_TILE_HOST_DEVICE static constexpr bfloat16_t max() { return bfloat16_t::bit_cast(0x7f7f); }
// difference between 1.0 and next value representable by float
CK_TILE_HOST_DEVICE static constexpr bfloat16_t epsilon()
{
return bfloat16_t::bit_cast(0x1000);
}
// maximum rounding error
CK_TILE_HOST_DEVICE static constexpr bfloat16_t round_error() { return bfloat16_t(0.5f); }
// positive infinity value
CK_TILE_HOST_DEVICE static constexpr bfloat16_t infinity()
{
return bfloat16_t::bit_cast(0x7f80);
}
// quiet NaN
CK_TILE_HOST_DEVICE static constexpr bfloat16_t quiet_NaN()
{
return bfloat16_t::bit_cast(0x7FFF);
}
// signaling NaN
CK_TILE_HOST_DEVICE static constexpr bfloat16_t signaling_NaN()
{
return bfloat16_t::bit_cast(0x7FFF);
}
// smallest positive subnormal value
CK_TILE_HOST_DEVICE static constexpr bfloat16_t denorm_min()
{
return bfloat16_t::bit_cast(0x0001);
}
};
CK_TILE_ARITHMETIC_USING_FLOAT(bfloat16_t)
// math
CK_TILE_HOST_DEVICE
bfloat16_t abs(const bfloat16_t& x) { return bfloat16_t::bit_cast(x.get() & 0x7fff); }
CK_TILE_HOST_DEVICE
bool isnan(const bfloat16_t& x)
{
uint16_t xx = x.get();
return (xx & 0x7FFF) > 0x7C00;
}
CK_TILE_DEVICE
bfloat16_t sqrt(bfloat16_t x)
{
return static_cast<bfloat16_t>(__builtin_amdgcn_sqrtf(static_cast<float>(x)));
};
CK_TILE_DEVICE
bfloat16_t exp(bfloat16_t x) { return static_cast<bfloat16_t>(__expf(static_cast<float>(x))); };
CK_TILE_DEVICE
bfloat16_t exp2(bfloat16_t x) { return static_cast<bfloat16_t>(exp2f(static_cast<float>(x))); };
CK_TILE_DEVICE
bfloat16_t log(bfloat16_t x) { return static_cast<bfloat16_t>(__logf(static_cast<float>(x))); };
using bf16_t = bfloat16_t;
} // namespace ck_tile

View File

@@ -0,0 +1,735 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/utility/limits.hpp"
#include "ck_tile/core/utility/random.hpp"
#include "ck_tile/core/numeric/arithmetic.hpp"
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include <stdint.h>
#include <type_traits>
#pragma once
namespace ck_tile {
// fp8 rounding modes
// use standard for rounding to nearest, the faster one
// use stochastic for stochastic rounding, helps to avoid error accumulation
enum class fp8_rounding_mode
{
standard = 0,
stochastic
};
/*
* ______________NANOO_________________ | ______________IEEE________________
* e4m3 e5m2 | e4m3 e5m2
* bias : 8 16 | 7 15
* inf : 1.0000.000 1.00000.00 | N/A s.11111.00
* Nan : 1.0000.000 1.00000.00 | s.1111.111 s.11111.{01, 10, 11}
* zero : 0.0000.000 0.00000.00 | s.0000.000 s.00000.00
* Max(norm) : s.1111.111 (240) s.11111.11(57344) | s.1111.110(448) s.11110.11(57344)
* Max(snorm): s.0000.111 s.00000.11 | s.0000.111(448) s.00000.11(57344)
* 0.0068359375 2.288818e-05 | 0.013671875 4.57763671875e-05
* Min(norm) : s.0001.000 s.00001.00 | s.0001.000 s.00001.00
* 2^-7(0.00078125) 2^-15(3.05176e-05) | 2^-6(0.015625) 2^-14(6.10352e-05)
* Min(snorm): s.0000.001 s.00000.01 | s.0000.001 s.00000.01
* 2^-10(0.00097656) 2^-17(7.629395e-06)| 2^-9(0.001953125) 2^-16(1.52588e-05)
*/
template <fp8_rounding_mode rounding = static_cast<fp8_rounding_mode>(CK_TILE_FLOAT_TO_FP8_DEFAULT)>
CK_TILE_HOST_DEVICE uint8_t float_to_fp8_raw(float, constant<rounding> = {});
template <fp8_rounding_mode rounding = static_cast<fp8_rounding_mode>(CK_TILE_FLOAT_TO_FP8_DEFAULT)>
CK_TILE_HOST_DEVICE uint8_t float_to_bf8_raw(float, constant<rounding> = {});
CK_TILE_HOST_DEVICE float fp8_to_float_raw(uint8_t);
CK_TILE_HOST_DEVICE float bf8_to_float_raw(uint8_t);
struct alignas(1) float8_e4m3_t
{
static constexpr int exponent = 4;
static constexpr int mantissa = 3;
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
static constexpr int bias = 1 << (exponent - 1); // NANOO
#else
static constexpr int bias = (1 << (exponent - 1)) - 1; // IEEE
#endif
using raw_type = uint8_t;
raw_type data;
CK_TILE_HOST_DEVICE
static float8_e4m3_t bit_cast(raw_type x)
{
float8_e4m3_t y;
y.data = x;
return y;
}
// constructor
float8_e4m3_t() = default;
// construct from float
CK_TILE_HOST_DEVICE
explicit float8_e4m3_t(const float& x) { data = float_to_fp8_raw(x); }
// construct from int
CK_TILE_HOST_DEVICE
explicit float8_e4m3_t(const int& x) { data = float_to_fp8_raw(static_cast<float>(x)); }
// construct from unsigned int
CK_TILE_HOST_DEVICE
explicit float8_e4m3_t(const unsigned int& x)
{
data = float_to_fp8_raw(static_cast<float>(x));
}
// cast to float
CK_TILE_HOST_DEVICE
explicit operator float() const { return fp8_to_float_raw(data); }
// cast to int
CK_TILE_HOST_DEVICE
explicit operator int() const { return static_cast<int>(fp8_to_float_raw(data)); }
// internal access
CK_TILE_HOST_DEVICE
raw_type& get() { return data; }
CK_TILE_HOST_DEVICE
raw_type get() const { return data; }
};
struct alignas(1) float8_e5m2_t
{
static constexpr int exponent = 5;
static constexpr int mantissa = 2;
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
static constexpr int bias = 1 << (exponent - 1); // NANOO
#else
static constexpr int bias = (1 << (exponent - 1)) - 1; // IEEE
#endif
using raw_type = uint8_t;
raw_type data;
CK_TILE_HOST_DEVICE
static float8_e5m2_t bit_cast(raw_type x)
{
float8_e5m2_t y;
y.data = x;
return y;
}
// constructor
float8_e5m2_t() = default;
// construct from float
CK_TILE_HOST_DEVICE
explicit float8_e5m2_t(const float& x) { data = float_to_bf8_raw(x); }
// construct from int
CK_TILE_HOST_DEVICE
explicit float8_e5m2_t(const int& x) { data = float_to_bf8_raw(static_cast<float>(x)); }
// construct from unsigned int
CK_TILE_HOST_DEVICE
explicit float8_e5m2_t(const unsigned int& x)
{
data = float_to_bf8_raw(static_cast<float>(x));
}
// cast to float
CK_TILE_HOST_DEVICE
explicit operator float() const { return bf8_to_float_raw(data); }
// cast to int
CK_TILE_HOST_DEVICE
explicit operator int() const { return static_cast<int>(bf8_to_float_raw(data)); }
// internal access
CK_TILE_HOST_DEVICE
raw_type& get() { return data; }
CK_TILE_HOST_DEVICE
raw_type get() const { return data; }
};
// below is sw fp8 conversion, not utilizing hw instruction
namespace impl {
template <typename X, typename Y, bool negative_zero_nan, bool clip, bool stoch>
CK_TILE_HOST_DEVICE Y run_cast_to_f8(X x, uint32_t rng)
{
// fp8/bf8 exponent/mantissa layout
constexpr int out_exp = numeric_utils<Y>::exp;
constexpr int out_mant = numeric_utils<Y>::mant;
// original type exponent/mantissa layout
constexpr int in_exp = numeric_utils<X>::exp;
constexpr int in_mant = numeric_utils<X>::mant;
int exponent, bias;
uint32_t head, mantissa, sign;
// nan code is same for float and half
constexpr Y nan_code = 0x80;
constexpr uint32_t nan_mask = numeric_utils<X>::nan_mask;
// convert to bitwise
using T_bitwise = typename numeric_utils<X>::bitwise_type;
T_bitwise x_bitwise = *(reinterpret_cast<T_bitwise*>(&x));
// unpack the input, depends on datatype
head = x_bitwise & numeric_utils<X>::head_mask;
mantissa = x_bitwise & numeric_utils<X>::mant_mask;
exponent = (head >> in_mant) & numeric_utils<X>::exp_mask;
sign = head >> (in_exp + in_mant);
bias = numeric_utils<X>::bias;
uint32_t signed_inf = (sign << (in_exp + in_mant)) + (((1 << in_exp) - 1) << in_mant);
uint32_t drop_mask = (1 << (in_mant - out_mant)) - 1;
constexpr int max_exp = (1 << out_exp) - (negative_zero_nan ? 1 : 2);
if constexpr(negative_zero_nan)
{
if((x_bitwise & nan_mask) == nan_mask)
return nan_code;
}
else
{
if((x_bitwise & nan_mask) == nan_mask)
return signed_inf + (mantissa != 0 ? 1 : 0);
}
// check if x is 0.0
if(x_bitwise == 0)
return 0;
// First need to check if it is normal or denorm as there is a difference of implict 1
// Then need to adjust the exponent to align with the F8 exponent, in the meanwhile, shift
// The mantissa. Then for stochastic rounding, add rng to mantissa and truncate. And for
// RNE, no need to add rng. Then probably need to check whether there is carry and adjust
// exponent and mantissa again3
// For IEEE bias mode, the bias is 2^(k-1)-1 where k is the width of exponent bits
const int out_bias = (1 << (out_exp - 1)) - 1 + (negative_zero_nan ? 1 : 0);
const int out_denormal_act_exponent = 1 - out_bias; // actual exponent of f8 denormal
// act_exponent is the actual exponent of fp32/fp16 (after subtracting bias)
// out_exponent is the converted f8 exponent with bias encoding
// exponent_diff is the diff between fp32/fp16 exponent and f8 exponent,
// the difference needs to be adjusted and mantissa shifted
int act_exponent, out_exponent, exponent_diff;
if(exponent == 0)
{ // fp32/fp16 is in denormal.
/* fp32 denormal is below 2^-127 so it is usually not a concern here, we mostly concern fp16
here. In this case, f8 is usually in denormal. But there could be exceptions. fp16 denormal has
exponent bias 15 while bf8 with NANOO has exponent bias 16. It means that there are some numbers in
fp16 denormal but they are bf8 (NANOO) normals - smallest bf8 (NANOO) normal is 2^-15. fp16 numbers
where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8 (NANOO) normal.
In this case, the fp16 mantissa should be shift left by 1 */
act_exponent = exponent - bias + 1;
exponent_diff = out_denormal_act_exponent -
act_exponent; // actual exponent is exponent-bias+1 as it is denormal
}
else
{ // fp32/fp16 is normal with implicit 1
act_exponent = exponent - bias;
if(act_exponent <= out_denormal_act_exponent)
{
/* This is the case where fp32/fp16 is normal but it is in f8 denormal range.
For example fp8 nanoo mode, denormal exponent is -7, but if the fp32/fp16
actual exponent is -7, it is actually larger due to the implict 1,
Therefore it needs to be adjust to -6 and mantissa shift right by 1.
So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */
exponent_diff = out_denormal_act_exponent - act_exponent;
}
else
{ // both fp32/fp16 and f8 are in normal range
exponent_diff =
0; // exponent_diff=0 does not mean there is no difference for this case,
// act_exponent could be larger. Just that it does not need shift mantissa
}
mantissa += (1 << in_mant); // Add the implicit 1 into mantissa
}
bool midpoint = (mantissa & ((1 << (in_mant - out_mant + exponent_diff)) - 1)) ==
(1 << (in_mant - out_mant + exponent_diff - 1));
/* This part is a bit tricky. The judgment of whether it is a tie needs to be done before we
shift right as shift right could rip off some residual part and make something not midpoint look
like midpoint. For example, the fp16 number 0x1002 (0 00100 0000000010), it is larger than
midpoint, but after shift right by 4 bits, it would look like midpoint. */
if(exponent_diff > 0)
mantissa >>= exponent_diff;
else if(exponent_diff == -1)
mantissa <<= -exponent_diff;
bool implicit_one = mantissa & (1 << in_mant);
// if there is no implict 1, it means the f8 is denormal and need to adjust to denorm exponent
out_exponent =
(act_exponent + exponent_diff) /*actual f8 exponent*/ + out_bias - (implicit_one ? 0 : 1);
// Now we have the exponent and mantissa adjusted
bool odd =
mantissa &
(1 << (in_mant - out_mant)); // if the least significant bit that is not truncated is 1
mantissa += (stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1) : mantissa)) & drop_mask;
// Now we deal with overflow
if(out_exponent == 0)
{
if((1 << in_mant) & mantissa)
{
out_exponent = 1; // denormal overflow to become normal, promote exponent
// No need to make 1 implicit now as it will be addressed later
}
}
else
{
if((1 << (in_mant + 1)) & mantissa)
{
mantissa >>= 1;
out_exponent++;
// No need to make 1 implicit now as it will be addressed later
}
}
mantissa >>= (in_mant - out_mant);
if(out_exponent > max_exp)
{
if(clip)
{
mantissa = (1 << out_mant) - 1;
out_exponent = max_exp;
}
else
{
return signed_inf;
}
}
// check if x is 0.0 or -0.0
if(out_exponent == 0 && mantissa == 0)
return negative_zero_nan ? 0 : (sign << (out_exp + out_mant));
mantissa &= (1 << out_mant) - 1;
return (sign << (out_exp + out_mant)) | (out_exponent << out_mant) | mantissa;
}
template <typename X, typename Y, bool negative_zero_nan>
CK_TILE_HOST_DEVICE Y run_cast_from_f8(X x)
{
// fp8/bf8 exponent/mantissa layout
constexpr int in_exp = numeric_utils<X>::exp;
constexpr int in_mant = numeric_utils<X>::mant;
// resulting type exponent/mantissa layout
constexpr int out_exp = numeric_utils<Y>::exp;
constexpr int out_mant = numeric_utils<Y>::mant;
// prepare the codes
constexpr X nan_code = 0x80;
Y Inf, NegInf, NaN, Neg0;
using T_bitwise = typename numeric_utils<Y>::bitwise_type;
constexpr T_bitwise Inf_bitwise = numeric_utils<Y>::Inf;
constexpr T_bitwise NegInf_bitwise = numeric_utils<Y>::NegInf;
constexpr T_bitwise NaN_bitwise = numeric_utils<Y>::NaN;
constexpr T_bitwise Neg0_bitwise = numeric_utils<Y>::Neg0;
Inf = *(reinterpret_cast<const Y*>(&Inf_bitwise));
NegInf = *(reinterpret_cast<const Y*>(&NegInf_bitwise));
NaN = *(reinterpret_cast<const Y*>(&NaN_bitwise));
Neg0 = *(reinterpret_cast<const Y*>(&Neg0_bitwise));
// check if x is 0.0
if(x == 0)
return static_cast<Y>(0);
// unpack the input
uint32_t sign = x >> (in_exp + in_mant);
uint32_t mantissa = x & ((1 << in_mant) - 1);
int exponent = (x & 0x7F) >> in_mant;
constexpr int exp_low_cutoff =
(1 << (out_exp - 1)) - (1 << (in_exp - 1)) + 1 - (negative_zero_nan ? 1 : 0);
T_bitwise retval;
if constexpr(negative_zero_nan)
{
if(x == nan_code)
return NaN;
}
else
{
if(x == nan_code)
return Neg0;
if(exponent == ((1 << in_exp) - 1))
return (mantissa == 0) ? (sign ? NegInf : Inf) : NaN;
}
if((numeric_utils<Y>::mant == 10) && (numeric_utils<X>::mant == 2) && !negative_zero_nan)
{
retval = x;
retval <<= 8;
return *(reinterpret_cast<const Y*>(&retval));
}
// subnormal input
if(exponent == 0)
{
// guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above
int sh = 1 + clz(mantissa) - (32 - in_mant);
mantissa <<= sh;
exponent += 1 - sh;
mantissa &= ((1 << in_mant) - 1);
}
exponent += exp_low_cutoff - 1;
mantissa <<= out_mant - in_mant;
// subnormal output (occurs when T=half, we=5, negative_zero_nan=true)
if(exponent <= 0)
{
mantissa |= 1 << out_mant;
mantissa >>= 1 - exponent;
exponent = 0;
}
retval = (sign << (out_exp + out_mant)) | (exponent << out_mant) | mantissa;
return *(reinterpret_cast<const Y*>(&retval));
}
template <typename X, typename Y, bool negative_zero_nan, bool clip, bool stoch>
CK_TILE_HOST_DEVICE Y cast_to_f8(X x, uint32_t rng)
{
// check datatypes
constexpr bool is_half = std::is_same<X, half_t>::value;
constexpr bool is_float = std::is_same<X, float>::value;
static_assert(is_half || is_float, "Only half and float can be casted.");
return run_cast_to_f8<X, Y, negative_zero_nan, clip, stoch>(x, rng);
}
template <typename X, typename Y, bool negative_zero_nan>
CK_TILE_HOST_DEVICE Y cast_from_f8(X x)
{
// check datatype
constexpr bool is_half = std::is_same<Y, half_t>::value;
constexpr bool is_float = std::is_same<Y, float>::value;
static_assert(is_half || is_float, "only half and float are supported.");
return run_cast_from_f8<X, Y, negative_zero_nan>(x);
}
} // namespace impl
CK_TILE_HOST_DEVICE uint8_t float_to_fp8_sr_raw(float x)
{
constexpr int seed = 42;
uint32_t rng = prand_generator<float, seed>{}(reinterpret_cast<uintptr_t>(&x), x);
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
float max_fp8 = 240.0f;
x = x > max_fp8 ? max_fp8 : (x < -max_fp8 ? -max_fp8 : x);
union
{
float fval;
uint32_t i32val;
uint8_t i8val[4]; // not endian independent
} val;
val.fval = x;
uint32_t ival = 0;
ival = __builtin_amdgcn_cvt_sr_fp8_f32(val.fval, rng, ival, 0); // 0 pos
val.i32val = ival;
return val.i8val[0]; // little endian
#else
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
constexpr fp8_rounding_mode rm = fp8_rounding_mode::stochastic;
return impl::
cast_to_f8<float, fp8_t, negative_zero_nan, clip, (rm == fp8_rounding_mode::stochastic)>(
x, rng);
#endif
}
CK_TILE_HOST_DEVICE uint8_t float_to_bf8_sr_raw(float x)
{
constexpr int seed = 42;
uint32_t rng = prand_generator<float, seed>{}(reinterpret_cast<uintptr_t>(&x), x);
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
union
{
float fval;
uint32_t i32val;
uint8_t i8val[4]; // not endian independent
} val;
val.fval = x;
uint32_t ival = 0;
ival = __builtin_amdgcn_cvt_sr_bf8_f32(val.fval, rng, ival, 0); // 0 pos
val.i32val = ival;
return val.i8val[0]; // little endian
#else
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
constexpr fp8_rounding_mode rm = fp8_rounding_mode::stochastic;
return impl::
cast_to_f8<float, bf8_t, negative_zero_nan, clip, (rm == fp8_rounding_mode::stochastic)>(
x, rng);
#endif
}
CK_TILE_HOST_DEVICE uint8_t float_to_fp8_rtn_raw(float x)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
float max_fp8 = 240.0f;
x = x > max_fp8 ? max_fp8 : (x < -max_fp8 ? -max_fp8 : x);
union
{
float fval;
uint32_t i32val;
uint8_t i8val[4]; // not endian independent
} val;
val.fval = x;
uint32_t ival = 0;
ival = __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival, false); // false -> WORD0
val.i32val = ival;
return val.i8val[0];
#else
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
constexpr fp8_rounding_mode rm = fp8_rounding_mode::standard;
constexpr uint32_t rng = 0;
return impl::
cast_to_f8<float, fp8_t, negative_zero_nan, clip, (rm == fp8_rounding_mode::stochastic)>(
x, rng);
#endif
}
CK_TILE_HOST_DEVICE uint8_t float_to_bf8_rtn_raw(float x)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
union
{
float fval;
uint32_t i32val;
uint8_t i8val[4]; // not endian independent
} val;
val.fval = x;
uint32_t ival = 0;
ival = __builtin_amdgcn_cvt_pk_bf8_f32(val.fval, val.fval, ival, false); // false -> WORD0
val.i32val = ival;
return val.i8val[0];
#else
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
constexpr fp8_rounding_mode rm = fp8_rounding_mode::standard;
constexpr uint32_t rng = 0;
return impl::
cast_to_f8<float, bf8_t, negative_zero_nan, clip, (rm == fp8_rounding_mode::stochastic)>(
x, rng);
#endif
}
// clang-format off
template<fp8_rounding_mode rounding = static_cast<fp8_rounding_mode>(CK_TILE_FLOAT_TO_FP8_DEFAULT)>
CK_TILE_HOST_DEVICE uint8_t float_to_fp8_raw(float x, constant<rounding> = {})
{
if constexpr (rounding == fp8_rounding_mode::standard) return float_to_fp8_rtn_raw(x);
else if constexpr (rounding == fp8_rounding_mode::stochastic) return float_to_fp8_sr_raw(x);
else return uint8_t{0};
}
template<fp8_rounding_mode rounding = static_cast<fp8_rounding_mode>(CK_TILE_FLOAT_TO_FP8_DEFAULT)>
CK_TILE_HOST_DEVICE uint8_t float_to_bf8_raw(float x, constant<rounding> = {})
{
if constexpr (rounding == fp8_rounding_mode::standard) return float_to_bf8_rtn_raw(x);
else if constexpr (rounding == fp8_rounding_mode::stochastic) return float_to_bf8_sr_raw(x);
else return uint8_t{0};
}
CK_TILE_HOST_DEVICE float fp8_to_float_raw(uint8_t x)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
float fval;
uint32_t i32val = static_cast<uint32_t>(x);
fval = __builtin_amdgcn_cvt_f32_fp8(i32val, 0);
// asm volatile("v_cvt_f32_fp8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val));
return fval;
#else
constexpr bool negative_zero_nan = true;
return impl::cast_from_f8<fp8_t, float, negative_zero_nan>(x);
#endif
}
CK_TILE_HOST_DEVICE float bf8_to_float_raw(uint8_t x)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
float fval;
uint32_t i32val = static_cast<uint32_t>(x);
fval = __builtin_amdgcn_cvt_f32_bf8(i32val, 0);
// asm volatile("v_cvt_f32_bf8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val));
return fval;
#else
constexpr bool negative_zero_nan = true;
return impl::cast_from_f8<bf8_t, float, negative_zero_nan>(x);
#endif
}
template<fp8_rounding_mode rounding = static_cast<fp8_rounding_mode>(CK_TILE_FLOAT_TO_FP8_DEFAULT)>
CK_TILE_HOST_DEVICE float8_e4m3_t float_to_fp8(float x, constant<rounding> = {})
{
return float8_e4m3_t::bit_cast(float_to_fp8_raw(x, constant<rounding>{}));
}
template<fp8_rounding_mode rounding = static_cast<fp8_rounding_mode>(CK_TILE_FLOAT_TO_FP8_DEFAULT)>
CK_TILE_HOST_DEVICE float8_e5m2_t float_to_bf8(float x, constant<rounding> = {})
{
return float8_e5m2_t::bit_cast(float_to_bf8_raw(x, constant<rounding>{}));
}
CK_TILE_HOST_DEVICE float fp8_to_float(float8_e4m3_t x)
{
return fp8_to_float_raw(x.get());
}
CK_TILE_HOST_DEVICE float bf8_to_float(float8_e5m2_t x)
{
return bf8_to_float_raw(x.get());
}
// clang-format on
using fp8_t = float8_e4m3_t;
using bf8_t = float8_e5m2_t;
template <typename T>
struct numeric_utils;
template <>
struct numeric_utils<fp8_t>
{
static constexpr int exp = fp8_t::exponent;
static constexpr int mant = fp8_t::mantissa;
static constexpr int bias = fp8_t::bias;
};
template <>
struct numeric_utils<bf8_t>
{
static constexpr int exp = bf8_t::exponent;
static constexpr int mant = bf8_t::mantissa;
static constexpr int bias = bf8_t::bias;
};
template <class T>
struct numeric_limits;
template <>
struct numeric_limits<fp8_t>
{
// minimum finite value, or minimum positive normalized value for float
CK_TILE_HOST_DEVICE static constexpr fp8_t min() { return fp8_t::bit_cast(0x08); }
// minumum finite value
CK_TILE_HOST_DEVICE static constexpr fp8_t lowest() { return fp8_t::bit_cast(0xff); }
// maximum finite value
CK_TILE_HOST_DEVICE static constexpr fp8_t max() { return fp8_t::bit_cast(0x7f); }
// difference between 1.0 and next value representable by float
CK_TILE_HOST_DEVICE static constexpr fp8_t epsilon() { return fp8_t::bit_cast(0x20); }
// maximum rounding error
CK_TILE_HOST_DEVICE static constexpr fp8_t round_error() { return fp8_t(0.5f); }
// positive infinity value
CK_TILE_HOST_DEVICE static constexpr fp8_t infinity() { return fp8_t::bit_cast(0x80); }
// quiet NaN
CK_TILE_HOST_DEVICE static constexpr fp8_t quiet_NaN() { return fp8_t::bit_cast(0x80); }
// signaling NaN
CK_TILE_HOST_DEVICE static constexpr fp8_t signaling_NaN() { return fp8_t::bit_cast(0x80); }
// smallest positive subnormal value
CK_TILE_HOST_DEVICE static constexpr fp8_t denorm_min() { return fp8_t::bit_cast(0x01); }
};
template <>
struct numeric_limits<bf8_t>
{
// minimum finite value, or minimum positive normalized value for float
CK_TILE_HOST_DEVICE static constexpr bf8_t min() { return bf8_t::bit_cast(0x04); }
// minumum finite value
CK_TILE_HOST_DEVICE static constexpr bf8_t lowest() { return bf8_t::bit_cast(0xff); }
// maximum finite value
CK_TILE_HOST_DEVICE static constexpr bf8_t max() { return bf8_t::bit_cast(0x7f); }
// difference between 1.0 and next value representable by float
CK_TILE_HOST_DEVICE static constexpr bf8_t epsilon() { return bf8_t::bit_cast(0x34); }
// maximum rounding error
CK_TILE_HOST_DEVICE static constexpr bf8_t round_error() { return bf8_t(0.5f); }
// positive infinity value
CK_TILE_HOST_DEVICE static constexpr bf8_t infinity() { return bf8_t::bit_cast(0x80); }
// quiet NaN
CK_TILE_HOST_DEVICE static constexpr bf8_t quiet_NaN() { return bf8_t::bit_cast(0x80); }
// signaling NaN
CK_TILE_HOST_DEVICE static constexpr bf8_t signaling_NaN() { return bf8_t::bit_cast(0x80); }
// smallest positive subnormal value
CK_TILE_HOST_DEVICE static constexpr bf8_t denorm_min() { return bf8_t::bit_cast(0x01); }
};
CK_TILE_ARITHMETIC_USING_FLOAT(fp8_t)
CK_TILE_ARITHMETIC_USING_FLOAT(bf8_t)
// math
CK_TILE_HOST_DEVICE
fp8_t abs(const fp8_t& x) { return fp8_t::bit_cast(x.get() & 0x7f); }
CK_TILE_HOST_DEVICE
bool isnan(const fp8_t& x)
{
uint8_t xx = x.get();
return xx == 0x80; // TODO: NANOO
}
CK_TILE_DEVICE
fp8_t sqrt(fp8_t x) { return static_cast<fp8_t>(__builtin_amdgcn_sqrtf(static_cast<float>(x))); };
CK_TILE_DEVICE
fp8_t exp(fp8_t x) { return static_cast<fp8_t>(__expf(static_cast<float>(x))); };
CK_TILE_DEVICE
fp8_t exp2(fp8_t x) { return static_cast<fp8_t>(exp2f(static_cast<float>(x))); };
CK_TILE_DEVICE
fp8_t log(fp8_t x) { return static_cast<fp8_t>(__logf(static_cast<float>(x))); };
CK_TILE_HOST_DEVICE
bf8_t abs(const bf8_t& x) { return bf8_t::bit_cast(x.get() & 0x7f); }
CK_TILE_HOST_DEVICE
bool isnan(const bf8_t& x)
{
uint8_t xx = x.get();
return xx == 0x80; // TODO: NANOO
}
CK_TILE_DEVICE
bf8_t sqrt(bf8_t x) { return static_cast<bf8_t>(__builtin_amdgcn_sqrtf(static_cast<float>(x))); };
CK_TILE_DEVICE
bf8_t exp(bf8_t x) { return static_cast<bf8_t>(__expf(static_cast<float>(x))); };
CK_TILE_DEVICE
bf8_t exp2(bf8_t x) { return static_cast<bf8_t>(exp2f(static_cast<float>(x))); };
CK_TILE_DEVICE
bf8_t log(bf8_t x) { return static_cast<bf8_t>(__logf(static_cast<float>(x))); };
} // namespace ck_tile

View File

@@ -0,0 +1,278 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include <hip/hip_fp16.h>
#pragma once
namespace ck_tile {
CK_TILE_HOST_DEVICE
float fp16_to_float_hip(const _Float16& x);
CK_TILE_HOST_DEVICE
_Float16 float_to_fp16_hip(const float& x);
// HIP use _Float16 as interchangable data type for float16
struct alignas(2) half_t
{
using raw_type = uint16_t;
raw_type data;
CK_TILE_HOST_DEVICE
static half_t bit_cast(raw_type x)
{
half_t y;
y.data = x;
return y;
}
CK_TILE_HOST_DEVICE
_Float16 to_fp16() const { return reinterpret_cast<const raw_type&>(data); }
// constructor
half_t() = default;
// construct from HIP half
CK_TILE_HOST_DEVICE
explicit half_t(const _Float16& x) : data(reinterpret_cast<const raw_type&>(x)) {}
// construct from float
CK_TILE_HOST_DEVICE
explicit half_t(const float& x) : half_t(float_to_fp16_hip(x)) {}
// construct from int
CK_TILE_HOST_DEVICE
explicit half_t(const int& x) : half_t(__int2half_rn(x)) {}
// construct from unsigned int
CK_TILE_HOST_DEVICE
explicit half_t(const unsigned int& x) : half_t(__uint2half_rn(x)) {}
// cast to float
CK_TILE_HOST_DEVICE
explicit operator float() const { return fp16_to_float_hip(to_fp16()); }
// cast to int
CK_TILE_HOST_DEVICE
explicit operator int() const { return static_cast<int>(fp16_to_float_hip(to_fp16())); }
// internal access
CK_TILE_HOST_DEVICE
raw_type& get() { return data; }
CK_TILE_HOST_DEVICE
raw_type get() const { return data; }
};
// conversions
CK_TILE_HOST_DEVICE
float fp16_to_float_hip(const _Float16& x)
{
// return __half2float(x);
return static_cast<float>(x);
}
CK_TILE_HOST_DEVICE
_Float16 float_to_fp16_hip(const float& x)
{
// return __float2half(x);
return static_cast<_Float16>(x);
}
CK_TILE_HOST_DEVICE
float fp16_to_float(const half_t& x) { return static_cast<float>(x); }
CK_TILE_HOST_DEVICE
half_t float_to_fp16(const float& x) { return half_t{x}; }
// limits
template <class T>
struct numeric_limits;
template <>
struct numeric_limits<half_t>
{
// minimum finite value, or minimum positive normalized value for float
CK_TILE_HOST_DEVICE static constexpr half_t min() { return half_t::bit_cast(0x0400); }
// minumum finite value
CK_TILE_HOST_DEVICE static constexpr half_t lowest() { return half_t::bit_cast(0xFBFF); }
// maximum finite value
CK_TILE_HOST_DEVICE static constexpr half_t max() { return half_t::bit_cast(0x7BFF); }
// difference between 1.0 and next value representable by float
CK_TILE_HOST_DEVICE static constexpr half_t epsilon() { return half_t::bit_cast(0x1800); }
// maximum rounding error
CK_TILE_HOST_DEVICE static constexpr half_t round_error() { return half_t(0.5f); }
// positive infinity value
CK_TILE_HOST_DEVICE static constexpr half_t infinity() { return half_t::bit_cast(0x7C00); }
// quiet NaN
CK_TILE_HOST_DEVICE static constexpr half_t quiet_NaN() { return half_t::bit_cast(0x7FFF); }
// signaling NaN
CK_TILE_HOST_DEVICE static constexpr half_t signaling_NaN() { return half_t::bit_cast(0x7FFF); }
// smallest positive subnormal value
CK_TILE_HOST_DEVICE static constexpr half_t denorm_min() { return half_t::bit_cast(0x0001); }
};
template <typename T>
struct numeric_utils;
template <>
struct numeric_utils<half_t>
{
static constexpr int exp = 5;
static constexpr int mant = 10;
static constexpr int bias = 15;
static constexpr uint16_t nan_mask = 0x7C00;
static constexpr uint16_t head_mask = 0xFC00;
static constexpr uint16_t mant_mask = 0x3FF;
static constexpr uint16_t exp_mask = 0x1F;
static constexpr uint32_t Inf = 0x7C00;
static constexpr uint32_t NegInf = 0xFC00;
static constexpr uint32_t NaN = 0x7C01;
static constexpr uint32_t Neg0 = 0x8000;
using bitwise_type = uint16_t;
};
// arithmetic
CK_TILE_HOST_DEVICE
bool operator==(const half_t& x, const half_t& y) { return __heq(x.to_fp16(), y.to_fp16()); }
CK_TILE_HOST_DEVICE
bool operator!=(const half_t& x, const half_t& y) { return __hne(x.to_fp16(), y.to_fp16()); }
CK_TILE_HOST_DEVICE
bool operator<(const half_t& x, const half_t& y) { return __hlt(x.to_fp16(), y.to_fp16()); }
CK_TILE_HOST_DEVICE
bool operator<=(const half_t& x, const half_t& y) { return __hle(x.to_fp16(), y.to_fp16()); }
CK_TILE_HOST_DEVICE
bool operator>(const half_t& x, const half_t& y) { return __hgt(x.to_fp16(), y.to_fp16()); }
CK_TILE_HOST_DEVICE
bool operator>=(const half_t& x, const half_t& y) { return __hge(x.to_fp16(), y.to_fp16()); }
CK_TILE_HOST_DEVICE
half_t operator+(const half_t& x, const half_t& y)
{
return half_t(__hadd(x.to_fp16(), y.to_fp16()));
}
CK_TILE_HOST_DEVICE
half_t operator-(const half_t& x) { return half_t(__hneg(x.to_fp16())); }
CK_TILE_HOST_DEVICE
half_t operator-(const half_t& x, const half_t& y)
{
return half_t(__hsub(x.to_fp16(), y.to_fp16()));
}
CK_TILE_HOST_DEVICE
half_t operator*(const half_t& x, const half_t& y)
{
return half_t(__hmul(x.to_fp16(), y.to_fp16()));
}
CK_TILE_HOST_DEVICE
half_t operator/(const half_t& x, const half_t& y)
{
return half_t(__hdiv(x.to_fp16(), y.to_fp16()));
}
CK_TILE_HOST_DEVICE
half_t& operator+=(half_t& x, const half_t& y)
{
x = half_t(__hadd(x.to_fp16(), y.to_fp16()));
return x;
}
CK_TILE_HOST_DEVICE
half_t& operator-=(half_t& x, const half_t& y)
{
x = half_t(__hsub(x.to_fp16(), y.to_fp16()));
return x;
}
CK_TILE_HOST_DEVICE
half_t& operator*=(half_t& x, const half_t& y)
{
x = half_t(__hmul(x.to_fp16(), y.to_fp16()));
return x;
}
CK_TILE_HOST_DEVICE
half_t& operator/=(half_t& x, const half_t& y)
{
x = half_t(__hdiv(x.to_fp16(), y.to_fp16()));
return x;
}
CK_TILE_HOST_DEVICE
half_t& operator++(half_t& x)
{
x = half_t(__hadd(x.to_fp16(), half_t(1.0f).to_fp16()));
return x;
}
CK_TILE_HOST_DEVICE
half_t& operator--(half_t& x)
{
x = half_t(__hsub(x.to_fp16(), half_t(1.0f).to_fp16()));
return x;
}
CK_TILE_HOST_DEVICE
half_t operator++(half_t& x, int)
{
half_t y(x);
x = half_t(__hadd(x.to_fp16(), half_t(1.0f).to_fp16()));
return y;
}
CK_TILE_HOST_DEVICE
half_t operator--(half_t& x, int)
{
half_t y(x);
x = half_t(__hsub(x.to_fp16(), half_t(1.0f).to_fp16()));
return y;
}
// math
CK_TILE_HOST_DEVICE
half_t abs(const half_t& x) { return half_t::bit_cast(x.get() & 0x7fff); }
CK_TILE_HOST_DEVICE
bool isnan(const half_t& x)
{
uint16_t xx = x.get();
return (xx & 0x7FFF) > 0x7C00;
}
CK_TILE_DEVICE
half_t sqrt(half_t x)
{
return static_cast<half_t>(__builtin_amdgcn_sqrtf(static_cast<float>(x)));
};
CK_TILE_DEVICE
half_t exp(half_t x) { return static_cast<half_t>(__expf(static_cast<float>(x))); };
CK_TILE_DEVICE
half_t exp2(half_t x) { return static_cast<half_t>(exp2f(static_cast<float>(x))); };
CK_TILE_DEVICE
half_t log(half_t x) { return static_cast<half_t>(__logf(static_cast<float>(x))); };
using fp16_t = half_t;
} // namespace ck_tile

View File

@@ -0,0 +1,13 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <stdint.h>
namespace ck_tile {
using index_t = int32_t;
using long_index_t = int64_t;
using int8_t = int8_t;
} // namespace ck_tile

View File

@@ -0,0 +1,82 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
namespace ck_tile {
template <auto v>
struct constant
{
using value_type = decltype(v);
using type = constant; // using injected-class-name
static constexpr value_type value = v;
constexpr CK_TILE_HOST_DEVICE operator value_type() const noexcept { return value; }
constexpr CK_TILE_HOST_DEVICE value_type operator()() const noexcept { return value; }
};
template <typename T, T v>
struct integral_constant : constant<v>
{
using value_type = T;
using type = integral_constant; // using injected-class-name
static constexpr T value = v;
// constexpr CK_TILE_HOST_DEVICE operator value_type() const noexcept { return value; }
// constexpr CK_TILE_HOST_DEVICE value_type operator()() const noexcept { return value; } //
};
template <index_t v>
using number = constant<v>;
template <long_index_t v>
using long_number = integral_constant<long_index_t, v>;
template <bool b>
using bool_constant = constant<b>;
#define CK_TILE_LEFT_UNARY_OP(OP) \
template <auto x> \
CK_TILE_HOST_DEVICE constexpr auto operator OP(constant<x>) \
{ \
return constant<(OP x)>{}; \
}
#define CK_TILE_BINARY_OP(OP) \
template <auto x, auto y> \
CK_TILE_HOST_DEVICE constexpr auto operator OP(constant<x>, constant<y>) \
{ \
return constant<(x OP y)>{}; \
}
CK_TILE_LEFT_UNARY_OP(+)
CK_TILE_LEFT_UNARY_OP(-)
CK_TILE_LEFT_UNARY_OP(~)
CK_TILE_LEFT_UNARY_OP(!)
CK_TILE_LEFT_UNARY_OP(*)
CK_TILE_BINARY_OP(+)
CK_TILE_BINARY_OP(-)
CK_TILE_BINARY_OP(*)
CK_TILE_BINARY_OP(/)
CK_TILE_BINARY_OP(%)
CK_TILE_BINARY_OP(&)
CK_TILE_BINARY_OP(|)
CK_TILE_BINARY_OP(^)
CK_TILE_BINARY_OP(<<)
CK_TILE_BINARY_OP(>>)
CK_TILE_BINARY_OP(&&)
CK_TILE_BINARY_OP(||)
CK_TILE_BINARY_OP(==)
CK_TILE_BINARY_OP(!=)
CK_TILE_BINARY_OP(>)
CK_TILE_BINARY_OP(<)
CK_TILE_BINARY_OP(>=)
CK_TILE_BINARY_OP(<=)
#undef CK_TILE_LEFT_UNARY_OP
#undef CK_TILE_BINARY_OP
} // namespace ck_tile

View File

@@ -0,0 +1,309 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include <type_traits>
#include <stdint.h>
namespace ck_tile {
template <typename T, T s>
struct scales
{
CK_TILE_HOST_DEVICE constexpr T operator()(T a) const { return s * a; }
};
template <typename T>
struct plus
{
CK_TILE_HOST_DEVICE constexpr T operator()(T a, T b) const { return a + b; }
};
template <typename T>
struct minus
{
CK_TILE_HOST_DEVICE constexpr T operator()(T a, T b) const { return a - b; }
};
struct multiplies
{
template <typename A, typename B>
CK_TILE_HOST_DEVICE constexpr auto operator()(const A& a, const B& b) const
{
return a * b;
}
};
template <typename T>
struct maximize
{
CK_TILE_HOST_DEVICE constexpr T operator()(T a, T b) const { return a >= b ? a : b; }
};
template <typename T>
struct minimize
{
CK_TILE_HOST_DEVICE constexpr T operator()(T a, T b) const { return a <= b ? a : b; }
};
template <typename T>
struct integer_divide_ceiler
{
CK_TILE_HOST_DEVICE constexpr T operator()(T a, T b) const
{
static_assert(std::is_same<T, index_t>{} || std::is_same<T, int>{}, "wrong type");
return (a + b - number<1>{}) / b;
}
};
template <typename X, typename Y>
CK_TILE_HOST_DEVICE constexpr auto integer_divide_floor(X x, Y y)
{
return x / y;
}
template <typename X, typename Y>
CK_TILE_HOST_DEVICE constexpr auto integer_divide_ceil(X x, Y y)
{
return (x + y - number<1>{}) / y;
}
template <typename X, typename Y>
CK_TILE_HOST_DEVICE constexpr auto integer_least_multiple(X x, Y y)
{
return y * integer_divide_ceil(x, y);
}
template <typename T>
CK_TILE_HOST_DEVICE constexpr T max(T x)
{
return x;
}
template <typename T>
CK_TILE_HOST_DEVICE constexpr T max(T x, T y)
{
return x > y ? x : y;
}
template <index_t X>
CK_TILE_HOST_DEVICE constexpr index_t max(number<X>, index_t y)
{
return X > y ? X : y;
}
template <index_t Y>
CK_TILE_HOST_DEVICE constexpr index_t max(index_t x, number<Y>)
{
return x > Y ? x : Y;
}
template <typename X, typename... Ys>
CK_TILE_HOST_DEVICE constexpr auto max(X x, Ys... ys)
{
static_assert(sizeof...(Ys) > 0, "not enough argument");
return max(x, max(ys...));
}
template <typename T>
CK_TILE_HOST_DEVICE constexpr T min(T x)
{
return x;
}
template <typename T>
CK_TILE_HOST_DEVICE constexpr T min(T x, T y)
{
return x < y ? x : y;
}
template <index_t X>
CK_TILE_HOST_DEVICE constexpr index_t min(number<X>, index_t y)
{
return X < y ? X : y;
}
template <index_t Y>
CK_TILE_HOST_DEVICE constexpr index_t min(index_t x, number<Y>)
{
return x < Y ? x : Y;
}
template <typename X, typename... Ys>
CK_TILE_HOST_DEVICE constexpr auto min(X x, Ys... ys)
{
static_assert(sizeof...(Ys) > 0, "not enough argument");
return min(x, min(ys...));
}
template <typename T>
CK_TILE_HOST_DEVICE constexpr T clamp(const T& x, const T& lowerbound, const T& upperbound)
{
return min(max(x, lowerbound), upperbound);
}
// greatest common divisor, aka highest common factor
CK_TILE_HOST_DEVICE constexpr index_t gcd(index_t x, index_t y)
{
if(x < 0)
{
return gcd(-x, y);
}
else if(y < 0)
{
return gcd(x, -y);
}
else if(x == y || x == 0)
{
return y;
}
else if(y == 0)
{
return x;
}
else if(x > y)
{
return gcd(x % y, y);
}
else
{
return gcd(x, y % x);
}
}
template <index_t X, index_t Y>
CK_TILE_HOST_DEVICE constexpr auto gcd(number<X>, number<Y>)
{
constexpr auto r = gcd(X, Y);
return number<r>{};
}
template <typename X,
typename... Ys,
typename std::enable_if<sizeof...(Ys) >= 2, bool>::type = false>
CK_TILE_HOST_DEVICE constexpr auto gcd(X x, Ys... ys)
{
return gcd(x, gcd(ys...));
}
// least common multiple
template <typename X, typename Y>
CK_TILE_HOST_DEVICE constexpr auto lcm(X x, Y y)
{
return (x * y) / gcd(x, y);
}
template <typename X,
typename... Ys,
typename std::enable_if<sizeof...(Ys) >= 2, bool>::type = false>
CK_TILE_HOST_DEVICE constexpr auto lcm(X x, Ys... ys)
{
return lcm(x, lcm(ys...));
}
template <typename T>
struct equal
{
CK_TILE_HOST_DEVICE constexpr bool operator()(T x, T y) const { return x == y; }
};
template <typename T>
struct less
{
CK_TILE_HOST_DEVICE constexpr bool operator()(T x, T y) const { return x < y; }
};
CK_TILE_HOST_DEVICE constexpr int32_t next_power_of_two(int32_t x)
{
// TODO: x need to be 2 ~ 0x7fffffff. 0, 1, or larger than 0x7fffffff will compile fail
return 1 << (32 - __builtin_clz(x - 1));
}
template <index_t X>
CK_TILE_HOST_DEVICE constexpr auto next_power_of_two()
{
constexpr index_t y = next_power_of_two(X);
return number<y>{};
}
template <index_t X>
CK_TILE_HOST_DEVICE constexpr auto next_power_of_two(number<X>)
{
constexpr index_t y = next_power_of_two(X);
return number<y>{};
}
CK_TILE_HOST_DEVICE constexpr int32_t integer_log2_floor(int32_t x)
{
// TODO: x need to be 1 ~ 0x7fffffff
// __builtin_clz will produce unexpected result if x is 0;
return 31 - __builtin_clz(x);
}
CK_TILE_HOST_DEVICE constexpr bool is_power_of_two_integer(int32_t x)
{
// TODO: x need to be 1 ~ 0x7fffffff
return x == (1 << integer_log2_floor(x));
}
#ifndef C_LOG2E
#define C_LOG2E 1.44269504088896340736 // log2(e)
#endif
template <typename T>
struct log2e;
template <>
struct log2e<double>
{
static constexpr double value = C_LOG2E;
};
template <>
struct log2e<float>
{
static constexpr float value = C_LOG2E;
};
template <typename T = double>
inline constexpr T log2e_v = log2e<T>::value;
// math
CK_TILE_HOST_DEVICE
float abs(const float& x)
{
union
{
float f32;
uint32_t u32;
} y;
y.f32 = x;
y.u32 = y.u32 & 0x7fffffff;
return y.f32;
}
CK_TILE_HOST_DEVICE
bool isnan(const float& x)
{
uint32_t xx = reinterpret_cast<const uint32_t&>(x);
return (xx & 0x7fffffff) > 0x7F800000;
}
CK_TILE_DEVICE
float sqrt(float x) { return __builtin_amdgcn_sqrtf(x); };
CK_TILE_DEVICE
float exp(float x) { return __expf(x); };
CK_TILE_DEVICE
float exp2(float x) { return exp2f(x); };
CK_TILE_DEVICE
float log(float x) { return __logf(x); };
} // namespace ck_tile

View File

@@ -0,0 +1,45 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include <type_traits>
namespace ck_tile {
#if 0
// Convert X to Y, both X and Y are non-const data types.
template <typename Y,
typename X,
std::enable_if_t<!(std::is_const_v<Y> || std::is_const_v<X>), bool> = false>
CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
{
static_assert(!std::is_reference_v<Y> && !std::is_reference_v<X>);
return static_cast<Y>(x);
}
// TODO: const version never called, we may never need
// Convert X to Y, either X or Y is a const data type.
template <typename Y,
typename X,
std::enable_if_t<std::is_const_v<Y> || std::is_const_v<X>, bool> = false>
CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
{
static_assert(!std::is_reference_v<Y> && !std::is_reference_v<X>);
using NonConstY = std::remove_const_t<Y>;
using NonConstX = std::remove_const_t<X>;
return static_cast<Y>(type_convert<NonConstY, NonConstX>(x));
}
#else
// compatible way to call conversion operator and constructor of each custom data type
template <typename Y, typename X>
CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
{
static_assert(!std::is_reference_v<Y> && !std::is_reference_v<X>);
return static_cast<Y>(x);
}
#endif
} // namespace ck_tile

View File

@@ -0,0 +1,304 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/array.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/numeric/float8.hpp"
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/bfloat16.hpp"
namespace ck_tile {
// TODO: the whole content of this file should consider deprecated!
template <typename T_, index_t N_>
struct vector_type
{
static constexpr index_t N = N_;
using value_type = T_;
using type = value_type __attribute__((ext_vector_type(N))); // this is danguous
CK_HOST_DEVICE constexpr vector_type()
{
for(auto i = 0; i < N; i++)
data[i] = static_cast<value_type>(0);
}
CK_HOST_DEVICE constexpr vector_type(type v)
{
auto& r = reinterpret_cast<const array<value_type, N>&>(v);
for(auto i = 0; i < N; i++)
data[i] = r.get(i);
}
value_type data[N];
CK_HOST_DEVICE static constexpr auto size() { return N; }
CK_HOST_DEVICE auto& get() { return data; }
CK_HOST_DEVICE const auto& get() const { return data; }
CK_HOST_DEVICE auto& get(index_t i) { return data[i]; }
CK_HOST_DEVICE const auto& get(index_t i) const { return data[i]; }
template <index_t I>
CK_HOST_DEVICE auto& operator[](number<I>)
{
return data[I];
}
template <index_t I>
CK_HOST_DEVICE const auto& operator[](number<I>) const
{
return data[I];
}
template <index_t I>
CK_HOST_DEVICE auto& operator()(number<I>)
{
return data[I];
}
CK_HOST_DEVICE auto& at(index_t i) { return data[i]; }
CK_HOST_DEVICE const auto& at(index_t i) const { return data[i]; }
template <index_t I>
CK_HOST_DEVICE auto& at()
{
return data[I];
}
template <index_t I>
CK_HOST_DEVICE const auto& at() const
{
return data[I];
}
template <index_t I>
CK_HOST_DEVICE auto& at(number<I>)
{
return data[I];
}
template <index_t I>
CK_HOST_DEVICE const auto& at(number<I>) const
{
return data[I];
}
#define _VT_COMMON_AS() \
static_assert(sizeof(value_type) * N % sizeof(Tx) == 0); \
constexpr int vx = sizeof(value_type) * N / sizeof(Tx)
template <typename Tx>
CK_HOST_DEVICE auto& get_as()
{
_VT_COMMON_AS();
return reinterpret_cast<array<Tx, vx>&>(data);
}
template <typename Tx>
CK_HOST_DEVICE const auto& get_as() const
{
_VT_COMMON_AS();
return reinterpret_cast<const array<Tx, vx>&>(data);
}
template <typename Tx>
CK_HOST_DEVICE auto& get_as(index_t i)
{
_VT_COMMON_AS();
return reinterpret_cast<array<Tx, vx>&>(data).get(i);
}
template <typename Tx>
CK_HOST_DEVICE const auto& get_as(index_t i) const
{
_VT_COMMON_AS();
return reinterpret_cast<const array<Tx, vx>&>(data).get(i);
}
#undef _VT_COMMON_AS
};
template <typename T, index_t N>
struct vector_type_maker
{
using type = vector_type<T, N>;
};
template <typename T, index_t N0, index_t N1>
struct vector_type_maker<T __attribute__((ext_vector_type(N1))), N0>
{
using type = vector_type<T, N0 * N1>;
};
template <typename T, index_t N0, index_t N1>
struct vector_type_maker<vector_type<T, N1>, N0>
{
using type = vector_type<T, N0 * N1>;
};
template <typename T, index_t N>
using vector_type_maker_t = typename vector_type_maker<T, N>::type;
template <typename T, index_t N>
CK_HOST_DEVICE constexpr auto make_vector_type(number<N>)
{
return typename vector_type_maker<T, N>::type{};
}
// scalar_type
template <typename TV>
struct scalar_type;
// is_scalar_type
template <typename TV>
struct is_scalar_type
{
static constexpr bool value = (scalar_type<remove_cvref_t<TV>>::vector_size == 1);
};
// has_same_scalar_type
template <typename X, typename Y>
using has_same_scalar_type = is_same<typename scalar_type<remove_cvref_t<X>>::type,
typename scalar_type<remove_cvref_t<Y>>::type>;
template <typename T, index_t N>
struct scalar_type<T __attribute__((ext_vector_type(N)))>
{
using type = T;
static constexpr index_t vector_size = N;
};
template <typename T, index_t N>
struct scalar_type<vector_type<T, N>>
{
using type = T;
static constexpr index_t vector_size = N;
};
//
template <>
struct scalar_type<double>
{
using type = double;
static constexpr index_t vector_size = 1;
};
template <>
struct scalar_type<float>
{
using type = float;
static constexpr index_t vector_size = 1;
};
template <>
struct scalar_type<half_t>
{
using type = half_t;
static constexpr index_t vector_size = 1;
};
template <>
struct scalar_type<bhalf_t>
{
using type = bhalf_t;
static constexpr index_t vector_size = 1;
};
template <>
struct scalar_type<int64_t>
{
using type = int64_t;
static constexpr index_t vector_size = 1;
};
template <>
struct scalar_type<int32_t>
{
using type = int32_t;
static constexpr index_t vector_size = 1;
};
template <>
struct scalar_type<int8_t>
{
using type = int8_t;
static constexpr index_t vector_size = 1;
};
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
template <>
struct scalar_type<int4_t>
{
using type = int4_t;
static constexpr index_t vector_size = 1;
};
#endif
template <>
struct scalar_type<fp8_t>
{
using type = fp8_t;
static constexpr index_t vector_size = 1;
};
template <>
struct scalar_type<bf8_t>
{
using type = bf8_t;
static constexpr index_t vector_size = 1;
};
// below are some pre-defines of ext_vector_type
// fp64
using double2_t = typename vector_type<double, 2>::type;
using double4_t = typename vector_type<double, 4>::type;
// fp32
using float2_t = typename vector_type<float, 2>::type;
using float4_t = typename vector_type<float, 4>::type;
using float8_t = typename vector_type<float, 8>::type;
using float16_t = typename vector_type<float, 16>::type;
using float32_t = typename vector_type<float, 32>::type;
using float64_t = typename vector_type<float, 64>::type;
// fp16
using half2_t = typename vector_type<half_t, 2>::type;
using half4_t = typename vector_type<half_t, 4>::type;
using half8_t = typename vector_type<half_t, 8>::type;
using half16_t = typename vector_type<half_t, 16>::type;
using half32_t = typename vector_type<half_t, 32>::type;
using half64_t = typename vector_type<half_t, 64>::type;
// bfp16
using bhalf2_t = typename vector_type<bhalf_t, 2>::type;
using bhalf4_t = typename vector_type<bhalf_t, 4>::type;
using bhalf8_t = typename vector_type<bhalf_t, 8>::type;
using bhalf16_t = typename vector_type<bhalf_t, 16>::type;
using bhalf32_t = typename vector_type<bhalf_t, 32>::type;
using bhalf64_t = typename vector_type<bhalf_t, 64>::type;
// i32
using int32x2_t = typename vector_type<int32_t, 2>::type;
using int32x4_t = typename vector_type<int32_t, 4>::type;
using int32x8_t = typename vector_type<int32_t, 8>::type;
using int32x16_t = typename vector_type<int32_t, 16>::type;
using int32x32_t = typename vector_type<int32_t, 32>::type;
using int32x64_t = typename vector_type<int32_t, 64>::type;
// i8
using int8x2_t = typename vector_type<int8_t, 2>::type;
using int8x4_t = typename vector_type<int8_t, 4>::type;
using int8x8_t = typename vector_type<int8_t, 8>::type;
using int8x16_t = typename vector_type<int8_t, 16>::type;
using int8x32_t = typename vector_type<int8_t, 32>::type;
using int8x64_t = typename vector_type<int8_t, 64>::type;
// f8
using fp8x2_t = typename vector_type<fp8_t, 2>::type;
using fp8x4_t = typename vector_type<fp8_t, 4>::type;
using fp8x8_t = typename vector_type<fp8_t, 8>::type;
using fp8x16_t = typename vector_type<fp8_t, 16>::type;
using fp8x32_t = typename vector_type<fp8_t, 32>::type;
using fp8x64_t = typename vector_type<fp8_t, 64>::type;
// bf8
using bf8x2_t = typename vector_type<bf8_t, 2>::type;
using bf8x4_t = typename vector_type<bf8_t, 4>::type;
using bf8x8_t = typename vector_type<bf8_t, 8>::type;
using bf8x16_t = typename vector_type<bf8_t, 16>::type;
using bf8x32_t = typename vector_type<bf8_t, 32>::type;
using bf8x64_t = typename vector_type<bf8_t, 64>::type;
} // namespace ck_tile