mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 17:26:00 +00:00
add code
This commit is contained in:
116
include/ck_tile/core/numeric/arithmetic.hpp
Normal file
116
include/ck_tile/core/numeric/arithmetic.hpp
Normal file
@@ -0,0 +1,116 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
#include <stdint.h>
|
||||
|
||||
#pragma once
|
||||
|
||||
#define CK_TILE_ARITHMETIC_USING_FLOAT(type_) \
|
||||
CK_TILE_HOST_DEVICE \
|
||||
bool operator==(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return static_cast<float>(x) == static_cast<float>(y); \
|
||||
} \
|
||||
CK_TILE_HOST_DEVICE \
|
||||
bool operator!=(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return static_cast<float>(x) != static_cast<float>(y); \
|
||||
} \
|
||||
CK_TILE_HOST_DEVICE \
|
||||
bool operator<(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return static_cast<float>(x) < static_cast<float>(y); \
|
||||
} \
|
||||
CK_TILE_HOST_DEVICE \
|
||||
bool operator<=(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return static_cast<float>(x) <= static_cast<float>(y); \
|
||||
} \
|
||||
CK_TILE_HOST_DEVICE \
|
||||
bool operator>(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return static_cast<float>(x) > static_cast<float>(y); \
|
||||
} \
|
||||
CK_TILE_HOST_DEVICE \
|
||||
bool operator>=(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return static_cast<float>(x) >= static_cast<float>(y); \
|
||||
} \
|
||||
CK_TILE_HOST_DEVICE \
|
||||
type_ operator+(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return type_(static_cast<float>(x) + static_cast<float>(y)); \
|
||||
} \
|
||||
CK_TILE_HOST_DEVICE \
|
||||
type_ operator-(const type_& x) \
|
||||
{ \
|
||||
constexpr uint32_t bits = sizeof(type_) * 8; \
|
||||
constexpr uint32_t mask = 1 << (bits - 1); \
|
||||
type_ y = x; \
|
||||
y.data ^= static_cast<typename type_::raw_type>(mask); \
|
||||
return y; \
|
||||
} \
|
||||
CK_TILE_HOST_DEVICE \
|
||||
type_ operator-(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return type_(static_cast<float>(x) - static_cast<float>(y)); \
|
||||
} \
|
||||
CK_TILE_HOST_DEVICE \
|
||||
type_ operator*(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return type_(static_cast<float>(x) * static_cast<float>(y)); \
|
||||
} \
|
||||
CK_TILE_HOST_DEVICE \
|
||||
type_ operator/(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return type_(static_cast<float>(x) / static_cast<float>(y)); \
|
||||
} \
|
||||
CK_TILE_HOST_DEVICE \
|
||||
type_& operator+=(type_& x, const type_& y) \
|
||||
{ \
|
||||
x = type_(static_cast<float>(x) + static_cast<float>(y)); \
|
||||
return x; \
|
||||
} \
|
||||
CK_TILE_HOST_DEVICE \
|
||||
type_& operator-=(type_& x, const type_& y) \
|
||||
{ \
|
||||
x = type_(static_cast<float>(x) - static_cast<float>(y)); \
|
||||
return x; \
|
||||
} \
|
||||
CK_TILE_HOST_DEVICE \
|
||||
type_& operator*=(type_& x, const type_& y) \
|
||||
{ \
|
||||
x = type_(static_cast<float>(x) * static_cast<float>(y)); \
|
||||
return x; \
|
||||
} \
|
||||
CK_TILE_HOST_DEVICE \
|
||||
type_& operator/=(type_& x, const type_& y) \
|
||||
{ \
|
||||
x = type_(static_cast<float>(x) / static_cast<float>(y)); \
|
||||
return x; \
|
||||
} \
|
||||
CK_TILE_HOST_DEVICE \
|
||||
type_& operator++(type_& x) \
|
||||
{ \
|
||||
x = type_(static_cast<float>(x) + 1.f); \
|
||||
return x; \
|
||||
} \
|
||||
CK_TILE_HOST_DEVICE \
|
||||
type_& operator--(type_& x) \
|
||||
{ \
|
||||
x = type_(static_cast<float>(x) - 1.f); \
|
||||
return x; \
|
||||
} \
|
||||
CK_TILE_HOST_DEVICE \
|
||||
type_ operator++(type_& x, int) \
|
||||
{ \
|
||||
type_ y(x); \
|
||||
x = type_(static_cast<float>(x) + 1.f); \
|
||||
return y; \
|
||||
} \
|
||||
CK_TILE_HOST_DEVICE \
|
||||
type_ operator--(type_& x, int) \
|
||||
{ \
|
||||
type_ y(x); \
|
||||
x = type_(static_cast<float>(x) - 1.f); \
|
||||
return y; \
|
||||
}
|
||||
263
include/ck_tile/core/numeric/bfloat16.hpp
Normal file
263
include/ck_tile/core/numeric/bfloat16.hpp
Normal file
@@ -0,0 +1,263 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/utility/bit_cast.hpp"
|
||||
#include "ck_tile/core/numeric/arithmetic.hpp"
|
||||
#include "ck_tile/core/numeric/half.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include <stdint.h>
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
enum class bf16_rounding_mode
|
||||
{
|
||||
standard = 0, // rtn
|
||||
truncate_with_nan,
|
||||
truncate,
|
||||
};
|
||||
|
||||
template <bf16_rounding_mode rounding = CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT>
|
||||
CK_TILE_HOST_DEVICE uint16_t float_to_bf16_raw(float f, constant<rounding> = {});
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
float bf16_to_float_raw(uint16_t x);
|
||||
|
||||
// HIP use __hip_bfloat16 as struct
|
||||
struct alignas(2) bfloat16_t
|
||||
{
|
||||
using raw_type = uint16_t;
|
||||
raw_type data;
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
static bfloat16_t bit_cast(raw_type x)
|
||||
{
|
||||
bfloat16_t y;
|
||||
y.data = x;
|
||||
return y;
|
||||
}
|
||||
|
||||
// constructor
|
||||
bfloat16_t() = default;
|
||||
|
||||
// construct from float
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit bfloat16_t(const float& x) { data = float_to_bf16_raw(x); }
|
||||
|
||||
// construct from int
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit bfloat16_t(const int& x) { data = float_to_bf16_raw(static_cast<float>(x)); }
|
||||
|
||||
// construct from unsigned int
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit bfloat16_t(const unsigned int& x) { data = float_to_bf16_raw(static_cast<float>(x)); }
|
||||
|
||||
// cast to float
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit operator float() const { return bf16_to_float_raw(data); }
|
||||
|
||||
// cast to int
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit operator int() const { return static_cast<int>(bf16_to_float_raw(data)); }
|
||||
|
||||
// internal access
|
||||
CK_TILE_HOST_DEVICE
|
||||
raw_type& get() { return data; }
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
raw_type get() const { return data; }
|
||||
};
|
||||
|
||||
// round to nearest
|
||||
CK_TILE_HOST_DEVICE
|
||||
uint16_t float_to_bf16_rtn_raw(float f)
|
||||
{
|
||||
union
|
||||
{
|
||||
float fp32;
|
||||
uint32_t int32;
|
||||
} u = {f};
|
||||
if(~u.int32 & 0x7f800000)
|
||||
{
|
||||
// When the exponent bits are not all 1s, then the value is zero, normal,
|
||||
// or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus
|
||||
// 1 if the least significant bit of the bfloat16 mantissa is 1 (odd).
|
||||
// This causes the bfloat16's mantissa to be incremented by 1 if the 16
|
||||
// least significant bits of the float mantissa are greater than 0x8000,
|
||||
// or if they are equal to 0x8000 and the least significant bit of the
|
||||
// bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when
|
||||
// the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already
|
||||
// has the value 0x7f, then incrementing it causes it to become 0x00 and
|
||||
// the exponent is incremented by one, which is the next higher FP value
|
||||
// to the unrounded bfloat16 value. When the bfloat16 value is subnormal
|
||||
// with an exponent of 0x00 and a mantissa of 0x7F, it may be rounded up
|
||||
// to a normal value with an exponent of 0x01 and a mantissa of 0x00.
|
||||
// When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F,
|
||||
// incrementing it causes it to become an exponent of 0xFF and a mantissa
|
||||
// of 0x00, which is Inf, the next higher value to the unrounded value.
|
||||
u.int32 += 0x7fff + ((u.int32 >> 16) & 1); // Round to nearest, round to even
|
||||
}
|
||||
else if(u.int32 & 0xffff)
|
||||
{
|
||||
// When all of the exponent bits are 1, the value is Inf or NaN.
|
||||
// Inf is indicated by a zero mantissa. NaN is indicated by any nonzero
|
||||
// mantissa bit. Quiet NaN is indicated by the most significant mantissa
|
||||
// bit being 1. Signaling NaN is indicated by the most significant
|
||||
// mantissa bit being 0 but some other bit(s) being 1. If any of the
|
||||
// lower 16 bits of the mantissa are 1, we set the least significant bit
|
||||
// of the bfloat16 mantissa, in order to preserve signaling NaN in case
|
||||
// the bloat16's mantissa bits are all 0.
|
||||
u.int32 |= 0x10000; // Preserve signaling NaN
|
||||
}
|
||||
return uint16_t(u.int32 >> 16);
|
||||
}
|
||||
|
||||
// Truncate instead of rounding, preserving SNaN
|
||||
CK_TILE_HOST_DEVICE
|
||||
uint16_t float_to_bf16_truc_nan_raw(float f)
|
||||
{
|
||||
union
|
||||
{
|
||||
float fp32;
|
||||
uint32_t int32;
|
||||
} u = {f};
|
||||
return uint16_t(u.int32 >> 16) | (!(~u.int32 & 0x7f800000) && (u.int32 & 0xffff));
|
||||
}
|
||||
|
||||
// Fast truncate instead of rounding, RTZ
|
||||
CK_TILE_HOST_DEVICE
|
||||
uint16_t float_to_bf16_truc_raw(float f)
|
||||
{
|
||||
union
|
||||
{
|
||||
float fp32;
|
||||
uint32_t int32;
|
||||
} u = {f};
|
||||
return uint16_t(u.int32 >> 16);
|
||||
}
|
||||
|
||||
template <bf16_rounding_mode rounding = CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT>
|
||||
CK_TILE_HOST_DEVICE uint16_t float_to_bf16_raw(float f, constant<rounding> = {})
|
||||
{
|
||||
if constexpr(rounding == bf16_rounding_mode::standard)
|
||||
return float_to_bf16_rtn_raw(f);
|
||||
else if constexpr(rounding == bf16_rounding_mode::truncate_with_nan)
|
||||
return float_to_bf16_truc_nan_raw(f);
|
||||
else
|
||||
return float_to_bf16_truc_raw(f);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
float bf16_to_float_raw(uint16_t x)
|
||||
{
|
||||
union
|
||||
{
|
||||
uint32_t int32;
|
||||
float fp32;
|
||||
} u = {uint32_t(x) << 16};
|
||||
return u.fp32;
|
||||
}
|
||||
|
||||
template <bf16_rounding_mode rounding = CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT>
|
||||
CK_TILE_HOST_DEVICE bfloat16_t float_to_bf16(float f, constant<rounding> = {})
|
||||
{
|
||||
return bfloat16_t::bit_cast(float_to_bf16_raw(f, constant<rounding>{}));
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
float bf16_to_float(bfloat16_t x) { return static_cast<float>(x); }
|
||||
|
||||
template <bf16_rounding_mode rounding = CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT>
|
||||
CK_TILE_HOST_DEVICE bfloat16_t fp16_to_bf16(half_t f, constant<rounding> = {})
|
||||
{
|
||||
return bfloat16_t::bit_cast(float_to_bf16_raw(static_cast<float>(f), constant<rounding>{}));
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
float bf16_to_fp16(bfloat16_t x) { return float_to_fp16(static_cast<float>(x)); }
|
||||
|
||||
template <class T>
|
||||
struct numeric_limits;
|
||||
|
||||
template <>
|
||||
struct numeric_limits<bfloat16_t>
|
||||
{
|
||||
// minimum finite value, or minimum positive normalized value for float
|
||||
CK_TILE_HOST_DEVICE static constexpr bfloat16_t min() { return bfloat16_t::bit_cast(0x0080); }
|
||||
|
||||
// minumum finite value
|
||||
CK_TILE_HOST_DEVICE static constexpr bfloat16_t lowest()
|
||||
{
|
||||
return bfloat16_t::bit_cast(0xff7f);
|
||||
}
|
||||
|
||||
// maximum finite value
|
||||
CK_TILE_HOST_DEVICE static constexpr bfloat16_t max() { return bfloat16_t::bit_cast(0x7f7f); }
|
||||
|
||||
// difference between 1.0 and next value representable by float
|
||||
CK_TILE_HOST_DEVICE static constexpr bfloat16_t epsilon()
|
||||
{
|
||||
return bfloat16_t::bit_cast(0x1000);
|
||||
}
|
||||
|
||||
// maximum rounding error
|
||||
CK_TILE_HOST_DEVICE static constexpr bfloat16_t round_error() { return bfloat16_t(0.5f); }
|
||||
|
||||
// positive infinity value
|
||||
CK_TILE_HOST_DEVICE static constexpr bfloat16_t infinity()
|
||||
{
|
||||
return bfloat16_t::bit_cast(0x7f80);
|
||||
}
|
||||
|
||||
// quiet NaN
|
||||
CK_TILE_HOST_DEVICE static constexpr bfloat16_t quiet_NaN()
|
||||
{
|
||||
return bfloat16_t::bit_cast(0x7FFF);
|
||||
}
|
||||
|
||||
// signaling NaN
|
||||
CK_TILE_HOST_DEVICE static constexpr bfloat16_t signaling_NaN()
|
||||
{
|
||||
return bfloat16_t::bit_cast(0x7FFF);
|
||||
}
|
||||
|
||||
// smallest positive subnormal value
|
||||
CK_TILE_HOST_DEVICE static constexpr bfloat16_t denorm_min()
|
||||
{
|
||||
return bfloat16_t::bit_cast(0x0001);
|
||||
}
|
||||
};
|
||||
|
||||
CK_TILE_ARITHMETIC_USING_FLOAT(bfloat16_t)
|
||||
|
||||
// math
|
||||
CK_TILE_HOST_DEVICE
|
||||
bfloat16_t abs(const bfloat16_t& x) { return bfloat16_t::bit_cast(x.get() & 0x7fff); }
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
bool isnan(const bfloat16_t& x)
|
||||
{
|
||||
uint16_t xx = x.get();
|
||||
return (xx & 0x7FFF) > 0x7C00;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE
|
||||
bfloat16_t sqrt(bfloat16_t x)
|
||||
{
|
||||
return static_cast<bfloat16_t>(__builtin_amdgcn_sqrtf(static_cast<float>(x)));
|
||||
};
|
||||
|
||||
CK_TILE_DEVICE
|
||||
bfloat16_t exp(bfloat16_t x) { return static_cast<bfloat16_t>(__expf(static_cast<float>(x))); };
|
||||
|
||||
CK_TILE_DEVICE
|
||||
bfloat16_t exp2(bfloat16_t x) { return static_cast<bfloat16_t>(exp2f(static_cast<float>(x))); };
|
||||
|
||||
CK_TILE_DEVICE
|
||||
bfloat16_t log(bfloat16_t x) { return static_cast<bfloat16_t>(__logf(static_cast<float>(x))); };
|
||||
|
||||
using bf16_t = bfloat16_t;
|
||||
|
||||
} // namespace ck_tile
|
||||
735
include/ck_tile/core/numeric/float8.hpp
Normal file
735
include/ck_tile/core/numeric/float8.hpp
Normal file
@@ -0,0 +1,735 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/utility/bit_cast.hpp"
|
||||
#include "ck_tile/core/utility/limits.hpp"
|
||||
#include "ck_tile/core/utility/random.hpp"
|
||||
#include "ck_tile/core/numeric/arithmetic.hpp"
|
||||
#include "ck_tile/core/numeric/half.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include <stdint.h>
|
||||
#include <type_traits>
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// fp8 rounding modes
|
||||
// use standard for rounding to nearest, the faster one
|
||||
// use stochastic for stochastic rounding, helps to avoid error accumulation
|
||||
enum class fp8_rounding_mode
|
||||
{
|
||||
standard = 0,
|
||||
stochastic
|
||||
};
|
||||
|
||||
/*
|
||||
* ______________NANOO_________________ | ______________IEEE________________
|
||||
* e4m3 e5m2 | e4m3 e5m2
|
||||
* bias : 8 16 | 7 15
|
||||
* inf : 1.0000.000 1.00000.00 | N/A s.11111.00
|
||||
* Nan : 1.0000.000 1.00000.00 | s.1111.111 s.11111.{01, 10, 11}
|
||||
* zero : 0.0000.000 0.00000.00 | s.0000.000 s.00000.00
|
||||
* Max(norm) : s.1111.111 (240) s.11111.11(57344) | s.1111.110(448) s.11110.11(57344)
|
||||
* Max(snorm): s.0000.111 s.00000.11 | s.0000.111(448) s.00000.11(57344)
|
||||
* 0.0068359375 2.288818e-05 | 0.013671875 4.57763671875e-05
|
||||
* Min(norm) : s.0001.000 s.00001.00 | s.0001.000 s.00001.00
|
||||
* 2^-7(0.00078125) 2^-15(3.05176e-05) | 2^-6(0.015625) 2^-14(6.10352e-05)
|
||||
* Min(snorm): s.0000.001 s.00000.01 | s.0000.001 s.00000.01
|
||||
* 2^-10(0.00097656) 2^-17(7.629395e-06)| 2^-9(0.001953125) 2^-16(1.52588e-05)
|
||||
*/
|
||||
|
||||
template <fp8_rounding_mode rounding = static_cast<fp8_rounding_mode>(CK_TILE_FLOAT_TO_FP8_DEFAULT)>
|
||||
CK_TILE_HOST_DEVICE uint8_t float_to_fp8_raw(float, constant<rounding> = {});
|
||||
|
||||
template <fp8_rounding_mode rounding = static_cast<fp8_rounding_mode>(CK_TILE_FLOAT_TO_FP8_DEFAULT)>
|
||||
CK_TILE_HOST_DEVICE uint8_t float_to_bf8_raw(float, constant<rounding> = {});
|
||||
|
||||
CK_TILE_HOST_DEVICE float fp8_to_float_raw(uint8_t);
|
||||
CK_TILE_HOST_DEVICE float bf8_to_float_raw(uint8_t);
|
||||
|
||||
struct alignas(1) float8_e4m3_t
|
||||
{
|
||||
static constexpr int exponent = 4;
|
||||
static constexpr int mantissa = 3;
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
static constexpr int bias = 1 << (exponent - 1); // NANOO
|
||||
#else
|
||||
static constexpr int bias = (1 << (exponent - 1)) - 1; // IEEE
|
||||
#endif
|
||||
using raw_type = uint8_t;
|
||||
raw_type data;
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
static float8_e4m3_t bit_cast(raw_type x)
|
||||
{
|
||||
float8_e4m3_t y;
|
||||
y.data = x;
|
||||
return y;
|
||||
}
|
||||
|
||||
// constructor
|
||||
float8_e4m3_t() = default;
|
||||
|
||||
// construct from float
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit float8_e4m3_t(const float& x) { data = float_to_fp8_raw(x); }
|
||||
|
||||
// construct from int
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit float8_e4m3_t(const int& x) { data = float_to_fp8_raw(static_cast<float>(x)); }
|
||||
|
||||
// construct from unsigned int
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit float8_e4m3_t(const unsigned int& x)
|
||||
{
|
||||
data = float_to_fp8_raw(static_cast<float>(x));
|
||||
}
|
||||
|
||||
// cast to float
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit operator float() const { return fp8_to_float_raw(data); }
|
||||
|
||||
// cast to int
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit operator int() const { return static_cast<int>(fp8_to_float_raw(data)); }
|
||||
|
||||
// internal access
|
||||
CK_TILE_HOST_DEVICE
|
||||
raw_type& get() { return data; }
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
raw_type get() const { return data; }
|
||||
};
|
||||
|
||||
struct alignas(1) float8_e5m2_t
|
||||
{
|
||||
static constexpr int exponent = 5;
|
||||
static constexpr int mantissa = 2;
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
static constexpr int bias = 1 << (exponent - 1); // NANOO
|
||||
#else
|
||||
static constexpr int bias = (1 << (exponent - 1)) - 1; // IEEE
|
||||
#endif
|
||||
using raw_type = uint8_t;
|
||||
raw_type data;
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
static float8_e5m2_t bit_cast(raw_type x)
|
||||
{
|
||||
float8_e5m2_t y;
|
||||
y.data = x;
|
||||
return y;
|
||||
}
|
||||
|
||||
// constructor
|
||||
float8_e5m2_t() = default;
|
||||
|
||||
// construct from float
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit float8_e5m2_t(const float& x) { data = float_to_bf8_raw(x); }
|
||||
|
||||
// construct from int
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit float8_e5m2_t(const int& x) { data = float_to_bf8_raw(static_cast<float>(x)); }
|
||||
|
||||
// construct from unsigned int
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit float8_e5m2_t(const unsigned int& x)
|
||||
{
|
||||
data = float_to_bf8_raw(static_cast<float>(x));
|
||||
}
|
||||
|
||||
// cast to float
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit operator float() const { return bf8_to_float_raw(data); }
|
||||
|
||||
// cast to int
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit operator int() const { return static_cast<int>(bf8_to_float_raw(data)); }
|
||||
|
||||
// internal access
|
||||
CK_TILE_HOST_DEVICE
|
||||
raw_type& get() { return data; }
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
raw_type get() const { return data; }
|
||||
};
|
||||
|
||||
// below is sw fp8 conversion, not utilizing hw instruction
|
||||
namespace impl {
|
||||
|
||||
template <typename X, typename Y, bool negative_zero_nan, bool clip, bool stoch>
|
||||
CK_TILE_HOST_DEVICE Y run_cast_to_f8(X x, uint32_t rng)
|
||||
{
|
||||
// fp8/bf8 exponent/mantissa layout
|
||||
constexpr int out_exp = numeric_utils<Y>::exp;
|
||||
constexpr int out_mant = numeric_utils<Y>::mant;
|
||||
|
||||
// original type exponent/mantissa layout
|
||||
constexpr int in_exp = numeric_utils<X>::exp;
|
||||
constexpr int in_mant = numeric_utils<X>::mant;
|
||||
|
||||
int exponent, bias;
|
||||
uint32_t head, mantissa, sign;
|
||||
// nan code is same for float and half
|
||||
constexpr Y nan_code = 0x80;
|
||||
constexpr uint32_t nan_mask = numeric_utils<X>::nan_mask;
|
||||
|
||||
// convert to bitwise
|
||||
using T_bitwise = typename numeric_utils<X>::bitwise_type;
|
||||
T_bitwise x_bitwise = *(reinterpret_cast<T_bitwise*>(&x));
|
||||
|
||||
// unpack the input, depends on datatype
|
||||
head = x_bitwise & numeric_utils<X>::head_mask;
|
||||
mantissa = x_bitwise & numeric_utils<X>::mant_mask;
|
||||
exponent = (head >> in_mant) & numeric_utils<X>::exp_mask;
|
||||
sign = head >> (in_exp + in_mant);
|
||||
bias = numeric_utils<X>::bias;
|
||||
|
||||
uint32_t signed_inf = (sign << (in_exp + in_mant)) + (((1 << in_exp) - 1) << in_mant);
|
||||
uint32_t drop_mask = (1 << (in_mant - out_mant)) - 1;
|
||||
constexpr int max_exp = (1 << out_exp) - (negative_zero_nan ? 1 : 2);
|
||||
|
||||
if constexpr(negative_zero_nan)
|
||||
{
|
||||
if((x_bitwise & nan_mask) == nan_mask)
|
||||
return nan_code;
|
||||
}
|
||||
else
|
||||
{
|
||||
if((x_bitwise & nan_mask) == nan_mask)
|
||||
return signed_inf + (mantissa != 0 ? 1 : 0);
|
||||
}
|
||||
|
||||
// check if x is 0.0
|
||||
if(x_bitwise == 0)
|
||||
return 0;
|
||||
|
||||
// First need to check if it is normal or denorm as there is a difference of implict 1
|
||||
// Then need to adjust the exponent to align with the F8 exponent, in the meanwhile, shift
|
||||
// The mantissa. Then for stochastic rounding, add rng to mantissa and truncate. And for
|
||||
// RNE, no need to add rng. Then probably need to check whether there is carry and adjust
|
||||
// exponent and mantissa again3
|
||||
|
||||
// For IEEE bias mode, the bias is 2^(k-1)-1 where k is the width of exponent bits
|
||||
const int out_bias = (1 << (out_exp - 1)) - 1 + (negative_zero_nan ? 1 : 0);
|
||||
const int out_denormal_act_exponent = 1 - out_bias; // actual exponent of f8 denormal
|
||||
// act_exponent is the actual exponent of fp32/fp16 (after subtracting bias)
|
||||
// out_exponent is the converted f8 exponent with bias encoding
|
||||
// exponent_diff is the diff between fp32/fp16 exponent and f8 exponent,
|
||||
// the difference needs to be adjusted and mantissa shifted
|
||||
int act_exponent, out_exponent, exponent_diff;
|
||||
|
||||
if(exponent == 0)
|
||||
{ // fp32/fp16 is in denormal.
|
||||
/* fp32 denormal is below 2^-127 so it is usually not a concern here, we mostly concern fp16
|
||||
here. In this case, f8 is usually in denormal. But there could be exceptions. fp16 denormal has
|
||||
exponent bias 15 while bf8 with NANOO has exponent bias 16. It means that there are some numbers in
|
||||
fp16 denormal but they are bf8 (NANOO) normals - smallest bf8 (NANOO) normal is 2^-15. fp16 numbers
|
||||
where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8 (NANOO) normal.
|
||||
In this case, the fp16 mantissa should be shift left by 1 */
|
||||
act_exponent = exponent - bias + 1;
|
||||
exponent_diff = out_denormal_act_exponent -
|
||||
act_exponent; // actual exponent is exponent-bias+1 as it is denormal
|
||||
}
|
||||
else
|
||||
{ // fp32/fp16 is normal with implicit 1
|
||||
act_exponent = exponent - bias;
|
||||
if(act_exponent <= out_denormal_act_exponent)
|
||||
{
|
||||
/* This is the case where fp32/fp16 is normal but it is in f8 denormal range.
|
||||
For example fp8 nanoo mode, denormal exponent is -7, but if the fp32/fp16
|
||||
actual exponent is -7, it is actually larger due to the implict 1,
|
||||
Therefore it needs to be adjust to -6 and mantissa shift right by 1.
|
||||
So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */
|
||||
exponent_diff = out_denormal_act_exponent - act_exponent;
|
||||
}
|
||||
else
|
||||
{ // both fp32/fp16 and f8 are in normal range
|
||||
exponent_diff =
|
||||
0; // exponent_diff=0 does not mean there is no difference for this case,
|
||||
// act_exponent could be larger. Just that it does not need shift mantissa
|
||||
}
|
||||
mantissa += (1 << in_mant); // Add the implicit 1 into mantissa
|
||||
}
|
||||
|
||||
bool midpoint = (mantissa & ((1 << (in_mant - out_mant + exponent_diff)) - 1)) ==
|
||||
(1 << (in_mant - out_mant + exponent_diff - 1));
|
||||
/* This part is a bit tricky. The judgment of whether it is a tie needs to be done before we
|
||||
shift right as shift right could rip off some residual part and make something not midpoint look
|
||||
like midpoint. For example, the fp16 number 0x1002 (0 00100 0000000010), it is larger than
|
||||
midpoint, but after shift right by 4 bits, it would look like midpoint. */
|
||||
|
||||
if(exponent_diff > 0)
|
||||
mantissa >>= exponent_diff;
|
||||
else if(exponent_diff == -1)
|
||||
mantissa <<= -exponent_diff;
|
||||
bool implicit_one = mantissa & (1 << in_mant);
|
||||
// if there is no implict 1, it means the f8 is denormal and need to adjust to denorm exponent
|
||||
out_exponent =
|
||||
(act_exponent + exponent_diff) /*actual f8 exponent*/ + out_bias - (implicit_one ? 0 : 1);
|
||||
|
||||
// Now we have the exponent and mantissa adjusted
|
||||
bool odd =
|
||||
mantissa &
|
||||
(1 << (in_mant - out_mant)); // if the least significant bit that is not truncated is 1
|
||||
mantissa += (stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1) : mantissa)) & drop_mask;
|
||||
|
||||
// Now we deal with overflow
|
||||
if(out_exponent == 0)
|
||||
{
|
||||
if((1 << in_mant) & mantissa)
|
||||
{
|
||||
out_exponent = 1; // denormal overflow to become normal, promote exponent
|
||||
// No need to make 1 implicit now as it will be addressed later
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if((1 << (in_mant + 1)) & mantissa)
|
||||
{
|
||||
mantissa >>= 1;
|
||||
out_exponent++;
|
||||
// No need to make 1 implicit now as it will be addressed later
|
||||
}
|
||||
}
|
||||
|
||||
mantissa >>= (in_mant - out_mant);
|
||||
|
||||
if(out_exponent > max_exp)
|
||||
{
|
||||
if(clip)
|
||||
{
|
||||
mantissa = (1 << out_mant) - 1;
|
||||
out_exponent = max_exp;
|
||||
}
|
||||
else
|
||||
{
|
||||
return signed_inf;
|
||||
}
|
||||
}
|
||||
|
||||
// check if x is 0.0 or -0.0
|
||||
if(out_exponent == 0 && mantissa == 0)
|
||||
return negative_zero_nan ? 0 : (sign << (out_exp + out_mant));
|
||||
mantissa &= (1 << out_mant) - 1;
|
||||
return (sign << (out_exp + out_mant)) | (out_exponent << out_mant) | mantissa;
|
||||
}
|
||||
|
||||
template <typename X, typename Y, bool negative_zero_nan>
|
||||
CK_TILE_HOST_DEVICE Y run_cast_from_f8(X x)
|
||||
{
|
||||
// fp8/bf8 exponent/mantissa layout
|
||||
constexpr int in_exp = numeric_utils<X>::exp;
|
||||
constexpr int in_mant = numeric_utils<X>::mant;
|
||||
|
||||
// resulting type exponent/mantissa layout
|
||||
constexpr int out_exp = numeric_utils<Y>::exp;
|
||||
constexpr int out_mant = numeric_utils<Y>::mant;
|
||||
|
||||
// prepare the codes
|
||||
constexpr X nan_code = 0x80;
|
||||
Y Inf, NegInf, NaN, Neg0;
|
||||
using T_bitwise = typename numeric_utils<Y>::bitwise_type;
|
||||
|
||||
constexpr T_bitwise Inf_bitwise = numeric_utils<Y>::Inf;
|
||||
constexpr T_bitwise NegInf_bitwise = numeric_utils<Y>::NegInf;
|
||||
constexpr T_bitwise NaN_bitwise = numeric_utils<Y>::NaN;
|
||||
constexpr T_bitwise Neg0_bitwise = numeric_utils<Y>::Neg0;
|
||||
|
||||
Inf = *(reinterpret_cast<const Y*>(&Inf_bitwise));
|
||||
NegInf = *(reinterpret_cast<const Y*>(&NegInf_bitwise));
|
||||
NaN = *(reinterpret_cast<const Y*>(&NaN_bitwise));
|
||||
Neg0 = *(reinterpret_cast<const Y*>(&Neg0_bitwise));
|
||||
|
||||
// check if x is 0.0
|
||||
if(x == 0)
|
||||
return static_cast<Y>(0);
|
||||
|
||||
// unpack the input
|
||||
uint32_t sign = x >> (in_exp + in_mant);
|
||||
uint32_t mantissa = x & ((1 << in_mant) - 1);
|
||||
int exponent = (x & 0x7F) >> in_mant;
|
||||
|
||||
constexpr int exp_low_cutoff =
|
||||
(1 << (out_exp - 1)) - (1 << (in_exp - 1)) + 1 - (negative_zero_nan ? 1 : 0);
|
||||
T_bitwise retval;
|
||||
|
||||
if constexpr(negative_zero_nan)
|
||||
{
|
||||
if(x == nan_code)
|
||||
return NaN;
|
||||
}
|
||||
else
|
||||
{
|
||||
if(x == nan_code)
|
||||
return Neg0;
|
||||
if(exponent == ((1 << in_exp) - 1))
|
||||
return (mantissa == 0) ? (sign ? NegInf : Inf) : NaN;
|
||||
}
|
||||
|
||||
if((numeric_utils<Y>::mant == 10) && (numeric_utils<X>::mant == 2) && !negative_zero_nan)
|
||||
{
|
||||
retval = x;
|
||||
retval <<= 8;
|
||||
return *(reinterpret_cast<const Y*>(&retval));
|
||||
}
|
||||
|
||||
// subnormal input
|
||||
if(exponent == 0)
|
||||
{
|
||||
// guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above
|
||||
int sh = 1 + clz(mantissa) - (32 - in_mant);
|
||||
mantissa <<= sh;
|
||||
exponent += 1 - sh;
|
||||
mantissa &= ((1 << in_mant) - 1);
|
||||
}
|
||||
exponent += exp_low_cutoff - 1;
|
||||
mantissa <<= out_mant - in_mant;
|
||||
|
||||
// subnormal output (occurs when T=half, we=5, negative_zero_nan=true)
|
||||
if(exponent <= 0)
|
||||
{
|
||||
mantissa |= 1 << out_mant;
|
||||
mantissa >>= 1 - exponent;
|
||||
exponent = 0;
|
||||
}
|
||||
|
||||
retval = (sign << (out_exp + out_mant)) | (exponent << out_mant) | mantissa;
|
||||
return *(reinterpret_cast<const Y*>(&retval));
|
||||
}
|
||||
|
||||
template <typename X, typename Y, bool negative_zero_nan, bool clip, bool stoch>
|
||||
CK_TILE_HOST_DEVICE Y cast_to_f8(X x, uint32_t rng)
|
||||
{
|
||||
// check datatypes
|
||||
constexpr bool is_half = std::is_same<X, half_t>::value;
|
||||
constexpr bool is_float = std::is_same<X, float>::value;
|
||||
static_assert(is_half || is_float, "Only half and float can be casted.");
|
||||
|
||||
return run_cast_to_f8<X, Y, negative_zero_nan, clip, stoch>(x, rng);
|
||||
}
|
||||
|
||||
template <typename X, typename Y, bool negative_zero_nan>
|
||||
CK_TILE_HOST_DEVICE Y cast_from_f8(X x)
|
||||
{
|
||||
// check datatype
|
||||
constexpr bool is_half = std::is_same<Y, half_t>::value;
|
||||
constexpr bool is_float = std::is_same<Y, float>::value;
|
||||
static_assert(is_half || is_float, "only half and float are supported.");
|
||||
|
||||
return run_cast_from_f8<X, Y, negative_zero_nan>(x);
|
||||
}
|
||||
} // namespace impl
|
||||
|
||||
CK_TILE_HOST_DEVICE uint8_t float_to_fp8_sr_raw(float x)
|
||||
{
|
||||
constexpr int seed = 42;
|
||||
uint32_t rng = prand_generator<float, seed>{}(reinterpret_cast<uintptr_t>(&x), x);
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
float max_fp8 = 240.0f;
|
||||
x = x > max_fp8 ? max_fp8 : (x < -max_fp8 ? -max_fp8 : x);
|
||||
union
|
||||
{
|
||||
float fval;
|
||||
uint32_t i32val;
|
||||
uint8_t i8val[4]; // not endian independent
|
||||
} val;
|
||||
val.fval = x;
|
||||
uint32_t ival = 0;
|
||||
ival = __builtin_amdgcn_cvt_sr_fp8_f32(val.fval, rng, ival, 0); // 0 pos
|
||||
val.i32val = ival;
|
||||
return val.i8val[0]; // little endian
|
||||
#else
|
||||
constexpr bool negative_zero_nan = true;
|
||||
constexpr bool clip = true;
|
||||
constexpr fp8_rounding_mode rm = fp8_rounding_mode::stochastic;
|
||||
return impl::
|
||||
cast_to_f8<float, fp8_t, negative_zero_nan, clip, (rm == fp8_rounding_mode::stochastic)>(
|
||||
x, rng);
|
||||
#endif
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE uint8_t float_to_bf8_sr_raw(float x)
|
||||
{
|
||||
constexpr int seed = 42;
|
||||
uint32_t rng = prand_generator<float, seed>{}(reinterpret_cast<uintptr_t>(&x), x);
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
union
|
||||
{
|
||||
float fval;
|
||||
uint32_t i32val;
|
||||
uint8_t i8val[4]; // not endian independent
|
||||
} val;
|
||||
val.fval = x;
|
||||
uint32_t ival = 0;
|
||||
ival = __builtin_amdgcn_cvt_sr_bf8_f32(val.fval, rng, ival, 0); // 0 pos
|
||||
val.i32val = ival;
|
||||
return val.i8val[0]; // little endian
|
||||
#else
|
||||
constexpr bool negative_zero_nan = true;
|
||||
constexpr bool clip = true;
|
||||
constexpr fp8_rounding_mode rm = fp8_rounding_mode::stochastic;
|
||||
return impl::
|
||||
cast_to_f8<float, bf8_t, negative_zero_nan, clip, (rm == fp8_rounding_mode::stochastic)>(
|
||||
x, rng);
|
||||
#endif
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE uint8_t float_to_fp8_rtn_raw(float x)
|
||||
{
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
float max_fp8 = 240.0f;
|
||||
x = x > max_fp8 ? max_fp8 : (x < -max_fp8 ? -max_fp8 : x);
|
||||
union
|
||||
{
|
||||
float fval;
|
||||
uint32_t i32val;
|
||||
uint8_t i8val[4]; // not endian independent
|
||||
} val;
|
||||
val.fval = x;
|
||||
uint32_t ival = 0;
|
||||
ival = __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival, false); // false -> WORD0
|
||||
val.i32val = ival;
|
||||
return val.i8val[0];
|
||||
#else
|
||||
constexpr bool negative_zero_nan = true;
|
||||
constexpr bool clip = true;
|
||||
constexpr fp8_rounding_mode rm = fp8_rounding_mode::standard;
|
||||
constexpr uint32_t rng = 0;
|
||||
return impl::
|
||||
cast_to_f8<float, fp8_t, negative_zero_nan, clip, (rm == fp8_rounding_mode::stochastic)>(
|
||||
x, rng);
|
||||
#endif
|
||||
}
|
||||
CK_TILE_HOST_DEVICE uint8_t float_to_bf8_rtn_raw(float x)
|
||||
{
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
union
|
||||
{
|
||||
float fval;
|
||||
uint32_t i32val;
|
||||
uint8_t i8val[4]; // not endian independent
|
||||
} val;
|
||||
val.fval = x;
|
||||
uint32_t ival = 0;
|
||||
ival = __builtin_amdgcn_cvt_pk_bf8_f32(val.fval, val.fval, ival, false); // false -> WORD0
|
||||
val.i32val = ival;
|
||||
return val.i8val[0];
|
||||
#else
|
||||
constexpr bool negative_zero_nan = true;
|
||||
constexpr bool clip = true;
|
||||
constexpr fp8_rounding_mode rm = fp8_rounding_mode::standard;
|
||||
constexpr uint32_t rng = 0;
|
||||
return impl::
|
||||
cast_to_f8<float, bf8_t, negative_zero_nan, clip, (rm == fp8_rounding_mode::stochastic)>(
|
||||
x, rng);
|
||||
#endif
|
||||
}
|
||||
|
||||
// clang-format off
|
||||
template<fp8_rounding_mode rounding = static_cast<fp8_rounding_mode>(CK_TILE_FLOAT_TO_FP8_DEFAULT)>
|
||||
CK_TILE_HOST_DEVICE uint8_t float_to_fp8_raw(float x, constant<rounding> = {})
|
||||
{
|
||||
if constexpr (rounding == fp8_rounding_mode::standard) return float_to_fp8_rtn_raw(x);
|
||||
else if constexpr (rounding == fp8_rounding_mode::stochastic) return float_to_fp8_sr_raw(x);
|
||||
else return uint8_t{0};
|
||||
}
|
||||
|
||||
template<fp8_rounding_mode rounding = static_cast<fp8_rounding_mode>(CK_TILE_FLOAT_TO_FP8_DEFAULT)>
|
||||
CK_TILE_HOST_DEVICE uint8_t float_to_bf8_raw(float x, constant<rounding> = {})
|
||||
{
|
||||
if constexpr (rounding == fp8_rounding_mode::standard) return float_to_bf8_rtn_raw(x);
|
||||
else if constexpr (rounding == fp8_rounding_mode::stochastic) return float_to_bf8_sr_raw(x);
|
||||
else return uint8_t{0};
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE float fp8_to_float_raw(uint8_t x)
|
||||
{
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
float fval;
|
||||
uint32_t i32val = static_cast<uint32_t>(x);
|
||||
fval = __builtin_amdgcn_cvt_f32_fp8(i32val, 0);
|
||||
// asm volatile("v_cvt_f32_fp8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val));
|
||||
return fval;
|
||||
#else
|
||||
constexpr bool negative_zero_nan = true;
|
||||
return impl::cast_from_f8<fp8_t, float, negative_zero_nan>(x);
|
||||
#endif
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE float bf8_to_float_raw(uint8_t x)
|
||||
{
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
float fval;
|
||||
uint32_t i32val = static_cast<uint32_t>(x);
|
||||
fval = __builtin_amdgcn_cvt_f32_bf8(i32val, 0);
|
||||
// asm volatile("v_cvt_f32_bf8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val));
|
||||
return fval;
|
||||
#else
|
||||
constexpr bool negative_zero_nan = true;
|
||||
return impl::cast_from_f8<bf8_t, float, negative_zero_nan>(x);
|
||||
#endif
|
||||
}
|
||||
|
||||
template<fp8_rounding_mode rounding = static_cast<fp8_rounding_mode>(CK_TILE_FLOAT_TO_FP8_DEFAULT)>
|
||||
CK_TILE_HOST_DEVICE float8_e4m3_t float_to_fp8(float x, constant<rounding> = {})
|
||||
{
|
||||
return float8_e4m3_t::bit_cast(float_to_fp8_raw(x, constant<rounding>{}));
|
||||
}
|
||||
|
||||
template<fp8_rounding_mode rounding = static_cast<fp8_rounding_mode>(CK_TILE_FLOAT_TO_FP8_DEFAULT)>
|
||||
CK_TILE_HOST_DEVICE float8_e5m2_t float_to_bf8(float x, constant<rounding> = {})
|
||||
{
|
||||
return float8_e5m2_t::bit_cast(float_to_bf8_raw(x, constant<rounding>{}));
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE float fp8_to_float(float8_e4m3_t x)
|
||||
{
|
||||
return fp8_to_float_raw(x.get());
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE float bf8_to_float(float8_e5m2_t x)
|
||||
{
|
||||
return bf8_to_float_raw(x.get());
|
||||
}
|
||||
|
||||
// clang-format on
|
||||
using fp8_t = float8_e4m3_t;
|
||||
using bf8_t = float8_e5m2_t;
|
||||
|
||||
template <typename T>
|
||||
struct numeric_utils;
|
||||
|
||||
template <>
|
||||
struct numeric_utils<fp8_t>
|
||||
{
|
||||
static constexpr int exp = fp8_t::exponent;
|
||||
static constexpr int mant = fp8_t::mantissa;
|
||||
static constexpr int bias = fp8_t::bias;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct numeric_utils<bf8_t>
|
||||
{
|
||||
static constexpr int exp = bf8_t::exponent;
|
||||
static constexpr int mant = bf8_t::mantissa;
|
||||
static constexpr int bias = bf8_t::bias;
|
||||
};
|
||||
|
||||
template <class T>
|
||||
struct numeric_limits;
|
||||
|
||||
template <>
|
||||
struct numeric_limits<fp8_t>
|
||||
{
|
||||
// minimum finite value, or minimum positive normalized value for float
|
||||
CK_TILE_HOST_DEVICE static constexpr fp8_t min() { return fp8_t::bit_cast(0x08); }
|
||||
|
||||
// minumum finite value
|
||||
CK_TILE_HOST_DEVICE static constexpr fp8_t lowest() { return fp8_t::bit_cast(0xff); }
|
||||
|
||||
// maximum finite value
|
||||
CK_TILE_HOST_DEVICE static constexpr fp8_t max() { return fp8_t::bit_cast(0x7f); }
|
||||
|
||||
// difference between 1.0 and next value representable by float
|
||||
CK_TILE_HOST_DEVICE static constexpr fp8_t epsilon() { return fp8_t::bit_cast(0x20); }
|
||||
|
||||
// maximum rounding error
|
||||
CK_TILE_HOST_DEVICE static constexpr fp8_t round_error() { return fp8_t(0.5f); }
|
||||
|
||||
// positive infinity value
|
||||
CK_TILE_HOST_DEVICE static constexpr fp8_t infinity() { return fp8_t::bit_cast(0x80); }
|
||||
|
||||
// quiet NaN
|
||||
CK_TILE_HOST_DEVICE static constexpr fp8_t quiet_NaN() { return fp8_t::bit_cast(0x80); }
|
||||
|
||||
// signaling NaN
|
||||
CK_TILE_HOST_DEVICE static constexpr fp8_t signaling_NaN() { return fp8_t::bit_cast(0x80); }
|
||||
|
||||
// smallest positive subnormal value
|
||||
CK_TILE_HOST_DEVICE static constexpr fp8_t denorm_min() { return fp8_t::bit_cast(0x01); }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct numeric_limits<bf8_t>
|
||||
{
|
||||
// minimum finite value, or minimum positive normalized value for float
|
||||
CK_TILE_HOST_DEVICE static constexpr bf8_t min() { return bf8_t::bit_cast(0x04); }
|
||||
|
||||
// minumum finite value
|
||||
CK_TILE_HOST_DEVICE static constexpr bf8_t lowest() { return bf8_t::bit_cast(0xff); }
|
||||
|
||||
// maximum finite value
|
||||
CK_TILE_HOST_DEVICE static constexpr bf8_t max() { return bf8_t::bit_cast(0x7f); }
|
||||
|
||||
// difference between 1.0 and next value representable by float
|
||||
CK_TILE_HOST_DEVICE static constexpr bf8_t epsilon() { return bf8_t::bit_cast(0x34); }
|
||||
|
||||
// maximum rounding error
|
||||
CK_TILE_HOST_DEVICE static constexpr bf8_t round_error() { return bf8_t(0.5f); }
|
||||
|
||||
// positive infinity value
|
||||
CK_TILE_HOST_DEVICE static constexpr bf8_t infinity() { return bf8_t::bit_cast(0x80); }
|
||||
|
||||
// quiet NaN
|
||||
CK_TILE_HOST_DEVICE static constexpr bf8_t quiet_NaN() { return bf8_t::bit_cast(0x80); }
|
||||
|
||||
// signaling NaN
|
||||
CK_TILE_HOST_DEVICE static constexpr bf8_t signaling_NaN() { return bf8_t::bit_cast(0x80); }
|
||||
|
||||
// smallest positive subnormal value
|
||||
CK_TILE_HOST_DEVICE static constexpr bf8_t denorm_min() { return bf8_t::bit_cast(0x01); }
|
||||
};
|
||||
|
||||
CK_TILE_ARITHMETIC_USING_FLOAT(fp8_t)
|
||||
CK_TILE_ARITHMETIC_USING_FLOAT(bf8_t)
|
||||
|
||||
// math
|
||||
CK_TILE_HOST_DEVICE
|
||||
fp8_t abs(const fp8_t& x) { return fp8_t::bit_cast(x.get() & 0x7f); }
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
bool isnan(const fp8_t& x)
|
||||
{
|
||||
uint8_t xx = x.get();
|
||||
return xx == 0x80; // TODO: NANOO
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE
|
||||
fp8_t sqrt(fp8_t x) { return static_cast<fp8_t>(__builtin_amdgcn_sqrtf(static_cast<float>(x))); };
|
||||
|
||||
CK_TILE_DEVICE
|
||||
fp8_t exp(fp8_t x) { return static_cast<fp8_t>(__expf(static_cast<float>(x))); };
|
||||
|
||||
CK_TILE_DEVICE
|
||||
fp8_t exp2(fp8_t x) { return static_cast<fp8_t>(exp2f(static_cast<float>(x))); };
|
||||
|
||||
CK_TILE_DEVICE
|
||||
fp8_t log(fp8_t x) { return static_cast<fp8_t>(__logf(static_cast<float>(x))); };
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
bf8_t abs(const bf8_t& x) { return bf8_t::bit_cast(x.get() & 0x7f); }
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
bool isnan(const bf8_t& x)
|
||||
{
|
||||
uint8_t xx = x.get();
|
||||
return xx == 0x80; // TODO: NANOO
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE
|
||||
bf8_t sqrt(bf8_t x) { return static_cast<bf8_t>(__builtin_amdgcn_sqrtf(static_cast<float>(x))); };
|
||||
|
||||
CK_TILE_DEVICE
|
||||
bf8_t exp(bf8_t x) { return static_cast<bf8_t>(__expf(static_cast<float>(x))); };
|
||||
|
||||
CK_TILE_DEVICE
|
||||
bf8_t exp2(bf8_t x) { return static_cast<bf8_t>(exp2f(static_cast<float>(x))); };
|
||||
|
||||
CK_TILE_DEVICE
|
||||
bf8_t log(bf8_t x) { return static_cast<bf8_t>(__logf(static_cast<float>(x))); };
|
||||
|
||||
} // namespace ck_tile
|
||||
278
include/ck_tile/core/numeric/half.hpp
Normal file
278
include/ck_tile/core/numeric/half.hpp
Normal file
@@ -0,0 +1,278 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/utility/bit_cast.hpp"
|
||||
#include <hip/hip_fp16.h>
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
float fp16_to_float_hip(const _Float16& x);
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
_Float16 float_to_fp16_hip(const float& x);
|
||||
|
||||
// HIP use _Float16 as interchangable data type for float16
|
||||
struct alignas(2) half_t
|
||||
{
|
||||
using raw_type = uint16_t;
|
||||
raw_type data;
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
static half_t bit_cast(raw_type x)
|
||||
{
|
||||
half_t y;
|
||||
y.data = x;
|
||||
return y;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
_Float16 to_fp16() const { return reinterpret_cast<const raw_type&>(data); }
|
||||
|
||||
// constructor
|
||||
half_t() = default;
|
||||
|
||||
// construct from HIP half
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit half_t(const _Float16& x) : data(reinterpret_cast<const raw_type&>(x)) {}
|
||||
|
||||
// construct from float
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit half_t(const float& x) : half_t(float_to_fp16_hip(x)) {}
|
||||
|
||||
// construct from int
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit half_t(const int& x) : half_t(__int2half_rn(x)) {}
|
||||
|
||||
// construct from unsigned int
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit half_t(const unsigned int& x) : half_t(__uint2half_rn(x)) {}
|
||||
|
||||
// cast to float
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit operator float() const { return fp16_to_float_hip(to_fp16()); }
|
||||
|
||||
// cast to int
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit operator int() const { return static_cast<int>(fp16_to_float_hip(to_fp16())); }
|
||||
|
||||
// internal access
|
||||
CK_TILE_HOST_DEVICE
|
||||
raw_type& get() { return data; }
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
raw_type get() const { return data; }
|
||||
};
|
||||
|
||||
// conversions
|
||||
CK_TILE_HOST_DEVICE
|
||||
float fp16_to_float_hip(const _Float16& x)
|
||||
{
|
||||
// return __half2float(x);
|
||||
return static_cast<float>(x);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
_Float16 float_to_fp16_hip(const float& x)
|
||||
{
|
||||
// return __float2half(x);
|
||||
return static_cast<_Float16>(x);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
float fp16_to_float(const half_t& x) { return static_cast<float>(x); }
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
half_t float_to_fp16(const float& x) { return half_t{x}; }
|
||||
|
||||
// limits
|
||||
template <class T>
|
||||
struct numeric_limits;
|
||||
|
||||
template <>
|
||||
struct numeric_limits<half_t>
|
||||
{
|
||||
// minimum finite value, or minimum positive normalized value for float
|
||||
CK_TILE_HOST_DEVICE static constexpr half_t min() { return half_t::bit_cast(0x0400); }
|
||||
|
||||
// minumum finite value
|
||||
CK_TILE_HOST_DEVICE static constexpr half_t lowest() { return half_t::bit_cast(0xFBFF); }
|
||||
|
||||
// maximum finite value
|
||||
CK_TILE_HOST_DEVICE static constexpr half_t max() { return half_t::bit_cast(0x7BFF); }
|
||||
|
||||
// difference between 1.0 and next value representable by float
|
||||
CK_TILE_HOST_DEVICE static constexpr half_t epsilon() { return half_t::bit_cast(0x1800); }
|
||||
|
||||
// maximum rounding error
|
||||
CK_TILE_HOST_DEVICE static constexpr half_t round_error() { return half_t(0.5f); }
|
||||
|
||||
// positive infinity value
|
||||
CK_TILE_HOST_DEVICE static constexpr half_t infinity() { return half_t::bit_cast(0x7C00); }
|
||||
|
||||
// quiet NaN
|
||||
CK_TILE_HOST_DEVICE static constexpr half_t quiet_NaN() { return half_t::bit_cast(0x7FFF); }
|
||||
|
||||
// signaling NaN
|
||||
CK_TILE_HOST_DEVICE static constexpr half_t signaling_NaN() { return half_t::bit_cast(0x7FFF); }
|
||||
|
||||
// smallest positive subnormal value
|
||||
CK_TILE_HOST_DEVICE static constexpr half_t denorm_min() { return half_t::bit_cast(0x0001); }
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct numeric_utils;
|
||||
|
||||
template <>
|
||||
struct numeric_utils<half_t>
|
||||
{
|
||||
static constexpr int exp = 5;
|
||||
static constexpr int mant = 10;
|
||||
static constexpr int bias = 15;
|
||||
static constexpr uint16_t nan_mask = 0x7C00;
|
||||
static constexpr uint16_t head_mask = 0xFC00;
|
||||
static constexpr uint16_t mant_mask = 0x3FF;
|
||||
static constexpr uint16_t exp_mask = 0x1F;
|
||||
static constexpr uint32_t Inf = 0x7C00;
|
||||
static constexpr uint32_t NegInf = 0xFC00;
|
||||
static constexpr uint32_t NaN = 0x7C01;
|
||||
static constexpr uint32_t Neg0 = 0x8000;
|
||||
using bitwise_type = uint16_t;
|
||||
};
|
||||
|
||||
// arithmetic
|
||||
CK_TILE_HOST_DEVICE
|
||||
bool operator==(const half_t& x, const half_t& y) { return __heq(x.to_fp16(), y.to_fp16()); }
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
bool operator!=(const half_t& x, const half_t& y) { return __hne(x.to_fp16(), y.to_fp16()); }
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
bool operator<(const half_t& x, const half_t& y) { return __hlt(x.to_fp16(), y.to_fp16()); }
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
bool operator<=(const half_t& x, const half_t& y) { return __hle(x.to_fp16(), y.to_fp16()); }
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
bool operator>(const half_t& x, const half_t& y) { return __hgt(x.to_fp16(), y.to_fp16()); }
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
bool operator>=(const half_t& x, const half_t& y) { return __hge(x.to_fp16(), y.to_fp16()); }
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
half_t operator+(const half_t& x, const half_t& y)
|
||||
{
|
||||
return half_t(__hadd(x.to_fp16(), y.to_fp16()));
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
half_t operator-(const half_t& x) { return half_t(__hneg(x.to_fp16())); }
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
half_t operator-(const half_t& x, const half_t& y)
|
||||
{
|
||||
return half_t(__hsub(x.to_fp16(), y.to_fp16()));
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
half_t operator*(const half_t& x, const half_t& y)
|
||||
{
|
||||
return half_t(__hmul(x.to_fp16(), y.to_fp16()));
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
half_t operator/(const half_t& x, const half_t& y)
|
||||
{
|
||||
return half_t(__hdiv(x.to_fp16(), y.to_fp16()));
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
half_t& operator+=(half_t& x, const half_t& y)
|
||||
{
|
||||
x = half_t(__hadd(x.to_fp16(), y.to_fp16()));
|
||||
return x;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
half_t& operator-=(half_t& x, const half_t& y)
|
||||
{
|
||||
x = half_t(__hsub(x.to_fp16(), y.to_fp16()));
|
||||
return x;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
half_t& operator*=(half_t& x, const half_t& y)
|
||||
{
|
||||
x = half_t(__hmul(x.to_fp16(), y.to_fp16()));
|
||||
return x;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
half_t& operator/=(half_t& x, const half_t& y)
|
||||
{
|
||||
x = half_t(__hdiv(x.to_fp16(), y.to_fp16()));
|
||||
return x;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
half_t& operator++(half_t& x)
|
||||
{
|
||||
x = half_t(__hadd(x.to_fp16(), half_t(1.0f).to_fp16()));
|
||||
return x;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
half_t& operator--(half_t& x)
|
||||
{
|
||||
x = half_t(__hsub(x.to_fp16(), half_t(1.0f).to_fp16()));
|
||||
return x;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
half_t operator++(half_t& x, int)
|
||||
{
|
||||
half_t y(x);
|
||||
x = half_t(__hadd(x.to_fp16(), half_t(1.0f).to_fp16()));
|
||||
return y;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
half_t operator--(half_t& x, int)
|
||||
{
|
||||
half_t y(x);
|
||||
x = half_t(__hsub(x.to_fp16(), half_t(1.0f).to_fp16()));
|
||||
return y;
|
||||
}
|
||||
|
||||
// math
|
||||
CK_TILE_HOST_DEVICE
|
||||
half_t abs(const half_t& x) { return half_t::bit_cast(x.get() & 0x7fff); }
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
bool isnan(const half_t& x)
|
||||
{
|
||||
uint16_t xx = x.get();
|
||||
return (xx & 0x7FFF) > 0x7C00;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE
|
||||
half_t sqrt(half_t x)
|
||||
{
|
||||
return static_cast<half_t>(__builtin_amdgcn_sqrtf(static_cast<float>(x)));
|
||||
};
|
||||
|
||||
CK_TILE_DEVICE
|
||||
half_t exp(half_t x) { return static_cast<half_t>(__expf(static_cast<float>(x))); };
|
||||
|
||||
CK_TILE_DEVICE
|
||||
half_t exp2(half_t x) { return static_cast<half_t>(exp2f(static_cast<float>(x))); };
|
||||
|
||||
CK_TILE_DEVICE
|
||||
half_t log(half_t x) { return static_cast<half_t>(__logf(static_cast<float>(x))); };
|
||||
|
||||
using fp16_t = half_t;
|
||||
|
||||
} // namespace ck_tile
|
||||
13
include/ck_tile/core/numeric/integer.hpp
Normal file
13
include/ck_tile/core/numeric/integer.hpp
Normal file
@@ -0,0 +1,13 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
#include <stdint.h>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
using index_t = int32_t;
|
||||
using long_index_t = int64_t;
|
||||
using int8_t = int8_t;
|
||||
|
||||
} // namespace ck_tile
|
||||
82
include/ck_tile/core/numeric/integral_constant.hpp
Normal file
82
include/ck_tile/core/numeric/integral_constant.hpp
Normal file
@@ -0,0 +1,82 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <auto v>
|
||||
struct constant
|
||||
{
|
||||
using value_type = decltype(v);
|
||||
using type = constant; // using injected-class-name
|
||||
static constexpr value_type value = v;
|
||||
constexpr CK_TILE_HOST_DEVICE operator value_type() const noexcept { return value; }
|
||||
constexpr CK_TILE_HOST_DEVICE value_type operator()() const noexcept { return value; }
|
||||
};
|
||||
|
||||
template <typename T, T v>
|
||||
struct integral_constant : constant<v>
|
||||
{
|
||||
using value_type = T;
|
||||
using type = integral_constant; // using injected-class-name
|
||||
static constexpr T value = v;
|
||||
// constexpr CK_TILE_HOST_DEVICE operator value_type() const noexcept { return value; }
|
||||
// constexpr CK_TILE_HOST_DEVICE value_type operator()() const noexcept { return value; } //
|
||||
};
|
||||
|
||||
template <index_t v>
|
||||
using number = constant<v>;
|
||||
|
||||
template <long_index_t v>
|
||||
using long_number = integral_constant<long_index_t, v>;
|
||||
|
||||
template <bool b>
|
||||
using bool_constant = constant<b>;
|
||||
|
||||
#define CK_TILE_LEFT_UNARY_OP(OP) \
|
||||
template <auto x> \
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator OP(constant<x>) \
|
||||
{ \
|
||||
return constant<(OP x)>{}; \
|
||||
}
|
||||
|
||||
#define CK_TILE_BINARY_OP(OP) \
|
||||
template <auto x, auto y> \
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator OP(constant<x>, constant<y>) \
|
||||
{ \
|
||||
return constant<(x OP y)>{}; \
|
||||
}
|
||||
|
||||
CK_TILE_LEFT_UNARY_OP(+)
|
||||
CK_TILE_LEFT_UNARY_OP(-)
|
||||
CK_TILE_LEFT_UNARY_OP(~)
|
||||
CK_TILE_LEFT_UNARY_OP(!)
|
||||
CK_TILE_LEFT_UNARY_OP(*)
|
||||
|
||||
CK_TILE_BINARY_OP(+)
|
||||
CK_TILE_BINARY_OP(-)
|
||||
CK_TILE_BINARY_OP(*)
|
||||
CK_TILE_BINARY_OP(/)
|
||||
CK_TILE_BINARY_OP(%)
|
||||
CK_TILE_BINARY_OP(&)
|
||||
CK_TILE_BINARY_OP(|)
|
||||
CK_TILE_BINARY_OP(^)
|
||||
CK_TILE_BINARY_OP(<<)
|
||||
CK_TILE_BINARY_OP(>>)
|
||||
CK_TILE_BINARY_OP(&&)
|
||||
CK_TILE_BINARY_OP(||)
|
||||
CK_TILE_BINARY_OP(==)
|
||||
CK_TILE_BINARY_OP(!=)
|
||||
CK_TILE_BINARY_OP(>)
|
||||
CK_TILE_BINARY_OP(<)
|
||||
CK_TILE_BINARY_OP(>=)
|
||||
CK_TILE_BINARY_OP(<=)
|
||||
|
||||
#undef CK_TILE_LEFT_UNARY_OP
|
||||
#undef CK_TILE_BINARY_OP
|
||||
|
||||
} // namespace ck_tile
|
||||
309
include/ck_tile/core/numeric/math.hpp
Normal file
309
include/ck_tile/core/numeric/math.hpp
Normal file
@@ -0,0 +1,309 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include <type_traits>
|
||||
#include <stdint.h>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename T, T s>
|
||||
struct scales
|
||||
{
|
||||
CK_TILE_HOST_DEVICE constexpr T operator()(T a) const { return s * a; }
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct plus
|
||||
{
|
||||
CK_TILE_HOST_DEVICE constexpr T operator()(T a, T b) const { return a + b; }
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct minus
|
||||
{
|
||||
CK_TILE_HOST_DEVICE constexpr T operator()(T a, T b) const { return a - b; }
|
||||
};
|
||||
|
||||
struct multiplies
|
||||
{
|
||||
template <typename A, typename B>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator()(const A& a, const B& b) const
|
||||
{
|
||||
return a * b;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct maximize
|
||||
{
|
||||
CK_TILE_HOST_DEVICE constexpr T operator()(T a, T b) const { return a >= b ? a : b; }
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct minimize
|
||||
{
|
||||
CK_TILE_HOST_DEVICE constexpr T operator()(T a, T b) const { return a <= b ? a : b; }
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct integer_divide_ceiler
|
||||
{
|
||||
CK_TILE_HOST_DEVICE constexpr T operator()(T a, T b) const
|
||||
{
|
||||
static_assert(std::is_same<T, index_t>{} || std::is_same<T, int>{}, "wrong type");
|
||||
return (a + b - number<1>{}) / b;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename X, typename Y>
|
||||
CK_TILE_HOST_DEVICE constexpr auto integer_divide_floor(X x, Y y)
|
||||
{
|
||||
return x / y;
|
||||
}
|
||||
|
||||
template <typename X, typename Y>
|
||||
CK_TILE_HOST_DEVICE constexpr auto integer_divide_ceil(X x, Y y)
|
||||
{
|
||||
return (x + y - number<1>{}) / y;
|
||||
}
|
||||
|
||||
template <typename X, typename Y>
|
||||
CK_TILE_HOST_DEVICE constexpr auto integer_least_multiple(X x, Y y)
|
||||
{
|
||||
return y * integer_divide_ceil(x, y);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE constexpr T max(T x)
|
||||
{
|
||||
return x;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE constexpr T max(T x, T y)
|
||||
{
|
||||
return x > y ? x : y;
|
||||
}
|
||||
|
||||
template <index_t X>
|
||||
CK_TILE_HOST_DEVICE constexpr index_t max(number<X>, index_t y)
|
||||
{
|
||||
return X > y ? X : y;
|
||||
}
|
||||
|
||||
template <index_t Y>
|
||||
CK_TILE_HOST_DEVICE constexpr index_t max(index_t x, number<Y>)
|
||||
{
|
||||
return x > Y ? x : Y;
|
||||
}
|
||||
|
||||
template <typename X, typename... Ys>
|
||||
CK_TILE_HOST_DEVICE constexpr auto max(X x, Ys... ys)
|
||||
{
|
||||
static_assert(sizeof...(Ys) > 0, "not enough argument");
|
||||
return max(x, max(ys...));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE constexpr T min(T x)
|
||||
{
|
||||
return x;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE constexpr T min(T x, T y)
|
||||
{
|
||||
return x < y ? x : y;
|
||||
}
|
||||
|
||||
template <index_t X>
|
||||
CK_TILE_HOST_DEVICE constexpr index_t min(number<X>, index_t y)
|
||||
{
|
||||
return X < y ? X : y;
|
||||
}
|
||||
|
||||
template <index_t Y>
|
||||
CK_TILE_HOST_DEVICE constexpr index_t min(index_t x, number<Y>)
|
||||
{
|
||||
return x < Y ? x : Y;
|
||||
}
|
||||
|
||||
template <typename X, typename... Ys>
|
||||
CK_TILE_HOST_DEVICE constexpr auto min(X x, Ys... ys)
|
||||
{
|
||||
static_assert(sizeof...(Ys) > 0, "not enough argument");
|
||||
return min(x, min(ys...));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE constexpr T clamp(const T& x, const T& lowerbound, const T& upperbound)
|
||||
{
|
||||
return min(max(x, lowerbound), upperbound);
|
||||
}
|
||||
|
||||
// greatest common divisor, aka highest common factor
|
||||
CK_TILE_HOST_DEVICE constexpr index_t gcd(index_t x, index_t y)
|
||||
{
|
||||
if(x < 0)
|
||||
{
|
||||
return gcd(-x, y);
|
||||
}
|
||||
else if(y < 0)
|
||||
{
|
||||
return gcd(x, -y);
|
||||
}
|
||||
else if(x == y || x == 0)
|
||||
{
|
||||
return y;
|
||||
}
|
||||
else if(y == 0)
|
||||
{
|
||||
return x;
|
||||
}
|
||||
else if(x > y)
|
||||
{
|
||||
return gcd(x % y, y);
|
||||
}
|
||||
else
|
||||
{
|
||||
return gcd(x, y % x);
|
||||
}
|
||||
}
|
||||
|
||||
template <index_t X, index_t Y>
|
||||
CK_TILE_HOST_DEVICE constexpr auto gcd(number<X>, number<Y>)
|
||||
{
|
||||
constexpr auto r = gcd(X, Y);
|
||||
|
||||
return number<r>{};
|
||||
}
|
||||
|
||||
template <typename X,
|
||||
typename... Ys,
|
||||
typename std::enable_if<sizeof...(Ys) >= 2, bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr auto gcd(X x, Ys... ys)
|
||||
{
|
||||
return gcd(x, gcd(ys...));
|
||||
}
|
||||
|
||||
// least common multiple
|
||||
template <typename X, typename Y>
|
||||
CK_TILE_HOST_DEVICE constexpr auto lcm(X x, Y y)
|
||||
{
|
||||
return (x * y) / gcd(x, y);
|
||||
}
|
||||
|
||||
template <typename X,
|
||||
typename... Ys,
|
||||
typename std::enable_if<sizeof...(Ys) >= 2, bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr auto lcm(X x, Ys... ys)
|
||||
{
|
||||
return lcm(x, lcm(ys...));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct equal
|
||||
{
|
||||
CK_TILE_HOST_DEVICE constexpr bool operator()(T x, T y) const { return x == y; }
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct less
|
||||
{
|
||||
CK_TILE_HOST_DEVICE constexpr bool operator()(T x, T y) const { return x < y; }
|
||||
};
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr int32_t next_power_of_two(int32_t x)
|
||||
{
|
||||
// TODO: x need to be 2 ~ 0x7fffffff. 0, 1, or larger than 0x7fffffff will compile fail
|
||||
return 1 << (32 - __builtin_clz(x - 1));
|
||||
}
|
||||
|
||||
template <index_t X>
|
||||
CK_TILE_HOST_DEVICE constexpr auto next_power_of_two()
|
||||
{
|
||||
constexpr index_t y = next_power_of_two(X);
|
||||
return number<y>{};
|
||||
}
|
||||
|
||||
template <index_t X>
|
||||
CK_TILE_HOST_DEVICE constexpr auto next_power_of_two(number<X>)
|
||||
{
|
||||
constexpr index_t y = next_power_of_two(X);
|
||||
return number<y>{};
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr int32_t integer_log2_floor(int32_t x)
|
||||
{
|
||||
// TODO: x need to be 1 ~ 0x7fffffff
|
||||
// __builtin_clz will produce unexpected result if x is 0;
|
||||
return 31 - __builtin_clz(x);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr bool is_power_of_two_integer(int32_t x)
|
||||
{
|
||||
// TODO: x need to be 1 ~ 0x7fffffff
|
||||
return x == (1 << integer_log2_floor(x));
|
||||
}
|
||||
|
||||
#ifndef C_LOG2E
|
||||
#define C_LOG2E 1.44269504088896340736 // log2(e)
|
||||
#endif
|
||||
|
||||
template <typename T>
|
||||
struct log2e;
|
||||
|
||||
template <>
|
||||
struct log2e<double>
|
||||
{
|
||||
static constexpr double value = C_LOG2E;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct log2e<float>
|
||||
{
|
||||
static constexpr float value = C_LOG2E;
|
||||
};
|
||||
|
||||
template <typename T = double>
|
||||
inline constexpr T log2e_v = log2e<T>::value;
|
||||
|
||||
// math
|
||||
CK_TILE_HOST_DEVICE
|
||||
float abs(const float& x)
|
||||
{
|
||||
union
|
||||
{
|
||||
float f32;
|
||||
uint32_t u32;
|
||||
} y;
|
||||
y.f32 = x;
|
||||
y.u32 = y.u32 & 0x7fffffff;
|
||||
return y.f32;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
bool isnan(const float& x)
|
||||
{
|
||||
uint32_t xx = reinterpret_cast<const uint32_t&>(x);
|
||||
return (xx & 0x7fffffff) > 0x7F800000;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE
|
||||
float sqrt(float x) { return __builtin_amdgcn_sqrtf(x); };
|
||||
|
||||
CK_TILE_DEVICE
|
||||
float exp(float x) { return __expf(x); };
|
||||
|
||||
CK_TILE_DEVICE
|
||||
float exp2(float x) { return exp2f(x); };
|
||||
|
||||
CK_TILE_DEVICE
|
||||
float log(float x) { return __logf(x); };
|
||||
|
||||
} // namespace ck_tile
|
||||
45
include/ck_tile/core/numeric/type_convert.hpp
Normal file
45
include/ck_tile/core/numeric/type_convert.hpp
Normal file
@@ -0,0 +1,45 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include <type_traits>
|
||||
|
||||
namespace ck_tile {
|
||||
#if 0
|
||||
// Convert X to Y, both X and Y are non-const data types.
|
||||
template <typename Y,
|
||||
typename X,
|
||||
std::enable_if_t<!(std::is_const_v<Y> || std::is_const_v<X>), bool> = false>
|
||||
CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
|
||||
{
|
||||
static_assert(!std::is_reference_v<Y> && !std::is_reference_v<X>);
|
||||
|
||||
return static_cast<Y>(x);
|
||||
}
|
||||
|
||||
// TODO: const version never called, we may never need
|
||||
// Convert X to Y, either X or Y is a const data type.
|
||||
template <typename Y,
|
||||
typename X,
|
||||
std::enable_if_t<std::is_const_v<Y> || std::is_const_v<X>, bool> = false>
|
||||
CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
|
||||
{
|
||||
static_assert(!std::is_reference_v<Y> && !std::is_reference_v<X>);
|
||||
|
||||
using NonConstY = std::remove_const_t<Y>;
|
||||
using NonConstX = std::remove_const_t<X>;
|
||||
return static_cast<Y>(type_convert<NonConstY, NonConstX>(x));
|
||||
}
|
||||
#else
|
||||
// compatible way to call conversion operator and constructor of each custom data type
|
||||
template <typename Y, typename X>
|
||||
CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
|
||||
{
|
||||
static_assert(!std::is_reference_v<Y> && !std::is_reference_v<X>);
|
||||
|
||||
return static_cast<Y>(x);
|
||||
}
|
||||
#endif
|
||||
} // namespace ck_tile
|
||||
304
include/ck_tile/core/numeric/vector_type.hpp
Normal file
304
include/ck_tile/core/numeric/vector_type.hpp
Normal file
@@ -0,0 +1,304 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/container/array.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/numeric/float8.hpp"
|
||||
#include "ck_tile/core/numeric/half.hpp"
|
||||
#include "ck_tile/core/numeric/bfloat16.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// TODO: the whole content of this file should consider deprecated!
|
||||
template <typename T_, index_t N_>
|
||||
struct vector_type
|
||||
{
|
||||
static constexpr index_t N = N_;
|
||||
using value_type = T_;
|
||||
using type = value_type __attribute__((ext_vector_type(N))); // this is danguous
|
||||
|
||||
CK_HOST_DEVICE constexpr vector_type()
|
||||
{
|
||||
for(auto i = 0; i < N; i++)
|
||||
data[i] = static_cast<value_type>(0);
|
||||
}
|
||||
CK_HOST_DEVICE constexpr vector_type(type v)
|
||||
{
|
||||
auto& r = reinterpret_cast<const array<value_type, N>&>(v);
|
||||
for(auto i = 0; i < N; i++)
|
||||
data[i] = r.get(i);
|
||||
}
|
||||
|
||||
value_type data[N];
|
||||
CK_HOST_DEVICE static constexpr auto size() { return N; }
|
||||
CK_HOST_DEVICE auto& get() { return data; }
|
||||
CK_HOST_DEVICE const auto& get() const { return data; }
|
||||
CK_HOST_DEVICE auto& get(index_t i) { return data[i]; }
|
||||
CK_HOST_DEVICE const auto& get(index_t i) const { return data[i]; }
|
||||
|
||||
template <index_t I>
|
||||
CK_HOST_DEVICE auto& operator[](number<I>)
|
||||
{
|
||||
return data[I];
|
||||
}
|
||||
template <index_t I>
|
||||
CK_HOST_DEVICE const auto& operator[](number<I>) const
|
||||
{
|
||||
return data[I];
|
||||
}
|
||||
template <index_t I>
|
||||
CK_HOST_DEVICE auto& operator()(number<I>)
|
||||
{
|
||||
return data[I];
|
||||
}
|
||||
|
||||
CK_HOST_DEVICE auto& at(index_t i) { return data[i]; }
|
||||
CK_HOST_DEVICE const auto& at(index_t i) const { return data[i]; }
|
||||
template <index_t I>
|
||||
CK_HOST_DEVICE auto& at()
|
||||
{
|
||||
return data[I];
|
||||
}
|
||||
template <index_t I>
|
||||
CK_HOST_DEVICE const auto& at() const
|
||||
{
|
||||
return data[I];
|
||||
}
|
||||
template <index_t I>
|
||||
CK_HOST_DEVICE auto& at(number<I>)
|
||||
{
|
||||
return data[I];
|
||||
}
|
||||
template <index_t I>
|
||||
CK_HOST_DEVICE const auto& at(number<I>) const
|
||||
{
|
||||
return data[I];
|
||||
}
|
||||
|
||||
#define _VT_COMMON_AS() \
|
||||
static_assert(sizeof(value_type) * N % sizeof(Tx) == 0); \
|
||||
constexpr int vx = sizeof(value_type) * N / sizeof(Tx)
|
||||
|
||||
template <typename Tx>
|
||||
CK_HOST_DEVICE auto& get_as()
|
||||
{
|
||||
_VT_COMMON_AS();
|
||||
return reinterpret_cast<array<Tx, vx>&>(data);
|
||||
}
|
||||
template <typename Tx>
|
||||
CK_HOST_DEVICE const auto& get_as() const
|
||||
{
|
||||
_VT_COMMON_AS();
|
||||
return reinterpret_cast<const array<Tx, vx>&>(data);
|
||||
}
|
||||
template <typename Tx>
|
||||
CK_HOST_DEVICE auto& get_as(index_t i)
|
||||
{
|
||||
_VT_COMMON_AS();
|
||||
return reinterpret_cast<array<Tx, vx>&>(data).get(i);
|
||||
}
|
||||
template <typename Tx>
|
||||
CK_HOST_DEVICE const auto& get_as(index_t i) const
|
||||
{
|
||||
_VT_COMMON_AS();
|
||||
return reinterpret_cast<const array<Tx, vx>&>(data).get(i);
|
||||
}
|
||||
#undef _VT_COMMON_AS
|
||||
};
|
||||
|
||||
template <typename T, index_t N>
|
||||
struct vector_type_maker
|
||||
{
|
||||
using type = vector_type<T, N>;
|
||||
};
|
||||
|
||||
template <typename T, index_t N0, index_t N1>
|
||||
struct vector_type_maker<T __attribute__((ext_vector_type(N1))), N0>
|
||||
{
|
||||
using type = vector_type<T, N0 * N1>;
|
||||
};
|
||||
|
||||
template <typename T, index_t N0, index_t N1>
|
||||
struct vector_type_maker<vector_type<T, N1>, N0>
|
||||
{
|
||||
using type = vector_type<T, N0 * N1>;
|
||||
};
|
||||
|
||||
template <typename T, index_t N>
|
||||
using vector_type_maker_t = typename vector_type_maker<T, N>::type;
|
||||
|
||||
template <typename T, index_t N>
|
||||
CK_HOST_DEVICE constexpr auto make_vector_type(number<N>)
|
||||
{
|
||||
return typename vector_type_maker<T, N>::type{};
|
||||
}
|
||||
|
||||
// scalar_type
|
||||
template <typename TV>
|
||||
struct scalar_type;
|
||||
|
||||
// is_scalar_type
|
||||
template <typename TV>
|
||||
struct is_scalar_type
|
||||
{
|
||||
static constexpr bool value = (scalar_type<remove_cvref_t<TV>>::vector_size == 1);
|
||||
};
|
||||
|
||||
// has_same_scalar_type
|
||||
template <typename X, typename Y>
|
||||
using has_same_scalar_type = is_same<typename scalar_type<remove_cvref_t<X>>::type,
|
||||
typename scalar_type<remove_cvref_t<Y>>::type>;
|
||||
|
||||
template <typename T, index_t N>
|
||||
struct scalar_type<T __attribute__((ext_vector_type(N)))>
|
||||
{
|
||||
using type = T;
|
||||
static constexpr index_t vector_size = N;
|
||||
};
|
||||
|
||||
template <typename T, index_t N>
|
||||
struct scalar_type<vector_type<T, N>>
|
||||
{
|
||||
using type = T;
|
||||
static constexpr index_t vector_size = N;
|
||||
};
|
||||
|
||||
//
|
||||
template <>
|
||||
struct scalar_type<double>
|
||||
{
|
||||
using type = double;
|
||||
static constexpr index_t vector_size = 1;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct scalar_type<float>
|
||||
{
|
||||
using type = float;
|
||||
static constexpr index_t vector_size = 1;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct scalar_type<half_t>
|
||||
{
|
||||
using type = half_t;
|
||||
static constexpr index_t vector_size = 1;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct scalar_type<bhalf_t>
|
||||
{
|
||||
using type = bhalf_t;
|
||||
static constexpr index_t vector_size = 1;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct scalar_type<int64_t>
|
||||
{
|
||||
using type = int64_t;
|
||||
static constexpr index_t vector_size = 1;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct scalar_type<int32_t>
|
||||
{
|
||||
using type = int32_t;
|
||||
static constexpr index_t vector_size = 1;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct scalar_type<int8_t>
|
||||
{
|
||||
using type = int8_t;
|
||||
static constexpr index_t vector_size = 1;
|
||||
};
|
||||
|
||||
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
|
||||
template <>
|
||||
struct scalar_type<int4_t>
|
||||
{
|
||||
using type = int4_t;
|
||||
static constexpr index_t vector_size = 1;
|
||||
};
|
||||
#endif
|
||||
|
||||
template <>
|
||||
struct scalar_type<fp8_t>
|
||||
{
|
||||
using type = fp8_t;
|
||||
static constexpr index_t vector_size = 1;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct scalar_type<bf8_t>
|
||||
{
|
||||
using type = bf8_t;
|
||||
static constexpr index_t vector_size = 1;
|
||||
};
|
||||
|
||||
// below are some pre-defines of ext_vector_type
|
||||
// fp64
|
||||
using double2_t = typename vector_type<double, 2>::type;
|
||||
using double4_t = typename vector_type<double, 4>::type;
|
||||
|
||||
// fp32
|
||||
using float2_t = typename vector_type<float, 2>::type;
|
||||
using float4_t = typename vector_type<float, 4>::type;
|
||||
using float8_t = typename vector_type<float, 8>::type;
|
||||
using float16_t = typename vector_type<float, 16>::type;
|
||||
using float32_t = typename vector_type<float, 32>::type;
|
||||
using float64_t = typename vector_type<float, 64>::type;
|
||||
|
||||
// fp16
|
||||
using half2_t = typename vector_type<half_t, 2>::type;
|
||||
using half4_t = typename vector_type<half_t, 4>::type;
|
||||
using half8_t = typename vector_type<half_t, 8>::type;
|
||||
using half16_t = typename vector_type<half_t, 16>::type;
|
||||
using half32_t = typename vector_type<half_t, 32>::type;
|
||||
using half64_t = typename vector_type<half_t, 64>::type;
|
||||
|
||||
// bfp16
|
||||
using bhalf2_t = typename vector_type<bhalf_t, 2>::type;
|
||||
using bhalf4_t = typename vector_type<bhalf_t, 4>::type;
|
||||
using bhalf8_t = typename vector_type<bhalf_t, 8>::type;
|
||||
using bhalf16_t = typename vector_type<bhalf_t, 16>::type;
|
||||
using bhalf32_t = typename vector_type<bhalf_t, 32>::type;
|
||||
using bhalf64_t = typename vector_type<bhalf_t, 64>::type;
|
||||
|
||||
// i32
|
||||
using int32x2_t = typename vector_type<int32_t, 2>::type;
|
||||
using int32x4_t = typename vector_type<int32_t, 4>::type;
|
||||
using int32x8_t = typename vector_type<int32_t, 8>::type;
|
||||
using int32x16_t = typename vector_type<int32_t, 16>::type;
|
||||
using int32x32_t = typename vector_type<int32_t, 32>::type;
|
||||
using int32x64_t = typename vector_type<int32_t, 64>::type;
|
||||
|
||||
// i8
|
||||
using int8x2_t = typename vector_type<int8_t, 2>::type;
|
||||
using int8x4_t = typename vector_type<int8_t, 4>::type;
|
||||
using int8x8_t = typename vector_type<int8_t, 8>::type;
|
||||
using int8x16_t = typename vector_type<int8_t, 16>::type;
|
||||
using int8x32_t = typename vector_type<int8_t, 32>::type;
|
||||
using int8x64_t = typename vector_type<int8_t, 64>::type;
|
||||
|
||||
// f8
|
||||
using fp8x2_t = typename vector_type<fp8_t, 2>::type;
|
||||
using fp8x4_t = typename vector_type<fp8_t, 4>::type;
|
||||
using fp8x8_t = typename vector_type<fp8_t, 8>::type;
|
||||
using fp8x16_t = typename vector_type<fp8_t, 16>::type;
|
||||
using fp8x32_t = typename vector_type<fp8_t, 32>::type;
|
||||
using fp8x64_t = typename vector_type<fp8_t, 64>::type;
|
||||
|
||||
// bf8
|
||||
using bf8x2_t = typename vector_type<bf8_t, 2>::type;
|
||||
using bf8x4_t = typename vector_type<bf8_t, 4>::type;
|
||||
using bf8x8_t = typename vector_type<bf8_t, 8>::type;
|
||||
using bf8x16_t = typename vector_type<bf8_t, 16>::type;
|
||||
using bf8x32_t = typename vector_type<bf8_t, 32>::type;
|
||||
using bf8x64_t = typename vector_type<bf8_t, 64>::type;
|
||||
|
||||
} // namespace ck_tile
|
||||
Reference in New Issue
Block a user