unify as tuple_array

This commit is contained in:
carlushuang
2024-03-06 18:36:45 +00:00
parent 7df3947819
commit 26a25eb4cd
6 changed files with 96 additions and 36 deletions

View File

@@ -15,10 +15,12 @@
#include "ck_tile/core/container/map.hpp"
#include "ck_tile/core/container/meta_data_buffer.hpp"
#include "ck_tile/core/container/multi_index.hpp"
#include "ck_tile/core/container/rbuffer.hpp"
#include "ck_tile/core/container/sequence.hpp"
#include "ck_tile/core/container/span.hpp"
#include "ck_tile/core/container/statically_indexed_array.hpp"
#include "ck_tile/core/container/tuple.hpp"
#include "ck_tile/core/container/tuple_array.hpp"
#include "ck_tile/core/numeric/arithmetic.hpp"
#include "ck_tile/core/numeric/bfloat16.hpp"
#include "ck_tile/core/numeric/float8.hpp"

View File

@@ -51,6 +51,12 @@
#define CK_TILE_STATICALLY_INDEXED_ARRAY_DEFAULT CK_TILE_STATICALLY_INDEXED_ARRAY_USE_TUPLE
#endif
#define CK_TILE_RBUFFER_USE_ARRAY 0
#define CK_TILE_RBUFFER_USE_TUPLE 1
#ifndef CK_TILE_RBUFFER_DEFAULT
#define CK_TILE_RBUFFER_DEFAULT CK_TILE_RBUFFER_USE_TUPLE
#endif
#ifndef CK_TILE_USE_LAUNCH_BOUNDS
#define CK_TILE_USE_LAUNCH_BOUNDS 1
#endif

View File

@@ -0,0 +1,20 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/array.hpp"
#include "ck_tile/core/container/tuple_array.hpp"
namespace ck_tile {
#if CK_TILE_RBUFFER_DEFAULT == CK_TILE_RBUFFER_USE_TUPLE
template <typename T, index_t N>
using rbuffer = tuple_array<T, N>;
#else
template <typename T, index_t N>
using rbuffer = array<T, N>
#endif
} // namespace ck_tile

View File

@@ -5,44 +5,15 @@
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/array.hpp"
#include "ck_tile/core/container/tuple_array.hpp"
#include "ck_tile/core/numeric/integer.hpp"
namespace ck_tile {
#if CK_TILE_STATICALLY_INDEXED_ARRAY_DEFAULT == CK_TILE_STATICALLY_INDEXED_ARRAY_USE_TUPLE
namespace detail {
template <typename X, typename Y>
struct tuple_concat;
template <typename... Xs, typename... Ys>
struct tuple_concat<tuple<Xs...>, tuple<Ys...>>
{
using type = tuple<Xs..., Ys...>;
};
template <typename T, index_t N>
struct statically_indexed_array_impl
{
using type =
typename tuple_concat<typename statically_indexed_array_impl<T, N / 2>::type,
typename statically_indexed_array_impl<T, N - N / 2>::type>::type;
};
template <typename T>
struct statically_indexed_array_impl<T, 0>
{
using type = tuple<>;
};
template <typename T>
struct statically_indexed_array_impl<T, 1>
{
using type = tuple<T>;
};
} // namespace detail
template <typename T, index_t N>
using statically_indexed_array = typename detail::statically_indexed_array_impl<T, N>::type;
using statically_indexed_array = tuple_array<T, N>;
#else
@@ -53,7 +24,7 @@ using statically_indexed_array = array<T, N>;
#endif
// consider always use ck_tile::array for this purpose
#if 0
template <typename X, typename... Xs>
CK_TILE_HOST_DEVICE constexpr auto make_statically_indexed_array(const X& x, const Xs&... xs)
{
@@ -66,5 +37,5 @@ CK_TILE_HOST_DEVICE constexpr auto make_statically_indexed_array()
{
return statically_indexed_array<X, 0>();
}
#endif
} // namespace ck_tile

View File

@@ -0,0 +1,60 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/array.hpp"
#include "ck_tile/core/container/tuple.hpp"
#include "ck_tile/core/numeric/integer.hpp"
namespace ck_tile {
namespace detail {
template <typename T, index_t N>
struct tuple_array_impl
{
using type = typename tuple_concat<typename tuple_array_impl<T, N / 2>::type,
typename tuple_array_impl<T, N - N / 2>::type>::type;
};
template <typename T>
struct tuple_array_impl<T, 0>
{
using type = tuple<>;
};
template <typename T>
struct tuple_array_impl<T, 1>
{
using type = tuple<T>;
};
} // namespace detail
template <typename T, index_t N>
using tuple_array_base_t = typename detail::tuple_array_impl<T, N>::type;
template <typename T_, index_t N_>
struct tuple_array : tuple_array_base_t<T_, N_>
{
using value_type = T_;
static constexpr index_t N = N_;
// clang-format off
#define TA_COM_() static_assert(sizeof(value_type) * N % sizeof(Tx) == 0); constexpr int vx = sizeof(value_type) * N / sizeof(Tx)
template <typename Tx> CK_TILE_HOST_DEVICE constexpr decltype(auto) get_as() { TA_COM_(); return reinterpret_cast<tuple_array<Tx, vx>&>(*this); }
template <typename Tx> CK_TILE_HOST_DEVICE constexpr decltype(auto) get_as() const { TA_COM_(); return reinterpret_cast<const tuple_array<Tx, vx>&>(*this); }
// below index is for index *AFTER* type convert, not before
template <typename Tx> CK_TILE_HOST_DEVICE constexpr decltype(auto) get_as(index_t i) { TA_COM_(); return reinterpret_cast<tuple_array<Tx, vx>&>(*this).at(i); }
template <typename Tx> CK_TILE_HOST_DEVICE constexpr decltype(auto) get_as(index_t i) const { TA_COM_(); return reinterpret_cast<const tuple_array<Tx, vx>&>(*this).at(i); }
template <typename Tx, index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get_as(number<I>) { TA_COM_(); return reinterpret_cast<tuple_array<Tx, vx>&>(*this).at(number<I>{}); }
template <typename Tx, index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get_as(number<I>) const { TA_COM_(); return reinterpret_cast<const tuple_array<Tx, vx>&>(*this).at(number<I>{}); }
template <typename Tx> CK_TILE_HOST_DEVICE constexpr void set_as(index_t i, const Tx & x) { TA_COM_(); reinterpret_cast<tuple_array<Tx, vx>&>(*this).at(i) = x; }
template <typename Tx, index_t I> CK_TILE_HOST_DEVICE constexpr void set_as(number<I>, const Tx & x) { TA_COM_(); reinterpret_cast<tuple_array<Tx, vx>&>(*this).at(number<I>{}) = x; }
#undef TA_COM_
// clang-format on
};
} // namespace ck_tile

View File

@@ -12,6 +12,7 @@
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include "ck_tile/core/tensor/tile_distribution.hpp"
#include "ck_tile/core/container/rbuffer.hpp"
namespace ck_tile {
@@ -71,7 +72,7 @@ struct static_distributed_tensor
constexpr auto sliced_thread_tensor_desc =
make_naive_tensor_descriptor_packed(make_tuple(YSliceLengths...));
array<DataType, sliced_thread_tensor_desc.get_element_space_size()> sliced_thread_data;
rbuffer<DataType, sliced_thread_tensor_desc.get_element_space_size()> sliced_thread_data;
static_ford<sequence<YSliceLengths...>>{}([&](auto idx) {
constexpr auto idx_ys = idx + sequence<YSliceOrigins...>{};
@@ -87,7 +88,7 @@ struct static_distributed_tensor
CK_TILE_HOST_DEVICE void
set_y_sliced_thread_data(sequence<YSliceOrigins...>,
sequence<YSliceLengths...>,
const array<DataType, NSlicedData>& sliced_thread_data)
const rbuffer<DataType, NSlicedData>& sliced_thread_data)
{
static_assert(sizeof...(YSliceOrigins) == StaticTileDistribution::NDimY &&
sizeof...(YSliceLengths) == StaticTileDistribution::NDimY,
@@ -129,7 +130,7 @@ struct static_distributed_tensor
}
//
array<DataType, kThreadElementSpaceSize> thread_buf_;
rbuffer<DataType, kThreadElementSpaceSize> thread_buf_;
};
template <typename DataType, typename StaticTileDistribution>