diff --git a/include/ck_tile/core.hpp b/include/ck_tile/core.hpp index 3877b5ceed..daf5a12d2d 100644 --- a/include/ck_tile/core.hpp +++ b/include/ck_tile/core.hpp @@ -15,12 +15,11 @@ #include "ck_tile/core/container/map.hpp" #include "ck_tile/core/container/meta_data_buffer.hpp" #include "ck_tile/core/container/multi_index.hpp" -#include "ck_tile/core/container/rbuffer.hpp" #include "ck_tile/core/container/sequence.hpp" #include "ck_tile/core/container/span.hpp" #include "ck_tile/core/container/statically_indexed_array.hpp" +#include "ck_tile/core/container/thread_buffer.hpp" #include "ck_tile/core/container/tuple.hpp" -#include "ck_tile/core/container/tuple_array.hpp" #include "ck_tile/core/numeric/arithmetic.hpp" #include "ck_tile/core/numeric/bfloat16.hpp" #include "ck_tile/core/numeric/float8.hpp" diff --git a/include/ck_tile/core/arch/amd_buffer_addressing.hpp b/include/ck_tile/core/arch/amd_buffer_addressing.hpp index 61ccde3804..8a7d0ac887 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing.hpp @@ -7,6 +7,7 @@ #include "ck_tile/core/numeric/integral_constant.hpp" #include "ck_tile/core/numeric/vector_type.hpp" #include "ck_tile/core/container/container_helper.hpp" +#include "ck_tile/core/container/thread_buffer.hpp" #include "ck_tile/core/utility/type_traits.hpp" #include "ck_tile/core/utility/bit_cast.hpp" #include "ck_tile/core/utility/functional.hpp" @@ -1475,7 +1476,7 @@ template -CK_TILE_DEVICE void amd_buffer_store_raw_impl(const array& dst_thread_data, +CK_TILE_DEVICE void amd_buffer_store_raw_impl(const thread_buffer& dst_thread_data, int32x4_t dst_wave_buffer_resource, index_t dst_thread_addr_offset, index_t dst_wave_addr_offset, @@ -1889,7 +1890,7 @@ template -CK_TILE_DEVICE void amd_buffer_store_raw(const array& src_thread_data, +CK_TILE_DEVICE void amd_buffer_store_raw(const thread_buffer& src_thread_data, T* p_dst_wave, const index_t dst_thread_element_offset, const bool dst_thread_element_valid, diff --git a/include/ck_tile/core/config.hpp b/include/ck_tile/core/config.hpp index cfb94ea2f3..86d14dc375 100644 --- a/include/ck_tile/core/config.hpp +++ b/include/ck_tile/core/config.hpp @@ -51,10 +51,10 @@ #define CK_TILE_STATICALLY_INDEXED_ARRAY_DEFAULT CK_TILE_STATICALLY_INDEXED_ARRAY_USE_TUPLE #endif -#define CK_TILE_RBUFFER_USE_ARRAY 0 -#define CK_TILE_RBUFFER_USE_TUPLE 1 -#ifndef CK_TILE_RBUFFER_DEFAULT -#define CK_TILE_RBUFFER_DEFAULT CK_TILE_RBUFFER_USE_TUPLE +#define CK_TILE_THREAD_BUFFER_USE_ARRAY 0 +#define CK_TILE_THREAD_BUFFER_USE_TUPLE 1 +#ifndef CK_TILE_THREAD_BUFFER_DEFAULT +#define CK_TILE_THREAD_BUFFER_DEFAULT CK_TILE_THREAD_BUFFER_USE_ARRAY #endif #ifndef CK_TILE_USE_LAUNCH_BOUNDS diff --git a/include/ck_tile/core/container/array.hpp b/include/ck_tile/core/container/array.hpp index 7f2041494b..9153d275df 100644 --- a/include/ck_tile/core/container/array.hpp +++ b/include/ck_tile/core/container/array.hpp @@ -157,19 +157,47 @@ struct vector_traits> static constexpr index_t vector_size = N; }; -template -CK_TILE_HOST_DEVICE constexpr auto make_array(Ts&&... xs) +namespace details { +template +struct is_ref_wrapper : std::false_type { - using value_type = remove_cvref_t; - return array{std::forward(xs)...}; +}; +template +struct is_ref_wrapper> : std::true_type +{ +}; + +template +using not_ref_wrapper = std::negation>>; + +template +struct return_type_helper +{ + using type = D; +}; +template +struct return_type_helper : std::common_type +{ + static_assert(std::conjunction_v...>, + "Ts cannot contain reference_wrappers when D is void"); +}; + +template +using return_type = array::type, sizeof...(Ts)>; +} // namespace details + +template +CK_TILE_HOST_DEVICE constexpr details::return_type make_array(Ts&&... ts) +{ + return {std::forward(ts)...}; } -// make empty array -template -CK_TILE_HOST_DEVICE constexpr auto make_array() -{ - return array{}; -} +// // make empty array +// template +// CK_TILE_HOST_DEVICE constexpr auto make_array() +// { +// return array{}; +// } // compatible with old ck's initializer, make an array and fill it withe the last element from // initializer_list diff --git a/include/ck_tile/core/container/rbuffer.hpp b/include/ck_tile/core/container/rbuffer.hpp deleted file mode 100644 index d44eacf2f7..0000000000 --- a/include/ck_tile/core/container/rbuffer.hpp +++ /dev/null @@ -1,20 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include "ck_tile/core/config.hpp" -#include "ck_tile/core/container/array.hpp" -#include "ck_tile/core/container/tuple_array.hpp" - -namespace ck_tile { - -#if CK_TILE_RBUFFER_DEFAULT == CK_TILE_RBUFFER_USE_TUPLE -template -using rbuffer = tuple_array; -#else -template -using rbuffer = array -#endif - -} // namespace ck_tile diff --git a/include/ck_tile/core/container/statically_indexed_array.hpp b/include/ck_tile/core/container/statically_indexed_array.hpp index 4e1af96cbd..d6da50b627 100644 --- a/include/ck_tile/core/container/statically_indexed_array.hpp +++ b/include/ck_tile/core/container/statically_indexed_array.hpp @@ -5,7 +5,7 @@ #include "ck_tile/core/config.hpp" #include "ck_tile/core/container/array.hpp" -#include "ck_tile/core/container/tuple_array.hpp" +#include "ck_tile/core/container/tuple.hpp" #include "ck_tile/core/numeric/integer.hpp" namespace ck_tile { diff --git a/include/ck_tile/core/container/thread_buffer.hpp b/include/ck_tile/core/container/thread_buffer.hpp new file mode 100644 index 0000000000..7b8895a953 --- /dev/null +++ b/include/ck_tile/core/container/thread_buffer.hpp @@ -0,0 +1,32 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/container/array.hpp" +#include "ck_tile/core/container/tuple.hpp" + +namespace ck_tile { + +#if CK_TILE_THREAD_BUFFER_DEFAULT == CK_TILE_THREAD_BUFFER_USE_TUPLE +template +using thread_buffer = tuple_array; + +template +CK_TILE_HOST_DEVICE constexpr auto make_thread_buffer(Ts&&... ts) +{ + return make_tuple(ts...); +} +#else +template +using thread_buffer = array; + +template +CK_TILE_HOST_DEVICE constexpr auto make_thread_buffer(Ts&&... ts) +{ + return make_array(ts...); +} +#endif + +} // namespace ck_tile diff --git a/include/ck_tile/core/container/tuple.hpp b/include/ck_tile/core/container/tuple.hpp index c146cba9cf..1be5a55dce 100644 --- a/include/ck_tile/core/container/tuple.hpp +++ b/include/ck_tile/core/container/tuple.hpp @@ -12,8 +12,20 @@ #include "ck_tile/core/utility/type_traits.hpp" #include +#ifndef CK_TILE_TUPLE_IMPL +#define CK_TILE_TUPLE_IMPL 1 +#endif + namespace ck_tile { +namespace impl { +template +struct tuple_array_impl; +} + +template +using tuple_array = typename impl::tuple_array_impl::type; + namespace impl { // the place where content is stored @@ -26,37 +38,77 @@ template struct tuple_object { CK_TILE_HOST_DEVICE constexpr tuple_object() {} - CK_TILE_HOST_DEVICE constexpr tuple_object(const T&) {} +#if CK_TILE_TUPLE_IMPL == 0 + template + CK_TILE_HOST_DEVICE constexpr tuple_object(U&&) + { + } + template + CK_TILE_HOST_DEVICE constexpr tuple_object(const U&) + { + } + template + CK_TILE_HOST_DEVICE constexpr tuple_object(U&) + { + } +#elif CK_TILE_TUPLE_IMPL == 1 + template , tuple_object>::value, + bool>::type = false> + CK_TILE_HOST_DEVICE constexpr tuple_object(U&&) + { + } +#endif }; template struct tuple_object { CK_TILE_HOST_DEVICE constexpr tuple_object() : element{} {} - CK_TILE_HOST_DEVICE constexpr tuple_object(const T& e) : element(e) {} +#if CK_TILE_TUPLE_IMPL == 0 + template + CK_TILE_HOST_DEVICE constexpr tuple_object(U&& e) : element(std::forward(e)) + { + } + template + CK_TILE_HOST_DEVICE constexpr tuple_object(const U& e) : element(e) + { + } + template + CK_TILE_HOST_DEVICE constexpr tuple_object(U& e) : element(e) + { + } +#elif CK_TILE_TUPLE_IMPL == 1 + template , tuple_object>::value, + bool>::type = false> + CK_TILE_HOST_DEVICE constexpr tuple_object(U&& e) : element(std::forward(e)) + { + } +#endif T element; }; // NOTE: we return a instance(not a reference) if content is empty -template +template CK_TILE_HOST_DEVICE constexpr T getv(const tuple_object&) { return {}; } -template +template CK_TILE_HOST_DEVICE constexpr const T& getv(const tuple_object& x) { return x.element; } -template +template CK_TILE_HOST_DEVICE constexpr T& getv(tuple_object& x) { return x.element; } -template +template CK_TILE_HOST_DEVICE constexpr T&& getv(tuple_object&& x) { return static_cast(x.element); @@ -68,18 +120,58 @@ struct tuple_base; template struct tuple_base, T...> : tuple_object... { - CK_TILE_HOST_DEVICE constexpr tuple_base() {} + CK_TILE_HOST_DEVICE constexpr tuple_base() = default; +#if CK_TILE_TUPLE_IMPL == 0 + template + CK_TILE_HOST_DEVICE constexpr explicit tuple_base(U&&... u) + : tuple_object(std::forward(u))... + { + } template CK_TILE_HOST_DEVICE constexpr explicit tuple_base(const U&... u) : tuple_object(u)... { } + template + CK_TILE_HOST_DEVICE constexpr explicit tuple_base(U&... u) : tuple_object(u)... + { + } + + template + CK_TILE_HOST_DEVICE constexpr tuple_base(tuple_base, U...>&& u) + : tuple_object(getv(static_cast&&>(u)))... + { + } + template CK_TILE_HOST_DEVICE constexpr tuple_base(const tuple_base, U...>& u) : tuple_object(getv(static_cast&>(u)))... { } + + template + CK_TILE_HOST_DEVICE constexpr tuple_base(tuple_base, U...>& u) + : tuple_object(getv(static_cast&>(u)))... + { + } +#elif CK_TILE_TUPLE_IMPL == 1 + template , tuple_base>::value, + bool>::type = false> + CK_TILE_HOST_DEVICE constexpr tuple_base(U&& u) : tuple_object(std::forward(u))... + { + } + + template = 2, bool>::type = false> + CK_TILE_HOST_DEVICE constexpr tuple_base(U&&... u) : tuple_object(std::forward(u))... + { + static_assert(sizeof...(I) == sizeof...(T) && sizeof...(I) == sizeof...(U), + "wrong! inconsistent size"); + } + +#endif }; } // namespace impl @@ -89,19 +181,56 @@ struct tuple : impl::tuple_base, T...> CK_TILE_HOST_DEVICE static constexpr auto size() { return sizeof...(T); } using base = impl::tuple_base, T...>; - CK_TILE_HOST_DEVICE constexpr tuple() {} + CK_TILE_HOST_DEVICE constexpr tuple() = default; +#if CK_TILE_TUPLE_IMPL == 0 + template + CK_TILE_HOST_DEVICE constexpr tuple(U&&... u) : base(std::forward(u)...) + { + } template CK_TILE_HOST_DEVICE constexpr tuple(const U&... u) : base(u...) { } + template + CK_TILE_HOST_DEVICE constexpr tuple(U&... u) : base(u...) + { + } + + template + CK_TILE_HOST_DEVICE constexpr tuple(tuple&& u) + : base(static_cast, U...>&&>(u)) + { + } + template CK_TILE_HOST_DEVICE constexpr tuple(const tuple& u) : base(static_cast, U...>&>(u)) { } + template + CK_TILE_HOST_DEVICE constexpr tuple(tuple& u) + : base(static_cast, U...>&>(u)) + { + } +#elif CK_TILE_TUPLE_IMPL == 1 + template < + typename U, + typename std::enable_if, tuple>::value, + bool>::type = false> + CK_TILE_HOST_DEVICE constexpr tuple(U&& u) : base(std::forward(u)) + { + } + + template = 2, + bool>::type = false> + CK_TILE_HOST_DEVICE constexpr tuple(U&&... u) : base(std::forward(u)...) + { + } +#endif CK_TILE_HOST_DEVICE static constexpr bool is_static() { bool flag = true; @@ -128,6 +257,19 @@ struct tuple : impl::tuple_base, T...> template CK_TILE_HOST_DEVICE constexpr decltype(auto) operator[](number) { TP_COM_(); return get(); } template CK_TILE_HOST_DEVICE constexpr decltype(auto) operator[](number) const { TP_COM_(); return get(); } template CK_TILE_HOST_DEVICE constexpr decltype(auto) operator()(number) { TP_COM_(); return get(); } // TODO: compatible + + // below function should be used under tuple_array<> type, no extra check will perform here + template CK_TILE_HOST_DEVICE constexpr decltype(auto) get_as() { return reinterpret_cast&>(*this); } + template CK_TILE_HOST_DEVICE constexpr decltype(auto) get_as() const { return reinterpret_cast&>(*this); } + // below index is for index *AFTER* type convert, not before + //template CK_TILE_HOST_DEVICE constexpr decltype(auto) get_as(index_t i) { TP_COM_(); return reinterpret_cast&>(*this).at(i); } + //template CK_TILE_HOST_DEVICE constexpr decltype(auto) get_as(index_t i) const { TP_COM_(); return reinterpret_cast&>(*this).at(i); } + template CK_TILE_HOST_DEVICE constexpr decltype(auto) get_as(number) { TP_COM_(); return reinterpret_cast&>(*this).at(number{}); } + template CK_TILE_HOST_DEVICE constexpr decltype(auto) get_as(number) const { TP_COM_(); return reinterpret_cast&>(*this).at(number{}); } + + // template CK_TILE_HOST_DEVICE constexpr void set_as(index_t i, const Tx & x) { TP_COM_(); reinterpret_cast&>(*this).at(i) = x; } + template CK_TILE_HOST_DEVICE constexpr void set_as(number, const Tx & x) { TP_COM_(); reinterpret_cast&>(*this).at(number{}) = x; } + // clang-format on #undef TP_COM_ }; @@ -163,6 +305,15 @@ CK_TILE_HOST_DEVICE constexpr bool operator!=(const tuple& a, const tuple template CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs&&... xs) { + // here xs is always a lvalue as function arg + // Xs may deduced as (e.g try to pass in a integer in following cases) + // 1). if pass in a rvalue (like function return or int{}) -> Xs is "int" + // 2). if pass in a const lvalue -> Xs is "const int &" + // 3). if pass in a non-const lvalue -> Xs is "int &" + // so the return type of std::forward will dependes on Xs + // 1). std::forward -> int&& + // 2). std::forward -> const int& + // 3). std::forward -> int& return tuple...>(std::forward(xs)...); } @@ -182,6 +333,38 @@ struct tuple_concat, tuple> using type = tuple; }; +namespace impl { +// be very careful using this type (because we want the internal type) +// template deduction will fail if infering the inner type +// e.g. +// template using some_wrapper = typename tuple_array_impl::type; +// template void foo(const some_wrapper&) {} +// -> compiler will fail to deduce this type, because this is under non-deduced context +// (https://en.cppreference.com/w/cpp/language/template_argument_deduction, "Non-deduced +// contexts") +// +// -> use this instead +// template void foo(const Tup&) {} +template +struct tuple_array_impl +{ + using type = typename tuple_concat::type, + typename tuple_array_impl::type>::type; +}; + +template +struct tuple_array_impl +{ + using type = tuple<>; +}; + +template +struct tuple_array_impl +{ + using type = tuple; +}; +} // namespace impl + template CK_TILE_HOST_DEVICE constexpr auto generate_tuple(F&& f, number) { diff --git a/include/ck_tile/core/container/tuple_array.hpp b/include/ck_tile/core/container/tuple_array.hpp deleted file mode 100644 index 27fe9d8f5d..0000000000 --- a/include/ck_tile/core/container/tuple_array.hpp +++ /dev/null @@ -1,60 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include "ck_tile/core/config.hpp" -#include "ck_tile/core/container/array.hpp" -#include "ck_tile/core/container/tuple.hpp" -#include "ck_tile/core/numeric/integer.hpp" - -namespace ck_tile { - -namespace detail { -template -struct tuple_array_impl -{ - using type = typename tuple_concat::type, - typename tuple_array_impl::type>::type; -}; - -template -struct tuple_array_impl -{ - using type = tuple<>; -}; - -template -struct tuple_array_impl -{ - using type = tuple; -}; -} // namespace detail - -template -using tuple_array_base_t = typename detail::tuple_array_impl::type; - -template -struct tuple_array : tuple_array_base_t -{ - using value_type = T_; - static constexpr index_t N = N_; - -// clang-format off -#define TA_COM_() static_assert(sizeof(value_type) * N % sizeof(Tx) == 0); constexpr int vx = sizeof(value_type) * N / sizeof(Tx) - template CK_TILE_HOST_DEVICE constexpr decltype(auto) get_as() { TA_COM_(); return reinterpret_cast&>(*this); } - template CK_TILE_HOST_DEVICE constexpr decltype(auto) get_as() const { TA_COM_(); return reinterpret_cast&>(*this); } - - // below index is for index *AFTER* type convert, not before - template CK_TILE_HOST_DEVICE constexpr decltype(auto) get_as(index_t i) { TA_COM_(); return reinterpret_cast&>(*this).at(i); } - template CK_TILE_HOST_DEVICE constexpr decltype(auto) get_as(index_t i) const { TA_COM_(); return reinterpret_cast&>(*this).at(i); } - template CK_TILE_HOST_DEVICE constexpr decltype(auto) get_as(number) { TA_COM_(); return reinterpret_cast&>(*this).at(number{}); } - template CK_TILE_HOST_DEVICE constexpr decltype(auto) get_as(number) const { TA_COM_(); return reinterpret_cast&>(*this).at(number{}); } - - template CK_TILE_HOST_DEVICE constexpr void set_as(index_t i, const Tx & x) { TA_COM_(); reinterpret_cast&>(*this).at(i) = x; } - template CK_TILE_HOST_DEVICE constexpr void set_as(number, const Tx & x) { TA_COM_(); reinterpret_cast&>(*this).at(number{}) = x; } -#undef TA_COM_ - // clang-format on -}; - -} // namespace ck_tile diff --git a/include/ck_tile/core/numeric/half.hpp b/include/ck_tile/core/numeric/half.hpp index 5ad9e3aacd..4a01a5a985 100644 --- a/include/ck_tile/core/numeric/half.hpp +++ b/include/ck_tile/core/numeric/half.hpp @@ -43,7 +43,7 @@ struct alignas(2) half_t constexpr fp16_hip_t to_fp16() const { return ck_tile::bit_cast(data); } // constructor - constexpr half_t() : data() {} + constexpr half_t() : data{} {} // construct from HIP half CK_TILE_HOST_DEVICE diff --git a/include/ck_tile/core/tensor/static_distributed_tensor.hpp b/include/ck_tile/core/tensor/static_distributed_tensor.hpp index 38f6f21770..299a74bc08 100644 --- a/include/ck_tile/core/tensor/static_distributed_tensor.hpp +++ b/include/ck_tile/core/tensor/static_distributed_tensor.hpp @@ -12,7 +12,7 @@ #include "ck_tile/core/utility/functional.hpp" #include "ck_tile/core/utility/type_traits.hpp" #include "ck_tile/core/tensor/tile_distribution.hpp" -#include "ck_tile/core/container/rbuffer.hpp" +#include "ck_tile/core/container/thread_buffer.hpp" namespace ck_tile { @@ -72,7 +72,8 @@ struct static_distributed_tensor constexpr auto sliced_thread_tensor_desc = make_naive_tensor_descriptor_packed(make_tuple(YSliceLengths...)); - rbuffer sliced_thread_data; + thread_buffer + sliced_thread_data; static_ford>{}([&](auto idx) { constexpr auto idx_ys = idx + sequence{}; @@ -84,11 +85,10 @@ struct static_distributed_tensor return sliced_thread_data; } - template - CK_TILE_HOST_DEVICE void - set_y_sliced_thread_data(sequence, - sequence, - const rbuffer& sliced_thread_data) + template + CK_TILE_HOST_DEVICE void set_y_sliced_thread_data(sequence, + sequence, + const SlicedThreadData& sliced_thread_data) { static_assert(sizeof...(YSliceOrigins) == StaticTileDistribution::NDimY && sizeof...(YSliceLengths) == StaticTileDistribution::NDimY, @@ -130,7 +130,7 @@ struct static_distributed_tensor } // - rbuffer thread_buf_; + thread_buffer thread_buf_; }; template @@ -140,6 +140,14 @@ CK_TILE_HOST_DEVICE constexpr auto make_static_distributed_tensor(const StaticTi remove_cvref_t>{}; } +template +CK_TILE_HOST_DEVICE constexpr auto make_static_distributed_tensor(const StaticTileDistribution&, + ThreadBuffer&& thread_buffer_) +{ + return static_distributed_tensor, + remove_cvref_t>{thread_buffer_}; +} + // get X indices from tuple of tile_distributed_index<> template CK_TILE_HOST_DEVICE constexpr auto diff --git a/include/ck_tile/core/tensor/tile_elementwise.hpp b/include/ck_tile/core/tensor/tile_elementwise.hpp index 974cb2ee1e..4d631f90ac 100644 --- a/include/ck_tile/core/tensor/tile_elementwise.hpp +++ b/include/ck_tile/core/tensor/tile_elementwise.hpp @@ -34,20 +34,20 @@ CK_TILE_DEVICE void tile_elementwise_inout(const InOutElementFunc& inout_element } template >...>>> + std::conjunction_v>...>>> CK_TILE_DEVICE auto tile_elementwise_in(const InElementFunc& in_element_func, - const InDstrTensors&... in_dstr_tensors) + const InTensor&... in_dstr_tensors) { - using OutDataType = decltype(in_element_func(typename InDstrTensors::DataType{}...)); + using OutDataType = decltype(in_element_func(typename InTensor::DataType{}...)); // TODO: make sure all distributed tensors have same lengths and distribution // static_assert(xxx); - constexpr auto in_tile_dstr = __type_pack_element<0, InDstrTensors...>::get_tile_distribution(); + constexpr auto in_tile_dstr = __type_pack_element<0, InTensor...>::get_tile_distribution(); constexpr index_t thread_buffer_size = - __type_pack_element<0, InDstrTensors...>::get_thread_buffer_size(); + __type_pack_element<0, InTensor...>::get_thread_buffer_size(); auto out_dstr_tensor = make_static_distributed_tensor(in_tile_dstr); @@ -107,15 +107,16 @@ CK_TILE_DEVICE void clear_tile(DstrTensors& dstr_tensor) set_tile(dstr_tensor, 0); } +namespace impl { // TODO: this is ugly -template -CK_TILE_DEVICE auto cast_tile_pk_fp8x4(const InDstrTensors& in_dstr_tensors) +template +CK_TILE_DEVICE auto cast_tile_pk_fp8x4(const InTensor& in_dstr_tensors) { #if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) // This API is designed to use the _pk_ serious of function - constexpr auto in_tile_dstr = InDstrTensors::get_tile_distribution(); + constexpr auto in_tile_dstr = InTensor::get_tile_distribution(); - constexpr index_t thread_buffer_size = InDstrTensors::get_thread_buffer_size(); + constexpr index_t thread_buffer_size = InTensor::get_thread_buffer_size(); static_assert(thread_buffer_size % 4 == 0); constexpr index_t thread_buffer_size_pk = thread_buffer_size / 4; @@ -150,24 +151,90 @@ CK_TILE_DEVICE auto cast_tile_pk_fp8x4(const InDstrTensors& in_dstr_tensors) return out_dstr_tensor; #else // fallback - return tile_elementwise_in(type_convert, + return tile_elementwise_in(type_convert, in_dstr_tensors); #endif } -template -CK_TILE_DEVICE auto cast_tile(const SrcDstrTensors& src_tensor) +// this function assume either src or dst (or both) date type is under 1 dword +// we pack subdword value into 1 dword to avoid compiler's default subdword behavior(which is buggy) +template +CK_TILE_DEVICE auto cast_tile_opt_subdword(const InTensor& in_dstr_tensors) +{ + constexpr auto in_tile_dstr = InTensor::get_tile_distribution(); + + auto out_dstr_tensor = make_static_distributed_tensor(in_tile_dstr); + + using i_type = remove_cvref_t; + using o_type = remove_cvref_t; + constexpr index_t i_elem_bytes = sizeof(i_type); + constexpr index_t o_elem_bytes = sizeof(o_type); + static_assert(i_elem_bytes < 4 || o_elem_bytes < 4); + + constexpr index_t bulk_size = + (i_elem_bytes >= o_elem_bytes) ? (4 / o_elem_bytes) : (4 / i_elem_bytes); + static_assert(bulk_size != 0); + + using o_bulk_type = + std::conditional_t= o_elem_bytes, float, array>; + + constexpr index_t thread_buffer_size = InTensor::get_thread_buffer_size(); + + constexpr index_t iters = thread_buffer_size / bulk_size; + constexpr index_t rems = thread_buffer_size % bulk_size; + + // cast the sequence per-bulk + static_for<0, iters, 1>{}([&](auto i) { + union bulk_wrapper + { + o_bulk_type bulk{}; + o_type data[bulk_size]; + } o_bulk; + + // TODO: should use below function, but somehow will result in spill (same as c-forloop) + // static_for<0, bulk_size, 1>{}([&o_bulk, &in_dstr_tensors, &i](auto ib){ + // o_bulk.data[ib.value] = + // static_cast(in_dstr_tensors.get_thread_buffer().template + // get_as()[number{}]); + // }); + + // TODO: fixme, should use above! + static_assert(sizeof(i_type) / sizeof(o_type) == 2); + o_bulk.data[0] = static_cast( + in_dstr_tensors.get_thread_buffer().template get_as()[number<2 * i + 0>{}]); + o_bulk.data[1] = static_cast( + in_dstr_tensors.get_thread_buffer().template get_as()[number<2 * i + 1>{}]); + + out_dstr_tensor.get_thread_buffer().template set_as(i, o_bulk.bulk); + }); + + static_for<0, rems, 1>{}([&](auto r) { + // TODO: introducing local scratch pad? + auto idx = number{}; + out_dstr_tensor.get_thread_buffer().at(idx) = + static_cast(in_dstr_tensors.get_thread_buffer().at(idx)); + }); + + return out_dstr_tensor; +} +} // namespace impl + +template +CK_TILE_DEVICE auto cast_tile(const SrcTensor& src_tensor) { if constexpr((std::is_same_v || - std::is_same_v)&&std::is_same_v)&&std::is_same_v && - (SrcDstrTensors::get_thread_buffer_size() % 4 == 0)) + (SrcTensor::get_thread_buffer_size() % 4 == 0)) { - return cast_tile_pk_fp8x4(src_tensor); + return impl::cast_tile_pk_fp8x4(src_tensor); + } + else if constexpr(sizeof(DstType) < 4 || sizeof(typename SrcTensor::DataType) < 4) + { + return impl::cast_tile_opt_subdword(src_tensor); } else - return tile_elementwise_in(type_convert, - src_tensor); + return tile_elementwise_in(type_convert, src_tensor); } // no-op function for null_tensor arguments diff --git a/include/ck_tile/core/tensor/tile_window.hpp b/include/ck_tile/core/tensor/tile_window.hpp index 643f6d77ef..cd96671456 100644 --- a/include/ck_tile/core/tensor/tile_window.hpp +++ b/include/ck_tile/core/tensor/tile_window.hpp @@ -535,7 +535,8 @@ struct tile_window_with_static_distribution using Traits = load_store_traits; // using vector_type_t = typename Traits::vector_type_t; - using vector_t = typename Traits::vector_t; + // using vector_t = typename Traits::vector_t; + using vector_t = thread_buffer; using SFC_Ys = typename Traits::SFC_Ys; constexpr auto tile_dstr = TileDstr{}; @@ -553,10 +554,11 @@ struct tile_window_with_static_distribution // data index [y0, y1, ...] constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess); + // TODO: below code may result in spill(?) +#if 0 // read from distributed tensor // vector_type_t vec; vector_t vec_value; - static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) { constexpr auto idx_ys = generate_array( [&](auto jj) { @@ -564,10 +566,8 @@ struct tile_window_with_static_distribution : idx_ys_start[jj]; }, number{}); - constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys); - vec_value.template get_as()(j) = dstr_tensor.get_thread_buffer().template at(); }); @@ -578,7 +578,16 @@ struct tile_window_with_static_distribution get_bottom_tensor_view() .template set_vectorized_elements_raw( bottom_tensor_thread_coord, vec_value); +#else + (void)tile_dstr; + (void)idx_ys_start; + get_bottom_tensor_view() + .template set_vectorized_elements_raw( + bottom_tensor_thread_coord, + dstr_tensor.get_thread_buffer().template get_as( + number{})); +#endif // move thread coordinate if constexpr(iCoordAccess != (NumAccessPerCoord - 1)) { diff --git a/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp b/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp index 2bfbb8b38f..5dc49c3b0e 100644 --- a/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp @@ -38,11 +38,7 @@ struct Default2DEpilogue // TODO: this is ugly if constexpr(kPadM || kPadN) { - // o_dram_window_tmp.foo(); - // ODataType{}.foo(); - // o_acc_tile.foo(); - auto x = cast_tile(o_acc_tile); - store_tile_raw(o_dram_window_tmp, x); + store_tile_raw(o_dram_window_tmp, cast_tile(o_acc_tile)); buffer_store_fence(); } else