mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-13 01:36:06 +00:00
now can build
This commit is contained in:
@@ -4,44 +4,36 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#define CK_TILE_ARITHMETIC_USING_FLOAT(type_) \
|
||||
CK_TILE_HOST_DEVICE \
|
||||
bool operator==(const type_& x, const type_& y) \
|
||||
#define CK_TILE_ARITHMETIC_USING_FLOAT(attr_, type_) \
|
||||
attr_ bool operator==(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return static_cast<float>(x) == static_cast<float>(y); \
|
||||
} \
|
||||
CK_TILE_HOST_DEVICE \
|
||||
bool operator!=(const type_& x, const type_& y) \
|
||||
attr_ bool operator!=(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return static_cast<float>(x) != static_cast<float>(y); \
|
||||
} \
|
||||
CK_TILE_HOST_DEVICE \
|
||||
bool operator<(const type_& x, const type_& y) \
|
||||
attr_ bool operator<(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return static_cast<float>(x) < static_cast<float>(y); \
|
||||
} \
|
||||
CK_TILE_HOST_DEVICE \
|
||||
bool operator<=(const type_& x, const type_& y) \
|
||||
attr_ bool operator<=(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return static_cast<float>(x) <= static_cast<float>(y); \
|
||||
} \
|
||||
CK_TILE_HOST_DEVICE \
|
||||
bool operator>(const type_& x, const type_& y) \
|
||||
attr_ bool operator>(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return static_cast<float>(x) > static_cast<float>(y); \
|
||||
} \
|
||||
CK_TILE_HOST_DEVICE \
|
||||
bool operator>=(const type_& x, const type_& y) \
|
||||
attr_ bool operator>=(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return static_cast<float>(x) >= static_cast<float>(y); \
|
||||
} \
|
||||
CK_TILE_HOST_DEVICE \
|
||||
type_ operator+(const type_& x, const type_& y) \
|
||||
attr_ type_ operator+(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return type_(static_cast<float>(x) + static_cast<float>(y)); \
|
||||
} \
|
||||
CK_TILE_HOST_DEVICE \
|
||||
type_ operator-(const type_& x) \
|
||||
attr_ type_ operator-(const type_& x) \
|
||||
{ \
|
||||
constexpr uint32_t bits = sizeof(type_) * 8; \
|
||||
constexpr uint32_t mask = 1 << (bits - 1); \
|
||||
@@ -49,66 +41,55 @@
|
||||
y.data ^= static_cast<typename type_::raw_type>(mask); \
|
||||
return y; \
|
||||
} \
|
||||
CK_TILE_HOST_DEVICE \
|
||||
type_ operator-(const type_& x, const type_& y) \
|
||||
attr_ type_ operator-(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return type_(static_cast<float>(x) - static_cast<float>(y)); \
|
||||
} \
|
||||
CK_TILE_HOST_DEVICE \
|
||||
type_ operator*(const type_& x, const type_& y) \
|
||||
attr_ type_ operator*(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return type_(static_cast<float>(x) * static_cast<float>(y)); \
|
||||
} \
|
||||
CK_TILE_HOST_DEVICE \
|
||||
type_ operator/(const type_& x, const type_& y) \
|
||||
attr_ type_ operator/(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return type_(static_cast<float>(x) / static_cast<float>(y)); \
|
||||
} \
|
||||
CK_TILE_HOST_DEVICE \
|
||||
type_& operator+=(type_& x, const type_& y) \
|
||||
attr_ type_& operator+=(type_& x, const type_& y) \
|
||||
{ \
|
||||
x = type_(static_cast<float>(x) + static_cast<float>(y)); \
|
||||
return x; \
|
||||
} \
|
||||
CK_TILE_HOST_DEVICE \
|
||||
type_& operator-=(type_& x, const type_& y) \
|
||||
attr_ type_& operator-=(type_& x, const type_& y) \
|
||||
{ \
|
||||
x = type_(static_cast<float>(x) - static_cast<float>(y)); \
|
||||
return x; \
|
||||
} \
|
||||
CK_TILE_HOST_DEVICE \
|
||||
type_& operator*=(type_& x, const type_& y) \
|
||||
attr_ type_& operator*=(type_& x, const type_& y) \
|
||||
{ \
|
||||
x = type_(static_cast<float>(x) * static_cast<float>(y)); \
|
||||
return x; \
|
||||
} \
|
||||
CK_TILE_HOST_DEVICE \
|
||||
type_& operator/=(type_& x, const type_& y) \
|
||||
attr_ type_& operator/=(type_& x, const type_& y) \
|
||||
{ \
|
||||
x = type_(static_cast<float>(x) / static_cast<float>(y)); \
|
||||
return x; \
|
||||
} \
|
||||
CK_TILE_HOST_DEVICE \
|
||||
type_& operator++(type_& x) \
|
||||
attr_ type_& operator++(type_& x) \
|
||||
{ \
|
||||
x = type_(static_cast<float>(x) + 1.f); \
|
||||
return x; \
|
||||
} \
|
||||
CK_TILE_HOST_DEVICE \
|
||||
type_& operator--(type_& x) \
|
||||
attr_ type_& operator--(type_& x) \
|
||||
{ \
|
||||
x = type_(static_cast<float>(x) - 1.f); \
|
||||
return x; \
|
||||
} \
|
||||
CK_TILE_HOST_DEVICE \
|
||||
type_ operator++(type_& x, int) \
|
||||
attr_ type_ operator++(type_& x, int) \
|
||||
{ \
|
||||
type_ y(x); \
|
||||
x = type_(static_cast<float>(x) + 1.f); \
|
||||
return y; \
|
||||
} \
|
||||
CK_TILE_HOST_DEVICE \
|
||||
type_ operator--(type_& x, int) \
|
||||
attr_ type_ operator--(type_& x, int) \
|
||||
{ \
|
||||
type_ y(x); \
|
||||
x = type_(static_cast<float>(x) - 1.f); \
|
||||
|
||||
@@ -24,9 +24,16 @@ template <bf16_rounding_mode rounding =
|
||||
static_cast<bf16_rounding_mode>(CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT)>
|
||||
CK_TILE_HOST_DEVICE uint16_t float_to_bf16_raw(float f, constant<rounding> = {});
|
||||
|
||||
template <bf16_rounding_mode rounding =
|
||||
static_cast<bf16_rounding_mode>(CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT)>
|
||||
CK_TILE_HOST_DEVICE uint16_t double_to_bf16_raw(double f, constant<rounding> = {});
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
float bf16_to_float_raw(uint16_t x);
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
double bf16_to_double_raw(uint16_t x);
|
||||
|
||||
// HIP use __hip_bfloat16 as struct
|
||||
struct alignas(2) bfloat16_t
|
||||
{
|
||||
@@ -48,6 +55,10 @@ struct alignas(2) bfloat16_t
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit constexpr bfloat16_t(const float& x) : data(float_to_bf16_raw(x)) {}
|
||||
|
||||
// construct from double
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit constexpr bfloat16_t(const double& x) : data(double_to_bf16_raw(x)) {}
|
||||
|
||||
// construct from int
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit constexpr bfloat16_t(const int& x) : data(float_to_bf16_raw(static_cast<float>(x))) {}
|
||||
@@ -63,6 +74,10 @@ struct alignas(2) bfloat16_t
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit constexpr operator float() const { return bf16_to_float_raw(data); }
|
||||
|
||||
// cast to float
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit constexpr operator double() const { return bf16_to_double_raw(data); }
|
||||
|
||||
// cast to int
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit constexpr operator int() const { return static_cast<int>(bf16_to_float_raw(data)); }
|
||||
@@ -157,6 +172,12 @@ CK_TILE_HOST_DEVICE uint16_t float_to_bf16_raw(float f, constant<rounding>)
|
||||
return float_to_bf16_truc_raw(f);
|
||||
}
|
||||
|
||||
template <bf16_rounding_mode rounding>
|
||||
CK_TILE_HOST_DEVICE uint16_t double_to_bf16_raw(double f, constant<rounding>)
|
||||
{
|
||||
return float_to_bf16_raw(static_cast<float>(f), constant<rounding>{});
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
float bf16_to_float_raw(uint16_t x)
|
||||
{
|
||||
@@ -168,6 +189,9 @@ float bf16_to_float_raw(uint16_t x)
|
||||
return u.fp32;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
double bf16_to_double_raw(uint16_t x) { return static_cast<double>(bf16_to_float_raw(x)); }
|
||||
|
||||
template <bf16_rounding_mode rounding =
|
||||
static_cast<bf16_rounding_mode>(CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT)>
|
||||
CK_TILE_HOST_DEVICE bfloat16_t float_to_bf16(float f, constant<rounding>)
|
||||
@@ -175,9 +199,19 @@ CK_TILE_HOST_DEVICE bfloat16_t float_to_bf16(float f, constant<rounding>)
|
||||
return bfloat16_t::bit_cast(float_to_bf16_raw(f, constant<rounding>{}));
|
||||
}
|
||||
|
||||
template <bf16_rounding_mode rounding =
|
||||
static_cast<bf16_rounding_mode>(CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT)>
|
||||
CK_TILE_HOST_DEVICE bfloat16_t double_to_bf16(double f, constant<rounding>)
|
||||
{
|
||||
return bfloat16_t::bit_cast(double_to_bf16_raw(f, constant<rounding>{}));
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
float bf16_to_float(bfloat16_t x) { return static_cast<float>(x); }
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
double bf16_to_double(bfloat16_t x) { return static_cast<double>(x); }
|
||||
|
||||
template <bf16_rounding_mode rounding =
|
||||
static_cast<bf16_rounding_mode>(CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT)>
|
||||
CK_TILE_HOST_DEVICE bfloat16_t fp16_to_bf16(half_t f, constant<rounding> = {})
|
||||
@@ -240,7 +274,7 @@ struct numeric_limits<bfloat16_t>
|
||||
}
|
||||
};
|
||||
|
||||
CK_TILE_ARITHMETIC_USING_FLOAT(bfloat16_t)
|
||||
CK_TILE_ARITHMETIC_USING_FLOAT(CK_TILE_HOST_DEVICE, bfloat16_t)
|
||||
|
||||
// math
|
||||
CK_TILE_HOST_DEVICE
|
||||
|
||||
@@ -184,7 +184,7 @@ CK_TILE_HOST_DEVICE Y run_cast_to_f8(X x, uint32_t rng)
|
||||
int exponent, bias;
|
||||
uint32_t head, mantissa, sign;
|
||||
// nan code is same for float and half
|
||||
constexpr Y nan_code = 0x80;
|
||||
constexpr Y nan_code = __builtin_bit_cast(Y, static_cast<uint8_t>(0x80));
|
||||
constexpr uint32_t nan_mask = numeric_utils<X>::nan_mask;
|
||||
|
||||
// convert to bitwise
|
||||
@@ -215,7 +215,7 @@ CK_TILE_HOST_DEVICE Y run_cast_to_f8(X x, uint32_t rng)
|
||||
|
||||
// check if x is 0.0
|
||||
if(x_bitwise == 0)
|
||||
return 0;
|
||||
return __builtin_bit_cast(Y, static_cast<uint8_t>(0));
|
||||
|
||||
// First need to check if it is normal or denorm as there is a difference of implict 1
|
||||
// Then need to adjust the exponent to align with the F8 exponent, in the meanwhile, shift
|
||||
@@ -317,15 +317,18 @@ In this case, the fp16 mantissa should be shift left by 1 */
|
||||
}
|
||||
else
|
||||
{
|
||||
return signed_inf;
|
||||
return __builtin_bit_cast(Y, static_cast<uint8_t>(signed_inf));
|
||||
}
|
||||
}
|
||||
|
||||
// check if x is 0.0 or -0.0
|
||||
if(out_exponent == 0 && mantissa == 0)
|
||||
return negative_zero_nan ? 0 : (sign << (out_exp + out_mant));
|
||||
return __builtin_bit_cast(
|
||||
Y, static_cast<uint8_t>(negative_zero_nan ? 0 : (sign << (out_exp + out_mant))));
|
||||
mantissa &= (1 << out_mant) - 1;
|
||||
return (sign << (out_exp + out_mant)) | (out_exponent << out_mant) | mantissa;
|
||||
return __builtin_bit_cast(Y,
|
||||
static_cast<uint8_t>((sign << (out_exp + out_mant)) |
|
||||
(out_exponent << out_mant) | mantissa));
|
||||
}
|
||||
|
||||
template <typename X, typename Y, bool negative_zero_nan>
|
||||
@@ -338,9 +341,10 @@ CK_TILE_HOST_DEVICE Y run_cast_from_f8(X x)
|
||||
// resulting type exponent/mantissa layout
|
||||
constexpr int out_exp = numeric_utils<Y>::exp;
|
||||
constexpr int out_mant = numeric_utils<Y>::mant;
|
||||
uint8_t x_raw = __builtin_bit_cast(uint8_t, x);
|
||||
|
||||
// prepare the codes
|
||||
constexpr X nan_code = 0x80;
|
||||
constexpr uint8_t nan_code = 0x80;
|
||||
Y Inf, NegInf, NaN, Neg0;
|
||||
using T_bitwise = typename numeric_utils<Y>::bitwise_type;
|
||||
|
||||
@@ -355,13 +359,13 @@ CK_TILE_HOST_DEVICE Y run_cast_from_f8(X x)
|
||||
Neg0 = *(reinterpret_cast<const Y*>(&Neg0_bitwise));
|
||||
|
||||
// check if x is 0.0
|
||||
if(x == 0)
|
||||
if(x_raw == 0)
|
||||
return static_cast<Y>(0);
|
||||
|
||||
// unpack the input
|
||||
uint32_t sign = x >> (in_exp + in_mant);
|
||||
uint32_t mantissa = x & ((1 << in_mant) - 1);
|
||||
int exponent = (x & 0x7F) >> in_mant;
|
||||
uint32_t sign = x_raw >> (in_exp + in_mant);
|
||||
uint32_t mantissa = x_raw & ((1 << in_mant) - 1);
|
||||
int exponent = (x_raw & 0x7F) >> in_mant;
|
||||
|
||||
constexpr int exp_low_cutoff =
|
||||
(1 << (out_exp - 1)) - (1 << (in_exp - 1)) + 1 - (negative_zero_nan ? 1 : 0);
|
||||
@@ -369,12 +373,12 @@ CK_TILE_HOST_DEVICE Y run_cast_from_f8(X x)
|
||||
|
||||
if constexpr(negative_zero_nan)
|
||||
{
|
||||
if(x == nan_code)
|
||||
if(x_raw == nan_code)
|
||||
return NaN;
|
||||
}
|
||||
else
|
||||
{
|
||||
if(x == nan_code)
|
||||
if(x_raw == nan_code)
|
||||
return Neg0;
|
||||
if(exponent == ((1 << in_exp) - 1))
|
||||
return (mantissa == 0) ? (sign ? NegInf : Inf) : NaN;
|
||||
@@ -382,7 +386,7 @@ CK_TILE_HOST_DEVICE Y run_cast_from_f8(X x)
|
||||
|
||||
if((numeric_utils<Y>::mant == 10) && (numeric_utils<X>::mant == 2) && !negative_zero_nan)
|
||||
{
|
||||
retval = x;
|
||||
retval = x_raw;
|
||||
retval <<= 8;
|
||||
return *(reinterpret_cast<const Y*>(&retval));
|
||||
}
|
||||
@@ -700,8 +704,8 @@ struct numeric_limits<bf8_t>
|
||||
CK_TILE_HOST_DEVICE static constexpr bf8_t denorm_min() { return bf8_t::bit_cast(0x01); }
|
||||
};
|
||||
|
||||
CK_TILE_ARITHMETIC_USING_FLOAT(fp8_t)
|
||||
CK_TILE_ARITHMETIC_USING_FLOAT(bf8_t)
|
||||
CK_TILE_ARITHMETIC_USING_FLOAT(CK_TILE_HOST_DEVICE, fp8_t)
|
||||
CK_TILE_ARITHMETIC_USING_FLOAT(CK_TILE_HOST_DEVICE, bf8_t)
|
||||
|
||||
// math
|
||||
CK_TILE_HOST_DEVICE
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/arithmetic.hpp"
|
||||
#include "ck_tile/core/utility/bit_cast.hpp"
|
||||
#include "ck_tile/core/utility/limits.hpp"
|
||||
#include <hip/hip_fp16.h>
|
||||
@@ -15,9 +16,15 @@ using fp16_hip_t = __half; // most of hip internal function use this type
|
||||
CK_TILE_HOST_DEVICE
|
||||
float fp16_to_float_hip(const fp16_hip_t& x);
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
double fp16_to_double_hip(const fp16_hip_t& x);
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
fp16_hip_t float_to_fp16_hip(const float& x);
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
fp16_hip_t double_to_fp16_hip(const double& x);
|
||||
|
||||
// HIP use fp16_hip_t as interchangable data type for float16
|
||||
struct alignas(2) half_t
|
||||
{
|
||||
@@ -46,6 +53,10 @@ struct alignas(2) half_t
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit constexpr half_t(const float& x) : half_t(float_to_fp16_hip(x)) {}
|
||||
|
||||
// construct from double
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit constexpr half_t(const double& x) : half_t(double_to_fp16_hip(x)) {}
|
||||
|
||||
// construct from int
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit constexpr half_t(const int& x) : half_t(static_cast<fp16_hip_t>(__int2half_rn(x))) {}
|
||||
@@ -61,6 +72,10 @@ struct alignas(2) half_t
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit constexpr operator float() const { return fp16_to_float_hip(to_fp16()); }
|
||||
|
||||
// cast to double
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit constexpr operator double() const { return fp16_to_double_hip(to_fp16()); }
|
||||
|
||||
// cast to int
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit constexpr operator int() const
|
||||
@@ -87,6 +102,9 @@ float fp16_to_float_hip(const fp16_hip_t& x)
|
||||
return static_cast<float>(x);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
double fp16_to_double_hip(const fp16_hip_t& x) { return static_cast<double>(fp16_to_float_hip(x)); }
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
fp16_hip_t float_to_fp16_hip(const float& x)
|
||||
{
|
||||
@@ -94,12 +112,25 @@ fp16_hip_t float_to_fp16_hip(const float& x)
|
||||
return static_cast<fp16_hip_t>(x);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
fp16_hip_t double_to_fp16_hip(const double& x)
|
||||
{
|
||||
// return __float2half(x);
|
||||
return static_cast<fp16_hip_t>(x);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
float fp16_to_float(const half_t& x) { return static_cast<float>(x); }
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
float fp16_to_double(const half_t& x) { return static_cast<float>(x); }
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
half_t float_to_fp16(const float& x) { return half_t{x}; }
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
half_t double_to_fp16(const double& x) { return half_t{x}; }
|
||||
|
||||
// limits
|
||||
template <class T>
|
||||
struct numeric_limits;
|
||||
@@ -156,94 +187,94 @@ struct numeric_utils<half_t>
|
||||
};
|
||||
|
||||
// arithmetic
|
||||
CK_TILE_HOST_DEVICE
|
||||
CK_TILE_DEVICE
|
||||
bool operator==(const half_t& x, const half_t& y) { return __heq(x.to_fp16(), y.to_fp16()); }
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
CK_TILE_DEVICE
|
||||
bool operator!=(const half_t& x, const half_t& y) { return __hne(x.to_fp16(), y.to_fp16()); }
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
CK_TILE_DEVICE
|
||||
bool operator<(const half_t& x, const half_t& y) { return __hlt(x.to_fp16(), y.to_fp16()); }
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
CK_TILE_DEVICE
|
||||
bool operator<=(const half_t& x, const half_t& y) { return __hle(x.to_fp16(), y.to_fp16()); }
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
CK_TILE_DEVICE
|
||||
bool operator>(const half_t& x, const half_t& y) { return __hgt(x.to_fp16(), y.to_fp16()); }
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
CK_TILE_DEVICE
|
||||
bool operator>=(const half_t& x, const half_t& y) { return __hge(x.to_fp16(), y.to_fp16()); }
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
CK_TILE_DEVICE
|
||||
half_t operator+(const half_t& x, const half_t& y)
|
||||
{
|
||||
return half_t(__hadd(x.to_fp16(), y.to_fp16()));
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
CK_TILE_DEVICE
|
||||
half_t operator-(const half_t& x) { return half_t(__hneg(x.to_fp16())); }
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
CK_TILE_DEVICE
|
||||
half_t operator-(const half_t& x, const half_t& y)
|
||||
{
|
||||
return half_t(__hsub(x.to_fp16(), y.to_fp16()));
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
CK_TILE_DEVICE
|
||||
half_t operator*(const half_t& x, const half_t& y)
|
||||
{
|
||||
return half_t(__hmul(x.to_fp16(), y.to_fp16()));
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
CK_TILE_DEVICE
|
||||
half_t operator/(const half_t& x, const half_t& y)
|
||||
{
|
||||
return half_t(__hdiv(x.to_fp16(), y.to_fp16()));
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
CK_TILE_DEVICE
|
||||
half_t& operator+=(half_t& x, const half_t& y)
|
||||
{
|
||||
x = half_t(__hadd(x.to_fp16(), y.to_fp16()));
|
||||
return x;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
CK_TILE_DEVICE
|
||||
half_t& operator-=(half_t& x, const half_t& y)
|
||||
{
|
||||
x = half_t(__hsub(x.to_fp16(), y.to_fp16()));
|
||||
return x;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
CK_TILE_DEVICE
|
||||
half_t& operator*=(half_t& x, const half_t& y)
|
||||
{
|
||||
x = half_t(__hmul(x.to_fp16(), y.to_fp16()));
|
||||
return x;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
CK_TILE_DEVICE
|
||||
half_t& operator/=(half_t& x, const half_t& y)
|
||||
{
|
||||
x = half_t(__hdiv(x.to_fp16(), y.to_fp16()));
|
||||
return x;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
CK_TILE_DEVICE
|
||||
half_t& operator++(half_t& x)
|
||||
{
|
||||
x = half_t(__hadd(x.to_fp16(), half_t(1.0f).to_fp16()));
|
||||
return x;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
CK_TILE_DEVICE
|
||||
half_t& operator--(half_t& x)
|
||||
{
|
||||
x = half_t(__hsub(x.to_fp16(), half_t(1.0f).to_fp16()));
|
||||
return x;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
CK_TILE_DEVICE
|
||||
half_t operator++(half_t& x, int)
|
||||
{
|
||||
half_t y(x);
|
||||
@@ -251,7 +282,7 @@ half_t operator++(half_t& x, int)
|
||||
return y;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
CK_TILE_DEVICE
|
||||
half_t operator--(half_t& x, int)
|
||||
{
|
||||
half_t y(x);
|
||||
@@ -259,6 +290,8 @@ half_t operator--(half_t& x, int)
|
||||
return y;
|
||||
}
|
||||
|
||||
CK_TILE_ARITHMETIC_USING_FLOAT(CK_TILE_HOST, half_t)
|
||||
|
||||
// math
|
||||
CK_TILE_HOST_DEVICE
|
||||
half_t abs(const half_t& x) { return half_t::bit_cast(x.get() & 0x7fff); }
|
||||
|
||||
@@ -14,8 +14,9 @@ struct constant
|
||||
using value_type = decltype(v);
|
||||
using type = constant; // using injected-class-name
|
||||
static constexpr value_type value = v;
|
||||
constexpr CK_TILE_HOST_DEVICE operator value_type() const noexcept { return value; }
|
||||
constexpr CK_TILE_HOST_DEVICE value_type operator()() const noexcept { return value; }
|
||||
CK_TILE_HOST_DEVICE constexpr operator value_type() const noexcept { return value; }
|
||||
CK_TILE_HOST_DEVICE constexpr value_type operator()() const noexcept { return value; }
|
||||
CK_TILE_HOST_DEVICE static constexpr bool is_static() { return true; }
|
||||
};
|
||||
|
||||
template <typename T, T v>
|
||||
|
||||
@@ -9,6 +9,7 @@
|
||||
#include "ck_tile/core/utility/bit_cast.hpp"
|
||||
#include <type_traits>
|
||||
#include <stdint.h>
|
||||
#include <cmath>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -147,8 +148,8 @@ CK_TILE_HOST_DEVICE constexpr T clamp(const T& x, const T& lowerbound, const T&
|
||||
return min(max(x, lowerbound), upperbound);
|
||||
}
|
||||
|
||||
CK_TILE_HOST inline int clz(uint32_t x) { return __builtin_clz(x); }
|
||||
CK_TILE_DEVICE inline int clz(uint32_t x) { return __clz(x); }
|
||||
CK_TILE_HOST int clz(uint32_t x) { return __builtin_clz(x); }
|
||||
CK_TILE_DEVICE int clz(uint32_t x) { return __clz(x); }
|
||||
|
||||
// greatest common divisor, aka highest common factor
|
||||
CK_TILE_HOST_DEVICE constexpr index_t gcd(index_t x, index_t y)
|
||||
@@ -246,7 +247,7 @@ CK_TILE_HOST_DEVICE constexpr int32_t integer_log2_floor(int32_t x)
|
||||
{
|
||||
// TODO: x need to be 1 ~ 0x7fffffff
|
||||
// __builtin_clz will produce unexpected result if x is 0;
|
||||
return 31 - clz(x);
|
||||
return 31 - __builtin_clz(x);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr bool is_power_of_two_integer(int32_t x)
|
||||
@@ -275,7 +276,7 @@ struct log2e<float>
|
||||
};
|
||||
|
||||
template <typename T = double>
|
||||
inline constexpr T log2e_v = log2e<T>::value;
|
||||
constexpr T log2e_v = log2e<T>::value;
|
||||
|
||||
// math
|
||||
CK_TILE_HOST_DEVICE
|
||||
@@ -298,16 +299,32 @@ bool isnan(const float& x)
|
||||
return (xx & 0x7fffffff) > 0x7F800000;
|
||||
}
|
||||
|
||||
CK_TILE_HOST float sqrt(float x) { return std::sqrt(x); };
|
||||
|
||||
CK_TILE_HOST double sqrt(double x) { return std::sqrt(x); };
|
||||
|
||||
CK_TILE_DEVICE
|
||||
float sqrt(float x) { return __builtin_amdgcn_sqrtf(x); };
|
||||
|
||||
CK_TILE_DEVICE
|
||||
double sqrt(double x) { return __builtin_amdgcn_sqrt(x); };
|
||||
|
||||
CK_TILE_DEVICE
|
||||
float exp(float x) { return __expf(x); };
|
||||
|
||||
CK_TILE_HOST
|
||||
float exp(float x) { return std::expf(x); }
|
||||
|
||||
CK_TILE_DEVICE
|
||||
float exp2(float x) { return exp2f(x); };
|
||||
|
||||
CK_TILE_HOST
|
||||
float exp2(float x) { return std::exp2f(x); };
|
||||
|
||||
CK_TILE_DEVICE
|
||||
float log(float x) { return __logf(x); };
|
||||
|
||||
CK_TILE_HOST
|
||||
float log(float x) { return std::logf(x); };
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -43,11 +43,11 @@ CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
|
||||
return static_cast<Y>(type_convert<non_const_y, non_const_x>(x));
|
||||
}
|
||||
|
||||
#define CK_TILE_TYPE_CONVERT(dtype_, stype_) \
|
||||
template <> \
|
||||
inline CK_TILE_HOST_DEVICE constexpr dtype_ type_convert<dtype_, stype_>(stype_ x) \
|
||||
{ \
|
||||
return stype_##_to_##dtype_(x); \
|
||||
#define CK_TILE_TYPE_CONVERT(dtype_, stype_) \
|
||||
template <> \
|
||||
CK_TILE_HOST_DEVICE constexpr dtype_ type_convert<dtype_, stype_>(stype_ x) \
|
||||
{ \
|
||||
return stype_##_to_##dtype_(x); \
|
||||
}
|
||||
|
||||
CK_TILE_TYPE_CONVERT(float, fp16_t)
|
||||
|
||||
@@ -63,12 +63,12 @@ using fp32x32_t = float __attribute__((ext_vector_type(32)));
|
||||
using fp32x64_t = float __attribute__((ext_vector_type(64)));
|
||||
|
||||
// fp16
|
||||
using fp16x2_t = fp16_raw_t __attribute__((ext_vector_type(2)));
|
||||
using fp16x4_t = fp16_raw_t __attribute__((ext_vector_type(4)));
|
||||
using fp16x8_t = fp16_raw_t __attribute__((ext_vector_type(8)));
|
||||
using fp16x16_t = fp16_raw_t __attribute__((ext_vector_type(16)));
|
||||
using fp16x32_t = fp16_raw_t __attribute__((ext_vector_type(32)));
|
||||
using fp16x64_t = fp16_raw_t __attribute__((ext_vector_type(64)));
|
||||
using fp16x2_t = _Float16 __attribute__((ext_vector_type(2)));
|
||||
using fp16x4_t = _Float16 __attribute__((ext_vector_type(4)));
|
||||
using fp16x8_t = _Float16 __attribute__((ext_vector_type(8)));
|
||||
using fp16x16_t = _Float16 __attribute__((ext_vector_type(16)));
|
||||
using fp16x32_t = _Float16 __attribute__((ext_vector_type(32)));
|
||||
using fp16x64_t = _Float16 __attribute__((ext_vector_type(64)));
|
||||
|
||||
// bfp16
|
||||
using bf16x2_t = bf16_raw_t __attribute__((ext_vector_type(2)));
|
||||
@@ -94,6 +94,14 @@ using int16x16_t = int16_t __attribute__((ext_vector_type(16)));
|
||||
using int16x32_t = int16_t __attribute__((ext_vector_type(32)));
|
||||
using int16x64_t = int16_t __attribute__((ext_vector_type(64)));
|
||||
|
||||
// u16
|
||||
using uint16x2_t = uint16_t __attribute__((ext_vector_type(2)));
|
||||
using uint16x4_t = uint16_t __attribute__((ext_vector_type(4)));
|
||||
using uint16x8_t = uint16_t __attribute__((ext_vector_type(8)));
|
||||
using uint16x16_t = uint16_t __attribute__((ext_vector_type(16)));
|
||||
using uint16x32_t = uint16_t __attribute__((ext_vector_type(32)));
|
||||
using uint16x64_t = uint16_t __attribute__((ext_vector_type(64)));
|
||||
|
||||
// i8
|
||||
using int8x2_t = int8_t __attribute((ext_vector_type(2)));
|
||||
using int8x4_t = int8_t __attribute((ext_vector_type(4)));
|
||||
|
||||
Reference in New Issue
Block a user