mirror of
https://github.com/NVIDIA/cutlass.git
synced 2026-05-24 14:54:34 +00:00
CUTLASS 3.6.0 (#1850)
* v3.6 * update changelog * update readme * fix typo * fixing typos * hopper gemm with weight prefetch --------- Co-authored-by: yuzhai <yuzhai@nvidia.com> Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
@@ -44,7 +44,7 @@
|
||||
/// Code guidelines and style preferences:
|
||||
///
|
||||
/// For perfect forwarding, don't use std::forward, because it may not
|
||||
/// be defined in device code when compiling with NVRTC. Instead, use
|
||||
/// be defined in device code when compiling with NVRTC. Instead, use
|
||||
/// `static_cast<ParameterType&&>(parameter_name)`.
|
||||
///
|
||||
/// CuTe generally does not bother forwarding functions, as
|
||||
@@ -52,24 +52,9 @@
|
||||
///
|
||||
/// Throughout CUTLASS, cute::make_tuple always needs to be called
|
||||
/// namespace-qualified, EVEN If inside the cute namespace and/or in
|
||||
/// scope of a "using namespace cute" declaration. Otherwise, the
|
||||
/// scope of a "using namespace cute" declaration. Otherwise, the
|
||||
/// compiler may select std::make_tuple instead of cute::make_tuple,
|
||||
/// due to argument-dependent lookup. Two problems may result from
|
||||
/// that.
|
||||
///
|
||||
/// 1. Functions have an unexpected return type (std::tuple instead of
|
||||
/// cute::tuple), so functions that take cute::tuple parameters
|
||||
/// fail to compile (generally inside functions that have template
|
||||
/// parameters expected to be cute::tuple).
|
||||
///
|
||||
/// 2. std::tuple does not have the required __host__ __device__
|
||||
/// markings, so the CUDA compiler complains if you use it in
|
||||
/// device code.
|
||||
///
|
||||
/// cute::make_tuple will occur more often than std::make_tuple would
|
||||
/// in modern C++ code, because cute::tuple's design deprioritizes
|
||||
/// correct operation of CTAD (constructor template argument
|
||||
/// deduction) in favor of implementation simplicity.
|
||||
/// due to argument-dependent lookup.
|
||||
|
||||
namespace cute
|
||||
{
|
||||
@@ -145,6 +130,8 @@ transform_apply(T&& t, F&& f, G&& g)
|
||||
} else {
|
||||
return g(f(static_cast<T&&>(t)));
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
template <class T0, class T1, class F, class G>
|
||||
@@ -157,6 +144,8 @@ transform_apply(T0&& t0, T1&& t1, F&& f, G&& g)
|
||||
} else {
|
||||
return g(f(static_cast<T0&&>(t0), static_cast<T1&&>(t1)));
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
template <class T0, class T1, class T2, class F, class G>
|
||||
@@ -169,6 +158,8 @@ transform_apply(T0&& t0, T1&& t1, T2&& t2, F&& f, G&& g)
|
||||
} else {
|
||||
return g(f(static_cast<T0&&>(t0), static_cast<T1&&>(t1), static_cast<T2&&>(t2)));
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
//
|
||||
@@ -401,71 +392,36 @@ filter_tuple(T0 const& t0, T1 const& t1, T2 const& t2, F&& f)
|
||||
|
||||
namespace detail {
|
||||
|
||||
// This impl compiles much faster than cute::apply and variadic args
|
||||
template <class T, class V, class F>
|
||||
template <class Fn, class Val>
|
||||
struct FoldAdaptor {
|
||||
template <class X>
|
||||
CUTE_HOST_DEVICE constexpr auto operator|(X&& x) {
|
||||
auto r = fn_(val_, static_cast<X&&>(x));
|
||||
return FoldAdaptor<Fn, decltype(r)>{fn_, r};
|
||||
}
|
||||
Fn fn_;
|
||||
Val val_;
|
||||
};
|
||||
|
||||
template <class T, class V, class F, int... Is>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
fold(T&&, V&& v, F&&, seq<>)
|
||||
fold(T&& t, V const& v, F&& f, seq<Is...>)
|
||||
{
|
||||
return v;
|
||||
return (FoldAdaptor<F,V>{f,v} | ... | get<Is>(static_cast<T&&>(t))).val_;
|
||||
}
|
||||
|
||||
template <class T, class V, class F, int I0>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
fold(T&& t, V&& v, F&& f, seq<I0>)
|
||||
{
|
||||
return f(static_cast<V&&>(v), get<I0>(static_cast<T&&>(t)));
|
||||
}
|
||||
|
||||
template <class T, class V, class F, int I0, int I1>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
fold(T&& t, V&& v, F&& f, seq<I0,I1>)
|
||||
{
|
||||
return f(f(static_cast<V&&>(v), get<I0>(static_cast<T&&>(t))), get<I1>(static_cast<T&&>(t)));
|
||||
}
|
||||
|
||||
template <class T, class V, class F, int I0, int I1, int I2>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
fold(T&& t, V&& v, F&& f, seq<I0,I1,I2>)
|
||||
{
|
||||
return f(f(f(static_cast<V&&>(v), get<I0>(static_cast<T&&>(t))), get<I1>(static_cast<T&&>(t))), get<I2>(static_cast<T&&>(t)));
|
||||
}
|
||||
|
||||
template <class T, class V, class F, int I0, int I1, int I2, int I3>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
fold(T&& t, V&& v, F&& f, seq<I0,I1,I2,I3>)
|
||||
{
|
||||
return f(f(f(f(static_cast<V&&>(v), get<I0>(static_cast<T&&>(t))), get<I1>(static_cast<T&&>(t))), get<I2>(static_cast<T&&>(t))), get<I3>(static_cast<T&&>(t)));
|
||||
}
|
||||
|
||||
template <class T, class V, class F, int I0, int I1, int I2, int I3, int... Is>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
fold(T&& t, V&& v, F&& f, seq<I0,I1,I2,I3,Is...>)
|
||||
{
|
||||
return fold(static_cast<T&&>(t),
|
||||
f(f(f(f(static_cast<V&&>(v), get<I0>(static_cast<T&&>(t))), get<I1>(static_cast<T&&>(t))), get<I2>(static_cast<T&&>(t))), get<I3>(static_cast<T&&>(t))),
|
||||
f,
|
||||
seq<Is...>{});
|
||||
}
|
||||
} // end namespace detail
|
||||
|
||||
template <class T, class V, class F>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
fold(T&& t, V&& v, F&& f)
|
||||
fold(T&& t, V const& v, F&& f)
|
||||
{
|
||||
if constexpr (is_tuple<remove_cvref_t<T>>::value) {
|
||||
return detail::fold(static_cast<T&&>(t),
|
||||
static_cast<V&&>(v),
|
||||
f,
|
||||
tuple_seq<T>{});
|
||||
return detail::fold(static_cast<T&&>(t), v, f, tuple_seq<T>{});
|
||||
} else {
|
||||
return f(static_cast<V&&>(v), static_cast<T&&>(t));
|
||||
return f(v, static_cast<T&&>(t));
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
@@ -477,10 +433,7 @@ auto
|
||||
fold_first(T&& t, F&& f)
|
||||
{
|
||||
if constexpr (is_tuple<remove_cvref_t<T>>::value) {
|
||||
return detail::fold(static_cast<T&&>(t),
|
||||
get<0>(static_cast<T&&>(t)),
|
||||
f,
|
||||
make_range<1,tuple_size<remove_cvref_t<T>>::value>{});
|
||||
return detail::fold(static_cast<T&&>(t), get<0>(t), f, make_range<1,tuple_size<remove_cvref_t<T>>::value>{});
|
||||
} else {
|
||||
return t;
|
||||
}
|
||||
@@ -536,13 +489,23 @@ CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
take(T const& t)
|
||||
{
|
||||
return detail::apply(t, [](auto const&... a) { return cute::make_tuple(a...); }, make_range<B,E>{});
|
||||
if constexpr (E == -1) {
|
||||
if constexpr (is_tuple<T>::value) {
|
||||
return take<B,tuple_size<T>::value>(t);
|
||||
} else {
|
||||
return take<B,1>(t);
|
||||
}
|
||||
} else
|
||||
if constexpr (B <= E) {
|
||||
return detail::apply(t, [](auto const&... a) { return cute::make_tuple(a...); }, make_range<B,E>{});
|
||||
} else {
|
||||
static_assert(B <= E);
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
//
|
||||
// Select tuple elements with given indices.
|
||||
//
|
||||
|
||||
template <int... I, class T>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
@@ -551,19 +514,6 @@ select(T const& t)
|
||||
return cute::make_tuple(get<I>(t)...);
|
||||
}
|
||||
|
||||
template <class T, class Indices>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
select(T const& t, Indices const& indices)
|
||||
{
|
||||
if constexpr (is_tuple<Indices>::value) {
|
||||
return cute::transform(indices, [&t](auto i) { return select(t, i); });
|
||||
} else {
|
||||
static_assert(is_static<Indices>::value, "Order must be static");
|
||||
return get<Indices::value>(t);
|
||||
}
|
||||
}
|
||||
|
||||
// Wrap non-tuples into rank-1 tuples or forward
|
||||
template <class T>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
|
||||
Reference in New Issue
Block a user