mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 17:26:00 +00:00
re-structure tuple/array to avoid spill
This commit is contained in:
@@ -15,12 +15,11 @@
|
||||
#include "ck_tile/core/container/map.hpp"
|
||||
#include "ck_tile/core/container/meta_data_buffer.hpp"
|
||||
#include "ck_tile/core/container/multi_index.hpp"
|
||||
#include "ck_tile/core/container/rbuffer.hpp"
|
||||
#include "ck_tile/core/container/sequence.hpp"
|
||||
#include "ck_tile/core/container/span.hpp"
|
||||
#include "ck_tile/core/container/statically_indexed_array.hpp"
|
||||
#include "ck_tile/core/container/thread_buffer.hpp"
|
||||
#include "ck_tile/core/container/tuple.hpp"
|
||||
#include "ck_tile/core/container/tuple_array.hpp"
|
||||
#include "ck_tile/core/numeric/arithmetic.hpp"
|
||||
#include "ck_tile/core/numeric/bfloat16.hpp"
|
||||
#include "ck_tile/core/numeric/float8.hpp"
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/numeric/vector_type.hpp"
|
||||
#include "ck_tile/core/container/container_helper.hpp"
|
||||
#include "ck_tile/core/container/thread_buffer.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
#include "ck_tile/core/utility/bit_cast.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
@@ -1475,7 +1476,7 @@ template <typename T,
|
||||
index_t N,
|
||||
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default,
|
||||
bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE void amd_buffer_store_raw_impl(const array<T, N>& dst_thread_data,
|
||||
CK_TILE_DEVICE void amd_buffer_store_raw_impl(const thread_buffer<T, N>& dst_thread_data,
|
||||
int32x4_t dst_wave_buffer_resource,
|
||||
index_t dst_thread_addr_offset,
|
||||
index_t dst_wave_addr_offset,
|
||||
@@ -1889,7 +1890,7 @@ template <typename T,
|
||||
index_t N,
|
||||
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default,
|
||||
bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE void amd_buffer_store_raw(const array<T, N>& src_thread_data,
|
||||
CK_TILE_DEVICE void amd_buffer_store_raw(const thread_buffer<T, N>& src_thread_data,
|
||||
T* p_dst_wave,
|
||||
const index_t dst_thread_element_offset,
|
||||
const bool dst_thread_element_valid,
|
||||
|
||||
@@ -51,10 +51,10 @@
|
||||
#define CK_TILE_STATICALLY_INDEXED_ARRAY_DEFAULT CK_TILE_STATICALLY_INDEXED_ARRAY_USE_TUPLE
|
||||
#endif
|
||||
|
||||
#define CK_TILE_RBUFFER_USE_ARRAY 0
|
||||
#define CK_TILE_RBUFFER_USE_TUPLE 1
|
||||
#ifndef CK_TILE_RBUFFER_DEFAULT
|
||||
#define CK_TILE_RBUFFER_DEFAULT CK_TILE_RBUFFER_USE_TUPLE
|
||||
#define CK_TILE_THREAD_BUFFER_USE_ARRAY 0
|
||||
#define CK_TILE_THREAD_BUFFER_USE_TUPLE 1
|
||||
#ifndef CK_TILE_THREAD_BUFFER_DEFAULT
|
||||
#define CK_TILE_THREAD_BUFFER_DEFAULT CK_TILE_THREAD_BUFFER_USE_ARRAY
|
||||
#endif
|
||||
|
||||
#ifndef CK_TILE_USE_LAUNCH_BOUNDS
|
||||
|
||||
@@ -157,19 +157,47 @@ struct vector_traits<array<T, N>>
|
||||
static constexpr index_t vector_size = N;
|
||||
};
|
||||
|
||||
template <typename T, typename... Ts>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_array(Ts&&... xs)
|
||||
namespace details {
|
||||
template <class>
|
||||
struct is_ref_wrapper : std::false_type
|
||||
{
|
||||
using value_type = remove_cvref_t<T>;
|
||||
return array<value_type, sizeof...(Ts)>{std::forward<Ts>(xs)...};
|
||||
};
|
||||
template <class T>
|
||||
struct is_ref_wrapper<std::reference_wrapper<T>> : std::true_type
|
||||
{
|
||||
};
|
||||
|
||||
template <class T>
|
||||
using not_ref_wrapper = std::negation<is_ref_wrapper<std::decay_t<T>>>;
|
||||
|
||||
template <class D, class...>
|
||||
struct return_type_helper
|
||||
{
|
||||
using type = D;
|
||||
};
|
||||
template <class... Ts>
|
||||
struct return_type_helper<void, Ts...> : std::common_type<Ts...>
|
||||
{
|
||||
static_assert(std::conjunction_v<not_ref_wrapper<Ts>...>,
|
||||
"Ts cannot contain reference_wrappers when D is void");
|
||||
};
|
||||
|
||||
template <class D, class... Ts>
|
||||
using return_type = array<typename return_type_helper<D, Ts...>::type, sizeof...(Ts)>;
|
||||
} // namespace details
|
||||
|
||||
template <typename D = void, typename... Ts>
|
||||
CK_TILE_HOST_DEVICE constexpr details::return_type<D, Ts...> make_array(Ts&&... ts)
|
||||
{
|
||||
return {std::forward<Ts>(ts)...};
|
||||
}
|
||||
|
||||
// make empty array
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_array()
|
||||
{
|
||||
return array<T, 0>{};
|
||||
}
|
||||
// // make empty array
|
||||
// template <typename T>
|
||||
// CK_TILE_HOST_DEVICE constexpr auto make_array()
|
||||
// {
|
||||
// return array<T, 0>{};
|
||||
// }
|
||||
|
||||
// compatible with old ck's initializer, make an array and fill it withe the last element from
|
||||
// initializer_list
|
||||
|
||||
@@ -1,20 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/container/array.hpp"
|
||||
#include "ck_tile/core/container/tuple_array.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
#if CK_TILE_RBUFFER_DEFAULT == CK_TILE_RBUFFER_USE_TUPLE
|
||||
template <typename T, index_t N>
|
||||
using rbuffer = tuple_array<T, N>;
|
||||
#else
|
||||
template <typename T, index_t N>
|
||||
using rbuffer = array<T, N>
|
||||
#endif
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -5,7 +5,7 @@
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/container/array.hpp"
|
||||
#include "ck_tile/core/container/tuple_array.hpp"
|
||||
#include "ck_tile/core/container/tuple.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
32
include/ck_tile/core/container/thread_buffer.hpp
Normal file
32
include/ck_tile/core/container/thread_buffer.hpp
Normal file
@@ -0,0 +1,32 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/container/array.hpp"
|
||||
#include "ck_tile/core/container/tuple.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
#if CK_TILE_THREAD_BUFFER_DEFAULT == CK_TILE_THREAD_BUFFER_USE_TUPLE
|
||||
template <typename T, index_t N>
|
||||
using thread_buffer = tuple_array<T, N>;
|
||||
|
||||
template <typename... Ts>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_thread_buffer(Ts&&... ts)
|
||||
{
|
||||
return make_tuple(ts...);
|
||||
}
|
||||
#else
|
||||
template <typename T, index_t N>
|
||||
using thread_buffer = array<T, N>;
|
||||
|
||||
template <typename... Ts>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_thread_buffer(Ts&&... ts)
|
||||
{
|
||||
return make_array(ts...);
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -12,8 +12,20 @@
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
#include <utility>
|
||||
|
||||
#ifndef CK_TILE_TUPLE_IMPL
|
||||
#define CK_TILE_TUPLE_IMPL 1
|
||||
#endif
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
namespace impl {
|
||||
template <typename T, index_t N>
|
||||
struct tuple_array_impl;
|
||||
}
|
||||
|
||||
template <typename T, index_t N>
|
||||
using tuple_array = typename impl::tuple_array_impl<T, N>::type;
|
||||
|
||||
namespace impl {
|
||||
|
||||
// the place where content is stored
|
||||
@@ -26,37 +38,77 @@ template <index_t idx, typename T>
|
||||
struct tuple_object<idx, T, true>
|
||||
{
|
||||
CK_TILE_HOST_DEVICE constexpr tuple_object() {}
|
||||
CK_TILE_HOST_DEVICE constexpr tuple_object(const T&) {}
|
||||
#if CK_TILE_TUPLE_IMPL == 0
|
||||
template <typename U>
|
||||
CK_TILE_HOST_DEVICE constexpr tuple_object(U&&)
|
||||
{
|
||||
}
|
||||
template <typename U>
|
||||
CK_TILE_HOST_DEVICE constexpr tuple_object(const U&)
|
||||
{
|
||||
}
|
||||
template <typename U>
|
||||
CK_TILE_HOST_DEVICE constexpr tuple_object(U&)
|
||||
{
|
||||
}
|
||||
#elif CK_TILE_TUPLE_IMPL == 1
|
||||
template <typename U,
|
||||
typename std::enable_if<!std::is_same<remove_cvref_t<U>, tuple_object>::value,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr tuple_object(U&&)
|
||||
{
|
||||
}
|
||||
#endif
|
||||
};
|
||||
|
||||
template <index_t idx, typename T>
|
||||
struct tuple_object<idx, T, false>
|
||||
{
|
||||
CK_TILE_HOST_DEVICE constexpr tuple_object() : element{} {}
|
||||
CK_TILE_HOST_DEVICE constexpr tuple_object(const T& e) : element(e) {}
|
||||
#if CK_TILE_TUPLE_IMPL == 0
|
||||
template <typename U>
|
||||
CK_TILE_HOST_DEVICE constexpr tuple_object(U&& e) : element(std::forward<U>(e))
|
||||
{
|
||||
}
|
||||
template <typename U>
|
||||
CK_TILE_HOST_DEVICE constexpr tuple_object(const U& e) : element(e)
|
||||
{
|
||||
}
|
||||
template <typename U>
|
||||
CK_TILE_HOST_DEVICE constexpr tuple_object(U& e) : element(e)
|
||||
{
|
||||
}
|
||||
#elif CK_TILE_TUPLE_IMPL == 1
|
||||
template <typename U,
|
||||
typename std::enable_if<!std::is_same<remove_cvref_t<U>, tuple_object>::value,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr tuple_object(U&& e) : element(std::forward<U>(e))
|
||||
{
|
||||
}
|
||||
#endif
|
||||
T element;
|
||||
};
|
||||
|
||||
// NOTE: we return a instance(not a reference) if content is empty
|
||||
template <std::size_t I, class T>
|
||||
template <index_t I, class T>
|
||||
CK_TILE_HOST_DEVICE constexpr T getv(const tuple_object<I, T, true>&)
|
||||
{
|
||||
return {};
|
||||
}
|
||||
|
||||
template <std::size_t I, class T>
|
||||
template <index_t I, class T>
|
||||
CK_TILE_HOST_DEVICE constexpr const T& getv(const tuple_object<I, T, false>& x)
|
||||
{
|
||||
return x.element;
|
||||
}
|
||||
|
||||
template <std::size_t I, class T>
|
||||
template <index_t I, class T>
|
||||
CK_TILE_HOST_DEVICE constexpr T& getv(tuple_object<I, T, false>& x)
|
||||
{
|
||||
return x.element;
|
||||
}
|
||||
|
||||
template <std::size_t I, class T>
|
||||
template <index_t I, class T>
|
||||
CK_TILE_HOST_DEVICE constexpr T&& getv(tuple_object<I, T, false>&& x)
|
||||
{
|
||||
return static_cast<T&&>(x.element);
|
||||
@@ -68,18 +120,58 @@ struct tuple_base;
|
||||
template <index_t... I, typename... T>
|
||||
struct tuple_base<sequence<I...>, T...> : tuple_object<I, T>...
|
||||
{
|
||||
CK_TILE_HOST_DEVICE constexpr tuple_base() {}
|
||||
CK_TILE_HOST_DEVICE constexpr tuple_base() = default;
|
||||
#if CK_TILE_TUPLE_IMPL == 0
|
||||
template <class... U>
|
||||
CK_TILE_HOST_DEVICE constexpr explicit tuple_base(U&&... u)
|
||||
: tuple_object<I, T>(std::forward<U>(u))...
|
||||
{
|
||||
}
|
||||
|
||||
template <class... U>
|
||||
CK_TILE_HOST_DEVICE constexpr explicit tuple_base(const U&... u) : tuple_object<I, T>(u)...
|
||||
{
|
||||
}
|
||||
|
||||
template <class... U>
|
||||
CK_TILE_HOST_DEVICE constexpr explicit tuple_base(U&... u) : tuple_object<I, T>(u)...
|
||||
{
|
||||
}
|
||||
|
||||
template <class... U>
|
||||
CK_TILE_HOST_DEVICE constexpr tuple_base(tuple_base<sequence<I...>, U...>&& u)
|
||||
: tuple_object<I, T>(getv(static_cast<tuple_object<I, U>&&>(u)))...
|
||||
{
|
||||
}
|
||||
|
||||
template <class... U>
|
||||
CK_TILE_HOST_DEVICE constexpr tuple_base(const tuple_base<sequence<I...>, U...>& u)
|
||||
: tuple_object<I, T>(getv(static_cast<const tuple_object<I, U>&>(u)))...
|
||||
{
|
||||
}
|
||||
|
||||
template <class... U>
|
||||
CK_TILE_HOST_DEVICE constexpr tuple_base(tuple_base<sequence<I...>, U...>& u)
|
||||
: tuple_object<I, T>(getv(static_cast<tuple_object<I, U>&>(u)))...
|
||||
{
|
||||
}
|
||||
#elif CK_TILE_TUPLE_IMPL == 1
|
||||
template <class U,
|
||||
typename std::enable_if<sizeof...(I) == 1 && sizeof...(T) == 1 &&
|
||||
!std::is_same<remove_cvref_t<U>, tuple_base>::value,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr tuple_base(U&& u) : tuple_object<I, T>(std::forward<U>(u))...
|
||||
{
|
||||
}
|
||||
|
||||
template <typename... U, typename std::enable_if<sizeof...(U) >= 2, bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr tuple_base(U&&... u) : tuple_object<I, T>(std::forward<U>(u))...
|
||||
{
|
||||
static_assert(sizeof...(I) == sizeof...(T) && sizeof...(I) == sizeof...(U),
|
||||
"wrong! inconsistent size");
|
||||
}
|
||||
|
||||
#endif
|
||||
};
|
||||
} // namespace impl
|
||||
|
||||
@@ -89,19 +181,56 @@ struct tuple : impl::tuple_base<make_index_sequence<sizeof...(T)>, T...>
|
||||
CK_TILE_HOST_DEVICE
|
||||
static constexpr auto size() { return sizeof...(T); }
|
||||
using base = impl::tuple_base<make_index_sequence<sizeof...(T)>, T...>;
|
||||
CK_TILE_HOST_DEVICE constexpr tuple() {}
|
||||
CK_TILE_HOST_DEVICE constexpr tuple() = default;
|
||||
#if CK_TILE_TUPLE_IMPL == 0
|
||||
template <class... U>
|
||||
CK_TILE_HOST_DEVICE constexpr tuple(U&&... u) : base(std::forward<U>(u)...)
|
||||
{
|
||||
}
|
||||
|
||||
template <class... U>
|
||||
CK_TILE_HOST_DEVICE constexpr tuple(const U&... u) : base(u...)
|
||||
{
|
||||
}
|
||||
|
||||
template <class... U>
|
||||
CK_TILE_HOST_DEVICE constexpr tuple(U&... u) : base(u...)
|
||||
{
|
||||
}
|
||||
|
||||
template <class... U>
|
||||
CK_TILE_HOST_DEVICE constexpr tuple(tuple<U...>&& u)
|
||||
: base(static_cast<impl::tuple_base<make_index_sequence<sizeof...(U)>, U...>&&>(u))
|
||||
{
|
||||
}
|
||||
|
||||
template <class... U>
|
||||
CK_TILE_HOST_DEVICE constexpr tuple(const tuple<U...>& u)
|
||||
: base(static_cast<const impl::tuple_base<make_index_sequence<sizeof...(U)>, U...>&>(u))
|
||||
{
|
||||
}
|
||||
|
||||
template <class... U>
|
||||
CK_TILE_HOST_DEVICE constexpr tuple(tuple<U...>& u)
|
||||
: base(static_cast<impl::tuple_base<make_index_sequence<sizeof...(U)>, U...>&>(u))
|
||||
{
|
||||
}
|
||||
#elif CK_TILE_TUPLE_IMPL == 1
|
||||
template <
|
||||
typename U,
|
||||
typename std::enable_if<sizeof...(T) == 1 && !std::is_same<remove_cvref_t<U>, tuple>::value,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr tuple(U&& u) : base(std::forward<U>(u))
|
||||
{
|
||||
}
|
||||
|
||||
template <typename... U,
|
||||
typename std::enable_if<sizeof...(U) == sizeof...(T) && sizeof...(U) >= 2,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr tuple(U&&... u) : base(std::forward<U>(u)...)
|
||||
{
|
||||
}
|
||||
#endif
|
||||
CK_TILE_HOST_DEVICE static constexpr bool is_static()
|
||||
{
|
||||
bool flag = true;
|
||||
@@ -128,6 +257,19 @@ struct tuple : impl::tuple_base<make_index_sequence<sizeof...(T)>, T...>
|
||||
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) operator[](number<I>) { TP_COM_(); return get<I>(); }
|
||||
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) operator[](number<I>) const { TP_COM_(); return get<I>(); }
|
||||
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) operator()(number<I>) { TP_COM_(); return get<I>(); } // TODO: compatible
|
||||
|
||||
// below function should be used under tuple_array<> type, no extra check will perform here
|
||||
template <typename Tx> CK_TILE_HOST_DEVICE constexpr decltype(auto) get_as() { return reinterpret_cast<tuple_array<Tx, size()>&>(*this); }
|
||||
template <typename Tx> CK_TILE_HOST_DEVICE constexpr decltype(auto) get_as() const { return reinterpret_cast<const tuple_array<Tx, size()>&>(*this); }
|
||||
// below index is for index *AFTER* type convert, not before
|
||||
//template <typename Tx> CK_TILE_HOST_DEVICE constexpr decltype(auto) get_as(index_t i) { TP_COM_(); return reinterpret_cast<tuple_array<Tx, size()>&>(*this).at(i); }
|
||||
//template <typename Tx> CK_TILE_HOST_DEVICE constexpr decltype(auto) get_as(index_t i) const { TP_COM_(); return reinterpret_cast<const tuple_array<Tx, size()>&>(*this).at(i); }
|
||||
template <typename Tx, index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get_as(number<I>) { TP_COM_(); return reinterpret_cast<tuple_array<Tx, size()>&>(*this).at(number<I>{}); }
|
||||
template <typename Tx, index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get_as(number<I>) const { TP_COM_(); return reinterpret_cast<const tuple_array<Tx, size()>&>(*this).at(number<I>{}); }
|
||||
|
||||
// template <typename Tx> CK_TILE_HOST_DEVICE constexpr void set_as(index_t i, const Tx & x) { TP_COM_(); reinterpret_cast<tuple_array<Tx, size()>&>(*this).at(i) = x; }
|
||||
template <typename Tx, index_t I> CK_TILE_HOST_DEVICE constexpr void set_as(number<I>, const Tx & x) { TP_COM_(); reinterpret_cast<tuple_array<Tx, size()>&>(*this).at(number<I>{}) = x; }
|
||||
|
||||
// clang-format on
|
||||
#undef TP_COM_
|
||||
};
|
||||
@@ -163,6 +305,15 @@ CK_TILE_HOST_DEVICE constexpr bool operator!=(const tuple<Xs...>& a, const tuple
|
||||
template <typename... Xs>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs&&... xs)
|
||||
{
|
||||
// here xs is always a lvalue as function arg
|
||||
// Xs may deduced as (e.g try to pass in a integer in following cases)
|
||||
// 1). if pass in a rvalue (like function return or int{}) -> Xs is "int"
|
||||
// 2). if pass in a const lvalue -> Xs is "const int &"
|
||||
// 3). if pass in a non-const lvalue -> Xs is "int &"
|
||||
// so the return type of std::forward will dependes on Xs
|
||||
// 1). std::forward -> int&&
|
||||
// 2). std::forward -> const int&
|
||||
// 3). std::forward -> int&
|
||||
return tuple<remove_cvref_t<Xs>...>(std::forward<Xs>(xs)...);
|
||||
}
|
||||
|
||||
@@ -182,6 +333,38 @@ struct tuple_concat<tuple<Xs...>, tuple<Ys...>>
|
||||
using type = tuple<Xs..., Ys...>;
|
||||
};
|
||||
|
||||
namespace impl {
|
||||
// be very careful using this type (because we want the internal type)
|
||||
// template deduction will fail if infering the inner type
|
||||
// e.g.
|
||||
// template<typename T, index_t N> using some_wrapper = typename tuple_array_impl<T, N>::type;
|
||||
// template<typename T, index_t N> void foo(const some_wrapper<T, N>&) {}
|
||||
// -> compiler will fail to deduce this type, because this is under non-deduced context
|
||||
// (https://en.cppreference.com/w/cpp/language/template_argument_deduction, "Non-deduced
|
||||
// contexts")
|
||||
//
|
||||
// -> use this instead
|
||||
// template<typename Tup> void foo(const Tup&) {}
|
||||
template <typename T, index_t N>
|
||||
struct tuple_array_impl
|
||||
{
|
||||
using type = typename tuple_concat<typename tuple_array_impl<T, N / 2>::type,
|
||||
typename tuple_array_impl<T, N - N / 2>::type>::type;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct tuple_array_impl<T, 0>
|
||||
{
|
||||
using type = tuple<>;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct tuple_array_impl<T, 1>
|
||||
{
|
||||
using type = tuple<T>;
|
||||
};
|
||||
} // namespace impl
|
||||
|
||||
template <typename F, index_t N>
|
||||
CK_TILE_HOST_DEVICE constexpr auto generate_tuple(F&& f, number<N>)
|
||||
{
|
||||
|
||||
@@ -1,60 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/container/array.hpp"
|
||||
#include "ck_tile/core/container/tuple.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
namespace detail {
|
||||
template <typename T, index_t N>
|
||||
struct tuple_array_impl
|
||||
{
|
||||
using type = typename tuple_concat<typename tuple_array_impl<T, N / 2>::type,
|
||||
typename tuple_array_impl<T, N - N / 2>::type>::type;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct tuple_array_impl<T, 0>
|
||||
{
|
||||
using type = tuple<>;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct tuple_array_impl<T, 1>
|
||||
{
|
||||
using type = tuple<T>;
|
||||
};
|
||||
} // namespace detail
|
||||
|
||||
template <typename T, index_t N>
|
||||
using tuple_array_base_t = typename detail::tuple_array_impl<T, N>::type;
|
||||
|
||||
template <typename T_, index_t N_>
|
||||
struct tuple_array : tuple_array_base_t<T_, N_>
|
||||
{
|
||||
using value_type = T_;
|
||||
static constexpr index_t N = N_;
|
||||
|
||||
// clang-format off
|
||||
#define TA_COM_() static_assert(sizeof(value_type) * N % sizeof(Tx) == 0); constexpr int vx = sizeof(value_type) * N / sizeof(Tx)
|
||||
template <typename Tx> CK_TILE_HOST_DEVICE constexpr decltype(auto) get_as() { TA_COM_(); return reinterpret_cast<tuple_array<Tx, vx>&>(*this); }
|
||||
template <typename Tx> CK_TILE_HOST_DEVICE constexpr decltype(auto) get_as() const { TA_COM_(); return reinterpret_cast<const tuple_array<Tx, vx>&>(*this); }
|
||||
|
||||
// below index is for index *AFTER* type convert, not before
|
||||
template <typename Tx> CK_TILE_HOST_DEVICE constexpr decltype(auto) get_as(index_t i) { TA_COM_(); return reinterpret_cast<tuple_array<Tx, vx>&>(*this).at(i); }
|
||||
template <typename Tx> CK_TILE_HOST_DEVICE constexpr decltype(auto) get_as(index_t i) const { TA_COM_(); return reinterpret_cast<const tuple_array<Tx, vx>&>(*this).at(i); }
|
||||
template <typename Tx, index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get_as(number<I>) { TA_COM_(); return reinterpret_cast<tuple_array<Tx, vx>&>(*this).at(number<I>{}); }
|
||||
template <typename Tx, index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get_as(number<I>) const { TA_COM_(); return reinterpret_cast<const tuple_array<Tx, vx>&>(*this).at(number<I>{}); }
|
||||
|
||||
template <typename Tx> CK_TILE_HOST_DEVICE constexpr void set_as(index_t i, const Tx & x) { TA_COM_(); reinterpret_cast<tuple_array<Tx, vx>&>(*this).at(i) = x; }
|
||||
template <typename Tx, index_t I> CK_TILE_HOST_DEVICE constexpr void set_as(number<I>, const Tx & x) { TA_COM_(); reinterpret_cast<tuple_array<Tx, vx>&>(*this).at(number<I>{}) = x; }
|
||||
#undef TA_COM_
|
||||
// clang-format on
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -43,7 +43,7 @@ struct alignas(2) half_t
|
||||
constexpr fp16_hip_t to_fp16() const { return ck_tile::bit_cast<fp16_hip_t>(data); }
|
||||
|
||||
// constructor
|
||||
constexpr half_t() : data() {}
|
||||
constexpr half_t() : data{} {}
|
||||
|
||||
// construct from HIP half
|
||||
CK_TILE_HOST_DEVICE
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
#include "ck_tile/core/tensor/tile_distribution.hpp"
|
||||
#include "ck_tile/core/container/rbuffer.hpp"
|
||||
#include "ck_tile/core/container/thread_buffer.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -72,7 +72,8 @@ struct static_distributed_tensor
|
||||
constexpr auto sliced_thread_tensor_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(YSliceLengths...));
|
||||
|
||||
rbuffer<DataType, sliced_thread_tensor_desc.get_element_space_size()> sliced_thread_data;
|
||||
thread_buffer<DataType, sliced_thread_tensor_desc.get_element_space_size()>
|
||||
sliced_thread_data;
|
||||
|
||||
static_ford<sequence<YSliceLengths...>>{}([&](auto idx) {
|
||||
constexpr auto idx_ys = idx + sequence<YSliceOrigins...>{};
|
||||
@@ -84,11 +85,10 @@ struct static_distributed_tensor
|
||||
return sliced_thread_data;
|
||||
}
|
||||
|
||||
template <index_t... YSliceOrigins, index_t... YSliceLengths, index_t NSlicedData>
|
||||
CK_TILE_HOST_DEVICE void
|
||||
set_y_sliced_thread_data(sequence<YSliceOrigins...>,
|
||||
sequence<YSliceLengths...>,
|
||||
const rbuffer<DataType, NSlicedData>& sliced_thread_data)
|
||||
template <index_t... YSliceOrigins, index_t... YSliceLengths, typename SlicedThreadData>
|
||||
CK_TILE_HOST_DEVICE void set_y_sliced_thread_data(sequence<YSliceOrigins...>,
|
||||
sequence<YSliceLengths...>,
|
||||
const SlicedThreadData& sliced_thread_data)
|
||||
{
|
||||
static_assert(sizeof...(YSliceOrigins) == StaticTileDistribution::NDimY &&
|
||||
sizeof...(YSliceLengths) == StaticTileDistribution::NDimY,
|
||||
@@ -130,7 +130,7 @@ struct static_distributed_tensor
|
||||
}
|
||||
|
||||
//
|
||||
rbuffer<DataType, kThreadElementSpaceSize> thread_buf_;
|
||||
thread_buffer<DataType, kThreadElementSpaceSize> thread_buf_;
|
||||
};
|
||||
|
||||
template <typename DataType, typename StaticTileDistribution>
|
||||
@@ -140,6 +140,14 @@ CK_TILE_HOST_DEVICE constexpr auto make_static_distributed_tensor(const StaticTi
|
||||
remove_cvref_t<StaticTileDistribution>>{};
|
||||
}
|
||||
|
||||
template <typename DataType, typename StaticTileDistribution, typename ThreadBuffer>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_static_distributed_tensor(const StaticTileDistribution&,
|
||||
ThreadBuffer&& thread_buffer_)
|
||||
{
|
||||
return static_distributed_tensor<remove_cvref_t<DataType>,
|
||||
remove_cvref_t<StaticTileDistribution>>{thread_buffer_};
|
||||
}
|
||||
|
||||
// get X indices from tuple of tile_distributed_index<>
|
||||
template <typename StaticTileDistribution, typename DistributedIndices>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
|
||||
@@ -34,20 +34,20 @@ CK_TILE_DEVICE void tile_elementwise_inout(const InOutElementFunc& inout_element
|
||||
}
|
||||
|
||||
template <typename InElementFunc,
|
||||
typename... InDstrTensors,
|
||||
typename... InTensor,
|
||||
typename = std::enable_if_t<
|
||||
std::conjunction_v<std::negation<std::is_same<InDstrTensors, null_tensor>>...>>>
|
||||
std::conjunction_v<std::negation<std::is_same<InTensor, null_tensor>>...>>>
|
||||
CK_TILE_DEVICE auto tile_elementwise_in(const InElementFunc& in_element_func,
|
||||
const InDstrTensors&... in_dstr_tensors)
|
||||
const InTensor&... in_dstr_tensors)
|
||||
{
|
||||
using OutDataType = decltype(in_element_func(typename InDstrTensors::DataType{}...));
|
||||
using OutDataType = decltype(in_element_func(typename InTensor::DataType{}...));
|
||||
|
||||
// TODO: make sure all distributed tensors have same lengths and distribution
|
||||
// static_assert(xxx);
|
||||
constexpr auto in_tile_dstr = __type_pack_element<0, InDstrTensors...>::get_tile_distribution();
|
||||
constexpr auto in_tile_dstr = __type_pack_element<0, InTensor...>::get_tile_distribution();
|
||||
|
||||
constexpr index_t thread_buffer_size =
|
||||
__type_pack_element<0, InDstrTensors...>::get_thread_buffer_size();
|
||||
__type_pack_element<0, InTensor...>::get_thread_buffer_size();
|
||||
|
||||
auto out_dstr_tensor = make_static_distributed_tensor<OutDataType>(in_tile_dstr);
|
||||
|
||||
@@ -107,15 +107,16 @@ CK_TILE_DEVICE void clear_tile(DstrTensors& dstr_tensor)
|
||||
set_tile(dstr_tensor, 0);
|
||||
}
|
||||
|
||||
namespace impl {
|
||||
// TODO: this is ugly
|
||||
template <typename OutDataType, typename InDstrTensors>
|
||||
CK_TILE_DEVICE auto cast_tile_pk_fp8x4(const InDstrTensors& in_dstr_tensors)
|
||||
template <typename OutDataType, typename InTensor>
|
||||
CK_TILE_DEVICE auto cast_tile_pk_fp8x4(const InTensor& in_dstr_tensors)
|
||||
{
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
// This API is designed to use the _pk_ serious of function
|
||||
constexpr auto in_tile_dstr = InDstrTensors::get_tile_distribution();
|
||||
constexpr auto in_tile_dstr = InTensor::get_tile_distribution();
|
||||
|
||||
constexpr index_t thread_buffer_size = InDstrTensors::get_thread_buffer_size();
|
||||
constexpr index_t thread_buffer_size = InTensor::get_thread_buffer_size();
|
||||
static_assert(thread_buffer_size % 4 == 0);
|
||||
constexpr index_t thread_buffer_size_pk = thread_buffer_size / 4;
|
||||
|
||||
@@ -150,24 +151,90 @@ CK_TILE_DEVICE auto cast_tile_pk_fp8x4(const InDstrTensors& in_dstr_tensors)
|
||||
return out_dstr_tensor;
|
||||
#else
|
||||
// fallback
|
||||
return tile_elementwise_in(type_convert<OutDataType, typename InDstrTensors::DataType>,
|
||||
return tile_elementwise_in(type_convert<OutDataType, typename InTensor::DataType>,
|
||||
in_dstr_tensors);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename DstType, typename SrcDstrTensors>
|
||||
CK_TILE_DEVICE auto cast_tile(const SrcDstrTensors& src_tensor)
|
||||
// this function assume either src or dst (or both) date type is under 1 dword
|
||||
// we pack subdword value into 1 dword to avoid compiler's default subdword behavior(which is buggy)
|
||||
template <typename OutDataType, typename InTensor>
|
||||
CK_TILE_DEVICE auto cast_tile_opt_subdword(const InTensor& in_dstr_tensors)
|
||||
{
|
||||
constexpr auto in_tile_dstr = InTensor::get_tile_distribution();
|
||||
|
||||
auto out_dstr_tensor = make_static_distributed_tensor<OutDataType>(in_tile_dstr);
|
||||
|
||||
using i_type = remove_cvref_t<typename InTensor::DataType>;
|
||||
using o_type = remove_cvref_t<OutDataType>;
|
||||
constexpr index_t i_elem_bytes = sizeof(i_type);
|
||||
constexpr index_t o_elem_bytes = sizeof(o_type);
|
||||
static_assert(i_elem_bytes < 4 || o_elem_bytes < 4);
|
||||
|
||||
constexpr index_t bulk_size =
|
||||
(i_elem_bytes >= o_elem_bytes) ? (4 / o_elem_bytes) : (4 / i_elem_bytes);
|
||||
static_assert(bulk_size != 0);
|
||||
|
||||
using o_bulk_type =
|
||||
std::conditional_t<i_elem_bytes >= o_elem_bytes, float, array<o_type, bulk_size>>;
|
||||
|
||||
constexpr index_t thread_buffer_size = InTensor::get_thread_buffer_size();
|
||||
|
||||
constexpr index_t iters = thread_buffer_size / bulk_size;
|
||||
constexpr index_t rems = thread_buffer_size % bulk_size;
|
||||
|
||||
// cast the sequence per-bulk
|
||||
static_for<0, iters, 1>{}([&](auto i) {
|
||||
union bulk_wrapper
|
||||
{
|
||||
o_bulk_type bulk{};
|
||||
o_type data[bulk_size];
|
||||
} o_bulk;
|
||||
|
||||
// TODO: should use below function, but somehow will result in spill (same as c-forloop)
|
||||
// static_for<0, bulk_size, 1>{}([&o_bulk, &in_dstr_tensors, &i](auto ib){
|
||||
// o_bulk.data[ib.value] =
|
||||
// static_cast<o_type>(in_dstr_tensors.get_thread_buffer().template
|
||||
// get_as<i_type>()[number<bulk_size * i.value + ib.value>{}]);
|
||||
// });
|
||||
|
||||
// TODO: fixme, should use above!
|
||||
static_assert(sizeof(i_type) / sizeof(o_type) == 2);
|
||||
o_bulk.data[0] = static_cast<o_type>(
|
||||
in_dstr_tensors.get_thread_buffer().template get_as<i_type>()[number<2 * i + 0>{}]);
|
||||
o_bulk.data[1] = static_cast<o_type>(
|
||||
in_dstr_tensors.get_thread_buffer().template get_as<i_type>()[number<2 * i + 1>{}]);
|
||||
|
||||
out_dstr_tensor.get_thread_buffer().template set_as<o_bulk_type>(i, o_bulk.bulk);
|
||||
});
|
||||
|
||||
static_for<0, rems, 1>{}([&](auto r) {
|
||||
// TODO: introducing local scratch pad?
|
||||
auto idx = number<iters * bulk_size + r>{};
|
||||
out_dstr_tensor.get_thread_buffer().at(idx) =
|
||||
static_cast<o_type>(in_dstr_tensors.get_thread_buffer().at(idx));
|
||||
});
|
||||
|
||||
return out_dstr_tensor;
|
||||
}
|
||||
} // namespace impl
|
||||
|
||||
template <typename DstType, typename SrcTensor>
|
||||
CK_TILE_DEVICE auto cast_tile(const SrcTensor& src_tensor)
|
||||
{
|
||||
if constexpr((std::is_same_v<DstType, fp8_t> ||
|
||||
std::is_same_v<DstType, bf8_t>)&&std::is_same_v<typename SrcDstrTensors::DataType,
|
||||
std::is_same_v<DstType, bf8_t>)&&std::is_same_v<typename SrcTensor::DataType,
|
||||
float> &&
|
||||
(SrcDstrTensors::get_thread_buffer_size() % 4 == 0))
|
||||
(SrcTensor::get_thread_buffer_size() % 4 == 0))
|
||||
{
|
||||
return cast_tile_pk_fp8x4<DstType, SrcDstrTensors>(src_tensor);
|
||||
return impl::cast_tile_pk_fp8x4<DstType, SrcTensor>(src_tensor);
|
||||
}
|
||||
else if constexpr(sizeof(DstType) < 4 || sizeof(typename SrcTensor::DataType) < 4)
|
||||
{
|
||||
return impl::cast_tile_opt_subdword<DstType, SrcTensor>(src_tensor);
|
||||
}
|
||||
else
|
||||
return tile_elementwise_in(type_convert<DstType, typename SrcDstrTensors::DataType>,
|
||||
src_tensor);
|
||||
return tile_elementwise_in(type_convert<DstType, typename SrcTensor::DataType>, src_tensor);
|
||||
}
|
||||
|
||||
// no-op function for null_tensor arguments
|
||||
|
||||
@@ -535,7 +535,8 @@ struct tile_window_with_static_distribution
|
||||
using Traits = load_store_traits;
|
||||
|
||||
// using vector_type_t = typename Traits::vector_type_t;
|
||||
using vector_t = typename Traits::vector_t;
|
||||
// using vector_t = typename Traits::vector_t;
|
||||
using vector_t = thread_buffer<DataType, Traits::ScalarPerVector>;
|
||||
using SFC_Ys = typename Traits::SFC_Ys;
|
||||
|
||||
constexpr auto tile_dstr = TileDstr{};
|
||||
@@ -553,10 +554,11 @@ struct tile_window_with_static_distribution
|
||||
// data index [y0, y1, ...]
|
||||
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
|
||||
|
||||
// TODO: below code may result in spill(?)
|
||||
#if 0
|
||||
// read from distributed tensor
|
||||
// vector_type_t vec;
|
||||
vector_t vec_value;
|
||||
|
||||
static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) {
|
||||
constexpr auto idx_ys = generate_array(
|
||||
[&](auto jj) {
|
||||
@@ -564,10 +566,8 @@ struct tile_window_with_static_distribution
|
||||
: idx_ys_start[jj];
|
||||
},
|
||||
number<NDimY>{});
|
||||
|
||||
constexpr index_t d =
|
||||
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
|
||||
|
||||
vec_value.template get_as<DataType>()(j) =
|
||||
dstr_tensor.get_thread_buffer().template at<d>();
|
||||
});
|
||||
@@ -578,7 +578,16 @@ struct tile_window_with_static_distribution
|
||||
get_bottom_tensor_view()
|
||||
.template set_vectorized_elements_raw<vector_t, oob_conditional_check>(
|
||||
bottom_tensor_thread_coord, vec_value);
|
||||
#else
|
||||
(void)tile_dstr;
|
||||
(void)idx_ys_start;
|
||||
|
||||
get_bottom_tensor_view()
|
||||
.template set_vectorized_elements_raw<vector_t, oob_conditional_check>(
|
||||
bottom_tensor_thread_coord,
|
||||
dstr_tensor.get_thread_buffer().template get_as<vector_t>(
|
||||
number<iCoord * NumAccessPerCoord + iCoordAccess>{}));
|
||||
#endif
|
||||
// move thread coordinate
|
||||
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
|
||||
{
|
||||
|
||||
@@ -38,11 +38,7 @@ struct Default2DEpilogue
|
||||
// TODO: this is ugly
|
||||
if constexpr(kPadM || kPadN)
|
||||
{
|
||||
// o_dram_window_tmp.foo();
|
||||
// ODataType{}.foo();
|
||||
// o_acc_tile.foo();
|
||||
auto x = cast_tile<ODataType>(o_acc_tile);
|
||||
store_tile_raw(o_dram_window_tmp, x);
|
||||
store_tile_raw(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
|
||||
buffer_store_fence();
|
||||
}
|
||||
else
|
||||
|
||||
Reference in New Issue
Block a user