[CK_TILE] MX Flatmm Use Byte Pointer Arithmetic for A Tensor (#3446)

* A as bytes

* Reformat with static_for_product
This commit is contained in:
Yi DING
2025-12-19 10:28:13 +08:00
committed by GitHub
parent c0ee71d735
commit 2220cbaba7
4 changed files with 309 additions and 314 deletions

View File

@@ -82,6 +82,34 @@ struct static_for<0, N, 1> : detail::make_applier<N>
using detail::make_applier<N>::operator();
};
template <typename... Ts>
struct static_for_product;
template <index_t... Is>
struct static_for_product<static_for<Is...>> : public static_for<Is...>
{
};
template <index_t... Is>
struct static_for_product<sequence<Is...>> : public static_for<Is...>
{
};
template <index_t I>
struct static_for_product<number<I>> : public static_for<0, I, 1>
{
};
template <typename First, typename... Rest>
struct static_for_product<First, Rest...>
{
template <typename F>
CK_TILE_HOST_DEVICE constexpr void operator()(F f) const
{
static_for_product<First>{}([=](auto I) {
static_for_product<Rest...>{}([=](auto... Is) { //
f(I, Is...);
});
});
}
};
struct identity
{
template <typename T>