mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 17:26:00 +00:00
add code
This commit is contained in:
18
include/ck_tile/core/README.md
Normal file
18
include/ck_tile/core/README.md
Normal file
@@ -0,0 +1,18 @@
|
||||
# ck_tile/core #
|
||||
|
||||
`ck_tile/core` contains every basic functions and structures to create a GPU kernel using `ck_tile`. User should only include `ck_tile/core.hpp` this single header to use all the functionality. Everything is under `ck_tile` namespace. The coding style under this folder should be similar to `std` (`snake_case` for structure/function, Camel for template types...)
|
||||
|
||||
```
|
||||
algorithm/
|
||||
coordinate transform and some other reusable algorithm
|
||||
arch/
|
||||
contains some basic device building block like mma, buffer addressing, etc...
|
||||
container/
|
||||
contains basic container data structure, array/sequence/tuple/...
|
||||
numeric/
|
||||
data type, and data type related math
|
||||
tensor/
|
||||
tensor descriptors and tile level API
|
||||
utility/
|
||||
other utility function for both host/device
|
||||
```
|
||||
38
include/ck_tile/core/algorithm/cluster_descriptor.hpp
Normal file
38
include/ck_tile/core/algorithm/cluster_descriptor.hpp
Normal file
@@ -0,0 +1,38 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
|
||||
#include "ck_tile/core/tensor/tensor_adaptor.hpp"
|
||||
#include "ck_tile/core/container/container_helper.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Lengths,
|
||||
typename ArrangeOrder = typename arithmetic_sequence_gen<0, Lengths::size(), 1>::type>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_cluster_descriptor(
|
||||
const Lengths& lengths,
|
||||
ArrangeOrder order = typename arithmetic_sequence_gen<0, Lengths::size(), 1>::type{})
|
||||
{
|
||||
constexpr index_t ndim_low = Lengths::size();
|
||||
|
||||
const auto reordered_lengths = container_reorder_given_new2old(lengths, order);
|
||||
|
||||
const auto low_lengths = generate_tuple(
|
||||
[&](auto idim_low) { return reordered_lengths[idim_low]; }, number<ndim_low>{});
|
||||
|
||||
const auto transform = make_merge_transform(low_lengths);
|
||||
|
||||
constexpr auto low_dim_old_top_ids = ArrangeOrder{};
|
||||
|
||||
constexpr auto up_dim_new_top_ids = sequence<0>{};
|
||||
|
||||
return make_single_stage_tensor_adaptor(
|
||||
make_tuple(transform), make_tuple(low_dim_old_top_ids), make_tuple(up_dim_new_top_ids));
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
1664
include/ck_tile/core/algorithm/coordinate_transform.hpp
Normal file
1664
include/ck_tile/core/algorithm/coordinate_transform.hpp
Normal file
File diff suppressed because it is too large
Load Diff
165
include/ck_tile/core/algorithm/space_filling_curve.hpp
Normal file
165
include/ck_tile/core/algorithm/space_filling_curve.hpp
Normal file
@@ -0,0 +1,165 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/container/multi_index.hpp"
|
||||
#include "ck_tile/core/container/container_helper.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename TensorLengths,
|
||||
typename DimAccessOrder,
|
||||
typename ScalarsPerAccess,
|
||||
bool SnakeCurved = true> // # of scalars per access in each dimension
|
||||
struct space_filling_curve
|
||||
{
|
||||
static constexpr index_t TensorSize =
|
||||
reduce_on_sequence(TensorLengths{}, multiplies{}, number<1>{});
|
||||
static_assert(0 < TensorSize,
|
||||
"space_filling_curve should be used to access a non-empty tensor");
|
||||
|
||||
static constexpr index_t nDim = TensorLengths::size();
|
||||
|
||||
using Index = multi_index<nDim>;
|
||||
|
||||
static constexpr index_t ScalarPerVector =
|
||||
reduce_on_sequence(ScalarsPerAccess{}, multiplies{}, number<1>{});
|
||||
|
||||
static constexpr auto access_lengths = TensorLengths{} / ScalarsPerAccess{};
|
||||
static constexpr auto dim_access_order = DimAccessOrder{};
|
||||
static constexpr auto ordered_access_lengths =
|
||||
container_reorder_given_new2old(access_lengths, dim_access_order);
|
||||
|
||||
static constexpr auto to_index_adaptor = make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(ordered_access_lengths)),
|
||||
make_tuple(typename arithmetic_sequence_gen<0, nDim, 1>::type{}),
|
||||
make_tuple(sequence<0>{}));
|
||||
|
||||
static constexpr auto I0 = number<0>{};
|
||||
static constexpr auto I1 = number<1>{};
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_access()
|
||||
{
|
||||
static_assert(TensorLengths::size() == ScalarsPerAccess::size());
|
||||
static_assert(TensorLengths{} % ScalarsPerAccess{} ==
|
||||
typename uniform_sequence_gen<TensorLengths::size(), 0>::type{});
|
||||
|
||||
return reduce_on_sequence(TensorLengths{}, multiplies{}, number<1>{}) / ScalarPerVector;
|
||||
}
|
||||
|
||||
template <index_t AccessIdx1dHead, index_t AccessIdx1dTail>
|
||||
static CK_TILE_HOST_DEVICE constexpr auto get_step_between(number<AccessIdx1dHead>,
|
||||
number<AccessIdx1dTail>)
|
||||
{
|
||||
static_assert(AccessIdx1dHead >= 0 && AccessIdx1dHead < get_num_of_access(),
|
||||
"1D index out of range");
|
||||
static_assert(AccessIdx1dTail >= 0 && AccessIdx1dTail < get_num_of_access(),
|
||||
"1D index out of range");
|
||||
|
||||
constexpr auto idx_head = get_index(number<AccessIdx1dHead>{});
|
||||
constexpr auto idx_tail = get_index(number<AccessIdx1dTail>{});
|
||||
return idx_tail - idx_head;
|
||||
}
|
||||
|
||||
template <index_t AccessIdx1d>
|
||||
static CK_TILE_HOST_DEVICE constexpr auto get_forward_step(number<AccessIdx1d>)
|
||||
{
|
||||
static_assert(AccessIdx1d < get_num_of_access(), "1D index should be larger than 0");
|
||||
return get_step_between(number<AccessIdx1d>{}, number<AccessIdx1d + 1>{});
|
||||
}
|
||||
|
||||
template <index_t AccessIdx1d>
|
||||
static CK_TILE_HOST_DEVICE constexpr auto get_backward_step(number<AccessIdx1d>)
|
||||
{
|
||||
static_assert(AccessIdx1d > 0, "1D index should be larger than 0");
|
||||
|
||||
return get_step_between(number<AccessIdx1d>{}, number<AccessIdx1d - 1>{});
|
||||
}
|
||||
|
||||
template <index_t AccessIdx1d>
|
||||
static CK_TILE_HOST_DEVICE constexpr Index get_index(number<AccessIdx1d>)
|
||||
{
|
||||
#if 0
|
||||
/*
|
||||
* \todo: tensor_adaptor::calculate_bottom_index does NOT return constexpr as expected.
|
||||
*/
|
||||
constexpr auto ordered_access_idx = to_index_adaptor.calculate_bottom_index(make_multi_index(number<AccessIdx1d>{}));
|
||||
#else
|
||||
|
||||
constexpr auto access_strides =
|
||||
container_reverse_exclusive_scan(ordered_access_lengths, multiplies{}, number<1>{});
|
||||
|
||||
constexpr auto idx_1d = number<AccessIdx1d>{};
|
||||
// Given tensor strides \p access_lengths, and 1D index of space-filling-curve, compute the
|
||||
// idim-th element of multidimensional index.
|
||||
// All constexpr variables have to be captured by VALUE.
|
||||
constexpr auto compute_index = [ idx_1d, access_strides ](auto idim) constexpr
|
||||
{
|
||||
constexpr auto compute_index_impl = [ idx_1d, access_strides ](auto jdim) constexpr
|
||||
{
|
||||
auto res = idx_1d.value;
|
||||
auto id = 0;
|
||||
|
||||
static_for<0, jdim.value + 1, 1>{}([&](auto kdim) {
|
||||
id = res / access_strides[kdim].value;
|
||||
res -= id * access_strides[kdim].value;
|
||||
});
|
||||
|
||||
return id;
|
||||
};
|
||||
|
||||
constexpr auto id = compute_index_impl(idim);
|
||||
return number<id>{};
|
||||
};
|
||||
|
||||
constexpr auto ordered_access_idx = generate_tuple(compute_index, number<nDim>{});
|
||||
#endif
|
||||
constexpr auto forward_sweep = [&]() {
|
||||
StaticallyIndexedArray<bool, nDim> forward_sweep_;
|
||||
|
||||
forward_sweep_(I0) = true;
|
||||
|
||||
static_for<1, nDim, 1>{}([&](auto idim) {
|
||||
index_t tmp = ordered_access_idx[I0];
|
||||
|
||||
static_for<1, idim, 1>{}(
|
||||
[&](auto j) { tmp = tmp * ordered_access_lengths[j] + ordered_access_idx[j]; });
|
||||
|
||||
forward_sweep_(idim) = tmp % 2 == 0;
|
||||
});
|
||||
|
||||
return forward_sweep_;
|
||||
}();
|
||||
|
||||
// calculate multi-dim tensor index
|
||||
auto idx_md = [&]() {
|
||||
Index ordered_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto idim) {
|
||||
ordered_idx(idim) =
|
||||
!SnakeCurved || forward_sweep[idim]
|
||||
? ordered_access_idx[idim]
|
||||
: ordered_access_lengths[idim] - 1 - ordered_access_idx[idim];
|
||||
});
|
||||
|
||||
return container_reorder_given_old2new(ordered_idx, dim_access_order) *
|
||||
ScalarsPerAccess{};
|
||||
}();
|
||||
return idx_md;
|
||||
}
|
||||
|
||||
// FIXME: rename this function
|
||||
template <index_t AccessIdx1d>
|
||||
static CK_TILE_HOST_DEVICE constexpr auto get_index_tuple_of_number(number<AccessIdx1d>)
|
||||
{
|
||||
constexpr auto idx = get_index(number<AccessIdx1d>{});
|
||||
|
||||
return generate_tuple([&](auto i) { return number<idx[i]>{}; }, number<nDim>{});
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
20
include/ck_tile/core/arch/amd_address_space.hpp
Normal file
20
include/ck_tile/core/arch/amd_address_space.hpp
Normal file
@@ -0,0 +1,20 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
// Address Space for AMDGCN
|
||||
// https://llvm.org/docs/AMDGPUUsage.html#address-space
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
enum struct address_space_enum
|
||||
{
|
||||
generic,
|
||||
global,
|
||||
lds,
|
||||
sgpr,
|
||||
vgpr,
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
2050
include/ck_tile/core/arch/amd_buffer_addressing.hpp
Normal file
2050
include/ck_tile/core/arch/amd_buffer_addressing.hpp
Normal file
File diff suppressed because it is too large
Load Diff
56
include/ck_tile/core/config.hpp
Normal file
56
include/ck_tile/core/config.hpp
Normal file
@@ -0,0 +1,56 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#ifdef __HIPCC__
|
||||
#define CK_TILE_HOST __host__
|
||||
#define CK_TILE_DEVICE __device__
|
||||
#define CK_TILE_HOST_DEVICE __host__ __device__
|
||||
#else
|
||||
#define CK_TILE_HOST inline
|
||||
#define CK_TILE_DEVICE inline
|
||||
#define CK_TILE_HOST_DEVICE inline
|
||||
#endif
|
||||
|
||||
#define CK_TILE_FLOAT_TO_BFLOAT16_STANDARD 0
|
||||
#define CK_TILE_FLOAT_TO_BFLOAT16_TRUNCATE_WITH_NAN 1
|
||||
#define CK_TILE_FLOAT_TO_BFLOAT16_TRUNCATE 2
|
||||
|
||||
#ifndef CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT
|
||||
#define CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT CK_TILE_FLOAT_TO_BFLOAT16_TRUNCATE
|
||||
#endif
|
||||
|
||||
#define CK_TILE_FLOAT_TO_FP8_STANDARD 0
|
||||
#define CK_TILE_FLOAT_TO_FP8_STOCHASTIC 1
|
||||
|
||||
#ifndef CK_TILE_FLOAT_TO_FP8_DEFAULT
|
||||
#define CK_TILE_FLOAT_TO_FP8_DEFAULT CK_TILE_FLOAT_TO_FP8_STANDARD
|
||||
#endif
|
||||
|
||||
#ifndef STATIC_ASSERT
|
||||
#ifndef NDEBUG
|
||||
#define STATIC_ASSERT(...) static_assert(__VA_ARGS__)
|
||||
#else
|
||||
#define STATIC_ASSERT(...)
|
||||
#endif
|
||||
#endif // #ifndef STATIC_ASSERT
|
||||
|
||||
// in the old rocm period, we have to use tuple array implementation to implement this
|
||||
// so turn on the _USE_TUPLE if meet compiler error, otherwise _USE_ARRAY by default.
|
||||
#define CK_TILE_STATICALLY_INDEXED_ARRAY_USE_ARRAY 0
|
||||
#define CK_TILE_STATICALLY_INDEXED_ARRAY_USE_TUPLE 1
|
||||
#ifndef CK_TILE_STATICALLY_INDEXED_ARRAY_DEFAULT
|
||||
#define CK_TILE_STATICALLY_INDEXED_ARRAY_DEFAULT CK_TILE_STATICALLY_INDEXED_ARRAY_USE_ARRAY
|
||||
#endif
|
||||
|
||||
#ifndef CK_TILE_USE_LAUNCH_BOUNDS
|
||||
#define CK_TILE_USE_LAUNCH_BOUNDS 1
|
||||
#endif
|
||||
|
||||
#ifndef CK_TILE_TIME_KERNEL
|
||||
#define CK_TILE_TIME_KERNEL 1
|
||||
#endif
|
||||
|
||||
#define CK_TILE_MAX_THREAD_PER_BLOCK 256
|
||||
#define CK_TILE_MIN_BLOCK_PER_CU 2
|
||||
201
include/ck_tile/core/container/array.hpp
Normal file
201
include/ck_tile/core/container/array.hpp
Normal file
@@ -0,0 +1,201 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// use aggregate initialization for this type
|
||||
// e.g. array<index_t, 4> buf {0}; => {0, 0, 0, 0}, clean
|
||||
// array<index_t, 4> buf {3, 2}; => {3, 2, 2, 2} (not {3,2,0,0})
|
||||
// use make_array_with({...}) to construct an array with compatible behavior as old ck
|
||||
// TODO: manually added constructor same as old ck
|
||||
template <typename T_, index_t N_>
|
||||
struct array
|
||||
{
|
||||
using value_type = T_;
|
||||
static constexpr index_t N = N_;
|
||||
value_type data[N];
|
||||
CK_TILE_HOST_DEVICE constexpr array() : data{} {}
|
||||
// TODO: will initialize the data[] with the last value repeatedly
|
||||
// behavior different from std
|
||||
CK_TILE_HOST_DEVICE constexpr array(std::initializer_list<value_type> ilist)
|
||||
{
|
||||
constexpr index_t list_size = std::initializer_list<value_type>{}.size();
|
||||
static_assert(list_size <= N, "out of bound");
|
||||
|
||||
index_t i = 0;
|
||||
value_type vlast = value_type{};
|
||||
|
||||
for(const value_type& val : ilist)
|
||||
{
|
||||
data[i] = val;
|
||||
vlast = val;
|
||||
++i;
|
||||
}
|
||||
for(; i < N; ++i)
|
||||
{
|
||||
data[i] = vlast;
|
||||
}
|
||||
}
|
||||
CK_TILE_HOST_DEVICE static constexpr auto size() { return N; }
|
||||
CK_TILE_HOST_DEVICE static constexpr bool is_static() { return is_static_v<value_type>; }
|
||||
|
||||
// clang-format off
|
||||
CK_TILE_HOST_DEVICE constexpr auto& get() { return data; }
|
||||
CK_TILE_HOST_DEVICE constexpr const auto& get() const { return data; }
|
||||
CK_TILE_HOST_DEVICE constexpr auto& get(index_t i) { return data[i]; }
|
||||
CK_TILE_HOST_DEVICE constexpr const auto& get(index_t i) const { return data[i]; }
|
||||
template <index_t I> CK_TILE_HOST_DEVICE constexpr auto& get() { return data[I]; }
|
||||
template <index_t I> CK_TILE_HOST_DEVICE constexpr const auto& get() const { return data[I]; }
|
||||
template <index_t I> CK_TILE_HOST_DEVICE constexpr auto& get(number<I>) { return data[I]; }
|
||||
template <index_t I> CK_TILE_HOST_DEVICE constexpr const auto& get(number<I>) const { return data[I]; }
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr auto& at(index_t i) { return data[i]; }
|
||||
CK_TILE_HOST_DEVICE constexpr const auto& at(index_t i) const { return data[i]; }
|
||||
template <index_t I> CK_TILE_HOST_DEVICE constexpr auto& at() { return data[I]; }
|
||||
template <index_t I> CK_TILE_HOST_DEVICE constexpr const auto& at() const { return data[I]; }
|
||||
template <index_t I> CK_TILE_HOST_DEVICE constexpr auto& at(number<I>) { return data[I]; }
|
||||
template <index_t I> CK_TILE_HOST_DEVICE constexpr const auto& at(number<I>) const { return data[I]; }
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr const value_type& operator[](index_t i) const { return data[i]; }
|
||||
CK_TILE_HOST_DEVICE constexpr value_type& operator[](index_t i) { return data[i]; }
|
||||
CK_TILE_HOST_DEVICE constexpr value_type& operator()(index_t i) { return data[i]; } // TODO: compatible
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator=(const T& a)
|
||||
{
|
||||
static_assert(T::size() == size(), "wrong! size not the same");
|
||||
for(index_t i = 0; i < size(); ++i)
|
||||
{
|
||||
data[i] = a[i];
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
// type punning (strict aliasing) member functions for read/write
|
||||
// aliasing this array of type "T", "N" elements
|
||||
// as array of type "Tx", sizeof(T)*N/sizeof(Tx) elements
|
||||
#define AR_AS_COM_() \
|
||||
static_assert(sizeof(value_type) * N % sizeof(Tx) == 0); \
|
||||
constexpr int vx = sizeof(value_type) * N / sizeof(Tx)
|
||||
|
||||
template <typename Tx> CK_TILE_HOST_DEVICE constexpr auto& get_as()
|
||||
{ AR_AS_COM_(); return reinterpret_cast<array<Tx, vx>&>(data); }
|
||||
template <typename Tx> CK_TILE_HOST_DEVICE constexpr const auto& get_as() const
|
||||
{ AR_AS_COM_(); return reinterpret_cast<const array<Tx, vx>&>(data); }
|
||||
|
||||
// below index is for index *AFTER* type convert, not before
|
||||
template <typename Tx> CK_TILE_HOST_DEVICE constexpr auto& get_as(index_t i)
|
||||
{ AR_AS_COM_(); return reinterpret_cast<array<Tx, vx>&>(data).at(i); }
|
||||
template <typename Tx> CK_TILE_HOST_DEVICE constexpr const auto& get_as(index_t i) const
|
||||
{ AR_AS_COM_(); return reinterpret_cast<const array<Tx, vx>&>(data).at(i); }
|
||||
template <typename Tx, index_t I> CK_TILE_HOST_DEVICE constexpr auto& get_as(number<I>)
|
||||
{ AR_AS_COM_(); return reinterpret_cast<array<Tx, vx>&>(data).at(number<I>{}); }
|
||||
template <typename Tx, index_t I> CK_TILE_HOST_DEVICE constexpr const auto& get_as(number<I>) const
|
||||
{ AR_AS_COM_(); return reinterpret_cast<const array<Tx, vx>&>(data).at(number<I>{}); }
|
||||
|
||||
template <typename Tx> CK_TILE_HOST_DEVICE constexpr void set_as(index_t i, const Tx & x)
|
||||
{ AR_AS_COM_(); reinterpret_cast<array<Tx, vx>&>(data).at(i) = x; }
|
||||
template <typename Tx, index_t I> CK_TILE_HOST_DEVICE constexpr void set_as(number<I>, const Tx & x)
|
||||
{ AR_AS_COM_(); reinterpret_cast<array<Tx, vx>&>(data).at(number<I>{}) = x; }
|
||||
#undef AR_AS_COM_
|
||||
// clang-format on
|
||||
};
|
||||
|
||||
// empty Array
|
||||
|
||||
template <typename T>
|
||||
struct array<T, 0>
|
||||
{
|
||||
using value_type = T;
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr array() {}
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t size() { return 0; }
|
||||
CK_TILE_HOST_DEVICE static constexpr bool is_static() { return is_static_v<T>; };
|
||||
CK_TILE_HOST_DEVICE void print() const { printf("array{size: 0, data: []}"); }
|
||||
};
|
||||
|
||||
template <typename T, typename... Ts>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_array(T&& x, Ts&&... xs)
|
||||
{
|
||||
using value_type = remove_cvref_t<T>;
|
||||
return array<value_type, sizeof...(Ts) + 1>{std::forward<T>(x), std::forward<Ts>(xs)...};
|
||||
}
|
||||
|
||||
// make empty array
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_array()
|
||||
{
|
||||
return array<T, 0>{};
|
||||
}
|
||||
|
||||
// compatible with old ck's initializer, make an array and fill it withe the last element from
|
||||
// initializer_list
|
||||
#include <initializer_list>
|
||||
template <typename T, index_t Size>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_array_with(std::initializer_list<T> ilist)
|
||||
{
|
||||
constexpr index_t list_size = std::initializer_list<T>{}.size();
|
||||
|
||||
static_assert(list_size <= Size, "out of bound");
|
||||
|
||||
index_t i = 0;
|
||||
T vlast = T{};
|
||||
array<T, Size> arr;
|
||||
|
||||
for(const T& val : ilist)
|
||||
{
|
||||
arr.data[i] = val;
|
||||
vlast = val;
|
||||
++i;
|
||||
}
|
||||
|
||||
for(; i < Size; ++i)
|
||||
{
|
||||
arr.data[i] = vlast;
|
||||
}
|
||||
|
||||
return arr;
|
||||
}
|
||||
|
||||
template <typename T, index_t Size>
|
||||
CK_TILE_HOST_DEVICE constexpr bool operator==(const array<T, Size>& a, const array<T, Size>& b)
|
||||
{
|
||||
bool same = true;
|
||||
|
||||
for(index_t i = 0; i < Size; ++i)
|
||||
{
|
||||
if(a[i] != b[i])
|
||||
{
|
||||
same = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return same;
|
||||
}
|
||||
|
||||
template <typename T, index_t Size>
|
||||
CK_TILE_HOST_DEVICE constexpr bool operator!=(const array<T, Size>& a, const array<T, Size>& b)
|
||||
{
|
||||
return !(a == b);
|
||||
}
|
||||
|
||||
template <typename T, index_t N, typename X>
|
||||
CK_TILE_HOST_DEVICE constexpr auto to_array(const X& x)
|
||||
{
|
||||
STATIC_ASSERT(N <= X::size(), "");
|
||||
|
||||
array<T, N> arr;
|
||||
|
||||
static_for<0, N, 1>{}([&x, &arr](auto i) { arr(i) = x[i]; });
|
||||
|
||||
return arr;
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
483
include/ck_tile/core/container/container_helper.hpp
Normal file
483
include/ck_tile/core/container/container_helper.hpp
Normal file
@@ -0,0 +1,483 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/container/array.hpp"
|
||||
#include "ck_tile/core/container/map.hpp"
|
||||
#include "ck_tile/core/container/sequence.hpp"
|
||||
#include "ck_tile/core/container/tuple.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename TData, index_t NSize>
|
||||
CK_TILE_HOST_DEVICE constexpr auto container_push_back(const array<TData, NSize>& a, const TData& x)
|
||||
{
|
||||
array<TData, NSize + 1> r;
|
||||
static_for<0, NSize, 1>{}([&r, &a ](auto i) constexpr { r(i) = a[i]; });
|
||||
r[number<NSize>{}] = x;
|
||||
return r;
|
||||
}
|
||||
|
||||
template <typename... Ts, typename T>
|
||||
CK_TILE_HOST_DEVICE constexpr auto container_push_front(const tuple<Ts...>& a, const T& x)
|
||||
{
|
||||
return container_concat(make_tuple(x), a);
|
||||
}
|
||||
|
||||
template <typename... Ts, typename T>
|
||||
CK_TILE_HOST_DEVICE constexpr auto container_push_back(const tuple<Ts...>& a, const T& x)
|
||||
{
|
||||
return container_concat(a, make_tuple(x));
|
||||
}
|
||||
|
||||
// reorder array
|
||||
template <typename TData, index_t NSize, index_t... IRs>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
container_reorder_given_new2old(const array<TData, NSize>& old_array, sequence<IRs...> /*new2old*/)
|
||||
{
|
||||
static_assert(NSize == sizeof...(IRs), "wrong! size not consistent");
|
||||
static_assert(is_valid_sequence_map<sequence<IRs...>>{}, "wrong! invalid reorder map");
|
||||
return make_array<remove_cvref_t<TData>>(old_array[IRs]...);
|
||||
}
|
||||
|
||||
template <typename TData, index_t NSize, index_t... IRs>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
container_reorder_given_old2new(const array<TData, NSize>& old_array, sequence<IRs...> old2new)
|
||||
{
|
||||
return container_reorder_given_new2old(
|
||||
old_array, typename sequence_map_inverse<decltype(old2new)>::type{});
|
||||
}
|
||||
|
||||
// reorder array
|
||||
template <typename TData, index_t NSize>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
container_reorder_given_new2old(const array<TData, NSize>& old_array,
|
||||
const map<index_t, index_t>& new2old)
|
||||
{
|
||||
array<TData, NSize> new_array;
|
||||
|
||||
for(const auto& [new_pos, old_pos] : new2old)
|
||||
{
|
||||
new_array(new_pos) = old_array[old_pos];
|
||||
}
|
||||
|
||||
return new_array;
|
||||
}
|
||||
|
||||
template <typename TData, index_t NSize>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
container_reorder_given_old2new(const array<TData, NSize>& old_array,
|
||||
const map<index_t, index_t>& old2new)
|
||||
{
|
||||
array<TData, NSize> new_array;
|
||||
|
||||
for(const auto& [old_pos, new_pos] : old2new)
|
||||
{
|
||||
new_array(new_pos) = old_array[old_pos];
|
||||
}
|
||||
|
||||
return new_array;
|
||||
}
|
||||
|
||||
// reorder tuple
|
||||
template <typename... Ts, index_t... IRs>
|
||||
CK_TILE_HOST_DEVICE constexpr auto container_reorder_given_new2old(const tuple<Ts...>& old_tuple,
|
||||
sequence<IRs...> /*new2old*/)
|
||||
{
|
||||
static_assert(sizeof...(Ts) == sizeof...(IRs), "wrong! size not consistent");
|
||||
|
||||
static_assert(is_valid_sequence_map<sequence<IRs...>>{}, "wrong! invalid reorder map");
|
||||
|
||||
return make_tuple(old_tuple[number<IRs>{}]...);
|
||||
}
|
||||
|
||||
template <typename... Ts, index_t... IRs>
|
||||
CK_TILE_HOST_DEVICE constexpr auto container_reorder_given_old2new(const tuple<Ts...>& old_tuple,
|
||||
sequence<IRs...> old2new)
|
||||
{
|
||||
return container_reorder_given_new2old(
|
||||
old_tuple, typename sequence_map_inverse<decltype(old2new)>::type{});
|
||||
}
|
||||
|
||||
// reorder sequence
|
||||
template <index_t... Is, index_t... IRs>
|
||||
CK_TILE_HOST_DEVICE constexpr auto container_reorder_given_new2old(sequence<Is...> /* old_seq */,
|
||||
sequence<IRs...> /*new2old*/)
|
||||
{
|
||||
static_assert(sizeof...(Is) == sizeof...(IRs), "wrong! size not consistent");
|
||||
|
||||
static_assert(is_valid_sequence_map<sequence<IRs...>>{}, "wrong! invalid reorder map");
|
||||
|
||||
return sequence<sequence<Is...>::at(number<IRs>{})...>{};
|
||||
}
|
||||
|
||||
template <index_t... Is, index_t... IRs>
|
||||
CK_TILE_HOST_DEVICE constexpr auto container_reorder_given_old2new(sequence<Is...> old_seq,
|
||||
sequence<IRs...> /* old2new */)
|
||||
{
|
||||
static_assert(sizeof...(Is) == sizeof...(IRs), "wrong! size not consistent");
|
||||
|
||||
static_assert(is_valid_sequence_map<sequence<IRs...>>{}, "wrong! invalid reorder map");
|
||||
|
||||
constexpr auto new2old = typename sequence_map_inverse<sequence<IRs...>>::type{};
|
||||
|
||||
return container_reorder_given_new2old(old_seq, new2old);
|
||||
}
|
||||
|
||||
#if 0
|
||||
// rocm-4.1 compiler would crash for recursive lambda
|
||||
template <typename Container,
|
||||
typename Reduce,
|
||||
typename Init,
|
||||
index_t IBegin = 0,
|
||||
index_t IEnd = Container::size(),
|
||||
index_t IStep = 1>
|
||||
CK_TILE_HOST_DEVICE constexpr auto container_reduce(const Container& x,
|
||||
Reduce reduce,
|
||||
Init init,
|
||||
number<IBegin> = number<0>{},
|
||||
number<IEnd> = number<Container::size()>{},
|
||||
number<IStep> = number<1>{})
|
||||
{
|
||||
static_assert((IEnd - IBegin) % IStep == 0, "wrong!");
|
||||
|
||||
// f is recursive function, fs is a dummy of f
|
||||
// i is index, y_old is current scan, r_old is current reduction
|
||||
auto f = [&](auto fs, auto i, auto r_old) {
|
||||
auto r_new = reduce(x[i], r_old);
|
||||
|
||||
if constexpr(i.value < IEnd - IStep)
|
||||
{
|
||||
// recursively call f/fs
|
||||
return fs(fs, i + number<IStep>{}, r_new);
|
||||
}
|
||||
else
|
||||
{
|
||||
return r_new;
|
||||
}
|
||||
};
|
||||
|
||||
// start recursion
|
||||
return f(f, number<IBegin>{}, init);
|
||||
}
|
||||
#else
|
||||
// i is index, y_old is current scan, r_old is current reduction
|
||||
template <typename Container,
|
||||
typename Reduce,
|
||||
typename ROld,
|
||||
index_t I,
|
||||
index_t IEnd,
|
||||
index_t IStep>
|
||||
CK_TILE_HOST_DEVICE constexpr auto container_reduce_impl(
|
||||
const Container& x, Reduce reduce, ROld r_old, number<I> i, number<IEnd>, number<IStep>)
|
||||
{
|
||||
auto r_new = reduce(x[i], r_old);
|
||||
|
||||
if constexpr(i.value < IEnd - IStep)
|
||||
{
|
||||
return container_reduce_impl(
|
||||
x, reduce, r_new, i + number<IStep>{}, number<IEnd>{}, number<IStep>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return r_new;
|
||||
}
|
||||
}
|
||||
|
||||
// rocm-4.1 compiler would crash for recursive lambda
|
||||
// container reduce with initial value
|
||||
template <typename Container,
|
||||
typename Reduce,
|
||||
typename Init,
|
||||
index_t IBegin = 0,
|
||||
index_t IEnd = Container::size(),
|
||||
index_t IStep = 1>
|
||||
CK_TILE_HOST_DEVICE constexpr auto container_reduce(const Container& x,
|
||||
Reduce reduce,
|
||||
Init init,
|
||||
number<IBegin> = number<0>{},
|
||||
number<IEnd> = number<Container::size()>{},
|
||||
number<IStep> = number<1>{})
|
||||
{
|
||||
static_assert((IEnd - IBegin) % IStep == 0, "wrong!");
|
||||
|
||||
if constexpr(IEnd > IBegin)
|
||||
{
|
||||
return container_reduce_impl(
|
||||
x, reduce, init, number<IBegin>{}, number<IEnd>{}, number<IStep>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return init;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
template <typename TData, index_t NSize, typename Reduce>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
container_reverse_inclusive_scan(const array<TData, NSize>& x, Reduce f, TData init)
|
||||
{
|
||||
array<TData, NSize> y;
|
||||
|
||||
TData r = init;
|
||||
|
||||
static_for<NSize - 1, 0, -1>{}([&](auto i) {
|
||||
r = f(r, x[i]);
|
||||
y(i) = r;
|
||||
});
|
||||
|
||||
r = f(r, x[number<0>{}]);
|
||||
y(number<0>{}) = r;
|
||||
|
||||
return y;
|
||||
}
|
||||
|
||||
template <typename TData, index_t NSize, typename Reduce, typename Init>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
container_reverse_exclusive_scan(const array<TData, NSize>& x, Reduce f, Init init)
|
||||
{
|
||||
#if 0
|
||||
array<TData, NSize> y;
|
||||
|
||||
TData r = init;
|
||||
|
||||
static_for<NSize - 1, 0, -1>{}([&](auto i) {
|
||||
y(i) = r;
|
||||
r = f(r, x[i]);
|
||||
});
|
||||
|
||||
y(number<0>{}) = r;
|
||||
|
||||
return y;
|
||||
#else
|
||||
array<TData, NSize> y;
|
||||
|
||||
TData r = init;
|
||||
|
||||
for(index_t i = NSize - 1; i > 0; --i)
|
||||
{
|
||||
y(i) = r;
|
||||
r = f(r, x[i]);
|
||||
}
|
||||
|
||||
y(0) = r;
|
||||
|
||||
return y;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <index_t... Is, typename Reduce, index_t Init>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
container_reverse_exclusive_scan(const sequence<Is...>& seq, Reduce f, number<Init>)
|
||||
{
|
||||
return reverse_exclusive_scan_sequence(seq, f, number<Init>{});
|
||||
}
|
||||
|
||||
#if 0
|
||||
// rocm4.1 compiler would crash with recursive lambda
|
||||
template <typename... Xs, typename Reduce, typename Init>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
container_reverse_exclusive_scan(const tuple<Xs...>& x, Reduce reduce, Init init)
|
||||
{
|
||||
constexpr index_t NSize = sizeof...(Xs);
|
||||
|
||||
// f is recursive function, fs is a dummy of f
|
||||
// i is index, y_old is current scan, r_old is current reduction
|
||||
auto f = [&](auto fs, auto i, auto y_old, auto r_old) {
|
||||
auto r_new = reduce(x[i], r_old);
|
||||
|
||||
auto y_new = container_push_front(y_old, r_new);
|
||||
|
||||
if constexpr(i.value > 1)
|
||||
{
|
||||
// recursively call f/fs
|
||||
return fs(fs, i - number<1>{}, y_new, r_new);
|
||||
}
|
||||
else
|
||||
{
|
||||
return y_new;
|
||||
}
|
||||
};
|
||||
|
||||
// start recursion
|
||||
return f(f, number<NSize - 1>{}, make_tuple(init), init);
|
||||
}
|
||||
#else
|
||||
// i is index, y_old is current scan, r_old is current reduction
|
||||
template <typename... Xs, typename Reduce, index_t I, typename YOld, typename ROld>
|
||||
CK_TILE_HOST_DEVICE constexpr auto container_reverse_exclusive_scan_impl(
|
||||
const tuple<Xs...>& x, Reduce reduce, number<I> i, YOld y_old, ROld r_old)
|
||||
{
|
||||
auto r_new = reduce(x[i], r_old);
|
||||
|
||||
auto y_new = container_push_front(y_old, r_new);
|
||||
|
||||
if constexpr(i.value > 1)
|
||||
{
|
||||
// recursively call f/fs
|
||||
return container_reverse_exclusive_scan_impl(x, reduce, i - number<1>{}, y_new, r_new);
|
||||
}
|
||||
else
|
||||
{
|
||||
return y_new;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename... Xs, typename Reduce, typename Init>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
container_reverse_exclusive_scan(const tuple<Xs...>& x, Reduce reduce, Init init)
|
||||
{
|
||||
constexpr index_t NSize = sizeof...(Xs);
|
||||
|
||||
return container_reverse_exclusive_scan_impl(
|
||||
x, reduce, number<NSize - 1>{}, make_tuple(init), init);
|
||||
}
|
||||
#endif
|
||||
|
||||
// TODO: update to like container_reverse_exclusive_scan to deal with tuple of Numebr<>
|
||||
template <typename... Xs, typename Reduce, typename TData>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
container_reverse_inclusive_scan(const tuple<Xs...>& x, Reduce f, TData init)
|
||||
{
|
||||
constexpr index_t NSize = sizeof...(Xs);
|
||||
|
||||
tuple<Xs...> y;
|
||||
|
||||
TData r = init;
|
||||
|
||||
static_for<NSize - 1, 0, -1>{}([&](auto i) {
|
||||
r = f(r, x[i]);
|
||||
y(i) = r;
|
||||
});
|
||||
|
||||
r = f(r, x[number<0>{}]);
|
||||
y(number<0>{}) = r;
|
||||
|
||||
return y;
|
||||
}
|
||||
|
||||
template <typename X, typename... Ys>
|
||||
CK_TILE_HOST_DEVICE constexpr auto container_concat(const X& x, const Ys&... ys)
|
||||
{
|
||||
return container_concat(x, container_concat(ys...));
|
||||
}
|
||||
|
||||
template <typename T, index_t NX, index_t NY>
|
||||
CK_TILE_HOST_DEVICE constexpr auto container_concat(const array<T, NX>& ax, const array<T, NY>& ay)
|
||||
{
|
||||
return unpack2(
|
||||
[&](auto&&... zs) { return make_array<T>(std::forward<decltype(zs)>(zs)...); }, ax, ay);
|
||||
}
|
||||
|
||||
template <typename... X, typename... Y>
|
||||
CK_TILE_HOST_DEVICE constexpr auto container_concat(const tuple<X...>& tx, const tuple<Y...>& ty)
|
||||
{
|
||||
return unpack2(
|
||||
[&](auto&&... zs) { return make_tuple(std::forward<decltype(zs)>(zs)...); }, tx, ty);
|
||||
}
|
||||
|
||||
template <typename Container>
|
||||
CK_TILE_HOST_DEVICE constexpr auto container_concat(const Container& x)
|
||||
{
|
||||
return x;
|
||||
}
|
||||
|
||||
template <typename T, index_t N, index_t... Is>
|
||||
CK_TILE_HOST_DEVICE constexpr auto get_container_subset(const array<T, N>& arr, sequence<Is...>)
|
||||
{
|
||||
static_assert(N >= sizeof...(Is), "wrong! size");
|
||||
|
||||
if constexpr(sizeof...(Is) > 0)
|
||||
{
|
||||
return make_array<T>(arr[Is]...);
|
||||
}
|
||||
else
|
||||
{
|
||||
return array<T, 0>{};
|
||||
}
|
||||
}
|
||||
|
||||
template <typename... Ts, index_t... Is>
|
||||
CK_TILE_HOST_DEVICE constexpr auto get_container_subset(const tuple<Ts...>& tup, sequence<Is...>)
|
||||
{
|
||||
static_assert(sizeof...(Ts) >= sizeof...(Is), "wrong! size");
|
||||
|
||||
if constexpr(sizeof...(Is) > 0)
|
||||
{
|
||||
return make_tuple(tup[number<Is>{}]...);
|
||||
}
|
||||
else
|
||||
{
|
||||
return tuple<>{};
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, index_t N, index_t... Is>
|
||||
CK_TILE_HOST_DEVICE constexpr void
|
||||
set_container_subset(array<T, N>& y, sequence<Is...> picks, const array<T, sizeof...(Is)>& x)
|
||||
{
|
||||
static_assert(N >= sizeof...(Is), "wrong! size");
|
||||
|
||||
if constexpr(sizeof...(Is) > 0)
|
||||
{
|
||||
for(index_t i = 0; i < picks.size(); ++i)
|
||||
{
|
||||
y(picks[i]) = x[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Y, typename X, index_t... Is>
|
||||
CK_TILE_HOST_DEVICE constexpr void set_container_subset(Y& y, sequence<Is...> picks, const X& x)
|
||||
{
|
||||
static_assert(Y::size() >= sizeof...(Is) && X::size() == sizeof...(Is), "wrong! size");
|
||||
|
||||
if constexpr(sizeof...(Is) > 0)
|
||||
{
|
||||
static_for<0, sizeof...(Is), 1>{}([&](auto i) { y(picks[i]) = x[i]; });
|
||||
}
|
||||
}
|
||||
|
||||
// return the index of first occurance in the sequence.
|
||||
// return seq.size(), if not found
|
||||
template <index_t... Is>
|
||||
constexpr index_t container_find(sequence<Is...> seq, index_t value)
|
||||
{
|
||||
for(auto i = 0; i < seq.size(); i++)
|
||||
{
|
||||
if(seq[i] == value)
|
||||
return i;
|
||||
}
|
||||
|
||||
return seq.size();
|
||||
}
|
||||
|
||||
template <index_t... Is>
|
||||
CK_TILE_HOST_DEVICE constexpr auto sequence_to_tuple_of_number(sequence<Is...>)
|
||||
{
|
||||
using Seq = sequence<Is...>;
|
||||
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
constexpr index_t tmp = Seq::at(i);
|
||||
return number<tmp>{};
|
||||
},
|
||||
number<Seq::size()>{});
|
||||
}
|
||||
|
||||
#define TO_TUPLE_OF_SEQUENCE(a_of_b_impl, a_size, bs_sizes) \
|
||||
[a_of_b_impl, a_size, bs_sizes] { \
|
||||
return ck_tile::generate_tuple( \
|
||||
[=](auto i) { \
|
||||
constexpr auto b_impl = a_of_b_impl[i]; \
|
||||
constexpr index_t b_size = bs_sizes[i]; \
|
||||
constexpr auto b = TO_SEQUENCE(b_impl, b_size); \
|
||||
return b; \
|
||||
}, \
|
||||
ck_tile::number<a_size>{}); \
|
||||
}()
|
||||
|
||||
} // namespace ck_tile
|
||||
164
include/ck_tile/core/container/map.hpp
Normal file
164
include/ck_tile/core/container/map.hpp
Normal file
@@ -0,0 +1,164 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/container/array.hpp"
|
||||
#include "ck_tile/core/container/sequence.hpp"
|
||||
#include "ck_tile/core/container/tuple.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// naive map
|
||||
template <typename key, typename data, index_t max_size = 128>
|
||||
struct map
|
||||
{
|
||||
using pair_type = tuple<key, data>;
|
||||
using impl_type = array<pair_type, max_size>;
|
||||
|
||||
impl_type impl_;
|
||||
index_t size_;
|
||||
|
||||
struct iterator
|
||||
{
|
||||
impl_type& impl_;
|
||||
index_t pos_;
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr iterator(impl_type& impl, index_t pos)
|
||||
: impl_{impl}, pos_{pos}
|
||||
{
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr iterator& operator++()
|
||||
{
|
||||
pos_++;
|
||||
return *this;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr bool operator!=(const iterator& other) const
|
||||
{
|
||||
return other.pos_ != pos_;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr pair_type& operator*() { return impl_.at(pos_); }
|
||||
};
|
||||
|
||||
struct const_iterator
|
||||
{
|
||||
const impl_type& impl_;
|
||||
index_t pos_;
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr const_iterator(const impl_type& impl, index_t pos)
|
||||
: impl_{impl}, pos_{pos}
|
||||
{
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr const_iterator& operator++()
|
||||
{
|
||||
pos_++;
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr bool operator!=(const const_iterator& other) const
|
||||
{
|
||||
return other.pos_ != pos_;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr const pair_type& operator*() const { return impl_.at(pos_); }
|
||||
};
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr map() : impl_{}, size_{0} {}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr index_t size() const { return size_; }
|
||||
|
||||
CK_TILE_HOST_DEVICE void clear() { size_ = 0; }
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr index_t find_position(const key& key) const
|
||||
{
|
||||
for(index_t i = 0; i < size(); i++)
|
||||
{
|
||||
if(impl_[i].template at<0>() == key)
|
||||
{
|
||||
return i;
|
||||
}
|
||||
}
|
||||
|
||||
return size_;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr const_iterator find(const key& key) const
|
||||
{
|
||||
return const_iterator{impl_, find_position(key)};
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr iterator find(const key& key)
|
||||
{
|
||||
return iterator{impl_, find_position(key)};
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr const data& operator[](const key& key) const
|
||||
{
|
||||
const auto it = find(key);
|
||||
|
||||
// FIXME
|
||||
assert(it.pos_ < size());
|
||||
|
||||
return impl_[it.pos_].template at<1>();
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr data& operator()(const key& key)
|
||||
{
|
||||
auto it = find(key);
|
||||
|
||||
// if entry not found
|
||||
if(it.pos_ == size())
|
||||
{
|
||||
impl_(it.pos_).template at<0>() = key;
|
||||
size_++;
|
||||
}
|
||||
|
||||
// FIXME
|
||||
assert(size_ <= max_size);
|
||||
|
||||
return impl_(it.pos_).template at<1>();
|
||||
}
|
||||
|
||||
// WARNING: needed by compiler for C++ range-based for loop only, don't use this function!
|
||||
CK_TILE_HOST_DEVICE constexpr const_iterator begin() const { return const_iterator{impl_, 0}; }
|
||||
|
||||
// WARNING: needed by compiler for C++ range-based for loop only, don't use this function!
|
||||
CK_TILE_HOST_DEVICE constexpr const_iterator end() const
|
||||
{
|
||||
return const_iterator{impl_, size_};
|
||||
}
|
||||
|
||||
// WARNING: needed by compiler for C++ range-based for loop only, don't use this function!
|
||||
CK_TILE_HOST_DEVICE constexpr iterator begin() { return iterator{impl_, 0}; }
|
||||
|
||||
// WARNING: needed by compiler for C++ range-based for loop only, don't use this function!
|
||||
CK_TILE_HOST_DEVICE constexpr iterator end() { return iterator{impl_, size_}; }
|
||||
|
||||
CK_TILE_HOST_DEVICE void print() const
|
||||
{
|
||||
printf("map{size_: %d, ", size_);
|
||||
//
|
||||
printf("impl_: [");
|
||||
//
|
||||
for(const auto& [key, data] : *this)
|
||||
{
|
||||
printf("{key: ");
|
||||
print(key);
|
||||
printf(", data: ");
|
||||
print(data);
|
||||
printf("}, ");
|
||||
}
|
||||
//
|
||||
printf("]");
|
||||
//
|
||||
printf("}");
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
99
include/ck_tile/core/container/meta_data_buffer.hpp
Normal file
99
include/ck_tile/core/container/meta_data_buffer.hpp
Normal file
@@ -0,0 +1,99 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/container/array.hpp"
|
||||
#include "ck_tile/core/utility/bit_cast.hpp"
|
||||
#include <cstddef>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// TODO: this structure is not intented to be used by user
|
||||
template <index_t MaxSize>
|
||||
struct meta_data_buffer
|
||||
{
|
||||
CK_TILE_HOST_DEVICE constexpr meta_data_buffer() : buffer_{}, size_{0} {}
|
||||
|
||||
template <typename X, typename... Xs>
|
||||
CK_TILE_HOST_DEVICE constexpr meta_data_buffer(const X& x, const Xs&... xs)
|
||||
: buffer_{}, size_{0}
|
||||
{
|
||||
push(x, xs...);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE constexpr void push(const T& data)
|
||||
{
|
||||
if constexpr(!std::is_empty_v<T>)
|
||||
{
|
||||
constexpr index_t size = sizeof(T);
|
||||
|
||||
auto tmp = bit_cast<array<std::byte, size>>(data);
|
||||
|
||||
for(int i = 0; i < size; i++)
|
||||
{
|
||||
buffer_(size_) = tmp[i];
|
||||
|
||||
size_++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename X, typename... Xs>
|
||||
CK_TILE_HOST_DEVICE constexpr void push(const X& x, const Xs&... xs)
|
||||
{
|
||||
push(x);
|
||||
push(xs...);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE constexpr T pop(index_t& pos) const
|
||||
{
|
||||
T data;
|
||||
|
||||
if constexpr(!std::is_empty_v<T>)
|
||||
{
|
||||
constexpr index_t size = sizeof(T);
|
||||
|
||||
array<std::byte, size> tmp;
|
||||
|
||||
for(int i = 0; i < size; i++)
|
||||
{
|
||||
tmp(i) = buffer_[pos];
|
||||
|
||||
pos++;
|
||||
}
|
||||
|
||||
data = bit_cast<T>(tmp);
|
||||
}
|
||||
|
||||
return data;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE constexpr T get(index_t pos) const
|
||||
{
|
||||
constexpr index_t size = sizeof(T);
|
||||
|
||||
array<std::byte, size> tmp;
|
||||
|
||||
for(int i = 0; i < size; i++)
|
||||
{
|
||||
tmp(i) = buffer_[pos];
|
||||
|
||||
pos++;
|
||||
}
|
||||
|
||||
auto data = bit_cast<T>(tmp);
|
||||
|
||||
return data;
|
||||
}
|
||||
|
||||
//
|
||||
array<std::byte, MaxSize> buffer_;
|
||||
index_t size_ = 0;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
99
include/ck_tile/core/container/multi_index.hpp
Normal file
99
include/ck_tile/core/container/multi_index.hpp
Normal file
@@ -0,0 +1,99 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/container/array.hpp"
|
||||
#include "ck_tile/core/container/sequence.hpp"
|
||||
#include "ck_tile/core/container/tuple.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// deprecated, always use array instead
|
||||
template <index_t N>
|
||||
using multi_index = array<index_t, N>;
|
||||
|
||||
template <typename... Xs>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_multi_index(Xs&&... xs)
|
||||
{
|
||||
return make_array<index_t>(index_t{xs}...);
|
||||
}
|
||||
|
||||
template <index_t NSize>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_zero_multi_index()
|
||||
{
|
||||
return unpack([](auto... xs) { return make_multi_index(xs...); },
|
||||
typename uniform_sequence_gen<NSize, 0>::type{});
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE constexpr auto to_multi_index(const T& x)
|
||||
{
|
||||
return unpack([](auto... ys) { return make_multi_index(ys...); }, x);
|
||||
}
|
||||
|
||||
template <index_t NSize, typename X>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator+=(multi_index<NSize>& y, const X& x)
|
||||
{
|
||||
static_assert(X::size() == NSize, "wrong! size not the same");
|
||||
static_for<0, NSize, 1>{}([&](auto i) { y[i] += x[i]; });
|
||||
return y;
|
||||
}
|
||||
|
||||
template <index_t NSize, typename X>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator-=(multi_index<NSize>& y, const X& x)
|
||||
{
|
||||
static_assert(X::size() == NSize, "wrong! size not the same");
|
||||
static_for<0, NSize, 1>{}([&](auto i) { y[i] -= x[i]; });
|
||||
return y;
|
||||
}
|
||||
|
||||
template <index_t NSize, typename T>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator+(const multi_index<NSize>& a, const T& b)
|
||||
{
|
||||
using type = multi_index<NSize>;
|
||||
static_assert(T::size() == NSize, "wrong! size not the same");
|
||||
type r;
|
||||
static_for<0, NSize, 1>{}([&](auto i) { r[i] = a[i] + b[i]; });
|
||||
return r;
|
||||
}
|
||||
|
||||
template <index_t NSize, typename T>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator-(const multi_index<NSize>& a, const T& b)
|
||||
{
|
||||
using type = multi_index<NSize>;
|
||||
static_assert(T::size() == NSize, "wrong! size not the same");
|
||||
type r;
|
||||
static_for<0, NSize, 1>{}([&](auto i) { r[i] = a[i] - b[i]; });
|
||||
return r;
|
||||
}
|
||||
|
||||
template <index_t NSize, typename T>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator*(const multi_index<NSize>& a, const T& b)
|
||||
{
|
||||
using type = multi_index<NSize>;
|
||||
static_assert(T::size() == NSize, "wrong! size not the same");
|
||||
type r;
|
||||
static_for<0, NSize, 1>{}([&](auto i) { r[i] = a[i] * b[i]; });
|
||||
return r;
|
||||
}
|
||||
|
||||
// multi_index = index_t * multi_index
|
||||
template <index_t NSize>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator*(index_t a, const multi_index<NSize>& x)
|
||||
{
|
||||
multi_index<NSize> r;
|
||||
static_for<0, NSize, 1>{}([&](auto i) { r[i] = a * x[i]; });
|
||||
return r;
|
||||
}
|
||||
|
||||
// multi_index = multi_index * index_t
|
||||
template <index_t NSize>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator*(const multi_index<NSize>& x, index_t a)
|
||||
{
|
||||
return a * x;
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
1114
include/ck_tile/core/container/sequence.hpp
Normal file
1114
include/ck_tile/core/container/sequence.hpp
Normal file
File diff suppressed because it is too large
Load Diff
78
include/ck_tile/core/container/span.hpp
Normal file
78
include/ck_tile/core/container/span.hpp
Normal file
@@ -0,0 +1,78 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include <cstddef>
|
||||
#include <array>
|
||||
#include <type_traits>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// implement the c++20 std::span, lightweight, non-owning reference to a sequence
|
||||
// weather it is dynamic or static range. Or can be seen as a view of a contiguous sequence
|
||||
// TODO: do we need in device consider this is pointer?
|
||||
template <typename T>
|
||||
class span
|
||||
{
|
||||
public:
|
||||
using element_type = T;
|
||||
using value_type = std::remove_cv_t<element_type>;
|
||||
using size_type = std::size_t;
|
||||
using difference_type = std::ptrdiff_t;
|
||||
using pointer = element_type*;
|
||||
using const_pointer = const element_type*;
|
||||
using reference = element_type&;
|
||||
using const_reference = const element_type&;
|
||||
using iterator = pointer;
|
||||
using const_iterator = pointer;
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr span() : span(nullptr, size_type{0}) {}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr span(pointer first, size_type count) : ptr_(first), size_(count)
|
||||
{
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr span(pointer first, pointer last) : span(first, last - first) {}
|
||||
|
||||
template <std::size_t N>
|
||||
CK_TILE_HOST_DEVICE constexpr span(element_type (&arr)[N]) noexcept : span(arr, N)
|
||||
{
|
||||
}
|
||||
|
||||
template <std::size_t N>
|
||||
CK_TILE_HOST_DEVICE constexpr span(std::array<value_type, N>& arr) noexcept
|
||||
: span(arr.data(), N)
|
||||
{
|
||||
}
|
||||
|
||||
template <typename Container>
|
||||
CK_TILE_HOST_DEVICE constexpr span(const Container& container)
|
||||
: span(container.data(), container.size())
|
||||
{
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr iterator begin() const noexcept { return ptr_; }
|
||||
CK_TILE_HOST_DEVICE constexpr const_iterator cbegin() const noexcept { return begin(); }
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr iterator end() const noexcept { return begin() + size(); }
|
||||
CK_TILE_HOST_DEVICE constexpr const_iterator cend() const noexcept { return end(); }
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr reference front() const { return *begin(); }
|
||||
CK_TILE_HOST_DEVICE constexpr reference back() const { return *(--end()); }
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr reference operator[](size_type idx) const
|
||||
{
|
||||
return *(begin() + idx);
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr pointer data() const noexcept { return ptr_; }
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr size_type size() const noexcept { return size_; }
|
||||
|
||||
private:
|
||||
pointer ptr_;
|
||||
size_type size_;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
70
include/ck_tile/core/container/statically_indexed_array.hpp
Normal file
70
include/ck_tile/core/container/statically_indexed_array.hpp
Normal file
@@ -0,0 +1,70 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/container/array.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
#if CK_TILE_STATICALLY_INDEXED_ARRAY_DEFAULT == CK_TILE_STATICALLY_INDEXED_ARRAY_USE_TUPLE
|
||||
namespace detail {
|
||||
template <typename X, typename Y>
|
||||
struct tuple_concat;
|
||||
|
||||
template <typename... Xs, typename... Ys>
|
||||
struct tuple_concat<tuple<Xs...>, tuple<Ys...>>
|
||||
{
|
||||
using type = tuple<Xs..., Ys...>;
|
||||
};
|
||||
|
||||
template <typename T, index_t N>
|
||||
struct statically_indexed_array_impl
|
||||
{
|
||||
using type =
|
||||
typename tuple_concat<typename statically_indexed_array_impl<T, N / 2>::type,
|
||||
typename statically_indexed_array_impl<T, N - N / 2>::type>::type;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct statically_indexed_array_impl<T, 0>
|
||||
{
|
||||
using type = tuple<>;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct statically_indexed_array_impl<T, 1>
|
||||
{
|
||||
using type = tuple<T>;
|
||||
};
|
||||
} // namespace detail
|
||||
|
||||
template <typename T, index_t N>
|
||||
using statically_indexed_array = typename detail::statically_indexed_array_impl<T, N>::type;
|
||||
|
||||
#else
|
||||
|
||||
// consider mark this struct as deprecated
|
||||
template <typename T, index_t N>
|
||||
using statically_indexed_array = array<T, N>;
|
||||
|
||||
#endif
|
||||
|
||||
// consider always use ck_tile::array for this purpose
|
||||
|
||||
template <typename X, typename... Xs>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_statically_indexed_array(const X& x, const Xs&... xs)
|
||||
{
|
||||
return statically_indexed_array<X, sizeof...(Xs) + 1>(x, static_cast<X>(xs)...);
|
||||
}
|
||||
|
||||
// make empty statically_indexed_array
|
||||
template <typename X>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_statically_indexed_array()
|
||||
{
|
||||
return statically_indexed_array<X, 0>();
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
483
include/ck_tile/core/container/tuple.hpp
Normal file
483
include/ck_tile/core/container/tuple.hpp
Normal file
@@ -0,0 +1,483 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/container/sequence.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/numeric/math.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
#include <utility>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
namespace impl {
|
||||
|
||||
template <index_t idx, typename T, bool is_empty = std::is_empty_v<T>>
|
||||
struct tuple_element
|
||||
{
|
||||
};
|
||||
|
||||
template <index_t idx, typename T>
|
||||
struct tuple_element<idx, T, true>
|
||||
{
|
||||
CK_TILE_HOST_DEVICE constexpr tuple_element() {}
|
||||
CK_TILE_HOST_DEVICE constexpr tuple_element(const T&) {}
|
||||
};
|
||||
|
||||
template <index_t idx, typename T>
|
||||
struct tuple_element<idx, T, false>
|
||||
{
|
||||
CK_TILE_HOST_DEVICE constexpr tuple_element() {}
|
||||
CK_TILE_HOST_DEVICE constexpr tuple_element(const T& e) : element(e) {}
|
||||
T element;
|
||||
};
|
||||
|
||||
template <std::size_t I, class T>
|
||||
CK_TILE_HOST_DEVICE constexpr T const& getv(tuple_element<I, T, false> const& x)
|
||||
{
|
||||
return x.element;
|
||||
}
|
||||
|
||||
template <std::size_t I, class T>
|
||||
CK_TILE_HOST_DEVICE constexpr T& getv(tuple_element<I, T, false>& x)
|
||||
{
|
||||
return x.element;
|
||||
}
|
||||
|
||||
template <std::size_t I, class T>
|
||||
CK_TILE_HOST_DEVICE constexpr T&& getv(tuple_element<I, T, false>&& x)
|
||||
{
|
||||
return static_cast<T&&>(x.element);
|
||||
}
|
||||
|
||||
template <typename index_seq, typename... T>
|
||||
struct tuple_base;
|
||||
|
||||
template <index_t... I, typename... T>
|
||||
struct tuple_base<sequence<I...>, T...> : public tuple_element<I, T>...
|
||||
{
|
||||
CK_TILE_HOST_DEVICE constexpr tuple_base() {}
|
||||
|
||||
template <class... U>
|
||||
CK_TILE_HOST_DEVICE constexpr explicit tuple_base(U const&... u) : tuple_element<I, T>(u)...
|
||||
{
|
||||
}
|
||||
|
||||
template <class... U>
|
||||
CK_TILE_HOST_DEVICE constexpr tuple_base(tuple_base<sequence<I...>, U...> const& u)
|
||||
: tuple_element<I, T>(getv(static_cast<tuple_element<I, U> const&>(u)))...
|
||||
{
|
||||
}
|
||||
};
|
||||
} // namespace impl
|
||||
|
||||
template <class... T>
|
||||
struct tuple : impl::tuple_base<make_index_sequence<sizeof...(T)>, T...>
|
||||
{
|
||||
CK_TILE_HOST_DEVICE
|
||||
static constexpr auto size() { return sizeof...(T); }
|
||||
using base = impl::tuple_base<make_index_sequence<sizeof...(T)>, T...>;
|
||||
CK_TILE_HOST_DEVICE constexpr tuple() {}
|
||||
|
||||
template <class... U>
|
||||
CK_TILE_HOST_DEVICE constexpr tuple(U const&... u)
|
||||
: impl::tuple_base<make_index_sequence<sizeof...(T)>, T...>(u...)
|
||||
{
|
||||
}
|
||||
|
||||
template <class... U>
|
||||
CK_TILE_HOST_DEVICE constexpr tuple(tuple<U...> const& u)
|
||||
: impl::tuple_base<make_index_sequence<sizeof...(T)>, T...>(
|
||||
static_cast<impl::tuple_base<U...> const&>(u))
|
||||
{
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr bool is_static()
|
||||
{
|
||||
bool flag = true;
|
||||
|
||||
static_for<0, sizeof...(Xs), 1>{}([&flag](auto i) {
|
||||
flag &= is_static_v<remove_cvref_t<__type_pack_element<i.value, Xs...>>>;
|
||||
});
|
||||
|
||||
return flag;
|
||||
}
|
||||
|
||||
#define TP_COM_() static_assert(I < size(), "wrong! out of range")
|
||||
// clang-format off
|
||||
template<index_t I> CK_TILE_HOST_DEVICE constexpr const auto & get() const { TP_COM_(); return impl::getv<I>(*this); }
|
||||
template<index_t I> CK_TILE_HOST_DEVICE constexpr const auto & get(number<I>) const { TP_COM_(); return get<I>(); }
|
||||
template<index_t I> CK_TILE_HOST_DEVICE constexpr auto & get() { TP_COM_(); return impl::getv<I>(*this); }
|
||||
template<index_t I> CK_TILE_HOST_DEVICE constexpr auto & get(number<I>) { TP_COM_(); return get<I>(); }
|
||||
|
||||
template<index_t I> CK_TILE_HOST_DEVICE constexpr const auto & at() const { TP_COM_(); return impl::getv<I>(*this); }
|
||||
template<index_t I> CK_TILE_HOST_DEVICE constexpr const auto & at(number<I>) const { TP_COM_(); return get<I>(); }
|
||||
template<index_t I> CK_TILE_HOST_DEVICE constexpr auto & at() { TP_COM_(); return impl::getv<I>(*this); }
|
||||
template<index_t I> CK_TILE_HOST_DEVICE constexpr auto & at(number<I>) { TP_COM_(); return get<I>(); }
|
||||
|
||||
template<index_t I> CK_TILE_HOST_DEVICE constexpr auto & operator[](number<I>) { TP_COM_(); return get<I>(); }
|
||||
template<index_t I> CK_TILE_HOST_DEVICE constexpr const auto & operator[](number<I>) const { TP_COM_(); return get<I>(); }
|
||||
template<index_t I> CK_TILE_HOST_DEVICE constexpr auto & operator()(number<I>) { TP_COM_(); return get<I>(); } // TODO: compatible
|
||||
// clang-format on
|
||||
#undef TP_COM_
|
||||
};
|
||||
|
||||
// template <class... T>
|
||||
// CK_TILE_HOST_DEVICE constexpr
|
||||
// tuple<T...>
|
||||
// make_tuple(T const&... t)
|
||||
// {
|
||||
// return {t...};
|
||||
// }
|
||||
|
||||
template <typename... Xs>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs&&... xs)
|
||||
{
|
||||
return tuple<remove_cvref_t<Xs>...>(std::forward<Xs>(xs)...);
|
||||
}
|
||||
|
||||
// https://en.cppreference.com/w/cpp/utility/tuple/tie
|
||||
template <typename... Args>
|
||||
constexpr tuple<Args&...> tie(Args&... args) noexcept
|
||||
{
|
||||
return {args...};
|
||||
}
|
||||
|
||||
template <typename X, typename Y>
|
||||
struct tuple_concat;
|
||||
|
||||
template <typename... Xs, typename... Ys>
|
||||
struct tuple_concat<tuple<Xs...>, tuple<Ys...>>
|
||||
{
|
||||
using type = tuple<Xs..., Ys...>;
|
||||
};
|
||||
|
||||
template <typename F, index_t N>
|
||||
CK_TILE_HOST_DEVICE constexpr auto generate_tuple(F&& f, number<N>)
|
||||
{
|
||||
return unpack([&f](auto&&... is) { return make_tuple(f(is)...); },
|
||||
typename arithmetic_sequence_gen<0, N, 1>::type{});
|
||||
}
|
||||
|
||||
template <typename F, index_t N>
|
||||
CK_TILE_HOST_DEVICE constexpr auto generate_tie(F&& f, number<N>)
|
||||
{
|
||||
return unpack([&f](auto&&... is) { return tie(f(is)...); },
|
||||
typename arithmetic_sequence_gen<0, N, 1>::type{});
|
||||
}
|
||||
|
||||
// tx and ty are tuple of references, return type of will tuple of referennce (not rvalue)
|
||||
template <typename... X, typename... Y>
|
||||
CK_TILE_HOST_DEVICE constexpr auto concat_tuple_of_reference(const tuple<X&...>& tx,
|
||||
const tuple<Y&...>& ty)
|
||||
{
|
||||
return unpack2(
|
||||
[&](auto&&... zs) { return tuple<decltype(zs)...>{std::forward<decltype(zs)>(zs)...}; },
|
||||
tx,
|
||||
ty);
|
||||
}
|
||||
|
||||
template <typename... X, typename... Y>
|
||||
CK_TILE_HOST_DEVICE constexpr auto concat_tuple(const tuple<X...>& tx, const tuple<Y...>& ty)
|
||||
{
|
||||
return unpack2(
|
||||
[&](auto... zs) { return tuple<decltype(zs)...>{std::forward<decltype(zs)>(zs)...}; },
|
||||
tx,
|
||||
ty);
|
||||
}
|
||||
|
||||
// Support any number of tuples to concat (also 1)
|
||||
template <typename... X>
|
||||
CK_TILE_HOST_DEVICE constexpr auto concat_tuple(const tuple<X...>& tx)
|
||||
{
|
||||
return tx;
|
||||
}
|
||||
|
||||
template <typename... X, typename... Tuples>
|
||||
CK_TILE_HOST_DEVICE constexpr auto concat_tuple(const tuple<X...>& tx, const Tuples&... tuples)
|
||||
{
|
||||
return concat_tuple(tx, concat_tuple(tuples...));
|
||||
}
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <typename F, typename X, index_t... Is>
|
||||
CK_TILE_HOST_DEVICE constexpr auto transform_tuples_impl(F f, const X& x, sequence<Is...>)
|
||||
{
|
||||
return make_tuple(f(x.at(number<Is>{}))...);
|
||||
}
|
||||
|
||||
template <typename F, typename X, typename Y, index_t... Is>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
transform_tuples_impl(F f, const X& x, const Y& y, sequence<Is...>)
|
||||
{
|
||||
return make_tuple(f(x.at(number<Is>{}), y.at(number<Is>{}))...);
|
||||
}
|
||||
|
||||
template <typename F, typename X, typename Y, typename Z, index_t... Is>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
transform_tuples_impl(F f, const X& x, const Y& y, const Z& z, sequence<Is...>)
|
||||
{
|
||||
return make_tuple(f(x.at(number<Is>{}), y.at(number<Is>{}), z.at(number<Is>{}))...);
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
template <typename F, typename X>
|
||||
CK_TILE_HOST_DEVICE constexpr auto transform_tuples(F f, const X& x)
|
||||
{
|
||||
return detail::transform_tuples_impl(
|
||||
f, x, typename arithmetic_sequence_gen<0, X::size()(), 1>::type{});
|
||||
}
|
||||
|
||||
template <typename F, typename X, typename Y>
|
||||
CK_TILE_HOST_DEVICE constexpr auto transform_tuples(F f, const X& x, const Y& y)
|
||||
{
|
||||
return detail::transform_tuples_impl(
|
||||
f, x, y, typename arithmetic_sequence_gen<0, X::size()(), 1>::type{});
|
||||
}
|
||||
|
||||
template <typename F, typename X, typename Y, typename Z>
|
||||
CK_TILE_HOST_DEVICE constexpr auto transform_tuples(F f, const X& x, const Y& y, const Z& z)
|
||||
{
|
||||
return detail::transform_tuples_impl(
|
||||
f, x, y, z, typename arithmetic_sequence_gen<0, X::size()(), 1>::type{});
|
||||
}
|
||||
|
||||
// By default unroll to the flatten
|
||||
template <index_t Depth = 0, index_t MaxDepth = -1>
|
||||
CK_TILE_HOST_DEVICE constexpr auto unroll_nested_tuple(const tuple<>& element)
|
||||
{
|
||||
return element;
|
||||
}
|
||||
|
||||
template <index_t Depth = 0, index_t MaxDepth = -1, typename T>
|
||||
CK_TILE_HOST_DEVICE constexpr auto unroll_nested_tuple(const T& element)
|
||||
{
|
||||
return make_tuple(element);
|
||||
}
|
||||
|
||||
template <index_t Depth = 0, index_t MaxDepth = -1, typename... Ts>
|
||||
CK_TILE_HOST_DEVICE constexpr auto unroll_nested_tuple(const tuple<Ts...>& tuple)
|
||||
{
|
||||
if constexpr(Depth == MaxDepth)
|
||||
{
|
||||
return tuple;
|
||||
}
|
||||
else
|
||||
{
|
||||
return unpack(
|
||||
[&](auto&&... ts) {
|
||||
return concat_tuple(unroll_nested_tuple<Depth + 1, MaxDepth>(ts)...);
|
||||
},
|
||||
tuple);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename... Ts>
|
||||
CK_TILE_HOST_DEVICE constexpr auto tuple_reverse(const tuple<Ts...>& tuple)
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
using Idx = number<tuple<Ts...>::size()() - i - 1>;
|
||||
return tuple.at(Idx{});
|
||||
},
|
||||
number<tuple<Ts...>::size()()>{});
|
||||
}
|
||||
|
||||
// Reduce tuple values in specific range using Function
|
||||
template <index_t Idx, index_t End, typename F, typename... Ts>
|
||||
CK_TILE_HOST_DEVICE constexpr auto tuple_reduce(F&& f, const tuple<Ts...>& tuple)
|
||||
{
|
||||
static_assert(Idx < End, "Wrong parameters for tuple_reduce");
|
||||
if constexpr(Idx + 1 == End)
|
||||
{
|
||||
return tuple.at(number<Idx>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return f(tuple.at(number<Idx>{}), tuple_reduce<Idx + 1, End>(f, tuple));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
using is_tuple = decltype(std::declval<T&>().IsTuple());
|
||||
|
||||
template <typename... Ts>
|
||||
CK_TILE_HOST_DEVICE constexpr auto is_nested_tuple(const tuple<Ts...>&)
|
||||
{
|
||||
return (is_detected<is_tuple, Ts>::value || ...);
|
||||
}
|
||||
|
||||
template <index_t depth = 0, typename T>
|
||||
CK_TILE_HOST_DEVICE constexpr auto tuple_depth(const T&)
|
||||
{
|
||||
return depth;
|
||||
}
|
||||
|
||||
template <index_t depth = 0, typename... Ts>
|
||||
CK_TILE_HOST_DEVICE constexpr auto tuple_depth(const tuple<Ts...>&)
|
||||
{
|
||||
return math::max(tuple_depth<depth + 1>(Ts{})...);
|
||||
}
|
||||
|
||||
template <typename... Seqs>
|
||||
CK_TILE_HOST_DEVICE constexpr auto to_array_of_array(tuple<Seqs...> t_of_s)
|
||||
{
|
||||
constexpr index_t n0 = sizeof...(Seqs);
|
||||
|
||||
constexpr index_t max_n1 = [&] {
|
||||
index_t max_n1_ = 0;
|
||||
|
||||
static_for<0, n0, 1>{}([&](auto i0) {
|
||||
constexpr index_t n1 = t_of_s[i0].size()();
|
||||
|
||||
max_n1_ = max_n1_ < n1 ? n1 : max_n1_;
|
||||
});
|
||||
|
||||
return max_n1_;
|
||||
}();
|
||||
|
||||
array<array<index_t, max_n1>, n0> a_of_a{{-1}};
|
||||
|
||||
static_for<0, n0, 1>{}([&](auto i0) {
|
||||
constexpr index_t n1 = t_of_s[i0].size()();
|
||||
|
||||
static_for<0, n1, 1>{}([&](auto i1) { a_of_a(i0)(i1) = t_of_s[i0][i1]; });
|
||||
});
|
||||
|
||||
return a_of_a;
|
||||
}
|
||||
|
||||
// Here should use MultiIndex<NSize>, instead of tuple<Ys...>, although the former
|
||||
// is the alias of the latter. This is because compiler cannot infer the NSize if
|
||||
// using MultiIndex<NSize>
|
||||
// TODO: how to fix this?
|
||||
template <typename... Ys,
|
||||
typename X,
|
||||
std::enable_if_t<!std::is_integral<X>::value && !std::is_floating_point<X>::value, bool> =
|
||||
false>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator+=(tuple<Ys...>& y, const X& x)
|
||||
{
|
||||
static_assert(X::Size() == sizeof...(Ys), "wrong! size not the same");
|
||||
constexpr index_t NSize = sizeof...(Ys);
|
||||
static_for<0, NSize, 1>{}([&](auto i) { y[i] += x[i]; });
|
||||
return y;
|
||||
}
|
||||
|
||||
template <typename... Ys,
|
||||
typename X,
|
||||
std::enable_if_t<!std::is_integral<X>::value && !std::is_floating_point<X>::value, bool> =
|
||||
false>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator-=(tuple<Ys...>& y, const X& x)
|
||||
{
|
||||
static_assert(X::Size() == sizeof...(Ys), "wrong! size not the same");
|
||||
constexpr index_t NSize = sizeof...(Ys);
|
||||
static_for<0, NSize, 1>{}([&](auto i) { y[i] -= x[i]; });
|
||||
return y;
|
||||
}
|
||||
|
||||
template <typename... Xs,
|
||||
typename Y,
|
||||
std::enable_if_t<!std::is_integral<Y>::value && !std::is_floating_point<Y>::value, bool> =
|
||||
false>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator+(const tuple<Xs...>& x, const Y& y)
|
||||
{
|
||||
static_assert(Y::Size() == sizeof...(Xs), "wrong! size not the same");
|
||||
constexpr index_t NSize = sizeof...(Xs);
|
||||
|
||||
tuple<Xs...> r;
|
||||
static_for<0, NSize, 1>{}([&](auto i) { r[i] = x[i] + y[i]; });
|
||||
return r;
|
||||
}
|
||||
|
||||
template <typename... Xs,
|
||||
typename Y,
|
||||
std::enable_if_t<!std::is_integral<Y>::value && !std::is_floating_point<Y>::value, bool> =
|
||||
false>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator-(const tuple<Xs...>& x, const Y& y)
|
||||
{
|
||||
static_assert(Y::Size() == sizeof...(Xs), "wrong! size not the same");
|
||||
constexpr index_t NSize = sizeof...(Xs);
|
||||
|
||||
tuple<Xs...> r;
|
||||
static_for<0, NSize, 1>{}([&](auto i) { r[i] = x[i] - y[i]; });
|
||||
return r;
|
||||
}
|
||||
|
||||
template <typename... Xs,
|
||||
typename Y,
|
||||
std::enable_if_t<!std::is_integral<Y>::value && !std::is_floating_point<Y>::value, bool> =
|
||||
false>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator*(const tuple<Xs...>& x, const Y& y)
|
||||
{
|
||||
static_assert(Y::Size() == sizeof...(Xs), "wrong! size not the same");
|
||||
constexpr index_t NSize = sizeof...(Xs);
|
||||
|
||||
tuple<Xs...> r;
|
||||
static_for<0, NSize, 1>{}([&](auto i) { r[i] = x[i] * y[i]; });
|
||||
return r;
|
||||
}
|
||||
|
||||
// MultiIndex = scalar * MultiIndex
|
||||
template <
|
||||
typename... Xs,
|
||||
typename Y,
|
||||
std::enable_if_t<std::is_integral<Y>::value || std::is_floating_point<Y>::value, bool> = false>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator*(Y a, const tuple<Xs...>& x)
|
||||
{
|
||||
constexpr index_t NSize = sizeof...(Xs);
|
||||
tuple<Xs...> r;
|
||||
static_for<0, NSize, 1>{}([&](auto i) { r[i] = a * x[i]; });
|
||||
return r;
|
||||
}
|
||||
|
||||
// MultiIndex = MultiIndex * scalar
|
||||
template <
|
||||
typename... Xs,
|
||||
typename Y,
|
||||
std::enable_if_t<std::is_integral<Y>::value || std::is_floating_point<Y>::value, bool> = false>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator*(const tuple<Xs...>& x, Y a)
|
||||
{
|
||||
return a * x;
|
||||
}
|
||||
|
||||
template <typename... Xs, typename... Ys>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator/(const tuple<Xs...>& x, const tuple<Ys...>& y)
|
||||
{
|
||||
static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong!");
|
||||
constexpr index_t NSize = sizeof...(Xs);
|
||||
return generate_tuple([&](auto i) { return x[i] / y[i]; }, number<NSize>{});
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
// WARNING: needed by compiler for C++ structured binding support only, don't use this
|
||||
namespace std {
|
||||
|
||||
template <typename... Ts>
|
||||
struct tuple_size<ck_tile::tuple<Ts...>> : std::integral_constant<std::size_t, sizeof...(Ts)>
|
||||
{
|
||||
};
|
||||
|
||||
template <std::size_t I, typename... Ts>
|
||||
struct tuple_element<I, ck_tile::tuple<Ts...>> : ck_tile::tuple_element<I, ck_tile::tuple<Ts...>>
|
||||
{
|
||||
};
|
||||
|
||||
template <typename... Ts>
|
||||
struct tuple_size<const ck_tile::tuple<Ts...>> : std::integral_constant<std::size_t, sizeof...(Ts)>
|
||||
{
|
||||
};
|
||||
|
||||
template <std::size_t I, typename... Ts>
|
||||
struct tuple_element<I, const ck_tile::tuple<Ts...>>
|
||||
: ck_tile::tuple_element<I, const ck_tile::tuple<Ts...>>
|
||||
{
|
||||
};
|
||||
|
||||
} // namespace std
|
||||
116
include/ck_tile/core/numeric/arithmetic.hpp
Normal file
116
include/ck_tile/core/numeric/arithmetic.hpp
Normal file
@@ -0,0 +1,116 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
#include <stdint.h>
|
||||
|
||||
#pragma once
|
||||
|
||||
#define CK_TILE_ARITHMETIC_USING_FLOAT(type_) \
|
||||
CK_TILE_HOST_DEVICE \
|
||||
bool operator==(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return static_cast<float>(x) == static_cast<float>(y); \
|
||||
} \
|
||||
CK_TILE_HOST_DEVICE \
|
||||
bool operator!=(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return static_cast<float>(x) != static_cast<float>(y); \
|
||||
} \
|
||||
CK_TILE_HOST_DEVICE \
|
||||
bool operator<(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return static_cast<float>(x) < static_cast<float>(y); \
|
||||
} \
|
||||
CK_TILE_HOST_DEVICE \
|
||||
bool operator<=(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return static_cast<float>(x) <= static_cast<float>(y); \
|
||||
} \
|
||||
CK_TILE_HOST_DEVICE \
|
||||
bool operator>(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return static_cast<float>(x) > static_cast<float>(y); \
|
||||
} \
|
||||
CK_TILE_HOST_DEVICE \
|
||||
bool operator>=(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return static_cast<float>(x) >= static_cast<float>(y); \
|
||||
} \
|
||||
CK_TILE_HOST_DEVICE \
|
||||
type_ operator+(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return type_(static_cast<float>(x) + static_cast<float>(y)); \
|
||||
} \
|
||||
CK_TILE_HOST_DEVICE \
|
||||
type_ operator-(const type_& x) \
|
||||
{ \
|
||||
constexpr uint32_t bits = sizeof(type_) * 8; \
|
||||
constexpr uint32_t mask = 1 << (bits - 1); \
|
||||
type_ y = x; \
|
||||
y.data ^= static_cast<typename type_::raw_type>(mask); \
|
||||
return y; \
|
||||
} \
|
||||
CK_TILE_HOST_DEVICE \
|
||||
type_ operator-(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return type_(static_cast<float>(x) - static_cast<float>(y)); \
|
||||
} \
|
||||
CK_TILE_HOST_DEVICE \
|
||||
type_ operator*(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return type_(static_cast<float>(x) * static_cast<float>(y)); \
|
||||
} \
|
||||
CK_TILE_HOST_DEVICE \
|
||||
type_ operator/(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return type_(static_cast<float>(x) / static_cast<float>(y)); \
|
||||
} \
|
||||
CK_TILE_HOST_DEVICE \
|
||||
type_& operator+=(type_& x, const type_& y) \
|
||||
{ \
|
||||
x = type_(static_cast<float>(x) + static_cast<float>(y)); \
|
||||
return x; \
|
||||
} \
|
||||
CK_TILE_HOST_DEVICE \
|
||||
type_& operator-=(type_& x, const type_& y) \
|
||||
{ \
|
||||
x = type_(static_cast<float>(x) - static_cast<float>(y)); \
|
||||
return x; \
|
||||
} \
|
||||
CK_TILE_HOST_DEVICE \
|
||||
type_& operator*=(type_& x, const type_& y) \
|
||||
{ \
|
||||
x = type_(static_cast<float>(x) * static_cast<float>(y)); \
|
||||
return x; \
|
||||
} \
|
||||
CK_TILE_HOST_DEVICE \
|
||||
type_& operator/=(type_& x, const type_& y) \
|
||||
{ \
|
||||
x = type_(static_cast<float>(x) / static_cast<float>(y)); \
|
||||
return x; \
|
||||
} \
|
||||
CK_TILE_HOST_DEVICE \
|
||||
type_& operator++(type_& x) \
|
||||
{ \
|
||||
x = type_(static_cast<float>(x) + 1.f); \
|
||||
return x; \
|
||||
} \
|
||||
CK_TILE_HOST_DEVICE \
|
||||
type_& operator--(type_& x) \
|
||||
{ \
|
||||
x = type_(static_cast<float>(x) - 1.f); \
|
||||
return x; \
|
||||
} \
|
||||
CK_TILE_HOST_DEVICE \
|
||||
type_ operator++(type_& x, int) \
|
||||
{ \
|
||||
type_ y(x); \
|
||||
x = type_(static_cast<float>(x) + 1.f); \
|
||||
return y; \
|
||||
} \
|
||||
CK_TILE_HOST_DEVICE \
|
||||
type_ operator--(type_& x, int) \
|
||||
{ \
|
||||
type_ y(x); \
|
||||
x = type_(static_cast<float>(x) - 1.f); \
|
||||
return y; \
|
||||
}
|
||||
263
include/ck_tile/core/numeric/bfloat16.hpp
Normal file
263
include/ck_tile/core/numeric/bfloat16.hpp
Normal file
@@ -0,0 +1,263 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/utility/bit_cast.hpp"
|
||||
#include "ck_tile/core/numeric/arithmetic.hpp"
|
||||
#include "ck_tile/core/numeric/half.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include <stdint.h>
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
enum class bf16_rounding_mode
|
||||
{
|
||||
standard = 0, // rtn
|
||||
truncate_with_nan,
|
||||
truncate,
|
||||
};
|
||||
|
||||
template <bf16_rounding_mode rounding = CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT>
|
||||
CK_TILE_HOST_DEVICE uint16_t float_to_bf16_raw(float f, constant<rounding> = {});
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
float bf16_to_float_raw(uint16_t x);
|
||||
|
||||
// HIP use __hip_bfloat16 as struct
|
||||
struct alignas(2) bfloat16_t
|
||||
{
|
||||
using raw_type = uint16_t;
|
||||
raw_type data;
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
static bfloat16_t bit_cast(raw_type x)
|
||||
{
|
||||
bfloat16_t y;
|
||||
y.data = x;
|
||||
return y;
|
||||
}
|
||||
|
||||
// constructor
|
||||
bfloat16_t() = default;
|
||||
|
||||
// construct from float
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit bfloat16_t(const float& x) { data = float_to_bf16_raw(x); }
|
||||
|
||||
// construct from int
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit bfloat16_t(const int& x) { data = float_to_bf16_raw(static_cast<float>(x)); }
|
||||
|
||||
// construct from unsigned int
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit bfloat16_t(const unsigned int& x) { data = float_to_bf16_raw(static_cast<float>(x)); }
|
||||
|
||||
// cast to float
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit operator float() const { return bf16_to_float_raw(data); }
|
||||
|
||||
// cast to int
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit operator int() const { return static_cast<int>(bf16_to_float_raw(data)); }
|
||||
|
||||
// internal access
|
||||
CK_TILE_HOST_DEVICE
|
||||
raw_type& get() { return data; }
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
raw_type get() const { return data; }
|
||||
};
|
||||
|
||||
// round to nearest
|
||||
CK_TILE_HOST_DEVICE
|
||||
uint16_t float_to_bf16_rtn_raw(float f)
|
||||
{
|
||||
union
|
||||
{
|
||||
float fp32;
|
||||
uint32_t int32;
|
||||
} u = {f};
|
||||
if(~u.int32 & 0x7f800000)
|
||||
{
|
||||
// When the exponent bits are not all 1s, then the value is zero, normal,
|
||||
// or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus
|
||||
// 1 if the least significant bit of the bfloat16 mantissa is 1 (odd).
|
||||
// This causes the bfloat16's mantissa to be incremented by 1 if the 16
|
||||
// least significant bits of the float mantissa are greater than 0x8000,
|
||||
// or if they are equal to 0x8000 and the least significant bit of the
|
||||
// bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when
|
||||
// the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already
|
||||
// has the value 0x7f, then incrementing it causes it to become 0x00 and
|
||||
// the exponent is incremented by one, which is the next higher FP value
|
||||
// to the unrounded bfloat16 value. When the bfloat16 value is subnormal
|
||||
// with an exponent of 0x00 and a mantissa of 0x7F, it may be rounded up
|
||||
// to a normal value with an exponent of 0x01 and a mantissa of 0x00.
|
||||
// When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F,
|
||||
// incrementing it causes it to become an exponent of 0xFF and a mantissa
|
||||
// of 0x00, which is Inf, the next higher value to the unrounded value.
|
||||
u.int32 += 0x7fff + ((u.int32 >> 16) & 1); // Round to nearest, round to even
|
||||
}
|
||||
else if(u.int32 & 0xffff)
|
||||
{
|
||||
// When all of the exponent bits are 1, the value is Inf or NaN.
|
||||
// Inf is indicated by a zero mantissa. NaN is indicated by any nonzero
|
||||
// mantissa bit. Quiet NaN is indicated by the most significant mantissa
|
||||
// bit being 1. Signaling NaN is indicated by the most significant
|
||||
// mantissa bit being 0 but some other bit(s) being 1. If any of the
|
||||
// lower 16 bits of the mantissa are 1, we set the least significant bit
|
||||
// of the bfloat16 mantissa, in order to preserve signaling NaN in case
|
||||
// the bloat16's mantissa bits are all 0.
|
||||
u.int32 |= 0x10000; // Preserve signaling NaN
|
||||
}
|
||||
return uint16_t(u.int32 >> 16);
|
||||
}
|
||||
|
||||
// Truncate instead of rounding, preserving SNaN
|
||||
CK_TILE_HOST_DEVICE
|
||||
uint16_t float_to_bf16_truc_nan_raw(float f)
|
||||
{
|
||||
union
|
||||
{
|
||||
float fp32;
|
||||
uint32_t int32;
|
||||
} u = {f};
|
||||
return uint16_t(u.int32 >> 16) | (!(~u.int32 & 0x7f800000) && (u.int32 & 0xffff));
|
||||
}
|
||||
|
||||
// Fast truncate instead of rounding, RTZ
|
||||
CK_TILE_HOST_DEVICE
|
||||
uint16_t float_to_bf16_truc_raw(float f)
|
||||
{
|
||||
union
|
||||
{
|
||||
float fp32;
|
||||
uint32_t int32;
|
||||
} u = {f};
|
||||
return uint16_t(u.int32 >> 16);
|
||||
}
|
||||
|
||||
template <bf16_rounding_mode rounding = CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT>
|
||||
CK_TILE_HOST_DEVICE uint16_t float_to_bf16_raw(float f, constant<rounding> = {})
|
||||
{
|
||||
if constexpr(rounding == bf16_rounding_mode::standard)
|
||||
return float_to_bf16_rtn_raw(f);
|
||||
else if constexpr(rounding == bf16_rounding_mode::truncate_with_nan)
|
||||
return float_to_bf16_truc_nan_raw(f);
|
||||
else
|
||||
return float_to_bf16_truc_raw(f);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
float bf16_to_float_raw(uint16_t x)
|
||||
{
|
||||
union
|
||||
{
|
||||
uint32_t int32;
|
||||
float fp32;
|
||||
} u = {uint32_t(x) << 16};
|
||||
return u.fp32;
|
||||
}
|
||||
|
||||
template <bf16_rounding_mode rounding = CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT>
|
||||
CK_TILE_HOST_DEVICE bfloat16_t float_to_bf16(float f, constant<rounding> = {})
|
||||
{
|
||||
return bfloat16_t::bit_cast(float_to_bf16_raw(f, constant<rounding>{}));
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
float bf16_to_float(bfloat16_t x) { return static_cast<float>(x); }
|
||||
|
||||
template <bf16_rounding_mode rounding = CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT>
|
||||
CK_TILE_HOST_DEVICE bfloat16_t fp16_to_bf16(half_t f, constant<rounding> = {})
|
||||
{
|
||||
return bfloat16_t::bit_cast(float_to_bf16_raw(static_cast<float>(f), constant<rounding>{}));
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
float bf16_to_fp16(bfloat16_t x) { return float_to_fp16(static_cast<float>(x)); }
|
||||
|
||||
template <class T>
|
||||
struct numeric_limits;
|
||||
|
||||
template <>
|
||||
struct numeric_limits<bfloat16_t>
|
||||
{
|
||||
// minimum finite value, or minimum positive normalized value for float
|
||||
CK_TILE_HOST_DEVICE static constexpr bfloat16_t min() { return bfloat16_t::bit_cast(0x0080); }
|
||||
|
||||
// minumum finite value
|
||||
CK_TILE_HOST_DEVICE static constexpr bfloat16_t lowest()
|
||||
{
|
||||
return bfloat16_t::bit_cast(0xff7f);
|
||||
}
|
||||
|
||||
// maximum finite value
|
||||
CK_TILE_HOST_DEVICE static constexpr bfloat16_t max() { return bfloat16_t::bit_cast(0x7f7f); }
|
||||
|
||||
// difference between 1.0 and next value representable by float
|
||||
CK_TILE_HOST_DEVICE static constexpr bfloat16_t epsilon()
|
||||
{
|
||||
return bfloat16_t::bit_cast(0x1000);
|
||||
}
|
||||
|
||||
// maximum rounding error
|
||||
CK_TILE_HOST_DEVICE static constexpr bfloat16_t round_error() { return bfloat16_t(0.5f); }
|
||||
|
||||
// positive infinity value
|
||||
CK_TILE_HOST_DEVICE static constexpr bfloat16_t infinity()
|
||||
{
|
||||
return bfloat16_t::bit_cast(0x7f80);
|
||||
}
|
||||
|
||||
// quiet NaN
|
||||
CK_TILE_HOST_DEVICE static constexpr bfloat16_t quiet_NaN()
|
||||
{
|
||||
return bfloat16_t::bit_cast(0x7FFF);
|
||||
}
|
||||
|
||||
// signaling NaN
|
||||
CK_TILE_HOST_DEVICE static constexpr bfloat16_t signaling_NaN()
|
||||
{
|
||||
return bfloat16_t::bit_cast(0x7FFF);
|
||||
}
|
||||
|
||||
// smallest positive subnormal value
|
||||
CK_TILE_HOST_DEVICE static constexpr bfloat16_t denorm_min()
|
||||
{
|
||||
return bfloat16_t::bit_cast(0x0001);
|
||||
}
|
||||
};
|
||||
|
||||
CK_TILE_ARITHMETIC_USING_FLOAT(bfloat16_t)
|
||||
|
||||
// math
|
||||
CK_TILE_HOST_DEVICE
|
||||
bfloat16_t abs(const bfloat16_t& x) { return bfloat16_t::bit_cast(x.get() & 0x7fff); }
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
bool isnan(const bfloat16_t& x)
|
||||
{
|
||||
uint16_t xx = x.get();
|
||||
return (xx & 0x7FFF) > 0x7C00;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE
|
||||
bfloat16_t sqrt(bfloat16_t x)
|
||||
{
|
||||
return static_cast<bfloat16_t>(__builtin_amdgcn_sqrtf(static_cast<float>(x)));
|
||||
};
|
||||
|
||||
CK_TILE_DEVICE
|
||||
bfloat16_t exp(bfloat16_t x) { return static_cast<bfloat16_t>(__expf(static_cast<float>(x))); };
|
||||
|
||||
CK_TILE_DEVICE
|
||||
bfloat16_t exp2(bfloat16_t x) { return static_cast<bfloat16_t>(exp2f(static_cast<float>(x))); };
|
||||
|
||||
CK_TILE_DEVICE
|
||||
bfloat16_t log(bfloat16_t x) { return static_cast<bfloat16_t>(__logf(static_cast<float>(x))); };
|
||||
|
||||
using bf16_t = bfloat16_t;
|
||||
|
||||
} // namespace ck_tile
|
||||
735
include/ck_tile/core/numeric/float8.hpp
Normal file
735
include/ck_tile/core/numeric/float8.hpp
Normal file
@@ -0,0 +1,735 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/utility/bit_cast.hpp"
|
||||
#include "ck_tile/core/utility/limits.hpp"
|
||||
#include "ck_tile/core/utility/random.hpp"
|
||||
#include "ck_tile/core/numeric/arithmetic.hpp"
|
||||
#include "ck_tile/core/numeric/half.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include <stdint.h>
|
||||
#include <type_traits>
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// fp8 rounding modes
|
||||
// use standard for rounding to nearest, the faster one
|
||||
// use stochastic for stochastic rounding, helps to avoid error accumulation
|
||||
enum class fp8_rounding_mode
|
||||
{
|
||||
standard = 0,
|
||||
stochastic
|
||||
};
|
||||
|
||||
/*
|
||||
* ______________NANOO_________________ | ______________IEEE________________
|
||||
* e4m3 e5m2 | e4m3 e5m2
|
||||
* bias : 8 16 | 7 15
|
||||
* inf : 1.0000.000 1.00000.00 | N/A s.11111.00
|
||||
* Nan : 1.0000.000 1.00000.00 | s.1111.111 s.11111.{01, 10, 11}
|
||||
* zero : 0.0000.000 0.00000.00 | s.0000.000 s.00000.00
|
||||
* Max(norm) : s.1111.111 (240) s.11111.11(57344) | s.1111.110(448) s.11110.11(57344)
|
||||
* Max(snorm): s.0000.111 s.00000.11 | s.0000.111(448) s.00000.11(57344)
|
||||
* 0.0068359375 2.288818e-05 | 0.013671875 4.57763671875e-05
|
||||
* Min(norm) : s.0001.000 s.00001.00 | s.0001.000 s.00001.00
|
||||
* 2^-7(0.00078125) 2^-15(3.05176e-05) | 2^-6(0.015625) 2^-14(6.10352e-05)
|
||||
* Min(snorm): s.0000.001 s.00000.01 | s.0000.001 s.00000.01
|
||||
* 2^-10(0.00097656) 2^-17(7.629395e-06)| 2^-9(0.001953125) 2^-16(1.52588e-05)
|
||||
*/
|
||||
|
||||
template <fp8_rounding_mode rounding = static_cast<fp8_rounding_mode>(CK_TILE_FLOAT_TO_FP8_DEFAULT)>
|
||||
CK_TILE_HOST_DEVICE uint8_t float_to_fp8_raw(float, constant<rounding> = {});
|
||||
|
||||
template <fp8_rounding_mode rounding = static_cast<fp8_rounding_mode>(CK_TILE_FLOAT_TO_FP8_DEFAULT)>
|
||||
CK_TILE_HOST_DEVICE uint8_t float_to_bf8_raw(float, constant<rounding> = {});
|
||||
|
||||
CK_TILE_HOST_DEVICE float fp8_to_float_raw(uint8_t);
|
||||
CK_TILE_HOST_DEVICE float bf8_to_float_raw(uint8_t);
|
||||
|
||||
struct alignas(1) float8_e4m3_t
|
||||
{
|
||||
static constexpr int exponent = 4;
|
||||
static constexpr int mantissa = 3;
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
static constexpr int bias = 1 << (exponent - 1); // NANOO
|
||||
#else
|
||||
static constexpr int bias = (1 << (exponent - 1)) - 1; // IEEE
|
||||
#endif
|
||||
using raw_type = uint8_t;
|
||||
raw_type data;
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
static float8_e4m3_t bit_cast(raw_type x)
|
||||
{
|
||||
float8_e4m3_t y;
|
||||
y.data = x;
|
||||
return y;
|
||||
}
|
||||
|
||||
// constructor
|
||||
float8_e4m3_t() = default;
|
||||
|
||||
// construct from float
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit float8_e4m3_t(const float& x) { data = float_to_fp8_raw(x); }
|
||||
|
||||
// construct from int
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit float8_e4m3_t(const int& x) { data = float_to_fp8_raw(static_cast<float>(x)); }
|
||||
|
||||
// construct from unsigned int
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit float8_e4m3_t(const unsigned int& x)
|
||||
{
|
||||
data = float_to_fp8_raw(static_cast<float>(x));
|
||||
}
|
||||
|
||||
// cast to float
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit operator float() const { return fp8_to_float_raw(data); }
|
||||
|
||||
// cast to int
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit operator int() const { return static_cast<int>(fp8_to_float_raw(data)); }
|
||||
|
||||
// internal access
|
||||
CK_TILE_HOST_DEVICE
|
||||
raw_type& get() { return data; }
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
raw_type get() const { return data; }
|
||||
};
|
||||
|
||||
struct alignas(1) float8_e5m2_t
|
||||
{
|
||||
static constexpr int exponent = 5;
|
||||
static constexpr int mantissa = 2;
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
static constexpr int bias = 1 << (exponent - 1); // NANOO
|
||||
#else
|
||||
static constexpr int bias = (1 << (exponent - 1)) - 1; // IEEE
|
||||
#endif
|
||||
using raw_type = uint8_t;
|
||||
raw_type data;
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
static float8_e5m2_t bit_cast(raw_type x)
|
||||
{
|
||||
float8_e5m2_t y;
|
||||
y.data = x;
|
||||
return y;
|
||||
}
|
||||
|
||||
// constructor
|
||||
float8_e5m2_t() = default;
|
||||
|
||||
// construct from float
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit float8_e5m2_t(const float& x) { data = float_to_bf8_raw(x); }
|
||||
|
||||
// construct from int
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit float8_e5m2_t(const int& x) { data = float_to_bf8_raw(static_cast<float>(x)); }
|
||||
|
||||
// construct from unsigned int
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit float8_e5m2_t(const unsigned int& x)
|
||||
{
|
||||
data = float_to_bf8_raw(static_cast<float>(x));
|
||||
}
|
||||
|
||||
// cast to float
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit operator float() const { return bf8_to_float_raw(data); }
|
||||
|
||||
// cast to int
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit operator int() const { return static_cast<int>(bf8_to_float_raw(data)); }
|
||||
|
||||
// internal access
|
||||
CK_TILE_HOST_DEVICE
|
||||
raw_type& get() { return data; }
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
raw_type get() const { return data; }
|
||||
};
|
||||
|
||||
// below is sw fp8 conversion, not utilizing hw instruction
|
||||
namespace impl {
|
||||
|
||||
template <typename X, typename Y, bool negative_zero_nan, bool clip, bool stoch>
|
||||
CK_TILE_HOST_DEVICE Y run_cast_to_f8(X x, uint32_t rng)
|
||||
{
|
||||
// fp8/bf8 exponent/mantissa layout
|
||||
constexpr int out_exp = numeric_utils<Y>::exp;
|
||||
constexpr int out_mant = numeric_utils<Y>::mant;
|
||||
|
||||
// original type exponent/mantissa layout
|
||||
constexpr int in_exp = numeric_utils<X>::exp;
|
||||
constexpr int in_mant = numeric_utils<X>::mant;
|
||||
|
||||
int exponent, bias;
|
||||
uint32_t head, mantissa, sign;
|
||||
// nan code is same for float and half
|
||||
constexpr Y nan_code = 0x80;
|
||||
constexpr uint32_t nan_mask = numeric_utils<X>::nan_mask;
|
||||
|
||||
// convert to bitwise
|
||||
using T_bitwise = typename numeric_utils<X>::bitwise_type;
|
||||
T_bitwise x_bitwise = *(reinterpret_cast<T_bitwise*>(&x));
|
||||
|
||||
// unpack the input, depends on datatype
|
||||
head = x_bitwise & numeric_utils<X>::head_mask;
|
||||
mantissa = x_bitwise & numeric_utils<X>::mant_mask;
|
||||
exponent = (head >> in_mant) & numeric_utils<X>::exp_mask;
|
||||
sign = head >> (in_exp + in_mant);
|
||||
bias = numeric_utils<X>::bias;
|
||||
|
||||
uint32_t signed_inf = (sign << (in_exp + in_mant)) + (((1 << in_exp) - 1) << in_mant);
|
||||
uint32_t drop_mask = (1 << (in_mant - out_mant)) - 1;
|
||||
constexpr int max_exp = (1 << out_exp) - (negative_zero_nan ? 1 : 2);
|
||||
|
||||
if constexpr(negative_zero_nan)
|
||||
{
|
||||
if((x_bitwise & nan_mask) == nan_mask)
|
||||
return nan_code;
|
||||
}
|
||||
else
|
||||
{
|
||||
if((x_bitwise & nan_mask) == nan_mask)
|
||||
return signed_inf + (mantissa != 0 ? 1 : 0);
|
||||
}
|
||||
|
||||
// check if x is 0.0
|
||||
if(x_bitwise == 0)
|
||||
return 0;
|
||||
|
||||
// First need to check if it is normal or denorm as there is a difference of implict 1
|
||||
// Then need to adjust the exponent to align with the F8 exponent, in the meanwhile, shift
|
||||
// The mantissa. Then for stochastic rounding, add rng to mantissa and truncate. And for
|
||||
// RNE, no need to add rng. Then probably need to check whether there is carry and adjust
|
||||
// exponent and mantissa again3
|
||||
|
||||
// For IEEE bias mode, the bias is 2^(k-1)-1 where k is the width of exponent bits
|
||||
const int out_bias = (1 << (out_exp - 1)) - 1 + (negative_zero_nan ? 1 : 0);
|
||||
const int out_denormal_act_exponent = 1 - out_bias; // actual exponent of f8 denormal
|
||||
// act_exponent is the actual exponent of fp32/fp16 (after subtracting bias)
|
||||
// out_exponent is the converted f8 exponent with bias encoding
|
||||
// exponent_diff is the diff between fp32/fp16 exponent and f8 exponent,
|
||||
// the difference needs to be adjusted and mantissa shifted
|
||||
int act_exponent, out_exponent, exponent_diff;
|
||||
|
||||
if(exponent == 0)
|
||||
{ // fp32/fp16 is in denormal.
|
||||
/* fp32 denormal is below 2^-127 so it is usually not a concern here, we mostly concern fp16
|
||||
here. In this case, f8 is usually in denormal. But there could be exceptions. fp16 denormal has
|
||||
exponent bias 15 while bf8 with NANOO has exponent bias 16. It means that there are some numbers in
|
||||
fp16 denormal but they are bf8 (NANOO) normals - smallest bf8 (NANOO) normal is 2^-15. fp16 numbers
|
||||
where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8 (NANOO) normal.
|
||||
In this case, the fp16 mantissa should be shift left by 1 */
|
||||
act_exponent = exponent - bias + 1;
|
||||
exponent_diff = out_denormal_act_exponent -
|
||||
act_exponent; // actual exponent is exponent-bias+1 as it is denormal
|
||||
}
|
||||
else
|
||||
{ // fp32/fp16 is normal with implicit 1
|
||||
act_exponent = exponent - bias;
|
||||
if(act_exponent <= out_denormal_act_exponent)
|
||||
{
|
||||
/* This is the case where fp32/fp16 is normal but it is in f8 denormal range.
|
||||
For example fp8 nanoo mode, denormal exponent is -7, but if the fp32/fp16
|
||||
actual exponent is -7, it is actually larger due to the implict 1,
|
||||
Therefore it needs to be adjust to -6 and mantissa shift right by 1.
|
||||
So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */
|
||||
exponent_diff = out_denormal_act_exponent - act_exponent;
|
||||
}
|
||||
else
|
||||
{ // both fp32/fp16 and f8 are in normal range
|
||||
exponent_diff =
|
||||
0; // exponent_diff=0 does not mean there is no difference for this case,
|
||||
// act_exponent could be larger. Just that it does not need shift mantissa
|
||||
}
|
||||
mantissa += (1 << in_mant); // Add the implicit 1 into mantissa
|
||||
}
|
||||
|
||||
bool midpoint = (mantissa & ((1 << (in_mant - out_mant + exponent_diff)) - 1)) ==
|
||||
(1 << (in_mant - out_mant + exponent_diff - 1));
|
||||
/* This part is a bit tricky. The judgment of whether it is a tie needs to be done before we
|
||||
shift right as shift right could rip off some residual part and make something not midpoint look
|
||||
like midpoint. For example, the fp16 number 0x1002 (0 00100 0000000010), it is larger than
|
||||
midpoint, but after shift right by 4 bits, it would look like midpoint. */
|
||||
|
||||
if(exponent_diff > 0)
|
||||
mantissa >>= exponent_diff;
|
||||
else if(exponent_diff == -1)
|
||||
mantissa <<= -exponent_diff;
|
||||
bool implicit_one = mantissa & (1 << in_mant);
|
||||
// if there is no implict 1, it means the f8 is denormal and need to adjust to denorm exponent
|
||||
out_exponent =
|
||||
(act_exponent + exponent_diff) /*actual f8 exponent*/ + out_bias - (implicit_one ? 0 : 1);
|
||||
|
||||
// Now we have the exponent and mantissa adjusted
|
||||
bool odd =
|
||||
mantissa &
|
||||
(1 << (in_mant - out_mant)); // if the least significant bit that is not truncated is 1
|
||||
mantissa += (stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1) : mantissa)) & drop_mask;
|
||||
|
||||
// Now we deal with overflow
|
||||
if(out_exponent == 0)
|
||||
{
|
||||
if((1 << in_mant) & mantissa)
|
||||
{
|
||||
out_exponent = 1; // denormal overflow to become normal, promote exponent
|
||||
// No need to make 1 implicit now as it will be addressed later
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if((1 << (in_mant + 1)) & mantissa)
|
||||
{
|
||||
mantissa >>= 1;
|
||||
out_exponent++;
|
||||
// No need to make 1 implicit now as it will be addressed later
|
||||
}
|
||||
}
|
||||
|
||||
mantissa >>= (in_mant - out_mant);
|
||||
|
||||
if(out_exponent > max_exp)
|
||||
{
|
||||
if(clip)
|
||||
{
|
||||
mantissa = (1 << out_mant) - 1;
|
||||
out_exponent = max_exp;
|
||||
}
|
||||
else
|
||||
{
|
||||
return signed_inf;
|
||||
}
|
||||
}
|
||||
|
||||
// check if x is 0.0 or -0.0
|
||||
if(out_exponent == 0 && mantissa == 0)
|
||||
return negative_zero_nan ? 0 : (sign << (out_exp + out_mant));
|
||||
mantissa &= (1 << out_mant) - 1;
|
||||
return (sign << (out_exp + out_mant)) | (out_exponent << out_mant) | mantissa;
|
||||
}
|
||||
|
||||
template <typename X, typename Y, bool negative_zero_nan>
|
||||
CK_TILE_HOST_DEVICE Y run_cast_from_f8(X x)
|
||||
{
|
||||
// fp8/bf8 exponent/mantissa layout
|
||||
constexpr int in_exp = numeric_utils<X>::exp;
|
||||
constexpr int in_mant = numeric_utils<X>::mant;
|
||||
|
||||
// resulting type exponent/mantissa layout
|
||||
constexpr int out_exp = numeric_utils<Y>::exp;
|
||||
constexpr int out_mant = numeric_utils<Y>::mant;
|
||||
|
||||
// prepare the codes
|
||||
constexpr X nan_code = 0x80;
|
||||
Y Inf, NegInf, NaN, Neg0;
|
||||
using T_bitwise = typename numeric_utils<Y>::bitwise_type;
|
||||
|
||||
constexpr T_bitwise Inf_bitwise = numeric_utils<Y>::Inf;
|
||||
constexpr T_bitwise NegInf_bitwise = numeric_utils<Y>::NegInf;
|
||||
constexpr T_bitwise NaN_bitwise = numeric_utils<Y>::NaN;
|
||||
constexpr T_bitwise Neg0_bitwise = numeric_utils<Y>::Neg0;
|
||||
|
||||
Inf = *(reinterpret_cast<const Y*>(&Inf_bitwise));
|
||||
NegInf = *(reinterpret_cast<const Y*>(&NegInf_bitwise));
|
||||
NaN = *(reinterpret_cast<const Y*>(&NaN_bitwise));
|
||||
Neg0 = *(reinterpret_cast<const Y*>(&Neg0_bitwise));
|
||||
|
||||
// check if x is 0.0
|
||||
if(x == 0)
|
||||
return static_cast<Y>(0);
|
||||
|
||||
// unpack the input
|
||||
uint32_t sign = x >> (in_exp + in_mant);
|
||||
uint32_t mantissa = x & ((1 << in_mant) - 1);
|
||||
int exponent = (x & 0x7F) >> in_mant;
|
||||
|
||||
constexpr int exp_low_cutoff =
|
||||
(1 << (out_exp - 1)) - (1 << (in_exp - 1)) + 1 - (negative_zero_nan ? 1 : 0);
|
||||
T_bitwise retval;
|
||||
|
||||
if constexpr(negative_zero_nan)
|
||||
{
|
||||
if(x == nan_code)
|
||||
return NaN;
|
||||
}
|
||||
else
|
||||
{
|
||||
if(x == nan_code)
|
||||
return Neg0;
|
||||
if(exponent == ((1 << in_exp) - 1))
|
||||
return (mantissa == 0) ? (sign ? NegInf : Inf) : NaN;
|
||||
}
|
||||
|
||||
if((numeric_utils<Y>::mant == 10) && (numeric_utils<X>::mant == 2) && !negative_zero_nan)
|
||||
{
|
||||
retval = x;
|
||||
retval <<= 8;
|
||||
return *(reinterpret_cast<const Y*>(&retval));
|
||||
}
|
||||
|
||||
// subnormal input
|
||||
if(exponent == 0)
|
||||
{
|
||||
// guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above
|
||||
int sh = 1 + clz(mantissa) - (32 - in_mant);
|
||||
mantissa <<= sh;
|
||||
exponent += 1 - sh;
|
||||
mantissa &= ((1 << in_mant) - 1);
|
||||
}
|
||||
exponent += exp_low_cutoff - 1;
|
||||
mantissa <<= out_mant - in_mant;
|
||||
|
||||
// subnormal output (occurs when T=half, we=5, negative_zero_nan=true)
|
||||
if(exponent <= 0)
|
||||
{
|
||||
mantissa |= 1 << out_mant;
|
||||
mantissa >>= 1 - exponent;
|
||||
exponent = 0;
|
||||
}
|
||||
|
||||
retval = (sign << (out_exp + out_mant)) | (exponent << out_mant) | mantissa;
|
||||
return *(reinterpret_cast<const Y*>(&retval));
|
||||
}
|
||||
|
||||
template <typename X, typename Y, bool negative_zero_nan, bool clip, bool stoch>
|
||||
CK_TILE_HOST_DEVICE Y cast_to_f8(X x, uint32_t rng)
|
||||
{
|
||||
// check datatypes
|
||||
constexpr bool is_half = std::is_same<X, half_t>::value;
|
||||
constexpr bool is_float = std::is_same<X, float>::value;
|
||||
static_assert(is_half || is_float, "Only half and float can be casted.");
|
||||
|
||||
return run_cast_to_f8<X, Y, negative_zero_nan, clip, stoch>(x, rng);
|
||||
}
|
||||
|
||||
template <typename X, typename Y, bool negative_zero_nan>
|
||||
CK_TILE_HOST_DEVICE Y cast_from_f8(X x)
|
||||
{
|
||||
// check datatype
|
||||
constexpr bool is_half = std::is_same<Y, half_t>::value;
|
||||
constexpr bool is_float = std::is_same<Y, float>::value;
|
||||
static_assert(is_half || is_float, "only half and float are supported.");
|
||||
|
||||
return run_cast_from_f8<X, Y, negative_zero_nan>(x);
|
||||
}
|
||||
} // namespace impl
|
||||
|
||||
CK_TILE_HOST_DEVICE uint8_t float_to_fp8_sr_raw(float x)
|
||||
{
|
||||
constexpr int seed = 42;
|
||||
uint32_t rng = prand_generator<float, seed>{}(reinterpret_cast<uintptr_t>(&x), x);
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
float max_fp8 = 240.0f;
|
||||
x = x > max_fp8 ? max_fp8 : (x < -max_fp8 ? -max_fp8 : x);
|
||||
union
|
||||
{
|
||||
float fval;
|
||||
uint32_t i32val;
|
||||
uint8_t i8val[4]; // not endian independent
|
||||
} val;
|
||||
val.fval = x;
|
||||
uint32_t ival = 0;
|
||||
ival = __builtin_amdgcn_cvt_sr_fp8_f32(val.fval, rng, ival, 0); // 0 pos
|
||||
val.i32val = ival;
|
||||
return val.i8val[0]; // little endian
|
||||
#else
|
||||
constexpr bool negative_zero_nan = true;
|
||||
constexpr bool clip = true;
|
||||
constexpr fp8_rounding_mode rm = fp8_rounding_mode::stochastic;
|
||||
return impl::
|
||||
cast_to_f8<float, fp8_t, negative_zero_nan, clip, (rm == fp8_rounding_mode::stochastic)>(
|
||||
x, rng);
|
||||
#endif
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE uint8_t float_to_bf8_sr_raw(float x)
|
||||
{
|
||||
constexpr int seed = 42;
|
||||
uint32_t rng = prand_generator<float, seed>{}(reinterpret_cast<uintptr_t>(&x), x);
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
union
|
||||
{
|
||||
float fval;
|
||||
uint32_t i32val;
|
||||
uint8_t i8val[4]; // not endian independent
|
||||
} val;
|
||||
val.fval = x;
|
||||
uint32_t ival = 0;
|
||||
ival = __builtin_amdgcn_cvt_sr_bf8_f32(val.fval, rng, ival, 0); // 0 pos
|
||||
val.i32val = ival;
|
||||
return val.i8val[0]; // little endian
|
||||
#else
|
||||
constexpr bool negative_zero_nan = true;
|
||||
constexpr bool clip = true;
|
||||
constexpr fp8_rounding_mode rm = fp8_rounding_mode::stochastic;
|
||||
return impl::
|
||||
cast_to_f8<float, bf8_t, negative_zero_nan, clip, (rm == fp8_rounding_mode::stochastic)>(
|
||||
x, rng);
|
||||
#endif
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE uint8_t float_to_fp8_rtn_raw(float x)
|
||||
{
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
float max_fp8 = 240.0f;
|
||||
x = x > max_fp8 ? max_fp8 : (x < -max_fp8 ? -max_fp8 : x);
|
||||
union
|
||||
{
|
||||
float fval;
|
||||
uint32_t i32val;
|
||||
uint8_t i8val[4]; // not endian independent
|
||||
} val;
|
||||
val.fval = x;
|
||||
uint32_t ival = 0;
|
||||
ival = __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival, false); // false -> WORD0
|
||||
val.i32val = ival;
|
||||
return val.i8val[0];
|
||||
#else
|
||||
constexpr bool negative_zero_nan = true;
|
||||
constexpr bool clip = true;
|
||||
constexpr fp8_rounding_mode rm = fp8_rounding_mode::standard;
|
||||
constexpr uint32_t rng = 0;
|
||||
return impl::
|
||||
cast_to_f8<float, fp8_t, negative_zero_nan, clip, (rm == fp8_rounding_mode::stochastic)>(
|
||||
x, rng);
|
||||
#endif
|
||||
}
|
||||
CK_TILE_HOST_DEVICE uint8_t float_to_bf8_rtn_raw(float x)
|
||||
{
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
union
|
||||
{
|
||||
float fval;
|
||||
uint32_t i32val;
|
||||
uint8_t i8val[4]; // not endian independent
|
||||
} val;
|
||||
val.fval = x;
|
||||
uint32_t ival = 0;
|
||||
ival = __builtin_amdgcn_cvt_pk_bf8_f32(val.fval, val.fval, ival, false); // false -> WORD0
|
||||
val.i32val = ival;
|
||||
return val.i8val[0];
|
||||
#else
|
||||
constexpr bool negative_zero_nan = true;
|
||||
constexpr bool clip = true;
|
||||
constexpr fp8_rounding_mode rm = fp8_rounding_mode::standard;
|
||||
constexpr uint32_t rng = 0;
|
||||
return impl::
|
||||
cast_to_f8<float, bf8_t, negative_zero_nan, clip, (rm == fp8_rounding_mode::stochastic)>(
|
||||
x, rng);
|
||||
#endif
|
||||
}
|
||||
|
||||
// clang-format off
|
||||
template<fp8_rounding_mode rounding = static_cast<fp8_rounding_mode>(CK_TILE_FLOAT_TO_FP8_DEFAULT)>
|
||||
CK_TILE_HOST_DEVICE uint8_t float_to_fp8_raw(float x, constant<rounding> = {})
|
||||
{
|
||||
if constexpr (rounding == fp8_rounding_mode::standard) return float_to_fp8_rtn_raw(x);
|
||||
else if constexpr (rounding == fp8_rounding_mode::stochastic) return float_to_fp8_sr_raw(x);
|
||||
else return uint8_t{0};
|
||||
}
|
||||
|
||||
template<fp8_rounding_mode rounding = static_cast<fp8_rounding_mode>(CK_TILE_FLOAT_TO_FP8_DEFAULT)>
|
||||
CK_TILE_HOST_DEVICE uint8_t float_to_bf8_raw(float x, constant<rounding> = {})
|
||||
{
|
||||
if constexpr (rounding == fp8_rounding_mode::standard) return float_to_bf8_rtn_raw(x);
|
||||
else if constexpr (rounding == fp8_rounding_mode::stochastic) return float_to_bf8_sr_raw(x);
|
||||
else return uint8_t{0};
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE float fp8_to_float_raw(uint8_t x)
|
||||
{
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
float fval;
|
||||
uint32_t i32val = static_cast<uint32_t>(x);
|
||||
fval = __builtin_amdgcn_cvt_f32_fp8(i32val, 0);
|
||||
// asm volatile("v_cvt_f32_fp8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val));
|
||||
return fval;
|
||||
#else
|
||||
constexpr bool negative_zero_nan = true;
|
||||
return impl::cast_from_f8<fp8_t, float, negative_zero_nan>(x);
|
||||
#endif
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE float bf8_to_float_raw(uint8_t x)
|
||||
{
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
float fval;
|
||||
uint32_t i32val = static_cast<uint32_t>(x);
|
||||
fval = __builtin_amdgcn_cvt_f32_bf8(i32val, 0);
|
||||
// asm volatile("v_cvt_f32_bf8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val));
|
||||
return fval;
|
||||
#else
|
||||
constexpr bool negative_zero_nan = true;
|
||||
return impl::cast_from_f8<bf8_t, float, negative_zero_nan>(x);
|
||||
#endif
|
||||
}
|
||||
|
||||
template<fp8_rounding_mode rounding = static_cast<fp8_rounding_mode>(CK_TILE_FLOAT_TO_FP8_DEFAULT)>
|
||||
CK_TILE_HOST_DEVICE float8_e4m3_t float_to_fp8(float x, constant<rounding> = {})
|
||||
{
|
||||
return float8_e4m3_t::bit_cast(float_to_fp8_raw(x, constant<rounding>{}));
|
||||
}
|
||||
|
||||
template<fp8_rounding_mode rounding = static_cast<fp8_rounding_mode>(CK_TILE_FLOAT_TO_FP8_DEFAULT)>
|
||||
CK_TILE_HOST_DEVICE float8_e5m2_t float_to_bf8(float x, constant<rounding> = {})
|
||||
{
|
||||
return float8_e5m2_t::bit_cast(float_to_bf8_raw(x, constant<rounding>{}));
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE float fp8_to_float(float8_e4m3_t x)
|
||||
{
|
||||
return fp8_to_float_raw(x.get());
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE float bf8_to_float(float8_e5m2_t x)
|
||||
{
|
||||
return bf8_to_float_raw(x.get());
|
||||
}
|
||||
|
||||
// clang-format on
|
||||
using fp8_t = float8_e4m3_t;
|
||||
using bf8_t = float8_e5m2_t;
|
||||
|
||||
template <typename T>
|
||||
struct numeric_utils;
|
||||
|
||||
template <>
|
||||
struct numeric_utils<fp8_t>
|
||||
{
|
||||
static constexpr int exp = fp8_t::exponent;
|
||||
static constexpr int mant = fp8_t::mantissa;
|
||||
static constexpr int bias = fp8_t::bias;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct numeric_utils<bf8_t>
|
||||
{
|
||||
static constexpr int exp = bf8_t::exponent;
|
||||
static constexpr int mant = bf8_t::mantissa;
|
||||
static constexpr int bias = bf8_t::bias;
|
||||
};
|
||||
|
||||
template <class T>
|
||||
struct numeric_limits;
|
||||
|
||||
template <>
|
||||
struct numeric_limits<fp8_t>
|
||||
{
|
||||
// minimum finite value, or minimum positive normalized value for float
|
||||
CK_TILE_HOST_DEVICE static constexpr fp8_t min() { return fp8_t::bit_cast(0x08); }
|
||||
|
||||
// minumum finite value
|
||||
CK_TILE_HOST_DEVICE static constexpr fp8_t lowest() { return fp8_t::bit_cast(0xff); }
|
||||
|
||||
// maximum finite value
|
||||
CK_TILE_HOST_DEVICE static constexpr fp8_t max() { return fp8_t::bit_cast(0x7f); }
|
||||
|
||||
// difference between 1.0 and next value representable by float
|
||||
CK_TILE_HOST_DEVICE static constexpr fp8_t epsilon() { return fp8_t::bit_cast(0x20); }
|
||||
|
||||
// maximum rounding error
|
||||
CK_TILE_HOST_DEVICE static constexpr fp8_t round_error() { return fp8_t(0.5f); }
|
||||
|
||||
// positive infinity value
|
||||
CK_TILE_HOST_DEVICE static constexpr fp8_t infinity() { return fp8_t::bit_cast(0x80); }
|
||||
|
||||
// quiet NaN
|
||||
CK_TILE_HOST_DEVICE static constexpr fp8_t quiet_NaN() { return fp8_t::bit_cast(0x80); }
|
||||
|
||||
// signaling NaN
|
||||
CK_TILE_HOST_DEVICE static constexpr fp8_t signaling_NaN() { return fp8_t::bit_cast(0x80); }
|
||||
|
||||
// smallest positive subnormal value
|
||||
CK_TILE_HOST_DEVICE static constexpr fp8_t denorm_min() { return fp8_t::bit_cast(0x01); }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct numeric_limits<bf8_t>
|
||||
{
|
||||
// minimum finite value, or minimum positive normalized value for float
|
||||
CK_TILE_HOST_DEVICE static constexpr bf8_t min() { return bf8_t::bit_cast(0x04); }
|
||||
|
||||
// minumum finite value
|
||||
CK_TILE_HOST_DEVICE static constexpr bf8_t lowest() { return bf8_t::bit_cast(0xff); }
|
||||
|
||||
// maximum finite value
|
||||
CK_TILE_HOST_DEVICE static constexpr bf8_t max() { return bf8_t::bit_cast(0x7f); }
|
||||
|
||||
// difference between 1.0 and next value representable by float
|
||||
CK_TILE_HOST_DEVICE static constexpr bf8_t epsilon() { return bf8_t::bit_cast(0x34); }
|
||||
|
||||
// maximum rounding error
|
||||
CK_TILE_HOST_DEVICE static constexpr bf8_t round_error() { return bf8_t(0.5f); }
|
||||
|
||||
// positive infinity value
|
||||
CK_TILE_HOST_DEVICE static constexpr bf8_t infinity() { return bf8_t::bit_cast(0x80); }
|
||||
|
||||
// quiet NaN
|
||||
CK_TILE_HOST_DEVICE static constexpr bf8_t quiet_NaN() { return bf8_t::bit_cast(0x80); }
|
||||
|
||||
// signaling NaN
|
||||
CK_TILE_HOST_DEVICE static constexpr bf8_t signaling_NaN() { return bf8_t::bit_cast(0x80); }
|
||||
|
||||
// smallest positive subnormal value
|
||||
CK_TILE_HOST_DEVICE static constexpr bf8_t denorm_min() { return bf8_t::bit_cast(0x01); }
|
||||
};
|
||||
|
||||
CK_TILE_ARITHMETIC_USING_FLOAT(fp8_t)
|
||||
CK_TILE_ARITHMETIC_USING_FLOAT(bf8_t)
|
||||
|
||||
// math
|
||||
CK_TILE_HOST_DEVICE
|
||||
fp8_t abs(const fp8_t& x) { return fp8_t::bit_cast(x.get() & 0x7f); }
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
bool isnan(const fp8_t& x)
|
||||
{
|
||||
uint8_t xx = x.get();
|
||||
return xx == 0x80; // TODO: NANOO
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE
|
||||
fp8_t sqrt(fp8_t x) { return static_cast<fp8_t>(__builtin_amdgcn_sqrtf(static_cast<float>(x))); };
|
||||
|
||||
CK_TILE_DEVICE
|
||||
fp8_t exp(fp8_t x) { return static_cast<fp8_t>(__expf(static_cast<float>(x))); };
|
||||
|
||||
CK_TILE_DEVICE
|
||||
fp8_t exp2(fp8_t x) { return static_cast<fp8_t>(exp2f(static_cast<float>(x))); };
|
||||
|
||||
CK_TILE_DEVICE
|
||||
fp8_t log(fp8_t x) { return static_cast<fp8_t>(__logf(static_cast<float>(x))); };
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
bf8_t abs(const bf8_t& x) { return bf8_t::bit_cast(x.get() & 0x7f); }
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
bool isnan(const bf8_t& x)
|
||||
{
|
||||
uint8_t xx = x.get();
|
||||
return xx == 0x80; // TODO: NANOO
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE
|
||||
bf8_t sqrt(bf8_t x) { return static_cast<bf8_t>(__builtin_amdgcn_sqrtf(static_cast<float>(x))); };
|
||||
|
||||
CK_TILE_DEVICE
|
||||
bf8_t exp(bf8_t x) { return static_cast<bf8_t>(__expf(static_cast<float>(x))); };
|
||||
|
||||
CK_TILE_DEVICE
|
||||
bf8_t exp2(bf8_t x) { return static_cast<bf8_t>(exp2f(static_cast<float>(x))); };
|
||||
|
||||
CK_TILE_DEVICE
|
||||
bf8_t log(bf8_t x) { return static_cast<bf8_t>(__logf(static_cast<float>(x))); };
|
||||
|
||||
} // namespace ck_tile
|
||||
278
include/ck_tile/core/numeric/half.hpp
Normal file
278
include/ck_tile/core/numeric/half.hpp
Normal file
@@ -0,0 +1,278 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/utility/bit_cast.hpp"
|
||||
#include <hip/hip_fp16.h>
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
float fp16_to_float_hip(const _Float16& x);
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
_Float16 float_to_fp16_hip(const float& x);
|
||||
|
||||
// HIP use _Float16 as interchangable data type for float16
|
||||
struct alignas(2) half_t
|
||||
{
|
||||
using raw_type = uint16_t;
|
||||
raw_type data;
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
static half_t bit_cast(raw_type x)
|
||||
{
|
||||
half_t y;
|
||||
y.data = x;
|
||||
return y;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
_Float16 to_fp16() const { return reinterpret_cast<const raw_type&>(data); }
|
||||
|
||||
// constructor
|
||||
half_t() = default;
|
||||
|
||||
// construct from HIP half
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit half_t(const _Float16& x) : data(reinterpret_cast<const raw_type&>(x)) {}
|
||||
|
||||
// construct from float
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit half_t(const float& x) : half_t(float_to_fp16_hip(x)) {}
|
||||
|
||||
// construct from int
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit half_t(const int& x) : half_t(__int2half_rn(x)) {}
|
||||
|
||||
// construct from unsigned int
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit half_t(const unsigned int& x) : half_t(__uint2half_rn(x)) {}
|
||||
|
||||
// cast to float
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit operator float() const { return fp16_to_float_hip(to_fp16()); }
|
||||
|
||||
// cast to int
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit operator int() const { return static_cast<int>(fp16_to_float_hip(to_fp16())); }
|
||||
|
||||
// internal access
|
||||
CK_TILE_HOST_DEVICE
|
||||
raw_type& get() { return data; }
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
raw_type get() const { return data; }
|
||||
};
|
||||
|
||||
// conversions
|
||||
CK_TILE_HOST_DEVICE
|
||||
float fp16_to_float_hip(const _Float16& x)
|
||||
{
|
||||
// return __half2float(x);
|
||||
return static_cast<float>(x);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
_Float16 float_to_fp16_hip(const float& x)
|
||||
{
|
||||
// return __float2half(x);
|
||||
return static_cast<_Float16>(x);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
float fp16_to_float(const half_t& x) { return static_cast<float>(x); }
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
half_t float_to_fp16(const float& x) { return half_t{x}; }
|
||||
|
||||
// limits
|
||||
template <class T>
|
||||
struct numeric_limits;
|
||||
|
||||
template <>
|
||||
struct numeric_limits<half_t>
|
||||
{
|
||||
// minimum finite value, or minimum positive normalized value for float
|
||||
CK_TILE_HOST_DEVICE static constexpr half_t min() { return half_t::bit_cast(0x0400); }
|
||||
|
||||
// minumum finite value
|
||||
CK_TILE_HOST_DEVICE static constexpr half_t lowest() { return half_t::bit_cast(0xFBFF); }
|
||||
|
||||
// maximum finite value
|
||||
CK_TILE_HOST_DEVICE static constexpr half_t max() { return half_t::bit_cast(0x7BFF); }
|
||||
|
||||
// difference between 1.0 and next value representable by float
|
||||
CK_TILE_HOST_DEVICE static constexpr half_t epsilon() { return half_t::bit_cast(0x1800); }
|
||||
|
||||
// maximum rounding error
|
||||
CK_TILE_HOST_DEVICE static constexpr half_t round_error() { return half_t(0.5f); }
|
||||
|
||||
// positive infinity value
|
||||
CK_TILE_HOST_DEVICE static constexpr half_t infinity() { return half_t::bit_cast(0x7C00); }
|
||||
|
||||
// quiet NaN
|
||||
CK_TILE_HOST_DEVICE static constexpr half_t quiet_NaN() { return half_t::bit_cast(0x7FFF); }
|
||||
|
||||
// signaling NaN
|
||||
CK_TILE_HOST_DEVICE static constexpr half_t signaling_NaN() { return half_t::bit_cast(0x7FFF); }
|
||||
|
||||
// smallest positive subnormal value
|
||||
CK_TILE_HOST_DEVICE static constexpr half_t denorm_min() { return half_t::bit_cast(0x0001); }
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct numeric_utils;
|
||||
|
||||
template <>
|
||||
struct numeric_utils<half_t>
|
||||
{
|
||||
static constexpr int exp = 5;
|
||||
static constexpr int mant = 10;
|
||||
static constexpr int bias = 15;
|
||||
static constexpr uint16_t nan_mask = 0x7C00;
|
||||
static constexpr uint16_t head_mask = 0xFC00;
|
||||
static constexpr uint16_t mant_mask = 0x3FF;
|
||||
static constexpr uint16_t exp_mask = 0x1F;
|
||||
static constexpr uint32_t Inf = 0x7C00;
|
||||
static constexpr uint32_t NegInf = 0xFC00;
|
||||
static constexpr uint32_t NaN = 0x7C01;
|
||||
static constexpr uint32_t Neg0 = 0x8000;
|
||||
using bitwise_type = uint16_t;
|
||||
};
|
||||
|
||||
// arithmetic
|
||||
CK_TILE_HOST_DEVICE
|
||||
bool operator==(const half_t& x, const half_t& y) { return __heq(x.to_fp16(), y.to_fp16()); }
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
bool operator!=(const half_t& x, const half_t& y) { return __hne(x.to_fp16(), y.to_fp16()); }
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
bool operator<(const half_t& x, const half_t& y) { return __hlt(x.to_fp16(), y.to_fp16()); }
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
bool operator<=(const half_t& x, const half_t& y) { return __hle(x.to_fp16(), y.to_fp16()); }
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
bool operator>(const half_t& x, const half_t& y) { return __hgt(x.to_fp16(), y.to_fp16()); }
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
bool operator>=(const half_t& x, const half_t& y) { return __hge(x.to_fp16(), y.to_fp16()); }
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
half_t operator+(const half_t& x, const half_t& y)
|
||||
{
|
||||
return half_t(__hadd(x.to_fp16(), y.to_fp16()));
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
half_t operator-(const half_t& x) { return half_t(__hneg(x.to_fp16())); }
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
half_t operator-(const half_t& x, const half_t& y)
|
||||
{
|
||||
return half_t(__hsub(x.to_fp16(), y.to_fp16()));
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
half_t operator*(const half_t& x, const half_t& y)
|
||||
{
|
||||
return half_t(__hmul(x.to_fp16(), y.to_fp16()));
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
half_t operator/(const half_t& x, const half_t& y)
|
||||
{
|
||||
return half_t(__hdiv(x.to_fp16(), y.to_fp16()));
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
half_t& operator+=(half_t& x, const half_t& y)
|
||||
{
|
||||
x = half_t(__hadd(x.to_fp16(), y.to_fp16()));
|
||||
return x;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
half_t& operator-=(half_t& x, const half_t& y)
|
||||
{
|
||||
x = half_t(__hsub(x.to_fp16(), y.to_fp16()));
|
||||
return x;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
half_t& operator*=(half_t& x, const half_t& y)
|
||||
{
|
||||
x = half_t(__hmul(x.to_fp16(), y.to_fp16()));
|
||||
return x;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
half_t& operator/=(half_t& x, const half_t& y)
|
||||
{
|
||||
x = half_t(__hdiv(x.to_fp16(), y.to_fp16()));
|
||||
return x;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
half_t& operator++(half_t& x)
|
||||
{
|
||||
x = half_t(__hadd(x.to_fp16(), half_t(1.0f).to_fp16()));
|
||||
return x;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
half_t& operator--(half_t& x)
|
||||
{
|
||||
x = half_t(__hsub(x.to_fp16(), half_t(1.0f).to_fp16()));
|
||||
return x;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
half_t operator++(half_t& x, int)
|
||||
{
|
||||
half_t y(x);
|
||||
x = half_t(__hadd(x.to_fp16(), half_t(1.0f).to_fp16()));
|
||||
return y;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
half_t operator--(half_t& x, int)
|
||||
{
|
||||
half_t y(x);
|
||||
x = half_t(__hsub(x.to_fp16(), half_t(1.0f).to_fp16()));
|
||||
return y;
|
||||
}
|
||||
|
||||
// math
|
||||
CK_TILE_HOST_DEVICE
|
||||
half_t abs(const half_t& x) { return half_t::bit_cast(x.get() & 0x7fff); }
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
bool isnan(const half_t& x)
|
||||
{
|
||||
uint16_t xx = x.get();
|
||||
return (xx & 0x7FFF) > 0x7C00;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE
|
||||
half_t sqrt(half_t x)
|
||||
{
|
||||
return static_cast<half_t>(__builtin_amdgcn_sqrtf(static_cast<float>(x)));
|
||||
};
|
||||
|
||||
CK_TILE_DEVICE
|
||||
half_t exp(half_t x) { return static_cast<half_t>(__expf(static_cast<float>(x))); };
|
||||
|
||||
CK_TILE_DEVICE
|
||||
half_t exp2(half_t x) { return static_cast<half_t>(exp2f(static_cast<float>(x))); };
|
||||
|
||||
CK_TILE_DEVICE
|
||||
half_t log(half_t x) { return static_cast<half_t>(__logf(static_cast<float>(x))); };
|
||||
|
||||
using fp16_t = half_t;
|
||||
|
||||
} // namespace ck_tile
|
||||
13
include/ck_tile/core/numeric/integer.hpp
Normal file
13
include/ck_tile/core/numeric/integer.hpp
Normal file
@@ -0,0 +1,13 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
#include <stdint.h>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
using index_t = int32_t;
|
||||
using long_index_t = int64_t;
|
||||
using int8_t = int8_t;
|
||||
|
||||
} // namespace ck_tile
|
||||
82
include/ck_tile/core/numeric/integral_constant.hpp
Normal file
82
include/ck_tile/core/numeric/integral_constant.hpp
Normal file
@@ -0,0 +1,82 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <auto v>
|
||||
struct constant
|
||||
{
|
||||
using value_type = decltype(v);
|
||||
using type = constant; // using injected-class-name
|
||||
static constexpr value_type value = v;
|
||||
constexpr CK_TILE_HOST_DEVICE operator value_type() const noexcept { return value; }
|
||||
constexpr CK_TILE_HOST_DEVICE value_type operator()() const noexcept { return value; }
|
||||
};
|
||||
|
||||
template <typename T, T v>
|
||||
struct integral_constant : constant<v>
|
||||
{
|
||||
using value_type = T;
|
||||
using type = integral_constant; // using injected-class-name
|
||||
static constexpr T value = v;
|
||||
// constexpr CK_TILE_HOST_DEVICE operator value_type() const noexcept { return value; }
|
||||
// constexpr CK_TILE_HOST_DEVICE value_type operator()() const noexcept { return value; } //
|
||||
};
|
||||
|
||||
template <index_t v>
|
||||
using number = constant<v>;
|
||||
|
||||
template <long_index_t v>
|
||||
using long_number = integral_constant<long_index_t, v>;
|
||||
|
||||
template <bool b>
|
||||
using bool_constant = constant<b>;
|
||||
|
||||
#define CK_TILE_LEFT_UNARY_OP(OP) \
|
||||
template <auto x> \
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator OP(constant<x>) \
|
||||
{ \
|
||||
return constant<(OP x)>{}; \
|
||||
}
|
||||
|
||||
#define CK_TILE_BINARY_OP(OP) \
|
||||
template <auto x, auto y> \
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator OP(constant<x>, constant<y>) \
|
||||
{ \
|
||||
return constant<(x OP y)>{}; \
|
||||
}
|
||||
|
||||
CK_TILE_LEFT_UNARY_OP(+)
|
||||
CK_TILE_LEFT_UNARY_OP(-)
|
||||
CK_TILE_LEFT_UNARY_OP(~)
|
||||
CK_TILE_LEFT_UNARY_OP(!)
|
||||
CK_TILE_LEFT_UNARY_OP(*)
|
||||
|
||||
CK_TILE_BINARY_OP(+)
|
||||
CK_TILE_BINARY_OP(-)
|
||||
CK_TILE_BINARY_OP(*)
|
||||
CK_TILE_BINARY_OP(/)
|
||||
CK_TILE_BINARY_OP(%)
|
||||
CK_TILE_BINARY_OP(&)
|
||||
CK_TILE_BINARY_OP(|)
|
||||
CK_TILE_BINARY_OP(^)
|
||||
CK_TILE_BINARY_OP(<<)
|
||||
CK_TILE_BINARY_OP(>>)
|
||||
CK_TILE_BINARY_OP(&&)
|
||||
CK_TILE_BINARY_OP(||)
|
||||
CK_TILE_BINARY_OP(==)
|
||||
CK_TILE_BINARY_OP(!=)
|
||||
CK_TILE_BINARY_OP(>)
|
||||
CK_TILE_BINARY_OP(<)
|
||||
CK_TILE_BINARY_OP(>=)
|
||||
CK_TILE_BINARY_OP(<=)
|
||||
|
||||
#undef CK_TILE_LEFT_UNARY_OP
|
||||
#undef CK_TILE_BINARY_OP
|
||||
|
||||
} // namespace ck_tile
|
||||
309
include/ck_tile/core/numeric/math.hpp
Normal file
309
include/ck_tile/core/numeric/math.hpp
Normal file
@@ -0,0 +1,309 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include <type_traits>
|
||||
#include <stdint.h>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename T, T s>
|
||||
struct scales
|
||||
{
|
||||
CK_TILE_HOST_DEVICE constexpr T operator()(T a) const { return s * a; }
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct plus
|
||||
{
|
||||
CK_TILE_HOST_DEVICE constexpr T operator()(T a, T b) const { return a + b; }
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct minus
|
||||
{
|
||||
CK_TILE_HOST_DEVICE constexpr T operator()(T a, T b) const { return a - b; }
|
||||
};
|
||||
|
||||
struct multiplies
|
||||
{
|
||||
template <typename A, typename B>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator()(const A& a, const B& b) const
|
||||
{
|
||||
return a * b;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct maximize
|
||||
{
|
||||
CK_TILE_HOST_DEVICE constexpr T operator()(T a, T b) const { return a >= b ? a : b; }
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct minimize
|
||||
{
|
||||
CK_TILE_HOST_DEVICE constexpr T operator()(T a, T b) const { return a <= b ? a : b; }
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct integer_divide_ceiler
|
||||
{
|
||||
CK_TILE_HOST_DEVICE constexpr T operator()(T a, T b) const
|
||||
{
|
||||
static_assert(std::is_same<T, index_t>{} || std::is_same<T, int>{}, "wrong type");
|
||||
return (a + b - number<1>{}) / b;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename X, typename Y>
|
||||
CK_TILE_HOST_DEVICE constexpr auto integer_divide_floor(X x, Y y)
|
||||
{
|
||||
return x / y;
|
||||
}
|
||||
|
||||
template <typename X, typename Y>
|
||||
CK_TILE_HOST_DEVICE constexpr auto integer_divide_ceil(X x, Y y)
|
||||
{
|
||||
return (x + y - number<1>{}) / y;
|
||||
}
|
||||
|
||||
template <typename X, typename Y>
|
||||
CK_TILE_HOST_DEVICE constexpr auto integer_least_multiple(X x, Y y)
|
||||
{
|
||||
return y * integer_divide_ceil(x, y);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE constexpr T max(T x)
|
||||
{
|
||||
return x;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE constexpr T max(T x, T y)
|
||||
{
|
||||
return x > y ? x : y;
|
||||
}
|
||||
|
||||
template <index_t X>
|
||||
CK_TILE_HOST_DEVICE constexpr index_t max(number<X>, index_t y)
|
||||
{
|
||||
return X > y ? X : y;
|
||||
}
|
||||
|
||||
template <index_t Y>
|
||||
CK_TILE_HOST_DEVICE constexpr index_t max(index_t x, number<Y>)
|
||||
{
|
||||
return x > Y ? x : Y;
|
||||
}
|
||||
|
||||
template <typename X, typename... Ys>
|
||||
CK_TILE_HOST_DEVICE constexpr auto max(X x, Ys... ys)
|
||||
{
|
||||
static_assert(sizeof...(Ys) > 0, "not enough argument");
|
||||
return max(x, max(ys...));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE constexpr T min(T x)
|
||||
{
|
||||
return x;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE constexpr T min(T x, T y)
|
||||
{
|
||||
return x < y ? x : y;
|
||||
}
|
||||
|
||||
template <index_t X>
|
||||
CK_TILE_HOST_DEVICE constexpr index_t min(number<X>, index_t y)
|
||||
{
|
||||
return X < y ? X : y;
|
||||
}
|
||||
|
||||
template <index_t Y>
|
||||
CK_TILE_HOST_DEVICE constexpr index_t min(index_t x, number<Y>)
|
||||
{
|
||||
return x < Y ? x : Y;
|
||||
}
|
||||
|
||||
template <typename X, typename... Ys>
|
||||
CK_TILE_HOST_DEVICE constexpr auto min(X x, Ys... ys)
|
||||
{
|
||||
static_assert(sizeof...(Ys) > 0, "not enough argument");
|
||||
return min(x, min(ys...));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE constexpr T clamp(const T& x, const T& lowerbound, const T& upperbound)
|
||||
{
|
||||
return min(max(x, lowerbound), upperbound);
|
||||
}
|
||||
|
||||
// greatest common divisor, aka highest common factor
|
||||
CK_TILE_HOST_DEVICE constexpr index_t gcd(index_t x, index_t y)
|
||||
{
|
||||
if(x < 0)
|
||||
{
|
||||
return gcd(-x, y);
|
||||
}
|
||||
else if(y < 0)
|
||||
{
|
||||
return gcd(x, -y);
|
||||
}
|
||||
else if(x == y || x == 0)
|
||||
{
|
||||
return y;
|
||||
}
|
||||
else if(y == 0)
|
||||
{
|
||||
return x;
|
||||
}
|
||||
else if(x > y)
|
||||
{
|
||||
return gcd(x % y, y);
|
||||
}
|
||||
else
|
||||
{
|
||||
return gcd(x, y % x);
|
||||
}
|
||||
}
|
||||
|
||||
template <index_t X, index_t Y>
|
||||
CK_TILE_HOST_DEVICE constexpr auto gcd(number<X>, number<Y>)
|
||||
{
|
||||
constexpr auto r = gcd(X, Y);
|
||||
|
||||
return number<r>{};
|
||||
}
|
||||
|
||||
template <typename X,
|
||||
typename... Ys,
|
||||
typename std::enable_if<sizeof...(Ys) >= 2, bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr auto gcd(X x, Ys... ys)
|
||||
{
|
||||
return gcd(x, gcd(ys...));
|
||||
}
|
||||
|
||||
// least common multiple
|
||||
template <typename X, typename Y>
|
||||
CK_TILE_HOST_DEVICE constexpr auto lcm(X x, Y y)
|
||||
{
|
||||
return (x * y) / gcd(x, y);
|
||||
}
|
||||
|
||||
template <typename X,
|
||||
typename... Ys,
|
||||
typename std::enable_if<sizeof...(Ys) >= 2, bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr auto lcm(X x, Ys... ys)
|
||||
{
|
||||
return lcm(x, lcm(ys...));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct equal
|
||||
{
|
||||
CK_TILE_HOST_DEVICE constexpr bool operator()(T x, T y) const { return x == y; }
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct less
|
||||
{
|
||||
CK_TILE_HOST_DEVICE constexpr bool operator()(T x, T y) const { return x < y; }
|
||||
};
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr int32_t next_power_of_two(int32_t x)
|
||||
{
|
||||
// TODO: x need to be 2 ~ 0x7fffffff. 0, 1, or larger than 0x7fffffff will compile fail
|
||||
return 1 << (32 - __builtin_clz(x - 1));
|
||||
}
|
||||
|
||||
template <index_t X>
|
||||
CK_TILE_HOST_DEVICE constexpr auto next_power_of_two()
|
||||
{
|
||||
constexpr index_t y = next_power_of_two(X);
|
||||
return number<y>{};
|
||||
}
|
||||
|
||||
template <index_t X>
|
||||
CK_TILE_HOST_DEVICE constexpr auto next_power_of_two(number<X>)
|
||||
{
|
||||
constexpr index_t y = next_power_of_two(X);
|
||||
return number<y>{};
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr int32_t integer_log2_floor(int32_t x)
|
||||
{
|
||||
// TODO: x need to be 1 ~ 0x7fffffff
|
||||
// __builtin_clz will produce unexpected result if x is 0;
|
||||
return 31 - __builtin_clz(x);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr bool is_power_of_two_integer(int32_t x)
|
||||
{
|
||||
// TODO: x need to be 1 ~ 0x7fffffff
|
||||
return x == (1 << integer_log2_floor(x));
|
||||
}
|
||||
|
||||
#ifndef C_LOG2E
|
||||
#define C_LOG2E 1.44269504088896340736 // log2(e)
|
||||
#endif
|
||||
|
||||
template <typename T>
|
||||
struct log2e;
|
||||
|
||||
template <>
|
||||
struct log2e<double>
|
||||
{
|
||||
static constexpr double value = C_LOG2E;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct log2e<float>
|
||||
{
|
||||
static constexpr float value = C_LOG2E;
|
||||
};
|
||||
|
||||
template <typename T = double>
|
||||
inline constexpr T log2e_v = log2e<T>::value;
|
||||
|
||||
// math
|
||||
CK_TILE_HOST_DEVICE
|
||||
float abs(const float& x)
|
||||
{
|
||||
union
|
||||
{
|
||||
float f32;
|
||||
uint32_t u32;
|
||||
} y;
|
||||
y.f32 = x;
|
||||
y.u32 = y.u32 & 0x7fffffff;
|
||||
return y.f32;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
bool isnan(const float& x)
|
||||
{
|
||||
uint32_t xx = reinterpret_cast<const uint32_t&>(x);
|
||||
return (xx & 0x7fffffff) > 0x7F800000;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE
|
||||
float sqrt(float x) { return __builtin_amdgcn_sqrtf(x); };
|
||||
|
||||
CK_TILE_DEVICE
|
||||
float exp(float x) { return __expf(x); };
|
||||
|
||||
CK_TILE_DEVICE
|
||||
float exp2(float x) { return exp2f(x); };
|
||||
|
||||
CK_TILE_DEVICE
|
||||
float log(float x) { return __logf(x); };
|
||||
|
||||
} // namespace ck_tile
|
||||
45
include/ck_tile/core/numeric/type_convert.hpp
Normal file
45
include/ck_tile/core/numeric/type_convert.hpp
Normal file
@@ -0,0 +1,45 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include <type_traits>
|
||||
|
||||
namespace ck_tile {
|
||||
#if 0
|
||||
// Convert X to Y, both X and Y are non-const data types.
|
||||
template <typename Y,
|
||||
typename X,
|
||||
std::enable_if_t<!(std::is_const_v<Y> || std::is_const_v<X>), bool> = false>
|
||||
CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
|
||||
{
|
||||
static_assert(!std::is_reference_v<Y> && !std::is_reference_v<X>);
|
||||
|
||||
return static_cast<Y>(x);
|
||||
}
|
||||
|
||||
// TODO: const version never called, we may never need
|
||||
// Convert X to Y, either X or Y is a const data type.
|
||||
template <typename Y,
|
||||
typename X,
|
||||
std::enable_if_t<std::is_const_v<Y> || std::is_const_v<X>, bool> = false>
|
||||
CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
|
||||
{
|
||||
static_assert(!std::is_reference_v<Y> && !std::is_reference_v<X>);
|
||||
|
||||
using NonConstY = std::remove_const_t<Y>;
|
||||
using NonConstX = std::remove_const_t<X>;
|
||||
return static_cast<Y>(type_convert<NonConstY, NonConstX>(x));
|
||||
}
|
||||
#else
|
||||
// compatible way to call conversion operator and constructor of each custom data type
|
||||
template <typename Y, typename X>
|
||||
CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
|
||||
{
|
||||
static_assert(!std::is_reference_v<Y> && !std::is_reference_v<X>);
|
||||
|
||||
return static_cast<Y>(x);
|
||||
}
|
||||
#endif
|
||||
} // namespace ck_tile
|
||||
304
include/ck_tile/core/numeric/vector_type.hpp
Normal file
304
include/ck_tile/core/numeric/vector_type.hpp
Normal file
@@ -0,0 +1,304 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/container/array.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/numeric/float8.hpp"
|
||||
#include "ck_tile/core/numeric/half.hpp"
|
||||
#include "ck_tile/core/numeric/bfloat16.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// TODO: the whole content of this file should consider deprecated!
|
||||
template <typename T_, index_t N_>
|
||||
struct vector_type
|
||||
{
|
||||
static constexpr index_t N = N_;
|
||||
using value_type = T_;
|
||||
using type = value_type __attribute__((ext_vector_type(N))); // this is danguous
|
||||
|
||||
CK_HOST_DEVICE constexpr vector_type()
|
||||
{
|
||||
for(auto i = 0; i < N; i++)
|
||||
data[i] = static_cast<value_type>(0);
|
||||
}
|
||||
CK_HOST_DEVICE constexpr vector_type(type v)
|
||||
{
|
||||
auto& r = reinterpret_cast<const array<value_type, N>&>(v);
|
||||
for(auto i = 0; i < N; i++)
|
||||
data[i] = r.get(i);
|
||||
}
|
||||
|
||||
value_type data[N];
|
||||
CK_HOST_DEVICE static constexpr auto size() { return N; }
|
||||
CK_HOST_DEVICE auto& get() { return data; }
|
||||
CK_HOST_DEVICE const auto& get() const { return data; }
|
||||
CK_HOST_DEVICE auto& get(index_t i) { return data[i]; }
|
||||
CK_HOST_DEVICE const auto& get(index_t i) const { return data[i]; }
|
||||
|
||||
template <index_t I>
|
||||
CK_HOST_DEVICE auto& operator[](number<I>)
|
||||
{
|
||||
return data[I];
|
||||
}
|
||||
template <index_t I>
|
||||
CK_HOST_DEVICE const auto& operator[](number<I>) const
|
||||
{
|
||||
return data[I];
|
||||
}
|
||||
template <index_t I>
|
||||
CK_HOST_DEVICE auto& operator()(number<I>)
|
||||
{
|
||||
return data[I];
|
||||
}
|
||||
|
||||
CK_HOST_DEVICE auto& at(index_t i) { return data[i]; }
|
||||
CK_HOST_DEVICE const auto& at(index_t i) const { return data[i]; }
|
||||
template <index_t I>
|
||||
CK_HOST_DEVICE auto& at()
|
||||
{
|
||||
return data[I];
|
||||
}
|
||||
template <index_t I>
|
||||
CK_HOST_DEVICE const auto& at() const
|
||||
{
|
||||
return data[I];
|
||||
}
|
||||
template <index_t I>
|
||||
CK_HOST_DEVICE auto& at(number<I>)
|
||||
{
|
||||
return data[I];
|
||||
}
|
||||
template <index_t I>
|
||||
CK_HOST_DEVICE const auto& at(number<I>) const
|
||||
{
|
||||
return data[I];
|
||||
}
|
||||
|
||||
#define _VT_COMMON_AS() \
|
||||
static_assert(sizeof(value_type) * N % sizeof(Tx) == 0); \
|
||||
constexpr int vx = sizeof(value_type) * N / sizeof(Tx)
|
||||
|
||||
template <typename Tx>
|
||||
CK_HOST_DEVICE auto& get_as()
|
||||
{
|
||||
_VT_COMMON_AS();
|
||||
return reinterpret_cast<array<Tx, vx>&>(data);
|
||||
}
|
||||
template <typename Tx>
|
||||
CK_HOST_DEVICE const auto& get_as() const
|
||||
{
|
||||
_VT_COMMON_AS();
|
||||
return reinterpret_cast<const array<Tx, vx>&>(data);
|
||||
}
|
||||
template <typename Tx>
|
||||
CK_HOST_DEVICE auto& get_as(index_t i)
|
||||
{
|
||||
_VT_COMMON_AS();
|
||||
return reinterpret_cast<array<Tx, vx>&>(data).get(i);
|
||||
}
|
||||
template <typename Tx>
|
||||
CK_HOST_DEVICE const auto& get_as(index_t i) const
|
||||
{
|
||||
_VT_COMMON_AS();
|
||||
return reinterpret_cast<const array<Tx, vx>&>(data).get(i);
|
||||
}
|
||||
#undef _VT_COMMON_AS
|
||||
};
|
||||
|
||||
template <typename T, index_t N>
|
||||
struct vector_type_maker
|
||||
{
|
||||
using type = vector_type<T, N>;
|
||||
};
|
||||
|
||||
template <typename T, index_t N0, index_t N1>
|
||||
struct vector_type_maker<T __attribute__((ext_vector_type(N1))), N0>
|
||||
{
|
||||
using type = vector_type<T, N0 * N1>;
|
||||
};
|
||||
|
||||
template <typename T, index_t N0, index_t N1>
|
||||
struct vector_type_maker<vector_type<T, N1>, N0>
|
||||
{
|
||||
using type = vector_type<T, N0 * N1>;
|
||||
};
|
||||
|
||||
template <typename T, index_t N>
|
||||
using vector_type_maker_t = typename vector_type_maker<T, N>::type;
|
||||
|
||||
template <typename T, index_t N>
|
||||
CK_HOST_DEVICE constexpr auto make_vector_type(number<N>)
|
||||
{
|
||||
return typename vector_type_maker<T, N>::type{};
|
||||
}
|
||||
|
||||
// scalar_type
|
||||
template <typename TV>
|
||||
struct scalar_type;
|
||||
|
||||
// is_scalar_type
|
||||
template <typename TV>
|
||||
struct is_scalar_type
|
||||
{
|
||||
static constexpr bool value = (scalar_type<remove_cvref_t<TV>>::vector_size == 1);
|
||||
};
|
||||
|
||||
// has_same_scalar_type
|
||||
template <typename X, typename Y>
|
||||
using has_same_scalar_type = is_same<typename scalar_type<remove_cvref_t<X>>::type,
|
||||
typename scalar_type<remove_cvref_t<Y>>::type>;
|
||||
|
||||
template <typename T, index_t N>
|
||||
struct scalar_type<T __attribute__((ext_vector_type(N)))>
|
||||
{
|
||||
using type = T;
|
||||
static constexpr index_t vector_size = N;
|
||||
};
|
||||
|
||||
template <typename T, index_t N>
|
||||
struct scalar_type<vector_type<T, N>>
|
||||
{
|
||||
using type = T;
|
||||
static constexpr index_t vector_size = N;
|
||||
};
|
||||
|
||||
//
|
||||
template <>
|
||||
struct scalar_type<double>
|
||||
{
|
||||
using type = double;
|
||||
static constexpr index_t vector_size = 1;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct scalar_type<float>
|
||||
{
|
||||
using type = float;
|
||||
static constexpr index_t vector_size = 1;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct scalar_type<half_t>
|
||||
{
|
||||
using type = half_t;
|
||||
static constexpr index_t vector_size = 1;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct scalar_type<bhalf_t>
|
||||
{
|
||||
using type = bhalf_t;
|
||||
static constexpr index_t vector_size = 1;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct scalar_type<int64_t>
|
||||
{
|
||||
using type = int64_t;
|
||||
static constexpr index_t vector_size = 1;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct scalar_type<int32_t>
|
||||
{
|
||||
using type = int32_t;
|
||||
static constexpr index_t vector_size = 1;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct scalar_type<int8_t>
|
||||
{
|
||||
using type = int8_t;
|
||||
static constexpr index_t vector_size = 1;
|
||||
};
|
||||
|
||||
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
|
||||
template <>
|
||||
struct scalar_type<int4_t>
|
||||
{
|
||||
using type = int4_t;
|
||||
static constexpr index_t vector_size = 1;
|
||||
};
|
||||
#endif
|
||||
|
||||
template <>
|
||||
struct scalar_type<fp8_t>
|
||||
{
|
||||
using type = fp8_t;
|
||||
static constexpr index_t vector_size = 1;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct scalar_type<bf8_t>
|
||||
{
|
||||
using type = bf8_t;
|
||||
static constexpr index_t vector_size = 1;
|
||||
};
|
||||
|
||||
// below are some pre-defines of ext_vector_type
|
||||
// fp64
|
||||
using double2_t = typename vector_type<double, 2>::type;
|
||||
using double4_t = typename vector_type<double, 4>::type;
|
||||
|
||||
// fp32
|
||||
using float2_t = typename vector_type<float, 2>::type;
|
||||
using float4_t = typename vector_type<float, 4>::type;
|
||||
using float8_t = typename vector_type<float, 8>::type;
|
||||
using float16_t = typename vector_type<float, 16>::type;
|
||||
using float32_t = typename vector_type<float, 32>::type;
|
||||
using float64_t = typename vector_type<float, 64>::type;
|
||||
|
||||
// fp16
|
||||
using half2_t = typename vector_type<half_t, 2>::type;
|
||||
using half4_t = typename vector_type<half_t, 4>::type;
|
||||
using half8_t = typename vector_type<half_t, 8>::type;
|
||||
using half16_t = typename vector_type<half_t, 16>::type;
|
||||
using half32_t = typename vector_type<half_t, 32>::type;
|
||||
using half64_t = typename vector_type<half_t, 64>::type;
|
||||
|
||||
// bfp16
|
||||
using bhalf2_t = typename vector_type<bhalf_t, 2>::type;
|
||||
using bhalf4_t = typename vector_type<bhalf_t, 4>::type;
|
||||
using bhalf8_t = typename vector_type<bhalf_t, 8>::type;
|
||||
using bhalf16_t = typename vector_type<bhalf_t, 16>::type;
|
||||
using bhalf32_t = typename vector_type<bhalf_t, 32>::type;
|
||||
using bhalf64_t = typename vector_type<bhalf_t, 64>::type;
|
||||
|
||||
// i32
|
||||
using int32x2_t = typename vector_type<int32_t, 2>::type;
|
||||
using int32x4_t = typename vector_type<int32_t, 4>::type;
|
||||
using int32x8_t = typename vector_type<int32_t, 8>::type;
|
||||
using int32x16_t = typename vector_type<int32_t, 16>::type;
|
||||
using int32x32_t = typename vector_type<int32_t, 32>::type;
|
||||
using int32x64_t = typename vector_type<int32_t, 64>::type;
|
||||
|
||||
// i8
|
||||
using int8x2_t = typename vector_type<int8_t, 2>::type;
|
||||
using int8x4_t = typename vector_type<int8_t, 4>::type;
|
||||
using int8x8_t = typename vector_type<int8_t, 8>::type;
|
||||
using int8x16_t = typename vector_type<int8_t, 16>::type;
|
||||
using int8x32_t = typename vector_type<int8_t, 32>::type;
|
||||
using int8x64_t = typename vector_type<int8_t, 64>::type;
|
||||
|
||||
// f8
|
||||
using fp8x2_t = typename vector_type<fp8_t, 2>::type;
|
||||
using fp8x4_t = typename vector_type<fp8_t, 4>::type;
|
||||
using fp8x8_t = typename vector_type<fp8_t, 8>::type;
|
||||
using fp8x16_t = typename vector_type<fp8_t, 16>::type;
|
||||
using fp8x32_t = typename vector_type<fp8_t, 32>::type;
|
||||
using fp8x64_t = typename vector_type<fp8_t, 64>::type;
|
||||
|
||||
// bf8
|
||||
using bf8x2_t = typename vector_type<bf8_t, 2>::type;
|
||||
using bf8x4_t = typename vector_type<bf8_t, 4>::type;
|
||||
using bf8x8_t = typename vector_type<bf8_t, 8>::type;
|
||||
using bf8x16_t = typename vector_type<bf8_t, 16>::type;
|
||||
using bf8x32_t = typename vector_type<bf8_t, 32>::type;
|
||||
using bf8x64_t = typename vector_type<bf8_t, 64>::type;
|
||||
|
||||
} // namespace ck_tile
|
||||
1041
include/ck_tile/core/tensor/buffer_view.hpp
Normal file
1041
include/ck_tile/core/tensor/buffer_view.hpp
Normal file
File diff suppressed because it is too large
Load Diff
78
include/ck_tile/core/tensor/load_tile.hpp
Normal file
78
include/ck_tile/core/tensor/load_tile.hpp
Normal file
@@ -0,0 +1,78 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
|
||||
#include "ck_tile/core/container/container_helper.hpp"
|
||||
#include "ck_tile/core/numeric/math.hpp"
|
||||
#include "ck_tile/core/tensor/tile_window.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
typename TileDistribution_,
|
||||
index_t NumCoord,
|
||||
bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE auto load_tile(const tile_window_with_static_distribution<BottomTensorView_,
|
||||
WindowLengths_,
|
||||
TileDistribution_,
|
||||
NumCoord>& tile_window,
|
||||
bool_constant<oob_conditional_check> = {})
|
||||
{
|
||||
return tile_window.load(bool_constant<oob_conditional_check>{});
|
||||
}
|
||||
|
||||
template <typename T,
|
||||
typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
typename TileDistribution_,
|
||||
index_t NumCoord,
|
||||
bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE auto load_tile_raw(T& tile,
|
||||
const tile_window_with_static_distribution<BottomTensorView_,
|
||||
WindowLengths_,
|
||||
TileDistribution_,
|
||||
NumCoord>& tile_window,
|
||||
bool_constant<oob_conditional_check> = {})
|
||||
{
|
||||
tile_window.load_raw(tile, bool_constant<oob_conditional_check>{});
|
||||
}
|
||||
|
||||
template <typename LdsTileWindow_,
|
||||
typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
typename TileDistribution_,
|
||||
index_t NumCoord>
|
||||
CK_TILE_DEVICE auto
|
||||
async_load_tile_raw(LdsTileWindow_&& lds_tile,
|
||||
const tile_window_with_static_distribution<BottomTensorView_,
|
||||
WindowLengths_,
|
||||
TileDistribution_,
|
||||
NumCoord>& tile_window)
|
||||
{
|
||||
return tile_window.async_load(lds_tile);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE auto async_load_fence(index_t cnt = 0)
|
||||
{
|
||||
asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory");
|
||||
}
|
||||
|
||||
template <typename WindowLengths>
|
||||
CK_TILE_DEVICE auto load_tile(const NullTileWindow<WindowLengths>&)
|
||||
{
|
||||
return NullTensor{};
|
||||
}
|
||||
|
||||
template <typename T, typename WindowLengths>
|
||||
CK_TILE_DEVICE auto load_tile_raw(T& /*null_tile*/, const NullTileWindow<WindowLengths>&)
|
||||
{
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
12
include/ck_tile/core/tensor/null_tensor.hpp
Normal file
12
include/ck_tile/core/tensor/null_tensor.hpp
Normal file
@@ -0,0 +1,12 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
struct null_tensor
|
||||
{
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
87
include/ck_tile/core/tensor/null_tile_window.hpp
Normal file
87
include/ck_tile/core/tensor/null_tile_window.hpp
Normal file
@@ -0,0 +1,87 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/numeric/math.hpp"
|
||||
#include "ck_tile/core/tensor/tile_window.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// placeholder type if we want to opt-out a tile window parameter
|
||||
template <typename WindowLengths_>
|
||||
struct null_tile_window
|
||||
{
|
||||
using BottomTensorView = null_tensor_view;
|
||||
using WindowLengths = remove_cvref_t<WindowLengths_>;
|
||||
|
||||
using BottomTensorIndex = array<index_t, WindowLengths::size()>;
|
||||
|
||||
CK_TILE_DEVICE constexpr null_tile_window() = default;
|
||||
|
||||
CK_TILE_DEVICE constexpr null_tile_window(const WindowLengths& window_lengths)
|
||||
: window_lengths_{window_lengths}
|
||||
{
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE constexpr auto get_window_lengths() const { return window_lengths_; }
|
||||
|
||||
CK_TILE_DEVICE constexpr auto get_bottom_tensor_view() const { return null_tensor_view{}; }
|
||||
|
||||
CK_TILE_DEVICE constexpr auto get_window_origin() const { return BottomTensorIndex{}; }
|
||||
|
||||
WindowLengths window_lengths_;
|
||||
};
|
||||
|
||||
// utility to check if this is a Null Tile Window
|
||||
namespace impl {
|
||||
template <typename>
|
||||
struct is_null_tile_window : public std::false_type
|
||||
{
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct is_null_tile_window<null_tile_window<T>> : public std::true_type
|
||||
{
|
||||
};
|
||||
} // namespace impl
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE constexpr auto is_null_tile_window(const T&)
|
||||
{
|
||||
return impl::is_null_tile_window<remove_cvref_t<T>>::value;
|
||||
}
|
||||
|
||||
template <typename WindowLengths>
|
||||
CK_TILE_DEVICE constexpr auto make_null_tile_window(const WindowLengths& window_lengths)
|
||||
{
|
||||
static_assert(ck_tile::is_known_at_compile_time<WindowLengths>::value,
|
||||
"wrong! lengths should be static");
|
||||
|
||||
return null_tile_window<remove_cvref_t<WindowLengths>>{window_lengths};
|
||||
}
|
||||
|
||||
template <typename WindowLengths, typename... Ts>
|
||||
CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view,
|
||||
const WindowLengths& window_lengths,
|
||||
const multi_index<WindowLengths::size()>& /*origin*/,
|
||||
Ts&&...)
|
||||
{
|
||||
static_assert(ck_tile::is_known_at_compile_time<WindowLengths>::value,
|
||||
"wrong! lengths should be static");
|
||||
|
||||
return null_tile_window<remove_cvref_t<WindowLengths>>{window_lengths};
|
||||
}
|
||||
|
||||
template <typename WindowLengths>
|
||||
CK_TILE_DEVICE void
|
||||
move_tile_window(null_tile_window<WindowLengths>&,
|
||||
const typename null_tile_window<WindowLengths>::BottomTensorIndex&)
|
||||
{
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
171
include/ck_tile/core/tensor/shuffle_tile.hpp
Normal file
171
include/ck_tile/core/tensor/shuffle_tile.hpp
Normal file
@@ -0,0 +1,171 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
|
||||
#include "ck_tile/core/container/container_helper.hpp"
|
||||
#include "ck_tile/core/container/statically_indexed_array.hpp"
|
||||
#include "ck_tile/core/numeric/math.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
#include "ck_tile/core/tensor/tile_elementwise.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
namespace detail {
|
||||
|
||||
template <typename OutTensor, typename InTensor>
|
||||
CK_TILE_DEVICE void shuffle_tile_impl_in_thread(OutTensor& out_tensor, const InTensor& in_tensor)
|
||||
{
|
||||
constexpr auto I0 = number<0>{};
|
||||
|
||||
using DataType = typename InTensor::DataType;
|
||||
|
||||
constexpr auto y_in_desc = InTensor::get_tile_distribution().get_ys_to_d_descriptor();
|
||||
constexpr auto y_out_desc = OutTensor::get_tile_distribution().get_ys_to_d_descriptor();
|
||||
|
||||
// y_dim_out_to_in
|
||||
constexpr auto get_rh_major_minor_to_y = [](auto dstr_tensor) {
|
||||
using DstrEncode = typename decltype(dstr_tensor.get_tile_distribution())::DstrEncode;
|
||||
|
||||
map<array<index_t, 2>, index_t> rh_major_minor_to_y_;
|
||||
|
||||
static_for<0, DstrEncode::NDimY, 1>{}([&](auto i) {
|
||||
constexpr index_t rh_major = DstrEncode::ys_to_rhs_major_[i];
|
||||
constexpr index_t rh_minor = DstrEncode::ys_to_rhs_minor_[i];
|
||||
|
||||
rh_major_minor_to_y_({rh_major, rh_minor}) = i;
|
||||
});
|
||||
|
||||
return rh_major_minor_to_y_;
|
||||
};
|
||||
|
||||
constexpr auto rh_major_minor_to_y_in = get_rh_major_minor_to_y(InTensor{});
|
||||
constexpr auto rh_major_minor_to_y_out = get_rh_major_minor_to_y(OutTensor{});
|
||||
|
||||
constexpr auto y_dim_out_to_in = [&] {
|
||||
map<index_t, index_t> y_dim_out_to_in_;
|
||||
|
||||
for(const auto& [rh_major_minor, y_out] : rh_major_minor_to_y_out)
|
||||
{
|
||||
y_dim_out_to_in_(y_out) = rh_major_minor_to_y_in[rh_major_minor];
|
||||
}
|
||||
|
||||
return y_dim_out_to_in_;
|
||||
}();
|
||||
|
||||
//
|
||||
constexpr index_t NDimY = InTensor::get_tile_distribution().GetNumOfDimensionY();
|
||||
|
||||
constexpr auto y_lengths = to_sequence(y_in_desc.get_lengths());
|
||||
|
||||
// input and output vector dim in the order of input Y dims
|
||||
constexpr index_t y_dim_vec_in = NDimY - 1;
|
||||
constexpr index_t y_dim_vec_out = y_dim_out_to_in[NDimY - 1];
|
||||
|
||||
// vector lengths
|
||||
constexpr index_t vec_length_in = y_lengths[y_dim_vec_in];
|
||||
constexpr index_t vec_length_out = y_lengths[y_dim_vec_out];
|
||||
|
||||
// # of vectors
|
||||
constexpr index_t num_vec_in = vec_length_out;
|
||||
constexpr index_t num_vec_out = vec_length_in;
|
||||
|
||||
using InVec = vector_type<DataType, vec_length_in>;
|
||||
using OutVec = vector_type<DataType, vec_length_out>;
|
||||
|
||||
using InVecType = typename InVec::type;
|
||||
using OutVecType = typename OutVec::type;
|
||||
|
||||
// SFC
|
||||
constexpr auto scalars_per_access_arr = generate_array(
|
||||
[&](auto i) { return (i == y_dim_vec_in or i == y_dim_vec_out) ? y_lengths[i] : 1; },
|
||||
number<NDimY>{});
|
||||
|
||||
constexpr auto scalars_per_access = TO_SEQUENCE(scalars_per_access_arr, NDimY);
|
||||
|
||||
using SFC_Y = space_filling_curve<decltype(y_lengths),
|
||||
typename arithmetic_sequence_gen<0, NDimY, 1>::type,
|
||||
decltype(scalars_per_access)>;
|
||||
|
||||
constexpr index_t num_access = SFC_Y::get_num_of_access();
|
||||
|
||||
static_assert(num_access > 0, "wrong! num_access should be larger than 0");
|
||||
|
||||
// in/out vectors to be transposed
|
||||
statically_indexed_array<InVec, num_vec_in> in_vectors;
|
||||
statically_indexed_array<OutVec, num_vec_out> out_vectors;
|
||||
|
||||
// loop over SFC and do transpose
|
||||
static_for<0, num_access, 1>{}([&](auto iAccess) {
|
||||
// data index [y0, y1, ...] in the order of input tensor
|
||||
constexpr auto idx_y_start = SFC_Y::get_index(iAccess);
|
||||
|
||||
// get input vectors
|
||||
static_for<0, num_vec_in, 1>{}([&](auto i) {
|
||||
constexpr auto idx_y_in = generate_array(
|
||||
[&](auto ii) {
|
||||
return ii == y_dim_vec_out ? idx_y_start[ii] + i : idx_y_start[ii];
|
||||
},
|
||||
number<NDimY>{});
|
||||
|
||||
constexpr index_t in_offset = y_in_desc.calculate_offset(idx_y_in);
|
||||
|
||||
in_vectors(i).template AsType<InVecType>()(I0) =
|
||||
in_tensor.get_thread_buffer().template get_as<InVecType>(number<in_offset>{});
|
||||
});
|
||||
|
||||
// transpose
|
||||
transpose_vectors<DataType, num_vec_in, num_vec_out>{}(in_vectors, out_vectors);
|
||||
|
||||
// set output vectors
|
||||
static_for<0, num_vec_out, 1>{}([&](auto i) {
|
||||
constexpr auto idx_y_out_tmp = generate_array(
|
||||
[&](auto ii) { return ii == y_dim_vec_in ? idx_y_start[ii] + i : idx_y_start[ii]; },
|
||||
number<NDimY>{});
|
||||
|
||||
constexpr auto idx_y_out =
|
||||
container_reorder_given_new2old(idx_y_out_tmp, y_dim_out_to_in);
|
||||
|
||||
constexpr index_t out_offset = y_out_desc.calculate_offset(idx_y_out);
|
||||
|
||||
out_tensor.get_thread_buffer().template set_as<OutVecType>(
|
||||
number<out_offset / sizeof(OutVecType)>{},
|
||||
out_vectors[i].template AsType<OutVecType>()[I0]);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
template <typename OutTensor, typename InTensor>
|
||||
CK_TILE_DEVICE void shuffle_tile(OutTensor& out, const InTensor& in)
|
||||
{
|
||||
using InDataType = typename InTensor::DataType;
|
||||
using OutDataType = typename OutTensor::DataType;
|
||||
|
||||
using InDstrEncode = typename InTensor::StaticTileDistribution::DstrEncode;
|
||||
using OutDstrEncode = typename OutTensor::StaticTileDistribution::DstrEncode;
|
||||
|
||||
// type convert
|
||||
const auto in_tmp = tile_elementwise_in(type_convert<OutDataType, InDataType>, in);
|
||||
|
||||
// shuffle
|
||||
if constexpr(InDstrEncode::rs_lengths_ == OutDstrEncode::rs_lengths_ &&
|
||||
InDstrEncode::hs_lengthss_ == OutDstrEncode::hs_lengthss_ &&
|
||||
InDstrEncode::ps_to_rhss_major_ == OutDstrEncode::ps_to_rhss_major_ &&
|
||||
InDstrEncode::ps_to_rhss_minor_ == OutDstrEncode::ps_to_rhss_minor_ &&
|
||||
InDstrEncode::NDimY == OutDstrEncode::NDimY)
|
||||
{
|
||||
detail::shuffle_tile_impl_in_thread(out, in_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
// NOT implemented
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
94
include/ck_tile/core/tensor/slice_tile.hpp
Normal file
94
include/ck_tile/core/tensor/slice_tile.hpp
Normal file
@@ -0,0 +1,94 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
|
||||
#include "ck_tile/core/container/container_helper.hpp"
|
||||
#include "ck_tile/core/numeric/math.hpp"
|
||||
#include "ck_tile/core/tensor/tile_window.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
namespace tile_program {
|
||||
|
||||
template <typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
index_t... SliceBegins,
|
||||
index_t... SliceEnds>
|
||||
CK_TILE_DEVICE constexpr auto
|
||||
get_slice_tile(const tile_window_with_static_lengths<BottomTensorView_, WindowLengths_>& tile,
|
||||
sequence<SliceBegins...> slice_begins,
|
||||
sequence<SliceEnds...> slice_ends)
|
||||
{
|
||||
using TileWindow = tile_window_with_static_lengths<BottomTensorView_, WindowLengths_>;
|
||||
// NOTE: This API will override the origin of the tile window!
|
||||
static_assert(sizeof...(SliceBegins) == sizeof...(SliceEnds));
|
||||
static_assert(sizeof...(SliceBegins) == TileWindow::get_num_of_dimension());
|
||||
|
||||
constexpr auto slice_lengths = slice_ends - slice_begins;
|
||||
|
||||
return make_tile_window(tile.GetBottomTensorView(),
|
||||
sequence_to_tuple_of_number(slice_lengths),
|
||||
to_multi_index(slice_begins));
|
||||
}
|
||||
|
||||
template <typename DataType_,
|
||||
typename StaticTileDistribution_,
|
||||
index_t... SliceBegins,
|
||||
index_t... SliceEnds>
|
||||
CK_TILE_DEVICE constexpr auto
|
||||
get_slice_tile(const static_distributed_tensor<DataType_, StaticTileDistribution_>& tile,
|
||||
sequence<SliceBegins...> slice_begins,
|
||||
sequence<SliceEnds...> slice_ends)
|
||||
{
|
||||
using DataType = remove_cvref_t<DataType_>;
|
||||
using Distribution = remove_cvref_t<StaticTileDistribution_>;
|
||||
|
||||
constexpr auto sliced_dstr_yidx_ylen =
|
||||
detail::slice_distribution_from_x(Distribution{}, slice_begins, slice_ends);
|
||||
|
||||
constexpr auto sliced_dstr = sliced_dstr_yidx_ylen.template at<0>();
|
||||
constexpr auto sliced_y_origins = sliced_dstr_yidx_ylen.template at<1>();
|
||||
constexpr auto sliced_y_lengths = sliced_dstr_yidx_ylen.template at<2>();
|
||||
|
||||
auto sliced_tensor = make_static_distributed_tensor<DataType>(sliced_dstr);
|
||||
|
||||
sliced_tensor.get_thread_buffer() =
|
||||
tile.get_y_sliced_thread_data(sliced_y_origins, sliced_y_lengths);
|
||||
|
||||
return sliced_tensor;
|
||||
}
|
||||
|
||||
template <typename DstDataType_,
|
||||
typename DstStaticTileDistribution_,
|
||||
typename SrcDataType_,
|
||||
typename SrcStaticTileDistribution_,
|
||||
index_t... SliceBegins,
|
||||
index_t... SliceEnds>
|
||||
CK_TILE_DEVICE constexpr auto
|
||||
set_slice_tile(static_distributed_tensor<DstDataType_, DstStaticTileDistribution_>& dst_tile,
|
||||
const static_distributed_tensor<SrcDataType_, SrcStaticTileDistribution_>& src_tile,
|
||||
sequence<SliceBegins...> slice_begins,
|
||||
sequence<SliceEnds...> slice_ends)
|
||||
{
|
||||
using DstDistribution = remove_cvref_t<DstStaticTileDistribution_>;
|
||||
|
||||
constexpr auto sliced_dstr_yidx_ylen =
|
||||
detail::slice_distribution_from_x(DstDistribution{}, slice_begins, slice_ends);
|
||||
|
||||
constexpr auto sliced_dstr = sliced_dstr_yidx_ylen.template at<0>();
|
||||
constexpr auto sliced_y_origins = sliced_dstr_yidx_ylen.template at<1>();
|
||||
constexpr auto sliced_y_lengths = sliced_dstr_yidx_ylen.template at<2>();
|
||||
|
||||
static_assert(is_same_v<decltype(sliced_dstr), DstDistribution>, "wrong!");
|
||||
|
||||
dst_tile.SetSlicedThreadData(sliced_y_origins, sliced_y_lengths, src_tile.get_thread_buffer());
|
||||
}
|
||||
|
||||
} // namespace tile_program
|
||||
} // namespace ck_tile
|
||||
180
include/ck_tile/core/tensor/static_distributed_tensor.hpp
Normal file
180
include/ck_tile/core/tensor/static_distributed_tensor.hpp
Normal file
@@ -0,0 +1,180 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
|
||||
#include "ck_tile/core/container/container_helper.hpp"
|
||||
#include "ck_tile/core/numeric/math.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename DataType_, typename StaticTileDistribution_>
|
||||
struct static_distributed_tensor
|
||||
{
|
||||
using DataType = remove_cvref_t<DataType_>;
|
||||
using StaticTileDistribution = remove_cvref_t<StaticTileDistribution_>;
|
||||
|
||||
static_assert(StaticTileDistribution::is_static(),
|
||||
"wrong! StaticTileDistribution should be known at compile tile");
|
||||
|
||||
using ThreadTensorDesc =
|
||||
remove_cvref_t<decltype(StaticTileDistribution{}.get_ys_to_d_descriptor())>;
|
||||
|
||||
static constexpr index_t kThreadElementSpaceSize = ThreadTensorDesc{}.get_element_space_size();
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_dimension()
|
||||
{
|
||||
return StaticTileDistribution::get_num_of_dimension_x();
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_lengths()
|
||||
{
|
||||
return StaticTileDistribution::get_lengths();
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_tile_distribution()
|
||||
{
|
||||
return StaticTileDistribution{};
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_distributed_spans()
|
||||
{
|
||||
return StaticTileDistribution::get_distributed_spans();
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void initialize(const DataType& x) { thread_buf_.initialize(x); }
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr const auto& get_thread_buffer() const { return thread_buf_; }
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr auto& get_thread_buffer() { return thread_buf_; }
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t get_thread_buffer_size()
|
||||
{
|
||||
return kThreadElementSpaceSize;
|
||||
}
|
||||
|
||||
template <index_t... YSliceOrigins, index_t... YSliceLengths>
|
||||
CK_TILE_HOST_DEVICE auto get_y_sliced_thread_data(sequence<YSliceOrigins...>,
|
||||
sequence<YSliceLengths...>) const
|
||||
{
|
||||
static_assert(sizeof...(YSliceOrigins) == StaticTileDistribution::NDimY &&
|
||||
sizeof...(YSliceLengths) == StaticTileDistribution::NDimY,
|
||||
"wrong!");
|
||||
|
||||
constexpr auto sliced_thread_tensor_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(YSliceLengths...));
|
||||
|
||||
array<DataType, sliced_thread_tensor_desc.get_element_space_size()> sliced_thread_data;
|
||||
|
||||
static_ford<sequence<YSliceLengths...>>{}([&](auto idx) {
|
||||
constexpr auto idx_ys = idx + sequence<YSliceOrigins...>{};
|
||||
|
||||
sliced_thread_data(number<sliced_thread_tensor_desc.calculate_offset(idx)>{}) =
|
||||
thread_buf_[number<ThreadTensorDesc{}.calculate_offset(idx_ys)>{}];
|
||||
});
|
||||
|
||||
return sliced_thread_data;
|
||||
}
|
||||
|
||||
template <index_t... YSliceOrigins, index_t... YSliceLengths, index_t NSlicedData>
|
||||
CK_TILE_HOST_DEVICE void
|
||||
set_y_sliced_thread_data(sequence<YSliceOrigins...>,
|
||||
sequence<YSliceLengths...>,
|
||||
const array<DataType, NSlicedData>& sliced_thread_data)
|
||||
{
|
||||
static_assert(sizeof...(YSliceOrigins) == StaticTileDistribution::NDimY &&
|
||||
sizeof...(YSliceLengths) == StaticTileDistribution::NDimY,
|
||||
"wrong!");
|
||||
|
||||
constexpr auto sliced_thread_tensor_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(YSliceLengths...));
|
||||
|
||||
static_ford<sequence<YSliceLengths...>>{}([&](auto idx) {
|
||||
constexpr auto idx_ys = idx + sequence<YSliceOrigins...>{};
|
||||
|
||||
thread_buf_(number<ThreadTensorDesc{}.calculate_offset(idx_ys)>{}) =
|
||||
sliced_thread_data[number<sliced_thread_tensor_desc.calculate_offset(idx)>{}];
|
||||
});
|
||||
}
|
||||
|
||||
template <typename TileDistributedIndices>
|
||||
CK_TILE_HOST_DEVICE constexpr const DataType& operator[](TileDistributedIndices) const
|
||||
{
|
||||
static_assert(is_static_v<TileDistributedIndices>,
|
||||
"wrong! Tile Distributed Indices should be static");
|
||||
|
||||
constexpr auto y_idx = get_tile_distribution().get_y_indices_from_distributed_indices(
|
||||
TileDistributedIndices{});
|
||||
|
||||
return thread_buf_[number<ThreadTensorDesc{}.calculate_offset(y_idx)>{}];
|
||||
}
|
||||
|
||||
template <typename TileDistributedIndices>
|
||||
CK_TILE_HOST_DEVICE constexpr DataType& operator()(TileDistributedIndices)
|
||||
{
|
||||
static_assert(is_static_v<TileDistributedIndices>,
|
||||
"wrong! Tile Distributed Indices should be static");
|
||||
|
||||
constexpr auto y_idx = get_tile_distribution().get_y_indices_from_distributed_indices(
|
||||
TileDistributedIndices{});
|
||||
|
||||
return thread_buf_(number<ThreadTensorDesc{}.calculate_offset(y_idx)>{});
|
||||
}
|
||||
|
||||
//
|
||||
array<DataType, kThreadElementSpaceSize> thread_buf_;
|
||||
};
|
||||
|
||||
template <typename DataType, typename StaticTileDistribution>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_static_distributed_tensor(const StaticTileDistribution&)
|
||||
{
|
||||
return static_distributed_tensor<remove_cvref_t<DataType>,
|
||||
remove_cvref_t<StaticTileDistribution>>{};
|
||||
}
|
||||
|
||||
// get X indices from tuple of tile_distributed_index<>
|
||||
template <typename StaticTileDistribution, typename DistributedIndices>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
get_x_indices_from_distributed_indices(StaticTileDistribution tile_distribution,
|
||||
DistributedIndices distributed_indices)
|
||||
{
|
||||
const auto partition_index = detail::get_partition_index(tile_distribution);
|
||||
constexpr auto y_indices =
|
||||
tile_distribution.get_y_indices_from_distributed_indices(distributed_indices);
|
||||
|
||||
const auto x_coord = make_tensor_adaptor_coordinate(
|
||||
tile_distribution.get_ps_ys_to_xs_adaptor(),
|
||||
container_concat(partition_index, to_array<ck_tile::index_t, y_indices.size()>(y_indices)));
|
||||
|
||||
return x_coord.get_bottom_index();
|
||||
}
|
||||
|
||||
template <typename DataType, typename StaticTileDistribution, typename XIndicesPredicate>
|
||||
CK_TILE_HOST_DEVICE void
|
||||
set_tile_if(static_distributed_tensor<DataType, StaticTileDistribution>& out_tensor,
|
||||
DataType value,
|
||||
XIndicesPredicate predicate)
|
||||
{
|
||||
constexpr auto out_spans =
|
||||
static_distributed_tensor<DataType, StaticTileDistribution>::get_distributed_spans();
|
||||
sweep_tile_span(out_spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(out_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto distributed_indices = make_tuple(idx0, idx1);
|
||||
const auto x_indices = get_x_indices_from_distributed_indices(StaticTileDistribution{},
|
||||
distributed_indices);
|
||||
|
||||
if(predicate(x_indices))
|
||||
{
|
||||
out_tensor(distributed_indices) = value;
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
93
include/ck_tile/core/tensor/store_tile.hpp
Normal file
93
include/ck_tile/core/tensor/store_tile.hpp
Normal file
@@ -0,0 +1,93 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
|
||||
#include "ck_tile/core/container/container_helper.hpp"
|
||||
#include "ck_tile/core/numeric/math.hpp"
|
||||
#include "ck_tile/core/tensor/tile_window.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
typename TileDistribution_,
|
||||
typename DataType_>
|
||||
CK_TILE_DEVICE void
|
||||
store_tile(tile_window_with_static_lengths<BottomTensorView_, WindowLengths_>& tile_window_tmp,
|
||||
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor)
|
||||
{
|
||||
using DataType = remove_cvref_t<typename BottomTensorView_::DataType>;
|
||||
using TileDstr = remove_cvref_t<TileDistribution_>;
|
||||
|
||||
static_assert(is_same_v<remove_cvref_t<DataType_>, DataType>, "wrong!");
|
||||
|
||||
constexpr auto tile_dstr = TileDstr{};
|
||||
|
||||
auto tile_window = make_tile_window(tile_window_tmp.get_bottom_tensor_view(),
|
||||
tile_window_tmp.get_window_lengths(),
|
||||
tile_window_tmp.get_window_origin(),
|
||||
tile_dstr);
|
||||
|
||||
tile_window.store(dstr_tensor);
|
||||
}
|
||||
|
||||
template <typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
typename TileDistribution_,
|
||||
typename DataType_>
|
||||
CK_TILE_DEVICE void
|
||||
store_tile_raw(tile_window_with_static_lengths<BottomTensorView_, WindowLengths_>& tile_window_tmp,
|
||||
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor)
|
||||
{
|
||||
using DataType = remove_cvref_t<typename BottomTensorView_::DataType>;
|
||||
using TileDstr = remove_cvref_t<TileDistribution_>;
|
||||
|
||||
static_assert(is_same_v<remove_cvref_t<DataType_>, DataType>, "wrong!");
|
||||
|
||||
constexpr auto tile_dstr = TileDstr{};
|
||||
|
||||
auto tile_window = make_tile_window(tile_window_tmp.get_bottom_tensor_view(),
|
||||
tile_window_tmp.get_window_lengths(),
|
||||
tile_window_tmp.get_window_origin(),
|
||||
tile_dstr);
|
||||
|
||||
tile_window.store_raw(dstr_tensor);
|
||||
}
|
||||
|
||||
template <typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
typename TileDistribution_,
|
||||
index_t NumCoord,
|
||||
typename DataType_>
|
||||
CK_TILE_DEVICE void
|
||||
store_tile(tile_window_with_static_distribution<BottomTensorView_,
|
||||
WindowLengths_,
|
||||
TileDistribution_,
|
||||
NumCoord>& tile_window,
|
||||
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor)
|
||||
{
|
||||
tile_window.store(dstr_tensor);
|
||||
}
|
||||
|
||||
template <typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
typename TileDistribution_,
|
||||
index_t NumCoord,
|
||||
typename DataType_>
|
||||
CK_TILE_DEVICE void
|
||||
store_tile_raw(tile_window_with_static_distribution<BottomTensorView_,
|
||||
WindowLengths_,
|
||||
TileDistribution_,
|
||||
NumCoord>& tile_window,
|
||||
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor)
|
||||
{
|
||||
tile_window.store_raw(dstr_tensor);
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
30
include/ck_tile/core/tensor/sweep_tile.hpp
Normal file
30
include/ck_tile/core/tensor/sweep_tile.hpp
Normal file
@@ -0,0 +1,30 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/tensor/tile_distribution.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// sweep over a span of a distribted tile and apply lambda function F
|
||||
template <typename TileDistributedSpan_, // tile_distributed_span<...>
|
||||
typename F // signature: F(tile_distributed_index<...>)
|
||||
>
|
||||
CK_TILE_DEVICE void sweep_tile_span(TileDistributedSpan_, const F& f)
|
||||
{
|
||||
using DstrSpan = remove_cvref_t<TileDistributedSpan_>;
|
||||
|
||||
static_ford<typename DstrSpan::Impl>{}([&](auto dstr_idx_impl) {
|
||||
constexpr auto dstr_idx = detail::make_tile_distributed_index(dstr_idx_impl);
|
||||
|
||||
f(dstr_idx);
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
942
include/ck_tile/core/tensor/tensor_adaptor.hpp
Normal file
942
include/ck_tile/core/tensor/tensor_adaptor.hpp
Normal file
@@ -0,0 +1,942 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
|
||||
#include "ck_tile/core/container/container_helper.hpp"
|
||||
#include "ck_tile/core/numeric/math.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// Transforms: Tuple<transforms...>
|
||||
// LowerDimensionHiddenIdss : Tuple<Sequence<...>, ...>
|
||||
// UpperDimensionHiddenIdss : Tuple<Sequence<...>, ...>
|
||||
// BottomDimensionHiddenIds : Sequence<...>
|
||||
// TopDimensionHiddenIds : Sequence<...>
|
||||
template <typename Transforms,
|
||||
typename LowerDimensionHiddenIdss,
|
||||
typename UpperDimensionHiddenIdss,
|
||||
typename BottomDimensionHiddenIds,
|
||||
typename TopDimensionHiddenIds>
|
||||
struct tensor_adaptor
|
||||
{
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_transform()
|
||||
{
|
||||
return Transforms::size();
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr const auto& get_transforms() const { return transforms_; }
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_lower_dimension_hidden_idss()
|
||||
{
|
||||
return LowerDimensionHiddenIdss{};
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_upper_dimension_hidden_idss()
|
||||
{
|
||||
return UpperDimensionHiddenIdss{};
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_bottom_dimension_hidden_ids()
|
||||
{
|
||||
return BottomDimensionHiddenIds{};
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_top_dimension_hidden_ids()
|
||||
{
|
||||
return TopDimensionHiddenIds{};
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto initialize_element_size(const Transforms& transforms)
|
||||
{
|
||||
const auto lengths = generate_tuple(
|
||||
[&](auto idim_top) {
|
||||
constexpr index_t idim_hidden = TopDimensionHiddenIds::at(idim_top);
|
||||
|
||||
constexpr auto tmp = get_transform_and_its_upper_dimension(number<idim_hidden>{});
|
||||
|
||||
constexpr index_t itran = tmp[number<0>{}];
|
||||
constexpr index_t idim_up = tmp[number<1>{}];
|
||||
constexpr bool found = tmp[number<2>{}];
|
||||
|
||||
static_assert(found == true,
|
||||
"wrong! not found matching transformation and upper-dimension");
|
||||
|
||||
const auto length =
|
||||
transforms[number<itran>{}].get_upper_lengths()[number<idim_up>{}];
|
||||
|
||||
return length;
|
||||
},
|
||||
number<ndim_top_>{});
|
||||
|
||||
// TODO: make container_reduce support tuple of number and index_t
|
||||
return container_reduce(lengths, multiplies{}, number<1>{});
|
||||
}
|
||||
|
||||
template <index_t IDimHidden>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto
|
||||
get_transform_and_its_upper_dimension(number<IDimHidden>)
|
||||
{
|
||||
// FIXME: length of bottom dimension is not known, since info about lower dim length are not
|
||||
// saved in transformation
|
||||
static_assert(IDimHidden >= ndim_bottom_, "wrong! not implemented");
|
||||
|
||||
index_t itran_found = 0;
|
||||
index_t idim_up_found = 0;
|
||||
bool found = false;
|
||||
|
||||
static_for<0, ntransform_, 1>{}([&](auto itran) {
|
||||
constexpr auto up_dim_ids = UpperDimensionHiddenIdss{}[itran];
|
||||
|
||||
static_for<0, up_dim_ids.size(), 1>{}([&](auto idim_up) {
|
||||
if constexpr(up_dim_ids[idim_up] == IDimHidden)
|
||||
{
|
||||
itran_found = itran;
|
||||
idim_up_found = idim_up;
|
||||
found = true;
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
return make_tuple(itran_found, idim_up_found, found);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_bottom_dimension()
|
||||
{
|
||||
return BottomDimensionHiddenIds::size();
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_top_dimension()
|
||||
{
|
||||
return TopDimensionHiddenIds::size();
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_hidden_dimension()
|
||||
{
|
||||
constexpr auto all_low_dim_ids = unpack(
|
||||
[](auto&&... xs) constexpr { return merge_sequences(xs...); },
|
||||
LowerDimensionHiddenIdss{});
|
||||
|
||||
constexpr auto all_up_dim_ids = unpack(
|
||||
[](auto&&... xs) constexpr { return merge_sequences(xs...); },
|
||||
UpperDimensionHiddenIdss{});
|
||||
|
||||
constexpr auto all_dim_ids = merge_sequences(all_low_dim_ids, all_up_dim_ids);
|
||||
|
||||
using unique_sort_all_dim_ids = typename sequence_unique_sort<decltype(all_dim_ids),
|
||||
less<index_t>,
|
||||
equal<index_t>>::type;
|
||||
|
||||
return unique_sort_all_dim_ids::size();
|
||||
}
|
||||
|
||||
constexpr static index_t ntransform_ = get_num_of_transform();
|
||||
constexpr static index_t ndim_hidden_ = get_num_of_hidden_dimension();
|
||||
constexpr static index_t ndim_bottom_ = get_num_of_bottom_dimension();
|
||||
constexpr static index_t ndim_top_ = get_num_of_top_dimension();
|
||||
|
||||
using HiddenIndex = multi_index<ndim_hidden_>;
|
||||
using BottomIndex = multi_index<ndim_bottom_>;
|
||||
using TopIndex = multi_index<ndim_top_>;
|
||||
|
||||
// may be index_t or number<>
|
||||
using ElementSize = remove_cv_t<decltype(initialize_element_size(Transforms{}))>;
|
||||
|
||||
public:
|
||||
CK_TILE_HOST_DEVICE constexpr tensor_adaptor() = default;
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr tensor_adaptor(const Transforms& transforms)
|
||||
: transforms_{transforms}, element_size_{initialize_element_size(transforms)}
|
||||
{
|
||||
static_assert(Transforms::size() == ntransform_ &&
|
||||
LowerDimensionHiddenIdss::size() == ntransform_ &&
|
||||
UpperDimensionHiddenIdss::size() == ntransform_,
|
||||
"wrong! inconsistent # of transformations");
|
||||
|
||||
// TODO check dependency of dimensions is valid
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr auto get_element_size() const { return element_size_; }
|
||||
|
||||
// FIXME: this logic is wrong when getting bottome dimension lengths
|
||||
template <index_t IDimHidden>
|
||||
CK_TILE_HOST_DEVICE constexpr auto get_hidden_dimension_length(number<IDimHidden>) const
|
||||
{
|
||||
static_assert(IDimHidden >= 0 && IDimHidden < ndim_hidden_, "wrong! out of range");
|
||||
|
||||
constexpr auto tmp = get_transform_and_its_upper_dimension(number<IDimHidden>{});
|
||||
|
||||
constexpr index_t itran = tmp[number<0>{}];
|
||||
constexpr index_t idim_up = tmp[number<1>{}];
|
||||
constexpr bool found = tmp[number<2>{}];
|
||||
|
||||
static_assert(found == true,
|
||||
"wrong! not found matching transformation and upper-dimension");
|
||||
|
||||
return transforms_[number<itran>{}].get_upper_lengths()[number<idim_up>{}];
|
||||
}
|
||||
|
||||
template <index_t IDimTop>
|
||||
CK_TILE_HOST_DEVICE constexpr auto get_top_dimension_length(number<IDimTop> idim_top) const
|
||||
{
|
||||
return get_hidden_dimension_length(TopDimensionHiddenIds::at(idim_top));
|
||||
}
|
||||
|
||||
#if 0
|
||||
// FIXME: get_hidden_dimension_length is wrong when getting bottome dimension lengths
|
||||
template <index_t IDimBottom>
|
||||
CK_TILE_HOST_DEVICE constexpr index_t
|
||||
get_bottom_dimension_length(number<IDimBottom> idim_bottom) const
|
||||
{
|
||||
return get_hidden_dimension_length(TopDimensionHiddenIds::at(idim_bottom));
|
||||
}
|
||||
#endif
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr auto get_top_dimension_lengths() const
|
||||
{
|
||||
return generate_tuple([&](auto i) { return get_top_dimension_length(i); },
|
||||
number<ndim_top_>{});
|
||||
}
|
||||
|
||||
#if 0
|
||||
// FIXME: get_hidden_dimension_length is wrong when getting bottome dimension lengths
|
||||
CK_TILE_HOST_DEVICE constexpr auto GetBottomDimensionLengths() const
|
||||
{
|
||||
return generate_tuple([&](auto i) { return get_bottom_dimension_length(i); },
|
||||
number<ndim_bottom_>{});
|
||||
}
|
||||
#endif
|
||||
|
||||
template <typename TopIdx>
|
||||
CK_TILE_HOST_DEVICE constexpr auto calculate_bottom_index(const TopIdx& idx_top) const
|
||||
{
|
||||
static_assert(TopIdx::size() == TopDimensionHiddenIds::size(),
|
||||
"wrong! # of dimension inconsistent");
|
||||
|
||||
constexpr index_t ntransform = get_num_of_transform();
|
||||
constexpr index_t ndim_hidden = get_num_of_hidden_dimension();
|
||||
|
||||
multi_index<ndim_hidden> idx_hidden;
|
||||
|
||||
// initialize uppest index
|
||||
set_container_subset(idx_hidden, get_top_dimension_hidden_ids(), idx_top);
|
||||
|
||||
// calculate hidden index
|
||||
static_for<ntransform, 0, -1>{}([&](auto itran_p1) {
|
||||
auto itran = itran_p1 - number<1>{};
|
||||
const auto& tran = get_transforms().at(itran);
|
||||
constexpr auto dims_low = get_lower_dimension_hidden_idss().at(itran);
|
||||
constexpr auto dims_up = get_upper_dimension_hidden_idss().at(itran);
|
||||
|
||||
const auto idx_up = get_container_subset(idx_hidden, dims_up);
|
||||
|
||||
multi_index<dims_low.size()> idx_low;
|
||||
|
||||
tran.calculate_lower_index(idx_low, idx_up);
|
||||
|
||||
set_container_subset(idx_hidden, dims_low, idx_low);
|
||||
});
|
||||
|
||||
return get_container_subset(idx_hidden, BottomDimensionHiddenIds{});
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr bool is_static()
|
||||
{
|
||||
bool is_known = true;
|
||||
|
||||
static_for<0, Transforms::size(), 1>{}([&](auto i) {
|
||||
is_known &= remove_cvref_t<decltype(Transforms{}[i])>::is_known_at_compile_time();
|
||||
});
|
||||
|
||||
return is_known && ck_tile::is_known_at_compile_time<ElementSize>::value;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time() { return is_static(); }
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_top_dimension_safe_vector_length_strides(
|
||||
const array<index_t, ndim_hidden_>& guaranteed_vector_lengths,
|
||||
const array<index_t, ndim_hidden_>& guaranteed_vector_strides)
|
||||
{
|
||||
auto vector_lengths = guaranteed_vector_lengths;
|
||||
auto vector_strides = guaranteed_vector_strides;
|
||||
|
||||
static_for<0, get_num_of_transform(), 1>{}([&](auto itran) {
|
||||
constexpr auto low_dims = get_lower_dimension_hidden_idss().at(itran);
|
||||
constexpr auto up_dims = get_upper_dimension_hidden_idss().at(itran);
|
||||
|
||||
const auto up_guaranteed_vector_lengths =
|
||||
get_container_subset(guaranteed_vector_lengths, up_dims);
|
||||
const auto up_guaranteed_vector_strides =
|
||||
get_container_subset(guaranteed_vector_strides, up_dims);
|
||||
|
||||
// only need type of transform
|
||||
auto [up_vector_lengths, up_vector_strides] =
|
||||
Transforms{}.at(itran).calculate_upper_dimension_safe_vector_length_strides(
|
||||
get_container_subset(vector_lengths, low_dims),
|
||||
get_container_subset(vector_strides, low_dims));
|
||||
|
||||
if constexpr(up_dims.size() > 0)
|
||||
{
|
||||
for(index_t i = 0; i < up_dims.size(); ++i)
|
||||
{
|
||||
up_vector_lengths(i) = (up_guaranteed_vector_lengths[i] != -1)
|
||||
? up_guaranteed_vector_lengths[i]
|
||||
: up_vector_lengths[i];
|
||||
|
||||
up_vector_strides(i) = (up_guaranteed_vector_strides[i] != -1)
|
||||
? up_guaranteed_vector_strides[i]
|
||||
: up_vector_strides[i];
|
||||
}
|
||||
}
|
||||
|
||||
set_container_subset(vector_lengths, up_dims, up_vector_lengths);
|
||||
set_container_subset(vector_strides, up_dims, up_vector_strides);
|
||||
});
|
||||
|
||||
constexpr auto top_dims = TopDimensionHiddenIds{};
|
||||
|
||||
return make_tuple(get_container_subset(vector_lengths, top_dims),
|
||||
get_container_subset(vector_strides, top_dims));
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void print() const
|
||||
{
|
||||
printf("tensor_adaptor{");
|
||||
|
||||
//
|
||||
printf("transforms: ");
|
||||
print(transforms_);
|
||||
printf(", ");
|
||||
|
||||
//
|
||||
printf("LowerDimensionHiddenIds: ");
|
||||
print(LowerDimensionHiddenIdss{});
|
||||
printf(", ");
|
||||
|
||||
//
|
||||
printf("UpperDimensionHiddenIds: ");
|
||||
print(UpperDimensionHiddenIdss{});
|
||||
printf(", ");
|
||||
|
||||
//
|
||||
printf("BottomDimensionHiddenIds: ");
|
||||
print(BottomDimensionHiddenIds{});
|
||||
printf(", ");
|
||||
|
||||
//
|
||||
printf("TopDimensionHiddenIds: ");
|
||||
print(TopDimensionHiddenIds{});
|
||||
|
||||
printf("}");
|
||||
}
|
||||
|
||||
private:
|
||||
Transforms transforms_;
|
||||
ElementSize element_size_;
|
||||
};
|
||||
|
||||
// Transforms: Tuple<transforms...>
|
||||
// LowerDimensionOldTopIdss: Tuple<Sequence<...>, ...>
|
||||
// UpperDimensionNewTopIdss: Tuple<Sequence<...>, ...>
|
||||
template <typename Transforms, typename LowerDimensionOldTopIdss, typename UpperDimensionNewTopIdss>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_single_stage_tensor_adaptor(const Transforms& transforms,
|
||||
LowerDimensionOldTopIdss,
|
||||
UpperDimensionNewTopIdss)
|
||||
{
|
||||
constexpr index_t ntransform = Transforms::size();
|
||||
|
||||
static_assert(LowerDimensionOldTopIdss::size() == ntransform &&
|
||||
UpperDimensionNewTopIdss::size() == ntransform,
|
||||
"wrong!");
|
||||
|
||||
// sanity check on LowerDimensionOldTopIdss and UpperDimensionNewTopIdss
|
||||
constexpr auto all_low_dim_old_top_ids = unpack(
|
||||
[](auto&&... xs) constexpr { return merge_sequences(xs...); }, LowerDimensionOldTopIdss{});
|
||||
|
||||
constexpr auto all_up_dim_new_top_ids = unpack(
|
||||
[](auto&&... xs) constexpr { return merge_sequences(xs...); }, UpperDimensionNewTopIdss{});
|
||||
|
||||
static_assert(is_valid_sequence_map<decltype(all_low_dim_old_top_ids)>::value &&
|
||||
is_valid_sequence_map<decltype(all_up_dim_new_top_ids)>::value,
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t ndim_old_top = all_low_dim_old_top_ids.size();
|
||||
constexpr index_t ndim_new_top = all_up_dim_new_top_ids.size();
|
||||
|
||||
// low_dim_hidden_idss
|
||||
constexpr auto low_dim_hidden_idss = LowerDimensionOldTopIdss{};
|
||||
|
||||
// up_dim_hidden_idss: shift UpperDimensionNewTopIdss by ndim_bottom
|
||||
constexpr auto up_dim_hidden_idss = generate_tuple(
|
||||
[](auto itran) { return UpperDimensionNewTopIdss{}[itran] + number<ndim_old_top>{}; },
|
||||
number<ntransform>{});
|
||||
|
||||
// bottom_dim_hidden_ids
|
||||
constexpr auto bottom_dim_hidden_ids =
|
||||
typename arithmetic_sequence_gen<0, ndim_old_top, 1>::type{};
|
||||
|
||||
// top_dim_hidden_ids
|
||||
constexpr auto top_dim_hidden_ids =
|
||||
typename arithmetic_sequence_gen<0, ndim_new_top, 1>::type{} + number<ndim_old_top>{};
|
||||
|
||||
return tensor_adaptor<remove_cvref_t<Transforms>,
|
||||
remove_cvref_t<decltype(low_dim_hidden_idss)>,
|
||||
remove_cvref_t<decltype(up_dim_hidden_idss)>,
|
||||
remove_cvref_t<decltype(bottom_dim_hidden_ids)>,
|
||||
remove_cvref_t<decltype(top_dim_hidden_ids)>>{transforms};
|
||||
}
|
||||
|
||||
// TODO: How to fix this? It uses an struct instead of lambda because lambda
|
||||
// doesn't have constructor, and to put it outside the scope where it is used
|
||||
// (transform_tensor_adaptor) because template cannot be defined inside a function
|
||||
// template
|
||||
template <typename NewTransforms>
|
||||
struct lambda_get_up_dim_num
|
||||
{
|
||||
template <typename I>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator()(I) const
|
||||
{
|
||||
using Tran = remove_reference_t<decltype(NewTransforms{}.at(I{}))>;
|
||||
return number<Tran::get_num_of_upper_dimension()>{};
|
||||
}
|
||||
};
|
||||
|
||||
template <typename OldTensorAdaptor,
|
||||
typename NewTransforms,
|
||||
typename NewLowerDimensionOldTopIdss,
|
||||
typename NewUpperDimensionNewTopIdss>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
transform_tensor_adaptor(const OldTensorAdaptor& old_tensor_adaptor,
|
||||
const NewTransforms& new_transforms,
|
||||
NewLowerDimensionOldTopIdss,
|
||||
NewUpperDimensionNewTopIdss)
|
||||
{
|
||||
// sanity check
|
||||
{
|
||||
static_assert(NewTransforms::size() == NewLowerDimensionOldTopIdss::size() &&
|
||||
NewTransforms::size() == NewUpperDimensionNewTopIdss::size(),
|
||||
"wrong! inconsitent number of transform");
|
||||
|
||||
constexpr auto all_old_top_ids = unpack([](auto... xs) { return merge_sequences(xs...); },
|
||||
NewLowerDimensionOldTopIdss{});
|
||||
|
||||
constexpr auto all_new_top_ids = unpack([](auto... xs) { return merge_sequences(xs...); },
|
||||
NewUpperDimensionNewTopIdss{});
|
||||
|
||||
static_assert(is_valid_sequence_map<decltype(all_old_top_ids)>::value &&
|
||||
is_valid_sequence_map<decltype(all_new_top_ids)>::value,
|
||||
"wrong!");
|
||||
}
|
||||
|
||||
// lower dimension's hidden idss
|
||||
// convert lower dimension top idss (tuple of sequences) to hidden idss (tuple of
|
||||
// sequences)
|
||||
constexpr auto low_dim_hidden_idss = transform_tuples(
|
||||
// convert lower dimension top ids (a sequence) to hidden ids (a sequence)
|
||||
[](auto low_dim_top_ids) constexpr {
|
||||
return transform_sequences(
|
||||
// convert lower dimension top id to hidden id
|
||||
[](auto low_dim_top_id) constexpr {
|
||||
return OldTensorAdaptor::get_top_dimension_hidden_ids()[low_dim_top_id];
|
||||
},
|
||||
low_dim_top_ids);
|
||||
},
|
||||
NewLowerDimensionOldTopIdss{});
|
||||
|
||||
constexpr index_t num_new_transform = NewTransforms::size();
|
||||
|
||||
// upper dimension's hidden idss
|
||||
constexpr index_t old_hidden_dim_number = OldTensorAdaptor::get_num_of_hidden_dimension();
|
||||
|
||||
constexpr auto up_dim_numbers =
|
||||
generate_sequence(lambda_get_up_dim_num<NewTransforms>{}, number<num_new_transform>{});
|
||||
|
||||
constexpr auto up_dim_numbers_scan = merge_sequences(
|
||||
Sequence<0>{}, inclusive_scan_sequence(up_dim_numbers, plus<index_t>{}, number<0>{}));
|
||||
|
||||
constexpr auto up_dim_hidden_idss = generate_tuple(
|
||||
[ old_hidden_dim_number, up_dim_numbers_scan ](auto i) constexpr {
|
||||
return
|
||||
typename arithmetic_sequence_gen<old_hidden_dim_number + up_dim_numbers_scan[i],
|
||||
old_hidden_dim_number + up_dim_numbers_scan[i + 1],
|
||||
1>::type{};
|
||||
},
|
||||
number<num_new_transform>{});
|
||||
|
||||
// new top dimension's hidden ids
|
||||
constexpr auto unordered_new_top_dim_hidden_ids = unpack(
|
||||
[](auto... xs) constexpr { return merge_sequences(xs...); }, up_dim_hidden_idss);
|
||||
|
||||
constexpr auto new_top_dim_unordered2ordered = unpack(
|
||||
[](auto... xs) constexpr { return merge_sequences(xs...); }, NewUpperDimensionNewTopIdss{});
|
||||
|
||||
constexpr auto new_top_dim_hidden_ids =
|
||||
unordered_new_top_dim_hidden_ids.reorder_old_to_new(new_top_dim_unordered2ordered);
|
||||
|
||||
// put everything together
|
||||
const auto all_transforms =
|
||||
container_concat(old_tensor_adaptor.get_transforms(), new_transforms);
|
||||
|
||||
constexpr auto all_low_dim_hidden_idss =
|
||||
container_concat(OldTensorAdaptor::get_lower_dimension_hidden_idss(), low_dim_hidden_idss);
|
||||
|
||||
constexpr auto all_up_dim_hidden_idss =
|
||||
container_concat(OldTensorAdaptor::get_upper_dimension_hidden_idss(), up_dim_hidden_idss);
|
||||
|
||||
return tensor_adaptor<
|
||||
remove_cvref_t<decltype(all_transforms)>,
|
||||
remove_cvref_t<decltype(all_low_dim_hidden_idss)>,
|
||||
remove_cvref_t<decltype(all_up_dim_hidden_idss)>,
|
||||
remove_cvref_t<decltype(OldTensorAdaptor::get_bottom_dimension_hidden_ids())>,
|
||||
remove_cvref_t<decltype(new_top_dim_hidden_ids)>>{all_transforms};
|
||||
}
|
||||
|
||||
template <typename TensorAdaptor0, typename TensorAdaptor1>
|
||||
CK_TILE_HOST_DEVICE constexpr auto chain_tensor_adaptors(const TensorAdaptor0& adaptor0,
|
||||
const TensorAdaptor1& adaptor1)
|
||||
{
|
||||
static_assert(TensorAdaptor0::get_num_of_top_dimension() ==
|
||||
TensorAdaptor1::get_num_of_bottom_dimension(),
|
||||
"wrong!");
|
||||
|
||||
// all_transforms = transform0 + transform1
|
||||
const auto all_transforms =
|
||||
container_concat(adaptor0.get_transforms(), adaptor1.get_transforms());
|
||||
|
||||
// shift
|
||||
constexpr index_t adaptor0_max_hidden_id = [&]() {
|
||||
index_t adaptor0_max_hidden_id_ = NumericLimits<index_t>::Min();
|
||||
|
||||
static_for<0, TensorAdaptor0::get_num_of_transform(), 1>{}([&](auto itran) {
|
||||
constexpr index_t ndim_low =
|
||||
TensorAdaptor0{}.get_transforms()[itran].get_num_of_lower_dimension();
|
||||
|
||||
static_for<0, ndim_low, 1>{}([&](auto idim_low) {
|
||||
adaptor0_max_hidden_id_ =
|
||||
max(adaptor0_max_hidden_id_,
|
||||
TensorAdaptor0::get_lower_dimension_hidden_idss()[itran][idim_low].value);
|
||||
});
|
||||
|
||||
constexpr index_t ndim_up =
|
||||
TensorAdaptor0{}.get_transforms()[itran].get_num_of_upper_dimension();
|
||||
|
||||
static_for<0, ndim_up, 1>{}([&](auto idim_up) {
|
||||
adaptor0_max_hidden_id_ =
|
||||
max(adaptor0_max_hidden_id_,
|
||||
TensorAdaptor0::get_upper_dimension_hidden_idss()[itran][idim_up].value);
|
||||
});
|
||||
});
|
||||
|
||||
return adaptor0_max_hidden_id_;
|
||||
}();
|
||||
|
||||
constexpr index_t adaptor1_min_hidden_id = [&]() {
|
||||
index_t adaptor1_min_hidden_id_ = NumericLimits<index_t>::Max();
|
||||
|
||||
static_for<0, TensorAdaptor1::get_num_of_transform(), 1>{}([&](auto itran) {
|
||||
constexpr index_t ndim_low =
|
||||
TensorAdaptor1{}.get_transforms()[itran].get_num_of_lower_dimension();
|
||||
|
||||
// get the min of all lower dimenions, but not bottom dimension (because their id will
|
||||
// be matched with top id from adaptor0)
|
||||
static_for<0, ndim_low, 1>{}([&](auto idim_low) {
|
||||
constexpr index_t low_dim_hidden_id =
|
||||
TensorAdaptor1::get_lower_dimension_hidden_idss()[itran][idim_low].value;
|
||||
|
||||
bool is_bottom_dim = false;
|
||||
static_for<0, TensorAdaptor1::get_num_of_bottom_dimension(), 1>{}([&](auto i) {
|
||||
if constexpr(low_dim_hidden_id ==
|
||||
TensorAdaptor1::get_bottom_dimension_hidden_ids()[i])
|
||||
{
|
||||
is_bottom_dim = true;
|
||||
}
|
||||
});
|
||||
|
||||
if(!is_bottom_dim)
|
||||
{
|
||||
adaptor1_min_hidden_id_ = min(adaptor1_min_hidden_id_, low_dim_hidden_id);
|
||||
}
|
||||
});
|
||||
|
||||
constexpr index_t ndim_up =
|
||||
TensorAdaptor1{}.get_transforms()[itran].get_num_of_upper_dimension();
|
||||
|
||||
// get the min of all upper dimensions
|
||||
static_for<0, ndim_up, 1>{}([&](auto idim_up) {
|
||||
adaptor1_min_hidden_id_ =
|
||||
min(adaptor1_min_hidden_id_,
|
||||
TensorAdaptor1::get_upper_dimension_hidden_idss()[itran][idim_up].value);
|
||||
});
|
||||
});
|
||||
|
||||
return adaptor1_min_hidden_id_;
|
||||
}();
|
||||
|
||||
constexpr index_t adaptor1_hidden_id_shift =
|
||||
adaptor0_max_hidden_id + 1 - adaptor1_min_hidden_id;
|
||||
|
||||
constexpr index_t ndim_bottom_1 = TensorAdaptor1::get_num_of_bottom_dimension();
|
||||
|
||||
// all_low_dim_hidden_idss =
|
||||
// low_dim_hidden_idss_0 + match_hidden_id_for_1(shift_hidden_id_for_1(low_dim_hiden_idss_1))
|
||||
constexpr auto low_dim_hidden_idss_1 = generate_tuple(
|
||||
// generate sequence of ids for a transform
|
||||
[&](auto itran) {
|
||||
constexpr auto ndim_low_1 =
|
||||
TensorAdaptor1::get_lower_dimension_hidden_idss()[itran].size();
|
||||
|
||||
constexpr auto low_dim_hidden_ids_1 =
|
||||
TensorAdaptor1::get_lower_dimension_hidden_idss()[itran];
|
||||
|
||||
// sequence in, sequence out
|
||||
constexpr auto low_dim_hidden_ids_1_mod = [&]() constexpr
|
||||
{
|
||||
auto low_dim_hidden_ids_1_mod_ = to_multi_index(low_dim_hidden_ids_1);
|
||||
|
||||
// shift hidden id so every dim id is unique
|
||||
static_for<0, ndim_low_1, 1>{}([&](auto idim_low_1) {
|
||||
low_dim_hidden_ids_1_mod_(idim_low_1) += adaptor1_hidden_id_shift;
|
||||
});
|
||||
|
||||
// match hidden id
|
||||
static_for<0, ndim_low_1, 1>{}([&](auto idim_low_1) {
|
||||
static_for<0, ndim_bottom_1, 1>{}([&](auto idim_bottom_1) {
|
||||
// if this low dim is bottom dim, then do id matching
|
||||
if constexpr(low_dim_hidden_ids_1[idim_low_1] ==
|
||||
TensorAdaptor1::get_bottom_dimension_hidden_ids()
|
||||
[idim_bottom_1])
|
||||
{
|
||||
low_dim_hidden_ids_1_mod_(idim_low_1) =
|
||||
TensorAdaptor0::get_top_dimension_hidden_ids()[idim_bottom_1];
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
return low_dim_hidden_ids_1_mod_;
|
||||
}
|
||||
();
|
||||
|
||||
return generate_sequence_v2(
|
||||
[&](auto i) constexpr { return number<low_dim_hidden_ids_1_mod[i]>{}; },
|
||||
number<ndim_low_1>{});
|
||||
},
|
||||
number<TensorAdaptor1::get_num_of_transform()>{});
|
||||
|
||||
constexpr auto all_low_dim_hidden_idss =
|
||||
container_concat(TensorAdaptor0::get_lower_dimension_hidden_idss(), low_dim_hidden_idss_1);
|
||||
|
||||
// all_up_dim_hidden_idss =
|
||||
// up_dim_hidden_idss_0 + shift_hidden_id_for_1(up_dim_hiden_idss_1)
|
||||
constexpr auto up_dim_hidden_idss_1 = generate_tuple(
|
||||
// generate sequence of ids for a transform
|
||||
[&](auto itran) {
|
||||
constexpr auto ndim_up_1 =
|
||||
TensorAdaptor1::get_upper_dimension_hidden_idss()[itran].size();
|
||||
|
||||
constexpr auto up_dim_hidden_ids_1 =
|
||||
TensorAdaptor1::get_upper_dimension_hidden_idss()[itran];
|
||||
|
||||
// sequence in, constexpr tuple out
|
||||
constexpr auto up_dim_hidden_ids_1_mod = [&]() constexpr
|
||||
{
|
||||
auto up_dim_hidden_ids_1_mod_ = to_multi_index(up_dim_hidden_ids_1);
|
||||
|
||||
// shift hidden id
|
||||
static_for<0, ndim_up_1, 1>{}([&](auto idim_up_1) {
|
||||
up_dim_hidden_ids_1_mod_(idim_up_1) += adaptor1_hidden_id_shift;
|
||||
});
|
||||
|
||||
return up_dim_hidden_ids_1_mod_;
|
||||
}
|
||||
();
|
||||
|
||||
// constexpr tuple to sequence
|
||||
return generate_sequence_v2(
|
||||
[&](auto i) constexpr { return number<up_dim_hidden_ids_1_mod[i]>{}; },
|
||||
number<ndim_up_1>{});
|
||||
},
|
||||
number<TensorAdaptor1::get_num_of_transform()>{});
|
||||
|
||||
constexpr auto all_up_dim_hidden_idss =
|
||||
container_concat(TensorAdaptor0::get_upper_dimension_hidden_idss(), up_dim_hidden_idss_1);
|
||||
|
||||
// bottom_dim_hidden_ids = bottom_dim_hidden_ids_0
|
||||
constexpr auto bottom_dim_hidden_ids = TensorAdaptor0::get_bottom_dimension_hidden_ids();
|
||||
|
||||
// top_dim_hidden_ids = shift_hidden_id(top_dim_hidden_ids_1)
|
||||
constexpr auto top_dim_hidden_ids =
|
||||
TensorAdaptor1::get_top_dimension_hidden_ids() + number<adaptor1_hidden_id_shift>{};
|
||||
|
||||
// put everything together
|
||||
return tensor_adaptor<remove_cvref_t<decltype(all_transforms)>,
|
||||
remove_cvref_t<decltype(all_low_dim_hidden_idss)>,
|
||||
remove_cvref_t<decltype(all_up_dim_hidden_idss)>,
|
||||
remove_cvref_t<decltype(bottom_dim_hidden_ids)>,
|
||||
remove_cvref_t<decltype(top_dim_hidden_ids)>>{all_transforms};
|
||||
}
|
||||
|
||||
template <typename X, typename... Xs, typename enable_if<sizeof...(Xs) >= 2, bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr auto chain_tensor_adaptors(const X& x, const Xs&... xs)
|
||||
{
|
||||
return chain_tensor_adaptors(x, chain_tensor_adaptors(xs...));
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
// Macro function
|
||||
// construct constexpr tensor_adaptor from constexpr encoding
|
||||
// encoded_tensor_adaptor are Tuple of following objects:
|
||||
// 1. encoded transforms (array of fixed size). Each encoded transform is a Tuple of following:
|
||||
// 1.1 name (cood_transform_enum)
|
||||
// 1.2 meta data for constructor of the transform
|
||||
// 1.3 num of lower dimension (index_t)
|
||||
// 1.4 lower dimension Ids (array of fixed size)
|
||||
// 1.5 num of up dimension (index_t)
|
||||
// 1.6 upper dimension Ids (array of fixed size)
|
||||
// 2. num of transforms (index_t)
|
||||
// 3. encoded bottom dimension Ids (array of fixed size)
|
||||
// 4. num of bottom dimension (index_t)
|
||||
// 5. encoded top dimension Ids (array of fixed size)
|
||||
// 6. num of top dimension (index_t)
|
||||
#define CONSTRUCT_TENSOR_ADAPTOR_FROM_ENCODING(encoded_tensor_adaptor) \
|
||||
[encoded_tensor_adaptor]() { \
|
||||
using namespace ck_tile; \
|
||||
\
|
||||
constexpr auto encoded_transforms = encoded_tensor_adaptor.template at<0>(); \
|
||||
constexpr index_t num_transform = encoded_tensor_adaptor.template at<1>(); \
|
||||
constexpr auto encoded_bottom_dims = encoded_tensor_adaptor.template at<2>(); \
|
||||
constexpr index_t num_bottom_dim = encoded_tensor_adaptor.template at<3>(); \
|
||||
constexpr auto encoded_top_dims = encoded_tensor_adaptor.template at<4>(); \
|
||||
constexpr index_t num_top_dim = encoded_tensor_adaptor.template at<5>(); \
|
||||
\
|
||||
constexpr auto trans = [&encoded_transforms, &num_transform]() { \
|
||||
return generate_tuple( \
|
||||
[&encoded_transforms](auto i) constexpr { \
|
||||
constexpr auto name = encoded_transforms[i].template at<0>(); \
|
||||
constexpr auto meta_data = encoded_transforms[i].template at<1>(); \
|
||||
constexpr auto num_low_dim = encoded_transforms[i].template at<2>(); \
|
||||
constexpr auto num_up_dim = encoded_transforms[i].template at<4>(); \
|
||||
\
|
||||
STATIC_ASSERT(name == cood_transform_enum::PassThrough || \
|
||||
name == cood_transform_enum::pad || \
|
||||
name == cood_transform_enum::embed || \
|
||||
name == cood_transform_enum::merge || \
|
||||
name == cood_transform_enum::unmerge || \
|
||||
name == cood_transform_enum::replicate, \
|
||||
""); \
|
||||
\
|
||||
if constexpr(name == cood_transform_enum::PassThrough) \
|
||||
{ \
|
||||
index_t pos = 0; \
|
||||
auto low_len = meta_data.template pop<index_t>(pos); \
|
||||
\
|
||||
return make_pass_through_transform(low_len); \
|
||||
} \
|
||||
else if constexpr(name == cood_transform_enum::pad) \
|
||||
{ \
|
||||
index_t pos = 0; \
|
||||
auto low_len = meta_data.template pop<index_t>(pos); \
|
||||
auto left_pad = meta_data.template pop<index_t>(pos); \
|
||||
auto right_pad = meta_data.template pop<index_t>(pos); \
|
||||
\
|
||||
return make_pad_transform(low_len, left_pad, right_pad); \
|
||||
} \
|
||||
else if constexpr(name == cood_transform_enum::embed) \
|
||||
{ \
|
||||
index_t pos = 0; \
|
||||
auto up_lens = meta_data.template pop<array<index_t, num_up_dim>>(pos); \
|
||||
auto coefficients = \
|
||||
meta_data.template pop<array<index_t, num_up_dim>>(pos); \
|
||||
\
|
||||
return make_embed_transform(up_lens, coefficients); \
|
||||
} \
|
||||
else if constexpr(name == cood_transform_enum::merge) \
|
||||
{ \
|
||||
index_t pos = 0; \
|
||||
auto low_lens = meta_data.template pop<array<index_t, num_low_dim>>(pos); \
|
||||
\
|
||||
return make_merge_transform(low_lens); \
|
||||
} \
|
||||
else if constexpr(name == cood_transform_enum::unmerge) \
|
||||
{ \
|
||||
index_t pos = 0; \
|
||||
auto up_lens = meta_data.template pop<array<index_t, num_up_dim>>(pos); \
|
||||
\
|
||||
return make_unmerge_transform(up_lens); \
|
||||
} \
|
||||
else if constexpr(name == cood_transform_enum::replicate) \
|
||||
{ \
|
||||
index_t pos = 0; \
|
||||
auto up_lens = meta_data.template pop<array<index_t, num_up_dim>>(pos); \
|
||||
\
|
||||
return make_replicate_transform(up_lens); \
|
||||
} \
|
||||
}, \
|
||||
number<num_transform>{}); \
|
||||
}(); \
|
||||
\
|
||||
constexpr auto low_dim_idss = [&encoded_transforms, &num_transform]() { \
|
||||
return generate_tuple( \
|
||||
[&encoded_transforms](auto i) { \
|
||||
constexpr auto num_low_dim = encoded_transforms[i].template at<2>(); \
|
||||
constexpr auto low_dims = encoded_transforms[i].template at<3>(); \
|
||||
\
|
||||
return TO_SEQUENCE(low_dims, num_low_dim); \
|
||||
}, \
|
||||
number<num_transform>()); \
|
||||
}(); \
|
||||
\
|
||||
constexpr auto up_dim_idss = [&encoded_transforms, &num_transform] { \
|
||||
return generate_tuple( \
|
||||
[&encoded_transforms](auto i) { \
|
||||
constexpr auto num_up_dim = encoded_transforms[i].template at<4>(); \
|
||||
constexpr auto up_dims = encoded_transforms[i].template at<5>(); \
|
||||
\
|
||||
return TO_SEQUENCE(up_dims, num_up_dim); \
|
||||
}, \
|
||||
number<num_transform>()); \
|
||||
}(); \
|
||||
\
|
||||
constexpr auto bottom_dim_ids = TO_SEQUENCE(encoded_bottom_dims, num_bottom_dim); \
|
||||
constexpr auto top_dim_ids = TO_SEQUENCE(encoded_top_dims, num_top_dim); \
|
||||
\
|
||||
return tensor_adaptor<remove_cvref_t<decltype(trans)>, \
|
||||
remove_cvref_t<decltype(low_dim_idss)>, \
|
||||
remove_cvref_t<decltype(up_dim_idss)>, \
|
||||
remove_cvref_t<decltype(bottom_dim_ids)>, \
|
||||
remove_cvref_t<decltype(top_dim_ids)>>{trans}; \
|
||||
}()
|
||||
|
||||
// Macro function
|
||||
// construct static tensor_adaptor from constexpr encoding
|
||||
// encoded_tensor_adaptor are Tuple of following objects:
|
||||
// 1. encoded transforms (array of fixed size). Each encoded transform is a Tuple of following:
|
||||
// 1.1 name (cood_transform_enum)
|
||||
// 1.2 meta data for constructor of the transform
|
||||
// 1.3 num of lower dimension (index_t)
|
||||
// 1.4 lower dimension Ids (array of fixed size)
|
||||
// 1.5 num of up dimension (index_t)
|
||||
// 1.6 upper dimension Ids (array of fixed size)
|
||||
// 2. num of transforms (index_t)
|
||||
// 3. encoded bottom dimension Ids (array of fixed size)
|
||||
// 4. num of bottom dimension (index_t)
|
||||
// 5. encoded top dimension Ids (array of fixed size)
|
||||
// 6. num of top dimension (index_t)
|
||||
#define CONSTRUCT_STATIC_TENSOR_ADAPTOR_FROM_ENCODING(encoded_tensor_adaptor) \
|
||||
[encoded_tensor_adaptor]() { \
|
||||
using namespace ck_tile; \
|
||||
\
|
||||
constexpr auto encoded_transforms = encoded_tensor_adaptor.template at<0>(); \
|
||||
constexpr index_t num_transform = encoded_tensor_adaptor.template at<1>(); \
|
||||
constexpr auto encoded_bottom_dims = encoded_tensor_adaptor.template at<2>(); \
|
||||
constexpr index_t num_bottom_dim = encoded_tensor_adaptor.template at<3>(); \
|
||||
constexpr auto encoded_top_dims = encoded_tensor_adaptor.template at<4>(); \
|
||||
constexpr index_t num_top_dim = encoded_tensor_adaptor.template at<5>(); \
|
||||
\
|
||||
constexpr auto trans = [&encoded_transforms, &num_transform]() { \
|
||||
return generate_tuple( \
|
||||
[&encoded_transforms](auto i) constexpr { \
|
||||
constexpr auto name = encoded_transforms[i].template at<0>(); \
|
||||
constexpr auto meta_data = encoded_transforms[i].template at<1>(); \
|
||||
constexpr auto num_low_dim = encoded_transforms[i].template at<2>(); \
|
||||
constexpr auto num_up_dim = encoded_transforms[i].template at<4>(); \
|
||||
\
|
||||
STATIC_ASSERT(name == cood_transform_enum::PassThrough || \
|
||||
name == cood_transform_enum::pad || \
|
||||
name == cood_transform_enum::embed || \
|
||||
name == cood_transform_enum::merge || \
|
||||
name == cood_transform_enum::unmerge || \
|
||||
name == cood_transform_enum::replicate, \
|
||||
""); \
|
||||
\
|
||||
if constexpr(name == cood_transform_enum::PassThrough) \
|
||||
{ \
|
||||
constexpr index_t low_len = meta_data.template get<index_t>(0); \
|
||||
\
|
||||
return make_pass_through_transform(number<low_len>{}); \
|
||||
} \
|
||||
else if constexpr(name == cood_transform_enum::pad) \
|
||||
{ \
|
||||
constexpr index_t low_len = meta_data.template get<index_t>(0); \
|
||||
\
|
||||
constexpr index_t left_pad = \
|
||||
meta_data.template get<index_t>(sizeof(low_len)); \
|
||||
\
|
||||
constexpr index_t right_pad = \
|
||||
meta_data.template pop<index_t>(sizeof(low_len) + sizeof(left_pad)); \
|
||||
\
|
||||
return make_pad_transform( \
|
||||
number<low_len>{}, number<left_pad>{}, number<right_pad>{}); \
|
||||
} \
|
||||
else if constexpr(name == cood_transform_enum::embed) \
|
||||
{ \
|
||||
constexpr auto up_lens = \
|
||||
meta_data.template get<array<index_t, num_up_dim>>(0); \
|
||||
\
|
||||
constexpr auto coefficients = \
|
||||
meta_data.template get<array<index_t, num_up_dim>>(sizeof(up_lens)); \
|
||||
\
|
||||
return make_embed_transform(TO_TUPLE_OF_NUMBER(up_lens, num_up_dim), \
|
||||
TO_TUPLE_OF_NUMBER(coefficients, num_up_dim)); \
|
||||
} \
|
||||
else if constexpr(name == cood_transform_enum::merge) \
|
||||
{ \
|
||||
constexpr auto low_lens = \
|
||||
meta_data.template get<array<index_t, num_low_dim>>(0); \
|
||||
\
|
||||
return make_merge_transform(TO_TUPLE_OF_NUMBER(low_lens, num_low_dim)); \
|
||||
} \
|
||||
else if constexpr(name == cood_transform_enum::unmerge) \
|
||||
{ \
|
||||
constexpr auto up_lens = \
|
||||
meta_data.template get<array<index_t, num_up_dim>>(0); \
|
||||
\
|
||||
return make_unmerge_transform(TO_TUPLE_OF_NUMBER(up_lens, num_up_dim)); \
|
||||
} \
|
||||
else if constexpr(name == cood_transform_enum::replicate) \
|
||||
{ \
|
||||
constexpr auto up_lens = \
|
||||
meta_data.template get<array<index_t, num_up_dim>>(0); \
|
||||
\
|
||||
return make_replicate_transform(TO_TUPLE_OF_NUMBER(up_lens, num_up_dim)); \
|
||||
} \
|
||||
}, \
|
||||
number<num_transform>{}); \
|
||||
}(); \
|
||||
\
|
||||
constexpr auto low_dim_idss = [&encoded_transforms, &num_transform]() { \
|
||||
return generate_tuple( \
|
||||
[&encoded_transforms](auto i) { \
|
||||
constexpr auto num_low_dim = encoded_transforms[i].template at<2>(); \
|
||||
constexpr auto low_dims = encoded_transforms[i].template at<3>(); \
|
||||
\
|
||||
return TO_SEQUENCE(low_dims, num_low_dim); \
|
||||
}, \
|
||||
number<num_transform>()); \
|
||||
}(); \
|
||||
\
|
||||
constexpr auto up_dim_idss = [&encoded_transforms, &num_transform] { \
|
||||
return generate_tuple( \
|
||||
[&encoded_transforms](auto i) { \
|
||||
constexpr auto num_up_dim = encoded_transforms[i].template at<4>(); \
|
||||
constexpr auto up_dims = encoded_transforms[i].template at<5>(); \
|
||||
\
|
||||
return TO_SEQUENCE(up_dims, num_up_dim); \
|
||||
}, \
|
||||
number<num_transform>()); \
|
||||
}(); \
|
||||
\
|
||||
constexpr auto bottom_dim_ids = TO_SEQUENCE(encoded_bottom_dims, num_bottom_dim); \
|
||||
constexpr auto top_dim_ids = TO_SEQUENCE(encoded_top_dims, num_top_dim); \
|
||||
\
|
||||
return tensor_adaptor<remove_cvref_t<decltype(trans)>, \
|
||||
remove_cvref_t<decltype(low_dim_idss)>, \
|
||||
remove_cvref_t<decltype(up_dim_idss)>, \
|
||||
remove_cvref_t<decltype(bottom_dim_ids)>, \
|
||||
remove_cvref_t<decltype(top_dim_ids)>>{trans}; \
|
||||
}()
|
||||
257
include/ck_tile/core/tensor/tensor_adaptor_coordinate.hpp
Normal file
257
include/ck_tile/core/tensor/tensor_adaptor_coordinate.hpp
Normal file
@@ -0,0 +1,257 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
|
||||
#include "ck_tile/core/tensor/tensor_adaptor.hpp"
|
||||
#include "ck_tile/core/container/container_helper.hpp"
|
||||
#include "ck_tile/core/container/multi_index.hpp"
|
||||
#include "ck_tile/core/numeric/math.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <index_t NDimHidden, typename BottomDimensionHiddenIds, typename TopDimensionHiddenIds>
|
||||
struct tensor_adaptor_coordinate
|
||||
{
|
||||
static constexpr index_t ndim_bottom_ = BottomDimensionHiddenIds::size();
|
||||
static constexpr index_t ndim_top_ = TopDimensionHiddenIds::size();
|
||||
|
||||
using HiddenIndex = multi_index<NDimHidden>;
|
||||
using BottomIndex = multi_index<ndim_bottom_>;
|
||||
using TopIndex = multi_index<ndim_top_>;
|
||||
|
||||
public:
|
||||
CK_TILE_HOST_DEVICE constexpr tensor_adaptor_coordinate() = default;
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr tensor_adaptor_coordinate(const HiddenIndex& idx_hidden)
|
||||
: idx_hidden_{idx_hidden}
|
||||
{
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr auto get_top_index() const
|
||||
{
|
||||
return get_container_subset(idx_hidden_, TopDimensionHiddenIds{});
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr auto get_bottom_index() const
|
||||
{
|
||||
return get_container_subset(idx_hidden_, BottomDimensionHiddenIds{});
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr const auto& get_hidden_index() const { return idx_hidden_; }
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr auto& get_hidden_index() { return idx_hidden_; }
|
||||
|
||||
//
|
||||
HiddenIndex idx_hidden_;
|
||||
};
|
||||
|
||||
template <typename Adaptor, typename TopIndex>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_tensor_adaptor_coordinate(const Adaptor& adaptor,
|
||||
const TopIndex& idx_top)
|
||||
{
|
||||
static_assert(Adaptor::get_num_of_top_dimension() == TopIndex::size(),
|
||||
"wrong! # of dimension inconsistent");
|
||||
|
||||
constexpr index_t ntransform = Adaptor::get_num_of_transform();
|
||||
constexpr index_t ndim_hidden = Adaptor::get_num_of_hidden_dimension();
|
||||
constexpr auto bottom_dim_ids = Adaptor::get_bottom_dimension_hidden_ids();
|
||||
constexpr auto top_dim_ids = Adaptor::get_top_dimension_hidden_ids();
|
||||
|
||||
multi_index<ndim_hidden> idx_hidden;
|
||||
|
||||
// initialize visible index
|
||||
set_container_subset(idx_hidden, top_dim_ids, idx_top);
|
||||
|
||||
// calculate hidden index
|
||||
static_for<ntransform, 0, -1>{}([&adaptor, &idx_hidden](auto itran_p1) {
|
||||
auto itran = itran_p1 - number<1>{};
|
||||
const auto& tran = adaptor.get_transforms().at(itran);
|
||||
constexpr auto dims_low = Adaptor::get_lower_dimension_hidden_idss().at(itran);
|
||||
constexpr auto dims_up = Adaptor::get_upper_dimension_hidden_idss().at(itran);
|
||||
|
||||
const auto idx_up = get_container_subset(idx_hidden, dims_up);
|
||||
|
||||
multi_index<dims_low.size()> idx_low;
|
||||
|
||||
tran.calculate_lower_index(idx_low, idx_up);
|
||||
|
||||
set_container_subset(idx_hidden, dims_low, idx_low);
|
||||
});
|
||||
|
||||
return tensor_adaptor_coordinate<ndim_hidden,
|
||||
remove_cvref_t<decltype(bottom_dim_ids)>,
|
||||
remove_cvref_t<decltype(top_dim_ids)>>{idx_hidden};
|
||||
}
|
||||
|
||||
template <bool JudgeDoTransforms = true,
|
||||
typename Adaptor,
|
||||
typename AdaptorCoord,
|
||||
typename TopIndex,
|
||||
typename BottomIndex>
|
||||
CK_TILE_HOST_DEVICE constexpr void move_tensor_adaptor_coordinate(const Adaptor& adaptor,
|
||||
AdaptorCoord& coord,
|
||||
const TopIndex& idx_diff_top,
|
||||
BottomIndex& idx_diff_bottom)
|
||||
{
|
||||
constexpr index_t ndim_hidden = Adaptor::get_num_of_hidden_dimension();
|
||||
constexpr index_t ndim_top = Adaptor::get_num_of_top_dimension();
|
||||
// constexpr index_t ndim_bottom = Adaptor::get_num_of_bottom_dimension();
|
||||
constexpr index_t ntransform = Adaptor::get_num_of_transform();
|
||||
|
||||
// STATIC_ASSERT(TopIndex::size() == ndim_top && BottomIndex::size() == ndim_bottom, "");
|
||||
|
||||
// judge whether calculation of lower diff is needed for each transform
|
||||
// use index_t for boolean type
|
||||
auto do_transforms = make_zero_multi_index<ntransform>();
|
||||
|
||||
if constexpr(JudgeDoTransforms)
|
||||
{
|
||||
auto is_non_zero_diff = make_zero_multi_index<ndim_hidden>();
|
||||
|
||||
// decide do_transform by checkout non-zero index diff components
|
||||
multi_index<ndim_top> non_zero_diff_pick_top;
|
||||
|
||||
static_for<0, ndim_top, 1>{}(
|
||||
[&](auto i) { non_zero_diff_pick_top(i) = (idx_diff_top[i] != 0); });
|
||||
|
||||
set_container_subset(
|
||||
is_non_zero_diff, Adaptor::get_top_dimension_hidden_ids(), non_zero_diff_pick_top);
|
||||
|
||||
static_for<ntransform - 1, -1, -1>{}([&](auto itran) {
|
||||
constexpr auto dims_low = Adaptor::get_lower_dimension_hidden_idss().at(itran);
|
||||
constexpr auto dims_up = Adaptor::get_upper_dimension_hidden_idss().at(itran);
|
||||
|
||||
const auto non_zero_diff_pick_up = get_container_subset(is_non_zero_diff, dims_up);
|
||||
|
||||
multi_index<dims_low.size()> non_zero_diff_pick_low;
|
||||
|
||||
// if any of upper index diff components is non-zero, then
|
||||
// 1) Need to do this transform
|
||||
// 2) all components of lower index diff will assume to be non-zero and need to be
|
||||
// computed
|
||||
const bool idx_diff_up_has_non_zero = container_reduce(
|
||||
non_zero_diff_pick_up, [](auto a, auto b) constexpr { return a or b; }, false);
|
||||
|
||||
do_transforms(itran) = idx_diff_up_has_non_zero;
|
||||
|
||||
static_for<0, dims_low.size(), 1>{}(
|
||||
[&](auto i) { non_zero_diff_pick_low(i) = idx_diff_up_has_non_zero; });
|
||||
|
||||
set_container_subset(is_non_zero_diff, dims_low, non_zero_diff_pick_low);
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
static_for<ntransform - 1, -1, -1>{}([&](auto itran) { do_transforms(itran) = 1; });
|
||||
}
|
||||
|
||||
// this is what needs to be calculated
|
||||
auto idx_diff_hidden = make_zero_multi_index<ndim_hidden>();
|
||||
|
||||
// initialize top index diff
|
||||
set_container_subset(idx_diff_hidden, Adaptor::get_top_dimension_hidden_ids(), idx_diff_top);
|
||||
|
||||
// this is what needs to be updated
|
||||
auto& idx_hidden = coord.get_hidden_index();
|
||||
|
||||
// update top index
|
||||
auto idx_hidden_pick_top =
|
||||
get_container_subset(idx_hidden, Adaptor::get_top_dimension_hidden_ids());
|
||||
|
||||
idx_hidden_pick_top += idx_diff_top;
|
||||
|
||||
set_container_subset(idx_hidden, Adaptor::get_top_dimension_hidden_ids(), idx_hidden_pick_top);
|
||||
|
||||
// update rest of hidden index
|
||||
static_for<ntransform - 1, -1, -1>{}([&](auto itran) {
|
||||
if(do_transforms[itran])
|
||||
{
|
||||
const auto& tran = adaptor.get_transforms().at(itran);
|
||||
constexpr auto dims_low = Adaptor::get_lower_dimension_hidden_idss().at(itran);
|
||||
constexpr auto dims_up = Adaptor::get_upper_dimension_hidden_idss().at(itran);
|
||||
|
||||
const auto idx_up_new = get_container_subset(idx_hidden, dims_up);
|
||||
auto idx_low = get_container_subset(idx_hidden, dims_low);
|
||||
const auto idx_diff_up = get_container_subset(idx_diff_hidden, dims_up);
|
||||
|
||||
multi_index<dims_low.size()> idx_diff_low;
|
||||
|
||||
tran.update_lower_index(idx_diff_low, idx_diff_up, idx_low, idx_up_new);
|
||||
|
||||
set_container_subset(idx_diff_hidden, dims_low, idx_diff_low);
|
||||
set_container_subset(idx_hidden, dims_low, idx_low);
|
||||
}
|
||||
});
|
||||
|
||||
// set bottom index diff
|
||||
idx_diff_bottom =
|
||||
get_container_subset(idx_diff_hidden, Adaptor::get_bottom_dimension_hidden_ids());
|
||||
}
|
||||
|
||||
template <bool JudgeDoTransforms = true, typename Adaptor, typename AdaptorCoord, typename TopIndex>
|
||||
CK_TILE_HOST_DEVICE constexpr void move_tensor_adaptor_coordinate(const Adaptor& adaptor,
|
||||
AdaptorCoord& coord,
|
||||
const TopIndex& idx_diff_top)
|
||||
{
|
||||
constexpr index_t ndim_bottom = Adaptor::get_num_of_bottom_dimension();
|
||||
|
||||
multi_index<ndim_bottom> tmp;
|
||||
|
||||
move_tensor_adaptor_coordinate<JudgeDoTransforms>(adaptor, coord, idx_diff_top, tmp);
|
||||
}
|
||||
|
||||
template <typename Adaptor, typename AdaptorCoord>
|
||||
CK_TILE_HOST_DEVICE constexpr bool
|
||||
adaptor_coordinate_is_valid_assuming_top_index_is_valid(const Adaptor& adaptor,
|
||||
const AdaptorCoord& coord)
|
||||
{
|
||||
bool valid = true;
|
||||
|
||||
constexpr index_t ntransform = Adaptor::get_num_of_transform();
|
||||
|
||||
const auto& idx_hidden = coord.get_hidden_index();
|
||||
|
||||
static_for<ntransform - 1, -1, -1>{}([&adaptor, &idx_hidden, &valid](auto itran) {
|
||||
const auto tran = adaptor.get_transforms().at(itran);
|
||||
|
||||
// check validity, only if current transformation does not always has a valid mapping
|
||||
if constexpr(!decltype(tran)::is_valid_upper_index_always_mapped_to_valid_lower_index())
|
||||
{
|
||||
const auto idx_up = get_container_subset(
|
||||
idx_hidden, Adaptor::get_upper_dimension_hidden_idss().at(itran));
|
||||
|
||||
// Comment: using valid = valid && .. will result in weird control flow in ISA
|
||||
valid &= tran.is_valid_upper_index_mapped_to_valid_lower_index(idx_up);
|
||||
}
|
||||
});
|
||||
|
||||
return valid;
|
||||
}
|
||||
|
||||
template <typename Adaptor, typename AdpatorCoord>
|
||||
CK_TILE_HOST_DEVICE constexpr bool adaptor_coordinate_is_valid(const Adaptor& adaptor,
|
||||
const AdpatorCoord& coord)
|
||||
{
|
||||
// check top index
|
||||
const auto& idx_top = coord.get_top_index();
|
||||
|
||||
bool is_top_index_valid = true;
|
||||
|
||||
static_for<0, Adaptor::get_num_of_dimension(), 1>{}(
|
||||
[&is_top_index_valid, &idx_top, &adaptor](auto i) {
|
||||
is_top_index_valid =
|
||||
is_top_index_valid && (idx_top[i] >= 0 && idx_top[i] < adaptor.get_length(i));
|
||||
});
|
||||
|
||||
// check other hidden index
|
||||
return is_top_index_valid &&
|
||||
adaptor_coordinate_is_valid_assuming_top_index_is_valid(adaptor, coord);
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
92
include/ck_tile/core/tensor/tensor_coordinate.hpp
Normal file
92
include/ck_tile/core/tensor/tensor_coordinate.hpp
Normal file
@@ -0,0 +1,92 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
|
||||
#include "ck_tile/core/tensor/tensor_adaptor.hpp"
|
||||
#include "ck_tile/core/tensor/tensor_adaptor_coordinate.hpp"
|
||||
#include "ck_tile/core/container/container_helper.hpp"
|
||||
#include "ck_tile/core/container/multi_index.hpp"
|
||||
#include "ck_tile/core/numeric/math.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <index_t NDimHidden, typename TopDimensionHiddenIds>
|
||||
struct tensor_coordinate
|
||||
: public tensor_adaptor_coordinate<NDimHidden, sequence<0>, TopDimensionHiddenIds>
|
||||
{
|
||||
using Base = tensor_adaptor_coordinate<NDimHidden, sequence<0>, TopDimensionHiddenIds>;
|
||||
|
||||
// TODO make these private
|
||||
static constexpr index_t ndim_top_ = TopDimensionHiddenIds::size();
|
||||
|
||||
using HiddenIndex = multi_index<NDimHidden>;
|
||||
using TopIndex = multi_index<ndim_top_>;
|
||||
|
||||
public:
|
||||
CK_TILE_HOST_DEVICE constexpr tensor_coordinate() = default;
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr tensor_coordinate(const HiddenIndex& idx_hidden)
|
||||
: Base{idx_hidden}
|
||||
{
|
||||
}
|
||||
|
||||
// construct from TensorAdaptorCoordinte base class
|
||||
CK_TILE_HOST_DEVICE constexpr tensor_coordinate(const Base& adaptor_coord) : Base{adaptor_coord}
|
||||
{
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr auto get_index() const { return Base::get_top_index(); }
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr index_t get_offset() const
|
||||
{
|
||||
return Base::get_bottom_index()[number<0>{}];
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr const auto& get_hidden_index() const
|
||||
{
|
||||
return Base::get_hidden_index();
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE auto& get_hidden_index() { return Base::get_hidden_index(); }
|
||||
};
|
||||
|
||||
template <typename TensorDesc, typename TopIndex>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_tensor_coordinate(const TensorDesc& tensor_desc,
|
||||
const TopIndex& idx_top)
|
||||
{
|
||||
const auto adaptor_coord = make_tensor_adaptor_coordinate(tensor_desc, idx_top);
|
||||
|
||||
return tensor_coordinate<TensorDesc::get_num_of_hidden_dimension(),
|
||||
remove_cvref_t<decltype(TensorDesc::get_top_dimension_hidden_ids())>>{
|
||||
adaptor_coord};
|
||||
}
|
||||
|
||||
template <bool JudgeDoTransforms = true, typename TensorDesc, typename TensorCoord, typename Index>
|
||||
CK_TILE_HOST_DEVICE constexpr void
|
||||
move_tensor_coordinate(const TensorDesc& tensor_desc, TensorCoord& coord, const Index& coord_step)
|
||||
{
|
||||
move_tensor_adaptor_coordinate(tensor_desc, coord, coord_step);
|
||||
}
|
||||
|
||||
template <typename TensorDesc, typename TensorCoord>
|
||||
CK_TILE_HOST_DEVICE constexpr bool
|
||||
coordinate_has_valid_offset_assuming_top_index_is_valid(const TensorDesc& tensor_desc,
|
||||
const TensorCoord& coord)
|
||||
{
|
||||
return adaptor_coordinate_is_valid_assuming_top_index_is_valid(tensor_desc, coord);
|
||||
}
|
||||
|
||||
template <typename TensorDesc, typename TensorCoord>
|
||||
CK_TILE_HOST_DEVICE constexpr bool coordinate_has_valid_offset(const TensorDesc& tensor_desc,
|
||||
const TensorCoord& coord)
|
||||
{
|
||||
return adaptor_coordinate_is_valid(tensor_desc, coord);
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
472
include/ck_tile/core/tensor/tensor_descriptor.hpp
Normal file
472
include/ck_tile/core/tensor/tensor_descriptor.hpp
Normal file
@@ -0,0 +1,472 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
|
||||
#include "ck_tile/core/tensor/tensor_adaptor.hpp"
|
||||
#include "ck_tile/core/container/container_helper.hpp"
|
||||
#include "ck_tile/core/container/multi_index.hpp"
|
||||
#include "ck_tile/core/numeric/math.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// Transforms: Tuple<transforms...>
|
||||
// LowerDimensionHiddenIdss : Tuple<sequence<...>, ...>
|
||||
// UpperDimensionHiddenIdss : Tuple<sequence<...>, ...>
|
||||
// TopDimensionHiddenIds> : sequence<...>
|
||||
template <typename Transforms,
|
||||
typename LowerDimensionHiddenIdss,
|
||||
typename UpperDimensionHiddenIdss,
|
||||
typename TopDimensionHiddenIds,
|
||||
typename ElementSpaceSize,
|
||||
typename GuaranteedVectorLengths_,
|
||||
typename GuaranteedVectorSrides_>
|
||||
struct tensor_descriptor : public tensor_adaptor<Transforms,
|
||||
LowerDimensionHiddenIdss,
|
||||
UpperDimensionHiddenIdss,
|
||||
sequence<0>,
|
||||
TopDimensionHiddenIds>
|
||||
{
|
||||
using Base = tensor_adaptor<Transforms,
|
||||
LowerDimensionHiddenIdss,
|
||||
UpperDimensionHiddenIdss,
|
||||
sequence<0>,
|
||||
TopDimensionHiddenIds>;
|
||||
|
||||
using ElementSpaceSizeType = ElementSpaceSize;
|
||||
|
||||
constexpr static index_t ntransform_ = Base::get_num_of_transform();
|
||||
constexpr static index_t ndim_hidden_ = Base::get_num_of_hidden_dimension();
|
||||
constexpr static index_t ndim_top_ = Base::get_num_of_top_dimension();
|
||||
|
||||
using GuaranteedVectorLengths = GuaranteedVectorLengths_;
|
||||
using GuaranteedVectorStrides = GuaranteedVectorSrides_;
|
||||
|
||||
static_assert(GuaranteedVectorLengths::size() == ndim_hidden_ &&
|
||||
GuaranteedVectorStrides::size() == ndim_hidden_,
|
||||
"wrong! inconsistent # of hidden dimensions");
|
||||
|
||||
using TopIndex = multi_index<ndim_top_>;
|
||||
using HiddenIndex = multi_index<ndim_hidden_>;
|
||||
|
||||
public:
|
||||
CK_TILE_HOST_DEVICE constexpr tensor_descriptor() = default;
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr tensor_descriptor(const Transforms& transforms,
|
||||
ElementSpaceSize element_space_size)
|
||||
: Base{transforms}, element_space_size_{element_space_size}
|
||||
|
||||
{
|
||||
static_assert(Transforms::size() == ntransform_ &&
|
||||
LowerDimensionHiddenIdss::size() == ntransform_ &&
|
||||
UpperDimensionHiddenIdss::size() == ntransform_,
|
||||
"wrong! inconsistent # of transformations");
|
||||
|
||||
// TODO check dependency of dimensions is valid
|
||||
}
|
||||
|
||||
// construct from tensor_adaptor base class
|
||||
CK_TILE_HOST_DEVICE constexpr tensor_descriptor(const Base& adaptor,
|
||||
ElementSpaceSize element_space_size)
|
||||
: Base{adaptor}, element_space_size_{element_space_size}
|
||||
{
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_dimension()
|
||||
{
|
||||
return Base::get_num_of_top_dimension();
|
||||
}
|
||||
|
||||
template <index_t IDim>
|
||||
CK_TILE_HOST_DEVICE constexpr auto get_length(number<IDim> idim) const
|
||||
{
|
||||
return Base::get_top_dimension_length(idim);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr auto get_lengths() const
|
||||
{
|
||||
return Base::get_top_dimension_length();
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr auto get_element_space_size() const
|
||||
{
|
||||
return element_space_size_;
|
||||
}
|
||||
|
||||
template <typename Idx>
|
||||
CK_TILE_HOST_DEVICE constexpr index_t calculate_offset(const Idx& idx) const
|
||||
{
|
||||
return Base::calculate_bottom_index(idx)[number<0>{}];
|
||||
}
|
||||
|
||||
// TODO make these private
|
||||
CK_TILE_HOST_DEVICE constexpr const auto& get_transforms() const
|
||||
{
|
||||
return Base::get_transforms();
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_lower_dimension_hidden_idss()
|
||||
{
|
||||
return Base::get_lower_dimension_hidden_idss();
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_upper_dimension_hidden_idss()
|
||||
{
|
||||
return Base::get_upper_dimension_hidden_idss();
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_top_dimension_hidden_ids()
|
||||
{
|
||||
return Base::get_top_dimension_hidden_ids();
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr bool is_static()
|
||||
{
|
||||
return Base::is_known_at_compile_time() &&
|
||||
ck_tile::is_known_at_compile_time<ElementSpaceSize>::value;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time() { return is_static(); }
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_top_dimension_safe_vector_length_strides()
|
||||
{
|
||||
return Base::get_top_dimension_safe_vector_length_strides(
|
||||
to_array<index_t, ndim_hidden_>(GuaranteedVectorLengths{}),
|
||||
to_array<index_t, ndim_hidden_>(GuaranteedVectorStrides{}));
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void print() const
|
||||
{
|
||||
printf("tensor_descriptor{");
|
||||
|
||||
// tensor_adaptor
|
||||
Base::print();
|
||||
printf(", ");
|
||||
|
||||
// element_space_size_
|
||||
printf("element_space_size_: ");
|
||||
print(element_space_size_);
|
||||
|
||||
printf("}");
|
||||
}
|
||||
|
||||
// TODO make these private
|
||||
ElementSpaceSize element_space_size_;
|
||||
};
|
||||
|
||||
template <typename Adaptor, typename ElementSpaceSize>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
make_tensor_descriptor_from_adaptor(const Adaptor& adaptor,
|
||||
const ElementSpaceSize& element_space_size)
|
||||
{
|
||||
constexpr index_t NDimHidden = Adaptor::get_num_of_hidden_dimension();
|
||||
|
||||
return tensor_descriptor<remove_cvref_t<decltype(adaptor.get_transforms())>,
|
||||
remove_cvref_t<decltype(adaptor.get_lower_dimension_hidden_idss())>,
|
||||
remove_cvref_t<decltype(adaptor.get_upper_dimension_hidden_idss())>,
|
||||
remove_cvref_t<decltype(adaptor.get_top_dimension_hidden_ids())>,
|
||||
remove_cvref_t<decltype(element_space_size)>,
|
||||
typename uniform_sequence_gen<NDimHidden, -1>::type,
|
||||
typename uniform_sequence_gen<NDimHidden, -1>::type>{
|
||||
adaptor, element_space_size};
|
||||
}
|
||||
|
||||
template <typename OldTensorDescriptor,
|
||||
typename NewTransforms,
|
||||
typename NewLowerDimensionOldTopIdss,
|
||||
typename NewUpperDimensionNewTopIdss>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
transform_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc,
|
||||
const NewTransforms& new_transforms,
|
||||
NewLowerDimensionOldTopIdss,
|
||||
NewUpperDimensionNewTopIdss)
|
||||
{
|
||||
const auto element_space_size = old_tensor_desc.get_element_space_size();
|
||||
|
||||
const auto new_tensor_adaptor = transform_tensor_adaptor(old_tensor_desc,
|
||||
new_transforms,
|
||||
NewLowerDimensionOldTopIdss{},
|
||||
NewUpperDimensionNewTopIdss{});
|
||||
|
||||
constexpr index_t NDimHiddenOld = OldTensorDescriptor::get_num_of_hidden_dimension();
|
||||
constexpr index_t NDimHiddenNew = decltype(new_tensor_adaptor)::get_num_of_hidden_dimension();
|
||||
|
||||
using NewGuaranteedVectorLengths = typename sequence_merge<
|
||||
typename OldTensorDescriptor::GuaranteedVectorLengths,
|
||||
typename uniform_sequence_gen<NDimHiddenNew - NDimHiddenOld, -1>::type>::type;
|
||||
|
||||
using NewGuaranteedVectorStrides = typename sequence_merge<
|
||||
typename OldTensorDescriptor::GuaranteedVectorStrides,
|
||||
typename uniform_sequence_gen<NDimHiddenNew - NDimHiddenOld, -1>::type>::type;
|
||||
|
||||
return tensor_descriptor<
|
||||
remove_cvref_t<decltype(new_tensor_adaptor.get_transforms())>,
|
||||
remove_cvref_t<decltype(new_tensor_adaptor.get_lower_dimension_hidden_idss())>,
|
||||
remove_cvref_t<decltype(new_tensor_adaptor.get_upper_dimension_hidden_idss())>,
|
||||
remove_cvref_t<decltype(new_tensor_adaptor.get_top_dimension_hidden_ids())>,
|
||||
remove_cvref_t<decltype(element_space_size)>,
|
||||
NewGuaranteedVectorLengths,
|
||||
NewGuaranteedVectorStrides>{new_tensor_adaptor, element_space_size};
|
||||
}
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <typename Lengths, typename Strides, index_t I, typename AccOld>
|
||||
CK_TILE_HOST_DEVICE constexpr auto calculate_element_space_size_impl(const Lengths& lengths,
|
||||
const Strides& strides,
|
||||
number<I> i,
|
||||
AccOld acc_old)
|
||||
{
|
||||
auto acc_new = acc_old + (lengths[i] - number<1>{}) * strides[i];
|
||||
|
||||
if constexpr(i.value < Lengths::size() - 1)
|
||||
{
|
||||
return calculate_element_space_size_impl(lengths, strides, i + number<1>{}, acc_new);
|
||||
}
|
||||
else
|
||||
{
|
||||
return acc_new;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
/*
|
||||
* These functions create naive tensor descriptor
|
||||
*/
|
||||
|
||||
// Lengths..., Strides... could be:
|
||||
// 1) index_t, which is known at run-time, or
|
||||
// 2) number<>, which is known at compile-time
|
||||
// element_space_size could be:
|
||||
// 1) long_index_t, or
|
||||
// 2) long_number<>
|
||||
template <typename... Lengths,
|
||||
typename... Strides,
|
||||
index_t GuaranteedLastDimensionVectorLength = -1,
|
||||
index_t GuaranteedLastDimensionVectorStride = -1,
|
||||
typename std::enable_if<sizeof...(Lengths) == sizeof...(Strides), bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
make_naive_tensor_descriptor(const tuple<Lengths...>& lengths,
|
||||
const tuple<Strides...>& strides,
|
||||
number<GuaranteedLastDimensionVectorLength> = number<-1>{},
|
||||
number<GuaranteedLastDimensionVectorStride> = number<-1>{})
|
||||
{
|
||||
constexpr index_t N = sizeof...(Lengths);
|
||||
|
||||
const auto transforms = make_tuple(make_embed_transform(lengths, strides));
|
||||
|
||||
constexpr auto low_dim_hidden_idss = make_tuple(sequence<0>{});
|
||||
|
||||
constexpr auto up_dim_hidden_idss =
|
||||
make_tuple(typename arithmetic_sequence_gen<1, N + 1, 1>::type{});
|
||||
|
||||
constexpr auto visible_dim_hidden_ids = typename arithmetic_sequence_gen<1, N + 1, 1>::type{};
|
||||
|
||||
const auto element_space_size =
|
||||
detail::calculate_element_space_size_impl(lengths, strides, number<0>{}, long_number<1>{});
|
||||
|
||||
using GuaranteedVectorLengths =
|
||||
typename sequence_merge<typename uniform_sequence_gen<N, -1>::type,
|
||||
sequence<GuaranteedLastDimensionVectorLength>>::type;
|
||||
|
||||
using GuaranteedVectorStrides =
|
||||
typename sequence_merge<typename uniform_sequence_gen<N, -1>::type,
|
||||
sequence<GuaranteedLastDimensionVectorStride>>::type;
|
||||
|
||||
return tensor_descriptor<remove_cv_t<decltype(transforms)>,
|
||||
remove_cv_t<decltype(low_dim_hidden_idss)>,
|
||||
remove_cv_t<decltype(up_dim_hidden_idss)>,
|
||||
remove_cv_t<decltype(visible_dim_hidden_ids)>,
|
||||
remove_cv_t<decltype(element_space_size)>,
|
||||
GuaranteedVectorLengths,
|
||||
GuaranteedVectorStrides>{transforms, element_space_size};
|
||||
}
|
||||
|
||||
// tensor descriptor with offset, the offset will not be added into element space size
|
||||
// only have an information of the starting offset, and will impact on offset calculation
|
||||
template <typename... Lengths,
|
||||
typename... Strides,
|
||||
typename offset,
|
||||
index_t GuaranteedLastDimensionVectorLength = -1,
|
||||
index_t GuaranteedLastDimensionVectorStride = -1,
|
||||
typename std::enable_if<sizeof...(Lengths) == sizeof...(Strides), bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
make_naive_tensor_descriptor_with_offset(const tuple<Lengths...>& lengths,
|
||||
const tuple<Strides...>& strides,
|
||||
const offset& offset,
|
||||
number<GuaranteedLastDimensionVectorLength> = number<-1>{},
|
||||
number<GuaranteedLastDimensionVectorStride> = number<-1>{})
|
||||
{
|
||||
const auto desc_0 = [&]() {
|
||||
const auto element_space_size = detail::calculate_element_space_size_impl(
|
||||
lengths, strides, number<0>{}, long_number<1>{});
|
||||
|
||||
const auto transforms = make_tuple(make_offset_transform(element_space_size, offset));
|
||||
|
||||
constexpr auto low_dim_hidden_idss = make_tuple(sequence<0>{});
|
||||
|
||||
constexpr auto up_dim_hidden_idss = make_tuple(sequence<1>{});
|
||||
|
||||
constexpr auto visible_dim_hidden_ids = sequence<1>{};
|
||||
|
||||
using GuaranteedVectorLengths =
|
||||
typename sequence_merge<typename uniform_sequence_gen<1, -1>::type,
|
||||
sequence<GuaranteedLastDimensionVectorLength>>::type;
|
||||
|
||||
using GuaranteedVectorStrides =
|
||||
typename sequence_merge<typename uniform_sequence_gen<1, -1>::type,
|
||||
sequence<GuaranteedLastDimensionVectorStride>>::type;
|
||||
|
||||
return tensor_descriptor<remove_cv_t<decltype(transforms)>,
|
||||
remove_cv_t<decltype(low_dim_hidden_idss)>,
|
||||
remove_cv_t<decltype(up_dim_hidden_idss)>,
|
||||
remove_cv_t<decltype(visible_dim_hidden_ids)>,
|
||||
remove_cv_t<decltype(element_space_size)>,
|
||||
GuaranteedVectorLengths,
|
||||
GuaranteedVectorStrides>{transforms, element_space_size};
|
||||
}();
|
||||
|
||||
constexpr index_t N = sizeof...(Lengths);
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
desc_0,
|
||||
make_tuple(make_embed_transform(lengths, strides)),
|
||||
make_tuple(sequence<0>{}),
|
||||
make_tuple(typename arithmetic_sequence_gen<0, N, 1>::type{}));
|
||||
}
|
||||
|
||||
// Lengths... could be:
|
||||
// 1) index_t, which is known at run-time, or
|
||||
// 2) number<>, which is known at compile-time
|
||||
// element_space_size could be:
|
||||
// 1) long_index_t, or
|
||||
// 2) long_number<>
|
||||
template <typename... Lengths, index_t GuaranteedLastDimensionVectorLength = -1>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
make_naive_tensor_descriptor_packed(const tuple<Lengths...>& lengths,
|
||||
number<GuaranteedLastDimensionVectorLength> = number<-1>{})
|
||||
{
|
||||
constexpr index_t N = sizeof...(Lengths);
|
||||
|
||||
const auto transforms = make_tuple(make_unmerge_transform(lengths));
|
||||
|
||||
constexpr auto low_dim_hidden_idss = make_tuple(sequence<0>{});
|
||||
|
||||
constexpr auto up_dim_hidden_idss =
|
||||
make_tuple(typename arithmetic_sequence_gen<1, N + 1, 1>::type{});
|
||||
|
||||
constexpr auto visible_dim_hidden_ids = typename arithmetic_sequence_gen<1, N + 1, 1>::type{};
|
||||
|
||||
const auto element_space_size = container_reduce(lengths, math::multiplies{}, long_number<1>{});
|
||||
|
||||
using GuaranteedVectorLengths =
|
||||
typename sequence_merge<typename uniform_sequence_gen<N, -1>::type,
|
||||
sequence<GuaranteedLastDimensionVectorLength>>::type;
|
||||
|
||||
using GuaranteedVectorStrides =
|
||||
typename sequence_merge<typename uniform_sequence_gen<N, -1>::type, sequence<1>>::type;
|
||||
|
||||
return tensor_descriptor<remove_cv_t<decltype(transforms)>,
|
||||
remove_cv_t<decltype(low_dim_hidden_idss)>,
|
||||
remove_cv_t<decltype(up_dim_hidden_idss)>,
|
||||
remove_cv_t<decltype(visible_dim_hidden_ids)>,
|
||||
remove_cv_t<decltype(element_space_size)>,
|
||||
GuaranteedVectorLengths,
|
||||
GuaranteedVectorStrides>{transforms, element_space_size};
|
||||
}
|
||||
|
||||
template <typename... Lengths,
|
||||
typename... Strides,
|
||||
typename offset,
|
||||
index_t GuaranteedLastDimensionVectorLength = -1,
|
||||
typename std::enable_if<sizeof...(Lengths) == sizeof...(Strides), bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_naive_tensor_descriptor_packed_with_offset(
|
||||
const tuple<Lengths...>& lengths,
|
||||
const offset& offset,
|
||||
number<GuaranteedLastDimensionVectorLength> = number<-1>{})
|
||||
{
|
||||
const auto desc_0 = [&]() {
|
||||
const auto element_space_size =
|
||||
container_reduce(lengths, math::multiplies{}, long_number<1>{});
|
||||
|
||||
const auto transforms = make_tuple(make_offset_transform(element_space_size, offset));
|
||||
|
||||
constexpr auto low_dim_hidden_idss = make_tuple(sequence<0>{});
|
||||
|
||||
constexpr auto up_dim_hidden_idss = make_tuple(sequence<1>{});
|
||||
|
||||
constexpr auto visible_dim_hidden_ids = sequence<1>{};
|
||||
|
||||
using GuaranteedVectorLengths =
|
||||
typename sequence_merge<typename uniform_sequence_gen<1, -1>::type,
|
||||
sequence<GuaranteedLastDimensionVectorLength>>::type;
|
||||
|
||||
using GuaranteedVectorStrides =
|
||||
typename sequence_merge<typename uniform_sequence_gen<1, -1>::type, sequence<1>>::type;
|
||||
|
||||
return tensor_descriptor<remove_cv_t<decltype(transforms)>,
|
||||
remove_cv_t<decltype(low_dim_hidden_idss)>,
|
||||
remove_cv_t<decltype(up_dim_hidden_idss)>,
|
||||
remove_cv_t<decltype(visible_dim_hidden_ids)>,
|
||||
remove_cv_t<decltype(element_space_size)>,
|
||||
GuaranteedVectorLengths,
|
||||
GuaranteedVectorStrides>{transforms, element_space_size};
|
||||
}();
|
||||
|
||||
constexpr index_t N = sizeof...(Lengths);
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
desc_0,
|
||||
make_tuple(make_unmerge_transform(lengths)),
|
||||
make_tuple(sequence<0>{}),
|
||||
make_tuple(typename arithmetic_sequence_gen<0, N, 1>::type{}));
|
||||
}
|
||||
|
||||
// Lengths... could be:
|
||||
// 1) index_t, which is known at run-time, or
|
||||
// 2) number<>, which is known at compile-time
|
||||
// align could be:
|
||||
// 1) index_t, or
|
||||
// 2) number<>
|
||||
template <typename... Lengths, typename Align>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
make_naive_tensor_descriptor_aligned(const tuple<Lengths...>& lengths, Align align)
|
||||
{
|
||||
constexpr auto I1 = number<1>{};
|
||||
|
||||
constexpr index_t N = sizeof...(Lengths);
|
||||
|
||||
const auto stride_n_minus_2 = math::integer_least_multiple(lengths[number<N - 1>{}], align);
|
||||
|
||||
auto strides = generate_tuple(
|
||||
[&](auto i) {
|
||||
if constexpr(i.value == N - 1)
|
||||
{
|
||||
return I1;
|
||||
}
|
||||
else if constexpr(i.value == N - 2)
|
||||
{
|
||||
return number<stride_n_minus_2>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
return container_reduce(lengths,
|
||||
math::multiplies{},
|
||||
number<stride_n_minus_2>{},
|
||||
i + I1,
|
||||
number<N - 1>{},
|
||||
I1);
|
||||
}
|
||||
},
|
||||
number<N>{});
|
||||
|
||||
return make_naive_tensor_descriptor(lengths, strides);
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
273
include/ck_tile/core/tensor/tensor_view.hpp
Normal file
273
include/ck_tile/core/tensor/tensor_view.hpp
Normal file
@@ -0,0 +1,273 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
|
||||
#include "ck_tile/core/container/container_helper.hpp"
|
||||
#include "ck_tile/core/numeric/math.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename BufferView_, typename TensorDesc_>
|
||||
struct tensor_view
|
||||
{
|
||||
using BufferView = remove_reference_t<BufferView_>;
|
||||
using DataType = typename BufferView::type;
|
||||
using TensorDesc = remove_cvref_t<TensorDesc_>;
|
||||
using TensorIndex = array<index_t, TensorDesc::get_num_of_top_dimension()>;
|
||||
using TensorCoord = decltype(make_tensor_coordinate(TensorDesc{}, TensorIndex{}));
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr tensor_view() = default;
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr tensor_view(const BufferView& buffer_view, const TensorDesc& desc)
|
||||
: buf_{buffer_view}, desc_{desc}
|
||||
{
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr auto& get_tensor_descriptor() const { return desc_; }
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_dimension()
|
||||
{
|
||||
return TensorDesc::get_num_of_top_dimension();
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr const auto& get_buffer_view() const { return buf_; }
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr auto& get_buffer_view() { return buf_; }
|
||||
|
||||
#if 0
|
||||
CK_TILE_HOST_DEVICE constexpr DataType get_element(const TensorCoord& coord) const
|
||||
{
|
||||
return buf_.template get<DataType>(
|
||||
coord.get_offset(),
|
||||
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord));
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr void set_element(const TensorCoord& coord, const DataType& x)
|
||||
{
|
||||
buf_.template set<DataType>(
|
||||
coord.get_offset(),
|
||||
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
|
||||
x);
|
||||
}
|
||||
#endif
|
||||
// X is vector of DataType.
|
||||
// "coord" is coordinate of DataType, not X. "coord" should be aligned to X
|
||||
template <typename X,
|
||||
bool oob_conditional_check = true,
|
||||
typename enable_if<is_same_v<typename scalar_type<remove_cvref_t<X>>::type,
|
||||
typename scalar_type<remove_cvref_t<DataType>>::type>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr remove_cvref_t<X>
|
||||
get_vectorized_elements(const TensorCoord& coord,
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
return buf_.template get<X>(
|
||||
coord.get_offset(),
|
||||
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
|
||||
bool_constant<oob_conditional_check>{});
|
||||
}
|
||||
|
||||
// X is vector of DataType.
|
||||
// "coord" is coordinate of DataType, not X. "coord" should be aligned to X
|
||||
template <typename X,
|
||||
bool oob_conditional_check = true,
|
||||
typename enable_if<is_same_v<typename scalar_type<remove_cvref_t<X>>::type,
|
||||
typename scalar_type<remove_cvref_t<DataType>>::type>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE void
|
||||
get_vectorized_elements_raw(remove_cvref_t<X>& dst,
|
||||
const TensorCoord& coord,
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
return buf_.template get_raw<X, oob_conditional_check>(
|
||||
dst,
|
||||
coord.get_offset(),
|
||||
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord));
|
||||
}
|
||||
|
||||
template <typename X,
|
||||
typename enable_if<is_same_v<typename scalar_type<remove_cvref_t<X>>::type,
|
||||
typename scalar_type<remove_cvref_t<DataType>>::type>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr void async_get_vectorized_elements(remove_cvref_t<DataType>* smem,
|
||||
const TensorCoord& coord) const
|
||||
{
|
||||
return buf_.template async_get<X>(smem, coord.get_offset(), true /*not used*/);
|
||||
}
|
||||
|
||||
// X is vector of DataType.
|
||||
// "coord" is coordinate of DataType, not X. "coord" should be aligned to X
|
||||
template <typename X,
|
||||
bool oob_conditional_check = true,
|
||||
typename enable_if<is_same_v<typename scalar_type<remove_cvref_t<X>>::type,
|
||||
typename scalar_type<remove_cvref_t<DataType>>::type>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr void set_vectorized_elements(
|
||||
const TensorCoord& coord, const X& x, bool_constant<oob_conditional_check> = {})
|
||||
{
|
||||
buf_.template set<X, oob_conditional_check>(
|
||||
coord.get_offset(),
|
||||
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
|
||||
x);
|
||||
}
|
||||
|
||||
template <typename X,
|
||||
bool oob_conditional_check = true,
|
||||
typename enable_if<is_same_v<typename scalar_type<remove_cvref_t<X>>::type,
|
||||
typename scalar_type<remove_cvref_t<DataType>>::type>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr void set_vectorized_elements_raw(
|
||||
const TensorCoord& coord, const X& x, bool_constant<oob_conditional_check> = {})
|
||||
{
|
||||
buf_.template set_raw<X, oob_conditional_check>(
|
||||
coord.get_offset(),
|
||||
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
|
||||
x);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void print() const
|
||||
{
|
||||
printf("tensor_view{");
|
||||
|
||||
// buf_
|
||||
printf("buf_: ");
|
||||
print(buf_);
|
||||
printf(", ");
|
||||
|
||||
// desc_
|
||||
printf("desc_: ");
|
||||
print(desc_);
|
||||
|
||||
printf("}");
|
||||
}
|
||||
|
||||
// member
|
||||
BufferView buf_;
|
||||
TensorDesc desc_;
|
||||
};
|
||||
|
||||
// placeholder type if we want to opt-out a tile view parameter
|
||||
struct null_tensor_view
|
||||
{
|
||||
};
|
||||
|
||||
template <AddressSpaceEnum BufferAddressSpace = AddressSpaceEnum::Generic,
|
||||
typename DataType,
|
||||
typename... Ts>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_tensor_view(DataType* p,
|
||||
const tensor_descriptor<Ts...>& desc)
|
||||
{
|
||||
auto buffer_view = make_buffer_view<BufferAddressSpace>(p, desc.get_element_space_size());
|
||||
|
||||
return tensor_view<decltype(buffer_view), decltype(desc)>{buffer_view, desc};
|
||||
}
|
||||
|
||||
template <AddressSpaceEnum BufferAddressSpace = AddressSpaceEnum::Generic,
|
||||
typename DataType,
|
||||
typename... Lengths,
|
||||
typename... Strides,
|
||||
index_t GuaranteedLastDimensionVectorLength = -1,
|
||||
index_t GuaranteedLastDimensionVectorStride = -1,
|
||||
typename enable_if<sizeof...(Lengths) == sizeof...(Strides), bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
make_naive_tensor_view(DataType* p,
|
||||
const tuple<Lengths...>& lengths,
|
||||
const tuple<Strides...>& strides,
|
||||
number<GuaranteedLastDimensionVectorLength> = number<-1>{},
|
||||
number<GuaranteedLastDimensionVectorStride> = number<-1>{})
|
||||
{
|
||||
auto desc = make_naive_tensor_descriptor(lengths,
|
||||
strides,
|
||||
number<GuaranteedLastDimensionVectorLength>{},
|
||||
number<GuaranteedLastDimensionVectorStride>{});
|
||||
|
||||
auto buffer_view = make_buffer_view<BufferAddressSpace>(p, desc.get_element_space_size());
|
||||
|
||||
return tensor_view<decltype(buffer_view), decltype(desc)>{buffer_view, desc};
|
||||
}
|
||||
|
||||
template <AddressSpaceEnum BufferAddressSpace = AddressSpaceEnum::Generic,
|
||||
typename DataType,
|
||||
typename... Lengths,
|
||||
index_t GuaranteedLastDimensionVectorLength = -1>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
make_naive_tensor_view_packed(DataType* p,
|
||||
const tuple<Lengths...>& lengths,
|
||||
number<GuaranteedLastDimensionVectorLength> = number<-1>{})
|
||||
{
|
||||
auto desc =
|
||||
make_naive_tensor_descriptor_packed(lengths, number<GuaranteedLastDimensionVectorLength>{});
|
||||
|
||||
auto buffer_view = make_buffer_view<BufferAddressSpace>(p, desc.get_element_space_size());
|
||||
|
||||
return tensor_view<decltype(buffer_view), decltype(desc)>{buffer_view, desc};
|
||||
}
|
||||
|
||||
template <typename OldTensorView,
|
||||
typename NewTransforms,
|
||||
typename NewLowerDimensionOldVisibleIdss,
|
||||
typename NewUpperDimensionNewVisibleIdss>
|
||||
CK_TILE_HOST_DEVICE constexpr auto transform_tensor_view(const OldTensorView& old_tensor_view,
|
||||
const NewTransforms& new_transforms,
|
||||
NewLowerDimensionOldVisibleIdss,
|
||||
NewUpperDimensionNewVisibleIdss)
|
||||
{
|
||||
auto new_desc = transform_tensor_descriptor(old_tensor_view.desc_,
|
||||
new_transforms,
|
||||
NewLowerDimensionOldVisibleIdss{},
|
||||
NewUpperDimensionNewVisibleIdss{});
|
||||
|
||||
return tensor_view<typename OldTensorView::BufferView, remove_cvref_t<decltype(new_desc)>>{
|
||||
old_tensor_view.buf_, new_desc};
|
||||
}
|
||||
|
||||
template <typename tensor_view,
|
||||
typename TileLengths, // tuple<...>
|
||||
typename DoPads> // sequence<bool, bool, ...>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
pad_tensor_view(const tensor_view& tensor_view, const TileLengths& tile_lengths, DoPads)
|
||||
{
|
||||
constexpr index_t num_dim = DoPads::size();
|
||||
|
||||
static_assert(num_dim == TileLengths::size() && num_dim == tensor_view::get_num_of_dimension(),
|
||||
"wrong! inconsistent # of dimensions");
|
||||
|
||||
// transforms
|
||||
const auto transforms = generate_tuple(
|
||||
[&](auto idim) {
|
||||
const auto old_length = tensor_view.get_tensor_descriptor().get_length(idim);
|
||||
|
||||
const auto tile_length = tile_lengths[idim];
|
||||
|
||||
const auto new_length =
|
||||
math::integer_divide_ceil(old_length, tile_length) * tile_length;
|
||||
|
||||
const auto pad_length = new_length - old_length;
|
||||
|
||||
constexpr bool DoPad = DoPads::at(idim);
|
||||
|
||||
const auto transform =
|
||||
conditional_expr<DoPad>(make_right_pad_transform(old_length, pad_length),
|
||||
make_pass_through_transform(old_length));
|
||||
|
||||
return transform;
|
||||
},
|
||||
number<num_dim>{});
|
||||
|
||||
// lower dimension Id
|
||||
const auto lower_dimss =
|
||||
generate_tuple([&](auto idim) { return sequence<idim.value>{}; }, number<num_dim>{});
|
||||
|
||||
// upper dimension Id
|
||||
const auto upper_dimss = lower_dimss;
|
||||
|
||||
return transform_tensor_view(tensor_view, transforms, lower_dimss, upper_dimss);
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
754
include/ck_tile/core/tensor/tile_distribution.hpp
Normal file
754
include/ck_tile/core/tensor/tile_distribution.hpp
Normal file
@@ -0,0 +1,754 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/container/array.hpp"
|
||||
#include "ck_tile/core/container/sequence.hpp"
|
||||
#include "ck_tile/core/container/tuple.hpp"
|
||||
#include "ck_tile/core/container/container_helper.hpp"
|
||||
#include "ck_tile/core/tensor/tensor_adaptor.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// distributed span
|
||||
template <index_t... PartialHsLengths>
|
||||
struct tile_distributed_span
|
||||
{
|
||||
using Impl = sequence<PartialHsLengths...>;
|
||||
|
||||
static constexpr auto impl_ = Impl{};
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr bool is_static() { return true; }
|
||||
};
|
||||
|
||||
// distributed index
|
||||
template <index_t... PartialHsIndices>
|
||||
struct tile_distributed_index
|
||||
{
|
||||
using Impl = sequence<PartialHsIndices...>;
|
||||
|
||||
static constexpr auto impl_ = Impl{};
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr bool is_static() { return true; }
|
||||
};
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <index_t... Is>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_tile_distributed_span(sequence<Is...>)
|
||||
{
|
||||
return tile_distributed_span<Is...>{};
|
||||
}
|
||||
|
||||
template <index_t... Is>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_tile_distributed_index(sequence<Is...>)
|
||||
{
|
||||
return tile_distributed_index<Is...>{};
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
template <typename PsYs2XsAdaptor_,
|
||||
typename Ys2DDescriptor_,
|
||||
typename StaticTileDistributionEncoding_,
|
||||
typename TileDistributionDetail_> // FIXME: this is for hold ad-hoc but useful info,
|
||||
// should be more elegnat
|
||||
struct tile_distribution
|
||||
{
|
||||
using PsYs2XsAdaptor = remove_cvref_t<PsYs2XsAdaptor_>;
|
||||
using Ys2DDescriptor = remove_cvref_t<Ys2DDescriptor_>;
|
||||
using DstrEncode = remove_cvref_t<StaticTileDistributionEncoding_>;
|
||||
using DstrDetail = remove_cvref_t<TileDistributionDetail_>;
|
||||
|
||||
static_assert(PsYs2XsAdaptor::is_static() && Ys2DDescriptor::is_static(),
|
||||
"wrong! should be static");
|
||||
|
||||
static constexpr index_t NDimX = PsYs2XsAdaptor::get_num_of_bottom_dimension();
|
||||
static constexpr index_t NDimY = Ys2DDescriptor::get_num_of_top_dimension();
|
||||
static constexpr index_t NDimP = PsYs2XsAdaptor::get_num_of_top_dimension() - NDimY;
|
||||
static constexpr index_t NDimR = StaticTileDistributionEncoding_::NDimR;
|
||||
|
||||
PsYs2XsAdaptor ps_ys_to_xs_;
|
||||
Ys2DDescriptor ys_to_d_;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_dimension_x() { return NDimX; }
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetNumOfDimensionY() { return NDimY; }
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetNumOfDimensionP() { return NDimP; }
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetNumOfDimensionR() { return NDimR; }
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_lengths()
|
||||
{
|
||||
#if 0
|
||||
// FIXME: tensor_adaptor::GetBottomDimensionLengths is wrong. re-enable this after it's fixed
|
||||
ps_ys_to_xs_.GetBottomDimensionLengths();
|
||||
#else
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
constexpr index_t x_length =
|
||||
container_reduce(typename DstrEncode::HsLengthss{}[i], math::multiplies{}, 1);
|
||||
|
||||
return number<x_length>{};
|
||||
},
|
||||
number<NDimX>{});
|
||||
#endif
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr const auto& get_ps_ys_to_xs_adaptor() const
|
||||
{
|
||||
return ps_ys_to_xs_;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr const auto& get_ys_to_d_descriptor() const { return ys_to_d_; }
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_static_tile_distribution_encoding()
|
||||
{
|
||||
return DstrEncode{};
|
||||
}
|
||||
|
||||
#if 1
|
||||
// Calculate Replication index [R0, R1, ...] based on Partion index
|
||||
// FIXME: very nasty implementation
|
||||
template <typename PartitionIndex>
|
||||
CK_TILE_HOST_DEVICE auto calculate_rs_index_from_ps_index(const PartitionIndex& ps_idx) const
|
||||
{
|
||||
static_assert(PartitionIndex::size() == NDimP, "wrong!");
|
||||
|
||||
const auto ps_ys_idx = container_concat(ps_idx, array<index_t, NDimY>{0});
|
||||
|
||||
const auto dummy_adaptor_coord = make_tensor_adaptor_coordinate(ps_ys_to_xs_, ps_ys_idx);
|
||||
|
||||
array<index_t, NDimR> rs_idx;
|
||||
|
||||
static_for<0, NDimP, 1>{}([&](auto idim_p) {
|
||||
constexpr index_t ndim_low = DstrEncode::ps_to_rhss_major_[idim_p].size();
|
||||
|
||||
static_for<0, ndim_low, 1>{}([&](auto i) {
|
||||
constexpr index_t rh_major = DstrEncode::ps_to_rhss_major_[idim_p][i];
|
||||
constexpr index_t rh_minor = DstrEncode::ps_to_rhss_minor_[idim_p][i];
|
||||
|
||||
// 0-th rh_major is the replicate dimension
|
||||
if constexpr(rh_major == 0)
|
||||
{
|
||||
constexpr index_t adaptor_hidden_id =
|
||||
DstrDetail::rh_major_minor_to_adaptor_hidden_idss_[rh_major][rh_minor];
|
||||
|
||||
// fill in
|
||||
rs_idx(rh_minor) = dummy_adaptor_coord.get_hidden_index()[adaptor_hidden_id];
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
return rs_idx;
|
||||
}
|
||||
#endif
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_distributed_spans()
|
||||
{
|
||||
constexpr auto distributed_spans_impl = DstrEncode::detail::distributed_spans_lengthss_;
|
||||
constexpr auto ndims_spans_minor = DstrEncode::detail::ndims_distributed_spans_minor_;
|
||||
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
constexpr auto span_impl = distributed_spans_impl[i];
|
||||
constexpr index_t ndim_span_minor = ndims_spans_minor[i];
|
||||
|
||||
constexpr auto span = TO_SEQUENCE(span_impl, ndim_span_minor);
|
||||
|
||||
return detail::make_tile_distributed_span(span);
|
||||
},
|
||||
number<NDimX>{});
|
||||
}
|
||||
|
||||
// FIXME: it's hacky to get Y index from Distributed-Index
|
||||
template <typename DistributedIndices>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto
|
||||
get_y_indices_from_distributed_indices(DistributedIndices)
|
||||
{
|
||||
constexpr auto ys_idx_arr = [] {
|
||||
array<index_t, NDimY> ys_idx;
|
||||
|
||||
static_for<0, NDimY, 1>{}([&](auto i) {
|
||||
constexpr index_t span_major = DstrEncode::detail::ys_to_span_major_[i];
|
||||
constexpr index_t span_minor = DstrEncode::detail::ys_to_span_minor_[i];
|
||||
|
||||
constexpr auto dstr_index = DistributedIndices{}[number<span_major>{}];
|
||||
|
||||
ys_idx(i) = dstr_index.impl_[span_minor];
|
||||
});
|
||||
|
||||
return ys_idx;
|
||||
}();
|
||||
|
||||
constexpr index_t ndim_y = NDimY;
|
||||
|
||||
return TO_SEQUENCE(ys_idx_arr, ndim_y);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr bool is_static()
|
||||
{
|
||||
return PsYs2XsAdaptor::is_static() && Ys2DDescriptor::is_static();
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void print() const
|
||||
{
|
||||
printf("tile_distribution{");
|
||||
//
|
||||
printf("tile_distribution_encoding: ");
|
||||
print(DstrEncode{});
|
||||
printf(", ");
|
||||
//
|
||||
printf("ps_ys_to_xs_: ");
|
||||
print(ps_ys_to_xs_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("ys_to_d_: ");
|
||||
print(ys_to_d_);
|
||||
//
|
||||
printf("}");
|
||||
}
|
||||
};
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <index_t NDimMax>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_sequential_index(index_t ibegin, index_t iend)
|
||||
{
|
||||
array<index_t, NDimMax> arr{0};
|
||||
|
||||
for(index_t i = 0; i < iend - ibegin; ++i)
|
||||
{
|
||||
arr(i) = ibegin + i;
|
||||
}
|
||||
|
||||
return arr;
|
||||
}
|
||||
|
||||
// this returns a constexpr encoding of tile_distribution
|
||||
template <typename StaticTileDistributionEncoding_>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
make_adaptor_encoding_for_tile_distribution(StaticTileDistributionEncoding_)
|
||||
{
|
||||
using RsLengths = typename StaticTileDistributionEncoding_::RsLengths;
|
||||
using HsLengthss = typename StaticTileDistributionEncoding_::HsLengthss;
|
||||
using Ps2RHssMajor = typename StaticTileDistributionEncoding_::Ps2RHssMajor;
|
||||
using Ps2RHssMinor = typename StaticTileDistributionEncoding_::Ps2RHssMinor;
|
||||
using Ys2RHsMajor = typename StaticTileDistributionEncoding_::Ys2RHsMajor;
|
||||
using Ys2RHsMinor = typename StaticTileDistributionEncoding_::Ys2RHsMinor;
|
||||
|
||||
// FIXME: increase max value if fail
|
||||
constexpr index_t kMaxNumTransforms = 20;
|
||||
constexpr index_t kMaxMetaDataSize = 128;
|
||||
constexpr index_t kMaxNumDim = 10;
|
||||
|
||||
using Name = cood_transform_enum;
|
||||
using MetaData = meta_data_buffer<kMaxMetaDataSize>;
|
||||
using NumDim = index_t;
|
||||
using Dims = array<index_t, kMaxNumDim>;
|
||||
using Lengths = array<index_t, kMaxNumDim>;
|
||||
|
||||
// Tile Adaptor
|
||||
// bottom dims [x0, x1, x2, ...]
|
||||
// top dims [p0, p1, ..., y0, y1, ...]
|
||||
constexpr index_t ndim_x = HsLengthss::size();
|
||||
|
||||
// Dim Ids: [idim_x_major, idim_x_minor] to [idim_hidden]
|
||||
array<array<index_t, kMaxNumDim>, ndim_x + 1> rh_major_minor_to_hidden_ids;
|
||||
array<array<index_t, kMaxNumDim>, ndim_x + 1> rh_major_minor_to_hidden_lengths;
|
||||
|
||||
auto trans = array<tuple<Name, MetaData, NumDim, Dims, NumDim, Dims>, kMaxNumTransforms>{};
|
||||
|
||||
index_t num_tran = 0;
|
||||
index_t hidden_dim_cnt = ndim_x;
|
||||
|
||||
// this is replicate transform
|
||||
{
|
||||
constexpr index_t ndim_r_minor = RsLengths::size();
|
||||
|
||||
constexpr auto r_minor_lengths = RsLengths{};
|
||||
|
||||
trans(num_tran++) = {
|
||||
cood_transform_enum::replicate,
|
||||
MetaData{to_array<index_t, ndim_r_minor>(r_minor_lengths)},
|
||||
NumDim{0},
|
||||
Dims{},
|
||||
NumDim{ndim_r_minor},
|
||||
make_sequential_index<kMaxNumDim>(hidden_dim_cnt, hidden_dim_cnt + ndim_r_minor)};
|
||||
|
||||
for(index_t i = 0; i < ndim_r_minor; ++i)
|
||||
{
|
||||
rh_major_minor_to_hidden_ids(0)(i) = hidden_dim_cnt;
|
||||
rh_major_minor_to_hidden_lengths(0)(i) = r_minor_lengths[i];
|
||||
|
||||
hidden_dim_cnt++;
|
||||
}
|
||||
};
|
||||
|
||||
// these are Unmerge transforms for X dimesions
|
||||
static_for<0, ndim_x, 1>{}([&trans,
|
||||
&num_tran,
|
||||
&hidden_dim_cnt,
|
||||
&rh_major_minor_to_hidden_ids,
|
||||
&rh_major_minor_to_hidden_lengths](auto idim_x) {
|
||||
constexpr auto h_minor_lengths = tuple_element_t<idim_x, HsLengthss>{};
|
||||
|
||||
constexpr index_t ndim_h_minor = h_minor_lengths.size();
|
||||
|
||||
trans(num_tran++) = {
|
||||
cood_transform_enum::unmerge,
|
||||
MetaData{to_array<index_t, ndim_h_minor>(h_minor_lengths)},
|
||||
NumDim{1},
|
||||
Dims{idim_x},
|
||||
NumDim{ndim_h_minor},
|
||||
make_sequential_index<kMaxNumDim>(hidden_dim_cnt, hidden_dim_cnt + ndim_h_minor)};
|
||||
|
||||
for(index_t i = 0; i < ndim_h_minor; ++i)
|
||||
{
|
||||
rh_major_minor_to_hidden_ids(idim_x + 1)(i) = hidden_dim_cnt;
|
||||
rh_major_minor_to_hidden_lengths(idim_x + 1)(i) = h_minor_lengths[i];
|
||||
|
||||
hidden_dim_cnt++;
|
||||
}
|
||||
});
|
||||
|
||||
// transform: P dimensions
|
||||
constexpr index_t ndim_p = Ps2RHssMajor::size();
|
||||
|
||||
Dims hidden_dim_id_ps;
|
||||
|
||||
static_for<0, ndim_p, 1>{}([&](auto iDimP) {
|
||||
//
|
||||
index_t hidden_dim_id_p = hidden_dim_cnt++;
|
||||
|
||||
hidden_dim_id_ps(iDimP) = hidden_dim_id_p;
|
||||
|
||||
constexpr auto p2RHsMajor = Ps2RHssMajor{}[iDimP];
|
||||
constexpr auto p2RHsMinor = Ps2RHssMinor{}[iDimP];
|
||||
|
||||
static_assert(p2RHsMajor.size() == p2RHsMinor.size(), "wrong!");
|
||||
|
||||
constexpr index_t ndim_low = p2RHsMajor.size();
|
||||
|
||||
Dims low_dims;
|
||||
Lengths low_lengths;
|
||||
|
||||
for(index_t i = 0; i < ndim_low; ++i)
|
||||
{
|
||||
index_t rh_major = p2RHsMajor[i];
|
||||
index_t rh_minor = p2RHsMinor[i];
|
||||
low_dims(i) = rh_major_minor_to_hidden_ids[rh_major][rh_minor];
|
||||
low_lengths(i) = rh_major_minor_to_hidden_lengths[rh_major][rh_minor];
|
||||
}
|
||||
|
||||
trans(num_tran++) = {cood_transform_enum::merge,
|
||||
MetaData{to_array<index_t, ndim_low>(low_lengths)},
|
||||
NumDim{ndim_low},
|
||||
low_dims,
|
||||
NumDim{1},
|
||||
Dims{hidden_dim_id_p}};
|
||||
});
|
||||
|
||||
constexpr index_t ndim_bottom = ndim_x;
|
||||
|
||||
constexpr auto bottom_dim_ids = make_sequential_index<kMaxNumDim>(0, ndim_bottom);
|
||||
|
||||
constexpr auto ys_to_rhs_major = Ys2RHsMajor{};
|
||||
constexpr auto ys_to_rhs_minor = Ys2RHsMinor{};
|
||||
|
||||
constexpr index_t ndim_y = Ys2RHsMajor::size();
|
||||
constexpr index_t ndim_top = ndim_p + ndim_y;
|
||||
|
||||
auto top_dim_ids = hidden_dim_id_ps;
|
||||
|
||||
{
|
||||
for(index_t i = 0; i < ndim_y; ++i)
|
||||
{
|
||||
index_t rh_major = ys_to_rhs_major[i];
|
||||
index_t rh_minor = ys_to_rhs_minor[i];
|
||||
top_dim_ids(ndim_p + i) = rh_major_minor_to_hidden_ids[rh_major][rh_minor];
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
const auto ps_ys_to_xs_adaptor_encoding =
|
||||
make_tuple(trans, num_tran, bottom_dim_ids, ndim_bottom, top_dim_ids, ndim_top);
|
||||
|
||||
// descriptor: [y0, y1, ...] to [d]
|
||||
Lengths y_lengths;
|
||||
index_t d_length = 1;
|
||||
|
||||
for(index_t i = 0; i < ndim_y; ++i)
|
||||
{
|
||||
index_t rh_major = ys_to_rhs_major[i];
|
||||
index_t rh_minor = ys_to_rhs_minor[i];
|
||||
index_t y_length = rh_major_minor_to_hidden_lengths[rh_major][rh_minor];
|
||||
y_lengths(i) = y_length;
|
||||
d_length *= y_length;
|
||||
}
|
||||
|
||||
auto tran = make_tuple(cood_transform_enum::unmerge,
|
||||
MetaData{to_array<index_t, ndim_y>(y_lengths)},
|
||||
NumDim{1},
|
||||
Dims{0},
|
||||
NumDim{ndim_y},
|
||||
make_sequential_index<kMaxNumDim>(1, ndim_y + 1));
|
||||
|
||||
const auto ys_to_d_adaptor_encoding = make_tuple(
|
||||
make_tuple(tran), 1, Dims{0}, 1, make_sequential_index<kMaxNumDim>(1, ndim_y + 1), ndim_y);
|
||||
|
||||
return make_tuple(ps_ys_to_xs_adaptor_encoding,
|
||||
ys_to_d_adaptor_encoding,
|
||||
d_length,
|
||||
rh_major_minor_to_hidden_ids);
|
||||
}
|
||||
|
||||
// FIXME: this is nasty. Move it inside TileDistributionEncoding::detail
|
||||
template <typename RhMajorMinor2AdaptorHiddenIdss> // tuple<sequence<...>, ...>
|
||||
struct tile_distribution_detail
|
||||
{
|
||||
static constexpr auto rh_major_minor_to_adaptor_hidden_idss_ =
|
||||
to_array_of_array(RhMajorMinor2AdaptorHiddenIdss{});
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
// this returns a constexpr tile_distribution
|
||||
template <typename StaticTileDistributionEncoding_>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_tile_distribution(StaticTileDistributionEncoding_)
|
||||
{
|
||||
using DstrEncode = remove_cvref_t<StaticTileDistributionEncoding_>;
|
||||
|
||||
constexpr auto adaptor_impl =
|
||||
detail::make_adaptor_encoding_for_tile_distribution(StaticTileDistributionEncoding_{});
|
||||
|
||||
constexpr auto ps_ys_to_xs_adaptor_impl = adaptor_impl.template at<0>();
|
||||
constexpr auto ys_to_d_adaptor_impl = adaptor_impl.template at<1>();
|
||||
constexpr index_t d_length = adaptor_impl.template at<2>();
|
||||
constexpr auto rh_major_minor_to_hidden_ids_impl = adaptor_impl.template at<3>();
|
||||
|
||||
constexpr auto ps_ys_to_xs_adaptor =
|
||||
CONSTRUCT_TENSOR_ADAPTOR_FROM_ENCODING(ps_ys_to_xs_adaptor_impl);
|
||||
|
||||
constexpr auto ys_to_d_adaptor = CONSTRUCT_TENSOR_ADAPTOR_FROM_ENCODING(ys_to_d_adaptor_impl);
|
||||
|
||||
constexpr auto ys_to_d_descriptor =
|
||||
make_tensor_descriptor_from_adaptor(ys_to_d_adaptor, d_length);
|
||||
|
||||
//
|
||||
constexpr index_t ndim_rh_major = DstrEncode::detail::ndim_rh_major_;
|
||||
constexpr auto ndims_rhs_minor = DstrEncode::detail::ndims_rhs_minor_;
|
||||
|
||||
constexpr auto rh_major_minor_to_hidden_ids =
|
||||
TO_TUPLE_OF_SEQUENCE(rh_major_minor_to_hidden_ids_impl, ndim_rh_major, ndims_rhs_minor);
|
||||
|
||||
return tile_distribution<
|
||||
remove_cvref_t<decltype(ps_ys_to_xs_adaptor)>,
|
||||
remove_cvref_t<decltype(ys_to_d_descriptor)>,
|
||||
remove_cvref_t<DstrEncode>,
|
||||
detail::tile_distribution_detail<remove_cvref_t<decltype(rh_major_minor_to_hidden_ids)>>>{
|
||||
ps_ys_to_xs_adaptor, ys_to_d_descriptor};
|
||||
}
|
||||
|
||||
// this returns a static tile_distribution
|
||||
template <typename StaticTileDistributionEncoding_>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_static_tile_distribution(StaticTileDistributionEncoding_)
|
||||
{
|
||||
using DstrEncode = remove_cvref_t<StaticTileDistributionEncoding_>;
|
||||
|
||||
constexpr auto adaptor_impl =
|
||||
detail::make_adaptor_encoding_for_tile_distribution(StaticTileDistributionEncoding_{});
|
||||
|
||||
constexpr auto ps_ys_to_xs_adaptor_impl = adaptor_impl.template at<0>();
|
||||
constexpr auto ys_to_d_adaptor_impl = adaptor_impl.template at<1>();
|
||||
constexpr index_t d_length = adaptor_impl.template at<2>();
|
||||
constexpr auto rh_major_minor_to_hidden_ids_impl = adaptor_impl.template at<3>();
|
||||
|
||||
constexpr auto ps_ys_to_xs_adaptor =
|
||||
CONSTRUCT_STATIC_TENSOR_ADAPTOR_FROM_ENCODING(ps_ys_to_xs_adaptor_impl);
|
||||
|
||||
constexpr auto ys_to_d_adaptor =
|
||||
CONSTRUCT_STATIC_TENSOR_ADAPTOR_FROM_ENCODING(ys_to_d_adaptor_impl);
|
||||
|
||||
constexpr auto ys_to_d_descriptor =
|
||||
make_tensor_descriptor_from_adaptor(ys_to_d_adaptor, number<d_length>{});
|
||||
|
||||
//
|
||||
constexpr index_t ndim_rh_major = DstrEncode::detail::ndim_rh_major_;
|
||||
constexpr auto ndims_rhs_minor = DstrEncode::detail::ndims_rhs_minor_;
|
||||
|
||||
constexpr auto rh_major_minor_to_hidden_ids =
|
||||
TO_TUPLE_OF_SEQUENCE(rh_major_minor_to_hidden_ids_impl, ndim_rh_major, ndims_rhs_minor);
|
||||
|
||||
return tile_distribution<
|
||||
remove_cvref_t<decltype(ps_ys_to_xs_adaptor)>,
|
||||
remove_cvref_t<decltype(ys_to_d_descriptor)>,
|
||||
remove_cvref_t<DstrEncode>,
|
||||
detail::tile_distribution_detail<remove_cvref_t<decltype(rh_major_minor_to_hidden_ids)>>>{
|
||||
ps_ys_to_xs_adaptor, ys_to_d_descriptor};
|
||||
}
|
||||
|
||||
//***********************************************************************************
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <typename Distribution>
|
||||
CK_TILE_HOST_DEVICE auto get_partition_index(Distribution)
|
||||
{
|
||||
// only support warp-tile and block-tile
|
||||
static_assert(Distribution::NDimP == 1 or Distribution::NDimP == 2, "wrong!");
|
||||
|
||||
if constexpr(Distribution::NDimP == 1)
|
||||
{
|
||||
return array<index_t, 1>{get_lane_id()};
|
||||
}
|
||||
else if constexpr(Distribution::NDimP == 2)
|
||||
{
|
||||
return array<index_t, 2>{get_warp_id(), get_lane_id()};
|
||||
}
|
||||
}
|
||||
|
||||
template <typename, typename, typename, index_t>
|
||||
struct reverse_slice_sequence_impl;
|
||||
|
||||
template <index_t x,
|
||||
index_t... xs,
|
||||
index_t m,
|
||||
index_t... ms,
|
||||
index_t id,
|
||||
index_t... ids,
|
||||
index_t SliceSize>
|
||||
struct reverse_slice_sequence_impl<sequence<x, xs...>,
|
||||
sequence<m, ms...>,
|
||||
sequence<id, ids...>,
|
||||
SliceSize>
|
||||
{
|
||||
using old_scan =
|
||||
reverse_slice_sequence_impl<sequence<xs...>, sequence<ms...>, sequence<ids...>, SliceSize>;
|
||||
|
||||
static constexpr auto slice_size = old_scan::remaining_slice_sizes::Front().value;
|
||||
static constexpr auto slice_length =
|
||||
std::conditional_t<m, number<math::gcd(x, slice_size)>, number<x>>::value;
|
||||
|
||||
using dim_lengths =
|
||||
typename sequence_merge<sequence<slice_length>, typename old_scan::dim_lengths>::type;
|
||||
using dim_slices =
|
||||
typename sequence_merge<sequence<x / slice_length>, typename old_scan::dim_slices>::type;
|
||||
using remaining_slice_sizes = typename sequence_merge<
|
||||
std::conditional_t<m, sequence<slice_size / slice_length>, sequence<slice_size>>,
|
||||
typename old_scan::remaining_slice_sizes>::type;
|
||||
|
||||
// the first idx that sliced length not equal to original length
|
||||
static constexpr index_t _flag =
|
||||
slice_length != x && remaining_slice_sizes{}.Front().value == 1;
|
||||
static constexpr index_t _split_flag = std::conditional_t<m, number<_flag>, number<0>>::value;
|
||||
static constexpr index_t _split_idx =
|
||||
std::conditional_t<_split_flag, number<id>, number<0>>::value;
|
||||
|
||||
static constexpr index_t split_flag = _split_flag || old_scan::split_flag;
|
||||
static constexpr index_t split_idx = std::
|
||||
conditional_t<old_scan::split_flag, number<old_scan::split_idx>, number<_split_idx>>::value;
|
||||
};
|
||||
|
||||
template <index_t x, index_t m, index_t id, index_t SliceSize>
|
||||
struct reverse_slice_sequence_impl<sequence<x>, sequence<m>, sequence<id>, SliceSize>
|
||||
{
|
||||
static constexpr auto slice_size = SliceSize;
|
||||
static constexpr auto slice_length =
|
||||
std::conditional_t<m, number<math::gcd(x, slice_size)>, number<x>>::value;
|
||||
|
||||
using dim_lengths = sequence<slice_length>;
|
||||
using dim_slices = sequence<x / slice_length>;
|
||||
using remaining_slice_sizes =
|
||||
std::conditional_t<m, sequence<slice_size / slice_length>, sequence<slice_size>>;
|
||||
|
||||
// the first idx that sliced length not equal to original length
|
||||
static constexpr index_t _flag =
|
||||
slice_length != x && remaining_slice_sizes{}.Front().value == 1;
|
||||
static constexpr index_t split_flag = std::conditional_t<m, number<_flag>, number<0>>::value;
|
||||
static constexpr index_t split_idx =
|
||||
std::conditional_t<split_flag, number<id>, number<0>>::value;
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
// input a sequence(with optional mask), and the SliceSize : size per slice
|
||||
// output the sequence each slice, and number of slices
|
||||
//
|
||||
// e.g. <2, 1, 4, 2>, 8 -> lengths:<1, 1, 4, 2> , nums: <2, 1, 1, 1> : 2 slices , slice_idx: 0
|
||||
// <4, 2, 4, 1, 2>, 4 -> lengths:<1, 1, 2, 1, 2> , nums: <4, 2, 2, 1, 1> : 16 slices , slice_idx: 2
|
||||
// <4, 2, 4, 1, 6>, 4 -> lengths:<1, 1, 2, 1, 2> , nums: <4, 2, 2, 1, 3> : 48 slices , slice_idx: 2
|
||||
// <4, 2, 5, 1, 2>, 10 -> lengths:<1, 1, 5, 1, 2> , nums: <4, 2, 1, 1, 1> : 8 slices , slice_idx: 1
|
||||
//
|
||||
// <4, 2, 8>, 64 -> lengths:<4, 2, 8> , nums: <1, 1, 1> : 1 slices , slice_idx: 0
|
||||
// <4, 2, 8>, 32 -> lengths:<2, 2, 8> , nums: <2, 1, 1> : 2 slices , slice_idx: 0
|
||||
// <4, 2, 8>, 16 -> lengths:<1, 2, 8> , nums: <4, 1, 1> : 4 slices , slice_idx: 0
|
||||
// <4, 2, 8>, 8 -> lengths:<1, 1, 8> , nums: <4, 2, 1> : 8 slices , slice_idx: 1
|
||||
// <4, 2, 8>, 4 -> lengths:<1, 1, 4> , nums: <4, 2, 2> : 16 slices , slice_idx: 2
|
||||
// <4, 2, 8>, 2 -> lengths:<1, 1, 2> , nums: <4, 2, 4> : 32 slices , slice_idx: 2
|
||||
// <4, 2, 8>, 1 -> lengths:<1, 1, 1> , nums: <4, 2, 8> : 64 slices , slice_idx: 2
|
||||
//
|
||||
// <4, 2, 1, 4, 2> / 4 ->
|
||||
// mask:<1, 1, 1, 0, 1>, -> lengths:<1, 2, 1, 4, 2> , nums: <4, 1, 1, 1, 1> : 8 slices , slice_idx: 0
|
||||
//
|
||||
// return tuple<slice_lengths, slice_nums, slice_index>, slice_index is at which index will start
|
||||
// have split slices (right -> left)
|
||||
// or the first index that sliced length is different from the original length
|
||||
// clang-format on
|
||||
template <typename Seq,
|
||||
index_t SliceSize,
|
||||
typename Mask = typename uniform_sequence_gen<Seq::size(), 1>::type>
|
||||
constexpr auto reverse_slice_sequence(Seq,
|
||||
number<SliceSize>,
|
||||
Mask = typename uniform_sequence_gen<Seq::size(), 1>::type{})
|
||||
{
|
||||
static_assert(Seq::size() == Mask::size());
|
||||
using sliced_type =
|
||||
reverse_slice_sequence_impl<Seq,
|
||||
Mask,
|
||||
typename arithmetic_sequence_gen<0, Seq::size(), 1>::type,
|
||||
SliceSize>;
|
||||
static_assert(sliced_type::remaining_slice_sizes::Front().value == 1,
|
||||
"can not evenly divide this sequence, please check");
|
||||
return make_tuple(typename sliced_type::dim_lengths{},
|
||||
typename sliced_type::dim_slices{},
|
||||
number<sliced_type::split_idx>{});
|
||||
}
|
||||
|
||||
//
|
||||
// slice tensor from x_dim, result in split in y_dim, not p_dim.
|
||||
// We don't support slice cross p_dim (aka, slice different threads)
|
||||
// also, sliced along y_dim need be the first dim of current dim.
|
||||
// Multiply Y dim before sliced dim does not make sense
|
||||
//
|
||||
// e.g
|
||||
// X0 X1
|
||||
// <1, 4, 32> - <4, 1, 4, 2, 4> | slice origin:<0, 0>, len:<0, 32>, (0 means all length)
|
||||
// Y P P Y P Y P Y
|
||||
// => <1, 4, 32> - <1, 1, 4, 2, 4> -> OK
|
||||
// |--> slice along this Y dim, is the first dim of X1, totally 4 slices
|
||||
//
|
||||
// X0 X1
|
||||
// <1, 4, 32> - <4, 1, 4, 2, 4> | slice origin:<0, 0>, len:<0, 8>, (0 means all length)
|
||||
// Y P P Y P Y P Y
|
||||
// => <1, 4, 32> - <1, 1, 1, 2, 4> -> OK
|
||||
// |--> slice along this Y dim, the P dim is 1 in the left, so is OK
|
||||
// totally 16 slices
|
||||
//
|
||||
// X0 X1
|
||||
// <1, 4, 32> - <4, 1, 4, 2, 4> | slice origin:<0, 0>, len:<0, 4>, (0 means all length)
|
||||
// Y P P Y P Y P Y
|
||||
// => <1, 4, 32> - <1, 1, 1, 1, 4> -> Fail
|
||||
// |--> slice along this P dim, will split threads, not supported
|
||||
//
|
||||
// X0 X1
|
||||
// <1, 4, 32> - <4, 1, 4, 2, 4> | slice origin:<0, 0>, len:<0, 16>, (0 means all length)
|
||||
// Y P P Y P Y P Y
|
||||
// => <1, 4, 32> - <1, 1, 2, 2, 4> -> OK
|
||||
// |--> slice along this Y dim, but this Y sim need to split into 2
|
||||
// subdime
|
||||
// the P dim in the left is 1, means actually not crossing P
|
||||
//
|
||||
template <typename Distribution, index_t... XSliceBegins, index_t... XSliceEnds>
|
||||
CK_TILE_HOST_DEVICE constexpr auto slice_distribution_from_x(
|
||||
Distribution, sequence<XSliceBegins...> x_slice_begins, sequence<XSliceEnds...> x_slice_ends)
|
||||
{
|
||||
// NOTE: this function need to be called under constexpr context,
|
||||
// due to https://wg21.link/p2280r0 we have to use non-reference type for distribution
|
||||
using Encoding = decltype(Distribution::get_static_tile_distribution_encoding());
|
||||
|
||||
static_assert(sizeof...(XSliceBegins) == sizeof...(XSliceEnds));
|
||||
|
||||
constexpr auto x_slice_lengths = x_slice_ends - x_slice_begins;
|
||||
|
||||
constexpr auto src_h_prefix_sum = Encoding::detail::get_h_dim_lengths_prefix_sum();
|
||||
constexpr auto src_y_info = Encoding::detail::get_sorted_y_info();
|
||||
constexpr auto src_y_dims = src_y_info[number<0>{}];
|
||||
constexpr auto src_y_maps = src_y_info[number<1>{}];
|
||||
constexpr auto src_y_prefix_sum = src_y_info[number<2>{}];
|
||||
|
||||
constexpr auto sliced_hlen_yidx_ylen = [&]() constexpr
|
||||
{
|
||||
auto y_slice_sorted_origins = make_zero_multi_index<Encoding::NDimY>();
|
||||
auto y_slice_lengths = Encoding::detail::ys_lengths_;
|
||||
|
||||
// This lambda will modify some value outside, so c++ will not treat return value as
|
||||
// constexpr
|
||||
// TODO: ugly
|
||||
auto new_h_lengths = transform_tuples(
|
||||
[&](auto h_len, auto id) {
|
||||
constexpr auto sliced_h =
|
||||
reverse_slice_sequence(h_len, number<x_slice_lengths[id]>{});
|
||||
|
||||
constexpr auto sliced_h_lens = sliced_h[number<0>{}];
|
||||
constexpr auto sliced_h_index = sliced_h[number<2>{}];
|
||||
|
||||
// update y_slice_lengths
|
||||
constexpr auto uniformed_h_index = sliced_h_index + number<src_h_prefix_sum[id]>{};
|
||||
constexpr auto found_y_index = container_find(src_y_dims, uniformed_h_index);
|
||||
|
||||
static_assert(found_y_index >= 0 && found_y_index < src_y_dims.size(),
|
||||
"not sliced at y dim, please check");
|
||||
|
||||
static_for<0, sliced_h_index + 1, 1>{}([&](auto i) {
|
||||
y_slice_lengths(src_y_maps[found_y_index - i]) =
|
||||
sliced_h_lens[sliced_h_index - i];
|
||||
});
|
||||
// TODO: add validations not across p dim
|
||||
|
||||
// NOTE: this y_origin is for all dims, not only current dim
|
||||
// will later use pick to select target dim
|
||||
constexpr auto y_origin = [&]() {
|
||||
constexpr auto h_trans = make_merge_transform_v3_division_mod(h_len);
|
||||
auto h_origin_ = make_zero_multi_index<h_trans.NDimLow>();
|
||||
h_trans.calculate_lower_index(h_origin_, sequence<x_slice_begins[id].value>{});
|
||||
|
||||
auto y_origin_ = make_zero_multi_index<Encoding::NDimY>();
|
||||
static_for<0, sliced_h_index + 1, 1>{}([&](auto i) {
|
||||
y_origin_(found_y_index - i) = h_origin_[sliced_h_index - i];
|
||||
});
|
||||
return y_origin_;
|
||||
}();
|
||||
|
||||
constexpr auto y_picks = typename arithmetic_sequence_gen<src_y_prefix_sum[id],
|
||||
src_y_prefix_sum[id + 1],
|
||||
1>::type{};
|
||||
|
||||
set_container_subset(
|
||||
y_slice_sorted_origins, y_picks, get_container_subset(y_origin, y_picks));
|
||||
return sliced_h_lens;
|
||||
},
|
||||
typename Encoding::HsLengthss{},
|
||||
typename arithmetic_sequence_gen<0, Encoding::HsLengthss::size(), 1>::type{});
|
||||
|
||||
auto y_slice_origins = container_reorder_given_old2new(y_slice_sorted_origins, src_y_maps);
|
||||
|
||||
return make_tuple(new_h_lengths, y_slice_origins, y_slice_lengths);
|
||||
}
|
||||
();
|
||||
|
||||
constexpr auto sliced_h_lengths = sliced_hlen_yidx_ylen[number<0>{}];
|
||||
constexpr auto sliced_y_origins_array = sliced_hlen_yidx_ylen[number<1>{}];
|
||||
constexpr auto sliced_y_origins_size = sliced_y_origins_array.size();
|
||||
constexpr auto sliced_y_lengths_array = sliced_hlen_yidx_ylen[number<2>{}];
|
||||
constexpr auto sliced_y_lengths_size = sliced_y_lengths_array.size();
|
||||
|
||||
constexpr auto sliced_y_origins = TO_SEQUENCE(sliced_y_origins_array, sliced_y_origins_size);
|
||||
constexpr auto sliced_y_lengths = TO_SEQUENCE(sliced_y_lengths_array, sliced_y_lengths_size);
|
||||
|
||||
return make_tuple(
|
||||
make_static_tile_distribution(
|
||||
tile_distribution_encoding<typename Encoding::RsLengths,
|
||||
decltype(sliced_h_lengths), // only need to change the
|
||||
// h_lengths type
|
||||
typename Encoding::Ps2RHssMajor,
|
||||
typename Encoding::Ps2RHssMinor,
|
||||
typename Encoding::Ys2RHsMajor,
|
||||
typename Encoding::Ys2RHsMinor>{}),
|
||||
sliced_y_origins,
|
||||
sliced_y_lengths);
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
} // namespace ck_tile
|
||||
761
include/ck_tile/core/tensor/tile_distribution_encoding.hpp
Normal file
761
include/ck_tile/core/tensor/tile_distribution_encoding.hpp
Normal file
@@ -0,0 +1,761 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
|
||||
#include "ck_tile/core/tensor/tensor_adaptor.hpp"
|
||||
#include "ck_tile/core/tensor/tensor_adaptor_coordinate.hpp"
|
||||
#include "ck_tile/core/container/container_helper.hpp"
|
||||
#include "ck_tile/core/container/multi_index.hpp"
|
||||
#include "ck_tile/core/numeric/math.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename RsLengths_, // sequence<...>
|
||||
typename HsLengthss_, // tuple<sequence<...>, ...>
|
||||
typename Ps2RHssMajor_, // tuple<sequence<...>, ...>
|
||||
typename Ps2RHssMinor_, // tuple<sequence<...>, ...>
|
||||
typename Ys2RHsMajor_, // sequence<...>
|
||||
typename Ys2RHsMinor_> // sequence<...>
|
||||
struct tile_distribution_encoding
|
||||
{
|
||||
using RsLengths = remove_cvref_t<RsLengths_>;
|
||||
using HsLengthss = remove_cvref_t<HsLengthss_>;
|
||||
using Ps2RHssMajor = remove_cvref_t<Ps2RHssMajor_>;
|
||||
using Ps2RHssMinor = remove_cvref_t<Ps2RHssMinor_>;
|
||||
using Ys2RHsMajor = remove_cvref_t<Ys2RHsMajor_>;
|
||||
using Ys2RHsMinor = remove_cvref_t<Ys2RHsMinor_>;
|
||||
|
||||
static_assert(Ps2RHssMajor::size() == Ps2RHssMinor::size(), "wrong!");
|
||||
static_assert(Ys2RHsMajor::size() == Ys2RHsMinor::size(), "wrong!");
|
||||
|
||||
static constexpr index_t NDimX = HsLengthss::size();
|
||||
static constexpr index_t NDimP = Ps2RHssMajor::size();
|
||||
static constexpr index_t NDimY = Ys2RHsMajor::size();
|
||||
static constexpr index_t NDimR = RsLengths::size();
|
||||
|
||||
// FIXME: move into detail
|
||||
static constexpr auto rs_lengths_ = RsLengths{};
|
||||
static constexpr auto hs_lengthss_ = HsLengthss{};
|
||||
static constexpr auto ps_to_rhss_major_ = Ps2RHssMajor{};
|
||||
static constexpr auto ps_to_rhss_minor_ = Ps2RHssMinor{};
|
||||
static constexpr auto ys_to_rhs_major_ = Ys2RHsMajor{};
|
||||
static constexpr auto ys_to_rhs_minor_ = Ys2RHsMinor{};
|
||||
|
||||
// redundant but useful info
|
||||
// TODO: really bad code, should be over-hauled
|
||||
struct detail
|
||||
{
|
||||
// ndim_rh_major_, ndim_span_mainor_
|
||||
static constexpr index_t ndim_rh_major_ = NDimX + 1;
|
||||
static constexpr index_t ndim_span_major_ = NDimX;
|
||||
|
||||
// ndims_rhs_minor_[ndim_rh_major_]
|
||||
static constexpr auto ndims_rhs_minor_ = generate_array(
|
||||
[](auto i) {
|
||||
if constexpr(i.value == 0)
|
||||
{
|
||||
return rs_lengths_.size();
|
||||
}
|
||||
else
|
||||
{
|
||||
return hs_lengthss_[i - number<1>{}].size();
|
||||
}
|
||||
},
|
||||
number<ndim_rh_major_>{});
|
||||
|
||||
// max_ndim_rh_minor_
|
||||
static constexpr index_t max_ndim_rh_minor_ =
|
||||
container_reduce(ndims_rhs_minor_, math::maximize<index_t>{}, 0);
|
||||
|
||||
// rhs_lengthss_[ndim_rh_major_][max_ndim_rh_minor_]
|
||||
static constexpr auto rhs_lengthss_ =
|
||||
to_array_of_array(container_concat(make_tuple(rs_lengths_), hs_lengthss_));
|
||||
|
||||
// ys_lengths_
|
||||
static constexpr auto ys_lengths_ = [] {
|
||||
array<index_t, NDimY> ys_lengths_tmp{-1};
|
||||
|
||||
for(index_t i = 0; i < NDimY; i++)
|
||||
{
|
||||
index_t rh_major = ys_to_rhs_major_[i];
|
||||
index_t rh_minor = ys_to_rhs_minor_[i];
|
||||
|
||||
ys_lengths_tmp(i) = rhs_lengthss_[rh_major][rh_minor];
|
||||
}
|
||||
|
||||
return ys_lengths_tmp;
|
||||
}();
|
||||
|
||||
// rhs_major_minor_to_ys_[ndim_rh_majpr_][max_ndim_rh_minor_]
|
||||
static constexpr auto rhs_major_minor_to_ys_ = [] {
|
||||
array<array<index_t, max_ndim_rh_minor_>, NDimX + 1> rhs_major_minor_to_ys_tmp{{-1}};
|
||||
|
||||
static_for<0, NDimY, 1>{}([&](auto i) {
|
||||
constexpr index_t rh_major = ys_to_rhs_major_[i];
|
||||
constexpr index_t rh_minor = ys_to_rhs_minor_[i];
|
||||
|
||||
rhs_major_minor_to_ys_tmp(rh_major)(rh_minor) = i;
|
||||
});
|
||||
|
||||
return rhs_major_minor_to_ys_tmp;
|
||||
}();
|
||||
|
||||
// ndims_span_minor_[NDimY]
|
||||
static constexpr auto ndims_span_minor_ = [] {
|
||||
array<index_t, NDimX> ndims_span_minor{0};
|
||||
|
||||
for(index_t i = 0; i < NDimY; i++)
|
||||
{
|
||||
const index_t span_major = ys_to_rhs_major_[i] - 1;
|
||||
|
||||
ndims_span_minor(span_major)++;
|
||||
}
|
||||
|
||||
return ndims_span_minor;
|
||||
}();
|
||||
|
||||
// max_ndim_span_minor_
|
||||
static constexpr index_t max_ndim_span_minor_ =
|
||||
container_reduce(ndims_span_minor_, math::maximize<index_t>{}, 0);
|
||||
|
||||
// rhs_major_minor_to_span_minor_ [ndim_rh_major_][max_ndim_rh_minor_]
|
||||
static constexpr auto rhs_major_minor_to_span_minor_ = [] {
|
||||
array<array<index_t, max_ndim_rh_minor_>, ndim_rh_major_> rhs_major_minor_to_span_minor{
|
||||
{-1}};
|
||||
|
||||
static_for<0, ndim_rh_major_, 1>{}([&](auto rh_major) {
|
||||
constexpr index_t ndim_rh_minor = ndims_rhs_minor_[rh_major];
|
||||
|
||||
index_t cnt_ndim_span_minor = 0;
|
||||
|
||||
static_for<0, ndim_rh_minor, 1>{}([&](auto rh_minor) {
|
||||
constexpr index_t idim_y = rhs_major_minor_to_ys_[rh_major][rh_minor];
|
||||
|
||||
if(idim_y >= 0)
|
||||
{
|
||||
rhs_major_minor_to_span_minor(rh_major)(rh_minor) = cnt_ndim_span_minor;
|
||||
|
||||
cnt_ndim_span_minor++;
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
return rhs_major_minor_to_span_minor;
|
||||
}();
|
||||
|
||||
// ys_to_span_major_[NDimY]
|
||||
static constexpr auto ys_to_span_major_ =
|
||||
generate_array([](auto i) { return ys_to_rhs_major_[i] - 1; }, number<NDimY>{});
|
||||
|
||||
// ys_to_span_minor_[NDimY]
|
||||
static constexpr auto ys_to_span_minor_ = generate_array(
|
||||
[](auto i) {
|
||||
return rhs_major_minor_to_span_minor_[ys_to_rhs_major_[i]][ys_to_rhs_minor_[i]];
|
||||
},
|
||||
number<NDimY>{});
|
||||
|
||||
// distributed_spans_lengthss_[ndim_span_major_][max_ndim_span_minor_]
|
||||
static constexpr auto distributed_spans_lengthss_ = [] {
|
||||
array<array<index_t, max_ndim_span_minor_>, ndim_span_major_>
|
||||
distributed_spans_lengthss{{-1}};
|
||||
|
||||
static_for<0, NDimY, 1>{}([&](auto i) {
|
||||
const index_t rh_major = ys_to_rhs_major_[i];
|
||||
const index_t rh_minor = ys_to_rhs_minor_[i];
|
||||
|
||||
const index_t h_length = hs_lengthss_[number<rh_major - 1>{}][rh_minor];
|
||||
|
||||
const index_t span_major = rh_major - 1;
|
||||
const index_t span_minor = rhs_major_minor_to_span_minor_[rh_major][rh_minor];
|
||||
|
||||
distributed_spans_lengthss(span_major)(span_minor) = h_length;
|
||||
});
|
||||
|
||||
return distributed_spans_lengthss;
|
||||
}();
|
||||
|
||||
// ndims_distributed_spans_minor_[ndim_span_major_]
|
||||
static constexpr auto ndims_distributed_spans_minor_ = [] {
|
||||
array<index_t, ndim_span_major_> ndims_distributed_spans_minor{0};
|
||||
|
||||
static_for<0, NDimY, 1>{}([&](auto i) {
|
||||
const index_t span_major = ys_to_rhs_major_[i] - 1;
|
||||
|
||||
ndims_distributed_spans_minor(span_major)++;
|
||||
});
|
||||
|
||||
return ndims_distributed_spans_minor;
|
||||
}();
|
||||
|
||||
// does_p_own_r_[NDimP][NDimR]
|
||||
static constexpr auto does_p_own_r_ = [] {
|
||||
if constexpr(NDimR > 0)
|
||||
{
|
||||
array<array<bool, NDimR>, NDimP> does_p_own_r{{false}};
|
||||
|
||||
static_for<0, NDimP, 1>{}([&](auto idim_p) {
|
||||
constexpr index_t ndim_low = ps_to_rhss_major_[idim_p].size();
|
||||
|
||||
static_for<0, ndim_low, 1>{}([&](auto idim_low) {
|
||||
constexpr index_t rh_major = ps_to_rhss_major_[idim_p][idim_low];
|
||||
constexpr index_t rh_minor = ps_to_rhss_minor_[idim_p][idim_low];
|
||||
|
||||
if constexpr(rh_major == 0)
|
||||
{
|
||||
does_p_own_r(idim_p)(rh_minor) = true;
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
return does_p_own_r;
|
||||
}
|
||||
else
|
||||
{
|
||||
return array<array<bool, NDimR>, NDimP>{};
|
||||
}
|
||||
}();
|
||||
|
||||
// ps_over_rs_derivative_[NDimP][NDimR]
|
||||
static constexpr auto ps_over_rs_derivative_ = [] {
|
||||
if constexpr(NDimR > 0)
|
||||
{
|
||||
array<array<index_t, NDimR>, NDimP> ps_over_rs_derivative{{0}};
|
||||
|
||||
static_for<0, NDimP, 1>{}([&](auto idim_p) {
|
||||
constexpr index_t ndim_low = ps_to_rhss_major_[idim_p].size();
|
||||
|
||||
index_t p_over_rh_derivative = 1;
|
||||
|
||||
static_for<ndim_low - 1, -1, -1>{}([&](auto idim_low) {
|
||||
constexpr index_t rh_major = ps_to_rhss_major_[idim_p][idim_low];
|
||||
constexpr index_t rh_minor = ps_to_rhss_minor_[idim_p][idim_low];
|
||||
|
||||
constexpr index_t rh_length = rhs_lengthss_[rh_major][rh_minor];
|
||||
|
||||
if constexpr(rh_major == 0)
|
||||
{
|
||||
ps_over_rs_derivative(idim_p)(rh_minor) = p_over_rh_derivative;
|
||||
}
|
||||
|
||||
p_over_rh_derivative *= rh_length;
|
||||
});
|
||||
});
|
||||
|
||||
return ps_over_rs_derivative;
|
||||
}
|
||||
else
|
||||
{
|
||||
return array<array<index_t, NDimR>, NDimP>{};
|
||||
}
|
||||
}();
|
||||
|
||||
// e.g. tuple<seq<1, 4, 32>, seq<4, 1, 4, 2, 4>> --> seq<3, 5> --> seq<0, 3, 8>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_h_dim_lengths_prefix_sum()
|
||||
{
|
||||
// <len_d0, len_d1, ...>
|
||||
// e.g. tuple<seq<1, 4, 32>, seq<4, 1, 4, 2, 4>> --> seq<3, 5>
|
||||
constexpr auto uniformed_h_dim_lengths = generate_sequence_v2(
|
||||
[&](auto i) {
|
||||
constexpr index_t size = HsLengthss{}[i].size();
|
||||
return number<size>{};
|
||||
},
|
||||
number<NDimX>{});
|
||||
|
||||
// <0, len_d0, len_d0+len_d1, ...>
|
||||
// e.g. seq<3, 5> --> seq<0, 3, 8>
|
||||
constexpr auto h_dim_prefix_sum = prefix_sum_sequence(uniformed_h_dim_lengths);
|
||||
|
||||
return h_dim_prefix_sum;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_uniformed_idx_y_to_h()
|
||||
{
|
||||
constexpr auto all_ys_2_rhss = transform_sequences(
|
||||
[](auto major, auto minor) constexpr {
|
||||
// <0, 0, len_d0, len_d0+len_d1, ...>
|
||||
constexpr auto x_dim_prefix_sum = merge_sequences(
|
||||
sequence<0>{} /*for R dims*/, get_h_dim_lengths_prefix_sum());
|
||||
return x_dim_prefix_sum.at(major) + minor;
|
||||
},
|
||||
Ys2RHsMajor{},
|
||||
Ys2RHsMinor{});
|
||||
|
||||
return all_ys_2_rhss;
|
||||
}
|
||||
|
||||
// return tuple<sorted_dims, sorted_maps, sorted_prefix_sum>
|
||||
template <typename IdxSeq, typename PrefixSumSeq>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_sorted_info(IdxSeq, PrefixSumSeq)
|
||||
{
|
||||
using sorted_idx =
|
||||
sequence_unique_sort<IdxSeq, math::less<index_t>, math::equal<index_t>>;
|
||||
|
||||
constexpr auto sorted_dims = typename sorted_idx::type{};
|
||||
constexpr auto sorted_maps = typename sorted_idx::sorted2unsorted_map{};
|
||||
|
||||
constexpr auto sorted_histogram =
|
||||
histogram_sorted_sequence(sorted_dims, PrefixSumSeq{});
|
||||
constexpr auto sorted_prefix_sum = prefix_sum_sequence(sorted_histogram);
|
||||
|
||||
return make_tuple(sorted_dims, sorted_maps, sorted_prefix_sum);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_sorted_y_info()
|
||||
{
|
||||
return get_sorted_info(get_uniformed_idx_y_to_h(), get_h_dim_lengths_prefix_sum());
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void print() const
|
||||
{
|
||||
printf("tile_distribution_encoding::detail{");
|
||||
//
|
||||
printf("ndim_rh_major_: ");
|
||||
print(ndim_rh_major_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("ndim_span_major_: ");
|
||||
print(ndim_span_major_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("ndims_rhs_minor_: ");
|
||||
print(ndims_rhs_minor_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("ndim_rh_major_: ");
|
||||
print(ndim_rh_major_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("max_ndim_rh_minor_: ");
|
||||
print(max_ndim_rh_minor_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("rhs_lengthss_: ");
|
||||
print(rhs_lengthss_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("ys_lengths_: ");
|
||||
print(ys_lengths_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("rhs_major_minor_to_ys_: ");
|
||||
print(rhs_major_minor_to_ys_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("ndims_span_minor_: ");
|
||||
print(ndims_span_minor_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("max_ndim_span_minor_: ");
|
||||
print(max_ndim_span_minor_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("ys_to_span_major_: ");
|
||||
print(ys_to_span_major_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("ys_to_span_minor_: ");
|
||||
print(ys_to_span_minor_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("distributed_spans_lengthss_: ");
|
||||
print(distributed_spans_lengthss_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("ndims_distributed_spans_minor_: ");
|
||||
print(ndims_distributed_spans_minor_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("ps_over_rs_derivative_: ");
|
||||
print(ps_over_rs_derivative_);
|
||||
//
|
||||
printf("}");
|
||||
}
|
||||
};
|
||||
|
||||
CK_TILE_HOST_DEVICE void print() const
|
||||
{
|
||||
printf("tile_distribution_encoding{");
|
||||
//
|
||||
printf("NDimX: %d, NDimP: %d, NDimY: %d, ", NDimX, NDimP, NDimY);
|
||||
//
|
||||
printf("rs_lengths_: ");
|
||||
print(rs_lengths_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("hs_lengthss_: ");
|
||||
print(hs_lengthss_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("ps_to_rhss_major_: ");
|
||||
print(ps_to_rhss_major_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("ps_to_rhss_minor_: ");
|
||||
print(ps_to_rhss_minor_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("ys_to_rhs_major_: ");
|
||||
print(ys_to_rhs_major_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("ys_to_rhs_minor_: ");
|
||||
print(ys_to_rhs_minor_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("detail: ");
|
||||
print(detail{});
|
||||
//
|
||||
printf("}");
|
||||
}
|
||||
};
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <typename OuterDstr, typename InnerDstr>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_embed_tile_distribution_encoding(OuterDstr, InnerDstr)
|
||||
{
|
||||
static_assert(OuterDstr::NDimX == InnerDstr::NDimX, "wrong!");
|
||||
|
||||
constexpr index_t NDimHMajor = OuterDstr::NDimX;
|
||||
|
||||
using RsLengths =
|
||||
sequence_merge_t<typename OuterDstr::RsLengths, typename InnerDstr::RsLengths>;
|
||||
|
||||
constexpr auto hs_lengthss = generate_tuple(
|
||||
[&](auto i) {
|
||||
return merge_sequences(typename OuterDstr::HsLengthss{}[i],
|
||||
typename InnerDstr::HsLengthss{}[i]);
|
||||
},
|
||||
number<NDimHMajor>{});
|
||||
|
||||
//
|
||||
constexpr auto rhs_major_2_ndim_outer_rhs_minor = [&]() {
|
||||
array<index_t, NDimHMajor + 1> rhs_major_2_ndim_outer_rhs_minor_;
|
||||
|
||||
// R dimension
|
||||
rhs_major_2_ndim_outer_rhs_minor_(0) = OuterDstr::RsLengths::size();
|
||||
|
||||
// Hs dimensions
|
||||
static_for<0, NDimHMajor, 1>{}([&](auto i) {
|
||||
rhs_major_2_ndim_outer_rhs_minor_(i + 1) = typename OuterDstr::HsLengthss{}[i].size();
|
||||
});
|
||||
|
||||
return rhs_major_2_ndim_outer_rhs_minor_;
|
||||
}();
|
||||
|
||||
// Ps2RHssMinor
|
||||
constexpr auto updated_inner_ps_2_rhss_minor = generate_tuple(
|
||||
[&](auto p) {
|
||||
constexpr auto inner_p_2_rhss_major = typename InnerDstr::Ps2RHssMajor{}[p];
|
||||
constexpr auto inner_p_2_rhss_minor = typename InnerDstr::Ps2RHssMinor{}[p];
|
||||
|
||||
constexpr index_t ndim_tmp = inner_p_2_rhss_minor.size();
|
||||
|
||||
constexpr auto updated_inner_p_2_rhss_minor = [&]() {
|
||||
array<index_t, ndim_tmp> updated_inner_p_2_rhss_minor_;
|
||||
|
||||
for(index_t i = 0; i < ndim_tmp; i++)
|
||||
{
|
||||
index_t rh_major = inner_p_2_rhss_major[i];
|
||||
|
||||
index_t ndim_outer_h_minor = rhs_major_2_ndim_outer_rhs_minor[rh_major];
|
||||
|
||||
updated_inner_p_2_rhss_minor_(i) = inner_p_2_rhss_minor[i] + ndim_outer_h_minor;
|
||||
}
|
||||
|
||||
return updated_inner_p_2_rhss_minor_;
|
||||
}();
|
||||
|
||||
return TO_SEQUENCE(updated_inner_p_2_rhss_minor, ndim_tmp);
|
||||
},
|
||||
number<InnerDstr::NDimP>{});
|
||||
|
||||
// Ys2RHsMinor
|
||||
constexpr auto updated_inner_ys_2_rhs_minor = [&]() {
|
||||
constexpr auto inner_ys_2_rhs_major = typename InnerDstr::Ys2RHsMajor{};
|
||||
constexpr auto inner_ys_2_rhs_minor = typename InnerDstr::Ys2RHsMinor{};
|
||||
|
||||
constexpr index_t ndim_tmp = inner_ys_2_rhs_minor.size();
|
||||
|
||||
constexpr auto updated_inner_ys_2_rhs_minor_ = [&]() {
|
||||
array<index_t, ndim_tmp> updated_inner_ys_2_rhs_minor__;
|
||||
|
||||
for(index_t i = 0; i < ndim_tmp; i++)
|
||||
{
|
||||
index_t rh_major = inner_ys_2_rhs_major[i];
|
||||
|
||||
index_t ndim_outer_h_minor = rhs_major_2_ndim_outer_rhs_minor[rh_major];
|
||||
|
||||
updated_inner_ys_2_rhs_minor__(i) = inner_ys_2_rhs_minor[i] + ndim_outer_h_minor;
|
||||
}
|
||||
|
||||
return updated_inner_ys_2_rhs_minor__;
|
||||
}();
|
||||
|
||||
return TO_SEQUENCE(updated_inner_ys_2_rhs_minor_, ndim_tmp);
|
||||
}();
|
||||
|
||||
//
|
||||
constexpr auto ps_2_rhss_major =
|
||||
container_concat(typename OuterDstr::Ps2RHssMajor{}, typename InnerDstr::Ps2RHssMajor{});
|
||||
|
||||
constexpr auto ps_2_rhss_minor =
|
||||
container_concat(typename OuterDstr::Ps2RHssMinor{}, updated_inner_ps_2_rhss_minor);
|
||||
|
||||
//
|
||||
constexpr auto ys_2_rhs_major =
|
||||
merge_sequences(typename OuterDstr::Ys2RHsMajor{}, typename InnerDstr::Ys2RHsMajor{});
|
||||
|
||||
constexpr auto ys_2_rhs_minor =
|
||||
merge_sequences(typename OuterDstr::Ys2RHsMinor{}, updated_inner_ys_2_rhs_minor);
|
||||
|
||||
return tile_distribution_encoding<RsLengths,
|
||||
remove_cvref_t<decltype(hs_lengthss)>,
|
||||
remove_cvref_t<decltype(ps_2_rhss_major)>,
|
||||
remove_cvref_t<decltype(ps_2_rhss_minor)>,
|
||||
remove_cvref_t<decltype(ys_2_rhs_major)>,
|
||||
remove_cvref_t<decltype(ys_2_rhs_minor)>>{};
|
||||
}
|
||||
|
||||
template <typename InDstr, index_t... InReduceDimXs>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
make_reduce_tile_distribution_encoding_impl(InDstr, sequence<InReduceDimXs...> reduce_dim_xs_in)
|
||||
{
|
||||
constexpr auto I1 = number<1>{};
|
||||
|
||||
// FIXME: increase if fail
|
||||
constexpr index_t max_ndim_r_out = 20;
|
||||
constexpr index_t max_ndim_y_out = 20;
|
||||
|
||||
//
|
||||
constexpr index_t ndim_p = InDstr::NDimP;
|
||||
constexpr index_t ndim_x_in = InDstr::NDimX;
|
||||
constexpr index_t ndim_y_in = InDstr::NDimY;
|
||||
constexpr index_t ndim_rh_major_in = InDstr::NDimX + 1;
|
||||
constexpr index_t ndim_x_out = ndim_x_in - sizeof...(InReduceDimXs);
|
||||
constexpr index_t max_ndim_rh_minor_in = InDstr::detail::max_ndim_rh_minor_;
|
||||
|
||||
// ndims_ps_low
|
||||
constexpr auto ndims_ps_low = generate_array(
|
||||
[&](auto i) { return InDstr::ps_to_rhss_major_[i].size(); }, number<ndim_p>{});
|
||||
|
||||
// is_rh_major_in_for_reduce
|
||||
array<bool, ndim_rh_major_in> is_rh_major_in_for_reduce{false};
|
||||
|
||||
for(index_t i = 0; i < reduce_dim_xs_in.size(); i++)
|
||||
{
|
||||
index_t rh_major = reduce_dim_xs_in[i] + 1;
|
||||
|
||||
is_rh_major_in_for_reduce(rh_major) = true;
|
||||
}
|
||||
|
||||
// is_y_in_for_reduce
|
||||
array<bool, ndim_y_in> is_y_in_for_reduce{false};
|
||||
|
||||
for(index_t i = 0; i < ndim_y_in; i++)
|
||||
{
|
||||
index_t rh_major = InDstr::ys_to_rhs_major_[i];
|
||||
|
||||
if(is_rh_major_in_for_reduce[rh_major])
|
||||
{
|
||||
is_y_in_for_reduce(i) = true;
|
||||
}
|
||||
}
|
||||
|
||||
// is_rh_minor_in_for_y_reduce
|
||||
array<array<bool, max_ndim_rh_minor_in>, ndim_rh_major_in> is_rh_minor_in_for_y_reduce{{false}};
|
||||
|
||||
static_for<0, ndim_y_in, 1>{}([&](auto i) {
|
||||
index_t rh_major = InDstr::ys_to_rhs_major_[i];
|
||||
index_t rh_minor = InDstr::ys_to_rhs_minor_[i];
|
||||
|
||||
if(is_y_in_for_reduce[i])
|
||||
{
|
||||
is_rh_minor_in_for_y_reduce(rh_major)(rh_minor) = true;
|
||||
}
|
||||
});
|
||||
|
||||
// in2out_rh_major
|
||||
array<index_t, ndim_rh_major_in> in2out_rh_major{-1};
|
||||
index_t cnt_ndim_rh_major_out = 0;
|
||||
|
||||
for(index_t i = 0; i < ndim_rh_major_in; i++)
|
||||
{
|
||||
if(is_rh_major_in_for_reduce[i])
|
||||
{
|
||||
in2out_rh_major(i) = 0;
|
||||
}
|
||||
else
|
||||
{
|
||||
in2out_rh_major(i) = cnt_ndim_rh_major_out;
|
||||
|
||||
cnt_ndim_rh_major_out++;
|
||||
}
|
||||
}
|
||||
|
||||
// rs_lengths_out, in2out_rh_minor
|
||||
array<index_t, max_ndim_r_out> rs_lengths_out{-1};
|
||||
array<array<index_t, max_ndim_rh_minor_in>, ndim_rh_major_in> in2out_rh_minor{{-1}};
|
||||
|
||||
// loop over input R dim
|
||||
for(index_t i = 0; i < InDstr::rs_lengths_.size(); i++)
|
||||
{
|
||||
// rs_lengths_out
|
||||
rs_lengths_out(i) = InDstr::rs_lengths_[i];
|
||||
|
||||
// in2out_rh_minor
|
||||
in2out_rh_minor(0)(i) = i;
|
||||
}
|
||||
|
||||
// loop over input H Dim
|
||||
index_t cnt_ndim_r_out = InDstr::rs_lengths_.size();
|
||||
|
||||
static_for<1, ndim_rh_major_in, 1>{}([&](auto rh_major_in) {
|
||||
constexpr auto h_major_in = rh_major_in - I1;
|
||||
|
||||
constexpr index_t ndim_rh_minor_in = InDstr::hs_lengthss_[h_major_in].size();
|
||||
|
||||
if(is_rh_major_in_for_reduce[rh_major_in])
|
||||
{
|
||||
for(index_t rh_minor_in = 0; rh_minor_in < ndim_rh_minor_in; rh_minor_in++)
|
||||
{
|
||||
if(not is_rh_minor_in_for_y_reduce[rh_major_in][rh_minor_in])
|
||||
{
|
||||
// rs_lengths_out
|
||||
rs_lengths_out(cnt_ndim_r_out) = InDstr::hs_lengthss_[h_major_in][rh_minor_in];
|
||||
|
||||
// in2out_rh_minor
|
||||
in2out_rh_minor(rh_major_in)(rh_minor_in) = cnt_ndim_r_out;
|
||||
|
||||
cnt_ndim_r_out++;
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
for(index_t rh_minor_in = 0; rh_minor_in < ndim_rh_minor_in; rh_minor_in++)
|
||||
{
|
||||
// in2out_rh_minor
|
||||
in2out_rh_minor(rh_major_in)(rh_minor_in) = rh_minor_in;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// ndim_r_out
|
||||
const index_t ndim_r_out = cnt_ndim_r_out;
|
||||
|
||||
// ndims_hs_minor_out, hs_lengthss_out
|
||||
array<index_t, ndim_x_out> ndims_hs_minor_out{-1};
|
||||
array<array<index_t, max_ndim_rh_minor_in>, ndim_x_out> hs_lengthss_out{{-1}};
|
||||
|
||||
index_t cnt_ndim_x_out = 0;
|
||||
|
||||
static_for<0, ndim_x_in, 1>{}([&](auto i) {
|
||||
if(not is_rh_major_in_for_reduce[i + I1])
|
||||
{
|
||||
// ndims_hs_minor_out
|
||||
ndims_hs_minor_out(cnt_ndim_x_out) = InDstr::hs_lengthss_[i].size();
|
||||
|
||||
// hs_lengthss_out
|
||||
static_for<0, InDstr::hs_lengthss_[i].size(), 1>{}(
|
||||
[&](auto j) { hs_lengthss_out(cnt_ndim_x_out)(j) = InDstr::hs_lengthss_[i][j]; });
|
||||
|
||||
cnt_ndim_x_out++;
|
||||
}
|
||||
});
|
||||
|
||||
// ps_to_rhss_major_out, ps_to_rhss_minor_out
|
||||
array<array<index_t, max_ndim_rh_minor_in>, ndim_p> ps_to_rhss_major_out{{-1}};
|
||||
array<array<index_t, max_ndim_rh_minor_in>, ndim_p> ps_to_rhss_minor_out{{-1}};
|
||||
|
||||
static_for<0, ndim_p, 1>{}([&](auto idim_p) {
|
||||
static_for<0, InDstr::ps_to_rhss_major_[idim_p].size(), 1>{}([&](auto idim_low) {
|
||||
index_t rh_major_in = InDstr::ps_to_rhss_major_[idim_p][idim_low];
|
||||
index_t rh_minor_in = InDstr::ps_to_rhss_minor_[idim_p][idim_low];
|
||||
|
||||
ps_to_rhss_major_out(idim_p)(idim_low) = in2out_rh_major[rh_major_in];
|
||||
ps_to_rhss_minor_out(idim_p)(idim_low) = in2out_rh_minor[rh_major_in][rh_minor_in];
|
||||
});
|
||||
});
|
||||
|
||||
// ys_to_rhs_major_out, ys_to_rhs_minor_out
|
||||
array<index_t, max_ndim_y_out> ys_to_rhs_major_out{-1};
|
||||
array<index_t, max_ndim_y_out> ys_to_rhs_minor_out{-1};
|
||||
|
||||
index_t cnt_ndim_y_out = 0;
|
||||
|
||||
static_for<0, ndim_y_in, 1>{}([&](auto i) {
|
||||
if(not is_y_in_for_reduce[i])
|
||||
{
|
||||
index_t rh_major_in = InDstr::ys_to_rhs_major_[i];
|
||||
index_t rh_minor_in = InDstr::ys_to_rhs_minor_[i];
|
||||
|
||||
ys_to_rhs_major_out(cnt_ndim_y_out) = in2out_rh_major[rh_major_in];
|
||||
ys_to_rhs_minor_out(cnt_ndim_y_out) = in2out_rh_minor[rh_major_in][rh_minor_in];
|
||||
|
||||
cnt_ndim_y_out++;
|
||||
}
|
||||
});
|
||||
|
||||
// ndim_y_out
|
||||
const index_t ndim_y_out = cnt_ndim_y_out;
|
||||
|
||||
//
|
||||
return make_tuple(ndim_x_out,
|
||||
ndim_p,
|
||||
ndim_y_out,
|
||||
ndim_r_out,
|
||||
ndims_hs_minor_out,
|
||||
ndims_ps_low,
|
||||
rs_lengths_out,
|
||||
hs_lengthss_out,
|
||||
ps_to_rhss_major_out,
|
||||
ps_to_rhss_minor_out,
|
||||
ys_to_rhs_major_out,
|
||||
ys_to_rhs_minor_out);
|
||||
}
|
||||
|
||||
template <typename InDstr, index_t... InReduceDimXs>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
make_reduce_tile_distribution_encoding(InDstr, sequence<InReduceDimXs...> reduce_dim_xs_in)
|
||||
{
|
||||
constexpr auto impl = make_reduce_tile_distribution_encoding_impl(InDstr{}, reduce_dim_xs_in);
|
||||
|
||||
constexpr index_t ndim_x = impl.template at<0>();
|
||||
constexpr index_t ndim_p = impl.template at<1>();
|
||||
constexpr index_t ndim_y = impl.template at<2>();
|
||||
constexpr index_t ndim_r = impl.template at<3>();
|
||||
constexpr auto ndims_hs_minor = impl.template at<4>();
|
||||
constexpr auto ndims_ps_low = impl.template at<5>();
|
||||
constexpr auto rs_lengths_impl = impl.template at<6>();
|
||||
constexpr auto hs_lengthss_impl = impl.template at<7>();
|
||||
constexpr auto ps_to_rhss_major_impl = impl.template at<8>();
|
||||
constexpr auto ps_to_rhss_minor_impl = impl.template at<9>();
|
||||
constexpr auto ys_to_rhs_major_impl = impl.template at<10>();
|
||||
constexpr auto ys_to_rhs_minor_impl = impl.template at<11>();
|
||||
|
||||
constexpr auto rs_lengths = TO_SEQUENCE(rs_lengths_impl, ndim_r);
|
||||
constexpr auto hs_lengthss = TO_TUPLE_OF_SEQUENCE(hs_lengthss_impl, ndim_x, ndims_hs_minor);
|
||||
constexpr auto ps_to_rhss_major =
|
||||
TO_TUPLE_OF_SEQUENCE(ps_to_rhss_major_impl, ndim_p, ndims_ps_low);
|
||||
constexpr auto ps_to_rhss_minor =
|
||||
TO_TUPLE_OF_SEQUENCE(ps_to_rhss_minor_impl, ndim_p, ndims_ps_low);
|
||||
constexpr auto ys_to_rhs_major = TO_SEQUENCE(ys_to_rhs_major_impl, ndim_y);
|
||||
constexpr auto ys_to_rhs_minor = TO_SEQUENCE(ys_to_rhs_minor_impl, ndim_y);
|
||||
|
||||
return tile_distribution_encoding<remove_cvref_t<decltype(rs_lengths)>,
|
||||
remove_cvref_t<decltype(hs_lengthss)>,
|
||||
remove_cvref_t<decltype(ps_to_rhss_major)>,
|
||||
remove_cvref_t<decltype(ps_to_rhss_minor)>,
|
||||
remove_cvref_t<decltype(ys_to_rhs_major)>,
|
||||
remove_cvref_t<decltype(ys_to_rhs_minor)>>{};
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
} // namespace ck_tile
|
||||
191
include/ck_tile/core/tensor/tile_elementwise.hpp
Normal file
191
include/ck_tile/core/tensor/tile_elementwise.hpp
Normal file
@@ -0,0 +1,191 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/container/array.hpp"
|
||||
#include "ck_tile/core/container/sequence.hpp"
|
||||
#include "ck_tile/core/container/tuple.hpp"
|
||||
#include "ck_tile/core/container/container_helper.hpp"
|
||||
#include "ck_tile/core/tensor/tensor_adaptor.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// TODO: support tensors with different distribution
|
||||
template <typename InOutElementFunc,
|
||||
typename... InOutDstrTensors,
|
||||
typename = std::enable_if_t<std::conjunction_v<
|
||||
std::negation<std::is_same<std::remove_const_t<InOutDstrTensors>, NullTensor>>...>>>
|
||||
CK_TILE_DEVICE void tile_elementwise_inout(const InOutElementFunc& inout_element_func,
|
||||
InOutDstrTensors&... inout_dstr_tensors)
|
||||
{
|
||||
// TODO: make sure all distributed tensors have same lengths and distribution
|
||||
// static_assert(xxx);
|
||||
|
||||
constexpr index_t thread_buffer_size =
|
||||
type_pack_element<0, InOutDstrTensors...>::get_thread_buffer_size();
|
||||
|
||||
static_for<0, thread_buffer_size, 1>{}(
|
||||
[&](auto i) { inout_element_func(inout_dstr_tensors.get_thread_buffer().at(i)...); });
|
||||
}
|
||||
|
||||
template <typename InElementFunc,
|
||||
typename... InDstrTensors,
|
||||
typename = std::enable_if_t<
|
||||
std::conjunction_v<std::negation<std::is_same<InDstrTensors, NullTensor>>...>>>
|
||||
CK_TILE_DEVICE auto tile_elementwise_in(const InElementFunc& in_element_func,
|
||||
const InDstrTensors&... in_dstr_tensors)
|
||||
{
|
||||
using OutDataType = decltype(in_element_func(typename InDstrTensors::DataType{}...));
|
||||
|
||||
// TODO: make sure all distributed tensors have same lengths and distribution
|
||||
// static_assert(xxx);
|
||||
constexpr auto in_tile_dstr = type_pack_element<0, InDstrTensors...>::get_tile_distribution();
|
||||
|
||||
constexpr index_t thread_buffer_size =
|
||||
type_pack_element<0, InDstrTensors...>::get_thread_buffer_size();
|
||||
|
||||
auto out_dstr_tensor = make_static_distributed_tensor<OutDataType>(in_tile_dstr);
|
||||
|
||||
static_for<0, thread_buffer_size, 1>{}([&](auto i) {
|
||||
out_dstr_tensor.get_thread_buffer()(i) =
|
||||
in_element_func(in_dstr_tensors.get_thread_buffer()[i]...);
|
||||
});
|
||||
|
||||
return out_dstr_tensor;
|
||||
}
|
||||
|
||||
template <typename DstrTensors, typename T>
|
||||
CK_TILE_DEVICE void set_tile(DstrTensors& dstr_tensor, const T& value)
|
||||
{
|
||||
tile_elementwise_inout(
|
||||
[&value](auto& x) {
|
||||
x = type_convert<typename DstrTensors::DataType, remove_cvref_t<T>>(value);
|
||||
},
|
||||
dstr_tensor);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE void set_tile(NullTensor&, const T&)
|
||||
{
|
||||
}
|
||||
|
||||
// TODO: prefer to use per-dword value to set a tensor, in case compiler not doing well with
|
||||
// sub-dword tensor...
|
||||
template <typename DstrTensors, index_t v>
|
||||
CK_TILE_DEVICE void set_tile(DstrTensors& dstr_tensor, number<v>)
|
||||
{
|
||||
constexpr index_t tensor_bytes =
|
||||
DstrTensors::get_thread_buffer_size() * sizeof(typename DstrTensors::DataType);
|
||||
if constexpr(v == 0 && tensor_bytes % 4 == 0)
|
||||
{
|
||||
using dvec_t = static_buffer_c<index_t, tensor_bytes / 4>;
|
||||
auto& tensor = reinterpret_cast<dvec_t&>(dstr_tensor.get_thread_buffer());
|
||||
for(auto i = 0; i < tensor.size(); i++)
|
||||
tensor.get(i) = v;
|
||||
}
|
||||
else
|
||||
{
|
||||
tile_elementwise_inout(
|
||||
[](auto& x) { x = type_convert<typename DstrTensors::DataType, index_t>(v); },
|
||||
dstr_tensor);
|
||||
}
|
||||
}
|
||||
|
||||
template <index_t v>
|
||||
CK_TILE_DEVICE void set_tile(NullTensor&, number<v>)
|
||||
{
|
||||
}
|
||||
|
||||
template <typename DstrTensors>
|
||||
CK_TILE_DEVICE void clear_tile(DstrTensors& dstr_tensor)
|
||||
{
|
||||
set_tile(dstr_tensor, 0);
|
||||
}
|
||||
|
||||
// TODO: this is ugly
|
||||
template <typename OutDataType, typename InDstrTensors>
|
||||
CK_TILE_DEVICE auto cast_tile_pk_fp8x4(const InDstrTensors& in_dstr_tensors)
|
||||
{
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
// This API is designed to use the _pk_ serious of function
|
||||
constexpr auto in_tile_dstr = InDstrTensors::get_tile_distribution();
|
||||
|
||||
constexpr index_t thread_buffer_size = InDstrTensors::get_thread_buffer_size();
|
||||
static_assert(thread_buffer_size % 4 == 0);
|
||||
constexpr index_t thread_buffer_size_pk = thread_buffer_size / 4;
|
||||
|
||||
auto out_dstr_tensor = make_static_distributed_tensor<OutDataType>(in_tile_dstr);
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wuninitialized"
|
||||
// __builtin_amdgcn_cvt_pk_fp8_f32() this builtin require the old value, and
|
||||
// will generate a v_mov_b32 vxxx [old] before cvt, which result in unwanted ISA
|
||||
// so we prepare an uninitialized variable purposely, and turn off the warning
|
||||
int dummy_old;
|
||||
static_for<0, thread_buffer_size_pk, 1>{}([&](auto i) {
|
||||
uint32_t x = __builtin_amdgcn_cvt_pk_fp8_f32(
|
||||
in_dstr_tensors.get_thread_buffer()[number<4 * i + 0>{}],
|
||||
in_dstr_tensors.get_thread_buffer()[number<4 * i + 1>{}],
|
||||
dummy_old,
|
||||
false); // false -> WORD0
|
||||
|
||||
uint32_t y = __builtin_amdgcn_cvt_pk_fp8_f32(
|
||||
in_dstr_tensors.get_thread_buffer()[number<4 * i + 2>{}],
|
||||
in_dstr_tensors.get_thread_buffer()[number<4 * i + 3>{}],
|
||||
dummy_old,
|
||||
false); // false -> WORD0
|
||||
|
||||
constexpr int32_t m0 = 0x05040100;
|
||||
using vec_t = typename vector_type<OutDataType, 4>::type;
|
||||
|
||||
vec_t d = bit_cast<vec_t>(__builtin_amdgcn_perm(y, x, m0));
|
||||
out_dstr_tensor.get_thread_buffer().template set_as<vec_t>(number<i>{}, d);
|
||||
});
|
||||
#pragma clang diagnostic pop
|
||||
|
||||
return out_dstr_tensor;
|
||||
#else
|
||||
// fallback
|
||||
return tile_elementwise_in(type_convert<OutDataType, typename InDstrTensors::DataType>,
|
||||
in_dstr_tensors);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename DstType, typename SrcDstrTensors>
|
||||
CK_TILE_DEVICE auto cast_tile(const SrcDstrTensors& src_tensor)
|
||||
{
|
||||
if constexpr((ck_tile::is_same_v<DstType, fp8_t> ||
|
||||
ck_tile::is_same_v<DstType, bf8_t>)&&ck_tile::
|
||||
is_same_v<typename SrcDstrTensors::DataType, float> &&
|
||||
(SrcDstrTensors::get_thread_buffer_size() % 4 == 0))
|
||||
{
|
||||
return cast_tile_pk_fp8x4<DstType, SrcDstrTensors>(src_tensor);
|
||||
}
|
||||
else
|
||||
return tile_elementwise_in(type_convert<DstType, typename SrcDstrTensors::DataType>,
|
||||
src_tensor);
|
||||
}
|
||||
|
||||
// no-op function for NullTensor arguments
|
||||
template <typename InOutElementFunc,
|
||||
typename... MaybeNullTensor,
|
||||
typename = std::enable_if_t<
|
||||
std::disjunction_v<std::is_same<remove_cvref_t<MaybeNullTensor>, NullTensor>...>>>
|
||||
CK_TILE_DEVICE void tile_elementwise_inout(const InOutElementFunc&, MaybeNullTensor&&...)
|
||||
{
|
||||
}
|
||||
|
||||
// no-op function for NullTensor arguments
|
||||
template <typename InElementFunc,
|
||||
typename... MaybeNullTensor,
|
||||
typename = std::enable_if_t<
|
||||
std::disjunction_v<std::is_same<remove_cvref_t<MaybeNullTensor>, NullTensor>...>>>
|
||||
CK_TILE_DEVICE auto tile_elementwise_in(const InElementFunc&, MaybeNullTensor&&...)
|
||||
{
|
||||
return NullTensor{};
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
735
include/ck_tile/core/tensor/tile_window.hpp
Normal file
735
include/ck_tile/core/tensor/tile_window.hpp
Normal file
@@ -0,0 +1,735 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/container/array.hpp"
|
||||
#include "ck_tile/core/container/sequence.hpp"
|
||||
#include "ck_tile/core/container/tuple.hpp"
|
||||
#include "ck_tile/core/container/container_helper.hpp"
|
||||
#include "ck_tile/core/tensor/tensor_adaptor.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
typename StaticTileDistribution_,
|
||||
index_t NumCoord>
|
||||
struct tile_window_with_static_distribution
|
||||
{
|
||||
using BottomTensorView = remove_reference_t<BottomTensorView_>;
|
||||
using WindowLengths = remove_cvref_t<WindowLengths_>;
|
||||
using TileDstr = remove_cvref_t<StaticTileDistribution_>;
|
||||
|
||||
using WindowAdaptor = typename TileDstr::PsYs2XsAdaptor;
|
||||
using BottomTensorDesc = typename BottomTensorView::TensorDesc;
|
||||
|
||||
using DataType = remove_cvref_t<typename BottomTensorView::DataType>;
|
||||
|
||||
static constexpr index_t NDimWindowAdaptorTop = WindowAdaptor::get_num_of_top_dimension();
|
||||
static constexpr index_t NDimBottomTensor = BottomTensorDesc::get_num_of_dimension();
|
||||
|
||||
static constexpr index_t NDimP = TileDstr::GetNumOfDimensionP();
|
||||
static constexpr index_t NDimY = TileDstr::GetNumOfDimensionY();
|
||||
|
||||
static constexpr auto I0 = number<0>{};
|
||||
static constexpr auto I1 = number<1>{};
|
||||
|
||||
// TODO: check WindowLengths and StaticTileDistribution are consistent
|
||||
|
||||
static_assert(ck_tile::is_known_at_compile_time<WindowLengths>::value,
|
||||
"wrong! lengths should be static");
|
||||
static_assert(TileDstr::is_static(), "wrong!");
|
||||
|
||||
static_assert(NDimBottomTensor == WindowAdaptor::get_num_of_bottom_dimension(),
|
||||
"wrong! inconsistent # of diemsnions");
|
||||
|
||||
using AdaptorTopIndex = array<index_t, NDimWindowAdaptorTop>;
|
||||
using BottomTensorIndex = array<index_t, NDimBottomTensor>;
|
||||
|
||||
using WindowAdaptorCoord =
|
||||
decltype(make_tensor_adaptor_coordinate(WindowAdaptor{}, AdaptorTopIndex{}));
|
||||
|
||||
using BottomTensorCoord =
|
||||
decltype(make_tensor_coordinate(BottomTensorDesc{}, BottomTensorIndex{}));
|
||||
|
||||
struct load_store_traits
|
||||
{
|
||||
private:
|
||||
static constexpr auto get_vector_dim_y_scalar_per_vector()
|
||||
{
|
||||
const auto [ys_vector_lengths, ys_vector_strides] =
|
||||
tile_window_with_static_distribution::
|
||||
get_window_adaptor_ys_safe_vector_length_strides();
|
||||
|
||||
index_t VectorDimY_ = 0;
|
||||
index_t ScalarPerVector_ = 1;
|
||||
|
||||
for(index_t i = 0; i < NDimY; ++i)
|
||||
{
|
||||
if(ys_vector_strides[i] == 1 && ys_vector_lengths[i] > ScalarPerVector_)
|
||||
{
|
||||
ScalarPerVector_ = ys_vector_lengths[i];
|
||||
VectorDimY_ = i;
|
||||
}
|
||||
}
|
||||
|
||||
return make_tuple(VectorDimY_, ScalarPerVector_);
|
||||
}
|
||||
|
||||
public:
|
||||
static constexpr index_t VectorDimY = get_vector_dim_y_scalar_per_vector().template at<0>();
|
||||
static constexpr index_t ScalarPerVector =
|
||||
get_vector_dim_y_scalar_per_vector().template at<1>();
|
||||
|
||||
using vector_type_t = vector_type_maker_t<DataType, ScalarPerVector>;
|
||||
using vector_t = typename vector_type_t::type;
|
||||
|
||||
private:
|
||||
static constexpr auto scalars_per_access_ = [] {
|
||||
constexpr auto scalars_per_access_arr = generate_array(
|
||||
[&](auto i) { return (i == VectorDimY) ? ScalarPerVector : 1; }, number<NDimY>{});
|
||||
|
||||
/// TODO: add non-automatic storage argument support to macro TO_SEQUENCE()
|
||||
constexpr auto NDimY_ = NDimY;
|
||||
|
||||
return TO_SEQUENCE(scalars_per_access_arr, NDimY_);
|
||||
}();
|
||||
|
||||
static constexpr auto get_space_filling_curve()
|
||||
{
|
||||
constexpr auto tile_dstr = TileDstr{};
|
||||
|
||||
constexpr auto thread_tensor_lengths_ys =
|
||||
to_sequence(tile_dstr.get_ys_to_d_descriptor().get_lengths());
|
||||
|
||||
// FIXME: need logic to judge dim access order
|
||||
using DimAccessOrder = typename arithmetic_sequence_gen<0, NDimY, 1>::type;
|
||||
|
||||
return space_filling_curve<decltype(thread_tensor_lengths_ys),
|
||||
DimAccessOrder,
|
||||
decltype(scalars_per_access_)>{};
|
||||
}
|
||||
|
||||
public:
|
||||
using SFC_Ys = decltype(get_space_filling_curve());
|
||||
|
||||
static constexpr index_t NumAccess = SFC_Ys::get_num_of_access();
|
||||
|
||||
static_assert(0 < NumAccess, "Wrong! NumAccess should be larger than 0");
|
||||
static_assert(NumAccess % NumCoord == 0, "wrong! # of access is not divisible by NumCoord");
|
||||
};
|
||||
|
||||
static constexpr index_t NumAccessPerCoord = load_store_traits::NumAccess / NumCoord;
|
||||
|
||||
CK_TILE_DEVICE constexpr tile_window_with_static_distribution() = default;
|
||||
|
||||
CK_TILE_DEVICE constexpr tile_window_with_static_distribution(
|
||||
const BottomTensorView& bottom_tensor_view,
|
||||
const WindowLengths& window_lengths,
|
||||
const BottomTensorIndex& window_origin,
|
||||
const TileDstr& tile_distribution)
|
||||
: bottom_tensor_view_{bottom_tensor_view},
|
||||
window_lengths_{window_lengths},
|
||||
window_origin_{window_origin},
|
||||
tile_dstr_{tile_distribution},
|
||||
pre_computed_coords_{}
|
||||
{
|
||||
#if 0 // debug
|
||||
// TODO: this use more register for FA, but less register for GEMM
|
||||
// need investigation
|
||||
// only support warp-tile and block-tile
|
||||
static_assert(NDimP == 1 or NDimP == 2, "wrong!");
|
||||
|
||||
WindowAdaptorCoord window_adaptor_thread_coord_tmp;
|
||||
|
||||
if constexpr(NDimP == 1)
|
||||
{
|
||||
window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
|
||||
tile_distribution.get_ps_ys_to_xs_adaptor(), AdaptorTopIndex{get_lane_id(), 0});
|
||||
}
|
||||
else if constexpr(NDimP == 2)
|
||||
{
|
||||
window_adaptor_thread_coord_tmp =
|
||||
make_tensor_adaptor_coordinate(tile_distribution.get_ps_ys_to_xs_adaptor(),
|
||||
AdaptorTopIndex{get_warp_id(), get_lane_id(), 0});
|
||||
}
|
||||
#else
|
||||
// TODO: this use less register for FA, but more register for GEMM
|
||||
// need investigation
|
||||
const auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
|
||||
tile_distribution.get_ps_ys_to_xs_adaptor(),
|
||||
container_concat(detail::get_partition_index(tile_distribution),
|
||||
array<index_t, NDimY>{0}));
|
||||
#endif
|
||||
|
||||
BottomTensorIndex bottom_tensor_thread_origin_idx_tmp =
|
||||
window_origin + window_adaptor_thread_coord_tmp.get_bottom_index();
|
||||
|
||||
const auto bottom_tensor_thread_coord_tmp = make_tensor_coordinate(
|
||||
bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp);
|
||||
|
||||
// pre-compute NumCoord (WindowAdaptorCoord, BottomTensorCoord) bundles to speed up
|
||||
// future load/store() calls (might allocate more registers)
|
||||
using Traits = load_store_traits;
|
||||
using SFC_Ys = typename Traits::SFC_Ys;
|
||||
|
||||
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
|
||||
auto window_adaptor_thread_coord = window_adaptor_thread_coord_tmp;
|
||||
auto bottom_tensor_thread_coord = bottom_tensor_thread_coord_tmp;
|
||||
|
||||
constexpr auto idx_diff_ys =
|
||||
SFC_Ys::get_step_between(number<0>{}, number<iCoord * NumAccessPerCoord>{});
|
||||
|
||||
constexpr auto idx_diff_ps_ys = container_concat(array<index_t, NDimP>{0}, idx_diff_ys);
|
||||
|
||||
move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
|
||||
|
||||
pre_computed_coords_(iCoord) =
|
||||
make_tuple(window_adaptor_thread_coord, bottom_tensor_thread_coord);
|
||||
});
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr index_t get_num_of_dimension() { return NDimBottomTensor; }
|
||||
|
||||
CK_TILE_DEVICE static constexpr bool has_static_tile_distribution()
|
||||
{
|
||||
return TileDstr::is_static();
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE constexpr auto get_window_lengths() const { return window_lengths_; }
|
||||
|
||||
CK_TILE_DEVICE constexpr auto get_tile_distribution() const { return tile_dstr_; }
|
||||
|
||||
CK_TILE_DEVICE constexpr auto get_bottom_tensor_view() const { return bottom_tensor_view_; }
|
||||
|
||||
CK_TILE_DEVICE constexpr auto get_window_origin() const { return window_origin_; }
|
||||
|
||||
// move thread's window adaptor coordinate and bottom tensor coordinate
|
||||
// [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...] ==> [x0', x1', ...] ==> [offset]
|
||||
CK_TILE_DEVICE void move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
WindowAdaptorCoord& window_adaptor_thread_coord,
|
||||
BottomTensorCoord& bottom_tensor_thread_coord,
|
||||
const AdaptorTopIndex& idx_diff_adaptor_top) const
|
||||
{
|
||||
array<index_t, NDimBottomTensor> idx_diff_adaptor_bottom;
|
||||
|
||||
move_tensor_adaptor_coordinate(tile_dstr_.get_ps_ys_to_xs_adaptor(),
|
||||
window_adaptor_thread_coord,
|
||||
idx_diff_adaptor_top,
|
||||
idx_diff_adaptor_bottom);
|
||||
|
||||
move_tensor_coordinate(bottom_tensor_view_.get_tensor_descriptor(),
|
||||
bottom_tensor_thread_coord,
|
||||
idx_diff_adaptor_bottom);
|
||||
}
|
||||
|
||||
// return vector dimension among [y0, y1, ...]
|
||||
CK_TILE_DEVICE static constexpr auto get_window_adaptor_ys_safe_vector_length_strides()
|
||||
{
|
||||
// bottom tensor top dimension vector lengths and strides
|
||||
const auto [bottom_tensor_top_dim_vector_lengths, bottom_tensor_top_dim_vector_strides] =
|
||||
BottomTensorDesc::get_top_dimension_safe_vector_length_strides();
|
||||
|
||||
// window vector lengths/strides
|
||||
const auto window_adaptor_bottom_dim_vector_lengths = bottom_tensor_top_dim_vector_lengths;
|
||||
const auto window_adaptor_bottom_dim_vector_strides = bottom_tensor_top_dim_vector_strides;
|
||||
|
||||
// window adaptor [p0, p1, ..., y0, y1, ...]
|
||||
array<index_t, WindowAdaptor::get_num_of_hidden_dimension()> window_adaptor_vector_lengths{
|
||||
-1};
|
||||
array<index_t, WindowAdaptor::get_num_of_hidden_dimension()> window_adaptor_vector_strides{
|
||||
-1};
|
||||
|
||||
constexpr auto window_adaptor_bottom_dims =
|
||||
WindowAdaptor::get_bottom_dimension_hidden_ids();
|
||||
|
||||
set_container_subset(window_adaptor_vector_lengths,
|
||||
window_adaptor_bottom_dims,
|
||||
window_adaptor_bottom_dim_vector_lengths);
|
||||
set_container_subset(window_adaptor_vector_strides,
|
||||
window_adaptor_bottom_dims,
|
||||
window_adaptor_bottom_dim_vector_strides);
|
||||
|
||||
const auto [window_adaptor_ps_ys_vector_lengths, window_adaptor_ps_ys_vector_strides] =
|
||||
WindowAdaptor{}.get_top_dimension_safe_vector_length_strides(
|
||||
window_adaptor_vector_lengths, window_adaptor_vector_strides);
|
||||
|
||||
// [y0, y1, ...]
|
||||
constexpr auto y_dims = typename arithmetic_sequence_gen<TileDstr::GetNumOfDimensionP(),
|
||||
NDimWindowAdaptorTop,
|
||||
1>::type{};
|
||||
|
||||
return make_tuple(get_container_subset(window_adaptor_ps_ys_vector_lengths, y_dims),
|
||||
get_container_subset(window_adaptor_ps_ys_vector_strides, y_dims));
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE constexpr auto get_num_access() const { return load_store_traits::NumAccess; }
|
||||
|
||||
template <bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE auto load(bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
using Traits = load_store_traits;
|
||||
|
||||
using vector_type_t = typename Traits::vector_type_t;
|
||||
using vector_t = typename vector_type_t::type;
|
||||
using SFC_Ys = typename Traits::SFC_Ys;
|
||||
|
||||
constexpr auto tile_dstr = TileDstr{};
|
||||
|
||||
auto dst_tensor = make_static_distributed_tensor<DataType>(tile_dstr);
|
||||
|
||||
// loop over thread tensor space [y0, y1, ...]
|
||||
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
|
||||
/// TODO: use structure binding (to be captured later) if compiled in C++20
|
||||
auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
|
||||
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
|
||||
|
||||
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
|
||||
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
|
||||
|
||||
// data index [y0, y1, ...]
|
||||
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
|
||||
|
||||
// read from bottom tensor
|
||||
const vector_t vec_value =
|
||||
get_bottom_tensor_view().template get_vectorized_elements<vector_t>(
|
||||
bottom_tensor_thread_coord, bool_constant<oob_conditional_check>{});
|
||||
|
||||
const vector_type_t vec{vec_value};
|
||||
|
||||
// write into distributed tensor
|
||||
static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) {
|
||||
constexpr auto idx_ys = generate_array(
|
||||
[&](auto jj) {
|
||||
return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
|
||||
: idx_ys_start[jj];
|
||||
},
|
||||
number<NDimY>{});
|
||||
|
||||
constexpr index_t d =
|
||||
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
|
||||
|
||||
dst_tensor.get_thread_buffer().template at<d>() =
|
||||
vec.template AsType<DataType>()[j];
|
||||
});
|
||||
|
||||
// move thread coordinate
|
||||
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
|
||||
{
|
||||
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
|
||||
|
||||
constexpr auto idx_diff_ps_ys =
|
||||
container_concat(array<index_t, NDimP>{0}, idx_diff_ys);
|
||||
|
||||
move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
return dst_tensor;
|
||||
}
|
||||
|
||||
template <typename DstTile, bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE void load_raw(DstTile& dst_tensor,
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
using Traits = load_store_traits;
|
||||
|
||||
using vector_type_t = typename Traits::vector_type_t;
|
||||
using vector_t = typename vector_type_t::type;
|
||||
using SFC_Ys = typename Traits::SFC_Ys;
|
||||
static constexpr index_t YElementSize =
|
||||
TileDstr{}.get_ys_to_d_descriptor().get_element_space_size();
|
||||
static_assert(YElementSize % Traits::ScalarPerVector == 0);
|
||||
using vectorized_tbuf = StaticBuffer<address_space_enum::vgpr,
|
||||
vector_t,
|
||||
YElementSize / Traits::ScalarPerVector,
|
||||
true>;
|
||||
|
||||
constexpr auto tile_dstr = TileDstr{};
|
||||
|
||||
auto& dst_vec_tbuf = reinterpret_cast<vectorized_tbuf&>(dst_tensor.get_thread_buffer());
|
||||
|
||||
// loop over thread tensor space [y0, y1, ...]
|
||||
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
|
||||
/// TODO: use structure binding (to be captured later) if compiled in C++20
|
||||
auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
|
||||
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
|
||||
|
||||
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
|
||||
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
|
||||
|
||||
// data index [y0, y1, ...]
|
||||
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
|
||||
constexpr index_t d =
|
||||
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start);
|
||||
static_assert(d % Traits::ScalarPerVector == 0);
|
||||
|
||||
get_bottom_tensor_view().template get_vectorized_elements_raw<vector_t>(
|
||||
dst_vec_tbuf.template at<d / Traits::ScalarPerVector>(),
|
||||
bottom_tensor_thread_coord,
|
||||
bool_constant<oob_conditional_check>{});
|
||||
|
||||
// move thread coordinate
|
||||
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
|
||||
{
|
||||
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
|
||||
|
||||
constexpr auto idx_diff_ps_ys =
|
||||
container_concat(array<index_t, NDimP>{0}, idx_diff_ys);
|
||||
|
||||
move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
// TODO: currently async load only implemented in inline asm
|
||||
template <typename LdsTileWindow_, bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE auto async_load(LdsTileWindow_&& lds_tile,
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
|
||||
// using LdsTensorView = typename LdsTileWindow::BottomTensorView;
|
||||
using LdsDataType = typename LdsTileWindow::DataType;
|
||||
// using LdsDescriptor = typename LdsTileWindow::BottomTensorDesc;
|
||||
|
||||
// issues * warps * lanes
|
||||
static_assert(LdsTileWindow::get_num_of_dimension() == 3); // TODO: hard coded
|
||||
|
||||
const index_t size_per_buf =
|
||||
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
|
||||
make_tuple(number<0>{}, number<0>{}, number<0>{})) *
|
||||
sizeof(LdsDataType);
|
||||
|
||||
const index_t size_per_wave =
|
||||
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
|
||||
make_tuple(number<0>{}, number<1>{}, number<0>{})) *
|
||||
sizeof(LdsDataType) -
|
||||
size_per_buf;
|
||||
|
||||
const index_t size_per_issue =
|
||||
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
|
||||
make_tuple(number<1>{}, number<0>{}, number<0>{})) *
|
||||
sizeof(LdsDataType) -
|
||||
size_per_buf;
|
||||
|
||||
const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id();
|
||||
m0_set_with_memory(m0_init_value); // This should be wave independent
|
||||
|
||||
using Traits = load_store_traits;
|
||||
|
||||
using vector_type_t = typename Traits::vector_type_t;
|
||||
using vector_t = typename vector_type_t::type;
|
||||
using SFC_Ys = typename Traits::SFC_Ys;
|
||||
|
||||
LdsDataType* smem = lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_;
|
||||
|
||||
// loop over thread tensor space [y0, y1, ...]
|
||||
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
|
||||
// TODO: use structure binding (to be captured later) if compiled in C++20
|
||||
auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
|
||||
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
|
||||
|
||||
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
|
||||
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
|
||||
|
||||
// read from bottom tensor
|
||||
get_bottom_tensor_view().template async_get_vectorized_elements<vector_t>(
|
||||
smem, bottom_tensor_thread_coord);
|
||||
|
||||
// move thread coordinate
|
||||
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
|
||||
{
|
||||
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
|
||||
|
||||
constexpr auto idx_diff_ps_ys =
|
||||
container_concat(array<index_t, NDimP>{0}, idx_diff_ys);
|
||||
|
||||
move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
|
||||
|
||||
m0_inc_with_memory(size_per_issue);
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
template <bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE void store(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor,
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
using Traits = load_store_traits;
|
||||
|
||||
using vector_type_t = typename Traits::vector_type_t;
|
||||
using vector_t = typename vector_type_t::type;
|
||||
using SFC_Ys = typename Traits::SFC_Ys;
|
||||
|
||||
constexpr auto tile_dstr = TileDstr{};
|
||||
|
||||
// loop over thread tensor space [y0, y1, ...]
|
||||
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
|
||||
/// TODO: use structure binding (to be captured later) if compiled in C++20
|
||||
auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
|
||||
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
|
||||
|
||||
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
|
||||
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
|
||||
|
||||
// data index [y0, y1, ...]
|
||||
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
|
||||
|
||||
// read from distributed tensor
|
||||
vector_type_t vec;
|
||||
|
||||
static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) {
|
||||
constexpr auto idx_ys = generate_array(
|
||||
[&](auto jj) {
|
||||
return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
|
||||
: idx_ys_start[jj];
|
||||
},
|
||||
number<NDimY>{});
|
||||
|
||||
constexpr index_t d =
|
||||
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
|
||||
|
||||
vec.template AsType<DataType>()(j) =
|
||||
dstr_tensor.get_thread_buffer().template at<d>();
|
||||
});
|
||||
|
||||
const vector_t vec_value = vec.template AsType<vector_t>().template at<0>();
|
||||
|
||||
// write into bottom tensor
|
||||
get_bottom_tensor_view().template set_vectorized_elements<vector_t>(
|
||||
bottom_tensor_thread_coord, vec_value, bool_constant<oob_conditional_check>{});
|
||||
|
||||
// move thread coordinate
|
||||
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
|
||||
{
|
||||
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
|
||||
|
||||
constexpr auto idx_diff_ps_ys =
|
||||
container_concat(array<index_t, NDimP>{0}, idx_diff_ys);
|
||||
|
||||
move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void
|
||||
store_raw(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor) const
|
||||
{
|
||||
using Traits = load_store_traits;
|
||||
|
||||
using vector_type_t = typename Traits::vector_type_t;
|
||||
using vector_t = typename vector_type_t::type;
|
||||
using SFC_Ys = typename Traits::SFC_Ys;
|
||||
|
||||
constexpr auto tile_dstr = TileDstr{};
|
||||
static constexpr bool oob_conditional_check = true;
|
||||
|
||||
// loop over thread tensor space [y0, y1, ...]
|
||||
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
|
||||
/// TODO: use structure binding (to be captured later) if compiled in C++20
|
||||
auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
|
||||
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
|
||||
|
||||
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
|
||||
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
|
||||
|
||||
// data index [y0, y1, ...]
|
||||
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
|
||||
|
||||
// read from distributed tensor
|
||||
vector_type_t vec;
|
||||
|
||||
static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) {
|
||||
constexpr auto idx_ys = generate_array(
|
||||
[&](auto jj) {
|
||||
return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
|
||||
: idx_ys_start[jj];
|
||||
},
|
||||
number<NDimY>{});
|
||||
|
||||
constexpr index_t d =
|
||||
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
|
||||
|
||||
vec.template AsType<DataType>()(j) =
|
||||
dstr_tensor.get_thread_buffer().template at<d>();
|
||||
});
|
||||
|
||||
const vector_t vec_value = vec.template AsType<vector_t>().template at<0>();
|
||||
|
||||
// write into bottom tensor
|
||||
get_bottom_tensor_view()
|
||||
.template set_vectorized_elements_raw<vector_t, oob_conditional_check>(
|
||||
bottom_tensor_thread_coord, vec_value);
|
||||
|
||||
// move thread coordinate
|
||||
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
|
||||
{
|
||||
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
|
||||
|
||||
constexpr auto idx_diff_ps_ys =
|
||||
container_concat(array<index_t, NDimP>{0}, idx_diff_ys);
|
||||
|
||||
move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
// move thread's botom tensor coordiante
|
||||
// [x0', x1', ... ] ==> [offset]
|
||||
// also move window-origin
|
||||
CK_TILE_DEVICE void move(const BottomTensorIndex& step)
|
||||
{
|
||||
window_origin_ += step;
|
||||
|
||||
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
|
||||
move_tensor_coordinate(bottom_tensor_view_.get_tensor_descriptor(),
|
||||
pre_computed_coords_(iCoord)(I1),
|
||||
step);
|
||||
});
|
||||
}
|
||||
|
||||
// this is the bottom tensor view
|
||||
// [x0', x1', ...] ==> [offset]
|
||||
BottomTensorView bottom_tensor_view_;
|
||||
|
||||
//
|
||||
WindowLengths window_lengths_;
|
||||
|
||||
// origin ([x0', x1', ...]) of window on bottom tensor
|
||||
BottomTensorIndex window_origin_;
|
||||
|
||||
// Tile tensor distribution, which contains:
|
||||
// 1. adaptor for window: [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...]
|
||||
// 2. thread descriptor for thread tensor in register: [y0, y1, ...] ==> [d]
|
||||
TileDstr tile_dstr_;
|
||||
|
||||
// this contains:
|
||||
// per-thread coordinate for window adaptor
|
||||
// per-thread coordinate for bottom tensor
|
||||
array<tuple<WindowAdaptorCoord, BottomTensorCoord>, NumCoord> pre_computed_coords_;
|
||||
};
|
||||
|
||||
// TODO: use strategy
|
||||
template <typename TensorView_,
|
||||
typename WindowLengths_,
|
||||
typename StaticTileDistribution_,
|
||||
index_t NumCoord = 1>
|
||||
CK_TILE_DEVICE constexpr auto
|
||||
make_tile_window(const TensorView_& tensor_view,
|
||||
const WindowLengths_& window_lengths,
|
||||
const multi_index<TensorView_::get_num_of_dimension()>& origin,
|
||||
const StaticTileDistribution_& tile_distribution,
|
||||
number<NumCoord> = {})
|
||||
{
|
||||
return tile_window_with_static_distribution<remove_cvref_t<TensorView_>,
|
||||
remove_cvref_t<WindowLengths_>,
|
||||
remove_cvref_t<StaticTileDistribution_>,
|
||||
NumCoord>{
|
||||
tensor_view, window_lengths, origin, tile_distribution};
|
||||
}
|
||||
|
||||
template <typename TensorView_,
|
||||
typename WindowLengths_,
|
||||
typename StaticTileDistribution_,
|
||||
index_t NumCoord>
|
||||
CK_TILE_DEVICE void move_tile_window(
|
||||
tile_window_with_static_distribution<TensorView_,
|
||||
WindowLengths_,
|
||||
StaticTileDistribution_,
|
||||
NumCoord>& window,
|
||||
const typename tile_window_with_static_distribution<TensorView_,
|
||||
WindowLengths_,
|
||||
StaticTileDistribution_,
|
||||
NumCoord>::BottomTensorIndex& step)
|
||||
{
|
||||
window.move(step);
|
||||
}
|
||||
|
||||
template <typename BottomTensorView_, typename WindowLengths_>
|
||||
struct tile_window_with_static_lengths
|
||||
{
|
||||
using BottomTensorView = remove_reference_t<BottomTensorView_>;
|
||||
using WindowLengths = remove_cvref_t<WindowLengths_>;
|
||||
using BottomTensorDesc = typename BottomTensorView::TensorDesc;
|
||||
using DataType = typename BottomTensorView::DataType;
|
||||
|
||||
static constexpr index_t NDimBottomTensor = BottomTensorDesc::get_num_of_dimension();
|
||||
|
||||
static_assert(ck_tile::is_known_at_compile_time<WindowLengths>::value,
|
||||
"wrong! lengths should be static");
|
||||
|
||||
using BottomTensorIndex = array<index_t, NDimBottomTensor>;
|
||||
|
||||
CK_TILE_DEVICE constexpr tile_window_with_static_lengths() = default;
|
||||
|
||||
CK_TILE_DEVICE constexpr tile_window_with_static_lengths(
|
||||
const BottomTensorView& bottom_tensor_view,
|
||||
const WindowLengths& window_lengths,
|
||||
const BottomTensorIndex& window_origin)
|
||||
: bottom_tensor_view_{bottom_tensor_view},
|
||||
window_lengths_{window_lengths},
|
||||
window_origin_{window_origin}
|
||||
{
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr index_t get_num_of_dimension() { return NDimBottomTensor; }
|
||||
|
||||
CK_TILE_DEVICE constexpr auto get_window_lengths() const { return window_lengths_; }
|
||||
|
||||
CK_TILE_DEVICE constexpr auto get_bottom_tensor_view() const { return bottom_tensor_view_; }
|
||||
|
||||
CK_TILE_DEVICE constexpr auto get_window_origin() const { return window_origin_; }
|
||||
|
||||
// move window-origin
|
||||
CK_TILE_DEVICE void move(const BottomTensorIndex& step) { window_origin_ += step; }
|
||||
|
||||
// this is the bottom tensor view
|
||||
// [x0', x1', ...] ==> [offset]
|
||||
BottomTensorView bottom_tensor_view_;
|
||||
|
||||
//
|
||||
WindowLengths window_lengths_;
|
||||
|
||||
// origin ([x0', x1', ...]) of window on bottom tensor
|
||||
BottomTensorIndex window_origin_;
|
||||
};
|
||||
|
||||
template <typename TensorView_, typename WindowLengths_>
|
||||
CK_TILE_DEVICE constexpr auto
|
||||
make_tile_window(const TensorView_& tensor_view,
|
||||
const WindowLengths_& window_lengths,
|
||||
const multi_index<TensorView_::get_num_of_dimension()>& origin)
|
||||
{
|
||||
static_assert(ck_tile::is_known_at_compile_time<WindowLengths_>::value,
|
||||
"wrong! lengths should be static");
|
||||
|
||||
return tile_window_with_static_lengths<remove_cvref_t<TensorView_>,
|
||||
remove_cvref_t<WindowLengths_>>{
|
||||
tensor_view, window_lengths, origin};
|
||||
}
|
||||
|
||||
template <typename TensorView_, typename WindowLengths_>
|
||||
CK_TILE_DEVICE void move_tile_window(
|
||||
tile_window_with_static_lengths<TensorView_, WindowLengths_>& window,
|
||||
const typename tile_window_with_static_lengths<TensorView_, WindowLengths_>::BottomTensorIndex&
|
||||
step)
|
||||
{
|
||||
window.move(step);
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
19
include/ck_tile/core/utility/bit_cast.hpp
Normal file
19
include/ck_tile/core/utility/bit_cast.hpp
Normal file
@@ -0,0 +1,19 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Y, typename X>
|
||||
CK_TILE_HOST_DEVICE constexpr Y bit_cast(const X& x)
|
||||
{
|
||||
static_assert(__has_builtin(__builtin_bit_cast), "");
|
||||
static_assert(sizeof(X) == sizeof(Y), "Do not support cast between different size of type");
|
||||
|
||||
return __builtin_bit_cast(Y, x);
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
194
include/ck_tile/core/utility/functional.hpp
Normal file
194
include/ck_tile/core/utility/functional.hpp
Normal file
@@ -0,0 +1,194 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/container/sequence.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include <stdint.h>
|
||||
#include <utility>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
namespace detail {
|
||||
|
||||
struct swallow
|
||||
{
|
||||
template <typename... Ts>
|
||||
CK_TILE_HOST_DEVICE constexpr swallow(Ts&&...)
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
template <class>
|
||||
struct static_for_impl;
|
||||
|
||||
template <index_t... Is>
|
||||
struct static_for_impl<sequence<Is...>>
|
||||
{
|
||||
template <class F>
|
||||
CK_TILE_HOST_DEVICE constexpr void operator()(F f) const
|
||||
{
|
||||
swallow{(f(number<Is>{}), 0)...};
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
// F signature: F(number<Iter>)
|
||||
template <index_t NBegin, index_t NEnd, index_t Increment>
|
||||
struct static_for
|
||||
{
|
||||
CK_TILE_HOST_DEVICE constexpr static_for()
|
||||
{
|
||||
static_assert(Increment != 0 && (NEnd - NBegin) % Increment == 0,
|
||||
"Wrong! should satisfy (NEnd - NBegin) % Increment == 0");
|
||||
static_assert((Increment > 0 && NBegin <= NEnd) || (Increment < 0 && NBegin >= NEnd),
|
||||
"wrongs! should (Increment > 0 && NBegin <= NEnd) || (Increment < 0 && "
|
||||
"NBegin >= NEnd)");
|
||||
}
|
||||
|
||||
template <class F>
|
||||
CK_TILE_HOST_DEVICE constexpr void operator()(F f) const
|
||||
{
|
||||
detail::static_for_impl<typename arithmetic_sequence_gen<NBegin, NEnd, Increment>::type>{}(
|
||||
f);
|
||||
}
|
||||
};
|
||||
|
||||
struct identity
|
||||
{
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE constexpr T&& operator()(T&& arg) const noexcept
|
||||
{
|
||||
return std::forward<T>(arg);
|
||||
}
|
||||
};
|
||||
|
||||
namespace detail {
|
||||
|
||||
// RemainLengths: sequence<...>
|
||||
// Orders: sequence<...>
|
||||
template <class RemainLengths, class Orders>
|
||||
struct static_ford_impl
|
||||
{
|
||||
CK_TILE_HOST_DEVICE constexpr static_ford_impl()
|
||||
{
|
||||
static_assert(RemainLengths::size() > 0, "wrong! should not get here");
|
||||
}
|
||||
|
||||
// F signature: F(sequence<...>)
|
||||
// CurrentOrderedId: sequence<...>
|
||||
template <class F, class CurrentOrderedId>
|
||||
CK_TILE_HOST_DEVICE constexpr void operator()(F f, CurrentOrderedId) const
|
||||
{
|
||||
static_for<0, RemainLengths::front(), 1>{}([=](auto I) {
|
||||
static_ford_impl<decltype(RemainLengths::pop_front()), Orders>{}(
|
||||
f, CurrentOrderedId::push_back(I));
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
template <class Orders>
|
||||
struct static_ford_impl<sequence<>, Orders>
|
||||
{
|
||||
// F signature: F(sequence<...>)
|
||||
// OrderedId: sequence<...>
|
||||
template <class F, class OrderedId>
|
||||
CK_TILE_HOST_DEVICE constexpr void operator()(F f, OrderedId) const
|
||||
{
|
||||
// retrive unordered Id
|
||||
f(OrderedId::reorder_old_to_new(Orders{}));
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
// Lengths is sequence<...>, it is the length of each dimension for
|
||||
// N-dimensional loop
|
||||
// Orders is sequence<...>, it is the order of dimension in which static_ford
|
||||
// will loop over each
|
||||
// dimension
|
||||
template <class Lengths,
|
||||
class Orders = typename arithmetic_sequence_gen<0, Lengths::size(), 1>::type>
|
||||
struct static_ford
|
||||
{
|
||||
CK_TILE_HOST_DEVICE constexpr static_ford()
|
||||
{
|
||||
static_assert(Lengths::size() > 0, "wrong! Lengths is empty");
|
||||
static_assert(Lengths::size() == Orders::size(), "wrong! inconsistent size");
|
||||
}
|
||||
|
||||
// F signature: F(sequence<...> multi_id)
|
||||
// multi_id is the unordered multi-index
|
||||
template <class F>
|
||||
CK_TILE_HOST_DEVICE constexpr void operator()(F f) const
|
||||
{
|
||||
constexpr auto ordered_lengths = Lengths::reorder_new_to_old(Orders{});
|
||||
detail::static_ford_impl<decltype(ordered_lengths), Orders>{}(f, sequence<>{});
|
||||
}
|
||||
};
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <typename Indices>
|
||||
struct unpack_impl;
|
||||
|
||||
template <index_t... Is>
|
||||
struct unpack_impl<sequence<Is...>>
|
||||
{
|
||||
template <typename F, typename X>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator()(F&& f, X&& x) const
|
||||
{
|
||||
#if 0
|
||||
return std::forward<F>(f)(std::forward<X>(x).at(number<Is>{})...);
|
||||
#else
|
||||
return std::forward<F>(f)(std::forward<X>(x).template at<Is>()...);
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Seq0, typename Seq1>
|
||||
struct unpack2_impl;
|
||||
|
||||
// TODO: remove this, after properly implementing unpack that takes any number of containers
|
||||
template <index_t... Is, index_t... Js>
|
||||
struct unpack2_impl<sequence<Is...>, sequence<Js...>>
|
||||
{
|
||||
template <typename F, typename X, typename Y>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator()(F&& f, X&& x, Y&& y) const
|
||||
{
|
||||
#if 0
|
||||
return std::forward<F>(f)(std::forward<X>(x).at(number<Is>{})...,
|
||||
std::forward<Y>(y).at(number<Js>{})...);
|
||||
#else
|
||||
return std::forward<F>(f)(std::forward<X>(x).template at<Is>()...,
|
||||
std::forward<Y>(y).template at<Js>()...);
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
template <typename F, typename X>
|
||||
CK_TILE_HOST_DEVICE constexpr auto unpack(F&& f, X&& x)
|
||||
{
|
||||
using X_ = remove_reference_t<X>;
|
||||
return detail::unpack_impl<typename arithmetic_sequence_gen<0, X_::size(), 1>::type>{}(
|
||||
std::forward<F>(f), std::forward<X>(x));
|
||||
}
|
||||
|
||||
// TODO: properly implement unpack that takes any number of containers
|
||||
template <typename F, typename X, typename Y>
|
||||
CK_TILE_HOST_DEVICE constexpr auto unpack2(F&& f, X&& x, Y&& y)
|
||||
{
|
||||
using X_ = remove_reference_t<X>;
|
||||
using Y_ = remove_reference_t<Y>;
|
||||
return detail::unpack2_impl<typename arithmetic_sequence_gen<0, X_::size(), 1>::type,
|
||||
typename arithmetic_sequence_gen<0, Y_::size(), 1>::type>{}(
|
||||
std::forward<F>(f), std::forward<X>(x), std::forward<Y>(y));
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
75
include/ck_tile/core/utility/limits.hpp
Normal file
75
include/ck_tile/core/utility/limits.hpp
Normal file
@@ -0,0 +1,75 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include <limits>
|
||||
#include <stdint.h>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename T>
|
||||
struct numeric_limits
|
||||
{
|
||||
// minimum finite value, or minimum positive normalized value for float
|
||||
CK_TILE_HOST_DEVICE static constexpr T min() { return std::numeric_limits<T>::min(); }
|
||||
|
||||
// minumum finite value
|
||||
CK_TILE_HOST_DEVICE static constexpr T lowest() { return std::numeric_limits<T>::lowest(); }
|
||||
|
||||
// maximum finite value
|
||||
CK_TILE_HOST_DEVICE static constexpr T max() { return std::numeric_limits<T>::max(); }
|
||||
|
||||
// difference between 1.0 and next value representable by float
|
||||
CK_TILE_HOST_DEVICE static constexpr T epsilon() { return std::numeric_limits<T>::epsilon(); }
|
||||
|
||||
// maximum rounding error
|
||||
CK_TILE_HOST_DEVICE static constexpr T round_error()
|
||||
{
|
||||
return std::numeric_limits<T>::round_error();
|
||||
}
|
||||
|
||||
// positive infinity value
|
||||
CK_TILE_HOST_DEVICE static constexpr T infinity() { return std::numeric_limits<T>::infinity(); }
|
||||
|
||||
// quiet NaN
|
||||
CK_TILE_HOST_DEVICE static constexpr T quiet_NaN()
|
||||
{
|
||||
return std::numeric_limits<T>::quiet_NaN();
|
||||
}
|
||||
|
||||
// signaling NaN
|
||||
CK_TILE_HOST_DEVICE static constexpr T signaling_NaN()
|
||||
{
|
||||
return std::numeric_limits<T>::signaling_NaN();
|
||||
}
|
||||
|
||||
// smallest positive subnormal value
|
||||
CK_TILE_HOST_DEVICE static constexpr T denorm_min()
|
||||
{
|
||||
return std::numeric_limits<T>::denorm_min();
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct numeric_utils;
|
||||
|
||||
template <>
|
||||
struct numeric_utils<float>
|
||||
{
|
||||
static constexpr int exp = 8;
|
||||
static constexpr int mant = 23;
|
||||
static constexpr int bias = 127;
|
||||
static constexpr uint32_t nan_mask = 0x7F800000;
|
||||
static constexpr uint32_t head_mask = 0xFF800000;
|
||||
static constexpr uint32_t mant_mask = 0x7FFFFF;
|
||||
static constexpr uint32_t exp_mask = 0xFF;
|
||||
static constexpr uint32_t Inf = 0x7F800000;
|
||||
static constexpr uint32_t NegInf = 0xFF800000;
|
||||
static constexpr uint32_t NaN = 0x7F800001;
|
||||
static constexpr uint32_t Neg0 = 0x80000000;
|
||||
using bitwise_type = uint32_t;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
261
include/ck_tile/core/utility/magic_div.hpp
Normal file
261
include/ck_tile/core/utility/magic_div.hpp
Normal file
@@ -0,0 +1,261 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/container/tuple.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/utility/bit_cast.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
#include <stdint.h>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// magic number division
|
||||
// Caution:
|
||||
// 1. For uint32_t as dividend: magic number division implementation being used would produce
|
||||
// correct result if the dividend is uint32_t and its value is within 31-bit value range.
|
||||
// 2. For int32_t as dividendd: magic number division for int32_t dividened has not been
|
||||
// implemented, the int32_t dividend would be bit-wise interpreted as uint32_t and magic number
|
||||
// division implementation for uint32_t is then used. Therefore, dividend value need to be
|
||||
// non-negative.
|
||||
// TODO:
|
||||
// 1. Implement magic number divison for int32_t
|
||||
// 2. Implement magic number divison for unit32_t with 32-bit value range
|
||||
struct magic_division32_bit_range
|
||||
{
|
||||
// uint32_t
|
||||
CK_TILE_HOST_DEVICE static constexpr auto calculate_magic_numbers(uint32_t divisor)
|
||||
{
|
||||
// WARNING: magic division is only valid for division inside this range.
|
||||
// assert(divisor >= 1 && divisor <= INT32_MAX)
|
||||
|
||||
uint32_t shift_u32 = 0;
|
||||
|
||||
while((1U << shift_u32) < divisor)
|
||||
{
|
||||
shift_u32++;
|
||||
};
|
||||
|
||||
uint64_t tmp_u64 = ((1UL << shift_u32) - divisor) << 32;
|
||||
uint32_t multiplier_u32 = tmp_u64 / divisor + 1;
|
||||
|
||||
return make_tuple(multiplier_u32, shift_u32);
|
||||
}
|
||||
|
||||
// integral_constant<uint32_t, .>
|
||||
template <uint32_t Divisor, typename = std::enable_if_t<(0 < Divisor)>>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto
|
||||
calculate_magic_numbers(integral_constant<uint32_t, Divisor>)
|
||||
{
|
||||
constexpr auto tmp = calculate_magic_numbers(uint32_t{Divisor});
|
||||
|
||||
constexpr uint32_t multiplier = tmp[number<0>{}];
|
||||
constexpr uint32_t shift = tmp[number<1>{}];
|
||||
|
||||
return make_tuple(integral_constant<uint32_t, multiplier>{},
|
||||
integral_constant<uint32_t, shift>{});
|
||||
}
|
||||
|
||||
// integral_constant<int32_t, .>
|
||||
template <int32_t Divisor>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto
|
||||
calculate_magic_numbers(integral_constant<int32_t, Divisor>)
|
||||
{
|
||||
return calculate_magic_numbers(integral_constant<uint32_t, Divisor>{});
|
||||
}
|
||||
|
||||
// magic division for uint32_t
|
||||
CK_TILE_DEVICE static constexpr uint32_t
|
||||
do_magic_division(uint32_t dividend, uint32_t multiplier, uint32_t shift)
|
||||
{
|
||||
uint32_t tmp = __umulhi(dividend, multiplier);
|
||||
return (tmp + dividend) >> shift;
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr uint32_t
|
||||
do_magic_division(uint32_t dividend, uint32_t multiplier, uint32_t shift)
|
||||
{
|
||||
uint32_t tmp = (static_cast<uint64_t>(dividend) * multiplier) >> 32;
|
||||
return (tmp + dividend) >> shift;
|
||||
}
|
||||
|
||||
// magic division for int32_t
|
||||
// HACK: use dividend_i32 as if it's uint32_t, dividend_i32 need to be
|
||||
// non-negative for result to be correct
|
||||
// TODO: figure out how to do magic number divison for int32_t as dividended
|
||||
CK_TILE_DEVICE static constexpr int32_t
|
||||
do_magic_division(int32_t dividend_i32, uint32_t multiplier, uint32_t shift)
|
||||
{
|
||||
uint32_t dividend_u32 = bit_cast<uint32_t>(dividend_i32);
|
||||
uint32_t tmp = __umulhi(dividend_u32, multiplier);
|
||||
return (tmp + dividend_u32) >> shift;
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr int32_t
|
||||
do_magic_division(int32_t dividend_i32, uint32_t multiplier, uint32_t shift)
|
||||
{
|
||||
uint32_t dividend_u32 = bit_cast<uint32_t>(dividend_i32);
|
||||
uint32_t tmp = (static_cast<uint64_t>(dividend_u32) * multiplier) >> 32;
|
||||
return (tmp + dividend_u32) >> shift;
|
||||
}
|
||||
};
|
||||
|
||||
// magic number division
|
||||
// This version on works for divisor and dividended between [0, 1 << 16]
|
||||
struct magic_division16_bit_range
|
||||
{
|
||||
// uint32_t
|
||||
CK_TILE_HOST_DEVICE static constexpr auto calculate_magic_numbers(uint32_t divisor)
|
||||
{
|
||||
// WARNING: magic division is only valid for division inside this range.
|
||||
// assert(divisor >= 1 && divisor <= (1U << 16));
|
||||
|
||||
uint32_t shift_u32 = 0;
|
||||
|
||||
while((1U << shift_u32) < divisor)
|
||||
{
|
||||
shift_u32++;
|
||||
};
|
||||
|
||||
uint32_t one = 1;
|
||||
uint32_t multiplier_u32 = ((one << 16) * ((one << shift_u32) - divisor)) / divisor + 1;
|
||||
|
||||
return make_tuple(multiplier_u32, shift_u32);
|
||||
}
|
||||
|
||||
// integral_constant<uint32_t, .>
|
||||
template <uint32_t Divisor>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto
|
||||
calculate_magic_numbers(integral_constant<uint32_t, Divisor>)
|
||||
{
|
||||
constexpr auto tmp = calculate_magic_numbers(uint32_t{Divisor});
|
||||
|
||||
constexpr uint32_t multiplier = tmp[number<0>{}];
|
||||
constexpr uint32_t shift = tmp[number<1>{}];
|
||||
|
||||
return make_tuple(integral_constant<uint32_t, multiplier>{},
|
||||
integral_constant<uint32_t, shift>{});
|
||||
}
|
||||
|
||||
// integral_constant<int32_t, .>
|
||||
template <int32_t Divisor>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto
|
||||
calculate_magic_numbers(integral_constant<int32_t, Divisor>)
|
||||
{
|
||||
return calculate_magic_numbers(integral_constant<uint32_t, Divisor>{});
|
||||
}
|
||||
|
||||
// magic division for uint32_t
|
||||
CK_TILE_DEVICE static constexpr uint32_t
|
||||
do_magic_division(uint32_t dividend, uint32_t multiplier, uint32_t shift)
|
||||
{
|
||||
uint32_t tmp = (dividend * multiplier) >> 16;
|
||||
return (tmp + dividend) >> shift;
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr uint32_t
|
||||
do_magic_division(uint32_t dividend, uint32_t multiplier, uint32_t shift)
|
||||
{
|
||||
uint32_t tmp = (dividend * multiplier) >> 16;
|
||||
return (tmp + dividend) >> shift;
|
||||
}
|
||||
|
||||
// magic division for int32_t
|
||||
// HACK: use dividend_i32 as if it's uint32_t, dividend_i32 need to be
|
||||
// non-negative for result to be correct
|
||||
// TODO: figure out how to do magic number divison for int32_t as dividended
|
||||
CK_TILE_DEVICE static constexpr int32_t
|
||||
do_magic_division(int32_t dividend_i32, uint32_t multiplier, uint32_t shift)
|
||||
{
|
||||
uint32_t dividend_u32 = bit_cast<uint32_t>(dividend_i32);
|
||||
uint32_t tmp = (dividend_u32 * multiplier) >> 16;
|
||||
return (tmp + dividend_u32) >> shift;
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr int32_t
|
||||
do_magic_division(int32_t dividend_i32, uint32_t multiplier, uint32_t shift)
|
||||
{
|
||||
uint32_t dividend_u32 = bit_cast<uint32_t>(dividend_i32);
|
||||
uint32_t tmp = (dividend_u32 * multiplier) >> 16;
|
||||
return (tmp + dividend_u32) >> shift;
|
||||
}
|
||||
};
|
||||
|
||||
// use 32bit version
|
||||
using magic_division = magic_division32_bit_range;
|
||||
|
||||
struct mdiv
|
||||
{
|
||||
// 1 dword -> 3 dword storage
|
||||
uint32_t divisor;
|
||||
uint32_t multiplier;
|
||||
uint32_t shift; // TODO: 8 bit is enough
|
||||
|
||||
// prefer construct on host
|
||||
CK_TILE_HOST_DEVICE mdiv(uint32_t divisor_) : divisor(divisor_)
|
||||
{
|
||||
auto tmp = magic_division::calculate_magic_numbers(divisor_);
|
||||
|
||||
multiplier = tmp[number<0>{}];
|
||||
shift = tmp[number<1>{}];
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE mdiv() : divisor(0), multiplier(0), shift(0) {}
|
||||
|
||||
CK_TILE_HOST_DEVICE void update(uint32_t divisor_)
|
||||
{
|
||||
divisor = divisor_;
|
||||
auto tmp = magic_division::calculate_magic_numbers(divisor_);
|
||||
|
||||
multiplier = tmp[number<0>{}];
|
||||
shift = tmp[number<1>{}];
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE uint32_t div(uint32_t dividend_) const
|
||||
{
|
||||
return magic_division::do_magic_division(dividend_, multiplier, shift);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void
|
||||
divmod(uint32_t dividend_, uint32_t& quotient_, uint32_t& remainder_) const
|
||||
{
|
||||
quotient_ = div(dividend_);
|
||||
remainder_ = dividend_ - (quotient_ * divisor);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE uint32_t get() const { return divisor; }
|
||||
};
|
||||
|
||||
struct mdiv2
|
||||
{
|
||||
// 1 dword -> 2 dword storage, divisor need compute from runtime
|
||||
uint32_t multiplier;
|
||||
uint32_t shift; // TODO: 8 bit is enough
|
||||
|
||||
// prefer construct on host
|
||||
CK_TILE_HOST_DEVICE mdiv2(uint32_t divisor_)
|
||||
{
|
||||
auto tmp = magic_division::calculate_magic_numbers(divisor_);
|
||||
|
||||
multiplier = tmp[number<0>{}];
|
||||
shift = tmp[number<1>{}];
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE mdiv2() : multiplier(0), shift(0) {}
|
||||
|
||||
CK_TILE_HOST_DEVICE uint32_t div(uint32_t dividend_) const
|
||||
{
|
||||
return magic_division::do_magic_division(dividend_, multiplier, shift);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void
|
||||
divmod(uint32_t dividend_, uint32_t divisor_, uint32_t& quotient_, uint32_t& remainder_) const
|
||||
{
|
||||
quotient_ = div(dividend_);
|
||||
remainder_ = dividend_ - (quotient_ * divisor_);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
64
include/ck_tile/core/utility/random.hpp
Normal file
64
include/ck_tile/core/utility/random.hpp
Normal file
@@ -0,0 +1,64 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/half.hpp"
|
||||
#include <stdint.h>
|
||||
#include <tuple>
|
||||
#include <type_traits>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// return 0 if data is not fp16 or fp32
|
||||
template <typename T, uint32_t seed_>
|
||||
struct prand_generator_t
|
||||
{
|
||||
CK_TILE_HOST_DEVICE uint32_t operator()(int id, T val, uint32_t seed = seed_)
|
||||
{
|
||||
std::ignore = id;
|
||||
std::ignore = val;
|
||||
std::ignore = seed;
|
||||
return 0;
|
||||
}
|
||||
};
|
||||
|
||||
// version for fp32
|
||||
template <uint32_t seed_>
|
||||
struct prand_generator_t<float, seed_>
|
||||
{
|
||||
CK_TILE_HOST_DEVICE uint32_t operator()(int id, float val, uint32_t seed = seed_)
|
||||
{
|
||||
uint32_t x = *(reinterpret_cast<uint32_t*>(&val));
|
||||
uint32_t drop_bits = uint32_t(x) & 0xFFFFu;
|
||||
drop_bits ^= x >> 16;
|
||||
drop_bits = ((drop_bits & 31) << 11) | (drop_bits >> 5);
|
||||
drop_bits *= 0x7000149;
|
||||
// NOTE: If id is in 64 bit, we are only using lower 32 bit.
|
||||
// So, it can have an effect of using same id for multiple elements when the id is
|
||||
// very large!
|
||||
uint32_t rng = (drop_bits ^ 0x13371337 ^ (id * 229791) ^ seed);
|
||||
return rng;
|
||||
}
|
||||
};
|
||||
|
||||
// version for fp16
|
||||
template <uint32_t seed_>
|
||||
struct prand_generator_t<half_t, seed_>
|
||||
{
|
||||
CK_TILE_HOST_DEVICE uint32_t operator()(int id, half_t val, uint32_t seed = seed_)
|
||||
{
|
||||
uint16_t x = *(reinterpret_cast<uint16_t*>(&val));
|
||||
uint32_t drop_bits = uint32_t(x) & 0xFFFFu;
|
||||
drop_bits = ((drop_bits & 31) << 11) | (drop_bits >> 5);
|
||||
drop_bits *= 0x7000149;
|
||||
// NOTE: If id is in 64 bit, we are only using lower 32 bit.
|
||||
// So, it can have an effect of using same id for multiple elements when the id is
|
||||
// very large!
|
||||
uint32_t rng = (drop_bits ^ 0x13371337 ^ (id * 229791) ^ seed);
|
||||
return rng;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
72
include/ck_tile/core/utility/to_sequence.hpp
Normal file
72
include/ck_tile/core/utility/to_sequence.hpp
Normal file
@@ -0,0 +1,72 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
#include "ck_tile/core/container/sequence.hpp"
|
||||
// TODO: use c++20 nontype template with struct to implement this
|
||||
|
||||
#if 1
|
||||
// clang happen to support this feature (__cpp_generic_lambdas >= 201707) in c++17 mode
|
||||
#define TO_SEQUENCE(a, n) \
|
||||
_Pragma("clang diagnostic push") \
|
||||
_Pragma("clang diagnostic ignored \"-Wc++20-extensions\"")[a]<ck_tile::index_t... Is>( \
|
||||
ck_tile::sequence<Is...>) \
|
||||
{ \
|
||||
return ck_tile::sequence<a.at(ck_tile::number<Is>{})...>{}; \
|
||||
} \
|
||||
(make_index_sequence<n>{}) _Pragma("clang diagnostic pop")
|
||||
|
||||
#else
|
||||
// Macro function
|
||||
// convert constexpr array to sequence, both a/n need to be constexpr (can't be a rvalue like 2)
|
||||
#define TO_SEQUENCE(a, n) \
|
||||
[a, n] { \
|
||||
static_assert(a.size() >= n, "wrong! out of bound"); \
|
||||
static_assert(n <= 10, "not implemented"); \
|
||||
if constexpr(n == 0) \
|
||||
{ \
|
||||
return ck_tile::sequence<>{}; \
|
||||
} \
|
||||
else if constexpr(n == 1) \
|
||||
{ \
|
||||
return ck_tile::sequence<a[0]>{}; \
|
||||
} \
|
||||
else if constexpr(n == 2) \
|
||||
{ \
|
||||
return ck_tile::sequence<a[0], a[1]>{}; \
|
||||
} \
|
||||
else if constexpr(n == 3) \
|
||||
{ \
|
||||
return ck_tile::sequence<a[0], a[1], a[2]>{}; \
|
||||
} \
|
||||
else if constexpr(n == 4) \
|
||||
{ \
|
||||
return ck_tile::sequence<a[0], a[1], a[2], a[3]>{}; \
|
||||
} \
|
||||
else if constexpr(n == 5) \
|
||||
{ \
|
||||
return ck_tile::sequence<a[0], a[1], a[2], a[3], a[4]>{}; \
|
||||
} \
|
||||
else if constexpr(n == 6) \
|
||||
{ \
|
||||
return ck_tile::sequence<a[0], a[1], a[2], a[3], a[4], a[5]>{}; \
|
||||
} \
|
||||
else if constexpr(n == 7) \
|
||||
{ \
|
||||
return ck_tile::sequence<a[0], a[1], a[2], a[3], a[4], a[5], a[6]>{}; \
|
||||
} \
|
||||
else if constexpr(n == 8) \
|
||||
{ \
|
||||
return ck_tile::sequence<a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7]>{}; \
|
||||
} \
|
||||
else if constexpr(n == 9) \
|
||||
{ \
|
||||
return ck_tile::sequence<a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], a[8]>{}; \
|
||||
} \
|
||||
else if constexpr(n == 10) \
|
||||
{ \
|
||||
return ck_tile:: \
|
||||
sequence<a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], a[8], a[9]>{}; \
|
||||
} \
|
||||
}()
|
||||
#endif
|
||||
57
include/ck_tile/core/utility/type_convert.hpp
Normal file
57
include/ck_tile/core/utility/type_convert.hpp
Normal file
@@ -0,0 +1,57 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <stdint.h>
|
||||
#include <tuple>
|
||||
#include <type_traits>
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/half.hpp"
|
||||
#include "ck_tile/core/numeric/bfloat16.hpp"
|
||||
#include "ck_tile/core/numeric/float8.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// Convert X to Y, both X and Y are non-const data types.
|
||||
template <typename Y,
|
||||
typename X,
|
||||
std::enable_if_t<!(std::is_const_v<Y> || std::is_const_v<X>), bool> = false>
|
||||
CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
|
||||
{
|
||||
static_assert(!std::is_reference_v<Y> && !std::is_reference_v<X>);
|
||||
|
||||
return static_cast<Y>(x);
|
||||
}
|
||||
|
||||
// Convert X to Y, either X or Y is a const data type.
|
||||
template <typename Y,
|
||||
typename X,
|
||||
std::enable_if_t<std::is_const_v<Y> || std::is_const_v<X>, bool> = false>
|
||||
CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
|
||||
{
|
||||
static_assert(!std::is_reference_v<Y> && !std::is_reference_v<X>);
|
||||
|
||||
using non_const_y = std::remove_const_t<Y>;
|
||||
using non_const_x = std::remove_const_t<X>;
|
||||
return static_cast<Y>(type_convert<non_const_y, non_const_x>(x));
|
||||
}
|
||||
|
||||
#define CK_TILE_TYPE_CONVERT(dtype_, stype_) \
|
||||
template <> \
|
||||
inline CK_TILE_HOST_DEVICE constexpr dtype_ type_convert<dtype_, stype_>(stype_ x) \
|
||||
{ \
|
||||
return stype_##_to_##dtype_(x); \
|
||||
}
|
||||
|
||||
CK_TILE_TYPE_CONVERT(float, fp16_t)
|
||||
CK_TILE_TYPE_CONVERT(float, bf16_t)
|
||||
CK_TILE_TYPE_CONVERT(float, fp8_t)
|
||||
CK_TILE_TYPE_CONVERT(float, bf8_t)
|
||||
|
||||
CK_TILE_TYPE_CONVERT(fp16_t, float)
|
||||
CK_TILE_TYPE_CONVERT(bf16_t, float)
|
||||
CK_TILE_TYPE_CONVERT(fp8_t, float)
|
||||
CK_TILE_TYPE_CONVERT(bf8_t, float)
|
||||
|
||||
} // namespace ck_tile
|
||||
46
include/ck_tile/core/utility/type_traits.hpp
Normal file
46
include/ck_tile/core/utility/type_traits.hpp
Normal file
@@ -0,0 +1,46 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include <type_traits>
|
||||
#include <stdint.h>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// remove_cvref_t
|
||||
template <typename T>
|
||||
using remove_reference_t = typename std::remove_reference<T>::type;
|
||||
|
||||
template <typename T>
|
||||
using remove_cv_t = typename std::remove_cv<T>::type;
|
||||
|
||||
template <typename T>
|
||||
using remove_cvref_t = remove_cv_t<std::remove_reference_t<T>>;
|
||||
|
||||
template <typename T>
|
||||
using remove_pointer_t = typename std::remove_pointer<T>::type;
|
||||
|
||||
namespace impl {
|
||||
template <typename T>
|
||||
struct is_static_impl
|
||||
{
|
||||
static constexpr bool value = std::is_arithmetic<T>::v ? false : T::is_static();
|
||||
};
|
||||
} // namespace impl
|
||||
|
||||
template <typename T>
|
||||
using is_static = impl::is_static_impl<remove_cvref_t<T>>;
|
||||
|
||||
template <typename T>
|
||||
inline constexpr bool is_static_v = is_static<T>::value;
|
||||
|
||||
// TODO: deprecate this
|
||||
template <typename T>
|
||||
using is_known_at_compile_time = is_static<T>;
|
||||
// TODO: if evaluating a rvalue, e.g. a const integer
|
||||
// , this helper will also return false, which is not good(?)
|
||||
// do we need something like is_constexpr()?
|
||||
|
||||
} // namespace ck_tile
|
||||
Reference in New Issue
Block a user