introducing ck_tile! (#1216)

* enable gfx940

* switch between intrinsic mfma routines on mi100/200 and mi300

* fix mfma_int8 on MI300

* disable 2 int8 examples on MI300

* Update cmake-ck-dev.sh

* restore gitignore file

* modify Jenkinsfile to the internal repo

* Bump rocm-docs-core from 0.24.0 to 0.29.0 in /docs/sphinx

Bumps [rocm-docs-core](https://github.com/RadeonOpenCompute/rocm-docs-core) from 0.24.0 to 0.29.0.
- [Release notes](https://github.com/RadeonOpenCompute/rocm-docs-core/releases)
- [Changelog](https://github.com/RadeonOpenCompute/rocm-docs-core/blob/develop/CHANGELOG.md)
- [Commits](https://github.com/RadeonOpenCompute/rocm-docs-core/compare/v0.24.0...v0.29.0)

---
updated-dependencies:
- dependency-name: rocm-docs-core
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>

* initial enablement of gfx950

* fix clang format

* disable examples 31 and 41 int8 on gfx950

* add code

* fix build wip

* fix xx

* now can build

* naming

* minor fix

* wip fix

* fix macro for exp2; fix warpgemm a/b in transposedC

* unify as tuple_array

* Update the required Python version to 3.9

* Update executable name in test scripts

* re-structure tuple/array to avoid spill

* Merge function templates

* Fix format

* Add constraint to array<> ctor

* Re-use function

* Some minor changes

* remove wrong code in store_raw()

* fix compile issue in transpose

* Rename enum
Rename 'cood_transform_enum' to 'coord_transform_enum'

* let more integral_constant->constant, and formating

* make sure thread_buffer can be tuple/array

* temp fix buffer_store spill

* not using custom data type by default, now we can have ISA-level same code as opt_padding

* fix compile error, fp8 not ready now

* fix fp8 duplicated move/shift/and/or problem

* Default use CK_TILE_FLOAT_TO_FP8_STOCHASTIC rounding mode

* fix scratch in fp8 kernel

* update some readme

* fix merge from upstream

* sync with upstream

* sync upstream again

* sync 22

* remove unused

* fix clang-format

* update README of ck_tile example

* fix several issue

* let python version to be 3.8 as minimal

* remove ck_tile example from default cmake target like all/install/check

* remove mistake

* 1).support receipe in generate.py 2).use simplified mask type 3).change left/right to pass into karg

* fix some bug in group-mode masking and codegen. update README

* F8 quantization for FMHA forward (#1224)

* Add SAccElementFunction, PComputeElementFunction, OAccElementFunction in pipeline

* Add element function to fmha api

* Adjust P elementwise function

* Fix bug of elementwise op, our elementwise op is not inout

* Add some elementwise op, prepare to quantization

* Let generate.py can generate different elementwise function

* To prevent compiler issue, remove the elementwise function we have not used.

* Remove f8 pipeline, we should share the same pipeline even in f8

* Remove remove_cvref_t

* Avoid warning

* Fix wrong fp8 QK/KV block gemm setting

* Check fp8 rounding error in check_err()

* Set fp8 rounding error for check_err()

* Use CK_TILE_FLOAT_TO_FP8_STANDARD as default fp8 rounding mode

* 1. codgen the f8 api and kernel
2. f8 host code

* prevent warning in filter mode

* Remove not-in-use elementwise function kargs

* Remove more not-in-use elementwise function kargs

* Small refinements in C++ source files

* Use conditional_t<> to simplify code

* Support heterogeneous argument for binary function types

* Re-use already-existing scales<> functor template

* Fix wrong value produced by saturating

* Generalize the composes<> template

* Unify saturates<> implementation

* Fix type errors in composes<>

* Extend less_equal<>

* Reuse the existing template less_equal<> in check_err()

* Add equal<float> & equal<double>

* Rename check_err() parameter

* Rename check_err() parameter

* Add FIXME comment for adding new macro in future

* Remove unnecessary cast to void

* Eliminate duplicated code

* Avoid dividing api pool into more than 2 groups

* Use more clear variable names

* Use affirmative condition in if stmt

* Remove blank lines

* Donot perfect forwarding in composes<>

* To fix compile error, revert generate.py back to 4439cc107d

* Fix bug of p element function

* Add compute element op to host softmax

* Remove element function in api interface

* Extract user parameter

* Rename pscale and oscale variable

* rename f8 to fp8

* rename more f8 to fp8

* Add pipeline::operator() without element_functor

* 1. Remove deprecated pipeline enum
2. Refine host code parameter

* Use quantization range as input

* 1. Rename max_dtype to dtype_max.
2. Rename scale to scale_s
3.Add init description

* Refine description

* prevent early return

* unify _squant kernel name in cpp, update README

* Adjust the default range.

* Refine error message and bias range

* Add fp8 benchmark and smoke test

* fix fp8 swizzle_factor=4 case

---------

Co-authored-by: Po Yen Chen <PoYen.Chen@amd.com>
Co-authored-by: carlushuang <carlus.huang@amd.com>

---------

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: illsilin <Illia.Silin@amd.com>
Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
Co-authored-by: Jing Zhang <jizha@amd.com>
Co-authored-by: zjing14 <zhangjing14@gmail.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Po-Yen, Chen <PoYen.Chen@amd.com>
Co-authored-by: rocking <ChunYu.Lai@amd.com>
This commit is contained in:
carlushuang
2024-04-16 08:27:12 +08:00
committed by GitHub
parent dd34ab6e64
commit db376dd8a4
141 changed files with 30623 additions and 2 deletions

View File

@@ -0,0 +1,251 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <initializer_list>
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include "ck_tile/core/utility/functional.hpp"
namespace ck_tile {
// use aggregate initialization for this type
// e.g. array<index_t, 4> buf {0}; => {0, 0, 0, 0}, clean
// array<index_t, 4> buf {3, 2}; => {3, 2, 2, 2} (not {3,2,0,0})
// use make_array_with({...}) to construct an array with compatible behavior as old ck
// TODO: manually added constructor same as old ck
template <typename T_, index_t N_>
struct array
{
using value_type = T_;
static constexpr index_t N = N_;
// TODO: do we need this?
// using bulk_type = uint8_t __attribute__((ext_vector_type(N * sizeof(value_type))));
// union {
value_type data[N];
// bulk_type __content;
//};
CK_TILE_HOST_DEVICE constexpr array() : data{} {}
// TODO: will initialize the data[] with the last value repeatedly
// behavior different from std
CK_TILE_HOST_DEVICE constexpr array(std::initializer_list<value_type> ilist)
{
constexpr index_t list_size = std::initializer_list<value_type>{}.size();
static_assert(list_size <= N, "out of bound");
index_t i = 0;
value_type vlast = value_type{};
for(const value_type& val : ilist)
{
data[i] = val;
vlast = val;
++i;
}
for(; i < N; ++i)
{
data[i] = vlast;
}
}
template <typename Y,
typename = std::enable_if_t<std::is_convertible_v<Y, value_type> ||
std::is_constructible_v<Y, value_type>>>
CK_TILE_HOST_DEVICE explicit constexpr array(Y c)
{
for(auto i = 0; i < size(); i++)
data[i] = static_cast<value_type>(c);
}
// template <typename Y>
// CK_TILE_HOST_DEVICE constexpr array(const array& o)
// {
// // static_assert(ArrayType::size() == size(), "wrong! size not the same");
// __content = o.__content;
// }
// CK_TILE_HOST_DEVICE constexpr array& operator=(const array& o)
// {
// // static_assert(ArrayType::size() == size(), "wrong! size not the same");
// __content = o.__content;
// return *this;
// }
CK_TILE_HOST_DEVICE static constexpr auto size() { return N; }
CK_TILE_HOST_DEVICE static constexpr bool is_static() { return is_static_v<value_type>; }
// clang-format off
CK_TILE_HOST_DEVICE constexpr auto& get() { return data; }
CK_TILE_HOST_DEVICE constexpr const auto& get() const { return data; }
CK_TILE_HOST_DEVICE constexpr auto& get(index_t i) { return data[i]; }
CK_TILE_HOST_DEVICE constexpr const auto& get(index_t i) const { return data[i]; }
template <index_t I> CK_TILE_HOST_DEVICE constexpr auto& get() { return data[I]; }
template <index_t I> CK_TILE_HOST_DEVICE constexpr const auto& get() const { return data[I]; }
template <index_t I> CK_TILE_HOST_DEVICE constexpr auto& get(number<I>) { return data[I]; }
template <index_t I> CK_TILE_HOST_DEVICE constexpr const auto& get(number<I>) const { return data[I]; }
CK_TILE_HOST_DEVICE constexpr auto& at(index_t i) { return get(i); }
CK_TILE_HOST_DEVICE constexpr const auto& at(index_t i) const { return get(i); }
template <index_t I> CK_TILE_HOST_DEVICE constexpr auto& at() { return get(I); }
template <index_t I> CK_TILE_HOST_DEVICE constexpr const auto& at() const { return get(I); }
template <index_t I> CK_TILE_HOST_DEVICE constexpr auto& at(number<I>) { return get(I); }
template <index_t I> CK_TILE_HOST_DEVICE constexpr const auto& at(number<I>) const { return get(I); }
CK_TILE_HOST_DEVICE constexpr const value_type& operator[](index_t i) const { return get(i); }
CK_TILE_HOST_DEVICE constexpr value_type& operator[](index_t i) { return get(i); }
CK_TILE_HOST_DEVICE constexpr value_type& operator()(index_t i) { return get(i); } // TODO: compatible
#if 0
template <typename ArrayLike>
CK_TILE_HOST_DEVICE constexpr auto operator=(const ArrayLike& arr)
{
static_assert(ArrayLike::size() == size(), "wrong! size not the same");
for(index_t i = 0; i < size(); ++i)
{
data[i] = arr[i];
}
return *this;
}
#endif
// type punning (strict aliasing) member functions for read/write
// aliasing this array of type "T", "N" elements
// as array of type "Tx", sizeof(T)*N/sizeof(Tx) elements
#define AR_AS_COM_() \
static_assert(sizeof(value_type) * N % sizeof(Tx) == 0); \
constexpr int vx = sizeof(value_type) * N / sizeof(Tx)
template <typename Tx> CK_TILE_HOST_DEVICE constexpr auto& get_as()
{ AR_AS_COM_(); return reinterpret_cast<array<Tx, vx>&>(data); }
template <typename Tx> CK_TILE_HOST_DEVICE constexpr const auto& get_as() const
{ AR_AS_COM_(); return reinterpret_cast<const array<Tx, vx>&>(data); }
// below index is for index *AFTER* type convert, not before
template <typename Tx> CK_TILE_HOST_DEVICE constexpr auto& get_as(index_t i)
{ AR_AS_COM_(); return reinterpret_cast<array<Tx, vx>&>(data).at(i); }
template <typename Tx> CK_TILE_HOST_DEVICE constexpr const auto& get_as(index_t i) const
{ AR_AS_COM_(); return reinterpret_cast<const array<Tx, vx>&>(data).at(i); }
template <typename Tx, index_t I> CK_TILE_HOST_DEVICE constexpr auto& get_as(number<I>)
{ AR_AS_COM_(); return reinterpret_cast<array<Tx, vx>&>(data).at(number<I>{}); }
template <typename Tx, index_t I> CK_TILE_HOST_DEVICE constexpr const auto& get_as(number<I>) const
{ AR_AS_COM_(); return reinterpret_cast<const array<Tx, vx>&>(data).at(number<I>{}); }
template <typename Tx> CK_TILE_HOST_DEVICE constexpr void set_as(index_t i, const Tx & x)
{ AR_AS_COM_(); reinterpret_cast<array<Tx, vx>&>(data).at(i) = x; }
template <typename Tx, index_t I> CK_TILE_HOST_DEVICE constexpr void set_as(number<I>, const Tx & x)
{ AR_AS_COM_(); reinterpret_cast<array<Tx, vx>&>(data).at(number<I>{}) = x; }
#undef AR_AS_COM_
// clang-format on
};
// empty Array
template <typename T>
struct array<T, 0>
{
using value_type = T;
CK_TILE_HOST_DEVICE constexpr array() {}
CK_TILE_HOST_DEVICE static constexpr index_t size() { return 0; }
CK_TILE_HOST_DEVICE static constexpr bool is_static() { return is_static_v<T>; };
CK_TILE_HOST_DEVICE void print() const { printf("array{size: 0, data: []}"); }
};
template <typename>
struct vector_traits;
// specialization for array
template <typename T, index_t N>
struct vector_traits<array<T, N>>
{
using scalar_type = T;
static constexpr index_t vector_size = N;
};
namespace details {
template <class>
struct is_ref_wrapper : std::false_type
{
};
template <class T>
struct is_ref_wrapper<std::reference_wrapper<T>> : std::true_type
{
};
template <class T>
using not_ref_wrapper = std::negation<is_ref_wrapper<std::decay_t<T>>>;
template <class D, class...>
struct return_type_helper
{
using type = D;
};
template <class... Ts>
struct return_type_helper<void, Ts...> : std::common_type<Ts...>
{
static_assert(std::conjunction_v<not_ref_wrapper<Ts>...>,
"Ts cannot contain reference_wrappers when D is void");
};
template <class D, class... Ts>
using return_type = array<typename return_type_helper<D, Ts...>::type, sizeof...(Ts)>;
} // namespace details
template <typename D = void, typename... Ts>
CK_TILE_HOST_DEVICE constexpr details::return_type<D, Ts...> make_array(Ts&&... ts)
{
return {std::forward<Ts>(ts)...};
}
// // make empty array
// template <typename T>
// CK_TILE_HOST_DEVICE constexpr auto make_array()
// {
// return array<T, 0>{};
// }
// compatible with old ck's initializer, make an array and fill it withe the last element from
// initializer_list
template <typename T, index_t Size>
CK_TILE_HOST_DEVICE constexpr auto make_array_with(std::initializer_list<T> ilist)
{
return array<T, Size>(ilist);
}
template <typename T, index_t Size>
CK_TILE_HOST_DEVICE constexpr bool operator==(const array<T, Size>& a, const array<T, Size>& b)
{
bool same = true;
for(index_t i = 0; i < Size; ++i)
{
if(a[i] != b[i])
{
same = false;
break;
}
}
return same;
}
template <typename T, index_t Size>
CK_TILE_HOST_DEVICE constexpr bool operator!=(const array<T, Size>& a, const array<T, Size>& b)
{
return !(a == b);
}
template <typename T, index_t N, typename X>
CK_TILE_HOST_DEVICE constexpr auto to_array(const X& x)
{
static_assert(N <= X::size(), "");
array<T, N> arr;
static_for<0, N, 1>{}([&x, &arr](auto i) { arr(i) = x[i]; });
return arr;
}
} // namespace ck_tile

View File

@@ -0,0 +1,499 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/array.hpp"
#include "ck_tile/core/container/map.hpp"
#include "ck_tile/core/container/sequence.hpp"
#include "ck_tile/core/container/tuple.hpp"
#include "ck_tile/core/utility/functional.hpp"
namespace ck_tile {
template <typename TData, index_t NSize>
CK_TILE_HOST_DEVICE constexpr auto container_push_back(const array<TData, NSize>& a, const TData& x)
{
array<TData, NSize + 1> r;
static_for<0, NSize, 1>{}([&r, &a ](auto i) constexpr { r(i) = a[i]; });
r[number<NSize>{}] = x;
return r;
}
template <typename... Ts, typename T>
CK_TILE_HOST_DEVICE constexpr auto container_push_front(const tuple<Ts...>& a, const T& x)
{
return container_concat(make_tuple(x), a);
}
template <typename... Ts, typename T>
CK_TILE_HOST_DEVICE constexpr auto container_push_back(const tuple<Ts...>& a, const T& x)
{
return container_concat(a, make_tuple(x));
}
// reorder array
template <typename TData, index_t NSize, index_t... IRs>
CK_TILE_HOST_DEVICE constexpr auto
container_reorder_given_new2old(const array<TData, NSize>& old_array, sequence<IRs...> /*new2old*/)
{
static_assert(NSize == sizeof...(IRs), "wrong! size not consistent");
static_assert(is_valid_sequence_map<sequence<IRs...>>{}, "wrong! invalid reorder map");
return make_array<remove_cvref_t<TData>>(old_array[IRs]...);
}
template <typename TData, index_t NSize, index_t... IRs>
CK_TILE_HOST_DEVICE constexpr auto
container_reorder_given_old2new(const array<TData, NSize>& old_array, sequence<IRs...> old2new)
{
return container_reorder_given_new2old(
old_array, typename sequence_map_inverse<decltype(old2new)>::type{});
}
// reorder array
template <typename TData, index_t NSize>
CK_TILE_HOST_DEVICE constexpr auto
container_reorder_given_new2old(const array<TData, NSize>& old_array,
const map<index_t, index_t>& new2old)
{
array<TData, NSize> new_array;
for(const auto& [new_pos, old_pos] : new2old)
{
new_array(new_pos) = old_array[old_pos];
}
return new_array;
}
template <typename TData, index_t NSize>
CK_TILE_HOST_DEVICE constexpr auto
container_reorder_given_old2new(const array<TData, NSize>& old_array,
const map<index_t, index_t>& old2new)
{
array<TData, NSize> new_array;
for(const auto& [old_pos, new_pos] : old2new)
{
new_array(new_pos) = old_array[old_pos];
}
return new_array;
}
// reorder tuple
template <typename... Ts, index_t... IRs>
CK_TILE_HOST_DEVICE constexpr auto container_reorder_given_new2old(const tuple<Ts...>& old_tuple,
sequence<IRs...> /*new2old*/)
{
static_assert(sizeof...(Ts) == sizeof...(IRs), "wrong! size not consistent");
static_assert(is_valid_sequence_map<sequence<IRs...>>{}, "wrong! invalid reorder map");
return make_tuple(old_tuple[number<IRs>{}]...);
}
template <typename... Ts, index_t... IRs>
CK_TILE_HOST_DEVICE constexpr auto container_reorder_given_old2new(const tuple<Ts...>& old_tuple,
sequence<IRs...> old2new)
{
return container_reorder_given_new2old(
old_tuple, typename sequence_map_inverse<decltype(old2new)>::type{});
}
// reorder sequence
template <index_t... Is, index_t... IRs>
CK_TILE_HOST_DEVICE constexpr auto container_reorder_given_new2old(sequence<Is...> /* old_seq */,
sequence<IRs...> /*new2old*/)
{
static_assert(sizeof...(Is) == sizeof...(IRs), "wrong! size not consistent");
static_assert(is_valid_sequence_map<sequence<IRs...>>{}, "wrong! invalid reorder map");
return sequence<sequence<Is...>::at(number<IRs>{})...>{};
}
template <index_t... Is, index_t... IRs>
CK_TILE_HOST_DEVICE constexpr auto container_reorder_given_old2new(sequence<Is...> old_seq,
sequence<IRs...> /* old2new */)
{
static_assert(sizeof...(Is) == sizeof...(IRs), "wrong! size not consistent");
static_assert(is_valid_sequence_map<sequence<IRs...>>{}, "wrong! invalid reorder map");
constexpr auto new2old = typename sequence_map_inverse<sequence<IRs...>>::type{};
return container_reorder_given_new2old(old_seq, new2old);
}
#if 0
// rocm-4.1 compiler would crash for recursive lambda
template <typename Container,
typename Reduce,
typename Init,
index_t IBegin = 0,
index_t IEnd = Container::size(),
index_t IStep = 1>
CK_TILE_HOST_DEVICE constexpr auto container_reduce(const Container& x,
Reduce reduce,
Init init,
number<IBegin> = number<0>{},
number<IEnd> = number<Container::size()>{},
number<IStep> = number<1>{})
{
static_assert((IEnd - IBegin) % IStep == 0, "wrong!");
// f is recursive function, fs is a dummy of f
// i is index, y_old is current scan, r_old is current reduction
auto f = [&](auto fs, auto i, auto r_old) {
auto r_new = reduce(x[i], r_old);
if constexpr(i.value < IEnd - IStep)
{
// recursively call f/fs
return fs(fs, i + number<IStep>{}, r_new);
}
else
{
return r_new;
}
};
// start recursion
return f(f, number<IBegin>{}, init);
}
#else
// i is index, y_old is current scan, r_old is current reduction
template <typename Container,
typename Reduce,
typename ROld,
index_t I,
index_t IEnd,
index_t IStep>
CK_TILE_HOST_DEVICE constexpr auto container_reduce_impl(
const Container& x, Reduce reduce, ROld r_old, number<I> i, number<IEnd>, number<IStep>)
{
auto r_new = reduce(x[i], r_old);
if constexpr(i.value < IEnd - IStep)
{
return container_reduce_impl(
x, reduce, r_new, i + number<IStep>{}, number<IEnd>{}, number<IStep>{});
}
else
{
return r_new;
}
}
// rocm-4.1 compiler would crash for recursive lambda
// container reduce with initial value
template <typename Container,
typename Reduce,
typename Init,
index_t IBegin = 0,
index_t IEnd = Container::size(),
index_t IStep = 1>
CK_TILE_HOST_DEVICE constexpr auto container_reduce(const Container& x,
Reduce reduce,
Init init,
number<IBegin> = number<0>{},
number<IEnd> = number<Container::size()>{},
number<IStep> = number<1>{})
{
static_assert((IEnd - IBegin) % IStep == 0, "wrong!");
if constexpr(IEnd > IBegin)
{
return container_reduce_impl(
x, reduce, init, number<IBegin>{}, number<IEnd>{}, number<IStep>{});
}
else
{
return init;
}
}
#endif
template <typename TData, index_t NSize, typename Reduce>
CK_TILE_HOST_DEVICE constexpr auto
container_reverse_inclusive_scan(const array<TData, NSize>& x, Reduce f, TData init)
{
array<TData, NSize> y;
TData r = init;
static_for<NSize - 1, 0, -1>{}([&](auto i) {
r = f(r, x[i]);
y(i) = r;
});
r = f(r, x[number<0>{}]);
y(number<0>{}) = r;
return y;
}
template <typename TData, index_t NSize, typename Reduce, typename Init>
CK_TILE_HOST_DEVICE constexpr auto
container_reverse_exclusive_scan(const array<TData, NSize>& x, Reduce f, Init init)
{
#if 0
array<TData, NSize> y;
TData r = init;
static_for<NSize - 1, 0, -1>{}([&](auto i) {
y(i) = r;
r = f(r, x[i]);
});
y(number<0>{}) = r;
return y;
#else
array<TData, NSize> y;
TData r = init;
for(index_t i = NSize - 1; i > 0; --i)
{
y(i) = r;
r = f(r, x[i]);
}
y(0) = r;
return y;
#endif
}
template <index_t... Is, typename Reduce, index_t Init>
CK_TILE_HOST_DEVICE constexpr auto
container_reverse_exclusive_scan(const sequence<Is...>& seq, Reduce f, number<Init>)
{
return reverse_exclusive_scan_sequence(seq, f, number<Init>{});
}
#if 0
// rocm4.1 compiler would crash with recursive lambda
template <typename... Xs, typename Reduce, typename Init>
CK_TILE_HOST_DEVICE constexpr auto
container_reverse_exclusive_scan(const tuple<Xs...>& x, Reduce reduce, Init init)
{
constexpr index_t NSize = sizeof...(Xs);
// f is recursive function, fs is a dummy of f
// i is index, y_old is current scan, r_old is current reduction
auto f = [&](auto fs, auto i, auto y_old, auto r_old) {
auto r_new = reduce(x[i], r_old);
auto y_new = container_push_front(y_old, r_new);
if constexpr(i.value > 1)
{
// recursively call f/fs
return fs(fs, i - number<1>{}, y_new, r_new);
}
else
{
return y_new;
}
};
// start recursion
return f(f, number<NSize - 1>{}, make_tuple(init), init);
}
#else
// i is index, y_old is current scan, r_old is current reduction
template <typename... Xs, typename Reduce, index_t I, typename YOld, typename ROld>
CK_TILE_HOST_DEVICE constexpr auto container_reverse_exclusive_scan_impl(
const tuple<Xs...>& x, Reduce reduce, number<I> i, YOld y_old, ROld r_old)
{
auto r_new = reduce(x[i], r_old);
auto y_new = container_push_front(y_old, r_new);
if constexpr(i.value > 1)
{
// recursively call f/fs
return container_reverse_exclusive_scan_impl(x, reduce, i - number<1>{}, y_new, r_new);
}
else
{
return y_new;
}
}
template <typename... Xs, typename Reduce, typename Init>
CK_TILE_HOST_DEVICE constexpr auto
container_reverse_exclusive_scan(const tuple<Xs...>& x, Reduce reduce, Init init)
{
constexpr index_t NSize = sizeof...(Xs);
return container_reverse_exclusive_scan_impl(
x, reduce, number<NSize - 1>{}, make_tuple(init), init);
}
#endif
// TODO: update to like container_reverse_exclusive_scan to deal with tuple of Numebr<>
template <typename... Xs, typename Reduce, typename TData>
CK_TILE_HOST_DEVICE constexpr auto
container_reverse_inclusive_scan(const tuple<Xs...>& x, Reduce f, TData init)
{
constexpr index_t NSize = sizeof...(Xs);
tuple<Xs...> y;
TData r = init;
static_for<NSize - 1, 0, -1>{}([&](auto i) {
r = f(r, x[i]);
y(i) = r;
});
r = f(r, x[number<0>{}]);
y(number<0>{}) = r;
return y;
}
template <typename X, typename... Ys>
CK_TILE_HOST_DEVICE constexpr auto container_concat(const X& x, const Ys&... ys)
{
return container_concat(x, container_concat(ys...));
}
template <typename T, index_t NX, index_t NY>
CK_TILE_HOST_DEVICE constexpr auto container_concat(const array<T, NX>& ax, const array<T, NY>& ay)
{
return unpack2(
[&](auto&&... zs) { return make_array<T>(std::forward<decltype(zs)>(zs)...); }, ax, ay);
}
template <typename... X, typename... Y>
CK_TILE_HOST_DEVICE constexpr auto container_concat(const tuple<X...>& tx, const tuple<Y...>& ty)
{
return unpack2(
[&](auto&&... zs) { return make_tuple(std::forward<decltype(zs)>(zs)...); }, tx, ty);
}
template <typename Container>
CK_TILE_HOST_DEVICE constexpr auto container_concat(const Container& x)
{
return x;
}
template <typename T, index_t N, index_t... Is>
CK_TILE_HOST_DEVICE constexpr auto get_container_subset(const array<T, N>& arr, sequence<Is...>)
{
static_assert(N >= sizeof...(Is), "wrong! size");
if constexpr(sizeof...(Is) > 0)
{
return make_array<T>(arr[Is]...);
}
else
{
return array<T, 0>{};
}
}
template <typename... Ts, index_t... Is>
CK_TILE_HOST_DEVICE constexpr auto get_container_subset(const tuple<Ts...>& tup, sequence<Is...>)
{
static_assert(sizeof...(Ts) >= sizeof...(Is), "wrong! size");
if constexpr(sizeof...(Is) > 0)
{
return make_tuple(tup[number<Is>{}]...);
}
else
{
return tuple<>{};
}
}
template <typename T, index_t N, index_t... Is>
CK_TILE_HOST_DEVICE constexpr void
set_container_subset(array<T, N>& y, sequence<Is...> picks, const array<T, sizeof...(Is)>& x)
{
static_assert(N >= sizeof...(Is), "wrong! size");
if constexpr(sizeof...(Is) > 0)
{
for(index_t i = 0; i < picks.size(); ++i)
{
y(picks[i]) = x[i];
}
}
}
template <typename Y, typename X, index_t... Is>
CK_TILE_HOST_DEVICE constexpr void set_container_subset(Y& y, sequence<Is...> picks, const X& x)
{
static_assert(Y::size() >= sizeof...(Is) && X::size() == sizeof...(Is), "wrong! size");
if constexpr(sizeof...(Is) > 0)
{
static_for<0, sizeof...(Is), 1>{}([&](auto i) { y(picks[i]) = x[i]; });
}
}
// return the index of first occurance in the sequence.
// return seq.size(), if not found
template <index_t... Is>
constexpr index_t container_find(sequence<Is...> seq, index_t value)
{
for(auto i = 0; i < seq.size(); i++)
{
if(seq[i] == value)
return i;
}
return seq.size();
}
template <index_t... Is>
CK_TILE_HOST_DEVICE constexpr auto sequence_to_tuple_of_number(sequence<Is...>)
{
using Seq = sequence<Is...>;
return generate_tuple(
[&](auto i) {
constexpr index_t tmp = Seq::at(i);
return number<tmp>{};
},
number<Seq::size()>{});
}
#if 0
#define TO_TUPLE_OF_SEQUENCE(a_of_b_impl, a_size, bs_sizes) \
[a_of_b_impl, a_size, bs_sizes] { \
return ck_tile::generate_tuple( \
[=](auto i) { \
constexpr auto b_impl = a_of_b_impl[i]; \
constexpr index_t b_size = bs_sizes[i]; \
constexpr auto b = TO_SEQUENCE(b_impl, b_size); \
return b; \
}, \
ck_tile::number<a_size>{}); \
}()
#else
// constexpr index_t can't be captured "-Wunused-lambda-capture"
// TODO: this is ugly
#define TO_TUPLE_OF_SEQUENCE(a_of_b_impl, a_size, bs_sizes) \
[a_of_b_impl, bs_sizes] { \
return ck_tile::generate_tuple( \
[=](auto i) { \
constexpr auto b_impl = a_of_b_impl[i]; \
constexpr index_t b_size = bs_sizes[i]; \
constexpr auto b = TO_SEQUENCE(b_impl, b_size); \
return b; \
}, \
ck_tile::number<a_size>{}); \
}()
#endif
} // namespace ck_tile

View File

@@ -0,0 +1,164 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/array.hpp"
#include "ck_tile/core/container/sequence.hpp"
#include "ck_tile/core/container/tuple.hpp"
namespace ck_tile {
// naive map
template <typename key, typename data, index_t max_size = 128>
struct map
{
using pair_type = tuple<key, data>;
using impl_type = array<pair_type, max_size>;
impl_type impl_;
index_t size_;
struct iterator
{
impl_type& impl_;
index_t pos_;
CK_TILE_HOST_DEVICE constexpr iterator(impl_type& impl, index_t pos)
: impl_{impl}, pos_{pos}
{
}
CK_TILE_HOST_DEVICE constexpr iterator& operator++()
{
pos_++;
return *this;
}
CK_TILE_HOST_DEVICE constexpr bool operator!=(const iterator& other) const
{
return other.pos_ != pos_;
}
CK_TILE_HOST_DEVICE constexpr pair_type& operator*() { return impl_.at(pos_); }
};
struct const_iterator
{
const impl_type& impl_;
index_t pos_;
CK_TILE_HOST_DEVICE constexpr const_iterator(const impl_type& impl, index_t pos)
: impl_{impl}, pos_{pos}
{
}
CK_TILE_HOST_DEVICE constexpr const_iterator& operator++()
{
pos_++;
return *this;
}
CK_TILE_HOST_DEVICE constexpr bool operator!=(const const_iterator& other) const
{
return other.pos_ != pos_;
}
CK_TILE_HOST_DEVICE constexpr const pair_type& operator*() const { return impl_.at(pos_); }
};
CK_TILE_HOST_DEVICE constexpr map() : impl_{}, size_{0} {}
CK_TILE_HOST_DEVICE constexpr index_t size() const { return size_; }
CK_TILE_HOST_DEVICE void clear() { size_ = 0; }
CK_TILE_HOST_DEVICE constexpr index_t find_position(const key& k) const
{
for(index_t i = 0; i < size(); i++)
{
if(impl_[i].template at<0>() == k)
{
return i;
}
}
return size_;
}
CK_TILE_HOST_DEVICE constexpr const_iterator find(const key& k) const
{
return const_iterator{impl_, find_position(k)};
}
CK_TILE_HOST_DEVICE constexpr iterator find(const key& k)
{
return iterator{impl_, find_position(k)};
}
CK_TILE_HOST_DEVICE constexpr const data& operator[](const key& k) const
{
const auto it = find(k);
// FIXME
// assert(it.pos_ < size());
return impl_[it.pos_].template at<1>();
}
CK_TILE_HOST_DEVICE constexpr data& operator()(const key& k)
{
auto it = find(k);
// if entry not found
if(it.pos_ == size())
{
impl_(it.pos_).template at<0>() = k;
size_++;
}
// FIXME
// assert(size_ <= max_size);
return impl_(it.pos_).template at<1>();
}
// WARNING: needed by compiler for C++ range-based for loop only, don't use this function!
CK_TILE_HOST_DEVICE constexpr const_iterator begin() const { return const_iterator{impl_, 0}; }
// WARNING: needed by compiler for C++ range-based for loop only, don't use this function!
CK_TILE_HOST_DEVICE constexpr const_iterator end() const
{
return const_iterator{impl_, size_};
}
// WARNING: needed by compiler for C++ range-based for loop only, don't use this function!
CK_TILE_HOST_DEVICE constexpr iterator begin() { return iterator{impl_, 0}; }
// WARNING: needed by compiler for C++ range-based for loop only, don't use this function!
CK_TILE_HOST_DEVICE constexpr iterator end() { return iterator{impl_, size_}; }
CK_TILE_HOST_DEVICE void print() const
{
printf("map{size_: %d, ", size_);
//
printf("impl_: [");
//
for(const auto& [k, d] : *this)
{
printf("{key: ");
print(k);
printf(", data: ");
print(d);
printf("}, ");
}
//
printf("]");
//
printf("}");
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,99 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/array.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include <cstddef>
namespace ck_tile {
// TODO: this structure is not intented to be used by user
template <index_t MaxSize>
struct meta_data_buffer
{
CK_TILE_HOST_DEVICE constexpr meta_data_buffer() : buffer_{}, size_{0} {}
template <typename X, typename... Xs>
CK_TILE_HOST_DEVICE constexpr meta_data_buffer(const X& x, const Xs&... xs)
: buffer_{}, size_{0}
{
push(x, xs...);
}
template <typename T>
CK_TILE_HOST_DEVICE constexpr void push(const T& data)
{
if constexpr(!std::is_empty_v<T>)
{
constexpr index_t size = sizeof(T);
auto tmp = bit_cast<array<std::byte, size>>(data);
for(int i = 0; i < size; i++)
{
buffer_(size_) = tmp[i];
size_++;
}
}
}
template <typename X, typename... Xs>
CK_TILE_HOST_DEVICE constexpr void push(const X& x, const Xs&... xs)
{
push(x);
push(xs...);
}
template <typename T>
CK_TILE_HOST_DEVICE constexpr T pop(index_t& pos) const
{
T data;
if constexpr(!std::is_empty_v<T>)
{
constexpr index_t size = sizeof(T);
array<std::byte, size> tmp;
for(int i = 0; i < size; i++)
{
tmp(i) = buffer_[pos];
pos++;
}
data = bit_cast<T>(tmp);
}
return data;
}
template <typename T>
CK_TILE_HOST_DEVICE constexpr T get(index_t pos) const
{
constexpr index_t size = sizeof(T);
array<std::byte, size> tmp;
for(int i = 0; i < size; i++)
{
tmp(i) = buffer_[pos];
pos++;
}
auto data = bit_cast<T>(tmp);
return data;
}
//
array<std::byte, MaxSize> buffer_;
index_t size_ = 0;
};
} // namespace ck_tile

View File

@@ -0,0 +1,100 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/array.hpp"
#include "ck_tile/core/container/sequence.hpp"
#include "ck_tile/core/container/tuple.hpp"
#include "ck_tile/core/utility/functional.hpp"
namespace ck_tile {
// Don't use tihs directly. This is for old CK's internal usage,
// in the future always use array instead
template <index_t N>
using multi_index = array<index_t, N>;
template <typename... Xs>
CK_TILE_HOST_DEVICE constexpr auto make_multi_index(Xs&&... xs)
{
return make_array<index_t>(index_t{xs}...);
}
template <index_t NSize>
CK_TILE_HOST_DEVICE constexpr auto make_zero_multi_index()
{
return unpack([](auto... xs) { return make_multi_index(xs...); },
typename uniform_sequence_gen<NSize, 0>::type{});
}
template <typename T>
CK_TILE_HOST_DEVICE constexpr auto to_multi_index(const T& x)
{
return unpack([](auto... ys) { return make_multi_index(ys...); }, x);
}
template <index_t NSize, typename X>
CK_TILE_HOST_DEVICE constexpr auto operator+=(multi_index<NSize>& y, const X& x)
{
static_assert(X::size() == NSize, "wrong! size not the same");
static_for<0, NSize, 1>{}([&](auto i) { y[i] += x[i]; });
return y;
}
template <index_t NSize, typename X>
CK_TILE_HOST_DEVICE constexpr auto operator-=(multi_index<NSize>& y, const X& x)
{
static_assert(X::size() == NSize, "wrong! size not the same");
static_for<0, NSize, 1>{}([&](auto i) { y[i] -= x[i]; });
return y;
}
template <index_t NSize, typename T>
CK_TILE_HOST_DEVICE constexpr auto operator+(const multi_index<NSize>& a, const T& b)
{
using type = multi_index<NSize>;
static_assert(T::size() == NSize, "wrong! size not the same");
type r;
static_for<0, NSize, 1>{}([&](auto i) { r[i] = a[i] + b[i]; });
return r;
}
template <index_t NSize, typename T>
CK_TILE_HOST_DEVICE constexpr auto operator-(const multi_index<NSize>& a, const T& b)
{
using type = multi_index<NSize>;
static_assert(T::size() == NSize, "wrong! size not the same");
type r;
static_for<0, NSize, 1>{}([&](auto i) { r[i] = a[i] - b[i]; });
return r;
}
template <index_t NSize, typename T>
CK_TILE_HOST_DEVICE constexpr auto operator*(const multi_index<NSize>& a, const T& b)
{
using type = multi_index<NSize>;
static_assert(T::size() == NSize, "wrong! size not the same");
type r;
static_for<0, NSize, 1>{}([&](auto i) { r[i] = a[i] * b[i]; });
return r;
}
// multi_index = index_t * multi_index
template <index_t NSize>
CK_TILE_HOST_DEVICE constexpr auto operator*(index_t a, const multi_index<NSize>& x)
{
multi_index<NSize> r;
static_for<0, NSize, 1>{}([&](auto i) { r[i] = a * x[i]; });
return r;
}
// multi_index = multi_index * index_t
template <index_t NSize>
CK_TILE_HOST_DEVICE constexpr auto operator*(const multi_index<NSize>& x, index_t a)
{
return a * x;
}
} // namespace ck_tile

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,78 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include <cstddef>
#include <array>
#include <type_traits>
namespace ck_tile {
// implement the c++20 std::span, lightweight, non-owning reference to a sequence
// weather it is dynamic or static range. Or can be seen as a view of a contiguous sequence
// TODO: do we need in device consider this is pointer?
template <typename T>
class span
{
public:
using element_type = T;
using value_type = std::remove_cv_t<element_type>;
using size_type = std::size_t;
using difference_type = std::ptrdiff_t;
using pointer = element_type*;
using const_pointer = const element_type*;
using reference = element_type&;
using const_reference = const element_type&;
using iterator = pointer;
using const_iterator = pointer;
CK_TILE_HOST_DEVICE constexpr span() : span(nullptr, size_type{0}) {}
CK_TILE_HOST_DEVICE constexpr span(pointer first, size_type count) : ptr_(first), size_(count)
{
}
CK_TILE_HOST_DEVICE constexpr span(pointer first, pointer last) : span(first, last - first) {}
template <std::size_t N>
CK_TILE_HOST_DEVICE constexpr span(element_type (&arr)[N]) noexcept : span(arr, N)
{
}
template <std::size_t N>
CK_TILE_HOST_DEVICE constexpr span(std::array<value_type, N>& arr) noexcept
: span(arr.data(), N)
{
}
template <typename Container>
CK_TILE_HOST_DEVICE constexpr span(const Container& container)
: span(container.data(), container.size())
{
}
CK_TILE_HOST_DEVICE constexpr iterator begin() const noexcept { return ptr_; }
CK_TILE_HOST_DEVICE constexpr const_iterator cbegin() const noexcept { return begin(); }
CK_TILE_HOST_DEVICE constexpr iterator end() const noexcept { return begin() + size(); }
CK_TILE_HOST_DEVICE constexpr const_iterator cend() const noexcept { return end(); }
CK_TILE_HOST_DEVICE constexpr reference front() const { return *begin(); }
CK_TILE_HOST_DEVICE constexpr reference back() const { return *(--end()); }
CK_TILE_HOST_DEVICE constexpr reference operator[](size_type idx) const
{
return *(begin() + idx);
}
CK_TILE_HOST_DEVICE constexpr pointer data() const noexcept { return ptr_; }
CK_TILE_HOST_DEVICE constexpr size_type size() const noexcept { return size_; }
private:
pointer ptr_;
size_type size_;
};
} // namespace ck_tile

View File

@@ -0,0 +1,41 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/array.hpp"
#include "ck_tile/core/container/tuple.hpp"
#include "ck_tile/core/numeric/integer.hpp"
namespace ck_tile {
#if CK_TILE_STATICALLY_INDEXED_ARRAY_DEFAULT == CK_TILE_STATICALLY_INDEXED_ARRAY_USE_TUPLE
template <typename T, index_t N>
using statically_indexed_array = tuple_array<T, N>;
#else
// consider mark this struct as deprecated
template <typename T, index_t N>
using statically_indexed_array = array<T, N>;
#endif
// consider always use ck_tile::array for this purpose
#if 0
template <typename X, typename... Xs>
CK_TILE_HOST_DEVICE constexpr auto make_statically_indexed_array(const X& x, const Xs&... xs)
{
return statically_indexed_array<X, sizeof...(Xs) + 1>(x, static_cast<X>(xs)...);
}
// make empty statically_indexed_array
template <typename X>
CK_TILE_HOST_DEVICE constexpr auto make_statically_indexed_array()
{
return statically_indexed_array<X, 0>();
}
#endif
} // namespace ck_tile

View File

@@ -0,0 +1,165 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/array.hpp"
#include "ck_tile/core/container/tuple.hpp"
namespace ck_tile {
#if CK_TILE_THREAD_BUFFER_DEFAULT == CK_TILE_THREAD_BUFFER_USE_TUPLE
template <typename T, index_t N>
using thread_buffer = tuple_array<T, N>;
template <typename... Ts>
CK_TILE_HOST_DEVICE constexpr auto make_thread_buffer(Ts&&... ts)
{
return make_tuple(ts...);
}
#else
#if 0
template <typename T, index_t N>
using thread_buffer = array<T, N>;
template <typename... Ts>
CK_TILE_HOST_DEVICE constexpr auto make_thread_buffer(Ts&&... ts)
{
return make_array(ts...);
}
#endif
// clang-format off
template<typename T_, index_t N_>
struct thread_buffer {
using value_type = remove_cvref_t<T_>;
static constexpr index_t N = N_;
value_type data[N];
// TODO: this ctor can't ignore
CK_TILE_HOST_DEVICE constexpr thread_buffer() : data{} {}
CK_TILE_HOST_DEVICE constexpr thread_buffer(const value_type & o) : data{o} {}
CK_TILE_HOST_DEVICE static constexpr auto size() { return N; }
CK_TILE_HOST_DEVICE auto & get() {return data; }
CK_TILE_HOST_DEVICE const auto & get() const {return data; }
CK_TILE_HOST_DEVICE auto & get(index_t i) {return data[i]; }
CK_TILE_HOST_DEVICE const auto & get(index_t i) const {return data[i]; }
CK_TILE_HOST_DEVICE constexpr const auto& operator[](index_t i) const { return get(i); }
CK_TILE_HOST_DEVICE constexpr auto& operator[](index_t i) { return get(i); }
CK_TILE_HOST_DEVICE constexpr auto& operator()(index_t i) { return get(i); } // TODO: compatible
CK_TILE_HOST_DEVICE constexpr auto& at(index_t i) { return get(i); }
CK_TILE_HOST_DEVICE constexpr const auto& at(index_t i) const { return get(i); }
template <index_t I> CK_TILE_HOST_DEVICE constexpr auto& at() { return get(I); }
template <index_t I> CK_TILE_HOST_DEVICE constexpr const auto& at() const { return get(I); }
template <index_t I> CK_TILE_HOST_DEVICE constexpr auto& at(number<I>) { return get(I); }
template <index_t I> CK_TILE_HOST_DEVICE constexpr const auto& at(number<I>) const { return get(I); }
template <typename X_,
typename std::enable_if<has_same_scalar_type<value_type, X_>::value, bool>::type = false>
CK_TILE_HOST_DEVICE constexpr auto _get_as() const
{
using X = remove_cvref_t<X_>;
constexpr index_t kSPerX = vector_traits<X>::vector_size;
static_assert(N % kSPerX == 0);
union {
thread_buffer<X_, N / kSPerX> data {};
// tuple_array<value_type, kSPerX> sub_data;
value_type sub_data[N];
} vx;
static_for<0, N, 1>{}(
[&](auto j) { vx.sub_data[j] = data[j]; });
return vx.data;
}
template <typename X_,
index_t Is,
typename std::enable_if<has_same_scalar_type<value_type, X_>::value, bool>::type = false>
CK_TILE_HOST_DEVICE const constexpr remove_reference_t<X_> _get_as(number<Is> is) const
{
using X = remove_cvref_t<X_>;
constexpr index_t kSPerX = vector_traits<X>::vector_size;
union {
X_ data {};
tuple_array<value_type, kSPerX> sub_data;
} vx;
static_for<0, kSPerX, 1>{}(
[&](auto j) { vx.sub_data(j) = operator[]((is * number<sizeof(X_)/sizeof(value_type)>{}) + j); });
return vx.data;
}
#if 0
template <typename X_,
index_t Is,
typename std::enable_if<has_same_scalar_type<value_type, X_>::value, bool>::type = false>
CK_TILE_HOST_DEVICE constexpr void _set_as(number<Is> is, X_ x)
{
using X = remove_cvref_t<X_>;
constexpr index_t kSPerX = vector_traits<X>::vector_size;
union {
X_ data;
tuple_array<value_type, kSPerX> sub_data;
} vx {x};
static_for<0, kSPerX, 1>{}(
[&](auto j) { operator()((is * number<sizeof(X_)/sizeof(value_type)>{}) + j) = vx.sub_data[j]; });
}
#endif
#define TB_COMMON_AS() \
static_assert(sizeof(value_type) * N % sizeof(Tx) == 0); \
constexpr int vx = sizeof(value_type) * N / sizeof(Tx)
template<typename Tx>
CK_TILE_HOST_DEVICE auto & get_as() {TB_COMMON_AS();
return reinterpret_cast<thread_buffer<Tx, vx>&>(data);}
template<typename Tx>
CK_TILE_HOST_DEVICE constexpr auto get_as() const {TB_COMMON_AS();
if constexpr(sizeof(value_type) <= 1 )
return _get_as<Tx>(); // TODO: current compiler for 8bit data need use union to get data back, should fix in the future
else
return reinterpret_cast<const thread_buffer<Tx, vx>&>(data);}
template<typename Tx, index_t I>
CK_TILE_HOST_DEVICE auto & get_as(number<I>) {TB_COMMON_AS();
return reinterpret_cast<thread_buffer<Tx, vx>&>(data).get(number<I>{});}
template<typename Tx, index_t I>
CK_TILE_HOST_DEVICE constexpr auto get_as(number<I>) const {TB_COMMON_AS();
if constexpr(sizeof(value_type) <= 1 )
return _get_as<Tx>(number<I>{}); // TODO: current compiler for 8bit data need use union to get data back, should fix in the future
else
return reinterpret_cast<const thread_buffer<Tx, vx>&>(data).get(number<I>{});}
template <typename Tx> CK_TILE_HOST_DEVICE constexpr void set_as(index_t i, const Tx & x)
{ TB_COMMON_AS(); reinterpret_cast<thread_buffer<Tx, vx>&>(data).at(i) = x; }
template <typename Tx, index_t I> CK_TILE_HOST_DEVICE constexpr void set_as(number<I>, const Tx & x)
{ TB_COMMON_AS(); reinterpret_cast<thread_buffer<Tx, vx>&>(data).at(number<I>{}) = x; }
#undef TB_COMMON_AS
};
// clang-format on
template <typename>
struct vector_traits;
// specialization for array
template <typename T, index_t N>
struct vector_traits<thread_buffer<T, N>>
{
using scalar_type = T;
static constexpr index_t vector_size = N;
};
#endif
} // namespace ck_tile

View File

@@ -0,0 +1,781 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/sequence.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include <utility>
#include <initializer_list>
#ifndef CK_TILE_TUPLE_IMPL
#define CK_TILE_TUPLE_IMPL 1
#endif
namespace ck_tile {
namespace impl {
template <typename T, index_t N>
struct tuple_array_impl;
}
template <typename T, index_t N>
using tuple_array = typename impl::tuple_array_impl<T, N>::type;
namespace impl {
// the place where content is stored
template <index_t idx, typename T, bool is_empty = std::is_empty_v<T>>
struct tuple_object
{
};
template <index_t idx, typename T>
struct tuple_object<idx, T, true>
{
CK_TILE_HOST_DEVICE constexpr tuple_object() {}
#if CK_TILE_TUPLE_IMPL == 0
template <typename U>
CK_TILE_HOST_DEVICE constexpr tuple_object(U&&)
{
}
template <typename U>
CK_TILE_HOST_DEVICE constexpr tuple_object(const U&)
{
}
template <typename U>
CK_TILE_HOST_DEVICE constexpr tuple_object(U&)
{
}
#elif CK_TILE_TUPLE_IMPL == 1
template <typename U,
typename std::enable_if<!std::is_same<remove_cvref_t<U>, tuple_object>::value,
bool>::type = false>
CK_TILE_HOST_DEVICE constexpr tuple_object(U&&)
{
}
#endif
};
template <index_t idx, typename T>
struct tuple_object<idx, T, false>
{
CK_TILE_HOST_DEVICE constexpr tuple_object() : element{} {}
#if CK_TILE_TUPLE_IMPL == 0
template <typename U>
CK_TILE_HOST_DEVICE constexpr tuple_object(U&& e) : element(std::forward<U>(e))
{
}
template <typename U>
CK_TILE_HOST_DEVICE constexpr tuple_object(const U& e) : element(e)
{
}
template <typename U>
CK_TILE_HOST_DEVICE constexpr tuple_object(U& e) : element(e)
{
}
#elif CK_TILE_TUPLE_IMPL == 1
template <typename U,
typename std::enable_if<!std::is_same<remove_cvref_t<U>, tuple_object>::value,
bool>::type = false>
CK_TILE_HOST_DEVICE constexpr tuple_object(U&& e) : element(std::forward<U>(e))
{
}
#endif
T element;
};
// NOTE: we return a instance(not a reference) if content is empty
template <index_t I, class T>
CK_TILE_HOST_DEVICE constexpr T getv(const tuple_object<I, T, true>&)
{
return {};
}
template <index_t I, class T>
CK_TILE_HOST_DEVICE constexpr const T& getv(const tuple_object<I, T, false>& x)
{
return x.element;
}
template <index_t I, class T>
CK_TILE_HOST_DEVICE constexpr T& getv(tuple_object<I, T, false>& x)
{
return x.element;
}
template <index_t I, class T>
CK_TILE_HOST_DEVICE constexpr T&& getv(tuple_object<I, T, false>&& x)
{
return static_cast<T&&>(x.element);
}
template <typename index_seq, typename... T>
struct tuple_base;
template <index_t... I, typename... T>
struct tuple_base<sequence<I...>, T...> : tuple_object<I, T>...
{
CK_TILE_HOST_DEVICE constexpr tuple_base() = default;
#if CK_TILE_TUPLE_CTOR_WITH_INITIALIZER_LIST
#define _ILE() (std::initializer_list<U>{}.size() - 1)
template <typename U>
CK_TILE_HOST_DEVICE constexpr tuple_base(std::initializer_list<U> us)
: tuple_object<I, T>(static_cast<T>(*(us.begin() + (I >= _ILE() ? _ILE() : I))))...
{
}
#undef _ILE
#endif
#if CK_TILE_TUPLE_IMPL == 0
template <class... U>
CK_TILE_HOST_DEVICE constexpr explicit tuple_base(U&&... u)
: tuple_object<I, T>(std::forward<U>(u))...
{
}
template <class... U>
CK_TILE_HOST_DEVICE constexpr explicit tuple_base(const U&... u) : tuple_object<I, T>(u)...
{
}
template <class... U>
CK_TILE_HOST_DEVICE constexpr explicit tuple_base(U&... u) : tuple_object<I, T>(u)...
{
}
template <class... U>
CK_TILE_HOST_DEVICE constexpr tuple_base(tuple_base<sequence<I...>, U...>&& u)
: tuple_object<I, T>(getv(static_cast<tuple_object<I, U>&&>(u)))...
{
}
template <class... U>
CK_TILE_HOST_DEVICE constexpr tuple_base(const tuple_base<sequence<I...>, U...>& u)
: tuple_object<I, T>(getv(static_cast<const tuple_object<I, U>&>(u)))...
{
}
template <class... U>
CK_TILE_HOST_DEVICE constexpr tuple_base(tuple_base<sequence<I...>, U...>& u)
: tuple_object<I, T>(getv(static_cast<tuple_object<I, U>&>(u)))...
{
}
#elif CK_TILE_TUPLE_IMPL == 1
template <class U,
typename std::enable_if<sizeof...(I) == 1 && sizeof...(T) == 1 &&
!std::is_same<remove_cvref_t<U>, tuple_base>::value,
bool>::type = false>
CK_TILE_HOST_DEVICE constexpr tuple_base(U&& u) : tuple_object<I, T>(std::forward<U>(u))...
{
}
template <typename... U, typename std::enable_if<sizeof...(U) >= 2, bool>::type = false>
CK_TILE_HOST_DEVICE constexpr tuple_base(U&&... u) : tuple_object<I, T>(std::forward<U>(u))...
{
static_assert(sizeof...(I) == sizeof...(T) && sizeof...(I) == sizeof...(U),
"wrong! inconsistent size");
}
#endif
};
} // namespace impl
template <class... T>
struct tuple : impl::tuple_base<make_index_sequence<sizeof...(T)>, T...>
{
CK_TILE_HOST_DEVICE
static constexpr auto size() { return sizeof...(T); }
using base = impl::tuple_base<make_index_sequence<sizeof...(T)>, T...>;
CK_TILE_HOST_DEVICE constexpr tuple() = default;
#if CK_TILE_TUPLE_CTOR_WITH_INITIALIZER_LIST
template <typename U>
CK_TILE_HOST_DEVICE constexpr tuple(std::initializer_list<U> us) : base(us)
{
}
#endif
#if CK_TILE_TUPLE_IMPL == 0
template <class... U>
CK_TILE_HOST_DEVICE constexpr tuple(U&&... u) : base(std::forward<U>(u)...)
{
}
template <class... U>
CK_TILE_HOST_DEVICE constexpr tuple(const U&... u) : base(u...)
{
}
template <class... U>
CK_TILE_HOST_DEVICE constexpr tuple(U&... u) : base(u...)
{
}
template <class... U>
CK_TILE_HOST_DEVICE constexpr tuple(tuple<U...>&& u)
: base(static_cast<impl::tuple_base<make_index_sequence<sizeof...(U)>, U...>&&>(u))
{
}
template <class... U>
CK_TILE_HOST_DEVICE constexpr tuple(const tuple<U...>& u)
: base(static_cast<const impl::tuple_base<make_index_sequence<sizeof...(U)>, U...>&>(u))
{
}
template <class... U>
CK_TILE_HOST_DEVICE constexpr tuple(tuple<U...>& u)
: base(static_cast<impl::tuple_base<make_index_sequence<sizeof...(U)>, U...>&>(u))
{
}
#elif CK_TILE_TUPLE_IMPL == 1
template <
typename U,
typename std::enable_if<sizeof...(T) == 1 && !std::is_same<remove_cvref_t<U>, tuple>::value,
bool>::type = false>
CK_TILE_HOST_DEVICE constexpr tuple(U&& u) : base(std::forward<U>(u))
{
}
template <typename... U,
typename std::enable_if<sizeof...(U) == sizeof...(T) && sizeof...(U) >= 2,
bool>::type = false>
CK_TILE_HOST_DEVICE constexpr tuple(U&&... u) : base(std::forward<U>(u)...)
{
}
#endif
CK_TILE_HOST_DEVICE static constexpr bool is_static()
{
bool flag = true;
static_for<0, sizeof...(T), 1>{}([&flag](auto i) {
flag &= is_static_v<remove_cvref_t<__type_pack_element<i.value, T...>>>;
});
return flag;
}
#define TP_COM_() static_assert(I < size(), "wrong! out of range")
// clang-format off
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get() const { TP_COM_(); return impl::getv<I>(*this); }
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get(number<I>) const { TP_COM_(); return get<I>(); }
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get() { TP_COM_(); return impl::getv<I>(*this); }
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get(number<I>) { TP_COM_(); return get<I>(); }
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) at() const { TP_COM_(); return impl::getv<I>(*this); }
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) at(number<I>) const { TP_COM_(); return get<I>(); }
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) at() { TP_COM_(); return impl::getv<I>(*this); }
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) at(number<I>) { TP_COM_(); return get<I>(); }
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) operator[](number<I>) { TP_COM_(); return get<I>(); }
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) operator[](number<I>) const { TP_COM_(); return get<I>(); }
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) operator()(number<I>) { TP_COM_(); return get<I>(); } // TODO: compatible
// below function should be used under tuple_array<> type, no extra check will perform here
template <typename Tx> CK_TILE_HOST_DEVICE constexpr decltype(auto) get_as() { return reinterpret_cast<tuple_array<Tx, size()>&>(*this); }
template <typename Tx> CK_TILE_HOST_DEVICE constexpr decltype(auto) get_as() const { return reinterpret_cast<const tuple_array<Tx, size()>&>(*this); }
// below index is for index *AFTER* type convert, not before
//template <typename Tx> CK_TILE_HOST_DEVICE constexpr decltype(auto) get_as(index_t i) { TP_COM_(); return reinterpret_cast<tuple_array<Tx, size()>&>(*this).at(i); }
//template <typename Tx> CK_TILE_HOST_DEVICE constexpr decltype(auto) get_as(index_t i) const { TP_COM_(); return reinterpret_cast<const tuple_array<Tx, size()>&>(*this).at(i); }
template <typename Tx, index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get_as(number<I>) { TP_COM_(); return reinterpret_cast<tuple_array<Tx, size()>&>(*this).at(number<I>{}); }
template <typename Tx, index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get_as(number<I>) const { TP_COM_(); return reinterpret_cast<const tuple_array<Tx, size()>&>(*this).at(number<I>{}); }
// template <typename Tx> CK_TILE_HOST_DEVICE constexpr void set_as(index_t i, const Tx & x) { TP_COM_(); reinterpret_cast<tuple_array<Tx, size()>&>(*this).at(i) = x; }
template <typename Tx, index_t I> CK_TILE_HOST_DEVICE constexpr void set_as(number<I>, const Tx & x) { TP_COM_(); reinterpret_cast<tuple_array<Tx, size()>&>(*this).at(number<I>{}) = x; }
// clang-format on
#undef TP_COM_
};
template <typename>
struct vector_traits;
// specialization for array
template <typename... T>
struct vector_traits<tuple<T...>>
{
using scalar_type = __type_pack_element<0, T...>;
static constexpr index_t vector_size = sizeof...(T);
};
// template <class... T>
// CK_TILE_HOST_DEVICE constexpr
// tuple<T...>
// make_tuple(T const&... t)
// {
// return {t...};
// }
template <typename... Xs>
CK_TILE_HOST_DEVICE constexpr bool operator==(const tuple<Xs...>& a, const tuple<Xs...>& b)
{
bool same = true;
static_for<0, sizeof...(Xs), 1>{}([&](auto i) {
if(a[i] != b[i])
{
same = false;
}
});
return same;
}
template <typename... Xs>
CK_TILE_HOST_DEVICE constexpr bool operator!=(const tuple<Xs...>& a, const tuple<Xs...>& b)
{
return !(a == b);
}
template <typename... Xs>
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs&&... xs)
{
// here xs is always a lvalue as function arg
// Xs may deduced as (e.g try to pass in a integer in following cases)
// 1). if pass in a rvalue (like function return or int{}) -> Xs is "int"
// 2). if pass in a const lvalue -> Xs is "const int &"
// 3). if pass in a non-const lvalue -> Xs is "int &"
// so the return type of std::forward will dependes on Xs
// 1). std::forward -> int&&
// 2). std::forward -> const int&
// 3). std::forward -> int&
return tuple<remove_cvref_t<Xs>...>(std::forward<Xs>(xs)...);
}
// https://en.cppreference.com/w/cpp/utility/tuple/tie
template <typename... Args>
constexpr tuple<Args&...> tie(Args&... args) noexcept
{
return {args...};
}
template <typename X, typename Y>
struct tuple_concat;
template <typename... Xs, typename... Ys>
struct tuple_concat<tuple<Xs...>, tuple<Ys...>>
{
using type = tuple<Xs..., Ys...>;
};
namespace impl {
// be very careful using this type (because we want the internal type)
// template deduction will fail if infering the inner type
// e.g.
// template<typename T, index_t N> using some_wrapper = typename tuple_array_impl<T, N>::type;
// template<typename T, index_t N> void foo(const some_wrapper<T, N>&) {}
// -> compiler will fail to deduce this type, because this is under non-deduced context
// (https://en.cppreference.com/w/cpp/language/template_argument_deduction, "Non-deduced
// contexts")
//
// -> use this instead
// template<typename Tup> void foo(const Tup&) {}
template <typename T, index_t N>
struct tuple_array_impl
{
using type = typename tuple_concat<typename tuple_array_impl<T, N / 2>::type,
typename tuple_array_impl<T, N - N / 2>::type>::type;
};
template <typename T>
struct tuple_array_impl<T, 0>
{
using type = tuple<>;
};
template <typename T>
struct tuple_array_impl<T, 1>
{
using type = tuple<T>;
};
} // namespace impl
template <typename F, index_t N>
CK_TILE_HOST_DEVICE constexpr auto generate_tuple(F&& f, number<N>)
{
return unpack([&f](auto&&... is) { return make_tuple(f(is)...); },
typename arithmetic_sequence_gen<0, N, 1>::type{});
}
template <typename F, index_t N>
CK_TILE_HOST_DEVICE constexpr auto generate_tie(F&& f, number<N>)
{
return unpack([&f](auto&&... is) { return tie(f(is)...); },
typename arithmetic_sequence_gen<0, N, 1>::type{});
}
// tx and ty are tuple of references, return type of will tuple of referennce (not rvalue)
template <typename... X, typename... Y>
CK_TILE_HOST_DEVICE constexpr auto concat_tuple_of_reference(const tuple<X&...>& tx,
const tuple<Y&...>& ty)
{
return unpack2(
[&](auto&&... zs) { return tuple<decltype(zs)...>{std::forward<decltype(zs)>(zs)...}; },
tx,
ty);
}
template <typename... X, typename... Y>
CK_TILE_HOST_DEVICE constexpr auto concat_tuple(const tuple<X...>& tx, const tuple<Y...>& ty)
{
return unpack2(
[&](auto... zs) { return tuple<decltype(zs)...>{std::forward<decltype(zs)>(zs)...}; },
tx,
ty);
}
// Support any number of tuples to concat (also 1)
template <typename... X>
CK_TILE_HOST_DEVICE constexpr auto concat_tuple(const tuple<X...>& tx)
{
return tx;
}
template <typename... X, typename... Tuples>
CK_TILE_HOST_DEVICE constexpr auto concat_tuple(const tuple<X...>& tx, const Tuples&... tuples)
{
return concat_tuple(tx, concat_tuple(tuples...));
}
namespace detail {
template <typename F, typename X, index_t... Is>
CK_TILE_HOST_DEVICE constexpr auto transform_tuples_impl(F f, const X& x, sequence<Is...>)
{
return make_tuple(f(x.at(number<Is>{}))...);
}
template <typename F, typename X, typename Y, index_t... Is>
CK_TILE_HOST_DEVICE constexpr auto
transform_tuples_impl(F f, const X& x, const Y& y, sequence<Is...>)
{
return make_tuple(f(x.at(number<Is>{}), y.at(number<Is>{}))...);
}
template <typename F, typename X, typename Y, typename Z, index_t... Is>
CK_TILE_HOST_DEVICE constexpr auto
transform_tuples_impl(F f, const X& x, const Y& y, const Z& z, sequence<Is...>)
{
return make_tuple(f(x.at(number<Is>{}), y.at(number<Is>{}), z.at(number<Is>{}))...);
}
} // namespace detail
template <typename F, typename X>
CK_TILE_HOST_DEVICE constexpr auto transform_tuples(F f, const X& x)
{
return detail::transform_tuples_impl(
f, x, typename arithmetic_sequence_gen<0, X::size(), 1>::type{});
}
template <typename F, typename X, typename Y>
CK_TILE_HOST_DEVICE constexpr auto transform_tuples(F f, const X& x, const Y& y)
{
return detail::transform_tuples_impl(
f, x, y, typename arithmetic_sequence_gen<0, X::size(), 1>::type{});
}
template <typename F, typename X, typename Y, typename Z>
CK_TILE_HOST_DEVICE constexpr auto transform_tuples(F f, const X& x, const Y& y, const Z& z)
{
return detail::transform_tuples_impl(
f, x, y, z, typename arithmetic_sequence_gen<0, X::size(), 1>::type{});
}
// By default unroll to the flatten
template <index_t Depth = 0, index_t MaxDepth = -1>
CK_TILE_HOST_DEVICE constexpr auto unroll_nested_tuple(const tuple<>& t)
{
return t;
}
template <index_t Depth = 0, index_t MaxDepth = -1, typename T>
CK_TILE_HOST_DEVICE constexpr auto unroll_nested_tuple(const T& t)
{
return make_tuple(t);
}
template <index_t Depth = 0, index_t MaxDepth = -1, typename... Ts>
CK_TILE_HOST_DEVICE constexpr auto unroll_nested_tuple(const tuple<Ts...>& t)
{
if constexpr(Depth == MaxDepth)
{
return t;
}
else
{
return unpack(
[&](auto&&... ts) {
return concat_tuple(unroll_nested_tuple<Depth + 1, MaxDepth>(ts)...);
},
t);
}
}
template <typename... Ts>
CK_TILE_HOST_DEVICE constexpr auto tuple_reverse(const tuple<Ts...>& t)
{
return generate_tuple(
[&](auto i) {
using Idx = number<tuple<Ts...>::size() - i - 1>;
return t.at(Idx{});
},
number<tuple<Ts...>::size()()>{});
}
// Reduce tuple values in specific range using Function
template <index_t Idx, index_t End, typename F, typename... Ts>
CK_TILE_HOST_DEVICE constexpr auto tuple_reduce(F&& f, const tuple<Ts...>& t)
{
static_assert(Idx < End, "Wrong parameters for tuple_reduce");
if constexpr(Idx + 1 == End)
{
return t.at(number<Idx>{});
}
else
{
return f(t.at(number<Idx>{}), tuple_reduce<Idx + 1, End>(f, t));
}
}
template <typename T>
using is_tuple = decltype(std::declval<T&>().IsTuple());
template <typename... Ts>
CK_TILE_HOST_DEVICE constexpr auto is_nested_tuple(const tuple<Ts...>&)
{
return (is_detected<is_tuple, Ts>::value || ...);
}
template <index_t depth = 0, typename T>
CK_TILE_HOST_DEVICE constexpr auto tuple_depth(const T&)
{
return depth;
}
template <index_t depth = 0, typename... Ts>
CK_TILE_HOST_DEVICE constexpr auto tuple_depth(const tuple<Ts...>&)
{
return max(tuple_depth<depth + 1>(Ts{})...);
}
template <typename... Seqs>
CK_TILE_HOST_DEVICE constexpr auto to_array_of_array(tuple<Seqs...> t_of_s)
{
constexpr index_t n0 = sizeof...(Seqs);
constexpr index_t max_n1 = [&] {
index_t max_n1_ = 0;
static_for<0, n0, 1>{}([&](auto i0) {
constexpr index_t n1 = t_of_s[i0].size();
max_n1_ = max_n1_ < n1 ? n1 : max_n1_;
});
return max_n1_;
}();
array<array<index_t, max_n1>, n0> a_of_a{{-1}};
static_for<0, n0, 1>{}([&](auto i0) {
constexpr index_t n1 = t_of_s[i0].size();
static_for<0, n1, 1>{}([&](auto i1) { a_of_a(i0)(i1) = t_of_s[i0][i1]; });
});
return a_of_a;
}
// Here should use MultiIndex<NSize>, instead of tuple<Ys...>, although the former
// is the alias of the latter. This is because compiler cannot infer the NSize if
// using MultiIndex<NSize>
// TODO: how to fix this?
template <typename... Ys,
typename X,
std::enable_if_t<!std::is_integral<X>::value && !std::is_floating_point<X>::value, bool> =
false>
CK_TILE_HOST_DEVICE constexpr auto operator+=(tuple<Ys...>& y, const X& x)
{
static_assert(X::Size() == sizeof...(Ys), "wrong! size not the same");
constexpr index_t NSize = sizeof...(Ys);
static_for<0, NSize, 1>{}([&](auto i) { y[i] += x[i]; });
return y;
}
template <typename... Ys,
typename X,
std::enable_if_t<!std::is_integral<X>::value && !std::is_floating_point<X>::value, bool> =
false>
CK_TILE_HOST_DEVICE constexpr auto operator-=(tuple<Ys...>& y, const X& x)
{
static_assert(X::Size() == sizeof...(Ys), "wrong! size not the same");
constexpr index_t NSize = sizeof...(Ys);
static_for<0, NSize, 1>{}([&](auto i) { y[i] -= x[i]; });
return y;
}
template <typename... Xs,
typename Y,
std::enable_if_t<!std::is_integral<Y>::value && !std::is_floating_point<Y>::value, bool> =
false>
CK_TILE_HOST_DEVICE constexpr auto operator+(const tuple<Xs...>& x, const Y& y)
{
static_assert(Y::Size() == sizeof...(Xs), "wrong! size not the same");
constexpr index_t NSize = sizeof...(Xs);
tuple<Xs...> r;
static_for<0, NSize, 1>{}([&](auto i) { r[i] = x[i] + y[i]; });
return r;
}
template <typename... Xs,
typename Y,
std::enable_if_t<!std::is_integral<Y>::value && !std::is_floating_point<Y>::value, bool> =
false>
CK_TILE_HOST_DEVICE constexpr auto operator-(const tuple<Xs...>& x, const Y& y)
{
static_assert(Y::Size() == sizeof...(Xs), "wrong! size not the same");
constexpr index_t NSize = sizeof...(Xs);
tuple<Xs...> r;
static_for<0, NSize, 1>{}([&](auto i) { r[i] = x[i] - y[i]; });
return r;
}
template <typename... Xs,
typename Y,
std::enable_if_t<!std::is_integral<Y>::value && !std::is_floating_point<Y>::value, bool> =
false>
CK_TILE_HOST_DEVICE constexpr auto operator*(const tuple<Xs...>& x, const Y& y)
{
static_assert(Y::Size() == sizeof...(Xs), "wrong! size not the same");
constexpr index_t NSize = sizeof...(Xs);
tuple<Xs...> r;
static_for<0, NSize, 1>{}([&](auto i) { r[i] = x[i] * y[i]; });
return r;
}
// MultiIndex = scalar * MultiIndex
template <
typename... Xs,
typename Y,
std::enable_if_t<std::is_integral<Y>::value || std::is_floating_point<Y>::value, bool> = false>
CK_TILE_HOST_DEVICE constexpr auto operator*(Y a, const tuple<Xs...>& x)
{
constexpr index_t NSize = sizeof...(Xs);
tuple<Xs...> r;
static_for<0, NSize, 1>{}([&](auto i) { r[i] = a * x[i]; });
return r;
}
// MultiIndex = MultiIndex * scalar
template <
typename... Xs,
typename Y,
std::enable_if_t<std::is_integral<Y>::value || std::is_floating_point<Y>::value, bool> = false>
CK_TILE_HOST_DEVICE constexpr auto operator*(const tuple<Xs...>& x, Y a)
{
return a * x;
}
template <typename... Xs, typename... Ys>
CK_TILE_HOST_DEVICE constexpr auto operator/(const tuple<Xs...>& x, const tuple<Ys...>& y)
{
static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong!");
constexpr index_t NSize = sizeof...(Xs);
return generate_tuple([&](auto i) { return x[i] / y[i]; }, number<NSize>{});
}
} // namespace ck_tile
#include <tuple>
// WARNING: needed by compiler for C++ structured binding support only, don't use this
namespace std {
template <typename... Ts>
struct tuple_size<ck_tile::tuple<Ts...>> : std::integral_constant<std::size_t, sizeof...(Ts)>
{
};
template <std::size_t I, typename... Ts>
struct tuple_element<I, ck_tile::tuple<Ts...>> : std::tuple_element<I, std::tuple<Ts...>>
{
};
template <typename... Ts>
struct tuple_size<const ck_tile::tuple<Ts...>> : std::integral_constant<std::size_t, sizeof...(Ts)>
{
};
template <std::size_t I, typename... Ts>
struct tuple_element<I, const ck_tile::tuple<Ts...>>
: std::tuple_element<I, const std::tuple<Ts...>>
{
};
} // namespace std
#if 1
#define TO_TUPLE_OF_NUMBER(a, n) \
_Pragma("clang diagnostic push") _Pragma( \
"clang diagnostic ignored \"-Wc++20-extensions\"")[a]<ck_tile::index_t... IDX_IDX_>( \
ck_tile::sequence<IDX_IDX_...>) \
{ \
return ck_tile::tuple<ck_tile::number<a[ck_tile::number<IDX_IDX_>{}]>...>{}; \
} \
(ck_tile::make_index_sequence<n>{}) _Pragma("clang diagnostic pop")
#else
#define TO_TUPLE_OF_NUMBER(arr, n_) \
[&arr, n_] { \
static_assert(arr.size() >= n_, "wrong! out of bound"); \
\
static_assert(n_ < 7, "not implemented"); \
\
if constexpr(n_ == 0) \
{ \
return ck_tile::tuple<>{}; \
} \
else if constexpr(n_ == 1) \
{ \
return ck_tile::tuple<number<arr[0]>>{}; \
} \
else if constexpr(n_ == 2) \
{ \
return ck_tile::tuple<number<arr[0]>, number<arr[1]>>{}; \
} \
else if constexpr(n_ == 3) \
{ \
return ck_tile::tuple<number<arr[0]>, number<arr[1]>, number<arr[2]>>{}; \
} \
else if constexpr(n_ == 4) \
{ \
return ck_tile:: \
tuple<number<arr[0]>, number<arr[1]>, number<arr[2]>, number<arr[3]>>{}; \
} \
else if constexpr(n_ == 5) \
{ \
return ck_tile::tuple<number<arr[0]>, \
number<arr[1]>, \
number<arr[2]>, \
number<arr[3]>, \
number<arr[4]>>{}; \
} \
else if constexpr(n_ == 6) \
{ \
return ck_tile::tuple<number<arr[0]>, \
number<arr[1]>, \
number<arr[2]>, \
number<arr[3]>, \
number<arr[4]>, \
number<arr[5]>>{}; \
} \
}()
#endif