This commit is contained in:
carlushuang
2024-02-28 22:57:19 +00:00
parent e60c5aea4e
commit f69356b1d7
130 changed files with 28268 additions and 0 deletions

View File

@@ -0,0 +1,18 @@
# ck_tile/core #
`ck_tile/core` contains every basic functions and structures to create a GPU kernel using `ck_tile`. User should only include `ck_tile/core.hpp` this single header to use all the functionality. Everything is under `ck_tile` namespace. The coding style under this folder should be similar to `std` (`snake_case` for structure/function, Camel for template types...)
```
algorithm/
coordinate transform and some other reusable algorithm
arch/
contains some basic device building block like mma, buffer addressing, etc...
container/
contains basic container data structure, array/sequence/tuple/...
numeric/
data type, and data type related math
tensor/
tensor descriptors and tile level API
utility/
other utility function for both host/device
```

View File

@@ -0,0 +1,38 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
#include "ck_tile/core/tensor/tensor_adaptor.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace ck_tile {
template <typename Lengths,
typename ArrangeOrder = typename arithmetic_sequence_gen<0, Lengths::size(), 1>::type>
CK_TILE_HOST_DEVICE constexpr auto make_cluster_descriptor(
const Lengths& lengths,
ArrangeOrder order = typename arithmetic_sequence_gen<0, Lengths::size(), 1>::type{})
{
constexpr index_t ndim_low = Lengths::size();
const auto reordered_lengths = container_reorder_given_new2old(lengths, order);
const auto low_lengths = generate_tuple(
[&](auto idim_low) { return reordered_lengths[idim_low]; }, number<ndim_low>{});
const auto transform = make_merge_transform(low_lengths);
constexpr auto low_dim_old_top_ids = ArrangeOrder{};
constexpr auto up_dim_new_top_ids = sequence<0>{};
return make_single_stage_tensor_adaptor(
make_tuple(transform), make_tuple(low_dim_old_top_ids), make_tuple(up_dim_new_top_ids));
}
} // namespace ck_tile

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,165 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/multi_index.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace ck_tile {
template <typename TensorLengths,
typename DimAccessOrder,
typename ScalarsPerAccess,
bool SnakeCurved = true> // # of scalars per access in each dimension
struct space_filling_curve
{
static constexpr index_t TensorSize =
reduce_on_sequence(TensorLengths{}, multiplies{}, number<1>{});
static_assert(0 < TensorSize,
"space_filling_curve should be used to access a non-empty tensor");
static constexpr index_t nDim = TensorLengths::size();
using Index = multi_index<nDim>;
static constexpr index_t ScalarPerVector =
reduce_on_sequence(ScalarsPerAccess{}, multiplies{}, number<1>{});
static constexpr auto access_lengths = TensorLengths{} / ScalarsPerAccess{};
static constexpr auto dim_access_order = DimAccessOrder{};
static constexpr auto ordered_access_lengths =
container_reorder_given_new2old(access_lengths, dim_access_order);
static constexpr auto to_index_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(ordered_access_lengths)),
make_tuple(typename arithmetic_sequence_gen<0, nDim, 1>::type{}),
make_tuple(sequence<0>{}));
static constexpr auto I0 = number<0>{};
static constexpr auto I1 = number<1>{};
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_access()
{
static_assert(TensorLengths::size() == ScalarsPerAccess::size());
static_assert(TensorLengths{} % ScalarsPerAccess{} ==
typename uniform_sequence_gen<TensorLengths::size(), 0>::type{});
return reduce_on_sequence(TensorLengths{}, multiplies{}, number<1>{}) / ScalarPerVector;
}
template <index_t AccessIdx1dHead, index_t AccessIdx1dTail>
static CK_TILE_HOST_DEVICE constexpr auto get_step_between(number<AccessIdx1dHead>,
number<AccessIdx1dTail>)
{
static_assert(AccessIdx1dHead >= 0 && AccessIdx1dHead < get_num_of_access(),
"1D index out of range");
static_assert(AccessIdx1dTail >= 0 && AccessIdx1dTail < get_num_of_access(),
"1D index out of range");
constexpr auto idx_head = get_index(number<AccessIdx1dHead>{});
constexpr auto idx_tail = get_index(number<AccessIdx1dTail>{});
return idx_tail - idx_head;
}
template <index_t AccessIdx1d>
static CK_TILE_HOST_DEVICE constexpr auto get_forward_step(number<AccessIdx1d>)
{
static_assert(AccessIdx1d < get_num_of_access(), "1D index should be larger than 0");
return get_step_between(number<AccessIdx1d>{}, number<AccessIdx1d + 1>{});
}
template <index_t AccessIdx1d>
static CK_TILE_HOST_DEVICE constexpr auto get_backward_step(number<AccessIdx1d>)
{
static_assert(AccessIdx1d > 0, "1D index should be larger than 0");
return get_step_between(number<AccessIdx1d>{}, number<AccessIdx1d - 1>{});
}
template <index_t AccessIdx1d>
static CK_TILE_HOST_DEVICE constexpr Index get_index(number<AccessIdx1d>)
{
#if 0
/*
* \todo: tensor_adaptor::calculate_bottom_index does NOT return constexpr as expected.
*/
constexpr auto ordered_access_idx = to_index_adaptor.calculate_bottom_index(make_multi_index(number<AccessIdx1d>{}));
#else
constexpr auto access_strides =
container_reverse_exclusive_scan(ordered_access_lengths, multiplies{}, number<1>{});
constexpr auto idx_1d = number<AccessIdx1d>{};
// Given tensor strides \p access_lengths, and 1D index of space-filling-curve, compute the
// idim-th element of multidimensional index.
// All constexpr variables have to be captured by VALUE.
constexpr auto compute_index = [ idx_1d, access_strides ](auto idim) constexpr
{
constexpr auto compute_index_impl = [ idx_1d, access_strides ](auto jdim) constexpr
{
auto res = idx_1d.value;
auto id = 0;
static_for<0, jdim.value + 1, 1>{}([&](auto kdim) {
id = res / access_strides[kdim].value;
res -= id * access_strides[kdim].value;
});
return id;
};
constexpr auto id = compute_index_impl(idim);
return number<id>{};
};
constexpr auto ordered_access_idx = generate_tuple(compute_index, number<nDim>{});
#endif
constexpr auto forward_sweep = [&]() {
StaticallyIndexedArray<bool, nDim> forward_sweep_;
forward_sweep_(I0) = true;
static_for<1, nDim, 1>{}([&](auto idim) {
index_t tmp = ordered_access_idx[I0];
static_for<1, idim, 1>{}(
[&](auto j) { tmp = tmp * ordered_access_lengths[j] + ordered_access_idx[j]; });
forward_sweep_(idim) = tmp % 2 == 0;
});
return forward_sweep_;
}();
// calculate multi-dim tensor index
auto idx_md = [&]() {
Index ordered_idx;
static_for<0, nDim, 1>{}([&](auto idim) {
ordered_idx(idim) =
!SnakeCurved || forward_sweep[idim]
? ordered_access_idx[idim]
: ordered_access_lengths[idim] - 1 - ordered_access_idx[idim];
});
return container_reorder_given_old2new(ordered_idx, dim_access_order) *
ScalarsPerAccess{};
}();
return idx_md;
}
// FIXME: rename this function
template <index_t AccessIdx1d>
static CK_TILE_HOST_DEVICE constexpr auto get_index_tuple_of_number(number<AccessIdx1d>)
{
constexpr auto idx = get_index(number<AccessIdx1d>{});
return generate_tuple([&](auto i) { return number<idx[i]>{}; }, number<nDim>{});
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,20 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
// Address Space for AMDGCN
// https://llvm.org/docs/AMDGPUUsage.html#address-space
namespace ck_tile {
enum struct address_space_enum
{
generic,
global,
lds,
sgpr,
vgpr,
};
} // namespace ck_tile

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,56 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#ifdef __HIPCC__
#define CK_TILE_HOST __host__
#define CK_TILE_DEVICE __device__
#define CK_TILE_HOST_DEVICE __host__ __device__
#else
#define CK_TILE_HOST inline
#define CK_TILE_DEVICE inline
#define CK_TILE_HOST_DEVICE inline
#endif
#define CK_TILE_FLOAT_TO_BFLOAT16_STANDARD 0
#define CK_TILE_FLOAT_TO_BFLOAT16_TRUNCATE_WITH_NAN 1
#define CK_TILE_FLOAT_TO_BFLOAT16_TRUNCATE 2
#ifndef CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT
#define CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT CK_TILE_FLOAT_TO_BFLOAT16_TRUNCATE
#endif
#define CK_TILE_FLOAT_TO_FP8_STANDARD 0
#define CK_TILE_FLOAT_TO_FP8_STOCHASTIC 1
#ifndef CK_TILE_FLOAT_TO_FP8_DEFAULT
#define CK_TILE_FLOAT_TO_FP8_DEFAULT CK_TILE_FLOAT_TO_FP8_STANDARD
#endif
#ifndef STATIC_ASSERT
#ifndef NDEBUG
#define STATIC_ASSERT(...) static_assert(__VA_ARGS__)
#else
#define STATIC_ASSERT(...)
#endif
#endif // #ifndef STATIC_ASSERT
// in the old rocm period, we have to use tuple array implementation to implement this
// so turn on the _USE_TUPLE if meet compiler error, otherwise _USE_ARRAY by default.
#define CK_TILE_STATICALLY_INDEXED_ARRAY_USE_ARRAY 0
#define CK_TILE_STATICALLY_INDEXED_ARRAY_USE_TUPLE 1
#ifndef CK_TILE_STATICALLY_INDEXED_ARRAY_DEFAULT
#define CK_TILE_STATICALLY_INDEXED_ARRAY_DEFAULT CK_TILE_STATICALLY_INDEXED_ARRAY_USE_ARRAY
#endif
#ifndef CK_TILE_USE_LAUNCH_BOUNDS
#define CK_TILE_USE_LAUNCH_BOUNDS 1
#endif
#ifndef CK_TILE_TIME_KERNEL
#define CK_TILE_TIME_KERNEL 1
#endif
#define CK_TILE_MAX_THREAD_PER_BLOCK 256
#define CK_TILE_MIN_BLOCK_PER_CU 2

View File

@@ -0,0 +1,201 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
namespace ck_tile {
// use aggregate initialization for this type
// e.g. array<index_t, 4> buf {0}; => {0, 0, 0, 0}, clean
// array<index_t, 4> buf {3, 2}; => {3, 2, 2, 2} (not {3,2,0,0})
// use make_array_with({...}) to construct an array with compatible behavior as old ck
// TODO: manually added constructor same as old ck
template <typename T_, index_t N_>
struct array
{
using value_type = T_;
static constexpr index_t N = N_;
value_type data[N];
CK_TILE_HOST_DEVICE constexpr array() : data{} {}
// TODO: will initialize the data[] with the last value repeatedly
// behavior different from std
CK_TILE_HOST_DEVICE constexpr array(std::initializer_list<value_type> ilist)
{
constexpr index_t list_size = std::initializer_list<value_type>{}.size();
static_assert(list_size <= N, "out of bound");
index_t i = 0;
value_type vlast = value_type{};
for(const value_type& val : ilist)
{
data[i] = val;
vlast = val;
++i;
}
for(; i < N; ++i)
{
data[i] = vlast;
}
}
CK_TILE_HOST_DEVICE static constexpr auto size() { return N; }
CK_TILE_HOST_DEVICE static constexpr bool is_static() { return is_static_v<value_type>; }
// clang-format off
CK_TILE_HOST_DEVICE constexpr auto& get() { return data; }
CK_TILE_HOST_DEVICE constexpr const auto& get() const { return data; }
CK_TILE_HOST_DEVICE constexpr auto& get(index_t i) { return data[i]; }
CK_TILE_HOST_DEVICE constexpr const auto& get(index_t i) const { return data[i]; }
template <index_t I> CK_TILE_HOST_DEVICE constexpr auto& get() { return data[I]; }
template <index_t I> CK_TILE_HOST_DEVICE constexpr const auto& get() const { return data[I]; }
template <index_t I> CK_TILE_HOST_DEVICE constexpr auto& get(number<I>) { return data[I]; }
template <index_t I> CK_TILE_HOST_DEVICE constexpr const auto& get(number<I>) const { return data[I]; }
CK_TILE_HOST_DEVICE constexpr auto& at(index_t i) { return data[i]; }
CK_TILE_HOST_DEVICE constexpr const auto& at(index_t i) const { return data[i]; }
template <index_t I> CK_TILE_HOST_DEVICE constexpr auto& at() { return data[I]; }
template <index_t I> CK_TILE_HOST_DEVICE constexpr const auto& at() const { return data[I]; }
template <index_t I> CK_TILE_HOST_DEVICE constexpr auto& at(number<I>) { return data[I]; }
template <index_t I> CK_TILE_HOST_DEVICE constexpr const auto& at(number<I>) const { return data[I]; }
CK_TILE_HOST_DEVICE constexpr const value_type& operator[](index_t i) const { return data[i]; }
CK_TILE_HOST_DEVICE constexpr value_type& operator[](index_t i) { return data[i]; }
CK_TILE_HOST_DEVICE constexpr value_type& operator()(index_t i) { return data[i]; } // TODO: compatible
template <typename T>
CK_TILE_HOST_DEVICE constexpr auto operator=(const T& a)
{
static_assert(T::size() == size(), "wrong! size not the same");
for(index_t i = 0; i < size(); ++i)
{
data[i] = a[i];
}
return *this;
}
// type punning (strict aliasing) member functions for read/write
// aliasing this array of type "T", "N" elements
// as array of type "Tx", sizeof(T)*N/sizeof(Tx) elements
#define AR_AS_COM_() \
static_assert(sizeof(value_type) * N % sizeof(Tx) == 0); \
constexpr int vx = sizeof(value_type) * N / sizeof(Tx)
template <typename Tx> CK_TILE_HOST_DEVICE constexpr auto& get_as()
{ AR_AS_COM_(); return reinterpret_cast<array<Tx, vx>&>(data); }
template <typename Tx> CK_TILE_HOST_DEVICE constexpr const auto& get_as() const
{ AR_AS_COM_(); return reinterpret_cast<const array<Tx, vx>&>(data); }
// below index is for index *AFTER* type convert, not before
template <typename Tx> CK_TILE_HOST_DEVICE constexpr auto& get_as(index_t i)
{ AR_AS_COM_(); return reinterpret_cast<array<Tx, vx>&>(data).at(i); }
template <typename Tx> CK_TILE_HOST_DEVICE constexpr const auto& get_as(index_t i) const
{ AR_AS_COM_(); return reinterpret_cast<const array<Tx, vx>&>(data).at(i); }
template <typename Tx, index_t I> CK_TILE_HOST_DEVICE constexpr auto& get_as(number<I>)
{ AR_AS_COM_(); return reinterpret_cast<array<Tx, vx>&>(data).at(number<I>{}); }
template <typename Tx, index_t I> CK_TILE_HOST_DEVICE constexpr const auto& get_as(number<I>) const
{ AR_AS_COM_(); return reinterpret_cast<const array<Tx, vx>&>(data).at(number<I>{}); }
template <typename Tx> CK_TILE_HOST_DEVICE constexpr void set_as(index_t i, const Tx & x)
{ AR_AS_COM_(); reinterpret_cast<array<Tx, vx>&>(data).at(i) = x; }
template <typename Tx, index_t I> CK_TILE_HOST_DEVICE constexpr void set_as(number<I>, const Tx & x)
{ AR_AS_COM_(); reinterpret_cast<array<Tx, vx>&>(data).at(number<I>{}) = x; }
#undef AR_AS_COM_
// clang-format on
};
// empty Array
template <typename T>
struct array<T, 0>
{
using value_type = T;
CK_TILE_HOST_DEVICE constexpr array() {}
CK_TILE_HOST_DEVICE static constexpr index_t size() { return 0; }
CK_TILE_HOST_DEVICE static constexpr bool is_static() { return is_static_v<T>; };
CK_TILE_HOST_DEVICE void print() const { printf("array{size: 0, data: []}"); }
};
template <typename T, typename... Ts>
CK_TILE_HOST_DEVICE constexpr auto make_array(T&& x, Ts&&... xs)
{
using value_type = remove_cvref_t<T>;
return array<value_type, sizeof...(Ts) + 1>{std::forward<T>(x), std::forward<Ts>(xs)...};
}
// make empty array
template <typename T>
CK_TILE_HOST_DEVICE constexpr auto make_array()
{
return array<T, 0>{};
}
// compatible with old ck's initializer, make an array and fill it withe the last element from
// initializer_list
#include <initializer_list>
template <typename T, index_t Size>
CK_TILE_HOST_DEVICE constexpr auto make_array_with(std::initializer_list<T> ilist)
{
constexpr index_t list_size = std::initializer_list<T>{}.size();
static_assert(list_size <= Size, "out of bound");
index_t i = 0;
T vlast = T{};
array<T, Size> arr;
for(const T& val : ilist)
{
arr.data[i] = val;
vlast = val;
++i;
}
for(; i < Size; ++i)
{
arr.data[i] = vlast;
}
return arr;
}
template <typename T, index_t Size>
CK_TILE_HOST_DEVICE constexpr bool operator==(const array<T, Size>& a, const array<T, Size>& b)
{
bool same = true;
for(index_t i = 0; i < Size; ++i)
{
if(a[i] != b[i])
{
same = false;
break;
}
}
return same;
}
template <typename T, index_t Size>
CK_TILE_HOST_DEVICE constexpr bool operator!=(const array<T, Size>& a, const array<T, Size>& b)
{
return !(a == b);
}
template <typename T, index_t N, typename X>
CK_TILE_HOST_DEVICE constexpr auto to_array(const X& x)
{
STATIC_ASSERT(N <= X::size(), "");
array<T, N> arr;
static_for<0, N, 1>{}([&x, &arr](auto i) { arr(i) = x[i]; });
return arr;
}
} // namespace ck_tile

View File

@@ -0,0 +1,483 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/array.hpp"
#include "ck_tile/core/container/map.hpp"
#include "ck_tile/core/container/sequence.hpp"
#include "ck_tile/core/container/tuple.hpp"
#include "ck_tile/core/utility/functional.hpp"
namespace ck_tile {
template <typename TData, index_t NSize>
CK_TILE_HOST_DEVICE constexpr auto container_push_back(const array<TData, NSize>& a, const TData& x)
{
array<TData, NSize + 1> r;
static_for<0, NSize, 1>{}([&r, &a ](auto i) constexpr { r(i) = a[i]; });
r[number<NSize>{}] = x;
return r;
}
template <typename... Ts, typename T>
CK_TILE_HOST_DEVICE constexpr auto container_push_front(const tuple<Ts...>& a, const T& x)
{
return container_concat(make_tuple(x), a);
}
template <typename... Ts, typename T>
CK_TILE_HOST_DEVICE constexpr auto container_push_back(const tuple<Ts...>& a, const T& x)
{
return container_concat(a, make_tuple(x));
}
// reorder array
template <typename TData, index_t NSize, index_t... IRs>
CK_TILE_HOST_DEVICE constexpr auto
container_reorder_given_new2old(const array<TData, NSize>& old_array, sequence<IRs...> /*new2old*/)
{
static_assert(NSize == sizeof...(IRs), "wrong! size not consistent");
static_assert(is_valid_sequence_map<sequence<IRs...>>{}, "wrong! invalid reorder map");
return make_array<remove_cvref_t<TData>>(old_array[IRs]...);
}
template <typename TData, index_t NSize, index_t... IRs>
CK_TILE_HOST_DEVICE constexpr auto
container_reorder_given_old2new(const array<TData, NSize>& old_array, sequence<IRs...> old2new)
{
return container_reorder_given_new2old(
old_array, typename sequence_map_inverse<decltype(old2new)>::type{});
}
// reorder array
template <typename TData, index_t NSize>
CK_TILE_HOST_DEVICE constexpr auto
container_reorder_given_new2old(const array<TData, NSize>& old_array,
const map<index_t, index_t>& new2old)
{
array<TData, NSize> new_array;
for(const auto& [new_pos, old_pos] : new2old)
{
new_array(new_pos) = old_array[old_pos];
}
return new_array;
}
template <typename TData, index_t NSize>
CK_TILE_HOST_DEVICE constexpr auto
container_reorder_given_old2new(const array<TData, NSize>& old_array,
const map<index_t, index_t>& old2new)
{
array<TData, NSize> new_array;
for(const auto& [old_pos, new_pos] : old2new)
{
new_array(new_pos) = old_array[old_pos];
}
return new_array;
}
// reorder tuple
template <typename... Ts, index_t... IRs>
CK_TILE_HOST_DEVICE constexpr auto container_reorder_given_new2old(const tuple<Ts...>& old_tuple,
sequence<IRs...> /*new2old*/)
{
static_assert(sizeof...(Ts) == sizeof...(IRs), "wrong! size not consistent");
static_assert(is_valid_sequence_map<sequence<IRs...>>{}, "wrong! invalid reorder map");
return make_tuple(old_tuple[number<IRs>{}]...);
}
template <typename... Ts, index_t... IRs>
CK_TILE_HOST_DEVICE constexpr auto container_reorder_given_old2new(const tuple<Ts...>& old_tuple,
sequence<IRs...> old2new)
{
return container_reorder_given_new2old(
old_tuple, typename sequence_map_inverse<decltype(old2new)>::type{});
}
// reorder sequence
template <index_t... Is, index_t... IRs>
CK_TILE_HOST_DEVICE constexpr auto container_reorder_given_new2old(sequence<Is...> /* old_seq */,
sequence<IRs...> /*new2old*/)
{
static_assert(sizeof...(Is) == sizeof...(IRs), "wrong! size not consistent");
static_assert(is_valid_sequence_map<sequence<IRs...>>{}, "wrong! invalid reorder map");
return sequence<sequence<Is...>::at(number<IRs>{})...>{};
}
template <index_t... Is, index_t... IRs>
CK_TILE_HOST_DEVICE constexpr auto container_reorder_given_old2new(sequence<Is...> old_seq,
sequence<IRs...> /* old2new */)
{
static_assert(sizeof...(Is) == sizeof...(IRs), "wrong! size not consistent");
static_assert(is_valid_sequence_map<sequence<IRs...>>{}, "wrong! invalid reorder map");
constexpr auto new2old = typename sequence_map_inverse<sequence<IRs...>>::type{};
return container_reorder_given_new2old(old_seq, new2old);
}
#if 0
// rocm-4.1 compiler would crash for recursive lambda
template <typename Container,
typename Reduce,
typename Init,
index_t IBegin = 0,
index_t IEnd = Container::size(),
index_t IStep = 1>
CK_TILE_HOST_DEVICE constexpr auto container_reduce(const Container& x,
Reduce reduce,
Init init,
number<IBegin> = number<0>{},
number<IEnd> = number<Container::size()>{},
number<IStep> = number<1>{})
{
static_assert((IEnd - IBegin) % IStep == 0, "wrong!");
// f is recursive function, fs is a dummy of f
// i is index, y_old is current scan, r_old is current reduction
auto f = [&](auto fs, auto i, auto r_old) {
auto r_new = reduce(x[i], r_old);
if constexpr(i.value < IEnd - IStep)
{
// recursively call f/fs
return fs(fs, i + number<IStep>{}, r_new);
}
else
{
return r_new;
}
};
// start recursion
return f(f, number<IBegin>{}, init);
}
#else
// i is index, y_old is current scan, r_old is current reduction
template <typename Container,
typename Reduce,
typename ROld,
index_t I,
index_t IEnd,
index_t IStep>
CK_TILE_HOST_DEVICE constexpr auto container_reduce_impl(
const Container& x, Reduce reduce, ROld r_old, number<I> i, number<IEnd>, number<IStep>)
{
auto r_new = reduce(x[i], r_old);
if constexpr(i.value < IEnd - IStep)
{
return container_reduce_impl(
x, reduce, r_new, i + number<IStep>{}, number<IEnd>{}, number<IStep>{});
}
else
{
return r_new;
}
}
// rocm-4.1 compiler would crash for recursive lambda
// container reduce with initial value
template <typename Container,
typename Reduce,
typename Init,
index_t IBegin = 0,
index_t IEnd = Container::size(),
index_t IStep = 1>
CK_TILE_HOST_DEVICE constexpr auto container_reduce(const Container& x,
Reduce reduce,
Init init,
number<IBegin> = number<0>{},
number<IEnd> = number<Container::size()>{},
number<IStep> = number<1>{})
{
static_assert((IEnd - IBegin) % IStep == 0, "wrong!");
if constexpr(IEnd > IBegin)
{
return container_reduce_impl(
x, reduce, init, number<IBegin>{}, number<IEnd>{}, number<IStep>{});
}
else
{
return init;
}
}
#endif
template <typename TData, index_t NSize, typename Reduce>
CK_TILE_HOST_DEVICE constexpr auto
container_reverse_inclusive_scan(const array<TData, NSize>& x, Reduce f, TData init)
{
array<TData, NSize> y;
TData r = init;
static_for<NSize - 1, 0, -1>{}([&](auto i) {
r = f(r, x[i]);
y(i) = r;
});
r = f(r, x[number<0>{}]);
y(number<0>{}) = r;
return y;
}
template <typename TData, index_t NSize, typename Reduce, typename Init>
CK_TILE_HOST_DEVICE constexpr auto
container_reverse_exclusive_scan(const array<TData, NSize>& x, Reduce f, Init init)
{
#if 0
array<TData, NSize> y;
TData r = init;
static_for<NSize - 1, 0, -1>{}([&](auto i) {
y(i) = r;
r = f(r, x[i]);
});
y(number<0>{}) = r;
return y;
#else
array<TData, NSize> y;
TData r = init;
for(index_t i = NSize - 1; i > 0; --i)
{
y(i) = r;
r = f(r, x[i]);
}
y(0) = r;
return y;
#endif
}
template <index_t... Is, typename Reduce, index_t Init>
CK_TILE_HOST_DEVICE constexpr auto
container_reverse_exclusive_scan(const sequence<Is...>& seq, Reduce f, number<Init>)
{
return reverse_exclusive_scan_sequence(seq, f, number<Init>{});
}
#if 0
// rocm4.1 compiler would crash with recursive lambda
template <typename... Xs, typename Reduce, typename Init>
CK_TILE_HOST_DEVICE constexpr auto
container_reverse_exclusive_scan(const tuple<Xs...>& x, Reduce reduce, Init init)
{
constexpr index_t NSize = sizeof...(Xs);
// f is recursive function, fs is a dummy of f
// i is index, y_old is current scan, r_old is current reduction
auto f = [&](auto fs, auto i, auto y_old, auto r_old) {
auto r_new = reduce(x[i], r_old);
auto y_new = container_push_front(y_old, r_new);
if constexpr(i.value > 1)
{
// recursively call f/fs
return fs(fs, i - number<1>{}, y_new, r_new);
}
else
{
return y_new;
}
};
// start recursion
return f(f, number<NSize - 1>{}, make_tuple(init), init);
}
#else
// i is index, y_old is current scan, r_old is current reduction
template <typename... Xs, typename Reduce, index_t I, typename YOld, typename ROld>
CK_TILE_HOST_DEVICE constexpr auto container_reverse_exclusive_scan_impl(
const tuple<Xs...>& x, Reduce reduce, number<I> i, YOld y_old, ROld r_old)
{
auto r_new = reduce(x[i], r_old);
auto y_new = container_push_front(y_old, r_new);
if constexpr(i.value > 1)
{
// recursively call f/fs
return container_reverse_exclusive_scan_impl(x, reduce, i - number<1>{}, y_new, r_new);
}
else
{
return y_new;
}
}
template <typename... Xs, typename Reduce, typename Init>
CK_TILE_HOST_DEVICE constexpr auto
container_reverse_exclusive_scan(const tuple<Xs...>& x, Reduce reduce, Init init)
{
constexpr index_t NSize = sizeof...(Xs);
return container_reverse_exclusive_scan_impl(
x, reduce, number<NSize - 1>{}, make_tuple(init), init);
}
#endif
// TODO: update to like container_reverse_exclusive_scan to deal with tuple of Numebr<>
template <typename... Xs, typename Reduce, typename TData>
CK_TILE_HOST_DEVICE constexpr auto
container_reverse_inclusive_scan(const tuple<Xs...>& x, Reduce f, TData init)
{
constexpr index_t NSize = sizeof...(Xs);
tuple<Xs...> y;
TData r = init;
static_for<NSize - 1, 0, -1>{}([&](auto i) {
r = f(r, x[i]);
y(i) = r;
});
r = f(r, x[number<0>{}]);
y(number<0>{}) = r;
return y;
}
template <typename X, typename... Ys>
CK_TILE_HOST_DEVICE constexpr auto container_concat(const X& x, const Ys&... ys)
{
return container_concat(x, container_concat(ys...));
}
template <typename T, index_t NX, index_t NY>
CK_TILE_HOST_DEVICE constexpr auto container_concat(const array<T, NX>& ax, const array<T, NY>& ay)
{
return unpack2(
[&](auto&&... zs) { return make_array<T>(std::forward<decltype(zs)>(zs)...); }, ax, ay);
}
template <typename... X, typename... Y>
CK_TILE_HOST_DEVICE constexpr auto container_concat(const tuple<X...>& tx, const tuple<Y...>& ty)
{
return unpack2(
[&](auto&&... zs) { return make_tuple(std::forward<decltype(zs)>(zs)...); }, tx, ty);
}
template <typename Container>
CK_TILE_HOST_DEVICE constexpr auto container_concat(const Container& x)
{
return x;
}
template <typename T, index_t N, index_t... Is>
CK_TILE_HOST_DEVICE constexpr auto get_container_subset(const array<T, N>& arr, sequence<Is...>)
{
static_assert(N >= sizeof...(Is), "wrong! size");
if constexpr(sizeof...(Is) > 0)
{
return make_array<T>(arr[Is]...);
}
else
{
return array<T, 0>{};
}
}
template <typename... Ts, index_t... Is>
CK_TILE_HOST_DEVICE constexpr auto get_container_subset(const tuple<Ts...>& tup, sequence<Is...>)
{
static_assert(sizeof...(Ts) >= sizeof...(Is), "wrong! size");
if constexpr(sizeof...(Is) > 0)
{
return make_tuple(tup[number<Is>{}]...);
}
else
{
return tuple<>{};
}
}
template <typename T, index_t N, index_t... Is>
CK_TILE_HOST_DEVICE constexpr void
set_container_subset(array<T, N>& y, sequence<Is...> picks, const array<T, sizeof...(Is)>& x)
{
static_assert(N >= sizeof...(Is), "wrong! size");
if constexpr(sizeof...(Is) > 0)
{
for(index_t i = 0; i < picks.size(); ++i)
{
y(picks[i]) = x[i];
}
}
}
template <typename Y, typename X, index_t... Is>
CK_TILE_HOST_DEVICE constexpr void set_container_subset(Y& y, sequence<Is...> picks, const X& x)
{
static_assert(Y::size() >= sizeof...(Is) && X::size() == sizeof...(Is), "wrong! size");
if constexpr(sizeof...(Is) > 0)
{
static_for<0, sizeof...(Is), 1>{}([&](auto i) { y(picks[i]) = x[i]; });
}
}
// return the index of first occurance in the sequence.
// return seq.size(), if not found
template <index_t... Is>
constexpr index_t container_find(sequence<Is...> seq, index_t value)
{
for(auto i = 0; i < seq.size(); i++)
{
if(seq[i] == value)
return i;
}
return seq.size();
}
template <index_t... Is>
CK_TILE_HOST_DEVICE constexpr auto sequence_to_tuple_of_number(sequence<Is...>)
{
using Seq = sequence<Is...>;
return generate_tuple(
[&](auto i) {
constexpr index_t tmp = Seq::at(i);
return number<tmp>{};
},
number<Seq::size()>{});
}
#define TO_TUPLE_OF_SEQUENCE(a_of_b_impl, a_size, bs_sizes) \
[a_of_b_impl, a_size, bs_sizes] { \
return ck_tile::generate_tuple( \
[=](auto i) { \
constexpr auto b_impl = a_of_b_impl[i]; \
constexpr index_t b_size = bs_sizes[i]; \
constexpr auto b = TO_SEQUENCE(b_impl, b_size); \
return b; \
}, \
ck_tile::number<a_size>{}); \
}()
} // namespace ck_tile

View File

@@ -0,0 +1,164 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/array.hpp"
#include "ck_tile/core/container/sequence.hpp"
#include "ck_tile/core/container/tuple.hpp"
namespace ck_tile {
// naive map
template <typename key, typename data, index_t max_size = 128>
struct map
{
using pair_type = tuple<key, data>;
using impl_type = array<pair_type, max_size>;
impl_type impl_;
index_t size_;
struct iterator
{
impl_type& impl_;
index_t pos_;
CK_TILE_HOST_DEVICE constexpr iterator(impl_type& impl, index_t pos)
: impl_{impl}, pos_{pos}
{
}
CK_TILE_HOST_DEVICE constexpr iterator& operator++()
{
pos_++;
return *this;
}
CK_TILE_HOST_DEVICE constexpr bool operator!=(const iterator& other) const
{
return other.pos_ != pos_;
}
CK_TILE_HOST_DEVICE constexpr pair_type& operator*() { return impl_.at(pos_); }
};
struct const_iterator
{
const impl_type& impl_;
index_t pos_;
CK_TILE_HOST_DEVICE constexpr const_iterator(const impl_type& impl, index_t pos)
: impl_{impl}, pos_{pos}
{
}
CK_TILE_HOST_DEVICE constexpr const_iterator& operator++()
{
pos_++;
return *this;
}
CK_TILE_HOST_DEVICE constexpr bool operator!=(const const_iterator& other) const
{
return other.pos_ != pos_;
}
CK_TILE_HOST_DEVICE constexpr const pair_type& operator*() const { return impl_.at(pos_); }
};
CK_TILE_HOST_DEVICE constexpr map() : impl_{}, size_{0} {}
CK_TILE_HOST_DEVICE constexpr index_t size() const { return size_; }
CK_TILE_HOST_DEVICE void clear() { size_ = 0; }
CK_TILE_HOST_DEVICE constexpr index_t find_position(const key& key) const
{
for(index_t i = 0; i < size(); i++)
{
if(impl_[i].template at<0>() == key)
{
return i;
}
}
return size_;
}
CK_TILE_HOST_DEVICE constexpr const_iterator find(const key& key) const
{
return const_iterator{impl_, find_position(key)};
}
CK_TILE_HOST_DEVICE constexpr iterator find(const key& key)
{
return iterator{impl_, find_position(key)};
}
CK_TILE_HOST_DEVICE constexpr const data& operator[](const key& key) const
{
const auto it = find(key);
// FIXME
assert(it.pos_ < size());
return impl_[it.pos_].template at<1>();
}
CK_TILE_HOST_DEVICE constexpr data& operator()(const key& key)
{
auto it = find(key);
// if entry not found
if(it.pos_ == size())
{
impl_(it.pos_).template at<0>() = key;
size_++;
}
// FIXME
assert(size_ <= max_size);
return impl_(it.pos_).template at<1>();
}
// WARNING: needed by compiler for C++ range-based for loop only, don't use this function!
CK_TILE_HOST_DEVICE constexpr const_iterator begin() const { return const_iterator{impl_, 0}; }
// WARNING: needed by compiler for C++ range-based for loop only, don't use this function!
CK_TILE_HOST_DEVICE constexpr const_iterator end() const
{
return const_iterator{impl_, size_};
}
// WARNING: needed by compiler for C++ range-based for loop only, don't use this function!
CK_TILE_HOST_DEVICE constexpr iterator begin() { return iterator{impl_, 0}; }
// WARNING: needed by compiler for C++ range-based for loop only, don't use this function!
CK_TILE_HOST_DEVICE constexpr iterator end() { return iterator{impl_, size_}; }
CK_TILE_HOST_DEVICE void print() const
{
printf("map{size_: %d, ", size_);
//
printf("impl_: [");
//
for(const auto& [key, data] : *this)
{
printf("{key: ");
print(key);
printf(", data: ");
print(data);
printf("}, ");
}
//
printf("]");
//
printf("}");
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,99 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/array.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include <cstddef>
namespace ck_tile {
// TODO: this structure is not intented to be used by user
template <index_t MaxSize>
struct meta_data_buffer
{
CK_TILE_HOST_DEVICE constexpr meta_data_buffer() : buffer_{}, size_{0} {}
template <typename X, typename... Xs>
CK_TILE_HOST_DEVICE constexpr meta_data_buffer(const X& x, const Xs&... xs)
: buffer_{}, size_{0}
{
push(x, xs...);
}
template <typename T>
CK_TILE_HOST_DEVICE constexpr void push(const T& data)
{
if constexpr(!std::is_empty_v<T>)
{
constexpr index_t size = sizeof(T);
auto tmp = bit_cast<array<std::byte, size>>(data);
for(int i = 0; i < size; i++)
{
buffer_(size_) = tmp[i];
size_++;
}
}
}
template <typename X, typename... Xs>
CK_TILE_HOST_DEVICE constexpr void push(const X& x, const Xs&... xs)
{
push(x);
push(xs...);
}
template <typename T>
CK_TILE_HOST_DEVICE constexpr T pop(index_t& pos) const
{
T data;
if constexpr(!std::is_empty_v<T>)
{
constexpr index_t size = sizeof(T);
array<std::byte, size> tmp;
for(int i = 0; i < size; i++)
{
tmp(i) = buffer_[pos];
pos++;
}
data = bit_cast<T>(tmp);
}
return data;
}
template <typename T>
CK_TILE_HOST_DEVICE constexpr T get(index_t pos) const
{
constexpr index_t size = sizeof(T);
array<std::byte, size> tmp;
for(int i = 0; i < size; i++)
{
tmp(i) = buffer_[pos];
pos++;
}
auto data = bit_cast<T>(tmp);
return data;
}
//
array<std::byte, MaxSize> buffer_;
index_t size_ = 0;
};
} // namespace ck_tile

View File

@@ -0,0 +1,99 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/array.hpp"
#include "ck_tile/core/container/sequence.hpp"
#include "ck_tile/core/container/tuple.hpp"
#include "ck_tile/core/utility/functional.hpp"
namespace ck_tile {
// deprecated, always use array instead
template <index_t N>
using multi_index = array<index_t, N>;
template <typename... Xs>
CK_TILE_HOST_DEVICE constexpr auto make_multi_index(Xs&&... xs)
{
return make_array<index_t>(index_t{xs}...);
}
template <index_t NSize>
CK_TILE_HOST_DEVICE constexpr auto make_zero_multi_index()
{
return unpack([](auto... xs) { return make_multi_index(xs...); },
typename uniform_sequence_gen<NSize, 0>::type{});
}
template <typename T>
CK_TILE_HOST_DEVICE constexpr auto to_multi_index(const T& x)
{
return unpack([](auto... ys) { return make_multi_index(ys...); }, x);
}
template <index_t NSize, typename X>
CK_TILE_HOST_DEVICE constexpr auto operator+=(multi_index<NSize>& y, const X& x)
{
static_assert(X::size() == NSize, "wrong! size not the same");
static_for<0, NSize, 1>{}([&](auto i) { y[i] += x[i]; });
return y;
}
template <index_t NSize, typename X>
CK_TILE_HOST_DEVICE constexpr auto operator-=(multi_index<NSize>& y, const X& x)
{
static_assert(X::size() == NSize, "wrong! size not the same");
static_for<0, NSize, 1>{}([&](auto i) { y[i] -= x[i]; });
return y;
}
template <index_t NSize, typename T>
CK_TILE_HOST_DEVICE constexpr auto operator+(const multi_index<NSize>& a, const T& b)
{
using type = multi_index<NSize>;
static_assert(T::size() == NSize, "wrong! size not the same");
type r;
static_for<0, NSize, 1>{}([&](auto i) { r[i] = a[i] + b[i]; });
return r;
}
template <index_t NSize, typename T>
CK_TILE_HOST_DEVICE constexpr auto operator-(const multi_index<NSize>& a, const T& b)
{
using type = multi_index<NSize>;
static_assert(T::size() == NSize, "wrong! size not the same");
type r;
static_for<0, NSize, 1>{}([&](auto i) { r[i] = a[i] - b[i]; });
return r;
}
template <index_t NSize, typename T>
CK_TILE_HOST_DEVICE constexpr auto operator*(const multi_index<NSize>& a, const T& b)
{
using type = multi_index<NSize>;
static_assert(T::size() == NSize, "wrong! size not the same");
type r;
static_for<0, NSize, 1>{}([&](auto i) { r[i] = a[i] * b[i]; });
return r;
}
// multi_index = index_t * multi_index
template <index_t NSize>
CK_TILE_HOST_DEVICE constexpr auto operator*(index_t a, const multi_index<NSize>& x)
{
multi_index<NSize> r;
static_for<0, NSize, 1>{}([&](auto i) { r[i] = a * x[i]; });
return r;
}
// multi_index = multi_index * index_t
template <index_t NSize>
CK_TILE_HOST_DEVICE constexpr auto operator*(const multi_index<NSize>& x, index_t a)
{
return a * x;
}
} // namespace ck_tile

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,78 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include <cstddef>
#include <array>
#include <type_traits>
namespace ck_tile {
// implement the c++20 std::span, lightweight, non-owning reference to a sequence
// weather it is dynamic or static range. Or can be seen as a view of a contiguous sequence
// TODO: do we need in device consider this is pointer?
template <typename T>
class span
{
public:
using element_type = T;
using value_type = std::remove_cv_t<element_type>;
using size_type = std::size_t;
using difference_type = std::ptrdiff_t;
using pointer = element_type*;
using const_pointer = const element_type*;
using reference = element_type&;
using const_reference = const element_type&;
using iterator = pointer;
using const_iterator = pointer;
CK_TILE_HOST_DEVICE constexpr span() : span(nullptr, size_type{0}) {}
CK_TILE_HOST_DEVICE constexpr span(pointer first, size_type count) : ptr_(first), size_(count)
{
}
CK_TILE_HOST_DEVICE constexpr span(pointer first, pointer last) : span(first, last - first) {}
template <std::size_t N>
CK_TILE_HOST_DEVICE constexpr span(element_type (&arr)[N]) noexcept : span(arr, N)
{
}
template <std::size_t N>
CK_TILE_HOST_DEVICE constexpr span(std::array<value_type, N>& arr) noexcept
: span(arr.data(), N)
{
}
template <typename Container>
CK_TILE_HOST_DEVICE constexpr span(const Container& container)
: span(container.data(), container.size())
{
}
CK_TILE_HOST_DEVICE constexpr iterator begin() const noexcept { return ptr_; }
CK_TILE_HOST_DEVICE constexpr const_iterator cbegin() const noexcept { return begin(); }
CK_TILE_HOST_DEVICE constexpr iterator end() const noexcept { return begin() + size(); }
CK_TILE_HOST_DEVICE constexpr const_iterator cend() const noexcept { return end(); }
CK_TILE_HOST_DEVICE constexpr reference front() const { return *begin(); }
CK_TILE_HOST_DEVICE constexpr reference back() const { return *(--end()); }
CK_TILE_HOST_DEVICE constexpr reference operator[](size_type idx) const
{
return *(begin() + idx);
}
CK_TILE_HOST_DEVICE constexpr pointer data() const noexcept { return ptr_; }
CK_TILE_HOST_DEVICE constexpr size_type size() const noexcept { return size_; }
private:
pointer ptr_;
size_type size_;
};
} // namespace ck_tile

View File

@@ -0,0 +1,70 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/array.hpp"
#include "ck_tile/core/numeric/integer.hpp"
namespace ck_tile {
#if CK_TILE_STATICALLY_INDEXED_ARRAY_DEFAULT == CK_TILE_STATICALLY_INDEXED_ARRAY_USE_TUPLE
namespace detail {
template <typename X, typename Y>
struct tuple_concat;
template <typename... Xs, typename... Ys>
struct tuple_concat<tuple<Xs...>, tuple<Ys...>>
{
using type = tuple<Xs..., Ys...>;
};
template <typename T, index_t N>
struct statically_indexed_array_impl
{
using type =
typename tuple_concat<typename statically_indexed_array_impl<T, N / 2>::type,
typename statically_indexed_array_impl<T, N - N / 2>::type>::type;
};
template <typename T>
struct statically_indexed_array_impl<T, 0>
{
using type = tuple<>;
};
template <typename T>
struct statically_indexed_array_impl<T, 1>
{
using type = tuple<T>;
};
} // namespace detail
template <typename T, index_t N>
using statically_indexed_array = typename detail::statically_indexed_array_impl<T, N>::type;
#else
// consider mark this struct as deprecated
template <typename T, index_t N>
using statically_indexed_array = array<T, N>;
#endif
// consider always use ck_tile::array for this purpose
template <typename X, typename... Xs>
CK_TILE_HOST_DEVICE constexpr auto make_statically_indexed_array(const X& x, const Xs&... xs)
{
return statically_indexed_array<X, sizeof...(Xs) + 1>(x, static_cast<X>(xs)...);
}
// make empty statically_indexed_array
template <typename X>
CK_TILE_HOST_DEVICE constexpr auto make_statically_indexed_array()
{
return statically_indexed_array<X, 0>();
}
} // namespace ck_tile

View File

@@ -0,0 +1,483 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/sequence.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include <utility>
namespace ck_tile {
namespace impl {
template <index_t idx, typename T, bool is_empty = std::is_empty_v<T>>
struct tuple_element
{
};
template <index_t idx, typename T>
struct tuple_element<idx, T, true>
{
CK_TILE_HOST_DEVICE constexpr tuple_element() {}
CK_TILE_HOST_DEVICE constexpr tuple_element(const T&) {}
};
template <index_t idx, typename T>
struct tuple_element<idx, T, false>
{
CK_TILE_HOST_DEVICE constexpr tuple_element() {}
CK_TILE_HOST_DEVICE constexpr tuple_element(const T& e) : element(e) {}
T element;
};
template <std::size_t I, class T>
CK_TILE_HOST_DEVICE constexpr T const& getv(tuple_element<I, T, false> const& x)
{
return x.element;
}
template <std::size_t I, class T>
CK_TILE_HOST_DEVICE constexpr T& getv(tuple_element<I, T, false>& x)
{
return x.element;
}
template <std::size_t I, class T>
CK_TILE_HOST_DEVICE constexpr T&& getv(tuple_element<I, T, false>&& x)
{
return static_cast<T&&>(x.element);
}
template <typename index_seq, typename... T>
struct tuple_base;
template <index_t... I, typename... T>
struct tuple_base<sequence<I...>, T...> : public tuple_element<I, T>...
{
CK_TILE_HOST_DEVICE constexpr tuple_base() {}
template <class... U>
CK_TILE_HOST_DEVICE constexpr explicit tuple_base(U const&... u) : tuple_element<I, T>(u)...
{
}
template <class... U>
CK_TILE_HOST_DEVICE constexpr tuple_base(tuple_base<sequence<I...>, U...> const& u)
: tuple_element<I, T>(getv(static_cast<tuple_element<I, U> const&>(u)))...
{
}
};
} // namespace impl
template <class... T>
struct tuple : impl::tuple_base<make_index_sequence<sizeof...(T)>, T...>
{
CK_TILE_HOST_DEVICE
static constexpr auto size() { return sizeof...(T); }
using base = impl::tuple_base<make_index_sequence<sizeof...(T)>, T...>;
CK_TILE_HOST_DEVICE constexpr tuple() {}
template <class... U>
CK_TILE_HOST_DEVICE constexpr tuple(U const&... u)
: impl::tuple_base<make_index_sequence<sizeof...(T)>, T...>(u...)
{
}
template <class... U>
CK_TILE_HOST_DEVICE constexpr tuple(tuple<U...> const& u)
: impl::tuple_base<make_index_sequence<sizeof...(T)>, T...>(
static_cast<impl::tuple_base<U...> const&>(u))
{
}
CK_TILE_HOST_DEVICE static constexpr bool is_static()
{
bool flag = true;
static_for<0, sizeof...(Xs), 1>{}([&flag](auto i) {
flag &= is_static_v<remove_cvref_t<__type_pack_element<i.value, Xs...>>>;
});
return flag;
}
#define TP_COM_() static_assert(I < size(), "wrong! out of range")
// clang-format off
template<index_t I> CK_TILE_HOST_DEVICE constexpr const auto & get() const { TP_COM_(); return impl::getv<I>(*this); }
template<index_t I> CK_TILE_HOST_DEVICE constexpr const auto & get(number<I>) const { TP_COM_(); return get<I>(); }
template<index_t I> CK_TILE_HOST_DEVICE constexpr auto & get() { TP_COM_(); return impl::getv<I>(*this); }
template<index_t I> CK_TILE_HOST_DEVICE constexpr auto & get(number<I>) { TP_COM_(); return get<I>(); }
template<index_t I> CK_TILE_HOST_DEVICE constexpr const auto & at() const { TP_COM_(); return impl::getv<I>(*this); }
template<index_t I> CK_TILE_HOST_DEVICE constexpr const auto & at(number<I>) const { TP_COM_(); return get<I>(); }
template<index_t I> CK_TILE_HOST_DEVICE constexpr auto & at() { TP_COM_(); return impl::getv<I>(*this); }
template<index_t I> CK_TILE_HOST_DEVICE constexpr auto & at(number<I>) { TP_COM_(); return get<I>(); }
template<index_t I> CK_TILE_HOST_DEVICE constexpr auto & operator[](number<I>) { TP_COM_(); return get<I>(); }
template<index_t I> CK_TILE_HOST_DEVICE constexpr const auto & operator[](number<I>) const { TP_COM_(); return get<I>(); }
template<index_t I> CK_TILE_HOST_DEVICE constexpr auto & operator()(number<I>) { TP_COM_(); return get<I>(); } // TODO: compatible
// clang-format on
#undef TP_COM_
};
// template <class... T>
// CK_TILE_HOST_DEVICE constexpr
// tuple<T...>
// make_tuple(T const&... t)
// {
// return {t...};
// }
template <typename... Xs>
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs&&... xs)
{
return tuple<remove_cvref_t<Xs>...>(std::forward<Xs>(xs)...);
}
// https://en.cppreference.com/w/cpp/utility/tuple/tie
template <typename... Args>
constexpr tuple<Args&...> tie(Args&... args) noexcept
{
return {args...};
}
template <typename X, typename Y>
struct tuple_concat;
template <typename... Xs, typename... Ys>
struct tuple_concat<tuple<Xs...>, tuple<Ys...>>
{
using type = tuple<Xs..., Ys...>;
};
template <typename F, index_t N>
CK_TILE_HOST_DEVICE constexpr auto generate_tuple(F&& f, number<N>)
{
return unpack([&f](auto&&... is) { return make_tuple(f(is)...); },
typename arithmetic_sequence_gen<0, N, 1>::type{});
}
template <typename F, index_t N>
CK_TILE_HOST_DEVICE constexpr auto generate_tie(F&& f, number<N>)
{
return unpack([&f](auto&&... is) { return tie(f(is)...); },
typename arithmetic_sequence_gen<0, N, 1>::type{});
}
// tx and ty are tuple of references, return type of will tuple of referennce (not rvalue)
template <typename... X, typename... Y>
CK_TILE_HOST_DEVICE constexpr auto concat_tuple_of_reference(const tuple<X&...>& tx,
const tuple<Y&...>& ty)
{
return unpack2(
[&](auto&&... zs) { return tuple<decltype(zs)...>{std::forward<decltype(zs)>(zs)...}; },
tx,
ty);
}
template <typename... X, typename... Y>
CK_TILE_HOST_DEVICE constexpr auto concat_tuple(const tuple<X...>& tx, const tuple<Y...>& ty)
{
return unpack2(
[&](auto... zs) { return tuple<decltype(zs)...>{std::forward<decltype(zs)>(zs)...}; },
tx,
ty);
}
// Support any number of tuples to concat (also 1)
template <typename... X>
CK_TILE_HOST_DEVICE constexpr auto concat_tuple(const tuple<X...>& tx)
{
return tx;
}
template <typename... X, typename... Tuples>
CK_TILE_HOST_DEVICE constexpr auto concat_tuple(const tuple<X...>& tx, const Tuples&... tuples)
{
return concat_tuple(tx, concat_tuple(tuples...));
}
namespace detail {
template <typename F, typename X, index_t... Is>
CK_TILE_HOST_DEVICE constexpr auto transform_tuples_impl(F f, const X& x, sequence<Is...>)
{
return make_tuple(f(x.at(number<Is>{}))...);
}
template <typename F, typename X, typename Y, index_t... Is>
CK_TILE_HOST_DEVICE constexpr auto
transform_tuples_impl(F f, const X& x, const Y& y, sequence<Is...>)
{
return make_tuple(f(x.at(number<Is>{}), y.at(number<Is>{}))...);
}
template <typename F, typename X, typename Y, typename Z, index_t... Is>
CK_TILE_HOST_DEVICE constexpr auto
transform_tuples_impl(F f, const X& x, const Y& y, const Z& z, sequence<Is...>)
{
return make_tuple(f(x.at(number<Is>{}), y.at(number<Is>{}), z.at(number<Is>{}))...);
}
} // namespace detail
template <typename F, typename X>
CK_TILE_HOST_DEVICE constexpr auto transform_tuples(F f, const X& x)
{
return detail::transform_tuples_impl(
f, x, typename arithmetic_sequence_gen<0, X::size()(), 1>::type{});
}
template <typename F, typename X, typename Y>
CK_TILE_HOST_DEVICE constexpr auto transform_tuples(F f, const X& x, const Y& y)
{
return detail::transform_tuples_impl(
f, x, y, typename arithmetic_sequence_gen<0, X::size()(), 1>::type{});
}
template <typename F, typename X, typename Y, typename Z>
CK_TILE_HOST_DEVICE constexpr auto transform_tuples(F f, const X& x, const Y& y, const Z& z)
{
return detail::transform_tuples_impl(
f, x, y, z, typename arithmetic_sequence_gen<0, X::size()(), 1>::type{});
}
// By default unroll to the flatten
template <index_t Depth = 0, index_t MaxDepth = -1>
CK_TILE_HOST_DEVICE constexpr auto unroll_nested_tuple(const tuple<>& element)
{
return element;
}
template <index_t Depth = 0, index_t MaxDepth = -1, typename T>
CK_TILE_HOST_DEVICE constexpr auto unroll_nested_tuple(const T& element)
{
return make_tuple(element);
}
template <index_t Depth = 0, index_t MaxDepth = -1, typename... Ts>
CK_TILE_HOST_DEVICE constexpr auto unroll_nested_tuple(const tuple<Ts...>& tuple)
{
if constexpr(Depth == MaxDepth)
{
return tuple;
}
else
{
return unpack(
[&](auto&&... ts) {
return concat_tuple(unroll_nested_tuple<Depth + 1, MaxDepth>(ts)...);
},
tuple);
}
}
template <typename... Ts>
CK_TILE_HOST_DEVICE constexpr auto tuple_reverse(const tuple<Ts...>& tuple)
{
return generate_tuple(
[&](auto i) {
using Idx = number<tuple<Ts...>::size()() - i - 1>;
return tuple.at(Idx{});
},
number<tuple<Ts...>::size()()>{});
}
// Reduce tuple values in specific range using Function
template <index_t Idx, index_t End, typename F, typename... Ts>
CK_TILE_HOST_DEVICE constexpr auto tuple_reduce(F&& f, const tuple<Ts...>& tuple)
{
static_assert(Idx < End, "Wrong parameters for tuple_reduce");
if constexpr(Idx + 1 == End)
{
return tuple.at(number<Idx>{});
}
else
{
return f(tuple.at(number<Idx>{}), tuple_reduce<Idx + 1, End>(f, tuple));
}
}
template <typename T>
using is_tuple = decltype(std::declval<T&>().IsTuple());
template <typename... Ts>
CK_TILE_HOST_DEVICE constexpr auto is_nested_tuple(const tuple<Ts...>&)
{
return (is_detected<is_tuple, Ts>::value || ...);
}
template <index_t depth = 0, typename T>
CK_TILE_HOST_DEVICE constexpr auto tuple_depth(const T&)
{
return depth;
}
template <index_t depth = 0, typename... Ts>
CK_TILE_HOST_DEVICE constexpr auto tuple_depth(const tuple<Ts...>&)
{
return math::max(tuple_depth<depth + 1>(Ts{})...);
}
template <typename... Seqs>
CK_TILE_HOST_DEVICE constexpr auto to_array_of_array(tuple<Seqs...> t_of_s)
{
constexpr index_t n0 = sizeof...(Seqs);
constexpr index_t max_n1 = [&] {
index_t max_n1_ = 0;
static_for<0, n0, 1>{}([&](auto i0) {
constexpr index_t n1 = t_of_s[i0].size()();
max_n1_ = max_n1_ < n1 ? n1 : max_n1_;
});
return max_n1_;
}();
array<array<index_t, max_n1>, n0> a_of_a{{-1}};
static_for<0, n0, 1>{}([&](auto i0) {
constexpr index_t n1 = t_of_s[i0].size()();
static_for<0, n1, 1>{}([&](auto i1) { a_of_a(i0)(i1) = t_of_s[i0][i1]; });
});
return a_of_a;
}
// Here should use MultiIndex<NSize>, instead of tuple<Ys...>, although the former
// is the alias of the latter. This is because compiler cannot infer the NSize if
// using MultiIndex<NSize>
// TODO: how to fix this?
template <typename... Ys,
typename X,
std::enable_if_t<!std::is_integral<X>::value && !std::is_floating_point<X>::value, bool> =
false>
CK_TILE_HOST_DEVICE constexpr auto operator+=(tuple<Ys...>& y, const X& x)
{
static_assert(X::Size() == sizeof...(Ys), "wrong! size not the same");
constexpr index_t NSize = sizeof...(Ys);
static_for<0, NSize, 1>{}([&](auto i) { y[i] += x[i]; });
return y;
}
template <typename... Ys,
typename X,
std::enable_if_t<!std::is_integral<X>::value && !std::is_floating_point<X>::value, bool> =
false>
CK_TILE_HOST_DEVICE constexpr auto operator-=(tuple<Ys...>& y, const X& x)
{
static_assert(X::Size() == sizeof...(Ys), "wrong! size not the same");
constexpr index_t NSize = sizeof...(Ys);
static_for<0, NSize, 1>{}([&](auto i) { y[i] -= x[i]; });
return y;
}
template <typename... Xs,
typename Y,
std::enable_if_t<!std::is_integral<Y>::value && !std::is_floating_point<Y>::value, bool> =
false>
CK_TILE_HOST_DEVICE constexpr auto operator+(const tuple<Xs...>& x, const Y& y)
{
static_assert(Y::Size() == sizeof...(Xs), "wrong! size not the same");
constexpr index_t NSize = sizeof...(Xs);
tuple<Xs...> r;
static_for<0, NSize, 1>{}([&](auto i) { r[i] = x[i] + y[i]; });
return r;
}
template <typename... Xs,
typename Y,
std::enable_if_t<!std::is_integral<Y>::value && !std::is_floating_point<Y>::value, bool> =
false>
CK_TILE_HOST_DEVICE constexpr auto operator-(const tuple<Xs...>& x, const Y& y)
{
static_assert(Y::Size() == sizeof...(Xs), "wrong! size not the same");
constexpr index_t NSize = sizeof...(Xs);
tuple<Xs...> r;
static_for<0, NSize, 1>{}([&](auto i) { r[i] = x[i] - y[i]; });
return r;
}
template <typename... Xs,
typename Y,
std::enable_if_t<!std::is_integral<Y>::value && !std::is_floating_point<Y>::value, bool> =
false>
CK_TILE_HOST_DEVICE constexpr auto operator*(const tuple<Xs...>& x, const Y& y)
{
static_assert(Y::Size() == sizeof...(Xs), "wrong! size not the same");
constexpr index_t NSize = sizeof...(Xs);
tuple<Xs...> r;
static_for<0, NSize, 1>{}([&](auto i) { r[i] = x[i] * y[i]; });
return r;
}
// MultiIndex = scalar * MultiIndex
template <
typename... Xs,
typename Y,
std::enable_if_t<std::is_integral<Y>::value || std::is_floating_point<Y>::value, bool> = false>
CK_TILE_HOST_DEVICE constexpr auto operator*(Y a, const tuple<Xs...>& x)
{
constexpr index_t NSize = sizeof...(Xs);
tuple<Xs...> r;
static_for<0, NSize, 1>{}([&](auto i) { r[i] = a * x[i]; });
return r;
}
// MultiIndex = MultiIndex * scalar
template <
typename... Xs,
typename Y,
std::enable_if_t<std::is_integral<Y>::value || std::is_floating_point<Y>::value, bool> = false>
CK_TILE_HOST_DEVICE constexpr auto operator*(const tuple<Xs...>& x, Y a)
{
return a * x;
}
template <typename... Xs, typename... Ys>
CK_TILE_HOST_DEVICE constexpr auto operator/(const tuple<Xs...>& x, const tuple<Ys...>& y)
{
static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong!");
constexpr index_t NSize = sizeof...(Xs);
return generate_tuple([&](auto i) { return x[i] / y[i]; }, number<NSize>{});
}
} // namespace ck_tile
// WARNING: needed by compiler for C++ structured binding support only, don't use this
namespace std {
template <typename... Ts>
struct tuple_size<ck_tile::tuple<Ts...>> : std::integral_constant<std::size_t, sizeof...(Ts)>
{
};
template <std::size_t I, typename... Ts>
struct tuple_element<I, ck_tile::tuple<Ts...>> : ck_tile::tuple_element<I, ck_tile::tuple<Ts...>>
{
};
template <typename... Ts>
struct tuple_size<const ck_tile::tuple<Ts...>> : std::integral_constant<std::size_t, sizeof...(Ts)>
{
};
template <std::size_t I, typename... Ts>
struct tuple_element<I, const ck_tile::tuple<Ts...>>
: ck_tile::tuple_element<I, const ck_tile::tuple<Ts...>>
{
};
} // namespace std

View File

@@ -0,0 +1,116 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <stdint.h>
#pragma once
#define CK_TILE_ARITHMETIC_USING_FLOAT(type_) \
CK_TILE_HOST_DEVICE \
bool operator==(const type_& x, const type_& y) \
{ \
return static_cast<float>(x) == static_cast<float>(y); \
} \
CK_TILE_HOST_DEVICE \
bool operator!=(const type_& x, const type_& y) \
{ \
return static_cast<float>(x) != static_cast<float>(y); \
} \
CK_TILE_HOST_DEVICE \
bool operator<(const type_& x, const type_& y) \
{ \
return static_cast<float>(x) < static_cast<float>(y); \
} \
CK_TILE_HOST_DEVICE \
bool operator<=(const type_& x, const type_& y) \
{ \
return static_cast<float>(x) <= static_cast<float>(y); \
} \
CK_TILE_HOST_DEVICE \
bool operator>(const type_& x, const type_& y) \
{ \
return static_cast<float>(x) > static_cast<float>(y); \
} \
CK_TILE_HOST_DEVICE \
bool operator>=(const type_& x, const type_& y) \
{ \
return static_cast<float>(x) >= static_cast<float>(y); \
} \
CK_TILE_HOST_DEVICE \
type_ operator+(const type_& x, const type_& y) \
{ \
return type_(static_cast<float>(x) + static_cast<float>(y)); \
} \
CK_TILE_HOST_DEVICE \
type_ operator-(const type_& x) \
{ \
constexpr uint32_t bits = sizeof(type_) * 8; \
constexpr uint32_t mask = 1 << (bits - 1); \
type_ y = x; \
y.data ^= static_cast<typename type_::raw_type>(mask); \
return y; \
} \
CK_TILE_HOST_DEVICE \
type_ operator-(const type_& x, const type_& y) \
{ \
return type_(static_cast<float>(x) - static_cast<float>(y)); \
} \
CK_TILE_HOST_DEVICE \
type_ operator*(const type_& x, const type_& y) \
{ \
return type_(static_cast<float>(x) * static_cast<float>(y)); \
} \
CK_TILE_HOST_DEVICE \
type_ operator/(const type_& x, const type_& y) \
{ \
return type_(static_cast<float>(x) / static_cast<float>(y)); \
} \
CK_TILE_HOST_DEVICE \
type_& operator+=(type_& x, const type_& y) \
{ \
x = type_(static_cast<float>(x) + static_cast<float>(y)); \
return x; \
} \
CK_TILE_HOST_DEVICE \
type_& operator-=(type_& x, const type_& y) \
{ \
x = type_(static_cast<float>(x) - static_cast<float>(y)); \
return x; \
} \
CK_TILE_HOST_DEVICE \
type_& operator*=(type_& x, const type_& y) \
{ \
x = type_(static_cast<float>(x) * static_cast<float>(y)); \
return x; \
} \
CK_TILE_HOST_DEVICE \
type_& operator/=(type_& x, const type_& y) \
{ \
x = type_(static_cast<float>(x) / static_cast<float>(y)); \
return x; \
} \
CK_TILE_HOST_DEVICE \
type_& operator++(type_& x) \
{ \
x = type_(static_cast<float>(x) + 1.f); \
return x; \
} \
CK_TILE_HOST_DEVICE \
type_& operator--(type_& x) \
{ \
x = type_(static_cast<float>(x) - 1.f); \
return x; \
} \
CK_TILE_HOST_DEVICE \
type_ operator++(type_& x, int) \
{ \
type_ y(x); \
x = type_(static_cast<float>(x) + 1.f); \
return y; \
} \
CK_TILE_HOST_DEVICE \
type_ operator--(type_& x, int) \
{ \
type_ y(x); \
x = type_(static_cast<float>(x) - 1.f); \
return y; \
}

View File

@@ -0,0 +1,263 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/numeric/arithmetic.hpp"
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include <stdint.h>
#pragma once
namespace ck_tile {
enum class bf16_rounding_mode
{
standard = 0, // rtn
truncate_with_nan,
truncate,
};
template <bf16_rounding_mode rounding = CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT>
CK_TILE_HOST_DEVICE uint16_t float_to_bf16_raw(float f, constant<rounding> = {});
CK_TILE_HOST_DEVICE
float bf16_to_float_raw(uint16_t x);
// HIP use __hip_bfloat16 as struct
struct alignas(2) bfloat16_t
{
using raw_type = uint16_t;
raw_type data;
CK_TILE_HOST_DEVICE
static bfloat16_t bit_cast(raw_type x)
{
bfloat16_t y;
y.data = x;
return y;
}
// constructor
bfloat16_t() = default;
// construct from float
CK_TILE_HOST_DEVICE
explicit bfloat16_t(const float& x) { data = float_to_bf16_raw(x); }
// construct from int
CK_TILE_HOST_DEVICE
explicit bfloat16_t(const int& x) { data = float_to_bf16_raw(static_cast<float>(x)); }
// construct from unsigned int
CK_TILE_HOST_DEVICE
explicit bfloat16_t(const unsigned int& x) { data = float_to_bf16_raw(static_cast<float>(x)); }
// cast to float
CK_TILE_HOST_DEVICE
explicit operator float() const { return bf16_to_float_raw(data); }
// cast to int
CK_TILE_HOST_DEVICE
explicit operator int() const { return static_cast<int>(bf16_to_float_raw(data)); }
// internal access
CK_TILE_HOST_DEVICE
raw_type& get() { return data; }
CK_TILE_HOST_DEVICE
raw_type get() const { return data; }
};
// round to nearest
CK_TILE_HOST_DEVICE
uint16_t float_to_bf16_rtn_raw(float f)
{
union
{
float fp32;
uint32_t int32;
} u = {f};
if(~u.int32 & 0x7f800000)
{
// When the exponent bits are not all 1s, then the value is zero, normal,
// or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus
// 1 if the least significant bit of the bfloat16 mantissa is 1 (odd).
// This causes the bfloat16's mantissa to be incremented by 1 if the 16
// least significant bits of the float mantissa are greater than 0x8000,
// or if they are equal to 0x8000 and the least significant bit of the
// bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when
// the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already
// has the value 0x7f, then incrementing it causes it to become 0x00 and
// the exponent is incremented by one, which is the next higher FP value
// to the unrounded bfloat16 value. When the bfloat16 value is subnormal
// with an exponent of 0x00 and a mantissa of 0x7F, it may be rounded up
// to a normal value with an exponent of 0x01 and a mantissa of 0x00.
// When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F,
// incrementing it causes it to become an exponent of 0xFF and a mantissa
// of 0x00, which is Inf, the next higher value to the unrounded value.
u.int32 += 0x7fff + ((u.int32 >> 16) & 1); // Round to nearest, round to even
}
else if(u.int32 & 0xffff)
{
// When all of the exponent bits are 1, the value is Inf or NaN.
// Inf is indicated by a zero mantissa. NaN is indicated by any nonzero
// mantissa bit. Quiet NaN is indicated by the most significant mantissa
// bit being 1. Signaling NaN is indicated by the most significant
// mantissa bit being 0 but some other bit(s) being 1. If any of the
// lower 16 bits of the mantissa are 1, we set the least significant bit
// of the bfloat16 mantissa, in order to preserve signaling NaN in case
// the bloat16's mantissa bits are all 0.
u.int32 |= 0x10000; // Preserve signaling NaN
}
return uint16_t(u.int32 >> 16);
}
// Truncate instead of rounding, preserving SNaN
CK_TILE_HOST_DEVICE
uint16_t float_to_bf16_truc_nan_raw(float f)
{
union
{
float fp32;
uint32_t int32;
} u = {f};
return uint16_t(u.int32 >> 16) | (!(~u.int32 & 0x7f800000) && (u.int32 & 0xffff));
}
// Fast truncate instead of rounding, RTZ
CK_TILE_HOST_DEVICE
uint16_t float_to_bf16_truc_raw(float f)
{
union
{
float fp32;
uint32_t int32;
} u = {f};
return uint16_t(u.int32 >> 16);
}
template <bf16_rounding_mode rounding = CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT>
CK_TILE_HOST_DEVICE uint16_t float_to_bf16_raw(float f, constant<rounding> = {})
{
if constexpr(rounding == bf16_rounding_mode::standard)
return float_to_bf16_rtn_raw(f);
else if constexpr(rounding == bf16_rounding_mode::truncate_with_nan)
return float_to_bf16_truc_nan_raw(f);
else
return float_to_bf16_truc_raw(f);
}
CK_TILE_HOST_DEVICE
float bf16_to_float_raw(uint16_t x)
{
union
{
uint32_t int32;
float fp32;
} u = {uint32_t(x) << 16};
return u.fp32;
}
template <bf16_rounding_mode rounding = CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT>
CK_TILE_HOST_DEVICE bfloat16_t float_to_bf16(float f, constant<rounding> = {})
{
return bfloat16_t::bit_cast(float_to_bf16_raw(f, constant<rounding>{}));
}
CK_TILE_HOST_DEVICE
float bf16_to_float(bfloat16_t x) { return static_cast<float>(x); }
template <bf16_rounding_mode rounding = CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT>
CK_TILE_HOST_DEVICE bfloat16_t fp16_to_bf16(half_t f, constant<rounding> = {})
{
return bfloat16_t::bit_cast(float_to_bf16_raw(static_cast<float>(f), constant<rounding>{}));
}
CK_TILE_HOST_DEVICE
float bf16_to_fp16(bfloat16_t x) { return float_to_fp16(static_cast<float>(x)); }
template <class T>
struct numeric_limits;
template <>
struct numeric_limits<bfloat16_t>
{
// minimum finite value, or minimum positive normalized value for float
CK_TILE_HOST_DEVICE static constexpr bfloat16_t min() { return bfloat16_t::bit_cast(0x0080); }
// minumum finite value
CK_TILE_HOST_DEVICE static constexpr bfloat16_t lowest()
{
return bfloat16_t::bit_cast(0xff7f);
}
// maximum finite value
CK_TILE_HOST_DEVICE static constexpr bfloat16_t max() { return bfloat16_t::bit_cast(0x7f7f); }
// difference between 1.0 and next value representable by float
CK_TILE_HOST_DEVICE static constexpr bfloat16_t epsilon()
{
return bfloat16_t::bit_cast(0x1000);
}
// maximum rounding error
CK_TILE_HOST_DEVICE static constexpr bfloat16_t round_error() { return bfloat16_t(0.5f); }
// positive infinity value
CK_TILE_HOST_DEVICE static constexpr bfloat16_t infinity()
{
return bfloat16_t::bit_cast(0x7f80);
}
// quiet NaN
CK_TILE_HOST_DEVICE static constexpr bfloat16_t quiet_NaN()
{
return bfloat16_t::bit_cast(0x7FFF);
}
// signaling NaN
CK_TILE_HOST_DEVICE static constexpr bfloat16_t signaling_NaN()
{
return bfloat16_t::bit_cast(0x7FFF);
}
// smallest positive subnormal value
CK_TILE_HOST_DEVICE static constexpr bfloat16_t denorm_min()
{
return bfloat16_t::bit_cast(0x0001);
}
};
CK_TILE_ARITHMETIC_USING_FLOAT(bfloat16_t)
// math
CK_TILE_HOST_DEVICE
bfloat16_t abs(const bfloat16_t& x) { return bfloat16_t::bit_cast(x.get() & 0x7fff); }
CK_TILE_HOST_DEVICE
bool isnan(const bfloat16_t& x)
{
uint16_t xx = x.get();
return (xx & 0x7FFF) > 0x7C00;
}
CK_TILE_DEVICE
bfloat16_t sqrt(bfloat16_t x)
{
return static_cast<bfloat16_t>(__builtin_amdgcn_sqrtf(static_cast<float>(x)));
};
CK_TILE_DEVICE
bfloat16_t exp(bfloat16_t x) { return static_cast<bfloat16_t>(__expf(static_cast<float>(x))); };
CK_TILE_DEVICE
bfloat16_t exp2(bfloat16_t x) { return static_cast<bfloat16_t>(exp2f(static_cast<float>(x))); };
CK_TILE_DEVICE
bfloat16_t log(bfloat16_t x) { return static_cast<bfloat16_t>(__logf(static_cast<float>(x))); };
using bf16_t = bfloat16_t;
} // namespace ck_tile

View File

@@ -0,0 +1,735 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/utility/limits.hpp"
#include "ck_tile/core/utility/random.hpp"
#include "ck_tile/core/numeric/arithmetic.hpp"
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include <stdint.h>
#include <type_traits>
#pragma once
namespace ck_tile {
// fp8 rounding modes
// use standard for rounding to nearest, the faster one
// use stochastic for stochastic rounding, helps to avoid error accumulation
enum class fp8_rounding_mode
{
standard = 0,
stochastic
};
/*
* ______________NANOO_________________ | ______________IEEE________________
* e4m3 e5m2 | e4m3 e5m2
* bias : 8 16 | 7 15
* inf : 1.0000.000 1.00000.00 | N/A s.11111.00
* Nan : 1.0000.000 1.00000.00 | s.1111.111 s.11111.{01, 10, 11}
* zero : 0.0000.000 0.00000.00 | s.0000.000 s.00000.00
* Max(norm) : s.1111.111 (240) s.11111.11(57344) | s.1111.110(448) s.11110.11(57344)
* Max(snorm): s.0000.111 s.00000.11 | s.0000.111(448) s.00000.11(57344)
* 0.0068359375 2.288818e-05 | 0.013671875 4.57763671875e-05
* Min(norm) : s.0001.000 s.00001.00 | s.0001.000 s.00001.00
* 2^-7(0.00078125) 2^-15(3.05176e-05) | 2^-6(0.015625) 2^-14(6.10352e-05)
* Min(snorm): s.0000.001 s.00000.01 | s.0000.001 s.00000.01
* 2^-10(0.00097656) 2^-17(7.629395e-06)| 2^-9(0.001953125) 2^-16(1.52588e-05)
*/
template <fp8_rounding_mode rounding = static_cast<fp8_rounding_mode>(CK_TILE_FLOAT_TO_FP8_DEFAULT)>
CK_TILE_HOST_DEVICE uint8_t float_to_fp8_raw(float, constant<rounding> = {});
template <fp8_rounding_mode rounding = static_cast<fp8_rounding_mode>(CK_TILE_FLOAT_TO_FP8_DEFAULT)>
CK_TILE_HOST_DEVICE uint8_t float_to_bf8_raw(float, constant<rounding> = {});
CK_TILE_HOST_DEVICE float fp8_to_float_raw(uint8_t);
CK_TILE_HOST_DEVICE float bf8_to_float_raw(uint8_t);
struct alignas(1) float8_e4m3_t
{
static constexpr int exponent = 4;
static constexpr int mantissa = 3;
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
static constexpr int bias = 1 << (exponent - 1); // NANOO
#else
static constexpr int bias = (1 << (exponent - 1)) - 1; // IEEE
#endif
using raw_type = uint8_t;
raw_type data;
CK_TILE_HOST_DEVICE
static float8_e4m3_t bit_cast(raw_type x)
{
float8_e4m3_t y;
y.data = x;
return y;
}
// constructor
float8_e4m3_t() = default;
// construct from float
CK_TILE_HOST_DEVICE
explicit float8_e4m3_t(const float& x) { data = float_to_fp8_raw(x); }
// construct from int
CK_TILE_HOST_DEVICE
explicit float8_e4m3_t(const int& x) { data = float_to_fp8_raw(static_cast<float>(x)); }
// construct from unsigned int
CK_TILE_HOST_DEVICE
explicit float8_e4m3_t(const unsigned int& x)
{
data = float_to_fp8_raw(static_cast<float>(x));
}
// cast to float
CK_TILE_HOST_DEVICE
explicit operator float() const { return fp8_to_float_raw(data); }
// cast to int
CK_TILE_HOST_DEVICE
explicit operator int() const { return static_cast<int>(fp8_to_float_raw(data)); }
// internal access
CK_TILE_HOST_DEVICE
raw_type& get() { return data; }
CK_TILE_HOST_DEVICE
raw_type get() const { return data; }
};
struct alignas(1) float8_e5m2_t
{
static constexpr int exponent = 5;
static constexpr int mantissa = 2;
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
static constexpr int bias = 1 << (exponent - 1); // NANOO
#else
static constexpr int bias = (1 << (exponent - 1)) - 1; // IEEE
#endif
using raw_type = uint8_t;
raw_type data;
CK_TILE_HOST_DEVICE
static float8_e5m2_t bit_cast(raw_type x)
{
float8_e5m2_t y;
y.data = x;
return y;
}
// constructor
float8_e5m2_t() = default;
// construct from float
CK_TILE_HOST_DEVICE
explicit float8_e5m2_t(const float& x) { data = float_to_bf8_raw(x); }
// construct from int
CK_TILE_HOST_DEVICE
explicit float8_e5m2_t(const int& x) { data = float_to_bf8_raw(static_cast<float>(x)); }
// construct from unsigned int
CK_TILE_HOST_DEVICE
explicit float8_e5m2_t(const unsigned int& x)
{
data = float_to_bf8_raw(static_cast<float>(x));
}
// cast to float
CK_TILE_HOST_DEVICE
explicit operator float() const { return bf8_to_float_raw(data); }
// cast to int
CK_TILE_HOST_DEVICE
explicit operator int() const { return static_cast<int>(bf8_to_float_raw(data)); }
// internal access
CK_TILE_HOST_DEVICE
raw_type& get() { return data; }
CK_TILE_HOST_DEVICE
raw_type get() const { return data; }
};
// below is sw fp8 conversion, not utilizing hw instruction
namespace impl {
template <typename X, typename Y, bool negative_zero_nan, bool clip, bool stoch>
CK_TILE_HOST_DEVICE Y run_cast_to_f8(X x, uint32_t rng)
{
// fp8/bf8 exponent/mantissa layout
constexpr int out_exp = numeric_utils<Y>::exp;
constexpr int out_mant = numeric_utils<Y>::mant;
// original type exponent/mantissa layout
constexpr int in_exp = numeric_utils<X>::exp;
constexpr int in_mant = numeric_utils<X>::mant;
int exponent, bias;
uint32_t head, mantissa, sign;
// nan code is same for float and half
constexpr Y nan_code = 0x80;
constexpr uint32_t nan_mask = numeric_utils<X>::nan_mask;
// convert to bitwise
using T_bitwise = typename numeric_utils<X>::bitwise_type;
T_bitwise x_bitwise = *(reinterpret_cast<T_bitwise*>(&x));
// unpack the input, depends on datatype
head = x_bitwise & numeric_utils<X>::head_mask;
mantissa = x_bitwise & numeric_utils<X>::mant_mask;
exponent = (head >> in_mant) & numeric_utils<X>::exp_mask;
sign = head >> (in_exp + in_mant);
bias = numeric_utils<X>::bias;
uint32_t signed_inf = (sign << (in_exp + in_mant)) + (((1 << in_exp) - 1) << in_mant);
uint32_t drop_mask = (1 << (in_mant - out_mant)) - 1;
constexpr int max_exp = (1 << out_exp) - (negative_zero_nan ? 1 : 2);
if constexpr(negative_zero_nan)
{
if((x_bitwise & nan_mask) == nan_mask)
return nan_code;
}
else
{
if((x_bitwise & nan_mask) == nan_mask)
return signed_inf + (mantissa != 0 ? 1 : 0);
}
// check if x is 0.0
if(x_bitwise == 0)
return 0;
// First need to check if it is normal or denorm as there is a difference of implict 1
// Then need to adjust the exponent to align with the F8 exponent, in the meanwhile, shift
// The mantissa. Then for stochastic rounding, add rng to mantissa and truncate. And for
// RNE, no need to add rng. Then probably need to check whether there is carry and adjust
// exponent and mantissa again3
// For IEEE bias mode, the bias is 2^(k-1)-1 where k is the width of exponent bits
const int out_bias = (1 << (out_exp - 1)) - 1 + (negative_zero_nan ? 1 : 0);
const int out_denormal_act_exponent = 1 - out_bias; // actual exponent of f8 denormal
// act_exponent is the actual exponent of fp32/fp16 (after subtracting bias)
// out_exponent is the converted f8 exponent with bias encoding
// exponent_diff is the diff between fp32/fp16 exponent and f8 exponent,
// the difference needs to be adjusted and mantissa shifted
int act_exponent, out_exponent, exponent_diff;
if(exponent == 0)
{ // fp32/fp16 is in denormal.
/* fp32 denormal is below 2^-127 so it is usually not a concern here, we mostly concern fp16
here. In this case, f8 is usually in denormal. But there could be exceptions. fp16 denormal has
exponent bias 15 while bf8 with NANOO has exponent bias 16. It means that there are some numbers in
fp16 denormal but they are bf8 (NANOO) normals - smallest bf8 (NANOO) normal is 2^-15. fp16 numbers
where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8 (NANOO) normal.
In this case, the fp16 mantissa should be shift left by 1 */
act_exponent = exponent - bias + 1;
exponent_diff = out_denormal_act_exponent -
act_exponent; // actual exponent is exponent-bias+1 as it is denormal
}
else
{ // fp32/fp16 is normal with implicit 1
act_exponent = exponent - bias;
if(act_exponent <= out_denormal_act_exponent)
{
/* This is the case where fp32/fp16 is normal but it is in f8 denormal range.
For example fp8 nanoo mode, denormal exponent is -7, but if the fp32/fp16
actual exponent is -7, it is actually larger due to the implict 1,
Therefore it needs to be adjust to -6 and mantissa shift right by 1.
So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */
exponent_diff = out_denormal_act_exponent - act_exponent;
}
else
{ // both fp32/fp16 and f8 are in normal range
exponent_diff =
0; // exponent_diff=0 does not mean there is no difference for this case,
// act_exponent could be larger. Just that it does not need shift mantissa
}
mantissa += (1 << in_mant); // Add the implicit 1 into mantissa
}
bool midpoint = (mantissa & ((1 << (in_mant - out_mant + exponent_diff)) - 1)) ==
(1 << (in_mant - out_mant + exponent_diff - 1));
/* This part is a bit tricky. The judgment of whether it is a tie needs to be done before we
shift right as shift right could rip off some residual part and make something not midpoint look
like midpoint. For example, the fp16 number 0x1002 (0 00100 0000000010), it is larger than
midpoint, but after shift right by 4 bits, it would look like midpoint. */
if(exponent_diff > 0)
mantissa >>= exponent_diff;
else if(exponent_diff == -1)
mantissa <<= -exponent_diff;
bool implicit_one = mantissa & (1 << in_mant);
// if there is no implict 1, it means the f8 is denormal and need to adjust to denorm exponent
out_exponent =
(act_exponent + exponent_diff) /*actual f8 exponent*/ + out_bias - (implicit_one ? 0 : 1);
// Now we have the exponent and mantissa adjusted
bool odd =
mantissa &
(1 << (in_mant - out_mant)); // if the least significant bit that is not truncated is 1
mantissa += (stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1) : mantissa)) & drop_mask;
// Now we deal with overflow
if(out_exponent == 0)
{
if((1 << in_mant) & mantissa)
{
out_exponent = 1; // denormal overflow to become normal, promote exponent
// No need to make 1 implicit now as it will be addressed later
}
}
else
{
if((1 << (in_mant + 1)) & mantissa)
{
mantissa >>= 1;
out_exponent++;
// No need to make 1 implicit now as it will be addressed later
}
}
mantissa >>= (in_mant - out_mant);
if(out_exponent > max_exp)
{
if(clip)
{
mantissa = (1 << out_mant) - 1;
out_exponent = max_exp;
}
else
{
return signed_inf;
}
}
// check if x is 0.0 or -0.0
if(out_exponent == 0 && mantissa == 0)
return negative_zero_nan ? 0 : (sign << (out_exp + out_mant));
mantissa &= (1 << out_mant) - 1;
return (sign << (out_exp + out_mant)) | (out_exponent << out_mant) | mantissa;
}
template <typename X, typename Y, bool negative_zero_nan>
CK_TILE_HOST_DEVICE Y run_cast_from_f8(X x)
{
// fp8/bf8 exponent/mantissa layout
constexpr int in_exp = numeric_utils<X>::exp;
constexpr int in_mant = numeric_utils<X>::mant;
// resulting type exponent/mantissa layout
constexpr int out_exp = numeric_utils<Y>::exp;
constexpr int out_mant = numeric_utils<Y>::mant;
// prepare the codes
constexpr X nan_code = 0x80;
Y Inf, NegInf, NaN, Neg0;
using T_bitwise = typename numeric_utils<Y>::bitwise_type;
constexpr T_bitwise Inf_bitwise = numeric_utils<Y>::Inf;
constexpr T_bitwise NegInf_bitwise = numeric_utils<Y>::NegInf;
constexpr T_bitwise NaN_bitwise = numeric_utils<Y>::NaN;
constexpr T_bitwise Neg0_bitwise = numeric_utils<Y>::Neg0;
Inf = *(reinterpret_cast<const Y*>(&Inf_bitwise));
NegInf = *(reinterpret_cast<const Y*>(&NegInf_bitwise));
NaN = *(reinterpret_cast<const Y*>(&NaN_bitwise));
Neg0 = *(reinterpret_cast<const Y*>(&Neg0_bitwise));
// check if x is 0.0
if(x == 0)
return static_cast<Y>(0);
// unpack the input
uint32_t sign = x >> (in_exp + in_mant);
uint32_t mantissa = x & ((1 << in_mant) - 1);
int exponent = (x & 0x7F) >> in_mant;
constexpr int exp_low_cutoff =
(1 << (out_exp - 1)) - (1 << (in_exp - 1)) + 1 - (negative_zero_nan ? 1 : 0);
T_bitwise retval;
if constexpr(negative_zero_nan)
{
if(x == nan_code)
return NaN;
}
else
{
if(x == nan_code)
return Neg0;
if(exponent == ((1 << in_exp) - 1))
return (mantissa == 0) ? (sign ? NegInf : Inf) : NaN;
}
if((numeric_utils<Y>::mant == 10) && (numeric_utils<X>::mant == 2) && !negative_zero_nan)
{
retval = x;
retval <<= 8;
return *(reinterpret_cast<const Y*>(&retval));
}
// subnormal input
if(exponent == 0)
{
// guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above
int sh = 1 + clz(mantissa) - (32 - in_mant);
mantissa <<= sh;
exponent += 1 - sh;
mantissa &= ((1 << in_mant) - 1);
}
exponent += exp_low_cutoff - 1;
mantissa <<= out_mant - in_mant;
// subnormal output (occurs when T=half, we=5, negative_zero_nan=true)
if(exponent <= 0)
{
mantissa |= 1 << out_mant;
mantissa >>= 1 - exponent;
exponent = 0;
}
retval = (sign << (out_exp + out_mant)) | (exponent << out_mant) | mantissa;
return *(reinterpret_cast<const Y*>(&retval));
}
template <typename X, typename Y, bool negative_zero_nan, bool clip, bool stoch>
CK_TILE_HOST_DEVICE Y cast_to_f8(X x, uint32_t rng)
{
// check datatypes
constexpr bool is_half = std::is_same<X, half_t>::value;
constexpr bool is_float = std::is_same<X, float>::value;
static_assert(is_half || is_float, "Only half and float can be casted.");
return run_cast_to_f8<X, Y, negative_zero_nan, clip, stoch>(x, rng);
}
template <typename X, typename Y, bool negative_zero_nan>
CK_TILE_HOST_DEVICE Y cast_from_f8(X x)
{
// check datatype
constexpr bool is_half = std::is_same<Y, half_t>::value;
constexpr bool is_float = std::is_same<Y, float>::value;
static_assert(is_half || is_float, "only half and float are supported.");
return run_cast_from_f8<X, Y, negative_zero_nan>(x);
}
} // namespace impl
CK_TILE_HOST_DEVICE uint8_t float_to_fp8_sr_raw(float x)
{
constexpr int seed = 42;
uint32_t rng = prand_generator<float, seed>{}(reinterpret_cast<uintptr_t>(&x), x);
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
float max_fp8 = 240.0f;
x = x > max_fp8 ? max_fp8 : (x < -max_fp8 ? -max_fp8 : x);
union
{
float fval;
uint32_t i32val;
uint8_t i8val[4]; // not endian independent
} val;
val.fval = x;
uint32_t ival = 0;
ival = __builtin_amdgcn_cvt_sr_fp8_f32(val.fval, rng, ival, 0); // 0 pos
val.i32val = ival;
return val.i8val[0]; // little endian
#else
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
constexpr fp8_rounding_mode rm = fp8_rounding_mode::stochastic;
return impl::
cast_to_f8<float, fp8_t, negative_zero_nan, clip, (rm == fp8_rounding_mode::stochastic)>(
x, rng);
#endif
}
CK_TILE_HOST_DEVICE uint8_t float_to_bf8_sr_raw(float x)
{
constexpr int seed = 42;
uint32_t rng = prand_generator<float, seed>{}(reinterpret_cast<uintptr_t>(&x), x);
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
union
{
float fval;
uint32_t i32val;
uint8_t i8val[4]; // not endian independent
} val;
val.fval = x;
uint32_t ival = 0;
ival = __builtin_amdgcn_cvt_sr_bf8_f32(val.fval, rng, ival, 0); // 0 pos
val.i32val = ival;
return val.i8val[0]; // little endian
#else
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
constexpr fp8_rounding_mode rm = fp8_rounding_mode::stochastic;
return impl::
cast_to_f8<float, bf8_t, negative_zero_nan, clip, (rm == fp8_rounding_mode::stochastic)>(
x, rng);
#endif
}
CK_TILE_HOST_DEVICE uint8_t float_to_fp8_rtn_raw(float x)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
float max_fp8 = 240.0f;
x = x > max_fp8 ? max_fp8 : (x < -max_fp8 ? -max_fp8 : x);
union
{
float fval;
uint32_t i32val;
uint8_t i8val[4]; // not endian independent
} val;
val.fval = x;
uint32_t ival = 0;
ival = __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival, false); // false -> WORD0
val.i32val = ival;
return val.i8val[0];
#else
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
constexpr fp8_rounding_mode rm = fp8_rounding_mode::standard;
constexpr uint32_t rng = 0;
return impl::
cast_to_f8<float, fp8_t, negative_zero_nan, clip, (rm == fp8_rounding_mode::stochastic)>(
x, rng);
#endif
}
CK_TILE_HOST_DEVICE uint8_t float_to_bf8_rtn_raw(float x)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
union
{
float fval;
uint32_t i32val;
uint8_t i8val[4]; // not endian independent
} val;
val.fval = x;
uint32_t ival = 0;
ival = __builtin_amdgcn_cvt_pk_bf8_f32(val.fval, val.fval, ival, false); // false -> WORD0
val.i32val = ival;
return val.i8val[0];
#else
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
constexpr fp8_rounding_mode rm = fp8_rounding_mode::standard;
constexpr uint32_t rng = 0;
return impl::
cast_to_f8<float, bf8_t, negative_zero_nan, clip, (rm == fp8_rounding_mode::stochastic)>(
x, rng);
#endif
}
// clang-format off
template<fp8_rounding_mode rounding = static_cast<fp8_rounding_mode>(CK_TILE_FLOAT_TO_FP8_DEFAULT)>
CK_TILE_HOST_DEVICE uint8_t float_to_fp8_raw(float x, constant<rounding> = {})
{
if constexpr (rounding == fp8_rounding_mode::standard) return float_to_fp8_rtn_raw(x);
else if constexpr (rounding == fp8_rounding_mode::stochastic) return float_to_fp8_sr_raw(x);
else return uint8_t{0};
}
template<fp8_rounding_mode rounding = static_cast<fp8_rounding_mode>(CK_TILE_FLOAT_TO_FP8_DEFAULT)>
CK_TILE_HOST_DEVICE uint8_t float_to_bf8_raw(float x, constant<rounding> = {})
{
if constexpr (rounding == fp8_rounding_mode::standard) return float_to_bf8_rtn_raw(x);
else if constexpr (rounding == fp8_rounding_mode::stochastic) return float_to_bf8_sr_raw(x);
else return uint8_t{0};
}
CK_TILE_HOST_DEVICE float fp8_to_float_raw(uint8_t x)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
float fval;
uint32_t i32val = static_cast<uint32_t>(x);
fval = __builtin_amdgcn_cvt_f32_fp8(i32val, 0);
// asm volatile("v_cvt_f32_fp8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val));
return fval;
#else
constexpr bool negative_zero_nan = true;
return impl::cast_from_f8<fp8_t, float, negative_zero_nan>(x);
#endif
}
CK_TILE_HOST_DEVICE float bf8_to_float_raw(uint8_t x)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
float fval;
uint32_t i32val = static_cast<uint32_t>(x);
fval = __builtin_amdgcn_cvt_f32_bf8(i32val, 0);
// asm volatile("v_cvt_f32_bf8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val));
return fval;
#else
constexpr bool negative_zero_nan = true;
return impl::cast_from_f8<bf8_t, float, negative_zero_nan>(x);
#endif
}
template<fp8_rounding_mode rounding = static_cast<fp8_rounding_mode>(CK_TILE_FLOAT_TO_FP8_DEFAULT)>
CK_TILE_HOST_DEVICE float8_e4m3_t float_to_fp8(float x, constant<rounding> = {})
{
return float8_e4m3_t::bit_cast(float_to_fp8_raw(x, constant<rounding>{}));
}
template<fp8_rounding_mode rounding = static_cast<fp8_rounding_mode>(CK_TILE_FLOAT_TO_FP8_DEFAULT)>
CK_TILE_HOST_DEVICE float8_e5m2_t float_to_bf8(float x, constant<rounding> = {})
{
return float8_e5m2_t::bit_cast(float_to_bf8_raw(x, constant<rounding>{}));
}
CK_TILE_HOST_DEVICE float fp8_to_float(float8_e4m3_t x)
{
return fp8_to_float_raw(x.get());
}
CK_TILE_HOST_DEVICE float bf8_to_float(float8_e5m2_t x)
{
return bf8_to_float_raw(x.get());
}
// clang-format on
using fp8_t = float8_e4m3_t;
using bf8_t = float8_e5m2_t;
template <typename T>
struct numeric_utils;
template <>
struct numeric_utils<fp8_t>
{
static constexpr int exp = fp8_t::exponent;
static constexpr int mant = fp8_t::mantissa;
static constexpr int bias = fp8_t::bias;
};
template <>
struct numeric_utils<bf8_t>
{
static constexpr int exp = bf8_t::exponent;
static constexpr int mant = bf8_t::mantissa;
static constexpr int bias = bf8_t::bias;
};
template <class T>
struct numeric_limits;
template <>
struct numeric_limits<fp8_t>
{
// minimum finite value, or minimum positive normalized value for float
CK_TILE_HOST_DEVICE static constexpr fp8_t min() { return fp8_t::bit_cast(0x08); }
// minumum finite value
CK_TILE_HOST_DEVICE static constexpr fp8_t lowest() { return fp8_t::bit_cast(0xff); }
// maximum finite value
CK_TILE_HOST_DEVICE static constexpr fp8_t max() { return fp8_t::bit_cast(0x7f); }
// difference between 1.0 and next value representable by float
CK_TILE_HOST_DEVICE static constexpr fp8_t epsilon() { return fp8_t::bit_cast(0x20); }
// maximum rounding error
CK_TILE_HOST_DEVICE static constexpr fp8_t round_error() { return fp8_t(0.5f); }
// positive infinity value
CK_TILE_HOST_DEVICE static constexpr fp8_t infinity() { return fp8_t::bit_cast(0x80); }
// quiet NaN
CK_TILE_HOST_DEVICE static constexpr fp8_t quiet_NaN() { return fp8_t::bit_cast(0x80); }
// signaling NaN
CK_TILE_HOST_DEVICE static constexpr fp8_t signaling_NaN() { return fp8_t::bit_cast(0x80); }
// smallest positive subnormal value
CK_TILE_HOST_DEVICE static constexpr fp8_t denorm_min() { return fp8_t::bit_cast(0x01); }
};
template <>
struct numeric_limits<bf8_t>
{
// minimum finite value, or minimum positive normalized value for float
CK_TILE_HOST_DEVICE static constexpr bf8_t min() { return bf8_t::bit_cast(0x04); }
// minumum finite value
CK_TILE_HOST_DEVICE static constexpr bf8_t lowest() { return bf8_t::bit_cast(0xff); }
// maximum finite value
CK_TILE_HOST_DEVICE static constexpr bf8_t max() { return bf8_t::bit_cast(0x7f); }
// difference between 1.0 and next value representable by float
CK_TILE_HOST_DEVICE static constexpr bf8_t epsilon() { return bf8_t::bit_cast(0x34); }
// maximum rounding error
CK_TILE_HOST_DEVICE static constexpr bf8_t round_error() { return bf8_t(0.5f); }
// positive infinity value
CK_TILE_HOST_DEVICE static constexpr bf8_t infinity() { return bf8_t::bit_cast(0x80); }
// quiet NaN
CK_TILE_HOST_DEVICE static constexpr bf8_t quiet_NaN() { return bf8_t::bit_cast(0x80); }
// signaling NaN
CK_TILE_HOST_DEVICE static constexpr bf8_t signaling_NaN() { return bf8_t::bit_cast(0x80); }
// smallest positive subnormal value
CK_TILE_HOST_DEVICE static constexpr bf8_t denorm_min() { return bf8_t::bit_cast(0x01); }
};
CK_TILE_ARITHMETIC_USING_FLOAT(fp8_t)
CK_TILE_ARITHMETIC_USING_FLOAT(bf8_t)
// math
CK_TILE_HOST_DEVICE
fp8_t abs(const fp8_t& x) { return fp8_t::bit_cast(x.get() & 0x7f); }
CK_TILE_HOST_DEVICE
bool isnan(const fp8_t& x)
{
uint8_t xx = x.get();
return xx == 0x80; // TODO: NANOO
}
CK_TILE_DEVICE
fp8_t sqrt(fp8_t x) { return static_cast<fp8_t>(__builtin_amdgcn_sqrtf(static_cast<float>(x))); };
CK_TILE_DEVICE
fp8_t exp(fp8_t x) { return static_cast<fp8_t>(__expf(static_cast<float>(x))); };
CK_TILE_DEVICE
fp8_t exp2(fp8_t x) { return static_cast<fp8_t>(exp2f(static_cast<float>(x))); };
CK_TILE_DEVICE
fp8_t log(fp8_t x) { return static_cast<fp8_t>(__logf(static_cast<float>(x))); };
CK_TILE_HOST_DEVICE
bf8_t abs(const bf8_t& x) { return bf8_t::bit_cast(x.get() & 0x7f); }
CK_TILE_HOST_DEVICE
bool isnan(const bf8_t& x)
{
uint8_t xx = x.get();
return xx == 0x80; // TODO: NANOO
}
CK_TILE_DEVICE
bf8_t sqrt(bf8_t x) { return static_cast<bf8_t>(__builtin_amdgcn_sqrtf(static_cast<float>(x))); };
CK_TILE_DEVICE
bf8_t exp(bf8_t x) { return static_cast<bf8_t>(__expf(static_cast<float>(x))); };
CK_TILE_DEVICE
bf8_t exp2(bf8_t x) { return static_cast<bf8_t>(exp2f(static_cast<float>(x))); };
CK_TILE_DEVICE
bf8_t log(bf8_t x) { return static_cast<bf8_t>(__logf(static_cast<float>(x))); };
} // namespace ck_tile

View File

@@ -0,0 +1,278 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include <hip/hip_fp16.h>
#pragma once
namespace ck_tile {
CK_TILE_HOST_DEVICE
float fp16_to_float_hip(const _Float16& x);
CK_TILE_HOST_DEVICE
_Float16 float_to_fp16_hip(const float& x);
// HIP use _Float16 as interchangable data type for float16
struct alignas(2) half_t
{
using raw_type = uint16_t;
raw_type data;
CK_TILE_HOST_DEVICE
static half_t bit_cast(raw_type x)
{
half_t y;
y.data = x;
return y;
}
CK_TILE_HOST_DEVICE
_Float16 to_fp16() const { return reinterpret_cast<const raw_type&>(data); }
// constructor
half_t() = default;
// construct from HIP half
CK_TILE_HOST_DEVICE
explicit half_t(const _Float16& x) : data(reinterpret_cast<const raw_type&>(x)) {}
// construct from float
CK_TILE_HOST_DEVICE
explicit half_t(const float& x) : half_t(float_to_fp16_hip(x)) {}
// construct from int
CK_TILE_HOST_DEVICE
explicit half_t(const int& x) : half_t(__int2half_rn(x)) {}
// construct from unsigned int
CK_TILE_HOST_DEVICE
explicit half_t(const unsigned int& x) : half_t(__uint2half_rn(x)) {}
// cast to float
CK_TILE_HOST_DEVICE
explicit operator float() const { return fp16_to_float_hip(to_fp16()); }
// cast to int
CK_TILE_HOST_DEVICE
explicit operator int() const { return static_cast<int>(fp16_to_float_hip(to_fp16())); }
// internal access
CK_TILE_HOST_DEVICE
raw_type& get() { return data; }
CK_TILE_HOST_DEVICE
raw_type get() const { return data; }
};
// conversions
CK_TILE_HOST_DEVICE
float fp16_to_float_hip(const _Float16& x)
{
// return __half2float(x);
return static_cast<float>(x);
}
CK_TILE_HOST_DEVICE
_Float16 float_to_fp16_hip(const float& x)
{
// return __float2half(x);
return static_cast<_Float16>(x);
}
CK_TILE_HOST_DEVICE
float fp16_to_float(const half_t& x) { return static_cast<float>(x); }
CK_TILE_HOST_DEVICE
half_t float_to_fp16(const float& x) { return half_t{x}; }
// limits
template <class T>
struct numeric_limits;
template <>
struct numeric_limits<half_t>
{
// minimum finite value, or minimum positive normalized value for float
CK_TILE_HOST_DEVICE static constexpr half_t min() { return half_t::bit_cast(0x0400); }
// minumum finite value
CK_TILE_HOST_DEVICE static constexpr half_t lowest() { return half_t::bit_cast(0xFBFF); }
// maximum finite value
CK_TILE_HOST_DEVICE static constexpr half_t max() { return half_t::bit_cast(0x7BFF); }
// difference between 1.0 and next value representable by float
CK_TILE_HOST_DEVICE static constexpr half_t epsilon() { return half_t::bit_cast(0x1800); }
// maximum rounding error
CK_TILE_HOST_DEVICE static constexpr half_t round_error() { return half_t(0.5f); }
// positive infinity value
CK_TILE_HOST_DEVICE static constexpr half_t infinity() { return half_t::bit_cast(0x7C00); }
// quiet NaN
CK_TILE_HOST_DEVICE static constexpr half_t quiet_NaN() { return half_t::bit_cast(0x7FFF); }
// signaling NaN
CK_TILE_HOST_DEVICE static constexpr half_t signaling_NaN() { return half_t::bit_cast(0x7FFF); }
// smallest positive subnormal value
CK_TILE_HOST_DEVICE static constexpr half_t denorm_min() { return half_t::bit_cast(0x0001); }
};
template <typename T>
struct numeric_utils;
template <>
struct numeric_utils<half_t>
{
static constexpr int exp = 5;
static constexpr int mant = 10;
static constexpr int bias = 15;
static constexpr uint16_t nan_mask = 0x7C00;
static constexpr uint16_t head_mask = 0xFC00;
static constexpr uint16_t mant_mask = 0x3FF;
static constexpr uint16_t exp_mask = 0x1F;
static constexpr uint32_t Inf = 0x7C00;
static constexpr uint32_t NegInf = 0xFC00;
static constexpr uint32_t NaN = 0x7C01;
static constexpr uint32_t Neg0 = 0x8000;
using bitwise_type = uint16_t;
};
// arithmetic
CK_TILE_HOST_DEVICE
bool operator==(const half_t& x, const half_t& y) { return __heq(x.to_fp16(), y.to_fp16()); }
CK_TILE_HOST_DEVICE
bool operator!=(const half_t& x, const half_t& y) { return __hne(x.to_fp16(), y.to_fp16()); }
CK_TILE_HOST_DEVICE
bool operator<(const half_t& x, const half_t& y) { return __hlt(x.to_fp16(), y.to_fp16()); }
CK_TILE_HOST_DEVICE
bool operator<=(const half_t& x, const half_t& y) { return __hle(x.to_fp16(), y.to_fp16()); }
CK_TILE_HOST_DEVICE
bool operator>(const half_t& x, const half_t& y) { return __hgt(x.to_fp16(), y.to_fp16()); }
CK_TILE_HOST_DEVICE
bool operator>=(const half_t& x, const half_t& y) { return __hge(x.to_fp16(), y.to_fp16()); }
CK_TILE_HOST_DEVICE
half_t operator+(const half_t& x, const half_t& y)
{
return half_t(__hadd(x.to_fp16(), y.to_fp16()));
}
CK_TILE_HOST_DEVICE
half_t operator-(const half_t& x) { return half_t(__hneg(x.to_fp16())); }
CK_TILE_HOST_DEVICE
half_t operator-(const half_t& x, const half_t& y)
{
return half_t(__hsub(x.to_fp16(), y.to_fp16()));
}
CK_TILE_HOST_DEVICE
half_t operator*(const half_t& x, const half_t& y)
{
return half_t(__hmul(x.to_fp16(), y.to_fp16()));
}
CK_TILE_HOST_DEVICE
half_t operator/(const half_t& x, const half_t& y)
{
return half_t(__hdiv(x.to_fp16(), y.to_fp16()));
}
CK_TILE_HOST_DEVICE
half_t& operator+=(half_t& x, const half_t& y)
{
x = half_t(__hadd(x.to_fp16(), y.to_fp16()));
return x;
}
CK_TILE_HOST_DEVICE
half_t& operator-=(half_t& x, const half_t& y)
{
x = half_t(__hsub(x.to_fp16(), y.to_fp16()));
return x;
}
CK_TILE_HOST_DEVICE
half_t& operator*=(half_t& x, const half_t& y)
{
x = half_t(__hmul(x.to_fp16(), y.to_fp16()));
return x;
}
CK_TILE_HOST_DEVICE
half_t& operator/=(half_t& x, const half_t& y)
{
x = half_t(__hdiv(x.to_fp16(), y.to_fp16()));
return x;
}
CK_TILE_HOST_DEVICE
half_t& operator++(half_t& x)
{
x = half_t(__hadd(x.to_fp16(), half_t(1.0f).to_fp16()));
return x;
}
CK_TILE_HOST_DEVICE
half_t& operator--(half_t& x)
{
x = half_t(__hsub(x.to_fp16(), half_t(1.0f).to_fp16()));
return x;
}
CK_TILE_HOST_DEVICE
half_t operator++(half_t& x, int)
{
half_t y(x);
x = half_t(__hadd(x.to_fp16(), half_t(1.0f).to_fp16()));
return y;
}
CK_TILE_HOST_DEVICE
half_t operator--(half_t& x, int)
{
half_t y(x);
x = half_t(__hsub(x.to_fp16(), half_t(1.0f).to_fp16()));
return y;
}
// math
CK_TILE_HOST_DEVICE
half_t abs(const half_t& x) { return half_t::bit_cast(x.get() & 0x7fff); }
CK_TILE_HOST_DEVICE
bool isnan(const half_t& x)
{
uint16_t xx = x.get();
return (xx & 0x7FFF) > 0x7C00;
}
CK_TILE_DEVICE
half_t sqrt(half_t x)
{
return static_cast<half_t>(__builtin_amdgcn_sqrtf(static_cast<float>(x)));
};
CK_TILE_DEVICE
half_t exp(half_t x) { return static_cast<half_t>(__expf(static_cast<float>(x))); };
CK_TILE_DEVICE
half_t exp2(half_t x) { return static_cast<half_t>(exp2f(static_cast<float>(x))); };
CK_TILE_DEVICE
half_t log(half_t x) { return static_cast<half_t>(__logf(static_cast<float>(x))); };
using fp16_t = half_t;
} // namespace ck_tile

View File

@@ -0,0 +1,13 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <stdint.h>
namespace ck_tile {
using index_t = int32_t;
using long_index_t = int64_t;
using int8_t = int8_t;
} // namespace ck_tile

View File

@@ -0,0 +1,82 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
namespace ck_tile {
template <auto v>
struct constant
{
using value_type = decltype(v);
using type = constant; // using injected-class-name
static constexpr value_type value = v;
constexpr CK_TILE_HOST_DEVICE operator value_type() const noexcept { return value; }
constexpr CK_TILE_HOST_DEVICE value_type operator()() const noexcept { return value; }
};
template <typename T, T v>
struct integral_constant : constant<v>
{
using value_type = T;
using type = integral_constant; // using injected-class-name
static constexpr T value = v;
// constexpr CK_TILE_HOST_DEVICE operator value_type() const noexcept { return value; }
// constexpr CK_TILE_HOST_DEVICE value_type operator()() const noexcept { return value; } //
};
template <index_t v>
using number = constant<v>;
template <long_index_t v>
using long_number = integral_constant<long_index_t, v>;
template <bool b>
using bool_constant = constant<b>;
#define CK_TILE_LEFT_UNARY_OP(OP) \
template <auto x> \
CK_TILE_HOST_DEVICE constexpr auto operator OP(constant<x>) \
{ \
return constant<(OP x)>{}; \
}
#define CK_TILE_BINARY_OP(OP) \
template <auto x, auto y> \
CK_TILE_HOST_DEVICE constexpr auto operator OP(constant<x>, constant<y>) \
{ \
return constant<(x OP y)>{}; \
}
CK_TILE_LEFT_UNARY_OP(+)
CK_TILE_LEFT_UNARY_OP(-)
CK_TILE_LEFT_UNARY_OP(~)
CK_TILE_LEFT_UNARY_OP(!)
CK_TILE_LEFT_UNARY_OP(*)
CK_TILE_BINARY_OP(+)
CK_TILE_BINARY_OP(-)
CK_TILE_BINARY_OP(*)
CK_TILE_BINARY_OP(/)
CK_TILE_BINARY_OP(%)
CK_TILE_BINARY_OP(&)
CK_TILE_BINARY_OP(|)
CK_TILE_BINARY_OP(^)
CK_TILE_BINARY_OP(<<)
CK_TILE_BINARY_OP(>>)
CK_TILE_BINARY_OP(&&)
CK_TILE_BINARY_OP(||)
CK_TILE_BINARY_OP(==)
CK_TILE_BINARY_OP(!=)
CK_TILE_BINARY_OP(>)
CK_TILE_BINARY_OP(<)
CK_TILE_BINARY_OP(>=)
CK_TILE_BINARY_OP(<=)
#undef CK_TILE_LEFT_UNARY_OP
#undef CK_TILE_BINARY_OP
} // namespace ck_tile

View File

@@ -0,0 +1,309 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include <type_traits>
#include <stdint.h>
namespace ck_tile {
template <typename T, T s>
struct scales
{
CK_TILE_HOST_DEVICE constexpr T operator()(T a) const { return s * a; }
};
template <typename T>
struct plus
{
CK_TILE_HOST_DEVICE constexpr T operator()(T a, T b) const { return a + b; }
};
template <typename T>
struct minus
{
CK_TILE_HOST_DEVICE constexpr T operator()(T a, T b) const { return a - b; }
};
struct multiplies
{
template <typename A, typename B>
CK_TILE_HOST_DEVICE constexpr auto operator()(const A& a, const B& b) const
{
return a * b;
}
};
template <typename T>
struct maximize
{
CK_TILE_HOST_DEVICE constexpr T operator()(T a, T b) const { return a >= b ? a : b; }
};
template <typename T>
struct minimize
{
CK_TILE_HOST_DEVICE constexpr T operator()(T a, T b) const { return a <= b ? a : b; }
};
template <typename T>
struct integer_divide_ceiler
{
CK_TILE_HOST_DEVICE constexpr T operator()(T a, T b) const
{
static_assert(std::is_same<T, index_t>{} || std::is_same<T, int>{}, "wrong type");
return (a + b - number<1>{}) / b;
}
};
template <typename X, typename Y>
CK_TILE_HOST_DEVICE constexpr auto integer_divide_floor(X x, Y y)
{
return x / y;
}
template <typename X, typename Y>
CK_TILE_HOST_DEVICE constexpr auto integer_divide_ceil(X x, Y y)
{
return (x + y - number<1>{}) / y;
}
template <typename X, typename Y>
CK_TILE_HOST_DEVICE constexpr auto integer_least_multiple(X x, Y y)
{
return y * integer_divide_ceil(x, y);
}
template <typename T>
CK_TILE_HOST_DEVICE constexpr T max(T x)
{
return x;
}
template <typename T>
CK_TILE_HOST_DEVICE constexpr T max(T x, T y)
{
return x > y ? x : y;
}
template <index_t X>
CK_TILE_HOST_DEVICE constexpr index_t max(number<X>, index_t y)
{
return X > y ? X : y;
}
template <index_t Y>
CK_TILE_HOST_DEVICE constexpr index_t max(index_t x, number<Y>)
{
return x > Y ? x : Y;
}
template <typename X, typename... Ys>
CK_TILE_HOST_DEVICE constexpr auto max(X x, Ys... ys)
{
static_assert(sizeof...(Ys) > 0, "not enough argument");
return max(x, max(ys...));
}
template <typename T>
CK_TILE_HOST_DEVICE constexpr T min(T x)
{
return x;
}
template <typename T>
CK_TILE_HOST_DEVICE constexpr T min(T x, T y)
{
return x < y ? x : y;
}
template <index_t X>
CK_TILE_HOST_DEVICE constexpr index_t min(number<X>, index_t y)
{
return X < y ? X : y;
}
template <index_t Y>
CK_TILE_HOST_DEVICE constexpr index_t min(index_t x, number<Y>)
{
return x < Y ? x : Y;
}
template <typename X, typename... Ys>
CK_TILE_HOST_DEVICE constexpr auto min(X x, Ys... ys)
{
static_assert(sizeof...(Ys) > 0, "not enough argument");
return min(x, min(ys...));
}
template <typename T>
CK_TILE_HOST_DEVICE constexpr T clamp(const T& x, const T& lowerbound, const T& upperbound)
{
return min(max(x, lowerbound), upperbound);
}
// greatest common divisor, aka highest common factor
CK_TILE_HOST_DEVICE constexpr index_t gcd(index_t x, index_t y)
{
if(x < 0)
{
return gcd(-x, y);
}
else if(y < 0)
{
return gcd(x, -y);
}
else if(x == y || x == 0)
{
return y;
}
else if(y == 0)
{
return x;
}
else if(x > y)
{
return gcd(x % y, y);
}
else
{
return gcd(x, y % x);
}
}
template <index_t X, index_t Y>
CK_TILE_HOST_DEVICE constexpr auto gcd(number<X>, number<Y>)
{
constexpr auto r = gcd(X, Y);
return number<r>{};
}
template <typename X,
typename... Ys,
typename std::enable_if<sizeof...(Ys) >= 2, bool>::type = false>
CK_TILE_HOST_DEVICE constexpr auto gcd(X x, Ys... ys)
{
return gcd(x, gcd(ys...));
}
// least common multiple
template <typename X, typename Y>
CK_TILE_HOST_DEVICE constexpr auto lcm(X x, Y y)
{
return (x * y) / gcd(x, y);
}
template <typename X,
typename... Ys,
typename std::enable_if<sizeof...(Ys) >= 2, bool>::type = false>
CK_TILE_HOST_DEVICE constexpr auto lcm(X x, Ys... ys)
{
return lcm(x, lcm(ys...));
}
template <typename T>
struct equal
{
CK_TILE_HOST_DEVICE constexpr bool operator()(T x, T y) const { return x == y; }
};
template <typename T>
struct less
{
CK_TILE_HOST_DEVICE constexpr bool operator()(T x, T y) const { return x < y; }
};
CK_TILE_HOST_DEVICE constexpr int32_t next_power_of_two(int32_t x)
{
// TODO: x need to be 2 ~ 0x7fffffff. 0, 1, or larger than 0x7fffffff will compile fail
return 1 << (32 - __builtin_clz(x - 1));
}
template <index_t X>
CK_TILE_HOST_DEVICE constexpr auto next_power_of_two()
{
constexpr index_t y = next_power_of_two(X);
return number<y>{};
}
template <index_t X>
CK_TILE_HOST_DEVICE constexpr auto next_power_of_two(number<X>)
{
constexpr index_t y = next_power_of_two(X);
return number<y>{};
}
CK_TILE_HOST_DEVICE constexpr int32_t integer_log2_floor(int32_t x)
{
// TODO: x need to be 1 ~ 0x7fffffff
// __builtin_clz will produce unexpected result if x is 0;
return 31 - __builtin_clz(x);
}
CK_TILE_HOST_DEVICE constexpr bool is_power_of_two_integer(int32_t x)
{
// TODO: x need to be 1 ~ 0x7fffffff
return x == (1 << integer_log2_floor(x));
}
#ifndef C_LOG2E
#define C_LOG2E 1.44269504088896340736 // log2(e)
#endif
template <typename T>
struct log2e;
template <>
struct log2e<double>
{
static constexpr double value = C_LOG2E;
};
template <>
struct log2e<float>
{
static constexpr float value = C_LOG2E;
};
template <typename T = double>
inline constexpr T log2e_v = log2e<T>::value;
// math
CK_TILE_HOST_DEVICE
float abs(const float& x)
{
union
{
float f32;
uint32_t u32;
} y;
y.f32 = x;
y.u32 = y.u32 & 0x7fffffff;
return y.f32;
}
CK_TILE_HOST_DEVICE
bool isnan(const float& x)
{
uint32_t xx = reinterpret_cast<const uint32_t&>(x);
return (xx & 0x7fffffff) > 0x7F800000;
}
CK_TILE_DEVICE
float sqrt(float x) { return __builtin_amdgcn_sqrtf(x); };
CK_TILE_DEVICE
float exp(float x) { return __expf(x); };
CK_TILE_DEVICE
float exp2(float x) { return exp2f(x); };
CK_TILE_DEVICE
float log(float x) { return __logf(x); };
} // namespace ck_tile

View File

@@ -0,0 +1,45 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include <type_traits>
namespace ck_tile {
#if 0
// Convert X to Y, both X and Y are non-const data types.
template <typename Y,
typename X,
std::enable_if_t<!(std::is_const_v<Y> || std::is_const_v<X>), bool> = false>
CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
{
static_assert(!std::is_reference_v<Y> && !std::is_reference_v<X>);
return static_cast<Y>(x);
}
// TODO: const version never called, we may never need
// Convert X to Y, either X or Y is a const data type.
template <typename Y,
typename X,
std::enable_if_t<std::is_const_v<Y> || std::is_const_v<X>, bool> = false>
CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
{
static_assert(!std::is_reference_v<Y> && !std::is_reference_v<X>);
using NonConstY = std::remove_const_t<Y>;
using NonConstX = std::remove_const_t<X>;
return static_cast<Y>(type_convert<NonConstY, NonConstX>(x));
}
#else
// compatible way to call conversion operator and constructor of each custom data type
template <typename Y, typename X>
CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
{
static_assert(!std::is_reference_v<Y> && !std::is_reference_v<X>);
return static_cast<Y>(x);
}
#endif
} // namespace ck_tile

View File

@@ -0,0 +1,304 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/array.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/numeric/float8.hpp"
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/bfloat16.hpp"
namespace ck_tile {
// TODO: the whole content of this file should consider deprecated!
template <typename T_, index_t N_>
struct vector_type
{
static constexpr index_t N = N_;
using value_type = T_;
using type = value_type __attribute__((ext_vector_type(N))); // this is danguous
CK_HOST_DEVICE constexpr vector_type()
{
for(auto i = 0; i < N; i++)
data[i] = static_cast<value_type>(0);
}
CK_HOST_DEVICE constexpr vector_type(type v)
{
auto& r = reinterpret_cast<const array<value_type, N>&>(v);
for(auto i = 0; i < N; i++)
data[i] = r.get(i);
}
value_type data[N];
CK_HOST_DEVICE static constexpr auto size() { return N; }
CK_HOST_DEVICE auto& get() { return data; }
CK_HOST_DEVICE const auto& get() const { return data; }
CK_HOST_DEVICE auto& get(index_t i) { return data[i]; }
CK_HOST_DEVICE const auto& get(index_t i) const { return data[i]; }
template <index_t I>
CK_HOST_DEVICE auto& operator[](number<I>)
{
return data[I];
}
template <index_t I>
CK_HOST_DEVICE const auto& operator[](number<I>) const
{
return data[I];
}
template <index_t I>
CK_HOST_DEVICE auto& operator()(number<I>)
{
return data[I];
}
CK_HOST_DEVICE auto& at(index_t i) { return data[i]; }
CK_HOST_DEVICE const auto& at(index_t i) const { return data[i]; }
template <index_t I>
CK_HOST_DEVICE auto& at()
{
return data[I];
}
template <index_t I>
CK_HOST_DEVICE const auto& at() const
{
return data[I];
}
template <index_t I>
CK_HOST_DEVICE auto& at(number<I>)
{
return data[I];
}
template <index_t I>
CK_HOST_DEVICE const auto& at(number<I>) const
{
return data[I];
}
#define _VT_COMMON_AS() \
static_assert(sizeof(value_type) * N % sizeof(Tx) == 0); \
constexpr int vx = sizeof(value_type) * N / sizeof(Tx)
template <typename Tx>
CK_HOST_DEVICE auto& get_as()
{
_VT_COMMON_AS();
return reinterpret_cast<array<Tx, vx>&>(data);
}
template <typename Tx>
CK_HOST_DEVICE const auto& get_as() const
{
_VT_COMMON_AS();
return reinterpret_cast<const array<Tx, vx>&>(data);
}
template <typename Tx>
CK_HOST_DEVICE auto& get_as(index_t i)
{
_VT_COMMON_AS();
return reinterpret_cast<array<Tx, vx>&>(data).get(i);
}
template <typename Tx>
CK_HOST_DEVICE const auto& get_as(index_t i) const
{
_VT_COMMON_AS();
return reinterpret_cast<const array<Tx, vx>&>(data).get(i);
}
#undef _VT_COMMON_AS
};
template <typename T, index_t N>
struct vector_type_maker
{
using type = vector_type<T, N>;
};
template <typename T, index_t N0, index_t N1>
struct vector_type_maker<T __attribute__((ext_vector_type(N1))), N0>
{
using type = vector_type<T, N0 * N1>;
};
template <typename T, index_t N0, index_t N1>
struct vector_type_maker<vector_type<T, N1>, N0>
{
using type = vector_type<T, N0 * N1>;
};
template <typename T, index_t N>
using vector_type_maker_t = typename vector_type_maker<T, N>::type;
template <typename T, index_t N>
CK_HOST_DEVICE constexpr auto make_vector_type(number<N>)
{
return typename vector_type_maker<T, N>::type{};
}
// scalar_type
template <typename TV>
struct scalar_type;
// is_scalar_type
template <typename TV>
struct is_scalar_type
{
static constexpr bool value = (scalar_type<remove_cvref_t<TV>>::vector_size == 1);
};
// has_same_scalar_type
template <typename X, typename Y>
using has_same_scalar_type = is_same<typename scalar_type<remove_cvref_t<X>>::type,
typename scalar_type<remove_cvref_t<Y>>::type>;
template <typename T, index_t N>
struct scalar_type<T __attribute__((ext_vector_type(N)))>
{
using type = T;
static constexpr index_t vector_size = N;
};
template <typename T, index_t N>
struct scalar_type<vector_type<T, N>>
{
using type = T;
static constexpr index_t vector_size = N;
};
//
template <>
struct scalar_type<double>
{
using type = double;
static constexpr index_t vector_size = 1;
};
template <>
struct scalar_type<float>
{
using type = float;
static constexpr index_t vector_size = 1;
};
template <>
struct scalar_type<half_t>
{
using type = half_t;
static constexpr index_t vector_size = 1;
};
template <>
struct scalar_type<bhalf_t>
{
using type = bhalf_t;
static constexpr index_t vector_size = 1;
};
template <>
struct scalar_type<int64_t>
{
using type = int64_t;
static constexpr index_t vector_size = 1;
};
template <>
struct scalar_type<int32_t>
{
using type = int32_t;
static constexpr index_t vector_size = 1;
};
template <>
struct scalar_type<int8_t>
{
using type = int8_t;
static constexpr index_t vector_size = 1;
};
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
template <>
struct scalar_type<int4_t>
{
using type = int4_t;
static constexpr index_t vector_size = 1;
};
#endif
template <>
struct scalar_type<fp8_t>
{
using type = fp8_t;
static constexpr index_t vector_size = 1;
};
template <>
struct scalar_type<bf8_t>
{
using type = bf8_t;
static constexpr index_t vector_size = 1;
};
// below are some pre-defines of ext_vector_type
// fp64
using double2_t = typename vector_type<double, 2>::type;
using double4_t = typename vector_type<double, 4>::type;
// fp32
using float2_t = typename vector_type<float, 2>::type;
using float4_t = typename vector_type<float, 4>::type;
using float8_t = typename vector_type<float, 8>::type;
using float16_t = typename vector_type<float, 16>::type;
using float32_t = typename vector_type<float, 32>::type;
using float64_t = typename vector_type<float, 64>::type;
// fp16
using half2_t = typename vector_type<half_t, 2>::type;
using half4_t = typename vector_type<half_t, 4>::type;
using half8_t = typename vector_type<half_t, 8>::type;
using half16_t = typename vector_type<half_t, 16>::type;
using half32_t = typename vector_type<half_t, 32>::type;
using half64_t = typename vector_type<half_t, 64>::type;
// bfp16
using bhalf2_t = typename vector_type<bhalf_t, 2>::type;
using bhalf4_t = typename vector_type<bhalf_t, 4>::type;
using bhalf8_t = typename vector_type<bhalf_t, 8>::type;
using bhalf16_t = typename vector_type<bhalf_t, 16>::type;
using bhalf32_t = typename vector_type<bhalf_t, 32>::type;
using bhalf64_t = typename vector_type<bhalf_t, 64>::type;
// i32
using int32x2_t = typename vector_type<int32_t, 2>::type;
using int32x4_t = typename vector_type<int32_t, 4>::type;
using int32x8_t = typename vector_type<int32_t, 8>::type;
using int32x16_t = typename vector_type<int32_t, 16>::type;
using int32x32_t = typename vector_type<int32_t, 32>::type;
using int32x64_t = typename vector_type<int32_t, 64>::type;
// i8
using int8x2_t = typename vector_type<int8_t, 2>::type;
using int8x4_t = typename vector_type<int8_t, 4>::type;
using int8x8_t = typename vector_type<int8_t, 8>::type;
using int8x16_t = typename vector_type<int8_t, 16>::type;
using int8x32_t = typename vector_type<int8_t, 32>::type;
using int8x64_t = typename vector_type<int8_t, 64>::type;
// f8
using fp8x2_t = typename vector_type<fp8_t, 2>::type;
using fp8x4_t = typename vector_type<fp8_t, 4>::type;
using fp8x8_t = typename vector_type<fp8_t, 8>::type;
using fp8x16_t = typename vector_type<fp8_t, 16>::type;
using fp8x32_t = typename vector_type<fp8_t, 32>::type;
using fp8x64_t = typename vector_type<fp8_t, 64>::type;
// bf8
using bf8x2_t = typename vector_type<bf8_t, 2>::type;
using bf8x4_t = typename vector_type<bf8_t, 4>::type;
using bf8x8_t = typename vector_type<bf8_t, 8>::type;
using bf8x16_t = typename vector_type<bf8_t, 16>::type;
using bf8x32_t = typename vector_type<bf8_t, 32>::type;
using bf8x64_t = typename vector_type<bf8_t, 64>::type;
} // namespace ck_tile

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,78 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/tensor/tile_window.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace ck_tile {
template <typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
index_t NumCoord,
bool oob_conditional_check = true>
CK_TILE_DEVICE auto load_tile(const tile_window_with_static_distribution<BottomTensorView_,
WindowLengths_,
TileDistribution_,
NumCoord>& tile_window,
bool_constant<oob_conditional_check> = {})
{
return tile_window.load(bool_constant<oob_conditional_check>{});
}
template <typename T,
typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
index_t NumCoord,
bool oob_conditional_check = true>
CK_TILE_DEVICE auto load_tile_raw(T& tile,
const tile_window_with_static_distribution<BottomTensorView_,
WindowLengths_,
TileDistribution_,
NumCoord>& tile_window,
bool_constant<oob_conditional_check> = {})
{
tile_window.load_raw(tile, bool_constant<oob_conditional_check>{});
}
template <typename LdsTileWindow_,
typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
index_t NumCoord>
CK_TILE_DEVICE auto
async_load_tile_raw(LdsTileWindow_&& lds_tile,
const tile_window_with_static_distribution<BottomTensorView_,
WindowLengths_,
TileDistribution_,
NumCoord>& tile_window)
{
return tile_window.async_load(lds_tile);
}
CK_TILE_DEVICE auto async_load_fence(index_t cnt = 0)
{
asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory");
}
template <typename WindowLengths>
CK_TILE_DEVICE auto load_tile(const NullTileWindow<WindowLengths>&)
{
return NullTensor{};
}
template <typename T, typename WindowLengths>
CK_TILE_DEVICE auto load_tile_raw(T& /*null_tile*/, const NullTileWindow<WindowLengths>&)
{
}
} // namespace ck_tile

View File

@@ -0,0 +1,12 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace ck_tile {
struct null_tensor
{
};
} // namespace ck_tile

View File

@@ -0,0 +1,87 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/tensor/tile_window.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace ck_tile {
// placeholder type if we want to opt-out a tile window parameter
template <typename WindowLengths_>
struct null_tile_window
{
using BottomTensorView = null_tensor_view;
using WindowLengths = remove_cvref_t<WindowLengths_>;
using BottomTensorIndex = array<index_t, WindowLengths::size()>;
CK_TILE_DEVICE constexpr null_tile_window() = default;
CK_TILE_DEVICE constexpr null_tile_window(const WindowLengths& window_lengths)
: window_lengths_{window_lengths}
{
}
CK_TILE_DEVICE constexpr auto get_window_lengths() const { return window_lengths_; }
CK_TILE_DEVICE constexpr auto get_bottom_tensor_view() const { return null_tensor_view{}; }
CK_TILE_DEVICE constexpr auto get_window_origin() const { return BottomTensorIndex{}; }
WindowLengths window_lengths_;
};
// utility to check if this is a Null Tile Window
namespace impl {
template <typename>
struct is_null_tile_window : public std::false_type
{
};
template <typename T>
struct is_null_tile_window<null_tile_window<T>> : public std::true_type
{
};
} // namespace impl
template <typename T>
CK_TILE_DEVICE constexpr auto is_null_tile_window(const T&)
{
return impl::is_null_tile_window<remove_cvref_t<T>>::value;
}
template <typename WindowLengths>
CK_TILE_DEVICE constexpr auto make_null_tile_window(const WindowLengths& window_lengths)
{
static_assert(ck_tile::is_known_at_compile_time<WindowLengths>::value,
"wrong! lengths should be static");
return null_tile_window<remove_cvref_t<WindowLengths>>{window_lengths};
}
template <typename WindowLengths, typename... Ts>
CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view,
const WindowLengths& window_lengths,
const multi_index<WindowLengths::size()>& /*origin*/,
Ts&&...)
{
static_assert(ck_tile::is_known_at_compile_time<WindowLengths>::value,
"wrong! lengths should be static");
return null_tile_window<remove_cvref_t<WindowLengths>>{window_lengths};
}
template <typename WindowLengths>
CK_TILE_DEVICE void
move_tile_window(null_tile_window<WindowLengths>&,
const typename null_tile_window<WindowLengths>::BottomTensorIndex&)
{
}
} // namespace ck_tile

View File

@@ -0,0 +1,171 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/container/statically_indexed_array.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include "ck_tile/core/tensor/tile_elementwise.hpp"
namespace ck_tile {
namespace detail {
template <typename OutTensor, typename InTensor>
CK_TILE_DEVICE void shuffle_tile_impl_in_thread(OutTensor& out_tensor, const InTensor& in_tensor)
{
constexpr auto I0 = number<0>{};
using DataType = typename InTensor::DataType;
constexpr auto y_in_desc = InTensor::get_tile_distribution().get_ys_to_d_descriptor();
constexpr auto y_out_desc = OutTensor::get_tile_distribution().get_ys_to_d_descriptor();
// y_dim_out_to_in
constexpr auto get_rh_major_minor_to_y = [](auto dstr_tensor) {
using DstrEncode = typename decltype(dstr_tensor.get_tile_distribution())::DstrEncode;
map<array<index_t, 2>, index_t> rh_major_minor_to_y_;
static_for<0, DstrEncode::NDimY, 1>{}([&](auto i) {
constexpr index_t rh_major = DstrEncode::ys_to_rhs_major_[i];
constexpr index_t rh_minor = DstrEncode::ys_to_rhs_minor_[i];
rh_major_minor_to_y_({rh_major, rh_minor}) = i;
});
return rh_major_minor_to_y_;
};
constexpr auto rh_major_minor_to_y_in = get_rh_major_minor_to_y(InTensor{});
constexpr auto rh_major_minor_to_y_out = get_rh_major_minor_to_y(OutTensor{});
constexpr auto y_dim_out_to_in = [&] {
map<index_t, index_t> y_dim_out_to_in_;
for(const auto& [rh_major_minor, y_out] : rh_major_minor_to_y_out)
{
y_dim_out_to_in_(y_out) = rh_major_minor_to_y_in[rh_major_minor];
}
return y_dim_out_to_in_;
}();
//
constexpr index_t NDimY = InTensor::get_tile_distribution().GetNumOfDimensionY();
constexpr auto y_lengths = to_sequence(y_in_desc.get_lengths());
// input and output vector dim in the order of input Y dims
constexpr index_t y_dim_vec_in = NDimY - 1;
constexpr index_t y_dim_vec_out = y_dim_out_to_in[NDimY - 1];
// vector lengths
constexpr index_t vec_length_in = y_lengths[y_dim_vec_in];
constexpr index_t vec_length_out = y_lengths[y_dim_vec_out];
// # of vectors
constexpr index_t num_vec_in = vec_length_out;
constexpr index_t num_vec_out = vec_length_in;
using InVec = vector_type<DataType, vec_length_in>;
using OutVec = vector_type<DataType, vec_length_out>;
using InVecType = typename InVec::type;
using OutVecType = typename OutVec::type;
// SFC
constexpr auto scalars_per_access_arr = generate_array(
[&](auto i) { return (i == y_dim_vec_in or i == y_dim_vec_out) ? y_lengths[i] : 1; },
number<NDimY>{});
constexpr auto scalars_per_access = TO_SEQUENCE(scalars_per_access_arr, NDimY);
using SFC_Y = space_filling_curve<decltype(y_lengths),
typename arithmetic_sequence_gen<0, NDimY, 1>::type,
decltype(scalars_per_access)>;
constexpr index_t num_access = SFC_Y::get_num_of_access();
static_assert(num_access > 0, "wrong! num_access should be larger than 0");
// in/out vectors to be transposed
statically_indexed_array<InVec, num_vec_in> in_vectors;
statically_indexed_array<OutVec, num_vec_out> out_vectors;
// loop over SFC and do transpose
static_for<0, num_access, 1>{}([&](auto iAccess) {
// data index [y0, y1, ...] in the order of input tensor
constexpr auto idx_y_start = SFC_Y::get_index(iAccess);
// get input vectors
static_for<0, num_vec_in, 1>{}([&](auto i) {
constexpr auto idx_y_in = generate_array(
[&](auto ii) {
return ii == y_dim_vec_out ? idx_y_start[ii] + i : idx_y_start[ii];
},
number<NDimY>{});
constexpr index_t in_offset = y_in_desc.calculate_offset(idx_y_in);
in_vectors(i).template AsType<InVecType>()(I0) =
in_tensor.get_thread_buffer().template get_as<InVecType>(number<in_offset>{});
});
// transpose
transpose_vectors<DataType, num_vec_in, num_vec_out>{}(in_vectors, out_vectors);
// set output vectors
static_for<0, num_vec_out, 1>{}([&](auto i) {
constexpr auto idx_y_out_tmp = generate_array(
[&](auto ii) { return ii == y_dim_vec_in ? idx_y_start[ii] + i : idx_y_start[ii]; },
number<NDimY>{});
constexpr auto idx_y_out =
container_reorder_given_new2old(idx_y_out_tmp, y_dim_out_to_in);
constexpr index_t out_offset = y_out_desc.calculate_offset(idx_y_out);
out_tensor.get_thread_buffer().template set_as<OutVecType>(
number<out_offset / sizeof(OutVecType)>{},
out_vectors[i].template AsType<OutVecType>()[I0]);
});
});
}
} // namespace detail
template <typename OutTensor, typename InTensor>
CK_TILE_DEVICE void shuffle_tile(OutTensor& out, const InTensor& in)
{
using InDataType = typename InTensor::DataType;
using OutDataType = typename OutTensor::DataType;
using InDstrEncode = typename InTensor::StaticTileDistribution::DstrEncode;
using OutDstrEncode = typename OutTensor::StaticTileDistribution::DstrEncode;
// type convert
const auto in_tmp = tile_elementwise_in(type_convert<OutDataType, InDataType>, in);
// shuffle
if constexpr(InDstrEncode::rs_lengths_ == OutDstrEncode::rs_lengths_ &&
InDstrEncode::hs_lengthss_ == OutDstrEncode::hs_lengthss_ &&
InDstrEncode::ps_to_rhss_major_ == OutDstrEncode::ps_to_rhss_major_ &&
InDstrEncode::ps_to_rhss_minor_ == OutDstrEncode::ps_to_rhss_minor_ &&
InDstrEncode::NDimY == OutDstrEncode::NDimY)
{
detail::shuffle_tile_impl_in_thread(out, in_tmp);
}
else
{
// NOT implemented
}
}
} // namespace ck_tile

View File

@@ -0,0 +1,94 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/tensor/tile_window.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace ck_tile {
namespace tile_program {
template <typename BottomTensorView_,
typename WindowLengths_,
index_t... SliceBegins,
index_t... SliceEnds>
CK_TILE_DEVICE constexpr auto
get_slice_tile(const tile_window_with_static_lengths<BottomTensorView_, WindowLengths_>& tile,
sequence<SliceBegins...> slice_begins,
sequence<SliceEnds...> slice_ends)
{
using TileWindow = tile_window_with_static_lengths<BottomTensorView_, WindowLengths_>;
// NOTE: This API will override the origin of the tile window!
static_assert(sizeof...(SliceBegins) == sizeof...(SliceEnds));
static_assert(sizeof...(SliceBegins) == TileWindow::get_num_of_dimension());
constexpr auto slice_lengths = slice_ends - slice_begins;
return make_tile_window(tile.GetBottomTensorView(),
sequence_to_tuple_of_number(slice_lengths),
to_multi_index(slice_begins));
}
template <typename DataType_,
typename StaticTileDistribution_,
index_t... SliceBegins,
index_t... SliceEnds>
CK_TILE_DEVICE constexpr auto
get_slice_tile(const static_distributed_tensor<DataType_, StaticTileDistribution_>& tile,
sequence<SliceBegins...> slice_begins,
sequence<SliceEnds...> slice_ends)
{
using DataType = remove_cvref_t<DataType_>;
using Distribution = remove_cvref_t<StaticTileDistribution_>;
constexpr auto sliced_dstr_yidx_ylen =
detail::slice_distribution_from_x(Distribution{}, slice_begins, slice_ends);
constexpr auto sliced_dstr = sliced_dstr_yidx_ylen.template at<0>();
constexpr auto sliced_y_origins = sliced_dstr_yidx_ylen.template at<1>();
constexpr auto sliced_y_lengths = sliced_dstr_yidx_ylen.template at<2>();
auto sliced_tensor = make_static_distributed_tensor<DataType>(sliced_dstr);
sliced_tensor.get_thread_buffer() =
tile.get_y_sliced_thread_data(sliced_y_origins, sliced_y_lengths);
return sliced_tensor;
}
template <typename DstDataType_,
typename DstStaticTileDistribution_,
typename SrcDataType_,
typename SrcStaticTileDistribution_,
index_t... SliceBegins,
index_t... SliceEnds>
CK_TILE_DEVICE constexpr auto
set_slice_tile(static_distributed_tensor<DstDataType_, DstStaticTileDistribution_>& dst_tile,
const static_distributed_tensor<SrcDataType_, SrcStaticTileDistribution_>& src_tile,
sequence<SliceBegins...> slice_begins,
sequence<SliceEnds...> slice_ends)
{
using DstDistribution = remove_cvref_t<DstStaticTileDistribution_>;
constexpr auto sliced_dstr_yidx_ylen =
detail::slice_distribution_from_x(DstDistribution{}, slice_begins, slice_ends);
constexpr auto sliced_dstr = sliced_dstr_yidx_ylen.template at<0>();
constexpr auto sliced_y_origins = sliced_dstr_yidx_ylen.template at<1>();
constexpr auto sliced_y_lengths = sliced_dstr_yidx_ylen.template at<2>();
static_assert(is_same_v<decltype(sliced_dstr), DstDistribution>, "wrong!");
dst_tile.SetSlicedThreadData(sliced_y_origins, sliced_y_lengths, src_tile.get_thread_buffer());
}
} // namespace tile_program
} // namespace ck_tile

View File

@@ -0,0 +1,180 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace ck_tile {
template <typename DataType_, typename StaticTileDistribution_>
struct static_distributed_tensor
{
using DataType = remove_cvref_t<DataType_>;
using StaticTileDistribution = remove_cvref_t<StaticTileDistribution_>;
static_assert(StaticTileDistribution::is_static(),
"wrong! StaticTileDistribution should be known at compile tile");
using ThreadTensorDesc =
remove_cvref_t<decltype(StaticTileDistribution{}.get_ys_to_d_descriptor())>;
static constexpr index_t kThreadElementSpaceSize = ThreadTensorDesc{}.get_element_space_size();
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_dimension()
{
return StaticTileDistribution::get_num_of_dimension_x();
}
CK_TILE_HOST_DEVICE static constexpr auto get_lengths()
{
return StaticTileDistribution::get_lengths();
}
CK_TILE_HOST_DEVICE static constexpr auto get_tile_distribution()
{
return StaticTileDistribution{};
}
CK_TILE_HOST_DEVICE static constexpr auto get_distributed_spans()
{
return StaticTileDistribution::get_distributed_spans();
}
CK_TILE_HOST_DEVICE void initialize(const DataType& x) { thread_buf_.initialize(x); }
CK_TILE_HOST_DEVICE constexpr const auto& get_thread_buffer() const { return thread_buf_; }
CK_TILE_HOST_DEVICE constexpr auto& get_thread_buffer() { return thread_buf_; }
CK_TILE_HOST_DEVICE static constexpr index_t get_thread_buffer_size()
{
return kThreadElementSpaceSize;
}
template <index_t... YSliceOrigins, index_t... YSliceLengths>
CK_TILE_HOST_DEVICE auto get_y_sliced_thread_data(sequence<YSliceOrigins...>,
sequence<YSliceLengths...>) const
{
static_assert(sizeof...(YSliceOrigins) == StaticTileDistribution::NDimY &&
sizeof...(YSliceLengths) == StaticTileDistribution::NDimY,
"wrong!");
constexpr auto sliced_thread_tensor_desc =
make_naive_tensor_descriptor_packed(make_tuple(YSliceLengths...));
array<DataType, sliced_thread_tensor_desc.get_element_space_size()> sliced_thread_data;
static_ford<sequence<YSliceLengths...>>{}([&](auto idx) {
constexpr auto idx_ys = idx + sequence<YSliceOrigins...>{};
sliced_thread_data(number<sliced_thread_tensor_desc.calculate_offset(idx)>{}) =
thread_buf_[number<ThreadTensorDesc{}.calculate_offset(idx_ys)>{}];
});
return sliced_thread_data;
}
template <index_t... YSliceOrigins, index_t... YSliceLengths, index_t NSlicedData>
CK_TILE_HOST_DEVICE void
set_y_sliced_thread_data(sequence<YSliceOrigins...>,
sequence<YSliceLengths...>,
const array<DataType, NSlicedData>& sliced_thread_data)
{
static_assert(sizeof...(YSliceOrigins) == StaticTileDistribution::NDimY &&
sizeof...(YSliceLengths) == StaticTileDistribution::NDimY,
"wrong!");
constexpr auto sliced_thread_tensor_desc =
make_naive_tensor_descriptor_packed(make_tuple(YSliceLengths...));
static_ford<sequence<YSliceLengths...>>{}([&](auto idx) {
constexpr auto idx_ys = idx + sequence<YSliceOrigins...>{};
thread_buf_(number<ThreadTensorDesc{}.calculate_offset(idx_ys)>{}) =
sliced_thread_data[number<sliced_thread_tensor_desc.calculate_offset(idx)>{}];
});
}
template <typename TileDistributedIndices>
CK_TILE_HOST_DEVICE constexpr const DataType& operator[](TileDistributedIndices) const
{
static_assert(is_static_v<TileDistributedIndices>,
"wrong! Tile Distributed Indices should be static");
constexpr auto y_idx = get_tile_distribution().get_y_indices_from_distributed_indices(
TileDistributedIndices{});
return thread_buf_[number<ThreadTensorDesc{}.calculate_offset(y_idx)>{}];
}
template <typename TileDistributedIndices>
CK_TILE_HOST_DEVICE constexpr DataType& operator()(TileDistributedIndices)
{
static_assert(is_static_v<TileDistributedIndices>,
"wrong! Tile Distributed Indices should be static");
constexpr auto y_idx = get_tile_distribution().get_y_indices_from_distributed_indices(
TileDistributedIndices{});
return thread_buf_(number<ThreadTensorDesc{}.calculate_offset(y_idx)>{});
}
//
array<DataType, kThreadElementSpaceSize> thread_buf_;
};
template <typename DataType, typename StaticTileDistribution>
CK_TILE_HOST_DEVICE constexpr auto make_static_distributed_tensor(const StaticTileDistribution&)
{
return static_distributed_tensor<remove_cvref_t<DataType>,
remove_cvref_t<StaticTileDistribution>>{};
}
// get X indices from tuple of tile_distributed_index<>
template <typename StaticTileDistribution, typename DistributedIndices>
CK_TILE_HOST_DEVICE constexpr auto
get_x_indices_from_distributed_indices(StaticTileDistribution tile_distribution,
DistributedIndices distributed_indices)
{
const auto partition_index = detail::get_partition_index(tile_distribution);
constexpr auto y_indices =
tile_distribution.get_y_indices_from_distributed_indices(distributed_indices);
const auto x_coord = make_tensor_adaptor_coordinate(
tile_distribution.get_ps_ys_to_xs_adaptor(),
container_concat(partition_index, to_array<ck_tile::index_t, y_indices.size()>(y_indices)));
return x_coord.get_bottom_index();
}
template <typename DataType, typename StaticTileDistribution, typename XIndicesPredicate>
CK_TILE_HOST_DEVICE void
set_tile_if(static_distributed_tensor<DataType, StaticTileDistribution>& out_tensor,
DataType value,
XIndicesPredicate predicate)
{
constexpr auto out_spans =
static_distributed_tensor<DataType, StaticTileDistribution>::get_distributed_spans();
sweep_tile_span(out_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(out_spans[number<1>{}], [&](auto idx1) {
constexpr auto distributed_indices = make_tuple(idx0, idx1);
const auto x_indices = get_x_indices_from_distributed_indices(StaticTileDistribution{},
distributed_indices);
if(predicate(x_indices))
{
out_tensor(distributed_indices) = value;
}
});
});
}
} // namespace ck_tile

View File

@@ -0,0 +1,93 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/tensor/tile_window.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace ck_tile {
template <typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
typename DataType_>
CK_TILE_DEVICE void
store_tile(tile_window_with_static_lengths<BottomTensorView_, WindowLengths_>& tile_window_tmp,
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor)
{
using DataType = remove_cvref_t<typename BottomTensorView_::DataType>;
using TileDstr = remove_cvref_t<TileDistribution_>;
static_assert(is_same_v<remove_cvref_t<DataType_>, DataType>, "wrong!");
constexpr auto tile_dstr = TileDstr{};
auto tile_window = make_tile_window(tile_window_tmp.get_bottom_tensor_view(),
tile_window_tmp.get_window_lengths(),
tile_window_tmp.get_window_origin(),
tile_dstr);
tile_window.store(dstr_tensor);
}
template <typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
typename DataType_>
CK_TILE_DEVICE void
store_tile_raw(tile_window_with_static_lengths<BottomTensorView_, WindowLengths_>& tile_window_tmp,
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor)
{
using DataType = remove_cvref_t<typename BottomTensorView_::DataType>;
using TileDstr = remove_cvref_t<TileDistribution_>;
static_assert(is_same_v<remove_cvref_t<DataType_>, DataType>, "wrong!");
constexpr auto tile_dstr = TileDstr{};
auto tile_window = make_tile_window(tile_window_tmp.get_bottom_tensor_view(),
tile_window_tmp.get_window_lengths(),
tile_window_tmp.get_window_origin(),
tile_dstr);
tile_window.store_raw(dstr_tensor);
}
template <typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
index_t NumCoord,
typename DataType_>
CK_TILE_DEVICE void
store_tile(tile_window_with_static_distribution<BottomTensorView_,
WindowLengths_,
TileDistribution_,
NumCoord>& tile_window,
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor)
{
tile_window.store(dstr_tensor);
}
template <typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
index_t NumCoord,
typename DataType_>
CK_TILE_DEVICE void
store_tile_raw(tile_window_with_static_distribution<BottomTensorView_,
WindowLengths_,
TileDistribution_,
NumCoord>& tile_window,
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor)
{
tile_window.store_raw(dstr_tensor);
}
} // namespace ck_tile

View File

@@ -0,0 +1,30 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/tensor/tile_distribution.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace ck_tile {
// sweep over a span of a distribted tile and apply lambda function F
template <typename TileDistributedSpan_, // tile_distributed_span<...>
typename F // signature: F(tile_distributed_index<...>)
>
CK_TILE_DEVICE void sweep_tile_span(TileDistributedSpan_, const F& f)
{
using DstrSpan = remove_cvref_t<TileDistributedSpan_>;
static_ford<typename DstrSpan::Impl>{}([&](auto dstr_idx_impl) {
constexpr auto dstr_idx = detail::make_tile_distributed_index(dstr_idx_impl);
f(dstr_idx);
});
}
} // namespace ck_tile

View File

@@ -0,0 +1,942 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace ck_tile {
// Transforms: Tuple<transforms...>
// LowerDimensionHiddenIdss : Tuple<Sequence<...>, ...>
// UpperDimensionHiddenIdss : Tuple<Sequence<...>, ...>
// BottomDimensionHiddenIds : Sequence<...>
// TopDimensionHiddenIds : Sequence<...>
template <typename Transforms,
typename LowerDimensionHiddenIdss,
typename UpperDimensionHiddenIdss,
typename BottomDimensionHiddenIds,
typename TopDimensionHiddenIds>
struct tensor_adaptor
{
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_transform()
{
return Transforms::size();
}
CK_TILE_HOST_DEVICE constexpr const auto& get_transforms() const { return transforms_; }
CK_TILE_HOST_DEVICE static constexpr auto get_lower_dimension_hidden_idss()
{
return LowerDimensionHiddenIdss{};
}
CK_TILE_HOST_DEVICE static constexpr auto get_upper_dimension_hidden_idss()
{
return UpperDimensionHiddenIdss{};
}
CK_TILE_HOST_DEVICE static constexpr auto get_bottom_dimension_hidden_ids()
{
return BottomDimensionHiddenIds{};
}
CK_TILE_HOST_DEVICE static constexpr auto get_top_dimension_hidden_ids()
{
return TopDimensionHiddenIds{};
}
CK_TILE_HOST_DEVICE static constexpr auto initialize_element_size(const Transforms& transforms)
{
const auto lengths = generate_tuple(
[&](auto idim_top) {
constexpr index_t idim_hidden = TopDimensionHiddenIds::at(idim_top);
constexpr auto tmp = get_transform_and_its_upper_dimension(number<idim_hidden>{});
constexpr index_t itran = tmp[number<0>{}];
constexpr index_t idim_up = tmp[number<1>{}];
constexpr bool found = tmp[number<2>{}];
static_assert(found == true,
"wrong! not found matching transformation and upper-dimension");
const auto length =
transforms[number<itran>{}].get_upper_lengths()[number<idim_up>{}];
return length;
},
number<ndim_top_>{});
// TODO: make container_reduce support tuple of number and index_t
return container_reduce(lengths, multiplies{}, number<1>{});
}
template <index_t IDimHidden>
CK_TILE_HOST_DEVICE static constexpr auto
get_transform_and_its_upper_dimension(number<IDimHidden>)
{
// FIXME: length of bottom dimension is not known, since info about lower dim length are not
// saved in transformation
static_assert(IDimHidden >= ndim_bottom_, "wrong! not implemented");
index_t itran_found = 0;
index_t idim_up_found = 0;
bool found = false;
static_for<0, ntransform_, 1>{}([&](auto itran) {
constexpr auto up_dim_ids = UpperDimensionHiddenIdss{}[itran];
static_for<0, up_dim_ids.size(), 1>{}([&](auto idim_up) {
if constexpr(up_dim_ids[idim_up] == IDimHidden)
{
itran_found = itran;
idim_up_found = idim_up;
found = true;
}
});
});
return make_tuple(itran_found, idim_up_found, found);
}
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_bottom_dimension()
{
return BottomDimensionHiddenIds::size();
}
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_top_dimension()
{
return TopDimensionHiddenIds::size();
}
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_hidden_dimension()
{
constexpr auto all_low_dim_ids = unpack(
[](auto&&... xs) constexpr { return merge_sequences(xs...); },
LowerDimensionHiddenIdss{});
constexpr auto all_up_dim_ids = unpack(
[](auto&&... xs) constexpr { return merge_sequences(xs...); },
UpperDimensionHiddenIdss{});
constexpr auto all_dim_ids = merge_sequences(all_low_dim_ids, all_up_dim_ids);
using unique_sort_all_dim_ids = typename sequence_unique_sort<decltype(all_dim_ids),
less<index_t>,
equal<index_t>>::type;
return unique_sort_all_dim_ids::size();
}
constexpr static index_t ntransform_ = get_num_of_transform();
constexpr static index_t ndim_hidden_ = get_num_of_hidden_dimension();
constexpr static index_t ndim_bottom_ = get_num_of_bottom_dimension();
constexpr static index_t ndim_top_ = get_num_of_top_dimension();
using HiddenIndex = multi_index<ndim_hidden_>;
using BottomIndex = multi_index<ndim_bottom_>;
using TopIndex = multi_index<ndim_top_>;
// may be index_t or number<>
using ElementSize = remove_cv_t<decltype(initialize_element_size(Transforms{}))>;
public:
CK_TILE_HOST_DEVICE constexpr tensor_adaptor() = default;
CK_TILE_HOST_DEVICE constexpr tensor_adaptor(const Transforms& transforms)
: transforms_{transforms}, element_size_{initialize_element_size(transforms)}
{
static_assert(Transforms::size() == ntransform_ &&
LowerDimensionHiddenIdss::size() == ntransform_ &&
UpperDimensionHiddenIdss::size() == ntransform_,
"wrong! inconsistent # of transformations");
// TODO check dependency of dimensions is valid
}
CK_TILE_HOST_DEVICE constexpr auto get_element_size() const { return element_size_; }
// FIXME: this logic is wrong when getting bottome dimension lengths
template <index_t IDimHidden>
CK_TILE_HOST_DEVICE constexpr auto get_hidden_dimension_length(number<IDimHidden>) const
{
static_assert(IDimHidden >= 0 && IDimHidden < ndim_hidden_, "wrong! out of range");
constexpr auto tmp = get_transform_and_its_upper_dimension(number<IDimHidden>{});
constexpr index_t itran = tmp[number<0>{}];
constexpr index_t idim_up = tmp[number<1>{}];
constexpr bool found = tmp[number<2>{}];
static_assert(found == true,
"wrong! not found matching transformation and upper-dimension");
return transforms_[number<itran>{}].get_upper_lengths()[number<idim_up>{}];
}
template <index_t IDimTop>
CK_TILE_HOST_DEVICE constexpr auto get_top_dimension_length(number<IDimTop> idim_top) const
{
return get_hidden_dimension_length(TopDimensionHiddenIds::at(idim_top));
}
#if 0
// FIXME: get_hidden_dimension_length is wrong when getting bottome dimension lengths
template <index_t IDimBottom>
CK_TILE_HOST_DEVICE constexpr index_t
get_bottom_dimension_length(number<IDimBottom> idim_bottom) const
{
return get_hidden_dimension_length(TopDimensionHiddenIds::at(idim_bottom));
}
#endif
CK_TILE_HOST_DEVICE constexpr auto get_top_dimension_lengths() const
{
return generate_tuple([&](auto i) { return get_top_dimension_length(i); },
number<ndim_top_>{});
}
#if 0
// FIXME: get_hidden_dimension_length is wrong when getting bottome dimension lengths
CK_TILE_HOST_DEVICE constexpr auto GetBottomDimensionLengths() const
{
return generate_tuple([&](auto i) { return get_bottom_dimension_length(i); },
number<ndim_bottom_>{});
}
#endif
template <typename TopIdx>
CK_TILE_HOST_DEVICE constexpr auto calculate_bottom_index(const TopIdx& idx_top) const
{
static_assert(TopIdx::size() == TopDimensionHiddenIds::size(),
"wrong! # of dimension inconsistent");
constexpr index_t ntransform = get_num_of_transform();
constexpr index_t ndim_hidden = get_num_of_hidden_dimension();
multi_index<ndim_hidden> idx_hidden;
// initialize uppest index
set_container_subset(idx_hidden, get_top_dimension_hidden_ids(), idx_top);
// calculate hidden index
static_for<ntransform, 0, -1>{}([&](auto itran_p1) {
auto itran = itran_p1 - number<1>{};
const auto& tran = get_transforms().at(itran);
constexpr auto dims_low = get_lower_dimension_hidden_idss().at(itran);
constexpr auto dims_up = get_upper_dimension_hidden_idss().at(itran);
const auto idx_up = get_container_subset(idx_hidden, dims_up);
multi_index<dims_low.size()> idx_low;
tran.calculate_lower_index(idx_low, idx_up);
set_container_subset(idx_hidden, dims_low, idx_low);
});
return get_container_subset(idx_hidden, BottomDimensionHiddenIds{});
}
CK_TILE_HOST_DEVICE static constexpr bool is_static()
{
bool is_known = true;
static_for<0, Transforms::size(), 1>{}([&](auto i) {
is_known &= remove_cvref_t<decltype(Transforms{}[i])>::is_known_at_compile_time();
});
return is_known && ck_tile::is_known_at_compile_time<ElementSize>::value;
}
CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time() { return is_static(); }
CK_TILE_HOST_DEVICE static constexpr auto get_top_dimension_safe_vector_length_strides(
const array<index_t, ndim_hidden_>& guaranteed_vector_lengths,
const array<index_t, ndim_hidden_>& guaranteed_vector_strides)
{
auto vector_lengths = guaranteed_vector_lengths;
auto vector_strides = guaranteed_vector_strides;
static_for<0, get_num_of_transform(), 1>{}([&](auto itran) {
constexpr auto low_dims = get_lower_dimension_hidden_idss().at(itran);
constexpr auto up_dims = get_upper_dimension_hidden_idss().at(itran);
const auto up_guaranteed_vector_lengths =
get_container_subset(guaranteed_vector_lengths, up_dims);
const auto up_guaranteed_vector_strides =
get_container_subset(guaranteed_vector_strides, up_dims);
// only need type of transform
auto [up_vector_lengths, up_vector_strides] =
Transforms{}.at(itran).calculate_upper_dimension_safe_vector_length_strides(
get_container_subset(vector_lengths, low_dims),
get_container_subset(vector_strides, low_dims));
if constexpr(up_dims.size() > 0)
{
for(index_t i = 0; i < up_dims.size(); ++i)
{
up_vector_lengths(i) = (up_guaranteed_vector_lengths[i] != -1)
? up_guaranteed_vector_lengths[i]
: up_vector_lengths[i];
up_vector_strides(i) = (up_guaranteed_vector_strides[i] != -1)
? up_guaranteed_vector_strides[i]
: up_vector_strides[i];
}
}
set_container_subset(vector_lengths, up_dims, up_vector_lengths);
set_container_subset(vector_strides, up_dims, up_vector_strides);
});
constexpr auto top_dims = TopDimensionHiddenIds{};
return make_tuple(get_container_subset(vector_lengths, top_dims),
get_container_subset(vector_strides, top_dims));
}
CK_TILE_HOST_DEVICE void print() const
{
printf("tensor_adaptor{");
//
printf("transforms: ");
print(transforms_);
printf(", ");
//
printf("LowerDimensionHiddenIds: ");
print(LowerDimensionHiddenIdss{});
printf(", ");
//
printf("UpperDimensionHiddenIds: ");
print(UpperDimensionHiddenIdss{});
printf(", ");
//
printf("BottomDimensionHiddenIds: ");
print(BottomDimensionHiddenIds{});
printf(", ");
//
printf("TopDimensionHiddenIds: ");
print(TopDimensionHiddenIds{});
printf("}");
}
private:
Transforms transforms_;
ElementSize element_size_;
};
// Transforms: Tuple<transforms...>
// LowerDimensionOldTopIdss: Tuple<Sequence<...>, ...>
// UpperDimensionNewTopIdss: Tuple<Sequence<...>, ...>
template <typename Transforms, typename LowerDimensionOldTopIdss, typename UpperDimensionNewTopIdss>
CK_TILE_HOST_DEVICE constexpr auto make_single_stage_tensor_adaptor(const Transforms& transforms,
LowerDimensionOldTopIdss,
UpperDimensionNewTopIdss)
{
constexpr index_t ntransform = Transforms::size();
static_assert(LowerDimensionOldTopIdss::size() == ntransform &&
UpperDimensionNewTopIdss::size() == ntransform,
"wrong!");
// sanity check on LowerDimensionOldTopIdss and UpperDimensionNewTopIdss
constexpr auto all_low_dim_old_top_ids = unpack(
[](auto&&... xs) constexpr { return merge_sequences(xs...); }, LowerDimensionOldTopIdss{});
constexpr auto all_up_dim_new_top_ids = unpack(
[](auto&&... xs) constexpr { return merge_sequences(xs...); }, UpperDimensionNewTopIdss{});
static_assert(is_valid_sequence_map<decltype(all_low_dim_old_top_ids)>::value &&
is_valid_sequence_map<decltype(all_up_dim_new_top_ids)>::value,
"wrong!");
constexpr index_t ndim_old_top = all_low_dim_old_top_ids.size();
constexpr index_t ndim_new_top = all_up_dim_new_top_ids.size();
// low_dim_hidden_idss
constexpr auto low_dim_hidden_idss = LowerDimensionOldTopIdss{};
// up_dim_hidden_idss: shift UpperDimensionNewTopIdss by ndim_bottom
constexpr auto up_dim_hidden_idss = generate_tuple(
[](auto itran) { return UpperDimensionNewTopIdss{}[itran] + number<ndim_old_top>{}; },
number<ntransform>{});
// bottom_dim_hidden_ids
constexpr auto bottom_dim_hidden_ids =
typename arithmetic_sequence_gen<0, ndim_old_top, 1>::type{};
// top_dim_hidden_ids
constexpr auto top_dim_hidden_ids =
typename arithmetic_sequence_gen<0, ndim_new_top, 1>::type{} + number<ndim_old_top>{};
return tensor_adaptor<remove_cvref_t<Transforms>,
remove_cvref_t<decltype(low_dim_hidden_idss)>,
remove_cvref_t<decltype(up_dim_hidden_idss)>,
remove_cvref_t<decltype(bottom_dim_hidden_ids)>,
remove_cvref_t<decltype(top_dim_hidden_ids)>>{transforms};
}
// TODO: How to fix this? It uses an struct instead of lambda because lambda
// doesn't have constructor, and to put it outside the scope where it is used
// (transform_tensor_adaptor) because template cannot be defined inside a function
// template
template <typename NewTransforms>
struct lambda_get_up_dim_num
{
template <typename I>
CK_TILE_HOST_DEVICE constexpr auto operator()(I) const
{
using Tran = remove_reference_t<decltype(NewTransforms{}.at(I{}))>;
return number<Tran::get_num_of_upper_dimension()>{};
}
};
template <typename OldTensorAdaptor,
typename NewTransforms,
typename NewLowerDimensionOldTopIdss,
typename NewUpperDimensionNewTopIdss>
CK_TILE_HOST_DEVICE constexpr auto
transform_tensor_adaptor(const OldTensorAdaptor& old_tensor_adaptor,
const NewTransforms& new_transforms,
NewLowerDimensionOldTopIdss,
NewUpperDimensionNewTopIdss)
{
// sanity check
{
static_assert(NewTransforms::size() == NewLowerDimensionOldTopIdss::size() &&
NewTransforms::size() == NewUpperDimensionNewTopIdss::size(),
"wrong! inconsitent number of transform");
constexpr auto all_old_top_ids = unpack([](auto... xs) { return merge_sequences(xs...); },
NewLowerDimensionOldTopIdss{});
constexpr auto all_new_top_ids = unpack([](auto... xs) { return merge_sequences(xs...); },
NewUpperDimensionNewTopIdss{});
static_assert(is_valid_sequence_map<decltype(all_old_top_ids)>::value &&
is_valid_sequence_map<decltype(all_new_top_ids)>::value,
"wrong!");
}
// lower dimension's hidden idss
// convert lower dimension top idss (tuple of sequences) to hidden idss (tuple of
// sequences)
constexpr auto low_dim_hidden_idss = transform_tuples(
// convert lower dimension top ids (a sequence) to hidden ids (a sequence)
[](auto low_dim_top_ids) constexpr {
return transform_sequences(
// convert lower dimension top id to hidden id
[](auto low_dim_top_id) constexpr {
return OldTensorAdaptor::get_top_dimension_hidden_ids()[low_dim_top_id];
},
low_dim_top_ids);
},
NewLowerDimensionOldTopIdss{});
constexpr index_t num_new_transform = NewTransforms::size();
// upper dimension's hidden idss
constexpr index_t old_hidden_dim_number = OldTensorAdaptor::get_num_of_hidden_dimension();
constexpr auto up_dim_numbers =
generate_sequence(lambda_get_up_dim_num<NewTransforms>{}, number<num_new_transform>{});
constexpr auto up_dim_numbers_scan = merge_sequences(
Sequence<0>{}, inclusive_scan_sequence(up_dim_numbers, plus<index_t>{}, number<0>{}));
constexpr auto up_dim_hidden_idss = generate_tuple(
[ old_hidden_dim_number, up_dim_numbers_scan ](auto i) constexpr {
return
typename arithmetic_sequence_gen<old_hidden_dim_number + up_dim_numbers_scan[i],
old_hidden_dim_number + up_dim_numbers_scan[i + 1],
1>::type{};
},
number<num_new_transform>{});
// new top dimension's hidden ids
constexpr auto unordered_new_top_dim_hidden_ids = unpack(
[](auto... xs) constexpr { return merge_sequences(xs...); }, up_dim_hidden_idss);
constexpr auto new_top_dim_unordered2ordered = unpack(
[](auto... xs) constexpr { return merge_sequences(xs...); }, NewUpperDimensionNewTopIdss{});
constexpr auto new_top_dim_hidden_ids =
unordered_new_top_dim_hidden_ids.reorder_old_to_new(new_top_dim_unordered2ordered);
// put everything together
const auto all_transforms =
container_concat(old_tensor_adaptor.get_transforms(), new_transforms);
constexpr auto all_low_dim_hidden_idss =
container_concat(OldTensorAdaptor::get_lower_dimension_hidden_idss(), low_dim_hidden_idss);
constexpr auto all_up_dim_hidden_idss =
container_concat(OldTensorAdaptor::get_upper_dimension_hidden_idss(), up_dim_hidden_idss);
return tensor_adaptor<
remove_cvref_t<decltype(all_transforms)>,
remove_cvref_t<decltype(all_low_dim_hidden_idss)>,
remove_cvref_t<decltype(all_up_dim_hidden_idss)>,
remove_cvref_t<decltype(OldTensorAdaptor::get_bottom_dimension_hidden_ids())>,
remove_cvref_t<decltype(new_top_dim_hidden_ids)>>{all_transforms};
}
template <typename TensorAdaptor0, typename TensorAdaptor1>
CK_TILE_HOST_DEVICE constexpr auto chain_tensor_adaptors(const TensorAdaptor0& adaptor0,
const TensorAdaptor1& adaptor1)
{
static_assert(TensorAdaptor0::get_num_of_top_dimension() ==
TensorAdaptor1::get_num_of_bottom_dimension(),
"wrong!");
// all_transforms = transform0 + transform1
const auto all_transforms =
container_concat(adaptor0.get_transforms(), adaptor1.get_transforms());
// shift
constexpr index_t adaptor0_max_hidden_id = [&]() {
index_t adaptor0_max_hidden_id_ = NumericLimits<index_t>::Min();
static_for<0, TensorAdaptor0::get_num_of_transform(), 1>{}([&](auto itran) {
constexpr index_t ndim_low =
TensorAdaptor0{}.get_transforms()[itran].get_num_of_lower_dimension();
static_for<0, ndim_low, 1>{}([&](auto idim_low) {
adaptor0_max_hidden_id_ =
max(adaptor0_max_hidden_id_,
TensorAdaptor0::get_lower_dimension_hidden_idss()[itran][idim_low].value);
});
constexpr index_t ndim_up =
TensorAdaptor0{}.get_transforms()[itran].get_num_of_upper_dimension();
static_for<0, ndim_up, 1>{}([&](auto idim_up) {
adaptor0_max_hidden_id_ =
max(adaptor0_max_hidden_id_,
TensorAdaptor0::get_upper_dimension_hidden_idss()[itran][idim_up].value);
});
});
return adaptor0_max_hidden_id_;
}();
constexpr index_t adaptor1_min_hidden_id = [&]() {
index_t adaptor1_min_hidden_id_ = NumericLimits<index_t>::Max();
static_for<0, TensorAdaptor1::get_num_of_transform(), 1>{}([&](auto itran) {
constexpr index_t ndim_low =
TensorAdaptor1{}.get_transforms()[itran].get_num_of_lower_dimension();
// get the min of all lower dimenions, but not bottom dimension (because their id will
// be matched with top id from adaptor0)
static_for<0, ndim_low, 1>{}([&](auto idim_low) {
constexpr index_t low_dim_hidden_id =
TensorAdaptor1::get_lower_dimension_hidden_idss()[itran][idim_low].value;
bool is_bottom_dim = false;
static_for<0, TensorAdaptor1::get_num_of_bottom_dimension(), 1>{}([&](auto i) {
if constexpr(low_dim_hidden_id ==
TensorAdaptor1::get_bottom_dimension_hidden_ids()[i])
{
is_bottom_dim = true;
}
});
if(!is_bottom_dim)
{
adaptor1_min_hidden_id_ = min(adaptor1_min_hidden_id_, low_dim_hidden_id);
}
});
constexpr index_t ndim_up =
TensorAdaptor1{}.get_transforms()[itran].get_num_of_upper_dimension();
// get the min of all upper dimensions
static_for<0, ndim_up, 1>{}([&](auto idim_up) {
adaptor1_min_hidden_id_ =
min(adaptor1_min_hidden_id_,
TensorAdaptor1::get_upper_dimension_hidden_idss()[itran][idim_up].value);
});
});
return adaptor1_min_hidden_id_;
}();
constexpr index_t adaptor1_hidden_id_shift =
adaptor0_max_hidden_id + 1 - adaptor1_min_hidden_id;
constexpr index_t ndim_bottom_1 = TensorAdaptor1::get_num_of_bottom_dimension();
// all_low_dim_hidden_idss =
// low_dim_hidden_idss_0 + match_hidden_id_for_1(shift_hidden_id_for_1(low_dim_hiden_idss_1))
constexpr auto low_dim_hidden_idss_1 = generate_tuple(
// generate sequence of ids for a transform
[&](auto itran) {
constexpr auto ndim_low_1 =
TensorAdaptor1::get_lower_dimension_hidden_idss()[itran].size();
constexpr auto low_dim_hidden_ids_1 =
TensorAdaptor1::get_lower_dimension_hidden_idss()[itran];
// sequence in, sequence out
constexpr auto low_dim_hidden_ids_1_mod = [&]() constexpr
{
auto low_dim_hidden_ids_1_mod_ = to_multi_index(low_dim_hidden_ids_1);
// shift hidden id so every dim id is unique
static_for<0, ndim_low_1, 1>{}([&](auto idim_low_1) {
low_dim_hidden_ids_1_mod_(idim_low_1) += adaptor1_hidden_id_shift;
});
// match hidden id
static_for<0, ndim_low_1, 1>{}([&](auto idim_low_1) {
static_for<0, ndim_bottom_1, 1>{}([&](auto idim_bottom_1) {
// if this low dim is bottom dim, then do id matching
if constexpr(low_dim_hidden_ids_1[idim_low_1] ==
TensorAdaptor1::get_bottom_dimension_hidden_ids()
[idim_bottom_1])
{
low_dim_hidden_ids_1_mod_(idim_low_1) =
TensorAdaptor0::get_top_dimension_hidden_ids()[idim_bottom_1];
}
});
});
return low_dim_hidden_ids_1_mod_;
}
();
return generate_sequence_v2(
[&](auto i) constexpr { return number<low_dim_hidden_ids_1_mod[i]>{}; },
number<ndim_low_1>{});
},
number<TensorAdaptor1::get_num_of_transform()>{});
constexpr auto all_low_dim_hidden_idss =
container_concat(TensorAdaptor0::get_lower_dimension_hidden_idss(), low_dim_hidden_idss_1);
// all_up_dim_hidden_idss =
// up_dim_hidden_idss_0 + shift_hidden_id_for_1(up_dim_hiden_idss_1)
constexpr auto up_dim_hidden_idss_1 = generate_tuple(
// generate sequence of ids for a transform
[&](auto itran) {
constexpr auto ndim_up_1 =
TensorAdaptor1::get_upper_dimension_hidden_idss()[itran].size();
constexpr auto up_dim_hidden_ids_1 =
TensorAdaptor1::get_upper_dimension_hidden_idss()[itran];
// sequence in, constexpr tuple out
constexpr auto up_dim_hidden_ids_1_mod = [&]() constexpr
{
auto up_dim_hidden_ids_1_mod_ = to_multi_index(up_dim_hidden_ids_1);
// shift hidden id
static_for<0, ndim_up_1, 1>{}([&](auto idim_up_1) {
up_dim_hidden_ids_1_mod_(idim_up_1) += adaptor1_hidden_id_shift;
});
return up_dim_hidden_ids_1_mod_;
}
();
// constexpr tuple to sequence
return generate_sequence_v2(
[&](auto i) constexpr { return number<up_dim_hidden_ids_1_mod[i]>{}; },
number<ndim_up_1>{});
},
number<TensorAdaptor1::get_num_of_transform()>{});
constexpr auto all_up_dim_hidden_idss =
container_concat(TensorAdaptor0::get_upper_dimension_hidden_idss(), up_dim_hidden_idss_1);
// bottom_dim_hidden_ids = bottom_dim_hidden_ids_0
constexpr auto bottom_dim_hidden_ids = TensorAdaptor0::get_bottom_dimension_hidden_ids();
// top_dim_hidden_ids = shift_hidden_id(top_dim_hidden_ids_1)
constexpr auto top_dim_hidden_ids =
TensorAdaptor1::get_top_dimension_hidden_ids() + number<adaptor1_hidden_id_shift>{};
// put everything together
return tensor_adaptor<remove_cvref_t<decltype(all_transforms)>,
remove_cvref_t<decltype(all_low_dim_hidden_idss)>,
remove_cvref_t<decltype(all_up_dim_hidden_idss)>,
remove_cvref_t<decltype(bottom_dim_hidden_ids)>,
remove_cvref_t<decltype(top_dim_hidden_ids)>>{all_transforms};
}
template <typename X, typename... Xs, typename enable_if<sizeof...(Xs) >= 2, bool>::type = false>
CK_TILE_HOST_DEVICE constexpr auto chain_tensor_adaptors(const X& x, const Xs&... xs)
{
return chain_tensor_adaptors(x, chain_tensor_adaptors(xs...));
}
} // namespace ck_tile
// Macro function
// construct constexpr tensor_adaptor from constexpr encoding
// encoded_tensor_adaptor are Tuple of following objects:
// 1. encoded transforms (array of fixed size). Each encoded transform is a Tuple of following:
// 1.1 name (cood_transform_enum)
// 1.2 meta data for constructor of the transform
// 1.3 num of lower dimension (index_t)
// 1.4 lower dimension Ids (array of fixed size)
// 1.5 num of up dimension (index_t)
// 1.6 upper dimension Ids (array of fixed size)
// 2. num of transforms (index_t)
// 3. encoded bottom dimension Ids (array of fixed size)
// 4. num of bottom dimension (index_t)
// 5. encoded top dimension Ids (array of fixed size)
// 6. num of top dimension (index_t)
#define CONSTRUCT_TENSOR_ADAPTOR_FROM_ENCODING(encoded_tensor_adaptor) \
[encoded_tensor_adaptor]() { \
using namespace ck_tile; \
\
constexpr auto encoded_transforms = encoded_tensor_adaptor.template at<0>(); \
constexpr index_t num_transform = encoded_tensor_adaptor.template at<1>(); \
constexpr auto encoded_bottom_dims = encoded_tensor_adaptor.template at<2>(); \
constexpr index_t num_bottom_dim = encoded_tensor_adaptor.template at<3>(); \
constexpr auto encoded_top_dims = encoded_tensor_adaptor.template at<4>(); \
constexpr index_t num_top_dim = encoded_tensor_adaptor.template at<5>(); \
\
constexpr auto trans = [&encoded_transforms, &num_transform]() { \
return generate_tuple( \
[&encoded_transforms](auto i) constexpr { \
constexpr auto name = encoded_transforms[i].template at<0>(); \
constexpr auto meta_data = encoded_transforms[i].template at<1>(); \
constexpr auto num_low_dim = encoded_transforms[i].template at<2>(); \
constexpr auto num_up_dim = encoded_transforms[i].template at<4>(); \
\
STATIC_ASSERT(name == cood_transform_enum::PassThrough || \
name == cood_transform_enum::pad || \
name == cood_transform_enum::embed || \
name == cood_transform_enum::merge || \
name == cood_transform_enum::unmerge || \
name == cood_transform_enum::replicate, \
""); \
\
if constexpr(name == cood_transform_enum::PassThrough) \
{ \
index_t pos = 0; \
auto low_len = meta_data.template pop<index_t>(pos); \
\
return make_pass_through_transform(low_len); \
} \
else if constexpr(name == cood_transform_enum::pad) \
{ \
index_t pos = 0; \
auto low_len = meta_data.template pop<index_t>(pos); \
auto left_pad = meta_data.template pop<index_t>(pos); \
auto right_pad = meta_data.template pop<index_t>(pos); \
\
return make_pad_transform(low_len, left_pad, right_pad); \
} \
else if constexpr(name == cood_transform_enum::embed) \
{ \
index_t pos = 0; \
auto up_lens = meta_data.template pop<array<index_t, num_up_dim>>(pos); \
auto coefficients = \
meta_data.template pop<array<index_t, num_up_dim>>(pos); \
\
return make_embed_transform(up_lens, coefficients); \
} \
else if constexpr(name == cood_transform_enum::merge) \
{ \
index_t pos = 0; \
auto low_lens = meta_data.template pop<array<index_t, num_low_dim>>(pos); \
\
return make_merge_transform(low_lens); \
} \
else if constexpr(name == cood_transform_enum::unmerge) \
{ \
index_t pos = 0; \
auto up_lens = meta_data.template pop<array<index_t, num_up_dim>>(pos); \
\
return make_unmerge_transform(up_lens); \
} \
else if constexpr(name == cood_transform_enum::replicate) \
{ \
index_t pos = 0; \
auto up_lens = meta_data.template pop<array<index_t, num_up_dim>>(pos); \
\
return make_replicate_transform(up_lens); \
} \
}, \
number<num_transform>{}); \
}(); \
\
constexpr auto low_dim_idss = [&encoded_transforms, &num_transform]() { \
return generate_tuple( \
[&encoded_transforms](auto i) { \
constexpr auto num_low_dim = encoded_transforms[i].template at<2>(); \
constexpr auto low_dims = encoded_transforms[i].template at<3>(); \
\
return TO_SEQUENCE(low_dims, num_low_dim); \
}, \
number<num_transform>()); \
}(); \
\
constexpr auto up_dim_idss = [&encoded_transforms, &num_transform] { \
return generate_tuple( \
[&encoded_transforms](auto i) { \
constexpr auto num_up_dim = encoded_transforms[i].template at<4>(); \
constexpr auto up_dims = encoded_transforms[i].template at<5>(); \
\
return TO_SEQUENCE(up_dims, num_up_dim); \
}, \
number<num_transform>()); \
}(); \
\
constexpr auto bottom_dim_ids = TO_SEQUENCE(encoded_bottom_dims, num_bottom_dim); \
constexpr auto top_dim_ids = TO_SEQUENCE(encoded_top_dims, num_top_dim); \
\
return tensor_adaptor<remove_cvref_t<decltype(trans)>, \
remove_cvref_t<decltype(low_dim_idss)>, \
remove_cvref_t<decltype(up_dim_idss)>, \
remove_cvref_t<decltype(bottom_dim_ids)>, \
remove_cvref_t<decltype(top_dim_ids)>>{trans}; \
}()
// Macro function
// construct static tensor_adaptor from constexpr encoding
// encoded_tensor_adaptor are Tuple of following objects:
// 1. encoded transforms (array of fixed size). Each encoded transform is a Tuple of following:
// 1.1 name (cood_transform_enum)
// 1.2 meta data for constructor of the transform
// 1.3 num of lower dimension (index_t)
// 1.4 lower dimension Ids (array of fixed size)
// 1.5 num of up dimension (index_t)
// 1.6 upper dimension Ids (array of fixed size)
// 2. num of transforms (index_t)
// 3. encoded bottom dimension Ids (array of fixed size)
// 4. num of bottom dimension (index_t)
// 5. encoded top dimension Ids (array of fixed size)
// 6. num of top dimension (index_t)
#define CONSTRUCT_STATIC_TENSOR_ADAPTOR_FROM_ENCODING(encoded_tensor_adaptor) \
[encoded_tensor_adaptor]() { \
using namespace ck_tile; \
\
constexpr auto encoded_transforms = encoded_tensor_adaptor.template at<0>(); \
constexpr index_t num_transform = encoded_tensor_adaptor.template at<1>(); \
constexpr auto encoded_bottom_dims = encoded_tensor_adaptor.template at<2>(); \
constexpr index_t num_bottom_dim = encoded_tensor_adaptor.template at<3>(); \
constexpr auto encoded_top_dims = encoded_tensor_adaptor.template at<4>(); \
constexpr index_t num_top_dim = encoded_tensor_adaptor.template at<5>(); \
\
constexpr auto trans = [&encoded_transforms, &num_transform]() { \
return generate_tuple( \
[&encoded_transforms](auto i) constexpr { \
constexpr auto name = encoded_transforms[i].template at<0>(); \
constexpr auto meta_data = encoded_transforms[i].template at<1>(); \
constexpr auto num_low_dim = encoded_transforms[i].template at<2>(); \
constexpr auto num_up_dim = encoded_transforms[i].template at<4>(); \
\
STATIC_ASSERT(name == cood_transform_enum::PassThrough || \
name == cood_transform_enum::pad || \
name == cood_transform_enum::embed || \
name == cood_transform_enum::merge || \
name == cood_transform_enum::unmerge || \
name == cood_transform_enum::replicate, \
""); \
\
if constexpr(name == cood_transform_enum::PassThrough) \
{ \
constexpr index_t low_len = meta_data.template get<index_t>(0); \
\
return make_pass_through_transform(number<low_len>{}); \
} \
else if constexpr(name == cood_transform_enum::pad) \
{ \
constexpr index_t low_len = meta_data.template get<index_t>(0); \
\
constexpr index_t left_pad = \
meta_data.template get<index_t>(sizeof(low_len)); \
\
constexpr index_t right_pad = \
meta_data.template pop<index_t>(sizeof(low_len) + sizeof(left_pad)); \
\
return make_pad_transform( \
number<low_len>{}, number<left_pad>{}, number<right_pad>{}); \
} \
else if constexpr(name == cood_transform_enum::embed) \
{ \
constexpr auto up_lens = \
meta_data.template get<array<index_t, num_up_dim>>(0); \
\
constexpr auto coefficients = \
meta_data.template get<array<index_t, num_up_dim>>(sizeof(up_lens)); \
\
return make_embed_transform(TO_TUPLE_OF_NUMBER(up_lens, num_up_dim), \
TO_TUPLE_OF_NUMBER(coefficients, num_up_dim)); \
} \
else if constexpr(name == cood_transform_enum::merge) \
{ \
constexpr auto low_lens = \
meta_data.template get<array<index_t, num_low_dim>>(0); \
\
return make_merge_transform(TO_TUPLE_OF_NUMBER(low_lens, num_low_dim)); \
} \
else if constexpr(name == cood_transform_enum::unmerge) \
{ \
constexpr auto up_lens = \
meta_data.template get<array<index_t, num_up_dim>>(0); \
\
return make_unmerge_transform(TO_TUPLE_OF_NUMBER(up_lens, num_up_dim)); \
} \
else if constexpr(name == cood_transform_enum::replicate) \
{ \
constexpr auto up_lens = \
meta_data.template get<array<index_t, num_up_dim>>(0); \
\
return make_replicate_transform(TO_TUPLE_OF_NUMBER(up_lens, num_up_dim)); \
} \
}, \
number<num_transform>{}); \
}(); \
\
constexpr auto low_dim_idss = [&encoded_transforms, &num_transform]() { \
return generate_tuple( \
[&encoded_transforms](auto i) { \
constexpr auto num_low_dim = encoded_transforms[i].template at<2>(); \
constexpr auto low_dims = encoded_transforms[i].template at<3>(); \
\
return TO_SEQUENCE(low_dims, num_low_dim); \
}, \
number<num_transform>()); \
}(); \
\
constexpr auto up_dim_idss = [&encoded_transforms, &num_transform] { \
return generate_tuple( \
[&encoded_transforms](auto i) { \
constexpr auto num_up_dim = encoded_transforms[i].template at<4>(); \
constexpr auto up_dims = encoded_transforms[i].template at<5>(); \
\
return TO_SEQUENCE(up_dims, num_up_dim); \
}, \
number<num_transform>()); \
}(); \
\
constexpr auto bottom_dim_ids = TO_SEQUENCE(encoded_bottom_dims, num_bottom_dim); \
constexpr auto top_dim_ids = TO_SEQUENCE(encoded_top_dims, num_top_dim); \
\
return tensor_adaptor<remove_cvref_t<decltype(trans)>, \
remove_cvref_t<decltype(low_dim_idss)>, \
remove_cvref_t<decltype(up_dim_idss)>, \
remove_cvref_t<decltype(bottom_dim_ids)>, \
remove_cvref_t<decltype(top_dim_ids)>>{trans}; \
}()

View File

@@ -0,0 +1,257 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
#include "ck_tile/core/tensor/tensor_adaptor.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/container/multi_index.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace ck_tile {
template <index_t NDimHidden, typename BottomDimensionHiddenIds, typename TopDimensionHiddenIds>
struct tensor_adaptor_coordinate
{
static constexpr index_t ndim_bottom_ = BottomDimensionHiddenIds::size();
static constexpr index_t ndim_top_ = TopDimensionHiddenIds::size();
using HiddenIndex = multi_index<NDimHidden>;
using BottomIndex = multi_index<ndim_bottom_>;
using TopIndex = multi_index<ndim_top_>;
public:
CK_TILE_HOST_DEVICE constexpr tensor_adaptor_coordinate() = default;
CK_TILE_HOST_DEVICE constexpr tensor_adaptor_coordinate(const HiddenIndex& idx_hidden)
: idx_hidden_{idx_hidden}
{
}
CK_TILE_HOST_DEVICE constexpr auto get_top_index() const
{
return get_container_subset(idx_hidden_, TopDimensionHiddenIds{});
}
CK_TILE_HOST_DEVICE constexpr auto get_bottom_index() const
{
return get_container_subset(idx_hidden_, BottomDimensionHiddenIds{});
}
CK_TILE_HOST_DEVICE constexpr const auto& get_hidden_index() const { return idx_hidden_; }
CK_TILE_HOST_DEVICE constexpr auto& get_hidden_index() { return idx_hidden_; }
//
HiddenIndex idx_hidden_;
};
template <typename Adaptor, typename TopIndex>
CK_TILE_HOST_DEVICE constexpr auto make_tensor_adaptor_coordinate(const Adaptor& adaptor,
const TopIndex& idx_top)
{
static_assert(Adaptor::get_num_of_top_dimension() == TopIndex::size(),
"wrong! # of dimension inconsistent");
constexpr index_t ntransform = Adaptor::get_num_of_transform();
constexpr index_t ndim_hidden = Adaptor::get_num_of_hidden_dimension();
constexpr auto bottom_dim_ids = Adaptor::get_bottom_dimension_hidden_ids();
constexpr auto top_dim_ids = Adaptor::get_top_dimension_hidden_ids();
multi_index<ndim_hidden> idx_hidden;
// initialize visible index
set_container_subset(idx_hidden, top_dim_ids, idx_top);
// calculate hidden index
static_for<ntransform, 0, -1>{}([&adaptor, &idx_hidden](auto itran_p1) {
auto itran = itran_p1 - number<1>{};
const auto& tran = adaptor.get_transforms().at(itran);
constexpr auto dims_low = Adaptor::get_lower_dimension_hidden_idss().at(itran);
constexpr auto dims_up = Adaptor::get_upper_dimension_hidden_idss().at(itran);
const auto idx_up = get_container_subset(idx_hidden, dims_up);
multi_index<dims_low.size()> idx_low;
tran.calculate_lower_index(idx_low, idx_up);
set_container_subset(idx_hidden, dims_low, idx_low);
});
return tensor_adaptor_coordinate<ndim_hidden,
remove_cvref_t<decltype(bottom_dim_ids)>,
remove_cvref_t<decltype(top_dim_ids)>>{idx_hidden};
}
template <bool JudgeDoTransforms = true,
typename Adaptor,
typename AdaptorCoord,
typename TopIndex,
typename BottomIndex>
CK_TILE_HOST_DEVICE constexpr void move_tensor_adaptor_coordinate(const Adaptor& adaptor,
AdaptorCoord& coord,
const TopIndex& idx_diff_top,
BottomIndex& idx_diff_bottom)
{
constexpr index_t ndim_hidden = Adaptor::get_num_of_hidden_dimension();
constexpr index_t ndim_top = Adaptor::get_num_of_top_dimension();
// constexpr index_t ndim_bottom = Adaptor::get_num_of_bottom_dimension();
constexpr index_t ntransform = Adaptor::get_num_of_transform();
// STATIC_ASSERT(TopIndex::size() == ndim_top && BottomIndex::size() == ndim_bottom, "");
// judge whether calculation of lower diff is needed for each transform
// use index_t for boolean type
auto do_transforms = make_zero_multi_index<ntransform>();
if constexpr(JudgeDoTransforms)
{
auto is_non_zero_diff = make_zero_multi_index<ndim_hidden>();
// decide do_transform by checkout non-zero index diff components
multi_index<ndim_top> non_zero_diff_pick_top;
static_for<0, ndim_top, 1>{}(
[&](auto i) { non_zero_diff_pick_top(i) = (idx_diff_top[i] != 0); });
set_container_subset(
is_non_zero_diff, Adaptor::get_top_dimension_hidden_ids(), non_zero_diff_pick_top);
static_for<ntransform - 1, -1, -1>{}([&](auto itran) {
constexpr auto dims_low = Adaptor::get_lower_dimension_hidden_idss().at(itran);
constexpr auto dims_up = Adaptor::get_upper_dimension_hidden_idss().at(itran);
const auto non_zero_diff_pick_up = get_container_subset(is_non_zero_diff, dims_up);
multi_index<dims_low.size()> non_zero_diff_pick_low;
// if any of upper index diff components is non-zero, then
// 1) Need to do this transform
// 2) all components of lower index diff will assume to be non-zero and need to be
// computed
const bool idx_diff_up_has_non_zero = container_reduce(
non_zero_diff_pick_up, [](auto a, auto b) constexpr { return a or b; }, false);
do_transforms(itran) = idx_diff_up_has_non_zero;
static_for<0, dims_low.size(), 1>{}(
[&](auto i) { non_zero_diff_pick_low(i) = idx_diff_up_has_non_zero; });
set_container_subset(is_non_zero_diff, dims_low, non_zero_diff_pick_low);
});
}
else
{
static_for<ntransform - 1, -1, -1>{}([&](auto itran) { do_transforms(itran) = 1; });
}
// this is what needs to be calculated
auto idx_diff_hidden = make_zero_multi_index<ndim_hidden>();
// initialize top index diff
set_container_subset(idx_diff_hidden, Adaptor::get_top_dimension_hidden_ids(), idx_diff_top);
// this is what needs to be updated
auto& idx_hidden = coord.get_hidden_index();
// update top index
auto idx_hidden_pick_top =
get_container_subset(idx_hidden, Adaptor::get_top_dimension_hidden_ids());
idx_hidden_pick_top += idx_diff_top;
set_container_subset(idx_hidden, Adaptor::get_top_dimension_hidden_ids(), idx_hidden_pick_top);
// update rest of hidden index
static_for<ntransform - 1, -1, -1>{}([&](auto itran) {
if(do_transforms[itran])
{
const auto& tran = adaptor.get_transforms().at(itran);
constexpr auto dims_low = Adaptor::get_lower_dimension_hidden_idss().at(itran);
constexpr auto dims_up = Adaptor::get_upper_dimension_hidden_idss().at(itran);
const auto idx_up_new = get_container_subset(idx_hidden, dims_up);
auto idx_low = get_container_subset(idx_hidden, dims_low);
const auto idx_diff_up = get_container_subset(idx_diff_hidden, dims_up);
multi_index<dims_low.size()> idx_diff_low;
tran.update_lower_index(idx_diff_low, idx_diff_up, idx_low, idx_up_new);
set_container_subset(idx_diff_hidden, dims_low, idx_diff_low);
set_container_subset(idx_hidden, dims_low, idx_low);
}
});
// set bottom index diff
idx_diff_bottom =
get_container_subset(idx_diff_hidden, Adaptor::get_bottom_dimension_hidden_ids());
}
template <bool JudgeDoTransforms = true, typename Adaptor, typename AdaptorCoord, typename TopIndex>
CK_TILE_HOST_DEVICE constexpr void move_tensor_adaptor_coordinate(const Adaptor& adaptor,
AdaptorCoord& coord,
const TopIndex& idx_diff_top)
{
constexpr index_t ndim_bottom = Adaptor::get_num_of_bottom_dimension();
multi_index<ndim_bottom> tmp;
move_tensor_adaptor_coordinate<JudgeDoTransforms>(adaptor, coord, idx_diff_top, tmp);
}
template <typename Adaptor, typename AdaptorCoord>
CK_TILE_HOST_DEVICE constexpr bool
adaptor_coordinate_is_valid_assuming_top_index_is_valid(const Adaptor& adaptor,
const AdaptorCoord& coord)
{
bool valid = true;
constexpr index_t ntransform = Adaptor::get_num_of_transform();
const auto& idx_hidden = coord.get_hidden_index();
static_for<ntransform - 1, -1, -1>{}([&adaptor, &idx_hidden, &valid](auto itran) {
const auto tran = adaptor.get_transforms().at(itran);
// check validity, only if current transformation does not always has a valid mapping
if constexpr(!decltype(tran)::is_valid_upper_index_always_mapped_to_valid_lower_index())
{
const auto idx_up = get_container_subset(
idx_hidden, Adaptor::get_upper_dimension_hidden_idss().at(itran));
// Comment: using valid = valid && .. will result in weird control flow in ISA
valid &= tran.is_valid_upper_index_mapped_to_valid_lower_index(idx_up);
}
});
return valid;
}
template <typename Adaptor, typename AdpatorCoord>
CK_TILE_HOST_DEVICE constexpr bool adaptor_coordinate_is_valid(const Adaptor& adaptor,
const AdpatorCoord& coord)
{
// check top index
const auto& idx_top = coord.get_top_index();
bool is_top_index_valid = true;
static_for<0, Adaptor::get_num_of_dimension(), 1>{}(
[&is_top_index_valid, &idx_top, &adaptor](auto i) {
is_top_index_valid =
is_top_index_valid && (idx_top[i] >= 0 && idx_top[i] < adaptor.get_length(i));
});
// check other hidden index
return is_top_index_valid &&
adaptor_coordinate_is_valid_assuming_top_index_is_valid(adaptor, coord);
}
} // namespace ck_tile

View File

@@ -0,0 +1,92 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
#include "ck_tile/core/tensor/tensor_adaptor.hpp"
#include "ck_tile/core/tensor/tensor_adaptor_coordinate.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/container/multi_index.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace ck_tile {
template <index_t NDimHidden, typename TopDimensionHiddenIds>
struct tensor_coordinate
: public tensor_adaptor_coordinate<NDimHidden, sequence<0>, TopDimensionHiddenIds>
{
using Base = tensor_adaptor_coordinate<NDimHidden, sequence<0>, TopDimensionHiddenIds>;
// TODO make these private
static constexpr index_t ndim_top_ = TopDimensionHiddenIds::size();
using HiddenIndex = multi_index<NDimHidden>;
using TopIndex = multi_index<ndim_top_>;
public:
CK_TILE_HOST_DEVICE constexpr tensor_coordinate() = default;
CK_TILE_HOST_DEVICE constexpr tensor_coordinate(const HiddenIndex& idx_hidden)
: Base{idx_hidden}
{
}
// construct from TensorAdaptorCoordinte base class
CK_TILE_HOST_DEVICE constexpr tensor_coordinate(const Base& adaptor_coord) : Base{adaptor_coord}
{
}
CK_TILE_HOST_DEVICE constexpr auto get_index() const { return Base::get_top_index(); }
CK_TILE_HOST_DEVICE constexpr index_t get_offset() const
{
return Base::get_bottom_index()[number<0>{}];
}
CK_TILE_HOST_DEVICE constexpr const auto& get_hidden_index() const
{
return Base::get_hidden_index();
}
CK_TILE_HOST_DEVICE auto& get_hidden_index() { return Base::get_hidden_index(); }
};
template <typename TensorDesc, typename TopIndex>
CK_TILE_HOST_DEVICE constexpr auto make_tensor_coordinate(const TensorDesc& tensor_desc,
const TopIndex& idx_top)
{
const auto adaptor_coord = make_tensor_adaptor_coordinate(tensor_desc, idx_top);
return tensor_coordinate<TensorDesc::get_num_of_hidden_dimension(),
remove_cvref_t<decltype(TensorDesc::get_top_dimension_hidden_ids())>>{
adaptor_coord};
}
template <bool JudgeDoTransforms = true, typename TensorDesc, typename TensorCoord, typename Index>
CK_TILE_HOST_DEVICE constexpr void
move_tensor_coordinate(const TensorDesc& tensor_desc, TensorCoord& coord, const Index& coord_step)
{
move_tensor_adaptor_coordinate(tensor_desc, coord, coord_step);
}
template <typename TensorDesc, typename TensorCoord>
CK_TILE_HOST_DEVICE constexpr bool
coordinate_has_valid_offset_assuming_top_index_is_valid(const TensorDesc& tensor_desc,
const TensorCoord& coord)
{
return adaptor_coordinate_is_valid_assuming_top_index_is_valid(tensor_desc, coord);
}
template <typename TensorDesc, typename TensorCoord>
CK_TILE_HOST_DEVICE constexpr bool coordinate_has_valid_offset(const TensorDesc& tensor_desc,
const TensorCoord& coord)
{
return adaptor_coordinate_is_valid(tensor_desc, coord);
}
} // namespace ck_tile

View File

@@ -0,0 +1,472 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
#include "ck_tile/core/tensor/tensor_adaptor.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/container/multi_index.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace ck_tile {
// Transforms: Tuple<transforms...>
// LowerDimensionHiddenIdss : Tuple<sequence<...>, ...>
// UpperDimensionHiddenIdss : Tuple<sequence<...>, ...>
// TopDimensionHiddenIds> : sequence<...>
template <typename Transforms,
typename LowerDimensionHiddenIdss,
typename UpperDimensionHiddenIdss,
typename TopDimensionHiddenIds,
typename ElementSpaceSize,
typename GuaranteedVectorLengths_,
typename GuaranteedVectorSrides_>
struct tensor_descriptor : public tensor_adaptor<Transforms,
LowerDimensionHiddenIdss,
UpperDimensionHiddenIdss,
sequence<0>,
TopDimensionHiddenIds>
{
using Base = tensor_adaptor<Transforms,
LowerDimensionHiddenIdss,
UpperDimensionHiddenIdss,
sequence<0>,
TopDimensionHiddenIds>;
using ElementSpaceSizeType = ElementSpaceSize;
constexpr static index_t ntransform_ = Base::get_num_of_transform();
constexpr static index_t ndim_hidden_ = Base::get_num_of_hidden_dimension();
constexpr static index_t ndim_top_ = Base::get_num_of_top_dimension();
using GuaranteedVectorLengths = GuaranteedVectorLengths_;
using GuaranteedVectorStrides = GuaranteedVectorSrides_;
static_assert(GuaranteedVectorLengths::size() == ndim_hidden_ &&
GuaranteedVectorStrides::size() == ndim_hidden_,
"wrong! inconsistent # of hidden dimensions");
using TopIndex = multi_index<ndim_top_>;
using HiddenIndex = multi_index<ndim_hidden_>;
public:
CK_TILE_HOST_DEVICE constexpr tensor_descriptor() = default;
CK_TILE_HOST_DEVICE constexpr tensor_descriptor(const Transforms& transforms,
ElementSpaceSize element_space_size)
: Base{transforms}, element_space_size_{element_space_size}
{
static_assert(Transforms::size() == ntransform_ &&
LowerDimensionHiddenIdss::size() == ntransform_ &&
UpperDimensionHiddenIdss::size() == ntransform_,
"wrong! inconsistent # of transformations");
// TODO check dependency of dimensions is valid
}
// construct from tensor_adaptor base class
CK_TILE_HOST_DEVICE constexpr tensor_descriptor(const Base& adaptor,
ElementSpaceSize element_space_size)
: Base{adaptor}, element_space_size_{element_space_size}
{
}
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_dimension()
{
return Base::get_num_of_top_dimension();
}
template <index_t IDim>
CK_TILE_HOST_DEVICE constexpr auto get_length(number<IDim> idim) const
{
return Base::get_top_dimension_length(idim);
}
CK_TILE_HOST_DEVICE constexpr auto get_lengths() const
{
return Base::get_top_dimension_length();
}
CK_TILE_HOST_DEVICE constexpr auto get_element_space_size() const
{
return element_space_size_;
}
template <typename Idx>
CK_TILE_HOST_DEVICE constexpr index_t calculate_offset(const Idx& idx) const
{
return Base::calculate_bottom_index(idx)[number<0>{}];
}
// TODO make these private
CK_TILE_HOST_DEVICE constexpr const auto& get_transforms() const
{
return Base::get_transforms();
}
CK_TILE_HOST_DEVICE static constexpr auto get_lower_dimension_hidden_idss()
{
return Base::get_lower_dimension_hidden_idss();
}
CK_TILE_HOST_DEVICE static constexpr auto get_upper_dimension_hidden_idss()
{
return Base::get_upper_dimension_hidden_idss();
}
CK_TILE_HOST_DEVICE static constexpr auto get_top_dimension_hidden_ids()
{
return Base::get_top_dimension_hidden_ids();
}
CK_TILE_HOST_DEVICE static constexpr bool is_static()
{
return Base::is_known_at_compile_time() &&
ck_tile::is_known_at_compile_time<ElementSpaceSize>::value;
}
CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time() { return is_static(); }
CK_TILE_HOST_DEVICE static constexpr auto get_top_dimension_safe_vector_length_strides()
{
return Base::get_top_dimension_safe_vector_length_strides(
to_array<index_t, ndim_hidden_>(GuaranteedVectorLengths{}),
to_array<index_t, ndim_hidden_>(GuaranteedVectorStrides{}));
}
CK_TILE_HOST_DEVICE void print() const
{
printf("tensor_descriptor{");
// tensor_adaptor
Base::print();
printf(", ");
// element_space_size_
printf("element_space_size_: ");
print(element_space_size_);
printf("}");
}
// TODO make these private
ElementSpaceSize element_space_size_;
};
template <typename Adaptor, typename ElementSpaceSize>
CK_TILE_HOST_DEVICE constexpr auto
make_tensor_descriptor_from_adaptor(const Adaptor& adaptor,
const ElementSpaceSize& element_space_size)
{
constexpr index_t NDimHidden = Adaptor::get_num_of_hidden_dimension();
return tensor_descriptor<remove_cvref_t<decltype(adaptor.get_transforms())>,
remove_cvref_t<decltype(adaptor.get_lower_dimension_hidden_idss())>,
remove_cvref_t<decltype(adaptor.get_upper_dimension_hidden_idss())>,
remove_cvref_t<decltype(adaptor.get_top_dimension_hidden_ids())>,
remove_cvref_t<decltype(element_space_size)>,
typename uniform_sequence_gen<NDimHidden, -1>::type,
typename uniform_sequence_gen<NDimHidden, -1>::type>{
adaptor, element_space_size};
}
template <typename OldTensorDescriptor,
typename NewTransforms,
typename NewLowerDimensionOldTopIdss,
typename NewUpperDimensionNewTopIdss>
CK_TILE_HOST_DEVICE constexpr auto
transform_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc,
const NewTransforms& new_transforms,
NewLowerDimensionOldTopIdss,
NewUpperDimensionNewTopIdss)
{
const auto element_space_size = old_tensor_desc.get_element_space_size();
const auto new_tensor_adaptor = transform_tensor_adaptor(old_tensor_desc,
new_transforms,
NewLowerDimensionOldTopIdss{},
NewUpperDimensionNewTopIdss{});
constexpr index_t NDimHiddenOld = OldTensorDescriptor::get_num_of_hidden_dimension();
constexpr index_t NDimHiddenNew = decltype(new_tensor_adaptor)::get_num_of_hidden_dimension();
using NewGuaranteedVectorLengths = typename sequence_merge<
typename OldTensorDescriptor::GuaranteedVectorLengths,
typename uniform_sequence_gen<NDimHiddenNew - NDimHiddenOld, -1>::type>::type;
using NewGuaranteedVectorStrides = typename sequence_merge<
typename OldTensorDescriptor::GuaranteedVectorStrides,
typename uniform_sequence_gen<NDimHiddenNew - NDimHiddenOld, -1>::type>::type;
return tensor_descriptor<
remove_cvref_t<decltype(new_tensor_adaptor.get_transforms())>,
remove_cvref_t<decltype(new_tensor_adaptor.get_lower_dimension_hidden_idss())>,
remove_cvref_t<decltype(new_tensor_adaptor.get_upper_dimension_hidden_idss())>,
remove_cvref_t<decltype(new_tensor_adaptor.get_top_dimension_hidden_ids())>,
remove_cvref_t<decltype(element_space_size)>,
NewGuaranteedVectorLengths,
NewGuaranteedVectorStrides>{new_tensor_adaptor, element_space_size};
}
namespace detail {
template <typename Lengths, typename Strides, index_t I, typename AccOld>
CK_TILE_HOST_DEVICE constexpr auto calculate_element_space_size_impl(const Lengths& lengths,
const Strides& strides,
number<I> i,
AccOld acc_old)
{
auto acc_new = acc_old + (lengths[i] - number<1>{}) * strides[i];
if constexpr(i.value < Lengths::size() - 1)
{
return calculate_element_space_size_impl(lengths, strides, i + number<1>{}, acc_new);
}
else
{
return acc_new;
}
}
} // namespace detail
/*
* These functions create naive tensor descriptor
*/
// Lengths..., Strides... could be:
// 1) index_t, which is known at run-time, or
// 2) number<>, which is known at compile-time
// element_space_size could be:
// 1) long_index_t, or
// 2) long_number<>
template <typename... Lengths,
typename... Strides,
index_t GuaranteedLastDimensionVectorLength = -1,
index_t GuaranteedLastDimensionVectorStride = -1,
typename std::enable_if<sizeof...(Lengths) == sizeof...(Strides), bool>::type = false>
CK_TILE_HOST_DEVICE constexpr auto
make_naive_tensor_descriptor(const tuple<Lengths...>& lengths,
const tuple<Strides...>& strides,
number<GuaranteedLastDimensionVectorLength> = number<-1>{},
number<GuaranteedLastDimensionVectorStride> = number<-1>{})
{
constexpr index_t N = sizeof...(Lengths);
const auto transforms = make_tuple(make_embed_transform(lengths, strides));
constexpr auto low_dim_hidden_idss = make_tuple(sequence<0>{});
constexpr auto up_dim_hidden_idss =
make_tuple(typename arithmetic_sequence_gen<1, N + 1, 1>::type{});
constexpr auto visible_dim_hidden_ids = typename arithmetic_sequence_gen<1, N + 1, 1>::type{};
const auto element_space_size =
detail::calculate_element_space_size_impl(lengths, strides, number<0>{}, long_number<1>{});
using GuaranteedVectorLengths =
typename sequence_merge<typename uniform_sequence_gen<N, -1>::type,
sequence<GuaranteedLastDimensionVectorLength>>::type;
using GuaranteedVectorStrides =
typename sequence_merge<typename uniform_sequence_gen<N, -1>::type,
sequence<GuaranteedLastDimensionVectorStride>>::type;
return tensor_descriptor<remove_cv_t<decltype(transforms)>,
remove_cv_t<decltype(low_dim_hidden_idss)>,
remove_cv_t<decltype(up_dim_hidden_idss)>,
remove_cv_t<decltype(visible_dim_hidden_ids)>,
remove_cv_t<decltype(element_space_size)>,
GuaranteedVectorLengths,
GuaranteedVectorStrides>{transforms, element_space_size};
}
// tensor descriptor with offset, the offset will not be added into element space size
// only have an information of the starting offset, and will impact on offset calculation
template <typename... Lengths,
typename... Strides,
typename offset,
index_t GuaranteedLastDimensionVectorLength = -1,
index_t GuaranteedLastDimensionVectorStride = -1,
typename std::enable_if<sizeof...(Lengths) == sizeof...(Strides), bool>::type = false>
CK_TILE_HOST_DEVICE constexpr auto
make_naive_tensor_descriptor_with_offset(const tuple<Lengths...>& lengths,
const tuple<Strides...>& strides,
const offset& offset,
number<GuaranteedLastDimensionVectorLength> = number<-1>{},
number<GuaranteedLastDimensionVectorStride> = number<-1>{})
{
const auto desc_0 = [&]() {
const auto element_space_size = detail::calculate_element_space_size_impl(
lengths, strides, number<0>{}, long_number<1>{});
const auto transforms = make_tuple(make_offset_transform(element_space_size, offset));
constexpr auto low_dim_hidden_idss = make_tuple(sequence<0>{});
constexpr auto up_dim_hidden_idss = make_tuple(sequence<1>{});
constexpr auto visible_dim_hidden_ids = sequence<1>{};
using GuaranteedVectorLengths =
typename sequence_merge<typename uniform_sequence_gen<1, -1>::type,
sequence<GuaranteedLastDimensionVectorLength>>::type;
using GuaranteedVectorStrides =
typename sequence_merge<typename uniform_sequence_gen<1, -1>::type,
sequence<GuaranteedLastDimensionVectorStride>>::type;
return tensor_descriptor<remove_cv_t<decltype(transforms)>,
remove_cv_t<decltype(low_dim_hidden_idss)>,
remove_cv_t<decltype(up_dim_hidden_idss)>,
remove_cv_t<decltype(visible_dim_hidden_ids)>,
remove_cv_t<decltype(element_space_size)>,
GuaranteedVectorLengths,
GuaranteedVectorStrides>{transforms, element_space_size};
}();
constexpr index_t N = sizeof...(Lengths);
return transform_tensor_descriptor(
desc_0,
make_tuple(make_embed_transform(lengths, strides)),
make_tuple(sequence<0>{}),
make_tuple(typename arithmetic_sequence_gen<0, N, 1>::type{}));
}
// Lengths... could be:
// 1) index_t, which is known at run-time, or
// 2) number<>, which is known at compile-time
// element_space_size could be:
// 1) long_index_t, or
// 2) long_number<>
template <typename... Lengths, index_t GuaranteedLastDimensionVectorLength = -1>
CK_TILE_HOST_DEVICE constexpr auto
make_naive_tensor_descriptor_packed(const tuple<Lengths...>& lengths,
number<GuaranteedLastDimensionVectorLength> = number<-1>{})
{
constexpr index_t N = sizeof...(Lengths);
const auto transforms = make_tuple(make_unmerge_transform(lengths));
constexpr auto low_dim_hidden_idss = make_tuple(sequence<0>{});
constexpr auto up_dim_hidden_idss =
make_tuple(typename arithmetic_sequence_gen<1, N + 1, 1>::type{});
constexpr auto visible_dim_hidden_ids = typename arithmetic_sequence_gen<1, N + 1, 1>::type{};
const auto element_space_size = container_reduce(lengths, math::multiplies{}, long_number<1>{});
using GuaranteedVectorLengths =
typename sequence_merge<typename uniform_sequence_gen<N, -1>::type,
sequence<GuaranteedLastDimensionVectorLength>>::type;
using GuaranteedVectorStrides =
typename sequence_merge<typename uniform_sequence_gen<N, -1>::type, sequence<1>>::type;
return tensor_descriptor<remove_cv_t<decltype(transforms)>,
remove_cv_t<decltype(low_dim_hidden_idss)>,
remove_cv_t<decltype(up_dim_hidden_idss)>,
remove_cv_t<decltype(visible_dim_hidden_ids)>,
remove_cv_t<decltype(element_space_size)>,
GuaranteedVectorLengths,
GuaranteedVectorStrides>{transforms, element_space_size};
}
template <typename... Lengths,
typename... Strides,
typename offset,
index_t GuaranteedLastDimensionVectorLength = -1,
typename std::enable_if<sizeof...(Lengths) == sizeof...(Strides), bool>::type = false>
CK_TILE_HOST_DEVICE constexpr auto make_naive_tensor_descriptor_packed_with_offset(
const tuple<Lengths...>& lengths,
const offset& offset,
number<GuaranteedLastDimensionVectorLength> = number<-1>{})
{
const auto desc_0 = [&]() {
const auto element_space_size =
container_reduce(lengths, math::multiplies{}, long_number<1>{});
const auto transforms = make_tuple(make_offset_transform(element_space_size, offset));
constexpr auto low_dim_hidden_idss = make_tuple(sequence<0>{});
constexpr auto up_dim_hidden_idss = make_tuple(sequence<1>{});
constexpr auto visible_dim_hidden_ids = sequence<1>{};
using GuaranteedVectorLengths =
typename sequence_merge<typename uniform_sequence_gen<1, -1>::type,
sequence<GuaranteedLastDimensionVectorLength>>::type;
using GuaranteedVectorStrides =
typename sequence_merge<typename uniform_sequence_gen<1, -1>::type, sequence<1>>::type;
return tensor_descriptor<remove_cv_t<decltype(transforms)>,
remove_cv_t<decltype(low_dim_hidden_idss)>,
remove_cv_t<decltype(up_dim_hidden_idss)>,
remove_cv_t<decltype(visible_dim_hidden_ids)>,
remove_cv_t<decltype(element_space_size)>,
GuaranteedVectorLengths,
GuaranteedVectorStrides>{transforms, element_space_size};
}();
constexpr index_t N = sizeof...(Lengths);
return transform_tensor_descriptor(
desc_0,
make_tuple(make_unmerge_transform(lengths)),
make_tuple(sequence<0>{}),
make_tuple(typename arithmetic_sequence_gen<0, N, 1>::type{}));
}
// Lengths... could be:
// 1) index_t, which is known at run-time, or
// 2) number<>, which is known at compile-time
// align could be:
// 1) index_t, or
// 2) number<>
template <typename... Lengths, typename Align>
CK_TILE_HOST_DEVICE constexpr auto
make_naive_tensor_descriptor_aligned(const tuple<Lengths...>& lengths, Align align)
{
constexpr auto I1 = number<1>{};
constexpr index_t N = sizeof...(Lengths);
const auto stride_n_minus_2 = math::integer_least_multiple(lengths[number<N - 1>{}], align);
auto strides = generate_tuple(
[&](auto i) {
if constexpr(i.value == N - 1)
{
return I1;
}
else if constexpr(i.value == N - 2)
{
return number<stride_n_minus_2>{};
}
else
{
return container_reduce(lengths,
math::multiplies{},
number<stride_n_minus_2>{},
i + I1,
number<N - 1>{},
I1);
}
},
number<N>{});
return make_naive_tensor_descriptor(lengths, strides);
}
} // namespace ck_tile

View File

@@ -0,0 +1,273 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace ck_tile {
template <typename BufferView_, typename TensorDesc_>
struct tensor_view
{
using BufferView = remove_reference_t<BufferView_>;
using DataType = typename BufferView::type;
using TensorDesc = remove_cvref_t<TensorDesc_>;
using TensorIndex = array<index_t, TensorDesc::get_num_of_top_dimension()>;
using TensorCoord = decltype(make_tensor_coordinate(TensorDesc{}, TensorIndex{}));
CK_TILE_HOST_DEVICE constexpr tensor_view() = default;
CK_TILE_HOST_DEVICE constexpr tensor_view(const BufferView& buffer_view, const TensorDesc& desc)
: buf_{buffer_view}, desc_{desc}
{
}
CK_TILE_HOST_DEVICE constexpr auto& get_tensor_descriptor() const { return desc_; }
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_dimension()
{
return TensorDesc::get_num_of_top_dimension();
}
CK_TILE_HOST_DEVICE constexpr const auto& get_buffer_view() const { return buf_; }
CK_TILE_HOST_DEVICE constexpr auto& get_buffer_view() { return buf_; }
#if 0
CK_TILE_HOST_DEVICE constexpr DataType get_element(const TensorCoord& coord) const
{
return buf_.template get<DataType>(
coord.get_offset(),
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord));
}
CK_TILE_HOST_DEVICE constexpr void set_element(const TensorCoord& coord, const DataType& x)
{
buf_.template set<DataType>(
coord.get_offset(),
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
x);
}
#endif
// X is vector of DataType.
// "coord" is coordinate of DataType, not X. "coord" should be aligned to X
template <typename X,
bool oob_conditional_check = true,
typename enable_if<is_same_v<typename scalar_type<remove_cvref_t<X>>::type,
typename scalar_type<remove_cvref_t<DataType>>::type>,
bool>::type = false>
CK_TILE_HOST_DEVICE constexpr remove_cvref_t<X>
get_vectorized_elements(const TensorCoord& coord,
bool_constant<oob_conditional_check> = {}) const
{
return buf_.template get<X>(
coord.get_offset(),
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
bool_constant<oob_conditional_check>{});
}
// X is vector of DataType.
// "coord" is coordinate of DataType, not X. "coord" should be aligned to X
template <typename X,
bool oob_conditional_check = true,
typename enable_if<is_same_v<typename scalar_type<remove_cvref_t<X>>::type,
typename scalar_type<remove_cvref_t<DataType>>::type>,
bool>::type = false>
CK_TILE_HOST_DEVICE void
get_vectorized_elements_raw(remove_cvref_t<X>& dst,
const TensorCoord& coord,
bool_constant<oob_conditional_check> = {}) const
{
return buf_.template get_raw<X, oob_conditional_check>(
dst,
coord.get_offset(),
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord));
}
template <typename X,
typename enable_if<is_same_v<typename scalar_type<remove_cvref_t<X>>::type,
typename scalar_type<remove_cvref_t<DataType>>::type>,
bool>::type = false>
CK_TILE_HOST_DEVICE constexpr void async_get_vectorized_elements(remove_cvref_t<DataType>* smem,
const TensorCoord& coord) const
{
return buf_.template async_get<X>(smem, coord.get_offset(), true /*not used*/);
}
// X is vector of DataType.
// "coord" is coordinate of DataType, not X. "coord" should be aligned to X
template <typename X,
bool oob_conditional_check = true,
typename enable_if<is_same_v<typename scalar_type<remove_cvref_t<X>>::type,
typename scalar_type<remove_cvref_t<DataType>>::type>,
bool>::type = false>
CK_TILE_HOST_DEVICE constexpr void set_vectorized_elements(
const TensorCoord& coord, const X& x, bool_constant<oob_conditional_check> = {})
{
buf_.template set<X, oob_conditional_check>(
coord.get_offset(),
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
x);
}
template <typename X,
bool oob_conditional_check = true,
typename enable_if<is_same_v<typename scalar_type<remove_cvref_t<X>>::type,
typename scalar_type<remove_cvref_t<DataType>>::type>,
bool>::type = false>
CK_TILE_HOST_DEVICE constexpr void set_vectorized_elements_raw(
const TensorCoord& coord, const X& x, bool_constant<oob_conditional_check> = {})
{
buf_.template set_raw<X, oob_conditional_check>(
coord.get_offset(),
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
x);
}
CK_TILE_HOST_DEVICE void print() const
{
printf("tensor_view{");
// buf_
printf("buf_: ");
print(buf_);
printf(", ");
// desc_
printf("desc_: ");
print(desc_);
printf("}");
}
// member
BufferView buf_;
TensorDesc desc_;
};
// placeholder type if we want to opt-out a tile view parameter
struct null_tensor_view
{
};
template <AddressSpaceEnum BufferAddressSpace = AddressSpaceEnum::Generic,
typename DataType,
typename... Ts>
CK_TILE_HOST_DEVICE constexpr auto make_tensor_view(DataType* p,
const tensor_descriptor<Ts...>& desc)
{
auto buffer_view = make_buffer_view<BufferAddressSpace>(p, desc.get_element_space_size());
return tensor_view<decltype(buffer_view), decltype(desc)>{buffer_view, desc};
}
template <AddressSpaceEnum BufferAddressSpace = AddressSpaceEnum::Generic,
typename DataType,
typename... Lengths,
typename... Strides,
index_t GuaranteedLastDimensionVectorLength = -1,
index_t GuaranteedLastDimensionVectorStride = -1,
typename enable_if<sizeof...(Lengths) == sizeof...(Strides), bool>::type = false>
CK_TILE_HOST_DEVICE constexpr auto
make_naive_tensor_view(DataType* p,
const tuple<Lengths...>& lengths,
const tuple<Strides...>& strides,
number<GuaranteedLastDimensionVectorLength> = number<-1>{},
number<GuaranteedLastDimensionVectorStride> = number<-1>{})
{
auto desc = make_naive_tensor_descriptor(lengths,
strides,
number<GuaranteedLastDimensionVectorLength>{},
number<GuaranteedLastDimensionVectorStride>{});
auto buffer_view = make_buffer_view<BufferAddressSpace>(p, desc.get_element_space_size());
return tensor_view<decltype(buffer_view), decltype(desc)>{buffer_view, desc};
}
template <AddressSpaceEnum BufferAddressSpace = AddressSpaceEnum::Generic,
typename DataType,
typename... Lengths,
index_t GuaranteedLastDimensionVectorLength = -1>
CK_TILE_HOST_DEVICE constexpr auto
make_naive_tensor_view_packed(DataType* p,
const tuple<Lengths...>& lengths,
number<GuaranteedLastDimensionVectorLength> = number<-1>{})
{
auto desc =
make_naive_tensor_descriptor_packed(lengths, number<GuaranteedLastDimensionVectorLength>{});
auto buffer_view = make_buffer_view<BufferAddressSpace>(p, desc.get_element_space_size());
return tensor_view<decltype(buffer_view), decltype(desc)>{buffer_view, desc};
}
template <typename OldTensorView,
typename NewTransforms,
typename NewLowerDimensionOldVisibleIdss,
typename NewUpperDimensionNewVisibleIdss>
CK_TILE_HOST_DEVICE constexpr auto transform_tensor_view(const OldTensorView& old_tensor_view,
const NewTransforms& new_transforms,
NewLowerDimensionOldVisibleIdss,
NewUpperDimensionNewVisibleIdss)
{
auto new_desc = transform_tensor_descriptor(old_tensor_view.desc_,
new_transforms,
NewLowerDimensionOldVisibleIdss{},
NewUpperDimensionNewVisibleIdss{});
return tensor_view<typename OldTensorView::BufferView, remove_cvref_t<decltype(new_desc)>>{
old_tensor_view.buf_, new_desc};
}
template <typename tensor_view,
typename TileLengths, // tuple<...>
typename DoPads> // sequence<bool, bool, ...>
CK_TILE_HOST_DEVICE constexpr auto
pad_tensor_view(const tensor_view& tensor_view, const TileLengths& tile_lengths, DoPads)
{
constexpr index_t num_dim = DoPads::size();
static_assert(num_dim == TileLengths::size() && num_dim == tensor_view::get_num_of_dimension(),
"wrong! inconsistent # of dimensions");
// transforms
const auto transforms = generate_tuple(
[&](auto idim) {
const auto old_length = tensor_view.get_tensor_descriptor().get_length(idim);
const auto tile_length = tile_lengths[idim];
const auto new_length =
math::integer_divide_ceil(old_length, tile_length) * tile_length;
const auto pad_length = new_length - old_length;
constexpr bool DoPad = DoPads::at(idim);
const auto transform =
conditional_expr<DoPad>(make_right_pad_transform(old_length, pad_length),
make_pass_through_transform(old_length));
return transform;
},
number<num_dim>{});
// lower dimension Id
const auto lower_dimss =
generate_tuple([&](auto idim) { return sequence<idim.value>{}; }, number<num_dim>{});
// upper dimension Id
const auto upper_dimss = lower_dimss;
return transform_tensor_view(tensor_view, transforms, lower_dimss, upper_dimss);
}
} // namespace ck_tile

View File

@@ -0,0 +1,754 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/array.hpp"
#include "ck_tile/core/container/sequence.hpp"
#include "ck_tile/core/container/tuple.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/tensor/tensor_adaptor.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace ck_tile {
// distributed span
template <index_t... PartialHsLengths>
struct tile_distributed_span
{
using Impl = sequence<PartialHsLengths...>;
static constexpr auto impl_ = Impl{};
CK_TILE_HOST_DEVICE static constexpr bool is_static() { return true; }
};
// distributed index
template <index_t... PartialHsIndices>
struct tile_distributed_index
{
using Impl = sequence<PartialHsIndices...>;
static constexpr auto impl_ = Impl{};
CK_TILE_HOST_DEVICE static constexpr bool is_static() { return true; }
};
namespace detail {
template <index_t... Is>
CK_TILE_HOST_DEVICE constexpr auto make_tile_distributed_span(sequence<Is...>)
{
return tile_distributed_span<Is...>{};
}
template <index_t... Is>
CK_TILE_HOST_DEVICE constexpr auto make_tile_distributed_index(sequence<Is...>)
{
return tile_distributed_index<Is...>{};
}
} // namespace detail
template <typename PsYs2XsAdaptor_,
typename Ys2DDescriptor_,
typename StaticTileDistributionEncoding_,
typename TileDistributionDetail_> // FIXME: this is for hold ad-hoc but useful info,
// should be more elegnat
struct tile_distribution
{
using PsYs2XsAdaptor = remove_cvref_t<PsYs2XsAdaptor_>;
using Ys2DDescriptor = remove_cvref_t<Ys2DDescriptor_>;
using DstrEncode = remove_cvref_t<StaticTileDistributionEncoding_>;
using DstrDetail = remove_cvref_t<TileDistributionDetail_>;
static_assert(PsYs2XsAdaptor::is_static() && Ys2DDescriptor::is_static(),
"wrong! should be static");
static constexpr index_t NDimX = PsYs2XsAdaptor::get_num_of_bottom_dimension();
static constexpr index_t NDimY = Ys2DDescriptor::get_num_of_top_dimension();
static constexpr index_t NDimP = PsYs2XsAdaptor::get_num_of_top_dimension() - NDimY;
static constexpr index_t NDimR = StaticTileDistributionEncoding_::NDimR;
PsYs2XsAdaptor ps_ys_to_xs_;
Ys2DDescriptor ys_to_d_;
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_dimension_x() { return NDimX; }
CK_TILE_HOST_DEVICE static constexpr index_t GetNumOfDimensionY() { return NDimY; }
CK_TILE_HOST_DEVICE static constexpr index_t GetNumOfDimensionP() { return NDimP; }
CK_TILE_HOST_DEVICE static constexpr index_t GetNumOfDimensionR() { return NDimR; }
CK_TILE_HOST_DEVICE static constexpr auto get_lengths()
{
#if 0
// FIXME: tensor_adaptor::GetBottomDimensionLengths is wrong. re-enable this after it's fixed
ps_ys_to_xs_.GetBottomDimensionLengths();
#else
return generate_tuple(
[&](auto i) {
constexpr index_t x_length =
container_reduce(typename DstrEncode::HsLengthss{}[i], math::multiplies{}, 1);
return number<x_length>{};
},
number<NDimX>{});
#endif
}
CK_TILE_HOST_DEVICE constexpr const auto& get_ps_ys_to_xs_adaptor() const
{
return ps_ys_to_xs_;
}
CK_TILE_HOST_DEVICE constexpr const auto& get_ys_to_d_descriptor() const { return ys_to_d_; }
CK_TILE_HOST_DEVICE static constexpr auto get_static_tile_distribution_encoding()
{
return DstrEncode{};
}
#if 1
// Calculate Replication index [R0, R1, ...] based on Partion index
// FIXME: very nasty implementation
template <typename PartitionIndex>
CK_TILE_HOST_DEVICE auto calculate_rs_index_from_ps_index(const PartitionIndex& ps_idx) const
{
static_assert(PartitionIndex::size() == NDimP, "wrong!");
const auto ps_ys_idx = container_concat(ps_idx, array<index_t, NDimY>{0});
const auto dummy_adaptor_coord = make_tensor_adaptor_coordinate(ps_ys_to_xs_, ps_ys_idx);
array<index_t, NDimR> rs_idx;
static_for<0, NDimP, 1>{}([&](auto idim_p) {
constexpr index_t ndim_low = DstrEncode::ps_to_rhss_major_[idim_p].size();
static_for<0, ndim_low, 1>{}([&](auto i) {
constexpr index_t rh_major = DstrEncode::ps_to_rhss_major_[idim_p][i];
constexpr index_t rh_minor = DstrEncode::ps_to_rhss_minor_[idim_p][i];
// 0-th rh_major is the replicate dimension
if constexpr(rh_major == 0)
{
constexpr index_t adaptor_hidden_id =
DstrDetail::rh_major_minor_to_adaptor_hidden_idss_[rh_major][rh_minor];
// fill in
rs_idx(rh_minor) = dummy_adaptor_coord.get_hidden_index()[adaptor_hidden_id];
}
});
});
return rs_idx;
}
#endif
CK_TILE_HOST_DEVICE static constexpr auto get_distributed_spans()
{
constexpr auto distributed_spans_impl = DstrEncode::detail::distributed_spans_lengthss_;
constexpr auto ndims_spans_minor = DstrEncode::detail::ndims_distributed_spans_minor_;
return generate_tuple(
[&](auto i) {
constexpr auto span_impl = distributed_spans_impl[i];
constexpr index_t ndim_span_minor = ndims_spans_minor[i];
constexpr auto span = TO_SEQUENCE(span_impl, ndim_span_minor);
return detail::make_tile_distributed_span(span);
},
number<NDimX>{});
}
// FIXME: it's hacky to get Y index from Distributed-Index
template <typename DistributedIndices>
CK_TILE_HOST_DEVICE static constexpr auto
get_y_indices_from_distributed_indices(DistributedIndices)
{
constexpr auto ys_idx_arr = [] {
array<index_t, NDimY> ys_idx;
static_for<0, NDimY, 1>{}([&](auto i) {
constexpr index_t span_major = DstrEncode::detail::ys_to_span_major_[i];
constexpr index_t span_minor = DstrEncode::detail::ys_to_span_minor_[i];
constexpr auto dstr_index = DistributedIndices{}[number<span_major>{}];
ys_idx(i) = dstr_index.impl_[span_minor];
});
return ys_idx;
}();
constexpr index_t ndim_y = NDimY;
return TO_SEQUENCE(ys_idx_arr, ndim_y);
}
CK_TILE_HOST_DEVICE static constexpr bool is_static()
{
return PsYs2XsAdaptor::is_static() && Ys2DDescriptor::is_static();
}
CK_TILE_HOST_DEVICE void print() const
{
printf("tile_distribution{");
//
printf("tile_distribution_encoding: ");
print(DstrEncode{});
printf(", ");
//
printf("ps_ys_to_xs_: ");
print(ps_ys_to_xs_);
printf(", ");
//
printf("ys_to_d_: ");
print(ys_to_d_);
//
printf("}");
}
};
namespace detail {
template <index_t NDimMax>
CK_TILE_HOST_DEVICE constexpr auto make_sequential_index(index_t ibegin, index_t iend)
{
array<index_t, NDimMax> arr{0};
for(index_t i = 0; i < iend - ibegin; ++i)
{
arr(i) = ibegin + i;
}
return arr;
}
// this returns a constexpr encoding of tile_distribution
template <typename StaticTileDistributionEncoding_>
CK_TILE_HOST_DEVICE constexpr auto
make_adaptor_encoding_for_tile_distribution(StaticTileDistributionEncoding_)
{
using RsLengths = typename StaticTileDistributionEncoding_::RsLengths;
using HsLengthss = typename StaticTileDistributionEncoding_::HsLengthss;
using Ps2RHssMajor = typename StaticTileDistributionEncoding_::Ps2RHssMajor;
using Ps2RHssMinor = typename StaticTileDistributionEncoding_::Ps2RHssMinor;
using Ys2RHsMajor = typename StaticTileDistributionEncoding_::Ys2RHsMajor;
using Ys2RHsMinor = typename StaticTileDistributionEncoding_::Ys2RHsMinor;
// FIXME: increase max value if fail
constexpr index_t kMaxNumTransforms = 20;
constexpr index_t kMaxMetaDataSize = 128;
constexpr index_t kMaxNumDim = 10;
using Name = cood_transform_enum;
using MetaData = meta_data_buffer<kMaxMetaDataSize>;
using NumDim = index_t;
using Dims = array<index_t, kMaxNumDim>;
using Lengths = array<index_t, kMaxNumDim>;
// Tile Adaptor
// bottom dims [x0, x1, x2, ...]
// top dims [p0, p1, ..., y0, y1, ...]
constexpr index_t ndim_x = HsLengthss::size();
// Dim Ids: [idim_x_major, idim_x_minor] to [idim_hidden]
array<array<index_t, kMaxNumDim>, ndim_x + 1> rh_major_minor_to_hidden_ids;
array<array<index_t, kMaxNumDim>, ndim_x + 1> rh_major_minor_to_hidden_lengths;
auto trans = array<tuple<Name, MetaData, NumDim, Dims, NumDim, Dims>, kMaxNumTransforms>{};
index_t num_tran = 0;
index_t hidden_dim_cnt = ndim_x;
// this is replicate transform
{
constexpr index_t ndim_r_minor = RsLengths::size();
constexpr auto r_minor_lengths = RsLengths{};
trans(num_tran++) = {
cood_transform_enum::replicate,
MetaData{to_array<index_t, ndim_r_minor>(r_minor_lengths)},
NumDim{0},
Dims{},
NumDim{ndim_r_minor},
make_sequential_index<kMaxNumDim>(hidden_dim_cnt, hidden_dim_cnt + ndim_r_minor)};
for(index_t i = 0; i < ndim_r_minor; ++i)
{
rh_major_minor_to_hidden_ids(0)(i) = hidden_dim_cnt;
rh_major_minor_to_hidden_lengths(0)(i) = r_minor_lengths[i];
hidden_dim_cnt++;
}
};
// these are Unmerge transforms for X dimesions
static_for<0, ndim_x, 1>{}([&trans,
&num_tran,
&hidden_dim_cnt,
&rh_major_minor_to_hidden_ids,
&rh_major_minor_to_hidden_lengths](auto idim_x) {
constexpr auto h_minor_lengths = tuple_element_t<idim_x, HsLengthss>{};
constexpr index_t ndim_h_minor = h_minor_lengths.size();
trans(num_tran++) = {
cood_transform_enum::unmerge,
MetaData{to_array<index_t, ndim_h_minor>(h_minor_lengths)},
NumDim{1},
Dims{idim_x},
NumDim{ndim_h_minor},
make_sequential_index<kMaxNumDim>(hidden_dim_cnt, hidden_dim_cnt + ndim_h_minor)};
for(index_t i = 0; i < ndim_h_minor; ++i)
{
rh_major_minor_to_hidden_ids(idim_x + 1)(i) = hidden_dim_cnt;
rh_major_minor_to_hidden_lengths(idim_x + 1)(i) = h_minor_lengths[i];
hidden_dim_cnt++;
}
});
// transform: P dimensions
constexpr index_t ndim_p = Ps2RHssMajor::size();
Dims hidden_dim_id_ps;
static_for<0, ndim_p, 1>{}([&](auto iDimP) {
//
index_t hidden_dim_id_p = hidden_dim_cnt++;
hidden_dim_id_ps(iDimP) = hidden_dim_id_p;
constexpr auto p2RHsMajor = Ps2RHssMajor{}[iDimP];
constexpr auto p2RHsMinor = Ps2RHssMinor{}[iDimP];
static_assert(p2RHsMajor.size() == p2RHsMinor.size(), "wrong!");
constexpr index_t ndim_low = p2RHsMajor.size();
Dims low_dims;
Lengths low_lengths;
for(index_t i = 0; i < ndim_low; ++i)
{
index_t rh_major = p2RHsMajor[i];
index_t rh_minor = p2RHsMinor[i];
low_dims(i) = rh_major_minor_to_hidden_ids[rh_major][rh_minor];
low_lengths(i) = rh_major_minor_to_hidden_lengths[rh_major][rh_minor];
}
trans(num_tran++) = {cood_transform_enum::merge,
MetaData{to_array<index_t, ndim_low>(low_lengths)},
NumDim{ndim_low},
low_dims,
NumDim{1},
Dims{hidden_dim_id_p}};
});
constexpr index_t ndim_bottom = ndim_x;
constexpr auto bottom_dim_ids = make_sequential_index<kMaxNumDim>(0, ndim_bottom);
constexpr auto ys_to_rhs_major = Ys2RHsMajor{};
constexpr auto ys_to_rhs_minor = Ys2RHsMinor{};
constexpr index_t ndim_y = Ys2RHsMajor::size();
constexpr index_t ndim_top = ndim_p + ndim_y;
auto top_dim_ids = hidden_dim_id_ps;
{
for(index_t i = 0; i < ndim_y; ++i)
{
index_t rh_major = ys_to_rhs_major[i];
index_t rh_minor = ys_to_rhs_minor[i];
top_dim_ids(ndim_p + i) = rh_major_minor_to_hidden_ids[rh_major][rh_minor];
}
}
//
const auto ps_ys_to_xs_adaptor_encoding =
make_tuple(trans, num_tran, bottom_dim_ids, ndim_bottom, top_dim_ids, ndim_top);
// descriptor: [y0, y1, ...] to [d]
Lengths y_lengths;
index_t d_length = 1;
for(index_t i = 0; i < ndim_y; ++i)
{
index_t rh_major = ys_to_rhs_major[i];
index_t rh_minor = ys_to_rhs_minor[i];
index_t y_length = rh_major_minor_to_hidden_lengths[rh_major][rh_minor];
y_lengths(i) = y_length;
d_length *= y_length;
}
auto tran = make_tuple(cood_transform_enum::unmerge,
MetaData{to_array<index_t, ndim_y>(y_lengths)},
NumDim{1},
Dims{0},
NumDim{ndim_y},
make_sequential_index<kMaxNumDim>(1, ndim_y + 1));
const auto ys_to_d_adaptor_encoding = make_tuple(
make_tuple(tran), 1, Dims{0}, 1, make_sequential_index<kMaxNumDim>(1, ndim_y + 1), ndim_y);
return make_tuple(ps_ys_to_xs_adaptor_encoding,
ys_to_d_adaptor_encoding,
d_length,
rh_major_minor_to_hidden_ids);
}
// FIXME: this is nasty. Move it inside TileDistributionEncoding::detail
template <typename RhMajorMinor2AdaptorHiddenIdss> // tuple<sequence<...>, ...>
struct tile_distribution_detail
{
static constexpr auto rh_major_minor_to_adaptor_hidden_idss_ =
to_array_of_array(RhMajorMinor2AdaptorHiddenIdss{});
};
} // namespace detail
// this returns a constexpr tile_distribution
template <typename StaticTileDistributionEncoding_>
CK_TILE_HOST_DEVICE constexpr auto make_tile_distribution(StaticTileDistributionEncoding_)
{
using DstrEncode = remove_cvref_t<StaticTileDistributionEncoding_>;
constexpr auto adaptor_impl =
detail::make_adaptor_encoding_for_tile_distribution(StaticTileDistributionEncoding_{});
constexpr auto ps_ys_to_xs_adaptor_impl = adaptor_impl.template at<0>();
constexpr auto ys_to_d_adaptor_impl = adaptor_impl.template at<1>();
constexpr index_t d_length = adaptor_impl.template at<2>();
constexpr auto rh_major_minor_to_hidden_ids_impl = adaptor_impl.template at<3>();
constexpr auto ps_ys_to_xs_adaptor =
CONSTRUCT_TENSOR_ADAPTOR_FROM_ENCODING(ps_ys_to_xs_adaptor_impl);
constexpr auto ys_to_d_adaptor = CONSTRUCT_TENSOR_ADAPTOR_FROM_ENCODING(ys_to_d_adaptor_impl);
constexpr auto ys_to_d_descriptor =
make_tensor_descriptor_from_adaptor(ys_to_d_adaptor, d_length);
//
constexpr index_t ndim_rh_major = DstrEncode::detail::ndim_rh_major_;
constexpr auto ndims_rhs_minor = DstrEncode::detail::ndims_rhs_minor_;
constexpr auto rh_major_minor_to_hidden_ids =
TO_TUPLE_OF_SEQUENCE(rh_major_minor_to_hidden_ids_impl, ndim_rh_major, ndims_rhs_minor);
return tile_distribution<
remove_cvref_t<decltype(ps_ys_to_xs_adaptor)>,
remove_cvref_t<decltype(ys_to_d_descriptor)>,
remove_cvref_t<DstrEncode>,
detail::tile_distribution_detail<remove_cvref_t<decltype(rh_major_minor_to_hidden_ids)>>>{
ps_ys_to_xs_adaptor, ys_to_d_descriptor};
}
// this returns a static tile_distribution
template <typename StaticTileDistributionEncoding_>
CK_TILE_HOST_DEVICE constexpr auto make_static_tile_distribution(StaticTileDistributionEncoding_)
{
using DstrEncode = remove_cvref_t<StaticTileDistributionEncoding_>;
constexpr auto adaptor_impl =
detail::make_adaptor_encoding_for_tile_distribution(StaticTileDistributionEncoding_{});
constexpr auto ps_ys_to_xs_adaptor_impl = adaptor_impl.template at<0>();
constexpr auto ys_to_d_adaptor_impl = adaptor_impl.template at<1>();
constexpr index_t d_length = adaptor_impl.template at<2>();
constexpr auto rh_major_minor_to_hidden_ids_impl = adaptor_impl.template at<3>();
constexpr auto ps_ys_to_xs_adaptor =
CONSTRUCT_STATIC_TENSOR_ADAPTOR_FROM_ENCODING(ps_ys_to_xs_adaptor_impl);
constexpr auto ys_to_d_adaptor =
CONSTRUCT_STATIC_TENSOR_ADAPTOR_FROM_ENCODING(ys_to_d_adaptor_impl);
constexpr auto ys_to_d_descriptor =
make_tensor_descriptor_from_adaptor(ys_to_d_adaptor, number<d_length>{});
//
constexpr index_t ndim_rh_major = DstrEncode::detail::ndim_rh_major_;
constexpr auto ndims_rhs_minor = DstrEncode::detail::ndims_rhs_minor_;
constexpr auto rh_major_minor_to_hidden_ids =
TO_TUPLE_OF_SEQUENCE(rh_major_minor_to_hidden_ids_impl, ndim_rh_major, ndims_rhs_minor);
return tile_distribution<
remove_cvref_t<decltype(ps_ys_to_xs_adaptor)>,
remove_cvref_t<decltype(ys_to_d_descriptor)>,
remove_cvref_t<DstrEncode>,
detail::tile_distribution_detail<remove_cvref_t<decltype(rh_major_minor_to_hidden_ids)>>>{
ps_ys_to_xs_adaptor, ys_to_d_descriptor};
}
//***********************************************************************************
namespace detail {
template <typename Distribution>
CK_TILE_HOST_DEVICE auto get_partition_index(Distribution)
{
// only support warp-tile and block-tile
static_assert(Distribution::NDimP == 1 or Distribution::NDimP == 2, "wrong!");
if constexpr(Distribution::NDimP == 1)
{
return array<index_t, 1>{get_lane_id()};
}
else if constexpr(Distribution::NDimP == 2)
{
return array<index_t, 2>{get_warp_id(), get_lane_id()};
}
}
template <typename, typename, typename, index_t>
struct reverse_slice_sequence_impl;
template <index_t x,
index_t... xs,
index_t m,
index_t... ms,
index_t id,
index_t... ids,
index_t SliceSize>
struct reverse_slice_sequence_impl<sequence<x, xs...>,
sequence<m, ms...>,
sequence<id, ids...>,
SliceSize>
{
using old_scan =
reverse_slice_sequence_impl<sequence<xs...>, sequence<ms...>, sequence<ids...>, SliceSize>;
static constexpr auto slice_size = old_scan::remaining_slice_sizes::Front().value;
static constexpr auto slice_length =
std::conditional_t<m, number<math::gcd(x, slice_size)>, number<x>>::value;
using dim_lengths =
typename sequence_merge<sequence<slice_length>, typename old_scan::dim_lengths>::type;
using dim_slices =
typename sequence_merge<sequence<x / slice_length>, typename old_scan::dim_slices>::type;
using remaining_slice_sizes = typename sequence_merge<
std::conditional_t<m, sequence<slice_size / slice_length>, sequence<slice_size>>,
typename old_scan::remaining_slice_sizes>::type;
// the first idx that sliced length not equal to original length
static constexpr index_t _flag =
slice_length != x && remaining_slice_sizes{}.Front().value == 1;
static constexpr index_t _split_flag = std::conditional_t<m, number<_flag>, number<0>>::value;
static constexpr index_t _split_idx =
std::conditional_t<_split_flag, number<id>, number<0>>::value;
static constexpr index_t split_flag = _split_flag || old_scan::split_flag;
static constexpr index_t split_idx = std::
conditional_t<old_scan::split_flag, number<old_scan::split_idx>, number<_split_idx>>::value;
};
template <index_t x, index_t m, index_t id, index_t SliceSize>
struct reverse_slice_sequence_impl<sequence<x>, sequence<m>, sequence<id>, SliceSize>
{
static constexpr auto slice_size = SliceSize;
static constexpr auto slice_length =
std::conditional_t<m, number<math::gcd(x, slice_size)>, number<x>>::value;
using dim_lengths = sequence<slice_length>;
using dim_slices = sequence<x / slice_length>;
using remaining_slice_sizes =
std::conditional_t<m, sequence<slice_size / slice_length>, sequence<slice_size>>;
// the first idx that sliced length not equal to original length
static constexpr index_t _flag =
slice_length != x && remaining_slice_sizes{}.Front().value == 1;
static constexpr index_t split_flag = std::conditional_t<m, number<_flag>, number<0>>::value;
static constexpr index_t split_idx =
std::conditional_t<split_flag, number<id>, number<0>>::value;
};
// clang-format off
// input a sequence(with optional mask), and the SliceSize : size per slice
// output the sequence each slice, and number of slices
//
// e.g. <2, 1, 4, 2>, 8 -> lengths:<1, 1, 4, 2> , nums: <2, 1, 1, 1> : 2 slices , slice_idx: 0
// <4, 2, 4, 1, 2>, 4 -> lengths:<1, 1, 2, 1, 2> , nums: <4, 2, 2, 1, 1> : 16 slices , slice_idx: 2
// <4, 2, 4, 1, 6>, 4 -> lengths:<1, 1, 2, 1, 2> , nums: <4, 2, 2, 1, 3> : 48 slices , slice_idx: 2
// <4, 2, 5, 1, 2>, 10 -> lengths:<1, 1, 5, 1, 2> , nums: <4, 2, 1, 1, 1> : 8 slices , slice_idx: 1
//
// <4, 2, 8>, 64 -> lengths:<4, 2, 8> , nums: <1, 1, 1> : 1 slices , slice_idx: 0
// <4, 2, 8>, 32 -> lengths:<2, 2, 8> , nums: <2, 1, 1> : 2 slices , slice_idx: 0
// <4, 2, 8>, 16 -> lengths:<1, 2, 8> , nums: <4, 1, 1> : 4 slices , slice_idx: 0
// <4, 2, 8>, 8 -> lengths:<1, 1, 8> , nums: <4, 2, 1> : 8 slices , slice_idx: 1
// <4, 2, 8>, 4 -> lengths:<1, 1, 4> , nums: <4, 2, 2> : 16 slices , slice_idx: 2
// <4, 2, 8>, 2 -> lengths:<1, 1, 2> , nums: <4, 2, 4> : 32 slices , slice_idx: 2
// <4, 2, 8>, 1 -> lengths:<1, 1, 1> , nums: <4, 2, 8> : 64 slices , slice_idx: 2
//
// <4, 2, 1, 4, 2> / 4 ->
// mask:<1, 1, 1, 0, 1>, -> lengths:<1, 2, 1, 4, 2> , nums: <4, 1, 1, 1, 1> : 8 slices , slice_idx: 0
//
// return tuple<slice_lengths, slice_nums, slice_index>, slice_index is at which index will start
// have split slices (right -> left)
// or the first index that sliced length is different from the original length
// clang-format on
template <typename Seq,
index_t SliceSize,
typename Mask = typename uniform_sequence_gen<Seq::size(), 1>::type>
constexpr auto reverse_slice_sequence(Seq,
number<SliceSize>,
Mask = typename uniform_sequence_gen<Seq::size(), 1>::type{})
{
static_assert(Seq::size() == Mask::size());
using sliced_type =
reverse_slice_sequence_impl<Seq,
Mask,
typename arithmetic_sequence_gen<0, Seq::size(), 1>::type,
SliceSize>;
static_assert(sliced_type::remaining_slice_sizes::Front().value == 1,
"can not evenly divide this sequence, please check");
return make_tuple(typename sliced_type::dim_lengths{},
typename sliced_type::dim_slices{},
number<sliced_type::split_idx>{});
}
//
// slice tensor from x_dim, result in split in y_dim, not p_dim.
// We don't support slice cross p_dim (aka, slice different threads)
// also, sliced along y_dim need be the first dim of current dim.
// Multiply Y dim before sliced dim does not make sense
//
// e.g
// X0 X1
// <1, 4, 32> - <4, 1, 4, 2, 4> | slice origin:<0, 0>, len:<0, 32>, (0 means all length)
// Y P P Y P Y P Y
// => <1, 4, 32> - <1, 1, 4, 2, 4> -> OK
// |--> slice along this Y dim, is the first dim of X1, totally 4 slices
//
// X0 X1
// <1, 4, 32> - <4, 1, 4, 2, 4> | slice origin:<0, 0>, len:<0, 8>, (0 means all length)
// Y P P Y P Y P Y
// => <1, 4, 32> - <1, 1, 1, 2, 4> -> OK
// |--> slice along this Y dim, the P dim is 1 in the left, so is OK
// totally 16 slices
//
// X0 X1
// <1, 4, 32> - <4, 1, 4, 2, 4> | slice origin:<0, 0>, len:<0, 4>, (0 means all length)
// Y P P Y P Y P Y
// => <1, 4, 32> - <1, 1, 1, 1, 4> -> Fail
// |--> slice along this P dim, will split threads, not supported
//
// X0 X1
// <1, 4, 32> - <4, 1, 4, 2, 4> | slice origin:<0, 0>, len:<0, 16>, (0 means all length)
// Y P P Y P Y P Y
// => <1, 4, 32> - <1, 1, 2, 2, 4> -> OK
// |--> slice along this Y dim, but this Y sim need to split into 2
// subdime
// the P dim in the left is 1, means actually not crossing P
//
template <typename Distribution, index_t... XSliceBegins, index_t... XSliceEnds>
CK_TILE_HOST_DEVICE constexpr auto slice_distribution_from_x(
Distribution, sequence<XSliceBegins...> x_slice_begins, sequence<XSliceEnds...> x_slice_ends)
{
// NOTE: this function need to be called under constexpr context,
// due to https://wg21.link/p2280r0 we have to use non-reference type for distribution
using Encoding = decltype(Distribution::get_static_tile_distribution_encoding());
static_assert(sizeof...(XSliceBegins) == sizeof...(XSliceEnds));
constexpr auto x_slice_lengths = x_slice_ends - x_slice_begins;
constexpr auto src_h_prefix_sum = Encoding::detail::get_h_dim_lengths_prefix_sum();
constexpr auto src_y_info = Encoding::detail::get_sorted_y_info();
constexpr auto src_y_dims = src_y_info[number<0>{}];
constexpr auto src_y_maps = src_y_info[number<1>{}];
constexpr auto src_y_prefix_sum = src_y_info[number<2>{}];
constexpr auto sliced_hlen_yidx_ylen = [&]() constexpr
{
auto y_slice_sorted_origins = make_zero_multi_index<Encoding::NDimY>();
auto y_slice_lengths = Encoding::detail::ys_lengths_;
// This lambda will modify some value outside, so c++ will not treat return value as
// constexpr
// TODO: ugly
auto new_h_lengths = transform_tuples(
[&](auto h_len, auto id) {
constexpr auto sliced_h =
reverse_slice_sequence(h_len, number<x_slice_lengths[id]>{});
constexpr auto sliced_h_lens = sliced_h[number<0>{}];
constexpr auto sliced_h_index = sliced_h[number<2>{}];
// update y_slice_lengths
constexpr auto uniformed_h_index = sliced_h_index + number<src_h_prefix_sum[id]>{};
constexpr auto found_y_index = container_find(src_y_dims, uniformed_h_index);
static_assert(found_y_index >= 0 && found_y_index < src_y_dims.size(),
"not sliced at y dim, please check");
static_for<0, sliced_h_index + 1, 1>{}([&](auto i) {
y_slice_lengths(src_y_maps[found_y_index - i]) =
sliced_h_lens[sliced_h_index - i];
});
// TODO: add validations not across p dim
// NOTE: this y_origin is for all dims, not only current dim
// will later use pick to select target dim
constexpr auto y_origin = [&]() {
constexpr auto h_trans = make_merge_transform_v3_division_mod(h_len);
auto h_origin_ = make_zero_multi_index<h_trans.NDimLow>();
h_trans.calculate_lower_index(h_origin_, sequence<x_slice_begins[id].value>{});
auto y_origin_ = make_zero_multi_index<Encoding::NDimY>();
static_for<0, sliced_h_index + 1, 1>{}([&](auto i) {
y_origin_(found_y_index - i) = h_origin_[sliced_h_index - i];
});
return y_origin_;
}();
constexpr auto y_picks = typename arithmetic_sequence_gen<src_y_prefix_sum[id],
src_y_prefix_sum[id + 1],
1>::type{};
set_container_subset(
y_slice_sorted_origins, y_picks, get_container_subset(y_origin, y_picks));
return sliced_h_lens;
},
typename Encoding::HsLengthss{},
typename arithmetic_sequence_gen<0, Encoding::HsLengthss::size(), 1>::type{});
auto y_slice_origins = container_reorder_given_old2new(y_slice_sorted_origins, src_y_maps);
return make_tuple(new_h_lengths, y_slice_origins, y_slice_lengths);
}
();
constexpr auto sliced_h_lengths = sliced_hlen_yidx_ylen[number<0>{}];
constexpr auto sliced_y_origins_array = sliced_hlen_yidx_ylen[number<1>{}];
constexpr auto sliced_y_origins_size = sliced_y_origins_array.size();
constexpr auto sliced_y_lengths_array = sliced_hlen_yidx_ylen[number<2>{}];
constexpr auto sliced_y_lengths_size = sliced_y_lengths_array.size();
constexpr auto sliced_y_origins = TO_SEQUENCE(sliced_y_origins_array, sliced_y_origins_size);
constexpr auto sliced_y_lengths = TO_SEQUENCE(sliced_y_lengths_array, sliced_y_lengths_size);
return make_tuple(
make_static_tile_distribution(
tile_distribution_encoding<typename Encoding::RsLengths,
decltype(sliced_h_lengths), // only need to change the
// h_lengths type
typename Encoding::Ps2RHssMajor,
typename Encoding::Ps2RHssMinor,
typename Encoding::Ys2RHsMajor,
typename Encoding::Ys2RHsMinor>{}),
sliced_y_origins,
sliced_y_lengths);
}
} // namespace detail
} // namespace ck_tile

View File

@@ -0,0 +1,761 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
#include "ck_tile/core/tensor/tensor_adaptor.hpp"
#include "ck_tile/core/tensor/tensor_adaptor_coordinate.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/container/multi_index.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace ck_tile {
template <typename RsLengths_, // sequence<...>
typename HsLengthss_, // tuple<sequence<...>, ...>
typename Ps2RHssMajor_, // tuple<sequence<...>, ...>
typename Ps2RHssMinor_, // tuple<sequence<...>, ...>
typename Ys2RHsMajor_, // sequence<...>
typename Ys2RHsMinor_> // sequence<...>
struct tile_distribution_encoding
{
using RsLengths = remove_cvref_t<RsLengths_>;
using HsLengthss = remove_cvref_t<HsLengthss_>;
using Ps2RHssMajor = remove_cvref_t<Ps2RHssMajor_>;
using Ps2RHssMinor = remove_cvref_t<Ps2RHssMinor_>;
using Ys2RHsMajor = remove_cvref_t<Ys2RHsMajor_>;
using Ys2RHsMinor = remove_cvref_t<Ys2RHsMinor_>;
static_assert(Ps2RHssMajor::size() == Ps2RHssMinor::size(), "wrong!");
static_assert(Ys2RHsMajor::size() == Ys2RHsMinor::size(), "wrong!");
static constexpr index_t NDimX = HsLengthss::size();
static constexpr index_t NDimP = Ps2RHssMajor::size();
static constexpr index_t NDimY = Ys2RHsMajor::size();
static constexpr index_t NDimR = RsLengths::size();
// FIXME: move into detail
static constexpr auto rs_lengths_ = RsLengths{};
static constexpr auto hs_lengthss_ = HsLengthss{};
static constexpr auto ps_to_rhss_major_ = Ps2RHssMajor{};
static constexpr auto ps_to_rhss_minor_ = Ps2RHssMinor{};
static constexpr auto ys_to_rhs_major_ = Ys2RHsMajor{};
static constexpr auto ys_to_rhs_minor_ = Ys2RHsMinor{};
// redundant but useful info
// TODO: really bad code, should be over-hauled
struct detail
{
// ndim_rh_major_, ndim_span_mainor_
static constexpr index_t ndim_rh_major_ = NDimX + 1;
static constexpr index_t ndim_span_major_ = NDimX;
// ndims_rhs_minor_[ndim_rh_major_]
static constexpr auto ndims_rhs_minor_ = generate_array(
[](auto i) {
if constexpr(i.value == 0)
{
return rs_lengths_.size();
}
else
{
return hs_lengthss_[i - number<1>{}].size();
}
},
number<ndim_rh_major_>{});
// max_ndim_rh_minor_
static constexpr index_t max_ndim_rh_minor_ =
container_reduce(ndims_rhs_minor_, math::maximize<index_t>{}, 0);
// rhs_lengthss_[ndim_rh_major_][max_ndim_rh_minor_]
static constexpr auto rhs_lengthss_ =
to_array_of_array(container_concat(make_tuple(rs_lengths_), hs_lengthss_));
// ys_lengths_
static constexpr auto ys_lengths_ = [] {
array<index_t, NDimY> ys_lengths_tmp{-1};
for(index_t i = 0; i < NDimY; i++)
{
index_t rh_major = ys_to_rhs_major_[i];
index_t rh_minor = ys_to_rhs_minor_[i];
ys_lengths_tmp(i) = rhs_lengthss_[rh_major][rh_minor];
}
return ys_lengths_tmp;
}();
// rhs_major_minor_to_ys_[ndim_rh_majpr_][max_ndim_rh_minor_]
static constexpr auto rhs_major_minor_to_ys_ = [] {
array<array<index_t, max_ndim_rh_minor_>, NDimX + 1> rhs_major_minor_to_ys_tmp{{-1}};
static_for<0, NDimY, 1>{}([&](auto i) {
constexpr index_t rh_major = ys_to_rhs_major_[i];
constexpr index_t rh_minor = ys_to_rhs_minor_[i];
rhs_major_minor_to_ys_tmp(rh_major)(rh_minor) = i;
});
return rhs_major_minor_to_ys_tmp;
}();
// ndims_span_minor_[NDimY]
static constexpr auto ndims_span_minor_ = [] {
array<index_t, NDimX> ndims_span_minor{0};
for(index_t i = 0; i < NDimY; i++)
{
const index_t span_major = ys_to_rhs_major_[i] - 1;
ndims_span_minor(span_major)++;
}
return ndims_span_minor;
}();
// max_ndim_span_minor_
static constexpr index_t max_ndim_span_minor_ =
container_reduce(ndims_span_minor_, math::maximize<index_t>{}, 0);
// rhs_major_minor_to_span_minor_ [ndim_rh_major_][max_ndim_rh_minor_]
static constexpr auto rhs_major_minor_to_span_minor_ = [] {
array<array<index_t, max_ndim_rh_minor_>, ndim_rh_major_> rhs_major_minor_to_span_minor{
{-1}};
static_for<0, ndim_rh_major_, 1>{}([&](auto rh_major) {
constexpr index_t ndim_rh_minor = ndims_rhs_minor_[rh_major];
index_t cnt_ndim_span_minor = 0;
static_for<0, ndim_rh_minor, 1>{}([&](auto rh_minor) {
constexpr index_t idim_y = rhs_major_minor_to_ys_[rh_major][rh_minor];
if(idim_y >= 0)
{
rhs_major_minor_to_span_minor(rh_major)(rh_minor) = cnt_ndim_span_minor;
cnt_ndim_span_minor++;
}
});
});
return rhs_major_minor_to_span_minor;
}();
// ys_to_span_major_[NDimY]
static constexpr auto ys_to_span_major_ =
generate_array([](auto i) { return ys_to_rhs_major_[i] - 1; }, number<NDimY>{});
// ys_to_span_minor_[NDimY]
static constexpr auto ys_to_span_minor_ = generate_array(
[](auto i) {
return rhs_major_minor_to_span_minor_[ys_to_rhs_major_[i]][ys_to_rhs_minor_[i]];
},
number<NDimY>{});
// distributed_spans_lengthss_[ndim_span_major_][max_ndim_span_minor_]
static constexpr auto distributed_spans_lengthss_ = [] {
array<array<index_t, max_ndim_span_minor_>, ndim_span_major_>
distributed_spans_lengthss{{-1}};
static_for<0, NDimY, 1>{}([&](auto i) {
const index_t rh_major = ys_to_rhs_major_[i];
const index_t rh_minor = ys_to_rhs_minor_[i];
const index_t h_length = hs_lengthss_[number<rh_major - 1>{}][rh_minor];
const index_t span_major = rh_major - 1;
const index_t span_minor = rhs_major_minor_to_span_minor_[rh_major][rh_minor];
distributed_spans_lengthss(span_major)(span_minor) = h_length;
});
return distributed_spans_lengthss;
}();
// ndims_distributed_spans_minor_[ndim_span_major_]
static constexpr auto ndims_distributed_spans_minor_ = [] {
array<index_t, ndim_span_major_> ndims_distributed_spans_minor{0};
static_for<0, NDimY, 1>{}([&](auto i) {
const index_t span_major = ys_to_rhs_major_[i] - 1;
ndims_distributed_spans_minor(span_major)++;
});
return ndims_distributed_spans_minor;
}();
// does_p_own_r_[NDimP][NDimR]
static constexpr auto does_p_own_r_ = [] {
if constexpr(NDimR > 0)
{
array<array<bool, NDimR>, NDimP> does_p_own_r{{false}};
static_for<0, NDimP, 1>{}([&](auto idim_p) {
constexpr index_t ndim_low = ps_to_rhss_major_[idim_p].size();
static_for<0, ndim_low, 1>{}([&](auto idim_low) {
constexpr index_t rh_major = ps_to_rhss_major_[idim_p][idim_low];
constexpr index_t rh_minor = ps_to_rhss_minor_[idim_p][idim_low];
if constexpr(rh_major == 0)
{
does_p_own_r(idim_p)(rh_minor) = true;
}
});
});
return does_p_own_r;
}
else
{
return array<array<bool, NDimR>, NDimP>{};
}
}();
// ps_over_rs_derivative_[NDimP][NDimR]
static constexpr auto ps_over_rs_derivative_ = [] {
if constexpr(NDimR > 0)
{
array<array<index_t, NDimR>, NDimP> ps_over_rs_derivative{{0}};
static_for<0, NDimP, 1>{}([&](auto idim_p) {
constexpr index_t ndim_low = ps_to_rhss_major_[idim_p].size();
index_t p_over_rh_derivative = 1;
static_for<ndim_low - 1, -1, -1>{}([&](auto idim_low) {
constexpr index_t rh_major = ps_to_rhss_major_[idim_p][idim_low];
constexpr index_t rh_minor = ps_to_rhss_minor_[idim_p][idim_low];
constexpr index_t rh_length = rhs_lengthss_[rh_major][rh_minor];
if constexpr(rh_major == 0)
{
ps_over_rs_derivative(idim_p)(rh_minor) = p_over_rh_derivative;
}
p_over_rh_derivative *= rh_length;
});
});
return ps_over_rs_derivative;
}
else
{
return array<array<index_t, NDimR>, NDimP>{};
}
}();
// e.g. tuple<seq<1, 4, 32>, seq<4, 1, 4, 2, 4>> --> seq<3, 5> --> seq<0, 3, 8>
CK_TILE_HOST_DEVICE static constexpr auto get_h_dim_lengths_prefix_sum()
{
// <len_d0, len_d1, ...>
// e.g. tuple<seq<1, 4, 32>, seq<4, 1, 4, 2, 4>> --> seq<3, 5>
constexpr auto uniformed_h_dim_lengths = generate_sequence_v2(
[&](auto i) {
constexpr index_t size = HsLengthss{}[i].size();
return number<size>{};
},
number<NDimX>{});
// <0, len_d0, len_d0+len_d1, ...>
// e.g. seq<3, 5> --> seq<0, 3, 8>
constexpr auto h_dim_prefix_sum = prefix_sum_sequence(uniformed_h_dim_lengths);
return h_dim_prefix_sum;
}
CK_TILE_HOST_DEVICE static constexpr auto get_uniformed_idx_y_to_h()
{
constexpr auto all_ys_2_rhss = transform_sequences(
[](auto major, auto minor) constexpr {
// <0, 0, len_d0, len_d0+len_d1, ...>
constexpr auto x_dim_prefix_sum = merge_sequences(
sequence<0>{} /*for R dims*/, get_h_dim_lengths_prefix_sum());
return x_dim_prefix_sum.at(major) + minor;
},
Ys2RHsMajor{},
Ys2RHsMinor{});
return all_ys_2_rhss;
}
// return tuple<sorted_dims, sorted_maps, sorted_prefix_sum>
template <typename IdxSeq, typename PrefixSumSeq>
CK_TILE_HOST_DEVICE static constexpr auto get_sorted_info(IdxSeq, PrefixSumSeq)
{
using sorted_idx =
sequence_unique_sort<IdxSeq, math::less<index_t>, math::equal<index_t>>;
constexpr auto sorted_dims = typename sorted_idx::type{};
constexpr auto sorted_maps = typename sorted_idx::sorted2unsorted_map{};
constexpr auto sorted_histogram =
histogram_sorted_sequence(sorted_dims, PrefixSumSeq{});
constexpr auto sorted_prefix_sum = prefix_sum_sequence(sorted_histogram);
return make_tuple(sorted_dims, sorted_maps, sorted_prefix_sum);
}
CK_TILE_HOST_DEVICE static constexpr auto get_sorted_y_info()
{
return get_sorted_info(get_uniformed_idx_y_to_h(), get_h_dim_lengths_prefix_sum());
}
CK_TILE_HOST_DEVICE void print() const
{
printf("tile_distribution_encoding::detail{");
//
printf("ndim_rh_major_: ");
print(ndim_rh_major_);
printf(", ");
//
printf("ndim_span_major_: ");
print(ndim_span_major_);
printf(", ");
//
printf("ndims_rhs_minor_: ");
print(ndims_rhs_minor_);
printf(", ");
//
printf("ndim_rh_major_: ");
print(ndim_rh_major_);
printf(", ");
//
printf("max_ndim_rh_minor_: ");
print(max_ndim_rh_minor_);
printf(", ");
//
printf("rhs_lengthss_: ");
print(rhs_lengthss_);
printf(", ");
//
printf("ys_lengths_: ");
print(ys_lengths_);
printf(", ");
//
printf("rhs_major_minor_to_ys_: ");
print(rhs_major_minor_to_ys_);
printf(", ");
//
printf("ndims_span_minor_: ");
print(ndims_span_minor_);
printf(", ");
//
printf("max_ndim_span_minor_: ");
print(max_ndim_span_minor_);
printf(", ");
//
printf("ys_to_span_major_: ");
print(ys_to_span_major_);
printf(", ");
//
printf("ys_to_span_minor_: ");
print(ys_to_span_minor_);
printf(", ");
//
printf("distributed_spans_lengthss_: ");
print(distributed_spans_lengthss_);
printf(", ");
//
printf("ndims_distributed_spans_minor_: ");
print(ndims_distributed_spans_minor_);
printf(", ");
//
printf("ps_over_rs_derivative_: ");
print(ps_over_rs_derivative_);
//
printf("}");
}
};
CK_TILE_HOST_DEVICE void print() const
{
printf("tile_distribution_encoding{");
//
printf("NDimX: %d, NDimP: %d, NDimY: %d, ", NDimX, NDimP, NDimY);
//
printf("rs_lengths_: ");
print(rs_lengths_);
printf(", ");
//
printf("hs_lengthss_: ");
print(hs_lengthss_);
printf(", ");
//
printf("ps_to_rhss_major_: ");
print(ps_to_rhss_major_);
printf(", ");
//
printf("ps_to_rhss_minor_: ");
print(ps_to_rhss_minor_);
printf(", ");
//
printf("ys_to_rhs_major_: ");
print(ys_to_rhs_major_);
printf(", ");
//
printf("ys_to_rhs_minor_: ");
print(ys_to_rhs_minor_);
printf(", ");
//
printf("detail: ");
print(detail{});
//
printf("}");
}
};
namespace detail {
template <typename OuterDstr, typename InnerDstr>
CK_TILE_HOST_DEVICE constexpr auto make_embed_tile_distribution_encoding(OuterDstr, InnerDstr)
{
static_assert(OuterDstr::NDimX == InnerDstr::NDimX, "wrong!");
constexpr index_t NDimHMajor = OuterDstr::NDimX;
using RsLengths =
sequence_merge_t<typename OuterDstr::RsLengths, typename InnerDstr::RsLengths>;
constexpr auto hs_lengthss = generate_tuple(
[&](auto i) {
return merge_sequences(typename OuterDstr::HsLengthss{}[i],
typename InnerDstr::HsLengthss{}[i]);
},
number<NDimHMajor>{});
//
constexpr auto rhs_major_2_ndim_outer_rhs_minor = [&]() {
array<index_t, NDimHMajor + 1> rhs_major_2_ndim_outer_rhs_minor_;
// R dimension
rhs_major_2_ndim_outer_rhs_minor_(0) = OuterDstr::RsLengths::size();
// Hs dimensions
static_for<0, NDimHMajor, 1>{}([&](auto i) {
rhs_major_2_ndim_outer_rhs_minor_(i + 1) = typename OuterDstr::HsLengthss{}[i].size();
});
return rhs_major_2_ndim_outer_rhs_minor_;
}();
// Ps2RHssMinor
constexpr auto updated_inner_ps_2_rhss_minor = generate_tuple(
[&](auto p) {
constexpr auto inner_p_2_rhss_major = typename InnerDstr::Ps2RHssMajor{}[p];
constexpr auto inner_p_2_rhss_minor = typename InnerDstr::Ps2RHssMinor{}[p];
constexpr index_t ndim_tmp = inner_p_2_rhss_minor.size();
constexpr auto updated_inner_p_2_rhss_minor = [&]() {
array<index_t, ndim_tmp> updated_inner_p_2_rhss_minor_;
for(index_t i = 0; i < ndim_tmp; i++)
{
index_t rh_major = inner_p_2_rhss_major[i];
index_t ndim_outer_h_minor = rhs_major_2_ndim_outer_rhs_minor[rh_major];
updated_inner_p_2_rhss_minor_(i) = inner_p_2_rhss_minor[i] + ndim_outer_h_minor;
}
return updated_inner_p_2_rhss_minor_;
}();
return TO_SEQUENCE(updated_inner_p_2_rhss_minor, ndim_tmp);
},
number<InnerDstr::NDimP>{});
// Ys2RHsMinor
constexpr auto updated_inner_ys_2_rhs_minor = [&]() {
constexpr auto inner_ys_2_rhs_major = typename InnerDstr::Ys2RHsMajor{};
constexpr auto inner_ys_2_rhs_minor = typename InnerDstr::Ys2RHsMinor{};
constexpr index_t ndim_tmp = inner_ys_2_rhs_minor.size();
constexpr auto updated_inner_ys_2_rhs_minor_ = [&]() {
array<index_t, ndim_tmp> updated_inner_ys_2_rhs_minor__;
for(index_t i = 0; i < ndim_tmp; i++)
{
index_t rh_major = inner_ys_2_rhs_major[i];
index_t ndim_outer_h_minor = rhs_major_2_ndim_outer_rhs_minor[rh_major];
updated_inner_ys_2_rhs_minor__(i) = inner_ys_2_rhs_minor[i] + ndim_outer_h_minor;
}
return updated_inner_ys_2_rhs_minor__;
}();
return TO_SEQUENCE(updated_inner_ys_2_rhs_minor_, ndim_tmp);
}();
//
constexpr auto ps_2_rhss_major =
container_concat(typename OuterDstr::Ps2RHssMajor{}, typename InnerDstr::Ps2RHssMajor{});
constexpr auto ps_2_rhss_minor =
container_concat(typename OuterDstr::Ps2RHssMinor{}, updated_inner_ps_2_rhss_minor);
//
constexpr auto ys_2_rhs_major =
merge_sequences(typename OuterDstr::Ys2RHsMajor{}, typename InnerDstr::Ys2RHsMajor{});
constexpr auto ys_2_rhs_minor =
merge_sequences(typename OuterDstr::Ys2RHsMinor{}, updated_inner_ys_2_rhs_minor);
return tile_distribution_encoding<RsLengths,
remove_cvref_t<decltype(hs_lengthss)>,
remove_cvref_t<decltype(ps_2_rhss_major)>,
remove_cvref_t<decltype(ps_2_rhss_minor)>,
remove_cvref_t<decltype(ys_2_rhs_major)>,
remove_cvref_t<decltype(ys_2_rhs_minor)>>{};
}
template <typename InDstr, index_t... InReduceDimXs>
CK_TILE_HOST_DEVICE constexpr auto
make_reduce_tile_distribution_encoding_impl(InDstr, sequence<InReduceDimXs...> reduce_dim_xs_in)
{
constexpr auto I1 = number<1>{};
// FIXME: increase if fail
constexpr index_t max_ndim_r_out = 20;
constexpr index_t max_ndim_y_out = 20;
//
constexpr index_t ndim_p = InDstr::NDimP;
constexpr index_t ndim_x_in = InDstr::NDimX;
constexpr index_t ndim_y_in = InDstr::NDimY;
constexpr index_t ndim_rh_major_in = InDstr::NDimX + 1;
constexpr index_t ndim_x_out = ndim_x_in - sizeof...(InReduceDimXs);
constexpr index_t max_ndim_rh_minor_in = InDstr::detail::max_ndim_rh_minor_;
// ndims_ps_low
constexpr auto ndims_ps_low = generate_array(
[&](auto i) { return InDstr::ps_to_rhss_major_[i].size(); }, number<ndim_p>{});
// is_rh_major_in_for_reduce
array<bool, ndim_rh_major_in> is_rh_major_in_for_reduce{false};
for(index_t i = 0; i < reduce_dim_xs_in.size(); i++)
{
index_t rh_major = reduce_dim_xs_in[i] + 1;
is_rh_major_in_for_reduce(rh_major) = true;
}
// is_y_in_for_reduce
array<bool, ndim_y_in> is_y_in_for_reduce{false};
for(index_t i = 0; i < ndim_y_in; i++)
{
index_t rh_major = InDstr::ys_to_rhs_major_[i];
if(is_rh_major_in_for_reduce[rh_major])
{
is_y_in_for_reduce(i) = true;
}
}
// is_rh_minor_in_for_y_reduce
array<array<bool, max_ndim_rh_minor_in>, ndim_rh_major_in> is_rh_minor_in_for_y_reduce{{false}};
static_for<0, ndim_y_in, 1>{}([&](auto i) {
index_t rh_major = InDstr::ys_to_rhs_major_[i];
index_t rh_minor = InDstr::ys_to_rhs_minor_[i];
if(is_y_in_for_reduce[i])
{
is_rh_minor_in_for_y_reduce(rh_major)(rh_minor) = true;
}
});
// in2out_rh_major
array<index_t, ndim_rh_major_in> in2out_rh_major{-1};
index_t cnt_ndim_rh_major_out = 0;
for(index_t i = 0; i < ndim_rh_major_in; i++)
{
if(is_rh_major_in_for_reduce[i])
{
in2out_rh_major(i) = 0;
}
else
{
in2out_rh_major(i) = cnt_ndim_rh_major_out;
cnt_ndim_rh_major_out++;
}
}
// rs_lengths_out, in2out_rh_minor
array<index_t, max_ndim_r_out> rs_lengths_out{-1};
array<array<index_t, max_ndim_rh_minor_in>, ndim_rh_major_in> in2out_rh_minor{{-1}};
// loop over input R dim
for(index_t i = 0; i < InDstr::rs_lengths_.size(); i++)
{
// rs_lengths_out
rs_lengths_out(i) = InDstr::rs_lengths_[i];
// in2out_rh_minor
in2out_rh_minor(0)(i) = i;
}
// loop over input H Dim
index_t cnt_ndim_r_out = InDstr::rs_lengths_.size();
static_for<1, ndim_rh_major_in, 1>{}([&](auto rh_major_in) {
constexpr auto h_major_in = rh_major_in - I1;
constexpr index_t ndim_rh_minor_in = InDstr::hs_lengthss_[h_major_in].size();
if(is_rh_major_in_for_reduce[rh_major_in])
{
for(index_t rh_minor_in = 0; rh_minor_in < ndim_rh_minor_in; rh_minor_in++)
{
if(not is_rh_minor_in_for_y_reduce[rh_major_in][rh_minor_in])
{
// rs_lengths_out
rs_lengths_out(cnt_ndim_r_out) = InDstr::hs_lengthss_[h_major_in][rh_minor_in];
// in2out_rh_minor
in2out_rh_minor(rh_major_in)(rh_minor_in) = cnt_ndim_r_out;
cnt_ndim_r_out++;
}
}
}
else
{
for(index_t rh_minor_in = 0; rh_minor_in < ndim_rh_minor_in; rh_minor_in++)
{
// in2out_rh_minor
in2out_rh_minor(rh_major_in)(rh_minor_in) = rh_minor_in;
}
}
});
// ndim_r_out
const index_t ndim_r_out = cnt_ndim_r_out;
// ndims_hs_minor_out, hs_lengthss_out
array<index_t, ndim_x_out> ndims_hs_minor_out{-1};
array<array<index_t, max_ndim_rh_minor_in>, ndim_x_out> hs_lengthss_out{{-1}};
index_t cnt_ndim_x_out = 0;
static_for<0, ndim_x_in, 1>{}([&](auto i) {
if(not is_rh_major_in_for_reduce[i + I1])
{
// ndims_hs_minor_out
ndims_hs_minor_out(cnt_ndim_x_out) = InDstr::hs_lengthss_[i].size();
// hs_lengthss_out
static_for<0, InDstr::hs_lengthss_[i].size(), 1>{}(
[&](auto j) { hs_lengthss_out(cnt_ndim_x_out)(j) = InDstr::hs_lengthss_[i][j]; });
cnt_ndim_x_out++;
}
});
// ps_to_rhss_major_out, ps_to_rhss_minor_out
array<array<index_t, max_ndim_rh_minor_in>, ndim_p> ps_to_rhss_major_out{{-1}};
array<array<index_t, max_ndim_rh_minor_in>, ndim_p> ps_to_rhss_minor_out{{-1}};
static_for<0, ndim_p, 1>{}([&](auto idim_p) {
static_for<0, InDstr::ps_to_rhss_major_[idim_p].size(), 1>{}([&](auto idim_low) {
index_t rh_major_in = InDstr::ps_to_rhss_major_[idim_p][idim_low];
index_t rh_minor_in = InDstr::ps_to_rhss_minor_[idim_p][idim_low];
ps_to_rhss_major_out(idim_p)(idim_low) = in2out_rh_major[rh_major_in];
ps_to_rhss_minor_out(idim_p)(idim_low) = in2out_rh_minor[rh_major_in][rh_minor_in];
});
});
// ys_to_rhs_major_out, ys_to_rhs_minor_out
array<index_t, max_ndim_y_out> ys_to_rhs_major_out{-1};
array<index_t, max_ndim_y_out> ys_to_rhs_minor_out{-1};
index_t cnt_ndim_y_out = 0;
static_for<0, ndim_y_in, 1>{}([&](auto i) {
if(not is_y_in_for_reduce[i])
{
index_t rh_major_in = InDstr::ys_to_rhs_major_[i];
index_t rh_minor_in = InDstr::ys_to_rhs_minor_[i];
ys_to_rhs_major_out(cnt_ndim_y_out) = in2out_rh_major[rh_major_in];
ys_to_rhs_minor_out(cnt_ndim_y_out) = in2out_rh_minor[rh_major_in][rh_minor_in];
cnt_ndim_y_out++;
}
});
// ndim_y_out
const index_t ndim_y_out = cnt_ndim_y_out;
//
return make_tuple(ndim_x_out,
ndim_p,
ndim_y_out,
ndim_r_out,
ndims_hs_minor_out,
ndims_ps_low,
rs_lengths_out,
hs_lengthss_out,
ps_to_rhss_major_out,
ps_to_rhss_minor_out,
ys_to_rhs_major_out,
ys_to_rhs_minor_out);
}
template <typename InDstr, index_t... InReduceDimXs>
CK_TILE_HOST_DEVICE constexpr auto
make_reduce_tile_distribution_encoding(InDstr, sequence<InReduceDimXs...> reduce_dim_xs_in)
{
constexpr auto impl = make_reduce_tile_distribution_encoding_impl(InDstr{}, reduce_dim_xs_in);
constexpr index_t ndim_x = impl.template at<0>();
constexpr index_t ndim_p = impl.template at<1>();
constexpr index_t ndim_y = impl.template at<2>();
constexpr index_t ndim_r = impl.template at<3>();
constexpr auto ndims_hs_minor = impl.template at<4>();
constexpr auto ndims_ps_low = impl.template at<5>();
constexpr auto rs_lengths_impl = impl.template at<6>();
constexpr auto hs_lengthss_impl = impl.template at<7>();
constexpr auto ps_to_rhss_major_impl = impl.template at<8>();
constexpr auto ps_to_rhss_minor_impl = impl.template at<9>();
constexpr auto ys_to_rhs_major_impl = impl.template at<10>();
constexpr auto ys_to_rhs_minor_impl = impl.template at<11>();
constexpr auto rs_lengths = TO_SEQUENCE(rs_lengths_impl, ndim_r);
constexpr auto hs_lengthss = TO_TUPLE_OF_SEQUENCE(hs_lengthss_impl, ndim_x, ndims_hs_minor);
constexpr auto ps_to_rhss_major =
TO_TUPLE_OF_SEQUENCE(ps_to_rhss_major_impl, ndim_p, ndims_ps_low);
constexpr auto ps_to_rhss_minor =
TO_TUPLE_OF_SEQUENCE(ps_to_rhss_minor_impl, ndim_p, ndims_ps_low);
constexpr auto ys_to_rhs_major = TO_SEQUENCE(ys_to_rhs_major_impl, ndim_y);
constexpr auto ys_to_rhs_minor = TO_SEQUENCE(ys_to_rhs_minor_impl, ndim_y);
return tile_distribution_encoding<remove_cvref_t<decltype(rs_lengths)>,
remove_cvref_t<decltype(hs_lengthss)>,
remove_cvref_t<decltype(ps_to_rhss_major)>,
remove_cvref_t<decltype(ps_to_rhss_minor)>,
remove_cvref_t<decltype(ys_to_rhs_major)>,
remove_cvref_t<decltype(ys_to_rhs_minor)>>{};
}
} // namespace detail
} // namespace ck_tile

View File

@@ -0,0 +1,191 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/array.hpp"
#include "ck_tile/core/container/sequence.hpp"
#include "ck_tile/core/container/tuple.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/tensor/tensor_adaptor.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace ck_tile {
// TODO: support tensors with different distribution
template <typename InOutElementFunc,
typename... InOutDstrTensors,
typename = std::enable_if_t<std::conjunction_v<
std::negation<std::is_same<std::remove_const_t<InOutDstrTensors>, NullTensor>>...>>>
CK_TILE_DEVICE void tile_elementwise_inout(const InOutElementFunc& inout_element_func,
InOutDstrTensors&... inout_dstr_tensors)
{
// TODO: make sure all distributed tensors have same lengths and distribution
// static_assert(xxx);
constexpr index_t thread_buffer_size =
type_pack_element<0, InOutDstrTensors...>::get_thread_buffer_size();
static_for<0, thread_buffer_size, 1>{}(
[&](auto i) { inout_element_func(inout_dstr_tensors.get_thread_buffer().at(i)...); });
}
template <typename InElementFunc,
typename... InDstrTensors,
typename = std::enable_if_t<
std::conjunction_v<std::negation<std::is_same<InDstrTensors, NullTensor>>...>>>
CK_TILE_DEVICE auto tile_elementwise_in(const InElementFunc& in_element_func,
const InDstrTensors&... in_dstr_tensors)
{
using OutDataType = decltype(in_element_func(typename InDstrTensors::DataType{}...));
// TODO: make sure all distributed tensors have same lengths and distribution
// static_assert(xxx);
constexpr auto in_tile_dstr = type_pack_element<0, InDstrTensors...>::get_tile_distribution();
constexpr index_t thread_buffer_size =
type_pack_element<0, InDstrTensors...>::get_thread_buffer_size();
auto out_dstr_tensor = make_static_distributed_tensor<OutDataType>(in_tile_dstr);
static_for<0, thread_buffer_size, 1>{}([&](auto i) {
out_dstr_tensor.get_thread_buffer()(i) =
in_element_func(in_dstr_tensors.get_thread_buffer()[i]...);
});
return out_dstr_tensor;
}
template <typename DstrTensors, typename T>
CK_TILE_DEVICE void set_tile(DstrTensors& dstr_tensor, const T& value)
{
tile_elementwise_inout(
[&value](auto& x) {
x = type_convert<typename DstrTensors::DataType, remove_cvref_t<T>>(value);
},
dstr_tensor);
}
template <typename T>
CK_TILE_DEVICE void set_tile(NullTensor&, const T&)
{
}
// TODO: prefer to use per-dword value to set a tensor, in case compiler not doing well with
// sub-dword tensor...
template <typename DstrTensors, index_t v>
CK_TILE_DEVICE void set_tile(DstrTensors& dstr_tensor, number<v>)
{
constexpr index_t tensor_bytes =
DstrTensors::get_thread_buffer_size() * sizeof(typename DstrTensors::DataType);
if constexpr(v == 0 && tensor_bytes % 4 == 0)
{
using dvec_t = static_buffer_c<index_t, tensor_bytes / 4>;
auto& tensor = reinterpret_cast<dvec_t&>(dstr_tensor.get_thread_buffer());
for(auto i = 0; i < tensor.size(); i++)
tensor.get(i) = v;
}
else
{
tile_elementwise_inout(
[](auto& x) { x = type_convert<typename DstrTensors::DataType, index_t>(v); },
dstr_tensor);
}
}
template <index_t v>
CK_TILE_DEVICE void set_tile(NullTensor&, number<v>)
{
}
template <typename DstrTensors>
CK_TILE_DEVICE void clear_tile(DstrTensors& dstr_tensor)
{
set_tile(dstr_tensor, 0);
}
// TODO: this is ugly
template <typename OutDataType, typename InDstrTensors>
CK_TILE_DEVICE auto cast_tile_pk_fp8x4(const InDstrTensors& in_dstr_tensors)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// This API is designed to use the _pk_ serious of function
constexpr auto in_tile_dstr = InDstrTensors::get_tile_distribution();
constexpr index_t thread_buffer_size = InDstrTensors::get_thread_buffer_size();
static_assert(thread_buffer_size % 4 == 0);
constexpr index_t thread_buffer_size_pk = thread_buffer_size / 4;
auto out_dstr_tensor = make_static_distributed_tensor<OutDataType>(in_tile_dstr);
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wuninitialized"
// __builtin_amdgcn_cvt_pk_fp8_f32() this builtin require the old value, and
// will generate a v_mov_b32 vxxx [old] before cvt, which result in unwanted ISA
// so we prepare an uninitialized variable purposely, and turn off the warning
int dummy_old;
static_for<0, thread_buffer_size_pk, 1>{}([&](auto i) {
uint32_t x = __builtin_amdgcn_cvt_pk_fp8_f32(
in_dstr_tensors.get_thread_buffer()[number<4 * i + 0>{}],
in_dstr_tensors.get_thread_buffer()[number<4 * i + 1>{}],
dummy_old,
false); // false -> WORD0
uint32_t y = __builtin_amdgcn_cvt_pk_fp8_f32(
in_dstr_tensors.get_thread_buffer()[number<4 * i + 2>{}],
in_dstr_tensors.get_thread_buffer()[number<4 * i + 3>{}],
dummy_old,
false); // false -> WORD0
constexpr int32_t m0 = 0x05040100;
using vec_t = typename vector_type<OutDataType, 4>::type;
vec_t d = bit_cast<vec_t>(__builtin_amdgcn_perm(y, x, m0));
out_dstr_tensor.get_thread_buffer().template set_as<vec_t>(number<i>{}, d);
});
#pragma clang diagnostic pop
return out_dstr_tensor;
#else
// fallback
return tile_elementwise_in(type_convert<OutDataType, typename InDstrTensors::DataType>,
in_dstr_tensors);
#endif
}
template <typename DstType, typename SrcDstrTensors>
CK_TILE_DEVICE auto cast_tile(const SrcDstrTensors& src_tensor)
{
if constexpr((ck_tile::is_same_v<DstType, fp8_t> ||
ck_tile::is_same_v<DstType, bf8_t>)&&ck_tile::
is_same_v<typename SrcDstrTensors::DataType, float> &&
(SrcDstrTensors::get_thread_buffer_size() % 4 == 0))
{
return cast_tile_pk_fp8x4<DstType, SrcDstrTensors>(src_tensor);
}
else
return tile_elementwise_in(type_convert<DstType, typename SrcDstrTensors::DataType>,
src_tensor);
}
// no-op function for NullTensor arguments
template <typename InOutElementFunc,
typename... MaybeNullTensor,
typename = std::enable_if_t<
std::disjunction_v<std::is_same<remove_cvref_t<MaybeNullTensor>, NullTensor>...>>>
CK_TILE_DEVICE void tile_elementwise_inout(const InOutElementFunc&, MaybeNullTensor&&...)
{
}
// no-op function for NullTensor arguments
template <typename InElementFunc,
typename... MaybeNullTensor,
typename = std::enable_if_t<
std::disjunction_v<std::is_same<remove_cvref_t<MaybeNullTensor>, NullTensor>...>>>
CK_TILE_DEVICE auto tile_elementwise_in(const InElementFunc&, MaybeNullTensor&&...)
{
return NullTensor{};
}
} // namespace ck_tile

View File

@@ -0,0 +1,735 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/array.hpp"
#include "ck_tile/core/container/sequence.hpp"
#include "ck_tile/core/container/tuple.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/tensor/tensor_adaptor.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace ck_tile {
template <typename BottomTensorView_,
typename WindowLengths_,
typename StaticTileDistribution_,
index_t NumCoord>
struct tile_window_with_static_distribution
{
using BottomTensorView = remove_reference_t<BottomTensorView_>;
using WindowLengths = remove_cvref_t<WindowLengths_>;
using TileDstr = remove_cvref_t<StaticTileDistribution_>;
using WindowAdaptor = typename TileDstr::PsYs2XsAdaptor;
using BottomTensorDesc = typename BottomTensorView::TensorDesc;
using DataType = remove_cvref_t<typename BottomTensorView::DataType>;
static constexpr index_t NDimWindowAdaptorTop = WindowAdaptor::get_num_of_top_dimension();
static constexpr index_t NDimBottomTensor = BottomTensorDesc::get_num_of_dimension();
static constexpr index_t NDimP = TileDstr::GetNumOfDimensionP();
static constexpr index_t NDimY = TileDstr::GetNumOfDimensionY();
static constexpr auto I0 = number<0>{};
static constexpr auto I1 = number<1>{};
// TODO: check WindowLengths and StaticTileDistribution are consistent
static_assert(ck_tile::is_known_at_compile_time<WindowLengths>::value,
"wrong! lengths should be static");
static_assert(TileDstr::is_static(), "wrong!");
static_assert(NDimBottomTensor == WindowAdaptor::get_num_of_bottom_dimension(),
"wrong! inconsistent # of diemsnions");
using AdaptorTopIndex = array<index_t, NDimWindowAdaptorTop>;
using BottomTensorIndex = array<index_t, NDimBottomTensor>;
using WindowAdaptorCoord =
decltype(make_tensor_adaptor_coordinate(WindowAdaptor{}, AdaptorTopIndex{}));
using BottomTensorCoord =
decltype(make_tensor_coordinate(BottomTensorDesc{}, BottomTensorIndex{}));
struct load_store_traits
{
private:
static constexpr auto get_vector_dim_y_scalar_per_vector()
{
const auto [ys_vector_lengths, ys_vector_strides] =
tile_window_with_static_distribution::
get_window_adaptor_ys_safe_vector_length_strides();
index_t VectorDimY_ = 0;
index_t ScalarPerVector_ = 1;
for(index_t i = 0; i < NDimY; ++i)
{
if(ys_vector_strides[i] == 1 && ys_vector_lengths[i] > ScalarPerVector_)
{
ScalarPerVector_ = ys_vector_lengths[i];
VectorDimY_ = i;
}
}
return make_tuple(VectorDimY_, ScalarPerVector_);
}
public:
static constexpr index_t VectorDimY = get_vector_dim_y_scalar_per_vector().template at<0>();
static constexpr index_t ScalarPerVector =
get_vector_dim_y_scalar_per_vector().template at<1>();
using vector_type_t = vector_type_maker_t<DataType, ScalarPerVector>;
using vector_t = typename vector_type_t::type;
private:
static constexpr auto scalars_per_access_ = [] {
constexpr auto scalars_per_access_arr = generate_array(
[&](auto i) { return (i == VectorDimY) ? ScalarPerVector : 1; }, number<NDimY>{});
/// TODO: add non-automatic storage argument support to macro TO_SEQUENCE()
constexpr auto NDimY_ = NDimY;
return TO_SEQUENCE(scalars_per_access_arr, NDimY_);
}();
static constexpr auto get_space_filling_curve()
{
constexpr auto tile_dstr = TileDstr{};
constexpr auto thread_tensor_lengths_ys =
to_sequence(tile_dstr.get_ys_to_d_descriptor().get_lengths());
// FIXME: need logic to judge dim access order
using DimAccessOrder = typename arithmetic_sequence_gen<0, NDimY, 1>::type;
return space_filling_curve<decltype(thread_tensor_lengths_ys),
DimAccessOrder,
decltype(scalars_per_access_)>{};
}
public:
using SFC_Ys = decltype(get_space_filling_curve());
static constexpr index_t NumAccess = SFC_Ys::get_num_of_access();
static_assert(0 < NumAccess, "Wrong! NumAccess should be larger than 0");
static_assert(NumAccess % NumCoord == 0, "wrong! # of access is not divisible by NumCoord");
};
static constexpr index_t NumAccessPerCoord = load_store_traits::NumAccess / NumCoord;
CK_TILE_DEVICE constexpr tile_window_with_static_distribution() = default;
CK_TILE_DEVICE constexpr tile_window_with_static_distribution(
const BottomTensorView& bottom_tensor_view,
const WindowLengths& window_lengths,
const BottomTensorIndex& window_origin,
const TileDstr& tile_distribution)
: bottom_tensor_view_{bottom_tensor_view},
window_lengths_{window_lengths},
window_origin_{window_origin},
tile_dstr_{tile_distribution},
pre_computed_coords_{}
{
#if 0 // debug
// TODO: this use more register for FA, but less register for GEMM
// need investigation
// only support warp-tile and block-tile
static_assert(NDimP == 1 or NDimP == 2, "wrong!");
WindowAdaptorCoord window_adaptor_thread_coord_tmp;
if constexpr(NDimP == 1)
{
window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
tile_distribution.get_ps_ys_to_xs_adaptor(), AdaptorTopIndex{get_lane_id(), 0});
}
else if constexpr(NDimP == 2)
{
window_adaptor_thread_coord_tmp =
make_tensor_adaptor_coordinate(tile_distribution.get_ps_ys_to_xs_adaptor(),
AdaptorTopIndex{get_warp_id(), get_lane_id(), 0});
}
#else
// TODO: this use less register for FA, but more register for GEMM
// need investigation
const auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
tile_distribution.get_ps_ys_to_xs_adaptor(),
container_concat(detail::get_partition_index(tile_distribution),
array<index_t, NDimY>{0}));
#endif
BottomTensorIndex bottom_tensor_thread_origin_idx_tmp =
window_origin + window_adaptor_thread_coord_tmp.get_bottom_index();
const auto bottom_tensor_thread_coord_tmp = make_tensor_coordinate(
bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp);
// pre-compute NumCoord (WindowAdaptorCoord, BottomTensorCoord) bundles to speed up
// future load/store() calls (might allocate more registers)
using Traits = load_store_traits;
using SFC_Ys = typename Traits::SFC_Ys;
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
auto window_adaptor_thread_coord = window_adaptor_thread_coord_tmp;
auto bottom_tensor_thread_coord = bottom_tensor_thread_coord_tmp;
constexpr auto idx_diff_ys =
SFC_Ys::get_step_between(number<0>{}, number<iCoord * NumAccessPerCoord>{});
constexpr auto idx_diff_ps_ys = container_concat(array<index_t, NDimP>{0}, idx_diff_ys);
move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
pre_computed_coords_(iCoord) =
make_tuple(window_adaptor_thread_coord, bottom_tensor_thread_coord);
});
}
CK_TILE_DEVICE static constexpr index_t get_num_of_dimension() { return NDimBottomTensor; }
CK_TILE_DEVICE static constexpr bool has_static_tile_distribution()
{
return TileDstr::is_static();
}
CK_TILE_DEVICE constexpr auto get_window_lengths() const { return window_lengths_; }
CK_TILE_DEVICE constexpr auto get_tile_distribution() const { return tile_dstr_; }
CK_TILE_DEVICE constexpr auto get_bottom_tensor_view() const { return bottom_tensor_view_; }
CK_TILE_DEVICE constexpr auto get_window_origin() const { return window_origin_; }
// move thread's window adaptor coordinate and bottom tensor coordinate
// [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...] ==> [x0', x1', ...] ==> [offset]
CK_TILE_DEVICE void move_window_adaptor_and_bottom_tensor_thread_coordinate(
WindowAdaptorCoord& window_adaptor_thread_coord,
BottomTensorCoord& bottom_tensor_thread_coord,
const AdaptorTopIndex& idx_diff_adaptor_top) const
{
array<index_t, NDimBottomTensor> idx_diff_adaptor_bottom;
move_tensor_adaptor_coordinate(tile_dstr_.get_ps_ys_to_xs_adaptor(),
window_adaptor_thread_coord,
idx_diff_adaptor_top,
idx_diff_adaptor_bottom);
move_tensor_coordinate(bottom_tensor_view_.get_tensor_descriptor(),
bottom_tensor_thread_coord,
idx_diff_adaptor_bottom);
}
// return vector dimension among [y0, y1, ...]
CK_TILE_DEVICE static constexpr auto get_window_adaptor_ys_safe_vector_length_strides()
{
// bottom tensor top dimension vector lengths and strides
const auto [bottom_tensor_top_dim_vector_lengths, bottom_tensor_top_dim_vector_strides] =
BottomTensorDesc::get_top_dimension_safe_vector_length_strides();
// window vector lengths/strides
const auto window_adaptor_bottom_dim_vector_lengths = bottom_tensor_top_dim_vector_lengths;
const auto window_adaptor_bottom_dim_vector_strides = bottom_tensor_top_dim_vector_strides;
// window adaptor [p0, p1, ..., y0, y1, ...]
array<index_t, WindowAdaptor::get_num_of_hidden_dimension()> window_adaptor_vector_lengths{
-1};
array<index_t, WindowAdaptor::get_num_of_hidden_dimension()> window_adaptor_vector_strides{
-1};
constexpr auto window_adaptor_bottom_dims =
WindowAdaptor::get_bottom_dimension_hidden_ids();
set_container_subset(window_adaptor_vector_lengths,
window_adaptor_bottom_dims,
window_adaptor_bottom_dim_vector_lengths);
set_container_subset(window_adaptor_vector_strides,
window_adaptor_bottom_dims,
window_adaptor_bottom_dim_vector_strides);
const auto [window_adaptor_ps_ys_vector_lengths, window_adaptor_ps_ys_vector_strides] =
WindowAdaptor{}.get_top_dimension_safe_vector_length_strides(
window_adaptor_vector_lengths, window_adaptor_vector_strides);
// [y0, y1, ...]
constexpr auto y_dims = typename arithmetic_sequence_gen<TileDstr::GetNumOfDimensionP(),
NDimWindowAdaptorTop,
1>::type{};
return make_tuple(get_container_subset(window_adaptor_ps_ys_vector_lengths, y_dims),
get_container_subset(window_adaptor_ps_ys_vector_strides, y_dims));
}
CK_TILE_DEVICE constexpr auto get_num_access() const { return load_store_traits::NumAccess; }
template <bool oob_conditional_check = true>
CK_TILE_DEVICE auto load(bool_constant<oob_conditional_check> = {}) const
{
using Traits = load_store_traits;
using vector_type_t = typename Traits::vector_type_t;
using vector_t = typename vector_type_t::type;
using SFC_Ys = typename Traits::SFC_Ys;
constexpr auto tile_dstr = TileDstr{};
auto dst_tensor = make_static_distributed_tensor<DataType>(tile_dstr);
// loop over thread tensor space [y0, y1, ...]
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
/// TODO: use structure binding (to be captured later) if compiled in C++20
auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
// data index [y0, y1, ...]
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
// read from bottom tensor
const vector_t vec_value =
get_bottom_tensor_view().template get_vectorized_elements<vector_t>(
bottom_tensor_thread_coord, bool_constant<oob_conditional_check>{});
const vector_type_t vec{vec_value};
// write into distributed tensor
static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) {
constexpr auto idx_ys = generate_array(
[&](auto jj) {
return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
: idx_ys_start[jj];
},
number<NDimY>{});
constexpr index_t d =
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
dst_tensor.get_thread_buffer().template at<d>() =
vec.template AsType<DataType>()[j];
});
// move thread coordinate
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
{
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
constexpr auto idx_diff_ps_ys =
container_concat(array<index_t, NDimP>{0}, idx_diff_ys);
move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
}
});
});
return dst_tensor;
}
template <typename DstTile, bool oob_conditional_check = true>
CK_TILE_DEVICE void load_raw(DstTile& dst_tensor,
bool_constant<oob_conditional_check> = {}) const
{
using Traits = load_store_traits;
using vector_type_t = typename Traits::vector_type_t;
using vector_t = typename vector_type_t::type;
using SFC_Ys = typename Traits::SFC_Ys;
static constexpr index_t YElementSize =
TileDstr{}.get_ys_to_d_descriptor().get_element_space_size();
static_assert(YElementSize % Traits::ScalarPerVector == 0);
using vectorized_tbuf = StaticBuffer<address_space_enum::vgpr,
vector_t,
YElementSize / Traits::ScalarPerVector,
true>;
constexpr auto tile_dstr = TileDstr{};
auto& dst_vec_tbuf = reinterpret_cast<vectorized_tbuf&>(dst_tensor.get_thread_buffer());
// loop over thread tensor space [y0, y1, ...]
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
/// TODO: use structure binding (to be captured later) if compiled in C++20
auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
// data index [y0, y1, ...]
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
constexpr index_t d =
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start);
static_assert(d % Traits::ScalarPerVector == 0);
get_bottom_tensor_view().template get_vectorized_elements_raw<vector_t>(
dst_vec_tbuf.template at<d / Traits::ScalarPerVector>(),
bottom_tensor_thread_coord,
bool_constant<oob_conditional_check>{});
// move thread coordinate
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
{
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
constexpr auto idx_diff_ps_ys =
container_concat(array<index_t, NDimP>{0}, idx_diff_ys);
move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
}
});
});
}
// TODO: currently async load only implemented in inline asm
template <typename LdsTileWindow_, bool oob_conditional_check = true>
CK_TILE_DEVICE auto async_load(LdsTileWindow_&& lds_tile,
bool_constant<oob_conditional_check> = {}) const
{
using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
// using LdsTensorView = typename LdsTileWindow::BottomTensorView;
using LdsDataType = typename LdsTileWindow::DataType;
// using LdsDescriptor = typename LdsTileWindow::BottomTensorDesc;
// issues * warps * lanes
static_assert(LdsTileWindow::get_num_of_dimension() == 3); // TODO: hard coded
const index_t size_per_buf =
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
make_tuple(number<0>{}, number<0>{}, number<0>{})) *
sizeof(LdsDataType);
const index_t size_per_wave =
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
make_tuple(number<0>{}, number<1>{}, number<0>{})) *
sizeof(LdsDataType) -
size_per_buf;
const index_t size_per_issue =
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
make_tuple(number<1>{}, number<0>{}, number<0>{})) *
sizeof(LdsDataType) -
size_per_buf;
const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id();
m0_set_with_memory(m0_init_value); // This should be wave independent
using Traits = load_store_traits;
using vector_type_t = typename Traits::vector_type_t;
using vector_t = typename vector_type_t::type;
using SFC_Ys = typename Traits::SFC_Ys;
LdsDataType* smem = lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_;
// loop over thread tensor space [y0, y1, ...]
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
// TODO: use structure binding (to be captured later) if compiled in C++20
auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
// read from bottom tensor
get_bottom_tensor_view().template async_get_vectorized_elements<vector_t>(
smem, bottom_tensor_thread_coord);
// move thread coordinate
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
{
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
constexpr auto idx_diff_ps_ys =
container_concat(array<index_t, NDimP>{0}, idx_diff_ys);
move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
m0_inc_with_memory(size_per_issue);
}
});
});
}
template <bool oob_conditional_check = true>
CK_TILE_DEVICE void store(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor,
bool_constant<oob_conditional_check> = {}) const
{
using Traits = load_store_traits;
using vector_type_t = typename Traits::vector_type_t;
using vector_t = typename vector_type_t::type;
using SFC_Ys = typename Traits::SFC_Ys;
constexpr auto tile_dstr = TileDstr{};
// loop over thread tensor space [y0, y1, ...]
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
/// TODO: use structure binding (to be captured later) if compiled in C++20
auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
// data index [y0, y1, ...]
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
// read from distributed tensor
vector_type_t vec;
static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) {
constexpr auto idx_ys = generate_array(
[&](auto jj) {
return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
: idx_ys_start[jj];
},
number<NDimY>{});
constexpr index_t d =
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
vec.template AsType<DataType>()(j) =
dstr_tensor.get_thread_buffer().template at<d>();
});
const vector_t vec_value = vec.template AsType<vector_t>().template at<0>();
// write into bottom tensor
get_bottom_tensor_view().template set_vectorized_elements<vector_t>(
bottom_tensor_thread_coord, vec_value, bool_constant<oob_conditional_check>{});
// move thread coordinate
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
{
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
constexpr auto idx_diff_ps_ys =
container_concat(array<index_t, NDimP>{0}, idx_diff_ys);
move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
}
});
});
}
CK_TILE_DEVICE void
store_raw(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor) const
{
using Traits = load_store_traits;
using vector_type_t = typename Traits::vector_type_t;
using vector_t = typename vector_type_t::type;
using SFC_Ys = typename Traits::SFC_Ys;
constexpr auto tile_dstr = TileDstr{};
static constexpr bool oob_conditional_check = true;
// loop over thread tensor space [y0, y1, ...]
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
/// TODO: use structure binding (to be captured later) if compiled in C++20
auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
// data index [y0, y1, ...]
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
// read from distributed tensor
vector_type_t vec;
static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) {
constexpr auto idx_ys = generate_array(
[&](auto jj) {
return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
: idx_ys_start[jj];
},
number<NDimY>{});
constexpr index_t d =
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
vec.template AsType<DataType>()(j) =
dstr_tensor.get_thread_buffer().template at<d>();
});
const vector_t vec_value = vec.template AsType<vector_t>().template at<0>();
// write into bottom tensor
get_bottom_tensor_view()
.template set_vectorized_elements_raw<vector_t, oob_conditional_check>(
bottom_tensor_thread_coord, vec_value);
// move thread coordinate
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
{
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
constexpr auto idx_diff_ps_ys =
container_concat(array<index_t, NDimP>{0}, idx_diff_ys);
move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
}
});
});
}
// move thread's botom tensor coordiante
// [x0', x1', ... ] ==> [offset]
// also move window-origin
CK_TILE_DEVICE void move(const BottomTensorIndex& step)
{
window_origin_ += step;
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
move_tensor_coordinate(bottom_tensor_view_.get_tensor_descriptor(),
pre_computed_coords_(iCoord)(I1),
step);
});
}
// this is the bottom tensor view
// [x0', x1', ...] ==> [offset]
BottomTensorView bottom_tensor_view_;
//
WindowLengths window_lengths_;
// origin ([x0', x1', ...]) of window on bottom tensor
BottomTensorIndex window_origin_;
// Tile tensor distribution, which contains:
// 1. adaptor for window: [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...]
// 2. thread descriptor for thread tensor in register: [y0, y1, ...] ==> [d]
TileDstr tile_dstr_;
// this contains:
// per-thread coordinate for window adaptor
// per-thread coordinate for bottom tensor
array<tuple<WindowAdaptorCoord, BottomTensorCoord>, NumCoord> pre_computed_coords_;
};
// TODO: use strategy
template <typename TensorView_,
typename WindowLengths_,
typename StaticTileDistribution_,
index_t NumCoord = 1>
CK_TILE_DEVICE constexpr auto
make_tile_window(const TensorView_& tensor_view,
const WindowLengths_& window_lengths,
const multi_index<TensorView_::get_num_of_dimension()>& origin,
const StaticTileDistribution_& tile_distribution,
number<NumCoord> = {})
{
return tile_window_with_static_distribution<remove_cvref_t<TensorView_>,
remove_cvref_t<WindowLengths_>,
remove_cvref_t<StaticTileDistribution_>,
NumCoord>{
tensor_view, window_lengths, origin, tile_distribution};
}
template <typename TensorView_,
typename WindowLengths_,
typename StaticTileDistribution_,
index_t NumCoord>
CK_TILE_DEVICE void move_tile_window(
tile_window_with_static_distribution<TensorView_,
WindowLengths_,
StaticTileDistribution_,
NumCoord>& window,
const typename tile_window_with_static_distribution<TensorView_,
WindowLengths_,
StaticTileDistribution_,
NumCoord>::BottomTensorIndex& step)
{
window.move(step);
}
template <typename BottomTensorView_, typename WindowLengths_>
struct tile_window_with_static_lengths
{
using BottomTensorView = remove_reference_t<BottomTensorView_>;
using WindowLengths = remove_cvref_t<WindowLengths_>;
using BottomTensorDesc = typename BottomTensorView::TensorDesc;
using DataType = typename BottomTensorView::DataType;
static constexpr index_t NDimBottomTensor = BottomTensorDesc::get_num_of_dimension();
static_assert(ck_tile::is_known_at_compile_time<WindowLengths>::value,
"wrong! lengths should be static");
using BottomTensorIndex = array<index_t, NDimBottomTensor>;
CK_TILE_DEVICE constexpr tile_window_with_static_lengths() = default;
CK_TILE_DEVICE constexpr tile_window_with_static_lengths(
const BottomTensorView& bottom_tensor_view,
const WindowLengths& window_lengths,
const BottomTensorIndex& window_origin)
: bottom_tensor_view_{bottom_tensor_view},
window_lengths_{window_lengths},
window_origin_{window_origin}
{
}
CK_TILE_DEVICE static constexpr index_t get_num_of_dimension() { return NDimBottomTensor; }
CK_TILE_DEVICE constexpr auto get_window_lengths() const { return window_lengths_; }
CK_TILE_DEVICE constexpr auto get_bottom_tensor_view() const { return bottom_tensor_view_; }
CK_TILE_DEVICE constexpr auto get_window_origin() const { return window_origin_; }
// move window-origin
CK_TILE_DEVICE void move(const BottomTensorIndex& step) { window_origin_ += step; }
// this is the bottom tensor view
// [x0', x1', ...] ==> [offset]
BottomTensorView bottom_tensor_view_;
//
WindowLengths window_lengths_;
// origin ([x0', x1', ...]) of window on bottom tensor
BottomTensorIndex window_origin_;
};
template <typename TensorView_, typename WindowLengths_>
CK_TILE_DEVICE constexpr auto
make_tile_window(const TensorView_& tensor_view,
const WindowLengths_& window_lengths,
const multi_index<TensorView_::get_num_of_dimension()>& origin)
{
static_assert(ck_tile::is_known_at_compile_time<WindowLengths_>::value,
"wrong! lengths should be static");
return tile_window_with_static_lengths<remove_cvref_t<TensorView_>,
remove_cvref_t<WindowLengths_>>{
tensor_view, window_lengths, origin};
}
template <typename TensorView_, typename WindowLengths_>
CK_TILE_DEVICE void move_tile_window(
tile_window_with_static_lengths<TensorView_, WindowLengths_>& window,
const typename tile_window_with_static_lengths<TensorView_, WindowLengths_>::BottomTensorIndex&
step)
{
window.move(step);
}
} // namespace ck_tile

View File

@@ -0,0 +1,19 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
namespace ck_tile {
template <typename Y, typename X>
CK_TILE_HOST_DEVICE constexpr Y bit_cast(const X& x)
{
static_assert(__has_builtin(__builtin_bit_cast), "");
static_assert(sizeof(X) == sizeof(Y), "Do not support cast between different size of type");
return __builtin_bit_cast(Y, x);
}
} // namespace ck_tile

View File

@@ -0,0 +1,194 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/sequence.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include <stdint.h>
#include <utility>
namespace ck_tile {
namespace detail {
struct swallow
{
template <typename... Ts>
CK_TILE_HOST_DEVICE constexpr swallow(Ts&&...)
{
}
};
template <class>
struct static_for_impl;
template <index_t... Is>
struct static_for_impl<sequence<Is...>>
{
template <class F>
CK_TILE_HOST_DEVICE constexpr void operator()(F f) const
{
swallow{(f(number<Is>{}), 0)...};
}
};
} // namespace detail
// F signature: F(number<Iter>)
template <index_t NBegin, index_t NEnd, index_t Increment>
struct static_for
{
CK_TILE_HOST_DEVICE constexpr static_for()
{
static_assert(Increment != 0 && (NEnd - NBegin) % Increment == 0,
"Wrong! should satisfy (NEnd - NBegin) % Increment == 0");
static_assert((Increment > 0 && NBegin <= NEnd) || (Increment < 0 && NBegin >= NEnd),
"wrongs! should (Increment > 0 && NBegin <= NEnd) || (Increment < 0 && "
"NBegin >= NEnd)");
}
template <class F>
CK_TILE_HOST_DEVICE constexpr void operator()(F f) const
{
detail::static_for_impl<typename arithmetic_sequence_gen<NBegin, NEnd, Increment>::type>{}(
f);
}
};
struct identity
{
template <typename T>
CK_TILE_HOST_DEVICE constexpr T&& operator()(T&& arg) const noexcept
{
return std::forward<T>(arg);
}
};
namespace detail {
// RemainLengths: sequence<...>
// Orders: sequence<...>
template <class RemainLengths, class Orders>
struct static_ford_impl
{
CK_TILE_HOST_DEVICE constexpr static_ford_impl()
{
static_assert(RemainLengths::size() > 0, "wrong! should not get here");
}
// F signature: F(sequence<...>)
// CurrentOrderedId: sequence<...>
template <class F, class CurrentOrderedId>
CK_TILE_HOST_DEVICE constexpr void operator()(F f, CurrentOrderedId) const
{
static_for<0, RemainLengths::front(), 1>{}([=](auto I) {
static_ford_impl<decltype(RemainLengths::pop_front()), Orders>{}(
f, CurrentOrderedId::push_back(I));
});
}
};
template <class Orders>
struct static_ford_impl<sequence<>, Orders>
{
// F signature: F(sequence<...>)
// OrderedId: sequence<...>
template <class F, class OrderedId>
CK_TILE_HOST_DEVICE constexpr void operator()(F f, OrderedId) const
{
// retrive unordered Id
f(OrderedId::reorder_old_to_new(Orders{}));
}
};
} // namespace detail
// Lengths is sequence<...>, it is the length of each dimension for
// N-dimensional loop
// Orders is sequence<...>, it is the order of dimension in which static_ford
// will loop over each
// dimension
template <class Lengths,
class Orders = typename arithmetic_sequence_gen<0, Lengths::size(), 1>::type>
struct static_ford
{
CK_TILE_HOST_DEVICE constexpr static_ford()
{
static_assert(Lengths::size() > 0, "wrong! Lengths is empty");
static_assert(Lengths::size() == Orders::size(), "wrong! inconsistent size");
}
// F signature: F(sequence<...> multi_id)
// multi_id is the unordered multi-index
template <class F>
CK_TILE_HOST_DEVICE constexpr void operator()(F f) const
{
constexpr auto ordered_lengths = Lengths::reorder_new_to_old(Orders{});
detail::static_ford_impl<decltype(ordered_lengths), Orders>{}(f, sequence<>{});
}
};
namespace detail {
template <typename Indices>
struct unpack_impl;
template <index_t... Is>
struct unpack_impl<sequence<Is...>>
{
template <typename F, typename X>
CK_TILE_HOST_DEVICE constexpr auto operator()(F&& f, X&& x) const
{
#if 0
return std::forward<F>(f)(std::forward<X>(x).at(number<Is>{})...);
#else
return std::forward<F>(f)(std::forward<X>(x).template at<Is>()...);
#endif
}
};
template <typename Seq0, typename Seq1>
struct unpack2_impl;
// TODO: remove this, after properly implementing unpack that takes any number of containers
template <index_t... Is, index_t... Js>
struct unpack2_impl<sequence<Is...>, sequence<Js...>>
{
template <typename F, typename X, typename Y>
CK_TILE_HOST_DEVICE constexpr auto operator()(F&& f, X&& x, Y&& y) const
{
#if 0
return std::forward<F>(f)(std::forward<X>(x).at(number<Is>{})...,
std::forward<Y>(y).at(number<Js>{})...);
#else
return std::forward<F>(f)(std::forward<X>(x).template at<Is>()...,
std::forward<Y>(y).template at<Js>()...);
#endif
}
};
} // namespace detail
template <typename F, typename X>
CK_TILE_HOST_DEVICE constexpr auto unpack(F&& f, X&& x)
{
using X_ = remove_reference_t<X>;
return detail::unpack_impl<typename arithmetic_sequence_gen<0, X_::size(), 1>::type>{}(
std::forward<F>(f), std::forward<X>(x));
}
// TODO: properly implement unpack that takes any number of containers
template <typename F, typename X, typename Y>
CK_TILE_HOST_DEVICE constexpr auto unpack2(F&& f, X&& x, Y&& y)
{
using X_ = remove_reference_t<X>;
using Y_ = remove_reference_t<Y>;
return detail::unpack2_impl<typename arithmetic_sequence_gen<0, X_::size(), 1>::type,
typename arithmetic_sequence_gen<0, Y_::size(), 1>::type>{}(
std::forward<F>(f), std::forward<X>(x), std::forward<Y>(y));
}
} // namespace ck_tile

View File

@@ -0,0 +1,75 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include <limits>
#include <stdint.h>
namespace ck_tile {
template <typename T>
struct numeric_limits
{
// minimum finite value, or minimum positive normalized value for float
CK_TILE_HOST_DEVICE static constexpr T min() { return std::numeric_limits<T>::min(); }
// minumum finite value
CK_TILE_HOST_DEVICE static constexpr T lowest() { return std::numeric_limits<T>::lowest(); }
// maximum finite value
CK_TILE_HOST_DEVICE static constexpr T max() { return std::numeric_limits<T>::max(); }
// difference between 1.0 and next value representable by float
CK_TILE_HOST_DEVICE static constexpr T epsilon() { return std::numeric_limits<T>::epsilon(); }
// maximum rounding error
CK_TILE_HOST_DEVICE static constexpr T round_error()
{
return std::numeric_limits<T>::round_error();
}
// positive infinity value
CK_TILE_HOST_DEVICE static constexpr T infinity() { return std::numeric_limits<T>::infinity(); }
// quiet NaN
CK_TILE_HOST_DEVICE static constexpr T quiet_NaN()
{
return std::numeric_limits<T>::quiet_NaN();
}
// signaling NaN
CK_TILE_HOST_DEVICE static constexpr T signaling_NaN()
{
return std::numeric_limits<T>::signaling_NaN();
}
// smallest positive subnormal value
CK_TILE_HOST_DEVICE static constexpr T denorm_min()
{
return std::numeric_limits<T>::denorm_min();
}
};
template <typename T>
struct numeric_utils;
template <>
struct numeric_utils<float>
{
static constexpr int exp = 8;
static constexpr int mant = 23;
static constexpr int bias = 127;
static constexpr uint32_t nan_mask = 0x7F800000;
static constexpr uint32_t head_mask = 0xFF800000;
static constexpr uint32_t mant_mask = 0x7FFFFF;
static constexpr uint32_t exp_mask = 0xFF;
static constexpr uint32_t Inf = 0x7F800000;
static constexpr uint32_t NegInf = 0xFF800000;
static constexpr uint32_t NaN = 0x7F800001;
static constexpr uint32_t Neg0 = 0x80000000;
using bitwise_type = uint32_t;
};
} // namespace ck_tile

View File

@@ -0,0 +1,261 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/tuple.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include <stdint.h>
namespace ck_tile {
// magic number division
// Caution:
// 1. For uint32_t as dividend: magic number division implementation being used would produce
// correct result if the dividend is uint32_t and its value is within 31-bit value range.
// 2. For int32_t as dividendd: magic number division for int32_t dividened has not been
// implemented, the int32_t dividend would be bit-wise interpreted as uint32_t and magic number
// division implementation for uint32_t is then used. Therefore, dividend value need to be
// non-negative.
// TODO:
// 1. Implement magic number divison for int32_t
// 2. Implement magic number divison for unit32_t with 32-bit value range
struct magic_division32_bit_range
{
// uint32_t
CK_TILE_HOST_DEVICE static constexpr auto calculate_magic_numbers(uint32_t divisor)
{
// WARNING: magic division is only valid for division inside this range.
// assert(divisor >= 1 && divisor <= INT32_MAX)
uint32_t shift_u32 = 0;
while((1U << shift_u32) < divisor)
{
shift_u32++;
};
uint64_t tmp_u64 = ((1UL << shift_u32) - divisor) << 32;
uint32_t multiplier_u32 = tmp_u64 / divisor + 1;
return make_tuple(multiplier_u32, shift_u32);
}
// integral_constant<uint32_t, .>
template <uint32_t Divisor, typename = std::enable_if_t<(0 < Divisor)>>
CK_TILE_HOST_DEVICE static constexpr auto
calculate_magic_numbers(integral_constant<uint32_t, Divisor>)
{
constexpr auto tmp = calculate_magic_numbers(uint32_t{Divisor});
constexpr uint32_t multiplier = tmp[number<0>{}];
constexpr uint32_t shift = tmp[number<1>{}];
return make_tuple(integral_constant<uint32_t, multiplier>{},
integral_constant<uint32_t, shift>{});
}
// integral_constant<int32_t, .>
template <int32_t Divisor>
CK_TILE_HOST_DEVICE static constexpr auto
calculate_magic_numbers(integral_constant<int32_t, Divisor>)
{
return calculate_magic_numbers(integral_constant<uint32_t, Divisor>{});
}
// magic division for uint32_t
CK_TILE_DEVICE static constexpr uint32_t
do_magic_division(uint32_t dividend, uint32_t multiplier, uint32_t shift)
{
uint32_t tmp = __umulhi(dividend, multiplier);
return (tmp + dividend) >> shift;
}
CK_TILE_HOST static constexpr uint32_t
do_magic_division(uint32_t dividend, uint32_t multiplier, uint32_t shift)
{
uint32_t tmp = (static_cast<uint64_t>(dividend) * multiplier) >> 32;
return (tmp + dividend) >> shift;
}
// magic division for int32_t
// HACK: use dividend_i32 as if it's uint32_t, dividend_i32 need to be
// non-negative for result to be correct
// TODO: figure out how to do magic number divison for int32_t as dividended
CK_TILE_DEVICE static constexpr int32_t
do_magic_division(int32_t dividend_i32, uint32_t multiplier, uint32_t shift)
{
uint32_t dividend_u32 = bit_cast<uint32_t>(dividend_i32);
uint32_t tmp = __umulhi(dividend_u32, multiplier);
return (tmp + dividend_u32) >> shift;
}
CK_TILE_HOST static constexpr int32_t
do_magic_division(int32_t dividend_i32, uint32_t multiplier, uint32_t shift)
{
uint32_t dividend_u32 = bit_cast<uint32_t>(dividend_i32);
uint32_t tmp = (static_cast<uint64_t>(dividend_u32) * multiplier) >> 32;
return (tmp + dividend_u32) >> shift;
}
};
// magic number division
// This version on works for divisor and dividended between [0, 1 << 16]
struct magic_division16_bit_range
{
// uint32_t
CK_TILE_HOST_DEVICE static constexpr auto calculate_magic_numbers(uint32_t divisor)
{
// WARNING: magic division is only valid for division inside this range.
// assert(divisor >= 1 && divisor <= (1U << 16));
uint32_t shift_u32 = 0;
while((1U << shift_u32) < divisor)
{
shift_u32++;
};
uint32_t one = 1;
uint32_t multiplier_u32 = ((one << 16) * ((one << shift_u32) - divisor)) / divisor + 1;
return make_tuple(multiplier_u32, shift_u32);
}
// integral_constant<uint32_t, .>
template <uint32_t Divisor>
CK_TILE_HOST_DEVICE static constexpr auto
calculate_magic_numbers(integral_constant<uint32_t, Divisor>)
{
constexpr auto tmp = calculate_magic_numbers(uint32_t{Divisor});
constexpr uint32_t multiplier = tmp[number<0>{}];
constexpr uint32_t shift = tmp[number<1>{}];
return make_tuple(integral_constant<uint32_t, multiplier>{},
integral_constant<uint32_t, shift>{});
}
// integral_constant<int32_t, .>
template <int32_t Divisor>
CK_TILE_HOST_DEVICE static constexpr auto
calculate_magic_numbers(integral_constant<int32_t, Divisor>)
{
return calculate_magic_numbers(integral_constant<uint32_t, Divisor>{});
}
// magic division for uint32_t
CK_TILE_DEVICE static constexpr uint32_t
do_magic_division(uint32_t dividend, uint32_t multiplier, uint32_t shift)
{
uint32_t tmp = (dividend * multiplier) >> 16;
return (tmp + dividend) >> shift;
}
CK_TILE_HOST static constexpr uint32_t
do_magic_division(uint32_t dividend, uint32_t multiplier, uint32_t shift)
{
uint32_t tmp = (dividend * multiplier) >> 16;
return (tmp + dividend) >> shift;
}
// magic division for int32_t
// HACK: use dividend_i32 as if it's uint32_t, dividend_i32 need to be
// non-negative for result to be correct
// TODO: figure out how to do magic number divison for int32_t as dividended
CK_TILE_DEVICE static constexpr int32_t
do_magic_division(int32_t dividend_i32, uint32_t multiplier, uint32_t shift)
{
uint32_t dividend_u32 = bit_cast<uint32_t>(dividend_i32);
uint32_t tmp = (dividend_u32 * multiplier) >> 16;
return (tmp + dividend_u32) >> shift;
}
CK_TILE_HOST static constexpr int32_t
do_magic_division(int32_t dividend_i32, uint32_t multiplier, uint32_t shift)
{
uint32_t dividend_u32 = bit_cast<uint32_t>(dividend_i32);
uint32_t tmp = (dividend_u32 * multiplier) >> 16;
return (tmp + dividend_u32) >> shift;
}
};
// use 32bit version
using magic_division = magic_division32_bit_range;
struct mdiv
{
// 1 dword -> 3 dword storage
uint32_t divisor;
uint32_t multiplier;
uint32_t shift; // TODO: 8 bit is enough
// prefer construct on host
CK_TILE_HOST_DEVICE mdiv(uint32_t divisor_) : divisor(divisor_)
{
auto tmp = magic_division::calculate_magic_numbers(divisor_);
multiplier = tmp[number<0>{}];
shift = tmp[number<1>{}];
}
CK_TILE_HOST_DEVICE mdiv() : divisor(0), multiplier(0), shift(0) {}
CK_TILE_HOST_DEVICE void update(uint32_t divisor_)
{
divisor = divisor_;
auto tmp = magic_division::calculate_magic_numbers(divisor_);
multiplier = tmp[number<0>{}];
shift = tmp[number<1>{}];
}
CK_TILE_HOST_DEVICE uint32_t div(uint32_t dividend_) const
{
return magic_division::do_magic_division(dividend_, multiplier, shift);
}
CK_TILE_HOST_DEVICE void
divmod(uint32_t dividend_, uint32_t& quotient_, uint32_t& remainder_) const
{
quotient_ = div(dividend_);
remainder_ = dividend_ - (quotient_ * divisor);
}
CK_TILE_HOST_DEVICE uint32_t get() const { return divisor; }
};
struct mdiv2
{
// 1 dword -> 2 dword storage, divisor need compute from runtime
uint32_t multiplier;
uint32_t shift; // TODO: 8 bit is enough
// prefer construct on host
CK_TILE_HOST_DEVICE mdiv2(uint32_t divisor_)
{
auto tmp = magic_division::calculate_magic_numbers(divisor_);
multiplier = tmp[number<0>{}];
shift = tmp[number<1>{}];
}
CK_TILE_HOST_DEVICE mdiv2() : multiplier(0), shift(0) {}
CK_TILE_HOST_DEVICE uint32_t div(uint32_t dividend_) const
{
return magic_division::do_magic_division(dividend_, multiplier, shift);
}
CK_TILE_HOST_DEVICE void
divmod(uint32_t dividend_, uint32_t divisor_, uint32_t& quotient_, uint32_t& remainder_) const
{
quotient_ = div(dividend_);
remainder_ = dividend_ - (quotient_ * divisor_);
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,64 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/half.hpp"
#include <stdint.h>
#include <tuple>
#include <type_traits>
namespace ck_tile {
// return 0 if data is not fp16 or fp32
template <typename T, uint32_t seed_>
struct prand_generator_t
{
CK_TILE_HOST_DEVICE uint32_t operator()(int id, T val, uint32_t seed = seed_)
{
std::ignore = id;
std::ignore = val;
std::ignore = seed;
return 0;
}
};
// version for fp32
template <uint32_t seed_>
struct prand_generator_t<float, seed_>
{
CK_TILE_HOST_DEVICE uint32_t operator()(int id, float val, uint32_t seed = seed_)
{
uint32_t x = *(reinterpret_cast<uint32_t*>(&val));
uint32_t drop_bits = uint32_t(x) & 0xFFFFu;
drop_bits ^= x >> 16;
drop_bits = ((drop_bits & 31) << 11) | (drop_bits >> 5);
drop_bits *= 0x7000149;
// NOTE: If id is in 64 bit, we are only using lower 32 bit.
// So, it can have an effect of using same id for multiple elements when the id is
// very large!
uint32_t rng = (drop_bits ^ 0x13371337 ^ (id * 229791) ^ seed);
return rng;
}
};
// version for fp16
template <uint32_t seed_>
struct prand_generator_t<half_t, seed_>
{
CK_TILE_HOST_DEVICE uint32_t operator()(int id, half_t val, uint32_t seed = seed_)
{
uint16_t x = *(reinterpret_cast<uint16_t*>(&val));
uint32_t drop_bits = uint32_t(x) & 0xFFFFu;
drop_bits = ((drop_bits & 31) << 11) | (drop_bits >> 5);
drop_bits *= 0x7000149;
// NOTE: If id is in 64 bit, we are only using lower 32 bit.
// So, it can have an effect of using same id for multiple elements when the id is
// very large!
uint32_t rng = (drop_bits ^ 0x13371337 ^ (id * 229791) ^ seed);
return rng;
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,72 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/container/sequence.hpp"
// TODO: use c++20 nontype template with struct to implement this
#if 1
// clang happen to support this feature (__cpp_generic_lambdas >= 201707) in c++17 mode
#define TO_SEQUENCE(a, n) \
_Pragma("clang diagnostic push") \
_Pragma("clang diagnostic ignored \"-Wc++20-extensions\"")[a]<ck_tile::index_t... Is>( \
ck_tile::sequence<Is...>) \
{ \
return ck_tile::sequence<a.at(ck_tile::number<Is>{})...>{}; \
} \
(make_index_sequence<n>{}) _Pragma("clang diagnostic pop")
#else
// Macro function
// convert constexpr array to sequence, both a/n need to be constexpr (can't be a rvalue like 2)
#define TO_SEQUENCE(a, n) \
[a, n] { \
static_assert(a.size() >= n, "wrong! out of bound"); \
static_assert(n <= 10, "not implemented"); \
if constexpr(n == 0) \
{ \
return ck_tile::sequence<>{}; \
} \
else if constexpr(n == 1) \
{ \
return ck_tile::sequence<a[0]>{}; \
} \
else if constexpr(n == 2) \
{ \
return ck_tile::sequence<a[0], a[1]>{}; \
} \
else if constexpr(n == 3) \
{ \
return ck_tile::sequence<a[0], a[1], a[2]>{}; \
} \
else if constexpr(n == 4) \
{ \
return ck_tile::sequence<a[0], a[1], a[2], a[3]>{}; \
} \
else if constexpr(n == 5) \
{ \
return ck_tile::sequence<a[0], a[1], a[2], a[3], a[4]>{}; \
} \
else if constexpr(n == 6) \
{ \
return ck_tile::sequence<a[0], a[1], a[2], a[3], a[4], a[5]>{}; \
} \
else if constexpr(n == 7) \
{ \
return ck_tile::sequence<a[0], a[1], a[2], a[3], a[4], a[5], a[6]>{}; \
} \
else if constexpr(n == 8) \
{ \
return ck_tile::sequence<a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7]>{}; \
} \
else if constexpr(n == 9) \
{ \
return ck_tile::sequence<a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], a[8]>{}; \
} \
else if constexpr(n == 10) \
{ \
return ck_tile:: \
sequence<a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], a[8], a[9]>{}; \
} \
}()
#endif

View File

@@ -0,0 +1,57 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <stdint.h>
#include <tuple>
#include <type_traits>
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/bfloat16.hpp"
#include "ck_tile/core/numeric/float8.hpp"
namespace ck_tile {
// Convert X to Y, both X and Y are non-const data types.
template <typename Y,
typename X,
std::enable_if_t<!(std::is_const_v<Y> || std::is_const_v<X>), bool> = false>
CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
{
static_assert(!std::is_reference_v<Y> && !std::is_reference_v<X>);
return static_cast<Y>(x);
}
// Convert X to Y, either X or Y is a const data type.
template <typename Y,
typename X,
std::enable_if_t<std::is_const_v<Y> || std::is_const_v<X>, bool> = false>
CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
{
static_assert(!std::is_reference_v<Y> && !std::is_reference_v<X>);
using non_const_y = std::remove_const_t<Y>;
using non_const_x = std::remove_const_t<X>;
return static_cast<Y>(type_convert<non_const_y, non_const_x>(x));
}
#define CK_TILE_TYPE_CONVERT(dtype_, stype_) \
template <> \
inline CK_TILE_HOST_DEVICE constexpr dtype_ type_convert<dtype_, stype_>(stype_ x) \
{ \
return stype_##_to_##dtype_(x); \
}
CK_TILE_TYPE_CONVERT(float, fp16_t)
CK_TILE_TYPE_CONVERT(float, bf16_t)
CK_TILE_TYPE_CONVERT(float, fp8_t)
CK_TILE_TYPE_CONVERT(float, bf8_t)
CK_TILE_TYPE_CONVERT(fp16_t, float)
CK_TILE_TYPE_CONVERT(bf16_t, float)
CK_TILE_TYPE_CONVERT(fp8_t, float)
CK_TILE_TYPE_CONVERT(bf8_t, float)
} // namespace ck_tile

View File

@@ -0,0 +1,46 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include <type_traits>
#include <stdint.h>
namespace ck_tile {
// remove_cvref_t
template <typename T>
using remove_reference_t = typename std::remove_reference<T>::type;
template <typename T>
using remove_cv_t = typename std::remove_cv<T>::type;
template <typename T>
using remove_cvref_t = remove_cv_t<std::remove_reference_t<T>>;
template <typename T>
using remove_pointer_t = typename std::remove_pointer<T>::type;
namespace impl {
template <typename T>
struct is_static_impl
{
static constexpr bool value = std::is_arithmetic<T>::v ? false : T::is_static();
};
} // namespace impl
template <typename T>
using is_static = impl::is_static_impl<remove_cvref_t<T>>;
template <typename T>
inline constexpr bool is_static_v = is_static<T>::value;
// TODO: deprecate this
template <typename T>
using is_known_at_compile_time = is_static<T>;
// TODO: if evaluating a rvalue, e.g. a const integer
// , this helper will also return false, which is not good(?)
// do we need something like is_constexpr()?
} // namespace ck_tile