mirror of
https://github.com/NVIDIA/cutlass.git
synced 2026-04-20 06:48:59 +00:00
CUTLASS 3.5.0 (#1411)
This commit is contained in:
@@ -33,6 +33,7 @@
|
||||
#include <cute/config.hpp>
|
||||
|
||||
#include <cute/tensor.hpp>
|
||||
#include <cute/tensor_predicate.hpp>
|
||||
|
||||
namespace cute
|
||||
{
|
||||
@@ -43,15 +44,17 @@ namespace cute
|
||||
template <class Alpha,
|
||||
class XEngine, class XLayout,
|
||||
class Beta,
|
||||
class YEngine, class YLayout>
|
||||
class YEngine, class YLayout,
|
||||
class PrdTensor = TrivialPredTensor>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
axpby(Alpha const& alpha,
|
||||
Tensor<XEngine, XLayout> const& x,
|
||||
Beta const& beta,
|
||||
Tensor<YEngine, YLayout> && y)
|
||||
Tensor<YEngine, YLayout> && y,
|
||||
PrdTensor const& p = {})
|
||||
{
|
||||
return axpby(alpha, x, beta, y);
|
||||
return axpby(alpha, x, beta, y, p);
|
||||
}
|
||||
|
||||
//
|
||||
@@ -60,13 +63,15 @@ axpby(Alpha const& alpha,
|
||||
template <class Alpha,
|
||||
class XEngine, class XLayout,
|
||||
class Beta,
|
||||
class YEngine, class YLayout>
|
||||
class YEngine, class YLayout,
|
||||
class PrdTensor = TrivialPredTensor>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
axpby(Alpha const& alpha,
|
||||
Tensor<XEngine, XLayout> const& x,
|
||||
Beta const& beta,
|
||||
Tensor<YEngine, YLayout> & y)
|
||||
Tensor<YEngine, YLayout> & y,
|
||||
PrdTensor const& p = {})
|
||||
{
|
||||
auto isBetaZero = [&] () {
|
||||
if constexpr (is_complex<Beta>::value) {
|
||||
@@ -81,7 +86,9 @@ axpby(Alpha const& alpha,
|
||||
|
||||
CUTE_UNROLL
|
||||
for (int i = 0; i < size(x); ++i) {
|
||||
y(i) = (isBetaZero ? alpha * x(i) : alpha * x(i) + beta * y(i));
|
||||
if (p(i)) {
|
||||
y(i) = (isBetaZero ? alpha * x(i) : alpha * x(i) + beta * y(i));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
196
include/cute/algorithm/cooperative_copy.hpp
Normal file
196
include/cute/algorithm/cooperative_copy.hpp
Normal file
@@ -0,0 +1,196 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include <cute/config.hpp>
|
||||
|
||||
#include <cute/atom/copy_atom.hpp>
|
||||
|
||||
#include <cute/algorithm/copy.hpp>
|
||||
|
||||
#include <cute/tensor.hpp>
|
||||
#include <cute/tensor_predicate.hpp>
|
||||
|
||||
namespace cute
|
||||
{
|
||||
|
||||
// cooperative_copy<NumThreads, MaxVecBits>(thr_idx, src, dst)
|
||||
// Use NumThreads to copy src to dst with element vectorization up to MaxVecBits.
|
||||
// @pre 0 <= @a tid < NumThreads
|
||||
// @pre Tensors @a src and @a dst are aligned up to MaxVecBits.
|
||||
//
|
||||
template <uint32_t NumThreads, uint32_t MaxVecBits,
|
||||
class SrcEngine, class SrcLayout,
|
||||
class DstEngine, class DstLayout>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
cooperative_copy(uint32_t const& tid,
|
||||
Tensor<SrcEngine, SrcLayout> const& src,
|
||||
Tensor<DstEngine, DstLayout> & dst)
|
||||
{
|
||||
// Assumes the shapes are static, can generalize
|
||||
CUTE_STATIC_ASSERT_V(size(src) == size(dst));
|
||||
// Assumes the types are the same, can generalize
|
||||
static_assert(sizeof_bits_v<typename SrcEngine::value_type> == sizeof_bits_v<typename DstEngine::value_type>);
|
||||
static_assert(MaxVecBits == sizeof_bits_v<typename SrcEngine::value_type> ||
|
||||
MaxVecBits == 8 || MaxVecBits == 16 || MaxVecBits == 32 || MaxVecBits == 64 || MaxVecBits == 128,
|
||||
"Expected MaxVecBits to be value size or 8 or 16 or 32 or 64 or 128 for alignment and performance.");
|
||||
// Check that the tensors are likely shared across threads: either gmem or smem
|
||||
static_assert((is_gmem<SrcEngine>::value || is_smem<SrcEngine>::value),
|
||||
"cooperative_copy expects shared gmem or smem source tensor.");
|
||||
static_assert((is_gmem<DstEngine>::value || is_smem<DstEngine>::value),
|
||||
"cooperative_copy expects shared gmem or smem destination tensor.");
|
||||
|
||||
// Precondition on tid in DEBUG
|
||||
assert(tid < NumThreads);
|
||||
// Precondition on pointer alignment in DEBUG
|
||||
assert(is_byte_aligned<max(MaxVecBits/8, 1u)>(raw_pointer_cast(src.data())));
|
||||
assert(is_byte_aligned<max(MaxVecBits/8, 1u)>(raw_pointer_cast(dst.data())));
|
||||
//
|
||||
// Determine val+thr vectorization based on src/dst size and number of threads
|
||||
// NOTE: This heuristic promotes parallelization over vectorization
|
||||
//
|
||||
|
||||
constexpr int elem_bits = sizeof_bits_v<typename SrcEngine::value_type>;
|
||||
|
||||
// The number of elements that can be vectorized in values
|
||||
constexpr int common_elem = decltype(max_common_vector(src, dst))::value;
|
||||
constexpr int common_bits = common_elem * elem_bits;
|
||||
constexpr int total_elem = decltype(size(src))::value;
|
||||
constexpr int total_bits = total_elem * elem_bits;
|
||||
static_assert(total_bits % NumThreads == 0);
|
||||
constexpr int total_bits_per_thr = total_bits / NumThreads;
|
||||
// If there are too many threads to allow a full elem copy, trunc the thrs and use elem_bits
|
||||
constexpr int max_vec_bits_by_thr = cute::max(elem_bits, total_bits_per_thr);
|
||||
|
||||
// Cap the vectorization to the common bits, the max_vec_bits_by_thr, and the MaxVecBits
|
||||
constexpr int vec_bits = cute::min(common_bits, max_vec_bits_by_thr, static_cast<int>(MaxVecBits));
|
||||
// Convert back to number of elements, safe_div
|
||||
static_assert((vec_bits % elem_bits) == 0);
|
||||
constexpr int vec_elem = vec_bits / elem_bits;
|
||||
|
||||
// Use only part of threads if there's not enough work for all threads
|
||||
constexpr int vec_thrs = (total_elem % (vec_elem * NumThreads) == 0)
|
||||
? NumThreads
|
||||
: (total_elem / vec_elem);
|
||||
|
||||
// The common layout of the two tensors that can be vectorized over threads
|
||||
// vidx -> coord
|
||||
auto common_layout = max_common_layout(get_nonswizzle_portion(src.layout()),
|
||||
get_nonswizzle_portion(dst.layout()));
|
||||
|
||||
// Scale up the common_layout to cover the entire tensors
|
||||
// vidx -> coord
|
||||
auto full_perm = tile_to_shape(make_layout(common_layout), size(src));
|
||||
|
||||
// Create the Tiler
|
||||
// ((vid,tid),iter)
|
||||
auto layout_vt = logical_divide(full_perm, Layout<Shape<Int<vec_elem>, Int<vec_thrs>>>{});
|
||||
|
||||
// Apply and slice
|
||||
Tensor src_v = src.compose(layout_vt)(make_coord(_,tid),_);
|
||||
Tensor dst_v = dst.compose(layout_vt)(make_coord(_,tid),_);
|
||||
|
||||
// Should account for vec_bits < 8 and/or vec_elem <= 1
|
||||
// And also account for subbyte types, which could cause race conditions
|
||||
// Want to ENFORCE sufficient vectorization in those cases
|
||||
static_assert((vec_bits >= 8), "No support for subbyte copying");
|
||||
using VecType = uint_bit_t<vec_bits>;
|
||||
|
||||
#if 0
|
||||
if (thread0()) {
|
||||
print(" "); print("NumThreads: "); print(NumThreads); print("\n");
|
||||
print(" "); print("src: "); print(src); print("\n");
|
||||
print(" "); print("dst: "); print(dst); print("\n");
|
||||
print(" "); print("common_layout: "); print(common_layout); print("\n");
|
||||
print(" "); print("full_perm: "); print(full_perm); print("\n");
|
||||
print(" "); print("Used vector: "); print(vec_elem); print("\n");
|
||||
print(" "); print("Used threads: "); print(vec_thrs); print("\n");
|
||||
print(" "); print("layout_vt: "); print(layout_vt); print("\n");
|
||||
print(" "); print("src.compose(layout_vt): "); print(src.compose(layout_vt)); print("\n");
|
||||
print(" "); print("dst.compose(layout_vt): "); print(dst.compose(layout_vt)); print("\n");
|
||||
print(" "); print("src_v: "); print(src_v); print("\n");
|
||||
print(" "); print("dst_v: "); print(dst_v); print("\n");
|
||||
print(" "); print("recast<VecType const>(src_v): "); print(recast<VecType const>(src_v)); print("\n");
|
||||
print(" "); print("recast<VecType const>(dst_v): "); print(recast<VecType const>(dst_v)); print("\n");
|
||||
}
|
||||
#ifdef __CUDA_ARCH__
|
||||
__syncthreads();
|
||||
#endif
|
||||
#endif
|
||||
|
||||
// If we're using all threads (static) or the tid is in in-range (dynamic)
|
||||
if (vec_thrs >= NumThreads or tid < vec_thrs) {
|
||||
return copy_if(TrivialPredTensor{}, recast<VecType const>(src_v), recast<VecType>(dst_v));
|
||||
}
|
||||
}
|
||||
|
||||
template <uint32_t NumThreads,
|
||||
class SrcEngine, class SrcLayout,
|
||||
class DstEngine, class DstLayout>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
cooperative_copy(uint32_t const& tid,
|
||||
Tensor<SrcEngine, SrcLayout> const& src,
|
||||
Tensor<DstEngine, DstLayout> & dst)
|
||||
{
|
||||
constexpr uint32_t MaxVecBits = sizeof_bits_v<typename SrcEngine::value_type>;
|
||||
return cooperative_copy<NumThreads, MaxVecBits>(tid, src, dst);
|
||||
}
|
||||
|
||||
// Accept mutable temporaries
|
||||
template <uint32_t NumThreads,
|
||||
class SrcEngine, class SrcLayout,
|
||||
class DstEngine, class DstLayout>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
cooperative_copy(uint32_t const& tid,
|
||||
Tensor<SrcEngine, SrcLayout> const& src,
|
||||
Tensor<DstEngine, DstLayout> && dst)
|
||||
{
|
||||
return cooperative_copy<NumThreads>(tid, src, dst);
|
||||
}
|
||||
|
||||
// Accept mutable temporaries
|
||||
template <uint32_t NumThreads,
|
||||
uint32_t MaxVecBits,
|
||||
class SrcEngine, class SrcLayout,
|
||||
class DstEngine, class DstLayout>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
cooperative_copy(uint32_t const& tid,
|
||||
Tensor<SrcEngine, SrcLayout> const& src,
|
||||
Tensor<DstEngine, DstLayout> && dst)
|
||||
{
|
||||
return cooperative_copy<NumThreads, MaxVecBits>(tid, src, dst);
|
||||
}
|
||||
|
||||
} // end namespace cute
|
||||
326
include/cute/algorithm/cooperative_gemm.hpp
Normal file
326
include/cute/algorithm/cooperative_gemm.hpp
Normal file
@@ -0,0 +1,326 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include <cute/config.hpp>
|
||||
#include <cute/util/type_traits.hpp>
|
||||
|
||||
#include <cute/atom/mma_atom.hpp>
|
||||
|
||||
#include <cute/algorithm/functional.hpp>
|
||||
#include <cute/algorithm/gemm.hpp>
|
||||
|
||||
#include <cute/tensor.hpp>
|
||||
|
||||
namespace cute
|
||||
{
|
||||
|
||||
//
|
||||
// Collective Shared-Memory GEMMs
|
||||
//
|
||||
|
||||
template <class... Args,
|
||||
class Alpha, class TA, class ALayout, class TB, class BLayout,
|
||||
class Beta, class TC, class CLayout,
|
||||
class ALoadTransformOp, class BLoadTransformOp,
|
||||
__CUTE_REQUIRES(ALayout::rank == 2 && is_smem<TA>::value &&
|
||||
BLayout::rank == 2 && is_smem<TB>::value &&
|
||||
CLayout::rank == 2 && is_smem<TC>::value)>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
cooperative_gemm(ThrMMA<Args...> const& thr_mma,
|
||||
Alpha const& alpha,
|
||||
Tensor<TA, ALayout> sA,
|
||||
Tensor<TB, BLayout> sB,
|
||||
Beta const& beta,
|
||||
Tensor<TC, CLayout> sC,
|
||||
ALoadTransformOp const& sA_load_op /* transforms A values before used in GEMM */,
|
||||
BLoadTransformOp const& sB_load_op /* transforms B values before used in GEMM */)
|
||||
{
|
||||
CUTE_STATIC_ASSERT_V(size<0>(sA) == size<0>(sC)); // AM == CM
|
||||
CUTE_STATIC_ASSERT_V(size<0>(sB) == size<1>(sC)); // BN == CN
|
||||
CUTE_STATIC_ASSERT_V(size<1>(sA) == size<1>(sB)); // AK == BK
|
||||
|
||||
using TypeA = typename TA::value_type;
|
||||
using TypeB = typename TB::value_type;
|
||||
using TypeC = typename TC::value_type;
|
||||
|
||||
static_assert(is_same_v<decay_t<invoke_result_t<ALoadTransformOp, TypeA>>, TypeA>,
|
||||
"ALoadTransformOp functor must accept and return value of type TA::value_type");
|
||||
static_assert(is_same_v<decay_t<invoke_result_t<BLoadTransformOp, TypeB>>, TypeB>,
|
||||
"BLoadTransformOp functor must accept and return value of type TB::value_type");
|
||||
|
||||
// Original, static size of the problem
|
||||
auto M = size<0>(sC);
|
||||
auto N = size<1>(sC);
|
||||
auto K = size<1>(sA);
|
||||
|
||||
// Block size of the compute tile
|
||||
auto BLK_M = tile_size<0>(thr_mma);
|
||||
auto BLK_N = tile_size<1>(thr_mma);
|
||||
auto BLK_K = tile_size<2>(thr_mma);
|
||||
|
||||
// Compute the "residues"
|
||||
auto m_residue = M - BLK_M * (ceil_div(M, BLK_M) - Int<1>{}); // (0,BLK_M]
|
||||
auto n_residue = N - BLK_N * (ceil_div(N, BLK_N) - Int<1>{}); // (0,BLK_N]
|
||||
auto k_residue = K - BLK_K * (ceil_div(K, BLK_K) ); // (-BLK_K,0]
|
||||
|
||||
// Shift the origin so k_residue is zeroth tile
|
||||
sA.data() = &sA(0,k_residue);
|
||||
sB.data() = &sB(0,k_residue);
|
||||
|
||||
#if 0
|
||||
if (thread0()) {
|
||||
printf("%d in BLK_M (%d)\n", int(m_residue), int(BLK_M));
|
||||
printf("%d in BLK_N (%d)\n", int(n_residue), int(BLK_N));
|
||||
printf("%d in BLK_K (%d)\n", int(k_residue), int(BLK_K));
|
||||
}
|
||||
#endif
|
||||
|
||||
//
|
||||
// MMA Partitioning
|
||||
//
|
||||
|
||||
// Round the layout extents up to BLK_X
|
||||
Tensor rounded_sA = sA.compose(make_shape(ceil_div(M, BLK_M) * BLK_M, ceil_div(K, BLK_K) * BLK_K));
|
||||
Tensor rounded_sB = sB.compose(make_shape(ceil_div(N, BLK_N) * BLK_N, ceil_div(K, BLK_K) * BLK_K));
|
||||
Tensor rounded_sC = sC.compose(make_shape(ceil_div(M, BLK_M) * BLK_M, ceil_div(N, BLK_N) * BLK_N));
|
||||
|
||||
#if 0
|
||||
if (thread0()) {
|
||||
print("rounded_sA: "); print(rounded_sA); print("\n");
|
||||
print("rounded_sB: "); print(rounded_sB); print("\n");
|
||||
print("rounded_sC: "); print(rounded_sC); print("\n");
|
||||
}
|
||||
#endif
|
||||
|
||||
// Partition the sA and sB tiles across the threads for the MMA
|
||||
Tensor tCsA = thr_mma.partition_A(rounded_sA); // (MMA,MMA_M,MMA_K)
|
||||
Tensor tCsB = thr_mma.partition_B(rounded_sB); // (MMA,MMA_N,MMA_K)
|
||||
Tensor tCsC = thr_mma.partition_C(rounded_sC); // (MMA,MMA_M,MMA_N)
|
||||
// Create register tensors for the MMA to operate on
|
||||
Tensor tCrA = thr_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K)
|
||||
Tensor tCrB = thr_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K)
|
||||
Tensor tCrC = thr_mma.make_fragment_C(tCsC); // (MMA,MMA_M,MMA_N)
|
||||
|
||||
#if 0
|
||||
if (thread0()) {
|
||||
print("tCsA: "); print(tCsA); print("\n");
|
||||
print("tCsB: "); print(tCsB); print("\n");
|
||||
print("tCsC: "); print(tCsC); print("\n");
|
||||
print("tCrA: "); print(tCrA); print("\n");
|
||||
print("tCrB: "); print(tCrB); print("\n");
|
||||
print("tCrC: "); print(tCrC); print("\n");
|
||||
}
|
||||
#endif
|
||||
|
||||
//
|
||||
// PREDICATION
|
||||
//
|
||||
|
||||
// Allocate the preds for only the MMA-mode of tCsA and tCsB
|
||||
Tensor tCpA = make_tensor<bool>(size<0>(tCsA));
|
||||
Tensor tCpB = make_tensor<bool>(size<0>(tCsB));
|
||||
|
||||
// Create coordinate tensors on a single compute block for predication
|
||||
Tensor cA = make_identity_tensor(make_shape(BLK_M, BLK_K)); // (BLK_M,BLK_K) -> (blk_m,blk_k)
|
||||
Tensor cB = make_identity_tensor(make_shape(BLK_N, BLK_K)); // (BLK_M,BLK_K) -> (blk_n,blk_k)
|
||||
|
||||
// Repeat partitioning with thr_mma
|
||||
Tensor tCcA = thr_mma.partition_A(cA); // (MMA,1,1) -> (blk_m,blk_k)
|
||||
Tensor tCcB = thr_mma.partition_B(cB); // (MMA,1,1) -> (blk_n,blk_k)
|
||||
|
||||
// Populate the m and n predicates
|
||||
CUTE_UNROLL
|
||||
for (int i = 0; i < size(tCpA); ++i) {
|
||||
tCpA(i) = elem_less(get<0>(tCcA(i)), m_residue);
|
||||
}
|
||||
CUTE_UNROLL
|
||||
for (int i = 0; i < size(tCpB); ++i) {
|
||||
tCpB(i) = elem_less(get<0>(tCcB(i)), n_residue);
|
||||
}
|
||||
|
||||
#if 0
|
||||
printf("Thr %d: A(%d,%d):%d B(%d,%d):%d\n",
|
||||
threadIdx.x,
|
||||
int(get<0>(tCcA(0))), int(get<1>(tCcA(0))), int(tCpA(0)),
|
||||
int(get<0>(tCcB(0))), int(get<1>(tCcB(0))), int(tCpB(0)));
|
||||
#endif
|
||||
|
||||
//
|
||||
// PREFETCH k_block = 0 (with k-predication)
|
||||
//
|
||||
|
||||
CUTE_UNROLL
|
||||
for (int i = 0; i < size<0>(tCsA); ++i) { // Copy MMA_I
|
||||
if (k_residue == 0 || get<1>(tCcA(i)) >= -k_residue) { // k_block = 0, predicated on k
|
||||
CUTE_UNROLL
|
||||
for (int m = 0; m < size<1>(tCsA); ++m) { // Copy MMA_M, predicated on m
|
||||
tCrA(i,m,0) = (m_residue == BLK_M || m < size<1>(tCsA)-1 || tCpA(i)) ? sA_load_op(tCsA(i,m,0)) : TypeA{};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
CUTE_UNROLL
|
||||
for (int i = 0; i < size<0>(tCsB); ++i) { // Copy MMA_I
|
||||
if (k_residue == 0 || get<1>(tCcB(i)) >= -k_residue) { // k_block = 0, predicated on k
|
||||
CUTE_UNROLL
|
||||
for (int n = 0; n < size<1>(tCsB); ++n) { // Copy MMA_N, predicated on n
|
||||
tCrB(i,n,0) = (n_residue == BLK_N || n < size<1>(tCsB)-1 || tCpB(i)) ? sB_load_op(tCsB(i,n,0)) : TypeB{};
|
||||
}
|
||||
}
|
||||
}
|
||||
//
|
||||
// MAINLOOP
|
||||
//
|
||||
|
||||
// Clear accumulators
|
||||
clear(tCrC);
|
||||
|
||||
constexpr int K_BLOCK_MAX = size<2>(tCrA);
|
||||
|
||||
CUTE_UNROLL
|
||||
for (int k_block = 0; k_block < K_BLOCK_MAX; ++k_block)
|
||||
{
|
||||
// static-if load the next k_block. No k-predication required on these loads.
|
||||
if (k_block < K_BLOCK_MAX-1)
|
||||
{
|
||||
// Load the next k_block
|
||||
int k_next = k_block + 1;
|
||||
|
||||
CUTE_UNROLL
|
||||
for (int m = 0; m < size<1>(tCsA); ++m) { // Copy MMA_M
|
||||
CUTE_UNROLL
|
||||
for (int i = 0; i < size<0>(tCsA); ++i) { // Copy_if MMA_I predicated on m
|
||||
tCrA(i,m,k_next) = (m_residue == BLK_M || m < size<1>(tCsA)-1 || tCpA(i)) ? sA_load_op(tCsA(i,m,k_next)) : TypeA{};
|
||||
}
|
||||
}
|
||||
|
||||
CUTE_UNROLL
|
||||
for (int n = 0; n < size<1>(tCsB); ++n) { // Copy MMA_N
|
||||
CUTE_UNROLL
|
||||
for (int i = 0; i < size<0>(tCsB); ++i) { // Copy MMA_I predicated on n
|
||||
tCrB(i,n,k_next) = (n_residue == BLK_N || n < size<1>(tCsB)-1 || tCpB(i)) ? sB_load_op(tCsB(i,n,k_next)) : TypeB{};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GEMM on k_block in registers
|
||||
gemm(thr_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC);
|
||||
}
|
||||
|
||||
//
|
||||
// Epilogue
|
||||
//
|
||||
|
||||
Tensor cC = make_identity_tensor(make_shape(BLK_M, BLK_N)); // (BLK_M,BLK_N) -> (blk_m,blk_n)
|
||||
Tensor tCcC = thr_mma.partition_C(cC); // (MMA, 1, 1) -> (blk_m,blk_n)
|
||||
|
||||
const bool isBetaZero = (beta == Beta{});
|
||||
|
||||
// Custom axpby_if for now
|
||||
CUTE_UNROLL
|
||||
for (int m = 0; m < size<1>(tCsC); ++m)
|
||||
{
|
||||
CUTE_UNROLL
|
||||
for (int n = 0; n < size<2>(tCsC); ++n)
|
||||
{
|
||||
CUTE_UNROLL
|
||||
for (int i = 0; i < size<0>(tCsC); ++i)
|
||||
{
|
||||
if ((m_residue == BLK_M || m < size<1>(tCrC)-1 || get<0>(tCcC(i)) < m_residue) &&
|
||||
(n_residue == BLK_N || n < size<2>(tCrC)-1 || get<1>(tCcC(i)) < n_residue))
|
||||
{
|
||||
tCsC(i,m,n) = isBetaZero ? alpha * static_cast<TypeC>(tCrC(i,m,n)) : alpha * static_cast<TypeC>(tCrC(i,m,n)) + beta * static_cast<TypeC>(tCsC(i,m,n));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <class... Args,
|
||||
class Alpha, class TA, class ALayout, class TB, class BLayout,
|
||||
class Beta, class TC, class CLayout,
|
||||
__CUTE_REQUIRES(ALayout::rank == 2 && is_smem<TA>::value &&
|
||||
BLayout::rank == 2 && is_smem<TB>::value &&
|
||||
CLayout::rank == 2 && is_smem<TC>::value)>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
cooperative_gemm(ThrMMA<Args...> const& thr_mma,
|
||||
Alpha const& alpha,
|
||||
Tensor<TA, ALayout> sA,
|
||||
Tensor<TB, BLayout> sB,
|
||||
Beta const& beta,
|
||||
Tensor<TC, CLayout> sC)
|
||||
{
|
||||
cooperative_gemm(thr_mma, alpha, sA, sB, beta, sC, identity() /* sA_load_op */, identity() /* sB_load_op */);
|
||||
}
|
||||
|
||||
template <class... Args,
|
||||
class Alpha, class TA, class ALayout, class TB, class BLayout,
|
||||
class Beta, class TC, class CLayout,
|
||||
class ALoadTransformOp, class BLoadTransformOp,
|
||||
__CUTE_REQUIRES(ALayout::rank == 2 && is_smem<TA>::value &&
|
||||
BLayout::rank == 2 && is_smem<TB>::value &&
|
||||
CLayout::rank == 2 && is_smem<TC>::value)>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
gemm(ThrMMA<Args...> const& thr_mma,
|
||||
Alpha const& alpha,
|
||||
Tensor<TA, ALayout> sA,
|
||||
Tensor<TB, BLayout> sB,
|
||||
Beta const& beta,
|
||||
Tensor<TC, CLayout> sC,
|
||||
ALoadTransformOp const& sA_load_op /* transforms A values before used in GEMM */,
|
||||
BLoadTransformOp const& sB_load_op /* transforms B values before used in GEMM */)
|
||||
{
|
||||
cooperative_gemm(thr_mma, alpha, sA, sB, beta, sC, sA_load_op, sB_load_op);
|
||||
}
|
||||
|
||||
template <class... Args,
|
||||
class Alpha, class TA, class ALayout, class TB, class BLayout,
|
||||
class Beta, class TC, class CLayout,
|
||||
__CUTE_REQUIRES(ALayout::rank == 2 && is_smem<TA>::value &&
|
||||
BLayout::rank == 2 && is_smem<TB>::value &&
|
||||
CLayout::rank == 2 && is_smem<TC>::value)>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
gemm(ThrMMA<Args...> const& thr_mma,
|
||||
Alpha const& alpha,
|
||||
Tensor<TA, ALayout> sA,
|
||||
Tensor<TB, BLayout> sB,
|
||||
Beta const& beta,
|
||||
Tensor<TC, CLayout> sC)
|
||||
{
|
||||
cooperative_gemm(thr_mma, alpha, sA, sB, beta, sC, identity() /* sA_load_op */, identity() /* sB_load_op */);
|
||||
}
|
||||
|
||||
} // end namespace cute
|
||||
@@ -145,10 +145,10 @@ copy_if(PrdTensor const& pred,
|
||||
namespace detail {
|
||||
|
||||
// Trait that detects if atom's traits has a member function with(bool)
|
||||
template<typename, typename Enable = void>
|
||||
template <class, class Enable = void>
|
||||
constexpr bool has_with_bool = false;
|
||||
|
||||
template<typename T>
|
||||
template <class T>
|
||||
constexpr bool has_with_bool<T, cute::void_t<decltype(declval<typename T::Traits>().with(declval<bool>()))>> = true;
|
||||
|
||||
} // end namespace detail
|
||||
|
||||
@@ -33,6 +33,7 @@
|
||||
#include <cute/config.hpp>
|
||||
|
||||
#include <cute/util/type_traits.hpp>
|
||||
#include <cute/numeric/complex.hpp>
|
||||
|
||||
/** C++14 <functional> extensions */
|
||||
|
||||
@@ -46,7 +47,7 @@ struct identity {
|
||||
template <class T>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
decltype(auto) operator()(T&& arg) const {
|
||||
return std::forward<T>(arg);
|
||||
return static_cast<T&&>(arg);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -69,7 +70,7 @@ struct constant_fn {
|
||||
template <class T> \
|
||||
CUTE_HOST_DEVICE constexpr \
|
||||
decltype(auto) operator()(T&& arg) const { \
|
||||
return OP std::forward<T>(arg); \
|
||||
return OP static_cast<T&&>(arg); \
|
||||
} \
|
||||
}
|
||||
#define CUTE_RIGHT_UNARY_OP(NAME,OP) \
|
||||
@@ -77,7 +78,7 @@ struct constant_fn {
|
||||
template <class T> \
|
||||
CUTE_HOST_DEVICE constexpr \
|
||||
decltype(auto) operator()(T&& arg) const { \
|
||||
return std::forward<T>(arg) OP ; \
|
||||
return static_cast<T&&>(arg) OP ; \
|
||||
} \
|
||||
}
|
||||
#define CUTE_NAMED_UNARY_OP(NAME,OP) \
|
||||
@@ -85,7 +86,7 @@ struct constant_fn {
|
||||
template <class T> \
|
||||
CUTE_HOST_DEVICE constexpr \
|
||||
decltype(auto) operator()(T&& arg) const { \
|
||||
return OP (std::forward<T>(arg)); \
|
||||
return OP (static_cast<T&&>(arg)); \
|
||||
} \
|
||||
}
|
||||
|
||||
@@ -115,7 +116,7 @@ struct shift_right_const {
|
||||
template <class T>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
decltype(auto) operator()(T&& arg) const {
|
||||
return std::forward<T>(arg) >> Shift;
|
||||
return static_cast<T&&>(arg) >> Shift;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -126,7 +127,7 @@ struct shift_left_const {
|
||||
template <class T>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
decltype(auto) operator()(T&& arg) const {
|
||||
return std::forward<T>(arg) << Shift;
|
||||
return static_cast<T&&>(arg) << Shift;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -139,7 +140,7 @@ struct shift_left_const {
|
||||
template <class T, class U> \
|
||||
CUTE_HOST_DEVICE constexpr \
|
||||
decltype(auto) operator()(T&& lhs, U&& rhs) const { \
|
||||
return std::forward<T>(lhs) OP std::forward<U>(rhs); \
|
||||
return static_cast<T&&>(lhs) OP static_cast<U&&>(rhs); \
|
||||
} \
|
||||
}
|
||||
#define CUTE_NAMED_BINARY_OP(NAME,OP) \
|
||||
@@ -147,7 +148,7 @@ struct shift_left_const {
|
||||
template <class T, class U> \
|
||||
CUTE_HOST_DEVICE constexpr \
|
||||
decltype(auto) operator()(T&& lhs, U&& rhs) const { \
|
||||
return OP (std::forward<T>(lhs), std::forward<U>(rhs)); \
|
||||
return OP (static_cast<T&&>(lhs), static_cast<U&&>(rhs)); \
|
||||
} \
|
||||
}
|
||||
|
||||
@@ -273,7 +274,7 @@ struct bound_fn {
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
decltype(auto)
|
||||
operator()(T&& arg) {
|
||||
return fn_(arg_, std::forward<T>(arg));
|
||||
return fn_(arg_, static_cast<T&&>(arg));
|
||||
}
|
||||
|
||||
Fn fn_;
|
||||
|
||||
@@ -252,7 +252,7 @@ gemm(MMA_Atom<MMA> const& mma,
|
||||
CUTE_STATIC_ASSERT_V(size<1>(typename MMA_Atom<MMA>::LayoutC_TV{}) == Int<1>{});
|
||||
CUTE_STATIC_ASSERT_V(size<1>(typename MMA_Atom<MMA>::LayoutA_TV{}) == Int<1>{});
|
||||
CUTE_STATIC_ASSERT_V(size<1>(typename MMA_Atom<MMA>::LayoutB_TV{}) == Int<1>{});
|
||||
|
||||
|
||||
gemm(mma,
|
||||
make_tensor(D.data(), prepend<3>(D.layout())), // (1,M,N)
|
||||
make_tensor(A.data(), prepend<3>(A.layout())), // (1,M,K)
|
||||
@@ -451,6 +451,7 @@ gemm(MMA_Atom<MMA> const& mma,
|
||||
CUTE_STATIC_ASSERT_V(size<1>(typename MMA_Atom<MMA>::LayoutC_TV{}) == Int<1>{});
|
||||
CUTE_STATIC_ASSERT_V(size<1>(typename MMA_Atom<MMA>::LayoutA_TV{}) == Int<1>{});
|
||||
CUTE_STATIC_ASSERT_V(size<1>(typename MMA_Atom<MMA>::LayoutB_TV{}) == Int<1>{});
|
||||
|
||||
gemm(mma,
|
||||
make_tensor(D.data(), prepend<3>(D.layout())), // (1,M,N)
|
||||
make_tensor(A.data(), prepend<3>(A.layout())), // (1,M,K)
|
||||
@@ -496,245 +497,4 @@ gemm(MMA_Atom<MMA> const& mma,
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// Collective Shared-Memory GEMMs
|
||||
//
|
||||
|
||||
template <class... Args,
|
||||
class Alpha, class TA, class ALayout, class TB, class BLayout,
|
||||
class Beta, class TC, class CLayout,
|
||||
class ALoadTransformOp, class BLoadTransformOp,
|
||||
__CUTE_REQUIRES(ALayout::rank == 2 && is_smem<TA>::value &&
|
||||
BLayout::rank == 2 && is_smem<TB>::value &&
|
||||
CLayout::rank == 2 && is_smem<TC>::value)>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
gemm(ThrMMA<Args...> const& thr_mma,
|
||||
Alpha const& alpha,
|
||||
Tensor<TA, ALayout> sA,
|
||||
Tensor<TB, BLayout> sB,
|
||||
Beta const& beta,
|
||||
Tensor<TC, CLayout> sC,
|
||||
ALoadTransformOp const& sA_load_op /* transforms A values before used in GEMM */,
|
||||
BLoadTransformOp const& sB_load_op /* transforms B values before used in GEMM */)
|
||||
{
|
||||
CUTE_STATIC_ASSERT_V(size<0>(sA) == size<0>(sC)); // AM == CM
|
||||
CUTE_STATIC_ASSERT_V(size<0>(sB) == size<1>(sC)); // BN == CN
|
||||
CUTE_STATIC_ASSERT_V(size<1>(sA) == size<1>(sB)); // AK == BK
|
||||
|
||||
using TypeA = typename TA::value_type;
|
||||
using TypeB = typename TB::value_type;
|
||||
using TypeC = typename TC::value_type;
|
||||
|
||||
static_assert(is_same_v<decay_t<invoke_result_t<ALoadTransformOp, TypeA>>, TypeA>,
|
||||
"ALoadTransformOp functor must accept and return value of type TA::value_type");
|
||||
static_assert(is_same_v<decay_t<invoke_result_t<BLoadTransformOp, TypeB>>, TypeB>,
|
||||
"BLoadTransformOp functor must accept and return value of type TB::value_type");
|
||||
|
||||
// Original, static size of the problem
|
||||
auto M = size<0>(sC);
|
||||
auto N = size<1>(sC);
|
||||
auto K = size<1>(sA);
|
||||
|
||||
// Block size of the compute tile
|
||||
auto BLK_M = tile_size<0>(thr_mma);
|
||||
auto BLK_N = tile_size<1>(thr_mma);
|
||||
auto BLK_K = tile_size<2>(thr_mma);
|
||||
|
||||
// Compute the "residues"
|
||||
auto m_residue = M - BLK_M * (ceil_div(M, BLK_M) - Int<1>{}); // (0,BLK_M]
|
||||
auto n_residue = N - BLK_N * (ceil_div(N, BLK_N) - Int<1>{}); // (0,BLK_N]
|
||||
auto k_residue = K - BLK_K * (ceil_div(K, BLK_K) ); // (-BLK_K,0]
|
||||
|
||||
// Shift the origin so k_residue is zeroth tile
|
||||
sA.data() = &sA(0,k_residue);
|
||||
sB.data() = &sB(0,k_residue);
|
||||
|
||||
#if 0
|
||||
if (thread0()) {
|
||||
printf("%d in BLK_M (%d)\n", int(m_residue), int(BLK_M));
|
||||
printf("%d in BLK_N (%d)\n", int(n_residue), int(BLK_N));
|
||||
printf("%d in BLK_K (%d)\n", int(k_residue), int(BLK_K));
|
||||
}
|
||||
#endif
|
||||
|
||||
//
|
||||
// MMA Partitioning
|
||||
//
|
||||
|
||||
// Round the layout extents up to BLK_X
|
||||
Tensor rounded_sA = sA.compose(make_shape(ceil_div(M, BLK_M) * BLK_M, ceil_div(K, BLK_K) * BLK_K));
|
||||
Tensor rounded_sB = sB.compose(make_shape(ceil_div(N, BLK_N) * BLK_N, ceil_div(K, BLK_K) * BLK_K));
|
||||
Tensor rounded_sC = sC.compose(make_shape(ceil_div(M, BLK_M) * BLK_M, ceil_div(N, BLK_N) * BLK_N));
|
||||
|
||||
#if 0
|
||||
if (thread0()) {
|
||||
print(rounded_sA.layout()); print("\n");
|
||||
print(rounded_sB.layout()); print("\n");
|
||||
print(rounded_sC.layout()); print("\n");
|
||||
}
|
||||
#endif
|
||||
|
||||
// Partition the sA and sB tiles across the threads for the MMA
|
||||
Tensor tCsA = thr_mma.partition_A(rounded_sA); // (MMA,MMA_M,MMA_K)
|
||||
Tensor tCsB = thr_mma.partition_B(rounded_sB); // (MMA,MMA_N,MMA_K)
|
||||
Tensor tCsC = thr_mma.partition_C(rounded_sC); // (MMA,MMA_M,MMA_N)
|
||||
// Create register tensors for the MMA to operate on
|
||||
Tensor tCrA = thr_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K)
|
||||
Tensor tCrB = thr_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K)
|
||||
Tensor tCrC = thr_mma.make_fragment_C(tCsC); // (MMA,MMA_M,MMA_N)
|
||||
|
||||
#if 0
|
||||
if (thread0()) {
|
||||
print(tCsA.layout()); print("\n");
|
||||
print(tCsB.layout()); print("\n");
|
||||
print(tCsC.layout()); print("\n");
|
||||
print(tCrA.layout()); print("\n");
|
||||
print(tCrB.layout()); print("\n");
|
||||
print(tCrC.layout()); print("\n");
|
||||
}
|
||||
#endif
|
||||
|
||||
//
|
||||
// PREDICATION
|
||||
//
|
||||
|
||||
// Allocate the preds for only the MMA-mode of tCsA and tCsB
|
||||
Tensor tCpA = make_tensor<bool>(size<0>(tCsA));
|
||||
Tensor tCpB = make_tensor<bool>(size<0>(tCsB));
|
||||
|
||||
// Create coordinate tensors on a single compute block for predication
|
||||
Tensor cA = make_identity_tensor(make_shape(BLK_M, BLK_K)); // (BLK_M,BLK_K) -> (blk_m,blk_k)
|
||||
Tensor cB = make_identity_tensor(make_shape(BLK_N, BLK_K)); // (BLK_M,BLK_K) -> (blk_n,blk_k)
|
||||
|
||||
// Repeat partitioning with thr_mma
|
||||
Tensor tCcA = thr_mma.partition_A(cA); // (MMA,1,1) -> (blk_m,blk_k)
|
||||
Tensor tCcB = thr_mma.partition_B(cB); // (MMA,1,1) -> (blk_n,blk_k)
|
||||
|
||||
// Populate the m and n predicates
|
||||
CUTE_UNROLL
|
||||
for (int i = 0; i < size(tCpA); ++i) {
|
||||
tCpA(i) = elem_less(get<0>(tCcA(i)), m_residue);
|
||||
}
|
||||
CUTE_UNROLL
|
||||
for (int i = 0; i < size(tCpB); ++i) {
|
||||
tCpB(i) = elem_less(get<0>(tCcB(i)), n_residue);
|
||||
}
|
||||
|
||||
#if 0
|
||||
printf("Thr %d: A(%d,%d):%d B(%d,%d):%d\n",
|
||||
threadIdx.x,
|
||||
int(get<0>(tCcA(0))), int(get<1>(tCcA(0))), int(tCpA(0)),
|
||||
int(get<0>(tCcB(0))), int(get<1>(tCcB(0))), int(tCpB(0)));
|
||||
#endif
|
||||
|
||||
//
|
||||
// PREFETCH k_block = 0 (with k-predication)
|
||||
//
|
||||
|
||||
CUTE_UNROLL
|
||||
for (int i = 0; i < size<0>(tCsA); ++i) { // Copy MMA_I
|
||||
if (k_residue == 0 || get<1>(tCcA(i)) >= -k_residue) { // k_block = 0, predicated on k
|
||||
CUTE_UNROLL
|
||||
for (int m = 0; m < size<1>(tCsA); ++m) { // Copy MMA_M, predicated on m
|
||||
tCrA(i,m,0) = (m_residue == BLK_M || m < size<1>(tCsA)-1 || tCpA(i)) ? sA_load_op(tCsA(i,m,0)) : TypeA{};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
CUTE_UNROLL
|
||||
for (int i = 0; i < size<0>(tCsB); ++i) { // Copy MMA_I
|
||||
if (k_residue == 0 || get<1>(tCcB(i)) >= -k_residue) { // k_block = 0, predicated on k
|
||||
CUTE_UNROLL
|
||||
for (int n = 0; n < size<1>(tCsB); ++n) { // Copy MMA_N, predicated on n
|
||||
tCrB(i,n,0) = (n_residue == BLK_N || n < size<1>(tCsB)-1 || tCpB(i)) ? sB_load_op(tCsB(i,n,0)) : TypeB{};
|
||||
}
|
||||
}
|
||||
}
|
||||
//
|
||||
// MAINLOOP
|
||||
//
|
||||
|
||||
// Clear accumulators
|
||||
clear(tCrC);
|
||||
|
||||
constexpr int K_BLOCK_MAX = size<2>(tCrA);
|
||||
|
||||
CUTE_UNROLL
|
||||
for (int k_block = 0; k_block < K_BLOCK_MAX; ++k_block)
|
||||
{
|
||||
// static-if load the next k_block. No k-predication required on these loads.
|
||||
if (k_block < K_BLOCK_MAX-1)
|
||||
{
|
||||
// Load the next k_block
|
||||
int k_next = k_block + 1;
|
||||
|
||||
CUTE_UNROLL
|
||||
for (int m = 0; m < size<1>(tCsA); ++m) { // Copy MMA_M
|
||||
CUTE_UNROLL
|
||||
for (int i = 0; i < size<0>(tCsA); ++i) { // Copy_if MMA_I predicated on m
|
||||
tCrA(i,m,k_next) = (m_residue == BLK_M || m < size<1>(tCsA)-1 || tCpA(i)) ? sA_load_op(tCsA(i,m,k_next)) : TypeA{};
|
||||
}
|
||||
}
|
||||
|
||||
CUTE_UNROLL
|
||||
for (int n = 0; n < size<1>(tCsB); ++n) { // Copy MMA_N
|
||||
CUTE_UNROLL
|
||||
for (int i = 0; i < size<0>(tCsB); ++i) { // Copy MMA_I predicated on n
|
||||
tCrB(i,n,k_next) = (n_residue == BLK_N || n < size<1>(tCsB)-1 || tCpB(i)) ? sB_load_op(tCsB(i,n,k_next)) : TypeB{};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GEMM on k_block in registers
|
||||
gemm(thr_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC);
|
||||
}
|
||||
|
||||
//
|
||||
// Epilogue
|
||||
//
|
||||
|
||||
Tensor cC = make_identity_tensor(make_shape(BLK_M, BLK_N)); // (BLK_M,BLK_N) -> (blk_m,blk_n)
|
||||
Tensor tCcC = thr_mma.partition_C(cC); // (MMA, 1, 1) -> (blk_m,blk_n)
|
||||
|
||||
const bool isBetaZero = (beta == Beta{});
|
||||
|
||||
// Custom axpby_if for now
|
||||
CUTE_UNROLL
|
||||
for (int m = 0; m < size<1>(tCsC); ++m)
|
||||
{
|
||||
CUTE_UNROLL
|
||||
for (int n = 0; n < size<2>(tCsC); ++n)
|
||||
{
|
||||
CUTE_UNROLL
|
||||
for (int i = 0; i < size<0>(tCsC); ++i)
|
||||
{
|
||||
if ((m_residue == BLK_M || m < size<1>(tCrC)-1 || get<0>(tCcC(i)) < m_residue) &&
|
||||
(n_residue == BLK_N || n < size<2>(tCrC)-1 || get<1>(tCcC(i)) < n_residue))
|
||||
{
|
||||
tCsC(i,m,n) = isBetaZero ? alpha * tCrC(i,m,n) : alpha * tCrC(i,m,n) + beta * tCsC(i,m,n);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <class... Args,
|
||||
class Alpha, class TA, class ALayout, class TB, class BLayout,
|
||||
class Beta, class TC, class CLayout,
|
||||
__CUTE_REQUIRES(ALayout::rank == 2 && is_smem<TA>::value &&
|
||||
BLayout::rank == 2 && is_smem<TB>::value &&
|
||||
CLayout::rank == 2 && is_smem<TC>::value)>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
gemm(ThrMMA<Args...> const& thr_mma,
|
||||
Alpha const& alpha,
|
||||
Tensor<TA, ALayout> sA,
|
||||
Tensor<TB, BLayout> sB,
|
||||
Beta const& beta,
|
||||
Tensor<TC, CLayout> sC)
|
||||
{
|
||||
gemm(thr_mma, alpha, sA, sB, beta, sC, identity() /* sA_load_op */, identity() /* sB_load_op */);
|
||||
}
|
||||
|
||||
} // end namespace cute
|
||||
|
||||
153
include/cute/algorithm/prefetch.hpp
Normal file
153
include/cute/algorithm/prefetch.hpp
Normal file
@@ -0,0 +1,153 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include <cute/config.hpp>
|
||||
|
||||
#include <cute/tensor.hpp>
|
||||
|
||||
#include <cute/atom/copy_atom.hpp>
|
||||
|
||||
namespace cute
|
||||
{
|
||||
|
||||
//
|
||||
// Prefetch global tensors into L2
|
||||
//
|
||||
|
||||
template <uint32_t NumThreads, uint32_t FetchBytes = 64,
|
||||
class GEngine, class GLayout>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
cooperative_prefetch(uint32_t const& tid,
|
||||
Tensor<GEngine, GLayout> const& src)
|
||||
{
|
||||
static_assert(is_gmem<GEngine>::value, "Expected global tensor for prefetch");
|
||||
|
||||
constexpr int V = decltype(max_common_vector(src, src))::value;
|
||||
|
||||
if constexpr (V > 1) {
|
||||
// L2 sector is 32B, default fetch granularity is 64B
|
||||
using VecType = conditional_t<(V * sizeof_bits_v<typename GEngine::value_type>) < (FetchBytes * 8),
|
||||
ArrayEngine<typename GEngine::value_type, V>,
|
||||
uint8_t[FetchBytes] >;
|
||||
|
||||
Tensor src_v = recast<VecType const>(src);
|
||||
CUTE_UNROLL
|
||||
for (int i = tid; i < size(src_v); i += NumThreads) {
|
||||
prefetch(raw_pointer_cast(&src_v(i)));
|
||||
}
|
||||
} else {
|
||||
CUTE_UNROLL
|
||||
for (int i = tid; i < size(src); i += NumThreads) {
|
||||
prefetch(raw_pointer_cast(&src(i)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <class GEngine, class GLayout>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
prefetch(Tensor<GEngine, GLayout> const& src)
|
||||
{
|
||||
return cooperative_prefetch<1>(0, src);
|
||||
}
|
||||
|
||||
// Prefetch with copy atom
|
||||
namespace detail {
|
||||
|
||||
template <class CopyOp, class = void>
|
||||
constexpr bool has_prefetch = false;
|
||||
|
||||
template <class CopyOp>
|
||||
constexpr bool has_prefetch<CopyOp, void_t<typename CopyOp::PREFETCH>> = true;
|
||||
|
||||
template <class CopyOp, class = void>
|
||||
constexpr bool is_prefetch = false;
|
||||
|
||||
template <class CopyOp>
|
||||
constexpr bool is_prefetch<CopyOp, void_t<typename CopyOp::PREFETCH>> = is_same_v<CopyOp, typename CopyOp::PREFETCH>;
|
||||
|
||||
} // end namespace detail
|
||||
|
||||
template <class CopyOp, class... CT_Args, class... CA_Args,
|
||||
class GEngine, class GLayout>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
prefetch(Copy_Atom<Copy_Traits<CopyOp, CT_Args...>, CA_Args...> const& atom,
|
||||
Tensor<GEngine, GLayout> const& src)
|
||||
{
|
||||
if constexpr (detail::has_prefetch<CopyOp>) {
|
||||
using Prefetch_Traits = Copy_Traits<typename CopyOp::PREFETCH, CT_Args...>;
|
||||
using Prefetch_Atom = Copy_Atom<Prefetch_Traits, CA_Args...>;
|
||||
Prefetch_Atom prefetch_atom{atom};
|
||||
auto& dst = const_cast<Tensor<GEngine, GLayout>&>(src); // dst is ignored for prefetch atoms
|
||||
return copy(prefetch_atom, src, dst);
|
||||
} else {
|
||||
return prefetch(src);
|
||||
}
|
||||
}
|
||||
|
||||
#if defined(CUTE_COPY_ATOM_TMA_SM90_ENABLED)
|
||||
template <class... CT_Args,
|
||||
class SrcEngine, class SrcLayout>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
prefetch(Copy_Traits<SM90_BULK_COPY_AUTO, CT_Args...> const& atom,
|
||||
Tensor<SrcEngine, SrcLayout> const& src)
|
||||
{
|
||||
using SrcType = typename SrcEngine::value_type;
|
||||
static_assert(is_gmem<SrcEngine>::value, "Expected global tensor for L2 prefetch");
|
||||
|
||||
auto tiler = max_common_layout(src, src);
|
||||
constexpr int vec_elem = decltype(size(tiler))::value;
|
||||
constexpr int vec_bits = vec_elem * sizeof_bits_v<SrcType>;
|
||||
static_assert(vec_bits >= 128, "Expected at least 128-bits for BLKCP");
|
||||
|
||||
// Construct a new concrete Atom of the vector size
|
||||
auto bulk_atom = Copy_Atom<Copy_Traits<SM90_BULK_COPY_G2S, Int<vec_bits>>, SrcType>{};
|
||||
|
||||
return prefetch(bulk_atom, logical_divide(src, tiler));
|
||||
}
|
||||
|
||||
// Backwards-compat. Throw out any extra Copy_Atom args.
|
||||
template <class... CT_Args, class... CA_Args,
|
||||
class SrcEngine, class SrcLayout>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
prefetch(Copy_Atom<Copy_Traits<SM90_BULK_COPY_AUTO, CT_Args...>, CA_Args...> const& atom,
|
||||
Tensor<SrcEngine, SrcLayout> const& src)
|
||||
{
|
||||
return prefetch(static_cast<Copy_Traits<SM90_BULK_COPY_AUTO, CT_Args...> const&>(atom), src);
|
||||
}
|
||||
#endif // #if defined(CUTE_COPY_ATOM_TMA_SM90_ENABLED)
|
||||
|
||||
} // end namespace cute
|
||||
@@ -50,7 +50,7 @@ for_each(Tensor<Engine,Layout> const& tensor, UnaryOp&& op)
|
||||
{
|
||||
CUTE_UNROLL
|
||||
for (int i = 0; i < size(tensor); ++i) {
|
||||
static_cast<UnaryOp&&>(op)(tensor(i));
|
||||
op(tensor(i));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -61,7 +61,7 @@ for_each(Tensor<Engine,Layout>& tensor, UnaryOp&& op)
|
||||
{
|
||||
CUTE_UNROLL
|
||||
for (int i = 0; i < size(tensor); ++i) {
|
||||
static_cast<UnaryOp&&>(op)(tensor(i));
|
||||
op(tensor(i));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -71,7 +71,7 @@ CUTE_HOST_DEVICE constexpr
|
||||
void
|
||||
for_each(Tensor<Engine,Layout>&& tensor, UnaryOp&& op)
|
||||
{
|
||||
return for_each(tensor, static_cast<UnaryOp&&>(op));
|
||||
return for_each(tensor, op);
|
||||
}
|
||||
|
||||
//
|
||||
@@ -86,7 +86,7 @@ transform(Tensor<Engine,Layout>& tensor, UnaryOp&& op)
|
||||
{
|
||||
CUTE_UNROLL
|
||||
for (int i = 0; i < size(tensor); ++i) {
|
||||
tensor(i) = static_cast<UnaryOp&&>(op)(tensor(i));
|
||||
tensor(i) = op(tensor(i));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -96,27 +96,34 @@ CUTE_HOST_DEVICE constexpr
|
||||
void
|
||||
transform(Tensor<Engine,Layout>&& tensor, UnaryOp&& op)
|
||||
{
|
||||
return transform(tensor, std::forward<UnaryOp>(op));
|
||||
return transform(tensor, op);
|
||||
}
|
||||
|
||||
// Similar to std::transform transforms one tensors and assigns it to another
|
||||
template <class EngineIn, class LayoutIn, class EngineOut, class LayoutOut, class UnaryOp>
|
||||
template <class EngineIn, class LayoutIn,
|
||||
class EngineOut, class LayoutOut,
|
||||
class UnaryOp>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
void
|
||||
transform(Tensor<EngineIn,LayoutIn>& tensor_in, Tensor<EngineOut,LayoutOut>& tensor_out, UnaryOp&& op)
|
||||
transform(Tensor<EngineIn, LayoutIn > const& tensor_in,
|
||||
Tensor<EngineOut,LayoutOut> & tensor_out,
|
||||
UnaryOp&& op)
|
||||
{
|
||||
CUTE_UNROLL
|
||||
for (int i = 0; i < size(tensor_in); ++i) {
|
||||
tensor_out(i) = static_cast<UnaryOp&&>(op)(tensor_in(i));
|
||||
tensor_out(i) = op(tensor_in(i));
|
||||
}
|
||||
}
|
||||
|
||||
// Accept mutable temporaries
|
||||
template <class EngineIn, class LayoutIn,
|
||||
class EngineOut, class LayoutOut, class UnaryOp>
|
||||
class EngineOut, class LayoutOut,
|
||||
class UnaryOp>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
void
|
||||
transform(Tensor<EngineIn,LayoutIn>&& tensor_in, Tensor<EngineOut,LayoutOut>&& tensor_out, UnaryOp&& op)
|
||||
transform(Tensor<EngineIn, LayoutIn > const& tensor_in,
|
||||
Tensor<EngineOut,LayoutOut> && tensor_out,
|
||||
UnaryOp&& op)
|
||||
{
|
||||
return transform(tensor_in, tensor_out, op);
|
||||
}
|
||||
@@ -127,29 +134,31 @@ transform(Tensor<EngineIn,LayoutIn>&& tensor_in, Tensor<EngineOut,LayoutOut>&& t
|
||||
// assigns it to tensor_out
|
||||
template <class EngineIn1, class LayoutIn1,
|
||||
class EngineIn2, class LayoutIn2,
|
||||
class EngineOut, class LayoutOut, class BinaryOp>
|
||||
class EngineOut, class LayoutOut,
|
||||
class BinaryOp>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
void
|
||||
transform(Tensor<EngineIn1,LayoutIn1>& tensor_in1,
|
||||
Tensor<EngineIn2,LayoutIn2>& tensor_in2,
|
||||
Tensor<EngineOut,LayoutOut>& tensor_out,
|
||||
transform(Tensor<EngineIn1,LayoutIn1> const& tensor_in1,
|
||||
Tensor<EngineIn2,LayoutIn2> const& tensor_in2,
|
||||
Tensor<EngineOut,LayoutOut> & tensor_out,
|
||||
BinaryOp&& op)
|
||||
{
|
||||
CUTE_UNROLL
|
||||
for (int i = 0; i < size(tensor_in1); ++i) {
|
||||
tensor_out(i) = static_cast<BinaryOp&&>(op)(tensor_in1(i), tensor_in2(i));
|
||||
tensor_out(i) = op(tensor_in1(i), tensor_in2(i));
|
||||
}
|
||||
}
|
||||
|
||||
// Accept mutable temporaries
|
||||
template <class EngineIn1, class LayoutIn1,
|
||||
class EngineIn2, class LayoutIn2,
|
||||
class EngineOut, class LayoutOut, class BinaryOp>
|
||||
class EngineOut, class LayoutOut,
|
||||
class BinaryOp>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
void
|
||||
transform(Tensor<EngineIn1,LayoutIn1>&& tensor_in1,
|
||||
Tensor<EngineIn2,LayoutIn2>&& tensor_in2,
|
||||
Tensor<EngineOut,LayoutOut>&& tensor_out,
|
||||
transform(Tensor<EngineIn1,LayoutIn1> const& tensor_in1,
|
||||
Tensor<EngineIn2,LayoutIn2> const& tensor_in2,
|
||||
Tensor<EngineOut,LayoutOut> && tensor_out,
|
||||
BinaryOp&& op)
|
||||
{
|
||||
return transform(tensor_in1, tensor_in2, tensor_out, op);
|
||||
|
||||
@@ -204,36 +204,6 @@ for_each_leaf(T&& t, F&& f)
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
//
|
||||
// For Sequence
|
||||
// (s, t, f) => (f(t[s_0]),f(t[s_1]),...,f(t[s_n]))
|
||||
//
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <int... I, class F>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
void
|
||||
for_sequence(seq<I...> const&, F&& f) {
|
||||
(f(Int<I>{}), ...);
|
||||
}
|
||||
|
||||
}; // end namespace detail
|
||||
|
||||
template <int... I, class T, class F>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
void
|
||||
for_sequence(seq<I...> const& s, T&& t, F&& f) {
|
||||
detail::for_sequence(s, [&](auto&& i){ f(get<remove_cvref_t<decltype(i)>::value>(static_cast<T&&>(t))); });
|
||||
}
|
||||
|
||||
template <int I, class T, class F>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
void
|
||||
for_sequence(T&& t, F&& f) {
|
||||
for_sequence(make_seq<I>{}, static_cast<T&&>(t), static_cast<F&&>(f));
|
||||
}
|
||||
|
||||
//
|
||||
// Transform
|
||||
// (t, f) => (f(t_0),f(t_1),...,f(t_n))
|
||||
@@ -551,15 +521,15 @@ take(T const& t)
|
||||
template <int... I, class T>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
select(T const & t)
|
||||
select(T const& t)
|
||||
{
|
||||
return cute::make_tuple(get<I>(t)...);
|
||||
}
|
||||
|
||||
template <class T, typename Indices>
|
||||
template <class T, class Indices>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
select(T const & t, Indices const & indices)
|
||||
select(T const& t, Indices const& indices)
|
||||
{
|
||||
if constexpr (is_tuple<Indices>::value) {
|
||||
return cute::transform(indices, [&t](auto i) { return select(t, i); });
|
||||
@@ -655,7 +625,7 @@ flatten(T const& t)
|
||||
|
||||
namespace detail {
|
||||
|
||||
template<class FlatTuple, class TargetProfile>
|
||||
template <class FlatTuple, class TargetProfile>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
unflatten_impl(FlatTuple const& flat_tuple, TargetProfile const& target_profile)
|
||||
@@ -680,7 +650,7 @@ unflatten_impl(FlatTuple const& flat_tuple, TargetProfile const& target_profile)
|
||||
// @pre rank(flatten(@a target_profile)) == rank(@a flat_tuple)
|
||||
// @post congruent(@a result, @a target_profile)
|
||||
// @post flatten(@a result) == @a flat_tuple
|
||||
template<class FlatTuple, class TargetProfile>
|
||||
template <class FlatTuple, class TargetProfile>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
unflatten(FlatTuple const& flat_tuple, TargetProfile const& target_profile)
|
||||
@@ -865,6 +835,7 @@ append(T const& a, X const& x)
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
template <class T, class X>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
@@ -902,6 +873,7 @@ prepend(T const& a, X const& x)
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
template <class T, class X>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
@@ -1105,14 +1077,13 @@ zip2_by(T const& t, TG const& guide)
|
||||
|
||||
/// @return A tuple of the elements of @c t in reverse order.
|
||||
template <class T>
|
||||
CUTE_HOST_DEVICE constexpr auto
|
||||
reverse(T const& t) {
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
reverse(T const& t)
|
||||
{
|
||||
if constexpr (is_tuple<T>::value) {
|
||||
return detail::apply(t, [] (auto const&... a) {
|
||||
return cute::make_tuple(a...);
|
||||
}, tuple_rseq<T>{});
|
||||
}
|
||||
else {
|
||||
return detail::apply(t, [](auto const&... a){ return cute::make_tuple(a...); }, tuple_rseq<T>{});
|
||||
} else {
|
||||
return t;
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user