diff --git a/include/ck_tile/core/arch/amd_buffer_addressing.hpp b/include/ck_tile/core/arch/amd_buffer_addressing.hpp index 8a7d0ac887..5c54f6cda2 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing.hpp @@ -532,49 +532,59 @@ namespace impl{ template CK_TILE_DEVICE void insert_dummy_dep_per_dword(array& b) { - for (auto i = 0; i < b.size(); i++) asm volatile(" " : : "v"(b.get(i)) : "memory"); + static_for<0, b.size(), 1>{}([&](auto i){ + asm volatile(" " : : "v"(b.get(i)) : "memory"); + }); } - +#if 1 +// below specialization just merge size() of dwords into single section template<> CK_TILE_DEVICE void insert_dummy_dep_per_dword<2>(array& b) { - asm volatile(" " : : "v"(b.get(0)), "v"(b.get(1)) : "memory"); + asm volatile(" " : : "v"(b.get(number<0>{})), "v"(b.get(number<1>{})) : "memory"); } template<> CK_TILE_DEVICE void insert_dummy_dep_per_dword<3>(array& b) { - asm volatile(" " : : "v"(b.get(0)), "v"(b.get(1)), "v"(b.get(2)) : "memory"); + asm volatile(" " : : "v"(b.get(number<0>{})), "v"(b.get(number<1>{})), "v"(b.get(number<2>{})) : "memory"); } template<> CK_TILE_DEVICE void insert_dummy_dep_per_dword<4>(array& b) { - asm volatile(" " : : "v"(b.get(0)), "v"(b.get(1)), "v"(b.get(2)), "v"(b.get(3)) : "memory"); + asm volatile(" " : : "v"(b.get(number<0>{})), "v"(b.get(number<1>{})), "v"(b.get(number<2>{})), "v"(b.get(number<3>{})) : "memory"); } template<> CK_TILE_DEVICE void insert_dummy_dep_per_dword<8>(array& b) { - asm volatile(" " : : "v"(b.get(0)), "v"(b.get(1)), "v"(b.get(2)), "v"(b.get(3)), "v"(b.get(4)), "v"(b.get(5)), "v"(b.get(6)), "v"(b.get(7)) : "memory"); + asm volatile(" " : : "v"(b.get(number<0>{})), "v"(b.get(number<1>{})), "v"(b.get(number<2>{})), "v"(b.get(number<3>{})), + "v"(b.get(number<4>{})), "v"(b.get(number<5>{})), "v"(b.get(number<6>{})), "v"(b.get(number<7>{})) : "memory"); } template<> CK_TILE_DEVICE void insert_dummy_dep_per_dword<16>(array& b) { - asm volatile(" " : : "v"(b.get(0)), "v"(b.get(1)), "v"(b.get(2)), "v"(b.get(3)), "v"(b.get(4)), "v"(b.get(5)), "v"(b.get(6)), "v"(b.get(7)), - "v"(b.get(8)), "v"(b.get(9)), "v"(b.get(10)), "v"(b.get(11)), "v"(b.get(12)), "v"(b.get(13)), "v"(b.get(14)), "v"(b.get(15)) : "memory"); + asm volatile(" " : : "v"(b.get(number<0>{})), "v"(b.get(number<1>{})), "v"(b.get(number<2>{})), "v"(b.get(number<3>{})), + "v"(b.get(number<4>{})), "v"(b.get(number<5>{})), "v"(b.get(number<6>{})), "v"(b.get(number<7>{})), + "v"(b.get(number<8>{})), "v"(b.get(number<9>{})), "v"(b.get(number<10>{})), "v"(b.get(number<11>{})), + "v"(b.get(number<12>{})), "v"(b.get(number<13>{})), "v"(b.get(number<14>{})), "v"(b.get(number<15>{})) : "memory"); } template<> CK_TILE_DEVICE void insert_dummy_dep_per_dword<32>(array& b) { - asm volatile(" " : : "v"(b.get(0)), "v"(b.get(1)), "v"(b.get(2)), "v"(b.get(3)), "v"(b.get(4)), "v"(b.get(5)), "v"(b.get(6)), "v"(b.get(7)), - "v"(b.get(8)), "v"(b.get(9)), "v"(b.get(10)), "v"(b.get(11)), "v"(b.get(12)), "v"(b.get(13)), "v"(b.get(14)), "v"(b.get(15)), - "v"(b.get(16)), "v"(b.get(17)), "v"(b.get(18)), "v"(b.get(19)), "v"(b.get(20)), "v"(b.get(21)), "v"(b.get(22)), "v"(b.get(23)), - "v"(b.get(24)), "v"(b.get(25)), "v"(b.get(26)), "v"(b.get(27)), "v"(b.get(28)), "v"(b.get(29)), "v"(b.get(30)), "v"(b.get(31)) : "memory"); + asm volatile(" " : : "v"(b.get(number<0>{})), "v"(b.get(number<1>{})), "v"(b.get(number<2>{})), "v"(b.get(number<3>{})), + "v"(b.get(number<4>{})), "v"(b.get(number<5>{})), "v"(b.get(number<6>{})), "v"(b.get(number<7>{})), + "v"(b.get(number<8>{})), "v"(b.get(number<9>{})), "v"(b.get(number<10>{})), "v"(b.get(number<11>{})), + "v"(b.get(number<12>{})), "v"(b.get(number<13>{})), "v"(b.get(number<14>{})), "v"(b.get(number<15>{})), + "v"(b.get(number<16>{})), "v"(b.get(number<17>{})), "v"(b.get(number<18>{})), "v"(b.get(number<19>{})), + "v"(b.get(number<20>{})), "v"(b.get(number<21>{})), "v"(b.get(number<22>{})), "v"(b.get(number<23>{})), + "v"(b.get(number<24>{})), "v"(b.get(number<25>{})), "v"(b.get(number<26>{})), "v"(b.get(number<27>{})), + "v"(b.get(number<28>{})), "v"(b.get(number<29>{})), "v"(b.get(number<30>{})), "v"(b.get(number<31>{})) : "memory"); } - +#endif CK_TILE_DEVICE void insert_dummy_dep() {} template @@ -876,14 +886,15 @@ enum struct amd_buffer_coherence_enum template -CK_TILE_DEVICE array amd_buffer_load_impl_with_bytes(int32x4_t src_wave_buffer_resource, - index_t src_thread_addr_offset, - index_t src_wave_addr_offset) +CK_TILE_DEVICE thread_buffer +amd_buffer_load_impl_with_bytes(int32x4_t src_wave_buffer_resource, + index_t src_thread_addr_offset, + index_t src_wave_addr_offset) { static_assert(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32 || N == 64, "wrong! not implemented"); - using rtn_type = array; + using rtn_type = thread_buffer; if constexpr(N == 1) { @@ -939,7 +950,7 @@ CK_TILE_DEVICE array amd_buffer_load_impl_with_bytes(int32x4_t src_wa src_thread_addr_offset, src_wave_addr_offset + 4 * sizeof(int32_t), static_cast(coherence)); - array tmp; + thread_buffer tmp; tmp.template get_as()(number<0>{}) = tmp0; tmp.template get_as()(number<1>{}) = tmp1; @@ -968,7 +979,7 @@ CK_TILE_DEVICE array amd_buffer_load_impl_with_bytes(int32x4_t src_wa src_wave_addr_offset + 12 * sizeof(int32_t), static_cast(coherence)); - array tmp; + thread_buffer tmp; tmp.template get_as()(number<0>{}) = tmp0; tmp.template get_as()(number<1>{}) = tmp1; @@ -986,9 +997,9 @@ CK_TILE_DEVICE array amd_buffer_load_impl_with_bytes(int32x4_t src_wa template -CK_TILE_DEVICE array amd_buffer_load_impl(int32x4_t src_wave_buffer_resource, - index_t src_thread_addr_offset, - index_t src_wave_addr_offset) +CK_TILE_DEVICE thread_buffer amd_buffer_load_impl(int32x4_t src_wave_buffer_resource, + index_t src_thread_addr_offset, + index_t src_wave_addr_offset) { static_assert( (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || @@ -1002,7 +1013,7 @@ CK_TILE_DEVICE array amd_buffer_load_impl(int32x4_t src_wave_buffer_resour (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)), "wrong! not implemented"); - using rtn_type = array; + using rtn_type = thread_buffer; if constexpr(std::is_same::value) // fp32 { @@ -1032,7 +1043,7 @@ CK_TILE_DEVICE array amd_buffer_load_impl(int32x4_t src_wave_buffer_resour } else if constexpr(N == 8) { - array tmp; + thread_buffer tmp; tmp.template get_as()(number<0>{}) = llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource, @@ -1050,7 +1061,7 @@ CK_TILE_DEVICE array amd_buffer_load_impl(int32x4_t src_wave_buffer_resour } else if constexpr(N == 16) { - array tmp; + thread_buffer tmp; tmp.template get_as()(number<0>{}) = llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource, @@ -1165,7 +1176,7 @@ template -CK_TILE_DEVICE void amd_buffer_load_raw_impl(array& dst, +CK_TILE_DEVICE void amd_buffer_load_raw_impl(thread_buffer& dst, int32x4_t src_wave_buffer_resource, index_t src_thread_addr_offset, index_t src_wave_addr_offset, @@ -1175,7 +1186,7 @@ CK_TILE_DEVICE void amd_buffer_load_raw_impl(array& dst, static_assert(bytes == 1 || bytes == 2 || bytes == 4 || bytes == 8 || bytes == 16, "wrong! not supported by buffer_load instruction"); - using type = array; + using type = thread_buffer; if constexpr(oob_conditional_check) { buffer_load_if{}( @@ -1208,7 +1219,7 @@ CK_TILE_DEVICE void amd_async_buffer_load_impl(T* smem, template -CK_TILE_DEVICE void amd_buffer_store_impl_with_bytes(const array src_thread_data, +CK_TILE_DEVICE void amd_buffer_store_impl_with_bytes(const thread_buffer src_thread_data, int32x4_t dst_wave_buffer_resource, index_t dst_thread_addr_offset, index_t dst_wave_addr_offset) @@ -1308,7 +1319,7 @@ CK_TILE_DEVICE void amd_buffer_store_impl_with_bytes(const array src_ template -CK_TILE_DEVICE void amd_buffer_store_impl(const array src_thread_data, +CK_TILE_DEVICE void amd_buffer_store_impl(const thread_buffer src_thread_data, int32x4_t dst_wave_buffer_resource, index_t dst_thread_addr_offset, index_t dst_wave_addr_offset) @@ -1396,7 +1407,7 @@ CK_TILE_DEVICE void amd_buffer_store_impl(const array src_thread_data, else if constexpr(N == 8) { #if 0 - array tmp{src_thread_data}; + thread_buffer tmp{src_thread_data}; llvm_amdgcn_raw_buffer_store_fp16x4(tmp.template get_as()[number<0>{}], dst_wave_buffer_resource, @@ -1463,7 +1474,7 @@ CK_TILE_DEVICE void amd_buffer_store_impl(const array src_thread_data, } else { - using r_t = array; + using r_t = thread_buffer; amd_buffer_store_impl_with_bytes(bit_cast(src_thread_data), dst_wave_buffer_resource, @@ -1486,7 +1497,7 @@ CK_TILE_DEVICE void amd_buffer_store_raw_impl(const thread_buffer& dst_thr static_assert(bytes == 1 || bytes == 2 || bytes == 4 || bytes == 8 || bytes == 16, "wrong! not supported by buffer_store instruction"); - using type = array; + using type = thread_buffer; if constexpr(oob_conditional_check) { buffer_store_if{}(dst_thread_data, @@ -1507,7 +1518,7 @@ CK_TILE_DEVICE void amd_buffer_store_raw_impl(const thread_buffer& dst_thr } template -CK_TILE_DEVICE void amd_buffer_atomic_add_impl(const array& src_thread_data, +CK_TILE_DEVICE void amd_buffer_atomic_add_impl(const thread_buffer& src_thread_data, int32x4_t dst_wave_buffer_resource, index_t dst_thread_addr_offset, index_t dst_wave_addr_offset) @@ -1667,7 +1678,7 @@ CK_TILE_DEVICE void amd_buffer_atomic_add_impl(const array& src_thread_dat } template -CK_TILE_DEVICE void amd_buffer_atomic_max_impl(const array src_thread_data, +CK_TILE_DEVICE void amd_buffer_atomic_max_impl(const thread_buffer src_thread_data, int32x4_t dst_wave_buffer_resource, index_t dst_thread_addr_offset, index_t dst_wave_addr_offset) @@ -1742,7 +1753,7 @@ template -CK_TILE_DEVICE array +CK_TILE_DEVICE thread_buffer amd_buffer_load_invalid_element_return_zero(const T* p_src_wave, index_t src_thread_element_offset, bool src_thread_element_valid, @@ -1763,10 +1774,10 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave, return amd_buffer_load_impl( src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0); #else - array tmp = + thread_buffer tmp = amd_buffer_load_impl(src_wave_buffer_resource, src_thread_addr_offset, 0); if constexpr(oob_conditional_check) - return src_thread_element_valid ? tmp : array{0}; + return src_thread_element_valid ? tmp : thread_buffer{0}; else return tmp; #endif @@ -1780,7 +1791,7 @@ template -CK_TILE_DEVICE array +CK_TILE_DEVICE thread_buffer amd_buffer_load_invalid_element_return_customized_value(const T* p_src_wave, index_t src_thread_element_offset, bool src_thread_element_valid, @@ -1792,11 +1803,11 @@ amd_buffer_load_invalid_element_return_customized_value(const T* p_src_wave, index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T); - array tmp = + thread_buffer tmp = amd_buffer_load_impl(src_wave_buffer_resource, src_thread_addr_offset, 0); if constexpr(oob_conditional_check) - return src_thread_element_valid ? tmp : array{customized_value}; + return src_thread_element_valid ? tmp : thread_buffer{customized_value}; else return tmp; } @@ -1805,7 +1816,7 @@ template -CK_TILE_DEVICE void amd_buffer_load_raw(array& dst, +CK_TILE_DEVICE void amd_buffer_load_raw(thread_buffer& dst, const T* p_src_wave, index_t src_thread_element_offset, index_t src_element_space_size, @@ -1849,7 +1860,7 @@ template -CK_TILE_DEVICE void amd_buffer_store(const array& src_thread_data, +CK_TILE_DEVICE void amd_buffer_store(const thread_buffer& src_thread_data, T* p_dst_wave, const index_t dst_thread_element_offset, const bool dst_thread_element_valid, @@ -1913,7 +1924,7 @@ CK_TILE_DEVICE void amd_buffer_store_raw(const thread_buffer& src_thread_d // 2) p_dst_wave must be a wavewise pointer. // It is user's responsibility to make sure that is true. template -CK_TILE_DEVICE void amd_buffer_atomic_add(const array& src_thread_data, +CK_TILE_DEVICE void amd_buffer_atomic_add(const thread_buffer& src_thread_data, T* p_dst_wave, const index_t dst_thread_element_offset, const bool dst_thread_element_valid, @@ -1943,7 +1954,7 @@ CK_TILE_DEVICE void amd_buffer_atomic_add(const array& src_thread_data, // 2) p_dst_wave must be a wavewise pointer. // It is user's responsibility to make sure that is true. template -CK_TILE_DEVICE void amd_buffer_atomic_max(const array& src_thread_data, +CK_TILE_DEVICE void amd_buffer_atomic_max(const thread_buffer& src_thread_data, T* p_dst_wave, const index_t dst_thread_element_offset, const bool dst_thread_element_valid, diff --git a/include/ck_tile/core/config.hpp b/include/ck_tile/core/config.hpp index 965281d98b..4688356ff1 100644 --- a/include/ck_tile/core/config.hpp +++ b/include/ck_tile/core/config.hpp @@ -49,6 +49,16 @@ #define CK_TILE_THREAD_BUFFER_DEFAULT CK_TILE_THREAD_BUFFER_USE_ARRAY #endif +#ifndef CK_TILE_TUPLE_CTOR_WITH_INITIALIZER_LIST +#if CK_TILE_THREAD_BUFFER_DEFAULT == CK_TILE_THREAD_BUFFER_USE_TUPLE +// if using tuple-array as thread_buffer implementation, need to support {} brace init +// ... with similiar behavior as array +#define CK_TILE_TUPLE_CTOR_WITH_INITIALIZER_LIST 1 +#else +#define CK_TILE_TUPLE_CTOR_WITH_INITIALIZER_LIST 0 +#endif +#endif + #ifndef CK_TILE_USE_LAUNCH_BOUNDS #define CK_TILE_USE_LAUNCH_BOUNDS 1 #endif diff --git a/include/ck_tile/core/container/multi_index.hpp b/include/ck_tile/core/container/multi_index.hpp index b78c35a8a5..921c590df8 100644 --- a/include/ck_tile/core/container/multi_index.hpp +++ b/include/ck_tile/core/container/multi_index.hpp @@ -11,7 +11,8 @@ namespace ck_tile { -// deprecated, always use array instead +// Don't use tihs directly. This is for old CK's internal usage, +// in the future always use array instead template using multi_index = array; diff --git a/include/ck_tile/core/container/tuple.hpp b/include/ck_tile/core/container/tuple.hpp index 6692012d0d..cb8c2c70c6 100644 --- a/include/ck_tile/core/container/tuple.hpp +++ b/include/ck_tile/core/container/tuple.hpp @@ -11,6 +11,7 @@ #include "ck_tile/core/utility/functional.hpp" #include "ck_tile/core/utility/type_traits.hpp" #include +#include #ifndef CK_TILE_TUPLE_IMPL #define CK_TILE_TUPLE_IMPL 1 @@ -121,6 +122,17 @@ template struct tuple_base, T...> : tuple_object... { CK_TILE_HOST_DEVICE constexpr tuple_base() = default; + +#if CK_TILE_TUPLE_CTOR_WITH_INITIALIZER_LIST +#define _ILE() (std::initializer_list{}.size() - 1) + template + CK_TILE_HOST_DEVICE constexpr tuple_base(std::initializer_list us) + : tuple_object(static_cast(*(us.begin() + (I >= _ILE() ? _ILE() : I))))... + { + } +#undef _ILE +#endif + #if CK_TILE_TUPLE_IMPL == 0 template CK_TILE_HOST_DEVICE constexpr explicit tuple_base(U&&... u) @@ -182,6 +194,14 @@ struct tuple : impl::tuple_base, T...> static constexpr auto size() { return sizeof...(T); } using base = impl::tuple_base, T...>; CK_TILE_HOST_DEVICE constexpr tuple() = default; + +#if CK_TILE_TUPLE_CTOR_WITH_INITIALIZER_LIST + template + CK_TILE_HOST_DEVICE constexpr tuple(std::initializer_list us) : base(us) + { + } +#endif + #if CK_TILE_TUPLE_IMPL == 0 template CK_TILE_HOST_DEVICE constexpr tuple(U&&... u) : base(std::forward(u)...) diff --git a/include/ck_tile/core/tensor/tile_window.hpp b/include/ck_tile/core/tensor/tile_window.hpp index 0eaddb9947..dc6f482abd 100644 --- a/include/ck_tile/core/tensor/tile_window.hpp +++ b/include/ck_tile/core/tensor/tile_window.hpp @@ -91,7 +91,7 @@ struct tile_window_with_static_distribution // using vector_type_t = vector_type_maker_t; // using vector_t = typename vector_type_t::type; - using vector_t = array; + using vector_t = thread_buffer; private: static constexpr auto scalars_per_access_ = [] { diff --git a/include/ck_tile/ops/gemm.hpp b/include/ck_tile/ops/gemm.hpp index c7ebcf9606..cbe37e8769 100644 --- a/include/ck_tile/ops/gemm.hpp +++ b/include/ck_tile/ops/gemm.hpp @@ -29,3 +29,4 @@ #include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_impl.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" +