mirror of
https://github.com/NVIDIA/cutlass.git
synced 2026-04-19 22:38:56 +00:00
Updates for CUTLASS 3.5.0 (#1468)
This commit is contained in:
@@ -71,85 +71,103 @@ cooperative_copy(uint32_t const& tid,
|
||||
|
||||
// Precondition on tid in DEBUG
|
||||
assert(tid < NumThreads);
|
||||
// Precondition on pointer alignment in DEBUG
|
||||
assert(is_byte_aligned<max(MaxVecBits/8, 1u)>(raw_pointer_cast(src.data())));
|
||||
assert(is_byte_aligned<max(MaxVecBits/8, 1u)>(raw_pointer_cast(dst.data())));
|
||||
//
|
||||
// Determine val+thr vectorization based on src/dst size and number of threads
|
||||
// NOTE: This heuristic promotes parallelization over vectorization
|
||||
//
|
||||
|
||||
constexpr int elem_bits = sizeof_bits_v<typename SrcEngine::value_type>;
|
||||
// Fallback - slow path, naive copy, vectorization disabled
|
||||
if constexpr(size(SrcLayout{}) % NumThreads != 0) {
|
||||
int index = static_cast<int>(tid);
|
||||
CUTE_UNROLL
|
||||
for(int i = 0; i < ceil_div(size(SrcLayout{}), NumThreads); i++) {
|
||||
if(index < size(SrcLayout{})) {
|
||||
dst[index] = src[index];
|
||||
}
|
||||
index += NumThreads;
|
||||
}
|
||||
} else {
|
||||
// Fast path with vectorization
|
||||
|
||||
// The number of elements that can be vectorized in values
|
||||
constexpr int common_elem = decltype(max_common_vector(src, dst))::value;
|
||||
constexpr int common_bits = common_elem * elem_bits;
|
||||
constexpr int total_elem = decltype(size(src))::value;
|
||||
constexpr int total_bits = total_elem * elem_bits;
|
||||
static_assert(total_bits % NumThreads == 0);
|
||||
constexpr int total_bits_per_thr = total_bits / NumThreads;
|
||||
// If there are too many threads to allow a full elem copy, trunc the thrs and use elem_bits
|
||||
constexpr int max_vec_bits_by_thr = cute::max(elem_bits, total_bits_per_thr);
|
||||
// Precondition on pointer alignment in DEBUG
|
||||
assert(is_byte_aligned<max(MaxVecBits/8, 1u)>(raw_pointer_cast(src.data())));
|
||||
assert(is_byte_aligned<max(MaxVecBits/8, 1u)>(raw_pointer_cast(dst.data())));
|
||||
constexpr int elem_bits = sizeof_bits_v<typename SrcEngine::value_type>;
|
||||
|
||||
// Cap the vectorization to the common bits, the max_vec_bits_by_thr, and the MaxVecBits
|
||||
constexpr int vec_bits = cute::min(common_bits, max_vec_bits_by_thr, static_cast<int>(MaxVecBits));
|
||||
// Convert back to number of elements, safe_div
|
||||
static_assert((vec_bits % elem_bits) == 0);
|
||||
constexpr int vec_elem = vec_bits / elem_bits;
|
||||
//
|
||||
// Determine val+thr vectorization based on src/dst size and number of threads
|
||||
// NOTE: This heuristic promotes parallelization over vectorization
|
||||
//
|
||||
|
||||
// Use only part of threads if there's not enough work for all threads
|
||||
constexpr int vec_thrs = (total_elem % (vec_elem * NumThreads) == 0)
|
||||
? NumThreads
|
||||
: (total_elem / vec_elem);
|
||||
// The number of elements that can be vectorized in values
|
||||
constexpr int common_elem = decltype(max_common_vector(src, dst))::value;
|
||||
constexpr int common_bits = common_elem * elem_bits;
|
||||
constexpr int total_elem = decltype(size(src))::value;
|
||||
constexpr int total_bits = total_elem * elem_bits;
|
||||
static_assert(total_bits % NumThreads == 0);
|
||||
constexpr int total_bits_per_thr = total_bits / NumThreads;
|
||||
// If there are too many threads to allow a full elem copy, trunc the thrs and use elem_bits
|
||||
constexpr int max_vec_bits_by_thr = cute::max(elem_bits, total_bits_per_thr);
|
||||
|
||||
// The common layout of the two tensors that can be vectorized over threads
|
||||
// vidx -> coord
|
||||
auto common_layout = max_common_layout(get_nonswizzle_portion(src.layout()),
|
||||
get_nonswizzle_portion(dst.layout()));
|
||||
// Cap the vectorization to the common bits, the max_vec_bits_by_thr, and the MaxVecBits
|
||||
constexpr int vec_bits = cute::min(common_bits, max_vec_bits_by_thr, static_cast<int>(MaxVecBits));
|
||||
// Convert back to number of elements, safe_div
|
||||
static_assert((vec_bits % elem_bits) == 0);
|
||||
constexpr int vec_elem = vec_bits / elem_bits;
|
||||
|
||||
// Scale up the common_layout to cover the entire tensors
|
||||
// vidx -> coord
|
||||
auto full_perm = tile_to_shape(make_layout(common_layout), size(src));
|
||||
// Use only part of threads if there's not enough work for all threads
|
||||
constexpr int vec_thrs = (total_elem % (vec_elem * NumThreads) == 0)
|
||||
? NumThreads
|
||||
: (total_elem / vec_elem);
|
||||
static_assert(vec_thrs <= NumThreads);
|
||||
|
||||
// Create the Tiler
|
||||
// ((vid,tid),iter)
|
||||
auto layout_vt = logical_divide(full_perm, Layout<Shape<Int<vec_elem>, Int<vec_thrs>>>{});
|
||||
// The common layout of the two tensors that can be vectorized over threads
|
||||
// vidx -> coord
|
||||
auto common_layout = max_common_layout(get_nonswizzle_portion(src.layout()),
|
||||
get_nonswizzle_portion(dst.layout()));
|
||||
|
||||
// Apply and slice
|
||||
Tensor src_v = src.compose(layout_vt)(make_coord(_,tid),_);
|
||||
Tensor dst_v = dst.compose(layout_vt)(make_coord(_,tid),_);
|
||||
// Scale up the common_layout to cover the entire tensors
|
||||
// vidx -> coord
|
||||
auto full_perm = tile_to_shape(make_layout(common_layout), size(src));
|
||||
|
||||
// Should account for vec_bits < 8 and/or vec_elem <= 1
|
||||
// And also account for subbyte types, which could cause race conditions
|
||||
// Want to ENFORCE sufficient vectorization in those cases
|
||||
static_assert((vec_bits >= 8), "No support for subbyte copying");
|
||||
using VecType = uint_bit_t<vec_bits>;
|
||||
// Create the Tiler
|
||||
// ((vid,tid),iter)
|
||||
auto layout_vt = logical_divide(full_perm, Layout<Shape<Int<vec_elem>, Int<vec_thrs>>>{});
|
||||
|
||||
// Apply and slice
|
||||
Tensor src_v = src.compose(layout_vt)(make_coord(_,tid),_);
|
||||
Tensor dst_v = dst.compose(layout_vt)(make_coord(_,tid),_);
|
||||
|
||||
// Should account for vec_bits < 8 and/or vec_elem <= 1
|
||||
// And also account for subbyte types, which could cause race conditions
|
||||
// Want to ENFORCE sufficient vectorization in those cases
|
||||
static_assert((vec_bits >= 8), "No support for subbyte copying");
|
||||
using VecType = uint_bit_t<vec_bits>;
|
||||
|
||||
#if 0
|
||||
if (thread0()) {
|
||||
print(" "); print("NumThreads: "); print(NumThreads); print("\n");
|
||||
print(" "); print("src: "); print(src); print("\n");
|
||||
print(" "); print("dst: "); print(dst); print("\n");
|
||||
print(" "); print("common_layout: "); print(common_layout); print("\n");
|
||||
print(" "); print("full_perm: "); print(full_perm); print("\n");
|
||||
print(" "); print("Used vector: "); print(vec_elem); print("\n");
|
||||
print(" "); print("Used threads: "); print(vec_thrs); print("\n");
|
||||
print(" "); print("layout_vt: "); print(layout_vt); print("\n");
|
||||
print(" "); print("src.compose(layout_vt): "); print(src.compose(layout_vt)); print("\n");
|
||||
print(" "); print("dst.compose(layout_vt): "); print(dst.compose(layout_vt)); print("\n");
|
||||
print(" "); print("src_v: "); print(src_v); print("\n");
|
||||
print(" "); print("dst_v: "); print(dst_v); print("\n");
|
||||
print(" "); print("recast<VecType const>(src_v): "); print(recast<VecType const>(src_v)); print("\n");
|
||||
print(" "); print("recast<VecType const>(dst_v): "); print(recast<VecType const>(dst_v)); print("\n");
|
||||
}
|
||||
if (thread0()) {
|
||||
print(" "); print("cooperative_copy -- vec\n");
|
||||
print(" "); print("NumThreads: "); print(NumThreads); print("\n");
|
||||
print(" "); print("MaxVecBits: "); print(MaxVecBits); print("\n");
|
||||
print(" "); print("src: "); print(src); print("\n");
|
||||
print(" "); print("dst: "); print(dst); print("\n");
|
||||
print(" "); print("common_layout: "); print(common_layout); print("\n");
|
||||
print(" "); print("full_perm: "); print(full_perm); print("\n");
|
||||
print(" "); print("Used vector: "); print(vec_elem); print("\n");
|
||||
print(" "); print("Used threads: "); print(vec_thrs); print("\n");
|
||||
print(" "); print("layout_vt: "); print(layout_vt); print("\n");
|
||||
print(" "); print("src.compose(layout_vt): "); print(src.compose(layout_vt)); print("\n");
|
||||
print(" "); print("dst.compose(layout_vt): "); print(dst.compose(layout_vt)); print("\n");
|
||||
print(" "); print("src_v: "); print(src_v); print("\n");
|
||||
print(" "); print("dst_v: "); print(dst_v); print("\n");
|
||||
print(" "); print("recast<VecType const>(src_v): "); print(recast<VecType const>(src_v)); print("\n");
|
||||
print(" "); print("recast<VecType const>(dst_v): "); print(recast<VecType const>(dst_v)); print("\n");
|
||||
}
|
||||
#ifdef __CUDA_ARCH__
|
||||
__syncthreads();
|
||||
__syncthreads();
|
||||
#endif
|
||||
#endif
|
||||
|
||||
// If we're using all threads (static) or the tid is in in-range (dynamic)
|
||||
if (vec_thrs >= NumThreads or tid < vec_thrs) {
|
||||
return copy_if(TrivialPredTensor{}, recast<VecType const>(src_v), recast<VecType>(dst_v));
|
||||
// If we're using all threads (static) or the tid is in in-range (dynamic)
|
||||
if (vec_thrs >= NumThreads or tid < vec_thrs) {
|
||||
return copy_if(TrivialPredTensor{}, recast<VecType const>(src_v), recast<VecType>(dst_v));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -35,6 +35,7 @@
|
||||
|
||||
#include <cute/atom/mma_atom.hpp>
|
||||
|
||||
#include <cute/algorithm/axpby.hpp>
|
||||
#include <cute/algorithm/functional.hpp>
|
||||
#include <cute/algorithm/gemm.hpp>
|
||||
|
||||
@@ -44,40 +45,37 @@ namespace cute
|
||||
{
|
||||
|
||||
//
|
||||
// Collective Shared-Memory GEMMs
|
||||
// Cooperative Shared-Memory GEMMs
|
||||
//
|
||||
|
||||
namespace detail {
|
||||
|
||||
// Predicated Cooperative GEMM
|
||||
template <class... Args,
|
||||
class Alpha, class TA, class ALayout, class TB, class BLayout,
|
||||
class Beta, class TC, class CLayout,
|
||||
class ALoadTransformOp, class BLoadTransformOp,
|
||||
class CLoadTransformOp, class CStoreTransformOp,
|
||||
__CUTE_REQUIRES(ALayout::rank == 2 && is_smem<TA>::value &&
|
||||
BLayout::rank == 2 && is_smem<TB>::value &&
|
||||
CLayout::rank == 2 && is_smem<TC>::value)>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
cooperative_gemm(ThrMMA<Args...> const& thr_mma,
|
||||
Alpha const& alpha,
|
||||
Tensor<TA, ALayout> sA,
|
||||
Tensor<TB, BLayout> sB,
|
||||
Beta const& beta,
|
||||
Tensor<TC, CLayout> sC,
|
||||
ALoadTransformOp const& sA_load_op /* transforms A values before used in GEMM */,
|
||||
BLoadTransformOp const& sB_load_op /* transforms B values before used in GEMM */)
|
||||
cooperative_gemm_predication(ThrMMA<Args...> const& thr_mma,
|
||||
Alpha const& alpha,
|
||||
Tensor<TA, ALayout> sA,
|
||||
Tensor<TB, BLayout> sB,
|
||||
Beta const& beta,
|
||||
Tensor<TC, CLayout> sC,
|
||||
ALoadTransformOp const& sA_load_op, // transforms A values before use in GEMM
|
||||
BLoadTransformOp const& sB_load_op, // transforms B values before use in GEMM
|
||||
CLoadTransformOp const& sC_load_op, // transforms C values before use in GEMM
|
||||
CStoreTransformOp const& sC_store_op) // transforms results before they are stored to C
|
||||
{
|
||||
CUTE_STATIC_ASSERT_V(size<0>(sA) == size<0>(sC)); // AM == CM
|
||||
CUTE_STATIC_ASSERT_V(size<0>(sB) == size<1>(sC)); // BN == CN
|
||||
CUTE_STATIC_ASSERT_V(size<1>(sA) == size<1>(sB)); // AK == BK
|
||||
|
||||
using TypeA = typename TA::value_type;
|
||||
using TypeB = typename TB::value_type;
|
||||
using TypeC = typename TC::value_type;
|
||||
|
||||
static_assert(is_same_v<decay_t<invoke_result_t<ALoadTransformOp, TypeA>>, TypeA>,
|
||||
"ALoadTransformOp functor must accept and return value of type TA::value_type");
|
||||
static_assert(is_same_v<decay_t<invoke_result_t<BLoadTransformOp, TypeB>>, TypeB>,
|
||||
"BLoadTransformOp functor must accept and return value of type TB::value_type");
|
||||
|
||||
// Original, static size of the problem
|
||||
auto M = size<0>(sC);
|
||||
auto N = size<1>(sC);
|
||||
@@ -88,39 +86,14 @@ cooperative_gemm(ThrMMA<Args...> const& thr_mma,
|
||||
auto BLK_N = tile_size<1>(thr_mma);
|
||||
auto BLK_K = tile_size<2>(thr_mma);
|
||||
|
||||
// Compute the "residues"
|
||||
auto m_residue = M - BLK_M * (ceil_div(M, BLK_M) - Int<1>{}); // (0,BLK_M]
|
||||
auto n_residue = N - BLK_N * (ceil_div(N, BLK_N) - Int<1>{}); // (0,BLK_N]
|
||||
auto k_residue = K - BLK_K * (ceil_div(K, BLK_K) ); // (-BLK_K,0]
|
||||
|
||||
// Shift the origin so k_residue is zeroth tile
|
||||
sA.data() = &sA(0,k_residue);
|
||||
sB.data() = &sB(0,k_residue);
|
||||
|
||||
#if 0
|
||||
if (thread0()) {
|
||||
printf("%d in BLK_M (%d)\n", int(m_residue), int(BLK_M));
|
||||
printf("%d in BLK_N (%d)\n", int(n_residue), int(BLK_N));
|
||||
printf("%d in BLK_K (%d)\n", int(k_residue), int(BLK_K));
|
||||
}
|
||||
#endif
|
||||
|
||||
//
|
||||
// MMA Partitioning
|
||||
//
|
||||
|
||||
// Round the layout extents up to BLK_X
|
||||
Tensor rounded_sA = sA.compose(make_shape(ceil_div(M, BLK_M) * BLK_M, ceil_div(K, BLK_K) * BLK_K));
|
||||
Tensor rounded_sB = sB.compose(make_shape(ceil_div(N, BLK_N) * BLK_N, ceil_div(K, BLK_K) * BLK_K));
|
||||
Tensor rounded_sC = sC.compose(make_shape(ceil_div(M, BLK_M) * BLK_M, ceil_div(N, BLK_N) * BLK_N));
|
||||
|
||||
#if 0
|
||||
if (thread0()) {
|
||||
print("rounded_sA: "); print(rounded_sA); print("\n");
|
||||
print("rounded_sB: "); print(rounded_sB); print("\n");
|
||||
print("rounded_sC: "); print(rounded_sC); print("\n");
|
||||
}
|
||||
#endif
|
||||
// Round the layout extents up to BLK_X to satisfy MMA partitioning safety
|
||||
Tensor rounded_sA = sA.compose(make_shape(round_up(M, BLK_M), round_up(K, BLK_K)));
|
||||
Tensor rounded_sB = sB.compose(make_shape(round_up(N, BLK_N), round_up(K, BLK_K)));
|
||||
Tensor rounded_sC = sC.compose(make_shape(round_up(M, BLK_M), round_up(N, BLK_N)));
|
||||
|
||||
// Partition the sA and sB tiles across the threads for the MMA
|
||||
Tensor tCsA = thr_mma.partition_A(rounded_sA); // (MMA,MMA_M,MMA_K)
|
||||
@@ -133,6 +106,13 @@ cooperative_gemm(ThrMMA<Args...> const& thr_mma,
|
||||
|
||||
#if 0
|
||||
if (thread0()) {
|
||||
print(" sA: "); print( sA); print("\n");
|
||||
print(" sB: "); print( sB); print("\n");
|
||||
print(" sC: "); print( sC); print("\n");
|
||||
print("r_sA: "); print(rounded_sA); print("\n");
|
||||
print("r_sB: "); print(rounded_sB); print("\n");
|
||||
print("r_sC: "); print(rounded_sC); print("\n");
|
||||
print(thr_mma);
|
||||
print("tCsA: "); print(tCsA); print("\n");
|
||||
print("tCsB: "); print(tCsB); print("\n");
|
||||
print("tCsC: "); print(tCsC); print("\n");
|
||||
@@ -146,58 +126,232 @@ cooperative_gemm(ThrMMA<Args...> const& thr_mma,
|
||||
// PREDICATION
|
||||
//
|
||||
|
||||
// Allocate the preds for only the MMA-mode of tCsA and tCsB
|
||||
Tensor tCpA = make_tensor<bool>(size<0>(tCsA));
|
||||
Tensor tCpB = make_tensor<bool>(size<0>(tCsB));
|
||||
|
||||
// Create coordinate tensors on a single compute block for predication
|
||||
Tensor cA = make_identity_tensor(make_shape(BLK_M, BLK_K)); // (BLK_M,BLK_K) -> (blk_m,blk_k)
|
||||
Tensor cB = make_identity_tensor(make_shape(BLK_N, BLK_K)); // (BLK_M,BLK_K) -> (blk_n,blk_k)
|
||||
// Create coordinate tensors for the problem
|
||||
Tensor cA = make_identity_tensor(shape(rounded_sA)); // (M,K) -> (m,k)
|
||||
Tensor cB = make_identity_tensor(shape(rounded_sB)); // (N,K) -> (n,k)
|
||||
|
||||
// Repeat partitioning with thr_mma
|
||||
Tensor tCcA = thr_mma.partition_A(cA); // (MMA,1,1) -> (blk_m,blk_k)
|
||||
Tensor tCcB = thr_mma.partition_B(cB); // (MMA,1,1) -> (blk_n,blk_k)
|
||||
Tensor tCcA = thr_mma.partition_A(cA); // (MMA,MMA_M,MMA_K) -> (m,k)
|
||||
Tensor tCcB = thr_mma.partition_B(cB); // (MMA,MMA_N,MMA_K) -> (n,k)
|
||||
|
||||
// Populate the m and n predicates
|
||||
// Allocate the preds for MMA- and MMA_MN-modes
|
||||
Tensor tCpA = make_tensor<bool>(make_shape(size<0>(tCsA), size<1>(tCsA)));
|
||||
Tensor tCpB = make_tensor<bool>(make_shape(size<0>(tCsB), size<1>(tCsB)));
|
||||
|
||||
// Populate the predicates on M and N
|
||||
CUTE_UNROLL
|
||||
for (int i = 0; i < size(tCpA); ++i) {
|
||||
tCpA(i) = elem_less(get<0>(tCcA(i)), m_residue);
|
||||
tCpA(i) = elem_less(get<0>(tCcA(_,_,Int<0>{})(i)), shape<0>(sA));
|
||||
}
|
||||
CUTE_UNROLL
|
||||
for (int i = 0; i < size(tCpB); ++i) {
|
||||
tCpB(i) = elem_less(get<0>(tCcB(i)), n_residue);
|
||||
tCpB(i) = elem_less(get<0>(tCcB(_,_,Int<0>{})(i)), shape<0>(sB));
|
||||
}
|
||||
|
||||
#if 0
|
||||
printf("Thr %d: A(%d,%d):%d B(%d,%d):%d\n",
|
||||
threadIdx.x,
|
||||
int(get<0>(tCcA(0))), int(get<1>(tCcA(0))), int(tCpA(0)),
|
||||
int(get<0>(tCcB(0))), int(get<1>(tCcB(0))), int(tCpB(0)));
|
||||
if (thread0()) {
|
||||
print(" cA: "); print( cA); print("\n");
|
||||
print(" cB: "); print( cB); print("\n");
|
||||
print("tCcA: "); print(tCcA); print("\n");
|
||||
print("tCcB: "); print(tCcB); print("\n");
|
||||
print_tensor(tCpA);
|
||||
print_tensor(tCpB);
|
||||
}
|
||||
#endif
|
||||
|
||||
//
|
||||
// PREFETCH k_block = 0 (with k-predication)
|
||||
// PREFETCH k_block = 0
|
||||
// Condition the k-predication on (static) k_block == K_BLOCK_MAX-1, the last k_block
|
||||
// Assumes the MMA-tiling in K is trivial
|
||||
//
|
||||
|
||||
CUTE_UNROLL
|
||||
for (int i = 0; i < size<0>(tCsA); ++i) { // Copy MMA_I
|
||||
if (k_residue == 0 || get<1>(tCcA(i)) >= -k_residue) { // k_block = 0, predicated on k
|
||||
CUTE_UNROLL
|
||||
for (int m = 0; m < size<1>(tCsA); ++m) { // Copy MMA_M, predicated on m
|
||||
tCrA(i,m,0) = (m_residue == BLK_M || m < size<1>(tCsA)-1 || tCpA(i)) ? sA_load_op(tCsA(i,m,0)) : TypeA{};
|
||||
}
|
||||
}
|
||||
}
|
||||
constexpr int K_BLOCK_MAX = size<2>(tCrA);
|
||||
|
||||
CUTE_UNROLL
|
||||
for (int i = 0; i < size<0>(tCsB); ++i) { // Copy MMA_I
|
||||
if (k_residue == 0 || get<1>(tCcB(i)) >= -k_residue) { // k_block = 0, predicated on k
|
||||
CUTE_UNROLL
|
||||
for (int n = 0; n < size<1>(tCsB); ++n) { // Copy MMA_N, predicated on n
|
||||
tCrB(i,n,0) = (n_residue == BLK_N || n < size<1>(tCsB)-1 || tCpB(i)) ? sB_load_op(tCsB(i,n,0)) : TypeB{};
|
||||
}
|
||||
for (int m = 0; m < size<1>(tCrA); ++m) { // Copy MMA_M
|
||||
CUTE_UNROLL
|
||||
for (int i = 0; i < size<0>(tCrA); ++i) { // Copy MMA_I
|
||||
tCrA(i,m,0) = (tCpA(i,m) && (0 < K_BLOCK_MAX-1 || elem_less(get<1>(tCcA(i,m,0)), shape<1>(sA)))) ? sA_load_op(tCsA(i,m,0)) : TypeA{};
|
||||
}
|
||||
}
|
||||
CUTE_UNROLL
|
||||
for (int n = 0; n < size<1>(tCrB); ++n) { // Copy MMA_N
|
||||
CUTE_UNROLL
|
||||
for (int i = 0; i < size<0>(tCrB); ++i) { // Copy MMA_I
|
||||
tCrB(i,n,0) = (tCpB(i,n) && (0 < K_BLOCK_MAX-1 || elem_less(get<1>(tCcB(i,n,0)), shape<1>(sB)))) ? sB_load_op(tCsB(i,n,0)) : TypeB{};
|
||||
}
|
||||
}
|
||||
//
|
||||
// MAINLOOP
|
||||
//
|
||||
|
||||
// Clear accumulators
|
||||
clear(tCrC);
|
||||
|
||||
CUTE_UNROLL
|
||||
for (int k_block = 0; k_block < K_BLOCK_MAX; ++k_block)
|
||||
{
|
||||
if (k_block < K_BLOCK_MAX-1) // static-if not the last k_block
|
||||
{
|
||||
int k_next = k_block + 1; // Load k_next block
|
||||
|
||||
// Condition the k-predication on (static) k_block == K_BLOCK_MAX-1, the last k_block
|
||||
// Assumes the MMA-tiling in K is trivial
|
||||
|
||||
CUTE_UNROLL
|
||||
for (int m = 0; m < size<1>(tCrA); ++m) { // Copy MMA_M
|
||||
CUTE_UNROLL
|
||||
for (int i = 0; i < size<0>(tCrA); ++i) { // Copy MMA_I
|
||||
tCrA(i,m,k_next) = (tCpA(i,m) && (k_next < K_BLOCK_MAX-1 || elem_less(get<1>(tCcA(i,m,k_next)), shape<1>(sA)))) ? sA_load_op(tCsA(i,m,k_next)) : TypeA{};
|
||||
}
|
||||
}
|
||||
CUTE_UNROLL
|
||||
for (int n = 0; n < size<1>(tCrB); ++n) { // Copy MMA_N
|
||||
CUTE_UNROLL
|
||||
for (int i = 0; i < size<0>(tCrB); ++i) { // Copy MMA_I
|
||||
tCrB(i,n,k_next) = (tCpB(i,n) && (k_next < K_BLOCK_MAX-1 || elem_less(get<1>(tCcB(i,n,k_next)), shape<1>(sB)))) ? sB_load_op(tCsB(i,n,k_next)) : TypeB{};
|
||||
}
|
||||
}
|
||||
}
|
||||
// GEMM on k_block in registers
|
||||
gemm(thr_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC);
|
||||
}
|
||||
|
||||
//
|
||||
// Epilogue
|
||||
//
|
||||
|
||||
// Create coordinate tensors for the problem
|
||||
Tensor cC = make_identity_tensor(shape(rounded_sC)); // (M,N) -> (m,n)
|
||||
// Repeat partitioning with thr_mma
|
||||
Tensor tCcC = thr_mma.partition_C(cC); // (MMA,MMA_M,MMA_N) -> (m,n)
|
||||
|
||||
const bool isBetaZero = (beta == Beta{});
|
||||
|
||||
// Custom axpby_if for now
|
||||
CUTE_UNROLL
|
||||
for (int i = 0; i < size(tCrC); ++i)
|
||||
{
|
||||
if (elem_less(tCcC(i), shape(sC)))
|
||||
{
|
||||
tCsC(i) = sC_store_op(isBetaZero ? alpha * static_cast<TypeC>(tCrC(i))
|
||||
: alpha * static_cast<TypeC>(tCrC(i)) +
|
||||
beta * static_cast<TypeC>(sC_load_op(tCsC(i))));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Slow fallback path
|
||||
template <class... Args,
|
||||
class Alpha, class TA, class ALayout, class TB, class BLayout,
|
||||
class Beta, class TC, class CLayout,
|
||||
class ALoadTransformOp, class BLoadTransformOp,
|
||||
class CLoadTransformOp, class CStoreTransformOp,
|
||||
__CUTE_REQUIRES(ALayout::rank == 2 && is_smem<TA>::value &&
|
||||
BLayout::rank == 2 && is_smem<TB>::value &&
|
||||
CLayout::rank == 2 && is_smem<TC>::value)>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
cooperative_gemm_predication(uint32_t thread_idx,
|
||||
TiledMMA<Args...> const& tiled_mma,
|
||||
Alpha const& alpha,
|
||||
Tensor<TA, ALayout> sA,
|
||||
Tensor<TB, BLayout> sB,
|
||||
Beta const& beta,
|
||||
Tensor<TC, CLayout> sC,
|
||||
ALoadTransformOp const& sA_load_op, // transforms A values before use in GEMM
|
||||
BLoadTransformOp const& sB_load_op, // transforms B values before use in GEMM
|
||||
CLoadTransformOp const& sC_load_op, // transforms C values before use in GEMM
|
||||
CStoreTransformOp const& sC_store_op) // transforms results before they are stored to C
|
||||
{
|
||||
// ThrMMA
|
||||
auto thr_mma = tiled_mma.get_thread_slice(thread_idx);
|
||||
cooperative_gemm_predication(thr_mma, alpha, sA, sB, beta, sC, sA_load_op, sB_load_op, sC_load_op, sC_store_op);
|
||||
}
|
||||
|
||||
// Unpredicated Cooperative GEMM
|
||||
template <class SmemCopyOpA, class SmemCopyOpB, class SmemCopyOpC,
|
||||
class... Args,
|
||||
class Alpha, class TA, class ALayout, class TB, class BLayout,
|
||||
class Beta, class TC, class CLayout,
|
||||
class ALoadTransformOp, class BLoadTransformOp,
|
||||
class CLoadTransformOp, class CStoreTransformOp,
|
||||
__CUTE_REQUIRES(ALayout::rank == 2 && is_smem<TA>::value &&
|
||||
BLayout::rank == 2 && is_smem<TB>::value &&
|
||||
CLayout::rank == 2 && is_smem<TC>::value)>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
cooperative_gemm_no_predication(uint32_t thread_idx,
|
||||
TiledMMA<Args...> const& tiled_mma,
|
||||
Alpha const& alpha,
|
||||
Tensor<TA, ALayout> sA,
|
||||
Tensor<TB, BLayout> sB,
|
||||
Beta const& beta,
|
||||
Tensor<TC, CLayout> sC,
|
||||
ALoadTransformOp const& sA_load_op, // transforms A values before use in GEMM
|
||||
BLoadTransformOp const& sB_load_op, // transforms B values before use in GEMM
|
||||
CLoadTransformOp const& sC_load_op, // transforms C values before use in GEMM
|
||||
CStoreTransformOp const& sC_store_op) // transforms results before they are stored to C
|
||||
{
|
||||
using TypeA = typename TA::value_type;
|
||||
using TypeB = typename TB::value_type;
|
||||
using TypeC = typename TC::value_type;
|
||||
|
||||
// ThrMMA
|
||||
auto thr_mma = tiled_mma.get_thread_slice(thread_idx);
|
||||
|
||||
//
|
||||
// MMA Partitioning
|
||||
//
|
||||
|
||||
Tensor tCsC = thr_mma.partition_C(sC);
|
||||
// Create register tensors for the MMA to operate on
|
||||
Tensor tCrA = thr_mma.partition_fragment_A(sA); // (MMA,MMA_M,MMA_K)
|
||||
Tensor tCrB = thr_mma.partition_fragment_B(sB); // (MMA,MMA_N,MMA_K)
|
||||
Tensor tCrC = thr_mma.make_fragment_C(tCsC); // (MMA,MMA_M,MMA_N)
|
||||
|
||||
using CopyOpAType = SmemCopyOpA;
|
||||
using CopyOpBType = SmemCopyOpB;
|
||||
|
||||
auto smem_tiled_copy_A = make_tiled_copy_A(Copy_Atom<CopyOpAType, TypeA>{}, thr_mma);
|
||||
auto smem_thr_copy_A = smem_tiled_copy_A.get_thread_slice(thread_idx);
|
||||
Tensor tCsA = smem_thr_copy_A.partition_S(sA);
|
||||
Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA);
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // CPY_M
|
||||
CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCrA_copy_view)); // CPY_K
|
||||
|
||||
auto smem_tiled_copy_B = make_tiled_copy_B(Copy_Atom<CopyOpBType, TypeB>{}, thr_mma);
|
||||
auto smem_thr_copy_B = smem_tiled_copy_B.get_thread_slice(thread_idx);
|
||||
Tensor tCsB = smem_thr_copy_B.partition_S(sB);
|
||||
Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB);
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // CPY_N
|
||||
CUTE_STATIC_ASSERT_V(size<2>(tCsB) == size<2>(tCrB_copy_view)); // CPY_K
|
||||
|
||||
#if 0
|
||||
if (thread0()) {
|
||||
print(" sA: "); print(sA); print("\n");
|
||||
print(" sB: "); print(sB); print("\n");
|
||||
print(" sC: "); print(sC); print("\n");
|
||||
print(thr_mma); print("\n");
|
||||
print("tCsC: "); print(tCsC); print("\n");
|
||||
print("tCrA: "); print(tCrA); print("\n");
|
||||
print("tCrB: "); print(tCrB); print("\n");
|
||||
print("tCrC: "); print(tCrC); print("\n");
|
||||
print(smem_thr_copy_A); print("\n");
|
||||
print("tCsA: "); print(tCsA); print("\n");
|
||||
print("tCrA_copy_view: "); print(tCrA_copy_view); print("\n");
|
||||
print(smem_thr_copy_B); print("\n");
|
||||
print("tCsB: "); print(tCsB); print("\n");
|
||||
print("tCrB_copy_view: "); print(tCrB_copy_view); print("\n");
|
||||
}
|
||||
#endif
|
||||
|
||||
//
|
||||
// PREFETCH
|
||||
//
|
||||
|
||||
copy(smem_tiled_copy_A, tCsA(_,_,Int<0>{}), tCrA_copy_view(_,_,Int<0>{}));
|
||||
copy(smem_tiled_copy_B, tCsB(_,_,Int<0>{}), tCrB_copy_view(_,_,Int<0>{}));
|
||||
//
|
||||
// MAINLOOP
|
||||
//
|
||||
@@ -214,25 +368,15 @@ cooperative_gemm(ThrMMA<Args...> const& thr_mma,
|
||||
if (k_block < K_BLOCK_MAX-1)
|
||||
{
|
||||
// Load the next k_block
|
||||
int k_next = k_block + 1;
|
||||
|
||||
CUTE_UNROLL
|
||||
for (int m = 0; m < size<1>(tCsA); ++m) { // Copy MMA_M
|
||||
CUTE_UNROLL
|
||||
for (int i = 0; i < size<0>(tCsA); ++i) { // Copy_if MMA_I predicated on m
|
||||
tCrA(i,m,k_next) = (m_residue == BLK_M || m < size<1>(tCsA)-1 || tCpA(i)) ? sA_load_op(tCsA(i,m,k_next)) : TypeA{};
|
||||
}
|
||||
}
|
||||
|
||||
CUTE_UNROLL
|
||||
for (int n = 0; n < size<1>(tCsB); ++n) { // Copy MMA_N
|
||||
CUTE_UNROLL
|
||||
for (int i = 0; i < size<0>(tCsB); ++i) { // Copy MMA_I predicated on n
|
||||
tCrB(i,n,k_next) = (n_residue == BLK_N || n < size<1>(tCsB)-1 || tCpB(i)) ? sB_load_op(tCsB(i,n,k_next)) : TypeB{};
|
||||
}
|
||||
}
|
||||
int k_next = k_block + 1; // statically unrolled
|
||||
copy(smem_tiled_copy_A, tCsA(_,_,k_next), tCrA_copy_view(_,_,k_next));
|
||||
copy(smem_tiled_copy_B, tCsB(_,_,k_next), tCrB_copy_view(_,_,k_next));
|
||||
}
|
||||
|
||||
// Transform A and B, relying on the compiler to remove in case of identity ops
|
||||
cute::transform(tCrA(_,_,k_block), sA_load_op);
|
||||
cute::transform(tCrB(_,_,k_block), sB_load_op);
|
||||
|
||||
// GEMM on k_block in registers
|
||||
gemm(thr_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC);
|
||||
}
|
||||
@@ -241,53 +385,124 @@ cooperative_gemm(ThrMMA<Args...> const& thr_mma,
|
||||
// Epilogue
|
||||
//
|
||||
|
||||
Tensor cC = make_identity_tensor(make_shape(BLK_M, BLK_N)); // (BLK_M,BLK_N) -> (blk_m,blk_n)
|
||||
Tensor tCcC = thr_mma.partition_C(cC); // (MMA, 1, 1) -> (blk_m,blk_n)
|
||||
|
||||
const bool isBetaZero = (beta == Beta{});
|
||||
|
||||
// Custom axpby_if for now
|
||||
CUTE_UNROLL
|
||||
for (int m = 0; m < size<1>(tCsC); ++m)
|
||||
{
|
||||
CUTE_UNROLL
|
||||
for (int n = 0; n < size<2>(tCsC); ++n)
|
||||
{
|
||||
CUTE_UNROLL
|
||||
for (int i = 0; i < size<0>(tCsC); ++i)
|
||||
{
|
||||
if ((m_residue == BLK_M || m < size<1>(tCrC)-1 || get<0>(tCcC(i)) < m_residue) &&
|
||||
(n_residue == BLK_N || n < size<2>(tCrC)-1 || get<1>(tCcC(i)) < n_residue))
|
||||
{
|
||||
tCsC(i,m,n) = isBetaZero ? alpha * static_cast<TypeC>(tCrC(i,m,n)) : alpha * static_cast<TypeC>(tCrC(i,m,n)) + beta * static_cast<TypeC>(tCsC(i,m,n));
|
||||
}
|
||||
}
|
||||
auto isBetaZero = [&] () {
|
||||
if constexpr (is_complex<Beta>::value) {
|
||||
return beta.real() == Int<0>{} && beta.imag() == Int<0>{};
|
||||
}
|
||||
else {
|
||||
return beta == Int<0>{};
|
||||
}
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
} ();
|
||||
|
||||
using CopyOpCType = SmemCopyOpC;
|
||||
Tensor tCrD = thr_mma.make_fragment_C(tCsC);
|
||||
if(!isBetaZero) {
|
||||
copy(CopyOpCType{}, tCsC, tCrD);
|
||||
// Transform C on/after load
|
||||
cute::transform(tCrD, sC_load_op);
|
||||
}
|
||||
// C = alpha * (A * B) + beta * C
|
||||
axpby(alpha, tCrC, beta, tCrD);
|
||||
// Transform C before/on store
|
||||
cute::transform(tCrD, sC_store_op);
|
||||
copy(CopyOpCType{}, tCrD, tCsC);
|
||||
}
|
||||
|
||||
} // end namespace detail
|
||||
|
||||
template <class SmemCopyOpA, class SmemCopyOpB, class SmemCopyOpC,
|
||||
class... Args,
|
||||
class Alpha, class TA, class ALayout, class TB, class BLayout,
|
||||
class Beta, class TC, class CLayout,
|
||||
class ALoadTransformOp = cute::identity, class BLoadTransformOp = cute::identity,
|
||||
class CLoadTransformOp = cute::identity, class CStoreTransformOp = cute::identity,
|
||||
__CUTE_REQUIRES(ALayout::rank == 2 && is_smem<TA>::value &&
|
||||
BLayout::rank == 2 && is_smem<TB>::value &&
|
||||
CLayout::rank == 2 && is_smem<TC>::value)>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
cooperative_gemm(uint32_t thread_idx,
|
||||
TiledMMA<Args...> const& tiled_mma,
|
||||
Alpha const& alpha,
|
||||
Tensor<TA, ALayout> sA,
|
||||
Tensor<TB, BLayout> sB,
|
||||
Beta const& beta,
|
||||
Tensor<TC, CLayout> sC,
|
||||
ALoadTransformOp const& sA_load_op = {}, // transforms A values before use in GEMM
|
||||
BLoadTransformOp const& sB_load_op = {}, // transforms B values before use in GEMM
|
||||
CLoadTransformOp const& sC_load_op = {}, // transforms C values before use in GEMM
|
||||
CStoreTransformOp const& sC_store_op = {}) // transforms results before they are stored to C
|
||||
{
|
||||
CUTE_STATIC_ASSERT_V(size<0>(sA) == size<0>(sC)); // AM == CM
|
||||
CUTE_STATIC_ASSERT_V(size<0>(sB) == size<1>(sC)); // BN == CN
|
||||
CUTE_STATIC_ASSERT_V(size<1>(sA) == size<1>(sB)); // AK == BK
|
||||
|
||||
using TypeA = typename TA::value_type;
|
||||
using TypeB = typename TB::value_type;
|
||||
using TypeC = typename TC::value_type;
|
||||
|
||||
static_assert(is_convertible_v<decay_t<invoke_result_t<ALoadTransformOp, TypeA>>, TypeA>,
|
||||
"ALoadTransformOp functor must accept value of type TA::value_type and return value convertible to type TA::value_type");
|
||||
static_assert(is_convertible_v<decay_t<invoke_result_t<BLoadTransformOp, TypeB>>, TypeB>,
|
||||
"BLoadTransformOp functor must accept value of type TB::value_type and return value convertible to type TB::value_type");
|
||||
static_assert(is_convertible_v<decay_t<invoke_result_t<CLoadTransformOp, TypeC>>, TypeC>,
|
||||
"CLoadTransformOp functor must accept value of type TC::value_type and return value convertible to type TC::value_type");
|
||||
static_assert(is_convertible_v<decay_t<invoke_result_t<CStoreTransformOp, TypeC>>, TypeC>,
|
||||
"CStoreTransformOp functor must accept value of type TC::value_type and return value convertible to type TC::value_type");
|
||||
|
||||
static constexpr bool compat = weakly_compatible(tile_shape(TiledMMA<Args...>{}),
|
||||
make_shape(size<0>(sA), size<0>(sB), size<1>(sA)));
|
||||
if constexpr (compat) {
|
||||
detail::cooperative_gemm_no_predication<SmemCopyOpA, SmemCopyOpB, SmemCopyOpC>(
|
||||
thread_idx, tiled_mma, alpha, sA, sB, beta, sC,
|
||||
sA_load_op, sB_load_op, sC_load_op, sC_store_op
|
||||
);
|
||||
} else {
|
||||
detail::cooperative_gemm_predication(
|
||||
thread_idx, tiled_mma, alpha, sA, sB, beta, sC,
|
||||
sA_load_op, sB_load_op, sC_load_op, sC_store_op
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
template <class... Args,
|
||||
class Alpha, class TA, class ALayout, class TB, class BLayout,
|
||||
class Beta, class TC, class CLayout,
|
||||
class ALoadTransformOp = cute::identity, class BLoadTransformOp = cute::identity,
|
||||
class CLoadTransformOp = cute::identity, class CStoreTransformOp = cute::identity,
|
||||
__CUTE_REQUIRES(ALayout::rank == 2 && is_smem<TA>::value &&
|
||||
BLayout::rank == 2 && is_smem<TB>::value &&
|
||||
CLayout::rank == 2 && is_smem<TC>::value)>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
cooperative_gemm(ThrMMA<Args...> const& thr_mma,
|
||||
cooperative_gemm(uint32_t thread_idx,
|
||||
TiledMMA<Args...> const& tiled_mma,
|
||||
Alpha const& alpha,
|
||||
Tensor<TA, ALayout> sA,
|
||||
Tensor<TB, BLayout> sB,
|
||||
Beta const& beta,
|
||||
Tensor<TC, CLayout> sC)
|
||||
Tensor<TC, CLayout> sC,
|
||||
ALoadTransformOp const& sA_load_op = {}, // transforms A values before use in GEMM
|
||||
BLoadTransformOp const& sB_load_op = {}, // transforms B values before use in GEMM
|
||||
CLoadTransformOp const& sC_load_op = {}, // transforms C values before use in GEMM
|
||||
CStoreTransformOp const& sC_store_op = {}) // transforms results before they are stored to C
|
||||
{
|
||||
cooperative_gemm(thr_mma, alpha, sA, sB, beta, sC, identity() /* sA_load_op */, identity() /* sB_load_op */);
|
||||
using CopyOpA = AutoVectorizingCopyWithAssumedAlignment<sizeof_bits_v<typename TA::value_type>>;
|
||||
using CopyOpB = AutoVectorizingCopyWithAssumedAlignment<sizeof_bits_v<typename TB::value_type>>;
|
||||
using CopyOpC = AutoVectorizingCopyWithAssumedAlignment<sizeof_bits_v<typename TC::value_type>>;
|
||||
cooperative_gemm<CopyOpA, CopyOpB, CopyOpC>(
|
||||
thread_idx, tiled_mma, alpha, sA, sB, beta, sC,
|
||||
sA_load_op, sB_load_op, sC_load_op, sC_store_op
|
||||
);
|
||||
}
|
||||
|
||||
// Legacy overload of cute::gemm for backwards-compatibility
|
||||
template <class... Args,
|
||||
class Alpha, class TA, class ALayout, class TB, class BLayout,
|
||||
class Beta, class TC, class CLayout,
|
||||
class ALoadTransformOp, class BLoadTransformOp,
|
||||
class ALoadTransformOp = cute::identity, class BLoadTransformOp = cute::identity,
|
||||
class CLoadTransformOp = cute::identity, class CStoreTransformOp = cute::identity,
|
||||
__CUTE_REQUIRES(ALayout::rank == 2 && is_smem<TA>::value &&
|
||||
BLayout::rank == 2 && is_smem<TB>::value &&
|
||||
CLayout::rank == 2 && is_smem<TC>::value)>
|
||||
@@ -299,28 +514,16 @@ gemm(ThrMMA<Args...> const& thr_mma,
|
||||
Tensor<TB, BLayout> sB,
|
||||
Beta const& beta,
|
||||
Tensor<TC, CLayout> sC,
|
||||
ALoadTransformOp const& sA_load_op /* transforms A values before used in GEMM */,
|
||||
BLoadTransformOp const& sB_load_op /* transforms B values before used in GEMM */)
|
||||
ALoadTransformOp const& sA_load_op = {}, // transforms A values before use in GEMM
|
||||
BLoadTransformOp const& sB_load_op = {}, // transforms B values before use in GEMM
|
||||
CLoadTransformOp const& sC_load_op = {}, // transforms C values before use in GEMM
|
||||
CStoreTransformOp const& sC_store_op = {}) // transforms results before they are stored to C
|
||||
{
|
||||
cooperative_gemm(thr_mma, alpha, sA, sB, beta, sC, sA_load_op, sB_load_op);
|
||||
}
|
||||
|
||||
template <class... Args,
|
||||
class Alpha, class TA, class ALayout, class TB, class BLayout,
|
||||
class Beta, class TC, class CLayout,
|
||||
__CUTE_REQUIRES(ALayout::rank == 2 && is_smem<TA>::value &&
|
||||
BLayout::rank == 2 && is_smem<TB>::value &&
|
||||
CLayout::rank == 2 && is_smem<TC>::value)>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
gemm(ThrMMA<Args...> const& thr_mma,
|
||||
Alpha const& alpha,
|
||||
Tensor<TA, ALayout> sA,
|
||||
Tensor<TB, BLayout> sB,
|
||||
Beta const& beta,
|
||||
Tensor<TC, CLayout> sC)
|
||||
{
|
||||
cooperative_gemm(thr_mma, alpha, sA, sB, beta, sC, identity() /* sA_load_op */, identity() /* sB_load_op */);
|
||||
// Goes directly to the slow path to avoid getting thread_idx from thr_mma
|
||||
detail::cooperative_gemm_predication(
|
||||
thr_mma, alpha, sA, sB, beta, sC,
|
||||
sA_load_op, sB_load_op, sC_load_op, sC_store_op
|
||||
);
|
||||
}
|
||||
|
||||
} // end namespace cute
|
||||
|
||||
Reference in New Issue
Block a user