mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 01:10:17 +00:00
unify as tuple_array
This commit is contained in:
@@ -15,10 +15,12 @@
|
||||
#include "ck_tile/core/container/map.hpp"
|
||||
#include "ck_tile/core/container/meta_data_buffer.hpp"
|
||||
#include "ck_tile/core/container/multi_index.hpp"
|
||||
#include "ck_tile/core/container/rbuffer.hpp"
|
||||
#include "ck_tile/core/container/sequence.hpp"
|
||||
#include "ck_tile/core/container/span.hpp"
|
||||
#include "ck_tile/core/container/statically_indexed_array.hpp"
|
||||
#include "ck_tile/core/container/tuple.hpp"
|
||||
#include "ck_tile/core/container/tuple_array.hpp"
|
||||
#include "ck_tile/core/numeric/arithmetic.hpp"
|
||||
#include "ck_tile/core/numeric/bfloat16.hpp"
|
||||
#include "ck_tile/core/numeric/float8.hpp"
|
||||
|
||||
@@ -51,6 +51,12 @@
|
||||
#define CK_TILE_STATICALLY_INDEXED_ARRAY_DEFAULT CK_TILE_STATICALLY_INDEXED_ARRAY_USE_TUPLE
|
||||
#endif
|
||||
|
||||
#define CK_TILE_RBUFFER_USE_ARRAY 0
|
||||
#define CK_TILE_RBUFFER_USE_TUPLE 1
|
||||
#ifndef CK_TILE_RBUFFER_DEFAULT
|
||||
#define CK_TILE_RBUFFER_DEFAULT CK_TILE_RBUFFER_USE_TUPLE
|
||||
#endif
|
||||
|
||||
#ifndef CK_TILE_USE_LAUNCH_BOUNDS
|
||||
#define CK_TILE_USE_LAUNCH_BOUNDS 1
|
||||
#endif
|
||||
|
||||
20
include/ck_tile/core/container/rbuffer.hpp
Normal file
20
include/ck_tile/core/container/rbuffer.hpp
Normal file
@@ -0,0 +1,20 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/container/array.hpp"
|
||||
#include "ck_tile/core/container/tuple_array.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
#if CK_TILE_RBUFFER_DEFAULT == CK_TILE_RBUFFER_USE_TUPLE
|
||||
template <typename T, index_t N>
|
||||
using rbuffer = tuple_array<T, N>;
|
||||
#else
|
||||
template <typename T, index_t N>
|
||||
using rbuffer = array<T, N>
|
||||
#endif
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -5,44 +5,15 @@
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/container/array.hpp"
|
||||
#include "ck_tile/core/container/tuple_array.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
#if CK_TILE_STATICALLY_INDEXED_ARRAY_DEFAULT == CK_TILE_STATICALLY_INDEXED_ARRAY_USE_TUPLE
|
||||
namespace detail {
|
||||
template <typename X, typename Y>
|
||||
struct tuple_concat;
|
||||
|
||||
template <typename... Xs, typename... Ys>
|
||||
struct tuple_concat<tuple<Xs...>, tuple<Ys...>>
|
||||
{
|
||||
using type = tuple<Xs..., Ys...>;
|
||||
};
|
||||
|
||||
template <typename T, index_t N>
|
||||
struct statically_indexed_array_impl
|
||||
{
|
||||
using type =
|
||||
typename tuple_concat<typename statically_indexed_array_impl<T, N / 2>::type,
|
||||
typename statically_indexed_array_impl<T, N - N / 2>::type>::type;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct statically_indexed_array_impl<T, 0>
|
||||
{
|
||||
using type = tuple<>;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct statically_indexed_array_impl<T, 1>
|
||||
{
|
||||
using type = tuple<T>;
|
||||
};
|
||||
} // namespace detail
|
||||
|
||||
template <typename T, index_t N>
|
||||
using statically_indexed_array = typename detail::statically_indexed_array_impl<T, N>::type;
|
||||
using statically_indexed_array = tuple_array<T, N>;
|
||||
|
||||
#else
|
||||
|
||||
@@ -53,7 +24,7 @@ using statically_indexed_array = array<T, N>;
|
||||
#endif
|
||||
|
||||
// consider always use ck_tile::array for this purpose
|
||||
|
||||
#if 0
|
||||
template <typename X, typename... Xs>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_statically_indexed_array(const X& x, const Xs&... xs)
|
||||
{
|
||||
@@ -66,5 +37,5 @@ CK_TILE_HOST_DEVICE constexpr auto make_statically_indexed_array()
|
||||
{
|
||||
return statically_indexed_array<X, 0>();
|
||||
}
|
||||
|
||||
#endif
|
||||
} // namespace ck_tile
|
||||
|
||||
60
include/ck_tile/core/container/tuple_array.hpp
Normal file
60
include/ck_tile/core/container/tuple_array.hpp
Normal file
@@ -0,0 +1,60 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/container/array.hpp"
|
||||
#include "ck_tile/core/container/tuple.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
namespace detail {
|
||||
template <typename T, index_t N>
|
||||
struct tuple_array_impl
|
||||
{
|
||||
using type = typename tuple_concat<typename tuple_array_impl<T, N / 2>::type,
|
||||
typename tuple_array_impl<T, N - N / 2>::type>::type;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct tuple_array_impl<T, 0>
|
||||
{
|
||||
using type = tuple<>;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct tuple_array_impl<T, 1>
|
||||
{
|
||||
using type = tuple<T>;
|
||||
};
|
||||
} // namespace detail
|
||||
|
||||
template <typename T, index_t N>
|
||||
using tuple_array_base_t = typename detail::tuple_array_impl<T, N>::type;
|
||||
|
||||
template <typename T_, index_t N_>
|
||||
struct tuple_array : tuple_array_base_t<T_, N_>
|
||||
{
|
||||
using value_type = T_;
|
||||
static constexpr index_t N = N_;
|
||||
|
||||
// clang-format off
|
||||
#define TA_COM_() static_assert(sizeof(value_type) * N % sizeof(Tx) == 0); constexpr int vx = sizeof(value_type) * N / sizeof(Tx)
|
||||
template <typename Tx> CK_TILE_HOST_DEVICE constexpr decltype(auto) get_as() { TA_COM_(); return reinterpret_cast<tuple_array<Tx, vx>&>(*this); }
|
||||
template <typename Tx> CK_TILE_HOST_DEVICE constexpr decltype(auto) get_as() const { TA_COM_(); return reinterpret_cast<const tuple_array<Tx, vx>&>(*this); }
|
||||
|
||||
// below index is for index *AFTER* type convert, not before
|
||||
template <typename Tx> CK_TILE_HOST_DEVICE constexpr decltype(auto) get_as(index_t i) { TA_COM_(); return reinterpret_cast<tuple_array<Tx, vx>&>(*this).at(i); }
|
||||
template <typename Tx> CK_TILE_HOST_DEVICE constexpr decltype(auto) get_as(index_t i) const { TA_COM_(); return reinterpret_cast<const tuple_array<Tx, vx>&>(*this).at(i); }
|
||||
template <typename Tx, index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get_as(number<I>) { TA_COM_(); return reinterpret_cast<tuple_array<Tx, vx>&>(*this).at(number<I>{}); }
|
||||
template <typename Tx, index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get_as(number<I>) const { TA_COM_(); return reinterpret_cast<const tuple_array<Tx, vx>&>(*this).at(number<I>{}); }
|
||||
|
||||
template <typename Tx> CK_TILE_HOST_DEVICE constexpr void set_as(index_t i, const Tx & x) { TA_COM_(); reinterpret_cast<tuple_array<Tx, vx>&>(*this).at(i) = x; }
|
||||
template <typename Tx, index_t I> CK_TILE_HOST_DEVICE constexpr void set_as(number<I>, const Tx & x) { TA_COM_(); reinterpret_cast<tuple_array<Tx, vx>&>(*this).at(number<I>{}) = x; }
|
||||
#undef TA_COM_
|
||||
// clang-format on
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -12,6 +12,7 @@
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
#include "ck_tile/core/tensor/tile_distribution.hpp"
|
||||
#include "ck_tile/core/container/rbuffer.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -71,7 +72,7 @@ struct static_distributed_tensor
|
||||
constexpr auto sliced_thread_tensor_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(YSliceLengths...));
|
||||
|
||||
array<DataType, sliced_thread_tensor_desc.get_element_space_size()> sliced_thread_data;
|
||||
rbuffer<DataType, sliced_thread_tensor_desc.get_element_space_size()> sliced_thread_data;
|
||||
|
||||
static_ford<sequence<YSliceLengths...>>{}([&](auto idx) {
|
||||
constexpr auto idx_ys = idx + sequence<YSliceOrigins...>{};
|
||||
@@ -87,7 +88,7 @@ struct static_distributed_tensor
|
||||
CK_TILE_HOST_DEVICE void
|
||||
set_y_sliced_thread_data(sequence<YSliceOrigins...>,
|
||||
sequence<YSliceLengths...>,
|
||||
const array<DataType, NSlicedData>& sliced_thread_data)
|
||||
const rbuffer<DataType, NSlicedData>& sliced_thread_data)
|
||||
{
|
||||
static_assert(sizeof...(YSliceOrigins) == StaticTileDistribution::NDimY &&
|
||||
sizeof...(YSliceLengths) == StaticTileDistribution::NDimY,
|
||||
@@ -129,7 +130,7 @@ struct static_distributed_tensor
|
||||
}
|
||||
|
||||
//
|
||||
array<DataType, kThreadElementSpaceSize> thread_buf_;
|
||||
rbuffer<DataType, kThreadElementSpaceSize> thread_buf_;
|
||||
};
|
||||
|
||||
template <typename DataType, typename StaticTileDistribution>
|
||||
|
||||
Reference in New Issue
Block a user