mirror of
https://github.com/NVIDIA/cutlass.git
synced 2026-04-20 06:48:59 +00:00
3.6.0 update (#2005)
* 3.6.0 update * doc and swap stuff --------- Co-authored-by: yuzhai <yuzhai@nvidia.com> Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
@@ -51,19 +51,14 @@ naive_cooperative_copy(uint32_t const& tid,
|
||||
Tensor<SrcEngine, SrcLayout> const& src,
|
||||
Tensor<DstEngine, DstLayout> & dst)
|
||||
{
|
||||
auto N = size(src);
|
||||
if (tid < N) {
|
||||
uint32_t upper_bound = (N / NumThreads) * NumThreads;
|
||||
CUTE_UNROLL
|
||||
for (uint32_t i = 0; i < upper_bound; i += NumThreads) { // All in-bounds
|
||||
dst[tid + i] = src[tid + i];
|
||||
}
|
||||
if (N % NumThreads != 0) { // Likely static condition
|
||||
uint32_t final_idx = tid + upper_bound;
|
||||
if (final_idx < N) { // Final in-bounds
|
||||
dst[final_idx] = src[final_idx];
|
||||
}
|
||||
}
|
||||
auto N = size(dst);
|
||||
auto R = N % Int<NumThreads>{};
|
||||
if (R > 0 && tid < R) { // Likely static condition && Residue in-bounds
|
||||
dst[tid] = src[tid];
|
||||
}
|
||||
CUTE_UNROLL
|
||||
for (uint32_t i = uint32_t(R); i < uint32_t(N); i += NumThreads) { // All in-bounds
|
||||
dst[tid + i] = src[tid + i];
|
||||
}
|
||||
}
|
||||
|
||||
@@ -117,12 +112,14 @@ heuristic_permutation(Tensor<AEngine, ALayout> const& a,
|
||||
//
|
||||
template <uint32_t NumThreads, uint32_t MaxVecBits,
|
||||
class SrcEngine, class SrcLayout,
|
||||
class DstEngine, class DstLayout>
|
||||
class DstEngine, class DstLayout,
|
||||
class CopyPolicy = DefaultCopy>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
cooperative_copy(uint32_t const& tid,
|
||||
Tensor<SrcEngine, SrcLayout> const& src,
|
||||
Tensor<DstEngine, DstLayout> & dst)
|
||||
Tensor<DstEngine, DstLayout> & dst,
|
||||
CopyPolicy const& cpy = {})
|
||||
{
|
||||
// Assumes the shapes are static, can generalize/fallback
|
||||
CUTE_STATIC_ASSERT_V(is_static<decltype(shape(src))>{} && is_static<decltype(shape(dst))>{});
|
||||
@@ -283,23 +280,28 @@ cooperative_copy(uint32_t const& tid,
|
||||
|
||||
// If we're using all threads (static) or the tid is in-range (dynamic)
|
||||
if (vec_thrs == NumThreads or tid < vec_thrs) {
|
||||
return copy_if(TrivialPredTensor{}, recast<VecType const>(src_v), recast<VecType>(dst_v));
|
||||
auto src_c = recast<VecType const>(src_v);
|
||||
auto dst_c = recast<VecType>(dst_v);
|
||||
return copy(cpy, src_c, dst_c);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Default max-vectorization size to value_type size
|
||||
template <uint32_t NumThreads,
|
||||
class SrcEngine, class SrcLayout,
|
||||
class DstEngine, class DstLayout>
|
||||
class DstEngine, class DstLayout,
|
||||
class CopyPolicy = DefaultCopy>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
cooperative_copy(uint32_t const& tid,
|
||||
Tensor<SrcEngine, SrcLayout> const& src,
|
||||
Tensor<DstEngine, DstLayout> & dst)
|
||||
Tensor<DstEngine, DstLayout> & dst,
|
||||
CopyPolicy const& cpy = {})
|
||||
{
|
||||
constexpr uint32_t MaxVecBits = sizeof_bits_v<typename SrcEngine::value_type>;
|
||||
return cooperative_copy<NumThreads, MaxVecBits>(tid, src, dst);
|
||||
return cooperative_copy<NumThreads, MaxVecBits>(tid, src, dst, cpy);
|
||||
}
|
||||
|
||||
//
|
||||
@@ -308,26 +310,30 @@ cooperative_copy(uint32_t const& tid,
|
||||
|
||||
template <uint32_t NumThreads,
|
||||
class SrcEngine, class SrcLayout,
|
||||
class DstEngine, class DstLayout>
|
||||
class DstEngine, class DstLayout,
|
||||
class CopyPolicy = DefaultCopy>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
cooperative_copy(uint32_t const& tid,
|
||||
Tensor<SrcEngine, SrcLayout> const& src,
|
||||
Tensor<DstEngine, DstLayout> && dst)
|
||||
Tensor<DstEngine, DstLayout> && dst,
|
||||
CopyPolicy const& cpy = {})
|
||||
{
|
||||
return cooperative_copy<NumThreads>(tid, src, dst);
|
||||
return cooperative_copy<NumThreads>(tid, src, dst, cpy);
|
||||
}
|
||||
|
||||
template <uint32_t NumThreads, uint32_t MaxVecBits,
|
||||
class SrcEngine, class SrcLayout,
|
||||
class DstEngine, class DstLayout>
|
||||
class DstEngine, class DstLayout,
|
||||
class CopyPolicy = DefaultCopy>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
cooperative_copy(uint32_t const& tid,
|
||||
Tensor<SrcEngine, SrcLayout> const& src,
|
||||
Tensor<DstEngine, DstLayout> && dst)
|
||||
Tensor<DstEngine, DstLayout> && dst,
|
||||
CopyPolicy const& cpy = {})
|
||||
{
|
||||
return cooperative_copy<NumThreads, MaxVecBits>(tid, src, dst);
|
||||
return cooperative_copy<NumThreads, MaxVecBits>(tid, src, dst, cpy);
|
||||
}
|
||||
|
||||
} // end namespace cute
|
||||
|
||||
@@ -50,31 +50,115 @@ namespace cute
|
||||
|
||||
namespace detail {
|
||||
|
||||
// Predicated Cooperative GEMM
|
||||
template <class... Args,
|
||||
class Alpha, class TA, class ALayout, class TB, class BLayout,
|
||||
class Beta, class TC, class CLayout,
|
||||
class ALoadTransformOp, class BLoadTransformOp,
|
||||
class CLoadTransformOp, class CStoreTransformOp,
|
||||
__CUTE_REQUIRES(ALayout::rank == 2 && is_smem<TA>::value &&
|
||||
BLayout::rank == 2 && is_smem<TB>::value &&
|
||||
CLayout::rank == 2 && is_smem<TC>::value)>
|
||||
// Slow fallback path:
|
||||
template<typename ... Args,
|
||||
typename Alpha, typename TRC, typename RCLayout,
|
||||
typename Beta, class TSC, typename CLayout, typename SCLayout,
|
||||
typename CLoadTransformOp, typename CStoreTransformOp>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
cooperative_gemm_predication(ThrMMA<Args...> const& thr_mma,
|
||||
Alpha const& alpha,
|
||||
Tensor<TA, ALayout> sA,
|
||||
Tensor<TB, BLayout> sB,
|
||||
Beta const& beta,
|
||||
Tensor<TC, CLayout> sC,
|
||||
ALoadTransformOp const& sA_load_op, // transforms A values before use in GEMM
|
||||
BLoadTransformOp const& sB_load_op, // transforms B values before use in GEMM
|
||||
CLoadTransformOp const& sC_load_op, // transforms C values before use in GEMM
|
||||
CStoreTransformOp const& sC_store_op) // transforms results before they are stored to C
|
||||
epilogue_predication(ThrMMA<Args...> const& thr_mma,
|
||||
Alpha const& alpha,
|
||||
Tensor<TRC, RCLayout> & tCrC,
|
||||
Beta const& beta,
|
||||
Tensor<TSC, CLayout> & sC,
|
||||
Tensor<TSC, SCLayout> & tCsC,
|
||||
CLoadTransformOp const& sC_load_op, // transforms C values before use in GEMM
|
||||
CStoreTransformOp const& sC_store_op) // transforms results before they are stored to C
|
||||
{
|
||||
using TypeA = typename TA::value_type;
|
||||
using TypeB = typename TB::value_type;
|
||||
using TypeC = typename TC::value_type;
|
||||
using InputTypeC = typename TSC::value_type;
|
||||
using ComputeTypeC = typename ThrMMA<Args...>::ValTypeC;
|
||||
CUTE_STATIC_ASSERT(CUTE_STL_NAMESPACE::is_same_v<ComputeTypeC, typename TRC::value_type>);
|
||||
|
||||
// Create coordinate tensors for the problem
|
||||
Tensor cC = make_identity_tensor(shape(sC)); // (M,N) -> (m,n)
|
||||
// Repeat partitioning with thr_mma
|
||||
Tensor tCcC = thr_mma.partition_C(cC); // (MMA,MMA_M,MMA_N) -> (m,n)
|
||||
|
||||
const bool isBetaZero = [&] () {
|
||||
if constexpr (is_complex<Beta>::value) {
|
||||
return beta.real() == Int<0>{} && beta.imag() == Int<0>{};
|
||||
}
|
||||
else {
|
||||
return beta == Int<0>{};
|
||||
}
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
} ();
|
||||
|
||||
// Custom axpby_if for now
|
||||
CUTE_UNROLL
|
||||
for (int i = 0; i < size(tCrC); ++i)
|
||||
{
|
||||
if (elem_less(tCcC(i), shape(sC)))
|
||||
{
|
||||
tCsC(i) = sC_store_op(isBetaZero ? alpha * tCrC(i)
|
||||
: alpha * tCrC(i) +
|
||||
beta * static_cast<ComputeTypeC>(sC_load_op(tCsC(i))));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<class Alpha, class TRC, class RCLayout,
|
||||
class Beta, class TSC, class SCLayout,
|
||||
class CLoadTransformOp, class CStoreTransformOp,
|
||||
class SmemCopyOpC>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
epilogue_no_predication(Alpha const& alpha,
|
||||
Tensor<TRC, RCLayout> & tCrC,
|
||||
Beta const& beta,
|
||||
Tensor<TSC, SCLayout> & tCsC,
|
||||
CLoadTransformOp const& sC_load_op, // transforms C values before use in GEMM
|
||||
CStoreTransformOp const& sC_store_op, // transforms results before they are stored to C
|
||||
SmemCopyOpC const& sC_copy_op)
|
||||
{
|
||||
using InputTypeC = typename TSC::value_type;
|
||||
using ComputeTypeC = typename TRC::value_type;
|
||||
|
||||
const bool isBetaZero = [&] () {
|
||||
if constexpr (is_complex<Beta>::value) {
|
||||
return beta.real() == Int<0>{} && beta.imag() == Int<0>{};
|
||||
}
|
||||
else {
|
||||
return beta == Int<0>{};
|
||||
}
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
} ();
|
||||
|
||||
Tensor tCrDi = make_fragment_like(tCsC);
|
||||
Tensor tCrD = make_fragment_like(tCrC);
|
||||
if(!isBetaZero) {
|
||||
copy(sC_copy_op, tCsC, tCrDi);
|
||||
// Transform C on/after load
|
||||
cute::transform(tCrDi, tCrD, sC_load_op);
|
||||
}
|
||||
// C = alpha * (A * B) + beta * C
|
||||
axpby(alpha, tCrC, beta, tCrD);
|
||||
// Transform C before/on store
|
||||
cute::transform(tCrD, tCrDi, sC_store_op);
|
||||
copy(sC_copy_op, tCrDi, tCsC);
|
||||
}
|
||||
|
||||
// Predicated Cooperative GEMM
|
||||
template <class... Args,
|
||||
class TA, class ALayout, class TB, class BLayout,
|
||||
class TC, class RCLayout,
|
||||
class ALoadTransformOp, class BLoadTransformOp>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
cooperative_gemm_predication(ThrMMA<Args...> const& thr_mma,
|
||||
Tensor<TA, ALayout> const& sA,
|
||||
Tensor<TB, BLayout> const& sB,
|
||||
Tensor<TC, RCLayout> & tCrC,
|
||||
ALoadTransformOp const& sA_load_op, // transforms A values before use in GEMM
|
||||
BLoadTransformOp const& sB_load_op) // transforms B values before use in GEMM
|
||||
{
|
||||
using InputTypeA = typename TA::value_type;
|
||||
using InputTypeB = typename TB::value_type;
|
||||
using InputTypeC = typename TC::value_type;
|
||||
using ComputeTypeA = typename ThrMMA<Args...>::ValTypeA;
|
||||
using ComputeTypeB = typename ThrMMA<Args...>::ValTypeB;
|
||||
using ComputeTypeC = typename ThrMMA<Args...>::ValTypeC;
|
||||
|
||||
//
|
||||
// MMA Partitioning
|
||||
@@ -83,22 +167,18 @@ cooperative_gemm_predication(ThrMMA<Args...> const& thr_mma,
|
||||
// Partition the sA, sB, and sC tiles across the threads for the MMA
|
||||
Tensor tCsA = thr_mma.partition_A(sA); // (MMA,MMA_M,MMA_K)
|
||||
Tensor tCsB = thr_mma.partition_B(sB); // (MMA,MMA_N,MMA_K)
|
||||
Tensor tCsC = thr_mma.partition_C(sC); // (MMA,MMA_M,MMA_N)
|
||||
|
||||
// Create register tensors for the MMA to operate on
|
||||
Tensor tCrA = thr_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K)
|
||||
Tensor tCrB = thr_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K)
|
||||
Tensor tCrC = thr_mma.make_fragment_C(tCsC); // (MMA,MMA_M,MMA_N)
|
||||
|
||||
#if 0
|
||||
if (thread0()) {
|
||||
print(" sA: "); print( sA); print("\n");
|
||||
print(" sB: "); print( sB); print("\n");
|
||||
print(" sC: "); print( sC); print("\n");
|
||||
print(thr_mma);
|
||||
print("tCsA: "); print(tCsA); print("\n");
|
||||
print("tCsB: "); print(tCsB); print("\n");
|
||||
print("tCsC: "); print(tCsC); print("\n");
|
||||
print("tCrA: "); print(tCrA); print("\n");
|
||||
print("tCrB: "); print(tCrB); print("\n");
|
||||
print("tCrC: "); print(tCrC); print("\n");
|
||||
@@ -154,23 +234,20 @@ cooperative_gemm_predication(ThrMMA<Args...> const& thr_mma,
|
||||
for (int m = 0; m < size<1>(tCrA); ++m) { // Copy MMA_M
|
||||
CUTE_UNROLL
|
||||
for (int i = 0; i < size<0>(tCrA); ++i) { // Copy MMA_I
|
||||
tCrA(i,m,0) = (tCpA(i,m) && (0 < K_BLOCK_MAX-1 || elem_less(get<1>(tCcA(i,m,0)), shape<1>(sA)))) ? sA_load_op(tCsA(i,m,0)) : TypeA{};
|
||||
tCrA(i,m,0) = (tCpA(i,m) && (0 < K_BLOCK_MAX-1 || elem_less(get<1>(tCcA(i,m,0)), shape<1>(sA)))) ? static_cast<ComputeTypeA>(sA_load_op(tCsA(i,m,0))) : ComputeTypeA{};
|
||||
}
|
||||
}
|
||||
CUTE_UNROLL
|
||||
for (int n = 0; n < size<1>(tCrB); ++n) { // Copy MMA_N
|
||||
CUTE_UNROLL
|
||||
for (int i = 0; i < size<0>(tCrB); ++i) { // Copy MMA_I
|
||||
tCrB(i,n,0) = (tCpB(i,n) && (0 < K_BLOCK_MAX-1 || elem_less(get<1>(tCcB(i,n,0)), shape<1>(sB)))) ? sB_load_op(tCsB(i,n,0)) : TypeB{};
|
||||
tCrB(i,n,0) = (tCpB(i,n) && (0 < K_BLOCK_MAX-1 || elem_less(get<1>(tCcB(i,n,0)), shape<1>(sB)))) ? static_cast<ComputeTypeB>(sB_load_op(tCsB(i,n,0))) : ComputeTypeB{};
|
||||
}
|
||||
}
|
||||
//
|
||||
// MAINLOOP
|
||||
//
|
||||
|
||||
// Clear accumulators
|
||||
clear(tCrC);
|
||||
|
||||
CUTE_UNROLL
|
||||
for (int k_block = 0; k_block < K_BLOCK_MAX; ++k_block)
|
||||
{
|
||||
@@ -185,138 +262,80 @@ cooperative_gemm_predication(ThrMMA<Args...> const& thr_mma,
|
||||
for (int m = 0; m < size<1>(tCrA); ++m) { // Copy MMA_M
|
||||
CUTE_UNROLL
|
||||
for (int i = 0; i < size<0>(tCrA); ++i) { // Copy MMA_I
|
||||
tCrA(i,m,k_next) = (tCpA(i,m) && (k_next < K_BLOCK_MAX-1 || elem_less(get<1>(tCcA(i,m,k_next)), shape<1>(sA)))) ? sA_load_op(tCsA(i,m,k_next)) : TypeA{};
|
||||
tCrA(i,m,k_next) = (tCpA(i,m) && (k_next < K_BLOCK_MAX-1 || elem_less(get<1>(tCcA(i,m,k_next)), shape<1>(sA)))) ? static_cast<ComputeTypeA>(sA_load_op(tCsA(i,m,k_next))) : ComputeTypeA{};
|
||||
}
|
||||
}
|
||||
CUTE_UNROLL
|
||||
for (int n = 0; n < size<1>(tCrB); ++n) { // Copy MMA_N
|
||||
CUTE_UNROLL
|
||||
for (int i = 0; i < size<0>(tCrB); ++i) { // Copy MMA_I
|
||||
tCrB(i,n,k_next) = (tCpB(i,n) && (k_next < K_BLOCK_MAX-1 || elem_less(get<1>(tCcB(i,n,k_next)), shape<1>(sB)))) ? sB_load_op(tCsB(i,n,k_next)) : TypeB{};
|
||||
tCrB(i,n,k_next) = (tCpB(i,n) && (k_next < K_BLOCK_MAX-1 || elem_less(get<1>(tCcB(i,n,k_next)), shape<1>(sB)))) ? static_cast<ComputeTypeB>(sB_load_op(tCsB(i,n,k_next))) : ComputeTypeB{};
|
||||
}
|
||||
}
|
||||
}
|
||||
// GEMM on k_block in registers
|
||||
gemm(thr_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC);
|
||||
}
|
||||
|
||||
//
|
||||
// Epilogue
|
||||
//
|
||||
|
||||
// Create coordinate tensors for the problem
|
||||
Tensor cC = make_identity_tensor(shape(sC)); // (M,N) -> (m,n)
|
||||
// Repeat partitioning with thr_mma
|
||||
Tensor tCcC = thr_mma.partition_C(cC); // (MMA,MMA_M,MMA_N) -> (m,n)
|
||||
|
||||
const bool isBetaZero = (beta == Beta{});
|
||||
|
||||
// Custom axpby_if for now
|
||||
CUTE_UNROLL
|
||||
for (int i = 0; i < size(tCrC); ++i)
|
||||
{
|
||||
if (elem_less(tCcC(i), shape(sC)))
|
||||
{
|
||||
tCsC(i) = sC_store_op(isBetaZero ? alpha * static_cast<TypeC>(tCrC(i))
|
||||
: alpha * static_cast<TypeC>(tCrC(i)) +
|
||||
beta * static_cast<TypeC>(sC_load_op(tCsC(i))));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Slow fallback path
|
||||
template <class... Args,
|
||||
class Alpha, class TA, class ALayout, class TB, class BLayout,
|
||||
class Beta, class TC, class CLayout,
|
||||
class ALoadTransformOp, class BLoadTransformOp,
|
||||
class CLoadTransformOp, class CStoreTransformOp,
|
||||
__CUTE_REQUIRES(ALayout::rank == 2 && is_smem<TA>::value &&
|
||||
BLayout::rank == 2 && is_smem<TB>::value &&
|
||||
CLayout::rank == 2 && is_smem<TC>::value)>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
cooperative_gemm_predication(uint32_t thread_idx,
|
||||
TiledMMA<Args...> const& tiled_mma,
|
||||
Alpha const& alpha,
|
||||
Tensor<TA, ALayout> sA,
|
||||
Tensor<TB, BLayout> sB,
|
||||
Beta const& beta,
|
||||
Tensor<TC, CLayout> sC,
|
||||
ALoadTransformOp const& sA_load_op, // transforms A values before use in GEMM
|
||||
BLoadTransformOp const& sB_load_op, // transforms B values before use in GEMM
|
||||
CLoadTransformOp const& sC_load_op, // transforms C values before use in GEMM
|
||||
CStoreTransformOp const& sC_store_op) // transforms results before they are stored to C
|
||||
{
|
||||
// ThrMMA
|
||||
auto thr_mma = tiled_mma.get_thread_slice(thread_idx);
|
||||
cooperative_gemm_predication(thr_mma, alpha, sA, sB, beta, sC, sA_load_op, sB_load_op, sC_load_op, sC_store_op);
|
||||
}
|
||||
|
||||
// Unpredicated Cooperative GEMM
|
||||
template <class SmemCopyOpA, class SmemCopyOpB, class SmemCopyOpC,
|
||||
class... Args,
|
||||
class Alpha, class TA, class ALayout, class TB, class BLayout,
|
||||
class Beta, class TC, class CLayout,
|
||||
template <class... Args,
|
||||
class TA, class ALayout, class TB, class BLayout,
|
||||
class TC, class CLayout,
|
||||
class ALoadTransformOp, class BLoadTransformOp,
|
||||
class CLoadTransformOp, class CStoreTransformOp,
|
||||
__CUTE_REQUIRES(ALayout::rank == 2 && is_smem<TA>::value &&
|
||||
BLayout::rank == 2 && is_smem<TB>::value &&
|
||||
CLayout::rank == 2 && is_smem<TC>::value)>
|
||||
class SmemCopyOpA, class SmemCopyOpB>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
cooperative_gemm_no_predication(uint32_t thread_idx,
|
||||
TiledMMA<Args...> const& tiled_mma,
|
||||
Alpha const& alpha,
|
||||
Tensor<TA, ALayout> sA,
|
||||
Tensor<TB, BLayout> sB,
|
||||
Beta const& beta,
|
||||
Tensor<TC, CLayout> sC,
|
||||
ALoadTransformOp const& sA_load_op, // transforms A values before use in GEMM
|
||||
BLoadTransformOp const& sB_load_op, // transforms B values before use in GEMM
|
||||
CLoadTransformOp const& sC_load_op, // transforms C values before use in GEMM
|
||||
CStoreTransformOp const& sC_store_op) // transforms results before they are stored to C
|
||||
cooperative_gemm_no_predication(uint32_t thread_idx,
|
||||
ThrMMA<Args...> const& thr_mma,
|
||||
Tensor<TA, ALayout> const& sA,
|
||||
Tensor<TB, BLayout> const& sB,
|
||||
Tensor<TC, CLayout> & tCrC,
|
||||
ALoadTransformOp const& sA_load_op, // transforms A values before use in GEMM
|
||||
BLoadTransformOp const& sB_load_op, // transforms B values before use in GEMM
|
||||
SmemCopyOpA const& sA_copy_op,
|
||||
SmemCopyOpB const& sB_copy_op)
|
||||
{
|
||||
using TypeA = typename TA::value_type;
|
||||
using TypeB = typename TB::value_type;
|
||||
using TypeC = typename TC::value_type;
|
||||
using InputTypeA = typename TA::value_type;
|
||||
using InputTypeB = typename TB::value_type;
|
||||
using InputTypeC = typename TC::value_type;
|
||||
using ComputeTypeA = typename ThrMMA<Args...>::ValTypeA;
|
||||
using ComputeTypeB = typename ThrMMA<Args...>::ValTypeB;
|
||||
using ComputeTypeC = typename ThrMMA<Args...>::ValTypeC;
|
||||
|
||||
// ThrMMA
|
||||
auto thr_mma = tiled_mma.get_thread_slice(thread_idx);
|
||||
|
||||
//
|
||||
// MMA Partitioning
|
||||
//
|
||||
|
||||
Tensor tCsC = thr_mma.partition_C(sC);
|
||||
// Create register tensors for the MMA to operate on
|
||||
Tensor tCrA = thr_mma.partition_fragment_A(sA); // (MMA,MMA_M,MMA_K)
|
||||
Tensor tCrB = thr_mma.partition_fragment_B(sB); // (MMA,MMA_N,MMA_K)
|
||||
Tensor tCrC = thr_mma.make_fragment_C(tCsC); // (MMA,MMA_M,MMA_N)
|
||||
|
||||
using CopyOpAType = SmemCopyOpA;
|
||||
using CopyOpBType = SmemCopyOpB;
|
||||
|
||||
auto smem_tiled_copy_A = make_tiled_copy_A(Copy_Atom<CopyOpAType, TypeA>{}, thr_mma);
|
||||
auto smem_tiled_copy_A = make_tiled_copy_A(Copy_Atom<CopyOpAType, InputTypeA>{}, thr_mma);
|
||||
auto smem_thr_copy_A = smem_tiled_copy_A.get_thread_slice(thread_idx);
|
||||
Tensor tCsA = smem_thr_copy_A.partition_S(sA);
|
||||
Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA);
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // CPY_M
|
||||
CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCrA_copy_view)); // CPY_K
|
||||
Tensor tCrAi = make_fragment_like(tCsA);
|
||||
Tensor tCrAi_copy_view = smem_thr_copy_A.retile_D(tCrAi);
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrAi_copy_view)); // CPY_M
|
||||
CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCrAi_copy_view)); // CPY_K
|
||||
|
||||
auto smem_tiled_copy_B = make_tiled_copy_B(Copy_Atom<CopyOpBType, TypeB>{}, thr_mma);
|
||||
auto smem_tiled_copy_B = make_tiled_copy_B(Copy_Atom<CopyOpBType, InputTypeB>{}, thr_mma);
|
||||
auto smem_thr_copy_B = smem_tiled_copy_B.get_thread_slice(thread_idx);
|
||||
Tensor tCsB = smem_thr_copy_B.partition_S(sB);
|
||||
Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB);
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // CPY_N
|
||||
CUTE_STATIC_ASSERT_V(size<2>(tCsB) == size<2>(tCrB_copy_view)); // CPY_K
|
||||
Tensor tCrBi = make_fragment_like(tCsB);
|
||||
Tensor tCrBi_copy_view = smem_thr_copy_B.retile_D(tCrBi);
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrBi_copy_view)); // CPY_N
|
||||
CUTE_STATIC_ASSERT_V(size<2>(tCsB) == size<2>(tCrBi_copy_view)); // CPY_K
|
||||
|
||||
#if 0
|
||||
if (thread0()) {
|
||||
print(" sA: "); print(sA); print("\n");
|
||||
print(" sB: "); print(sB); print("\n");
|
||||
print(" sC: "); print(sC); print("\n");
|
||||
print(thr_mma); print("\n");
|
||||
print("tCsC: "); print(tCsC); print("\n");
|
||||
print("tCrA: "); print(tCrA); print("\n");
|
||||
print("tCrB: "); print(tCrB); print("\n");
|
||||
print("tCrC: "); print(tCrC); print("\n");
|
||||
@@ -333,15 +352,12 @@ cooperative_gemm_no_predication(uint32_t thread_idx,
|
||||
// PREFETCH
|
||||
//
|
||||
|
||||
copy(smem_tiled_copy_A, tCsA(_,_,Int<0>{}), tCrA_copy_view(_,_,Int<0>{}));
|
||||
copy(smem_tiled_copy_B, tCsB(_,_,Int<0>{}), tCrB_copy_view(_,_,Int<0>{}));
|
||||
copy(smem_tiled_copy_A, tCsA(_,_,Int<0>{}), tCrAi_copy_view(_,_,Int<0>{}));
|
||||
copy(smem_tiled_copy_B, tCsB(_,_,Int<0>{}), tCrBi_copy_view(_,_,Int<0>{}));
|
||||
//
|
||||
// MAINLOOP
|
||||
//
|
||||
|
||||
// Clear accumulators
|
||||
clear(tCrC);
|
||||
|
||||
constexpr int K_BLOCK_MAX = size<2>(tCrA);
|
||||
|
||||
CUTE_UNROLL
|
||||
@@ -352,132 +368,178 @@ cooperative_gemm_no_predication(uint32_t thread_idx,
|
||||
{
|
||||
// Load the next k_block
|
||||
int k_next = k_block + 1; // statically unrolled
|
||||
copy(smem_tiled_copy_A, tCsA(_,_,k_next), tCrA_copy_view(_,_,k_next));
|
||||
copy(smem_tiled_copy_B, tCsB(_,_,k_next), tCrB_copy_view(_,_,k_next));
|
||||
copy(smem_tiled_copy_A, tCsA(_,_,k_next), tCrAi_copy_view(_,_,k_next));
|
||||
copy(smem_tiled_copy_B, tCsB(_,_,k_next), tCrBi_copy_view(_,_,k_next));
|
||||
}
|
||||
|
||||
// Transform A and B, relying on the compiler to remove in case of identity ops
|
||||
cute::transform(tCrA(_,_,k_block), sA_load_op);
|
||||
cute::transform(tCrB(_,_,k_block), sB_load_op);
|
||||
cute::transform(tCrAi(_,_,k_block), tCrA(_,_,k_block), sA_load_op);
|
||||
cute::transform(tCrBi(_,_,k_block), tCrB(_,_,k_block), sB_load_op);
|
||||
|
||||
// GEMM on k_block in registers
|
||||
gemm(thr_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC);
|
||||
}
|
||||
|
||||
//
|
||||
// Epilogue
|
||||
//
|
||||
|
||||
auto isBetaZero = [&] () {
|
||||
if constexpr (is_complex<Beta>::value) {
|
||||
return beta.real() == Int<0>{} && beta.imag() == Int<0>{};
|
||||
}
|
||||
else {
|
||||
return beta == Int<0>{};
|
||||
}
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
} ();
|
||||
|
||||
using CopyOpCType = SmemCopyOpC;
|
||||
Tensor tCrD = thr_mma.make_fragment_C(tCsC);
|
||||
if(!isBetaZero) {
|
||||
copy(CopyOpCType{}, tCsC, tCrD);
|
||||
// Transform C on/after load
|
||||
cute::transform(tCrD, sC_load_op);
|
||||
}
|
||||
// C = alpha * (A * B) + beta * C
|
||||
axpby(alpha, tCrC, beta, tCrD);
|
||||
// Transform C before/on store
|
||||
cute::transform(tCrD, sC_store_op);
|
||||
copy(CopyOpCType{}, tCrD, tCsC);
|
||||
}
|
||||
|
||||
} // end namespace detail
|
||||
|
||||
template <class SmemCopyOpA, class SmemCopyOpB, class SmemCopyOpC,
|
||||
class... Args,
|
||||
class Alpha, class TA, class ALayout, class TB, class BLayout,
|
||||
class Beta, class TC, class CLayout,
|
||||
class ALoadTransformOp = cute::identity, class BLoadTransformOp = cute::identity,
|
||||
class CLoadTransformOp = cute::identity, class CStoreTransformOp = cute::identity,
|
||||
__CUTE_REQUIRES(ALayout::rank == 2 && is_smem<TA>::value &&
|
||||
BLayout::rank == 2 && is_smem<TB>::value &&
|
||||
CLayout::rank == 2 && is_smem<TC>::value)>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
cooperative_gemm(uint32_t thread_idx,
|
||||
TiledMMA<Args...> const& tiled_mma,
|
||||
Alpha const& alpha,
|
||||
Tensor<TA, ALayout> sA,
|
||||
Tensor<TB, BLayout> sB,
|
||||
Beta const& beta,
|
||||
Tensor<TC, CLayout> sC,
|
||||
ALoadTransformOp const& sA_load_op = {}, // transforms A values before use in GEMM
|
||||
BLoadTransformOp const& sB_load_op = {}, // transforms B values before use in GEMM
|
||||
CLoadTransformOp const& sC_load_op = {}, // transforms C values before use in GEMM
|
||||
CStoreTransformOp const& sC_store_op = {}) // transforms results before they are stored to C
|
||||
{
|
||||
CUTE_STATIC_ASSERT_V(size<0>(sA) == size<0>(sC)); // AM == CM
|
||||
CUTE_STATIC_ASSERT_V(size<0>(sB) == size<1>(sC)); // BN == CN
|
||||
CUTE_STATIC_ASSERT_V(size<1>(sA) == size<1>(sB)); // AK == BK
|
||||
|
||||
using TypeA = typename TA::value_type;
|
||||
using TypeB = typename TB::value_type;
|
||||
using TypeC = typename TC::value_type;
|
||||
|
||||
static_assert(is_convertible_v<decay_t<invoke_result_t<ALoadTransformOp, TypeA>>, TypeA>,
|
||||
"ALoadTransformOp functor must accept value of type TA::value_type and return value convertible to type TA::value_type");
|
||||
static_assert(is_convertible_v<decay_t<invoke_result_t<BLoadTransformOp, TypeB>>, TypeB>,
|
||||
"BLoadTransformOp functor must accept value of type TB::value_type and return value convertible to type TB::value_type");
|
||||
static_assert(is_convertible_v<decay_t<invoke_result_t<CLoadTransformOp, TypeC>>, TypeC>,
|
||||
"CLoadTransformOp functor must accept value of type TC::value_type and return value convertible to type TC::value_type");
|
||||
static_assert(is_convertible_v<decay_t<invoke_result_t<CStoreTransformOp, TypeC>>, TypeC>,
|
||||
"CStoreTransformOp functor must accept value of type TC::value_type and return value convertible to type TC::value_type");
|
||||
|
||||
static constexpr bool compat = evenly_divides(make_shape(size<0>(sA), size<0>(sB), size<1>(sA)),
|
||||
tile_shape(TiledMMA<Args...>{}));
|
||||
if constexpr (compat) {
|
||||
detail::cooperative_gemm_no_predication<SmemCopyOpA, SmemCopyOpB, SmemCopyOpC>(
|
||||
thread_idx, tiled_mma, alpha, sA, sB, beta, sC,
|
||||
sA_load_op, sB_load_op, sC_load_op, sC_store_op
|
||||
);
|
||||
} else {
|
||||
detail::cooperative_gemm_predication(
|
||||
thread_idx, tiled_mma, alpha, sA, sB, beta, sC,
|
||||
sA_load_op, sB_load_op, sC_load_op, sC_store_op
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// C passed as a shared memory tensor
|
||||
// Epilogue included
|
||||
template <class... Args,
|
||||
class Alpha, class TA, class ALayout, class TB, class BLayout,
|
||||
class Beta, class TC, class CLayout,
|
||||
class ALoadTransformOp = cute::identity, class BLoadTransformOp = cute::identity,
|
||||
class CLoadTransformOp = cute::identity, class CStoreTransformOp = cute::identity,
|
||||
__CUTE_REQUIRES(ALayout::rank == 2 && is_smem<TA>::value &&
|
||||
BLayout::rank == 2 && is_smem<TB>::value &&
|
||||
CLayout::rank == 2 && is_smem<TC>::value)>
|
||||
class SmemCopyOpA = DefaultCopy, class SmemCopyOpB = DefaultCopy,
|
||||
class SmemCopyOpC = DefaultCopy>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
cooperative_gemm(uint32_t thread_idx,
|
||||
TiledMMA<Args...> const& tiled_mma,
|
||||
Alpha const& alpha,
|
||||
Tensor<TA, ALayout> const& sA,
|
||||
Tensor<TB, BLayout> const& sB,
|
||||
Beta const& beta,
|
||||
Tensor<TC, CLayout> & sC,
|
||||
ALoadTransformOp const& sA_load_op = {}, // transforms A values before use in GEMM
|
||||
BLoadTransformOp const& sB_load_op = {}, // transforms B values before use in GEMM
|
||||
CLoadTransformOp const& sC_load_op = {}, // transforms C values before use in GEMM
|
||||
CStoreTransformOp const& sC_store_op = {}, // transforms results before they are stored to C
|
||||
SmemCopyOpA const& sA_copy_op = {},
|
||||
SmemCopyOpB const& sB_copy_op = {},
|
||||
SmemCopyOpC const& sC_copy_op = {})
|
||||
{
|
||||
CUTE_STATIC_ASSERT_V(rank(sA) == Int<2>{});
|
||||
CUTE_STATIC_ASSERT_V(rank(sB) == Int<2>{});
|
||||
CUTE_STATIC_ASSERT_V(rank(sC) == Int<2>{});
|
||||
|
||||
CUTE_STATIC_ASSERT_V(size<0>(sA) == size<0>(sC)); // AM == CM
|
||||
CUTE_STATIC_ASSERT_V(size<0>(sB) == size<1>(sC)); // BN == CN
|
||||
CUTE_STATIC_ASSERT_V(size<1>(sA) == size<1>(sB)); // AK == BK
|
||||
|
||||
using InputTypeA = typename TA::value_type;
|
||||
using InputTypeB = typename TB::value_type;
|
||||
using InputTypeC = typename TC::value_type;
|
||||
using ComputeTypeA = typename TiledMMA<Args...>::ValTypeA;
|
||||
using ComputeTypeB = typename TiledMMA<Args...>::ValTypeB;
|
||||
using ComputeTypeC = typename TiledMMA<Args...>::ValTypeC;
|
||||
|
||||
auto compat = evenly_divides(make_shape(size<0>(sA), size<0>(sB), size<1>(sA)),
|
||||
tile_shape(TiledMMA<Args...>{}));
|
||||
|
||||
// ThrMMA
|
||||
auto thr_mma = tiled_mma.get_thread_slice(thread_idx);
|
||||
Tensor tCsC = thr_mma.partition_C(sC); // (MMA,MMA_M,MMA_N) :: InputTypeC
|
||||
Tensor tCrC = thr_mma.make_fragment_C(tCsC); // (MMA,MMA_M,MMA_N) :: ComputeTypeC
|
||||
|
||||
// Clear accumulators
|
||||
clear(tCrC);
|
||||
|
||||
#if 0
|
||||
if (thread0()) {
|
||||
print(" sC: "); print(sC); print("\n");
|
||||
print(" tCsC: "); print(tCsC); print("\n");
|
||||
}
|
||||
#endif
|
||||
|
||||
if constexpr (is_constant<true, decltype(compat)>::value) {
|
||||
detail::cooperative_gemm_no_predication(
|
||||
thread_idx, thr_mma, sA, sB, tCrC, sA_load_op, sB_load_op, sA_copy_op, sB_copy_op
|
||||
);
|
||||
detail::epilogue_no_predication(
|
||||
alpha, tCrC, beta, tCsC, sC_load_op, sC_store_op, sC_copy_op
|
||||
);
|
||||
} else {
|
||||
detail::cooperative_gemm_predication(
|
||||
thr_mma, sA, sB, tCrC, sA_load_op, sB_load_op
|
||||
);
|
||||
detail::epilogue_predication(
|
||||
thr_mma, alpha, tCrC, beta, sC, tCsC, sC_load_op, sC_store_op
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// C already partitioned into registers on input
|
||||
// It can be passed non-empty
|
||||
// Epilogue not included
|
||||
template <class... Args,
|
||||
class TA, class ALayout, class TB, class BLayout,
|
||||
class TC, class CLayout,
|
||||
class ALoadTransformOp = cute::identity, class BLoadTransformOp = cute::identity,
|
||||
class SmemCopyOpA = DefaultCopy, class SmemCopyOpB = DefaultCopy>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
cooperative_gemm(uint32_t thread_idx,
|
||||
TiledMMA<Args...> const& tiled_mma,
|
||||
Tensor<TA, ALayout> const& sA,
|
||||
Tensor<TB, BLayout> const& sB,
|
||||
Tensor<TC, CLayout> & tCrC,
|
||||
ALoadTransformOp const& sA_load_op = {}, // transforms A values before use in GEMM
|
||||
BLoadTransformOp const& sB_load_op = {}, // transforms B values before use in GEMM
|
||||
SmemCopyOpA const& sA_copy_op = {},
|
||||
SmemCopyOpB const& sB_copy_op = {})
|
||||
{
|
||||
CUTE_STATIC_ASSERT_V(rank(sA) == Int<2>{});
|
||||
CUTE_STATIC_ASSERT_V(rank(sB) == Int<2>{});
|
||||
|
||||
CUTE_STATIC_ASSERT_V(size<1>(sA) == size<1>(sB)); // AK == BK
|
||||
|
||||
using InputTypeA = typename TA::value_type;
|
||||
using InputTypeB = typename TB::value_type;
|
||||
using InputTypeC = typename TC::value_type;
|
||||
using ComputeTypeA = typename TiledMMA<Args...>::ValTypeA;
|
||||
using ComputeTypeB = typename TiledMMA<Args...>::ValTypeB;
|
||||
using ComputeTypeC = typename TiledMMA<Args...>::ValTypeC;
|
||||
|
||||
// Check if input C fragment is compatible with thr_mma and problem size
|
||||
using ref_c_frag = decltype(partition_shape_C(tiled_mma, make_shape(size<0>(sA), size<0>(sB))));
|
||||
CUTE_STATIC_ASSERT_V(compatible(shape(ref_c_frag{}), shape(tCrC)));
|
||||
|
||||
auto compat = evenly_divides(make_shape(size<0>(sA), size<0>(sB), size<1>(sA)),
|
||||
tile_shape(TiledMMA<Args...>{}));
|
||||
|
||||
// ThrMMA
|
||||
auto thr_mma = tiled_mma.get_thread_slice(thread_idx);
|
||||
|
||||
if constexpr (is_constant<true, decltype(compat)>::value) {
|
||||
detail::cooperative_gemm_no_predication(
|
||||
thread_idx, thr_mma, sA, sB, tCrC, sA_load_op, sB_load_op, sA_copy_op, sB_copy_op
|
||||
);
|
||||
} else {
|
||||
detail::cooperative_gemm_predication(
|
||||
thr_mma, sA, sB, tCrC, sA_load_op, sB_load_op
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Accept mutable temporaries
|
||||
template <class... Args,
|
||||
class Alpha, class TA, class ALayout, class TB, class BLayout,
|
||||
class Beta, class TC, class CLayout,
|
||||
class ALoadTransformOp = cute::identity, class BLoadTransformOp = cute::identity,
|
||||
class CLoadTransformOp = cute::identity, class CStoreTransformOp = cute::identity,
|
||||
class SmemCopyOpA = DefaultCopy, class SmemCopyOpB = DefaultCopy,
|
||||
class SmemCopyOpC = DefaultCopy>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
cooperative_gemm(uint32_t thread_idx,
|
||||
TiledMMA<Args...> const& tiled_mma,
|
||||
Alpha const& alpha,
|
||||
Tensor<TA, ALayout> sA,
|
||||
Tensor<TB, BLayout> sB,
|
||||
Beta const& beta,
|
||||
Tensor<TC, CLayout> sC,
|
||||
ALoadTransformOp const& sA_load_op = {}, // transforms A values before use in GEMM
|
||||
BLoadTransformOp const& sB_load_op = {}, // transforms B values before use in GEMM
|
||||
CLoadTransformOp const& sC_load_op = {}, // transforms C values before use in GEMM
|
||||
CStoreTransformOp const& sC_store_op = {}) // transforms results before they are stored to C
|
||||
TiledMMA<Args...> const& tiled_mma,
|
||||
Alpha const& alpha,
|
||||
Tensor<TA, ALayout> const& sA,
|
||||
Tensor<TB, BLayout> const& sB,
|
||||
Beta const& beta,
|
||||
Tensor<TC, CLayout> && sC,
|
||||
ALoadTransformOp const& sA_load_op = {}, // transforms A values before use in GEMM
|
||||
BLoadTransformOp const& sB_load_op = {}, // transforms B values before use in GEMM
|
||||
CLoadTransformOp const& sC_load_op = {}, // transforms C values before use in GEMM
|
||||
CStoreTransformOp const& sC_store_op = {}, // transforms results before they are stored to C
|
||||
SmemCopyOpA const& sA_copy_op = {},
|
||||
SmemCopyOpB const& sB_copy_op = {},
|
||||
SmemCopyOpC const& sC_copy_op = {})
|
||||
{
|
||||
using CopyOpA = AutoVectorizingCopyWithAssumedAlignment<sizeof_bits_v<typename TA::value_type>>;
|
||||
using CopyOpB = AutoVectorizingCopyWithAssumedAlignment<sizeof_bits_v<typename TB::value_type>>;
|
||||
using CopyOpC = AutoVectorizingCopyWithAssumedAlignment<sizeof_bits_v<typename TC::value_type>>;
|
||||
cooperative_gemm<CopyOpA, CopyOpB, CopyOpC>(
|
||||
thread_idx, tiled_mma, alpha, sA, sB, beta, sC,
|
||||
sA_load_op, sB_load_op, sC_load_op, sC_store_op
|
||||
);
|
||||
cooperative_gemm(thread_idx, tiled_mma, alpha, sA, sB, beta, sC,
|
||||
sA_load_op, sB_load_op, sC_load_op, sC_store_op,
|
||||
sA_copy_op, sB_copy_op, sC_copy_op);
|
||||
}
|
||||
|
||||
// Legacy overload of cute::gemm for backwards-compatibility
|
||||
@@ -485,27 +547,38 @@ template <class... Args,
|
||||
class Alpha, class TA, class ALayout, class TB, class BLayout,
|
||||
class Beta, class TC, class CLayout,
|
||||
class ALoadTransformOp = cute::identity, class BLoadTransformOp = cute::identity,
|
||||
class CLoadTransformOp = cute::identity, class CStoreTransformOp = cute::identity,
|
||||
__CUTE_REQUIRES(ALayout::rank == 2 && is_smem<TA>::value &&
|
||||
BLayout::rank == 2 && is_smem<TB>::value &&
|
||||
CLayout::rank == 2 && is_smem<TC>::value)>
|
||||
class CLoadTransformOp = cute::identity, class CStoreTransformOp = cute::identity>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
gemm(ThrMMA<Args...> const& thr_mma,
|
||||
Alpha const& alpha,
|
||||
Tensor<TA, ALayout> sA,
|
||||
Tensor<TB, BLayout> sB,
|
||||
Beta const& beta,
|
||||
Tensor<TC, CLayout> sC,
|
||||
ALoadTransformOp const& sA_load_op = {}, // transforms A values before use in GEMM
|
||||
BLoadTransformOp const& sB_load_op = {}, // transforms B values before use in GEMM
|
||||
CLoadTransformOp const& sC_load_op = {}, // transforms C values before use in GEMM
|
||||
CStoreTransformOp const& sC_store_op = {}) // transforms results before they are stored to C
|
||||
gemm(ThrMMA<Args...> const& thr_mma,
|
||||
Alpha const& alpha,
|
||||
Tensor<TA, ALayout> const& sA,
|
||||
Tensor<TB, BLayout> const& sB,
|
||||
Beta const& beta,
|
||||
Tensor<TC, CLayout> & sC,
|
||||
ALoadTransformOp const& sA_load_op = {}, // transforms A values before use in GEMM
|
||||
BLoadTransformOp const& sB_load_op = {}, // transforms B values before use in GEMM
|
||||
CLoadTransformOp const& sC_load_op = {}, // transforms C values before use in GEMM
|
||||
CStoreTransformOp const& sC_store_op = {}) // transforms results before they are stored to C
|
||||
{
|
||||
CUTE_STATIC_ASSERT_V(rank(sA) == Int<2>{});
|
||||
CUTE_STATIC_ASSERT_V(rank(sB) == Int<2>{});
|
||||
CUTE_STATIC_ASSERT_V(rank(sC) == Int<2>{});
|
||||
|
||||
CUTE_STATIC_ASSERT_V(size<0>(sA) == size<0>(sC)); // AM == CM
|
||||
CUTE_STATIC_ASSERT_V(size<0>(sB) == size<1>(sC)); // BN == CN
|
||||
CUTE_STATIC_ASSERT_V(size<1>(sA) == size<1>(sB)); // AK == BK
|
||||
|
||||
Tensor tCsC = thr_mma.partition_C(sC); // (MMA,MMA_M,MMA_N)
|
||||
Tensor tCrC = thr_mma.make_fragment_C(tCsC); // (MMA,MMA_M,MMA_N)
|
||||
|
||||
// Goes directly to the slow path to avoid getting thread_idx from thr_mma
|
||||
detail::cooperative_gemm_predication(
|
||||
thr_mma, alpha, sA, sB, beta, sC,
|
||||
sA_load_op, sB_load_op, sC_load_op, sC_store_op
|
||||
thr_mma, sA, sB, sC, sA_load_op, sB_load_op
|
||||
);
|
||||
|
||||
detail::epilogue_predication(
|
||||
thr_mma, alpha, tCrC, beta, sC, tCsC, sC_load_op, sC_store_op
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
@@ -38,79 +38,6 @@
|
||||
namespace cute
|
||||
{
|
||||
|
||||
//
|
||||
// Accept mutable temporaries
|
||||
//
|
||||
|
||||
template <class SrcEngine, class SrcLayout,
|
||||
class DstEngine, class DstLayout>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
copy(Tensor<SrcEngine, SrcLayout> const& src,
|
||||
Tensor<DstEngine, DstLayout> && dst)
|
||||
{
|
||||
return copy(src, dst);
|
||||
}
|
||||
|
||||
template <class VecType,
|
||||
class SrcEngine, class SrcLayout,
|
||||
class DstEngine, class DstLayout>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
copy_vec(Tensor<SrcEngine, SrcLayout> const& src,
|
||||
Tensor<DstEngine, DstLayout> && dst)
|
||||
{
|
||||
return copy_vec<VecType>(src, dst);
|
||||
}
|
||||
|
||||
template <class SrcEngine, class SrcLayout,
|
||||
class DstEngine, class DstLayout>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
copy_aligned(Tensor<SrcEngine, SrcLayout> const& src,
|
||||
Tensor<DstEngine, DstLayout> && dst)
|
||||
{
|
||||
return copy_aligned(src, dst);
|
||||
}
|
||||
|
||||
template <class PrdTensor,
|
||||
class SrcEngine, class SrcLayout,
|
||||
class DstEngine, class DstLayout>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
copy_if(PrdTensor const& pred,
|
||||
Tensor<SrcEngine, SrcLayout> const& src,
|
||||
Tensor<DstEngine, DstLayout> && dst)
|
||||
{
|
||||
return copy_if(pred, src, dst);
|
||||
}
|
||||
|
||||
template <class CopyPolicy,
|
||||
class PrdTensor,
|
||||
class SrcEngine, class SrcLayout,
|
||||
class DstEngine, class DstLayout>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
copy_if(CopyPolicy const& copy_policy,
|
||||
PrdTensor const& pred,
|
||||
Tensor<SrcEngine, SrcLayout> const& src,
|
||||
Tensor<DstEngine, DstLayout> && dst)
|
||||
{
|
||||
return copy_if(copy_policy, pred, src, dst);
|
||||
}
|
||||
|
||||
template <class CopyPolicy,
|
||||
class SrcEngine, class SrcLayout,
|
||||
class DstEngine, class DstLayout>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
copy(CopyPolicy const& copy_policy,
|
||||
Tensor<SrcEngine, SrcLayout> const& src,
|
||||
Tensor<DstEngine, DstLayout> && dst)
|
||||
{
|
||||
return copy(copy_policy, src, dst);
|
||||
}
|
||||
|
||||
//
|
||||
// copy_if -- Predicated Copy
|
||||
//
|
||||
@@ -124,12 +51,13 @@ copy_if(PrdTensor const& pred,
|
||||
Tensor<SrcEngine, SrcLayout> const& src,
|
||||
Tensor<DstEngine, DstLayout> & dst)
|
||||
{
|
||||
auto copy_op = select_elementwise_copy(src, dst);
|
||||
using SrcType = typename SrcEngine::value_type;
|
||||
using DstType = typename DstEngine::value_type;
|
||||
|
||||
CUTE_UNROLL
|
||||
for (int i = 0; i < size(src); ++i) {
|
||||
for (int i = 0; i < size(dst); ++i) {
|
||||
if (pred(i)) {
|
||||
copy_op.copy(src(i), dst(i));
|
||||
dst(i) = static_cast<DstType>(static_cast<SrcType>(src(i)));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -138,17 +66,6 @@ copy_if(PrdTensor const& pred,
|
||||
// copy_if -- Predicated CopyAtom
|
||||
//
|
||||
|
||||
namespace detail {
|
||||
|
||||
// Trait that detects if atom's traits has a member function with(bool)
|
||||
template <class, class Enable = void>
|
||||
constexpr bool has_with_bool = false;
|
||||
|
||||
template <class T>
|
||||
constexpr bool has_with_bool<T, cute::void_t<decltype(declval<typename T::Traits>().with(declval<bool>()))>> = true;
|
||||
|
||||
} // end namespace detail
|
||||
|
||||
template <class... CopyArgs,
|
||||
class PredTensor,
|
||||
class SrcEngine, class SrcLayout,
|
||||
@@ -161,73 +78,90 @@ copy_if(Copy_Atom<CopyArgs...> const& copy_atom,
|
||||
Tensor<DstEngine, DstLayout> & dst) // (V,Rest...)
|
||||
{
|
||||
static_assert(SrcLayout::rank == DstLayout::rank, "CopyAtom rank-mismatch.");
|
||||
auto has_with_bool = cute::is_valid([](auto t)->void_t<decltype(declval<typename decltype(t)::Traits>().with(true))>{}, copy_atom);
|
||||
|
||||
if constexpr (SrcLayout::rank == 1) { // Dispatch the copy
|
||||
copy_atom.call(src, dst);
|
||||
if constexpr (has_with_bool) {
|
||||
copy_atom.with(pred()).call(src, dst);
|
||||
} else {
|
||||
if (pred()) { copy_atom.call(src, dst); }
|
||||
}
|
||||
} else { // Loop over all but the first mode
|
||||
constexpr int R = SrcLayout::rank;
|
||||
Tensor src_v = group_modes<1,R>(src);
|
||||
Tensor dst_v = group_modes<1,R>(dst);
|
||||
CUTE_UNROLL
|
||||
for (int i = 0; i < size<1>(src_v); ++i) {
|
||||
// If copy traits can be transformed with a predicate value, do it, otherwise branch here
|
||||
if constexpr (detail::has_with_bool<Copy_Atom<CopyArgs...>>) {
|
||||
for (int i = 0; i < size<1>(dst_v); ++i) {
|
||||
if constexpr (has_with_bool) {
|
||||
copy_atom.with(pred(i)).call(src_v(_,i), dst_v(_,i));
|
||||
} else {
|
||||
if (pred(i)) {
|
||||
copy_atom.call(src_v(_,i), dst_v(_,i));
|
||||
}
|
||||
if (pred(i)) { copy_atom.call(src_v(_,i), dst_v(_,i)); }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// copy_vec -- attempt vectorized copy with VecType
|
||||
// copy_if -- AutoCopyAsync
|
||||
//
|
||||
|
||||
template <class VecType,
|
||||
template <class PrdTensor,
|
||||
class SrcEngine, class SrcLayout,
|
||||
class DstEngine, class DstLayout>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
copy_vec(Tensor<SrcEngine, SrcLayout> const& src,
|
||||
Tensor<DstEngine, DstLayout> & dst)
|
||||
copy_if(AutoCopyAsync const& cpy,
|
||||
PrdTensor const& pred,
|
||||
Tensor<SrcEngine, SrcLayout> const& src,
|
||||
Tensor<DstEngine, DstLayout> & dst)
|
||||
{
|
||||
static_assert(sizeof_bits_v<VecType> >= 8 && sizeof_bits_v<VecType> % 8 == 0,
|
||||
"Expected a vectorization type of at least a byte.");
|
||||
using SrcElemWithConst = remove_reference_t<typename SrcEngine::reference>;
|
||||
using SrcType = typename SrcEngine::value_type;
|
||||
using DstType = typename DstEngine::value_type;
|
||||
if constexpr (cute::is_same<SrcType, DstType>::value &&
|
||||
sizeof_bits_v<VecType> > sizeof_bits_v<DstType>)
|
||||
{
|
||||
// Preserve volatility of Src/Dst types.
|
||||
using SrcVecType = conditional_t<is_volatile_v<typename SrcEngine::element_type>, VecType const volatile, VecType const>;
|
||||
using DstVecType = conditional_t<is_volatile_v<typename DstEngine::element_type>, VecType volatile, VecType >;
|
||||
Tensor src_v = recast<SrcVecType>(src);
|
||||
Tensor dst_v = recast<DstVecType>(dst);
|
||||
|
||||
#if 0
|
||||
if (thread0()) {
|
||||
print("copy_vec<%db> -- vectorizing copy:\n", int(sizeof_bits_v<VecType>));
|
||||
print(" "); print(src); print(" => "); print(src_v); print("\n");
|
||||
print(" "); print(dst); print(" => "); print(dst_v); print("\n");
|
||||
auto copy_op = []() {
|
||||
#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED)
|
||||
if constexpr (is_gmem<SrcEngine>::value && is_smem<DstEngine>::value &&
|
||||
sizeof(SrcType) == sizeof(DstType)) {
|
||||
if constexpr (is_const_v<SrcElemWithConst> && sizeof(SrcType) == 16) {
|
||||
return SM80_CP_ASYNC_CACHEGLOBAL<SrcType,DstType>{};
|
||||
} else if constexpr (sizeof(SrcType) == 4 || sizeof(SrcType) == 8 || sizeof(SrcType) == 16) {
|
||||
return SM80_CP_ASYNC_CACHEALWAYS<SrcType,DstType>{};
|
||||
} else {
|
||||
return UniversalCopy<SrcType,DstType>{};
|
||||
}
|
||||
} else {
|
||||
return UniversalCopy<SrcType,DstType>{};
|
||||
}
|
||||
#endif
|
||||
|
||||
return copy_if(TrivialPredTensor{}, src_v, dst_v);
|
||||
} else {
|
||||
#if 0
|
||||
if (thread0()) {
|
||||
print("copy_vec<%db> -- NOT vectorizing copy:\n", int(sizeof_bits_v<VecType>));
|
||||
print(" "); print(src); print("\n");
|
||||
print(" "); print(dst); print("\n");
|
||||
}
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
#else
|
||||
return UniversalCopy<SrcType,DstType>{};
|
||||
#endif
|
||||
}();
|
||||
|
||||
return copy_if(TrivialPredTensor{}, src, dst);
|
||||
CUTE_UNROLL
|
||||
for (int i = 0; i < size(dst); ++i) {
|
||||
if (pred(i)) {
|
||||
copy_op.copy(src(i), dst(i));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// copy -- AutoCopyAsync
|
||||
//
|
||||
|
||||
template <class SrcEngine, class SrcLayout,
|
||||
class DstEngine, class DstLayout>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
copy(AutoCopyAsync const& cpy,
|
||||
Tensor<SrcEngine, SrcLayout> const& src, // (V,Rest...)
|
||||
Tensor<DstEngine, DstLayout> & dst) // (V,Rest...)
|
||||
{
|
||||
copy_if(cpy, TrivialPredTensor{}, src, dst);
|
||||
}
|
||||
|
||||
//
|
||||
// copy -- CopyAtom
|
||||
//
|
||||
@@ -238,15 +172,56 @@ template <class... CopyArgs,
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
copy(Copy_Atom<CopyArgs...> const& copy_atom,
|
||||
Tensor<SrcEngine, SrcLayout> const& src,
|
||||
Tensor<DstEngine, DstLayout> & dst)
|
||||
Tensor<SrcEngine, SrcLayout> const& src, // (V,Rest...)
|
||||
Tensor<DstEngine, DstLayout> & dst) // (V,Rest...)
|
||||
{
|
||||
return copy_if(copy_atom, TrivialPredTensor{}, src, dst);
|
||||
static_assert(SrcLayout::rank == DstLayout::rank, "CopyAtom rank-mismatch.");
|
||||
|
||||
if constexpr (SrcLayout::rank == 1) { // Dispatch the copy
|
||||
copy_atom.call(src, dst);
|
||||
} else { // Loop over all but the first mode
|
||||
constexpr int R = SrcLayout::rank;
|
||||
Tensor src_v = group_modes<1,R>(src);
|
||||
Tensor dst_v = group_modes<1,R>(dst);
|
||||
|
||||
if constexpr (is_static<decltype(shape(src_v))>::value && is_static<decltype(shape(dst_v))>::value) {
|
||||
CUTE_STATIC_ASSERT_V(size<1>(src_v) == size<1>(dst_v));
|
||||
|
||||
// AutoFilter on the Rest-mode
|
||||
auto dst_null = nullspace(layout<1>(dst_v));
|
||||
|
||||
Tensor dst_n = zipped_divide(dst_v, make_tile(shape<0>(dst_v), dst_null)); // ((V, NLL), (_1, Rest))
|
||||
Tensor src_n = zipped_divide(src_v, make_tile(shape<0>(src_v), dst_null)); // ((V, NLL), (_1, Rest))
|
||||
|
||||
CUTE_STATIC_ASSERT_V(size<1>(src_n) == size<1>(dst_n));
|
||||
CUTE_STATIC_ASSERT_V((cosize<0,1>(dst_n.layout()) == Int<1>{}), "Nullspace definition error");
|
||||
CUTE_STATIC_ASSERT_V((cosize<0,1>(src_n.layout()) == Int<1>{}), "Error: Ambiguous scatter detected in copy");
|
||||
CUTE_STATIC_ASSERT_V((size<1,0>(dst_n) == Int<1>{}));
|
||||
CUTE_STATIC_ASSERT_V((size<1,0>(src_n) == Int<1>{}));
|
||||
|
||||
Tensor dst_c = dst_n(make_coord(_,Int<0>{}),make_coord(Int<0>{},_)); // (V, Rest)
|
||||
Tensor src_c = src_n(make_coord(_,Int<0>{}),make_coord(Int<0>{},_)); // (V, Rest)
|
||||
|
||||
CUTE_STATIC_ASSERT_V(size<1>(src_c) == size<1>(dst_c));
|
||||
CUTE_STATIC_ASSERT_V(shape<0>(dst_c) == shape<0>(dst));
|
||||
CUTE_STATIC_ASSERT_V(shape<0>(src_c) == shape<0>(src));
|
||||
|
||||
CUTE_UNROLL
|
||||
for (int i = 0; i < size<1>(dst_c); ++i) {
|
||||
copy_atom.call(src_c(_,i), dst_c(_,i));
|
||||
}
|
||||
} else {
|
||||
CUTE_UNROLL
|
||||
for (int i = 0; i < size<1>(dst_v); ++i) {
|
||||
copy_atom.call(src_v(_,i), dst_v(_,i));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//////////////////////////////////////////
|
||||
// Special Auto-Vectorizing Overloads
|
||||
//////////////////////////////////////////
|
||||
////////////////////////////////////////////////////////
|
||||
// Special Auto-Vectorizing, Auto-Filtering Overloads //
|
||||
////////////////////////////////////////////////////////
|
||||
|
||||
// Specialization for AutoVectorizingCopyAssumedAlignment<MaxVecBits>
|
||||
template <int MaxVecBits, class... Args,
|
||||
@@ -258,30 +233,67 @@ copy(AutoVectorizingCopyWithAssumedAlignment<MaxVecBits> const&,
|
||||
Tensor<SrcEngine, SrcLayout> const& src,
|
||||
Tensor<DstEngine, DstLayout> & dst)
|
||||
{
|
||||
constexpr int vec_elem = decltype(max_common_vector(src, dst))::value;
|
||||
constexpr int common_elem = CUTE_STATIC_V(max_common_vector(src, dst));
|
||||
constexpr int align_bits = CUTE_STATIC_V(gcd(max_alignment(src), max_alignment(dst), Int<MaxVecBits>{}));
|
||||
static_assert(is_integral<decltype(Int<common_elem>{} * sizeof_bits_v<typename SrcEngine::value_type>)>::value, "Error: Attempting a subbit copy!");
|
||||
constexpr int vec_bits = gcd(common_elem * sizeof_bits_v<typename SrcEngine::value_type>, align_bits);
|
||||
|
||||
constexpr int max_align_src = decltype(max_alignment(src.layout()))::value;
|
||||
constexpr int max_align_dst = decltype(max_alignment(dst.layout()))::value;
|
||||
constexpr int max_align = gcd(vec_elem, max_align_src, max_align_dst);
|
||||
if constexpr (common_elem > 1 && ((vec_bits % 8) == 0)) {
|
||||
// If more than one element vectorizes to 8bits or more, then recast and copy
|
||||
using VecType = uint_bit_t<vec_bits>;
|
||||
// Preserve volatility
|
||||
using SrcVecType = conditional_t<is_volatile_v<typename SrcEngine::element_type>, VecType const volatile, VecType const>;
|
||||
using DstVecType = conditional_t<is_volatile_v<typename DstEngine::element_type>, VecType volatile, VecType >;
|
||||
|
||||
constexpr int src_bits = sizeof_bits<typename SrcEngine::value_type>::value;
|
||||
constexpr int vec_bits = gcd(src_bits * max_align, MaxVecBits);
|
||||
// Recast
|
||||
Tensor src_v = recast<SrcVecType>(src);
|
||||
Tensor dst_v = recast<DstVecType>(dst);
|
||||
|
||||
if constexpr (vec_elem > 1 && vec_bits >= 8) {
|
||||
// If more than one element vectorizes to 8bits or more, then copy_vec
|
||||
#if 0
|
||||
if (thread0()) {
|
||||
print("copy -- found max_common_vector of %d elems and vectorization to %d bits\n", vec_elem, vec_bits);
|
||||
print(" "); print(src); print("\n");
|
||||
print(" "); print(dst); print("\n");
|
||||
print("copy -- found max_common_vector of %d elems and vectorization to %d bits\n", common_elem, vec_bits);
|
||||
print(" "); print(src); print(" => "); print(src_v); print("\n");
|
||||
print(" "); print(dst); print(" => "); print(dst_v); print("\n");
|
||||
}
|
||||
#endif
|
||||
return copy_vec<uint_bit_t<vec_bits>>(src, dst);
|
||||
|
||||
return copy_if(TrivialPredTensor{}, src_v, dst_v);
|
||||
} else {
|
||||
return copy_if(TrivialPredTensor{}, src, dst);
|
||||
}
|
||||
}
|
||||
|
||||
template <class Base>
|
||||
struct AutoFilter {
|
||||
Base const& base;
|
||||
CUTE_HOST_DEVICE AutoFilter(Base const& b) : base(b) {}
|
||||
};
|
||||
|
||||
// Specialization for AutoFilter
|
||||
template <class CopyOp,
|
||||
class SrcEngine, class SrcLayout,
|
||||
class DstEngine, class DstLayout>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
copy(AutoFilter<CopyOp> const& copy_op,
|
||||
Tensor<SrcEngine, SrcLayout> const& src,
|
||||
Tensor<DstEngine, DstLayout> & dst)
|
||||
{
|
||||
if constexpr (is_constant<true, decltype(size(src) == size(dst))>::value) {
|
||||
auto dst_null = nullspace(dst.layout());
|
||||
|
||||
Tensor dst_n = zipped_divide(dst, dst_null);
|
||||
Tensor src_n = zipped_divide(src, dst_null);
|
||||
|
||||
CUTE_STATIC_ASSERT_V(cosize<0>(dst_n.layout()) == Int<1>{}, "Nullspace definition error");
|
||||
CUTE_STATIC_ASSERT_V(cosize<0>(src_n.layout()) == Int<1>{}, "Error: Ambiguous scatter detected in copy");
|
||||
|
||||
copy(copy_op.base, src_n(Int<0>{},_), dst_n(Int<0>{},_));
|
||||
} else {
|
||||
copy(copy_op.base, src, dst);
|
||||
}
|
||||
}
|
||||
|
||||
// Auto-vectorizing copy for static layouts
|
||||
template <class SrcEngine, class SrcLayout,
|
||||
class DstEngine, class DstLayout>
|
||||
@@ -292,7 +304,11 @@ copy(Tensor<SrcEngine, SrcLayout> const& src,
|
||||
{
|
||||
if constexpr (is_static<SrcLayout>::value && is_static<DstLayout>::value) {
|
||||
// Assume Tensors with static layouts (e.g. registers) have pointers that are 128b aligned
|
||||
return copy(AutoVectorizingCopyWithAssumedAlignment<128>{}, src, dst);
|
||||
return copy(AutoFilter(AutoVectorizingCopyWithAssumedAlignment<128>{}), src, dst);
|
||||
} else
|
||||
if constexpr (is_static<decltype(shape(src))>::value && is_static<decltype(shape(dst))>::value) {
|
||||
// Tensors with static shapes can be filtered, but do not assume that dynamic layouts are aligned.
|
||||
return copy(AutoFilter(AutoVectorizingCopyWithAssumedAlignment<8>{}), src, dst);
|
||||
} else {
|
||||
// Do not assume that dynamic layouts are aligned.
|
||||
return copy(AutoVectorizingCopyWithAssumedAlignment<8>{}, src, dst);
|
||||
@@ -307,7 +323,12 @@ void
|
||||
copy_aligned(Tensor<SrcEngine, SrcLayout> const& src,
|
||||
Tensor<DstEngine, DstLayout> & dst)
|
||||
{
|
||||
return copy(AutoVectorizingCopyWithAssumedAlignment<128>{}, src, dst);
|
||||
if constexpr (is_static<decltype(shape(src))>::value && is_static<decltype(shape(dst))>::value) {
|
||||
// Tensors with static shapes can be filtered
|
||||
return copy(AutoFilter(AutoVectorizingCopyWithAssumedAlignment<128>{}), src, dst);
|
||||
} else {
|
||||
return copy(AutoVectorizingCopyWithAssumedAlignment<128>{}, src, dst);
|
||||
}
|
||||
}
|
||||
|
||||
// Specializaton for Atom AutoVectorizingCopyAssumedAlignment
|
||||
@@ -379,4 +400,146 @@ copy(Copy_Atom<Copy_Traits<SM90_BULK_COPY_AUTO, CT_Args...>, CA_Args...> const&
|
||||
}
|
||||
#endif // #if defined(CUTE_COPY_ATOM_TMA_SM90_ENABLED)
|
||||
|
||||
//
|
||||
// Decay TiledCopy to CopyAtom
|
||||
//
|
||||
|
||||
template <class CopyAtom, class TV, class Tiler,
|
||||
class PrdTensor,
|
||||
class SrcEngine, class SrcLayout,
|
||||
class DstEngine, class DstLayout>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
copy_if(TiledCopy<CopyAtom, TV, Tiler> const& tiled_copy,
|
||||
PrdTensor const& pred,
|
||||
Tensor<SrcEngine, SrcLayout> const& src,
|
||||
Tensor<DstEngine, DstLayout> & dst)
|
||||
{
|
||||
return copy_if(static_cast<CopyAtom const&>(tiled_copy), pred, src, dst);
|
||||
}
|
||||
|
||||
template <class CopyAtom, class TV, class Tiler,
|
||||
class SrcEngine, class SrcLayout,
|
||||
class DstEngine, class DstLayout>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
copy(TiledCopy<CopyAtom, TV, Tiler> const& tiled_copy,
|
||||
Tensor<SrcEngine, SrcLayout> const& src,
|
||||
Tensor<DstEngine, DstLayout> & dst)
|
||||
{
|
||||
return copy(static_cast<CopyAtom const&>(tiled_copy), src, dst);
|
||||
}
|
||||
|
||||
template <class TiledCopy, class ThrIdx,
|
||||
class PrdTensor,
|
||||
class SrcEngine, class SrcLayout,
|
||||
class DstEngine, class DstLayout>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
copy_if(ThrCopy<TiledCopy, ThrIdx> const& thr_copy,
|
||||
PrdTensor const& pred,
|
||||
Tensor<SrcEngine, SrcLayout> const& src,
|
||||
Tensor<DstEngine, DstLayout> & dst) = delete;
|
||||
|
||||
template <class TiledCopy, class ThrIdx,
|
||||
class SrcEngine, class SrcLayout,
|
||||
class DstEngine, class DstLayout>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
copy(ThrCopy<TiledCopy, ThrIdx> const& thr_copy,
|
||||
Tensor<SrcEngine, SrcLayout> const& src,
|
||||
Tensor<DstEngine, DstLayout> & dst) = delete;
|
||||
|
||||
//
|
||||
// Catch uncaught policies
|
||||
//
|
||||
|
||||
template <class CopyPolicy,
|
||||
class PredTensor,
|
||||
class SrcEngine, class SrcLayout,
|
||||
class DstEngine, class DstLayout>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
copy_if(CopyPolicy const& cpy,
|
||||
PredTensor const& prd,
|
||||
Tensor<SrcEngine, SrcLayout> const& src,
|
||||
Tensor<DstEngine, DstLayout> & dst)
|
||||
{
|
||||
static_assert(dependent_false<CopyPolicy>, "Unrecognized CopyPolicy.");
|
||||
}
|
||||
|
||||
template <class CopyPolicy,
|
||||
class SrcEngine, class SrcLayout,
|
||||
class DstEngine, class DstLayout>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
copy(CopyPolicy const& cpy,
|
||||
Tensor<SrcEngine, SrcLayout> const& src,
|
||||
Tensor<DstEngine, DstLayout> & dst)
|
||||
{
|
||||
static_assert(dependent_false<CopyPolicy>, "Unrecognized CopyPolicy.");
|
||||
}
|
||||
|
||||
//
|
||||
// Accept mutable temporaries
|
||||
//
|
||||
|
||||
template <class PrdTensor,
|
||||
class SrcEngine, class SrcLayout,
|
||||
class DstEngine, class DstLayout>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
copy_if(PrdTensor const& pred,
|
||||
Tensor<SrcEngine, SrcLayout> const& src,
|
||||
Tensor<DstEngine, DstLayout> && dst)
|
||||
{
|
||||
return copy_if(pred, src, dst);
|
||||
}
|
||||
|
||||
template <class CopyPolicy,
|
||||
class PrdTensor,
|
||||
class SrcEngine, class SrcLayout,
|
||||
class DstEngine, class DstLayout>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
copy_if(CopyPolicy const& copy_policy,
|
||||
PrdTensor const& pred,
|
||||
Tensor<SrcEngine, SrcLayout> const& src,
|
||||
Tensor<DstEngine, DstLayout> && dst)
|
||||
{
|
||||
return copy_if(copy_policy, pred, src, dst);
|
||||
}
|
||||
|
||||
template <class SrcEngine, class SrcLayout,
|
||||
class DstEngine, class DstLayout>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
copy(Tensor<SrcEngine, SrcLayout> const& src,
|
||||
Tensor<DstEngine, DstLayout> && dst)
|
||||
{
|
||||
return copy(src, dst);
|
||||
}
|
||||
|
||||
template <class CopyPolicy,
|
||||
class SrcEngine, class SrcLayout,
|
||||
class DstEngine, class DstLayout>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
copy(CopyPolicy const& copy_policy,
|
||||
Tensor<SrcEngine, SrcLayout> const& src,
|
||||
Tensor<DstEngine, DstLayout> && dst)
|
||||
{
|
||||
return copy(copy_policy, src, dst);
|
||||
}
|
||||
|
||||
template <class SrcEngine, class SrcLayout,
|
||||
class DstEngine, class DstLayout>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
copy_aligned(Tensor<SrcEngine, SrcLayout> const& src,
|
||||
Tensor<DstEngine, DstLayout> && dst)
|
||||
{
|
||||
return copy_aligned(src, dst);
|
||||
}
|
||||
|
||||
} // end namespace cute
|
||||
|
||||
Reference in New Issue
Block a user