[CK_TILE][CORE] enhance slice_tile api (#2430)

* support slice cross p

* fix some bug in y_len

* more case

* fix a bug when R exist

* support -1 to hint end of current length

* format

* change commit
This commit is contained in:
carlushuang
2025-07-07 11:13:12 +08:00
committed by GitHub
parent 7998ae8969
commit a8742f7e31
6 changed files with 337 additions and 30 deletions

View File

@@ -1178,6 +1178,15 @@ struct reverse_slice_sequence_impl<sequence<x>, sequence<m>, sequence<id>, Slice
// clang-format off
// input a sequence(with optional mask), and the SliceSize : size per slice
// output the sequence each slice, and number of slices
// the length count for slice size is from right to left(reverse slice)
// or we can say, find the greatest common divider(gcd) from right to left, for the slice length
//
// e.g. <2, 8, 4>, slice length = 16
// step-1: we take the right most <*, *, 4>, remaining 16/4=4
// step-2: we only need 4 out of 8, of the midden dim, hence <*, 4, 4>
// step-3: since nonthing remain, so the first dim we only need 1, hence<1, 4, 4>
// => we got <1, 4, 4> as length for each slice
// => total number of slice = <2, 8, 4> / <1, 4, 4> = <2, 2, 1>
//
// e.g. <2, 1, 4, 2>, 8 -> lengths:<1, 1, 4, 2> , nums: <2, 1, 1, 1> : 2 slices , slice_idx: 0
// <4, 2, 4, 1, 2>, 4 -> lengths:<1, 1, 2, 1, 2> , nums: <4, 2, 2, 1, 1> : 16 slices , slice_idx: 2
@@ -1197,7 +1206,7 @@ struct reverse_slice_sequence_impl<sequence<x>, sequence<m>, sequence<id>, Slice
//
// return tuple<slice_lengths, slice_nums, slice_index>, slice_index is at which index will start
// have split slices (right -> left)
// or the first index that sliced length is different from the original length
// or the first index (right -> left) that sliced length is different from the original length
// clang-format on
template <typename Seq,
index_t SliceSize,
@@ -1207,6 +1216,11 @@ constexpr auto reverse_slice_sequence(Seq,
Mask = typename uniform_sequence_gen<Seq::size(), 1>::type{})
{
static_assert(Seq::size() == Mask::size());
static_assert(SliceSize != 0, "slice size zero is invalid");
static_assert(container_reduce(pick_sequence_elements_by_mask(Seq{}, Mask{}), multiplies{}, 1) %
SliceSize ==
0,
"slice size can't evenly divide input sizes");
using sliced_type =
impl::reverse_slice_sequence_impl<Seq,
Mask,