CUTLASS 3.3.0 (#1167)

* Release 3.3.0

Adds support for mixed precision GEMMs On Hopper and Ampere
Adds support for < 16B aligned GEMMs on Hopper
Enhancements to EVT
Enhancements to Python interface
Enhancements to Sub-byte type handling in CuTe
Several other bug-fixes and performance improvements.

* minor doc update
This commit is contained in:
Pradeep Ramani
2023-11-02 08:09:05 -07:00
committed by GitHub
parent 922fb5108b
commit c008b4aea8
263 changed files with 16214 additions and 5008 deletions

View File

@@ -130,6 +130,17 @@ copy_if(PrdTensor const& pred,
// copy_if -- Predicated CopyAtom
//
namespace detail {
// Trait that detects if atom's traits has a member function with(bool)
template<typename, typename Enable = void>
constexpr bool has_with_bool = false;
template<typename T>
constexpr bool has_with_bool<T, cute::void_t<decltype(declval<typename T::Traits>().with(declval<bool>()))>> = true;
} // end namespace detail
template <class... CopyArgs,
class PredTensor,
class SrcEngine, class SrcLayout,
@@ -150,8 +161,14 @@ copy_if(Copy_Atom<CopyArgs...> const& copy_atom,
auto dst_v = group_modes<1,R>(dst);
CUTE_UNROLL
for (int i = 0; i < size<1>(src_v); ++i) {
if (pred(i)) {
copy_atom.call(src_v(_,i), dst_v(_,i));
// If copy traits can be transformed with a predicate value, do it, otherwise branch here
if constexpr (detail::has_with_bool<Copy_Atom<CopyArgs...>>) {
copy_atom.with(pred(i)).call(src_v(_,i), dst_v(_,i));
}
else {
if (pred(i)) {
copy_atom.call(src_v(_,i), dst_v(_,i));
}
}
}
}
@@ -169,15 +186,17 @@ void
copy_vec(Tensor<SrcEngine, SrcLayout> const& src,
Tensor<DstEngine, DstLayout> & dst)
{
using SrcType = typename SrcEngine::value_type;
using DstType = typename DstEngine::value_type;
using SrcType = typename SrcEngine::element_type;
using DstType = typename DstEngine::element_type;
if constexpr (sizeof(SrcType) == sizeof(DstType) && sizeof(VecType) > sizeof(DstType))
{
/* @pre is_aligned<N>(src.data()) &&
* is_aligned<N>(dst.data())
*/
auto src_v = recast<VecType const>(src);
auto dst_v = recast<VecType >(dst);
using SrcVecType = conditional_t<is_volatile_v<SrcType>, VecType const volatile, VecType const>;
using DstVecType = conditional_t<is_volatile_v<DstType>, VecType volatile, VecType >;
auto src_v = recast<SrcVecType>(src);
auto dst_v = recast<DstVecType>(dst);
#if 0
if (thread0()) {

View File

@@ -170,6 +170,76 @@ CUTE_NAMED_BINARY_OP(min_fn, cute::min);
#undef CUTE_BINARY_OP
#undef CUTE_NAMED_BINARY_OP
/**********/
/** Fold **/
/**********/
#define CUTE_FOLD_OP(NAME,OP) \
struct NAME##_unary_rfold { \
template <class... T> \
CUTE_HOST_DEVICE constexpr \
auto operator()(T&&... t) const { \
return (t OP ...); \
} \
}; \
struct NAME##_unary_lfold { \
template <class... T> \
CUTE_HOST_DEVICE constexpr \
auto operator()(T&&... t) const { \
return (... OP t); \
} \
}; \
struct NAME##_binary_rfold { \
template <class U, class... T> \
CUTE_HOST_DEVICE constexpr \
auto operator()(U&& u, T&&... t) const { \
return (t OP ... OP u); \
} \
}; \
struct NAME##_binary_lfold { \
template <class U, class... T> \
CUTE_HOST_DEVICE constexpr \
auto operator()(U&& u, T&&... t) const { \
return (u OP ... OP t); \
} \
}
CUTE_FOLD_OP(plus, +);
CUTE_FOLD_OP(minus, -);
CUTE_FOLD_OP(multiplies, *);
CUTE_FOLD_OP(divides, /);
CUTE_FOLD_OP(modulus, %);
CUTE_FOLD_OP(plus_assign, +=);
CUTE_FOLD_OP(minus_assign, -=);
CUTE_FOLD_OP(multiplies_assign, *=);
CUTE_FOLD_OP(divides_assign, /=);
CUTE_FOLD_OP(modulus_assign, %=);
CUTE_FOLD_OP(bit_and, &);
CUTE_FOLD_OP(bit_or, |);
CUTE_FOLD_OP(bit_xor, ^);
CUTE_FOLD_OP(left_shift, <<);
CUTE_FOLD_OP(right_shift, >>);
CUTE_FOLD_OP(bit_and_assign, &=);
CUTE_FOLD_OP(bit_or_assign, |=);
CUTE_FOLD_OP(bit_xor_assign, ^=);
CUTE_FOLD_OP(left_shift_assign, <<=);
CUTE_FOLD_OP(right_shift_assign, >>=);
CUTE_FOLD_OP(logical_and, &&);
CUTE_FOLD_OP(logical_or, ||);
CUTE_FOLD_OP(equal_to, ==);
CUTE_FOLD_OP(not_equal_to, !=);
CUTE_FOLD_OP(greater, >);
CUTE_FOLD_OP(less, <);
CUTE_FOLD_OP(greater_equal, >=);
CUTE_FOLD_OP(less_equal, <=);
#undef CUTE_FOLD_OP
/**********/
/** Meta **/
/**********/