Reorganize project folders (#6)

This commit is contained in:
Joseph Macaranas
2025-04-30 13:46:39 -04:00
committed by GitHub
commit 1eb2e57380
3952 changed files with 654944 additions and 0 deletions

View File

@@ -0,0 +1,423 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/numeric/numeric.hpp"
#include <stdint.h>
#pragma once
namespace ck_tile {
enum class bf16_rounding_mode
{
standard = 0, // rtn
truncate_with_nan,
truncate,
standard_asm,
rta_asm, // round to nearest away
};
template <bf16_rounding_mode rounding =
static_cast<bf16_rounding_mode>(CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT)>
CK_TILE_HOST_DEVICE constexpr uint16_t float_to_bf16_raw(float f, constant<rounding> = {});
template <bf16_rounding_mode rounding =
static_cast<bf16_rounding_mode>(CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT)>
CK_TILE_HOST_DEVICE constexpr uint16_t double_to_bf16_raw(double f, constant<rounding> = {});
CK_TILE_HOST_DEVICE
constexpr float bf16_to_float_raw(uint16_t x);
CK_TILE_HOST_DEVICE
constexpr double bf16_to_double_raw(uint16_t x);
#if CK_TILE_USE_CUSTOM_DATA_TYPE
// HIP use __hip_bfloat16 as struct
struct alignas(2) bfloat16_t
{
using raw_type = uint16_t;
raw_type data;
CK_TILE_HOST_DEVICE
static constexpr bfloat16_t bit_cast(raw_type x)
{
bfloat16_t y;
y.data = x;
return y;
}
// constructor
constexpr bfloat16_t() : data() {}
// construct from float
CK_TILE_HOST_DEVICE
explicit constexpr bfloat16_t(const float& x) : data(float_to_bf16_raw(x)) {}
// construct from double
CK_TILE_HOST_DEVICE
explicit constexpr bfloat16_t(const double& x) : data(double_to_bf16_raw(x)) {}
// construct from int
CK_TILE_HOST_DEVICE
explicit constexpr bfloat16_t(const int& x) : data(float_to_bf16_raw(static_cast<float>(x))) {}
// construct from unsigned int
CK_TILE_HOST_DEVICE
explicit constexpr bfloat16_t(const unsigned int& x)
: data(float_to_bf16_raw(static_cast<float>(x)))
{
}
// cast to float
CK_TILE_HOST_DEVICE
explicit constexpr operator float() const { return bf16_to_float_raw(data); }
// cast to float
CK_TILE_HOST_DEVICE
explicit constexpr operator double() const { return bf16_to_double_raw(data); }
// cast to int
CK_TILE_HOST_DEVICE
explicit constexpr operator int() const { return static_cast<int>(bf16_to_float_raw(data)); }
// internal access
CK_TILE_HOST_DEVICE
constexpr raw_type& get() { return data; }
CK_TILE_HOST_DEVICE
constexpr raw_type get() const { return data; }
};
template <typename>
struct native_t;
template <>
struct native_t<bfloat16_t>
{
using type = ushort;
};
using bf16_t = bfloat16_t;
using bf16_raw_t = typename bf16_t::raw_type;
#else
using bfloat16_t = ushort;
using bf16_t = bfloat16_t;
using bf16_raw_t = uint16_t;
#endif
// round to nearest
CK_TILE_HOST_DEVICE
constexpr uint16_t float_to_bf16_rtn_raw(float f)
{
union
{
float fp32;
uint32_t int32;
} u = {f};
if(~u.int32 & 0x7f800000)
{
// When the exponent bits are not all 1s, then the value is zero, normal,
// or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus
// 1 if the least significant bit of the bfloat16 mantissa is 1 (odd).
// This causes the bfloat16's mantissa to be incremented by 1 if the 16
// least significant bits of the float mantissa are greater than 0x8000,
// or if they are equal to 0x8000 and the least significant bit of the
// bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when
// the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already
// has the value 0x7f, then incrementing it causes it to become 0x00 and
// the exponent is incremented by one, which is the next higher FP value
// to the unrounded bfloat16 value. When the bfloat16 value is subnormal
// with an exponent of 0x00 and a mantissa of 0x7F, it may be rounded up
// to a normal value with an exponent of 0x01 and a mantissa of 0x00.
// When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F,
// incrementing it causes it to become an exponent of 0xFF and a mantissa
// of 0x00, which is Inf, the next higher value to the unrounded value.
u.int32 += 0x7fff + ((u.int32 >> 16) & 1); // Round to nearest, round to even
}
else if(u.int32 & 0xffff)
{
// When all of the exponent bits are 1, the value is Inf or NaN.
// Inf is indicated by a zero mantissa. NaN is indicated by any nonzero
// mantissa bit. Quiet NaN is indicated by the most significant mantissa
// bit being 1. Signaling NaN is indicated by the most significant
// mantissa bit being 0 but some other bit(s) being 1. If any of the
// lower 16 bits of the mantissa are 1, we set the least significant bit
// of the bfloat16 mantissa, in order to preserve signaling NaN in case
// the bloat16's mantissa bits are all 0.
u.int32 |= 0x10000; // Preserve signaling NaN
}
return uint16_t(u.int32 >> 16);
}
CK_TILE_HOST
constexpr uint16_t float_to_bf16_rtn_asm(float f) { return float_to_bf16_rtn_raw(f); }
CK_TILE_DEVICE
uint16_t float_to_bf16_rtn_asm(float f)
{
union
{
float fp32;
uint32_t int32;
} u = {f};
static constexpr uint32_t FP32_NAN = 0x7fff0000;
static constexpr uint32_t ROUND_BIAS_FOR_BF16 = 0x7fff;
using uint32x2_t = uint32_t __attribute__((ext_vector_type(2)));
uint32x2_t check_nan;
uint32_t tmp;
asm volatile("\n \
v_cmp_u_f32 %0, %2, %2 \n \
v_bfe_u32 %1, %2, 16, 1 \n \
v_add3_u32 %1, %2, %1, %3 \n \
v_cndmask_b32 %2, %1, %4, %0 \n \
v_lshrrev_b32 %2, 16, %2 \n \
"
: "=s"(check_nan), "+v"(tmp), "+v"(u.fp32)
: "v"(ROUND_BIAS_FOR_BF16), "v"(FP32_NAN));
return uint16_t(u.int32);
}
// TODO: do we need this on host?
CK_TILE_HOST
uint16_t float_to_bf16_rta_asm(float f) { return float_to_bf16_rtn_raw(f); }
CK_TILE_DEVICE
uint16_t float_to_bf16_rta_asm(float f)
{
union
{
float fp32;
struct
{
uint16_t lo;
uint16_t hi;
};
} u = {f};
const uint32_t low_nan = 0x7fff;
const uint32_t hi_nan = 0x7fff0000;
using uint32x2_t = uint32_t __attribute__((ext_vector_type(2)));
uint32x2_t check_nan;
asm volatile("v_cmp_u_f32 %[s_cnan], %[v_x], %[v_x] \n"
"v_add3_u32 %[v_x], %[v_x], %[v_blo], 1 \n"
"v_cndmask_b32 %[v_x], %[v_x], %[v_bhi], %[s_cnan]"
: [s_cnan] "+s"(check_nan), [v_x] "+v"(u.fp32)
: [v_blo] "v"(low_nan), [v_bhi] "v"(hi_nan));
// Note: in above code snipet, we use hi 16 bit
return u.hi;
}
// Truncate instead of rounding, preserving SNaN
CK_TILE_HOST_DEVICE
constexpr uint16_t float_to_bf16_truc_nan_raw(float f)
{
union
{
float fp32;
uint32_t int32;
} u = {f};
return uint16_t(u.int32 >> 16) | (!(~u.int32 & 0x7f800000) && (u.int32 & 0xffff));
}
// Fast truncate instead of rounding, RTZ
CK_TILE_HOST_DEVICE
constexpr uint16_t float_to_bf16_truc_raw(float f)
{
union
{
float fp32;
uint32_t int32;
} u = {f};
return uint16_t(u.int32 >> 16);
}
template <bf16_rounding_mode rounding>
CK_TILE_HOST_DEVICE constexpr uint16_t float_to_bf16_raw(float f, constant<rounding>)
{
if constexpr(rounding == bf16_rounding_mode::standard)
return float_to_bf16_rtn_raw(f);
else if constexpr(rounding == bf16_rounding_mode::standard_asm)
return float_to_bf16_rtn_asm(f);
else if constexpr(rounding == bf16_rounding_mode::truncate_with_nan)
return float_to_bf16_truc_nan_raw(f);
else if constexpr(rounding == bf16_rounding_mode::rta_asm)
return float_to_bf16_rta_asm(f);
else
return float_to_bf16_truc_raw(f);
}
template <bf16_rounding_mode rounding>
CK_TILE_HOST_DEVICE constexpr uint16_t double_to_bf16_raw(double f, constant<rounding>)
{
return float_to_bf16_raw(static_cast<float>(f), constant<rounding>{});
}
CK_TILE_HOST_DEVICE
constexpr float bf16_to_float_raw(uint16_t x)
{
union
{
uint32_t int32;
float fp32;
} u = {uint32_t(x) << 16};
return u.fp32;
}
CK_TILE_HOST_DEVICE
constexpr double bf16_to_double_raw(uint16_t x)
{
return static_cast<double>(bf16_to_float_raw(x));
}
template <bf16_rounding_mode rounding =
static_cast<bf16_rounding_mode>(CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT)>
CK_TILE_HOST_DEVICE constexpr bfloat16_t float_to_bf16(float f, constant<rounding> = {})
{
return bit_cast<bfloat16_t>(float_to_bf16_raw(f, constant<rounding>{}));
}
template <bf16_rounding_mode rounding =
static_cast<bf16_rounding_mode>(CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT)>
CK_TILE_HOST_DEVICE constexpr bfloat16_t double_to_bf16(double f, constant<rounding> = {})
{
return bit_cast<bfloat16_t>(double_to_bf16_raw(f, constant<rounding>{}));
}
CK_TILE_HOST_DEVICE
constexpr float bf16_to_float(bfloat16_t x) { return bf16_to_float_raw(bit_cast<uint16_t>(x)); }
CK_TILE_HOST_DEVICE
constexpr double bf16_to_double(bfloat16_t x) { return static_cast<double>(bf16_to_float_raw(x)); }
template <bf16_rounding_mode rounding =
static_cast<bf16_rounding_mode>(CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT)>
CK_TILE_HOST_DEVICE bfloat16_t constexpr fp16_to_bf16(half_t f, constant<rounding> = {})
{
return bit_cast<bfloat16_t>(float_to_bf16_raw(static_cast<float>(f), constant<rounding>{}));
}
CK_TILE_HOST_DEVICE
constexpr half_t bf16_to_fp16(bfloat16_t x) { return static_cast<fp16_t>(static_cast<float>(x)); }
template <class T>
struct numeric;
template <>
struct numeric<bfloat16_t>
{
// minimum finite value, or minimum positive normalized value for float
CK_TILE_HOST_DEVICE static constexpr bfloat16_t min()
{
return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0x0080));
}
// minumum finite value
CK_TILE_HOST_DEVICE static constexpr bfloat16_t lowest()
{
return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0xff7f));
}
// maximum finite value
CK_TILE_HOST_DEVICE static constexpr bfloat16_t max()
{
return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0x7f7f));
}
// difference between 1.0 and next value representable by float
CK_TILE_HOST_DEVICE static constexpr bfloat16_t epsilon()
{
return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0x1000));
}
// maximum rounding error
// maximum rounding error
// bin : f edcba 9876543210
// bits: s eeeeeeee mmmmmmm
// 0 01111110 0000000 (0.5)
//
CK_TILE_HOST_DEVICE static constexpr bfloat16_t round_error()
{
return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0x3f00));
}
// positive infinity value
CK_TILE_HOST_DEVICE static constexpr bfloat16_t infinity()
{
return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0x7f80));
}
// quiet NaN
CK_TILE_HOST_DEVICE static constexpr bfloat16_t quiet_NaN()
{
return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0x7FFF));
}
// signaling NaN
CK_TILE_HOST_DEVICE static constexpr bfloat16_t signaling_NaN()
{
return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0x7FFF));
}
// smallest positive subnormal value
CK_TILE_HOST_DEVICE static constexpr bfloat16_t denorm_min()
{
return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0x0001));
}
CK_TILE_HOST_DEVICE static constexpr bfloat16_t zero()
{
return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0));
}
};
template <>
struct numeric_traits<bfloat16_t>
{
static constexpr int exp = 8;
static constexpr int mant = 7;
static constexpr int PackedSize = 1;
};
#if CK_TILE_USE_CUSTOM_DATA_TYPE
CK_TILE_ARITHMETIC_USING_FLOAT(CK_TILE_HOST_DEVICE, bfloat16_t)
#endif
// math
CK_TILE_HOST_DEVICE
bfloat16_t abs(const bfloat16_t& x)
{
return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(bit_cast<bf16_raw_t>(x) & 0x7fff));
}
CK_TILE_HOST_DEVICE
bool isnan(const bfloat16_t& x)
{
uint16_t xx = bit_cast<bf16_raw_t>(x);
return (xx & 0x7FFF) > 0x7C00;
}
CK_TILE_DEVICE
bfloat16_t sqrt(bfloat16_t x)
{
return static_cast<bfloat16_t>(__builtin_amdgcn_sqrtf(static_cast<float>(x)));
};
CK_TILE_DEVICE
bfloat16_t exp(bfloat16_t x)
{
return static_cast<bfloat16_t>(__ocml_exp_f32(static_cast<float>(x)));
};
CK_TILE_DEVICE
bfloat16_t exp2(bfloat16_t x) { return static_cast<bfloat16_t>(exp2f(static_cast<float>(x))); };
CK_TILE_DEVICE
bfloat16_t log(bfloat16_t x) { return static_cast<bfloat16_t>(__logf(static_cast<float>(x))); };
} // namespace ck_tile

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,404 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/numeric/numeric.hpp"
#include <hip/hip_fp16.h>
#pragma once
namespace ck_tile {
using fp16_hip_t = _Float16; // most of hip internal function use this type
using fp16_raw_t = uint16_t;
CK_TILE_HOST_DEVICE
constexpr float fp16_to_float_hip(const fp16_hip_t& x);
CK_TILE_HOST_DEVICE
constexpr double fp16_to_double_hip(const fp16_hip_t& x);
CK_TILE_HOST_DEVICE
constexpr fp16_hip_t float_to_fp16_hip(const float& x);
CK_TILE_HOST_DEVICE
constexpr fp16_hip_t double_to_fp16_hip(const double& x);
#if CK_TILE_USE_CUSTOM_DATA_TYPE
// HIP use fp16_hip_t as interchangable data type for float16
struct alignas(2) half_t
{
using raw_type = fp16_raw_t;
raw_type data;
CK_TILE_HOST_DEVICE
static constexpr half_t bit_cast(raw_type x)
{
half_t y;
y.data = x;
return y;
}
CK_TILE_HOST_DEVICE
constexpr fp16_hip_t to_fp16() const { return ck_tile::bit_cast<fp16_hip_t>(data); }
// constructor
constexpr half_t() : data{} {}
// construct from HIP half
CK_TILE_HOST_DEVICE
explicit constexpr half_t(const fp16_hip_t& x) : data(ck_tile::bit_cast<raw_type>(x)) {}
// construct from float
CK_TILE_HOST_DEVICE
explicit constexpr half_t(const float& x) : half_t(float_to_fp16_hip(x)) {}
// construct from double
CK_TILE_HOST_DEVICE
explicit constexpr half_t(const double& x) : half_t(double_to_fp16_hip(x)) {}
// construct from int
CK_TILE_HOST_DEVICE
explicit constexpr half_t(const int& x) : half_t(static_cast<fp16_hip_t>(__int2half_rn(x))) {}
// construct from unsigned int
CK_TILE_HOST_DEVICE
explicit constexpr half_t(const unsigned int& x)
: half_t(static_cast<fp16_hip_t>(__uint2half_rn(x)))
{
}
// cast to float
CK_TILE_HOST_DEVICE
explicit constexpr operator float() const { return fp16_to_float_hip(to_fp16()); }
// cast to double
CK_TILE_HOST_DEVICE
explicit constexpr operator double() const { return fp16_to_double_hip(to_fp16()); }
// cast to int
CK_TILE_HOST_DEVICE
explicit constexpr operator int() const
{
return static_cast<int>(fp16_to_float_hip(to_fp16()));
}
CK_TILE_HOST_DEVICE
explicit constexpr operator fp16_hip_t() const { return ck_tile::bit_cast<fp16_hip_t>(data); }
// internal access
CK_TILE_HOST_DEVICE
constexpr raw_type& get() { return data; }
CK_TILE_HOST_DEVICE
constexpr raw_type get() const { return data; }
};
template <typename>
struct native_t;
template <>
struct native_t<half_t>
{
using type = _Float16;
};
using fp16_t = half_t;
using fp16_raw_t = typename half_t::raw_type;
#else
using fp16_t = _Float16;
using half_t = _Float16;
using fp16_raw_t = ushort;
#endif
// conversions
CK_TILE_HOST_DEVICE
constexpr float fp16_to_float_hip(const fp16_hip_t& x)
{
// return __half2float(x);
return static_cast<float>(x);
}
CK_TILE_HOST_DEVICE
constexpr double fp16_to_double_hip(const fp16_hip_t& x)
{
return static_cast<double>(fp16_to_float_hip(x));
}
CK_TILE_HOST_DEVICE
constexpr fp16_hip_t float_to_fp16_hip(const float& x)
{
// return __float2half(x);
return static_cast<fp16_hip_t>(x);
}
CK_TILE_HOST_DEVICE
constexpr fp16_hip_t double_to_fp16_hip(const double& x)
{
// return __float2half(x);
return static_cast<fp16_hip_t>(x);
}
CK_TILE_HOST_DEVICE
constexpr float fp16_to_float(const half_t& x) { return static_cast<float>(x); }
CK_TILE_HOST_DEVICE
constexpr float fp16_to_double(const half_t& x) { return static_cast<float>(x); }
CK_TILE_HOST_DEVICE
constexpr half_t float_to_fp16(const float& x) { return static_cast<half_t>(x); }
CK_TILE_HOST_DEVICE
constexpr half_t double_to_fp16(const double& x) { return static_cast<half_t>(x); }
// limits
template <class T>
struct numeric;
template <>
struct numeric<half_t>
{
// minimum finite value, or minimum positive normalized value for float
CK_TILE_HOST_DEVICE static constexpr half_t min()
{
return bit_cast<half_t>(static_cast<fp16_raw_t>(0x0400));
}
// minumum finite value
CK_TILE_HOST_DEVICE static constexpr half_t lowest()
{
return bit_cast<half_t>(static_cast<fp16_raw_t>(0xFBFF));
}
// maximum finite value
CK_TILE_HOST_DEVICE static constexpr half_t max()
{
return bit_cast<half_t>(static_cast<fp16_raw_t>(0x7BFF));
}
// difference between 1.0 and next value representable by float
CK_TILE_HOST_DEVICE static constexpr half_t epsilon()
{
return bit_cast<half_t>(static_cast<fp16_raw_t>(0x1800));
}
// maximum rounding error
// bin : f edcba 9876543210
// bits: s eeeee mmmmmmmmmm
// 0 01110 0000000000 (0.5)
//
CK_TILE_HOST_DEVICE static constexpr half_t round_error()
{
return bit_cast<half_t>(static_cast<fp16_raw_t>(0x3800));
}
// positive infinity value
CK_TILE_HOST_DEVICE static constexpr half_t infinity()
{
return bit_cast<half_t>(static_cast<fp16_raw_t>(0x7C00));
}
// quiet NaN
CK_TILE_HOST_DEVICE static constexpr half_t quiet_NaN()
{
return bit_cast<half_t>(static_cast<fp16_raw_t>(0x7FFF));
}
// signaling NaN
CK_TILE_HOST_DEVICE static constexpr half_t signaling_NaN()
{
return bit_cast<half_t>(static_cast<fp16_raw_t>(0x7FFF));
}
// smallest positive subnormal value
CK_TILE_HOST_DEVICE static constexpr half_t denorm_min()
{
return bit_cast<half_t>(static_cast<fp16_raw_t>(0x0001));
}
CK_TILE_HOST_DEVICE static constexpr half_t zero()
{
return bit_cast<half_t>(static_cast<fp16_raw_t>(0));
}
};
template <>
struct numeric_traits<half_t>
{
static constexpr int exp = 5;
static constexpr int mant = 10;
static constexpr int bias = 15;
static constexpr uint16_t nan_mask = 0x7C00;
static constexpr uint16_t head_mask = 0xFC00;
static constexpr uint16_t mant_mask = 0x3FF;
static constexpr uint16_t exp_mask = 0x1F;
static constexpr uint16_t abs_mask = 0x7FFF;
static constexpr uint16_t Inf = 0x7C00;
static constexpr uint16_t NegInf = 0xFC00;
static constexpr uint16_t NaN = 0x7C01;
static constexpr uint16_t Neg0 = 0x8000;
static constexpr int PackedSize = 1;
using bitwise_type = uint16_t;
};
#if CK_TILE_USE_CUSTOM_DATA_TYPE
// arithmetic
CK_TILE_DEVICE bool operator==(const half_t& x, const half_t& y)
{
return __heq(x.to_fp16(), y.to_fp16());
}
CK_TILE_DEVICE
bool operator!=(const half_t& x, const half_t& y) { return __hne(x.to_fp16(), y.to_fp16()); }
CK_TILE_DEVICE
bool operator<(const half_t& x, const half_t& y) { return __hlt(x.to_fp16(), y.to_fp16()); }
CK_TILE_DEVICE
bool operator<=(const half_t& x, const half_t& y) { return __hle(x.to_fp16(), y.to_fp16()); }
CK_TILE_DEVICE
bool operator>(const half_t& x, const half_t& y) { return __hgt(x.to_fp16(), y.to_fp16()); }
CK_TILE_DEVICE
bool operator>=(const half_t& x, const half_t& y) { return __hge(x.to_fp16(), y.to_fp16()); }
#if 0
CK_TILE_DEVICE
half_t operator+(const half_t& x, const half_t& y)
{
return half_t(__hadd(x.to_fp16(), y.to_fp16()));
}
CK_TILE_DEVICE
half_t operator-(const half_t& x) { return half_t(__hneg(x.to_fp16())); }
CK_TILE_DEVICE
half_t operator-(const half_t& x, const half_t& y)
{
return half_t(__hsub(x.to_fp16(), y.to_fp16()));
}
CK_TILE_DEVICE
half_t operator*(const half_t& x, const half_t& y)
{
return half_t(__hmul(x.to_fp16(), y.to_fp16()));
}
CK_TILE_DEVICE
half_t operator/(const half_t& x, const half_t& y)
{
return half_t(__hdiv(x.to_fp16(), y.to_fp16()));
}
CK_TILE_DEVICE
half_t& operator+=(half_t& x, const half_t& y)
{
x = half_t(__hadd(x.to_fp16(), y.to_fp16()));
return x;
}
CK_TILE_DEVICE
half_t& operator-=(half_t& x, const half_t& y)
{
x = half_t(__hsub(x.to_fp16(), y.to_fp16()));
return x;
}
CK_TILE_DEVICE
half_t& operator*=(half_t& x, const half_t& y)
{
x = half_t(__hmul(x.to_fp16(), y.to_fp16()));
return x;
}
CK_TILE_DEVICE
half_t& operator/=(half_t& x, const half_t& y)
{
x = half_t(__hdiv(x.to_fp16(), y.to_fp16()));
return x;
}
CK_TILE_DEVICE
half_t& operator++(half_t& x)
{
x = half_t(__hadd(x.to_fp16(), half_t(1.0f).to_fp16()));
return x;
}
CK_TILE_DEVICE
half_t& operator--(half_t& x)
{
x = half_t(__hsub(x.to_fp16(), half_t(1.0f).to_fp16()));
return x;
}
CK_TILE_DEVICE
half_t operator++(half_t& x, int)
{
half_t y(x);
x = half_t(__hadd(x.to_fp16(), half_t(1.0f).to_fp16()));
return y;
}
CK_TILE_DEVICE
half_t operator--(half_t& x, int)
{
half_t y(x);
x = half_t(__hsub(x.to_fp16(), half_t(1.0f).to_fp16()));
return y;
}
#endif
#if CK_TILE_USE_CUSTOM_DATA_TYPE
CK_TILE_ARITHMETIC_USING_FLOAT(CK_TILE_HOST, half_t)
#endif
// math
CK_TILE_HOST_DEVICE
half_t abs(const half_t& x) { return bit_cast<half_t>(x.get() & 0x7fff); }
CK_TILE_HOST_DEVICE
bool isnan(const half_t& x)
{
uint16_t xx = x.get();
return (xx & 0x7FFF) > 0x7C00;
}
CK_TILE_DEVICE
half_t sqrt(half_t x)
{
return static_cast<half_t>(__builtin_amdgcn_sqrtf(static_cast<float>(x)));
};
CK_TILE_DEVICE
half_t exp(half_t x) { return static_cast<half_t>(__ocml_exp_f32(static_cast<float>(x))); };
CK_TILE_DEVICE
half_t exp2(half_t x) { return static_cast<half_t>(exp2f(static_cast<float>(x))); };
CK_TILE_DEVICE
half_t log(half_t x) { return static_cast<half_t>(__logf(static_cast<float>(x))); };
#endif
using fp16x2_t = _Float16 __attribute__((ext_vector_type(2)));
CK_TILE_HOST fp16x2_t pk_add_f16(const fp16x2_t& x, const fp16x2_t& y)
{
fp16x2_t vector_res;
vector_res.x = x.x + y.x;
vector_res.y = x.y + y.y;
return vector_res;
}
CK_TILE_DEVICE fp16x2_t pk_add_f16(const fp16x2_t& x, const fp16x2_t& y)
{
fp16x2_t c;
asm volatile("v_pk_add_f16 %0, %1, %2" : "=v"(c) : "v"(x), "v"(y));
return c;
}
} // namespace ck_tile

View File

@@ -0,0 +1,103 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/numeric/numeric.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/utility/random.hpp"
#include <stdint.h>
#include <type_traits>
#pragma once
namespace ck_tile {
// use int8_t directly for int8 arithemetic
// here one can use ck_tile::int8_t to access original int8_t
using int8_t = int8_t;
// limits
template <class T>
struct numeric;
template <>
struct numeric<int8_t>
{
// minimum finite value, or minimum positive normalized value for float
CK_TILE_HOST_DEVICE static constexpr int8_t min() { return int8_t(-128); }
// minumum finite value
CK_TILE_HOST_DEVICE static constexpr int8_t lowest() { return int8_t(-128); }
// maximum finite value
CK_TILE_HOST_DEVICE static constexpr int8_t max() { return int8_t(127); }
// difference between 1.0 and next value representable by float
CK_TILE_HOST_DEVICE static constexpr int8_t epsilon()
{
return 1; // not used
}
CK_TILE_HOST_DEVICE static constexpr int8_t round_error()
{
return 1; // not used
}
// positive infinity value
CK_TILE_HOST_DEVICE static constexpr int8_t infinity()
{
return 1; // not used
}
// quiet NaN
CK_TILE_HOST_DEVICE static constexpr int8_t quiet_NaN()
{
return 1; // not used
}
// signaling NaN
CK_TILE_HOST_DEVICE static constexpr int8_t signaling_NaN()
{
return 1; // not used
}
// smallest positive subnormal value
CK_TILE_HOST_DEVICE static constexpr int8_t denorm_min()
{
return 1; // not used
}
CK_TILE_HOST_DEVICE static constexpr int8_t zero() { return 0; }
};
#if 0
template <>
struct numeric_traits<int8_t>
{
static constexpr int exp = 5;
static constexpr int mant = 10;
static constexpr int bias = 15;
static constexpr uint16_t nan_mask = 0x7C00;
static constexpr uint16_t head_mask = 0xFC00;
static constexpr uint16_t mant_mask = 0x3FF;
static constexpr uint16_t exp_mask = 0x1F;
static constexpr uint32_t Inf = 0x7C00;
static constexpr uint32_t NegInf = 0xFC00;
static constexpr uint32_t NaN = 0x7C01;
static constexpr uint32_t Neg0 = 0x8000;
static constexpr int PackedSize = 1;
using bitwise_type = uint16_t;
};
#endif
CK_TILE_HOST_DEVICE
constexpr float int8_to_float(const int8_t& x) { return static_cast<float>(x); }
CK_TILE_HOST_DEVICE
constexpr int8_t float_to_int8(const float& x) { return static_cast<int8_t>(x); }
} // namespace ck_tile

View File

@@ -0,0 +1,13 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <stdint.h>
namespace ck_tile {
using index_t = int32_t;
using long_index_t = int64_t;
using int8_t = int8_t;
} // namespace ck_tile

View File

@@ -0,0 +1,82 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
namespace ck_tile {
template <auto v>
struct constant
{
using value_type = decltype(v);
using type = constant; // using injected-class-name
static constexpr value_type value = v;
CK_TILE_HOST_DEVICE constexpr operator value_type() const noexcept { return value; }
CK_TILE_HOST_DEVICE constexpr value_type operator()() const noexcept { return value; }
CK_TILE_HOST_DEVICE static constexpr bool is_static() { return true; }
};
template <typename T, T v>
struct integral_constant : constant<v>
{
using value_type = T;
using type = integral_constant; // using injected-class-name
static constexpr T value = v;
// constexpr CK_TILE_HOST_DEVICE operator value_type() const noexcept { return value; }
// constexpr CK_TILE_HOST_DEVICE value_type operator()() const noexcept { return value; } //
};
template <index_t v>
using number = constant<v>;
template <long_index_t v>
using long_number = constant<v>;
template <bool b>
using bool_constant = constant<b>;
#define CK_TILE_LEFT_UNARY_OP(OP) \
template <auto x> \
CK_TILE_HOST_DEVICE constexpr auto operator OP(constant<x>) \
{ \
return constant<(OP x)>{}; \
}
#define CK_TILE_BINARY_OP(OP) \
template <auto x, auto y> \
CK_TILE_HOST_DEVICE constexpr auto operator OP(constant<x>, constant<y>) \
{ \
return constant<(x OP y)>{}; \
}
CK_TILE_LEFT_UNARY_OP(+)
CK_TILE_LEFT_UNARY_OP(-)
CK_TILE_LEFT_UNARY_OP(~)
CK_TILE_LEFT_UNARY_OP(!)
CK_TILE_BINARY_OP(+)
CK_TILE_BINARY_OP(-)
CK_TILE_BINARY_OP(*)
CK_TILE_BINARY_OP(/)
CK_TILE_BINARY_OP(%)
CK_TILE_BINARY_OP(&)
CK_TILE_BINARY_OP(|)
CK_TILE_BINARY_OP(^)
CK_TILE_BINARY_OP(<<)
CK_TILE_BINARY_OP(>>)
CK_TILE_BINARY_OP(&&)
CK_TILE_BINARY_OP(||)
CK_TILE_BINARY_OP(==)
CK_TILE_BINARY_OP(!=)
CK_TILE_BINARY_OP(>)
CK_TILE_BINARY_OP(<)
CK_TILE_BINARY_OP(>=)
CK_TILE_BINARY_OP(<=)
#undef CK_TILE_LEFT_UNARY_OP
#undef CK_TILE_BINARY_OP
} // namespace ck_tile

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,13 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <stdint.h>
namespace ck_tile {
struct null_type
{
};
} // namespace ck_tile

View File

@@ -0,0 +1,196 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include <limits>
#include <stdint.h>
namespace ck_tile {
// this struct has the information of
// 1. limit of a certain type, simliar to std::numeric_limits
// 2. some pre-defined value, zero, one...
//
template <typename T>
struct numeric
{
// minimum finite value, or minimum positive normalized value for float
CK_TILE_HOST_DEVICE static constexpr T min() { return std::numeric_limits<T>::min(); }
// minumum finite value
CK_TILE_HOST_DEVICE static constexpr T lowest() { return std::numeric_limits<T>::lowest(); }
// maximum finite value
CK_TILE_HOST_DEVICE static constexpr T max() { return std::numeric_limits<T>::max(); }
// difference between 1.0 and next value representable by float
CK_TILE_HOST_DEVICE static constexpr T epsilon() { return std::numeric_limits<T>::epsilon(); }
// maximum rounding error
CK_TILE_HOST_DEVICE static constexpr T round_error()
{
return std::numeric_limits<T>::round_error();
}
// positive infinity value
CK_TILE_HOST_DEVICE static constexpr T infinity() { return std::numeric_limits<T>::infinity(); }
// quiet NaN
CK_TILE_HOST_DEVICE static constexpr T quiet_NaN()
{
return std::numeric_limits<T>::quiet_NaN();
}
// signaling NaN
CK_TILE_HOST_DEVICE static constexpr T signaling_NaN()
{
return std::numeric_limits<T>::signaling_NaN();
}
// smallest positive subnormal value
CK_TILE_HOST_DEVICE static constexpr T denorm_min()
{
return std::numeric_limits<T>::denorm_min();
}
CK_TILE_HOST_DEVICE static constexpr T zero() { return static_cast<T>(0); }
CK_TILE_HOST_DEVICE static constexpr T one() { return static_cast<T>(1); }
#ifndef C_LOG2E
#define C_LOG2E 1.44269504088896340736 // log2(e)
#endif
CK_TILE_HOST_DEVICE static constexpr T log2e()
{
if constexpr(std::is_same_v<T, float> || std::is_same_v<T, double>)
{
return static_cast<T>(C_LOG2E);
}
else
{
return 0; // TODO: integer?
}
}
};
template <typename T>
struct numeric_traits
{
static constexpr int PackedSize = 1;
};
template <>
struct numeric_traits<float>
{
static constexpr int exp = 8;
static constexpr int mant = 23;
static constexpr int bias = 127;
static constexpr uint32_t nan_mask = 0x7F800000;
static constexpr uint32_t head_mask = 0xFF800000;
static constexpr uint32_t mant_mask = 0x7FFFFF;
static constexpr uint32_t exp_mask = 0xFF;
static constexpr uint32_t abs_mask = 0x7FFFFFFF;
static constexpr uint32_t Inf = 0x7F800000;
static constexpr uint32_t NegInf = 0xFF800000;
static constexpr uint32_t NaN = 0x7F800001;
static constexpr uint32_t Neg0 = 0x80000000;
static constexpr int PackedSize = 1;
using bitwise_type = uint32_t;
};
} // namespace ck_tile
#define CK_TILE_ARITHMETIC_USING_FLOAT(attr_, type_) \
attr_ bool operator==(const type_& x, const type_& y) \
{ \
return static_cast<float>(x) == static_cast<float>(y); \
} \
attr_ bool operator!=(const type_& x, const type_& y) \
{ \
return static_cast<float>(x) != static_cast<float>(y); \
} \
attr_ bool operator<(const type_& x, const type_& y) \
{ \
return static_cast<float>(x) < static_cast<float>(y); \
} \
attr_ bool operator<=(const type_& x, const type_& y) \
{ \
return static_cast<float>(x) <= static_cast<float>(y); \
} \
attr_ bool operator>(const type_& x, const type_& y) \
{ \
return static_cast<float>(x) > static_cast<float>(y); \
} \
attr_ bool operator>=(const type_& x, const type_& y) \
{ \
return static_cast<float>(x) >= static_cast<float>(y); \
} \
attr_ type_ operator+(const type_& x, const type_& y) \
{ \
return type_(static_cast<float>(x) + static_cast<float>(y)); \
} \
attr_ type_ operator-(const type_& x) \
{ \
constexpr uint32_t bits = sizeof(type_) * 8; \
constexpr uint32_t mask = 1 << (bits - 1); \
type_ y = x; \
y.data ^= static_cast<typename type_::raw_type>(mask); \
return y; \
} \
attr_ type_ operator-(const type_& x, const type_& y) \
{ \
return type_(static_cast<float>(x) - static_cast<float>(y)); \
} \
attr_ type_ operator*(const type_& x, const type_& y) \
{ \
return type_(static_cast<float>(x) * static_cast<float>(y)); \
} \
attr_ type_ operator/(const type_& x, const type_& y) \
{ \
return type_(static_cast<float>(x) / static_cast<float>(y)); \
} \
attr_ type_& operator+=(type_& x, const type_& y) \
{ \
x = type_(static_cast<float>(x) + static_cast<float>(y)); \
return x; \
} \
attr_ type_& operator-=(type_& x, const type_& y) \
{ \
x = type_(static_cast<float>(x) - static_cast<float>(y)); \
return x; \
} \
attr_ type_& operator*=(type_& x, const type_& y) \
{ \
x = type_(static_cast<float>(x) * static_cast<float>(y)); \
return x; \
} \
attr_ type_& operator/=(type_& x, const type_& y) \
{ \
x = type_(static_cast<float>(x) / static_cast<float>(y)); \
return x; \
} \
attr_ type_& operator++(type_& x) \
{ \
x = type_(static_cast<float>(x) + 1.f); \
return x; \
} \
attr_ type_& operator--(type_& x) \
{ \
x = type_(static_cast<float>(x) - 1.f); \
return x; \
} \
attr_ type_ operator++(type_& x, int) \
{ \
type_ y(x); \
x = type_(static_cast<float>(x) + 1.f); \
return y; \
} \
attr_ type_ operator--(type_& x, int) \
{ \
type_ y(x); \
x = type_(static_cast<float>(x) - 1.f); \
return y; \
}

View File

@@ -0,0 +1,150 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/numeric/numeric.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/utility/random.hpp"
#include <stdint.h>
#include <type_traits>
#include "ck_tile/core/numeric/int8.hpp"
#pragma once
namespace ck_tile {
// Packed 2xint4
struct pk_int4_t
{
using type = int8_t;
type data;
CK_TILE_HOST_DEVICE constexpr pk_int4_t() : data{type{}} {}
CK_TILE_HOST_DEVICE constexpr pk_int4_t(type init) : data{init} {}
};
// limits
template <class T>
struct numeric;
template <>
struct numeric<pk_int4_t>
{
// minimum finite value, or minimum positive normalized value for float
CK_TILE_HOST_DEVICE static constexpr pk_int4_t min()
{
constexpr uint8_t val = 0b10001000;
return pk_int4_t(bit_cast<int8_t>(val));
}
// minumum finite value
CK_TILE_HOST_DEVICE static constexpr pk_int4_t lowest()
{
constexpr uint8_t val = 0b10001000;
return pk_int4_t(bit_cast<int8_t>(val));
}
// maximum finite value
CK_TILE_HOST_DEVICE static constexpr pk_int4_t max()
{
constexpr uint8_t val = 0b01110111;
return pk_int4_t(bit_cast<int8_t>(val));
}
// difference between 1.0 and next value representable by float
CK_TILE_HOST_DEVICE static constexpr pk_int4_t epsilon()
{
return 1; // not used
}
CK_TILE_HOST_DEVICE static constexpr pk_int4_t round_error()
{
return 1; // not used
}
// positive infinity value
CK_TILE_HOST_DEVICE static constexpr pk_int4_t infinity()
{
return 1; // not used
}
// quiet NaN
CK_TILE_HOST_DEVICE static constexpr pk_int4_t quiet_NaN()
{
return 1; // not used
}
// signaling NaN
CK_TILE_HOST_DEVICE static constexpr pk_int4_t signaling_NaN()
{
return 1; // not used
}
// smallest positive subnormal value
CK_TILE_HOST_DEVICE static constexpr pk_int4_t denorm_min()
{
return 1; // not used
}
CK_TILE_HOST_DEVICE static constexpr pk_int4_t zero() { return 0; }
};
template <>
struct numeric_traits<pk_int4_t>
{
static constexpr int PackedSize = 2;
};
using fp32x2_t = float __attribute__((ext_vector_type(2)));
using fp16x2_t = _Float16 __attribute__((ext_vector_type(2)));
using bf16x2_t = bf16_raw_t __attribute__((ext_vector_type(2)));
CK_TILE_HOST_DEVICE fp32x2_t pk_int4_t_to_fp32x2_t(const pk_int4_t& x)
{
uint8_t x_u8 = ck_tile::bit_cast<uint8_t>(x);
float x_l = ((x_u8 & 0x0f) >> 0) - 8.f;
float x_h = ((x_u8 & 0xf0) >> 4) - 8.f;
#ifdef CK_TILE_USE_PK4_LAYOUT_SHUFFLE
fp32x2_t res = {x_h, x_l};
#elif
fp32x2_t res = {x_l, x_h};
#endif
return res;
}
CK_TILE_HOST_DEVICE fp16x2_t pk_int4_t_to_halfx2_t(const pk_int4_t& x)
{
uint8_t x_u8 = ck_tile::bit_cast<uint8_t>(x);
#ifdef CK_TILE_USE_PK4_LAYOUT_SHUFFLE
uint32_t i4s = ((x_u8 & 0x0f) << 16) | ((x_u8 & 0xf0) >> 4);
#elif
uint32_t i4s = ((x_u8 & 0xf0) << 12) | (x_u8 & 0xf);
#endif
const int EX = 0x64006400;
const int SUB = 0xE408E408; //-8
int lo = i4s | EX;
return pk_add_f16(bit_cast<fp16x2_t>(lo), bit_cast<fp16x2_t>(SUB));
}
CK_TILE_HOST_DEVICE bf16x2_t pk_int4_t_to_bfloat16x2_t(const pk_int4_t& x)
{
uint8_t x_u8 = ck_tile::bit_cast<uint8_t>(x);
float x_l = ((x_u8 & 0x0f) >> 0) - 8.f;
float x_h = ((x_u8 & 0xf0) >> 4) - 8.f;
#ifdef CK_TILE_USE_PK4_LAYOUT_SHUFFLE
bf16x2_t res = {type_convert<bf16_t>(x_h), type_convert<bf16_t>(x_l)};
#elif
bf16x2_t res = {type_convert<bf16_t>(x_l), type_convert<bf16_t>(x_h)};
#endif
return res;
}
} // namespace ck_tile

View File

@@ -0,0 +1,70 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <stdint.h>
#include <tuple>
#include <type_traits>
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/bfloat16.hpp"
#include "ck_tile/core/numeric/float8.hpp"
#include "ck_tile/core/numeric/int8.hpp"
namespace ck_tile {
#if CK_TILE_USE_CUSTOM_DATA_TYPE
template <typename Y, typename X>
CK_TILE_HOST_DEVICE constexpr remove_cvref_t<Y> type_convert(const X& x)
{
return static_cast<Y>(x);
}
#else
// Convert X to Y, both X and Y are non-const data types.
template <typename Y,
typename X,
std::enable_if_t<!(std::is_const_v<Y> || std::is_const_v<X>), bool> = false>
CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
{
static_assert(!std::is_reference_v<Y> && !std::is_reference_v<X>);
return static_cast<Y>(x);
}
// Convert X to Y, either X or Y is a const data type.
template <typename Y,
typename X,
std::enable_if_t<std::is_const_v<Y> || std::is_const_v<X>, bool> = false>
CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
{
static_assert(!std::is_reference_v<Y> && !std::is_reference_v<X>);
using non_const_y = std::remove_const_t<Y>;
using non_const_x = std::remove_const_t<X>;
return static_cast<Y>(type_convert<non_const_y, non_const_x>(x));
}
#define CK_TILE_TYPE_CONVERT(dtype_, dname_, stype_, sname_) \
template <> \
CK_TILE_HOST_DEVICE constexpr dtype_ type_convert<dtype_, stype_>(stype_ x) \
{ \
return sname_##_to_##dname_(x); \
}
CK_TILE_TYPE_CONVERT(float, float, fp16_t, fp16)
CK_TILE_TYPE_CONVERT(float, float, bf16_t, bf16)
CK_TILE_TYPE_CONVERT(float, float, fp8_t, fp8)
CK_TILE_TYPE_CONVERT(float, float, bf8_t, bf8)
CK_TILE_TYPE_CONVERT(fp16_t, fp16, float, float)
CK_TILE_TYPE_CONVERT(bf16_t, bf16, float, float)
CK_TILE_TYPE_CONVERT(fp8_t, fp8, float, float)
CK_TILE_TYPE_CONVERT(bf8_t, bf8, float, float)
CK_TILE_TYPE_CONVERT(float, float, int8_t, int8)
CK_TILE_TYPE_CONVERT(int8_t, int8, float, float)
#undef CK_TILE_TYPE_CONVERT
#endif
} // namespace ck_tile

View File

@@ -0,0 +1,240 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/array.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/numeric/float8.hpp"
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/bfloat16.hpp"
#include "ck_tile/core/numeric/pk_int4.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace ck_tile {
// this structure is used to pick up the <base> type inside
// using xxx = <base> __attribute__((ext_vector_type(N)));
// because clang only allow native type + bool in this term (custom type will fail)
// overload this structure to let proper <base> type
template <typename T>
struct native_t
{
using type = remove_cvref_t<T>;
};
// we name this as ext_vector purposely, because clang ext_vector_type extention only accept literay
// basic type to construct a ext_vector_type you must be very careful using this, or will have lot
// of compiler errors e.g. struct A; using Ax2_t = A __attribute__((ext_vector_type(2))); -> will
// have compiler error
namespace impl {
template <typename T_, index_t N_, typename = void>
struct ext_vector;
template <typename T_, index_t N_>
struct ext_vector<T_, N_, std::enable_if_t<!std::is_class_v<typename native_t<T_>::type>>>
{
static constexpr index_t N = N_;
// struct type is not supported for ext_vector
using value_type = typename native_t<T_>::type;
static_assert(!std::is_class_v<value_type>);
using type = value_type __attribute__((ext_vector_type(N))); // this is danguous
};
template <typename T_, index_t N_>
struct ext_vector<T_, N_, std::enable_if_t<std::is_class_v<typename native_t<T_>::type>>>
{
static constexpr index_t N = N_;
// struct type is not supported for ext_vector
using value_type = typename native_t<T_>::type::type;
static_assert(!std::is_class_v<value_type>);
using type = value_type __attribute__((ext_vector_type(N))); // this is danguous
};
template <typename V_, index_t Vs_, index_t N_>
struct ext_vector<V_ __attribute__((ext_vector_type(Vs_))),
N_,
std::enable_if_t<!std::is_class_v<typename native_t<V_>::type>>>
{
static constexpr index_t N = Vs_ * N_;
using value_type = typename native_t<remove_cvref_t<V_>>::type;
static_assert(!std::is_class_v<value_type>);
using type = value_type __attribute__((ext_vector_type(N))); // this is danguous
};
template <typename V_, index_t Vs_, index_t N_>
struct ext_vector<V_ __attribute__((ext_vector_type(Vs_))),
N_,
std::enable_if_t<std::is_class_v<typename native_t<V_>::type>>>
{
static constexpr index_t N = Vs_ * N_;
using value_type = typename native_t<remove_cvref_t<V_>>::type::type;
static_assert(!std::is_class_v<value_type>);
using type = value_type __attribute__((ext_vector_type(N))); // this is danguous
};
} // namespace impl
template <typename T, index_t N>
using ext_vector_t = typename impl::ext_vector<T, N>::type;
// by default, any type will result in a vector_size=1 with scalar_type=T traits.
// ... unless we have other vector_traits specialization
template <typename T, typename>
struct vector_traits
{
using scalar_type =
std::conditional_t<std::is_same_v<remove_cvref_t<T>, pk_int4_t>, int8_t, remove_cvref_t<T>>;
static constexpr index_t vector_size = 1;
};
// specialization for ext_vector_type()
template <typename T, index_t N>
struct vector_traits<T __attribute__((ext_vector_type(N)))>
{
using scalar_type = std::conditional_t<std::is_same_v<T, pk_int4_t>, int8_t, T>;
static constexpr index_t vector_size = N;
};
template <typename X, typename Y>
using has_same_scalar_type = std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<Y>>::scalar_type>;
// below are some pre-defines of ext_vector_type
// attention! 2 vector type could be just the same type
// fp64
using fp64_t = double;
using fp64x2_t = double __attribute__((ext_vector_type(2)));
using fp64x4_t = double __attribute__((ext_vector_type(4)));
// fp32
using fp32_t = float;
using fp32x2_t = float __attribute__((ext_vector_type(2)));
using fp32x4_t = float __attribute__((ext_vector_type(4)));
using fp32x8_t = float __attribute__((ext_vector_type(8)));
using fp32x16_t = float __attribute__((ext_vector_type(16)));
using fp32x32_t = float __attribute__((ext_vector_type(32)));
using fp32x64_t = float __attribute__((ext_vector_type(64)));
// fp16
// using fp16_t = ...
using fp16x2_t = _Float16 __attribute__((ext_vector_type(2)));
using fp16x4_t = _Float16 __attribute__((ext_vector_type(4)));
using fp16x8_t = _Float16 __attribute__((ext_vector_type(8)));
using fp16x16_t = _Float16 __attribute__((ext_vector_type(16)));
using fp16x32_t = _Float16 __attribute__((ext_vector_type(32)));
using fp16x64_t = _Float16 __attribute__((ext_vector_type(64)));
// bf16
// using bf16_t = ...
using bf16x2_t = bf16_raw_t __attribute__((ext_vector_type(2)));
using bf16x4_t = bf16_raw_t __attribute__((ext_vector_type(4)));
using bf16x8_t = bf16_raw_t __attribute__((ext_vector_type(8)));
using bf16x16_t = bf16_raw_t __attribute__((ext_vector_type(16)));
using bf16x32_t = bf16_raw_t __attribute__((ext_vector_type(32)));
using bf16x64_t = bf16_raw_t __attribute__((ext_vector_type(64)));
// i32
// using int32_t = ...
using int32x2_t = int32_t __attribute__((ext_vector_type(2)));
using int32x4_t = int32_t __attribute__((ext_vector_type(4)));
using int32x8_t = int32_t __attribute__((ext_vector_type(8)));
using int32x16_t = int32_t __attribute__((ext_vector_type(16)));
using int32x32_t = int32_t __attribute__((ext_vector_type(32)));
using int32x64_t = int32_t __attribute__((ext_vector_type(64)));
// u32
// using uint32_t = ...
using uint32x2_t = uint32_t __attribute__((ext_vector_type(2)));
using uint32x4_t = uint32_t __attribute__((ext_vector_type(4)));
using uint32x8_t = uint32_t __attribute__((ext_vector_type(8)));
using uint32x16_t = uint32_t __attribute__((ext_vector_type(16)));
using uint32x32_t = uint32_t __attribute__((ext_vector_type(32)));
using uint32x64_t = uint32_t __attribute__((ext_vector_type(64)));
// i16
// using int16_t = ...
using int16x2_t = int16_t __attribute__((ext_vector_type(2)));
using int16x4_t = int16_t __attribute__((ext_vector_type(4)));
using int16x8_t = int16_t __attribute__((ext_vector_type(8)));
using int16x16_t = int16_t __attribute__((ext_vector_type(16)));
using int16x32_t = int16_t __attribute__((ext_vector_type(32)));
using int16x64_t = int16_t __attribute__((ext_vector_type(64)));
// u16
// using uint16_t
using uint16x2_t = uint16_t __attribute__((ext_vector_type(2)));
using uint16x4_t = uint16_t __attribute__((ext_vector_type(4)));
using uint16x8_t = uint16_t __attribute__((ext_vector_type(8)));
using uint16x16_t = uint16_t __attribute__((ext_vector_type(16)));
using uint16x32_t = uint16_t __attribute__((ext_vector_type(32)));
using uint16x64_t = uint16_t __attribute__((ext_vector_type(64)));
// i8
// using int8_t
using int8x2_t = int8_t __attribute((ext_vector_type(2)));
using int8x4_t = int8_t __attribute((ext_vector_type(4)));
using int8x8_t = int8_t __attribute((ext_vector_type(8)));
using int8x16_t = int8_t __attribute((ext_vector_type(16)));
using int8x32_t = int8_t __attribute((ext_vector_type(32)));
using int8x64_t = int8_t __attribute((ext_vector_type(64)));
// ui8
// using uint8_t
using uint8x2_t = uint8_t __attribute((ext_vector_type(2)));
using uint8x4_t = uint8_t __attribute((ext_vector_type(4)));
using uint8x8_t = uint8_t __attribute((ext_vector_type(8)));
using uint8x16_t = uint8_t __attribute((ext_vector_type(16)));
using uint8x32_t = uint8_t __attribute((ext_vector_type(32)));
using uint8x64_t = uint8_t __attribute((ext_vector_type(64)));
#if CK_TILE_USE_CUSTOM_DATA_TYPE
// f8
// using fp8_t
using fp8x2_t = fp8_raw_t __attribute((ext_vector_type(2)));
using fp8x4_t = fp8_raw_t __attribute((ext_vector_type(4)));
using fp8x8_t = fp8_raw_t __attribute((ext_vector_type(8)));
using fp8x16_t = fp8_raw_t __attribute((ext_vector_type(16)));
using fp8x32_t = fp8_raw_t __attribute((ext_vector_type(32)));
using fp8x64_t = fp8_raw_t __attribute((ext_vector_type(64)));
// bf8
// using bf8_t
using bf8x2_t = bf8_raw_t __attribute((ext_vector_type(2)));
using bf8x4_t = bf8_raw_t __attribute((ext_vector_type(4)));
using bf8x8_t = bf8_raw_t __attribute((ext_vector_type(8)));
using bf8x16_t = bf8_raw_t __attribute((ext_vector_type(16)));
using bf8x32_t = bf8_raw_t __attribute((ext_vector_type(32)));
using bf8x64_t = bf8_raw_t __attribute((ext_vector_type(64)));
#else
// f8
// using fp8_t
using fp8x2_t = fp8_t __attribute((ext_vector_type(2)));
using fp8x4_t = fp8_t __attribute((ext_vector_type(4)));
using fp8x8_t = fp8_t __attribute((ext_vector_type(8)));
using fp8x16_t = fp8_t __attribute((ext_vector_type(16)));
using fp8x32_t = fp8_t __attribute((ext_vector_type(32)));
using fp8x64_t = fp8_t __attribute((ext_vector_type(64)));
// bf8
// using bf8_t
using bf8x2_t = bf8_t __attribute((ext_vector_type(2)));
using bf8x4_t = bf8_t __attribute((ext_vector_type(4)));
using bf8x8_t = bf8_t __attribute((ext_vector_type(8)));
using bf8x16_t = bf8_t __attribute((ext_vector_type(16)));
using bf8x32_t = bf8_t __attribute((ext_vector_type(32)));
using bf8x64_t = bf8_t __attribute((ext_vector_type(64)));
#endif
// pk_int4_t
// using pk_int4_t
using pk_int4x2_t = int8_t __attribute((ext_vector_type(2)));
using pk_int4x4_t = int8_t __attribute((ext_vector_type(4)));
using pk_int4x8_t = int8_t __attribute((ext_vector_type(8)));
using pk_int4x16_t = int8_t __attribute((ext_vector_type(16)));
using pk_int4x32_t = int8_t __attribute((ext_vector_type(32)));
} // namespace ck_tile