v3.9 update (#2203)

* v3.9 update

* voidD

---------

Co-authored-by: yuzhai <yuzhai@nvidia.com>
This commit is contained in:
Yujia Zhai
2025-04-02 12:11:18 -07:00
committed by GitHub
parent 62750a2b75
commit 6f4921858b
129 changed files with 7719 additions and 2036 deletions

View File

@@ -98,19 +98,23 @@ epilogue_predication(ThrMMA<Args...> const& thr_mma,
}
}
template<class Alpha, class TRC, class RCLayout,
template<class ... Args,
class Alpha, class TRC, class RCLayout,
class Beta, class TSC, class SCLayout,
class CLoadTransformOp, class CStoreTransformOp,
class SmemCopyOpC>
class SmemCopyLdOpC, class SmemCopyStOpC>
CUTE_HOST_DEVICE
void
epilogue_no_predication(Alpha const& alpha,
epilogue_no_predication(uint32_t thread_idx,
ThrMMA<Args...> const& thr_mma,
Alpha const& alpha,
Tensor<TRC, RCLayout> & tCrC,
Beta const& beta,
Tensor<TSC, SCLayout> & tCsC,
Tensor<TSC, SCLayout> & sC,
CLoadTransformOp const& sC_load_op, // transforms C values before use in GEMM
CStoreTransformOp const& sC_store_op, // transforms results before they are stored to C
SmemCopyOpC const& sC_copy_op)
SmemCopyLdOpC const& sC_copy_ld_op,
SmemCopyStOpC const& sC_copy_st_op)
{
using InputTypeC = typename TSC::value_type;
using ComputeTypeC = typename TRC::value_type;
@@ -125,10 +129,18 @@ epilogue_no_predication(Alpha const& alpha,
CUTE_GCC_UNREACHABLE;
} ();
Tensor tCrDi = make_fragment_like(tCsC);
Tensor tCrD = make_fragment_like(tCrC);
Tensor tCrDi = make_fragment_like<InputTypeC>(tCrD);
if(!isBetaZero) {
copy(sC_copy_op, tCsC, tCrDi);
auto smem_tiled_copy_C = make_tiled_copy_C(Copy_Atom<SmemCopyLdOpC, InputTypeC>{}, thr_mma);
auto smem_thr_copy_C = smem_tiled_copy_C.get_thread_slice(thread_idx);
Tensor tCsC = smem_thr_copy_C.partition_S(sC);
Tensor tCrDi_copy_view = smem_thr_copy_C.retile_D(tCrDi);
CUTE_STATIC_ASSERT_V(size<1>(tCsC) == size<1>(tCrDi_copy_view)); // CPY_M
CUTE_STATIC_ASSERT_V(size<2>(tCsC) == size<2>(tCrDi_copy_view)); // CPY_N
copy(smem_tiled_copy_C, tCsC, tCrDi_copy_view);
// Transform C on/after load
cute::transform(tCrDi, tCrD, sC_load_op);
}
@@ -136,7 +148,14 @@ epilogue_no_predication(Alpha const& alpha,
axpby(alpha, tCrC, beta, tCrD);
// Transform C before/on store
cute::transform(tCrD, tCrDi, sC_store_op);
copy(sC_copy_op, tCrDi, tCsC);
auto smem_tiled_copy_C = make_tiled_copy_C(Copy_Atom<SmemCopyStOpC, InputTypeC>{}, thr_mma);
auto smem_thr_copy_C = smem_tiled_copy_C.get_thread_slice(thread_idx);
Tensor tCsC = smem_thr_copy_C.partition_D(sC);
Tensor tCrDi_copy_view = smem_thr_copy_C.retile_S(tCrDi);
CUTE_STATIC_ASSERT_V(size<1>(tCsC) == size<1>(tCrDi_copy_view)); // CPY_M
CUTE_STATIC_ASSERT_V(size<2>(tCsC) == size<2>(tCrDi_copy_view)); // CPY_N
copy(smem_tiled_copy_C, tCrDi_copy_view, tCsC);
}
// Predicated Cooperative GEMM
@@ -283,7 +302,9 @@ cooperative_gemm_no_predication(uint32_t thread_idx,
// Create register tensors for the MMA to operate on
Tensor tCrA = thr_mma.partition_fragment_A(sA); // (MMA,MMA_M,MMA_K)
Tensor tCrAi = make_fragment_like<InputTypeA>(tCrA);
Tensor tCrB = thr_mma.partition_fragment_B(sB); // (MMA,MMA_N,MMA_K)
Tensor tCrBi = make_fragment_like<InputTypeB>(tCrB);
using CopyOpAType = SmemCopyOpA;
using CopyOpBType = SmemCopyOpB;
@@ -291,7 +312,6 @@ cooperative_gemm_no_predication(uint32_t thread_idx,
auto smem_tiled_copy_A = make_tiled_copy_A(Copy_Atom<CopyOpAType, InputTypeA>{}, thr_mma);
auto smem_thr_copy_A = smem_tiled_copy_A.get_thread_slice(thread_idx);
Tensor tCsA = smem_thr_copy_A.partition_S(sA);
Tensor tCrAi = make_fragment_like(tCsA);
Tensor tCrAi_copy_view = smem_thr_copy_A.retile_D(tCrAi);
CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrAi_copy_view)); // CPY_M
CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCrAi_copy_view)); // CPY_K
@@ -299,7 +319,6 @@ cooperative_gemm_no_predication(uint32_t thread_idx,
auto smem_tiled_copy_B = make_tiled_copy_B(Copy_Atom<CopyOpBType, InputTypeB>{}, thr_mma);
auto smem_thr_copy_B = smem_tiled_copy_B.get_thread_slice(thread_idx);
Tensor tCsB = smem_thr_copy_B.partition_S(sB);
Tensor tCrBi = make_fragment_like(tCsB);
Tensor tCrBi_copy_view = smem_thr_copy_B.retile_D(tCrBi);
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrBi_copy_view)); // CPY_N
CUTE_STATIC_ASSERT_V(size<2>(tCsB) == size<2>(tCrBi_copy_view)); // CPY_K
@@ -346,7 +365,7 @@ template <class... Args,
class ALoadTransformOp = cute::identity, class BLoadTransformOp = cute::identity,
class CLoadTransformOp = cute::identity, class CStoreTransformOp = cute::identity,
class SmemCopyOpA = DefaultCopy, class SmemCopyOpB = DefaultCopy,
class SmemCopyOpC = DefaultCopy>
class SmemCopyLdOpC = DefaultCopy, class SmemCopyStOpC = DefaultCopy>
CUTE_HOST_DEVICE
void
cooperative_gemm(uint32_t thread_idx,
@@ -356,13 +375,14 @@ cooperative_gemm(uint32_t thread_idx,
Tensor<TB, BLayout> const& sB,
Beta const& beta,
Tensor<TC, CLayout> & sC,
ALoadTransformOp const& sA_load_op = {}, // transforms A values before use in GEMM
BLoadTransformOp const& sB_load_op = {}, // transforms B values before use in GEMM
CLoadTransformOp const& sC_load_op = {}, // transforms C values before use in GEMM
CStoreTransformOp const& sC_store_op = {}, // transforms results before they are stored to C
SmemCopyOpA const& sA_copy_op = {},
SmemCopyOpB const& sB_copy_op = {},
SmemCopyOpC const& sC_copy_op = {})
ALoadTransformOp const& sA_load_op = {}, // transforms A values before use in GEMM
BLoadTransformOp const& sB_load_op = {}, // transforms B values before use in GEMM
CLoadTransformOp const& sC_load_op = {}, // transforms C values before use in GEMM
CStoreTransformOp const& sC_store_op = {}, // transforms results before they are stored to C
SmemCopyOpA const& sA_copy_op = {},
SmemCopyOpB const& sB_copy_op = {},
SmemCopyLdOpC const& sC_copy_ld_op = {},
SmemCopyStOpC const& sC_copy_st_op = {})
{
CUTE_STATIC_ASSERT_V(rank(sA) == Int<2>{});
CUTE_STATIC_ASSERT_V(rank(sB) == Int<2>{});
@@ -394,7 +414,7 @@ cooperative_gemm(uint32_t thread_idx,
thread_idx, thr_mma, sA, sB, tCrC, sA_load_op, sB_load_op, sA_copy_op, sB_copy_op
);
detail::epilogue_no_predication(
alpha, tCrC, beta, tCsC, sC_load_op, sC_store_op, sC_copy_op
thread_idx, thr_mma,alpha, tCrC, beta, sC, sC_load_op, sC_store_op, sC_copy_ld_op, sC_copy_st_op
);
} else {
detail::cooperative_gemm_predication(
@@ -466,7 +486,7 @@ template <class... Args,
class ALoadTransformOp = cute::identity, class BLoadTransformOp = cute::identity,
class CLoadTransformOp = cute::identity, class CStoreTransformOp = cute::identity,
class SmemCopyOpA = DefaultCopy, class SmemCopyOpB = DefaultCopy,
class SmemCopyOpC = DefaultCopy>
class SmemCopyLdOpC = DefaultCopy, class SmemCopyStOpC = DefaultCopy>
CUTE_HOST_DEVICE
void
cooperative_gemm(uint32_t thread_idx,
@@ -476,17 +496,18 @@ cooperative_gemm(uint32_t thread_idx,
Tensor<TB, BLayout> const& sB,
Beta const& beta,
Tensor<TC, CLayout> && sC,
ALoadTransformOp const& sA_load_op = {}, // transforms A values before use in GEMM
BLoadTransformOp const& sB_load_op = {}, // transforms B values before use in GEMM
CLoadTransformOp const& sC_load_op = {}, // transforms C values before use in GEMM
CStoreTransformOp const& sC_store_op = {}, // transforms results before they are stored to C
SmemCopyOpA const& sA_copy_op = {},
SmemCopyOpB const& sB_copy_op = {},
SmemCopyOpC const& sC_copy_op = {})
ALoadTransformOp const& sA_load_op = {}, // transforms A values before use in GEMM
BLoadTransformOp const& sB_load_op = {}, // transforms B values before use in GEMM
CLoadTransformOp const& sC_load_op = {}, // transforms C values before use in GEMM
CStoreTransformOp const& sC_store_op = {}, // transforms results before they are stored to C
SmemCopyOpA const& sA_copy_op = {},
SmemCopyOpB const& sB_copy_op = {},
SmemCopyLdOpC const& sC_copy_ld_op = {},
SmemCopyStOpC const& sC_copy_st_op = {})
{
cooperative_gemm(thread_idx, tiled_mma, alpha, sA, sB, beta, sC,
sA_load_op, sB_load_op, sC_load_op, sC_store_op,
sA_copy_op, sB_copy_op, sC_copy_op);
sA_copy_op, sB_copy_op, sC_copy_ld_op, sC_copy_st_op);
}
// Legacy overload of cute::gemm for backwards-compatibility