CUTLASS 3.2.1 (#1113)

* Updates for 3.2.1 release.

* Minor fix in gemm op profiler for raster order.

* Add scheduler mapping for raster order in the kernels.
This commit is contained in:
ANIKET SHIVAM
2023-09-26 14:24:26 -07:00
committed by GitHub
parent e0aaa3c3b3
commit 90d3b0fb18
428 changed files with 22253 additions and 21762 deletions

View File

@@ -68,7 +68,14 @@ axpby(Alpha const& alpha,
Beta const& beta,
Tensor<YEngine, YLayout> & y)
{
auto isBetaZero = (beta == Int<0>{});
auto isBetaZero = [&] () {
if constexpr (is_complex<Beta>::value) {
return beta.real() == Int<0>{} && beta.imag() == Int<0>{};
}
else {
return beta == Int<0>{};
}
} ();
CUTE_UNROLL
for (int i = 0; i < size(x); ++i) {

View File

@@ -218,7 +218,6 @@ gemm(MMA_Atom<MMA> const& mma,
CUTE_STATIC_ASSERT_V(size<0>(A) == size<0>(C)); // AM == CM
CUTE_STATIC_ASSERT_V(size<0>(B) == size<1>(C)); // BN == CN
CUTE_STATIC_ASSERT_V(size<0>(C) == size<0>(D) && size<1>(C) == size<1>(D));
gemm(mma,
D, // (M,N)
make_tensor(A.data(), append<2>(A.layout())), // (M,1)
@@ -253,7 +252,7 @@ gemm(MMA_Atom<MMA> const& mma,
CUTE_STATIC_ASSERT_V(size<1>(typename MMA_Atom<MMA>::LayoutC_TV{}) == Int<1>{});
CUTE_STATIC_ASSERT_V(size<1>(typename MMA_Atom<MMA>::LayoutA_TV{}) == Int<1>{});
CUTE_STATIC_ASSERT_V(size<1>(typename MMA_Atom<MMA>::LayoutB_TV{}) == Int<1>{});
gemm(mma,
make_tensor(D.data(), prepend<3>(D.layout())), // (1,M,N)
make_tensor(A.data(), prepend<3>(A.layout())), // (1,M,K)
@@ -282,7 +281,6 @@ gemm(MMA_Atom<MMA> const& mma,
CUTE_STATIC_ASSERT_V(size<1>(A) == size<1>(C)); // AM == CM
CUTE_STATIC_ASSERT_V(size<1>(B) == size<2>(C)); // BN == CN
CUTE_STATIC_ASSERT_V(size<0>(C) == size<0>(D) && size<1>(C) == size<1>(D) && size<2>(C) == size<2>(D));
auto M = size<1>(A);
auto N = size<1>(B);
// REGISTER .reuse OPTIMIZATIONS
@@ -409,7 +407,6 @@ gemm(MMA_Atom<MMA> const& mma,
CUTE_STATIC_ASSERT_V(size<1>(B) == size<2>(C)); // BN == CN
CUTE_STATIC_ASSERT_V(size<2>(A) == size<2>(B)); // AK == BK
CUTE_STATIC_ASSERT_V(size<0>(C) == size<0>(D) && size<1>(C) == size<1>(D) && size<2>(C) == size<2>(D));
auto K = size<2>(A);
CUTE_UNROLL
@@ -454,7 +451,6 @@ gemm(MMA_Atom<MMA> const& mma,
CUTE_STATIC_ASSERT_V(size<1>(typename MMA_Atom<MMA>::LayoutC_TV{}) == Int<1>{});
CUTE_STATIC_ASSERT_V(size<1>(typename MMA_Atom<MMA>::LayoutA_TV{}) == Int<1>{});
CUTE_STATIC_ASSERT_V(size<1>(typename MMA_Atom<MMA>::LayoutB_TV{}) == Int<1>{});
gemm(mma,
make_tensor(D.data(), prepend<3>(D.layout())), // (1,M,N)
make_tensor(A.data(), prepend<3>(A.layout())), // (1,M,K)

View File

@@ -140,7 +140,11 @@ CUTE_HOST_DEVICE constexpr
auto
transform_apply(T&& t, F&& f, G&& g)
{
return detail::tapply(static_cast<T&&>(t), f, g, tuple_seq<T>{});
if constexpr (is_tuple<remove_cvref_t<T>>::value) {
return detail::tapply(static_cast<T&&>(t), f, g, tuple_seq<T>{});
} else {
return g(f(static_cast<T&&>(t)));
}
}
template <class T0, class T1, class F, class G>
@@ -148,7 +152,11 @@ CUTE_HOST_DEVICE constexpr
auto
transform_apply(T0&& t0, T1&& t1, F&& f, G&& g)
{
return detail::tapply(static_cast<T0&&>(t0), static_cast<T1&&>(t1), f, g, tuple_seq<T0>{});
if constexpr (is_tuple<remove_cvref_t<T0>>::value) {
return detail::tapply(static_cast<T0&&>(t0), static_cast<T1&&>(t1), f, g, tuple_seq<T0>{});
} else {
return g(f(static_cast<T0&&>(t0), static_cast<T1&&>(t1)));
}
}
template <class T0, class T1, class T2, class F, class G>
@@ -156,7 +164,11 @@ CUTE_HOST_DEVICE constexpr
auto
transform_apply(T0&& t0, T1&& t1, T2&& t2, F&& f, G&& g)
{
return detail::tapply(static_cast<T0&&>(t0), static_cast<T1&&>(t1), static_cast<T2&&>(t2), f, g, tuple_seq<T0>{});
if constexpr (is_tuple<remove_cvref_t<T0>>::value) {
return detail::tapply(static_cast<T0&&>(t0), static_cast<T1&&>(t1), static_cast<T2&&>(t2), f, g, tuple_seq<T0>{});
} else {
return g(f(static_cast<T0&&>(t0), static_cast<T1&&>(t1), static_cast<T2&&>(t2)));
}
}
//
@@ -306,21 +318,16 @@ transform_leaf(T0 const& t0, T1 const& t1, F&& f)
namespace detail {
template <class T, class F>
CUTE_HOST_DEVICE constexpr
auto
find_if(T const& t, F&& f, seq<>)
{
return cute::integral_constant<int, tuple_size<T>::value>{};
}
template <class T, class F, int I, int... Is>
CUTE_HOST_DEVICE constexpr
auto
find_if(T const& t, F&& f, seq<I,Is...>)
{
if constexpr (decltype(f(get<I>(t)))::value) {
return cute::integral_constant<int, I>{};
return cute::C<I>{};
} else
if constexpr (sizeof...(Is) == 0) {
return cute::C<I+1>{};
} else {
return find_if(t, f, seq<Is...>{});
}
@@ -338,7 +345,7 @@ find_if(T const& t, F&& f)
if constexpr (is_tuple<T>::value) {
return detail::find_if(t, f, tuple_seq<T>{});
} else {
return cute::integral_constant<int, decltype(f(t))::value ? 0 : 1>{};
return cute::C<decltype(f(t))::value ? 0 : 1>{};
}
CUTE_GCC_UNREACHABLE;
@@ -355,12 +362,12 @@ find(T const& t, X const& x)
template <class T, class F>
CUTE_HOST_DEVICE constexpr
auto
none_of(T const& t, F&& f)
any_of(T const& t, F&& f)
{
if constexpr (is_tuple<T>::value) {
return cute::integral_constant<bool, decltype(find_if(t, f))::value == tuple_size<T>::value>{};
return detail::apply(cute::transform(t, f), [&] (auto const&... a) { return (false_type{} || ... || a); }, tuple_seq<T>{});
} else {
return not f(t);
return f(t);
}
CUTE_GCC_UNREACHABLE;
@@ -372,8 +379,7 @@ auto
all_of(T const& t, F&& f)
{
if constexpr (is_tuple<T>::value) {
auto not_f = [&](auto const& a) { return not f(a); };
return cute::integral_constant<bool, decltype(find_if(t, not_f))::value == tuple_size<T>::value>{};
return detail::apply(t, [&] (auto const&... a) { return (true_type{} && ... && f(a)); }, tuple_seq<T>{});
} else {
return f(t);
}
@@ -384,9 +390,9 @@ all_of(T const& t, F&& f)
template <class T, class F>
CUTE_HOST_DEVICE constexpr
auto
any_of(T const& t, F&& f)
none_of(T const& t, F&& f)
{
return not none_of(t, f);
return not any_of(t, f);
}
//
@@ -410,6 +416,14 @@ filter_tuple(T0 const& t0, T1 const& t1, F&& f)
return transform_apply(t0, t1, f, [](auto const&... a) { return cute::tuple_cat(a...); });
}
template <class T0, class T1, class T2, class F>
CUTE_HOST_DEVICE constexpr
auto
filter_tuple(T0 const& t0, T1 const& t1, T2 const& t2, F&& f)
{
return transform_apply(t0, t1, t2, f, [](auto const&... a) { return cute::tuple_cat(a...); });
}
//
// Fold (Reduce, Accumulate)
// (t, v, f) => f(...f(f(v,t_0),t_1),...,t_n)
@@ -595,6 +609,13 @@ unwrap(T const& t)
//
// Flatten a hierarchical tuple to a tuple of depth one.
//
//
template <class T>
struct is_flat : true_type {};
template <class... Ts>
struct is_flat<tuple<Ts...>> : bool_constant<(true && ... && (not is_tuple<Ts>::value))> {};
template <class T>
CUTE_HOST_DEVICE constexpr
@@ -602,7 +623,12 @@ auto
flatten_to_tuple(T const& t)
{
if constexpr (is_tuple<T>::value) {
return filter_tuple(t, [](auto const& a) { return flatten_to_tuple(a); });
if constexpr (is_flat<T>::value) {
return t;
} else
{
return filter_tuple(t, [](auto const& a) { return flatten_to_tuple(a); });
}
} else {
return cute::make_tuple(t);
}
@@ -616,7 +642,12 @@ auto
flatten(T const& t)
{
if constexpr (is_tuple<T>::value) {
return filter_tuple(t, [](auto const& a) { return flatten_to_tuple(a); });
if constexpr (is_flat<T>::value) {
return t;
} else
{
return filter_tuple(t, [](auto const& a) { return flatten_to_tuple(a); });
}
} else {
return t;
}