CUTLASS 3.5.0 (#1411)

This commit is contained in:
Vijay Thakkar
2024-03-19 17:51:04 -04:00
committed by GitHub
parent ffa34e7075
commit 629f4653c3
468 changed files with 48730 additions and 7253 deletions

View File

@@ -33,6 +33,7 @@
#include <cute/config.hpp>
#include <cute/tensor.hpp>
#include <cute/tensor_predicate.hpp>
namespace cute
{
@@ -43,15 +44,17 @@ namespace cute
template <class Alpha,
class XEngine, class XLayout,
class Beta,
class YEngine, class YLayout>
class YEngine, class YLayout,
class PrdTensor = TrivialPredTensor>
CUTE_HOST_DEVICE
void
axpby(Alpha const& alpha,
Tensor<XEngine, XLayout> const& x,
Beta const& beta,
Tensor<YEngine, YLayout> && y)
Tensor<YEngine, YLayout> && y,
PrdTensor const& p = {})
{
return axpby(alpha, x, beta, y);
return axpby(alpha, x, beta, y, p);
}
//
@@ -60,13 +63,15 @@ axpby(Alpha const& alpha,
template <class Alpha,
class XEngine, class XLayout,
class Beta,
class YEngine, class YLayout>
class YEngine, class YLayout,
class PrdTensor = TrivialPredTensor>
CUTE_HOST_DEVICE
void
axpby(Alpha const& alpha,
Tensor<XEngine, XLayout> const& x,
Beta const& beta,
Tensor<YEngine, YLayout> & y)
Tensor<YEngine, YLayout> & y,
PrdTensor const& p = {})
{
auto isBetaZero = [&] () {
if constexpr (is_complex<Beta>::value) {
@@ -81,7 +86,9 @@ axpby(Alpha const& alpha,
CUTE_UNROLL
for (int i = 0; i < size(x); ++i) {
y(i) = (isBetaZero ? alpha * x(i) : alpha * x(i) + beta * y(i));
if (p(i)) {
y(i) = (isBetaZero ? alpha * x(i) : alpha * x(i) + beta * y(i));
}
}
}

View File

@@ -0,0 +1,196 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include <cute/config.hpp>
#include <cute/atom/copy_atom.hpp>
#include <cute/algorithm/copy.hpp>
#include <cute/tensor.hpp>
#include <cute/tensor_predicate.hpp>
namespace cute
{
// cooperative_copy<NumThreads, MaxVecBits>(thr_idx, src, dst)
// Use NumThreads to copy src to dst with element vectorization up to MaxVecBits.
// @pre 0 <= @a tid < NumThreads
// @pre Tensors @a src and @a dst are aligned up to MaxVecBits.
//
template <uint32_t NumThreads, uint32_t MaxVecBits,
class SrcEngine, class SrcLayout,
class DstEngine, class DstLayout>
CUTE_HOST_DEVICE
void
cooperative_copy(uint32_t const& tid,
Tensor<SrcEngine, SrcLayout> const& src,
Tensor<DstEngine, DstLayout> & dst)
{
// Assumes the shapes are static, can generalize
CUTE_STATIC_ASSERT_V(size(src) == size(dst));
// Assumes the types are the same, can generalize
static_assert(sizeof_bits_v<typename SrcEngine::value_type> == sizeof_bits_v<typename DstEngine::value_type>);
static_assert(MaxVecBits == sizeof_bits_v<typename SrcEngine::value_type> ||
MaxVecBits == 8 || MaxVecBits == 16 || MaxVecBits == 32 || MaxVecBits == 64 || MaxVecBits == 128,
"Expected MaxVecBits to be value size or 8 or 16 or 32 or 64 or 128 for alignment and performance.");
// Check that the tensors are likely shared across threads: either gmem or smem
static_assert((is_gmem<SrcEngine>::value || is_smem<SrcEngine>::value),
"cooperative_copy expects shared gmem or smem source tensor.");
static_assert((is_gmem<DstEngine>::value || is_smem<DstEngine>::value),
"cooperative_copy expects shared gmem or smem destination tensor.");
// Precondition on tid in DEBUG
assert(tid < NumThreads);
// Precondition on pointer alignment in DEBUG
assert(is_byte_aligned<max(MaxVecBits/8, 1u)>(raw_pointer_cast(src.data())));
assert(is_byte_aligned<max(MaxVecBits/8, 1u)>(raw_pointer_cast(dst.data())));
//
// Determine val+thr vectorization based on src/dst size and number of threads
// NOTE: This heuristic promotes parallelization over vectorization
//
constexpr int elem_bits = sizeof_bits_v<typename SrcEngine::value_type>;
// The number of elements that can be vectorized in values
constexpr int common_elem = decltype(max_common_vector(src, dst))::value;
constexpr int common_bits = common_elem * elem_bits;
constexpr int total_elem = decltype(size(src))::value;
constexpr int total_bits = total_elem * elem_bits;
static_assert(total_bits % NumThreads == 0);
constexpr int total_bits_per_thr = total_bits / NumThreads;
// If there are too many threads to allow a full elem copy, trunc the thrs and use elem_bits
constexpr int max_vec_bits_by_thr = cute::max(elem_bits, total_bits_per_thr);
// Cap the vectorization to the common bits, the max_vec_bits_by_thr, and the MaxVecBits
constexpr int vec_bits = cute::min(common_bits, max_vec_bits_by_thr, static_cast<int>(MaxVecBits));
// Convert back to number of elements, safe_div
static_assert((vec_bits % elem_bits) == 0);
constexpr int vec_elem = vec_bits / elem_bits;
// Use only part of threads if there's not enough work for all threads
constexpr int vec_thrs = (total_elem % (vec_elem * NumThreads) == 0)
? NumThreads
: (total_elem / vec_elem);
// The common layout of the two tensors that can be vectorized over threads
// vidx -> coord
auto common_layout = max_common_layout(get_nonswizzle_portion(src.layout()),
get_nonswizzle_portion(dst.layout()));
// Scale up the common_layout to cover the entire tensors
// vidx -> coord
auto full_perm = tile_to_shape(make_layout(common_layout), size(src));
// Create the Tiler
// ((vid,tid),iter)
auto layout_vt = logical_divide(full_perm, Layout<Shape<Int<vec_elem>, Int<vec_thrs>>>{});
// Apply and slice
Tensor src_v = src.compose(layout_vt)(make_coord(_,tid),_);
Tensor dst_v = dst.compose(layout_vt)(make_coord(_,tid),_);
// Should account for vec_bits < 8 and/or vec_elem <= 1
// And also account for subbyte types, which could cause race conditions
// Want to ENFORCE sufficient vectorization in those cases
static_assert((vec_bits >= 8), "No support for subbyte copying");
using VecType = uint_bit_t<vec_bits>;
#if 0
if (thread0()) {
print(" "); print("NumThreads: "); print(NumThreads); print("\n");
print(" "); print("src: "); print(src); print("\n");
print(" "); print("dst: "); print(dst); print("\n");
print(" "); print("common_layout: "); print(common_layout); print("\n");
print(" "); print("full_perm: "); print(full_perm); print("\n");
print(" "); print("Used vector: "); print(vec_elem); print("\n");
print(" "); print("Used threads: "); print(vec_thrs); print("\n");
print(" "); print("layout_vt: "); print(layout_vt); print("\n");
print(" "); print("src.compose(layout_vt): "); print(src.compose(layout_vt)); print("\n");
print(" "); print("dst.compose(layout_vt): "); print(dst.compose(layout_vt)); print("\n");
print(" "); print("src_v: "); print(src_v); print("\n");
print(" "); print("dst_v: "); print(dst_v); print("\n");
print(" "); print("recast<VecType const>(src_v): "); print(recast<VecType const>(src_v)); print("\n");
print(" "); print("recast<VecType const>(dst_v): "); print(recast<VecType const>(dst_v)); print("\n");
}
#ifdef __CUDA_ARCH__
__syncthreads();
#endif
#endif
// If we're using all threads (static) or the tid is in in-range (dynamic)
if (vec_thrs >= NumThreads or tid < vec_thrs) {
return copy_if(TrivialPredTensor{}, recast<VecType const>(src_v), recast<VecType>(dst_v));
}
}
template <uint32_t NumThreads,
class SrcEngine, class SrcLayout,
class DstEngine, class DstLayout>
CUTE_HOST_DEVICE
void
cooperative_copy(uint32_t const& tid,
Tensor<SrcEngine, SrcLayout> const& src,
Tensor<DstEngine, DstLayout> & dst)
{
constexpr uint32_t MaxVecBits = sizeof_bits_v<typename SrcEngine::value_type>;
return cooperative_copy<NumThreads, MaxVecBits>(tid, src, dst);
}
// Accept mutable temporaries
template <uint32_t NumThreads,
class SrcEngine, class SrcLayout,
class DstEngine, class DstLayout>
CUTE_HOST_DEVICE
void
cooperative_copy(uint32_t const& tid,
Tensor<SrcEngine, SrcLayout> const& src,
Tensor<DstEngine, DstLayout> && dst)
{
return cooperative_copy<NumThreads>(tid, src, dst);
}
// Accept mutable temporaries
template <uint32_t NumThreads,
uint32_t MaxVecBits,
class SrcEngine, class SrcLayout,
class DstEngine, class DstLayout>
CUTE_HOST_DEVICE
void
cooperative_copy(uint32_t const& tid,
Tensor<SrcEngine, SrcLayout> const& src,
Tensor<DstEngine, DstLayout> && dst)
{
return cooperative_copy<NumThreads, MaxVecBits>(tid, src, dst);
}
} // end namespace cute

View File

@@ -0,0 +1,326 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include <cute/config.hpp>
#include <cute/util/type_traits.hpp>
#include <cute/atom/mma_atom.hpp>
#include <cute/algorithm/functional.hpp>
#include <cute/algorithm/gemm.hpp>
#include <cute/tensor.hpp>
namespace cute
{
//
// Collective Shared-Memory GEMMs
//
template <class... Args,
class Alpha, class TA, class ALayout, class TB, class BLayout,
class Beta, class TC, class CLayout,
class ALoadTransformOp, class BLoadTransformOp,
__CUTE_REQUIRES(ALayout::rank == 2 && is_smem<TA>::value &&
BLayout::rank == 2 && is_smem<TB>::value &&
CLayout::rank == 2 && is_smem<TC>::value)>
CUTE_HOST_DEVICE
void
cooperative_gemm(ThrMMA<Args...> const& thr_mma,
Alpha const& alpha,
Tensor<TA, ALayout> sA,
Tensor<TB, BLayout> sB,
Beta const& beta,
Tensor<TC, CLayout> sC,
ALoadTransformOp const& sA_load_op /* transforms A values before used in GEMM */,
BLoadTransformOp const& sB_load_op /* transforms B values before used in GEMM */)
{
CUTE_STATIC_ASSERT_V(size<0>(sA) == size<0>(sC)); // AM == CM
CUTE_STATIC_ASSERT_V(size<0>(sB) == size<1>(sC)); // BN == CN
CUTE_STATIC_ASSERT_V(size<1>(sA) == size<1>(sB)); // AK == BK
using TypeA = typename TA::value_type;
using TypeB = typename TB::value_type;
using TypeC = typename TC::value_type;
static_assert(is_same_v<decay_t<invoke_result_t<ALoadTransformOp, TypeA>>, TypeA>,
"ALoadTransformOp functor must accept and return value of type TA::value_type");
static_assert(is_same_v<decay_t<invoke_result_t<BLoadTransformOp, TypeB>>, TypeB>,
"BLoadTransformOp functor must accept and return value of type TB::value_type");
// Original, static size of the problem
auto M = size<0>(sC);
auto N = size<1>(sC);
auto K = size<1>(sA);
// Block size of the compute tile
auto BLK_M = tile_size<0>(thr_mma);
auto BLK_N = tile_size<1>(thr_mma);
auto BLK_K = tile_size<2>(thr_mma);
// Compute the "residues"
auto m_residue = M - BLK_M * (ceil_div(M, BLK_M) - Int<1>{}); // (0,BLK_M]
auto n_residue = N - BLK_N * (ceil_div(N, BLK_N) - Int<1>{}); // (0,BLK_N]
auto k_residue = K - BLK_K * (ceil_div(K, BLK_K) ); // (-BLK_K,0]
// Shift the origin so k_residue is zeroth tile
sA.data() = &sA(0,k_residue);
sB.data() = &sB(0,k_residue);
#if 0
if (thread0()) {
printf("%d in BLK_M (%d)\n", int(m_residue), int(BLK_M));
printf("%d in BLK_N (%d)\n", int(n_residue), int(BLK_N));
printf("%d in BLK_K (%d)\n", int(k_residue), int(BLK_K));
}
#endif
//
// MMA Partitioning
//
// Round the layout extents up to BLK_X
Tensor rounded_sA = sA.compose(make_shape(ceil_div(M, BLK_M) * BLK_M, ceil_div(K, BLK_K) * BLK_K));
Tensor rounded_sB = sB.compose(make_shape(ceil_div(N, BLK_N) * BLK_N, ceil_div(K, BLK_K) * BLK_K));
Tensor rounded_sC = sC.compose(make_shape(ceil_div(M, BLK_M) * BLK_M, ceil_div(N, BLK_N) * BLK_N));
#if 0
if (thread0()) {
print("rounded_sA: "); print(rounded_sA); print("\n");
print("rounded_sB: "); print(rounded_sB); print("\n");
print("rounded_sC: "); print(rounded_sC); print("\n");
}
#endif
// Partition the sA and sB tiles across the threads for the MMA
Tensor tCsA = thr_mma.partition_A(rounded_sA); // (MMA,MMA_M,MMA_K)
Tensor tCsB = thr_mma.partition_B(rounded_sB); // (MMA,MMA_N,MMA_K)
Tensor tCsC = thr_mma.partition_C(rounded_sC); // (MMA,MMA_M,MMA_N)
// Create register tensors for the MMA to operate on
Tensor tCrA = thr_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K)
Tensor tCrB = thr_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K)
Tensor tCrC = thr_mma.make_fragment_C(tCsC); // (MMA,MMA_M,MMA_N)
#if 0
if (thread0()) {
print("tCsA: "); print(tCsA); print("\n");
print("tCsB: "); print(tCsB); print("\n");
print("tCsC: "); print(tCsC); print("\n");
print("tCrA: "); print(tCrA); print("\n");
print("tCrB: "); print(tCrB); print("\n");
print("tCrC: "); print(tCrC); print("\n");
}
#endif
//
// PREDICATION
//
// Allocate the preds for only the MMA-mode of tCsA and tCsB
Tensor tCpA = make_tensor<bool>(size<0>(tCsA));
Tensor tCpB = make_tensor<bool>(size<0>(tCsB));
// Create coordinate tensors on a single compute block for predication
Tensor cA = make_identity_tensor(make_shape(BLK_M, BLK_K)); // (BLK_M,BLK_K) -> (blk_m,blk_k)
Tensor cB = make_identity_tensor(make_shape(BLK_N, BLK_K)); // (BLK_M,BLK_K) -> (blk_n,blk_k)
// Repeat partitioning with thr_mma
Tensor tCcA = thr_mma.partition_A(cA); // (MMA,1,1) -> (blk_m,blk_k)
Tensor tCcB = thr_mma.partition_B(cB); // (MMA,1,1) -> (blk_n,blk_k)
// Populate the m and n predicates
CUTE_UNROLL
for (int i = 0; i < size(tCpA); ++i) {
tCpA(i) = elem_less(get<0>(tCcA(i)), m_residue);
}
CUTE_UNROLL
for (int i = 0; i < size(tCpB); ++i) {
tCpB(i) = elem_less(get<0>(tCcB(i)), n_residue);
}
#if 0
printf("Thr %d: A(%d,%d):%d B(%d,%d):%d\n",
threadIdx.x,
int(get<0>(tCcA(0))), int(get<1>(tCcA(0))), int(tCpA(0)),
int(get<0>(tCcB(0))), int(get<1>(tCcB(0))), int(tCpB(0)));
#endif
//
// PREFETCH k_block = 0 (with k-predication)
//
CUTE_UNROLL
for (int i = 0; i < size<0>(tCsA); ++i) { // Copy MMA_I
if (k_residue == 0 || get<1>(tCcA(i)) >= -k_residue) { // k_block = 0, predicated on k
CUTE_UNROLL
for (int m = 0; m < size<1>(tCsA); ++m) { // Copy MMA_M, predicated on m
tCrA(i,m,0) = (m_residue == BLK_M || m < size<1>(tCsA)-1 || tCpA(i)) ? sA_load_op(tCsA(i,m,0)) : TypeA{};
}
}
}
CUTE_UNROLL
for (int i = 0; i < size<0>(tCsB); ++i) { // Copy MMA_I
if (k_residue == 0 || get<1>(tCcB(i)) >= -k_residue) { // k_block = 0, predicated on k
CUTE_UNROLL
for (int n = 0; n < size<1>(tCsB); ++n) { // Copy MMA_N, predicated on n
tCrB(i,n,0) = (n_residue == BLK_N || n < size<1>(tCsB)-1 || tCpB(i)) ? sB_load_op(tCsB(i,n,0)) : TypeB{};
}
}
}
//
// MAINLOOP
//
// Clear accumulators
clear(tCrC);
constexpr int K_BLOCK_MAX = size<2>(tCrA);
CUTE_UNROLL
for (int k_block = 0; k_block < K_BLOCK_MAX; ++k_block)
{
// static-if load the next k_block. No k-predication required on these loads.
if (k_block < K_BLOCK_MAX-1)
{
// Load the next k_block
int k_next = k_block + 1;
CUTE_UNROLL
for (int m = 0; m < size<1>(tCsA); ++m) { // Copy MMA_M
CUTE_UNROLL
for (int i = 0; i < size<0>(tCsA); ++i) { // Copy_if MMA_I predicated on m
tCrA(i,m,k_next) = (m_residue == BLK_M || m < size<1>(tCsA)-1 || tCpA(i)) ? sA_load_op(tCsA(i,m,k_next)) : TypeA{};
}
}
CUTE_UNROLL
for (int n = 0; n < size<1>(tCsB); ++n) { // Copy MMA_N
CUTE_UNROLL
for (int i = 0; i < size<0>(tCsB); ++i) { // Copy MMA_I predicated on n
tCrB(i,n,k_next) = (n_residue == BLK_N || n < size<1>(tCsB)-1 || tCpB(i)) ? sB_load_op(tCsB(i,n,k_next)) : TypeB{};
}
}
}
// GEMM on k_block in registers
gemm(thr_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC);
}
//
// Epilogue
//
Tensor cC = make_identity_tensor(make_shape(BLK_M, BLK_N)); // (BLK_M,BLK_N) -> (blk_m,blk_n)
Tensor tCcC = thr_mma.partition_C(cC); // (MMA, 1, 1) -> (blk_m,blk_n)
const bool isBetaZero = (beta == Beta{});
// Custom axpby_if for now
CUTE_UNROLL
for (int m = 0; m < size<1>(tCsC); ++m)
{
CUTE_UNROLL
for (int n = 0; n < size<2>(tCsC); ++n)
{
CUTE_UNROLL
for (int i = 0; i < size<0>(tCsC); ++i)
{
if ((m_residue == BLK_M || m < size<1>(tCrC)-1 || get<0>(tCcC(i)) < m_residue) &&
(n_residue == BLK_N || n < size<2>(tCrC)-1 || get<1>(tCcC(i)) < n_residue))
{
tCsC(i,m,n) = isBetaZero ? alpha * static_cast<TypeC>(tCrC(i,m,n)) : alpha * static_cast<TypeC>(tCrC(i,m,n)) + beta * static_cast<TypeC>(tCsC(i,m,n));
}
}
}
}
}
template <class... Args,
class Alpha, class TA, class ALayout, class TB, class BLayout,
class Beta, class TC, class CLayout,
__CUTE_REQUIRES(ALayout::rank == 2 && is_smem<TA>::value &&
BLayout::rank == 2 && is_smem<TB>::value &&
CLayout::rank == 2 && is_smem<TC>::value)>
CUTE_HOST_DEVICE
void
cooperative_gemm(ThrMMA<Args...> const& thr_mma,
Alpha const& alpha,
Tensor<TA, ALayout> sA,
Tensor<TB, BLayout> sB,
Beta const& beta,
Tensor<TC, CLayout> sC)
{
cooperative_gemm(thr_mma, alpha, sA, sB, beta, sC, identity() /* sA_load_op */, identity() /* sB_load_op */);
}
template <class... Args,
class Alpha, class TA, class ALayout, class TB, class BLayout,
class Beta, class TC, class CLayout,
class ALoadTransformOp, class BLoadTransformOp,
__CUTE_REQUIRES(ALayout::rank == 2 && is_smem<TA>::value &&
BLayout::rank == 2 && is_smem<TB>::value &&
CLayout::rank == 2 && is_smem<TC>::value)>
CUTE_HOST_DEVICE
void
gemm(ThrMMA<Args...> const& thr_mma,
Alpha const& alpha,
Tensor<TA, ALayout> sA,
Tensor<TB, BLayout> sB,
Beta const& beta,
Tensor<TC, CLayout> sC,
ALoadTransformOp const& sA_load_op /* transforms A values before used in GEMM */,
BLoadTransformOp const& sB_load_op /* transforms B values before used in GEMM */)
{
cooperative_gemm(thr_mma, alpha, sA, sB, beta, sC, sA_load_op, sB_load_op);
}
template <class... Args,
class Alpha, class TA, class ALayout, class TB, class BLayout,
class Beta, class TC, class CLayout,
__CUTE_REQUIRES(ALayout::rank == 2 && is_smem<TA>::value &&
BLayout::rank == 2 && is_smem<TB>::value &&
CLayout::rank == 2 && is_smem<TC>::value)>
CUTE_HOST_DEVICE
void
gemm(ThrMMA<Args...> const& thr_mma,
Alpha const& alpha,
Tensor<TA, ALayout> sA,
Tensor<TB, BLayout> sB,
Beta const& beta,
Tensor<TC, CLayout> sC)
{
cooperative_gemm(thr_mma, alpha, sA, sB, beta, sC, identity() /* sA_load_op */, identity() /* sB_load_op */);
}
} // end namespace cute

View File

@@ -145,10 +145,10 @@ copy_if(PrdTensor const& pred,
namespace detail {
// Trait that detects if atom's traits has a member function with(bool)
template<typename, typename Enable = void>
template <class, class Enable = void>
constexpr bool has_with_bool = false;
template<typename T>
template <class T>
constexpr bool has_with_bool<T, cute::void_t<decltype(declval<typename T::Traits>().with(declval<bool>()))>> = true;
} // end namespace detail

View File

@@ -33,6 +33,7 @@
#include <cute/config.hpp>
#include <cute/util/type_traits.hpp>
#include <cute/numeric/complex.hpp>
/** C++14 <functional> extensions */
@@ -46,7 +47,7 @@ struct identity {
template <class T>
CUTE_HOST_DEVICE constexpr
decltype(auto) operator()(T&& arg) const {
return std::forward<T>(arg);
return static_cast<T&&>(arg);
}
};
@@ -69,7 +70,7 @@ struct constant_fn {
template <class T> \
CUTE_HOST_DEVICE constexpr \
decltype(auto) operator()(T&& arg) const { \
return OP std::forward<T>(arg); \
return OP static_cast<T&&>(arg); \
} \
}
#define CUTE_RIGHT_UNARY_OP(NAME,OP) \
@@ -77,7 +78,7 @@ struct constant_fn {
template <class T> \
CUTE_HOST_DEVICE constexpr \
decltype(auto) operator()(T&& arg) const { \
return std::forward<T>(arg) OP ; \
return static_cast<T&&>(arg) OP ; \
} \
}
#define CUTE_NAMED_UNARY_OP(NAME,OP) \
@@ -85,7 +86,7 @@ struct constant_fn {
template <class T> \
CUTE_HOST_DEVICE constexpr \
decltype(auto) operator()(T&& arg) const { \
return OP (std::forward<T>(arg)); \
return OP (static_cast<T&&>(arg)); \
} \
}
@@ -115,7 +116,7 @@ struct shift_right_const {
template <class T>
CUTE_HOST_DEVICE constexpr
decltype(auto) operator()(T&& arg) const {
return std::forward<T>(arg) >> Shift;
return static_cast<T&&>(arg) >> Shift;
}
};
@@ -126,7 +127,7 @@ struct shift_left_const {
template <class T>
CUTE_HOST_DEVICE constexpr
decltype(auto) operator()(T&& arg) const {
return std::forward<T>(arg) << Shift;
return static_cast<T&&>(arg) << Shift;
}
};
@@ -139,7 +140,7 @@ struct shift_left_const {
template <class T, class U> \
CUTE_HOST_DEVICE constexpr \
decltype(auto) operator()(T&& lhs, U&& rhs) const { \
return std::forward<T>(lhs) OP std::forward<U>(rhs); \
return static_cast<T&&>(lhs) OP static_cast<U&&>(rhs); \
} \
}
#define CUTE_NAMED_BINARY_OP(NAME,OP) \
@@ -147,7 +148,7 @@ struct shift_left_const {
template <class T, class U> \
CUTE_HOST_DEVICE constexpr \
decltype(auto) operator()(T&& lhs, U&& rhs) const { \
return OP (std::forward<T>(lhs), std::forward<U>(rhs)); \
return OP (static_cast<T&&>(lhs), static_cast<U&&>(rhs)); \
} \
}
@@ -273,7 +274,7 @@ struct bound_fn {
CUTE_HOST_DEVICE constexpr
decltype(auto)
operator()(T&& arg) {
return fn_(arg_, std::forward<T>(arg));
return fn_(arg_, static_cast<T&&>(arg));
}
Fn fn_;

View File

@@ -252,7 +252,7 @@ gemm(MMA_Atom<MMA> const& mma,
CUTE_STATIC_ASSERT_V(size<1>(typename MMA_Atom<MMA>::LayoutC_TV{}) == Int<1>{});
CUTE_STATIC_ASSERT_V(size<1>(typename MMA_Atom<MMA>::LayoutA_TV{}) == Int<1>{});
CUTE_STATIC_ASSERT_V(size<1>(typename MMA_Atom<MMA>::LayoutB_TV{}) == Int<1>{});
gemm(mma,
make_tensor(D.data(), prepend<3>(D.layout())), // (1,M,N)
make_tensor(A.data(), prepend<3>(A.layout())), // (1,M,K)
@@ -451,6 +451,7 @@ gemm(MMA_Atom<MMA> const& mma,
CUTE_STATIC_ASSERT_V(size<1>(typename MMA_Atom<MMA>::LayoutC_TV{}) == Int<1>{});
CUTE_STATIC_ASSERT_V(size<1>(typename MMA_Atom<MMA>::LayoutA_TV{}) == Int<1>{});
CUTE_STATIC_ASSERT_V(size<1>(typename MMA_Atom<MMA>::LayoutB_TV{}) == Int<1>{});
gemm(mma,
make_tensor(D.data(), prepend<3>(D.layout())), // (1,M,N)
make_tensor(A.data(), prepend<3>(A.layout())), // (1,M,K)
@@ -496,245 +497,4 @@ gemm(MMA_Atom<MMA> const& mma,
}
}
//
// Collective Shared-Memory GEMMs
//
template <class... Args,
class Alpha, class TA, class ALayout, class TB, class BLayout,
class Beta, class TC, class CLayout,
class ALoadTransformOp, class BLoadTransformOp,
__CUTE_REQUIRES(ALayout::rank == 2 && is_smem<TA>::value &&
BLayout::rank == 2 && is_smem<TB>::value &&
CLayout::rank == 2 && is_smem<TC>::value)>
CUTE_HOST_DEVICE
void
gemm(ThrMMA<Args...> const& thr_mma,
Alpha const& alpha,
Tensor<TA, ALayout> sA,
Tensor<TB, BLayout> sB,
Beta const& beta,
Tensor<TC, CLayout> sC,
ALoadTransformOp const& sA_load_op /* transforms A values before used in GEMM */,
BLoadTransformOp const& sB_load_op /* transforms B values before used in GEMM */)
{
CUTE_STATIC_ASSERT_V(size<0>(sA) == size<0>(sC)); // AM == CM
CUTE_STATIC_ASSERT_V(size<0>(sB) == size<1>(sC)); // BN == CN
CUTE_STATIC_ASSERT_V(size<1>(sA) == size<1>(sB)); // AK == BK
using TypeA = typename TA::value_type;
using TypeB = typename TB::value_type;
using TypeC = typename TC::value_type;
static_assert(is_same_v<decay_t<invoke_result_t<ALoadTransformOp, TypeA>>, TypeA>,
"ALoadTransformOp functor must accept and return value of type TA::value_type");
static_assert(is_same_v<decay_t<invoke_result_t<BLoadTransformOp, TypeB>>, TypeB>,
"BLoadTransformOp functor must accept and return value of type TB::value_type");
// Original, static size of the problem
auto M = size<0>(sC);
auto N = size<1>(sC);
auto K = size<1>(sA);
// Block size of the compute tile
auto BLK_M = tile_size<0>(thr_mma);
auto BLK_N = tile_size<1>(thr_mma);
auto BLK_K = tile_size<2>(thr_mma);
// Compute the "residues"
auto m_residue = M - BLK_M * (ceil_div(M, BLK_M) - Int<1>{}); // (0,BLK_M]
auto n_residue = N - BLK_N * (ceil_div(N, BLK_N) - Int<1>{}); // (0,BLK_N]
auto k_residue = K - BLK_K * (ceil_div(K, BLK_K) ); // (-BLK_K,0]
// Shift the origin so k_residue is zeroth tile
sA.data() = &sA(0,k_residue);
sB.data() = &sB(0,k_residue);
#if 0
if (thread0()) {
printf("%d in BLK_M (%d)\n", int(m_residue), int(BLK_M));
printf("%d in BLK_N (%d)\n", int(n_residue), int(BLK_N));
printf("%d in BLK_K (%d)\n", int(k_residue), int(BLK_K));
}
#endif
//
// MMA Partitioning
//
// Round the layout extents up to BLK_X
Tensor rounded_sA = sA.compose(make_shape(ceil_div(M, BLK_M) * BLK_M, ceil_div(K, BLK_K) * BLK_K));
Tensor rounded_sB = sB.compose(make_shape(ceil_div(N, BLK_N) * BLK_N, ceil_div(K, BLK_K) * BLK_K));
Tensor rounded_sC = sC.compose(make_shape(ceil_div(M, BLK_M) * BLK_M, ceil_div(N, BLK_N) * BLK_N));
#if 0
if (thread0()) {
print(rounded_sA.layout()); print("\n");
print(rounded_sB.layout()); print("\n");
print(rounded_sC.layout()); print("\n");
}
#endif
// Partition the sA and sB tiles across the threads for the MMA
Tensor tCsA = thr_mma.partition_A(rounded_sA); // (MMA,MMA_M,MMA_K)
Tensor tCsB = thr_mma.partition_B(rounded_sB); // (MMA,MMA_N,MMA_K)
Tensor tCsC = thr_mma.partition_C(rounded_sC); // (MMA,MMA_M,MMA_N)
// Create register tensors for the MMA to operate on
Tensor tCrA = thr_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K)
Tensor tCrB = thr_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K)
Tensor tCrC = thr_mma.make_fragment_C(tCsC); // (MMA,MMA_M,MMA_N)
#if 0
if (thread0()) {
print(tCsA.layout()); print("\n");
print(tCsB.layout()); print("\n");
print(tCsC.layout()); print("\n");
print(tCrA.layout()); print("\n");
print(tCrB.layout()); print("\n");
print(tCrC.layout()); print("\n");
}
#endif
//
// PREDICATION
//
// Allocate the preds for only the MMA-mode of tCsA and tCsB
Tensor tCpA = make_tensor<bool>(size<0>(tCsA));
Tensor tCpB = make_tensor<bool>(size<0>(tCsB));
// Create coordinate tensors on a single compute block for predication
Tensor cA = make_identity_tensor(make_shape(BLK_M, BLK_K)); // (BLK_M,BLK_K) -> (blk_m,blk_k)
Tensor cB = make_identity_tensor(make_shape(BLK_N, BLK_K)); // (BLK_M,BLK_K) -> (blk_n,blk_k)
// Repeat partitioning with thr_mma
Tensor tCcA = thr_mma.partition_A(cA); // (MMA,1,1) -> (blk_m,blk_k)
Tensor tCcB = thr_mma.partition_B(cB); // (MMA,1,1) -> (blk_n,blk_k)
// Populate the m and n predicates
CUTE_UNROLL
for (int i = 0; i < size(tCpA); ++i) {
tCpA(i) = elem_less(get<0>(tCcA(i)), m_residue);
}
CUTE_UNROLL
for (int i = 0; i < size(tCpB); ++i) {
tCpB(i) = elem_less(get<0>(tCcB(i)), n_residue);
}
#if 0
printf("Thr %d: A(%d,%d):%d B(%d,%d):%d\n",
threadIdx.x,
int(get<0>(tCcA(0))), int(get<1>(tCcA(0))), int(tCpA(0)),
int(get<0>(tCcB(0))), int(get<1>(tCcB(0))), int(tCpB(0)));
#endif
//
// PREFETCH k_block = 0 (with k-predication)
//
CUTE_UNROLL
for (int i = 0; i < size<0>(tCsA); ++i) { // Copy MMA_I
if (k_residue == 0 || get<1>(tCcA(i)) >= -k_residue) { // k_block = 0, predicated on k
CUTE_UNROLL
for (int m = 0; m < size<1>(tCsA); ++m) { // Copy MMA_M, predicated on m
tCrA(i,m,0) = (m_residue == BLK_M || m < size<1>(tCsA)-1 || tCpA(i)) ? sA_load_op(tCsA(i,m,0)) : TypeA{};
}
}
}
CUTE_UNROLL
for (int i = 0; i < size<0>(tCsB); ++i) { // Copy MMA_I
if (k_residue == 0 || get<1>(tCcB(i)) >= -k_residue) { // k_block = 0, predicated on k
CUTE_UNROLL
for (int n = 0; n < size<1>(tCsB); ++n) { // Copy MMA_N, predicated on n
tCrB(i,n,0) = (n_residue == BLK_N || n < size<1>(tCsB)-1 || tCpB(i)) ? sB_load_op(tCsB(i,n,0)) : TypeB{};
}
}
}
//
// MAINLOOP
//
// Clear accumulators
clear(tCrC);
constexpr int K_BLOCK_MAX = size<2>(tCrA);
CUTE_UNROLL
for (int k_block = 0; k_block < K_BLOCK_MAX; ++k_block)
{
// static-if load the next k_block. No k-predication required on these loads.
if (k_block < K_BLOCK_MAX-1)
{
// Load the next k_block
int k_next = k_block + 1;
CUTE_UNROLL
for (int m = 0; m < size<1>(tCsA); ++m) { // Copy MMA_M
CUTE_UNROLL
for (int i = 0; i < size<0>(tCsA); ++i) { // Copy_if MMA_I predicated on m
tCrA(i,m,k_next) = (m_residue == BLK_M || m < size<1>(tCsA)-1 || tCpA(i)) ? sA_load_op(tCsA(i,m,k_next)) : TypeA{};
}
}
CUTE_UNROLL
for (int n = 0; n < size<1>(tCsB); ++n) { // Copy MMA_N
CUTE_UNROLL
for (int i = 0; i < size<0>(tCsB); ++i) { // Copy MMA_I predicated on n
tCrB(i,n,k_next) = (n_residue == BLK_N || n < size<1>(tCsB)-1 || tCpB(i)) ? sB_load_op(tCsB(i,n,k_next)) : TypeB{};
}
}
}
// GEMM on k_block in registers
gemm(thr_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC);
}
//
// Epilogue
//
Tensor cC = make_identity_tensor(make_shape(BLK_M, BLK_N)); // (BLK_M,BLK_N) -> (blk_m,blk_n)
Tensor tCcC = thr_mma.partition_C(cC); // (MMA, 1, 1) -> (blk_m,blk_n)
const bool isBetaZero = (beta == Beta{});
// Custom axpby_if for now
CUTE_UNROLL
for (int m = 0; m < size<1>(tCsC); ++m)
{
CUTE_UNROLL
for (int n = 0; n < size<2>(tCsC); ++n)
{
CUTE_UNROLL
for (int i = 0; i < size<0>(tCsC); ++i)
{
if ((m_residue == BLK_M || m < size<1>(tCrC)-1 || get<0>(tCcC(i)) < m_residue) &&
(n_residue == BLK_N || n < size<2>(tCrC)-1 || get<1>(tCcC(i)) < n_residue))
{
tCsC(i,m,n) = isBetaZero ? alpha * tCrC(i,m,n) : alpha * tCrC(i,m,n) + beta * tCsC(i,m,n);
}
}
}
}
}
template <class... Args,
class Alpha, class TA, class ALayout, class TB, class BLayout,
class Beta, class TC, class CLayout,
__CUTE_REQUIRES(ALayout::rank == 2 && is_smem<TA>::value &&
BLayout::rank == 2 && is_smem<TB>::value &&
CLayout::rank == 2 && is_smem<TC>::value)>
CUTE_HOST_DEVICE
void
gemm(ThrMMA<Args...> const& thr_mma,
Alpha const& alpha,
Tensor<TA, ALayout> sA,
Tensor<TB, BLayout> sB,
Beta const& beta,
Tensor<TC, CLayout> sC)
{
gemm(thr_mma, alpha, sA, sB, beta, sC, identity() /* sA_load_op */, identity() /* sB_load_op */);
}
} // end namespace cute

View File

@@ -0,0 +1,153 @@
/***************************************************************************************************
* Copyright (c) 2024 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include <cute/config.hpp>
#include <cute/tensor.hpp>
#include <cute/atom/copy_atom.hpp>
namespace cute
{
//
// Prefetch global tensors into L2
//
template <uint32_t NumThreads, uint32_t FetchBytes = 64,
class GEngine, class GLayout>
CUTE_HOST_DEVICE
void
cooperative_prefetch(uint32_t const& tid,
Tensor<GEngine, GLayout> const& src)
{
static_assert(is_gmem<GEngine>::value, "Expected global tensor for prefetch");
constexpr int V = decltype(max_common_vector(src, src))::value;
if constexpr (V > 1) {
// L2 sector is 32B, default fetch granularity is 64B
using VecType = conditional_t<(V * sizeof_bits_v<typename GEngine::value_type>) < (FetchBytes * 8),
ArrayEngine<typename GEngine::value_type, V>,
uint8_t[FetchBytes] >;
Tensor src_v = recast<VecType const>(src);
CUTE_UNROLL
for (int i = tid; i < size(src_v); i += NumThreads) {
prefetch(raw_pointer_cast(&src_v(i)));
}
} else {
CUTE_UNROLL
for (int i = tid; i < size(src); i += NumThreads) {
prefetch(raw_pointer_cast(&src(i)));
}
}
}
template <class GEngine, class GLayout>
CUTE_HOST_DEVICE
void
prefetch(Tensor<GEngine, GLayout> const& src)
{
return cooperative_prefetch<1>(0, src);
}
// Prefetch with copy atom
namespace detail {
template <class CopyOp, class = void>
constexpr bool has_prefetch = false;
template <class CopyOp>
constexpr bool has_prefetch<CopyOp, void_t<typename CopyOp::PREFETCH>> = true;
template <class CopyOp, class = void>
constexpr bool is_prefetch = false;
template <class CopyOp>
constexpr bool is_prefetch<CopyOp, void_t<typename CopyOp::PREFETCH>> = is_same_v<CopyOp, typename CopyOp::PREFETCH>;
} // end namespace detail
template <class CopyOp, class... CT_Args, class... CA_Args,
class GEngine, class GLayout>
CUTE_HOST_DEVICE
void
prefetch(Copy_Atom<Copy_Traits<CopyOp, CT_Args...>, CA_Args...> const& atom,
Tensor<GEngine, GLayout> const& src)
{
if constexpr (detail::has_prefetch<CopyOp>) {
using Prefetch_Traits = Copy_Traits<typename CopyOp::PREFETCH, CT_Args...>;
using Prefetch_Atom = Copy_Atom<Prefetch_Traits, CA_Args...>;
Prefetch_Atom prefetch_atom{atom};
auto& dst = const_cast<Tensor<GEngine, GLayout>&>(src); // dst is ignored for prefetch atoms
return copy(prefetch_atom, src, dst);
} else {
return prefetch(src);
}
}
#if defined(CUTE_COPY_ATOM_TMA_SM90_ENABLED)
template <class... CT_Args,
class SrcEngine, class SrcLayout>
CUTE_HOST_DEVICE
void
prefetch(Copy_Traits<SM90_BULK_COPY_AUTO, CT_Args...> const& atom,
Tensor<SrcEngine, SrcLayout> const& src)
{
using SrcType = typename SrcEngine::value_type;
static_assert(is_gmem<SrcEngine>::value, "Expected global tensor for L2 prefetch");
auto tiler = max_common_layout(src, src);
constexpr int vec_elem = decltype(size(tiler))::value;
constexpr int vec_bits = vec_elem * sizeof_bits_v<SrcType>;
static_assert(vec_bits >= 128, "Expected at least 128-bits for BLKCP");
// Construct a new concrete Atom of the vector size
auto bulk_atom = Copy_Atom<Copy_Traits<SM90_BULK_COPY_G2S, Int<vec_bits>>, SrcType>{};
return prefetch(bulk_atom, logical_divide(src, tiler));
}
// Backwards-compat. Throw out any extra Copy_Atom args.
template <class... CT_Args, class... CA_Args,
class SrcEngine, class SrcLayout>
CUTE_HOST_DEVICE
void
prefetch(Copy_Atom<Copy_Traits<SM90_BULK_COPY_AUTO, CT_Args...>, CA_Args...> const& atom,
Tensor<SrcEngine, SrcLayout> const& src)
{
return prefetch(static_cast<Copy_Traits<SM90_BULK_COPY_AUTO, CT_Args...> const&>(atom), src);
}
#endif // #if defined(CUTE_COPY_ATOM_TMA_SM90_ENABLED)
} // end namespace cute

View File

@@ -50,7 +50,7 @@ for_each(Tensor<Engine,Layout> const& tensor, UnaryOp&& op)
{
CUTE_UNROLL
for (int i = 0; i < size(tensor); ++i) {
static_cast<UnaryOp&&>(op)(tensor(i));
op(tensor(i));
}
}
@@ -61,7 +61,7 @@ for_each(Tensor<Engine,Layout>& tensor, UnaryOp&& op)
{
CUTE_UNROLL
for (int i = 0; i < size(tensor); ++i) {
static_cast<UnaryOp&&>(op)(tensor(i));
op(tensor(i));
}
}
@@ -71,7 +71,7 @@ CUTE_HOST_DEVICE constexpr
void
for_each(Tensor<Engine,Layout>&& tensor, UnaryOp&& op)
{
return for_each(tensor, static_cast<UnaryOp&&>(op));
return for_each(tensor, op);
}
//
@@ -86,7 +86,7 @@ transform(Tensor<Engine,Layout>& tensor, UnaryOp&& op)
{
CUTE_UNROLL
for (int i = 0; i < size(tensor); ++i) {
tensor(i) = static_cast<UnaryOp&&>(op)(tensor(i));
tensor(i) = op(tensor(i));
}
}
@@ -96,27 +96,34 @@ CUTE_HOST_DEVICE constexpr
void
transform(Tensor<Engine,Layout>&& tensor, UnaryOp&& op)
{
return transform(tensor, std::forward<UnaryOp>(op));
return transform(tensor, op);
}
// Similar to std::transform transforms one tensors and assigns it to another
template <class EngineIn, class LayoutIn, class EngineOut, class LayoutOut, class UnaryOp>
template <class EngineIn, class LayoutIn,
class EngineOut, class LayoutOut,
class UnaryOp>
CUTE_HOST_DEVICE constexpr
void
transform(Tensor<EngineIn,LayoutIn>& tensor_in, Tensor<EngineOut,LayoutOut>& tensor_out, UnaryOp&& op)
transform(Tensor<EngineIn, LayoutIn > const& tensor_in,
Tensor<EngineOut,LayoutOut> & tensor_out,
UnaryOp&& op)
{
CUTE_UNROLL
for (int i = 0; i < size(tensor_in); ++i) {
tensor_out(i) = static_cast<UnaryOp&&>(op)(tensor_in(i));
tensor_out(i) = op(tensor_in(i));
}
}
// Accept mutable temporaries
template <class EngineIn, class LayoutIn,
class EngineOut, class LayoutOut, class UnaryOp>
class EngineOut, class LayoutOut,
class UnaryOp>
CUTE_HOST_DEVICE constexpr
void
transform(Tensor<EngineIn,LayoutIn>&& tensor_in, Tensor<EngineOut,LayoutOut>&& tensor_out, UnaryOp&& op)
transform(Tensor<EngineIn, LayoutIn > const& tensor_in,
Tensor<EngineOut,LayoutOut> && tensor_out,
UnaryOp&& op)
{
return transform(tensor_in, tensor_out, op);
}
@@ -127,29 +134,31 @@ transform(Tensor<EngineIn,LayoutIn>&& tensor_in, Tensor<EngineOut,LayoutOut>&& t
// assigns it to tensor_out
template <class EngineIn1, class LayoutIn1,
class EngineIn2, class LayoutIn2,
class EngineOut, class LayoutOut, class BinaryOp>
class EngineOut, class LayoutOut,
class BinaryOp>
CUTE_HOST_DEVICE constexpr
void
transform(Tensor<EngineIn1,LayoutIn1>& tensor_in1,
Tensor<EngineIn2,LayoutIn2>& tensor_in2,
Tensor<EngineOut,LayoutOut>& tensor_out,
transform(Tensor<EngineIn1,LayoutIn1> const& tensor_in1,
Tensor<EngineIn2,LayoutIn2> const& tensor_in2,
Tensor<EngineOut,LayoutOut> & tensor_out,
BinaryOp&& op)
{
CUTE_UNROLL
for (int i = 0; i < size(tensor_in1); ++i) {
tensor_out(i) = static_cast<BinaryOp&&>(op)(tensor_in1(i), tensor_in2(i));
tensor_out(i) = op(tensor_in1(i), tensor_in2(i));
}
}
// Accept mutable temporaries
template <class EngineIn1, class LayoutIn1,
class EngineIn2, class LayoutIn2,
class EngineOut, class LayoutOut, class BinaryOp>
class EngineOut, class LayoutOut,
class BinaryOp>
CUTE_HOST_DEVICE constexpr
void
transform(Tensor<EngineIn1,LayoutIn1>&& tensor_in1,
Tensor<EngineIn2,LayoutIn2>&& tensor_in2,
Tensor<EngineOut,LayoutOut>&& tensor_out,
transform(Tensor<EngineIn1,LayoutIn1> const& tensor_in1,
Tensor<EngineIn2,LayoutIn2> const& tensor_in2,
Tensor<EngineOut,LayoutOut> && tensor_out,
BinaryOp&& op)
{
return transform(tensor_in1, tensor_in2, tensor_out, op);

View File

@@ -204,36 +204,6 @@ for_each_leaf(T&& t, F&& f)
CUTE_GCC_UNREACHABLE;
}
//
// For Sequence
// (s, t, f) => (f(t[s_0]),f(t[s_1]),...,f(t[s_n]))
//
namespace detail {
template <int... I, class F>
CUTE_HOST_DEVICE constexpr
void
for_sequence(seq<I...> const&, F&& f) {
(f(Int<I>{}), ...);
}
}; // end namespace detail
template <int... I, class T, class F>
CUTE_HOST_DEVICE constexpr
void
for_sequence(seq<I...> const& s, T&& t, F&& f) {
detail::for_sequence(s, [&](auto&& i){ f(get<remove_cvref_t<decltype(i)>::value>(static_cast<T&&>(t))); });
}
template <int I, class T, class F>
CUTE_HOST_DEVICE constexpr
void
for_sequence(T&& t, F&& f) {
for_sequence(make_seq<I>{}, static_cast<T&&>(t), static_cast<F&&>(f));
}
//
// Transform
// (t, f) => (f(t_0),f(t_1),...,f(t_n))
@@ -551,15 +521,15 @@ take(T const& t)
template <int... I, class T>
CUTE_HOST_DEVICE constexpr
auto
select(T const & t)
select(T const& t)
{
return cute::make_tuple(get<I>(t)...);
}
template <class T, typename Indices>
template <class T, class Indices>
CUTE_HOST_DEVICE constexpr
auto
select(T const & t, Indices const & indices)
select(T const& t, Indices const& indices)
{
if constexpr (is_tuple<Indices>::value) {
return cute::transform(indices, [&t](auto i) { return select(t, i); });
@@ -655,7 +625,7 @@ flatten(T const& t)
namespace detail {
template<class FlatTuple, class TargetProfile>
template <class FlatTuple, class TargetProfile>
CUTE_HOST_DEVICE constexpr
auto
unflatten_impl(FlatTuple const& flat_tuple, TargetProfile const& target_profile)
@@ -680,7 +650,7 @@ unflatten_impl(FlatTuple const& flat_tuple, TargetProfile const& target_profile)
// @pre rank(flatten(@a target_profile)) == rank(@a flat_tuple)
// @post congruent(@a result, @a target_profile)
// @post flatten(@a result) == @a flat_tuple
template<class FlatTuple, class TargetProfile>
template <class FlatTuple, class TargetProfile>
CUTE_HOST_DEVICE constexpr
auto
unflatten(FlatTuple const& flat_tuple, TargetProfile const& target_profile)
@@ -865,6 +835,7 @@ append(T const& a, X const& x)
CUTE_GCC_UNREACHABLE;
}
template <class T, class X>
CUTE_HOST_DEVICE constexpr
auto
@@ -902,6 +873,7 @@ prepend(T const& a, X const& x)
CUTE_GCC_UNREACHABLE;
}
template <class T, class X>
CUTE_HOST_DEVICE constexpr
auto
@@ -1105,14 +1077,13 @@ zip2_by(T const& t, TG const& guide)
/// @return A tuple of the elements of @c t in reverse order.
template <class T>
CUTE_HOST_DEVICE constexpr auto
reverse(T const& t) {
CUTE_HOST_DEVICE constexpr
auto
reverse(T const& t)
{
if constexpr (is_tuple<T>::value) {
return detail::apply(t, [] (auto const&... a) {
return cute::make_tuple(a...);
}, tuple_rseq<T>{});
}
else {
return detail::apply(t, [](auto const&... a){ return cute::make_tuple(a...); }, tuple_rseq<T>{});
} else {
return t;
}
}