mirror of
https://github.com/NVIDIA/cutlass.git
synced 2026-04-19 22:38:56 +00:00
CUTLASS 3.2.1 (#1113)
* Updates for 3.2.1 release. * Minor fix in gemm op profiler for raster order. * Add scheduler mapping for raster order in the kernels.
This commit is contained in:
@@ -68,7 +68,14 @@ axpby(Alpha const& alpha,
|
||||
Beta const& beta,
|
||||
Tensor<YEngine, YLayout> & y)
|
||||
{
|
||||
auto isBetaZero = (beta == Int<0>{});
|
||||
auto isBetaZero = [&] () {
|
||||
if constexpr (is_complex<Beta>::value) {
|
||||
return beta.real() == Int<0>{} && beta.imag() == Int<0>{};
|
||||
}
|
||||
else {
|
||||
return beta == Int<0>{};
|
||||
}
|
||||
} ();
|
||||
|
||||
CUTE_UNROLL
|
||||
for (int i = 0; i < size(x); ++i) {
|
||||
|
||||
@@ -218,7 +218,6 @@ gemm(MMA_Atom<MMA> const& mma,
|
||||
CUTE_STATIC_ASSERT_V(size<0>(A) == size<0>(C)); // AM == CM
|
||||
CUTE_STATIC_ASSERT_V(size<0>(B) == size<1>(C)); // BN == CN
|
||||
CUTE_STATIC_ASSERT_V(size<0>(C) == size<0>(D) && size<1>(C) == size<1>(D));
|
||||
|
||||
gemm(mma,
|
||||
D, // (M,N)
|
||||
make_tensor(A.data(), append<2>(A.layout())), // (M,1)
|
||||
@@ -253,7 +252,7 @@ gemm(MMA_Atom<MMA> const& mma,
|
||||
CUTE_STATIC_ASSERT_V(size<1>(typename MMA_Atom<MMA>::LayoutC_TV{}) == Int<1>{});
|
||||
CUTE_STATIC_ASSERT_V(size<1>(typename MMA_Atom<MMA>::LayoutA_TV{}) == Int<1>{});
|
||||
CUTE_STATIC_ASSERT_V(size<1>(typename MMA_Atom<MMA>::LayoutB_TV{}) == Int<1>{});
|
||||
|
||||
|
||||
gemm(mma,
|
||||
make_tensor(D.data(), prepend<3>(D.layout())), // (1,M,N)
|
||||
make_tensor(A.data(), prepend<3>(A.layout())), // (1,M,K)
|
||||
@@ -282,7 +281,6 @@ gemm(MMA_Atom<MMA> const& mma,
|
||||
CUTE_STATIC_ASSERT_V(size<1>(A) == size<1>(C)); // AM == CM
|
||||
CUTE_STATIC_ASSERT_V(size<1>(B) == size<2>(C)); // BN == CN
|
||||
CUTE_STATIC_ASSERT_V(size<0>(C) == size<0>(D) && size<1>(C) == size<1>(D) && size<2>(C) == size<2>(D));
|
||||
|
||||
auto M = size<1>(A);
|
||||
auto N = size<1>(B);
|
||||
// REGISTER .reuse OPTIMIZATIONS
|
||||
@@ -409,7 +407,6 @@ gemm(MMA_Atom<MMA> const& mma,
|
||||
CUTE_STATIC_ASSERT_V(size<1>(B) == size<2>(C)); // BN == CN
|
||||
CUTE_STATIC_ASSERT_V(size<2>(A) == size<2>(B)); // AK == BK
|
||||
CUTE_STATIC_ASSERT_V(size<0>(C) == size<0>(D) && size<1>(C) == size<1>(D) && size<2>(C) == size<2>(D));
|
||||
|
||||
auto K = size<2>(A);
|
||||
|
||||
CUTE_UNROLL
|
||||
@@ -454,7 +451,6 @@ gemm(MMA_Atom<MMA> const& mma,
|
||||
CUTE_STATIC_ASSERT_V(size<1>(typename MMA_Atom<MMA>::LayoutC_TV{}) == Int<1>{});
|
||||
CUTE_STATIC_ASSERT_V(size<1>(typename MMA_Atom<MMA>::LayoutA_TV{}) == Int<1>{});
|
||||
CUTE_STATIC_ASSERT_V(size<1>(typename MMA_Atom<MMA>::LayoutB_TV{}) == Int<1>{});
|
||||
|
||||
gemm(mma,
|
||||
make_tensor(D.data(), prepend<3>(D.layout())), // (1,M,N)
|
||||
make_tensor(A.data(), prepend<3>(A.layout())), // (1,M,K)
|
||||
|
||||
@@ -140,7 +140,11 @@ CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
transform_apply(T&& t, F&& f, G&& g)
|
||||
{
|
||||
return detail::tapply(static_cast<T&&>(t), f, g, tuple_seq<T>{});
|
||||
if constexpr (is_tuple<remove_cvref_t<T>>::value) {
|
||||
return detail::tapply(static_cast<T&&>(t), f, g, tuple_seq<T>{});
|
||||
} else {
|
||||
return g(f(static_cast<T&&>(t)));
|
||||
}
|
||||
}
|
||||
|
||||
template <class T0, class T1, class F, class G>
|
||||
@@ -148,7 +152,11 @@ CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
transform_apply(T0&& t0, T1&& t1, F&& f, G&& g)
|
||||
{
|
||||
return detail::tapply(static_cast<T0&&>(t0), static_cast<T1&&>(t1), f, g, tuple_seq<T0>{});
|
||||
if constexpr (is_tuple<remove_cvref_t<T0>>::value) {
|
||||
return detail::tapply(static_cast<T0&&>(t0), static_cast<T1&&>(t1), f, g, tuple_seq<T0>{});
|
||||
} else {
|
||||
return g(f(static_cast<T0&&>(t0), static_cast<T1&&>(t1)));
|
||||
}
|
||||
}
|
||||
|
||||
template <class T0, class T1, class T2, class F, class G>
|
||||
@@ -156,7 +164,11 @@ CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
transform_apply(T0&& t0, T1&& t1, T2&& t2, F&& f, G&& g)
|
||||
{
|
||||
return detail::tapply(static_cast<T0&&>(t0), static_cast<T1&&>(t1), static_cast<T2&&>(t2), f, g, tuple_seq<T0>{});
|
||||
if constexpr (is_tuple<remove_cvref_t<T0>>::value) {
|
||||
return detail::tapply(static_cast<T0&&>(t0), static_cast<T1&&>(t1), static_cast<T2&&>(t2), f, g, tuple_seq<T0>{});
|
||||
} else {
|
||||
return g(f(static_cast<T0&&>(t0), static_cast<T1&&>(t1), static_cast<T2&&>(t2)));
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
@@ -306,21 +318,16 @@ transform_leaf(T0 const& t0, T1 const& t1, F&& f)
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <class T, class F>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
find_if(T const& t, F&& f, seq<>)
|
||||
{
|
||||
return cute::integral_constant<int, tuple_size<T>::value>{};
|
||||
}
|
||||
|
||||
template <class T, class F, int I, int... Is>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
find_if(T const& t, F&& f, seq<I,Is...>)
|
||||
{
|
||||
if constexpr (decltype(f(get<I>(t)))::value) {
|
||||
return cute::integral_constant<int, I>{};
|
||||
return cute::C<I>{};
|
||||
} else
|
||||
if constexpr (sizeof...(Is) == 0) {
|
||||
return cute::C<I+1>{};
|
||||
} else {
|
||||
return find_if(t, f, seq<Is...>{});
|
||||
}
|
||||
@@ -338,7 +345,7 @@ find_if(T const& t, F&& f)
|
||||
if constexpr (is_tuple<T>::value) {
|
||||
return detail::find_if(t, f, tuple_seq<T>{});
|
||||
} else {
|
||||
return cute::integral_constant<int, decltype(f(t))::value ? 0 : 1>{};
|
||||
return cute::C<decltype(f(t))::value ? 0 : 1>{};
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
@@ -355,12 +362,12 @@ find(T const& t, X const& x)
|
||||
template <class T, class F>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
none_of(T const& t, F&& f)
|
||||
any_of(T const& t, F&& f)
|
||||
{
|
||||
if constexpr (is_tuple<T>::value) {
|
||||
return cute::integral_constant<bool, decltype(find_if(t, f))::value == tuple_size<T>::value>{};
|
||||
return detail::apply(cute::transform(t, f), [&] (auto const&... a) { return (false_type{} || ... || a); }, tuple_seq<T>{});
|
||||
} else {
|
||||
return not f(t);
|
||||
return f(t);
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
@@ -372,8 +379,7 @@ auto
|
||||
all_of(T const& t, F&& f)
|
||||
{
|
||||
if constexpr (is_tuple<T>::value) {
|
||||
auto not_f = [&](auto const& a) { return not f(a); };
|
||||
return cute::integral_constant<bool, decltype(find_if(t, not_f))::value == tuple_size<T>::value>{};
|
||||
return detail::apply(t, [&] (auto const&... a) { return (true_type{} && ... && f(a)); }, tuple_seq<T>{});
|
||||
} else {
|
||||
return f(t);
|
||||
}
|
||||
@@ -384,9 +390,9 @@ all_of(T const& t, F&& f)
|
||||
template <class T, class F>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
any_of(T const& t, F&& f)
|
||||
none_of(T const& t, F&& f)
|
||||
{
|
||||
return not none_of(t, f);
|
||||
return not any_of(t, f);
|
||||
}
|
||||
|
||||
//
|
||||
@@ -410,6 +416,14 @@ filter_tuple(T0 const& t0, T1 const& t1, F&& f)
|
||||
return transform_apply(t0, t1, f, [](auto const&... a) { return cute::tuple_cat(a...); });
|
||||
}
|
||||
|
||||
template <class T0, class T1, class T2, class F>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
filter_tuple(T0 const& t0, T1 const& t1, T2 const& t2, F&& f)
|
||||
{
|
||||
return transform_apply(t0, t1, t2, f, [](auto const&... a) { return cute::tuple_cat(a...); });
|
||||
}
|
||||
|
||||
//
|
||||
// Fold (Reduce, Accumulate)
|
||||
// (t, v, f) => f(...f(f(v,t_0),t_1),...,t_n)
|
||||
@@ -595,6 +609,13 @@ unwrap(T const& t)
|
||||
//
|
||||
// Flatten a hierarchical tuple to a tuple of depth one.
|
||||
//
|
||||
//
|
||||
|
||||
template <class T>
|
||||
struct is_flat : true_type {};
|
||||
|
||||
template <class... Ts>
|
||||
struct is_flat<tuple<Ts...>> : bool_constant<(true && ... && (not is_tuple<Ts>::value))> {};
|
||||
|
||||
template <class T>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
@@ -602,7 +623,12 @@ auto
|
||||
flatten_to_tuple(T const& t)
|
||||
{
|
||||
if constexpr (is_tuple<T>::value) {
|
||||
return filter_tuple(t, [](auto const& a) { return flatten_to_tuple(a); });
|
||||
if constexpr (is_flat<T>::value) {
|
||||
return t;
|
||||
} else
|
||||
{
|
||||
return filter_tuple(t, [](auto const& a) { return flatten_to_tuple(a); });
|
||||
}
|
||||
} else {
|
||||
return cute::make_tuple(t);
|
||||
}
|
||||
@@ -616,7 +642,12 @@ auto
|
||||
flatten(T const& t)
|
||||
{
|
||||
if constexpr (is_tuple<T>::value) {
|
||||
return filter_tuple(t, [](auto const& a) { return flatten_to_tuple(a); });
|
||||
if constexpr (is_flat<T>::value) {
|
||||
return t;
|
||||
} else
|
||||
{
|
||||
return filter_tuple(t, [](auto const& a) { return flatten_to_tuple(a); });
|
||||
}
|
||||
} else {
|
||||
return t;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user