diff --git a/include/ck_tile/core.hpp b/include/ck_tile/core.hpp index f53a6b0fd6..3877b5ceed 100644 --- a/include/ck_tile/core.hpp +++ b/include/ck_tile/core.hpp @@ -15,10 +15,12 @@ #include "ck_tile/core/container/map.hpp" #include "ck_tile/core/container/meta_data_buffer.hpp" #include "ck_tile/core/container/multi_index.hpp" +#include "ck_tile/core/container/rbuffer.hpp" #include "ck_tile/core/container/sequence.hpp" #include "ck_tile/core/container/span.hpp" #include "ck_tile/core/container/statically_indexed_array.hpp" #include "ck_tile/core/container/tuple.hpp" +#include "ck_tile/core/container/tuple_array.hpp" #include "ck_tile/core/numeric/arithmetic.hpp" #include "ck_tile/core/numeric/bfloat16.hpp" #include "ck_tile/core/numeric/float8.hpp" diff --git a/include/ck_tile/core/config.hpp b/include/ck_tile/core/config.hpp index a2362fb46b..cfb94ea2f3 100644 --- a/include/ck_tile/core/config.hpp +++ b/include/ck_tile/core/config.hpp @@ -51,6 +51,12 @@ #define CK_TILE_STATICALLY_INDEXED_ARRAY_DEFAULT CK_TILE_STATICALLY_INDEXED_ARRAY_USE_TUPLE #endif +#define CK_TILE_RBUFFER_USE_ARRAY 0 +#define CK_TILE_RBUFFER_USE_TUPLE 1 +#ifndef CK_TILE_RBUFFER_DEFAULT +#define CK_TILE_RBUFFER_DEFAULT CK_TILE_RBUFFER_USE_TUPLE +#endif + #ifndef CK_TILE_USE_LAUNCH_BOUNDS #define CK_TILE_USE_LAUNCH_BOUNDS 1 #endif diff --git a/include/ck_tile/core/container/rbuffer.hpp b/include/ck_tile/core/container/rbuffer.hpp new file mode 100644 index 0000000000..d44eacf2f7 --- /dev/null +++ b/include/ck_tile/core/container/rbuffer.hpp @@ -0,0 +1,20 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/container/array.hpp" +#include "ck_tile/core/container/tuple_array.hpp" + +namespace ck_tile { + +#if CK_TILE_RBUFFER_DEFAULT == CK_TILE_RBUFFER_USE_TUPLE +template +using rbuffer = tuple_array; +#else +template +using rbuffer = array +#endif + +} // namespace ck_tile diff --git a/include/ck_tile/core/container/statically_indexed_array.hpp b/include/ck_tile/core/container/statically_indexed_array.hpp index 1542ad0768..4e1af96cbd 100644 --- a/include/ck_tile/core/container/statically_indexed_array.hpp +++ b/include/ck_tile/core/container/statically_indexed_array.hpp @@ -5,44 +5,15 @@ #include "ck_tile/core/config.hpp" #include "ck_tile/core/container/array.hpp" +#include "ck_tile/core/container/tuple_array.hpp" #include "ck_tile/core/numeric/integer.hpp" namespace ck_tile { #if CK_TILE_STATICALLY_INDEXED_ARRAY_DEFAULT == CK_TILE_STATICALLY_INDEXED_ARRAY_USE_TUPLE -namespace detail { -template -struct tuple_concat; - -template -struct tuple_concat, tuple> -{ - using type = tuple; -}; template -struct statically_indexed_array_impl -{ - using type = - typename tuple_concat::type, - typename statically_indexed_array_impl::type>::type; -}; - -template -struct statically_indexed_array_impl -{ - using type = tuple<>; -}; - -template -struct statically_indexed_array_impl -{ - using type = tuple; -}; -} // namespace detail - -template -using statically_indexed_array = typename detail::statically_indexed_array_impl::type; +using statically_indexed_array = tuple_array; #else @@ -53,7 +24,7 @@ using statically_indexed_array = array; #endif // consider always use ck_tile::array for this purpose - +#if 0 template CK_TILE_HOST_DEVICE constexpr auto make_statically_indexed_array(const X& x, const Xs&... xs) { @@ -66,5 +37,5 @@ CK_TILE_HOST_DEVICE constexpr auto make_statically_indexed_array() { return statically_indexed_array(); } - +#endif } // namespace ck_tile diff --git a/include/ck_tile/core/container/tuple_array.hpp b/include/ck_tile/core/container/tuple_array.hpp new file mode 100644 index 0000000000..27fe9d8f5d --- /dev/null +++ b/include/ck_tile/core/container/tuple_array.hpp @@ -0,0 +1,60 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/container/array.hpp" +#include "ck_tile/core/container/tuple.hpp" +#include "ck_tile/core/numeric/integer.hpp" + +namespace ck_tile { + +namespace detail { +template +struct tuple_array_impl +{ + using type = typename tuple_concat::type, + typename tuple_array_impl::type>::type; +}; + +template +struct tuple_array_impl +{ + using type = tuple<>; +}; + +template +struct tuple_array_impl +{ + using type = tuple; +}; +} // namespace detail + +template +using tuple_array_base_t = typename detail::tuple_array_impl::type; + +template +struct tuple_array : tuple_array_base_t +{ + using value_type = T_; + static constexpr index_t N = N_; + +// clang-format off +#define TA_COM_() static_assert(sizeof(value_type) * N % sizeof(Tx) == 0); constexpr int vx = sizeof(value_type) * N / sizeof(Tx) + template CK_TILE_HOST_DEVICE constexpr decltype(auto) get_as() { TA_COM_(); return reinterpret_cast&>(*this); } + template CK_TILE_HOST_DEVICE constexpr decltype(auto) get_as() const { TA_COM_(); return reinterpret_cast&>(*this); } + + // below index is for index *AFTER* type convert, not before + template CK_TILE_HOST_DEVICE constexpr decltype(auto) get_as(index_t i) { TA_COM_(); return reinterpret_cast&>(*this).at(i); } + template CK_TILE_HOST_DEVICE constexpr decltype(auto) get_as(index_t i) const { TA_COM_(); return reinterpret_cast&>(*this).at(i); } + template CK_TILE_HOST_DEVICE constexpr decltype(auto) get_as(number) { TA_COM_(); return reinterpret_cast&>(*this).at(number{}); } + template CK_TILE_HOST_DEVICE constexpr decltype(auto) get_as(number) const { TA_COM_(); return reinterpret_cast&>(*this).at(number{}); } + + template CK_TILE_HOST_DEVICE constexpr void set_as(index_t i, const Tx & x) { TA_COM_(); reinterpret_cast&>(*this).at(i) = x; } + template CK_TILE_HOST_DEVICE constexpr void set_as(number, const Tx & x) { TA_COM_(); reinterpret_cast&>(*this).at(number{}) = x; } +#undef TA_COM_ + // clang-format on +}; + +} // namespace ck_tile diff --git a/include/ck_tile/core/tensor/static_distributed_tensor.hpp b/include/ck_tile/core/tensor/static_distributed_tensor.hpp index 0c9e0debb1..38f6f21770 100644 --- a/include/ck_tile/core/tensor/static_distributed_tensor.hpp +++ b/include/ck_tile/core/tensor/static_distributed_tensor.hpp @@ -12,6 +12,7 @@ #include "ck_tile/core/utility/functional.hpp" #include "ck_tile/core/utility/type_traits.hpp" #include "ck_tile/core/tensor/tile_distribution.hpp" +#include "ck_tile/core/container/rbuffer.hpp" namespace ck_tile { @@ -71,7 +72,7 @@ struct static_distributed_tensor constexpr auto sliced_thread_tensor_desc = make_naive_tensor_descriptor_packed(make_tuple(YSliceLengths...)); - array sliced_thread_data; + rbuffer sliced_thread_data; static_ford>{}([&](auto idx) { constexpr auto idx_ys = idx + sequence{}; @@ -87,7 +88,7 @@ struct static_distributed_tensor CK_TILE_HOST_DEVICE void set_y_sliced_thread_data(sequence, sequence, - const array& sliced_thread_data) + const rbuffer& sliced_thread_data) { static_assert(sizeof...(YSliceOrigins) == StaticTileDistribution::NDimY && sizeof...(YSliceLengths) == StaticTileDistribution::NDimY, @@ -129,7 +130,7 @@ struct static_distributed_tensor } // - array thread_buf_; + rbuffer thread_buf_; }; template