CUTLASS 3.6.0 (#1850)

* v3.6

* update changelog

* update readme

* fix typo

* fixing typos

* hopper gemm with weight prefetch

---------

Co-authored-by: yuzhai <yuzhai@nvidia.com>
Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
Yujia Zhai
2024-10-09 12:33:27 -07:00
committed by GitHub
parent 0837a2a00a
commit cc3c29a81a
354 changed files with 105943 additions and 8203 deletions

View File

@@ -44,7 +44,7 @@
/// Code guidelines and style preferences:
///
/// For perfect forwarding, don't use std::forward, because it may not
/// be defined in device code when compiling with NVRTC. Instead, use
/// be defined in device code when compiling with NVRTC. Instead, use
/// `static_cast<ParameterType&&>(parameter_name)`.
///
/// CuTe generally does not bother forwarding functions, as
@@ -52,24 +52,9 @@
///
/// Throughout CUTLASS, cute::make_tuple always needs to be called
/// namespace-qualified, EVEN If inside the cute namespace and/or in
/// scope of a "using namespace cute" declaration. Otherwise, the
/// scope of a "using namespace cute" declaration. Otherwise, the
/// compiler may select std::make_tuple instead of cute::make_tuple,
/// due to argument-dependent lookup. Two problems may result from
/// that.
///
/// 1. Functions have an unexpected return type (std::tuple instead of
/// cute::tuple), so functions that take cute::tuple parameters
/// fail to compile (generally inside functions that have template
/// parameters expected to be cute::tuple).
///
/// 2. std::tuple does not have the required __host__ __device__
/// markings, so the CUDA compiler complains if you use it in
/// device code.
///
/// cute::make_tuple will occur more often than std::make_tuple would
/// in modern C++ code, because cute::tuple's design deprioritizes
/// correct operation of CTAD (constructor template argument
/// deduction) in favor of implementation simplicity.
/// due to argument-dependent lookup.
namespace cute
{
@@ -145,6 +130,8 @@ transform_apply(T&& t, F&& f, G&& g)
} else {
return g(f(static_cast<T&&>(t)));
}
CUTE_GCC_UNREACHABLE;
}
template <class T0, class T1, class F, class G>
@@ -157,6 +144,8 @@ transform_apply(T0&& t0, T1&& t1, F&& f, G&& g)
} else {
return g(f(static_cast<T0&&>(t0), static_cast<T1&&>(t1)));
}
CUTE_GCC_UNREACHABLE;
}
template <class T0, class T1, class T2, class F, class G>
@@ -169,6 +158,8 @@ transform_apply(T0&& t0, T1&& t1, T2&& t2, F&& f, G&& g)
} else {
return g(f(static_cast<T0&&>(t0), static_cast<T1&&>(t1), static_cast<T2&&>(t2)));
}
CUTE_GCC_UNREACHABLE;
}
//
@@ -401,71 +392,36 @@ filter_tuple(T0 const& t0, T1 const& t1, T2 const& t2, F&& f)
namespace detail {
// This impl compiles much faster than cute::apply and variadic args
template <class T, class V, class F>
template <class Fn, class Val>
struct FoldAdaptor {
template <class X>
CUTE_HOST_DEVICE constexpr auto operator|(X&& x) {
auto r = fn_(val_, static_cast<X&&>(x));
return FoldAdaptor<Fn, decltype(r)>{fn_, r};
}
Fn fn_;
Val val_;
};
template <class T, class V, class F, int... Is>
CUTE_HOST_DEVICE constexpr
auto
fold(T&&, V&& v, F&&, seq<>)
fold(T&& t, V const& v, F&& f, seq<Is...>)
{
return v;
return (FoldAdaptor<F,V>{f,v} | ... | get<Is>(static_cast<T&&>(t))).val_;
}
template <class T, class V, class F, int I0>
CUTE_HOST_DEVICE constexpr
auto
fold(T&& t, V&& v, F&& f, seq<I0>)
{
return f(static_cast<V&&>(v), get<I0>(static_cast<T&&>(t)));
}
template <class T, class V, class F, int I0, int I1>
CUTE_HOST_DEVICE constexpr
auto
fold(T&& t, V&& v, F&& f, seq<I0,I1>)
{
return f(f(static_cast<V&&>(v), get<I0>(static_cast<T&&>(t))), get<I1>(static_cast<T&&>(t)));
}
template <class T, class V, class F, int I0, int I1, int I2>
CUTE_HOST_DEVICE constexpr
auto
fold(T&& t, V&& v, F&& f, seq<I0,I1,I2>)
{
return f(f(f(static_cast<V&&>(v), get<I0>(static_cast<T&&>(t))), get<I1>(static_cast<T&&>(t))), get<I2>(static_cast<T&&>(t)));
}
template <class T, class V, class F, int I0, int I1, int I2, int I3>
CUTE_HOST_DEVICE constexpr
auto
fold(T&& t, V&& v, F&& f, seq<I0,I1,I2,I3>)
{
return f(f(f(f(static_cast<V&&>(v), get<I0>(static_cast<T&&>(t))), get<I1>(static_cast<T&&>(t))), get<I2>(static_cast<T&&>(t))), get<I3>(static_cast<T&&>(t)));
}
template <class T, class V, class F, int I0, int I1, int I2, int I3, int... Is>
CUTE_HOST_DEVICE constexpr
auto
fold(T&& t, V&& v, F&& f, seq<I0,I1,I2,I3,Is...>)
{
return fold(static_cast<T&&>(t),
f(f(f(f(static_cast<V&&>(v), get<I0>(static_cast<T&&>(t))), get<I1>(static_cast<T&&>(t))), get<I2>(static_cast<T&&>(t))), get<I3>(static_cast<T&&>(t))),
f,
seq<Is...>{});
}
} // end namespace detail
template <class T, class V, class F>
CUTE_HOST_DEVICE constexpr
auto
fold(T&& t, V&& v, F&& f)
fold(T&& t, V const& v, F&& f)
{
if constexpr (is_tuple<remove_cvref_t<T>>::value) {
return detail::fold(static_cast<T&&>(t),
static_cast<V&&>(v),
f,
tuple_seq<T>{});
return detail::fold(static_cast<T&&>(t), v, f, tuple_seq<T>{});
} else {
return f(static_cast<V&&>(v), static_cast<T&&>(t));
return f(v, static_cast<T&&>(t));
}
CUTE_GCC_UNREACHABLE;
@@ -477,10 +433,7 @@ auto
fold_first(T&& t, F&& f)
{
if constexpr (is_tuple<remove_cvref_t<T>>::value) {
return detail::fold(static_cast<T&&>(t),
get<0>(static_cast<T&&>(t)),
f,
make_range<1,tuple_size<remove_cvref_t<T>>::value>{});
return detail::fold(static_cast<T&&>(t), get<0>(t), f, make_range<1,tuple_size<remove_cvref_t<T>>::value>{});
} else {
return t;
}
@@ -536,13 +489,23 @@ CUTE_HOST_DEVICE constexpr
auto
take(T const& t)
{
return detail::apply(t, [](auto const&... a) { return cute::make_tuple(a...); }, make_range<B,E>{});
if constexpr (E == -1) {
if constexpr (is_tuple<T>::value) {
return take<B,tuple_size<T>::value>(t);
} else {
return take<B,1>(t);
}
} else
if constexpr (B <= E) {
return detail::apply(t, [](auto const&... a) { return cute::make_tuple(a...); }, make_range<B,E>{});
} else {
static_assert(B <= E);
}
CUTE_GCC_UNREACHABLE;
}
//
// Select tuple elements with given indices.
//
template <int... I, class T>
CUTE_HOST_DEVICE constexpr
auto
@@ -551,19 +514,6 @@ select(T const& t)
return cute::make_tuple(get<I>(t)...);
}
template <class T, class Indices>
CUTE_HOST_DEVICE constexpr
auto
select(T const& t, Indices const& indices)
{
if constexpr (is_tuple<Indices>::value) {
return cute::transform(indices, [&t](auto i) { return select(t, i); });
} else {
static_assert(is_static<Indices>::value, "Order must be static");
return get<Indices::value>(t);
}
}
// Wrap non-tuples into rank-1 tuples or forward
template <class T>
CUTE_HOST_DEVICE constexpr