mirror of
https://github.com/NVIDIA/cutlass.git
synced 2026-04-19 22:38:56 +00:00
CUTLASS 3.3.0 (#1167)
* Release 3.3.0 Adds support for mixed precision GEMMs On Hopper and Ampere Adds support for < 16B aligned GEMMs on Hopper Enhancements to EVT Enhancements to Python interface Enhancements to Sub-byte type handling in CuTe Several other bug-fixes and performance improvements. * minor doc update
This commit is contained in:
@@ -130,6 +130,17 @@ copy_if(PrdTensor const& pred,
|
||||
// copy_if -- Predicated CopyAtom
|
||||
//
|
||||
|
||||
namespace detail {
|
||||
|
||||
// Trait that detects if atom's traits has a member function with(bool)
|
||||
template<typename, typename Enable = void>
|
||||
constexpr bool has_with_bool = false;
|
||||
|
||||
template<typename T>
|
||||
constexpr bool has_with_bool<T, cute::void_t<decltype(declval<typename T::Traits>().with(declval<bool>()))>> = true;
|
||||
|
||||
} // end namespace detail
|
||||
|
||||
template <class... CopyArgs,
|
||||
class PredTensor,
|
||||
class SrcEngine, class SrcLayout,
|
||||
@@ -150,8 +161,14 @@ copy_if(Copy_Atom<CopyArgs...> const& copy_atom,
|
||||
auto dst_v = group_modes<1,R>(dst);
|
||||
CUTE_UNROLL
|
||||
for (int i = 0; i < size<1>(src_v); ++i) {
|
||||
if (pred(i)) {
|
||||
copy_atom.call(src_v(_,i), dst_v(_,i));
|
||||
// If copy traits can be transformed with a predicate value, do it, otherwise branch here
|
||||
if constexpr (detail::has_with_bool<Copy_Atom<CopyArgs...>>) {
|
||||
copy_atom.with(pred(i)).call(src_v(_,i), dst_v(_,i));
|
||||
}
|
||||
else {
|
||||
if (pred(i)) {
|
||||
copy_atom.call(src_v(_,i), dst_v(_,i));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -169,15 +186,17 @@ void
|
||||
copy_vec(Tensor<SrcEngine, SrcLayout> const& src,
|
||||
Tensor<DstEngine, DstLayout> & dst)
|
||||
{
|
||||
using SrcType = typename SrcEngine::value_type;
|
||||
using DstType = typename DstEngine::value_type;
|
||||
using SrcType = typename SrcEngine::element_type;
|
||||
using DstType = typename DstEngine::element_type;
|
||||
if constexpr (sizeof(SrcType) == sizeof(DstType) && sizeof(VecType) > sizeof(DstType))
|
||||
{
|
||||
/* @pre is_aligned<N>(src.data()) &&
|
||||
* is_aligned<N>(dst.data())
|
||||
*/
|
||||
auto src_v = recast<VecType const>(src);
|
||||
auto dst_v = recast<VecType >(dst);
|
||||
using SrcVecType = conditional_t<is_volatile_v<SrcType>, VecType const volatile, VecType const>;
|
||||
using DstVecType = conditional_t<is_volatile_v<DstType>, VecType volatile, VecType >;
|
||||
auto src_v = recast<SrcVecType>(src);
|
||||
auto dst_v = recast<DstVecType>(dst);
|
||||
|
||||
#if 0
|
||||
if (thread0()) {
|
||||
|
||||
@@ -170,6 +170,76 @@ CUTE_NAMED_BINARY_OP(min_fn, cute::min);
|
||||
#undef CUTE_BINARY_OP
|
||||
#undef CUTE_NAMED_BINARY_OP
|
||||
|
||||
/**********/
|
||||
/** Fold **/
|
||||
/**********/
|
||||
|
||||
#define CUTE_FOLD_OP(NAME,OP) \
|
||||
struct NAME##_unary_rfold { \
|
||||
template <class... T> \
|
||||
CUTE_HOST_DEVICE constexpr \
|
||||
auto operator()(T&&... t) const { \
|
||||
return (t OP ...); \
|
||||
} \
|
||||
}; \
|
||||
struct NAME##_unary_lfold { \
|
||||
template <class... T> \
|
||||
CUTE_HOST_DEVICE constexpr \
|
||||
auto operator()(T&&... t) const { \
|
||||
return (... OP t); \
|
||||
} \
|
||||
}; \
|
||||
struct NAME##_binary_rfold { \
|
||||
template <class U, class... T> \
|
||||
CUTE_HOST_DEVICE constexpr \
|
||||
auto operator()(U&& u, T&&... t) const { \
|
||||
return (t OP ... OP u); \
|
||||
} \
|
||||
}; \
|
||||
struct NAME##_binary_lfold { \
|
||||
template <class U, class... T> \
|
||||
CUTE_HOST_DEVICE constexpr \
|
||||
auto operator()(U&& u, T&&... t) const { \
|
||||
return (u OP ... OP t); \
|
||||
} \
|
||||
}
|
||||
|
||||
CUTE_FOLD_OP(plus, +);
|
||||
CUTE_FOLD_OP(minus, -);
|
||||
CUTE_FOLD_OP(multiplies, *);
|
||||
CUTE_FOLD_OP(divides, /);
|
||||
CUTE_FOLD_OP(modulus, %);
|
||||
|
||||
CUTE_FOLD_OP(plus_assign, +=);
|
||||
CUTE_FOLD_OP(minus_assign, -=);
|
||||
CUTE_FOLD_OP(multiplies_assign, *=);
|
||||
CUTE_FOLD_OP(divides_assign, /=);
|
||||
CUTE_FOLD_OP(modulus_assign, %=);
|
||||
|
||||
CUTE_FOLD_OP(bit_and, &);
|
||||
CUTE_FOLD_OP(bit_or, |);
|
||||
CUTE_FOLD_OP(bit_xor, ^);
|
||||
CUTE_FOLD_OP(left_shift, <<);
|
||||
CUTE_FOLD_OP(right_shift, >>);
|
||||
|
||||
CUTE_FOLD_OP(bit_and_assign, &=);
|
||||
CUTE_FOLD_OP(bit_or_assign, |=);
|
||||
CUTE_FOLD_OP(bit_xor_assign, ^=);
|
||||
CUTE_FOLD_OP(left_shift_assign, <<=);
|
||||
CUTE_FOLD_OP(right_shift_assign, >>=);
|
||||
|
||||
CUTE_FOLD_OP(logical_and, &&);
|
||||
CUTE_FOLD_OP(logical_or, ||);
|
||||
|
||||
CUTE_FOLD_OP(equal_to, ==);
|
||||
CUTE_FOLD_OP(not_equal_to, !=);
|
||||
CUTE_FOLD_OP(greater, >);
|
||||
CUTE_FOLD_OP(less, <);
|
||||
CUTE_FOLD_OP(greater_equal, >=);
|
||||
CUTE_FOLD_OP(less_equal, <=);
|
||||
|
||||
#undef CUTE_FOLD_OP
|
||||
|
||||
/**********/
|
||||
/** Meta **/
|
||||
/**********/
|
||||
|
||||
Reference in New Issue
Block a user