support flatmm scaling

This commit is contained in:
Feng Shijie
2025-07-23 19:04:22 +00:00
parent 3f7d848dd3
commit 5a1183ebbd
7 changed files with 476 additions and 318 deletions

View File

@@ -165,6 +165,9 @@ struct sequence
return sequence<Is..., Xs...>{};
}
CK_TILE_HOST_DEVICE static constexpr auto sum() { return (Is + ... + 0); }
CK_TILE_HOST_DEVICE static constexpr auto product() { return (Is * ... * 1); }
// pickup element at index <Ids...>
template <index_t... Ids>
CK_TILE_HOST_DEVICE static constexpr auto extract(number<Ids>...)
@@ -1236,9 +1239,8 @@ constexpr auto reverse_slice_sequence(Seq,
template <typename Seq,
index_t SliceSize,
typename Mask = typename uniform_sequence_gen<Seq::size(), 1>::type>
constexpr auto slice_sequence(Seq,
number<SliceSize>,
Mask = typename uniform_sequence_gen<Seq::size(), 1>::type{})
constexpr auto
slice_sequence(Seq, number<SliceSize>, Mask = typename uniform_sequence_gen<Seq::size(), 1>::type{})
{
constexpr auto r =
reverse_slice_sequence(Seq{}.reverse(), number<SliceSize>{}, Mask{}.reverse());