re-structure tuple/array to avoid spill

This commit is contained in:
carlushuang
2024-03-11 15:32:10 +00:00
parent 0bd76de8a6
commit 9f34bcb431
14 changed files with 386 additions and 143 deletions

View File

@@ -15,12 +15,11 @@
#include "ck_tile/core/container/map.hpp"
#include "ck_tile/core/container/meta_data_buffer.hpp"
#include "ck_tile/core/container/multi_index.hpp"
#include "ck_tile/core/container/rbuffer.hpp"
#include "ck_tile/core/container/sequence.hpp"
#include "ck_tile/core/container/span.hpp"
#include "ck_tile/core/container/statically_indexed_array.hpp"
#include "ck_tile/core/container/thread_buffer.hpp"
#include "ck_tile/core/container/tuple.hpp"
#include "ck_tile/core/container/tuple_array.hpp"
#include "ck_tile/core/numeric/arithmetic.hpp"
#include "ck_tile/core/numeric/bfloat16.hpp"
#include "ck_tile/core/numeric/float8.hpp"

View File

@@ -7,6 +7,7 @@
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/numeric/vector_type.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/container/thread_buffer.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/utility/functional.hpp"
@@ -1475,7 +1476,7 @@ template <typename T,
index_t N,
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default,
bool oob_conditional_check = true>
CK_TILE_DEVICE void amd_buffer_store_raw_impl(const array<T, N>& dst_thread_data,
CK_TILE_DEVICE void amd_buffer_store_raw_impl(const thread_buffer<T, N>& dst_thread_data,
int32x4_t dst_wave_buffer_resource,
index_t dst_thread_addr_offset,
index_t dst_wave_addr_offset,
@@ -1889,7 +1890,7 @@ template <typename T,
index_t N,
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default,
bool oob_conditional_check = true>
CK_TILE_DEVICE void amd_buffer_store_raw(const array<T, N>& src_thread_data,
CK_TILE_DEVICE void amd_buffer_store_raw(const thread_buffer<T, N>& src_thread_data,
T* p_dst_wave,
const index_t dst_thread_element_offset,
const bool dst_thread_element_valid,

View File

@@ -51,10 +51,10 @@
#define CK_TILE_STATICALLY_INDEXED_ARRAY_DEFAULT CK_TILE_STATICALLY_INDEXED_ARRAY_USE_TUPLE
#endif
#define CK_TILE_RBUFFER_USE_ARRAY 0
#define CK_TILE_RBUFFER_USE_TUPLE 1
#ifndef CK_TILE_RBUFFER_DEFAULT
#define CK_TILE_RBUFFER_DEFAULT CK_TILE_RBUFFER_USE_TUPLE
#define CK_TILE_THREAD_BUFFER_USE_ARRAY 0
#define CK_TILE_THREAD_BUFFER_USE_TUPLE 1
#ifndef CK_TILE_THREAD_BUFFER_DEFAULT
#define CK_TILE_THREAD_BUFFER_DEFAULT CK_TILE_THREAD_BUFFER_USE_ARRAY
#endif
#ifndef CK_TILE_USE_LAUNCH_BOUNDS

View File

@@ -157,19 +157,47 @@ struct vector_traits<array<T, N>>
static constexpr index_t vector_size = N;
};
template <typename T, typename... Ts>
CK_TILE_HOST_DEVICE constexpr auto make_array(Ts&&... xs)
namespace details {
template <class>
struct is_ref_wrapper : std::false_type
{
using value_type = remove_cvref_t<T>;
return array<value_type, sizeof...(Ts)>{std::forward<Ts>(xs)...};
};
template <class T>
struct is_ref_wrapper<std::reference_wrapper<T>> : std::true_type
{
};
template <class T>
using not_ref_wrapper = std::negation<is_ref_wrapper<std::decay_t<T>>>;
template <class D, class...>
struct return_type_helper
{
using type = D;
};
template <class... Ts>
struct return_type_helper<void, Ts...> : std::common_type<Ts...>
{
static_assert(std::conjunction_v<not_ref_wrapper<Ts>...>,
"Ts cannot contain reference_wrappers when D is void");
};
template <class D, class... Ts>
using return_type = array<typename return_type_helper<D, Ts...>::type, sizeof...(Ts)>;
} // namespace details
template <typename D = void, typename... Ts>
CK_TILE_HOST_DEVICE constexpr details::return_type<D, Ts...> make_array(Ts&&... ts)
{
return {std::forward<Ts>(ts)...};
}
// make empty array
template <typename T>
CK_TILE_HOST_DEVICE constexpr auto make_array()
{
return array<T, 0>{};
}
// // make empty array
// template <typename T>
// CK_TILE_HOST_DEVICE constexpr auto make_array()
// {
// return array<T, 0>{};
// }
// compatible with old ck's initializer, make an array and fill it withe the last element from
// initializer_list

View File

@@ -1,20 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/array.hpp"
#include "ck_tile/core/container/tuple_array.hpp"
namespace ck_tile {
#if CK_TILE_RBUFFER_DEFAULT == CK_TILE_RBUFFER_USE_TUPLE
template <typename T, index_t N>
using rbuffer = tuple_array<T, N>;
#else
template <typename T, index_t N>
using rbuffer = array<T, N>
#endif
} // namespace ck_tile

View File

@@ -5,7 +5,7 @@
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/array.hpp"
#include "ck_tile/core/container/tuple_array.hpp"
#include "ck_tile/core/container/tuple.hpp"
#include "ck_tile/core/numeric/integer.hpp"
namespace ck_tile {

View File

@@ -0,0 +1,32 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/array.hpp"
#include "ck_tile/core/container/tuple.hpp"
namespace ck_tile {
#if CK_TILE_THREAD_BUFFER_DEFAULT == CK_TILE_THREAD_BUFFER_USE_TUPLE
template <typename T, index_t N>
using thread_buffer = tuple_array<T, N>;
template <typename... Ts>
CK_TILE_HOST_DEVICE constexpr auto make_thread_buffer(Ts&&... ts)
{
return make_tuple(ts...);
}
#else
template <typename T, index_t N>
using thread_buffer = array<T, N>;
template <typename... Ts>
CK_TILE_HOST_DEVICE constexpr auto make_thread_buffer(Ts&&... ts)
{
return make_array(ts...);
}
#endif
} // namespace ck_tile

View File

@@ -12,8 +12,20 @@
#include "ck_tile/core/utility/type_traits.hpp"
#include <utility>
#ifndef CK_TILE_TUPLE_IMPL
#define CK_TILE_TUPLE_IMPL 1
#endif
namespace ck_tile {
namespace impl {
template <typename T, index_t N>
struct tuple_array_impl;
}
template <typename T, index_t N>
using tuple_array = typename impl::tuple_array_impl<T, N>::type;
namespace impl {
// the place where content is stored
@@ -26,37 +38,77 @@ template <index_t idx, typename T>
struct tuple_object<idx, T, true>
{
CK_TILE_HOST_DEVICE constexpr tuple_object() {}
CK_TILE_HOST_DEVICE constexpr tuple_object(const T&) {}
#if CK_TILE_TUPLE_IMPL == 0
template <typename U>
CK_TILE_HOST_DEVICE constexpr tuple_object(U&&)
{
}
template <typename U>
CK_TILE_HOST_DEVICE constexpr tuple_object(const U&)
{
}
template <typename U>
CK_TILE_HOST_DEVICE constexpr tuple_object(U&)
{
}
#elif CK_TILE_TUPLE_IMPL == 1
template <typename U,
typename std::enable_if<!std::is_same<remove_cvref_t<U>, tuple_object>::value,
bool>::type = false>
CK_TILE_HOST_DEVICE constexpr tuple_object(U&&)
{
}
#endif
};
template <index_t idx, typename T>
struct tuple_object<idx, T, false>
{
CK_TILE_HOST_DEVICE constexpr tuple_object() : element{} {}
CK_TILE_HOST_DEVICE constexpr tuple_object(const T& e) : element(e) {}
#if CK_TILE_TUPLE_IMPL == 0
template <typename U>
CK_TILE_HOST_DEVICE constexpr tuple_object(U&& e) : element(std::forward<U>(e))
{
}
template <typename U>
CK_TILE_HOST_DEVICE constexpr tuple_object(const U& e) : element(e)
{
}
template <typename U>
CK_TILE_HOST_DEVICE constexpr tuple_object(U& e) : element(e)
{
}
#elif CK_TILE_TUPLE_IMPL == 1
template <typename U,
typename std::enable_if<!std::is_same<remove_cvref_t<U>, tuple_object>::value,
bool>::type = false>
CK_TILE_HOST_DEVICE constexpr tuple_object(U&& e) : element(std::forward<U>(e))
{
}
#endif
T element;
};
// NOTE: we return a instance(not a reference) if content is empty
template <std::size_t I, class T>
template <index_t I, class T>
CK_TILE_HOST_DEVICE constexpr T getv(const tuple_object<I, T, true>&)
{
return {};
}
template <std::size_t I, class T>
template <index_t I, class T>
CK_TILE_HOST_DEVICE constexpr const T& getv(const tuple_object<I, T, false>& x)
{
return x.element;
}
template <std::size_t I, class T>
template <index_t I, class T>
CK_TILE_HOST_DEVICE constexpr T& getv(tuple_object<I, T, false>& x)
{
return x.element;
}
template <std::size_t I, class T>
template <index_t I, class T>
CK_TILE_HOST_DEVICE constexpr T&& getv(tuple_object<I, T, false>&& x)
{
return static_cast<T&&>(x.element);
@@ -68,18 +120,58 @@ struct tuple_base;
template <index_t... I, typename... T>
struct tuple_base<sequence<I...>, T...> : tuple_object<I, T>...
{
CK_TILE_HOST_DEVICE constexpr tuple_base() {}
CK_TILE_HOST_DEVICE constexpr tuple_base() = default;
#if CK_TILE_TUPLE_IMPL == 0
template <class... U>
CK_TILE_HOST_DEVICE constexpr explicit tuple_base(U&&... u)
: tuple_object<I, T>(std::forward<U>(u))...
{
}
template <class... U>
CK_TILE_HOST_DEVICE constexpr explicit tuple_base(const U&... u) : tuple_object<I, T>(u)...
{
}
template <class... U>
CK_TILE_HOST_DEVICE constexpr explicit tuple_base(U&... u) : tuple_object<I, T>(u)...
{
}
template <class... U>
CK_TILE_HOST_DEVICE constexpr tuple_base(tuple_base<sequence<I...>, U...>&& u)
: tuple_object<I, T>(getv(static_cast<tuple_object<I, U>&&>(u)))...
{
}
template <class... U>
CK_TILE_HOST_DEVICE constexpr tuple_base(const tuple_base<sequence<I...>, U...>& u)
: tuple_object<I, T>(getv(static_cast<const tuple_object<I, U>&>(u)))...
{
}
template <class... U>
CK_TILE_HOST_DEVICE constexpr tuple_base(tuple_base<sequence<I...>, U...>& u)
: tuple_object<I, T>(getv(static_cast<tuple_object<I, U>&>(u)))...
{
}
#elif CK_TILE_TUPLE_IMPL == 1
template <class U,
typename std::enable_if<sizeof...(I) == 1 && sizeof...(T) == 1 &&
!std::is_same<remove_cvref_t<U>, tuple_base>::value,
bool>::type = false>
CK_TILE_HOST_DEVICE constexpr tuple_base(U&& u) : tuple_object<I, T>(std::forward<U>(u))...
{
}
template <typename... U, typename std::enable_if<sizeof...(U) >= 2, bool>::type = false>
CK_TILE_HOST_DEVICE constexpr tuple_base(U&&... u) : tuple_object<I, T>(std::forward<U>(u))...
{
static_assert(sizeof...(I) == sizeof...(T) && sizeof...(I) == sizeof...(U),
"wrong! inconsistent size");
}
#endif
};
} // namespace impl
@@ -89,19 +181,56 @@ struct tuple : impl::tuple_base<make_index_sequence<sizeof...(T)>, T...>
CK_TILE_HOST_DEVICE
static constexpr auto size() { return sizeof...(T); }
using base = impl::tuple_base<make_index_sequence<sizeof...(T)>, T...>;
CK_TILE_HOST_DEVICE constexpr tuple() {}
CK_TILE_HOST_DEVICE constexpr tuple() = default;
#if CK_TILE_TUPLE_IMPL == 0
template <class... U>
CK_TILE_HOST_DEVICE constexpr tuple(U&&... u) : base(std::forward<U>(u)...)
{
}
template <class... U>
CK_TILE_HOST_DEVICE constexpr tuple(const U&... u) : base(u...)
{
}
template <class... U>
CK_TILE_HOST_DEVICE constexpr tuple(U&... u) : base(u...)
{
}
template <class... U>
CK_TILE_HOST_DEVICE constexpr tuple(tuple<U...>&& u)
: base(static_cast<impl::tuple_base<make_index_sequence<sizeof...(U)>, U...>&&>(u))
{
}
template <class... U>
CK_TILE_HOST_DEVICE constexpr tuple(const tuple<U...>& u)
: base(static_cast<const impl::tuple_base<make_index_sequence<sizeof...(U)>, U...>&>(u))
{
}
template <class... U>
CK_TILE_HOST_DEVICE constexpr tuple(tuple<U...>& u)
: base(static_cast<impl::tuple_base<make_index_sequence<sizeof...(U)>, U...>&>(u))
{
}
#elif CK_TILE_TUPLE_IMPL == 1
template <
typename U,
typename std::enable_if<sizeof...(T) == 1 && !std::is_same<remove_cvref_t<U>, tuple>::value,
bool>::type = false>
CK_TILE_HOST_DEVICE constexpr tuple(U&& u) : base(std::forward<U>(u))
{
}
template <typename... U,
typename std::enable_if<sizeof...(U) == sizeof...(T) && sizeof...(U) >= 2,
bool>::type = false>
CK_TILE_HOST_DEVICE constexpr tuple(U&&... u) : base(std::forward<U>(u)...)
{
}
#endif
CK_TILE_HOST_DEVICE static constexpr bool is_static()
{
bool flag = true;
@@ -128,6 +257,19 @@ struct tuple : impl::tuple_base<make_index_sequence<sizeof...(T)>, T...>
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) operator[](number<I>) { TP_COM_(); return get<I>(); }
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) operator[](number<I>) const { TP_COM_(); return get<I>(); }
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) operator()(number<I>) { TP_COM_(); return get<I>(); } // TODO: compatible
// below function should be used under tuple_array<> type, no extra check will perform here
template <typename Tx> CK_TILE_HOST_DEVICE constexpr decltype(auto) get_as() { return reinterpret_cast<tuple_array<Tx, size()>&>(*this); }
template <typename Tx> CK_TILE_HOST_DEVICE constexpr decltype(auto) get_as() const { return reinterpret_cast<const tuple_array<Tx, size()>&>(*this); }
// below index is for index *AFTER* type convert, not before
//template <typename Tx> CK_TILE_HOST_DEVICE constexpr decltype(auto) get_as(index_t i) { TP_COM_(); return reinterpret_cast<tuple_array<Tx, size()>&>(*this).at(i); }
//template <typename Tx> CK_TILE_HOST_DEVICE constexpr decltype(auto) get_as(index_t i) const { TP_COM_(); return reinterpret_cast<const tuple_array<Tx, size()>&>(*this).at(i); }
template <typename Tx, index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get_as(number<I>) { TP_COM_(); return reinterpret_cast<tuple_array<Tx, size()>&>(*this).at(number<I>{}); }
template <typename Tx, index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get_as(number<I>) const { TP_COM_(); return reinterpret_cast<const tuple_array<Tx, size()>&>(*this).at(number<I>{}); }
// template <typename Tx> CK_TILE_HOST_DEVICE constexpr void set_as(index_t i, const Tx & x) { TP_COM_(); reinterpret_cast<tuple_array<Tx, size()>&>(*this).at(i) = x; }
template <typename Tx, index_t I> CK_TILE_HOST_DEVICE constexpr void set_as(number<I>, const Tx & x) { TP_COM_(); reinterpret_cast<tuple_array<Tx, size()>&>(*this).at(number<I>{}) = x; }
// clang-format on
#undef TP_COM_
};
@@ -163,6 +305,15 @@ CK_TILE_HOST_DEVICE constexpr bool operator!=(const tuple<Xs...>& a, const tuple
template <typename... Xs>
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs&&... xs)
{
// here xs is always a lvalue as function arg
// Xs may deduced as (e.g try to pass in a integer in following cases)
// 1). if pass in a rvalue (like function return or int{}) -> Xs is "int"
// 2). if pass in a const lvalue -> Xs is "const int &"
// 3). if pass in a non-const lvalue -> Xs is "int &"
// so the return type of std::forward will dependes on Xs
// 1). std::forward -> int&&
// 2). std::forward -> const int&
// 3). std::forward -> int&
return tuple<remove_cvref_t<Xs>...>(std::forward<Xs>(xs)...);
}
@@ -182,6 +333,38 @@ struct tuple_concat<tuple<Xs...>, tuple<Ys...>>
using type = tuple<Xs..., Ys...>;
};
namespace impl {
// be very careful using this type (because we want the internal type)
// template deduction will fail if infering the inner type
// e.g.
// template<typename T, index_t N> using some_wrapper = typename tuple_array_impl<T, N>::type;
// template<typename T, index_t N> void foo(const some_wrapper<T, N>&) {}
// -> compiler will fail to deduce this type, because this is under non-deduced context
// (https://en.cppreference.com/w/cpp/language/template_argument_deduction, "Non-deduced
// contexts")
//
// -> use this instead
// template<typename Tup> void foo(const Tup&) {}
template <typename T, index_t N>
struct tuple_array_impl
{
using type = typename tuple_concat<typename tuple_array_impl<T, N / 2>::type,
typename tuple_array_impl<T, N - N / 2>::type>::type;
};
template <typename T>
struct tuple_array_impl<T, 0>
{
using type = tuple<>;
};
template <typename T>
struct tuple_array_impl<T, 1>
{
using type = tuple<T>;
};
} // namespace impl
template <typename F, index_t N>
CK_TILE_HOST_DEVICE constexpr auto generate_tuple(F&& f, number<N>)
{

View File

@@ -1,60 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/array.hpp"
#include "ck_tile/core/container/tuple.hpp"
#include "ck_tile/core/numeric/integer.hpp"
namespace ck_tile {
namespace detail {
template <typename T, index_t N>
struct tuple_array_impl
{
using type = typename tuple_concat<typename tuple_array_impl<T, N / 2>::type,
typename tuple_array_impl<T, N - N / 2>::type>::type;
};
template <typename T>
struct tuple_array_impl<T, 0>
{
using type = tuple<>;
};
template <typename T>
struct tuple_array_impl<T, 1>
{
using type = tuple<T>;
};
} // namespace detail
template <typename T, index_t N>
using tuple_array_base_t = typename detail::tuple_array_impl<T, N>::type;
template <typename T_, index_t N_>
struct tuple_array : tuple_array_base_t<T_, N_>
{
using value_type = T_;
static constexpr index_t N = N_;
// clang-format off
#define TA_COM_() static_assert(sizeof(value_type) * N % sizeof(Tx) == 0); constexpr int vx = sizeof(value_type) * N / sizeof(Tx)
template <typename Tx> CK_TILE_HOST_DEVICE constexpr decltype(auto) get_as() { TA_COM_(); return reinterpret_cast<tuple_array<Tx, vx>&>(*this); }
template <typename Tx> CK_TILE_HOST_DEVICE constexpr decltype(auto) get_as() const { TA_COM_(); return reinterpret_cast<const tuple_array<Tx, vx>&>(*this); }
// below index is for index *AFTER* type convert, not before
template <typename Tx> CK_TILE_HOST_DEVICE constexpr decltype(auto) get_as(index_t i) { TA_COM_(); return reinterpret_cast<tuple_array<Tx, vx>&>(*this).at(i); }
template <typename Tx> CK_TILE_HOST_DEVICE constexpr decltype(auto) get_as(index_t i) const { TA_COM_(); return reinterpret_cast<const tuple_array<Tx, vx>&>(*this).at(i); }
template <typename Tx, index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get_as(number<I>) { TA_COM_(); return reinterpret_cast<tuple_array<Tx, vx>&>(*this).at(number<I>{}); }
template <typename Tx, index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get_as(number<I>) const { TA_COM_(); return reinterpret_cast<const tuple_array<Tx, vx>&>(*this).at(number<I>{}); }
template <typename Tx> CK_TILE_HOST_DEVICE constexpr void set_as(index_t i, const Tx & x) { TA_COM_(); reinterpret_cast<tuple_array<Tx, vx>&>(*this).at(i) = x; }
template <typename Tx, index_t I> CK_TILE_HOST_DEVICE constexpr void set_as(number<I>, const Tx & x) { TA_COM_(); reinterpret_cast<tuple_array<Tx, vx>&>(*this).at(number<I>{}) = x; }
#undef TA_COM_
// clang-format on
};
} // namespace ck_tile

View File

@@ -43,7 +43,7 @@ struct alignas(2) half_t
constexpr fp16_hip_t to_fp16() const { return ck_tile::bit_cast<fp16_hip_t>(data); }
// constructor
constexpr half_t() : data() {}
constexpr half_t() : data{} {}
// construct from HIP half
CK_TILE_HOST_DEVICE

View File

@@ -12,7 +12,7 @@
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include "ck_tile/core/tensor/tile_distribution.hpp"
#include "ck_tile/core/container/rbuffer.hpp"
#include "ck_tile/core/container/thread_buffer.hpp"
namespace ck_tile {
@@ -72,7 +72,8 @@ struct static_distributed_tensor
constexpr auto sliced_thread_tensor_desc =
make_naive_tensor_descriptor_packed(make_tuple(YSliceLengths...));
rbuffer<DataType, sliced_thread_tensor_desc.get_element_space_size()> sliced_thread_data;
thread_buffer<DataType, sliced_thread_tensor_desc.get_element_space_size()>
sliced_thread_data;
static_ford<sequence<YSliceLengths...>>{}([&](auto idx) {
constexpr auto idx_ys = idx + sequence<YSliceOrigins...>{};
@@ -84,11 +85,10 @@ struct static_distributed_tensor
return sliced_thread_data;
}
template <index_t... YSliceOrigins, index_t... YSliceLengths, index_t NSlicedData>
CK_TILE_HOST_DEVICE void
set_y_sliced_thread_data(sequence<YSliceOrigins...>,
sequence<YSliceLengths...>,
const rbuffer<DataType, NSlicedData>& sliced_thread_data)
template <index_t... YSliceOrigins, index_t... YSliceLengths, typename SlicedThreadData>
CK_TILE_HOST_DEVICE void set_y_sliced_thread_data(sequence<YSliceOrigins...>,
sequence<YSliceLengths...>,
const SlicedThreadData& sliced_thread_data)
{
static_assert(sizeof...(YSliceOrigins) == StaticTileDistribution::NDimY &&
sizeof...(YSliceLengths) == StaticTileDistribution::NDimY,
@@ -130,7 +130,7 @@ struct static_distributed_tensor
}
//
rbuffer<DataType, kThreadElementSpaceSize> thread_buf_;
thread_buffer<DataType, kThreadElementSpaceSize> thread_buf_;
};
template <typename DataType, typename StaticTileDistribution>
@@ -140,6 +140,14 @@ CK_TILE_HOST_DEVICE constexpr auto make_static_distributed_tensor(const StaticTi
remove_cvref_t<StaticTileDistribution>>{};
}
template <typename DataType, typename StaticTileDistribution, typename ThreadBuffer>
CK_TILE_HOST_DEVICE constexpr auto make_static_distributed_tensor(const StaticTileDistribution&,
ThreadBuffer&& thread_buffer_)
{
return static_distributed_tensor<remove_cvref_t<DataType>,
remove_cvref_t<StaticTileDistribution>>{thread_buffer_};
}
// get X indices from tuple of tile_distributed_index<>
template <typename StaticTileDistribution, typename DistributedIndices>
CK_TILE_HOST_DEVICE constexpr auto

View File

@@ -34,20 +34,20 @@ CK_TILE_DEVICE void tile_elementwise_inout(const InOutElementFunc& inout_element
}
template <typename InElementFunc,
typename... InDstrTensors,
typename... InTensor,
typename = std::enable_if_t<
std::conjunction_v<std::negation<std::is_same<InDstrTensors, null_tensor>>...>>>
std::conjunction_v<std::negation<std::is_same<InTensor, null_tensor>>...>>>
CK_TILE_DEVICE auto tile_elementwise_in(const InElementFunc& in_element_func,
const InDstrTensors&... in_dstr_tensors)
const InTensor&... in_dstr_tensors)
{
using OutDataType = decltype(in_element_func(typename InDstrTensors::DataType{}...));
using OutDataType = decltype(in_element_func(typename InTensor::DataType{}...));
// TODO: make sure all distributed tensors have same lengths and distribution
// static_assert(xxx);
constexpr auto in_tile_dstr = __type_pack_element<0, InDstrTensors...>::get_tile_distribution();
constexpr auto in_tile_dstr = __type_pack_element<0, InTensor...>::get_tile_distribution();
constexpr index_t thread_buffer_size =
__type_pack_element<0, InDstrTensors...>::get_thread_buffer_size();
__type_pack_element<0, InTensor...>::get_thread_buffer_size();
auto out_dstr_tensor = make_static_distributed_tensor<OutDataType>(in_tile_dstr);
@@ -107,15 +107,16 @@ CK_TILE_DEVICE void clear_tile(DstrTensors& dstr_tensor)
set_tile(dstr_tensor, 0);
}
namespace impl {
// TODO: this is ugly
template <typename OutDataType, typename InDstrTensors>
CK_TILE_DEVICE auto cast_tile_pk_fp8x4(const InDstrTensors& in_dstr_tensors)
template <typename OutDataType, typename InTensor>
CK_TILE_DEVICE auto cast_tile_pk_fp8x4(const InTensor& in_dstr_tensors)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// This API is designed to use the _pk_ serious of function
constexpr auto in_tile_dstr = InDstrTensors::get_tile_distribution();
constexpr auto in_tile_dstr = InTensor::get_tile_distribution();
constexpr index_t thread_buffer_size = InDstrTensors::get_thread_buffer_size();
constexpr index_t thread_buffer_size = InTensor::get_thread_buffer_size();
static_assert(thread_buffer_size % 4 == 0);
constexpr index_t thread_buffer_size_pk = thread_buffer_size / 4;
@@ -150,24 +151,90 @@ CK_TILE_DEVICE auto cast_tile_pk_fp8x4(const InDstrTensors& in_dstr_tensors)
return out_dstr_tensor;
#else
// fallback
return tile_elementwise_in(type_convert<OutDataType, typename InDstrTensors::DataType>,
return tile_elementwise_in(type_convert<OutDataType, typename InTensor::DataType>,
in_dstr_tensors);
#endif
}
template <typename DstType, typename SrcDstrTensors>
CK_TILE_DEVICE auto cast_tile(const SrcDstrTensors& src_tensor)
// this function assume either src or dst (or both) date type is under 1 dword
// we pack subdword value into 1 dword to avoid compiler's default subdword behavior(which is buggy)
template <typename OutDataType, typename InTensor>
CK_TILE_DEVICE auto cast_tile_opt_subdword(const InTensor& in_dstr_tensors)
{
constexpr auto in_tile_dstr = InTensor::get_tile_distribution();
auto out_dstr_tensor = make_static_distributed_tensor<OutDataType>(in_tile_dstr);
using i_type = remove_cvref_t<typename InTensor::DataType>;
using o_type = remove_cvref_t<OutDataType>;
constexpr index_t i_elem_bytes = sizeof(i_type);
constexpr index_t o_elem_bytes = sizeof(o_type);
static_assert(i_elem_bytes < 4 || o_elem_bytes < 4);
constexpr index_t bulk_size =
(i_elem_bytes >= o_elem_bytes) ? (4 / o_elem_bytes) : (4 / i_elem_bytes);
static_assert(bulk_size != 0);
using o_bulk_type =
std::conditional_t<i_elem_bytes >= o_elem_bytes, float, array<o_type, bulk_size>>;
constexpr index_t thread_buffer_size = InTensor::get_thread_buffer_size();
constexpr index_t iters = thread_buffer_size / bulk_size;
constexpr index_t rems = thread_buffer_size % bulk_size;
// cast the sequence per-bulk
static_for<0, iters, 1>{}([&](auto i) {
union bulk_wrapper
{
o_bulk_type bulk{};
o_type data[bulk_size];
} o_bulk;
// TODO: should use below function, but somehow will result in spill (same as c-forloop)
// static_for<0, bulk_size, 1>{}([&o_bulk, &in_dstr_tensors, &i](auto ib){
// o_bulk.data[ib.value] =
// static_cast<o_type>(in_dstr_tensors.get_thread_buffer().template
// get_as<i_type>()[number<bulk_size * i.value + ib.value>{}]);
// });
// TODO: fixme, should use above!
static_assert(sizeof(i_type) / sizeof(o_type) == 2);
o_bulk.data[0] = static_cast<o_type>(
in_dstr_tensors.get_thread_buffer().template get_as<i_type>()[number<2 * i + 0>{}]);
o_bulk.data[1] = static_cast<o_type>(
in_dstr_tensors.get_thread_buffer().template get_as<i_type>()[number<2 * i + 1>{}]);
out_dstr_tensor.get_thread_buffer().template set_as<o_bulk_type>(i, o_bulk.bulk);
});
static_for<0, rems, 1>{}([&](auto r) {
// TODO: introducing local scratch pad?
auto idx = number<iters * bulk_size + r>{};
out_dstr_tensor.get_thread_buffer().at(idx) =
static_cast<o_type>(in_dstr_tensors.get_thread_buffer().at(idx));
});
return out_dstr_tensor;
}
} // namespace impl
template <typename DstType, typename SrcTensor>
CK_TILE_DEVICE auto cast_tile(const SrcTensor& src_tensor)
{
if constexpr((std::is_same_v<DstType, fp8_t> ||
std::is_same_v<DstType, bf8_t>)&&std::is_same_v<typename SrcDstrTensors::DataType,
std::is_same_v<DstType, bf8_t>)&&std::is_same_v<typename SrcTensor::DataType,
float> &&
(SrcDstrTensors::get_thread_buffer_size() % 4 == 0))
(SrcTensor::get_thread_buffer_size() % 4 == 0))
{
return cast_tile_pk_fp8x4<DstType, SrcDstrTensors>(src_tensor);
return impl::cast_tile_pk_fp8x4<DstType, SrcTensor>(src_tensor);
}
else if constexpr(sizeof(DstType) < 4 || sizeof(typename SrcTensor::DataType) < 4)
{
return impl::cast_tile_opt_subdword<DstType, SrcTensor>(src_tensor);
}
else
return tile_elementwise_in(type_convert<DstType, typename SrcDstrTensors::DataType>,
src_tensor);
return tile_elementwise_in(type_convert<DstType, typename SrcTensor::DataType>, src_tensor);
}
// no-op function for null_tensor arguments

View File

@@ -535,7 +535,8 @@ struct tile_window_with_static_distribution
using Traits = load_store_traits;
// using vector_type_t = typename Traits::vector_type_t;
using vector_t = typename Traits::vector_t;
// using vector_t = typename Traits::vector_t;
using vector_t = thread_buffer<DataType, Traits::ScalarPerVector>;
using SFC_Ys = typename Traits::SFC_Ys;
constexpr auto tile_dstr = TileDstr{};
@@ -553,10 +554,11 @@ struct tile_window_with_static_distribution
// data index [y0, y1, ...]
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
// TODO: below code may result in spill(?)
#if 0
// read from distributed tensor
// vector_type_t vec;
vector_t vec_value;
static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) {
constexpr auto idx_ys = generate_array(
[&](auto jj) {
@@ -564,10 +566,8 @@ struct tile_window_with_static_distribution
: idx_ys_start[jj];
},
number<NDimY>{});
constexpr index_t d =
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
vec_value.template get_as<DataType>()(j) =
dstr_tensor.get_thread_buffer().template at<d>();
});
@@ -578,7 +578,16 @@ struct tile_window_with_static_distribution
get_bottom_tensor_view()
.template set_vectorized_elements_raw<vector_t, oob_conditional_check>(
bottom_tensor_thread_coord, vec_value);
#else
(void)tile_dstr;
(void)idx_ys_start;
get_bottom_tensor_view()
.template set_vectorized_elements_raw<vector_t, oob_conditional_check>(
bottom_tensor_thread_coord,
dstr_tensor.get_thread_buffer().template get_as<vector_t>(
number<iCoord * NumAccessPerCoord + iCoordAccess>{}));
#endif
// move thread coordinate
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
{

View File

@@ -38,11 +38,7 @@ struct Default2DEpilogue
// TODO: this is ugly
if constexpr(kPadM || kPadN)
{
// o_dram_window_tmp.foo();
// ODataType{}.foo();
// o_acc_tile.foo();
auto x = cast_tile<ODataType>(o_acc_tile);
store_tile_raw(o_dram_window_tmp, x);
store_tile_raw(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
buffer_store_fence();
}
else