mirror of
https://github.com/NVIDIA/cutlass.git
synced 2026-04-19 22:38:56 +00:00
v3.9 update (#2203)
* v3.9 update * voidD --------- Co-authored-by: yuzhai <yuzhai@nvidia.com>
This commit is contained in:
@@ -98,19 +98,23 @@ epilogue_predication(ThrMMA<Args...> const& thr_mma,
|
||||
}
|
||||
}
|
||||
|
||||
template<class Alpha, class TRC, class RCLayout,
|
||||
template<class ... Args,
|
||||
class Alpha, class TRC, class RCLayout,
|
||||
class Beta, class TSC, class SCLayout,
|
||||
class CLoadTransformOp, class CStoreTransformOp,
|
||||
class SmemCopyOpC>
|
||||
class SmemCopyLdOpC, class SmemCopyStOpC>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
epilogue_no_predication(Alpha const& alpha,
|
||||
epilogue_no_predication(uint32_t thread_idx,
|
||||
ThrMMA<Args...> const& thr_mma,
|
||||
Alpha const& alpha,
|
||||
Tensor<TRC, RCLayout> & tCrC,
|
||||
Beta const& beta,
|
||||
Tensor<TSC, SCLayout> & tCsC,
|
||||
Tensor<TSC, SCLayout> & sC,
|
||||
CLoadTransformOp const& sC_load_op, // transforms C values before use in GEMM
|
||||
CStoreTransformOp const& sC_store_op, // transforms results before they are stored to C
|
||||
SmemCopyOpC const& sC_copy_op)
|
||||
SmemCopyLdOpC const& sC_copy_ld_op,
|
||||
SmemCopyStOpC const& sC_copy_st_op)
|
||||
{
|
||||
using InputTypeC = typename TSC::value_type;
|
||||
using ComputeTypeC = typename TRC::value_type;
|
||||
@@ -125,10 +129,18 @@ epilogue_no_predication(Alpha const& alpha,
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
} ();
|
||||
|
||||
Tensor tCrDi = make_fragment_like(tCsC);
|
||||
Tensor tCrD = make_fragment_like(tCrC);
|
||||
Tensor tCrDi = make_fragment_like<InputTypeC>(tCrD);
|
||||
|
||||
if(!isBetaZero) {
|
||||
copy(sC_copy_op, tCsC, tCrDi);
|
||||
auto smem_tiled_copy_C = make_tiled_copy_C(Copy_Atom<SmemCopyLdOpC, InputTypeC>{}, thr_mma);
|
||||
auto smem_thr_copy_C = smem_tiled_copy_C.get_thread_slice(thread_idx);
|
||||
Tensor tCsC = smem_thr_copy_C.partition_S(sC);
|
||||
Tensor tCrDi_copy_view = smem_thr_copy_C.retile_D(tCrDi);
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCsC) == size<1>(tCrDi_copy_view)); // CPY_M
|
||||
CUTE_STATIC_ASSERT_V(size<2>(tCsC) == size<2>(tCrDi_copy_view)); // CPY_N
|
||||
copy(smem_tiled_copy_C, tCsC, tCrDi_copy_view);
|
||||
|
||||
// Transform C on/after load
|
||||
cute::transform(tCrDi, tCrD, sC_load_op);
|
||||
}
|
||||
@@ -136,7 +148,14 @@ epilogue_no_predication(Alpha const& alpha,
|
||||
axpby(alpha, tCrC, beta, tCrD);
|
||||
// Transform C before/on store
|
||||
cute::transform(tCrD, tCrDi, sC_store_op);
|
||||
copy(sC_copy_op, tCrDi, tCsC);
|
||||
|
||||
auto smem_tiled_copy_C = make_tiled_copy_C(Copy_Atom<SmemCopyStOpC, InputTypeC>{}, thr_mma);
|
||||
auto smem_thr_copy_C = smem_tiled_copy_C.get_thread_slice(thread_idx);
|
||||
Tensor tCsC = smem_thr_copy_C.partition_D(sC);
|
||||
Tensor tCrDi_copy_view = smem_thr_copy_C.retile_S(tCrDi);
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCsC) == size<1>(tCrDi_copy_view)); // CPY_M
|
||||
CUTE_STATIC_ASSERT_V(size<2>(tCsC) == size<2>(tCrDi_copy_view)); // CPY_N
|
||||
copy(smem_tiled_copy_C, tCrDi_copy_view, tCsC);
|
||||
}
|
||||
|
||||
// Predicated Cooperative GEMM
|
||||
@@ -283,7 +302,9 @@ cooperative_gemm_no_predication(uint32_t thread_idx,
|
||||
|
||||
// Create register tensors for the MMA to operate on
|
||||
Tensor tCrA = thr_mma.partition_fragment_A(sA); // (MMA,MMA_M,MMA_K)
|
||||
Tensor tCrAi = make_fragment_like<InputTypeA>(tCrA);
|
||||
Tensor tCrB = thr_mma.partition_fragment_B(sB); // (MMA,MMA_N,MMA_K)
|
||||
Tensor tCrBi = make_fragment_like<InputTypeB>(tCrB);
|
||||
|
||||
using CopyOpAType = SmemCopyOpA;
|
||||
using CopyOpBType = SmemCopyOpB;
|
||||
@@ -291,7 +312,6 @@ cooperative_gemm_no_predication(uint32_t thread_idx,
|
||||
auto smem_tiled_copy_A = make_tiled_copy_A(Copy_Atom<CopyOpAType, InputTypeA>{}, thr_mma);
|
||||
auto smem_thr_copy_A = smem_tiled_copy_A.get_thread_slice(thread_idx);
|
||||
Tensor tCsA = smem_thr_copy_A.partition_S(sA);
|
||||
Tensor tCrAi = make_fragment_like(tCsA);
|
||||
Tensor tCrAi_copy_view = smem_thr_copy_A.retile_D(tCrAi);
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrAi_copy_view)); // CPY_M
|
||||
CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCrAi_copy_view)); // CPY_K
|
||||
@@ -299,7 +319,6 @@ cooperative_gemm_no_predication(uint32_t thread_idx,
|
||||
auto smem_tiled_copy_B = make_tiled_copy_B(Copy_Atom<CopyOpBType, InputTypeB>{}, thr_mma);
|
||||
auto smem_thr_copy_B = smem_tiled_copy_B.get_thread_slice(thread_idx);
|
||||
Tensor tCsB = smem_thr_copy_B.partition_S(sB);
|
||||
Tensor tCrBi = make_fragment_like(tCsB);
|
||||
Tensor tCrBi_copy_view = smem_thr_copy_B.retile_D(tCrBi);
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrBi_copy_view)); // CPY_N
|
||||
CUTE_STATIC_ASSERT_V(size<2>(tCsB) == size<2>(tCrBi_copy_view)); // CPY_K
|
||||
@@ -346,7 +365,7 @@ template <class... Args,
|
||||
class ALoadTransformOp = cute::identity, class BLoadTransformOp = cute::identity,
|
||||
class CLoadTransformOp = cute::identity, class CStoreTransformOp = cute::identity,
|
||||
class SmemCopyOpA = DefaultCopy, class SmemCopyOpB = DefaultCopy,
|
||||
class SmemCopyOpC = DefaultCopy>
|
||||
class SmemCopyLdOpC = DefaultCopy, class SmemCopyStOpC = DefaultCopy>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
cooperative_gemm(uint32_t thread_idx,
|
||||
@@ -356,13 +375,14 @@ cooperative_gemm(uint32_t thread_idx,
|
||||
Tensor<TB, BLayout> const& sB,
|
||||
Beta const& beta,
|
||||
Tensor<TC, CLayout> & sC,
|
||||
ALoadTransformOp const& sA_load_op = {}, // transforms A values before use in GEMM
|
||||
BLoadTransformOp const& sB_load_op = {}, // transforms B values before use in GEMM
|
||||
CLoadTransformOp const& sC_load_op = {}, // transforms C values before use in GEMM
|
||||
CStoreTransformOp const& sC_store_op = {}, // transforms results before they are stored to C
|
||||
SmemCopyOpA const& sA_copy_op = {},
|
||||
SmemCopyOpB const& sB_copy_op = {},
|
||||
SmemCopyOpC const& sC_copy_op = {})
|
||||
ALoadTransformOp const& sA_load_op = {}, // transforms A values before use in GEMM
|
||||
BLoadTransformOp const& sB_load_op = {}, // transforms B values before use in GEMM
|
||||
CLoadTransformOp const& sC_load_op = {}, // transforms C values before use in GEMM
|
||||
CStoreTransformOp const& sC_store_op = {}, // transforms results before they are stored to C
|
||||
SmemCopyOpA const& sA_copy_op = {},
|
||||
SmemCopyOpB const& sB_copy_op = {},
|
||||
SmemCopyLdOpC const& sC_copy_ld_op = {},
|
||||
SmemCopyStOpC const& sC_copy_st_op = {})
|
||||
{
|
||||
CUTE_STATIC_ASSERT_V(rank(sA) == Int<2>{});
|
||||
CUTE_STATIC_ASSERT_V(rank(sB) == Int<2>{});
|
||||
@@ -394,7 +414,7 @@ cooperative_gemm(uint32_t thread_idx,
|
||||
thread_idx, thr_mma, sA, sB, tCrC, sA_load_op, sB_load_op, sA_copy_op, sB_copy_op
|
||||
);
|
||||
detail::epilogue_no_predication(
|
||||
alpha, tCrC, beta, tCsC, sC_load_op, sC_store_op, sC_copy_op
|
||||
thread_idx, thr_mma,alpha, tCrC, beta, sC, sC_load_op, sC_store_op, sC_copy_ld_op, sC_copy_st_op
|
||||
);
|
||||
} else {
|
||||
detail::cooperative_gemm_predication(
|
||||
@@ -466,7 +486,7 @@ template <class... Args,
|
||||
class ALoadTransformOp = cute::identity, class BLoadTransformOp = cute::identity,
|
||||
class CLoadTransformOp = cute::identity, class CStoreTransformOp = cute::identity,
|
||||
class SmemCopyOpA = DefaultCopy, class SmemCopyOpB = DefaultCopy,
|
||||
class SmemCopyOpC = DefaultCopy>
|
||||
class SmemCopyLdOpC = DefaultCopy, class SmemCopyStOpC = DefaultCopy>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
cooperative_gemm(uint32_t thread_idx,
|
||||
@@ -476,17 +496,18 @@ cooperative_gemm(uint32_t thread_idx,
|
||||
Tensor<TB, BLayout> const& sB,
|
||||
Beta const& beta,
|
||||
Tensor<TC, CLayout> && sC,
|
||||
ALoadTransformOp const& sA_load_op = {}, // transforms A values before use in GEMM
|
||||
BLoadTransformOp const& sB_load_op = {}, // transforms B values before use in GEMM
|
||||
CLoadTransformOp const& sC_load_op = {}, // transforms C values before use in GEMM
|
||||
CStoreTransformOp const& sC_store_op = {}, // transforms results before they are stored to C
|
||||
SmemCopyOpA const& sA_copy_op = {},
|
||||
SmemCopyOpB const& sB_copy_op = {},
|
||||
SmemCopyOpC const& sC_copy_op = {})
|
||||
ALoadTransformOp const& sA_load_op = {}, // transforms A values before use in GEMM
|
||||
BLoadTransformOp const& sB_load_op = {}, // transforms B values before use in GEMM
|
||||
CLoadTransformOp const& sC_load_op = {}, // transforms C values before use in GEMM
|
||||
CStoreTransformOp const& sC_store_op = {}, // transforms results before they are stored to C
|
||||
SmemCopyOpA const& sA_copy_op = {},
|
||||
SmemCopyOpB const& sB_copy_op = {},
|
||||
SmemCopyLdOpC const& sC_copy_ld_op = {},
|
||||
SmemCopyStOpC const& sC_copy_st_op = {})
|
||||
{
|
||||
cooperative_gemm(thread_idx, tiled_mma, alpha, sA, sB, beta, sC,
|
||||
sA_load_op, sB_load_op, sC_load_op, sC_store_op,
|
||||
sA_copy_op, sB_copy_op, sC_copy_op);
|
||||
sA_copy_op, sB_copy_op, sC_copy_ld_op, sC_copy_st_op);
|
||||
}
|
||||
|
||||
// Legacy overload of cute::gemm for backwards-compatibility
|
||||
|
||||
Reference in New Issue
Block a user