This commit is contained in:
carlushuang
2024-02-28 22:57:19 +00:00
parent e60c5aea4e
commit f69356b1d7
130 changed files with 28268 additions and 0 deletions

View File

@@ -0,0 +1,201 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
namespace ck_tile {
// use aggregate initialization for this type
// e.g. array<index_t, 4> buf {0}; => {0, 0, 0, 0}, clean
// array<index_t, 4> buf {3, 2}; => {3, 2, 2, 2} (not {3,2,0,0})
// use make_array_with({...}) to construct an array with compatible behavior as old ck
// TODO: manually added constructor same as old ck
template <typename T_, index_t N_>
struct array
{
using value_type = T_;
static constexpr index_t N = N_;
value_type data[N];
CK_TILE_HOST_DEVICE constexpr array() : data{} {}
// TODO: will initialize the data[] with the last value repeatedly
// behavior different from std
CK_TILE_HOST_DEVICE constexpr array(std::initializer_list<value_type> ilist)
{
constexpr index_t list_size = std::initializer_list<value_type>{}.size();
static_assert(list_size <= N, "out of bound");
index_t i = 0;
value_type vlast = value_type{};
for(const value_type& val : ilist)
{
data[i] = val;
vlast = val;
++i;
}
for(; i < N; ++i)
{
data[i] = vlast;
}
}
CK_TILE_HOST_DEVICE static constexpr auto size() { return N; }
CK_TILE_HOST_DEVICE static constexpr bool is_static() { return is_static_v<value_type>; }
// clang-format off
CK_TILE_HOST_DEVICE constexpr auto& get() { return data; }
CK_TILE_HOST_DEVICE constexpr const auto& get() const { return data; }
CK_TILE_HOST_DEVICE constexpr auto& get(index_t i) { return data[i]; }
CK_TILE_HOST_DEVICE constexpr const auto& get(index_t i) const { return data[i]; }
template <index_t I> CK_TILE_HOST_DEVICE constexpr auto& get() { return data[I]; }
template <index_t I> CK_TILE_HOST_DEVICE constexpr const auto& get() const { return data[I]; }
template <index_t I> CK_TILE_HOST_DEVICE constexpr auto& get(number<I>) { return data[I]; }
template <index_t I> CK_TILE_HOST_DEVICE constexpr const auto& get(number<I>) const { return data[I]; }
CK_TILE_HOST_DEVICE constexpr auto& at(index_t i) { return data[i]; }
CK_TILE_HOST_DEVICE constexpr const auto& at(index_t i) const { return data[i]; }
template <index_t I> CK_TILE_HOST_DEVICE constexpr auto& at() { return data[I]; }
template <index_t I> CK_TILE_HOST_DEVICE constexpr const auto& at() const { return data[I]; }
template <index_t I> CK_TILE_HOST_DEVICE constexpr auto& at(number<I>) { return data[I]; }
template <index_t I> CK_TILE_HOST_DEVICE constexpr const auto& at(number<I>) const { return data[I]; }
CK_TILE_HOST_DEVICE constexpr const value_type& operator[](index_t i) const { return data[i]; }
CK_TILE_HOST_DEVICE constexpr value_type& operator[](index_t i) { return data[i]; }
CK_TILE_HOST_DEVICE constexpr value_type& operator()(index_t i) { return data[i]; } // TODO: compatible
template <typename T>
CK_TILE_HOST_DEVICE constexpr auto operator=(const T& a)
{
static_assert(T::size() == size(), "wrong! size not the same");
for(index_t i = 0; i < size(); ++i)
{
data[i] = a[i];
}
return *this;
}
// type punning (strict aliasing) member functions for read/write
// aliasing this array of type "T", "N" elements
// as array of type "Tx", sizeof(T)*N/sizeof(Tx) elements
#define AR_AS_COM_() \
static_assert(sizeof(value_type) * N % sizeof(Tx) == 0); \
constexpr int vx = sizeof(value_type) * N / sizeof(Tx)
template <typename Tx> CK_TILE_HOST_DEVICE constexpr auto& get_as()
{ AR_AS_COM_(); return reinterpret_cast<array<Tx, vx>&>(data); }
template <typename Tx> CK_TILE_HOST_DEVICE constexpr const auto& get_as() const
{ AR_AS_COM_(); return reinterpret_cast<const array<Tx, vx>&>(data); }
// below index is for index *AFTER* type convert, not before
template <typename Tx> CK_TILE_HOST_DEVICE constexpr auto& get_as(index_t i)
{ AR_AS_COM_(); return reinterpret_cast<array<Tx, vx>&>(data).at(i); }
template <typename Tx> CK_TILE_HOST_DEVICE constexpr const auto& get_as(index_t i) const
{ AR_AS_COM_(); return reinterpret_cast<const array<Tx, vx>&>(data).at(i); }
template <typename Tx, index_t I> CK_TILE_HOST_DEVICE constexpr auto& get_as(number<I>)
{ AR_AS_COM_(); return reinterpret_cast<array<Tx, vx>&>(data).at(number<I>{}); }
template <typename Tx, index_t I> CK_TILE_HOST_DEVICE constexpr const auto& get_as(number<I>) const
{ AR_AS_COM_(); return reinterpret_cast<const array<Tx, vx>&>(data).at(number<I>{}); }
template <typename Tx> CK_TILE_HOST_DEVICE constexpr void set_as(index_t i, const Tx & x)
{ AR_AS_COM_(); reinterpret_cast<array<Tx, vx>&>(data).at(i) = x; }
template <typename Tx, index_t I> CK_TILE_HOST_DEVICE constexpr void set_as(number<I>, const Tx & x)
{ AR_AS_COM_(); reinterpret_cast<array<Tx, vx>&>(data).at(number<I>{}) = x; }
#undef AR_AS_COM_
// clang-format on
};
// empty Array
template <typename T>
struct array<T, 0>
{
using value_type = T;
CK_TILE_HOST_DEVICE constexpr array() {}
CK_TILE_HOST_DEVICE static constexpr index_t size() { return 0; }
CK_TILE_HOST_DEVICE static constexpr bool is_static() { return is_static_v<T>; };
CK_TILE_HOST_DEVICE void print() const { printf("array{size: 0, data: []}"); }
};
template <typename T, typename... Ts>
CK_TILE_HOST_DEVICE constexpr auto make_array(T&& x, Ts&&... xs)
{
using value_type = remove_cvref_t<T>;
return array<value_type, sizeof...(Ts) + 1>{std::forward<T>(x), std::forward<Ts>(xs)...};
}
// make empty array
template <typename T>
CK_TILE_HOST_DEVICE constexpr auto make_array()
{
return array<T, 0>{};
}
// compatible with old ck's initializer, make an array and fill it withe the last element from
// initializer_list
#include <initializer_list>
template <typename T, index_t Size>
CK_TILE_HOST_DEVICE constexpr auto make_array_with(std::initializer_list<T> ilist)
{
constexpr index_t list_size = std::initializer_list<T>{}.size();
static_assert(list_size <= Size, "out of bound");
index_t i = 0;
T vlast = T{};
array<T, Size> arr;
for(const T& val : ilist)
{
arr.data[i] = val;
vlast = val;
++i;
}
for(; i < Size; ++i)
{
arr.data[i] = vlast;
}
return arr;
}
template <typename T, index_t Size>
CK_TILE_HOST_DEVICE constexpr bool operator==(const array<T, Size>& a, const array<T, Size>& b)
{
bool same = true;
for(index_t i = 0; i < Size; ++i)
{
if(a[i] != b[i])
{
same = false;
break;
}
}
return same;
}
template <typename T, index_t Size>
CK_TILE_HOST_DEVICE constexpr bool operator!=(const array<T, Size>& a, const array<T, Size>& b)
{
return !(a == b);
}
template <typename T, index_t N, typename X>
CK_TILE_HOST_DEVICE constexpr auto to_array(const X& x)
{
STATIC_ASSERT(N <= X::size(), "");
array<T, N> arr;
static_for<0, N, 1>{}([&x, &arr](auto i) { arr(i) = x[i]; });
return arr;
}
} // namespace ck_tile

View File

@@ -0,0 +1,483 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/array.hpp"
#include "ck_tile/core/container/map.hpp"
#include "ck_tile/core/container/sequence.hpp"
#include "ck_tile/core/container/tuple.hpp"
#include "ck_tile/core/utility/functional.hpp"
namespace ck_tile {
template <typename TData, index_t NSize>
CK_TILE_HOST_DEVICE constexpr auto container_push_back(const array<TData, NSize>& a, const TData& x)
{
array<TData, NSize + 1> r;
static_for<0, NSize, 1>{}([&r, &a ](auto i) constexpr { r(i) = a[i]; });
r[number<NSize>{}] = x;
return r;
}
template <typename... Ts, typename T>
CK_TILE_HOST_DEVICE constexpr auto container_push_front(const tuple<Ts...>& a, const T& x)
{
return container_concat(make_tuple(x), a);
}
template <typename... Ts, typename T>
CK_TILE_HOST_DEVICE constexpr auto container_push_back(const tuple<Ts...>& a, const T& x)
{
return container_concat(a, make_tuple(x));
}
// reorder array
template <typename TData, index_t NSize, index_t... IRs>
CK_TILE_HOST_DEVICE constexpr auto
container_reorder_given_new2old(const array<TData, NSize>& old_array, sequence<IRs...> /*new2old*/)
{
static_assert(NSize == sizeof...(IRs), "wrong! size not consistent");
static_assert(is_valid_sequence_map<sequence<IRs...>>{}, "wrong! invalid reorder map");
return make_array<remove_cvref_t<TData>>(old_array[IRs]...);
}
template <typename TData, index_t NSize, index_t... IRs>
CK_TILE_HOST_DEVICE constexpr auto
container_reorder_given_old2new(const array<TData, NSize>& old_array, sequence<IRs...> old2new)
{
return container_reorder_given_new2old(
old_array, typename sequence_map_inverse<decltype(old2new)>::type{});
}
// reorder array
template <typename TData, index_t NSize>
CK_TILE_HOST_DEVICE constexpr auto
container_reorder_given_new2old(const array<TData, NSize>& old_array,
const map<index_t, index_t>& new2old)
{
array<TData, NSize> new_array;
for(const auto& [new_pos, old_pos] : new2old)
{
new_array(new_pos) = old_array[old_pos];
}
return new_array;
}
template <typename TData, index_t NSize>
CK_TILE_HOST_DEVICE constexpr auto
container_reorder_given_old2new(const array<TData, NSize>& old_array,
const map<index_t, index_t>& old2new)
{
array<TData, NSize> new_array;
for(const auto& [old_pos, new_pos] : old2new)
{
new_array(new_pos) = old_array[old_pos];
}
return new_array;
}
// reorder tuple
template <typename... Ts, index_t... IRs>
CK_TILE_HOST_DEVICE constexpr auto container_reorder_given_new2old(const tuple<Ts...>& old_tuple,
sequence<IRs...> /*new2old*/)
{
static_assert(sizeof...(Ts) == sizeof...(IRs), "wrong! size not consistent");
static_assert(is_valid_sequence_map<sequence<IRs...>>{}, "wrong! invalid reorder map");
return make_tuple(old_tuple[number<IRs>{}]...);
}
template <typename... Ts, index_t... IRs>
CK_TILE_HOST_DEVICE constexpr auto container_reorder_given_old2new(const tuple<Ts...>& old_tuple,
sequence<IRs...> old2new)
{
return container_reorder_given_new2old(
old_tuple, typename sequence_map_inverse<decltype(old2new)>::type{});
}
// reorder sequence
template <index_t... Is, index_t... IRs>
CK_TILE_HOST_DEVICE constexpr auto container_reorder_given_new2old(sequence<Is...> /* old_seq */,
sequence<IRs...> /*new2old*/)
{
static_assert(sizeof...(Is) == sizeof...(IRs), "wrong! size not consistent");
static_assert(is_valid_sequence_map<sequence<IRs...>>{}, "wrong! invalid reorder map");
return sequence<sequence<Is...>::at(number<IRs>{})...>{};
}
template <index_t... Is, index_t... IRs>
CK_TILE_HOST_DEVICE constexpr auto container_reorder_given_old2new(sequence<Is...> old_seq,
sequence<IRs...> /* old2new */)
{
static_assert(sizeof...(Is) == sizeof...(IRs), "wrong! size not consistent");
static_assert(is_valid_sequence_map<sequence<IRs...>>{}, "wrong! invalid reorder map");
constexpr auto new2old = typename sequence_map_inverse<sequence<IRs...>>::type{};
return container_reorder_given_new2old(old_seq, new2old);
}
#if 0
// rocm-4.1 compiler would crash for recursive lambda
template <typename Container,
typename Reduce,
typename Init,
index_t IBegin = 0,
index_t IEnd = Container::size(),
index_t IStep = 1>
CK_TILE_HOST_DEVICE constexpr auto container_reduce(const Container& x,
Reduce reduce,
Init init,
number<IBegin> = number<0>{},
number<IEnd> = number<Container::size()>{},
number<IStep> = number<1>{})
{
static_assert((IEnd - IBegin) % IStep == 0, "wrong!");
// f is recursive function, fs is a dummy of f
// i is index, y_old is current scan, r_old is current reduction
auto f = [&](auto fs, auto i, auto r_old) {
auto r_new = reduce(x[i], r_old);
if constexpr(i.value < IEnd - IStep)
{
// recursively call f/fs
return fs(fs, i + number<IStep>{}, r_new);
}
else
{
return r_new;
}
};
// start recursion
return f(f, number<IBegin>{}, init);
}
#else
// i is index, y_old is current scan, r_old is current reduction
template <typename Container,
typename Reduce,
typename ROld,
index_t I,
index_t IEnd,
index_t IStep>
CK_TILE_HOST_DEVICE constexpr auto container_reduce_impl(
const Container& x, Reduce reduce, ROld r_old, number<I> i, number<IEnd>, number<IStep>)
{
auto r_new = reduce(x[i], r_old);
if constexpr(i.value < IEnd - IStep)
{
return container_reduce_impl(
x, reduce, r_new, i + number<IStep>{}, number<IEnd>{}, number<IStep>{});
}
else
{
return r_new;
}
}
// rocm-4.1 compiler would crash for recursive lambda
// container reduce with initial value
template <typename Container,
typename Reduce,
typename Init,
index_t IBegin = 0,
index_t IEnd = Container::size(),
index_t IStep = 1>
CK_TILE_HOST_DEVICE constexpr auto container_reduce(const Container& x,
Reduce reduce,
Init init,
number<IBegin> = number<0>{},
number<IEnd> = number<Container::size()>{},
number<IStep> = number<1>{})
{
static_assert((IEnd - IBegin) % IStep == 0, "wrong!");
if constexpr(IEnd > IBegin)
{
return container_reduce_impl(
x, reduce, init, number<IBegin>{}, number<IEnd>{}, number<IStep>{});
}
else
{
return init;
}
}
#endif
template <typename TData, index_t NSize, typename Reduce>
CK_TILE_HOST_DEVICE constexpr auto
container_reverse_inclusive_scan(const array<TData, NSize>& x, Reduce f, TData init)
{
array<TData, NSize> y;
TData r = init;
static_for<NSize - 1, 0, -1>{}([&](auto i) {
r = f(r, x[i]);
y(i) = r;
});
r = f(r, x[number<0>{}]);
y(number<0>{}) = r;
return y;
}
template <typename TData, index_t NSize, typename Reduce, typename Init>
CK_TILE_HOST_DEVICE constexpr auto
container_reverse_exclusive_scan(const array<TData, NSize>& x, Reduce f, Init init)
{
#if 0
array<TData, NSize> y;
TData r = init;
static_for<NSize - 1, 0, -1>{}([&](auto i) {
y(i) = r;
r = f(r, x[i]);
});
y(number<0>{}) = r;
return y;
#else
array<TData, NSize> y;
TData r = init;
for(index_t i = NSize - 1; i > 0; --i)
{
y(i) = r;
r = f(r, x[i]);
}
y(0) = r;
return y;
#endif
}
template <index_t... Is, typename Reduce, index_t Init>
CK_TILE_HOST_DEVICE constexpr auto
container_reverse_exclusive_scan(const sequence<Is...>& seq, Reduce f, number<Init>)
{
return reverse_exclusive_scan_sequence(seq, f, number<Init>{});
}
#if 0
// rocm4.1 compiler would crash with recursive lambda
template <typename... Xs, typename Reduce, typename Init>
CK_TILE_HOST_DEVICE constexpr auto
container_reverse_exclusive_scan(const tuple<Xs...>& x, Reduce reduce, Init init)
{
constexpr index_t NSize = sizeof...(Xs);
// f is recursive function, fs is a dummy of f
// i is index, y_old is current scan, r_old is current reduction
auto f = [&](auto fs, auto i, auto y_old, auto r_old) {
auto r_new = reduce(x[i], r_old);
auto y_new = container_push_front(y_old, r_new);
if constexpr(i.value > 1)
{
// recursively call f/fs
return fs(fs, i - number<1>{}, y_new, r_new);
}
else
{
return y_new;
}
};
// start recursion
return f(f, number<NSize - 1>{}, make_tuple(init), init);
}
#else
// i is index, y_old is current scan, r_old is current reduction
template <typename... Xs, typename Reduce, index_t I, typename YOld, typename ROld>
CK_TILE_HOST_DEVICE constexpr auto container_reverse_exclusive_scan_impl(
const tuple<Xs...>& x, Reduce reduce, number<I> i, YOld y_old, ROld r_old)
{
auto r_new = reduce(x[i], r_old);
auto y_new = container_push_front(y_old, r_new);
if constexpr(i.value > 1)
{
// recursively call f/fs
return container_reverse_exclusive_scan_impl(x, reduce, i - number<1>{}, y_new, r_new);
}
else
{
return y_new;
}
}
template <typename... Xs, typename Reduce, typename Init>
CK_TILE_HOST_DEVICE constexpr auto
container_reverse_exclusive_scan(const tuple<Xs...>& x, Reduce reduce, Init init)
{
constexpr index_t NSize = sizeof...(Xs);
return container_reverse_exclusive_scan_impl(
x, reduce, number<NSize - 1>{}, make_tuple(init), init);
}
#endif
// TODO: update to like container_reverse_exclusive_scan to deal with tuple of Numebr<>
template <typename... Xs, typename Reduce, typename TData>
CK_TILE_HOST_DEVICE constexpr auto
container_reverse_inclusive_scan(const tuple<Xs...>& x, Reduce f, TData init)
{
constexpr index_t NSize = sizeof...(Xs);
tuple<Xs...> y;
TData r = init;
static_for<NSize - 1, 0, -1>{}([&](auto i) {
r = f(r, x[i]);
y(i) = r;
});
r = f(r, x[number<0>{}]);
y(number<0>{}) = r;
return y;
}
template <typename X, typename... Ys>
CK_TILE_HOST_DEVICE constexpr auto container_concat(const X& x, const Ys&... ys)
{
return container_concat(x, container_concat(ys...));
}
template <typename T, index_t NX, index_t NY>
CK_TILE_HOST_DEVICE constexpr auto container_concat(const array<T, NX>& ax, const array<T, NY>& ay)
{
return unpack2(
[&](auto&&... zs) { return make_array<T>(std::forward<decltype(zs)>(zs)...); }, ax, ay);
}
template <typename... X, typename... Y>
CK_TILE_HOST_DEVICE constexpr auto container_concat(const tuple<X...>& tx, const tuple<Y...>& ty)
{
return unpack2(
[&](auto&&... zs) { return make_tuple(std::forward<decltype(zs)>(zs)...); }, tx, ty);
}
template <typename Container>
CK_TILE_HOST_DEVICE constexpr auto container_concat(const Container& x)
{
return x;
}
template <typename T, index_t N, index_t... Is>
CK_TILE_HOST_DEVICE constexpr auto get_container_subset(const array<T, N>& arr, sequence<Is...>)
{
static_assert(N >= sizeof...(Is), "wrong! size");
if constexpr(sizeof...(Is) > 0)
{
return make_array<T>(arr[Is]...);
}
else
{
return array<T, 0>{};
}
}
template <typename... Ts, index_t... Is>
CK_TILE_HOST_DEVICE constexpr auto get_container_subset(const tuple<Ts...>& tup, sequence<Is...>)
{
static_assert(sizeof...(Ts) >= sizeof...(Is), "wrong! size");
if constexpr(sizeof...(Is) > 0)
{
return make_tuple(tup[number<Is>{}]...);
}
else
{
return tuple<>{};
}
}
template <typename T, index_t N, index_t... Is>
CK_TILE_HOST_DEVICE constexpr void
set_container_subset(array<T, N>& y, sequence<Is...> picks, const array<T, sizeof...(Is)>& x)
{
static_assert(N >= sizeof...(Is), "wrong! size");
if constexpr(sizeof...(Is) > 0)
{
for(index_t i = 0; i < picks.size(); ++i)
{
y(picks[i]) = x[i];
}
}
}
template <typename Y, typename X, index_t... Is>
CK_TILE_HOST_DEVICE constexpr void set_container_subset(Y& y, sequence<Is...> picks, const X& x)
{
static_assert(Y::size() >= sizeof...(Is) && X::size() == sizeof...(Is), "wrong! size");
if constexpr(sizeof...(Is) > 0)
{
static_for<0, sizeof...(Is), 1>{}([&](auto i) { y(picks[i]) = x[i]; });
}
}
// return the index of first occurance in the sequence.
// return seq.size(), if not found
template <index_t... Is>
constexpr index_t container_find(sequence<Is...> seq, index_t value)
{
for(auto i = 0; i < seq.size(); i++)
{
if(seq[i] == value)
return i;
}
return seq.size();
}
template <index_t... Is>
CK_TILE_HOST_DEVICE constexpr auto sequence_to_tuple_of_number(sequence<Is...>)
{
using Seq = sequence<Is...>;
return generate_tuple(
[&](auto i) {
constexpr index_t tmp = Seq::at(i);
return number<tmp>{};
},
number<Seq::size()>{});
}
#define TO_TUPLE_OF_SEQUENCE(a_of_b_impl, a_size, bs_sizes) \
[a_of_b_impl, a_size, bs_sizes] { \
return ck_tile::generate_tuple( \
[=](auto i) { \
constexpr auto b_impl = a_of_b_impl[i]; \
constexpr index_t b_size = bs_sizes[i]; \
constexpr auto b = TO_SEQUENCE(b_impl, b_size); \
return b; \
}, \
ck_tile::number<a_size>{}); \
}()
} // namespace ck_tile

View File

@@ -0,0 +1,164 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/array.hpp"
#include "ck_tile/core/container/sequence.hpp"
#include "ck_tile/core/container/tuple.hpp"
namespace ck_tile {
// naive map
template <typename key, typename data, index_t max_size = 128>
struct map
{
using pair_type = tuple<key, data>;
using impl_type = array<pair_type, max_size>;
impl_type impl_;
index_t size_;
struct iterator
{
impl_type& impl_;
index_t pos_;
CK_TILE_HOST_DEVICE constexpr iterator(impl_type& impl, index_t pos)
: impl_{impl}, pos_{pos}
{
}
CK_TILE_HOST_DEVICE constexpr iterator& operator++()
{
pos_++;
return *this;
}
CK_TILE_HOST_DEVICE constexpr bool operator!=(const iterator& other) const
{
return other.pos_ != pos_;
}
CK_TILE_HOST_DEVICE constexpr pair_type& operator*() { return impl_.at(pos_); }
};
struct const_iterator
{
const impl_type& impl_;
index_t pos_;
CK_TILE_HOST_DEVICE constexpr const_iterator(const impl_type& impl, index_t pos)
: impl_{impl}, pos_{pos}
{
}
CK_TILE_HOST_DEVICE constexpr const_iterator& operator++()
{
pos_++;
return *this;
}
CK_TILE_HOST_DEVICE constexpr bool operator!=(const const_iterator& other) const
{
return other.pos_ != pos_;
}
CK_TILE_HOST_DEVICE constexpr const pair_type& operator*() const { return impl_.at(pos_); }
};
CK_TILE_HOST_DEVICE constexpr map() : impl_{}, size_{0} {}
CK_TILE_HOST_DEVICE constexpr index_t size() const { return size_; }
CK_TILE_HOST_DEVICE void clear() { size_ = 0; }
CK_TILE_HOST_DEVICE constexpr index_t find_position(const key& key) const
{
for(index_t i = 0; i < size(); i++)
{
if(impl_[i].template at<0>() == key)
{
return i;
}
}
return size_;
}
CK_TILE_HOST_DEVICE constexpr const_iterator find(const key& key) const
{
return const_iterator{impl_, find_position(key)};
}
CK_TILE_HOST_DEVICE constexpr iterator find(const key& key)
{
return iterator{impl_, find_position(key)};
}
CK_TILE_HOST_DEVICE constexpr const data& operator[](const key& key) const
{
const auto it = find(key);
// FIXME
assert(it.pos_ < size());
return impl_[it.pos_].template at<1>();
}
CK_TILE_HOST_DEVICE constexpr data& operator()(const key& key)
{
auto it = find(key);
// if entry not found
if(it.pos_ == size())
{
impl_(it.pos_).template at<0>() = key;
size_++;
}
// FIXME
assert(size_ <= max_size);
return impl_(it.pos_).template at<1>();
}
// WARNING: needed by compiler for C++ range-based for loop only, don't use this function!
CK_TILE_HOST_DEVICE constexpr const_iterator begin() const { return const_iterator{impl_, 0}; }
// WARNING: needed by compiler for C++ range-based for loop only, don't use this function!
CK_TILE_HOST_DEVICE constexpr const_iterator end() const
{
return const_iterator{impl_, size_};
}
// WARNING: needed by compiler for C++ range-based for loop only, don't use this function!
CK_TILE_HOST_DEVICE constexpr iterator begin() { return iterator{impl_, 0}; }
// WARNING: needed by compiler for C++ range-based for loop only, don't use this function!
CK_TILE_HOST_DEVICE constexpr iterator end() { return iterator{impl_, size_}; }
CK_TILE_HOST_DEVICE void print() const
{
printf("map{size_: %d, ", size_);
//
printf("impl_: [");
//
for(const auto& [key, data] : *this)
{
printf("{key: ");
print(key);
printf(", data: ");
print(data);
printf("}, ");
}
//
printf("]");
//
printf("}");
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,99 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/array.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include <cstddef>
namespace ck_tile {
// TODO: this structure is not intented to be used by user
template <index_t MaxSize>
struct meta_data_buffer
{
CK_TILE_HOST_DEVICE constexpr meta_data_buffer() : buffer_{}, size_{0} {}
template <typename X, typename... Xs>
CK_TILE_HOST_DEVICE constexpr meta_data_buffer(const X& x, const Xs&... xs)
: buffer_{}, size_{0}
{
push(x, xs...);
}
template <typename T>
CK_TILE_HOST_DEVICE constexpr void push(const T& data)
{
if constexpr(!std::is_empty_v<T>)
{
constexpr index_t size = sizeof(T);
auto tmp = bit_cast<array<std::byte, size>>(data);
for(int i = 0; i < size; i++)
{
buffer_(size_) = tmp[i];
size_++;
}
}
}
template <typename X, typename... Xs>
CK_TILE_HOST_DEVICE constexpr void push(const X& x, const Xs&... xs)
{
push(x);
push(xs...);
}
template <typename T>
CK_TILE_HOST_DEVICE constexpr T pop(index_t& pos) const
{
T data;
if constexpr(!std::is_empty_v<T>)
{
constexpr index_t size = sizeof(T);
array<std::byte, size> tmp;
for(int i = 0; i < size; i++)
{
tmp(i) = buffer_[pos];
pos++;
}
data = bit_cast<T>(tmp);
}
return data;
}
template <typename T>
CK_TILE_HOST_DEVICE constexpr T get(index_t pos) const
{
constexpr index_t size = sizeof(T);
array<std::byte, size> tmp;
for(int i = 0; i < size; i++)
{
tmp(i) = buffer_[pos];
pos++;
}
auto data = bit_cast<T>(tmp);
return data;
}
//
array<std::byte, MaxSize> buffer_;
index_t size_ = 0;
};
} // namespace ck_tile

View File

@@ -0,0 +1,99 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/array.hpp"
#include "ck_tile/core/container/sequence.hpp"
#include "ck_tile/core/container/tuple.hpp"
#include "ck_tile/core/utility/functional.hpp"
namespace ck_tile {
// deprecated, always use array instead
template <index_t N>
using multi_index = array<index_t, N>;
template <typename... Xs>
CK_TILE_HOST_DEVICE constexpr auto make_multi_index(Xs&&... xs)
{
return make_array<index_t>(index_t{xs}...);
}
template <index_t NSize>
CK_TILE_HOST_DEVICE constexpr auto make_zero_multi_index()
{
return unpack([](auto... xs) { return make_multi_index(xs...); },
typename uniform_sequence_gen<NSize, 0>::type{});
}
template <typename T>
CK_TILE_HOST_DEVICE constexpr auto to_multi_index(const T& x)
{
return unpack([](auto... ys) { return make_multi_index(ys...); }, x);
}
template <index_t NSize, typename X>
CK_TILE_HOST_DEVICE constexpr auto operator+=(multi_index<NSize>& y, const X& x)
{
static_assert(X::size() == NSize, "wrong! size not the same");
static_for<0, NSize, 1>{}([&](auto i) { y[i] += x[i]; });
return y;
}
template <index_t NSize, typename X>
CK_TILE_HOST_DEVICE constexpr auto operator-=(multi_index<NSize>& y, const X& x)
{
static_assert(X::size() == NSize, "wrong! size not the same");
static_for<0, NSize, 1>{}([&](auto i) { y[i] -= x[i]; });
return y;
}
template <index_t NSize, typename T>
CK_TILE_HOST_DEVICE constexpr auto operator+(const multi_index<NSize>& a, const T& b)
{
using type = multi_index<NSize>;
static_assert(T::size() == NSize, "wrong! size not the same");
type r;
static_for<0, NSize, 1>{}([&](auto i) { r[i] = a[i] + b[i]; });
return r;
}
template <index_t NSize, typename T>
CK_TILE_HOST_DEVICE constexpr auto operator-(const multi_index<NSize>& a, const T& b)
{
using type = multi_index<NSize>;
static_assert(T::size() == NSize, "wrong! size not the same");
type r;
static_for<0, NSize, 1>{}([&](auto i) { r[i] = a[i] - b[i]; });
return r;
}
template <index_t NSize, typename T>
CK_TILE_HOST_DEVICE constexpr auto operator*(const multi_index<NSize>& a, const T& b)
{
using type = multi_index<NSize>;
static_assert(T::size() == NSize, "wrong! size not the same");
type r;
static_for<0, NSize, 1>{}([&](auto i) { r[i] = a[i] * b[i]; });
return r;
}
// multi_index = index_t * multi_index
template <index_t NSize>
CK_TILE_HOST_DEVICE constexpr auto operator*(index_t a, const multi_index<NSize>& x)
{
multi_index<NSize> r;
static_for<0, NSize, 1>{}([&](auto i) { r[i] = a * x[i]; });
return r;
}
// multi_index = multi_index * index_t
template <index_t NSize>
CK_TILE_HOST_DEVICE constexpr auto operator*(const multi_index<NSize>& x, index_t a)
{
return a * x;
}
} // namespace ck_tile

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,78 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include <cstddef>
#include <array>
#include <type_traits>
namespace ck_tile {
// implement the c++20 std::span, lightweight, non-owning reference to a sequence
// weather it is dynamic or static range. Or can be seen as a view of a contiguous sequence
// TODO: do we need in device consider this is pointer?
template <typename T>
class span
{
public:
using element_type = T;
using value_type = std::remove_cv_t<element_type>;
using size_type = std::size_t;
using difference_type = std::ptrdiff_t;
using pointer = element_type*;
using const_pointer = const element_type*;
using reference = element_type&;
using const_reference = const element_type&;
using iterator = pointer;
using const_iterator = pointer;
CK_TILE_HOST_DEVICE constexpr span() : span(nullptr, size_type{0}) {}
CK_TILE_HOST_DEVICE constexpr span(pointer first, size_type count) : ptr_(first), size_(count)
{
}
CK_TILE_HOST_DEVICE constexpr span(pointer first, pointer last) : span(first, last - first) {}
template <std::size_t N>
CK_TILE_HOST_DEVICE constexpr span(element_type (&arr)[N]) noexcept : span(arr, N)
{
}
template <std::size_t N>
CK_TILE_HOST_DEVICE constexpr span(std::array<value_type, N>& arr) noexcept
: span(arr.data(), N)
{
}
template <typename Container>
CK_TILE_HOST_DEVICE constexpr span(const Container& container)
: span(container.data(), container.size())
{
}
CK_TILE_HOST_DEVICE constexpr iterator begin() const noexcept { return ptr_; }
CK_TILE_HOST_DEVICE constexpr const_iterator cbegin() const noexcept { return begin(); }
CK_TILE_HOST_DEVICE constexpr iterator end() const noexcept { return begin() + size(); }
CK_TILE_HOST_DEVICE constexpr const_iterator cend() const noexcept { return end(); }
CK_TILE_HOST_DEVICE constexpr reference front() const { return *begin(); }
CK_TILE_HOST_DEVICE constexpr reference back() const { return *(--end()); }
CK_TILE_HOST_DEVICE constexpr reference operator[](size_type idx) const
{
return *(begin() + idx);
}
CK_TILE_HOST_DEVICE constexpr pointer data() const noexcept { return ptr_; }
CK_TILE_HOST_DEVICE constexpr size_type size() const noexcept { return size_; }
private:
pointer ptr_;
size_type size_;
};
} // namespace ck_tile

View File

@@ -0,0 +1,70 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/array.hpp"
#include "ck_tile/core/numeric/integer.hpp"
namespace ck_tile {
#if CK_TILE_STATICALLY_INDEXED_ARRAY_DEFAULT == CK_TILE_STATICALLY_INDEXED_ARRAY_USE_TUPLE
namespace detail {
template <typename X, typename Y>
struct tuple_concat;
template <typename... Xs, typename... Ys>
struct tuple_concat<tuple<Xs...>, tuple<Ys...>>
{
using type = tuple<Xs..., Ys...>;
};
template <typename T, index_t N>
struct statically_indexed_array_impl
{
using type =
typename tuple_concat<typename statically_indexed_array_impl<T, N / 2>::type,
typename statically_indexed_array_impl<T, N - N / 2>::type>::type;
};
template <typename T>
struct statically_indexed_array_impl<T, 0>
{
using type = tuple<>;
};
template <typename T>
struct statically_indexed_array_impl<T, 1>
{
using type = tuple<T>;
};
} // namespace detail
template <typename T, index_t N>
using statically_indexed_array = typename detail::statically_indexed_array_impl<T, N>::type;
#else
// consider mark this struct as deprecated
template <typename T, index_t N>
using statically_indexed_array = array<T, N>;
#endif
// consider always use ck_tile::array for this purpose
template <typename X, typename... Xs>
CK_TILE_HOST_DEVICE constexpr auto make_statically_indexed_array(const X& x, const Xs&... xs)
{
return statically_indexed_array<X, sizeof...(Xs) + 1>(x, static_cast<X>(xs)...);
}
// make empty statically_indexed_array
template <typename X>
CK_TILE_HOST_DEVICE constexpr auto make_statically_indexed_array()
{
return statically_indexed_array<X, 0>();
}
} // namespace ck_tile

View File

@@ -0,0 +1,483 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/sequence.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include <utility>
namespace ck_tile {
namespace impl {
template <index_t idx, typename T, bool is_empty = std::is_empty_v<T>>
struct tuple_element
{
};
template <index_t idx, typename T>
struct tuple_element<idx, T, true>
{
CK_TILE_HOST_DEVICE constexpr tuple_element() {}
CK_TILE_HOST_DEVICE constexpr tuple_element(const T&) {}
};
template <index_t idx, typename T>
struct tuple_element<idx, T, false>
{
CK_TILE_HOST_DEVICE constexpr tuple_element() {}
CK_TILE_HOST_DEVICE constexpr tuple_element(const T& e) : element(e) {}
T element;
};
template <std::size_t I, class T>
CK_TILE_HOST_DEVICE constexpr T const& getv(tuple_element<I, T, false> const& x)
{
return x.element;
}
template <std::size_t I, class T>
CK_TILE_HOST_DEVICE constexpr T& getv(tuple_element<I, T, false>& x)
{
return x.element;
}
template <std::size_t I, class T>
CK_TILE_HOST_DEVICE constexpr T&& getv(tuple_element<I, T, false>&& x)
{
return static_cast<T&&>(x.element);
}
template <typename index_seq, typename... T>
struct tuple_base;
template <index_t... I, typename... T>
struct tuple_base<sequence<I...>, T...> : public tuple_element<I, T>...
{
CK_TILE_HOST_DEVICE constexpr tuple_base() {}
template <class... U>
CK_TILE_HOST_DEVICE constexpr explicit tuple_base(U const&... u) : tuple_element<I, T>(u)...
{
}
template <class... U>
CK_TILE_HOST_DEVICE constexpr tuple_base(tuple_base<sequence<I...>, U...> const& u)
: tuple_element<I, T>(getv(static_cast<tuple_element<I, U> const&>(u)))...
{
}
};
} // namespace impl
template <class... T>
struct tuple : impl::tuple_base<make_index_sequence<sizeof...(T)>, T...>
{
CK_TILE_HOST_DEVICE
static constexpr auto size() { return sizeof...(T); }
using base = impl::tuple_base<make_index_sequence<sizeof...(T)>, T...>;
CK_TILE_HOST_DEVICE constexpr tuple() {}
template <class... U>
CK_TILE_HOST_DEVICE constexpr tuple(U const&... u)
: impl::tuple_base<make_index_sequence<sizeof...(T)>, T...>(u...)
{
}
template <class... U>
CK_TILE_HOST_DEVICE constexpr tuple(tuple<U...> const& u)
: impl::tuple_base<make_index_sequence<sizeof...(T)>, T...>(
static_cast<impl::tuple_base<U...> const&>(u))
{
}
CK_TILE_HOST_DEVICE static constexpr bool is_static()
{
bool flag = true;
static_for<0, sizeof...(Xs), 1>{}([&flag](auto i) {
flag &= is_static_v<remove_cvref_t<__type_pack_element<i.value, Xs...>>>;
});
return flag;
}
#define TP_COM_() static_assert(I < size(), "wrong! out of range")
// clang-format off
template<index_t I> CK_TILE_HOST_DEVICE constexpr const auto & get() const { TP_COM_(); return impl::getv<I>(*this); }
template<index_t I> CK_TILE_HOST_DEVICE constexpr const auto & get(number<I>) const { TP_COM_(); return get<I>(); }
template<index_t I> CK_TILE_HOST_DEVICE constexpr auto & get() { TP_COM_(); return impl::getv<I>(*this); }
template<index_t I> CK_TILE_HOST_DEVICE constexpr auto & get(number<I>) { TP_COM_(); return get<I>(); }
template<index_t I> CK_TILE_HOST_DEVICE constexpr const auto & at() const { TP_COM_(); return impl::getv<I>(*this); }
template<index_t I> CK_TILE_HOST_DEVICE constexpr const auto & at(number<I>) const { TP_COM_(); return get<I>(); }
template<index_t I> CK_TILE_HOST_DEVICE constexpr auto & at() { TP_COM_(); return impl::getv<I>(*this); }
template<index_t I> CK_TILE_HOST_DEVICE constexpr auto & at(number<I>) { TP_COM_(); return get<I>(); }
template<index_t I> CK_TILE_HOST_DEVICE constexpr auto & operator[](number<I>) { TP_COM_(); return get<I>(); }
template<index_t I> CK_TILE_HOST_DEVICE constexpr const auto & operator[](number<I>) const { TP_COM_(); return get<I>(); }
template<index_t I> CK_TILE_HOST_DEVICE constexpr auto & operator()(number<I>) { TP_COM_(); return get<I>(); } // TODO: compatible
// clang-format on
#undef TP_COM_
};
// template <class... T>
// CK_TILE_HOST_DEVICE constexpr
// tuple<T...>
// make_tuple(T const&... t)
// {
// return {t...};
// }
template <typename... Xs>
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs&&... xs)
{
return tuple<remove_cvref_t<Xs>...>(std::forward<Xs>(xs)...);
}
// https://en.cppreference.com/w/cpp/utility/tuple/tie
template <typename... Args>
constexpr tuple<Args&...> tie(Args&... args) noexcept
{
return {args...};
}
template <typename X, typename Y>
struct tuple_concat;
template <typename... Xs, typename... Ys>
struct tuple_concat<tuple<Xs...>, tuple<Ys...>>
{
using type = tuple<Xs..., Ys...>;
};
template <typename F, index_t N>
CK_TILE_HOST_DEVICE constexpr auto generate_tuple(F&& f, number<N>)
{
return unpack([&f](auto&&... is) { return make_tuple(f(is)...); },
typename arithmetic_sequence_gen<0, N, 1>::type{});
}
template <typename F, index_t N>
CK_TILE_HOST_DEVICE constexpr auto generate_tie(F&& f, number<N>)
{
return unpack([&f](auto&&... is) { return tie(f(is)...); },
typename arithmetic_sequence_gen<0, N, 1>::type{});
}
// tx and ty are tuple of references, return type of will tuple of referennce (not rvalue)
template <typename... X, typename... Y>
CK_TILE_HOST_DEVICE constexpr auto concat_tuple_of_reference(const tuple<X&...>& tx,
const tuple<Y&...>& ty)
{
return unpack2(
[&](auto&&... zs) { return tuple<decltype(zs)...>{std::forward<decltype(zs)>(zs)...}; },
tx,
ty);
}
template <typename... X, typename... Y>
CK_TILE_HOST_DEVICE constexpr auto concat_tuple(const tuple<X...>& tx, const tuple<Y...>& ty)
{
return unpack2(
[&](auto... zs) { return tuple<decltype(zs)...>{std::forward<decltype(zs)>(zs)...}; },
tx,
ty);
}
// Support any number of tuples to concat (also 1)
template <typename... X>
CK_TILE_HOST_DEVICE constexpr auto concat_tuple(const tuple<X...>& tx)
{
return tx;
}
template <typename... X, typename... Tuples>
CK_TILE_HOST_DEVICE constexpr auto concat_tuple(const tuple<X...>& tx, const Tuples&... tuples)
{
return concat_tuple(tx, concat_tuple(tuples...));
}
namespace detail {
template <typename F, typename X, index_t... Is>
CK_TILE_HOST_DEVICE constexpr auto transform_tuples_impl(F f, const X& x, sequence<Is...>)
{
return make_tuple(f(x.at(number<Is>{}))...);
}
template <typename F, typename X, typename Y, index_t... Is>
CK_TILE_HOST_DEVICE constexpr auto
transform_tuples_impl(F f, const X& x, const Y& y, sequence<Is...>)
{
return make_tuple(f(x.at(number<Is>{}), y.at(number<Is>{}))...);
}
template <typename F, typename X, typename Y, typename Z, index_t... Is>
CK_TILE_HOST_DEVICE constexpr auto
transform_tuples_impl(F f, const X& x, const Y& y, const Z& z, sequence<Is...>)
{
return make_tuple(f(x.at(number<Is>{}), y.at(number<Is>{}), z.at(number<Is>{}))...);
}
} // namespace detail
template <typename F, typename X>
CK_TILE_HOST_DEVICE constexpr auto transform_tuples(F f, const X& x)
{
return detail::transform_tuples_impl(
f, x, typename arithmetic_sequence_gen<0, X::size()(), 1>::type{});
}
template <typename F, typename X, typename Y>
CK_TILE_HOST_DEVICE constexpr auto transform_tuples(F f, const X& x, const Y& y)
{
return detail::transform_tuples_impl(
f, x, y, typename arithmetic_sequence_gen<0, X::size()(), 1>::type{});
}
template <typename F, typename X, typename Y, typename Z>
CK_TILE_HOST_DEVICE constexpr auto transform_tuples(F f, const X& x, const Y& y, const Z& z)
{
return detail::transform_tuples_impl(
f, x, y, z, typename arithmetic_sequence_gen<0, X::size()(), 1>::type{});
}
// By default unroll to the flatten
template <index_t Depth = 0, index_t MaxDepth = -1>
CK_TILE_HOST_DEVICE constexpr auto unroll_nested_tuple(const tuple<>& element)
{
return element;
}
template <index_t Depth = 0, index_t MaxDepth = -1, typename T>
CK_TILE_HOST_DEVICE constexpr auto unroll_nested_tuple(const T& element)
{
return make_tuple(element);
}
template <index_t Depth = 0, index_t MaxDepth = -1, typename... Ts>
CK_TILE_HOST_DEVICE constexpr auto unroll_nested_tuple(const tuple<Ts...>& tuple)
{
if constexpr(Depth == MaxDepth)
{
return tuple;
}
else
{
return unpack(
[&](auto&&... ts) {
return concat_tuple(unroll_nested_tuple<Depth + 1, MaxDepth>(ts)...);
},
tuple);
}
}
template <typename... Ts>
CK_TILE_HOST_DEVICE constexpr auto tuple_reverse(const tuple<Ts...>& tuple)
{
return generate_tuple(
[&](auto i) {
using Idx = number<tuple<Ts...>::size()() - i - 1>;
return tuple.at(Idx{});
},
number<tuple<Ts...>::size()()>{});
}
// Reduce tuple values in specific range using Function
template <index_t Idx, index_t End, typename F, typename... Ts>
CK_TILE_HOST_DEVICE constexpr auto tuple_reduce(F&& f, const tuple<Ts...>& tuple)
{
static_assert(Idx < End, "Wrong parameters for tuple_reduce");
if constexpr(Idx + 1 == End)
{
return tuple.at(number<Idx>{});
}
else
{
return f(tuple.at(number<Idx>{}), tuple_reduce<Idx + 1, End>(f, tuple));
}
}
template <typename T>
using is_tuple = decltype(std::declval<T&>().IsTuple());
template <typename... Ts>
CK_TILE_HOST_DEVICE constexpr auto is_nested_tuple(const tuple<Ts...>&)
{
return (is_detected<is_tuple, Ts>::value || ...);
}
template <index_t depth = 0, typename T>
CK_TILE_HOST_DEVICE constexpr auto tuple_depth(const T&)
{
return depth;
}
template <index_t depth = 0, typename... Ts>
CK_TILE_HOST_DEVICE constexpr auto tuple_depth(const tuple<Ts...>&)
{
return math::max(tuple_depth<depth + 1>(Ts{})...);
}
template <typename... Seqs>
CK_TILE_HOST_DEVICE constexpr auto to_array_of_array(tuple<Seqs...> t_of_s)
{
constexpr index_t n0 = sizeof...(Seqs);
constexpr index_t max_n1 = [&] {
index_t max_n1_ = 0;
static_for<0, n0, 1>{}([&](auto i0) {
constexpr index_t n1 = t_of_s[i0].size()();
max_n1_ = max_n1_ < n1 ? n1 : max_n1_;
});
return max_n1_;
}();
array<array<index_t, max_n1>, n0> a_of_a{{-1}};
static_for<0, n0, 1>{}([&](auto i0) {
constexpr index_t n1 = t_of_s[i0].size()();
static_for<0, n1, 1>{}([&](auto i1) { a_of_a(i0)(i1) = t_of_s[i0][i1]; });
});
return a_of_a;
}
// Here should use MultiIndex<NSize>, instead of tuple<Ys...>, although the former
// is the alias of the latter. This is because compiler cannot infer the NSize if
// using MultiIndex<NSize>
// TODO: how to fix this?
template <typename... Ys,
typename X,
std::enable_if_t<!std::is_integral<X>::value && !std::is_floating_point<X>::value, bool> =
false>
CK_TILE_HOST_DEVICE constexpr auto operator+=(tuple<Ys...>& y, const X& x)
{
static_assert(X::Size() == sizeof...(Ys), "wrong! size not the same");
constexpr index_t NSize = sizeof...(Ys);
static_for<0, NSize, 1>{}([&](auto i) { y[i] += x[i]; });
return y;
}
template <typename... Ys,
typename X,
std::enable_if_t<!std::is_integral<X>::value && !std::is_floating_point<X>::value, bool> =
false>
CK_TILE_HOST_DEVICE constexpr auto operator-=(tuple<Ys...>& y, const X& x)
{
static_assert(X::Size() == sizeof...(Ys), "wrong! size not the same");
constexpr index_t NSize = sizeof...(Ys);
static_for<0, NSize, 1>{}([&](auto i) { y[i] -= x[i]; });
return y;
}
template <typename... Xs,
typename Y,
std::enable_if_t<!std::is_integral<Y>::value && !std::is_floating_point<Y>::value, bool> =
false>
CK_TILE_HOST_DEVICE constexpr auto operator+(const tuple<Xs...>& x, const Y& y)
{
static_assert(Y::Size() == sizeof...(Xs), "wrong! size not the same");
constexpr index_t NSize = sizeof...(Xs);
tuple<Xs...> r;
static_for<0, NSize, 1>{}([&](auto i) { r[i] = x[i] + y[i]; });
return r;
}
template <typename... Xs,
typename Y,
std::enable_if_t<!std::is_integral<Y>::value && !std::is_floating_point<Y>::value, bool> =
false>
CK_TILE_HOST_DEVICE constexpr auto operator-(const tuple<Xs...>& x, const Y& y)
{
static_assert(Y::Size() == sizeof...(Xs), "wrong! size not the same");
constexpr index_t NSize = sizeof...(Xs);
tuple<Xs...> r;
static_for<0, NSize, 1>{}([&](auto i) { r[i] = x[i] - y[i]; });
return r;
}
template <typename... Xs,
typename Y,
std::enable_if_t<!std::is_integral<Y>::value && !std::is_floating_point<Y>::value, bool> =
false>
CK_TILE_HOST_DEVICE constexpr auto operator*(const tuple<Xs...>& x, const Y& y)
{
static_assert(Y::Size() == sizeof...(Xs), "wrong! size not the same");
constexpr index_t NSize = sizeof...(Xs);
tuple<Xs...> r;
static_for<0, NSize, 1>{}([&](auto i) { r[i] = x[i] * y[i]; });
return r;
}
// MultiIndex = scalar * MultiIndex
template <
typename... Xs,
typename Y,
std::enable_if_t<std::is_integral<Y>::value || std::is_floating_point<Y>::value, bool> = false>
CK_TILE_HOST_DEVICE constexpr auto operator*(Y a, const tuple<Xs...>& x)
{
constexpr index_t NSize = sizeof...(Xs);
tuple<Xs...> r;
static_for<0, NSize, 1>{}([&](auto i) { r[i] = a * x[i]; });
return r;
}
// MultiIndex = MultiIndex * scalar
template <
typename... Xs,
typename Y,
std::enable_if_t<std::is_integral<Y>::value || std::is_floating_point<Y>::value, bool> = false>
CK_TILE_HOST_DEVICE constexpr auto operator*(const tuple<Xs...>& x, Y a)
{
return a * x;
}
template <typename... Xs, typename... Ys>
CK_TILE_HOST_DEVICE constexpr auto operator/(const tuple<Xs...>& x, const tuple<Ys...>& y)
{
static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong!");
constexpr index_t NSize = sizeof...(Xs);
return generate_tuple([&](auto i) { return x[i] / y[i]; }, number<NSize>{});
}
} // namespace ck_tile
// WARNING: needed by compiler for C++ structured binding support only, don't use this
namespace std {
template <typename... Ts>
struct tuple_size<ck_tile::tuple<Ts...>> : std::integral_constant<std::size_t, sizeof...(Ts)>
{
};
template <std::size_t I, typename... Ts>
struct tuple_element<I, ck_tile::tuple<Ts...>> : ck_tile::tuple_element<I, ck_tile::tuple<Ts...>>
{
};
template <typename... Ts>
struct tuple_size<const ck_tile::tuple<Ts...>> : std::integral_constant<std::size_t, sizeof...(Ts)>
{
};
template <std::size_t I, typename... Ts>
struct tuple_element<I, const ck_tile::tuple<Ts...>>
: ck_tile::tuple_element<I, const ck_tile::tuple<Ts...>>
{
};
} // namespace std