mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[CK_TILE][CORE] enhance slice_tile api (#2430)
* support slice cross p * fix some bug in y_len * more case * fix a bug when R exist * support -1 to hint end of current length * format * change commit
This commit is contained in:
@@ -1178,6 +1178,15 @@ struct reverse_slice_sequence_impl<sequence<x>, sequence<m>, sequence<id>, Slice
|
||||
// clang-format off
|
||||
// input a sequence(with optional mask), and the SliceSize : size per slice
|
||||
// output the sequence each slice, and number of slices
|
||||
// the length count for slice size is from right to left(reverse slice)
|
||||
// or we can say, find the greatest common divider(gcd) from right to left, for the slice length
|
||||
//
|
||||
// e.g. <2, 8, 4>, slice length = 16
|
||||
// step-1: we take the right most <*, *, 4>, remaining 16/4=4
|
||||
// step-2: we only need 4 out of 8, of the midden dim, hence <*, 4, 4>
|
||||
// step-3: since nonthing remain, so the first dim we only need 1, hence<1, 4, 4>
|
||||
// => we got <1, 4, 4> as length for each slice
|
||||
// => total number of slice = <2, 8, 4> / <1, 4, 4> = <2, 2, 1>
|
||||
//
|
||||
// e.g. <2, 1, 4, 2>, 8 -> lengths:<1, 1, 4, 2> , nums: <2, 1, 1, 1> : 2 slices , slice_idx: 0
|
||||
// <4, 2, 4, 1, 2>, 4 -> lengths:<1, 1, 2, 1, 2> , nums: <4, 2, 2, 1, 1> : 16 slices , slice_idx: 2
|
||||
@@ -1197,7 +1206,7 @@ struct reverse_slice_sequence_impl<sequence<x>, sequence<m>, sequence<id>, Slice
|
||||
//
|
||||
// return tuple<slice_lengths, slice_nums, slice_index>, slice_index is at which index will start
|
||||
// have split slices (right -> left)
|
||||
// or the first index that sliced length is different from the original length
|
||||
// or the first index (right -> left) that sliced length is different from the original length
|
||||
// clang-format on
|
||||
template <typename Seq,
|
||||
index_t SliceSize,
|
||||
@@ -1207,6 +1216,11 @@ constexpr auto reverse_slice_sequence(Seq,
|
||||
Mask = typename uniform_sequence_gen<Seq::size(), 1>::type{})
|
||||
{
|
||||
static_assert(Seq::size() == Mask::size());
|
||||
static_assert(SliceSize != 0, "slice size zero is invalid");
|
||||
static_assert(container_reduce(pick_sequence_elements_by_mask(Seq{}, Mask{}), multiplies{}, 1) %
|
||||
SliceSize ==
|
||||
0,
|
||||
"slice size can't evenly divide input sizes");
|
||||
using sliced_type =
|
||||
impl::reverse_slice_sequence_impl<Seq,
|
||||
Mask,
|
||||
|
||||
Reference in New Issue
Block a user