CUTLASS 2.7 (#318)

CUTLASS 2.7

Mainloop fusion for GEMM: summation over A or B
Strided DGRAD (optimized iterators)
Half-precision GELU_taylor activation functions
Use these when accumulation and epilogue compute types are all cutlass::half_t
Tuning and bug fixes to fused GEMM + GEMM example
Support for smaller than 128b aligned Convolutions: see examples
Caching of results to accelerate Convolution unit tests
Can be enabled or disabled by running cmake .. -DCUTLASS_TEST_ENABLE_CACHED_RESULTS=OFF
Corrections and bug fixes reported by the CUTLASS community
Thank you for filing these issues!

authored-by: Haicheng Wu haichengw@nvidia.com, Manish Gupta manigupta@nvidia.com, Dustyn Blasig dblasig@nvidia.com, Andrew Kerr akerr@nvidia.com
This commit is contained in:
Manish Gupta
2021-09-20 11:02:22 -07:00
committed by GitHub
parent 9ac255863f
commit 2e07c4cc2f
62 changed files with 5611 additions and 186 deletions

View File

@@ -225,6 +225,34 @@ struct global_store;
//
/////////////////////////////////////////////////////////////////////////////////////////////////
template <typename AccessType>
struct global_store<AccessType, 64> {
CUTLASS_DEVICE
global_store(AccessType const &D, void *ptr, bool pred_guard) {
uint4 const *data = reinterpret_cast<uint4 const *>(&D);
asm volatile(
"{\n"
" .reg .pred p;\n"
" setp.ne.b32 p, %5, 0;\n"
" @p st.global.v4.u32 [%0], {%1, %2, %3, %4};\n"
" @p st.global.v4.u32 [%6], {%7, %8, %9, %10};\n"
" @p st.global.v4.u32 [%11], {%12, %13, %14, %15};\n"
" @p st.global.v4.u32 [%16], {%17, %18, %19, %20};\n"
"}\n"
:
: "l"(ptr), "r"(data[0].x), "r"(data[0].y), "r"(data[0].z),
"r"(data[0].w), "r"((int)pred_guard), "l"(((uint8_t *)ptr) + 16),
"r"(data[1].x), "r"(data[1].y), "r"(data[1].z), "r"(data[1].w),
"l"(((uint8_t *)ptr) + 32),
"r"(data[2].x), "r"(data[2].y), "r"(data[2].z), "r"(data[2].w),
"l"(((uint8_t *)ptr) + 48),
"r"(data[3].x), "r"(data[3].y), "r"(data[3].z), "r"(data[2].w));
}
};
template <typename AccessType>
struct global_store<AccessType, 32> {
CUTLASS_DEVICE

View File

@@ -65,7 +65,7 @@ template <
typename ThreadblockSwizzle,
int Stages,
typename MathOperatorTag,
conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kAnalytic,
conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized,
conv::StrideSupport StrideSupport = StrideSupport::kStrided,
/// Access granularity of A matrix in units of elements
int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value,
@@ -675,6 +675,243 @@ struct DefaultConv2dDgrad <
>;
};
/// Defines a kernel for Conv2dDgrad specialzation for Optimized IteratorAlgorithm Dgrad Strided and
// multistage pipeline.
template <
typename ElementA,
typename LayoutA,
typename ElementB,
typename LayoutB,
typename ElementC,
typename LayoutC,
typename ElementAccumulator,
typename ArchTag,
typename ThreadblockShape,
typename WarpShape,
typename InstructionShape,
typename EpilogueOutputOp,
typename ThreadblockSwizzle,
int Stages,
typename MathOperatorTag,
int AlignmentA,
int AlignmentB
>
struct DefaultConv2dDgrad <
ElementA,
LayoutA,
ElementB,
LayoutB,
ElementC,
LayoutC,
ElementAccumulator,
arch::OpClassTensorOp,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
EpilogueOutputOp,
ThreadblockSwizzle,
Stages,
MathOperatorTag,
IteratorAlgorithm::kOptimized,
StrideSupport::kStrided,
AlignmentA,
AlignmentB
> {
// Define the core components from GEMM
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor,
ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
Stages, MathOperatorTag>;
// Define iterators over tiles from the A operand
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
using AccessTypeA = cutlass::AlignedArray<ElementA, AlignmentA>;
using IteratorA =
cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorOptimized<
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
ElementA,
ThreadMapA,
StrideSupport::kStrided,
AccessTypeA
>;
using SmemIteratorA = typename MmaCore::SmemIteratorA;
// Define iterators over tiles from the B operand
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
using AccessTypeB = cutlass::AlignedArray<ElementB, AlignmentB>;
using IteratorB =
cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorOptimized<
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
ElementB,
ThreadMapB,
StrideSupport::kStrided,
AccessTypeB
>;
using SmemIteratorB = typename MmaCore::SmemIteratorB;
// Warp-level GEMM components
using WarpMmaTensorOp = typename MmaCore::MmaTensorOp;
using MmaPolicy = typename MmaCore::MmaPolicy;
static cutlass::arch::CacheOperation::Kind const CacheOpB =
((sizeof_bits<ElementB>::value * AlignmentB) == 128)
? cutlass::arch::CacheOperation::Global
: cutlass::arch::CacheOperation::Always;
// Define the Mma
using Mma = threadblock::ImplicitGemmMultistage<
ThreadblockShape,
IteratorA,
SmemIteratorA,
arch::CacheOperation::Always,
IteratorB,
SmemIteratorB,
CacheOpB,
MmaPolicy,
Stages
>;
static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;
// Define the epilogue
using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOpStridedDgrad<
ThreadblockShape,
WarpMmaTensorOp,
kPartitionsK,
EpilogueOutputOp,
EpilogueOutputOp::kCount
>::Epilogue;
// Define the kernel
using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionStridedDgrad<
Mma,
Epilogue,
ThreadblockSwizzle,
conv::Operator::kDgrad
>;
};
/// Defines a kernel for Conv2dDgrad specialzation for Optimized IteratorAlgorithm Dgrad Strided
// and 2 stage pipeline.
template <
typename ElementA,
typename LayoutA,
typename ElementB,
typename LayoutB,
typename ElementC,
typename LayoutC,
typename ElementAccumulator,
typename ArchTag,
typename ThreadblockShape,
typename WarpShape,
typename InstructionShape,
typename EpilogueOutputOp,
typename ThreadblockSwizzle,
typename MathOperatorTag,
int AlignmentA,
int AlignmentB
>
struct DefaultConv2dDgrad <
ElementA,
LayoutA,
ElementB,
LayoutB,
ElementC,
LayoutC,
ElementAccumulator,
arch::OpClassTensorOp,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
EpilogueOutputOp,
ThreadblockSwizzle,
2,
MathOperatorTag,
IteratorAlgorithm::kOptimized,
StrideSupport::kStrided,
AlignmentA,
AlignmentB
> {
// Define the core components from GEMM
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor,
ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
2, MathOperatorTag>;
// Define iterators over tiles from the A operand
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
using AccessTypeA = cutlass::AlignedArray<ElementA, AlignmentA>;
using IteratorA =
cutlass::conv::threadblock::TileIteratorStridedDgrad<
cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorOptimized<
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
ElementA,
ThreadMapA,
StrideSupport::kStrided,
AccessTypeA
>
>;
using SmemIteratorA = typename MmaCore::SmemIteratorA;
// Define iterators over tiles from the B operand
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
using AccessTypeB = cutlass::AlignedArray<ElementB, AlignmentB>;
using IteratorB =
cutlass::conv::threadblock::TileIteratorStridedDgrad<
cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorOptimized<
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
ElementB,
ThreadMapB,
StrideSupport::kStrided,
AccessTypeB
>
>;
using SmemIteratorB = typename MmaCore::SmemIteratorB;
// Warp-level GEMM components
using WarpMmaTensorOp = typename MmaCore::MmaTensorOp;
using MmaPolicy = typename MmaCore::MmaPolicy;
// Define the Mma
using Mma = threadblock::ImplicitGemmPipelined<
ThreadblockShape,
IteratorA,
SmemIteratorA,
IteratorB,
SmemIteratorB,
ElementC,
LayoutC,
MmaPolicy
>;
static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;
// Define the epilogue
using Epilogue = typename detail::DefaultConvEpilogueStridedDgrad<
ArchTag,
ThreadblockShape,
WarpMmaTensorOp,
kPartitionsK,
EpilogueOutputOp
>::Epilogue;
// Define the kernel
using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionStridedDgrad<
Mma,
Epilogue,
ThreadblockSwizzle,
conv::Operator::kDgrad
>;
};
/// Defines a kernel for Conv2dDgrad specialzation for Optimized IteratorAlgorithm Dgrad Unity
// 2 stage pipeline
template <
@@ -1126,6 +1363,112 @@ struct DefaultConv2dDgrad <
};
/////////////////////////////////////////////////////////////////////////////////////////////////
template <
typename ElementA,
typename LayoutA,
typename ElementB,
typename LayoutB,
typename ElementC,
typename LayoutC,
typename ElementAccumulator,
typename ArchTag,
typename ThreadblockShape,
typename WarpShape,
typename InstructionShape,
typename EpilogueOutputOp,
typename ThreadblockSwizzle,
int Stages,
typename MathOperatorTag,
int AlignmentA,
int AlignmentB
>
struct DefaultConv2dDgrad <
ElementA,
LayoutA,
ElementB,
LayoutB,
ElementC,
LayoutC,
ElementAccumulator,
arch::OpClassSimt,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
EpilogueOutputOp,
ThreadblockSwizzle,
Stages,
MathOperatorTag,
IteratorAlgorithm::kOptimized,
conv::StrideSupport::kStrided,
AlignmentA,
AlignmentB
> {
// Define the core components from GEMM
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor,
ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt,
Stages, MathOperatorTag>;
// Define iterators over tiles from the A operand
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
using IteratorA =
cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorOptimized<
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
ElementA,
ThreadMapA,
conv::StrideSupport::kStrided
>;
using SmemIteratorA = typename MmaCore::SmemIteratorA;
// Define iterators over tiles from the B operand
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
using IteratorB =
cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorOptimized<
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
ElementB,
ThreadMapB,
conv::StrideSupport::kStrided
>;
using SmemIteratorB = typename MmaCore::SmemIteratorB;
// Warp-level GEMM components
using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt;
using MmaPolicy = typename MmaCore::MmaPolicy;
// Define the Mma
using Mma = threadblock::ImplicitGemmMultistage<
ThreadblockShape,
IteratorA,
SmemIteratorA,
arch::CacheOperation::Always,
IteratorB,
SmemIteratorB,
arch::CacheOperation::Always,
MmaPolicy,
Stages
>;
// Define the epilogue
using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimtStridedDgrad<
ThreadblockShape,
WarpMmaSimtOp,
EpilogueOutputOp,
EpilogueOutputOp::kCount
>::Epilogue;
// Define the kernel
using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionStridedDgrad<
Mma,
Epilogue,
ThreadblockSwizzle,
conv::Operator::kDgrad
>;
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Defines a kernel for Conv2dDgrad specialzation for Analytic IteratorAlgorithm,
@@ -1462,6 +1805,115 @@ struct DefaultConv2dDgrad <
};
/////////////////////////////////////////////////////////////////////////////////////////////////
template <
typename ElementA,
typename LayoutA,
typename ElementB,
typename LayoutB,
typename ElementC,
typename LayoutC,
typename ElementAccumulator,
typename ArchTag,
typename ThreadblockShape,
typename WarpShape,
typename InstructionShape,
typename EpilogueOutputOp,
typename ThreadblockSwizzle,
typename MathOperatorTag,
int AlignmentA,
int AlignmentB
>
struct DefaultConv2dDgrad <
ElementA,
LayoutA,
ElementB,
LayoutB,
ElementC,
LayoutC,
ElementAccumulator,
arch::OpClassSimt,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
EpilogueOutputOp,
ThreadblockSwizzle,
2,
MathOperatorTag,
IteratorAlgorithm::kOptimized,
conv::StrideSupport::kStrided,
AlignmentA,
AlignmentB
> {
// Define the core components from GEMM
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor,
ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt,
2, MathOperatorTag>;
// Define iterators over tiles from the A operand
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
using IteratorA =
cutlass::conv::threadblock::TileIteratorStridedDgrad<
cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorOptimized<
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
ElementA,
ThreadMapA,
conv::StrideSupport::kStrided
>
>;
using SmemIteratorA = typename MmaCore::SmemIteratorA;
// Define iterators over tiles from the B operand
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
using IteratorB =
cutlass::conv::threadblock::TileIteratorStridedDgrad<
cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorOptimized<
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
ElementB,
ThreadMapB,
conv::StrideSupport::kStrided
>
>;
using SmemIteratorB = typename MmaCore::SmemIteratorB;
// Warp-level GEMM components
using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt;
using MmaPolicy = typename MmaCore::MmaPolicy;
// Define the Mma
using Mma = threadblock::ImplicitGemmPipelined<
ThreadblockShape,
IteratorA,
SmemIteratorA,
IteratorB,
SmemIteratorB,
ElementC,
LayoutC,
MmaPolicy
>;
// Define the epilogue
using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimtStridedDgrad<
ThreadblockShape,
WarpMmaSimtOp,
EpilogueOutputOp,
EpilogueOutputOp::kCount
>::Epilogue;
// Define the kernel
using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionStridedDgrad<
Mma,
Epilogue,
ThreadblockSwizzle,
conv::Operator::kDgrad
>;
};
} // namespace kernel
} // namespace conv
} // namespace cutlass

View File

@@ -65,7 +65,7 @@ template <
typename ThreadblockSwizzle,
int Stages,
typename MathOperatorTag,
conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kAnalytic,
conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized,
conv::StrideSupport StrideSupport = StrideSupport::kStrided,
/// Access granularity of A matrix in units of elements
int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value,

View File

@@ -64,7 +64,7 @@ template <
typename ThreadblockSwizzle,
int Stages,
typename MathOperatorTag,
conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kAnalytic,
conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized,
conv::StrideSupport StrideSupport = StrideSupport::kStrided,
/// Access granularity of A matrix in units of elements
int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value,

View File

@@ -65,7 +65,7 @@ template <
typename ThreadblockSwizzle,
int Stages,
typename MathOperatorTag,
conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kAnalytic,
conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized,
conv::StrideSupport StrideSupport = StrideSupport::kStrided,
/// Access granularity of A matrix in units of elements
int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value,

View File

@@ -66,7 +66,7 @@ template <
typename ThreadblockSwizzle,
int Stages,
typename MathOperatorTag,
conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kAnalytic,
conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized,
conv::StrideSupport StrideSupport = StrideSupport::kStrided,
/// Access granularity of A matrix in units of elements
int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value,

View File

@@ -66,7 +66,7 @@ template <
typename ThreadblockSwizzle,
int Stages,
typename MathOperatorTag,
conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kAnalytic,
conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized,
conv::StrideSupport StrideSupport = StrideSupport::kStrided
> struct DefaultConv3dDgrad;

View File

@@ -66,7 +66,7 @@ template <
typename ThreadblockSwizzle,
int Stages,
typename MathOperatorTag,
conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kAnalytic,
conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized,
conv::StrideSupport StrideSupport = StrideSupport::kStrided
> struct DefaultConv3dFprop;

View File

@@ -65,7 +65,7 @@ template <
typename ThreadblockSwizzle,
int Stages,
typename MathOperatorTag,
conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kAnalytic,
conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized,
conv::StrideSupport StrideSupport = StrideSupport::kStrided
> struct DefaultConv3dWgrad;

View File

@@ -210,9 +210,9 @@ public:
CUTLASS_HOST_DEVICE
TensorCoord at() const {
int c = offset_c_[iteration_contiguous_];
int k = offset_k_[iteration_strided_];
int c = offset_c_[iteration_contiguous_] + iteration_vector_ * AccessType::kElements;
return TensorCoord(k, filter_r_, filter_s_, c);
}
@@ -222,7 +222,7 @@ public:
TensorCoord coord = at();
return coord.n() < problem_size_.K && (coord.c() + iteration_vector_ * AccessType::kElements) < problem_size_.C;
return coord.n() < problem_size_.K && coord.c() < problem_size_.C;
}
/// Returns a pointer to the vector starting at the current coordinate
@@ -232,7 +232,7 @@ public:
TensorCoord coord = at();
LongIndex offset = params_.layout(coord);
return reinterpret_cast<AccessType const *>(pointer_ + offset * sizeof_bits<Element>::value / 8) + iteration_vector_;
return reinterpret_cast<AccessType const *>(pointer_ + offset * sizeof_bits<Element>::value / 8);
}
@@ -250,6 +250,7 @@ public:
return *this;
}
iteration_contiguous_ = 0;
++iteration_strided_;
if (iteration_strided_ < ThreadMap::Iterations::kStrided) {
return *this;
@@ -408,8 +409,8 @@ public:
CUTLASS_HOST_DEVICE
TensorCoord at() const {
int c = offset_c_[iteration_contiguous_];
int k = offset_k_[iteration_strided_];
int c = offset_c_[iteration_contiguous_] + iteration_vector_ * AccessType::kElements;
return TensorCoord(k, filter_r_, filter_s_, c);
}
@@ -420,7 +421,7 @@ public:
TensorCoord coord = at();
return coord.n() < problem_size_.K && (coord.c() + iteration_vector_ * AccessType::kElements) < problem_size_.C;
return coord.n() < problem_size_.K && coord.c() < problem_size_.C;
}
/// Returns a pointer to the vector starting at the current coordinate
@@ -430,7 +431,7 @@ public:
TensorCoord coord = at();
LongIndex offset = params_.layout(coord);
return reinterpret_cast<AccessType const *>(pointer_ + offset * sizeof_bits<Element>::value / 8) + iteration_vector_;
return reinterpret_cast<AccessType const *>(pointer_ + offset * sizeof_bits<Element>::value / 8);
}
/// Increments to the next memory access

View File

@@ -67,6 +67,282 @@ class Conv2dDgradFilterTileAccessIteratorOptimized;
/////////////////////////////////////////////////////////////////////////////////////////////////
// Conv2dDgradFilterTileAccessIteratorOptimized unity strided dgrad is more performant for dgrad
// on problem sizes with stride = {1x1}
template <
typename Shape_,
typename Element_,
typename ThreadMap_,
typename AccessType_
>
class Conv2dDgradFilterTileAccessIteratorOptimized <
Shape_,
Element_,
ThreadMap_,
conv::StrideSupport::kStrided,
AccessType_
> {
public:
//
// Types
//
using Shape = Shape_;
using Element = Element_;
using Layout = layout::TensorNHWC;
using ThreadMap = ThreadMap_;
using AccessType = AccessType_;
using TensorRef = cutlass::TensorRef<Element, Layout>;
using TensorCoord = typename Layout::TensorCoord;
using Index = typename Layout::Index;
using LongIndex = typename Layout::LongIndex;
static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kOptimized;
static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided;
static int const kConvDim = 2;
using ConvProblemSize = typename conv::Conv2dProblemSize;
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements),
"Vectors implied by the thread map must be divisible by the access type.");
//
// Parameters structure
//
struct Params : Conv2dStridedDgradFilterIteratorOptimizedParams {
//
// Methods
//
CUTLASS_HOST_DEVICE
Params() { }
CUTLASS_HOST_DEVICE
Params(Conv2dStridedDgradFilterIteratorOptimizedParams const &base):
Conv2dStridedDgradFilterIteratorOptimizedParams(base) { }
CUTLASS_HOST_DEVICE
Params(
Conv2dProblemSize const &problem_size,
Layout const &layout
):
Conv2dStridedDgradFilterIteratorOptimizedParams(
problem_size,
layout,
sizeof_bits<Element>::value,
{Shape::kRow, Shape::kColumn},
ThreadMap::kThreads,
ThreadMap::kElementsPerAccess,
{ThreadMap::Iterations::kContiguous, ThreadMap::Iterations::kStrided},
{ThreadMap::Delta::kContiguous, ThreadMap::Delta::kStrided}
) { }
};
private:
Conv2dStridedDgradFilterIteratorOptimizedParams const &params_;
Conv2dProblemSize const &problem_size_;
LongIndex iteration_contiguous_;
LongIndex iteration_strided_;
LongIndex iteration_vector_;
char const *pointer_;
uint32_t predicates_[kAccessesPerVector];
int filter_k_;
int filter_r_;
int filter_s_;
int start_r_;
int start_s_;
int64_t reset_bytes_s_;
int64_t reset_bytes_r_;
//
// Assertions
//
// We map predicates into bits packed in this uint32_t container
static_assert(ThreadMap::Iterations::kStrided *
ThreadMap::Iterations::kContiguous < sizeof(predicates_) * 8,
"Currently, the number of loads per iteration is limited by the size of the predicates container.");
public:
CUTLASS_HOST_DEVICE
Conv2dDgradFilterTileAccessIteratorOptimized(
Conv2dStridedDgradFilterIteratorOptimizedParams const &params,
Conv2dProblemSize const &problem_size,
Element const *ptr,
int thread_idx,
int start_r, int start_s,
MatrixCoord const &threadblock_offset = MatrixCoord()
):
params_(params),
problem_size_(problem_size),
pointer_(reinterpret_cast<char const *>(ptr)),
predicates_{0},
filter_r_(start_r),
filter_s_(start_s),
start_r_(start_r),
start_s_(start_s) {
layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx);
filter_k_ = threadblock_offset.row() + thread_coord.strided();
Index column = threadblock_offset.column() + thread_coord.contiguous();
reset_bytes_s_ = (problem_size_.num_gemm_k_filter_s(start_s_) - 1) * params_.inc_next[0];
reset_bytes_r_ = reset_bytes_s_ +
(problem_size_.num_gemm_k_filter_r(start_r_) - 1) * params_.inc_next[1];
CUTLASS_PRAGMA_UNROLL
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
CUTLASS_PRAGMA_UNROLL
for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) {
int filter_k = filter_k_ + s * ThreadMap::Delta::kStrided;
int filter_c = column + c * ThreadMap::Delta::kContiguous;
CUTLASS_PRAGMA_UNROLL
for (int v = 0; v < kAccessesPerVector; ++v) {
uint32_t pred = ((filter_k < problem_size_.K && (filter_c + v * AccessType::kElements) < problem_size_.C) ? 1u : 0);
int pred_idx = c + s * ThreadMap::Iterations::kContiguous;
predicates_[v] |= (pred << pred_idx);
}
}
}
TensorCoord coord{filter_k_, filter_r_, filter_s_, column};
pointer_ += params_.layout(coord) * sizeof_bits<Element>::value / 8;
set_iteration_index(0);
}
/// Overrides the internal iteration index
CUTLASS_HOST_DEVICE
void set_iteration_index(Index index) {
iteration_vector_ = index % kAccessesPerVector;
int residual_access = index / kAccessesPerVector;
iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous;
iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous;
}
/// Adds a pointer offset in units of Element
CUTLASS_HOST_DEVICE
void add_pointer_offset(LongIndex pointer_offset) {
pointer_ += pointer_offset * sizeof_bits<Element>::value / 8;
}
CUTLASS_HOST_DEVICE
void advance() {
int next_idx = 0;
LongIndex reset_bytes = params_.reset_bytes;
// Move filter_s by stride_w
filter_s_ += problem_size_.stride_w;
if (filter_s_ >= problem_size_.S) {
// Restore filter_s
filter_s_ = start_s_;
// Move filter_r by stride_h
filter_r_ += problem_size_.stride_h;
bool check = (filter_r_ < problem_size_.R);
filter_r_ = check ? filter_r_ : start_r_;
next_idx = check ? 1 : 2;
reset_bytes += (check ? reset_bytes_s_ : reset_bytes_r_);
}
// offset pointers by offset_bytes
pointer_ += (params_.inc_next[next_idx] - reset_bytes);
if (next_idx == 2) {
filter_k_ += params_.filter_k_delta;
}
// Clear predicates if needed
CUTLASS_PRAGMA_UNROLL
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
if (filter_k_ + s * ThreadMap::Delta::kStrided >= problem_size_.K) {
uint32_t kClearMask = ((1u << ThreadMap::Iterations::kContiguous) - 1) << (s * ThreadMap::Iterations::kContiguous);
CUTLASS_PRAGMA_UNROLL
for (int v = 0; v < kAccessesPerVector; ++v) {
predicates_[v] = (predicates_[v] & (~kClearMask));
}
}
}
}
/// Returns true if the current coordinate is within the filter tensor W
CUTLASS_HOST_DEVICE
bool valid() {
LongIndex pred_idx = iteration_contiguous_ + iteration_strided_ * ThreadMap::Iterations::kContiguous;
return (predicates_[iteration_vector_] & (1u << pred_idx));
}
/// Returns a pointer to the vector starting at the current coordinate
CUTLASS_HOST_DEVICE
AccessType const *get() const {
return reinterpret_cast<AccessType const *>(pointer_ +
iteration_contiguous_ * ThreadMap::Delta::kContiguous * sizeof_bits<Element>::value / 8) + iteration_vector_;
}
/// Increments to the next memory access
CUTLASS_HOST_DEVICE
Conv2dDgradFilterTileAccessIteratorOptimized &operator++() {
++iteration_vector_;
if (iteration_vector_ < kAccessesPerVector) {
return *this;
}
iteration_vector_ = 0;
++iteration_contiguous_;
if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) {
return *this;
}
iteration_contiguous_ = 0;
++iteration_strided_;
if (iteration_strided_ < ThreadMap::Iterations::kStrided) {
// Move to the next K coordinate within the tile
pointer_ += params_.inc_next_strided;
return *this;
}
iteration_strided_ = 0;
return *this;
}
/// Determines whether the Implicit GEMM can execute the given problem.
CUTLASS_HOST_DEVICE
static Status can_implement(Conv2dProblemSize const &problem_size) {
// check alignment constraint on iterator's contiguous dimension
if (problem_size.C % AccessType::kElements) {
return Status::kErrorInvalidProblem;
}
return Status::kSuccess;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
// Conv2dDgradFilterTileAccessIteratorOptimized unity strided dgrad is more performant for dgrad
// on problem sizes with stride = {1x1}
template <

View File

@@ -268,11 +268,13 @@ public:
p += (conv_sign * (filter_r_ / problem_size_.stride_h));
q += (conv_sign * (filter_s_ / problem_size_.stride_w));
int k = filter_k_ + iteration_vector_ * AccessType::kElements;
return TensorCoord(
n,
p,
q,
filter_k_);
k);
}
@@ -286,7 +288,7 @@ public:
coord.n() < problem_size_.N &&
coord.h() >= 0 && coord.h() < problem_size_.P &&
coord.w() >= 0 && coord.w() < problem_size_.Q &&
(coord.c() + iteration_vector_ * AccessType::kElements) < problem_size_.K;
coord.c() < problem_size_.K;
}
/// Returns a pointer to the vector starting at the current coordinate
@@ -296,7 +298,7 @@ public:
TensorCoord coord = at();
LongIndex offset = params_.layout(coord);
return reinterpret_cast<AccessType const *>(pointer_ + offset * sizeof_bits<Element>::value / 8) + iteration_vector_;
return reinterpret_cast<AccessType const *>(pointer_ + offset * sizeof_bits<Element>::value / 8);
}
/// Increments to the next memory access
@@ -313,6 +315,7 @@ public:
return *this;
}
iteration_contiguous_ = 0;
++iteration_strided_;
if (iteration_strided_ < ThreadMap::Iterations::kStrided) {
return *this;
@@ -516,7 +519,9 @@ public:
int p = (h + problem_size_.pad_h - r * problem_size_.dilation_h) / problem_size_.stride_h;
int q = (w + problem_size_.pad_w - s * problem_size_.dilation_w) / problem_size_.stride_w;
return TensorCoord(n, p, q, filter_k_);
int k = filter_k_ + iteration_vector_ * AccessType::kElements;
return TensorCoord(n, p, q, k);
}
@@ -529,7 +534,7 @@ public:
return coord.n() < problem_size_.N &&
coord.h() >= 0 && coord.h() < problem_size_.P &&
coord.w() >= 0 && coord.w() < problem_size_.Q &&
(coord.c() + iteration_vector_ * AccessType::kElements) < problem_size_.K;
coord.c() < problem_size_.K;
}
/// Returns a pointer to the vector starting at the current coordinate
@@ -539,7 +544,7 @@ public:
TensorCoord coord = at();
LongIndex offset = params_.layout(coord);
return reinterpret_cast<AccessType const *>(pointer_ + offset * sizeof_bits<Element>::value / 8) + iteration_vector_;
return reinterpret_cast<AccessType const *>(pointer_ + offset * sizeof_bits<Element>::value / 8);
}
/// Increments to the next memory access

View File

@@ -67,6 +67,380 @@ template <
class Conv2dDgradOutputGradientTileAccessIteratorOptimized;
/////////////////////////////////////////////////////////////////////////////////////////////////
/////////////////////////////////////////////////////////////////////////////////////////////////
// Conv2dDgradOutputGradientTileAccessIteratorOptimized strided dgrad needs special handling
// to skip MMAs (Dx = Dy * w) on invalid filter positions
/////////////////////////////////////////////////////////////////////////////////////////////////
template <
typename Shape_,
typename Element_,
typename ThreadMap_,
typename AccessType_
>
class Conv2dDgradOutputGradientTileAccessIteratorOptimized <
Shape_,
Element_,
ThreadMap_,
conv::StrideSupport::kStrided,
AccessType_
> {
public:
//
// Types
//
using Shape = Shape_;
using Element = Element_;
using Layout = layout::TensorNHWC;
using ThreadMap = ThreadMap_;
using AccessType = AccessType_;
using TensorRef = cutlass::TensorRef<Element, Layout>;
using TensorCoord = typename Layout::TensorCoord;
using Index = typename Layout::Index;
using LongIndex = typename Layout::LongIndex;
static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kOptimized;
static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided;
static int const kConvDim = 2;
using ConvProblemSize = typename conv::Conv2dProblemSize;
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements),
"Vectors implied by the thread map must be divisible by the access type.");
using Mask = uint64_t;
static_assert(sizeof_bits<Element>::value >= 8,
"DGRAD requires elements of size 8b or greater.");
//
// Simpligying assertions
//
static_assert(ThreadMap::Iterations::kContiguous == 1,
"Require Iterations::kContiguous == 1");
//
// Parameters structure
//
using Params = Conv2dStridedDgradOutputGradientIteratorOptimizedParams;
private:
Params const &params_;
Conv2dProblemSize const &problem_size_;
LongIndex iteration_contiguous_;
LongIndex iteration_strided_;
LongIndex iteration_vector_;
// One pointer per access
char const *pointer_[ThreadMap::Iterations::kStrided];
int filter_k_;
int filter_r_;
int filter_s_;
int start_r_;
int start_s_;
int64_t reset_bytes_s_;
int64_t reset_bytes_r_;
Index masks_[ThreadMap::Iterations::kStrided][kAccessesPerVector][2];
public:
CUTLASS_HOST_DEVICE
Conv2dDgradOutputGradientTileAccessIteratorOptimized(
Params const &params,
Conv2dProblemSize const &problem_size,
Element const *ptr,
int thread_idx,
FastDivmod const &stride_h_divmod, FastDivmod const &stride_w_divmod,
int start_r, int start_s,
MatrixCoord const &threadblock_offset = MatrixCoord() // threadblock offset - units are whole CTA tiles
):
params_(params),
problem_size_(problem_size),
filter_k_(0),
filter_r_(start_r),
filter_s_(start_s),
start_r_(start_r),
start_s_(start_s) {
layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx);
filter_k_ = threadblock_offset.column() + thread_coord.contiguous();
reset_bytes_s_ = (problem_size_.num_gemm_k_filter_s(start_s_) - 1) * params_.inc_next[0];
reset_bytes_r_ = (problem_size_.num_gemm_k_filter_s(start_s_) - 1) * params_.inc_next[0] +
(problem_size_.num_gemm_k_filter_r(start_r_) - 1) * params_.inc_next[1];
int offset_n[ThreadMap::Iterations::kStrided];
int offset_p[ThreadMap::Iterations::kStrided];
int offset_q[ThreadMap::Iterations::kStrided];
int filter_r = filter_r_;
int filter_s = filter_s_;
if (problem_size_.mode == Mode::kConvolution) {
filter_r = (problem_size_.R - 1 - filter_r);
filter_s = (problem_size_.S - 1 - filter_s);
}
// Starting h, w positions for filter position in gemm_k=0
int start_h, start_w;
strided_dgrad_starting_coords(
problem_size_,
stride_h_divmod, stride_w_divmod,
filter_r, filter_s,
start_h, start_w);
// Effective starting P and Q for filter position required for remapping NHW rows
int P = (problem_size_.H - start_h + problem_size_.stride_h - 1) / problem_size_.stride_h;
int Q = (problem_size_.W - start_w + problem_size_.stride_w - 1) / problem_size_.stride_w;
CUTLASS_PRAGMA_UNROLL
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
pointer_[s] = reinterpret_cast<char const *>(ptr);
int offset_npq = (threadblock_offset.row() + thread_coord.strided() + s * ThreadMap::Delta::kStrided) % params_.tiled_rows_per_filter;
// (STEP 1) [reorder NHW rows to start with same filter positions]
offset_n[s] = offset_npq / (P * Q);
int residual = offset_npq % (P * Q);
int p = (residual / Q);
int q = (residual % Q);
int mapped_h = (start_h + p * problem_size_.stride_h);
int mapped_w = (start_w + q * problem_size_.stride_w);
// Access (p, q) coordinates for Dy tensor for filter position in gemm_k=0
// note that (h + pad_h - filter_r) and (w + pad_w - filter_s) are ensured to be
// divisible by stride_h and stride_w
offset_p[s] = (mapped_h + problem_size_.pad_h - filter_r) / problem_size_.stride_h;
offset_q[s] = (mapped_w + problem_size_.pad_w - filter_s) / problem_size_.stride_w;
// Intialize pointers for gemm_k=0
TensorCoord coord{offset_n[s], offset_p[s], offset_q[s], filter_k_};
pointer_[s] += params_.layout(coord) * sizeof_bits<Element>::value / 8;
}
//
// Precompute mask predicates
//
clear_mask();
CUTLASS_PRAGMA_NO_UNROLL
for (int r = start_r; r < problem_size_.R; r += problem_size_.stride_h) {
CUTLASS_PRAGMA_UNROLL
for (int s_idx = 0; s_idx < ThreadMap::Iterations::kStrided; ++s_idx) {
int p = offset_p[s_idx] ;
p += (params_.conv_sign * (r / problem_size_.stride_h));
bool pred = (offset_n[s_idx] < problem_size_.N && p >= 0 && p < problem_size_.P);
CUTLASS_PRAGMA_UNROLL
for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) {
masks_[s_idx][v_idx][0] |= (pred << r);
}
}
}
CUTLASS_PRAGMA_NO_UNROLL
for(int s = start_s; s < problem_size_.S; s += problem_size_.stride_w) {
CUTLASS_PRAGMA_UNROLL
for (int s_idx = 0; s_idx < ThreadMap::Iterations::kStrided; ++s_idx) {
int q = offset_q[s_idx];
q += (params_.conv_sign * (s / problem_size_.stride_w));
bool pred = (q >=0 && q < problem_size_.Q);
CUTLASS_PRAGMA_UNROLL
for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) {
masks_[s_idx][v_idx][1] |= (pred << s);
}
}
}
CUTLASS_PRAGMA_UNROLL
for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) {
clear_mask(v_idx, (filter_k_ + v_idx * AccessType::kElements) >= problem_size.K);
}
set_iteration_index(0);
}
CUTLASS_HOST_DEVICE
static Params getParams(Conv2dProblemSize const &problem_size, Layout const &layout) {
return Params(problem_size,
layout,
sizeof_bits<Element>::value,
{Shape::kRow, Shape::kColumn});
}
private:
/// Adds a pointer offset in units of element
CUTLASS_HOST_DEVICE
void add_byte_offset_(LongIndex byte_offset, LongIndex byte_reset = 0) {
CUTLASS_PRAGMA_UNROLL
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
pointer_[s] += byte_offset - byte_reset;
}
}
public:
/// Overrides the internal iteration index
CUTLASS_HOST_DEVICE
void set_iteration_index(Index index) {
iteration_vector_ = index % kAccessesPerVector;
int residual_access = index / kAccessesPerVector;
iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous;
iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous;
}
/// Adds a pointer offset in units of Element
CUTLASS_HOST_DEVICE
void add_pointer_offset(LongIndex pointer_offset) {
add_byte_offset_(pointer_offset * sizeof_bits<Element>::value / 8);
}
CUTLASS_HOST_DEVICE
void advance() {
int next_idx = 0;
int64_t reset_bytes = 0;
// Move filter_s by stride_w
filter_s_ += problem_size_.stride_w;
if (filter_s_ >= problem_size_.S) {
// Restore filter_s
filter_s_ = start_s_;
// Move filter_r by stride_h
filter_r_ += problem_size_.stride_h;
if (filter_r_ < problem_size_.R) {
next_idx = 1;
// Restore bytes in q coordinate (Mma in filter s dimenstion)
reset_bytes = reset_bytes_s_;
} else {
// Restore filter_r
filter_r_ = start_r_;
next_idx = 2;
// Restore bytes in p and q coordinate (Mma in filter s and r dimenstion)
reset_bytes = reset_bytes_r_;
}
}
// offset pointers by offset_bytes
add_byte_offset_(params_.inc_next[next_idx] - reset_bytes);
if (next_idx == 2) {
filter_k_ += params_.filter_k_delta;
}
CUTLASS_PRAGMA_UNROLL
for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) {
clear_mask(v_idx, (filter_k_ + v_idx * AccessType::kElements) >= problem_size_.K);
}
}
/// Clears the predicates
CUTLASS_HOST_DEVICE
void clear_mask(bool clear = true) {
CUTLASS_PRAGMA_UNROLL
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
CUTLASS_PRAGMA_UNROLL
for (int v = 0; v < kAccessesPerVector; ++v) {
masks_[s][v][0] = clear ? Mask(0) : masks_[s][v][0];
masks_[s][v][1] = clear ? Mask(0) : masks_[s][v][1];
}
}
}
/// Clears the predicates
CUTLASS_HOST_DEVICE
void clear_mask(int v, bool clear = true) {
CUTLASS_PRAGMA_UNROLL
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
masks_[s][v][0] = clear ? Mask(0) : masks_[s][v][0];
masks_[s][v][1] = clear ? Mask(0) : masks_[s][v][1];
}
}
/// Returns true if the current coordinate is within the output tensor Dy
CUTLASS_HOST_DEVICE
bool valid() const {
return
(masks_[iteration_strided_][iteration_vector_][0] & (Index(1) << filter_r_)) &&
(masks_[iteration_strided_][iteration_vector_][1] & (Index(1) << filter_s_));
}
/// Returns a pointer to the vector starting at the current coordinate
CUTLASS_HOST_DEVICE
AccessType const *get() const {
return reinterpret_cast<AccessType const *>(pointer_[iteration_strided_]) + iteration_vector_;
}
/// Increments to the next memory access
CUTLASS_HOST_DEVICE
Conv2dDgradOutputGradientTileAccessIteratorOptimized &operator++() {
++iteration_vector_;
if (iteration_vector_ < kAccessesPerVector) {
return *this;
}
iteration_vector_ = 0;
++iteration_contiguous_;
if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) {
return *this;
}
iteration_contiguous_ = 0;
++iteration_strided_;
if (iteration_strided_ < ThreadMap::Iterations::kStrided) {
return *this;
}
iteration_strided_ = 0;
return *this;
}
/// Determines whether the Implicit GEMM can execute the given problem.
CUTLASS_HOST_DEVICE
static Status can_implement(Conv2dProblemSize const &problem_size) {
// check alignment constraint on iterator's contiguous dimension
if (problem_size.K % AccessType::kElements) {
return Status::kErrorInvalidProblem;
}
// Limit on filter size
if (problem_size.R > 32 || problem_size.S > 32) {
return Status::kErrorNotSupported;
}
return Status::kSuccess;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
// Conv2dDgradOutputGradientTileAccessIteratorOptimized unity stride dgrad is optimized for dgrad
// with problem stride = {1x1}

View File

@@ -209,7 +209,9 @@ public:
int h = p * problem_size_.stride_h - problem_size_.pad_h + r * problem_size_.dilation_h;
int w = q * problem_size_.stride_w - problem_size_.pad_w + s * problem_size_.dilation_w;
return TensorCoord(n, h, w, filter_c_);
int c = filter_c_ + iteration_vector_ * AccessType::kElements;
return TensorCoord(n, h, w, c);
}
/// Returns true if the current coordinate is within the activations tensor X
@@ -221,7 +223,7 @@ public:
return coord.n() < problem_size_.N &&
coord.h() >= 0 && coord.h() < problem_size_.H &&
coord.w() >= 0 && coord.w() < problem_size_.W &&
(coord.c() + iteration_vector_ * AccessType::kElements) < problem_size_.C;
coord.c() < problem_size_.C;
}
/// Returns a pointer to the vector starting at the current coordinate
@@ -231,7 +233,7 @@ public:
TensorCoord coord = at();
LongIndex offset = params_.layout(coord);
AccessType const *ptr = reinterpret_cast<AccessType const *>(pointer_ + offset * sizeof_bits<Element>::value / 8) + iteration_vector_;
AccessType const *ptr = reinterpret_cast<AccessType const *>(pointer_ + offset * sizeof_bits<Element>::value / 8);
return ptr;
}

View File

@@ -183,8 +183,9 @@ public:
TensorCoord at() const {
int k = offset_k_[iteration_strided_];
int c = filter_c_ + iteration_vector_ * AccessType::kElements;
return TensorCoord(k, filter_r_, filter_s_, filter_c_);
return TensorCoord(k, filter_r_, filter_s_, c);
}
/// Returns true if the current coordinate is within the activations tensor W
@@ -194,7 +195,7 @@ public:
TensorCoord coord = at();
return coord.n() < problem_size_.K &&
(coord.c() + iteration_vector_ * AccessType::kElements) < problem_size_.C;
coord.c() < problem_size_.C;
}
/// Returns a pointer to the vector starting at the current coordinate
@@ -204,7 +205,7 @@ public:
TensorCoord coord = at();
LongIndex offset = params_.layout(coord);
return reinterpret_cast<AccessType const *>(pointer_ + offset * sizeof_bits<Element>::value / 8) + iteration_vector_;
return reinterpret_cast<AccessType const *>(pointer_ + offset * sizeof_bits<Element>::value / 8);
}
/// Increments to the next memory access

View File

@@ -527,6 +527,64 @@ struct Conv2dDgradOutputGradientIteratorOptimizedParams {
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
// Strided Dgrad Optimized Dy params (layout::TensorNHWC)
/////////////////////////////////////////////////////////////////////////////////////////////////
struct Conv2dStridedDgradOutputGradientIteratorOptimizedParams {
using Layout = layout::TensorNHWC;
Layout layout;
int64_t inc_next[3]; // {next S, next R, next K}
int filter_k_delta; // number of logical elements to add to filter_k_
int tiled_rows_per_filter;
int conv_sign;
//
// Methods
//
CUTLASS_HOST_DEVICE
Conv2dStridedDgradOutputGradientIteratorOptimizedParams() { }
CUTLASS_HOST_DEVICE
Conv2dStridedDgradOutputGradientIteratorOptimizedParams(
Conv2dProblemSize const &problem_size,
Layout const &layout, ///< layout object
int element_size_bits, ///< size of each element in bits
MatrixCoord threadblock_shape
): layout(layout) {
int tile_m_per_filter = strided_dgrad_tile_m_per_filter(problem_size, threadblock_shape.row());
tiled_rows_per_filter = tile_m_per_filter * threadblock_shape.row();
conv_sign = (problem_size.mode == Mode::kConvolution ? 1 : -1);
// next S
inc_next[0] = conv_sign * (
layout.stride()[0] * problem_size.dilation_w
) * element_size_bits / 8;
// next R
inc_next[1] = conv_sign * (
layout.stride()[1] * problem_size.dilation_h
) * element_size_bits / 8;
// next K
inc_next[2] = (
threadblock_shape.column() * problem_size.split_k_slices
) * element_size_bits / 8;
// logical offset added to internal channel counter - units are elements, not bytes
filter_k_delta = threadblock_shape.column() * problem_size.split_k_slices;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////
// Dgrad Optimized w params (layout::TensorNHWC)
/////////////////////////////////////////////////////////////////////////////////////////////////
@@ -584,6 +642,73 @@ struct Conv2dDgradFilterIteratorOptimizedParams {
/////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////
// StridedDgrad Optimized w params (layout::TensorNHWC)
/////////////////////////////////////////////////////////////////////////////////////////////////
struct Conv2dStridedDgradFilterIteratorOptimizedParams {
using Layout = layout::TensorNHWC;
Layout layout;
int RS;
int filter_k_delta;
int64_t inc_next_strided; // offset in units of bytes to next K coordinate within tile
int64_t inc_next[3]; // {next S, next R, next K}
int64_t reset_bytes; // offset in units of bytes to move back the pointer
//
// Methods
//
CUTLASS_HOST_DEVICE
Conv2dStridedDgradFilterIteratorOptimizedParams() { }
CUTLASS_HOST_DEVICE
Conv2dStridedDgradFilterIteratorOptimizedParams(
Conv2dProblemSize const &problem_size,
Layout const &layout,
int element_size_bits, ///< size of each element in bits
MatrixCoord threadblock_shape,
int thread_count,
int access_size,
layout::PitchLinearCoord threadmap_iterations,
layout::PitchLinearCoord threadmap_delta
):
layout(layout), RS(problem_size.R * problem_size.S) {
TRACE_CONV_INITIALIZERS("conv2d_dgrad", "filter",
element_size_bits, threadblock_shape, thread_count, access_size, threadmap_iterations, threadmap_delta);
inc_next_strided = (layout.stride()[2] * threadmap_delta.strided() * element_size_bits) / 8;
// next S
inc_next[0] =
( layout.stride()[0] * problem_size.stride_w
//- (threadmap_iterations.strided() - 1) * threadmap_delta.strided() * layout.stride()[2]
) * element_size_bits / 8;
// next R
inc_next[1] =
( layout.stride()[1] * problem_size.stride_h
//- (threadmap_iterations.strided() - 1) * threadmap_delta.strided() * layout.stride()[2]
) * element_size_bits / 8;
// next K
inc_next[2] =
(
threadblock_shape.row() * problem_size.split_k_slices * layout.stride()[2]
//- (problem_size.R * problem_size.S - 1) * layout.stride()[0]
//- (threadmap_iterations.strided() - 1) * threadmap_delta.strided() * layout.stride()[2]
) * element_size_bits / 8;
// offset in units of bytes to move the pointer in backward direction
reset_bytes = (threadmap_iterations.strided() - 1) * threadmap_delta.strided() * layout.stride()[2]
* element_size_bits / 8;
filter_k_delta = threadblock_shape.row() * problem_size.split_k_slices;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Parameters object for Conv2d WGRAD Output Gradient (dy) iterator
struct Conv2dWgradOutputGradientIteratorOptimizedParams {

View File

@@ -183,10 +183,13 @@ public:
int r, s, c;
if (kAccessesPerVector == 1) {
/// One 128b aligned access fetching more than one element
c = filter_c_[iteration_contiguous_];
r = filter_r_[iteration_contiguous_];
s = filter_s_[iteration_contiguous_];
c = filter_c_[iteration_contiguous_];
} else {
}
else {
/// Multiple access to support non-128b alignment in contiguous dimenstion
c = (filter_c_[iteration_contiguous_] + iteration_vector_ * AccessType::kElements) % problem_size_.C;
int wrap_c = (filter_c_[iteration_contiguous_] + iteration_vector_ * AccessType::kElements) / problem_size_.C;
s = (filter_s_[iteration_contiguous_] + wrap_c) % problem_size_.S;

View File

@@ -205,6 +205,8 @@ public:
int c = filter_c_[iteration_contiguous_];
if (kAccessesPerVector > 1) {
// This code section is only to support non-128b alignment
// Multiple access to support non-128b alignment in contiguous dimenstion
int wrap_c;
params_.c_divmod(wrap_c, c, c + iteration_vector_ * AccessType::kElements);

View File

@@ -182,7 +182,9 @@ public:
int p = residual / problem_size_.Q;
int q = residual % problem_size_.Q;
return TensorCoord(n, p, q, filter_k_[iteration_contiguous_]);
int k = filter_k_[iteration_contiguous_] + iteration_vector_ * AccessType::kElements;
return TensorCoord(n, p, q, k);
}
@@ -194,7 +196,7 @@ public:
return coord.n() < problem_size_.N &&
coord.h() < problem_size_.P &&
coord.w() < problem_size_.Q &&
(coord.c() + iteration_vector_ * AccessType::kElements) < problem_size_.K;
coord.c() < problem_size_.K;
}
/// Returns a pointer to the vector starting at the current coordinate
@@ -204,7 +206,7 @@ public:
TensorCoord coord = at();
LongIndex offset = params_.layout(coord);
return reinterpret_cast<AccessType const *>(pointer_ + offset * sizeof_bits<Element>::value / 8) + iteration_vector_;
return reinterpret_cast<AccessType const *>(pointer_ + offset * sizeof_bits<Element>::value / 8);
}
/// Increments to the next memory access

View File

@@ -192,6 +192,32 @@ struct GELU_taylor {
}
};
template <int N>
struct GELU_taylor<Array<half_t, N> > {
static const bool kIsHeavy=true;
CUTLASS_HOST_DEVICE
Array<half_t, N> operator()(Array<half_t, N> const &z) const {
using T = half_t;
Array<half_t, N> y;
half_t k0 = half_t(0.7978845608028654);
half_t k1 = half_t(0.044715);
multiply_add<Array<half_t, N>> fma;
multiplies<Array<half_t, N>> mul;
plus<Array<half_t, N>> add;
fast_tanh_op<Array<half_t, N>> tanh;
Array<half_t, N> u = mul(mul(k0, z), fma(mul(k1, z), z, cutlass::constants::one<T>()));
y = mul(mul(z, cutlass::constants::half<T>()), add(cutlass::constants::one<T>(), tanh(u)));
return y;
}
};
template <typename T, int N>
struct GELU_taylor<Array<T, N> > {
static const bool kIsHeavy=true;

View File

@@ -234,8 +234,9 @@ public:
if (WarpShape::kN == 64) {
ptr = pointers_[n / 4];
}
#else
else
#endif
{
// This is the reference implementation
int column_idx = warp_column_ + n * Detail::kLanesInQuad * Policy::kElementsPerAccess;
int ptr_idx = ((column_idx * sizeof_bits<Element>::value) / 1024) % Detail::kPointerCount;
@@ -252,7 +253,8 @@ public:
else if (ptr_idx == 3) {
ptr = pointers_[3 % Detail::kPointerCount];
}
#endif
}
int offset = n * Detail::kLanesInQuad + pointer_offset / Policy::kElementsPerAccess;

View File

@@ -34,6 +34,7 @@
#endif
#include "cutlass/cutlass.h"
#include "cutlass/array.h"
#include "cutlass/uint128.h"
#include "cutlass/coord.h"
#include "cutlass/numeric_types.h"
@@ -724,7 +725,13 @@ double fast_log(double x) {
CUTLASS_HOST_DEVICE
float fast_tanh(float x) {
#if defined(__CUDA_ARCH__)
return ::tanhf(x);
#if (__CUDACC_VER_MAJOR__ >= 11) && (__CUDA_ARCH__ >= 750)
float y;
asm volatile ( "tanh.approx.f32 %0, %1; " : "=f"(y) : "f"(x));
return y;
#else
return ::tanhf(x);
#endif
#else
return std::tanh(x);
#endif
@@ -739,6 +746,74 @@ double fast_tanh(double x) {
#endif
}
CUTLASS_HOST_DEVICE
half_t fast_tanh(half_t x) {
#if defined(__CUDA_ARCH__) && (__CUDACC_VER_MAJOR__ >= 11) && (__CUDA_ARCH__ >= 750)
asm volatile ( "tanh.approx.f16 %0, %1;" : "=h"(x.raw()) : "h"(x.raw()));
return x;
#else
return half_t(fast_tanh(float(x)));
#endif
}
/////////////////////////////////////////////////////////////////////////////////////////////////
template <typename T>
struct fast_tanh_op {
CUTLASS_HOST_DEVICE
T operator()(T const &rhs) const {
return fast_tanh(rhs);
}
};
#if defined(__CUDA_ARCH__) && (__CUDACC_VER_MAJOR__ >= 11) && (__CUDA_ARCH__ >= 750)
template <int N>
struct fast_tanh_op<Array<half_t, N>> {
CUTLASS_DEVICE
Array<half_t, N> operator()(Array<half_t, N> const &rhs) const {
Array<half_t, N> result;
// use x2 specialization
uint32_t const *in = reinterpret_cast<uint32_t const *>(&rhs);
uint32_t *out = reinterpret_cast<uint32_t *>(&result);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < N / 2; ++i) {
asm volatile ("tanh.approx.f16x2 %0, %1;" : "=r"(out[i]) : "r"(in[i]));
}
// residual
if (N % 2) {
uint16_t const *in = reinterpret_cast<uint16_t const *>(&rhs);
uint16_t *out = reinterpret_cast<uint16_t *>(&result);
asm volatile ("tanh.approx.f16 %0, %1;" : "=h"(out[N - 1]) : "h"(in[N - 1]));
}
return result;
}
};
#endif // #if defined(__CUDA_ARCH__)
template <typename T, int N>
struct fast_tanh_op<Array<T, N>> {
CUTLASS_HOST_DEVICE
Array<T, N> operator()(Array<T, N> const &rhs) const {
fast_tanh_op<T> fast_op;
Array<T, N> y;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < N; ++i) {
y[i] = fast_op(rhs[i]);
}
return y;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass

View File

@@ -126,7 +126,7 @@ struct DefaultGemmWithKReduction {
ThreadblockShape, typename Mma::Operator, kPartitionsK, EpilogueOutputOp,
EpilogueOutputOp::kCount>::Epilogue;
/// Define the epilogue
/// Define the epilogue of the reduction vector
using EpilogueGemmKReduction =
typename cutlass::epilogue::threadblock::EpilogueGemmKReduction<
ElementAccumulator, ElementC, ThreadblockShape, typename Mma::Operator, kReduceKForA>;

View File

@@ -582,6 +582,13 @@ public:
__threadfence();
}
// Execute the epilogue operator to update the destination tensor.
epilogue(
output_op,
iterator_D,
accumulators,
iterator_C);
if ((kReduceKForA && threadblock_tile_offset.n() == 0)
|| (!kReduceKForA && threadblock_tile_offset.m() == 0)) {
@@ -610,14 +617,7 @@ public:
&& (threadblock_tile_offset.k() > 0));
}
}
// Execute the epilogue operator to update the destination tensor.
epilogue(
output_op,
iterator_D,
accumulators,
iterator_C);
//
// Release the semaphore
//

View File

@@ -378,11 +378,21 @@ struct DefaultMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB,
|| platform::is_same<LayoutC, layout::AffineRankN<2>>::value,
"simt epilogue must be row major");
static cutlass::arch::CacheOperation::Kind const CacheOpA =
((sizeof_bits<ElementA>::value * kAlignmentA) == 128)
? cutlass::arch::CacheOperation::Global
: cutlass::arch::CacheOperation::Always;
static cutlass::arch::CacheOperation::Kind const CacheOpB =
((sizeof_bits<ElementB>::value * kAlignmentB) == 128)
? cutlass::arch::CacheOperation::Global
: cutlass::arch::CacheOperation::Always;
// Define the MmaCore components
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA,
ElementB, LayoutB, ElementAccumulator, LayoutC, arch::OpClassSimt,
Stages, Operator>;
Stages, Operator, false, CacheOpA, CacheOpB>;
// Define iterators over tiles from the A operand
using ThreadMapA = typename MmaCore::IteratorThreadMapA;

View File

@@ -1111,8 +1111,8 @@ struct DefaultMmaCore<
using ElementC = complex<double>;
using LayoutC = LayoutC_;
static int const kStages = Stages;
static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always;
static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always;
static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Global;
static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Global;
static const ComplexTransform TransformA = TransformA_;
static const ComplexTransform TransformB = TransformB_;

View File

@@ -116,11 +116,22 @@ struct DefaultMultistageMmaComplex<ElementA, LayoutA, ElementB, LayoutB,
ElementAccumulator, layout::RowMajor, OperatorClass,
ArchTag, ThreadblockShape, WarpShape,
InstructionShape, Stages, TransformA, TransformB, Operator> {
static cutlass::arch::CacheOperation::Kind const CacheOpA =
(sizeof_bits<ElementA>::value == 128)
? cutlass::arch::CacheOperation::Global
: cutlass::arch::CacheOperation::Always;
static cutlass::arch::CacheOperation::Kind const CacheOpB =
(sizeof_bits<ElementB>::value == 128)
? cutlass::arch::CacheOperation::Global
: cutlass::arch::CacheOperation::Always;
// Define the MmaCore components
using MmaCore = typename cutlass::gemm::threadblock::DefaultMultistageMmaComplexCore<
ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA,
ElementB, LayoutB, ElementAccumulator, layout::RowMajor, OperatorClass,
Stages, TransformA, TransformB, Operator>;
Stages, TransformA, TransformB, Operator, CacheOpA, CacheOpB>;
// Define iterators over tiles from the A operand
using ThreadMapA = typename MmaCore::IteratorThreadMapA;

View File

@@ -113,8 +113,8 @@ struct DefaultMultistageMmaComplexCore<
static ComplexTransform const kTransformA = TransformA;
static ComplexTransform const kTransformB = TransformB;
using Operator = Operator_;
static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always;
static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always;
static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Global;
static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Global;
/// Number of warps present
using WarpCount = GemmShape<Shape::kM / WarpShape::kM,
@@ -242,8 +242,8 @@ struct DefaultMultistageMmaComplexCore<
using Operator = Operator_;
static ComplexTransform const kTransformA = TransformA;
static ComplexTransform const kTransformB = TransformB;
static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always;
static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always;
static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Global;
static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Global;
/// Number of warps present
using WarpCount = GemmShape<Shape::kM / WarpShape::kM,
@@ -371,8 +371,8 @@ struct DefaultMultistageMmaComplexCore<
static ComplexTransform const kTransformA = TransformA;
static ComplexTransform const kTransformB = TransformB;
using Operator = Operator_;
static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always;
static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always;
static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Global;
static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Global;
/// Number of warps present
using WarpCount = GemmShape<Shape::kM / WarpShape::kM,
@@ -501,8 +501,8 @@ struct DefaultMultistageMmaComplexCore<
static ComplexTransform const kTransformA = TransformA;
static ComplexTransform const kTransformB = TransformB;
using Operator = Operator_;
static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always;
static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always;
static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Global;
static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Global;
/// Number of warps present
using WarpCount = GemmShape<Shape::kM / WarpShape::kM,
@@ -1159,8 +1159,8 @@ struct DefaultMultistageMmaComplexCore<
static ComplexTransform const kTransformA = TransformA;
static ComplexTransform const kTransformB = TransformB;
using Operator = Operator_;
static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always;
static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always;
static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA;
static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB;
/// Number of warps present
using WarpCount = GemmShape<Shape::kM / WarpShape::kM,
@@ -1326,8 +1326,8 @@ struct DefaultMultistageMmaComplexCore<
static ComplexTransform const kTransformA = TransformA;
static ComplexTransform const kTransformB = TransformB;
using Operator = Operator_;
static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always;
static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always;
static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA;
static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB;
/// Number of warps present
using WarpCount = GemmShape<Shape::kM / WarpShape::kM,
@@ -1490,8 +1490,8 @@ struct DefaultMultistageMmaComplexCore<
static ComplexTransform const kTransformA = TransformA;
static ComplexTransform const kTransformB = TransformB;
using Operator = Operator_;
static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always;
static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always;
static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA;
static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB;
/// Number of warps present
using WarpCount = GemmShape<Shape::kM / WarpShape::kM,
@@ -1660,8 +1660,8 @@ struct DefaultMultistageMmaComplexCore<
static ComplexTransform const kTransformA = TransformA;
static ComplexTransform const kTransformB = TransformB;
using Operator = Operator_;
static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always;
static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always;
static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA;
static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB;
/// Number of warps present
using WarpCount = GemmShape<Shape::kM / WarpShape::kM,
@@ -1775,7 +1775,6 @@ struct DefaultMultistageMmaComplexCore<
////////////////////////////////////////////////////////////////////////////////
} // namespace threadblock
} // namespace gemm
} // namespace cutlass

View File

@@ -298,7 +298,6 @@ class PredicatedTileAccessIteratorPredicates {
return pred;
}
};
////////////////////////////////////////////////////////////////////////////////