make sure thread_buffer can be tuple/array

This commit is contained in:
carlushuang
2024-03-13 22:03:42 +00:00
parent 616932068d
commit 04762d212b
6 changed files with 89 additions and 46 deletions

View File

@@ -532,49 +532,59 @@ namespace impl{
template<index_t N>
CK_TILE_DEVICE void insert_dummy_dep_per_dword(array<float, N>& b)
{
for (auto i = 0; i < b.size(); i++) asm volatile(" " : : "v"(b.get(i)) : "memory");
static_for<0, b.size(), 1>{}([&](auto i){
asm volatile(" " : : "v"(b.get(i)) : "memory");
});
}
#if 1
// below specialization just merge size() of dwords into single section
template<>
CK_TILE_DEVICE void insert_dummy_dep_per_dword<2>(array<float, 2>& b)
{
asm volatile(" " : : "v"(b.get(0)), "v"(b.get(1)) : "memory");
asm volatile(" " : : "v"(b.get(number<0>{})), "v"(b.get(number<1>{})) : "memory");
}
template<>
CK_TILE_DEVICE void insert_dummy_dep_per_dword<3>(array<float, 3>& b)
{
asm volatile(" " : : "v"(b.get(0)), "v"(b.get(1)), "v"(b.get(2)) : "memory");
asm volatile(" " : : "v"(b.get(number<0>{})), "v"(b.get(number<1>{})), "v"(b.get(number<2>{})) : "memory");
}
template<>
CK_TILE_DEVICE void insert_dummy_dep_per_dword<4>(array<float, 4>& b)
{
asm volatile(" " : : "v"(b.get(0)), "v"(b.get(1)), "v"(b.get(2)), "v"(b.get(3)) : "memory");
asm volatile(" " : : "v"(b.get(number<0>{})), "v"(b.get(number<1>{})), "v"(b.get(number<2>{})), "v"(b.get(number<3>{})) : "memory");
}
template<>
CK_TILE_DEVICE void insert_dummy_dep_per_dword<8>(array<float, 8>& b)
{
asm volatile(" " : : "v"(b.get(0)), "v"(b.get(1)), "v"(b.get(2)), "v"(b.get(3)), "v"(b.get(4)), "v"(b.get(5)), "v"(b.get(6)), "v"(b.get(7)) : "memory");
asm volatile(" " : : "v"(b.get(number<0>{})), "v"(b.get(number<1>{})), "v"(b.get(number<2>{})), "v"(b.get(number<3>{})),
"v"(b.get(number<4>{})), "v"(b.get(number<5>{})), "v"(b.get(number<6>{})), "v"(b.get(number<7>{})) : "memory");
}
template<>
CK_TILE_DEVICE void insert_dummy_dep_per_dword<16>(array<float, 16>& b)
{
asm volatile(" " : : "v"(b.get(0)), "v"(b.get(1)), "v"(b.get(2)), "v"(b.get(3)), "v"(b.get(4)), "v"(b.get(5)), "v"(b.get(6)), "v"(b.get(7)),
"v"(b.get(8)), "v"(b.get(9)), "v"(b.get(10)), "v"(b.get(11)), "v"(b.get(12)), "v"(b.get(13)), "v"(b.get(14)), "v"(b.get(15)) : "memory");
asm volatile(" " : : "v"(b.get(number<0>{})), "v"(b.get(number<1>{})), "v"(b.get(number<2>{})), "v"(b.get(number<3>{})),
"v"(b.get(number<4>{})), "v"(b.get(number<5>{})), "v"(b.get(number<6>{})), "v"(b.get(number<7>{})),
"v"(b.get(number<8>{})), "v"(b.get(number<9>{})), "v"(b.get(number<10>{})), "v"(b.get(number<11>{})),
"v"(b.get(number<12>{})), "v"(b.get(number<13>{})), "v"(b.get(number<14>{})), "v"(b.get(number<15>{})) : "memory");
}
template<>
CK_TILE_DEVICE void insert_dummy_dep_per_dword<32>(array<float, 32>& b)
{
asm volatile(" " : : "v"(b.get(0)), "v"(b.get(1)), "v"(b.get(2)), "v"(b.get(3)), "v"(b.get(4)), "v"(b.get(5)), "v"(b.get(6)), "v"(b.get(7)),
"v"(b.get(8)), "v"(b.get(9)), "v"(b.get(10)), "v"(b.get(11)), "v"(b.get(12)), "v"(b.get(13)), "v"(b.get(14)), "v"(b.get(15)),
"v"(b.get(16)), "v"(b.get(17)), "v"(b.get(18)), "v"(b.get(19)), "v"(b.get(20)), "v"(b.get(21)), "v"(b.get(22)), "v"(b.get(23)),
"v"(b.get(24)), "v"(b.get(25)), "v"(b.get(26)), "v"(b.get(27)), "v"(b.get(28)), "v"(b.get(29)), "v"(b.get(30)), "v"(b.get(31)) : "memory");
asm volatile(" " : : "v"(b.get(number<0>{})), "v"(b.get(number<1>{})), "v"(b.get(number<2>{})), "v"(b.get(number<3>{})),
"v"(b.get(number<4>{})), "v"(b.get(number<5>{})), "v"(b.get(number<6>{})), "v"(b.get(number<7>{})),
"v"(b.get(number<8>{})), "v"(b.get(number<9>{})), "v"(b.get(number<10>{})), "v"(b.get(number<11>{})),
"v"(b.get(number<12>{})), "v"(b.get(number<13>{})), "v"(b.get(number<14>{})), "v"(b.get(number<15>{})),
"v"(b.get(number<16>{})), "v"(b.get(number<17>{})), "v"(b.get(number<18>{})), "v"(b.get(number<19>{})),
"v"(b.get(number<20>{})), "v"(b.get(number<21>{})), "v"(b.get(number<22>{})), "v"(b.get(number<23>{})),
"v"(b.get(number<24>{})), "v"(b.get(number<25>{})), "v"(b.get(number<26>{})), "v"(b.get(number<27>{})),
"v"(b.get(number<28>{})), "v"(b.get(number<29>{})), "v"(b.get(number<30>{})), "v"(b.get(number<31>{})) : "memory");
}
#endif
CK_TILE_DEVICE void insert_dummy_dep() {}
template<typename T>
@@ -876,14 +886,15 @@ enum struct amd_buffer_coherence_enum
template <index_t N,
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default>
CK_TILE_DEVICE array<int8_t, N> amd_buffer_load_impl_with_bytes(int32x4_t src_wave_buffer_resource,
index_t src_thread_addr_offset,
index_t src_wave_addr_offset)
CK_TILE_DEVICE thread_buffer<int8_t, N>
amd_buffer_load_impl_with_bytes(int32x4_t src_wave_buffer_resource,
index_t src_thread_addr_offset,
index_t src_wave_addr_offset)
{
static_assert(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32 || N == 64,
"wrong! not implemented");
using rtn_type = array<int8_t, N>;
using rtn_type = thread_buffer<int8_t, N>;
if constexpr(N == 1)
{
@@ -939,7 +950,7 @@ CK_TILE_DEVICE array<int8_t, N> amd_buffer_load_impl_with_bytes(int32x4_t src_wa
src_thread_addr_offset,
src_wave_addr_offset + 4 * sizeof(int32_t),
static_cast<index_t>(coherence));
array<int32_t, 8> tmp;
thread_buffer<int32_t, 8> tmp;
tmp.template get_as<int32x4_t>()(number<0>{}) = tmp0;
tmp.template get_as<int32x4_t>()(number<1>{}) = tmp1;
@@ -968,7 +979,7 @@ CK_TILE_DEVICE array<int8_t, N> amd_buffer_load_impl_with_bytes(int32x4_t src_wa
src_wave_addr_offset + 12 * sizeof(int32_t),
static_cast<index_t>(coherence));
array<int32_t, 16> tmp;
thread_buffer<int32_t, 16> tmp;
tmp.template get_as<int32x4_t>()(number<0>{}) = tmp0;
tmp.template get_as<int32x4_t>()(number<1>{}) = tmp1;
@@ -986,9 +997,9 @@ CK_TILE_DEVICE array<int8_t, N> amd_buffer_load_impl_with_bytes(int32x4_t src_wa
template <typename T,
index_t N,
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default>
CK_TILE_DEVICE array<T, N> amd_buffer_load_impl(int32x4_t src_wave_buffer_resource,
index_t src_thread_addr_offset,
index_t src_wave_addr_offset)
CK_TILE_DEVICE thread_buffer<T, N> amd_buffer_load_impl(int32x4_t src_wave_buffer_resource,
index_t src_thread_addr_offset,
index_t src_wave_addr_offset)
{
static_assert(
(std::is_same<T, double>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
@@ -1002,7 +1013,7 @@ CK_TILE_DEVICE array<T, N> amd_buffer_load_impl(int32x4_t src_wave_buffer_resour
(std::is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)),
"wrong! not implemented");
using rtn_type = array<T, N>;
using rtn_type = thread_buffer<T, N>;
if constexpr(std::is_same<T, float>::value) // fp32
{
@@ -1032,7 +1043,7 @@ CK_TILE_DEVICE array<T, N> amd_buffer_load_impl(int32x4_t src_wave_buffer_resour
}
else if constexpr(N == 8)
{
array<float, 8> tmp;
thread_buffer<float, 8> tmp;
tmp.template get_as<fp32x4_t>()(number<0>{}) =
llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
@@ -1050,7 +1061,7 @@ CK_TILE_DEVICE array<T, N> amd_buffer_load_impl(int32x4_t src_wave_buffer_resour
}
else if constexpr(N == 16)
{
array<float, 16> tmp;
thread_buffer<float, 16> tmp;
tmp.template get_as<fp32x4_t>()(number<0>{}) =
llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
@@ -1165,7 +1176,7 @@ template <typename T,
index_t N,
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default,
bool oob_conditional_check = true>
CK_TILE_DEVICE void amd_buffer_load_raw_impl(array<T, N>& dst,
CK_TILE_DEVICE void amd_buffer_load_raw_impl(thread_buffer<T, N>& dst,
int32x4_t src_wave_buffer_resource,
index_t src_thread_addr_offset,
index_t src_wave_addr_offset,
@@ -1175,7 +1186,7 @@ CK_TILE_DEVICE void amd_buffer_load_raw_impl(array<T, N>& dst,
static_assert(bytes == 1 || bytes == 2 || bytes == 4 || bytes == 8 || bytes == 16,
"wrong! not supported by buffer_load instruction");
using type = array<T, N>;
using type = thread_buffer<T, N>;
if constexpr(oob_conditional_check)
{
buffer_load_if<sizeof(type)>{}(
@@ -1208,7 +1219,7 @@ CK_TILE_DEVICE void amd_async_buffer_load_impl(T* smem,
template <index_t N,
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default>
CK_TILE_DEVICE void amd_buffer_store_impl_with_bytes(const array<int8_t, N> src_thread_data,
CK_TILE_DEVICE void amd_buffer_store_impl_with_bytes(const thread_buffer<int8_t, N> src_thread_data,
int32x4_t dst_wave_buffer_resource,
index_t dst_thread_addr_offset,
index_t dst_wave_addr_offset)
@@ -1308,7 +1319,7 @@ CK_TILE_DEVICE void amd_buffer_store_impl_with_bytes(const array<int8_t, N> src_
template <typename T,
index_t N,
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default>
CK_TILE_DEVICE void amd_buffer_store_impl(const array<T, N> src_thread_data,
CK_TILE_DEVICE void amd_buffer_store_impl(const thread_buffer<T, N> src_thread_data,
int32x4_t dst_wave_buffer_resource,
index_t dst_thread_addr_offset,
index_t dst_wave_addr_offset)
@@ -1396,7 +1407,7 @@ CK_TILE_DEVICE void amd_buffer_store_impl(const array<T, N> src_thread_data,
else if constexpr(N == 8)
{
#if 0
array<fp16_t, 8> tmp{src_thread_data};
thread_buffer<fp16_t, 8> tmp{src_thread_data};
llvm_amdgcn_raw_buffer_store_fp16x4(tmp.template get_as<fp16x4_t>()[number<0>{}],
dst_wave_buffer_resource,
@@ -1463,7 +1474,7 @@ CK_TILE_DEVICE void amd_buffer_store_impl(const array<T, N> src_thread_data,
}
else
{
using r_t = array<int8_t, sizeof(T) * N>;
using r_t = thread_buffer<int8_t, sizeof(T) * N>;
amd_buffer_store_impl_with_bytes<sizeof(T) * N, coherence>(bit_cast<r_t>(src_thread_data),
dst_wave_buffer_resource,
@@ -1486,7 +1497,7 @@ CK_TILE_DEVICE void amd_buffer_store_raw_impl(const thread_buffer<T, N>& dst_thr
static_assert(bytes == 1 || bytes == 2 || bytes == 4 || bytes == 8 || bytes == 16,
"wrong! not supported by buffer_store instruction");
using type = array<T, N>;
using type = thread_buffer<T, N>;
if constexpr(oob_conditional_check)
{
buffer_store_if<sizeof(type)>{}(dst_thread_data,
@@ -1507,7 +1518,7 @@ CK_TILE_DEVICE void amd_buffer_store_raw_impl(const thread_buffer<T, N>& dst_thr
}
template <typename T, index_t N>
CK_TILE_DEVICE void amd_buffer_atomic_add_impl(const array<T, N>& src_thread_data,
CK_TILE_DEVICE void amd_buffer_atomic_add_impl(const thread_buffer<T, N>& src_thread_data,
int32x4_t dst_wave_buffer_resource,
index_t dst_thread_addr_offset,
index_t dst_wave_addr_offset)
@@ -1667,7 +1678,7 @@ CK_TILE_DEVICE void amd_buffer_atomic_add_impl(const array<T, N>& src_thread_dat
}
template <typename T, index_t N>
CK_TILE_DEVICE void amd_buffer_atomic_max_impl(const array<T, N> src_thread_data,
CK_TILE_DEVICE void amd_buffer_atomic_max_impl(const thread_buffer<T, N> src_thread_data,
int32x4_t dst_wave_buffer_resource,
index_t dst_thread_addr_offset,
index_t dst_wave_addr_offset)
@@ -1742,7 +1753,7 @@ template <typename T,
index_t N,
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default,
bool oob_conditional_check = true>
CK_TILE_DEVICE array<T, N>
CK_TILE_DEVICE thread_buffer<T, N>
amd_buffer_load_invalid_element_return_zero(const T* p_src_wave,
index_t src_thread_element_offset,
bool src_thread_element_valid,
@@ -1763,10 +1774,10 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave,
return amd_buffer_load_impl<T, N, coherence>(
src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0);
#else
array<T, N> tmp =
thread_buffer<T, N> tmp =
amd_buffer_load_impl<T, N, coherence>(src_wave_buffer_resource, src_thread_addr_offset, 0);
if constexpr(oob_conditional_check)
return src_thread_element_valid ? tmp : array<T, N>{0};
return src_thread_element_valid ? tmp : thread_buffer<T, N>{0};
else
return tmp;
#endif
@@ -1780,7 +1791,7 @@ template <typename T,
index_t N,
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default,
bool oob_conditional_check = true>
CK_TILE_DEVICE array<T, N>
CK_TILE_DEVICE thread_buffer<T, N>
amd_buffer_load_invalid_element_return_customized_value(const T* p_src_wave,
index_t src_thread_element_offset,
bool src_thread_element_valid,
@@ -1792,11 +1803,11 @@ amd_buffer_load_invalid_element_return_customized_value(const T* p_src_wave,
index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
array<T, N> tmp =
thread_buffer<T, N> tmp =
amd_buffer_load_impl<T, N, coherence>(src_wave_buffer_resource, src_thread_addr_offset, 0);
if constexpr(oob_conditional_check)
return src_thread_element_valid ? tmp : array<T, N>{customized_value};
return src_thread_element_valid ? tmp : thread_buffer<T, N>{customized_value};
else
return tmp;
}
@@ -1805,7 +1816,7 @@ template <typename T,
index_t N,
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default,
bool oob_conditional_check = true>
CK_TILE_DEVICE void amd_buffer_load_raw(array<T, N>& dst,
CK_TILE_DEVICE void amd_buffer_load_raw(thread_buffer<T, N>& dst,
const T* p_src_wave,
index_t src_thread_element_offset,
index_t src_element_space_size,
@@ -1849,7 +1860,7 @@ template <typename T,
index_t N,
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default,
bool oob_conditional_check = true>
CK_TILE_DEVICE void amd_buffer_store(const array<T, N>& src_thread_data,
CK_TILE_DEVICE void amd_buffer_store(const thread_buffer<T, N>& src_thread_data,
T* p_dst_wave,
const index_t dst_thread_element_offset,
const bool dst_thread_element_valid,
@@ -1913,7 +1924,7 @@ CK_TILE_DEVICE void amd_buffer_store_raw(const thread_buffer<T, N>& src_thread_d
// 2) p_dst_wave must be a wavewise pointer.
// It is user's responsibility to make sure that is true.
template <typename T, index_t N>
CK_TILE_DEVICE void amd_buffer_atomic_add(const array<T, N>& src_thread_data,
CK_TILE_DEVICE void amd_buffer_atomic_add(const thread_buffer<T, N>& src_thread_data,
T* p_dst_wave,
const index_t dst_thread_element_offset,
const bool dst_thread_element_valid,
@@ -1943,7 +1954,7 @@ CK_TILE_DEVICE void amd_buffer_atomic_add(const array<T, N>& src_thread_data,
// 2) p_dst_wave must be a wavewise pointer.
// It is user's responsibility to make sure that is true.
template <typename T, index_t N>
CK_TILE_DEVICE void amd_buffer_atomic_max(const array<T, N>& src_thread_data,
CK_TILE_DEVICE void amd_buffer_atomic_max(const thread_buffer<T, N>& src_thread_data,
T* p_dst_wave,
const index_t dst_thread_element_offset,
const bool dst_thread_element_valid,

View File

@@ -49,6 +49,16 @@
#define CK_TILE_THREAD_BUFFER_DEFAULT CK_TILE_THREAD_BUFFER_USE_ARRAY
#endif
#ifndef CK_TILE_TUPLE_CTOR_WITH_INITIALIZER_LIST
#if CK_TILE_THREAD_BUFFER_DEFAULT == CK_TILE_THREAD_BUFFER_USE_TUPLE
// if using tuple-array as thread_buffer implementation, need to support {} brace init
// ... with similiar behavior as array
#define CK_TILE_TUPLE_CTOR_WITH_INITIALIZER_LIST 1
#else
#define CK_TILE_TUPLE_CTOR_WITH_INITIALIZER_LIST 0
#endif
#endif
#ifndef CK_TILE_USE_LAUNCH_BOUNDS
#define CK_TILE_USE_LAUNCH_BOUNDS 1
#endif

View File

@@ -11,7 +11,8 @@
namespace ck_tile {
// deprecated, always use array instead
// Don't use tihs directly. This is for old CK's internal usage,
// in the future always use array instead
template <index_t N>
using multi_index = array<index_t, N>;

View File

@@ -11,6 +11,7 @@
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include <utility>
#include <initializer_list>
#ifndef CK_TILE_TUPLE_IMPL
#define CK_TILE_TUPLE_IMPL 1
@@ -121,6 +122,17 @@ template <index_t... I, typename... T>
struct tuple_base<sequence<I...>, T...> : tuple_object<I, T>...
{
CK_TILE_HOST_DEVICE constexpr tuple_base() = default;
#if CK_TILE_TUPLE_CTOR_WITH_INITIALIZER_LIST
#define _ILE() (std::initializer_list<U>{}.size() - 1)
template <typename U>
CK_TILE_HOST_DEVICE constexpr tuple_base(std::initializer_list<U> us)
: tuple_object<I, T>(static_cast<T>(*(us.begin() + (I >= _ILE() ? _ILE() : I))))...
{
}
#undef _ILE
#endif
#if CK_TILE_TUPLE_IMPL == 0
template <class... U>
CK_TILE_HOST_DEVICE constexpr explicit tuple_base(U&&... u)
@@ -182,6 +194,14 @@ struct tuple : impl::tuple_base<make_index_sequence<sizeof...(T)>, T...>
static constexpr auto size() { return sizeof...(T); }
using base = impl::tuple_base<make_index_sequence<sizeof...(T)>, T...>;
CK_TILE_HOST_DEVICE constexpr tuple() = default;
#if CK_TILE_TUPLE_CTOR_WITH_INITIALIZER_LIST
template <typename U>
CK_TILE_HOST_DEVICE constexpr tuple(std::initializer_list<U> us) : base(us)
{
}
#endif
#if CK_TILE_TUPLE_IMPL == 0
template <class... U>
CK_TILE_HOST_DEVICE constexpr tuple(U&&... u) : base(std::forward<U>(u)...)

View File

@@ -91,7 +91,7 @@ struct tile_window_with_static_distribution
// using vector_type_t = vector_type_maker_t<DataType, ScalarPerVector>;
// using vector_t = typename vector_type_t::type;
using vector_t = array<DataType, ScalarPerVector>;
using vector_t = thread_buffer<DataType, ScalarPerVector>;
private:
static constexpr auto scalars_per_access_ = [] {

View File

@@ -29,3 +29,4 @@
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_impl.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"