mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 10:09:41 +00:00
Merge commit 'a8742f7e31d481b5fb2152ab5428b721c6bcb27b' into develop
This commit is contained in:
@@ -1178,6 +1178,15 @@ struct reverse_slice_sequence_impl<sequence<x>, sequence<m>, sequence<id>, Slice
|
||||
// clang-format off
|
||||
// input a sequence(with optional mask), and the SliceSize : size per slice
|
||||
// output the sequence each slice, and number of slices
|
||||
// the length count for slice size is from right to left(reverse slice)
|
||||
// or we can say, find the greatest common divider(gcd) from right to left, for the slice length
|
||||
//
|
||||
// e.g. <2, 8, 4>, slice length = 16
|
||||
// step-1: we take the right most <*, *, 4>, remaining 16/4=4
|
||||
// step-2: we only need 4 out of 8, of the midden dim, hence <*, 4, 4>
|
||||
// step-3: since nonthing remain, so the first dim we only need 1, hence<1, 4, 4>
|
||||
// => we got <1, 4, 4> as length for each slice
|
||||
// => total number of slice = <2, 8, 4> / <1, 4, 4> = <2, 2, 1>
|
||||
//
|
||||
// e.g. <2, 1, 4, 2>, 8 -> lengths:<1, 1, 4, 2> , nums: <2, 1, 1, 1> : 2 slices , slice_idx: 0
|
||||
// <4, 2, 4, 1, 2>, 4 -> lengths:<1, 1, 2, 1, 2> , nums: <4, 2, 2, 1, 1> : 16 slices , slice_idx: 2
|
||||
@@ -1197,7 +1206,7 @@ struct reverse_slice_sequence_impl<sequence<x>, sequence<m>, sequence<id>, Slice
|
||||
//
|
||||
// return tuple<slice_lengths, slice_nums, slice_index>, slice_index is at which index will start
|
||||
// have split slices (right -> left)
|
||||
// or the first index that sliced length is different from the original length
|
||||
// or the first index (right -> left) that sliced length is different from the original length
|
||||
// clang-format on
|
||||
template <typename Seq,
|
||||
index_t SliceSize,
|
||||
@@ -1207,6 +1216,11 @@ constexpr auto reverse_slice_sequence(Seq,
|
||||
Mask = typename uniform_sequence_gen<Seq::size(), 1>::type{})
|
||||
{
|
||||
static_assert(Seq::size() == Mask::size());
|
||||
static_assert(SliceSize != 0, "slice size zero is invalid");
|
||||
static_assert(container_reduce(pick_sequence_elements_by_mask(Seq{}, Mask{}), multiplies{}, 1) %
|
||||
SliceSize ==
|
||||
0,
|
||||
"slice size can't evenly divide input sizes");
|
||||
using sliced_type =
|
||||
impl::reverse_slice_sequence_impl<Seq,
|
||||
Mask,
|
||||
|
||||
@@ -542,26 +542,26 @@ namespace detail {
|
||||
//
|
||||
// e.g
|
||||
// X0 X1
|
||||
// <1, 4, 32> - <4, 1, 4, 2, 4> | slice origin:<0, 0>, len:<0, 32>, (0 means all length)
|
||||
// <1, 4, 32> - <4, 1, 4, 2, 4> | slice start:<0, 0>, end:<-1, 32>, (-1 means the last one)
|
||||
// Y P P Y P Y P Y
|
||||
// => <1, 4, 32> - <1, 1, 4, 2, 4> -> OK
|
||||
// |--> slice along this Y dim, is the first dim of X1, totally 4 slices
|
||||
//
|
||||
// X0 X1
|
||||
// <1, 4, 32> - <4, 1, 4, 2, 4> | slice origin:<0, 0>, len:<0, 8>, (0 means all length)
|
||||
// <1, 4, 32> - <4, 1, 4, 2, 4> | slice start:<0, 0>, end:<-1, 8>, (-1 means the last one)
|
||||
// Y P P Y P Y P Y
|
||||
// => <1, 4, 32> - <1, 1, 1, 2, 4> -> OK
|
||||
// |--> slice along this Y dim, the P dim is 1 in the left, so is OK
|
||||
// totally 16 slices
|
||||
//
|
||||
// X0 X1
|
||||
// <1, 4, 32> - <4, 1, 4, 2, 4> | slice origin:<0, 0>, len:<0, 4>, (0 means all length)
|
||||
// <1, 4, 32> - <4, 1, 4, 2, 4> | slice start:<0, 0>, end:<-1, 4>, (-1 means the last one)
|
||||
// Y P P Y P Y P Y
|
||||
// => <1, 4, 32> - <1, 1, 1, 1, 4> -> Fail
|
||||
// |--> slice along this P dim, will split threads, not supported
|
||||
//
|
||||
// X0 X1
|
||||
// <1, 4, 32> - <4, 1, 4, 2, 4> | slice origin:<0, 0>, len:<0, 16>, (0 means all length)
|
||||
// <1, 4, 32> - <4, 1, 4, 2, 4> | slice start:<0, 0>, end:<-1, 16>, (-1 means the last one)
|
||||
// Y P P Y P Y P Y
|
||||
// => <1, 4, 32> - <1, 1, 2, 2, 4> -> OK
|
||||
// |--> slice along this Y dim, but this Y sim need to split into 2
|
||||
@@ -577,11 +577,39 @@ CK_TILE_HOST_DEVICE constexpr auto slice_distribution_from_x(
|
||||
using Encoding = decltype(Distribution::get_static_tile_distribution_encoding());
|
||||
|
||||
static_assert(sizeof...(XSliceBegins) == sizeof...(XSliceEnds));
|
||||
static_assert(sizeof...(XSliceBegins) == Encoding::NDimX, "only support slice over h, not r");
|
||||
|
||||
constexpr auto x_slice_lengths = x_slice_ends - x_slice_begins;
|
||||
constexpr auto p_len_over_h = Encoding::detail::get_uniformed_p_dim_lengths_over_h();
|
||||
|
||||
constexpr auto x_slice_ends_ = generate_sequence_v2(
|
||||
[&](auto i) {
|
||||
if constexpr(x_slice_ends[i] == -1)
|
||||
{
|
||||
// -1 means till the end
|
||||
constexpr auto x_length_ =
|
||||
container_reduce(typename Encoding::HsLengthss{}[i], multiplies{}, number<1>{});
|
||||
return x_length_;
|
||||
}
|
||||
else
|
||||
{
|
||||
return x_slice_ends[i];
|
||||
}
|
||||
},
|
||||
number<x_slice_ends.size()>{});
|
||||
|
||||
constexpr auto x_slice_lengths = x_slice_ends_ - x_slice_begins;
|
||||
|
||||
constexpr auto x_slice_lengths_without_p = generate_sequence_v2(
|
||||
[&](auto i) constexpr {
|
||||
constexpr auto len_ = x_slice_lengths[i];
|
||||
static_assert(len_ % p_len_over_h[i] == 0,
|
||||
"slice length must be dividable by p_len_over_h");
|
||||
return number<len_ / p_len_over_h[i]>{};
|
||||
},
|
||||
number<x_slice_lengths.size()>{});
|
||||
|
||||
constexpr auto src_h_prefix_sum = Encoding::detail::get_h_dim_lengths_prefix_sum();
|
||||
constexpr auto src_y_info = Encoding::detail::get_sorted_y_info();
|
||||
constexpr auto src_y_info = Encoding::detail::get_sorted_y_to_h_info();
|
||||
constexpr auto src_y_dims = src_y_info[number<0>{}];
|
||||
constexpr auto src_y_maps = src_y_info[number<1>{}];
|
||||
constexpr auto src_y_prefix_sum = src_y_info[number<2>{}];
|
||||
@@ -590,14 +618,15 @@ CK_TILE_HOST_DEVICE constexpr auto slice_distribution_from_x(
|
||||
{
|
||||
auto y_slice_sorted_origins = make_zero_multi_index<Encoding::NDimY>();
|
||||
auto y_slice_lengths = Encoding::detail::ys_lengths_;
|
||||
constexpr auto y_to_h_masks = Encoding::detail::get_y_to_h_masks();
|
||||
|
||||
// This lambda will modify some value outside, so c++ will not treat return value as
|
||||
// constexpr
|
||||
// TODO: ugly
|
||||
auto new_h_lengths = transform_tuples(
|
||||
[&](auto h_len, auto id) {
|
||||
constexpr auto sliced_h =
|
||||
reverse_slice_sequence(h_len, number<x_slice_lengths[id]>{});
|
||||
constexpr auto sliced_h = reverse_slice_sequence(
|
||||
h_len, number<x_slice_lengths_without_p[id]>{}, y_to_h_masks[id]);
|
||||
|
||||
constexpr auto sliced_h_lens = sliced_h[number<0>{}];
|
||||
constexpr auto sliced_h_index = sliced_h[number<2>{}];
|
||||
@@ -605,26 +634,39 @@ CK_TILE_HOST_DEVICE constexpr auto slice_distribution_from_x(
|
||||
// update y_slice_lengths
|
||||
constexpr auto uniformed_h_index = sliced_h_index + number<src_h_prefix_sum[id]>{};
|
||||
constexpr auto found_y_index = container_find(src_y_dims, uniformed_h_index);
|
||||
constexpr auto y_to_h_dim_end = src_y_prefix_sum[id + 1];
|
||||
|
||||
static_assert(found_y_index >= 0 && found_y_index < src_y_dims.size(),
|
||||
"not sliced at y dim, please check");
|
||||
|
||||
static_for<0, sliced_h_index + 1, 1>{}([&](auto i) {
|
||||
y_slice_lengths(src_y_maps[found_y_index - i]) =
|
||||
sliced_h_lens[sliced_h_index - i];
|
||||
});
|
||||
{
|
||||
constexpr auto sliced_y_to_h_lens =
|
||||
pick_sequence_elements_by_mask(sliced_h_lens, y_to_h_masks[id]);
|
||||
constexpr auto sliced_y_to_h_dims = sliced_y_to_h_lens.size();
|
||||
static_for<0, sliced_y_to_h_dims, 1>{}([&](auto i) {
|
||||
y_slice_lengths(src_y_maps[y_to_h_dim_end - 1 - i]) =
|
||||
sliced_y_to_h_lens[sliced_y_to_h_dims - 1 - i];
|
||||
});
|
||||
}
|
||||
// TODO: add validations not across p dim
|
||||
|
||||
// NOTE: this y_origin is for all dims, not only current dim
|
||||
// will later use pick to select target dim
|
||||
constexpr auto y_origin = [&]() {
|
||||
constexpr auto h_trans = make_merge_transform_v3_division_mod(h_len);
|
||||
auto h_origin_ = make_zero_multi_index<h_trans.NDimLow>();
|
||||
h_trans.calculate_lower_index(h_origin_, sequence<x_slice_begins[id].value>{});
|
||||
// can't use Encoding::Ys2RHsMajor/Ys2RHsMinor, these are unordered
|
||||
constexpr auto y_to_h_len =
|
||||
pick_sequence_elements_by_mask(h_len, y_to_h_masks[id]);
|
||||
constexpr auto y_to_h_dims = y_to_h_len.size();
|
||||
|
||||
constexpr auto h_trans = make_merge_transform_v3_division_mod(y_to_h_len);
|
||||
auto h_origin_ = make_zero_multi_index<h_trans.NDimLow>();
|
||||
constexpr auto y_begin_ = x_slice_begins[id] / p_len_over_h[id];
|
||||
h_trans.calculate_lower_index(h_origin_, sequence<y_begin_.value>{});
|
||||
|
||||
auto y_origin_ = make_zero_multi_index<Encoding::NDimY>();
|
||||
static_for<0, sliced_h_index + 1, 1>{}([&](auto i) {
|
||||
y_origin_(found_y_index - i) = h_origin_[sliced_h_index - i];
|
||||
|
||||
static_for<0, y_to_h_dims, 1>{}([&](auto i) {
|
||||
y_origin_(y_to_h_dim_end - 1 - i) = h_origin_[y_to_h_dims - 1 - i];
|
||||
});
|
||||
return y_origin_;
|
||||
}();
|
||||
|
||||
@@ -255,33 +255,107 @@ struct tile_distribution_encoding
|
||||
}
|
||||
}();
|
||||
|
||||
// e.g. tuple<seq<1, 4, 32>, seq<4, 1, 4, 2, 4>> --> seq<3, 5> --> seq<0, 3, 8>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_h_dim_lengths_prefix_sum()
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_uniformed_h_dim_lengths()
|
||||
{
|
||||
// <len_d0, len_d1, ...>
|
||||
// e.g. tuple<seq<1, 4, 32>, seq<4, 1, 4, 2, 4>> --> seq<3, 5>
|
||||
constexpr auto uniformed_h_dim_lengths = generate_sequence_v2(
|
||||
[&](auto i) {
|
||||
constexpr index_t size = HsLengthss{}[i].size();
|
||||
return number<size>{};
|
||||
constexpr index_t size_ = HsLengthss{}[i].size();
|
||||
return number<size_>{};
|
||||
},
|
||||
number<NDimX>{});
|
||||
return uniformed_h_dim_lengths;
|
||||
}
|
||||
|
||||
// note: this function only count the p dim length along h, not r
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_uniformed_p_dim_lengths_over_h()
|
||||
{
|
||||
// e.g. tuple<seq<1, 4, 32>, seq<1, 2, 8, 4, 4>>
|
||||
// Y P Y Y P Y P Y
|
||||
// | | |
|
||||
// v v v
|
||||
// return : seq<4, 2 * 4> => seq<4, 8>
|
||||
constexpr auto uniformed_ps_to_rhss_major_ =
|
||||
unpack([](auto... xs_) { return merge_sequences(xs_...); }, ps_to_rhss_major_);
|
||||
constexpr auto uniformed_ps_to_rhss_minor_ =
|
||||
unpack([](auto... xs_) { return merge_sequences(xs_...); }, ps_to_rhss_minor_);
|
||||
|
||||
constexpr auto p_len_ = [&]() {
|
||||
array<index_t, NDimX> len_{1};
|
||||
static_for<0, NDimX, 1>{}([&](auto idim_x_) {
|
||||
constexpr auto major_ = number<idim_x_ + 1>{}; // RDim
|
||||
static_for<0, uniformed_ps_to_rhss_major_.size(), 1>{}([&](auto idim_u_) {
|
||||
if constexpr(major_.value == uniformed_ps_to_rhss_major_[idim_u_])
|
||||
{
|
||||
constexpr auto minor_ = uniformed_ps_to_rhss_minor_[idim_u_];
|
||||
constexpr auto h_length_ = hs_lengthss_[idim_x_][minor_];
|
||||
len_[idim_x_] *= h_length_;
|
||||
}
|
||||
});
|
||||
});
|
||||
return len_;
|
||||
}();
|
||||
constexpr auto p_len_over_h_seq_ = TO_SEQUENCE(p_len_, NDimX);
|
||||
return p_len_over_h_seq_;
|
||||
}
|
||||
|
||||
//
|
||||
// R: seq<3>, H: tuple<seq<1, 4, 32>, seq<4, 1, 4, 2, 4>>
|
||||
// => return seq<1, 3, 5>
|
||||
// R: seq<>, H: tuple<seq<2, 4>, seq<16, 8, 8>>
|
||||
// => return seq<0, 2, 3>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_uniformed_rh_dim_lengths()
|
||||
{
|
||||
constexpr auto uniformed_rh_dim_lengths =
|
||||
merge_sequences(sequence<NDimR>{} /*for R dims*/, get_uniformed_h_dim_lengths());
|
||||
|
||||
return uniformed_rh_dim_lengths;
|
||||
}
|
||||
|
||||
// e.g. tuple<seq<1, 4, 32>, seq<4, 1, 4, 2, 4>> --> seq<3, 5> --> seq<0, 3, 8>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_h_dim_lengths_prefix_sum()
|
||||
{
|
||||
// <0, len_d0, len_d0+len_d1, ...>
|
||||
// e.g. seq<3, 5> --> seq<0, 3, 8>
|
||||
constexpr auto h_dim_prefix_sum = prefix_sum_sequence(uniformed_h_dim_lengths);
|
||||
constexpr auto h_dim_prefix_sum = prefix_sum_sequence(get_uniformed_h_dim_lengths());
|
||||
|
||||
return h_dim_prefix_sum;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_uniformed_idx_y_to_h()
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_rh_dim_lengths_prefix_sum()
|
||||
{
|
||||
// <0, len_d0, len_d0+len_d1, ...>
|
||||
// e.g. seq<3, 5> --> seq<0, 3, 8>
|
||||
constexpr auto rh_dim_prefix_sum = prefix_sum_sequence(get_uniformed_rh_dim_lengths());
|
||||
|
||||
return rh_dim_prefix_sum;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_uniformed_idx_p_to_h()
|
||||
{
|
||||
// tuple<seq<xx..>, seq<yy..>> -> seq<xx..yy..>
|
||||
constexpr auto uniformed_ps_to_rhss_major_ =
|
||||
unpack([](auto... xs_) { return merge_sequences(xs_...); }, ps_to_rhss_major_);
|
||||
constexpr auto uniformed_ps_to_rhss_minor_ =
|
||||
unpack([](auto... xs_) { return merge_sequences(xs_...); }, ps_to_rhss_minor_);
|
||||
|
||||
constexpr auto all_ps_2_rhss = transform_sequences(
|
||||
[](auto major, auto minor) constexpr {
|
||||
constexpr auto rh_dim_prefix_sum = get_rh_dim_lengths_prefix_sum();
|
||||
return rh_dim_prefix_sum.at(major) + minor;
|
||||
},
|
||||
uniformed_ps_to_rhss_major_,
|
||||
uniformed_ps_to_rhss_minor_);
|
||||
|
||||
return all_ps_2_rhss;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_uniformed_idx_y_to_rh()
|
||||
{
|
||||
constexpr auto all_ys_2_rhss = transform_sequences(
|
||||
[](auto major, auto minor) constexpr {
|
||||
// <0, 0, len_d0, len_d0+len_d1, ...>
|
||||
constexpr auto x_dim_prefix_sum = merge_sequences(
|
||||
sequence<0>{} /*for R dims*/, get_h_dim_lengths_prefix_sum());
|
||||
return x_dim_prefix_sum.at(major) + minor;
|
||||
constexpr auto rh_dim_prefix_sum = get_rh_dim_lengths_prefix_sum();
|
||||
return rh_dim_prefix_sum.at(major) + minor;
|
||||
},
|
||||
Ys2RHsMajor{},
|
||||
Ys2RHsMinor{});
|
||||
@@ -289,6 +363,45 @@ struct tile_distribution_encoding
|
||||
return all_ys_2_rhss;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_uniformed_idx_y_to_h()
|
||||
{
|
||||
// TODO: Y can't point to R
|
||||
constexpr auto all_ys_2_rhss = transform_sequences(
|
||||
[](auto major, auto minor) constexpr {
|
||||
constexpr auto rh_dim_prefix_sum = get_rh_dim_lengths_prefix_sum();
|
||||
return rh_dim_prefix_sum.at(major) + minor - NDimR;
|
||||
},
|
||||
Ys2RHsMajor{},
|
||||
Ys2RHsMinor{});
|
||||
|
||||
return all_ys_2_rhss;
|
||||
}
|
||||
|
||||
// return tuple of seq
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_y_to_h_masks()
|
||||
{
|
||||
constexpr auto masks_ = generate_tuple(
|
||||
[&](auto i) {
|
||||
constexpr auto size_ = HsLengthss{}[i].size();
|
||||
constexpr auto current_y_to_h_mask_ = [&]() {
|
||||
array<index_t, size_> m_{0};
|
||||
// TODO: we loop over all y for each h dim
|
||||
for(auto j = 0; j < NDimY; j++)
|
||||
{
|
||||
if(Ys2RHsMajor{}[j] == (i + 1) /*RDim need plus 1*/)
|
||||
{
|
||||
m_[Ys2RHsMinor{}[j]] = 1;
|
||||
}
|
||||
}
|
||||
return m_;
|
||||
}();
|
||||
|
||||
return TO_SEQUENCE(current_y_to_h_mask_, size_);
|
||||
},
|
||||
number<NDimX>{});
|
||||
return masks_;
|
||||
}
|
||||
|
||||
// return tuple<sorted_dims, sorted_maps, sorted_prefix_sum>
|
||||
template <typename IdxSeq, typename PrefixSumSeq>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_sorted_info(IdxSeq, PrefixSumSeq)
|
||||
@@ -305,7 +418,8 @@ struct tile_distribution_encoding
|
||||
return make_tuple(sorted_dims, sorted_maps, sorted_prefix_sum);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_sorted_y_info()
|
||||
// Note here y_to_h does not count R dim!
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_sorted_y_to_h_info()
|
||||
{
|
||||
return get_sorted_info(get_uniformed_idx_y_to_h(), get_h_dim_lengths_prefix_sum());
|
||||
}
|
||||
|
||||
@@ -4,3 +4,4 @@ add_subdirectory(batched_gemm)
|
||||
add_subdirectory(grouped_gemm)
|
||||
add_subdirectory(gemm_multi_d)
|
||||
add_subdirectory(data_type)
|
||||
add_subdirectory(slice_tile)
|
||||
|
||||
1
test/ck_tile/slice_tile/CMakeLists.txt
Normal file
1
test/ck_tile/slice_tile/CMakeLists.txt
Normal file
@@ -0,0 +1 @@
|
||||
add_test_executable(test_slice_tile test_slice_tile.cpp)
|
||||
135
test/ck_tile/slice_tile/test_slice_tile.cpp
Normal file
135
test/ck_tile/slice_tile/test_slice_tile.cpp
Normal file
@@ -0,0 +1,135 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include <type_traits>
|
||||
|
||||
// clang-format off
|
||||
template<typename SliceStart_ = ck_tile::sequence<0, 0>,
|
||||
typename SliceEnd_ = ck_tile::sequence<64, 16>,
|
||||
typename Y_Origin_ = ck_tile::sequence<0, 0, 0, 0>>
|
||||
void test_slice_distribution_from_x_case_0(SliceStart_ = {}, SliceEnd_={}, Y_Origin_ = {})
|
||||
{
|
||||
// slice length [-1, 16]
|
||||
using namespace ck_tile;
|
||||
constexpr auto r = detail::slice_distribution_from_x(
|
||||
make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<1, 4, 16>, sequence<2, 2, 1, 4, 4>>,
|
||||
// Y P P Y P Y P Y
|
||||
tuple<sequence<1, 2>, sequence<2, 1>>,
|
||||
tuple<sequence<1, 1>, sequence<3, 2>>,
|
||||
sequence<1, 2, 2, 2>,
|
||||
sequence<0, 0, 2, 4>>{}),
|
||||
SliceStart_{},
|
||||
SliceEnd_{});
|
||||
|
||||
using sliced_dist_enc = remove_cvref_t<decltype(r[number<0>{}].get_static_tile_distribution_encoding())>;
|
||||
using target_dist_enc = tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<1, 4, 16>, sequence<1, 2, 1, 4, 2>>,
|
||||
// Y P P Y P Y P Y
|
||||
tuple<sequence<1, 2>, sequence<2, 1>>,
|
||||
tuple<sequence<1, 1>, sequence<3, 2>>,
|
||||
sequence<1, 2, 2, 2>,
|
||||
sequence<0, 0, 2, 4>>;
|
||||
|
||||
static_assert(std::is_same_v<sliced_dist_enc, target_dist_enc>);
|
||||
|
||||
using sliced_y_origins = remove_cvref_t<decltype(r[number<1>{}])>;
|
||||
using sliced_y_lengths = remove_cvref_t<decltype(r[number<2>{}])>;
|
||||
static_assert(std::is_same_v<sliced_y_origins, Y_Origin_>);
|
||||
static_assert(std::is_same_v<sliced_y_lengths, sequence<1, 1, 1, 2>>);
|
||||
}
|
||||
|
||||
template<typename SliceStart_ = ck_tile::sequence<0, 0>,
|
||||
typename SliceEnd_ = ck_tile::sequence<16, 16>,
|
||||
typename Y_Origin_ = ck_tile::sequence<0, 0, 0, 0, 0>>
|
||||
void test_slice_distribution_from_x_case_1(SliceStart_ = {}, SliceEnd_={}, Y_Origin_ = {})
|
||||
{
|
||||
// slice length [16, 16]
|
||||
using namespace ck_tile;
|
||||
constexpr auto r = detail::slice_distribution_from_x(
|
||||
make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<4, 8, 2>, sequence<2, 4, 2, 8, 2>>,
|
||||
// Y P Y Y P Y Y P
|
||||
tuple<sequence<1>, sequence<2, 2>>,
|
||||
tuple<sequence<1>, sequence<4, 1>>,
|
||||
sequence<1, 1, 2, 2, 2>,
|
||||
sequence<0, 2, 0, 2, 3>>{}),
|
||||
SliceStart_{},
|
||||
SliceEnd_{});
|
||||
|
||||
using sliced_dist_enc = remove_cvref_t<decltype(r[number<0>{}].get_static_tile_distribution_encoding())>;
|
||||
using target_dist_enc = tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<1, 8, 2>, sequence<1, 4, 1, 2, 2>>,
|
||||
// Y P Y Y P Y Y P
|
||||
tuple<sequence<1>, sequence<2, 2>>,
|
||||
tuple<sequence<1>, sequence<4, 1>>,
|
||||
sequence<1, 1, 2, 2, 2>,
|
||||
sequence<0, 2, 0, 2, 3>>;
|
||||
|
||||
static_assert(std::is_same_v<sliced_dist_enc, target_dist_enc>);
|
||||
|
||||
using sliced_y_origins = remove_cvref_t<decltype(r[number<1>{}])>;
|
||||
using sliced_y_lengths = remove_cvref_t<decltype(r[number<2>{}])>;
|
||||
static_assert(std::is_same_v<sliced_y_origins, Y_Origin_>);
|
||||
static_assert(std::is_same_v<sliced_y_lengths, sequence<1, 2, 1, 1, 2>>);
|
||||
}
|
||||
|
||||
template<typename SliceStart_ = ck_tile::sequence<0, 0>,
|
||||
typename SliceEnd_ = ck_tile::sequence<12, 48>,
|
||||
typename Y_Origin_ = ck_tile::sequence<0, 0, 0, 0, 0>>
|
||||
void test_slice_distribution_from_x_case_2(SliceStart_ = {}, SliceEnd_={}, Y_Origin_ = {})
|
||||
{
|
||||
// slice length [12, 48]
|
||||
using namespace ck_tile;
|
||||
constexpr auto r = detail::slice_distribution_from_x(
|
||||
make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<4, 5>,
|
||||
tuple<sequence<4, 3, 2>, sequence<2, 2, 1, 4, 3, 4>>,
|
||||
// Y P Y Y P, Y, P P, Y
|
||||
tuple<sequence<0, 1, 0>, sequence<2, 2, 2>>,
|
||||
tuple<sequence<0, 1, 1>, sequence<4, 1, 3>>,
|
||||
sequence<1, 2, 1, 2, 2>,
|
||||
sequence<2, 0, 0, 5, 2>>{}),
|
||||
SliceStart_{},
|
||||
SliceEnd_{});
|
||||
|
||||
using sliced_dist_enc = remove_cvref_t<decltype(r[number<0>{}].get_static_tile_distribution_encoding())>;
|
||||
using target_dist_enc = tile_distribution_encoding<sequence<4, 5>,
|
||||
tuple<sequence<2, 3, 2>, sequence<1, 2, 1, 4, 3, 2>>,
|
||||
// Y P Y Y P, Y, P P, Y
|
||||
tuple<sequence<0, 1, 0>, sequence<2, 2, 2>>,
|
||||
tuple<sequence<0, 1, 1>, sequence<4, 1, 3>>,
|
||||
sequence<1, 2, 1, 2, 2>,
|
||||
sequence<2, 0, 0, 5, 2>>;
|
||||
|
||||
static_assert(std::is_same_v<sliced_dist_enc, target_dist_enc>);
|
||||
|
||||
using sliced_y_origins = remove_cvref_t<decltype(r[number<1>{}])>;
|
||||
using sliced_y_lengths = remove_cvref_t<decltype(r[number<2>{}])>;
|
||||
static_assert(std::is_same_v<sliced_y_origins, Y_Origin_>);
|
||||
static_assert(std::is_same_v<sliced_y_lengths, sequence<2, 1, 2, 2, 1>>);
|
||||
}
|
||||
|
||||
void test_slice_distribution_from_x()
|
||||
{
|
||||
using namespace ck_tile;
|
||||
|
||||
test_slice_distribution_from_x_case_0(sequence< 0, 0>{}, sequence<-1, 16>{}, sequence<0, 0, 0, 0>{});
|
||||
test_slice_distribution_from_x_case_0(sequence< 0, 16>{}, sequence<-1, 32>{}, sequence<0, 0, 0, 2>{});
|
||||
test_slice_distribution_from_x_case_0(sequence< 0, 32>{}, sequence<-1, 48>{}, sequence<0, 1, 0, 0>{});
|
||||
test_slice_distribution_from_x_case_0(sequence< 0, 48>{}, sequence<-1, 64>{}, sequence<0, 1, 0, 2>{});
|
||||
|
||||
test_slice_distribution_from_x_case_1(sequence< 0, 0>{}, sequence<16, 16>{}, sequence<0, 0, 0, 0, 0>{});
|
||||
test_slice_distribution_from_x_case_1(sequence<16, 16>{}, sequence<32, 32>{}, sequence<1, 0, 0, 0, 2>{});
|
||||
test_slice_distribution_from_x_case_1(sequence<32, 64>{}, sequence<48, 80>{}, sequence<2, 0, 0, 1, 0>{});
|
||||
test_slice_distribution_from_x_case_1(sequence<48, 208>{}, sequence<64, 224>{}, sequence<3, 0, 1, 1, 2>{});
|
||||
|
||||
test_slice_distribution_from_x_case_2(sequence< 0, 0>{}, sequence<12, 48>{}, sequence<0, 0, 0, 0, 0>{});
|
||||
test_slice_distribution_from_x_case_2(sequence<12, 144>{}, sequence<24, 192>{}, sequence<0, 1, 2, 2, 0>{});
|
||||
}
|
||||
|
||||
// clang-format on
|
||||
int main() { test_slice_distribution_from_x(); }
|
||||
Reference in New Issue
Block a user