This commit is contained in:
carlushuang
2024-02-28 22:57:19 +00:00
parent e60c5aea4e
commit f69356b1d7
130 changed files with 28268 additions and 0 deletions

View File

@@ -0,0 +1,19 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
namespace ck_tile {
template <typename Y, typename X>
CK_TILE_HOST_DEVICE constexpr Y bit_cast(const X& x)
{
static_assert(__has_builtin(__builtin_bit_cast), "");
static_assert(sizeof(X) == sizeof(Y), "Do not support cast between different size of type");
return __builtin_bit_cast(Y, x);
}
} // namespace ck_tile

View File

@@ -0,0 +1,194 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/sequence.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include <stdint.h>
#include <utility>
namespace ck_tile {
namespace detail {
struct swallow
{
template <typename... Ts>
CK_TILE_HOST_DEVICE constexpr swallow(Ts&&...)
{
}
};
template <class>
struct static_for_impl;
template <index_t... Is>
struct static_for_impl<sequence<Is...>>
{
template <class F>
CK_TILE_HOST_DEVICE constexpr void operator()(F f) const
{
swallow{(f(number<Is>{}), 0)...};
}
};
} // namespace detail
// F signature: F(number<Iter>)
template <index_t NBegin, index_t NEnd, index_t Increment>
struct static_for
{
CK_TILE_HOST_DEVICE constexpr static_for()
{
static_assert(Increment != 0 && (NEnd - NBegin) % Increment == 0,
"Wrong! should satisfy (NEnd - NBegin) % Increment == 0");
static_assert((Increment > 0 && NBegin <= NEnd) || (Increment < 0 && NBegin >= NEnd),
"wrongs! should (Increment > 0 && NBegin <= NEnd) || (Increment < 0 && "
"NBegin >= NEnd)");
}
template <class F>
CK_TILE_HOST_DEVICE constexpr void operator()(F f) const
{
detail::static_for_impl<typename arithmetic_sequence_gen<NBegin, NEnd, Increment>::type>{}(
f);
}
};
struct identity
{
template <typename T>
CK_TILE_HOST_DEVICE constexpr T&& operator()(T&& arg) const noexcept
{
return std::forward<T>(arg);
}
};
namespace detail {
// RemainLengths: sequence<...>
// Orders: sequence<...>
template <class RemainLengths, class Orders>
struct static_ford_impl
{
CK_TILE_HOST_DEVICE constexpr static_ford_impl()
{
static_assert(RemainLengths::size() > 0, "wrong! should not get here");
}
// F signature: F(sequence<...>)
// CurrentOrderedId: sequence<...>
template <class F, class CurrentOrderedId>
CK_TILE_HOST_DEVICE constexpr void operator()(F f, CurrentOrderedId) const
{
static_for<0, RemainLengths::front(), 1>{}([=](auto I) {
static_ford_impl<decltype(RemainLengths::pop_front()), Orders>{}(
f, CurrentOrderedId::push_back(I));
});
}
};
template <class Orders>
struct static_ford_impl<sequence<>, Orders>
{
// F signature: F(sequence<...>)
// OrderedId: sequence<...>
template <class F, class OrderedId>
CK_TILE_HOST_DEVICE constexpr void operator()(F f, OrderedId) const
{
// retrive unordered Id
f(OrderedId::reorder_old_to_new(Orders{}));
}
};
} // namespace detail
// Lengths is sequence<...>, it is the length of each dimension for
// N-dimensional loop
// Orders is sequence<...>, it is the order of dimension in which static_ford
// will loop over each
// dimension
template <class Lengths,
class Orders = typename arithmetic_sequence_gen<0, Lengths::size(), 1>::type>
struct static_ford
{
CK_TILE_HOST_DEVICE constexpr static_ford()
{
static_assert(Lengths::size() > 0, "wrong! Lengths is empty");
static_assert(Lengths::size() == Orders::size(), "wrong! inconsistent size");
}
// F signature: F(sequence<...> multi_id)
// multi_id is the unordered multi-index
template <class F>
CK_TILE_HOST_DEVICE constexpr void operator()(F f) const
{
constexpr auto ordered_lengths = Lengths::reorder_new_to_old(Orders{});
detail::static_ford_impl<decltype(ordered_lengths), Orders>{}(f, sequence<>{});
}
};
namespace detail {
template <typename Indices>
struct unpack_impl;
template <index_t... Is>
struct unpack_impl<sequence<Is...>>
{
template <typename F, typename X>
CK_TILE_HOST_DEVICE constexpr auto operator()(F&& f, X&& x) const
{
#if 0
return std::forward<F>(f)(std::forward<X>(x).at(number<Is>{})...);
#else
return std::forward<F>(f)(std::forward<X>(x).template at<Is>()...);
#endif
}
};
template <typename Seq0, typename Seq1>
struct unpack2_impl;
// TODO: remove this, after properly implementing unpack that takes any number of containers
template <index_t... Is, index_t... Js>
struct unpack2_impl<sequence<Is...>, sequence<Js...>>
{
template <typename F, typename X, typename Y>
CK_TILE_HOST_DEVICE constexpr auto operator()(F&& f, X&& x, Y&& y) const
{
#if 0
return std::forward<F>(f)(std::forward<X>(x).at(number<Is>{})...,
std::forward<Y>(y).at(number<Js>{})...);
#else
return std::forward<F>(f)(std::forward<X>(x).template at<Is>()...,
std::forward<Y>(y).template at<Js>()...);
#endif
}
};
} // namespace detail
template <typename F, typename X>
CK_TILE_HOST_DEVICE constexpr auto unpack(F&& f, X&& x)
{
using X_ = remove_reference_t<X>;
return detail::unpack_impl<typename arithmetic_sequence_gen<0, X_::size(), 1>::type>{}(
std::forward<F>(f), std::forward<X>(x));
}
// TODO: properly implement unpack that takes any number of containers
template <typename F, typename X, typename Y>
CK_TILE_HOST_DEVICE constexpr auto unpack2(F&& f, X&& x, Y&& y)
{
using X_ = remove_reference_t<X>;
using Y_ = remove_reference_t<Y>;
return detail::unpack2_impl<typename arithmetic_sequence_gen<0, X_::size(), 1>::type,
typename arithmetic_sequence_gen<0, Y_::size(), 1>::type>{}(
std::forward<F>(f), std::forward<X>(x), std::forward<Y>(y));
}
} // namespace ck_tile

View File

@@ -0,0 +1,75 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include <limits>
#include <stdint.h>
namespace ck_tile {
template <typename T>
struct numeric_limits
{
// minimum finite value, or minimum positive normalized value for float
CK_TILE_HOST_DEVICE static constexpr T min() { return std::numeric_limits<T>::min(); }
// minumum finite value
CK_TILE_HOST_DEVICE static constexpr T lowest() { return std::numeric_limits<T>::lowest(); }
// maximum finite value
CK_TILE_HOST_DEVICE static constexpr T max() { return std::numeric_limits<T>::max(); }
// difference between 1.0 and next value representable by float
CK_TILE_HOST_DEVICE static constexpr T epsilon() { return std::numeric_limits<T>::epsilon(); }
// maximum rounding error
CK_TILE_HOST_DEVICE static constexpr T round_error()
{
return std::numeric_limits<T>::round_error();
}
// positive infinity value
CK_TILE_HOST_DEVICE static constexpr T infinity() { return std::numeric_limits<T>::infinity(); }
// quiet NaN
CK_TILE_HOST_DEVICE static constexpr T quiet_NaN()
{
return std::numeric_limits<T>::quiet_NaN();
}
// signaling NaN
CK_TILE_HOST_DEVICE static constexpr T signaling_NaN()
{
return std::numeric_limits<T>::signaling_NaN();
}
// smallest positive subnormal value
CK_TILE_HOST_DEVICE static constexpr T denorm_min()
{
return std::numeric_limits<T>::denorm_min();
}
};
template <typename T>
struct numeric_utils;
template <>
struct numeric_utils<float>
{
static constexpr int exp = 8;
static constexpr int mant = 23;
static constexpr int bias = 127;
static constexpr uint32_t nan_mask = 0x7F800000;
static constexpr uint32_t head_mask = 0xFF800000;
static constexpr uint32_t mant_mask = 0x7FFFFF;
static constexpr uint32_t exp_mask = 0xFF;
static constexpr uint32_t Inf = 0x7F800000;
static constexpr uint32_t NegInf = 0xFF800000;
static constexpr uint32_t NaN = 0x7F800001;
static constexpr uint32_t Neg0 = 0x80000000;
using bitwise_type = uint32_t;
};
} // namespace ck_tile

View File

@@ -0,0 +1,261 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/tuple.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include <stdint.h>
namespace ck_tile {
// magic number division
// Caution:
// 1. For uint32_t as dividend: magic number division implementation being used would produce
// correct result if the dividend is uint32_t and its value is within 31-bit value range.
// 2. For int32_t as dividendd: magic number division for int32_t dividened has not been
// implemented, the int32_t dividend would be bit-wise interpreted as uint32_t and magic number
// division implementation for uint32_t is then used. Therefore, dividend value need to be
// non-negative.
// TODO:
// 1. Implement magic number divison for int32_t
// 2. Implement magic number divison for unit32_t with 32-bit value range
struct magic_division32_bit_range
{
// uint32_t
CK_TILE_HOST_DEVICE static constexpr auto calculate_magic_numbers(uint32_t divisor)
{
// WARNING: magic division is only valid for division inside this range.
// assert(divisor >= 1 && divisor <= INT32_MAX)
uint32_t shift_u32 = 0;
while((1U << shift_u32) < divisor)
{
shift_u32++;
};
uint64_t tmp_u64 = ((1UL << shift_u32) - divisor) << 32;
uint32_t multiplier_u32 = tmp_u64 / divisor + 1;
return make_tuple(multiplier_u32, shift_u32);
}
// integral_constant<uint32_t, .>
template <uint32_t Divisor, typename = std::enable_if_t<(0 < Divisor)>>
CK_TILE_HOST_DEVICE static constexpr auto
calculate_magic_numbers(integral_constant<uint32_t, Divisor>)
{
constexpr auto tmp = calculate_magic_numbers(uint32_t{Divisor});
constexpr uint32_t multiplier = tmp[number<0>{}];
constexpr uint32_t shift = tmp[number<1>{}];
return make_tuple(integral_constant<uint32_t, multiplier>{},
integral_constant<uint32_t, shift>{});
}
// integral_constant<int32_t, .>
template <int32_t Divisor>
CK_TILE_HOST_DEVICE static constexpr auto
calculate_magic_numbers(integral_constant<int32_t, Divisor>)
{
return calculate_magic_numbers(integral_constant<uint32_t, Divisor>{});
}
// magic division for uint32_t
CK_TILE_DEVICE static constexpr uint32_t
do_magic_division(uint32_t dividend, uint32_t multiplier, uint32_t shift)
{
uint32_t tmp = __umulhi(dividend, multiplier);
return (tmp + dividend) >> shift;
}
CK_TILE_HOST static constexpr uint32_t
do_magic_division(uint32_t dividend, uint32_t multiplier, uint32_t shift)
{
uint32_t tmp = (static_cast<uint64_t>(dividend) * multiplier) >> 32;
return (tmp + dividend) >> shift;
}
// magic division for int32_t
// HACK: use dividend_i32 as if it's uint32_t, dividend_i32 need to be
// non-negative for result to be correct
// TODO: figure out how to do magic number divison for int32_t as dividended
CK_TILE_DEVICE static constexpr int32_t
do_magic_division(int32_t dividend_i32, uint32_t multiplier, uint32_t shift)
{
uint32_t dividend_u32 = bit_cast<uint32_t>(dividend_i32);
uint32_t tmp = __umulhi(dividend_u32, multiplier);
return (tmp + dividend_u32) >> shift;
}
CK_TILE_HOST static constexpr int32_t
do_magic_division(int32_t dividend_i32, uint32_t multiplier, uint32_t shift)
{
uint32_t dividend_u32 = bit_cast<uint32_t>(dividend_i32);
uint32_t tmp = (static_cast<uint64_t>(dividend_u32) * multiplier) >> 32;
return (tmp + dividend_u32) >> shift;
}
};
// magic number division
// This version on works for divisor and dividended between [0, 1 << 16]
struct magic_division16_bit_range
{
// uint32_t
CK_TILE_HOST_DEVICE static constexpr auto calculate_magic_numbers(uint32_t divisor)
{
// WARNING: magic division is only valid for division inside this range.
// assert(divisor >= 1 && divisor <= (1U << 16));
uint32_t shift_u32 = 0;
while((1U << shift_u32) < divisor)
{
shift_u32++;
};
uint32_t one = 1;
uint32_t multiplier_u32 = ((one << 16) * ((one << shift_u32) - divisor)) / divisor + 1;
return make_tuple(multiplier_u32, shift_u32);
}
// integral_constant<uint32_t, .>
template <uint32_t Divisor>
CK_TILE_HOST_DEVICE static constexpr auto
calculate_magic_numbers(integral_constant<uint32_t, Divisor>)
{
constexpr auto tmp = calculate_magic_numbers(uint32_t{Divisor});
constexpr uint32_t multiplier = tmp[number<0>{}];
constexpr uint32_t shift = tmp[number<1>{}];
return make_tuple(integral_constant<uint32_t, multiplier>{},
integral_constant<uint32_t, shift>{});
}
// integral_constant<int32_t, .>
template <int32_t Divisor>
CK_TILE_HOST_DEVICE static constexpr auto
calculate_magic_numbers(integral_constant<int32_t, Divisor>)
{
return calculate_magic_numbers(integral_constant<uint32_t, Divisor>{});
}
// magic division for uint32_t
CK_TILE_DEVICE static constexpr uint32_t
do_magic_division(uint32_t dividend, uint32_t multiplier, uint32_t shift)
{
uint32_t tmp = (dividend * multiplier) >> 16;
return (tmp + dividend) >> shift;
}
CK_TILE_HOST static constexpr uint32_t
do_magic_division(uint32_t dividend, uint32_t multiplier, uint32_t shift)
{
uint32_t tmp = (dividend * multiplier) >> 16;
return (tmp + dividend) >> shift;
}
// magic division for int32_t
// HACK: use dividend_i32 as if it's uint32_t, dividend_i32 need to be
// non-negative for result to be correct
// TODO: figure out how to do magic number divison for int32_t as dividended
CK_TILE_DEVICE static constexpr int32_t
do_magic_division(int32_t dividend_i32, uint32_t multiplier, uint32_t shift)
{
uint32_t dividend_u32 = bit_cast<uint32_t>(dividend_i32);
uint32_t tmp = (dividend_u32 * multiplier) >> 16;
return (tmp + dividend_u32) >> shift;
}
CK_TILE_HOST static constexpr int32_t
do_magic_division(int32_t dividend_i32, uint32_t multiplier, uint32_t shift)
{
uint32_t dividend_u32 = bit_cast<uint32_t>(dividend_i32);
uint32_t tmp = (dividend_u32 * multiplier) >> 16;
return (tmp + dividend_u32) >> shift;
}
};
// use 32bit version
using magic_division = magic_division32_bit_range;
struct mdiv
{
// 1 dword -> 3 dword storage
uint32_t divisor;
uint32_t multiplier;
uint32_t shift; // TODO: 8 bit is enough
// prefer construct on host
CK_TILE_HOST_DEVICE mdiv(uint32_t divisor_) : divisor(divisor_)
{
auto tmp = magic_division::calculate_magic_numbers(divisor_);
multiplier = tmp[number<0>{}];
shift = tmp[number<1>{}];
}
CK_TILE_HOST_DEVICE mdiv() : divisor(0), multiplier(0), shift(0) {}
CK_TILE_HOST_DEVICE void update(uint32_t divisor_)
{
divisor = divisor_;
auto tmp = magic_division::calculate_magic_numbers(divisor_);
multiplier = tmp[number<0>{}];
shift = tmp[number<1>{}];
}
CK_TILE_HOST_DEVICE uint32_t div(uint32_t dividend_) const
{
return magic_division::do_magic_division(dividend_, multiplier, shift);
}
CK_TILE_HOST_DEVICE void
divmod(uint32_t dividend_, uint32_t& quotient_, uint32_t& remainder_) const
{
quotient_ = div(dividend_);
remainder_ = dividend_ - (quotient_ * divisor);
}
CK_TILE_HOST_DEVICE uint32_t get() const { return divisor; }
};
struct mdiv2
{
// 1 dword -> 2 dword storage, divisor need compute from runtime
uint32_t multiplier;
uint32_t shift; // TODO: 8 bit is enough
// prefer construct on host
CK_TILE_HOST_DEVICE mdiv2(uint32_t divisor_)
{
auto tmp = magic_division::calculate_magic_numbers(divisor_);
multiplier = tmp[number<0>{}];
shift = tmp[number<1>{}];
}
CK_TILE_HOST_DEVICE mdiv2() : multiplier(0), shift(0) {}
CK_TILE_HOST_DEVICE uint32_t div(uint32_t dividend_) const
{
return magic_division::do_magic_division(dividend_, multiplier, shift);
}
CK_TILE_HOST_DEVICE void
divmod(uint32_t dividend_, uint32_t divisor_, uint32_t& quotient_, uint32_t& remainder_) const
{
quotient_ = div(dividend_);
remainder_ = dividend_ - (quotient_ * divisor_);
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,64 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/half.hpp"
#include <stdint.h>
#include <tuple>
#include <type_traits>
namespace ck_tile {
// return 0 if data is not fp16 or fp32
template <typename T, uint32_t seed_>
struct prand_generator_t
{
CK_TILE_HOST_DEVICE uint32_t operator()(int id, T val, uint32_t seed = seed_)
{
std::ignore = id;
std::ignore = val;
std::ignore = seed;
return 0;
}
};
// version for fp32
template <uint32_t seed_>
struct prand_generator_t<float, seed_>
{
CK_TILE_HOST_DEVICE uint32_t operator()(int id, float val, uint32_t seed = seed_)
{
uint32_t x = *(reinterpret_cast<uint32_t*>(&val));
uint32_t drop_bits = uint32_t(x) & 0xFFFFu;
drop_bits ^= x >> 16;
drop_bits = ((drop_bits & 31) << 11) | (drop_bits >> 5);
drop_bits *= 0x7000149;
// NOTE: If id is in 64 bit, we are only using lower 32 bit.
// So, it can have an effect of using same id for multiple elements when the id is
// very large!
uint32_t rng = (drop_bits ^ 0x13371337 ^ (id * 229791) ^ seed);
return rng;
}
};
// version for fp16
template <uint32_t seed_>
struct prand_generator_t<half_t, seed_>
{
CK_TILE_HOST_DEVICE uint32_t operator()(int id, half_t val, uint32_t seed = seed_)
{
uint16_t x = *(reinterpret_cast<uint16_t*>(&val));
uint32_t drop_bits = uint32_t(x) & 0xFFFFu;
drop_bits = ((drop_bits & 31) << 11) | (drop_bits >> 5);
drop_bits *= 0x7000149;
// NOTE: If id is in 64 bit, we are only using lower 32 bit.
// So, it can have an effect of using same id for multiple elements when the id is
// very large!
uint32_t rng = (drop_bits ^ 0x13371337 ^ (id * 229791) ^ seed);
return rng;
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,72 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/container/sequence.hpp"
// TODO: use c++20 nontype template with struct to implement this
#if 1
// clang happen to support this feature (__cpp_generic_lambdas >= 201707) in c++17 mode
#define TO_SEQUENCE(a, n) \
_Pragma("clang diagnostic push") \
_Pragma("clang diagnostic ignored \"-Wc++20-extensions\"")[a]<ck_tile::index_t... Is>( \
ck_tile::sequence<Is...>) \
{ \
return ck_tile::sequence<a.at(ck_tile::number<Is>{})...>{}; \
} \
(make_index_sequence<n>{}) _Pragma("clang diagnostic pop")
#else
// Macro function
// convert constexpr array to sequence, both a/n need to be constexpr (can't be a rvalue like 2)
#define TO_SEQUENCE(a, n) \
[a, n] { \
static_assert(a.size() >= n, "wrong! out of bound"); \
static_assert(n <= 10, "not implemented"); \
if constexpr(n == 0) \
{ \
return ck_tile::sequence<>{}; \
} \
else if constexpr(n == 1) \
{ \
return ck_tile::sequence<a[0]>{}; \
} \
else if constexpr(n == 2) \
{ \
return ck_tile::sequence<a[0], a[1]>{}; \
} \
else if constexpr(n == 3) \
{ \
return ck_tile::sequence<a[0], a[1], a[2]>{}; \
} \
else if constexpr(n == 4) \
{ \
return ck_tile::sequence<a[0], a[1], a[2], a[3]>{}; \
} \
else if constexpr(n == 5) \
{ \
return ck_tile::sequence<a[0], a[1], a[2], a[3], a[4]>{}; \
} \
else if constexpr(n == 6) \
{ \
return ck_tile::sequence<a[0], a[1], a[2], a[3], a[4], a[5]>{}; \
} \
else if constexpr(n == 7) \
{ \
return ck_tile::sequence<a[0], a[1], a[2], a[3], a[4], a[5], a[6]>{}; \
} \
else if constexpr(n == 8) \
{ \
return ck_tile::sequence<a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7]>{}; \
} \
else if constexpr(n == 9) \
{ \
return ck_tile::sequence<a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], a[8]>{}; \
} \
else if constexpr(n == 10) \
{ \
return ck_tile:: \
sequence<a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], a[8], a[9]>{}; \
} \
}()
#endif

View File

@@ -0,0 +1,57 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <stdint.h>
#include <tuple>
#include <type_traits>
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/bfloat16.hpp"
#include "ck_tile/core/numeric/float8.hpp"
namespace ck_tile {
// Convert X to Y, both X and Y are non-const data types.
template <typename Y,
typename X,
std::enable_if_t<!(std::is_const_v<Y> || std::is_const_v<X>), bool> = false>
CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
{
static_assert(!std::is_reference_v<Y> && !std::is_reference_v<X>);
return static_cast<Y>(x);
}
// Convert X to Y, either X or Y is a const data type.
template <typename Y,
typename X,
std::enable_if_t<std::is_const_v<Y> || std::is_const_v<X>, bool> = false>
CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
{
static_assert(!std::is_reference_v<Y> && !std::is_reference_v<X>);
using non_const_y = std::remove_const_t<Y>;
using non_const_x = std::remove_const_t<X>;
return static_cast<Y>(type_convert<non_const_y, non_const_x>(x));
}
#define CK_TILE_TYPE_CONVERT(dtype_, stype_) \
template <> \
inline CK_TILE_HOST_DEVICE constexpr dtype_ type_convert<dtype_, stype_>(stype_ x) \
{ \
return stype_##_to_##dtype_(x); \
}
CK_TILE_TYPE_CONVERT(float, fp16_t)
CK_TILE_TYPE_CONVERT(float, bf16_t)
CK_TILE_TYPE_CONVERT(float, fp8_t)
CK_TILE_TYPE_CONVERT(float, bf8_t)
CK_TILE_TYPE_CONVERT(fp16_t, float)
CK_TILE_TYPE_CONVERT(bf16_t, float)
CK_TILE_TYPE_CONVERT(fp8_t, float)
CK_TILE_TYPE_CONVERT(bf8_t, float)
} // namespace ck_tile

View File

@@ -0,0 +1,46 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include <type_traits>
#include <stdint.h>
namespace ck_tile {
// remove_cvref_t
template <typename T>
using remove_reference_t = typename std::remove_reference<T>::type;
template <typename T>
using remove_cv_t = typename std::remove_cv<T>::type;
template <typename T>
using remove_cvref_t = remove_cv_t<std::remove_reference_t<T>>;
template <typename T>
using remove_pointer_t = typename std::remove_pointer<T>::type;
namespace impl {
template <typename T>
struct is_static_impl
{
static constexpr bool value = std::is_arithmetic<T>::v ? false : T::is_static();
};
} // namespace impl
template <typename T>
using is_static = impl::is_static_impl<remove_cvref_t<T>>;
template <typename T>
inline constexpr bool is_static_v = is_static<T>::value;
// TODO: deprecate this
template <typename T>
using is_known_at_compile_time = is_static<T>;
// TODO: if evaluating a rvalue, e.g. a const integer
// , this helper will also return false, which is not good(?)
// do we need something like is_constexpr()?
} // namespace ck_tile