mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 21:51:28 +00:00
support flatmm scaling
This commit is contained in:
@@ -165,6 +165,9 @@ struct sequence
|
||||
return sequence<Is..., Xs...>{};
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto sum() { return (Is + ... + 0); }
|
||||
CK_TILE_HOST_DEVICE static constexpr auto product() { return (Is * ... * 1); }
|
||||
|
||||
// pickup element at index <Ids...>
|
||||
template <index_t... Ids>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto extract(number<Ids>...)
|
||||
@@ -1236,9 +1239,8 @@ constexpr auto reverse_slice_sequence(Seq,
|
||||
template <typename Seq,
|
||||
index_t SliceSize,
|
||||
typename Mask = typename uniform_sequence_gen<Seq::size(), 1>::type>
|
||||
constexpr auto slice_sequence(Seq,
|
||||
number<SliceSize>,
|
||||
Mask = typename uniform_sequence_gen<Seq::size(), 1>::type{})
|
||||
constexpr auto
|
||||
slice_sequence(Seq, number<SliceSize>, Mask = typename uniform_sequence_gen<Seq::size(), 1>::type{})
|
||||
{
|
||||
constexpr auto r =
|
||||
reverse_slice_sequence(Seq{}.reverse(), number<SliceSize>{}, Mask{}.reverse());
|
||||
|
||||
Reference in New Issue
Block a user