Implement grouped gemm tile loop for RDNA4 (#3304)

* feat: grouped gemm tile loop support for RDNA4

* fix: removed extra parameter from grouped gemm example instance

* fix: FP8 check incorrectly enabling FP8 on RDNA3
This commit is contained in:
Erwin Terpstra
2026-01-13 07:14:23 +01:00
committed by GitHub
parent 141f77aa12
commit eb041079a3
44 changed files with 3067 additions and 1223 deletions

View File

@@ -7,6 +7,7 @@
#include "ck/utility/sequence.hpp"
#include "ck/utility/type.hpp"
#include "ck/utility/enable_if.hpp"
#include <tuple>
namespace ck {
@@ -220,4 +221,49 @@ constexpr Tuple<Args&...> tie(Args&... args) noexcept
return {args...};
}
//
// tuple_map: Map tuple with a different type
// e.g. tuple_map<Wrapper, Tuple<T1, T2, T3>> becomes Tuple<Wrapper<T1>, Wrapper<T2>, Wrapper<T3>>
//
template <template <typename> class Wrapper, typename Tuple>
struct tuple_map;
template <template <typename> class Wrapper, typename... Ts>
struct tuple_map<Wrapper, Tuple<Ts...>>
{
using type = Tuple<Wrapper<Ts>...>;
};
template <template <typename> class Wrapper, typename Tuple>
using tuple_map_t = typename tuple_map<Wrapper, Tuple>::type;
//
// tuple_element_or: helper to access type element of a tuple by index, with the option to default
// to a type if the index is out of range of the tuple size
//
namespace detail {
// Base template (will be specialized on the boolean)
template <ck::index_t N, typename Tuple, typename Default, bool InRange = (N < Tuple::Size())>
struct tuple_element_or_impl;
// Specialization for the in-range case: use tuple_element_t
template <ck::index_t N, typename Tuple, typename Default>
struct tuple_element_or_impl<N, Tuple, Default, true>
{
using type = tuple_element_t<N, Tuple>;
};
// Specialization for the out-of-range case: use Default
template <ck::index_t N, typename Tuple, typename Default>
struct tuple_element_or_impl<N, Tuple, Default, false>
{
using type = Default;
};
} // namespace detail
// User-facing alias
template <ck::index_t N, typename Tuple, typename Default>
using tuple_element_or_t = typename detail::tuple_element_or_impl<N, Tuple, Default>::type;
} // namespace ck