mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-20 21:09:08 +00:00
[CK_TILE] fused-moe first version (#1634)
* moe pipeline
* update code
* compile OK
* update
* update cpu reference
* update pipeline_gemm0
* compiler ok
* update pipeline
* rename to ex pipeline
* block-asm
* update
* update
* update first gemm ok
* compute correct
* update file structure
* update README
* update
* update
* update code
* update API
* return unsupport case
* add comment
* update readme
* update
* uncomment
* update
* fix build err
---------
Co-authored-by: valarLip <340077269@qq.com>
[ROCm/composable_kernel commit: 440e28b08f]
This commit is contained in:
@@ -621,6 +621,65 @@ CK_TILE_DEVICE void buffer_load_fence(index_t cnt = 0)
|
||||
asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory");
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void lds_load_fence(index_t cnt = 0)
|
||||
{
|
||||
asm volatile("s_waitcnt lgkmcnt(%0)" : : "n"(cnt) : "memory");
|
||||
}
|
||||
|
||||
template <typename scalar_type, index_t N, bool pre_nop = false>
|
||||
struct buffer_atomic_add_if;
|
||||
|
||||
template <bool pre_nop>
|
||||
struct buffer_atomic_add_if<bf16_t, 2, pre_nop>
|
||||
{
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE void operator()(const T& value,
|
||||
int32x4_t res /*buffer resource*/,
|
||||
index_t v_offset,
|
||||
index_t /*s_offset*/,
|
||||
index_t i_offset /*max 0xFFF*/,
|
||||
index_t flag = 1)
|
||||
{
|
||||
static_assert(sizeof(T) == 4);
|
||||
auto save_exec = __builtin_amdgcn_read_exec();
|
||||
using mbuf_t = float;
|
||||
asm volatile("v_cmpx_le_u32 exec, 1, %4\n"
|
||||
"global_atomic_pk_add_bf16 %0, %1, %2 offset:%3\n"
|
||||
"s_mov_b64 exec %5"
|
||||
:
|
||||
: "v"(v_offset),
|
||||
"v"(bit_cast<mbuf_t>(value)),
|
||||
"s"(res.xy),
|
||||
"n"(i_offset),
|
||||
"v"(flag),
|
||||
"s"(save_exec)
|
||||
: "memory");
|
||||
}
|
||||
};
|
||||
|
||||
template <typename scalar_type, index_t N, bool pre_nop = false>
|
||||
struct buffer_atomic_add;
|
||||
|
||||
template <bool pre_nop>
|
||||
struct buffer_atomic_add<bf16_t, 2, pre_nop>
|
||||
{
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE void operator()(const T& value,
|
||||
int32x4_t res /*buffer resource*/,
|
||||
index_t v_offset,
|
||||
index_t /*s_offset*/,
|
||||
index_t i_offset /*max 0xFFF*/,
|
||||
index_t /*flag = 1*/)
|
||||
{
|
||||
static_assert(sizeof(T) == 4);
|
||||
using mbuf_t = float;
|
||||
asm volatile("global_atomic_pk_add_bf16 %0, %1, %2 offset:%3"
|
||||
:
|
||||
: "v"(v_offset), "v"(bit_cast<mbuf_t>(value)), "s"(res.xy), "n"(i_offset)
|
||||
: "memory");
|
||||
}
|
||||
};
|
||||
|
||||
namespace impl {
|
||||
// below type indicate the data type used for buffer load inline asm
|
||||
// clang-format off
|
||||
@@ -810,6 +869,11 @@ CK_TILE_DEVICE void buffer_store_fence(index_t cnt = 0)
|
||||
asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory");
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE auto async_load_fence_raw(index_t cnt = 0)
|
||||
{
|
||||
asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory");
|
||||
}
|
||||
|
||||
// buffer load i8
|
||||
CK_TILE_DEVICE_EXTERN int8_t
|
||||
llvm_amdgcn_raw_buffer_load_i8(int32x4_t srsrc,
|
||||
@@ -2378,6 +2442,45 @@ CK_TILE_DEVICE void amd_buffer_atomic_add(const thread_buffer<T, N>& src_thread_
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T,
|
||||
index_t N,
|
||||
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default,
|
||||
bool oob_conditional_check = true,
|
||||
bool pre_nop = false>
|
||||
CK_TILE_DEVICE void amd_buffer_atomic_add_raw(const thread_buffer<T, N>& src_thread_data,
|
||||
T* p_dst_wave,
|
||||
const index_t dst_thread_element_offset,
|
||||
const index_t dst_linear_element_offset,
|
||||
const bool dst_thread_element_valid,
|
||||
const index_t dst_element_space_size,
|
||||
bool_constant<pre_nop> = {})
|
||||
{
|
||||
const int32x4_t dst_wave_buffer_resource =
|
||||
make_wave_buffer_resource(p_dst_wave, dst_element_space_size * sizeof(T));
|
||||
|
||||
index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T);
|
||||
index_t dst_linear_addr_offset = dst_linear_element_offset * sizeof(T);
|
||||
|
||||
if constexpr(oob_conditional_check)
|
||||
{
|
||||
buffer_atomic_add_if<T, N, pre_nop>{}(src_thread_data,
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
0,
|
||||
dst_linear_addr_offset,
|
||||
dst_thread_element_valid);
|
||||
}
|
||||
else
|
||||
{
|
||||
buffer_atomic_add<T, N, pre_nop>{}(src_thread_data,
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
0,
|
||||
dst_linear_addr_offset,
|
||||
1);
|
||||
}
|
||||
}
|
||||
|
||||
// buffer_atomic_max requires:
|
||||
// 1) p_dst_wave must point to global memory
|
||||
// 2) p_dst_wave must be a wavewise pointer.
|
||||
|
||||
@@ -73,6 +73,24 @@ CK_TILE_DEVICE void block_sync_lds()
|
||||
#endif
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void block_sync_load_raw(index_t cnt = 0)
|
||||
{
|
||||
#ifdef __gfx12__
|
||||
asm volatile("s_wait_loadcnt %0 \n"
|
||||
"s_barrier_signal -1 \n"
|
||||
"s_barrier_wait -1"
|
||||
:
|
||||
: "n"(cnt)
|
||||
: "memory");
|
||||
#else
|
||||
asm volatile("s_waitcnt vmcnt(%0) \n"
|
||||
"s_barrier"
|
||||
:
|
||||
: "n"(cnt)
|
||||
: "memory");
|
||||
#endif
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void block_sync_lds_direct_load()
|
||||
{
|
||||
asm volatile("\
|
||||
|
||||
@@ -102,4 +102,28 @@ CK_TILE_DEVICE T warp_shuffle(const T& v_local, uint32_t src_lane)
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE auto flag_to_exec(const T& v_flag)
|
||||
{
|
||||
static_assert(sizeof(T) == 4);
|
||||
// per-thread v_flag store into 2x sgpr
|
||||
uint32x2_t exec_flag;
|
||||
asm volatile("v_cmp_ge_u32 %[s_exec_flag], %[v_flag], 1"
|
||||
: [s_exec_flag] "=s"(exec_flag)
|
||||
: [v_flag] "v"(v_flag));
|
||||
return exec_flag;
|
||||
}
|
||||
|
||||
template <typename X, typename Y>
|
||||
CK_TILE_DEVICE auto cmp_lt_to_exec(const X& x, const Y& y)
|
||||
{
|
||||
static_assert(sizeof(X) == 4 && sizeof(Y) == 4);
|
||||
// per-thread cmp store into 2x sgpr
|
||||
uint32x2_t exec_flag;
|
||||
asm volatile("v_cmp_lt_u32 %[s_exec_flag], %[v_x], %[v_y]"
|
||||
: [s_exec_flag] "=s"(exec_flag)
|
||||
: [v_x] "v"(x), [v_y] "v"(y));
|
||||
return exec_flag;
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -437,34 +437,74 @@ struct buffer_view<address_space_enum::global,
|
||||
// i is offset of T, not X. i should be aligned to X
|
||||
template <memory_operation_enum Op,
|
||||
typename X,
|
||||
bool oob_conditional_check = true,
|
||||
typename std::enable_if<
|
||||
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
|
||||
bool>::type = false>
|
||||
CK_TILE_DEVICE void update(index_t i, index_t linear_offset, bool is_valid_element, const X& x)
|
||||
CK_TILE_DEVICE void update(index_t i,
|
||||
index_t linear_offset,
|
||||
bool is_valid_element,
|
||||
const X& x,
|
||||
bool_constant<oob_conditional_check> = {})
|
||||
{
|
||||
if constexpr(Op == memory_operation_enum::set)
|
||||
{
|
||||
this->template set<X>(i, linear_offset, is_valid_element, x);
|
||||
this->template set<X, oob_conditional_check>(i, linear_offset, is_valid_element, x);
|
||||
}
|
||||
else if constexpr(Op == memory_operation_enum::atomic_add)
|
||||
{
|
||||
this->template atomic_add<X>(i, linear_offset, is_valid_element, x);
|
||||
this->template atomic_add<X, oob_conditional_check>(
|
||||
i, linear_offset, is_valid_element, x);
|
||||
}
|
||||
else if constexpr(Op == memory_operation_enum::atomic_max)
|
||||
{
|
||||
this->template atomic_max<X>(i, linear_offset, is_valid_element, x);
|
||||
this->template atomic_max<X, oob_conditional_check>(
|
||||
i, linear_offset, is_valid_element, x);
|
||||
}
|
||||
// FIXME: remove memory_operation_enum::add
|
||||
else if constexpr(Op == memory_operation_enum::add)
|
||||
{
|
||||
auto tmp = this->template get<X>(i, linear_offset, is_valid_element);
|
||||
this->template set<X>(i, linear_offset, is_valid_element, x + tmp);
|
||||
auto tmp =
|
||||
this->template get<X, oob_conditional_check>(i, linear_offset, is_valid_element);
|
||||
this->template set<X, oob_conditional_check>(
|
||||
i, linear_offset, is_valid_element, x + tmp);
|
||||
// tmp += x;
|
||||
// this->template set<X>(i, is_valid_element, tmp);
|
||||
}
|
||||
}
|
||||
|
||||
// i is offset of T, not X. i should be aligned to X
|
||||
template <memory_operation_enum Op,
|
||||
typename X,
|
||||
bool oob_conditional_check = true,
|
||||
bool pre_nop = false,
|
||||
typename std::enable_if<
|
||||
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
|
||||
bool>::type = false>
|
||||
CK_TILE_DEVICE void update_raw(index_t i,
|
||||
index_t linear_offset,
|
||||
bool is_valid_element,
|
||||
const X& x,
|
||||
bool_constant<oob_conditional_check> = {},
|
||||
bool_constant<pre_nop> = {})
|
||||
{
|
||||
if constexpr(Op == memory_operation_enum::set)
|
||||
{
|
||||
this->template set_raw<X, oob_conditional_check>(i, linear_offset, is_valid_element, x);
|
||||
}
|
||||
else if constexpr(Op == memory_operation_enum::atomic_add)
|
||||
{
|
||||
this->template atomic_add_raw<X, oob_conditional_check, pre_nop>(
|
||||
i, linear_offset, is_valid_element, x);
|
||||
}
|
||||
else if constexpr(Op == memory_operation_enum::atomic_max)
|
||||
{
|
||||
// this->template atomic_max_raw<X>(i, linear_offset, is_valid_element, x);
|
||||
}
|
||||
}
|
||||
|
||||
// i is offset of T, not X. i should be aligned to X
|
||||
template <typename X,
|
||||
bool oob_conditional_check = true,
|
||||
@@ -533,6 +573,7 @@ struct buffer_view<address_space_enum::global,
|
||||
}
|
||||
|
||||
template <typename X,
|
||||
bool oob_conditional_check = true,
|
||||
typename std::enable_if<
|
||||
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
|
||||
@@ -585,6 +626,39 @@ struct buffer_view<address_space_enum::global,
|
||||
}
|
||||
|
||||
template <typename X,
|
||||
bool oob_conditional_check = true,
|
||||
bool pre_nop = true,
|
||||
typename std::enable_if<
|
||||
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
|
||||
bool>::type = false>
|
||||
CK_TILE_DEVICE void
|
||||
atomic_add_raw(index_t i, index_t linear_offset, bool is_valid_element, const X& x)
|
||||
{
|
||||
// using scalar_t = typename vector_traits<remove_cvref_t<T>>::scalar_type;
|
||||
|
||||
// X contains multiple T
|
||||
constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
|
||||
|
||||
constexpr index_t scalar_per_x_vector = vector_traits<remove_cvref_t<X>>::vector_size;
|
||||
|
||||
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
|
||||
"wrong! X should contain multiple T");
|
||||
|
||||
static_assert(get_address_space() == address_space_enum::global, "only support global mem");
|
||||
|
||||
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
|
||||
|
||||
amd_buffer_atomic_add_raw<remove_cvref_t<T>,
|
||||
t_per_x,
|
||||
Coherence,
|
||||
oob_conditional_check,
|
||||
pre_nop>(
|
||||
x, p_data_, i, linear_offset, is_valid_element, buffer_size_);
|
||||
}
|
||||
|
||||
template <typename X,
|
||||
bool oob_conditional_check = true,
|
||||
typename std::enable_if<
|
||||
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
|
||||
|
||||
@@ -22,28 +22,32 @@ template <typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
typename TileDistribution_,
|
||||
index_t NumCoord,
|
||||
index_t i_access = -1,
|
||||
bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE auto load_tile(const tile_window_with_static_distribution<BottomTensorView_,
|
||||
WindowLengths_,
|
||||
TileDistribution_,
|
||||
NumCoord>& tile_window,
|
||||
number<i_access> = {},
|
||||
bool_constant<oob_conditional_check> = {})
|
||||
{
|
||||
return tile_window.load(number<-1>{}, bool_constant<oob_conditional_check>{});
|
||||
return tile_window.load(number<i_access>{}, bool_constant<oob_conditional_check>{});
|
||||
}
|
||||
|
||||
template <typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
typename TileDistribution_,
|
||||
typename LinearBottomDims_,
|
||||
index_t i_access = -1,
|
||||
bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE auto load_tile(const tile_window_linear<BottomTensorView_,
|
||||
WindowLengths_,
|
||||
TileDistribution_,
|
||||
LinearBottomDims_>& tile_window,
|
||||
number<i_access> = {},
|
||||
bool_constant<oob_conditional_check> = {})
|
||||
{
|
||||
return tile_window.load(number<-1>{}, bool_constant<oob_conditional_check>{});
|
||||
return tile_window.load(number<i_access>{}, bool_constant<oob_conditional_check>{});
|
||||
}
|
||||
|
||||
template <typename DistributedTensor_,
|
||||
@@ -51,15 +55,35 @@ template <typename DistributedTensor_,
|
||||
typename WindowLengths_,
|
||||
typename TileDistribution_,
|
||||
index_t NumCoord,
|
||||
index_t i_access = -1,
|
||||
bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE auto load_tile(DistributedTensor_& dst_tile,
|
||||
const tile_window_with_static_distribution<BottomTensorView_,
|
||||
WindowLengths_,
|
||||
TileDistribution_,
|
||||
NumCoord>& tile_window,
|
||||
number<i_access> = {},
|
||||
bool_constant<oob_conditional_check> = {})
|
||||
{
|
||||
return tile_window.load(dst_tile, bool_constant<oob_conditional_check>{});
|
||||
return tile_window.load(dst_tile, number<i_access>{}, bool_constant<oob_conditional_check>{});
|
||||
}
|
||||
|
||||
template <typename DistributedTensor_,
|
||||
typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
typename TileDistribution_,
|
||||
typename LinearBottomDims_,
|
||||
index_t i_access = -1,
|
||||
bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE auto load_tile(DistributedTensor_& dst_tile,
|
||||
const tile_window_linear<BottomTensorView_,
|
||||
WindowLengths_,
|
||||
TileDistribution_,
|
||||
LinearBottomDims_>& tile_window,
|
||||
number<i_access> = {},
|
||||
bool_constant<oob_conditional_check> = {})
|
||||
{
|
||||
return tile_window.load(dst_tile, number<i_access>{}, bool_constant<oob_conditional_check>{});
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -76,6 +100,7 @@ template <typename T,
|
||||
typename WindowLengths_,
|
||||
typename TileDistribution_,
|
||||
index_t NumCoord,
|
||||
index_t i_access = -1,
|
||||
bool oob_conditional_check = true,
|
||||
bool pre_nop = false>
|
||||
CK_TILE_DEVICE auto load_tile_raw(T& tile,
|
||||
@@ -83,11 +108,12 @@ CK_TILE_DEVICE auto load_tile_raw(T& tile,
|
||||
WindowLengths_,
|
||||
TileDistribution_,
|
||||
NumCoord>& tile_window,
|
||||
number<i_access> = {},
|
||||
bool_constant<oob_conditional_check> = {},
|
||||
bool_constant<pre_nop> = {})
|
||||
{
|
||||
tile_window.load_raw(
|
||||
tile, number<-1>{}, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
|
||||
tile, number<i_access>{}, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
|
||||
}
|
||||
|
||||
template <typename T,
|
||||
@@ -95,6 +121,7 @@ template <typename T,
|
||||
typename WindowLengths_,
|
||||
typename TileDistribution_,
|
||||
typename LinearBottomDims_,
|
||||
index_t i_access = -1,
|
||||
bool oob_conditional_check = true,
|
||||
bool pre_nop = false>
|
||||
CK_TILE_DEVICE auto load_tile_raw(T& tile,
|
||||
@@ -102,11 +129,12 @@ CK_TILE_DEVICE auto load_tile_raw(T& tile,
|
||||
WindowLengths_,
|
||||
TileDistribution_,
|
||||
LinearBottomDims_>& tile_window,
|
||||
number<i_access> = {},
|
||||
bool_constant<oob_conditional_check> = {},
|
||||
bool_constant<pre_nop> = {})
|
||||
{
|
||||
tile_window.load_raw(
|
||||
tile, number<-1>{}, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
|
||||
tile, number<i_access>{}, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
|
||||
}
|
||||
|
||||
template <typename LdsTileWindow_,
|
||||
@@ -114,6 +142,7 @@ template <typename LdsTileWindow_,
|
||||
typename WindowLengths_,
|
||||
typename TileDistribution_,
|
||||
index_t NumCoord,
|
||||
index_t i_access = -1,
|
||||
bool oob_conditional_check = true,
|
||||
bool pre_nop = false>
|
||||
CK_TILE_DEVICE auto
|
||||
@@ -122,11 +151,14 @@ async_load_tile_raw(LdsTileWindow_&& lds_tile,
|
||||
WindowLengths_,
|
||||
TileDistribution_,
|
||||
NumCoord>& tile_window,
|
||||
number<i_access> = {},
|
||||
bool_constant<oob_conditional_check> = {},
|
||||
bool_constant<pre_nop> = {})
|
||||
{
|
||||
return tile_window.async_load_raw(
|
||||
lds_tile, number<-1>{}, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
|
||||
return tile_window.async_load_raw(lds_tile,
|
||||
number<i_access>{},
|
||||
bool_constant<oob_conditional_check>{},
|
||||
bool_constant<pre_nop>{});
|
||||
}
|
||||
|
||||
template <typename LdsTileWindow_,
|
||||
@@ -134,6 +166,7 @@ template <typename LdsTileWindow_,
|
||||
typename WindowLengths_,
|
||||
typename TileDistribution_,
|
||||
typename LinearBottomDims_,
|
||||
index_t i_access = -1,
|
||||
bool oob_conditional_check = true,
|
||||
bool pre_nop = false>
|
||||
CK_TILE_DEVICE auto async_load_tile_raw(LdsTileWindow_&& lds_tile,
|
||||
@@ -141,11 +174,14 @@ CK_TILE_DEVICE auto async_load_tile_raw(LdsTileWindow_&& lds_tile,
|
||||
WindowLengths_,
|
||||
TileDistribution_,
|
||||
LinearBottomDims_>& tile_window,
|
||||
number<i_access> = {},
|
||||
bool_constant<oob_conditional_check> = {},
|
||||
bool_constant<pre_nop> = {})
|
||||
{
|
||||
return tile_window.async_load_raw(
|
||||
lds_tile, number<-1>{}, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
|
||||
return tile_window.async_load_raw(lds_tile,
|
||||
number<i_access>{},
|
||||
bool_constant<oob_conditional_check>{},
|
||||
bool_constant<pre_nop>{});
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE auto async_load_fence(index_t cnt = 0)
|
||||
|
||||
@@ -201,4 +201,30 @@ CK_TILE_HOST_DEVICE constexpr auto get_y_unpacks_from_x_unpacks(YLengths, number
|
||||
return unpacks;
|
||||
}
|
||||
|
||||
namespace detail {
|
||||
|
||||
// check if 2 static_distributed_tensor has same data type and size of element
|
||||
// but only difference in distribution
|
||||
template <typename X, typename Y>
|
||||
struct is_similiar_distributed_tensor
|
||||
{
|
||||
static constexpr bool value = false;
|
||||
};
|
||||
|
||||
template <typename TypeX, typename DistX, typename TypeY, typename DistY>
|
||||
struct is_similiar_distributed_tensor<static_distributed_tensor<TypeX, DistX>,
|
||||
static_distributed_tensor<TypeY, DistY>>
|
||||
{
|
||||
using Tx = static_distributed_tensor<TypeX, DistX>;
|
||||
using Ty = static_distributed_tensor<TypeY, DistY>;
|
||||
static constexpr bool value = std::is_same_v<typename Tx::DataType, typename Ty::DataType> &&
|
||||
Tx::get_thread_buffer_size() == Ty::get_thread_buffer_size();
|
||||
};
|
||||
|
||||
template <typename X, typename Y>
|
||||
inline constexpr bool is_similiar_distributed_tensor_v =
|
||||
is_similiar_distributed_tensor<X, Y>::value;
|
||||
|
||||
} // namespace detail
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -333,6 +333,48 @@ struct tensor_view
|
||||
coord.get_offset(), linear_offset, is_valid_element, x);
|
||||
}
|
||||
|
||||
// X is vector of DataType.
|
||||
// "coord" is coordinate of DataType, not X. "coord" should be aligned to X
|
||||
template <typename X,
|
||||
bool oob_conditional_check = true,
|
||||
bool pre_nop = false,
|
||||
typename std::enable_if<
|
||||
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr void
|
||||
update_vectorized_elements_raw(const TensorCoord& coord,
|
||||
index_t linear_offset,
|
||||
const X& x,
|
||||
bool_constant<oob_conditional_check> = {},
|
||||
bool_constant<pre_nop> = {})
|
||||
{
|
||||
buf_.template update_raw<DstInMemOp, X, oob_conditional_check, pre_nop>(
|
||||
coord.get_offset(),
|
||||
linear_offset,
|
||||
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
|
||||
x);
|
||||
}
|
||||
|
||||
template <typename X,
|
||||
bool oob_conditional_check = true,
|
||||
bool pre_nop = false,
|
||||
typename std::enable_if<
|
||||
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr void
|
||||
update_vectorized_elements_raw(const TensorCoord& coord,
|
||||
index_t linear_offset,
|
||||
bool is_valid_element,
|
||||
const X& x,
|
||||
bool_constant<oob_conditional_check> = {},
|
||||
bool_constant<pre_nop> = {})
|
||||
{
|
||||
buf_.template update_raw<DstInMemOp, X, oob_conditional_check, pre_nop>(
|
||||
coord.get_offset(), linear_offset, is_valid_element, x);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void print() const
|
||||
{
|
||||
printf("tensor_view{");
|
||||
|
||||
@@ -292,12 +292,15 @@ struct tile_window_with_static_distribution
|
||||
{
|
||||
constexpr auto tile_dstr = TileDstr{};
|
||||
auto dst_tensor = make_static_distributed_tensor<DataType>(tile_dstr);
|
||||
load(dst_tensor, bool_constant<oob_conditional_check>{});
|
||||
load(dst_tensor, number<i_access_unsupport_>{}, bool_constant<oob_conditional_check>{});
|
||||
return dst_tensor;
|
||||
}
|
||||
|
||||
template <typename DistributedTensor, bool oob_conditional_check = true>
|
||||
template <typename DistributedTensor,
|
||||
index_t i_access_unsupport_ = -1,
|
||||
bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE auto load(DistributedTensor& dst_tensor,
|
||||
number<i_access_unsupport_> = {},
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
using Traits = load_store_traits;
|
||||
@@ -785,6 +788,73 @@ struct tile_window_with_static_distribution
|
||||
});
|
||||
}
|
||||
|
||||
template <index_t i_access_unsupport_ = -1, bool oob_conditional_check = true, bool pre_nop>
|
||||
CK_TILE_DEVICE void update_raw(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor,
|
||||
number<i_access_unsupport_> = {},
|
||||
bool_constant<oob_conditional_check> = {},
|
||||
bool_constant<pre_nop> = {}) const
|
||||
{
|
||||
using Traits = load_store_traits;
|
||||
|
||||
using vector_t = typename Traits::vector_t;
|
||||
using SFC_Ys = typename Traits::SFC_Ys;
|
||||
|
||||
constexpr auto tile_dstr = TileDstr{};
|
||||
|
||||
// loop over thread tensor space [y0, y1, ...]
|
||||
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
|
||||
/// TODO: use structure binding (to be captured later) if compiled in C++20
|
||||
auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
|
||||
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
|
||||
|
||||
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
|
||||
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
|
||||
|
||||
// data index [y0, y1, ...]
|
||||
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
|
||||
|
||||
// read from distributed tensor
|
||||
vector_t vec_value;
|
||||
|
||||
static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) {
|
||||
constexpr auto idx_ys = generate_tuple(
|
||||
[&](auto jj) {
|
||||
return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
|
||||
: idx_ys_start[jj];
|
||||
},
|
||||
number<NDimY>{});
|
||||
|
||||
constexpr index_t d =
|
||||
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
|
||||
|
||||
vec_value.template get_as<DataType>()(j) =
|
||||
dstr_tensor.get_thread_buffer().template at<d>();
|
||||
});
|
||||
|
||||
// write into bottom tensor
|
||||
get_bottom_tensor_view().template update_vectorized_elements_raw<vector_t>(
|
||||
bottom_tensor_thread_coord,
|
||||
0,
|
||||
vec_value,
|
||||
bool_constant<oob_conditional_check>{},
|
||||
bool_constant<pre_nop>{});
|
||||
|
||||
// move thread coordinate
|
||||
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
|
||||
{
|
||||
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
|
||||
|
||||
constexpr auto idx_diff_ps_ys = container_concat(
|
||||
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
|
||||
idx_diff_ys);
|
||||
|
||||
move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
// move thread's botom tensor coordiante
|
||||
// [x0', x1', ... ] ==> [offset]
|
||||
// also move window-origin
|
||||
|
||||
@@ -432,23 +432,38 @@ struct tile_window_linear
|
||||
CK_TILE_DEVICE static constexpr index_t get_bottom_linear_offset(number<i_access>)
|
||||
{
|
||||
constexpr auto linear_coord = get_bottom_linear_coordinate(number<i_access>{});
|
||||
// since this is linear offset, we assum bottom X tensor is always linear
|
||||
constexpr index_t linear_offset = [&]() {
|
||||
constexpr auto x_idx_ = linear_coord;
|
||||
constexpr auto x_len_ = TileDstr{}.get_lengths();
|
||||
static_assert(x_idx_.size() == x_len_.size());
|
||||
constexpr index_t x_dims_ = x_idx_.size();
|
||||
index_t cu_stride_ = 1;
|
||||
index_t cu_offset_ = 0;
|
||||
static_for<0, x_dims_, 1>{}([&](auto i_) {
|
||||
auto r_i_ = number<x_dims_ - i_ - 1>{};
|
||||
cu_offset_ += x_idx_[r_i_] * cu_stride_;
|
||||
cu_stride_ *= x_len_[r_i_];
|
||||
});
|
||||
return cu_offset_;
|
||||
}();
|
||||
|
||||
return linear_offset;
|
||||
constexpr auto is_pure_linear_tensor =
|
||||
reduce_on_sequence(LinearBottomDims{}, multiplies{}, number<1>{});
|
||||
if constexpr(is_pure_linear_tensor)
|
||||
{
|
||||
// this case usually is a LDS window, everything is known at compile tile.
|
||||
// we directly use BottomTensorView transform to compute the offset, in case padding
|
||||
auto bottom_tensor_coord =
|
||||
make_tensor_coordinate(BottomTensorView{}.get_tensor_descriptor(), linear_coord);
|
||||
return bottom_tensor_coord.get_offset();
|
||||
}
|
||||
else
|
||||
{
|
||||
// this case usually is a global window, where last dim can be linear
|
||||
// we hack here, that use the original TileDstr to compute the linear offset
|
||||
// ... hoping that there is no extra padding between other dims, which make sense
|
||||
// since that would introduce runtime length (so can't use linear offset)
|
||||
constexpr index_t linear_offset = [&]() {
|
||||
constexpr auto x_idx_ = linear_coord;
|
||||
constexpr auto x_len_ = TileDstr{}.get_lengths();
|
||||
static_assert(x_idx_.size() == x_len_.size());
|
||||
constexpr index_t x_dims_ = x_idx_.size();
|
||||
index_t cu_stride_ = 1;
|
||||
index_t cu_offset_ = 0;
|
||||
static_for<0, x_dims_, 1>{}([&](auto i_) {
|
||||
auto r_i_ = number<x_dims_ - i_ - 1>{};
|
||||
cu_offset_ += x_idx_[r_i_] * cu_stride_;
|
||||
cu_stride_ *= x_len_[r_i_];
|
||||
});
|
||||
return cu_offset_;
|
||||
}();
|
||||
return linear_offset;
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE constexpr auto get_num_of_access() const { return traits::NumAccess; }
|
||||
@@ -509,6 +524,64 @@ struct tile_window_linear
|
||||
return dst_tensor;
|
||||
}
|
||||
|
||||
template <typename DstTile, index_t i_access = -1, bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE auto load(DstTile& dst_tensor,
|
||||
number<i_access> = {},
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
using vector_t = typename traits::vector_t;
|
||||
using SFC_Ys = typename traits::SFC_Ys;
|
||||
|
||||
constexpr auto tile_dstr = TileDstr{};
|
||||
|
||||
// auto dst_tensor = make_static_distributed_tensor<DataType>(tile_dstr);
|
||||
|
||||
auto issue = [&](auto i_access_) {
|
||||
constexpr auto IAccess = number<i_access_>{};
|
||||
|
||||
constexpr auto non_linear_id = number<AccessMap_NonLinear{}[IAccess]>{};
|
||||
auto bottom_tensor_thread_coord = cached_coords_[non_linear_id];
|
||||
auto bottom_tensor_flag = cached_flags_[IAccess];
|
||||
|
||||
constexpr auto linear_offset = get_bottom_linear_offset(IAccess);
|
||||
|
||||
// read from bottom tensor
|
||||
const vector_t vec_value =
|
||||
get_bottom_tensor_view().template get_vectorized_elements<vector_t>(
|
||||
bottom_tensor_thread_coord,
|
||||
linear_offset,
|
||||
bottom_tensor_flag,
|
||||
bool_constant<oob_conditional_check>{});
|
||||
#if 1
|
||||
// data index [y0, y1, ...]
|
||||
constexpr auto idx_diff_ys = SFC_Ys::get_index(IAccess);
|
||||
// write into distributed tensor
|
||||
static_for<0, traits::ScalarPerVector, 1>{}([&](auto j) {
|
||||
constexpr auto idx_ys = generate_tuple(
|
||||
[&](auto jj) {
|
||||
return jj == traits::VectorDimY ? (idx_diff_ys[jj] + j) : idx_diff_ys[jj];
|
||||
},
|
||||
number<NDimY>{});
|
||||
|
||||
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
|
||||
|
||||
dst_tensor.get_thread_buffer().template at<d>() =
|
||||
vec_value.template get_as<DataType>()[j];
|
||||
});
|
||||
#else
|
||||
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start);
|
||||
static_assert(d % traits::ScalarPerVector == 0);
|
||||
|
||||
dst_tensor.get_thread_buffer().template get_as<vector_t>()(
|
||||
number<d / traits::ScalarPerVector>{}) = bit_cast<vector_t>(vec_value);
|
||||
#endif
|
||||
};
|
||||
|
||||
WINDOW_DISPATCH_ISSUE();
|
||||
|
||||
return dst_tensor;
|
||||
}
|
||||
|
||||
template <typename DstTile,
|
||||
index_t i_access = -1,
|
||||
bool oob_conditional_check = true,
|
||||
@@ -849,6 +922,58 @@ struct tile_window_linear
|
||||
WINDOW_DISPATCH_ISSUE();
|
||||
}
|
||||
|
||||
template <index_t i_access = -1, bool oob_conditional_check = true, bool pre_nop = false>
|
||||
CK_TILE_DEVICE void update_raw(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor,
|
||||
number<i_access> = {},
|
||||
bool_constant<oob_conditional_check> = {},
|
||||
bool_constant<pre_nop> = {}) const
|
||||
{
|
||||
|
||||
using vector_t = typename traits::vector_t;
|
||||
using SFC_Ys = typename traits::SFC_Ys;
|
||||
|
||||
constexpr auto tile_dstr = TileDstr{};
|
||||
|
||||
// loop over thread tensor space [y0, y1, ...]
|
||||
auto issue = [&](auto i_access_) {
|
||||
constexpr auto IAccess = number<i_access_>{};
|
||||
constexpr auto non_linear_id = number<AccessMap_NonLinear{}[IAccess]>{};
|
||||
auto bottom_tensor_thread_coord = cached_coords_[non_linear_id];
|
||||
constexpr auto linear_offset = get_bottom_linear_offset(IAccess);
|
||||
auto bottom_tensor_flag = cached_flags_[IAccess];
|
||||
|
||||
// data index [y0, y1, ...]
|
||||
constexpr auto idx_ys_start = SFC_Ys::get_index(IAccess);
|
||||
|
||||
// read from distributed tensor
|
||||
vector_t vec_value;
|
||||
|
||||
static_for<0, traits::ScalarPerVector, 1>{}([&](auto j) {
|
||||
constexpr auto idx_ys = generate_tuple(
|
||||
[&](auto jj) {
|
||||
return jj == traits::VectorDimY ? (idx_ys_start[jj] + j) : idx_ys_start[jj];
|
||||
},
|
||||
number<NDimY>{});
|
||||
|
||||
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
|
||||
|
||||
vec_value.template get_as<DataType>()(j) =
|
||||
dstr_tensor.get_thread_buffer().template at<d>();
|
||||
});
|
||||
|
||||
// write into bottom tensor
|
||||
get_bottom_tensor_view().template update_vectorized_elements_raw<vector_t>(
|
||||
bottom_tensor_thread_coord,
|
||||
linear_offset,
|
||||
bottom_tensor_flag,
|
||||
vec_value,
|
||||
bool_constant<oob_conditional_check>{},
|
||||
bool_constant<pre_nop>{});
|
||||
};
|
||||
|
||||
WINDOW_DISPATCH_ISSUE();
|
||||
}
|
||||
|
||||
// move thread's botom tensor coordiante
|
||||
// [x0', x1', ... ] ==> [offset]
|
||||
// also move window-origin
|
||||
|
||||
54
include/ck_tile/core/tensor/tile_window_utils.hpp
Normal file
54
include/ck_tile/core/tensor/tile_window_utils.hpp
Normal file
@@ -0,0 +1,54 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/core/arch/utility.hpp"
|
||||
#include "ck_tile/core/algorithm/space_filling_curve.hpp"
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/container/array.hpp"
|
||||
#include "ck_tile/core/container/sequence.hpp"
|
||||
#include "ck_tile/core/container/tuple.hpp"
|
||||
#include "ck_tile/core/container/container_helper.hpp"
|
||||
#include "ck_tile/core/tensor/static_distributed_tensor.hpp"
|
||||
#include "ck_tile/core/tensor/tensor_adaptor.hpp"
|
||||
#include "ck_tile/core/tensor/tile_distribution.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
#pragma once
|
||||
namespace ck_tile {
|
||||
|
||||
// input a lds store tile, extract some information from it
|
||||
// used to set m0 value for gfx9 serious
|
||||
template <typename LdsTileWindow_>
|
||||
CK_TILE_DEVICE auto get_async_store_smem_info(LdsTileWindow_&& lds_tile)
|
||||
{
|
||||
using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
|
||||
using LdsDataType = typename LdsTileWindow::DataType;
|
||||
|
||||
// issues * warps * lanes
|
||||
static_assert(LdsTileWindow::get_num_of_dimension() == 3); // TODO: hard coded
|
||||
|
||||
const index_t size_per_buf =
|
||||
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
|
||||
make_tuple(number<0>{}, number<0>{}, number<0>{})) *
|
||||
sizeof(LdsDataType);
|
||||
|
||||
const index_t size_per_wave =
|
||||
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
|
||||
make_tuple(number<0>{}, number<1>{}, number<0>{})) *
|
||||
sizeof(LdsDataType) -
|
||||
size_per_buf;
|
||||
|
||||
const index_t size_per_issue =
|
||||
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
|
||||
make_tuple(number<1>{}, number<0>{}, number<0>{})) *
|
||||
sizeof(LdsDataType) -
|
||||
size_per_buf;
|
||||
|
||||
const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id();
|
||||
|
||||
return make_tuple(m0_init_value, size_per_issue);
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -41,15 +41,65 @@ template <typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
typename TileDistribution_,
|
||||
index_t NumCoord,
|
||||
typename DataType_>
|
||||
typename DataType_,
|
||||
index_t i_access = -1,
|
||||
bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE void
|
||||
update_tile(tile_window_with_static_distribution<BottomTensorView_,
|
||||
WindowLengths_,
|
||||
TileDistribution_,
|
||||
NumCoord>& tile_window,
|
||||
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor)
|
||||
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor,
|
||||
number<i_access> = {},
|
||||
bool_constant<oob_conditional_check> = {})
|
||||
{
|
||||
tile_window.update(dstr_tensor);
|
||||
tile_window.update(dstr_tensor, number<i_access>{}, bool_constant<oob_conditional_check>{});
|
||||
}
|
||||
|
||||
template <typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
typename TileDistribution_,
|
||||
index_t NumCoord,
|
||||
typename DataType_,
|
||||
index_t i_access = -1,
|
||||
bool oob_conditional_check = true,
|
||||
bool pre_nop = false>
|
||||
CK_TILE_DEVICE void
|
||||
update_tile_raw(tile_window_with_static_distribution<BottomTensorView_,
|
||||
WindowLengths_,
|
||||
TileDistribution_,
|
||||
NumCoord>& tile_window,
|
||||
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor,
|
||||
number<i_access> = {},
|
||||
bool_constant<oob_conditional_check> = {},
|
||||
bool_constant<pre_nop> = {})
|
||||
{
|
||||
tile_window.update_raw(dstr_tensor,
|
||||
number<i_access>{},
|
||||
bool_constant<oob_conditional_check>{},
|
||||
bool_constant<pre_nop>{});
|
||||
}
|
||||
|
||||
template <typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
typename TileDistribution_,
|
||||
typename LinearBottomDims_,
|
||||
typename DataType_,
|
||||
index_t i_access = -1,
|
||||
bool oob_conditional_check = true,
|
||||
bool pre_nop = false>
|
||||
CK_TILE_DEVICE auto update_tile_raw(
|
||||
tile_window_linear<BottomTensorView_, WindowLengths_, TileDistribution_, LinearBottomDims_>&
|
||||
tile_window,
|
||||
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor,
|
||||
number<i_access> = {},
|
||||
bool_constant<oob_conditional_check> = {},
|
||||
bool_constant<pre_nop> = {})
|
||||
{
|
||||
tile_window.update_raw(dstr_tensor,
|
||||
number<i_access>{},
|
||||
bool_constant<oob_conditional_check>{},
|
||||
bool_constant<pre_nop>{});
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
116
include/ck_tile/core/utility/static_counter.hpp
Normal file
116
include/ck_tile/core/utility/static_counter.hpp
Normal file
@@ -0,0 +1,116 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Context, index_t Start = 0, index_t Step = 1>
|
||||
struct static_counter
|
||||
{
|
||||
public:
|
||||
template <typename Unique>
|
||||
static constexpr index_t next()
|
||||
{
|
||||
return next<Unique>(0) * Step + Start;
|
||||
}
|
||||
|
||||
template <unsigned long long>
|
||||
static constexpr index_t next()
|
||||
{
|
||||
struct Unique
|
||||
{
|
||||
};
|
||||
return next<Unique>(0) * Step + Start;
|
||||
}
|
||||
|
||||
template <typename Unique>
|
||||
static constexpr index_t current()
|
||||
{
|
||||
return current<Unique>(0) * Step + Start;
|
||||
}
|
||||
|
||||
template <unsigned long long>
|
||||
static constexpr index_t current()
|
||||
{
|
||||
struct Unique
|
||||
{
|
||||
};
|
||||
return current<Unique>(0) * Step + Start;
|
||||
}
|
||||
|
||||
private:
|
||||
template <index_t I>
|
||||
struct slot
|
||||
{
|
||||
_Pragma("GCC diagnostic push");
|
||||
_Pragma("GCC diagnostic ignored \"-Wundefined-internal\"");
|
||||
friend constexpr bool slot_allocated(slot<I>);
|
||||
_Pragma("GCC diagnostic pop");
|
||||
};
|
||||
|
||||
template <index_t I>
|
||||
struct allocate_slot
|
||||
{
|
||||
friend constexpr bool slot_allocated(slot<I>) { return true; }
|
||||
enum
|
||||
{
|
||||
value = I
|
||||
};
|
||||
};
|
||||
|
||||
// If slot_allocated(slot<I>) has NOT been defined, then SFINAE will keep this function out of
|
||||
// the overload set...
|
||||
template <typename Unique, index_t I = 0, bool = slot_allocated(slot<I>())>
|
||||
static constexpr index_t next(index_t)
|
||||
{
|
||||
return next<Unique, I + 1>(0);
|
||||
}
|
||||
|
||||
// ...And this function will be used, instead, which will define slot_allocated(slot<I>) via
|
||||
// allocate_slot<I>.
|
||||
template <typename Unique, index_t I = 0>
|
||||
static constexpr index_t next(double)
|
||||
{
|
||||
return allocate_slot<I>::value;
|
||||
}
|
||||
|
||||
// If slot_allocated(slot<I>) has NOT been defined, then SFINAE will keep this function out of
|
||||
// the overload set...
|
||||
template <typename Unique, index_t I = Start, bool = slot_allocated(slot<I>())>
|
||||
static constexpr index_t current(index_t)
|
||||
{
|
||||
return current<Unique, I + 1>(0);
|
||||
}
|
||||
|
||||
// ...And this function will be used, instead, which will return the current counter, or assert
|
||||
// in case next() hasn't been called yet.
|
||||
template <typename Unique, index_t I = Start>
|
||||
static constexpr index_t current(double)
|
||||
{
|
||||
static_assert(I != 0, "You must invoke next() first");
|
||||
|
||||
return I - 1;
|
||||
}
|
||||
};
|
||||
|
||||
namespace impl {
|
||||
template <int I>
|
||||
struct static_counter_uniq_;
|
||||
}
|
||||
|
||||
#define MAKE_SC() \
|
||||
ck_tile::static_counter<ck_tile::impl::static_counter_uniq_<__COUNTER__>> {}
|
||||
#define MAKE_SC_WITH(start_, step_) \
|
||||
ck_tile::static_counter<ck_tile::impl::static_counter_uniq_<__COUNTER__>, start_, step_> {}
|
||||
#define NEXT_SC(c_) c_.next<__COUNTER__>()
|
||||
#define NEXT_SCI(c_, static_i_) c_.next<__COUNTER__ + static_i_>()
|
||||
|
||||
// Usage:
|
||||
// constexpr auto c = MAKE_SC()
|
||||
// NEXT_SC(c) // -> constexpr 0
|
||||
// NEXT_SC(c) // -> constexpr 1
|
||||
// NEXT_SC(c) // -> constexpr 2
|
||||
} // namespace ck_tile
|
||||
Reference in New Issue
Block a user