mirror of
https://github.com/NVIDIA/cutlass.git
synced 2026-05-11 17:00:05 +00:00
refine the implementation
This commit is contained in:
@@ -105,6 +105,18 @@ public:
|
||||
return status;
|
||||
}
|
||||
|
||||
static int const kAlignmentC = ImplicitGemmKernel::Epilogue::OutputTileIterator::kElementsPerAccess;
|
||||
if (kConvolutionalOperator == conv::Operator::kFprop) {
|
||||
if (args.problem_size.K % kAlignmentC)
|
||||
return Status::kErrorMisalignedOperand;
|
||||
} else if (kConvolutionalOperator == conv::Operator::kDgrad) {
|
||||
if (args.problem_size.C % kAlignmentC)
|
||||
return Status::kErrorMisalignedOperand;
|
||||
} else if (kConvolutionalOperator == conv::Operator::kWgrad) {
|
||||
if (args.problem_size.C % kAlignmentC)
|
||||
return Status::kErrorMisalignedOperand;
|
||||
}
|
||||
|
||||
// check for unsupported problem sizes for strided dgrad implementation
|
||||
if (kConvolutionalOperator == conv::Operator::kDgrad &&
|
||||
kStrideSupport == conv::StrideSupport::kStrided) {
|
||||
|
||||
@@ -66,7 +66,11 @@ template <
|
||||
int Stages,
|
||||
typename MathOperatorTag,
|
||||
conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kAnalytic,
|
||||
conv::StrideSupport StrideSupport = StrideSupport::kStrided
|
||||
conv::StrideSupport StrideSupport = StrideSupport::kStrided,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value
|
||||
> struct DefaultConv2dDgrad;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -90,7 +94,9 @@ template <
|
||||
typename EpilogueOutputOp,
|
||||
typename ThreadblockSwizzle,
|
||||
int Stages,
|
||||
typename MathOperatorTag
|
||||
typename MathOperatorTag,
|
||||
int AlignmentA,
|
||||
int AlignmentB
|
||||
>
|
||||
struct DefaultConv2dDgrad <
|
||||
ElementA,
|
||||
@@ -110,7 +116,9 @@ struct DefaultConv2dDgrad <
|
||||
Stages,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kAnalytic,
|
||||
StrideSupport::kStrided
|
||||
StrideSupport::kStrided,
|
||||
AlignmentA,
|
||||
AlignmentB
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
@@ -121,24 +129,28 @@ struct DefaultConv2dDgrad <
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
||||
using AccessTypeA = cutlass::AlignedArray<ElementA, AlignmentA>;
|
||||
using IteratorA =
|
||||
cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
||||
ElementA,
|
||||
ThreadMapA,
|
||||
StrideSupport::kStrided
|
||||
StrideSupport::kStrided,
|
||||
AccessTypeA
|
||||
>;
|
||||
|
||||
using SmemIteratorA = typename MmaCore::SmemIteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
||||
using AccessTypeB = cutlass::AlignedArray<ElementB, AlignmentB>;
|
||||
using IteratorB =
|
||||
cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
||||
ElementB,
|
||||
ThreadMapB,
|
||||
StrideSupport::kStrided
|
||||
StrideSupport::kStrided,
|
||||
AccessTypeB
|
||||
>;
|
||||
|
||||
using SmemIteratorB = typename MmaCore::SmemIteratorB;
|
||||
@@ -147,6 +159,11 @@ struct DefaultConv2dDgrad <
|
||||
using WarpMmaTensorOp = typename MmaCore::MmaTensorOp;
|
||||
using MmaPolicy = typename MmaCore::MmaPolicy;
|
||||
|
||||
static cutlass::arch::CacheOperation::Kind const CacheOpB =
|
||||
((sizeof_bits<ElementB>::value * AlignmentB) == 128)
|
||||
? cutlass::arch::CacheOperation::Global
|
||||
: cutlass::arch::CacheOperation::Always;
|
||||
|
||||
// Define the Mma
|
||||
using Mma = threadblock::ImplicitGemmMultistage<
|
||||
ThreadblockShape,
|
||||
@@ -155,7 +172,7 @@ struct DefaultConv2dDgrad <
|
||||
arch::CacheOperation::Always,
|
||||
IteratorB,
|
||||
SmemIteratorB,
|
||||
arch::CacheOperation::Global,
|
||||
CacheOpB,
|
||||
MmaPolicy,
|
||||
Stages
|
||||
>;
|
||||
@@ -196,7 +213,9 @@ template <
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp,
|
||||
typename ThreadblockSwizzle,
|
||||
typename MathOperatorTag
|
||||
typename MathOperatorTag,
|
||||
int AlignmentA,
|
||||
int AlignmentB
|
||||
>
|
||||
struct DefaultConv2dDgrad <
|
||||
ElementA,
|
||||
@@ -216,7 +235,9 @@ struct DefaultConv2dDgrad <
|
||||
2,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kAnalytic,
|
||||
StrideSupport::kStrided
|
||||
StrideSupport::kStrided,
|
||||
AlignmentA,
|
||||
AlignmentB
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
@@ -227,13 +248,15 @@ struct DefaultConv2dDgrad <
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
||||
using AccessTypeA = cutlass::AlignedArray<ElementA, AlignmentA>;
|
||||
using IteratorA =
|
||||
cutlass::conv::threadblock::TileIteratorStridedDgrad<
|
||||
cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
||||
ElementA,
|
||||
ThreadMapA,
|
||||
StrideSupport::kStrided
|
||||
StrideSupport::kStrided,
|
||||
AccessTypeA
|
||||
>
|
||||
>;
|
||||
|
||||
@@ -241,13 +264,15 @@ struct DefaultConv2dDgrad <
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
||||
using AccessTypeB = cutlass::AlignedArray<ElementB, AlignmentB>;
|
||||
using IteratorB =
|
||||
cutlass::conv::threadblock::TileIteratorStridedDgrad<
|
||||
cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
||||
ElementB,
|
||||
ThreadMapB,
|
||||
StrideSupport::kStrided
|
||||
StrideSupport::kStrided,
|
||||
AccessTypeB
|
||||
>
|
||||
>;
|
||||
|
||||
@@ -308,7 +333,9 @@ template <
|
||||
typename EpilogueOutputOp,
|
||||
typename ThreadblockSwizzle,
|
||||
int Stages,
|
||||
typename MathOperatorTag
|
||||
typename MathOperatorTag,
|
||||
int AlignmentA,
|
||||
int AlignmentB
|
||||
>
|
||||
struct DefaultConv2dDgrad <
|
||||
ElementA,
|
||||
@@ -328,7 +355,9 @@ struct DefaultConv2dDgrad <
|
||||
Stages,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kAnalytic,
|
||||
StrideSupport::kUnity
|
||||
StrideSupport::kUnity,
|
||||
AlignmentA,
|
||||
AlignmentB
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
@@ -339,24 +368,28 @@ struct DefaultConv2dDgrad <
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
||||
using AccessTypeA = cutlass::AlignedArray<ElementA, AlignmentA>;
|
||||
using IteratorA =
|
||||
cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
||||
ElementA,
|
||||
ThreadMapA,
|
||||
StrideSupport::kUnity
|
||||
StrideSupport::kUnity,
|
||||
AccessTypeA
|
||||
>;
|
||||
|
||||
using SmemIteratorA = typename MmaCore::SmemIteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
||||
using AccessTypeB = cutlass::AlignedArray<ElementB, AlignmentB>;
|
||||
using IteratorB =
|
||||
cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
||||
ElementB,
|
||||
ThreadMapB,
|
||||
StrideSupport::kUnity
|
||||
StrideSupport::kUnity,
|
||||
AccessTypeB
|
||||
>;
|
||||
|
||||
using SmemIteratorB = typename MmaCore::SmemIteratorB;
|
||||
@@ -365,6 +398,11 @@ struct DefaultConv2dDgrad <
|
||||
using WarpMmaTensorOp = typename MmaCore::MmaTensorOp;
|
||||
using MmaPolicy = typename MmaCore::MmaPolicy;
|
||||
|
||||
static cutlass::arch::CacheOperation::Kind const CacheOpB =
|
||||
((sizeof_bits<ElementB>::value * AlignmentB) == 128)
|
||||
? cutlass::arch::CacheOperation::Global
|
||||
: cutlass::arch::CacheOperation::Always;
|
||||
|
||||
// Define the Mma
|
||||
using Mma = threadblock::ImplicitGemmMultistage<
|
||||
ThreadblockShape,
|
||||
@@ -373,7 +411,7 @@ struct DefaultConv2dDgrad <
|
||||
arch::CacheOperation::Always,
|
||||
IteratorB,
|
||||
SmemIteratorB,
|
||||
arch::CacheOperation::Global,
|
||||
CacheOpB,
|
||||
MmaPolicy,
|
||||
Stages
|
||||
>;
|
||||
@@ -414,7 +452,9 @@ template <
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp,
|
||||
typename ThreadblockSwizzle,
|
||||
typename MathOperatorTag
|
||||
typename MathOperatorTag,
|
||||
int AlignmentA,
|
||||
int AlignmentB
|
||||
>
|
||||
struct DefaultConv2dDgrad <
|
||||
ElementA,
|
||||
@@ -434,7 +474,9 @@ struct DefaultConv2dDgrad <
|
||||
2,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kAnalytic,
|
||||
StrideSupport::kUnity
|
||||
StrideSupport::kUnity,
|
||||
AlignmentA,
|
||||
AlignmentB
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
@@ -445,13 +487,15 @@ struct DefaultConv2dDgrad <
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
||||
using AccessTypeA = cutlass::AlignedArray<ElementA, AlignmentA>;
|
||||
using IteratorA =
|
||||
cutlass::conv::threadblock::TileIterator<
|
||||
cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
||||
ElementA,
|
||||
ThreadMapA,
|
||||
StrideSupport::kUnity
|
||||
StrideSupport::kUnity,
|
||||
AccessTypeA
|
||||
>
|
||||
>;
|
||||
|
||||
@@ -459,13 +503,15 @@ struct DefaultConv2dDgrad <
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
||||
using AccessTypeB = cutlass::AlignedArray<ElementB, AlignmentB>;
|
||||
using IteratorB =
|
||||
cutlass::conv::threadblock::TileIterator<
|
||||
cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
||||
ElementB,
|
||||
ThreadMapB,
|
||||
StrideSupport::kUnity
|
||||
StrideSupport::kUnity,
|
||||
AccessTypeB
|
||||
>
|
||||
>;
|
||||
|
||||
@@ -526,7 +572,9 @@ template <
|
||||
typename EpilogueOutputOp,
|
||||
typename ThreadblockSwizzle,
|
||||
int Stages,
|
||||
typename MathOperatorTag
|
||||
typename MathOperatorTag,
|
||||
int AlignmentA,
|
||||
int AlignmentB
|
||||
>
|
||||
struct DefaultConv2dDgrad <
|
||||
ElementA,
|
||||
@@ -546,7 +594,9 @@ struct DefaultConv2dDgrad <
|
||||
Stages,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kOptimized,
|
||||
StrideSupport::kUnity
|
||||
StrideSupport::kUnity,
|
||||
AlignmentA,
|
||||
AlignmentB
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
@@ -557,23 +607,28 @@ struct DefaultConv2dDgrad <
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
||||
using AccessTypeA = cutlass::AlignedArray<ElementA, AlignmentA>;
|
||||
using IteratorA =
|
||||
cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
||||
ElementA,
|
||||
ThreadMapA,
|
||||
StrideSupport::kUnity
|
||||
StrideSupport::kUnity,
|
||||
AccessTypeA
|
||||
>;
|
||||
|
||||
using SmemIteratorA = typename MmaCore::SmemIteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
||||
using AccessTypeB = cutlass::AlignedArray<ElementB, AlignmentB>;
|
||||
using IteratorB =
|
||||
cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
||||
ElementB,
|
||||
ThreadMapB
|
||||
ThreadMapB,
|
||||
StrideSupport::kUnity,
|
||||
AccessTypeB
|
||||
>;
|
||||
|
||||
using SmemIteratorB = typename MmaCore::SmemIteratorB;
|
||||
@@ -582,6 +637,11 @@ struct DefaultConv2dDgrad <
|
||||
using WarpMmaTensorOp = typename MmaCore::MmaTensorOp;
|
||||
using MmaPolicy = typename MmaCore::MmaPolicy;
|
||||
|
||||
static cutlass::arch::CacheOperation::Kind const CacheOpB =
|
||||
((sizeof_bits<ElementB>::value * AlignmentB) == 128)
|
||||
? cutlass::arch::CacheOperation::Global
|
||||
: cutlass::arch::CacheOperation::Always;
|
||||
|
||||
// Define the Mma
|
||||
using Mma = threadblock::ImplicitGemmMultistage<
|
||||
ThreadblockShape,
|
||||
@@ -590,7 +650,7 @@ struct DefaultConv2dDgrad <
|
||||
arch::CacheOperation::Always,
|
||||
IteratorB,
|
||||
SmemIteratorB,
|
||||
arch::CacheOperation::Global,
|
||||
CacheOpB,
|
||||
MmaPolicy,
|
||||
Stages
|
||||
>;
|
||||
@@ -631,7 +691,9 @@ template <
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp,
|
||||
typename ThreadblockSwizzle,
|
||||
typename MathOperatorTag
|
||||
typename MathOperatorTag,
|
||||
int AlignmentA,
|
||||
int AlignmentB
|
||||
>
|
||||
struct DefaultConv2dDgrad <
|
||||
ElementA,
|
||||
@@ -651,7 +713,9 @@ struct DefaultConv2dDgrad <
|
||||
2,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kOptimized,
|
||||
StrideSupport::kUnity
|
||||
StrideSupport::kUnity,
|
||||
AlignmentA,
|
||||
AlignmentB
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
@@ -662,13 +726,15 @@ struct DefaultConv2dDgrad <
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
||||
using AccessTypeA = cutlass::AlignedArray<ElementA, AlignmentA>;
|
||||
using IteratorA =
|
||||
cutlass::conv::threadblock::TileIterator<
|
||||
cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
||||
ElementA,
|
||||
ThreadMapA,
|
||||
StrideSupport::kUnity
|
||||
StrideSupport::kUnity,
|
||||
AccessTypeA
|
||||
>
|
||||
>;
|
||||
|
||||
@@ -676,12 +742,15 @@ struct DefaultConv2dDgrad <
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
||||
using AccessTypeB = cutlass::AlignedArray<ElementB, AlignmentB>;
|
||||
using IteratorB =
|
||||
cutlass::conv::threadblock::TileIterator<
|
||||
cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
||||
ElementB,
|
||||
ThreadMapB
|
||||
ThreadMapB,
|
||||
StrideSupport::kUnity,
|
||||
AccessTypeB
|
||||
>
|
||||
>;
|
||||
|
||||
@@ -744,7 +813,10 @@ template <
|
||||
typename EpilogueOutputOp,
|
||||
typename ThreadblockSwizzle,
|
||||
int Stages,
|
||||
typename MathOperatorTag>
|
||||
typename MathOperatorTag,
|
||||
int AlignmentA,
|
||||
int AlignmentB
|
||||
>
|
||||
struct DefaultConv2dDgrad <
|
||||
ElementA,
|
||||
LayoutA,
|
||||
@@ -763,7 +835,9 @@ struct DefaultConv2dDgrad <
|
||||
Stages,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kAnalytic,
|
||||
conv::StrideSupport::kUnity
|
||||
conv::StrideSupport::kUnity,
|
||||
AlignmentA,
|
||||
AlignmentB
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
@@ -848,7 +922,10 @@ template <
|
||||
typename EpilogueOutputOp,
|
||||
typename ThreadblockSwizzle,
|
||||
int Stages,
|
||||
typename MathOperatorTag>
|
||||
typename MathOperatorTag,
|
||||
int AlignmentA,
|
||||
int AlignmentB
|
||||
>
|
||||
struct DefaultConv2dDgrad <
|
||||
ElementA,
|
||||
LayoutA,
|
||||
@@ -867,7 +944,9 @@ struct DefaultConv2dDgrad <
|
||||
Stages,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kAnalytic,
|
||||
conv::StrideSupport::kStrided
|
||||
conv::StrideSupport::kStrided,
|
||||
AlignmentA,
|
||||
AlignmentB
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
@@ -955,7 +1034,9 @@ template <
|
||||
typename EpilogueOutputOp,
|
||||
typename ThreadblockSwizzle,
|
||||
int Stages,
|
||||
typename MathOperatorTag
|
||||
typename MathOperatorTag,
|
||||
int AlignmentA,
|
||||
int AlignmentB
|
||||
>
|
||||
struct DefaultConv2dDgrad <
|
||||
ElementA,
|
||||
@@ -975,7 +1056,9 @@ struct DefaultConv2dDgrad <
|
||||
Stages,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kOptimized,
|
||||
StrideSupport::kUnity
|
||||
StrideSupport::kUnity,
|
||||
AlignmentA,
|
||||
AlignmentB
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
@@ -1040,10 +1123,8 @@ struct DefaultConv2dDgrad <
|
||||
ThreadblockSwizzle,
|
||||
conv::Operator::kDgrad
|
||||
>;
|
||||
|
||||
};
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@@ -1063,7 +1144,9 @@ template <
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp,
|
||||
typename ThreadblockSwizzle,
|
||||
typename MathOperatorTag
|
||||
typename MathOperatorTag,
|
||||
int AlignmentA,
|
||||
int AlignmentB
|
||||
>
|
||||
struct DefaultConv2dDgrad <
|
||||
ElementA,
|
||||
@@ -1083,7 +1166,9 @@ struct DefaultConv2dDgrad <
|
||||
2,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kAnalytic,
|
||||
conv::StrideSupport::kUnity
|
||||
conv::StrideSupport::kUnity,
|
||||
AlignmentA,
|
||||
AlignmentB
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
@@ -1169,7 +1254,9 @@ template <
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp,
|
||||
typename ThreadblockSwizzle,
|
||||
typename MathOperatorTag
|
||||
typename MathOperatorTag,
|
||||
int AlignmentA,
|
||||
int AlignmentB
|
||||
>
|
||||
struct DefaultConv2dDgrad <
|
||||
ElementA,
|
||||
@@ -1189,7 +1276,9 @@ struct DefaultConv2dDgrad <
|
||||
2,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kAnalytic,
|
||||
conv::StrideSupport::kStrided
|
||||
conv::StrideSupport::kStrided,
|
||||
AlignmentA,
|
||||
AlignmentB
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
@@ -1257,7 +1346,6 @@ struct DefaultConv2dDgrad <
|
||||
ThreadblockSwizzle,
|
||||
conv::Operator::kDgrad
|
||||
>;
|
||||
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -1278,7 +1366,9 @@ template <
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp,
|
||||
typename ThreadblockSwizzle,
|
||||
typename MathOperatorTag
|
||||
typename MathOperatorTag,
|
||||
int AlignmentA,
|
||||
int AlignmentB
|
||||
>
|
||||
struct DefaultConv2dDgrad <
|
||||
ElementA,
|
||||
@@ -1298,7 +1388,9 @@ struct DefaultConv2dDgrad <
|
||||
2,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kOptimized,
|
||||
StrideSupport::kUnity
|
||||
StrideSupport::kUnity,
|
||||
AlignmentA,
|
||||
AlignmentB
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
@@ -1368,10 +1460,10 @@ struct DefaultConv2dDgrad <
|
||||
>;
|
||||
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
} // namespace kernel
|
||||
} // namespace conv
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
@@ -66,11 +66,11 @@ template <
|
||||
int Stages,
|
||||
typename MathOperatorTag,
|
||||
conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kAnalytic,
|
||||
conv::StrideSupport StrideSupport = StrideSupport::kStrided,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value,
|
||||
conv::StrideSupport StrideSupport = StrideSupport::kStrided
|
||||
int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value
|
||||
> struct DefaultConv2dFprop;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -95,6 +95,7 @@ template <
|
||||
typename ThreadblockSwizzle,
|
||||
int Stages,
|
||||
typename MathOperatorTag,
|
||||
conv::StrideSupport StrideSupport,
|
||||
int AlignmentA,
|
||||
int AlignmentB
|
||||
>
|
||||
@@ -116,6 +117,7 @@ struct DefaultConv2dFprop <
|
||||
Stages,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kAnalytic,
|
||||
StrideSupport,
|
||||
AlignmentA,
|
||||
AlignmentB
|
||||
> {
|
||||
@@ -128,22 +130,26 @@ struct DefaultConv2dFprop <
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
||||
using AccessTypeA = cutlass::AlignedArray<ElementA, AlignmentA>;
|
||||
using IteratorA =
|
||||
cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
||||
ElementA, LayoutA,
|
||||
ThreadMapA
|
||||
ThreadMapA,
|
||||
AccessTypeA
|
||||
>;
|
||||
|
||||
using SmemIteratorA = typename MmaCore::SmemIteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
||||
using AccessTypeB = cutlass::AlignedArray<ElementB, AlignmentB>;
|
||||
using IteratorB =
|
||||
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
||||
ElementB, LayoutB,
|
||||
ThreadMapB
|
||||
ThreadMapB,
|
||||
AccessTypeB
|
||||
>;
|
||||
|
||||
using SmemIteratorB = typename MmaCore::SmemIteratorB;
|
||||
@@ -152,6 +158,11 @@ struct DefaultConv2dFprop <
|
||||
using WarpMmaTensorOp = typename MmaCore::MmaTensorOp;
|
||||
using MmaPolicy = typename MmaCore::MmaPolicy;
|
||||
|
||||
static cutlass::arch::CacheOperation::Kind const CacheOpB =
|
||||
((sizeof_bits<ElementB>::value * AlignmentB) == 128)
|
||||
? cutlass::arch::CacheOperation::Global
|
||||
: cutlass::arch::CacheOperation::Always;
|
||||
|
||||
// Define the Mma
|
||||
using Mma = threadblock::ImplicitGemmMultistage<
|
||||
ThreadblockShape,
|
||||
@@ -160,7 +171,7 @@ struct DefaultConv2dFprop <
|
||||
arch::CacheOperation::Always,
|
||||
IteratorB,
|
||||
SmemIteratorB,
|
||||
arch::CacheOperation::Global,
|
||||
CacheOpB,
|
||||
MmaPolicy,
|
||||
Stages
|
||||
>;
|
||||
@@ -203,9 +214,10 @@ template <
|
||||
typename ThreadblockSwizzle,
|
||||
int Stages,
|
||||
typename MathOperatorTag,
|
||||
int InterleavedK,
|
||||
conv::StrideSupport StrideSupport,
|
||||
int AlignmentA,
|
||||
int AlignmentB
|
||||
int AlignmentB,
|
||||
int InterleavedK
|
||||
>
|
||||
struct DefaultConv2dFprop <
|
||||
ElementA,
|
||||
@@ -225,6 +237,7 @@ struct DefaultConv2dFprop <
|
||||
Stages,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kAnalytic,
|
||||
StrideSupport,
|
||||
AlignmentA,
|
||||
AlignmentB
|
||||
> {
|
||||
@@ -325,6 +338,7 @@ template <
|
||||
typename EpilogueOutputOp,
|
||||
typename ThreadblockSwizzle,
|
||||
typename MathOperatorTag,
|
||||
conv::StrideSupport StrideSupport,
|
||||
int AlignmentA,
|
||||
int AlignmentB
|
||||
>
|
||||
@@ -346,6 +360,7 @@ struct DefaultConv2dFprop <
|
||||
2,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kAnalytic,
|
||||
StrideSupport,
|
||||
AlignmentA,
|
||||
AlignmentB
|
||||
> {
|
||||
@@ -358,12 +373,14 @@ struct DefaultConv2dFprop <
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
||||
using AccessTypeA = cutlass::AlignedArray<ElementA, AlignmentA>;
|
||||
using IteratorA =
|
||||
cutlass::conv::threadblock::TileIterator<
|
||||
cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
||||
ElementA, LayoutA,
|
||||
ThreadMapA
|
||||
ThreadMapA,
|
||||
AccessTypeA
|
||||
>
|
||||
>;
|
||||
|
||||
@@ -371,12 +388,14 @@ struct DefaultConv2dFprop <
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
||||
using AccessTypeB = cutlass::AlignedArray<ElementB, AlignmentB>;
|
||||
using IteratorB =
|
||||
cutlass::conv::threadblock::TileIterator<
|
||||
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
||||
ElementB, LayoutB,
|
||||
ThreadMapB
|
||||
ThreadMapB,
|
||||
AccessTypeB
|
||||
>
|
||||
>;
|
||||
|
||||
@@ -435,9 +454,10 @@ template <
|
||||
typename EpilogueOutputOp,
|
||||
typename ThreadblockSwizzle,
|
||||
typename MathOperatorTag,
|
||||
int InterleavedK,
|
||||
conv::StrideSupport StrideSupport,
|
||||
int AlignmentA,
|
||||
int AlignmentB
|
||||
int AlignmentB,
|
||||
int InterleavedK
|
||||
>
|
||||
struct DefaultConv2dFprop <
|
||||
ElementA,
|
||||
@@ -457,6 +477,7 @@ struct DefaultConv2dFprop <
|
||||
2,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kAnalytic,
|
||||
StrideSupport,
|
||||
AlignmentA,
|
||||
AlignmentB
|
||||
> {
|
||||
@@ -561,6 +582,7 @@ template <
|
||||
typename ThreadblockSwizzle,
|
||||
int Stages,
|
||||
typename MathOperatorTag,
|
||||
conv::StrideSupport StrideSupport,
|
||||
int AlignmentA,
|
||||
int AlignmentB
|
||||
>
|
||||
@@ -582,6 +604,7 @@ struct DefaultConv2dFprop <
|
||||
Stages,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kOptimized,
|
||||
StrideSupport,
|
||||
AlignmentA,
|
||||
AlignmentB
|
||||
> {
|
||||
@@ -595,26 +618,28 @@ struct DefaultConv2dFprop <
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
||||
using AccessTypeA = cutlass::AlignedArray<ElementA, AlignmentA>;
|
||||
using IteratorA =
|
||||
cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
||||
ElementA,
|
||||
LayoutA,
|
||||
ThreadMapA,
|
||||
AlignmentA
|
||||
AccessTypeA
|
||||
>;
|
||||
|
||||
using SmemIteratorA = typename MmaCore::SmemIteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
||||
using AccessTypeB = cutlass::AlignedArray<ElementB, AlignmentB>;
|
||||
using IteratorB =
|
||||
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
ThreadMapB,
|
||||
AlignmentB
|
||||
AccessTypeB
|
||||
>;
|
||||
|
||||
using SmemIteratorB = typename MmaCore::SmemIteratorB;
|
||||
@@ -624,7 +649,7 @@ struct DefaultConv2dFprop <
|
||||
using MmaPolicy = typename MmaCore::MmaPolicy;
|
||||
|
||||
static cutlass::arch::CacheOperation::Kind const CacheOpB =
|
||||
((sizeof_bits<ElementA>::value * AlignmentB) == 128)
|
||||
((sizeof_bits<ElementB>::value * AlignmentB) == 128)
|
||||
? cutlass::arch::CacheOperation::Global
|
||||
: cutlass::arch::CacheOperation::Always;
|
||||
|
||||
@@ -679,9 +704,10 @@ template <
|
||||
typename ThreadblockSwizzle,
|
||||
int Stages,
|
||||
typename MathOperatorTag,
|
||||
int InterleavedK,
|
||||
conv::StrideSupport StrideSupport,
|
||||
int AlignmentA,
|
||||
int AlignmentB
|
||||
int AlignmentB,
|
||||
int InterleavedK
|
||||
>
|
||||
struct DefaultConv2dFprop <
|
||||
ElementA,
|
||||
@@ -701,6 +727,7 @@ struct DefaultConv2dFprop <
|
||||
Stages,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kOptimized,
|
||||
StrideSupport,
|
||||
AlignmentA,
|
||||
AlignmentB
|
||||
> {
|
||||
@@ -774,6 +801,8 @@ struct DefaultConv2dFprop <
|
||||
>;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines a kernel for Conv2dFprop specialzation for Optimized IteratorAlgorithm
|
||||
/// and 2 stage pipeline.
|
||||
template <
|
||||
@@ -791,6 +820,7 @@ template <
|
||||
typename EpilogueOutputOp,
|
||||
typename ThreadblockSwizzle,
|
||||
typename MathOperatorTag,
|
||||
conv::StrideSupport StrideSupport,
|
||||
int AlignmentA,
|
||||
int AlignmentB
|
||||
>
|
||||
@@ -812,6 +842,7 @@ struct DefaultConv2dFprop <
|
||||
2,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kOptimized,
|
||||
StrideSupport,
|
||||
AlignmentA,
|
||||
AlignmentB
|
||||
> {
|
||||
@@ -824,6 +855,7 @@ struct DefaultConv2dFprop <
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
||||
using AccessTypeA = cutlass::AlignedArray<ElementA, AlignmentA>;
|
||||
using IteratorA =
|
||||
cutlass::conv::threadblock::TileIterator<
|
||||
cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized<
|
||||
@@ -831,7 +863,7 @@ struct DefaultConv2dFprop <
|
||||
ElementA,
|
||||
LayoutA,
|
||||
ThreadMapA,
|
||||
AlignmentA
|
||||
AccessTypeA
|
||||
>
|
||||
>;
|
||||
|
||||
@@ -839,6 +871,7 @@ struct DefaultConv2dFprop <
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
||||
using AccessTypeB = cutlass::AlignedArray<ElementB, AlignmentB>;
|
||||
using IteratorB =
|
||||
cutlass::conv::threadblock::TileIterator<
|
||||
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<
|
||||
@@ -846,7 +879,7 @@ struct DefaultConv2dFprop <
|
||||
ElementB,
|
||||
LayoutB,
|
||||
ThreadMapB,
|
||||
AlignmentB
|
||||
AccessTypeB
|
||||
>
|
||||
>;
|
||||
|
||||
@@ -905,9 +938,10 @@ template <
|
||||
typename EpilogueOutputOp,
|
||||
typename ThreadblockSwizzle,
|
||||
typename MathOperatorTag,
|
||||
int InterleavedK,
|
||||
conv::StrideSupport StrideSupport,
|
||||
int AlignmentA,
|
||||
int AlignmentB
|
||||
int AlignmentB,
|
||||
int InterleavedK
|
||||
>
|
||||
struct DefaultConv2dFprop <
|
||||
ElementA,
|
||||
@@ -927,6 +961,7 @@ struct DefaultConv2dFprop <
|
||||
2,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kOptimized,
|
||||
StrideSupport,
|
||||
AlignmentA,
|
||||
AlignmentB
|
||||
> {
|
||||
@@ -1023,6 +1058,7 @@ template <
|
||||
typename ThreadblockSwizzle,
|
||||
int Stages,
|
||||
typename MathOperatorTag,
|
||||
conv::StrideSupport StrideSupport,
|
||||
int AlignmentA,
|
||||
int AlignmentB
|
||||
>
|
||||
@@ -1044,6 +1080,7 @@ struct DefaultConv2dFprop <
|
||||
Stages,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kAnalytic,
|
||||
StrideSupport,
|
||||
AlignmentA,
|
||||
AlignmentB
|
||||
> {
|
||||
@@ -1132,6 +1169,7 @@ template <
|
||||
typename ThreadblockSwizzle,
|
||||
int Stages,
|
||||
typename MathOperatorTag,
|
||||
conv::StrideSupport StrideSupport,
|
||||
int AlignmentA,
|
||||
int AlignmentB
|
||||
>
|
||||
@@ -1153,6 +1191,7 @@ struct DefaultConv2dFprop <
|
||||
Stages,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kOptimized,
|
||||
StrideSupport,
|
||||
AlignmentA,
|
||||
AlignmentB
|
||||
> {
|
||||
@@ -1241,6 +1280,7 @@ template <
|
||||
typename EpilogueOutputOp,
|
||||
typename ThreadblockSwizzle,
|
||||
typename MathOperatorTag,
|
||||
conv::StrideSupport StrideSupport,
|
||||
int AlignmentA,
|
||||
int AlignmentB
|
||||
>
|
||||
@@ -1262,6 +1302,7 @@ struct DefaultConv2dFprop <
|
||||
2,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kAnalytic,
|
||||
StrideSupport,
|
||||
AlignmentA,
|
||||
AlignmentB
|
||||
> {
|
||||
@@ -1351,6 +1392,7 @@ template <
|
||||
typename EpilogueOutputOp,
|
||||
typename ThreadblockSwizzle,
|
||||
typename MathOperatorTag,
|
||||
conv::StrideSupport StrideSupport,
|
||||
int AlignmentA,
|
||||
int AlignmentB
|
||||
>
|
||||
@@ -1372,6 +1414,7 @@ struct DefaultConv2dFprop <
|
||||
2,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kOptimized,
|
||||
StrideSupport,
|
||||
AlignmentA,
|
||||
AlignmentB
|
||||
> {
|
||||
|
||||
@@ -65,7 +65,11 @@ template <
|
||||
int Stages,
|
||||
typename MathOperatorTag,
|
||||
conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kAnalytic,
|
||||
conv::StrideSupport StrideSupport = StrideSupport::kStrided
|
||||
conv::StrideSupport StrideSupport = StrideSupport::kStrided,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value
|
||||
>
|
||||
struct DefaultConv2dFpropWithBroadcast {
|
||||
|
||||
@@ -84,7 +88,9 @@ struct DefaultConv2dFpropWithBroadcast {
|
||||
Stages,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm,
|
||||
StrideSupport
|
||||
StrideSupport,
|
||||
AlignmentA,
|
||||
AlignmentB
|
||||
>::Kernel;
|
||||
|
||||
// Replace epilogue
|
||||
|
||||
@@ -66,7 +66,11 @@ template <
|
||||
int Stages,
|
||||
typename MathOperatorTag,
|
||||
conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kAnalytic,
|
||||
conv::StrideSupport StrideSupport = StrideSupport::kStrided
|
||||
conv::StrideSupport StrideSupport = StrideSupport::kStrided,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value
|
||||
>
|
||||
struct DefaultConv2dFpropWithReduction {
|
||||
|
||||
@@ -85,7 +89,9 @@ struct DefaultConv2dFpropWithReduction {
|
||||
Stages,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm,
|
||||
StrideSupport
|
||||
StrideSupport,
|
||||
AlignmentA,
|
||||
AlignmentB
|
||||
>::Kernel;
|
||||
|
||||
// Replace epilogue
|
||||
|
||||
@@ -67,8 +67,13 @@ template <
|
||||
int Stages,
|
||||
typename MathOperatorTag,
|
||||
conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kAnalytic,
|
||||
conv::StrideSupport StrideSupport = StrideSupport::kStrided
|
||||
conv::StrideSupport StrideSupport = StrideSupport::kStrided,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value
|
||||
> struct DefaultConv2dWgrad;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -93,7 +98,10 @@ template <
|
||||
typename EpilogueOutputOp,
|
||||
typename ThreadblockSwizzle,
|
||||
int Stages,
|
||||
typename MathOperatorTag
|
||||
typename MathOperatorTag,
|
||||
conv::StrideSupport StrideSupport,
|
||||
int AlignmentA,
|
||||
int AlignmentB
|
||||
>
|
||||
struct DefaultConv2dWgrad <
|
||||
ElementA,
|
||||
@@ -112,7 +120,10 @@ struct DefaultConv2dWgrad <
|
||||
ThreadblockSwizzle,
|
||||
Stages,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kAnalytic
|
||||
IteratorAlgorithm::kAnalytic,
|
||||
StrideSupport,
|
||||
AlignmentA,
|
||||
AlignmentB
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
@@ -123,22 +134,26 @@ struct DefaultConv2dWgrad <
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
||||
using AccessTypeA = cutlass::AlignedArray<ElementA, AlignmentA>;
|
||||
using IteratorA =
|
||||
cutlass::conv::threadblock::Conv2dWgradOutputGradientTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
||||
ElementA,
|
||||
ThreadMapA
|
||||
ThreadMapA,
|
||||
AccessTypeA
|
||||
>;
|
||||
|
||||
using SmemIteratorA = typename MmaCore::SmemIteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
||||
using AccessTypeB = cutlass::AlignedArray<ElementB, AlignmentB>;
|
||||
using IteratorB =
|
||||
cutlass::conv::threadblock::Conv2dWgradActivationTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
||||
ElementB,
|
||||
ThreadMapB
|
||||
ThreadMapB,
|
||||
AccessTypeB
|
||||
>;
|
||||
|
||||
using SmemIteratorB = typename MmaCore::SmemIteratorB;
|
||||
@@ -179,6 +194,7 @@ struct DefaultConv2dWgrad <
|
||||
conv::Operator::kWgrad
|
||||
>;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines a kernel for Conv2dWgrad specialzation for Analytic IteratorAlgorithm and two
|
||||
@@ -198,7 +214,10 @@ template <
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp,
|
||||
typename ThreadblockSwizzle,
|
||||
typename MathOperatorTag
|
||||
typename MathOperatorTag,
|
||||
conv::StrideSupport StrideSupport,
|
||||
int AlignmentA,
|
||||
int AlignmentB
|
||||
>
|
||||
struct DefaultConv2dWgrad <
|
||||
ElementA,
|
||||
@@ -217,7 +236,10 @@ struct DefaultConv2dWgrad <
|
||||
ThreadblockSwizzle,
|
||||
2,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kAnalytic
|
||||
IteratorAlgorithm::kAnalytic,
|
||||
StrideSupport,
|
||||
AlignmentA,
|
||||
AlignmentB
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
@@ -228,12 +250,14 @@ struct DefaultConv2dWgrad <
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
||||
using AccessTypeA = cutlass::AlignedArray<ElementA, AlignmentA>;
|
||||
using IteratorA =
|
||||
cutlass::conv::threadblock::TileIterator<
|
||||
cutlass::conv::threadblock::Conv2dWgradOutputGradientTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
||||
ElementA,
|
||||
ThreadMapA
|
||||
ThreadMapA,
|
||||
AccessTypeA
|
||||
>
|
||||
>;
|
||||
|
||||
@@ -241,12 +265,14 @@ struct DefaultConv2dWgrad <
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
||||
using AccessTypeB = cutlass::AlignedArray<ElementB, AlignmentB>;
|
||||
using IteratorB =
|
||||
cutlass::conv::threadblock::TileIterator<
|
||||
cutlass::conv::threadblock::Conv2dWgradActivationTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
||||
ElementB,
|
||||
ThreadMapB
|
||||
ThreadMapB,
|
||||
AccessTypeB
|
||||
>
|
||||
>;
|
||||
|
||||
@@ -308,7 +334,10 @@ template <
|
||||
typename EpilogueOutputOp,
|
||||
typename ThreadblockSwizzle,
|
||||
int Stages,
|
||||
typename MathOperatorTag
|
||||
typename MathOperatorTag,
|
||||
conv::StrideSupport StrideSupport,
|
||||
int AlignmentA,
|
||||
int AlignmentB
|
||||
>
|
||||
struct DefaultConv2dWgrad <
|
||||
ElementA,
|
||||
@@ -327,7 +356,10 @@ struct DefaultConv2dWgrad <
|
||||
ThreadblockSwizzle,
|
||||
Stages,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kOptimized
|
||||
IteratorAlgorithm::kOptimized,
|
||||
StrideSupport,
|
||||
AlignmentA,
|
||||
AlignmentB
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
@@ -338,22 +370,26 @@ struct DefaultConv2dWgrad <
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
||||
using AccessTypeA = cutlass::AlignedArray<ElementA, AlignmentA>;
|
||||
using IteratorA =
|
||||
cutlass::conv::threadblock::Conv2dWgradOutputGradientTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
||||
ElementA,
|
||||
ThreadMapA
|
||||
ThreadMapA,
|
||||
AccessTypeA
|
||||
>;
|
||||
|
||||
using SmemIteratorA = typename MmaCore::SmemIteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
||||
using AccessTypeB = cutlass::AlignedArray<ElementB, AlignmentB>;
|
||||
using IteratorB =
|
||||
cutlass::conv::threadblock::Conv2dWgradActivationTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
||||
ElementB,
|
||||
ThreadMapB
|
||||
ThreadMapB,
|
||||
AccessTypeB
|
||||
>;
|
||||
|
||||
using SmemIteratorB = typename MmaCore::SmemIteratorB;
|
||||
@@ -394,6 +430,7 @@ struct DefaultConv2dWgrad <
|
||||
conv::Operator::kWgrad
|
||||
>;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines a kernel for Conv2dWgrad specialzation for Optimized IteratorAlgorithm and two
|
||||
@@ -413,7 +450,10 @@ template <
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp,
|
||||
typename ThreadblockSwizzle,
|
||||
typename MathOperatorTag
|
||||
typename MathOperatorTag,
|
||||
conv::StrideSupport StrideSupport,
|
||||
int AlignmentA,
|
||||
int AlignmentB
|
||||
>
|
||||
struct DefaultConv2dWgrad <
|
||||
ElementA,
|
||||
@@ -432,7 +472,10 @@ struct DefaultConv2dWgrad <
|
||||
ThreadblockSwizzle,
|
||||
2,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kOptimized
|
||||
IteratorAlgorithm::kOptimized,
|
||||
StrideSupport,
|
||||
AlignmentA,
|
||||
AlignmentB
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
@@ -443,12 +486,14 @@ struct DefaultConv2dWgrad <
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
||||
using AccessTypeA = cutlass::AlignedArray<ElementA, AlignmentA>;
|
||||
using IteratorA =
|
||||
cutlass::conv::threadblock::TileIterator<
|
||||
cutlass::conv::threadblock::Conv2dWgradOutputGradientTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
||||
ElementA,
|
||||
ThreadMapA
|
||||
ThreadMapA,
|
||||
AccessTypeA
|
||||
>
|
||||
>;
|
||||
|
||||
@@ -456,12 +501,14 @@ struct DefaultConv2dWgrad <
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
||||
using AccessTypeB = cutlass::AlignedArray<ElementB, AlignmentB>;
|
||||
using IteratorB =
|
||||
cutlass::conv::threadblock::TileIterator<
|
||||
cutlass::conv::threadblock::Conv2dWgradActivationTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
||||
ElementB,
|
||||
ThreadMapB
|
||||
ThreadMapB,
|
||||
AccessTypeB
|
||||
>
|
||||
>;
|
||||
|
||||
@@ -524,7 +571,10 @@ template <
|
||||
typename EpilogueOutputOp,
|
||||
typename ThreadblockSwizzle,
|
||||
int Stages,
|
||||
typename MathOperatorTag
|
||||
typename MathOperatorTag,
|
||||
conv::StrideSupport StrideSupport,
|
||||
int AccessTypeA,
|
||||
int AccessTypeB
|
||||
>
|
||||
struct DefaultConv2dWgrad <
|
||||
ElementA,
|
||||
@@ -543,7 +593,10 @@ struct DefaultConv2dWgrad <
|
||||
ThreadblockSwizzle,
|
||||
Stages,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kAnalytic
|
||||
IteratorAlgorithm::kAnalytic,
|
||||
StrideSupport,
|
||||
AccessTypeA,
|
||||
AccessTypeB
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
@@ -629,7 +682,10 @@ template <
|
||||
typename EpilogueOutputOp,
|
||||
typename ThreadblockSwizzle,
|
||||
int Stages,
|
||||
typename MathOperatorTag
|
||||
typename MathOperatorTag,
|
||||
conv::StrideSupport StrideSupport,
|
||||
int AccessTypeA,
|
||||
int AccessTypeB
|
||||
>
|
||||
struct DefaultConv2dWgrad <
|
||||
ElementA,
|
||||
@@ -648,7 +704,10 @@ struct DefaultConv2dWgrad <
|
||||
ThreadblockSwizzle,
|
||||
Stages,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kOptimized
|
||||
IteratorAlgorithm::kOptimized,
|
||||
StrideSupport,
|
||||
AccessTypeA,
|
||||
AccessTypeB
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
@@ -732,7 +791,10 @@ template <
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp,
|
||||
typename ThreadblockSwizzle,
|
||||
typename MathOperatorTag
|
||||
typename MathOperatorTag,
|
||||
conv::StrideSupport StrideSupport,
|
||||
int AccessTypeA,
|
||||
int AccessTypeB
|
||||
>
|
||||
struct DefaultConv2dWgrad <
|
||||
ElementA,
|
||||
@@ -751,7 +813,10 @@ struct DefaultConv2dWgrad <
|
||||
ThreadblockSwizzle,
|
||||
2,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kAnalytic
|
||||
IteratorAlgorithm::kAnalytic,
|
||||
StrideSupport,
|
||||
AccessTypeA,
|
||||
AccessTypeB
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
@@ -817,7 +882,6 @@ struct DefaultConv2dWgrad <
|
||||
ThreadblockSwizzle,
|
||||
conv::Operator::kWgrad
|
||||
>;
|
||||
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -838,7 +902,10 @@ template <
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp,
|
||||
typename ThreadblockSwizzle,
|
||||
typename MathOperatorTag
|
||||
typename MathOperatorTag,
|
||||
conv::StrideSupport StrideSupport,
|
||||
int AccessTypeA,
|
||||
int AccessTypeB
|
||||
>
|
||||
struct DefaultConv2dWgrad <
|
||||
ElementA,
|
||||
@@ -857,7 +924,10 @@ struct DefaultConv2dWgrad <
|
||||
ThreadblockSwizzle,
|
||||
2,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kOptimized
|
||||
IteratorAlgorithm::kOptimized,
|
||||
StrideSupport,
|
||||
AccessTypeA,
|
||||
AccessTypeB
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
@@ -925,12 +995,11 @@ struct DefaultConv2dWgrad <
|
||||
>;
|
||||
|
||||
};
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace conv
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
@@ -228,7 +228,6 @@ struct DefaultConv3dDgrad <
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
||||
|
||||
using IteratorA =
|
||||
cutlass::conv::threadblock::Conv3dDgradOutputGradientTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
||||
|
||||
@@ -501,4 +501,3 @@ struct DefaultConv3dWgrad <
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
@@ -59,7 +59,8 @@ template <
|
||||
typename Shape_,
|
||||
typename Element_,
|
||||
typename ThreadMap_,
|
||||
conv::StrideSupport StrideSupport_ = conv::StrideSupport::kUnity
|
||||
conv::StrideSupport StrideSupport_ = conv::StrideSupport::kUnity,
|
||||
typename AccessType_ = cutlass::AlignedArray<Element_, ThreadMap_::kElementsPerAccess>
|
||||
>
|
||||
class Conv2dDgradFilterTileAccessIteratorAnalytic;
|
||||
|
||||
@@ -70,13 +71,15 @@ class Conv2dDgradFilterTileAccessIteratorAnalytic;
|
||||
template <
|
||||
typename Shape_,
|
||||
typename Element_,
|
||||
typename ThreadMap_
|
||||
typename ThreadMap_,
|
||||
typename AccessType_
|
||||
>
|
||||
class Conv2dDgradFilterTileAccessIteratorAnalytic <
|
||||
Shape_,
|
||||
Element_,
|
||||
ThreadMap_,
|
||||
conv::StrideSupport::kStrided
|
||||
conv::StrideSupport::kStrided,
|
||||
AccessType_
|
||||
> {
|
||||
public:
|
||||
|
||||
@@ -88,7 +91,7 @@ public:
|
||||
using Element = Element_;
|
||||
using Layout = layout::TensorNHWC;
|
||||
using ThreadMap = ThreadMap_;
|
||||
using AccessType = AlignedArray<Element, ThreadMap::kElementsPerAccess>;
|
||||
using AccessType = AccessType_;
|
||||
using TensorRef = cutlass::TensorRef<Element, Layout>;
|
||||
using TensorCoord = typename Layout::TensorCoord;
|
||||
using Index = typename Layout::Index;
|
||||
@@ -97,7 +100,12 @@ public:
|
||||
static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided;
|
||||
static int const kConvDim = 2;
|
||||
using ConvProblemSize = typename conv::Conv2dProblemSize;
|
||||
|
||||
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
|
||||
|
||||
static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements),
|
||||
"Vectors implied by the thread map must be divisible by the access type.");
|
||||
|
||||
static_assert(sizeof_bits<Element>::value >= 8,
|
||||
"DGRAD requires elements of size 8b or larger.");
|
||||
|
||||
@@ -107,14 +115,13 @@ public:
|
||||
|
||||
using Params = Conv2dAnalyticParams<Layout>;
|
||||
|
||||
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
|
||||
|
||||
private:
|
||||
|
||||
Params const ¶ms_;
|
||||
Conv2dProblemSize const &problem_size_;
|
||||
LongIndex iteration_contiguous_;
|
||||
LongIndex iteration_strided_;
|
||||
LongIndex iteration_vector_;
|
||||
char const *pointer_;
|
||||
|
||||
// For a fixed filter position (r,s) find and fill offset_k_, offset_c_ in strided and contiguous dimension
|
||||
@@ -162,8 +169,10 @@ public:
|
||||
/// Overrides the internal iteration index
|
||||
CUTLASS_HOST_DEVICE
|
||||
void set_iteration_index(Index index) {
|
||||
iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous;
|
||||
iteration_strided_ = index / ThreadMap::Iterations::kContiguous;
|
||||
iteration_vector_ = index % kAccessesPerVector;
|
||||
int residual_access = index / kAccessesPerVector;
|
||||
iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous;
|
||||
iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous;
|
||||
}
|
||||
|
||||
/// Adds a pointer offset in units of Element
|
||||
@@ -213,7 +222,7 @@ public:
|
||||
|
||||
TensorCoord coord = at();
|
||||
|
||||
return coord.n() < problem_size_.K && coord.c() < problem_size_.C;
|
||||
return coord.n() < problem_size_.K && (coord.c() + iteration_vector_ * AccessType::kElements) < problem_size_.C;
|
||||
}
|
||||
|
||||
/// Returns a pointer to the vector starting at the current coordinate
|
||||
@@ -223,13 +232,19 @@ public:
|
||||
TensorCoord coord = at();
|
||||
LongIndex offset = params_.layout(coord);
|
||||
|
||||
return reinterpret_cast<AccessType const *>(pointer_ + offset * sizeof_bits<Element>::value / 8);
|
||||
return reinterpret_cast<AccessType const *>(pointer_ + offset * sizeof_bits<Element>::value / 8) + iteration_vector_;
|
||||
|
||||
}
|
||||
|
||||
/// Increments to the next memory access
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv2dDgradFilterTileAccessIteratorAnalytic &operator++() {
|
||||
++iteration_vector_;
|
||||
if (iteration_vector_ < kAccessesPerVector) {
|
||||
return *this;
|
||||
}
|
||||
iteration_vector_ = 0;
|
||||
|
||||
++iteration_contiguous_;
|
||||
if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) {
|
||||
return *this;
|
||||
@@ -249,7 +264,7 @@ public:
|
||||
static Status can_implement(Conv2dProblemSize const &problem_size) {
|
||||
|
||||
// check alignment constraint on iterator's contiguous dimension
|
||||
if (problem_size.C % (128/sizeof_bits<Element>::value)) {
|
||||
if (problem_size.C % AccessType::kElements) {
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
|
||||
@@ -263,13 +278,15 @@ public:
|
||||
template <
|
||||
typename Shape_,
|
||||
typename Element_,
|
||||
typename ThreadMap_
|
||||
typename ThreadMap_,
|
||||
typename AccessType_
|
||||
>
|
||||
class Conv2dDgradFilterTileAccessIteratorAnalytic <
|
||||
Shape_,
|
||||
Element_,
|
||||
ThreadMap_,
|
||||
conv::StrideSupport::kUnity
|
||||
conv::StrideSupport::kUnity,
|
||||
AccessType_
|
||||
>{
|
||||
public:
|
||||
|
||||
@@ -281,7 +298,7 @@ public:
|
||||
using Element = Element_;
|
||||
using Layout = layout::TensorNHWC;
|
||||
using ThreadMap = ThreadMap_;
|
||||
using AccessType = AlignedArray<Element, ThreadMap::kElementsPerAccess>;
|
||||
using AccessType = AccessType_;
|
||||
using TensorRef = cutlass::TensorRef<Element, Layout>;
|
||||
using TensorCoord = typename Layout::TensorCoord;
|
||||
using Index = typename Layout::Index;
|
||||
@@ -290,7 +307,12 @@ public:
|
||||
static StrideSupport const kStrideSupport = conv::StrideSupport::kUnity;
|
||||
static int const kConvDim = 2;
|
||||
using ConvProblemSize = typename conv::Conv2dProblemSize;
|
||||
|
||||
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
|
||||
|
||||
static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements),
|
||||
"Vectors implied by the thread map must be divisible by the access type.");
|
||||
|
||||
static_assert(sizeof_bits<Element>::value >= 8,
|
||||
"DGRAD requires elements of size 8b or larger.");
|
||||
|
||||
@@ -306,6 +328,7 @@ private:
|
||||
Conv2dProblemSize const &problem_size_;
|
||||
LongIndex iteration_contiguous_;
|
||||
LongIndex iteration_strided_;
|
||||
LongIndex iteration_vector_;
|
||||
char const *pointer_;
|
||||
|
||||
// For a fixed filter position (r,s) find and fill offset_k_, offset_c_ in strided and contiguous dimension
|
||||
@@ -348,8 +371,10 @@ public:
|
||||
/// Overrides the internal iteration index
|
||||
CUTLASS_HOST_DEVICE
|
||||
void set_iteration_index(Index index) {
|
||||
iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous;
|
||||
iteration_strided_ = index / ThreadMap::Iterations::kContiguous;
|
||||
iteration_vector_ = index % kAccessesPerVector;
|
||||
int residual_access = index / kAccessesPerVector;
|
||||
iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous;
|
||||
iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous;
|
||||
}
|
||||
|
||||
/// Adds a pointer offset in units of Element
|
||||
@@ -395,7 +420,7 @@ public:
|
||||
|
||||
TensorCoord coord = at();
|
||||
|
||||
return coord.n() < problem_size_.K && coord.c() < problem_size_.C;
|
||||
return coord.n() < problem_size_.K && (coord.c() + iteration_vector_ * AccessType::kElements) < problem_size_.C;
|
||||
}
|
||||
|
||||
/// Returns a pointer to the vector starting at the current coordinate
|
||||
@@ -405,13 +430,18 @@ public:
|
||||
TensorCoord coord = at();
|
||||
LongIndex offset = params_.layout(coord);
|
||||
|
||||
return reinterpret_cast<AccessType const *>(pointer_ + offset * sizeof_bits<Element>::value / 8);
|
||||
|
||||
return reinterpret_cast<AccessType const *>(pointer_ + offset * sizeof_bits<Element>::value / 8) + iteration_vector_;
|
||||
}
|
||||
|
||||
/// Increments to the next memory access
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv2dDgradFilterTileAccessIteratorAnalytic &operator++() {
|
||||
++iteration_vector_;
|
||||
if (iteration_vector_ < kAccessesPerVector) {
|
||||
return *this;
|
||||
}
|
||||
iteration_vector_ = 0;
|
||||
|
||||
++iteration_contiguous_;
|
||||
if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) {
|
||||
return *this;
|
||||
@@ -431,7 +461,7 @@ public:
|
||||
static Status can_implement(Conv2dProblemSize const &problem_size) {
|
||||
|
||||
// check alignment constraint on iterator's contiguous dimension
|
||||
if (problem_size.C % (128/sizeof_bits<Element>::value)) {
|
||||
if (problem_size.C % AccessType::kElements) {
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
|
||||
@@ -446,5 +476,3 @@ public:
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
|
||||
@@ -60,7 +60,8 @@ template <
|
||||
typename Shape_,
|
||||
typename Element_,
|
||||
typename ThreadMap_,
|
||||
conv::StrideSupport StrideSupport_ = conv::StrideSupport::kUnity
|
||||
conv::StrideSupport StrideSupport_ = conv::StrideSupport::kUnity,
|
||||
typename AccessType_ = cutlass::AlignedArray<Element_, ThreadMap_::kElementsPerAccess>
|
||||
>
|
||||
class Conv2dDgradFilterTileAccessIteratorOptimized;
|
||||
|
||||
@@ -71,13 +72,15 @@ class Conv2dDgradFilterTileAccessIteratorOptimized;
|
||||
template <
|
||||
typename Shape_,
|
||||
typename Element_,
|
||||
typename ThreadMap_
|
||||
typename ThreadMap_,
|
||||
typename AccessType_
|
||||
>
|
||||
class Conv2dDgradFilterTileAccessIteratorOptimized <
|
||||
Shape_,
|
||||
Element_,
|
||||
ThreadMap_,
|
||||
conv::StrideSupport::kUnity
|
||||
conv::StrideSupport::kUnity,
|
||||
AccessType_
|
||||
> {
|
||||
public:
|
||||
|
||||
@@ -89,7 +92,7 @@ public:
|
||||
using Element = Element_;
|
||||
using Layout = layout::TensorNHWC;
|
||||
using ThreadMap = ThreadMap_;
|
||||
using AccessType = AlignedArray<Element, ThreadMap::kElementsPerAccess>;
|
||||
using AccessType = AccessType_;
|
||||
using TensorRef = cutlass::TensorRef<Element, Layout>;
|
||||
using TensorCoord = typename Layout::TensorCoord;
|
||||
using Index = typename Layout::Index;
|
||||
@@ -98,9 +101,12 @@ public:
|
||||
static StrideSupport const kStrideSupport = conv::StrideSupport::kUnity;
|
||||
static int const kConvDim = 2;
|
||||
using ConvProblemSize = typename conv::Conv2dProblemSize;
|
||||
|
||||
|
||||
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
|
||||
|
||||
static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements),
|
||||
"Vectors implied by the thread map must be divisible by the access type.");
|
||||
|
||||
//
|
||||
// Parameters structure
|
||||
//
|
||||
@@ -141,9 +147,10 @@ private:
|
||||
Conv2dProblemSize const &problem_size_;
|
||||
LongIndex iteration_contiguous_;
|
||||
LongIndex iteration_strided_;
|
||||
LongIndex iteration_vector_;
|
||||
char const *pointer_;
|
||||
|
||||
uint32_t predicates_;
|
||||
uint32_t predicates_[kAccessesPerVector];
|
||||
int filter_rs_;
|
||||
int filter_k_;
|
||||
|
||||
@@ -169,7 +176,7 @@ public:
|
||||
params_(params),
|
||||
problem_size_(problem_size),
|
||||
pointer_(reinterpret_cast<char const *>(ptr)),
|
||||
predicates_(0),
|
||||
predicates_{0},
|
||||
filter_rs_(0),
|
||||
filter_k_(0) {
|
||||
|
||||
@@ -186,11 +193,15 @@ public:
|
||||
int filter_k = filter_k_ + s * ThreadMap::Delta::kStrided;
|
||||
int filter_c = column + c * ThreadMap::Delta::kContiguous;
|
||||
|
||||
uint32_t pred = ((filter_k < problem_size_.K && filter_c < problem_size_.C) ? 1u : 0);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < kAccessesPerVector; ++v) {
|
||||
|
||||
int pred_idx = c + s * ThreadMap::Iterations::kContiguous;
|
||||
|
||||
predicates_ |= (pred << pred_idx);
|
||||
uint32_t pred = ((filter_k < problem_size_.K && (filter_c + v * AccessType::kElements) < problem_size_.C) ? 1u : 0);
|
||||
|
||||
int pred_idx = c + s * ThreadMap::Iterations::kContiguous;
|
||||
|
||||
predicates_[v] |= (pred << pred_idx);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -204,8 +215,10 @@ public:
|
||||
/// Overrides the internal iteration index
|
||||
CUTLASS_HOST_DEVICE
|
||||
void set_iteration_index(Index index) {
|
||||
iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous;
|
||||
iteration_strided_ = index / ThreadMap::Iterations::kContiguous;
|
||||
iteration_vector_ = index % kAccessesPerVector;
|
||||
int residual_access = index / kAccessesPerVector;
|
||||
iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous;
|
||||
iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous;
|
||||
}
|
||||
|
||||
/// Adds a pointer offset in units of Element
|
||||
@@ -234,7 +247,11 @@ public:
|
||||
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
|
||||
if (filter_k_ + s * ThreadMap::Delta::kStrided >= problem_size_.K) {
|
||||
uint32_t kClearMask = ((1u << ThreadMap::Iterations::kContiguous) - 1) << (s * ThreadMap::Iterations::kContiguous);
|
||||
predicates_ = (predicates_ & (~kClearMask));
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < kAccessesPerVector; ++v) {
|
||||
predicates_[v] = (predicates_[v] & (~kClearMask));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -245,19 +262,25 @@ public:
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool valid() {
|
||||
LongIndex pred_idx = iteration_contiguous_ + iteration_strided_ * ThreadMap::Iterations::kContiguous;
|
||||
return (predicates_ & (1u << pred_idx));
|
||||
return (predicates_[iteration_vector_] & (1u << pred_idx));
|
||||
}
|
||||
|
||||
/// Returns a pointer to the vector starting at the current coordinate
|
||||
CUTLASS_HOST_DEVICE
|
||||
AccessType const *get() const {
|
||||
return reinterpret_cast<AccessType const *>(pointer_ +
|
||||
iteration_contiguous_ * ThreadMap::Delta::kContiguous * sizeof_bits<Element>::value / 8);
|
||||
iteration_contiguous_ * ThreadMap::Delta::kContiguous * sizeof_bits<Element>::value / 8) + iteration_vector_;
|
||||
}
|
||||
|
||||
/// Increments to the next memory access
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv2dDgradFilterTileAccessIteratorOptimized &operator++() {
|
||||
++iteration_vector_;
|
||||
if (iteration_vector_ < kAccessesPerVector) {
|
||||
return *this;
|
||||
}
|
||||
iteration_vector_ = 0;
|
||||
|
||||
++iteration_contiguous_;
|
||||
if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) {
|
||||
return *this;
|
||||
@@ -282,7 +305,7 @@ public:
|
||||
static Status can_implement(Conv2dProblemSize const &problem_size) {
|
||||
|
||||
// check alignment constraint on iterator's contiguous dimension
|
||||
if (problem_size.C % (128/sizeof_bits<Element>::value)) {
|
||||
if (problem_size.C % AccessType::kElements) {
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
|
||||
@@ -297,5 +320,3 @@ public:
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
|
||||
@@ -59,7 +59,8 @@ template <
|
||||
typename Shape_,
|
||||
typename Element_,
|
||||
typename ThreadMap_,
|
||||
conv::StrideSupport StrideSupport_ = conv::StrideSupport::kStrided
|
||||
conv::StrideSupport StrideSupport_ = conv::StrideSupport::kStrided,
|
||||
typename AccessType_ = cutlass::AlignedArray<Element_, ThreadMap_::kElementsPerAccess>
|
||||
>
|
||||
class Conv2dDgradOutputGradientTileAccessIteratorAnalytic;
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -69,13 +70,15 @@ class Conv2dDgradOutputGradientTileAccessIteratorAnalytic;
|
||||
template <
|
||||
typename Shape_,
|
||||
typename Element_,
|
||||
typename ThreadMap_
|
||||
typename ThreadMap_,
|
||||
typename AccessType_
|
||||
>
|
||||
class Conv2dDgradOutputGradientTileAccessIteratorAnalytic <
|
||||
Shape_,
|
||||
Element_,
|
||||
ThreadMap_,
|
||||
conv::StrideSupport::kStrided
|
||||
conv::StrideSupport::kStrided,
|
||||
AccessType_
|
||||
> {
|
||||
public:
|
||||
|
||||
@@ -86,7 +89,7 @@ public:
|
||||
using Element = Element_;
|
||||
using Layout = layout::TensorNHWC;
|
||||
using ThreadMap = ThreadMap_;
|
||||
using AccessType = AlignedArray<Element, ThreadMap::kElementsPerAccess>;
|
||||
using AccessType = AccessType_;
|
||||
using TensorRef = cutlass::TensorRef<Element, Layout>;
|
||||
using TensorCoord = typename Layout::TensorCoord;
|
||||
using Index = typename Layout::Index;
|
||||
@@ -95,7 +98,12 @@ public:
|
||||
static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided;
|
||||
static int const kConvDim = 2;
|
||||
using ConvProblemSize = typename conv::Conv2dProblemSize;
|
||||
|
||||
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
|
||||
|
||||
static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements),
|
||||
"Vectors implied by the thread map must be divisible by the access type.");
|
||||
|
||||
static_assert(sizeof_bits<Element>::value >= 8,
|
||||
"DGRAD requires elements of size 8b or greater.");
|
||||
|
||||
@@ -112,14 +120,13 @@ public:
|
||||
|
||||
using Params = Conv2dDgradOutputGradientTileAccessIteratorAnalyticParams;
|
||||
|
||||
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
|
||||
|
||||
private:
|
||||
|
||||
Params const ¶ms_;
|
||||
Conv2dProblemSize const &problem_size_;
|
||||
LongIndex iteration_contiguous_;
|
||||
LongIndex iteration_strided_;
|
||||
LongIndex iteration_vector_;
|
||||
char const *pointer_;
|
||||
|
||||
int filter_k_;
|
||||
@@ -211,8 +218,10 @@ public:
|
||||
/// Overrides the internal iteration index
|
||||
CUTLASS_HOST_DEVICE
|
||||
void set_iteration_index(Index index) {
|
||||
iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous;
|
||||
iteration_strided_ = index / ThreadMap::Iterations::kContiguous;
|
||||
iteration_vector_ = index % kAccessesPerVector;
|
||||
int residual_access = index / kAccessesPerVector;
|
||||
iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous;
|
||||
iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous;
|
||||
}
|
||||
|
||||
/// Adds a pointer offset in units of Element
|
||||
@@ -277,7 +286,7 @@ public:
|
||||
coord.n() < problem_size_.N &&
|
||||
coord.h() >= 0 && coord.h() < problem_size_.P &&
|
||||
coord.w() >= 0 && coord.w() < problem_size_.Q &&
|
||||
coord.c() < problem_size_.K;
|
||||
(coord.c() + iteration_vector_ * AccessType::kElements) < problem_size_.K;
|
||||
}
|
||||
|
||||
/// Returns a pointer to the vector starting at the current coordinate
|
||||
@@ -287,12 +296,18 @@ public:
|
||||
TensorCoord coord = at();
|
||||
LongIndex offset = params_.layout(coord);
|
||||
|
||||
return reinterpret_cast<AccessType const *>(pointer_ + offset * sizeof_bits<Element>::value / 8);
|
||||
return reinterpret_cast<AccessType const *>(pointer_ + offset * sizeof_bits<Element>::value / 8) + iteration_vector_;
|
||||
}
|
||||
|
||||
/// Increments to the next memory access
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv2dDgradOutputGradientTileAccessIteratorAnalytic &operator++() {
|
||||
++iteration_vector_;
|
||||
if (iteration_vector_ < kAccessesPerVector) {
|
||||
return *this;
|
||||
}
|
||||
iteration_vector_ = 0;
|
||||
|
||||
++iteration_contiguous_;
|
||||
if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) {
|
||||
return *this;
|
||||
@@ -312,14 +327,14 @@ public:
|
||||
static Status can_implement(Conv2dProblemSize const &problem_size) {
|
||||
|
||||
// check alignment constraint on iterator's contiguous dimension
|
||||
if (problem_size.K % (128/sizeof_bits<Element>::value)) {
|
||||
if (problem_size.K % AccessType::kElements) {
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Conv2dDgradOutputGradientTileAccessIteratorAnalytic for unity strides can be optimized by
|
||||
@@ -327,13 +342,15 @@ public:
|
||||
template <
|
||||
typename Shape_,
|
||||
typename Element_,
|
||||
typename ThreadMap_
|
||||
typename ThreadMap_,
|
||||
typename AccessType_
|
||||
>
|
||||
class Conv2dDgradOutputGradientTileAccessIteratorAnalytic <
|
||||
Shape_,
|
||||
Element_,
|
||||
ThreadMap_,
|
||||
conv::StrideSupport::kUnity
|
||||
conv::StrideSupport::kUnity,
|
||||
AccessType_
|
||||
> {
|
||||
public:
|
||||
|
||||
@@ -344,7 +361,7 @@ public:
|
||||
using Element = Element_;
|
||||
using Layout = layout::TensorNHWC;
|
||||
using ThreadMap = ThreadMap_;
|
||||
using AccessType = AlignedArray<Element, ThreadMap::kElementsPerAccess>;
|
||||
using AccessType = AccessType_;
|
||||
using TensorRef = cutlass::TensorRef<Element, Layout>;
|
||||
using TensorCoord = typename Layout::TensorCoord;
|
||||
using Index = typename Layout::Index;
|
||||
@@ -353,7 +370,12 @@ public:
|
||||
static StrideSupport const kStrideSupport = conv::StrideSupport::kUnity;
|
||||
static int const kConvDim = 2;
|
||||
using ConvProblemSize = typename conv::Conv2dProblemSize;
|
||||
|
||||
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
|
||||
|
||||
static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements),
|
||||
"Vectors implied by the thread map must be divisible by the access type.");
|
||||
|
||||
static_assert(sizeof_bits<Element>::value >= 8,
|
||||
"DGRAD requires elements of size 8b or greater.");
|
||||
|
||||
@@ -368,8 +390,6 @@ public:
|
||||
// Parameters structure
|
||||
//
|
||||
|
||||
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
|
||||
|
||||
struct Params {
|
||||
|
||||
Layout layout;
|
||||
@@ -395,6 +415,7 @@ private:
|
||||
Conv2dProblemSize const &problem_size_;
|
||||
LongIndex iteration_contiguous_;
|
||||
LongIndex iteration_strided_;
|
||||
LongIndex iteration_vector_;
|
||||
char const *pointer_;
|
||||
|
||||
int filter_k_;
|
||||
@@ -446,8 +467,10 @@ public:
|
||||
/// Overrides the internal iteration index
|
||||
CUTLASS_HOST_DEVICE
|
||||
void set_iteration_index(Index index) {
|
||||
iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous;
|
||||
iteration_strided_ = index / ThreadMap::Iterations::kContiguous;
|
||||
iteration_vector_ = index % kAccessesPerVector;
|
||||
int residual_access = index / kAccessesPerVector;
|
||||
iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous;
|
||||
iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous;
|
||||
}
|
||||
|
||||
/// Adds a pointer offset in units of Element
|
||||
@@ -497,7 +520,6 @@ public:
|
||||
|
||||
}
|
||||
|
||||
|
||||
/// Returns true if the current coordinate is within the output tensor Dy
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool valid() const {
|
||||
@@ -507,7 +529,7 @@ public:
|
||||
return coord.n() < problem_size_.N &&
|
||||
coord.h() >= 0 && coord.h() < problem_size_.P &&
|
||||
coord.w() >= 0 && coord.w() < problem_size_.Q &&
|
||||
coord.c() < problem_size_.K;
|
||||
(coord.c() + iteration_vector_ * AccessType::kElements) < problem_size_.K;
|
||||
}
|
||||
|
||||
/// Returns a pointer to the vector starting at the current coordinate
|
||||
@@ -517,12 +539,18 @@ public:
|
||||
TensorCoord coord = at();
|
||||
LongIndex offset = params_.layout(coord);
|
||||
|
||||
return reinterpret_cast<AccessType const *>(pointer_ + offset * sizeof_bits<Element>::value / 8);
|
||||
return reinterpret_cast<AccessType const *>(pointer_ + offset * sizeof_bits<Element>::value / 8) + iteration_vector_;
|
||||
}
|
||||
|
||||
/// Increments to the next memory access
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv2dDgradOutputGradientTileAccessIteratorAnalytic &operator++() {
|
||||
++iteration_vector_;
|
||||
if (iteration_vector_ < kAccessesPerVector) {
|
||||
return *this;
|
||||
}
|
||||
iteration_vector_ = 0;
|
||||
|
||||
++iteration_contiguous_;
|
||||
if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) {
|
||||
return *this;
|
||||
@@ -548,7 +576,7 @@ public:
|
||||
}
|
||||
|
||||
// check alignment constraint on iterator's contiguous dimension
|
||||
if (problem_size.K % (128/sizeof_bits<Element>::value)) {
|
||||
if (problem_size.K % AccessType::kElements) {
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
|
||||
@@ -556,7 +584,9 @@ public:
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace conv
|
||||
} // namespace cutlass
|
||||
|
||||
@@ -61,7 +61,8 @@ template <
|
||||
typename Shape_,
|
||||
typename Element_,
|
||||
typename ThreadMap_,
|
||||
conv::StrideSupport StrideSupport_ = conv::StrideSupport::kUnity
|
||||
conv::StrideSupport StrideSupport_ = conv::StrideSupport::kUnity,
|
||||
typename AccessType_ = cutlass::AlignedArray<Element_, ThreadMap_::kElementsPerAccess>
|
||||
>
|
||||
class Conv2dDgradOutputGradientTileAccessIteratorOptimized;
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -74,14 +75,16 @@ class Conv2dDgradOutputGradientTileAccessIteratorOptimized;
|
||||
template <
|
||||
typename Shape_,
|
||||
typename Element_,
|
||||
typename ThreadMap_
|
||||
typename ThreadMap_,
|
||||
typename AccessType_
|
||||
>
|
||||
class Conv2dDgradOutputGradientTileAccessIteratorOptimized <
|
||||
Shape_,
|
||||
Element_,
|
||||
ThreadMap_,
|
||||
conv::StrideSupport::kUnity
|
||||
> {
|
||||
conv::StrideSupport::kUnity,
|
||||
AccessType_
|
||||
> {
|
||||
public:
|
||||
|
||||
//
|
||||
@@ -93,7 +96,7 @@ public:
|
||||
using Layout = layout::TensorNHWC;
|
||||
using TensorCoord = typename Layout::TensorCoord;
|
||||
using ThreadMap = ThreadMap_;
|
||||
using AccessType = AlignedArray<Element, ThreadMap::kElementsPerAccess>;
|
||||
using AccessType = AccessType_;
|
||||
using TensorRef = cutlass::TensorRef<Element, Layout>;
|
||||
using Index = typename Layout::Index;
|
||||
using LongIndex = typename Layout::LongIndex;
|
||||
@@ -101,7 +104,12 @@ public:
|
||||
static StrideSupport const kStrideSupport = conv::StrideSupport::kUnity;
|
||||
static int const kConvDim = 2;
|
||||
using ConvProblemSize = typename conv::Conv2dProblemSize;
|
||||
|
||||
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
|
||||
|
||||
static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements),
|
||||
"Vectors implied by the thread map must be divisible by the access type.");
|
||||
|
||||
using Mask = uint64_t;
|
||||
|
||||
//
|
||||
@@ -116,14 +124,13 @@ public:
|
||||
|
||||
using Params = Conv2dDgradOutputGradientIteratorOptimizedParams;
|
||||
|
||||
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
|
||||
|
||||
private:
|
||||
|
||||
Conv2dDgradOutputGradientIteratorOptimizedParams const ¶ms_;
|
||||
Conv2dProblemSize const &problem_size_;
|
||||
LongIndex iteration_contiguous_;
|
||||
LongIndex iteration_strided_;
|
||||
LongIndex iteration_vector_;
|
||||
|
||||
// One pointer per access
|
||||
char const *pointer_[ThreadMap::Iterations::kStrided];
|
||||
@@ -133,7 +140,7 @@ private:
|
||||
int filter_s_;
|
||||
int filter_k_;
|
||||
|
||||
Index masks_[ThreadMap::Iterations::kStrided][2];
|
||||
Index masks_[ThreadMap::Iterations::kStrided][kAccessesPerVector][2];
|
||||
|
||||
public:
|
||||
|
||||
@@ -201,7 +208,11 @@ public:
|
||||
int p = offset_h[s_idx] + problem_size_.pad_h - r_ * problem_size_.dilation_h;
|
||||
|
||||
bool pred = (offset_n[s_idx] < problem_size_.N && p >= 0 && p < problem_size_.P);
|
||||
masks_[s_idx][0] |= (pred << r);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) {
|
||||
masks_[s_idx][v_idx][0] |= (pred << r);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -218,12 +229,17 @@ public:
|
||||
int q = offset_w[s_idx] + problem_size_.pad_w - s_ * problem_size_.dilation_w;
|
||||
|
||||
bool pred = (q >= 0 && q < problem_size_.Q);
|
||||
masks_[s_idx][1] |= (pred << s);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) {
|
||||
masks_[s_idx][v_idx][1] |= (pred << s);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (filter_k_ >= problem_size.K) {
|
||||
clear_mask();
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) {
|
||||
clear_mask(v_idx, filter_k_ >= problem_size.K);
|
||||
}
|
||||
|
||||
set_iteration_index(0);
|
||||
@@ -269,62 +285,15 @@ private:
|
||||
}
|
||||
}
|
||||
|
||||
/// Clears the predicates
|
||||
CUTLASS_HOST_DEVICE
|
||||
void clear_mask_(bool clear) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
|
||||
|
||||
// We are using inline PTX assembly here to avoid an CUDA C++ compilation
|
||||
// artifact in which control flow instructions are generated. Instead, our
|
||||
// intent is to predicate the mov instructions.
|
||||
#if defined(__CUDA_ARCH__)
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" .reg .pred p;\n"
|
||||
" .reg .u32 m;"
|
||||
" mov.u32 m, %2;"
|
||||
" setp.ne.b32 p, %1, 0;\n"
|
||||
" @p mov.u32 m, 0;\n"
|
||||
" mov.u32 %0, m;\n"
|
||||
"}\n"
|
||||
:
|
||||
"=r"(masks_[s][0])
|
||||
:
|
||||
"r"((int)clear),
|
||||
"r"(masks_[s][0])
|
||||
);
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" .reg .pred p;\n"
|
||||
" .reg .u32 m;"
|
||||
" mov.u32 m, %2;"
|
||||
" setp.ne.b32 p, %1, 0;\n"
|
||||
" @p mov.u32 m, 0;\n"
|
||||
" mov.u32 %0, m;\n"
|
||||
"}\n"
|
||||
:
|
||||
"=r"(masks_[s][1])
|
||||
:
|
||||
"r"((int)clear),
|
||||
"r"(masks_[s][1])
|
||||
);
|
||||
#else
|
||||
if (clear) {
|
||||
masks_[s][0] = 0;
|
||||
masks_[s][1] = 0;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
|
||||
/// Overrides the internal iteration index
|
||||
CUTLASS_HOST_DEVICE
|
||||
void set_iteration_index(Index index) {
|
||||
iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous;
|
||||
iteration_strided_ = index / ThreadMap::Iterations::kContiguous;
|
||||
iteration_vector_ = index % kAccessesPerVector;
|
||||
int residual_access = index / kAccessesPerVector;
|
||||
iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous;
|
||||
iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous;
|
||||
}
|
||||
|
||||
/// Adds a pointer offset in units of element
|
||||
@@ -359,16 +328,32 @@ public:
|
||||
filter_k_ += params_.filter_k_delta;
|
||||
}
|
||||
|
||||
clear_mask_(filter_k_ >= problem_size_.K);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) {
|
||||
clear_mask(v_idx, (filter_k_ + v_idx * AccessType::kElements) >= problem_size_.K);
|
||||
}
|
||||
}
|
||||
|
||||
/// Clears the predicates
|
||||
CUTLASS_HOST_DEVICE
|
||||
void clear_mask() {
|
||||
void clear_mask(bool clear = true) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
|
||||
masks_[s][0] = Mask(0);
|
||||
masks_[s][1] = Mask(0);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < kAccessesPerVector; ++v) {
|
||||
masks_[s][v][0] = clear ? Mask(0) : masks_[s][v][0];
|
||||
masks_[s][v][1] = clear ? Mask(0) : masks_[s][v][1];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Clears the predicates
|
||||
CUTLASS_HOST_DEVICE
|
||||
void clear_mask(int v, bool clear = true) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
|
||||
masks_[s][v][0] = clear ? Mask(0) : masks_[s][v][0];
|
||||
masks_[s][v][1] = clear ? Mask(0) : masks_[s][v][1];
|
||||
}
|
||||
}
|
||||
|
||||
@@ -376,20 +361,25 @@ public:
|
||||
bool valid() {
|
||||
|
||||
return
|
||||
(masks_[iteration_strided_][0] & (Index(1) << filter_r_)) &&
|
||||
(masks_[iteration_strided_][1] & (Index(1) << filter_s_));
|
||||
(masks_[iteration_strided_][iteration_vector_][0] & (Index(1) << filter_r_)) &&
|
||||
(masks_[iteration_strided_][iteration_vector_][1] & (Index(1) << filter_s_));
|
||||
}
|
||||
|
||||
/// Returns a pointer to the vector starting at the current coordinate
|
||||
CUTLASS_HOST_DEVICE
|
||||
AccessType const *get() const {
|
||||
|
||||
return reinterpret_cast<AccessType const *>(pointer_[iteration_strided_]);
|
||||
return reinterpret_cast<AccessType const *>(pointer_[iteration_strided_]) + iteration_vector_;
|
||||
}
|
||||
|
||||
/// Increments to the next memory access
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv2dDgradOutputGradientTileAccessIteratorOptimized &operator++() {
|
||||
++iteration_vector_;
|
||||
if (iteration_vector_ < kAccessesPerVector) {
|
||||
return *this;
|
||||
}
|
||||
iteration_vector_ = 0;
|
||||
|
||||
++iteration_contiguous_;
|
||||
if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) {
|
||||
@@ -416,7 +406,7 @@ public:
|
||||
}
|
||||
|
||||
// check alignment constraint on iterator's contiguous dimension
|
||||
if (problem_size.K % (128/sizeof_bits<Element>::value)) {
|
||||
if (problem_size.K % AccessType::kElements) {
|
||||
return Status::kErrorNotSupported;
|
||||
}
|
||||
|
||||
|
||||
@@ -60,7 +60,8 @@ template <
|
||||
typename Shape_,
|
||||
typename Element_,
|
||||
typename Layout_,
|
||||
typename ThreadMap_
|
||||
typename ThreadMap_,
|
||||
typename AccessType_ = cutlass::AlignedArray<Element_, ThreadMap_::kElementsPerAccess>
|
||||
>
|
||||
class Conv2dFpropActivationTileAccessIteratorAnalytic {
|
||||
public:
|
||||
@@ -74,7 +75,7 @@ public:
|
||||
using Layout = Layout_;
|
||||
using TensorCoord = typename Layout::TensorCoord;
|
||||
using ThreadMap = ThreadMap_;
|
||||
using AccessType = AlignedArray<Element, ThreadMap::kElementsPerAccess>;
|
||||
using AccessType = AccessType_;
|
||||
using TensorRef = cutlass::TensorRef<Element, Layout>;
|
||||
using Index = typename Layout::Index;
|
||||
using LongIndex = typename Layout::LongIndex;
|
||||
@@ -82,7 +83,12 @@ public:
|
||||
static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided;
|
||||
static int const kConvDim = 2;
|
||||
using ConvProblemSize = typename conv::Conv2dProblemSize;
|
||||
|
||||
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
|
||||
|
||||
static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements),
|
||||
"Vectors implied by the thread map must be divisible by the access type.");
|
||||
|
||||
//
|
||||
// Simplifying assertions
|
||||
//
|
||||
@@ -95,14 +101,13 @@ public:
|
||||
|
||||
using Params = Conv2dAnalyticParams<Layout>;
|
||||
|
||||
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
|
||||
|
||||
private:
|
||||
|
||||
Params const ¶ms_;
|
||||
Conv2dProblemSize const &problem_size_;
|
||||
LongIndex iteration_contiguous_;
|
||||
LongIndex iteration_strided_;
|
||||
LongIndex iteration_vector_;
|
||||
char const *pointer_;
|
||||
|
||||
int filter_c_;
|
||||
@@ -156,8 +161,10 @@ public:
|
||||
/// Overrides the internal iteration index
|
||||
CUTLASS_HOST_DEVICE
|
||||
void set_iteration_index(Index index) {
|
||||
iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous;
|
||||
iteration_strided_ = index / ThreadMap::Iterations::kContiguous;
|
||||
iteration_vector_ = index % kAccessesPerVector;
|
||||
int residual_access = index / kAccessesPerVector;
|
||||
iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous;
|
||||
iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous;
|
||||
}
|
||||
|
||||
/// Adds a pointer offset in units of Element
|
||||
@@ -214,7 +221,7 @@ public:
|
||||
return coord.n() < problem_size_.N &&
|
||||
coord.h() >= 0 && coord.h() < problem_size_.H &&
|
||||
coord.w() >= 0 && coord.w() < problem_size_.W &&
|
||||
coord.c() < problem_size_.C;
|
||||
(coord.c() + iteration_vector_ * AccessType::kElements) < problem_size_.C;
|
||||
}
|
||||
|
||||
/// Returns a pointer to the vector starting at the current coordinate
|
||||
@@ -224,7 +231,7 @@ public:
|
||||
TensorCoord coord = at();
|
||||
LongIndex offset = params_.layout(coord);
|
||||
|
||||
AccessType const *ptr = reinterpret_cast<AccessType const *>(pointer_ + offset * sizeof_bits<Element>::value / 8);
|
||||
AccessType const *ptr = reinterpret_cast<AccessType const *>(pointer_ + offset * sizeof_bits<Element>::value / 8) + iteration_vector_;
|
||||
|
||||
return ptr;
|
||||
}
|
||||
@@ -232,6 +239,12 @@ public:
|
||||
/// Increments to the next memory access
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv2dFpropActivationTileAccessIteratorAnalytic &operator++() {
|
||||
++iteration_vector_;
|
||||
if (iteration_vector_ < kAccessesPerVector) {
|
||||
return *this;
|
||||
}
|
||||
iteration_vector_ = 0;
|
||||
|
||||
++iteration_contiguous_;
|
||||
if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) {
|
||||
return *this;
|
||||
@@ -252,7 +265,7 @@ public:
|
||||
static Status can_implement(Conv2dProblemSize const &problem_size) {
|
||||
|
||||
// check alignment constraint on iterator's contiguous dimension
|
||||
if (problem_size.C % (128/sizeof_bits<Element>::value)) {
|
||||
if (problem_size.C % AccessType::kElements) {
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
|
||||
|
||||
@@ -61,7 +61,7 @@ template <
|
||||
typename Element_,
|
||||
typename Layout_,
|
||||
typename ThreadMap_,
|
||||
int AccessSize = ThreadMap_::kElementsPerAccess
|
||||
typename AccessType_ = cutlass::AlignedArray<Element_, ThreadMap_::kElementsPerAccess>
|
||||
>
|
||||
class Conv2dFpropActivationTileAccessIteratorOptimized {
|
||||
public:
|
||||
@@ -75,7 +75,7 @@ public:
|
||||
using Layout = Layout_;
|
||||
using TensorCoord = typename Layout::TensorCoord;
|
||||
using ThreadMap = ThreadMap_;
|
||||
using AccessType = AlignedArray<Element, AccessSize>;
|
||||
using AccessType = AccessType_;
|
||||
using TensorRef = cutlass::TensorRef<Element, Layout>;
|
||||
using Index = typename Layout::Index;
|
||||
using LongIndex = typename Layout::LongIndex;
|
||||
@@ -86,6 +86,11 @@ public:
|
||||
|
||||
using Mask = uint64_t;
|
||||
|
||||
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
|
||||
|
||||
static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements),
|
||||
"Vectors implied by the thread map must be divisible by the access type.");
|
||||
|
||||
//
|
||||
// Simplifying assertions
|
||||
//
|
||||
@@ -98,8 +103,6 @@ public:
|
||||
|
||||
using Params = Conv2dFpropActivationIteratorOptimizedParams<Layout>;
|
||||
|
||||
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
|
||||
|
||||
private:
|
||||
|
||||
Params const ¶ms_;
|
||||
@@ -213,10 +216,10 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) {
|
||||
clear_mask_(filter_c_ + v_idx * AccessSize >= problem_size_.C, v_idx);
|
||||
}
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) {
|
||||
clear_mask(v_idx, filter_c_ + v_idx * AccessType::kElements >= problem_size_.C);
|
||||
}
|
||||
|
||||
set_iteration_index(0);
|
||||
}
|
||||
@@ -260,56 +263,7 @@ private:
|
||||
pointer_[s] += byte_offset;
|
||||
}
|
||||
}
|
||||
|
||||
/// Clears the predicates
|
||||
CUTLASS_HOST_DEVICE
|
||||
void clear_mask_(bool clear, int index) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
|
||||
|
||||
// We are using inline PTX assembly here to avoid an CUDA C++ compilation
|
||||
// artifact in which control flow instructions are generated. Instead, our
|
||||
// intent is to predicate the mov instructions.
|
||||
#if defined(__CUDA_ARCH__)
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" .reg .pred p;\n"
|
||||
" .reg .u32 m;"
|
||||
" mov.u32 m, %2;"
|
||||
" setp.ne.b32 p, %1, 0;\n"
|
||||
" @p mov.u32 m, 0;\n"
|
||||
" mov.u32 %0, m;\n"
|
||||
"}\n"
|
||||
:
|
||||
"=r"(masks_[s][index][0])
|
||||
:
|
||||
"r"((int)clear),
|
||||
"r"(masks_[s][index][0])
|
||||
);
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" .reg .pred p;\n"
|
||||
" .reg .u32 m;"
|
||||
" mov.u32 m, %2;"
|
||||
" setp.ne.b32 p, %1, 0;\n"
|
||||
" @p mov.u32 m, 0;\n"
|
||||
" mov.u32 %0, m;\n"
|
||||
"}\n"
|
||||
:
|
||||
"=r"(masks_[s][index][1])
|
||||
:
|
||||
"r"((int)clear),
|
||||
"r"(masks_[s][index][1])
|
||||
);
|
||||
#else
|
||||
if (clear) {
|
||||
masks_[s][index][0] = 0;
|
||||
masks_[s][index][1] = 0;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
|
||||
/// Overrides the internal iteration index
|
||||
@@ -354,23 +308,33 @@ public:
|
||||
filter_c_ += params_.filter_c_delta;
|
||||
}
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) {
|
||||
clear_mask_(filter_c_ + v_idx * AccessSize >= problem_size_.C, v_idx);
|
||||
}
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) {
|
||||
clear_mask(v_idx, filter_c_ + v_idx * AccessType::kElements >= problem_size_.C);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/// Clears the predicates
|
||||
CUTLASS_HOST_DEVICE
|
||||
void clear_mask() {
|
||||
void clear_mask(bool clear = true) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < kAccessesPerVector; ++v) {
|
||||
masks_[s][v][0] = Mask(0);
|
||||
masks_[s][v][1] = Mask(0);
|
||||
masks_[s][v][0] = clear ? 0 : masks_[s][v][0];
|
||||
masks_[s][v][1] = clear ? 0 : masks_[s][v][1];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Clears the predicates
|
||||
CUTLASS_HOST_DEVICE
|
||||
void clear_mask(int v, bool clear = true) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
|
||||
masks_[s][v][0] = clear ? 0 : masks_[s][v][0];
|
||||
masks_[s][v][1] = clear ? 0 : masks_[s][v][1];
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
@@ -396,7 +360,6 @@ public:
|
||||
if (iteration_vector_ < kAccessesPerVector) {
|
||||
return *this;
|
||||
}
|
||||
|
||||
iteration_vector_ = 0;
|
||||
|
||||
++iteration_contiguous_;
|
||||
@@ -419,7 +382,7 @@ public:
|
||||
static Status can_implement(Conv2dProblemSize const &problem_size) {
|
||||
|
||||
// check alignment constraint on iterator's contiguous dimension
|
||||
if (problem_size.C % AccessSize) {
|
||||
if (problem_size.C % AccessType::kElements) {
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
|
||||
|
||||
@@ -59,7 +59,8 @@ template <
|
||||
typename Shape_,
|
||||
typename Element_,
|
||||
typename Layout_,
|
||||
typename ThreadMap_
|
||||
typename ThreadMap_,
|
||||
typename AccessType_ = cutlass::AlignedArray<Element_, ThreadMap_::kElementsPerAccess>
|
||||
>
|
||||
class Conv2dFpropFilterTileAccessIteratorAnalytic {
|
||||
public:
|
||||
@@ -72,7 +73,7 @@ public:
|
||||
using Element = Element_;
|
||||
using Layout = Layout_;
|
||||
using ThreadMap = ThreadMap_;
|
||||
using AccessType = AlignedArray<Element, ThreadMap::kElementsPerAccess>;
|
||||
using AccessType = AccessType_;
|
||||
using TensorRef = cutlass::TensorRef<Element, Layout>;
|
||||
using TensorCoord = typename Layout::TensorCoord;
|
||||
using Index = typename Layout::Index;
|
||||
@@ -81,7 +82,12 @@ public:
|
||||
static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided;
|
||||
static int const kConvDim = 2;
|
||||
using ConvProblemSize = typename conv::Conv2dProblemSize;
|
||||
|
||||
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
|
||||
|
||||
static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements),
|
||||
"Vectors implied by the thread map must be divisible by the access type.");
|
||||
|
||||
//
|
||||
// Simplifying assertions
|
||||
//
|
||||
@@ -94,14 +100,13 @@ public:
|
||||
|
||||
using Params = Conv2dAnalyticParams<Layout>;
|
||||
|
||||
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
|
||||
|
||||
private:
|
||||
|
||||
Params const ¶ms_;
|
||||
Conv2dProblemSize const &problem_size_;
|
||||
LongIndex iteration_contiguous_;
|
||||
LongIndex iteration_strided_;
|
||||
LongIndex iteration_vector_;
|
||||
char const *pointer_;
|
||||
|
||||
int filter_r_;
|
||||
@@ -142,8 +147,10 @@ public:
|
||||
/// Overrides the internal iteration index
|
||||
CUTLASS_HOST_DEVICE
|
||||
void set_iteration_index(Index index) {
|
||||
iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous;
|
||||
iteration_strided_ = index / ThreadMap::Iterations::kContiguous;
|
||||
iteration_vector_ = index % kAccessesPerVector;
|
||||
int residual_access = index / kAccessesPerVector;
|
||||
iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous;
|
||||
iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous;
|
||||
}
|
||||
|
||||
/// Adds a pointer offset in units of Element
|
||||
@@ -187,7 +194,7 @@ public:
|
||||
TensorCoord coord = at();
|
||||
|
||||
return coord.n() < problem_size_.K &&
|
||||
coord.c() < problem_size_.C;
|
||||
(coord.c() + iteration_vector_ * AccessType::kElements) < problem_size_.C;
|
||||
}
|
||||
|
||||
/// Returns a pointer to the vector starting at the current coordinate
|
||||
@@ -197,12 +204,18 @@ public:
|
||||
TensorCoord coord = at();
|
||||
LongIndex offset = params_.layout(coord);
|
||||
|
||||
return reinterpret_cast<AccessType const *>(pointer_ + offset * sizeof_bits<Element>::value / 8);
|
||||
return reinterpret_cast<AccessType const *>(pointer_ + offset * sizeof_bits<Element>::value / 8) + iteration_vector_;
|
||||
}
|
||||
|
||||
/// Increments to the next memory access
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv2dFpropFilterTileAccessIteratorAnalytic &operator++() {
|
||||
++iteration_vector_;
|
||||
if (iteration_vector_ < kAccessesPerVector) {
|
||||
return *this;
|
||||
}
|
||||
iteration_vector_ = 0;
|
||||
|
||||
++iteration_contiguous_;
|
||||
if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) {
|
||||
return *this;
|
||||
@@ -223,7 +236,7 @@ public:
|
||||
static Status can_implement(Conv2dProblemSize const &problem_size) {
|
||||
|
||||
// check alignment constraint on iterator's contiguous dimension
|
||||
if (problem_size.K % (128/sizeof_bits<Element>::value)) {
|
||||
if (problem_size.K % AccessType::kElements) {
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
|
||||
@@ -250,5 +263,3 @@ public:
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
|
||||
@@ -61,7 +61,7 @@ template <
|
||||
typename Element_,
|
||||
typename Layout_,
|
||||
typename ThreadMap_,
|
||||
int AccessSize = ThreadMap_::kElementsPerAccess
|
||||
typename AccessType_ = cutlass::AlignedArray<Element_, ThreadMap_::kElementsPerAccess>
|
||||
>
|
||||
class Conv2dFpropFilterTileAccessIteratorOptimized{
|
||||
public:
|
||||
@@ -74,7 +74,7 @@ public:
|
||||
using Element = Element_;
|
||||
using Layout = Layout_;
|
||||
using ThreadMap = ThreadMap_;
|
||||
using AccessType = AlignedArray<Element, AccessSize>;
|
||||
using AccessType = AccessType_;
|
||||
using TensorRef = cutlass::TensorRef<Element, Layout>;
|
||||
using TensorCoord = typename Layout::TensorCoord;
|
||||
using Index = typename Layout::Index;
|
||||
@@ -83,15 +83,18 @@ public:
|
||||
static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided;
|
||||
static int const kConvDim = 2;
|
||||
using ConvProblemSize = typename conv::Conv2dProblemSize;
|
||||
|
||||
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
|
||||
|
||||
static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements),
|
||||
"Vectors implied by the thread map must be divisible by the access type.");
|
||||
|
||||
//
|
||||
// Simplifying assertions
|
||||
//
|
||||
static_assert(ThreadMap::Iterations::kContiguous == 1,
|
||||
"Require Iterations::kContiguous == 1");
|
||||
|
||||
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
|
||||
|
||||
//
|
||||
// Parameters structure
|
||||
//
|
||||
@@ -170,6 +173,7 @@ public:
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
|
||||
uint32_t pred = ((column + s * ThreadMap::Delta::kStrided < problem_size_.K) ? 1u : 0);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) {
|
||||
predicates_[v_idx] |= (pred << s);
|
||||
@@ -178,7 +182,7 @@ public:
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) {
|
||||
clear_mask_(filter_c_ + v_idx * AccessSize >= problem_size_.C, v_idx);
|
||||
clear_mask(v_idx, filter_c_ + v_idx * AccessType::kElements >= problem_size_.C);
|
||||
}
|
||||
|
||||
pointer_ += (
|
||||
@@ -188,41 +192,11 @@ public:
|
||||
set_iteration_index(0);
|
||||
}
|
||||
|
||||
/// Clears the predicates
|
||||
CUTLASS_HOST_DEVICE
|
||||
void clear_mask_(bool clear, int index) {
|
||||
// We are using inline PTX assembly here to avoid an CUDA C++ compilation
|
||||
// artifact in which control flow instructions are generated. Instead, our
|
||||
// intent is to predicate the mov instructions.
|
||||
#if defined(__CUDA_ARCH__)
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" .reg .pred p;\n"
|
||||
" .reg .u32 m;"
|
||||
" mov.u32 m, %2;"
|
||||
" setp.ne.b32 p, %1, 0;\n"
|
||||
" @p mov.u32 m, 0;\n"
|
||||
" mov.u32 %0, m;\n"
|
||||
"}\n"
|
||||
:
|
||||
"=r"(predicates_[index])
|
||||
:
|
||||
"r"((int)clear),
|
||||
"r"(predicates_[index])
|
||||
);
|
||||
#else
|
||||
if (clear) {
|
||||
predicates_[index] = 0;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
/// Overrides the internal iteration index
|
||||
CUTLASS_HOST_DEVICE
|
||||
void set_iteration_index(Index index) {
|
||||
iteration_vector_ = index % kAccessesPerVector;
|
||||
int residual_access = index / kAccessesPerVector;
|
||||
|
||||
iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous;
|
||||
iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous;
|
||||
}
|
||||
@@ -246,15 +220,21 @@ public:
|
||||
next = params_.inc_next_c;
|
||||
filter_c_ += params_.filter_c_delta;
|
||||
}
|
||||
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) {
|
||||
clear_mask_(filter_c_ + v_idx * AccessSize >= problem_size_.C, v_idx);
|
||||
clear_mask(v_idx, filter_c_ + v_idx * AccessType::kElements >= problem_size_.C);
|
||||
}
|
||||
|
||||
pointer_ += next;
|
||||
}
|
||||
|
||||
/// Clears the predicates
|
||||
CUTLASS_HOST_DEVICE
|
||||
void clear_mask(int v, bool clear = true) {
|
||||
predicates_[v] = clear ? 0u : predicates_[v];
|
||||
}
|
||||
|
||||
/// Returns true if the current coordinate is within the filter tensor W
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool valid() {
|
||||
@@ -274,7 +254,6 @@ public:
|
||||
if (iteration_vector_ < kAccessesPerVector) {
|
||||
return *this;
|
||||
}
|
||||
|
||||
iteration_vector_ = 0;
|
||||
|
||||
++iteration_contiguous_;
|
||||
@@ -301,7 +280,7 @@ public:
|
||||
static Status can_implement(Conv2dProblemSize const &problem_size) {
|
||||
|
||||
// check alignment constraint on iterator's contiguous dimension
|
||||
if (problem_size.C % AccessSize) {
|
||||
if (problem_size.C % AccessType::kElements) {
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
|
||||
|
||||
@@ -68,6 +68,7 @@ public:
|
||||
using Params = typename TileAccessIterator::Params;
|
||||
static int const kConvDim = TileAccessIterator::kConvDim;
|
||||
using ConvProblemSize = typename TileAccessIterator::ConvProblemSize;
|
||||
static int const kAccessesPerVector = TileAccessIterator::kAccessesPerVector;
|
||||
|
||||
/// Fragment object to be loaded or stored
|
||||
using Fragment = cutlass::Array<
|
||||
@@ -130,18 +131,20 @@ public:
|
||||
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) {
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < tile_access_iterator_.kAccessesPerVector; ++v) {
|
||||
for (int v = 0; v < kAccessesPerVector; ++v) {
|
||||
|
||||
int idx = v + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous);
|
||||
|
||||
cutlass::arch::global_load<
|
||||
AccessType,
|
||||
sizeof(AccessType)
|
||||
>(
|
||||
frag_ptr[(c + s * ThreadMap::Iterations::kContiguous) * tile_access_iterator_.kAccessesPerVector + v],
|
||||
frag_ptr[idx],
|
||||
tile_access_iterator_.get() + pointer_offset,
|
||||
tile_access_iterator_.valid()
|
||||
);
|
||||
|
||||
|
||||
++tile_access_iterator_;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -58,7 +58,8 @@ namespace threadblock {
|
||||
template <
|
||||
typename Shape_,
|
||||
typename Element_,
|
||||
typename ThreadMap_
|
||||
typename ThreadMap_,
|
||||
typename AccessType_ = cutlass::AlignedArray<Element_, ThreadMap_::kElementsPerAccess>
|
||||
>
|
||||
class Conv2dWgradActivationTileAccessIteratorAnalytic {
|
||||
public:
|
||||
@@ -70,7 +71,7 @@ public:
|
||||
using Element = Element_;
|
||||
using Layout = layout::TensorNHWC;
|
||||
using ThreadMap = ThreadMap_;
|
||||
using AccessType = AlignedArray<Element, ThreadMap::kElementsPerAccess>;
|
||||
using AccessType = AccessType_;
|
||||
using TensorRef = cutlass::TensorRef<Element, Layout>;
|
||||
using TensorCoord = typename Layout::TensorCoord;
|
||||
using Index = typename Layout::Index;
|
||||
@@ -79,7 +80,12 @@ public:
|
||||
static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided;
|
||||
static int const kConvDim = 2;
|
||||
using ConvProblemSize = typename conv::Conv2dProblemSize;
|
||||
|
||||
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
|
||||
|
||||
static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements),
|
||||
"Vectors implied by the thread map must be divisible by the access type.");
|
||||
|
||||
static_assert(sizeof_bits<Element>::value >= 8,
|
||||
"WGRAD requires elements of size 8b or greater.");
|
||||
|
||||
@@ -89,14 +95,13 @@ public:
|
||||
|
||||
using Params = Conv2dAnalyticParams<Layout>;
|
||||
|
||||
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
|
||||
|
||||
private:
|
||||
|
||||
Params const ¶ms_;
|
||||
Conv2dProblemSize const &problem_size_;
|
||||
LongIndex iteration_contiguous_;
|
||||
LongIndex iteration_strided_;
|
||||
LongIndex iteration_vector_;
|
||||
char const *pointer_;
|
||||
|
||||
// Filter postion (r,s,c) in contiguous dimension stays constant for each gemm_iteration_k
|
||||
@@ -149,8 +154,10 @@ public:
|
||||
/// Overrides the internal iteration index
|
||||
CUTLASS_HOST_DEVICE
|
||||
void set_iteration_index(Index index) {
|
||||
iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous;
|
||||
iteration_strided_ = index / ThreadMap::Iterations::kContiguous;
|
||||
iteration_vector_ = index % kAccessesPerVector;
|
||||
int residual_access = index / kAccessesPerVector;
|
||||
iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous;
|
||||
iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous;
|
||||
}
|
||||
|
||||
/// Adds a pointer offset in units of Element
|
||||
@@ -173,9 +180,19 @@ public:
|
||||
/// by the iterator.
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorCoord at() const {
|
||||
int r, s, c;
|
||||
|
||||
int r = filter_r_[iteration_contiguous_];
|
||||
int s = filter_s_[iteration_contiguous_];
|
||||
if (kAccessesPerVector == 1) {
|
||||
r = filter_r_[iteration_contiguous_];
|
||||
s = filter_s_[iteration_contiguous_];
|
||||
c = filter_c_[iteration_contiguous_];
|
||||
} else {
|
||||
c = (filter_c_[iteration_contiguous_] + iteration_vector_ * AccessType::kElements) % problem_size_.C;
|
||||
int wrap_c = (filter_c_[iteration_contiguous_] + iteration_vector_ * AccessType::kElements) / problem_size_.C;
|
||||
s = (filter_s_[iteration_contiguous_] + wrap_c) % problem_size_.S;
|
||||
int wrap_s = (filter_s_[iteration_contiguous_] + wrap_c) / problem_size_.S;
|
||||
r = filter_r_[iteration_contiguous_] + wrap_s;
|
||||
}
|
||||
|
||||
if (problem_size_.mode == Mode::kConvolution) {
|
||||
r = (problem_size_.R - 1 - r);
|
||||
@@ -184,14 +201,14 @@ public:
|
||||
|
||||
int n = offset_npq_[iteration_strided_] / (problem_size_.P * problem_size_.Q);
|
||||
int residual = offset_npq_[iteration_strided_] % (problem_size_.P * problem_size_.Q);
|
||||
|
||||
|
||||
int p = residual / problem_size_.Q;
|
||||
int q = residual % problem_size_.Q;
|
||||
|
||||
|
||||
int h = p * problem_size_.stride_h - problem_size_.pad_h + r * problem_size_.dilation_h;
|
||||
int w = q * problem_size_.stride_w - problem_size_.pad_w + s * problem_size_.dilation_w;
|
||||
|
||||
return TensorCoord(n, h, w, filter_c_[iteration_contiguous_]);
|
||||
|
||||
return TensorCoord(n, h, w, c);
|
||||
}
|
||||
|
||||
/// Returns true if the current coordinate is within the activation tensor x
|
||||
@@ -201,8 +218,7 @@ public:
|
||||
|
||||
return coord.n() < problem_size_.N &&
|
||||
coord.h() >= 0 && coord.h() < problem_size_.H &&
|
||||
coord.w() >= 0 && coord.w() < problem_size_.W &&
|
||||
coord.c() < problem_size_.C;
|
||||
coord.w() >= 0 && coord.w() < problem_size_.W;
|
||||
}
|
||||
|
||||
/// Returns a pointer to the vector starting at the current coordinate
|
||||
@@ -218,6 +234,12 @@ public:
|
||||
/// Increments to the next memory access
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv2dWgradActivationTileAccessIteratorAnalytic &operator++() {
|
||||
++iteration_vector_;
|
||||
if (iteration_vector_ < kAccessesPerVector) {
|
||||
return *this;
|
||||
}
|
||||
iteration_vector_ = 0;
|
||||
|
||||
++iteration_contiguous_;
|
||||
if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) {
|
||||
return *this;
|
||||
@@ -237,13 +259,12 @@ public:
|
||||
static Status can_implement(Conv2dProblemSize const &problem_size) {
|
||||
|
||||
// check alignment constraint on iterator's contiguous dimension
|
||||
if (problem_size.K % (128/sizeof_bits<Element>::value)) {
|
||||
if (problem_size.K % AccessType::kElements) {
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@@ -57,7 +57,8 @@ namespace threadblock {
|
||||
template <
|
||||
typename Shape_,
|
||||
typename Element_,
|
||||
typename ThreadMap_
|
||||
typename ThreadMap_,
|
||||
typename AccessType_ = cutlass::AlignedArray<Element_, ThreadMap_::kElementsPerAccess>
|
||||
>
|
||||
class Conv2dWgradActivationTileAccessIteratorOptimized {
|
||||
public:
|
||||
@@ -69,7 +70,7 @@ public:
|
||||
using Element = Element_;
|
||||
using Layout = layout::TensorNHWC;
|
||||
using ThreadMap = ThreadMap_;
|
||||
using AccessType = AlignedArray<Element, ThreadMap::kElementsPerAccess>;
|
||||
using AccessType = AccessType_;
|
||||
using TensorRef = cutlass::TensorRef<Element, Layout>;
|
||||
using TensorCoord = typename Layout::TensorCoord;
|
||||
using Index = typename Layout::Index;
|
||||
@@ -78,7 +79,12 @@ public:
|
||||
static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided;
|
||||
static int const kConvDim = 2;
|
||||
using ConvProblemSize = typename conv::Conv2dProblemSize;
|
||||
|
||||
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
|
||||
|
||||
static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements),
|
||||
"Vectors implied by the thread map must be divisible by the access type.");
|
||||
|
||||
static_assert(sizeof_bits<Element>::value >= 8,
|
||||
"WGRAD requires elements of size 8b or greater.");
|
||||
|
||||
@@ -88,14 +94,13 @@ public:
|
||||
|
||||
using Params = Conv2dWgradActivationIteratorOptimizedParams;
|
||||
|
||||
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
|
||||
|
||||
private:
|
||||
|
||||
Conv2dWgradActivationIteratorOptimizedParams const ¶ms_;
|
||||
Conv2dProblemSize const &problem_size_;
|
||||
LongIndex iteration_contiguous_;
|
||||
LongIndex iteration_strided_;
|
||||
LongIndex iteration_vector_;
|
||||
char const *pointer_;
|
||||
|
||||
// Precomputed effective filter postion (r,s) in contiguous dimension stays constant for each gemm_iteration_k
|
||||
@@ -153,9 +158,8 @@ public:
|
||||
s = (problem_size_.S - 1 - s);
|
||||
}
|
||||
|
||||
precomputed_filter_r_[c] = - problem_size_.pad_h + r * problem_size_.dilation_h;
|
||||
precomputed_filter_s_[c] = - problem_size_.pad_w + s * problem_size_.dilation_w;
|
||||
|
||||
precomputed_filter_r_[c] = -problem_size_.pad_h + r * problem_size_.dilation_h;
|
||||
precomputed_filter_s_[c] = -problem_size_.pad_w + s * problem_size_.dilation_w;
|
||||
}
|
||||
|
||||
// initialize n, p, q offset for every strided iteration
|
||||
@@ -170,8 +174,10 @@ public:
|
||||
/// Overrides the internal iteration index
|
||||
CUTLASS_HOST_DEVICE
|
||||
void set_iteration_index(Index index) {
|
||||
iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous;
|
||||
iteration_strided_ = index / ThreadMap::Iterations::kContiguous;
|
||||
iteration_vector_ = index % kAccessesPerVector;
|
||||
int residual_access = index / kAccessesPerVector;
|
||||
iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous;
|
||||
iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous;
|
||||
}
|
||||
|
||||
/// Adds a pointer offset in units of Element
|
||||
@@ -194,6 +200,31 @@ public:
|
||||
/// by the iterator.
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorCoord at() const {
|
||||
int r = precomputed_filter_r_[iteration_contiguous_];
|
||||
int s = precomputed_filter_s_[iteration_contiguous_];
|
||||
int c = filter_c_[iteration_contiguous_];
|
||||
|
||||
if (kAccessesPerVector > 1) {
|
||||
int wrap_c;
|
||||
params_.c_divmod(wrap_c, c, c + iteration_vector_ * AccessType::kElements);
|
||||
|
||||
if (problem_size_.mode == Mode::kConvolution) {
|
||||
s -= (problem_size_.dilation_w * wrap_c);
|
||||
|
||||
int wrap_s = (s == -problem_size_.pad_w - problem_size_.dilation_w);
|
||||
s = wrap_s ? (-problem_size_.pad_w + (problem_size_.S - 1) * problem_size_.dilation_w): s;
|
||||
|
||||
r -= (problem_size_.dilation_h * wrap_s);
|
||||
|
||||
} else {
|
||||
s += (problem_size_.dilation_w * wrap_c);
|
||||
|
||||
int wrap_s = (s == (-problem_size_.pad_w + problem_size_.S * problem_size_.dilation_w));
|
||||
s = wrap_s ? -problem_size_.pad_w : s;
|
||||
|
||||
r += (problem_size_.dilation_h * wrap_s);
|
||||
}
|
||||
}
|
||||
|
||||
// The subseqnet fast_divmod() operations are equivalent to the following logical computation:
|
||||
//
|
||||
@@ -209,10 +240,10 @@ public:
|
||||
params_.pq_divmod(n, residual, offset_npq_[iteration_strided_]);
|
||||
params_.q_divmod(p, q, residual);
|
||||
|
||||
int h = p * problem_size_.stride_h + precomputed_filter_r_[iteration_contiguous_];
|
||||
int w = q * problem_size_.stride_w + precomputed_filter_s_[iteration_contiguous_];
|
||||
int h = p * problem_size_.stride_h + r;
|
||||
int w = q * problem_size_.stride_w + s;
|
||||
|
||||
return TensorCoord(n, h, w, filter_c_[iteration_contiguous_]);
|
||||
return TensorCoord(n, h, w, c);
|
||||
}
|
||||
|
||||
/// Returns true if the current coordinate is within the activation tensor x
|
||||
@@ -222,8 +253,7 @@ public:
|
||||
|
||||
return coord.n() < problem_size_.N &&
|
||||
coord.h() >= 0 && coord.h() < problem_size_.H &&
|
||||
coord.w() >= 0 && coord.w() < problem_size_.W &&
|
||||
coord.c() < problem_size_.C;
|
||||
coord.w() >= 0 && coord.w() < problem_size_.W;
|
||||
}
|
||||
|
||||
/// Returns a pointer to the vector starting at the current coordinate
|
||||
@@ -239,6 +269,12 @@ public:
|
||||
/// Increments to the next memory access
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv2dWgradActivationTileAccessIteratorOptimized &operator++() {
|
||||
++iteration_vector_;
|
||||
if (iteration_vector_ < kAccessesPerVector) {
|
||||
return *this;
|
||||
}
|
||||
iteration_vector_ = 0;
|
||||
|
||||
++iteration_contiguous_;
|
||||
if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) {
|
||||
return *this;
|
||||
@@ -258,14 +294,14 @@ public:
|
||||
static Status can_implement(Conv2dProblemSize const &problem_size) {
|
||||
|
||||
// check alignment constraint on iterator's contiguous dimension
|
||||
if (problem_size.K % (128/sizeof_bits<Element>::value)) {
|
||||
if (problem_size.K % AccessType::kElements) {
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
|
||||
@@ -58,7 +58,8 @@ namespace threadblock {
|
||||
template <
|
||||
typename Shape_,
|
||||
typename Element_,
|
||||
typename ThreadMap_
|
||||
typename ThreadMap_,
|
||||
typename AccessType_ = cutlass::AlignedArray<Element_, ThreadMap_::kElementsPerAccess>
|
||||
>
|
||||
class Conv2dWgradOutputGradientTileAccessIteratorAnalytic {
|
||||
public:
|
||||
@@ -70,7 +71,7 @@ public:
|
||||
using Element = Element_;
|
||||
using Layout = layout::TensorNHWC;
|
||||
using ThreadMap = ThreadMap_;
|
||||
using AccessType = AlignedArray<Element, ThreadMap::kElementsPerAccess>;
|
||||
using AccessType = AccessType_;
|
||||
using TensorRef = cutlass::TensorRef<Element, Layout>;
|
||||
using TensorCoord = typename Layout::TensorCoord;
|
||||
using Index = typename Layout::Index;
|
||||
@@ -80,6 +81,11 @@ public:
|
||||
static int const kConvDim = 2;
|
||||
using ConvProblemSize = typename conv::Conv2dProblemSize;
|
||||
|
||||
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
|
||||
|
||||
static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements),
|
||||
"Vectors implied by the thread map must be divisible by the access type.");
|
||||
|
||||
static_assert(sizeof_bits<Element>::value >= 8,
|
||||
"WGRAD requires elements of size 8b or greater.");
|
||||
|
||||
@@ -89,14 +95,13 @@ public:
|
||||
|
||||
using Params = Conv2dAnalyticParams<Layout>;
|
||||
|
||||
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
|
||||
|
||||
private:
|
||||
|
||||
Params const ¶ms_;
|
||||
Conv2dProblemSize const &problem_size_;
|
||||
LongIndex iteration_contiguous_;
|
||||
LongIndex iteration_strided_;
|
||||
LongIndex iteration_vector_;
|
||||
char const *pointer_;
|
||||
|
||||
int filter_k_[ThreadMap::Iterations::kContiguous];
|
||||
@@ -143,8 +148,10 @@ public:
|
||||
/// Overrides the internal iteration index
|
||||
CUTLASS_HOST_DEVICE
|
||||
void set_iteration_index(Index index) {
|
||||
iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous;
|
||||
iteration_strided_ = index / ThreadMap::Iterations::kContiguous;
|
||||
iteration_vector_ = index % kAccessesPerVector;
|
||||
int residual_access = index / kAccessesPerVector;
|
||||
iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous;
|
||||
iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous;
|
||||
}
|
||||
|
||||
/// Adds a pointer offset in units of Element
|
||||
@@ -187,7 +194,7 @@ public:
|
||||
return coord.n() < problem_size_.N &&
|
||||
coord.h() < problem_size_.P &&
|
||||
coord.w() < problem_size_.Q &&
|
||||
coord.c() < problem_size_.K;
|
||||
(coord.c() + iteration_vector_ * AccessType::kElements) < problem_size_.K;
|
||||
}
|
||||
|
||||
/// Returns a pointer to the vector starting at the current coordinate
|
||||
@@ -197,12 +204,18 @@ public:
|
||||
TensorCoord coord = at();
|
||||
LongIndex offset = params_.layout(coord);
|
||||
|
||||
return reinterpret_cast<AccessType const *>(pointer_ + offset * sizeof_bits<Element>::value / 8);
|
||||
return reinterpret_cast<AccessType const *>(pointer_ + offset * sizeof_bits<Element>::value / 8) + iteration_vector_;
|
||||
}
|
||||
|
||||
/// Increments to the next memory access
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv2dWgradOutputGradientTileAccessIteratorAnalytic &operator++() {
|
||||
++iteration_vector_;
|
||||
if (iteration_vector_ < kAccessesPerVector) {
|
||||
return *this;
|
||||
}
|
||||
iteration_vector_ = 0;
|
||||
|
||||
++iteration_contiguous_;
|
||||
if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) {
|
||||
return *this;
|
||||
@@ -222,14 +235,14 @@ public:
|
||||
static Status can_implement(Conv2dProblemSize const &problem_size) {
|
||||
|
||||
// check alignment constraint on iterator's contiguous dimension
|
||||
if (problem_size.C % (128/sizeof_bits<Element>::value)) {
|
||||
if (problem_size.C % AccessType::kElements) {
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
@@ -237,5 +250,3 @@ public:
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
|
||||
@@ -57,7 +57,8 @@ namespace threadblock {
|
||||
template <
|
||||
typename Shape_,
|
||||
typename Element_,
|
||||
typename ThreadMap_
|
||||
typename ThreadMap_,
|
||||
typename AccessType_ = cutlass::AlignedArray<Element_, ThreadMap_::kElementsPerAccess>
|
||||
>
|
||||
class Conv2dWgradOutputGradientTileAccessIteratorOptimized {
|
||||
public:
|
||||
@@ -69,7 +70,7 @@ public:
|
||||
using Element = Element_;
|
||||
using Layout = layout::TensorNHWC;
|
||||
using ThreadMap = ThreadMap_;
|
||||
using AccessType = AlignedArray<Element, ThreadMap::kElementsPerAccess>;
|
||||
using AccessType = AccessType_;
|
||||
using TensorRef = cutlass::TensorRef<Element, Layout>;
|
||||
using TensorCoord = typename Layout::TensorCoord;
|
||||
using Index = typename Layout::Index;
|
||||
@@ -79,6 +80,11 @@ public:
|
||||
static int const kConvDim = 2;
|
||||
using ConvProblemSize = typename conv::Conv2dProblemSize;
|
||||
|
||||
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
|
||||
|
||||
static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements),
|
||||
"Vectors implied by the thread map must be divisible by the access type.");
|
||||
|
||||
static_assert(sizeof_bits<Element>::value >= 8,
|
||||
"WGRAD requires elements of size 8b or greater.");
|
||||
|
||||
@@ -88,17 +94,16 @@ public:
|
||||
|
||||
using Params = Conv2dWgradOutputGradientIteratorOptimizedParams;
|
||||
|
||||
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
|
||||
|
||||
private:
|
||||
|
||||
Conv2dWgradOutputGradientIteratorOptimizedParams const ¶ms_;
|
||||
Conv2dProblemSize const &problem_size_;
|
||||
LongIndex iteration_contiguous_;
|
||||
LongIndex iteration_strided_;
|
||||
LongIndex iteration_vector_;
|
||||
char const *pointer_;
|
||||
|
||||
uint32_t predicates_;
|
||||
uint32_t predicates_[kAccessesPerVector];
|
||||
int filter_k_;
|
||||
int offset_npq_;
|
||||
|
||||
@@ -115,7 +120,7 @@ public:
|
||||
params_(params),
|
||||
problem_size_(problem_size),
|
||||
pointer_(reinterpret_cast<char const *>(ptr)),
|
||||
predicates_(0),
|
||||
predicates_{0},
|
||||
filter_k_(0),
|
||||
offset_npq_(0) {
|
||||
|
||||
@@ -132,13 +137,16 @@ public:
|
||||
int filter_k = filter_k_ + c * ThreadMap::Delta::kContiguous;
|
||||
int offset_npq = offset_npq_ + s * ThreadMap::Delta::kStrided;
|
||||
|
||||
bool predicate = valid_(at_(offset_npq, filter_k));
|
||||
|
||||
uint32_t pred = (predicate ? 1u : 0);
|
||||
|
||||
int pred_idx = c + s * ThreadMap::Iterations::kContiguous;
|
||||
|
||||
predicates_ |= (pred << pred_idx);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < kAccessesPerVector; ++v) {
|
||||
bool predicate = valid_(at_(offset_npq, filter_k + v * AccessType::kElements));
|
||||
|
||||
uint32_t pred = (predicate ? 1u : 0);
|
||||
|
||||
int pred_idx = c + s * ThreadMap::Iterations::kContiguous;
|
||||
|
||||
predicates_[v] |= (pred << pred_idx);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -165,8 +173,10 @@ public:
|
||||
/// Overrides the internal iteration index
|
||||
CUTLASS_HOST_DEVICE
|
||||
void set_iteration_index(Index index) {
|
||||
iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous;
|
||||
iteration_strided_ = index / ThreadMap::Iterations::kContiguous;
|
||||
iteration_vector_ = index % kAccessesPerVector;
|
||||
int residual_access = index / kAccessesPerVector;
|
||||
iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous;
|
||||
iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous;
|
||||
}
|
||||
|
||||
/// Adds a pointer offset in units of Element
|
||||
@@ -185,7 +195,11 @@ public:
|
||||
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
|
||||
if (offset_npq_ + s * ThreadMap::Delta::kStrided >= params_.NPQ) {
|
||||
uint32_t kClearMask = ((1u << ThreadMap::Iterations::kContiguous) - 1) << (s * ThreadMap::Iterations::kContiguous);
|
||||
predicates_ = (predicates_ & (~kClearMask));
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < kAccessesPerVector; ++v) {
|
||||
predicates_[v] = (predicates_[v] & (~kClearMask));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -231,7 +245,7 @@ public:
|
||||
bool valid() const {
|
||||
|
||||
LongIndex pred_idx = iteration_contiguous_ + iteration_strided_ * ThreadMap::Iterations::kContiguous;
|
||||
return (predicates_ & (1u << pred_idx));
|
||||
return (predicates_[iteration_vector_] & (1u << pred_idx));
|
||||
}
|
||||
|
||||
/// Returns a pointer to the vector starting at the current coordinate
|
||||
@@ -242,12 +256,18 @@ public:
|
||||
pointer_ +
|
||||
iteration_strided_ * params_.offset_next_strided +
|
||||
iteration_contiguous_ * params_.offset_next_contiguous
|
||||
);
|
||||
) + iteration_vector_;
|
||||
}
|
||||
|
||||
/// Increments to the next memory access
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv2dWgradOutputGradientTileAccessIteratorOptimized &operator++() {
|
||||
++iteration_vector_;
|
||||
if (iteration_vector_ < kAccessesPerVector) {
|
||||
return *this;
|
||||
}
|
||||
iteration_vector_ = 0;
|
||||
|
||||
++iteration_contiguous_;
|
||||
if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) {
|
||||
return *this;
|
||||
@@ -267,14 +287,14 @@ public:
|
||||
static Status can_implement(Conv2dProblemSize const &problem_size) {
|
||||
|
||||
// check alignment constraint on iterator's contiguous dimension
|
||||
if (problem_size.C % (128/sizeof_bits<Element>::value)) {
|
||||
if (problem_size.C % AccessType::kElements) {
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
@@ -282,5 +302,3 @@ public:
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
|
||||
@@ -79,11 +79,10 @@ public:
|
||||
static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided;
|
||||
static int const kConvDim = 3;
|
||||
using ConvProblemSize = typename conv::Conv3dProblemSize;
|
||||
static int const kAccessesPerVector = 1;
|
||||
|
||||
static_assert(sizeof_bits<Element>::value >= 8,
|
||||
"DGRAD requires elements of size 8b or larger.");
|
||||
|
||||
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
|
||||
|
||||
//
|
||||
// Parameters structure
|
||||
@@ -261,5 +260,3 @@ public:
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
|
||||
@@ -82,8 +82,7 @@ public:
|
||||
static StrideSupport const kStrideSupport = StrideSupport_;
|
||||
static int const kConvDim = 3;
|
||||
using ConvProblemSize = typename conv::Conv3dProblemSize;
|
||||
|
||||
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
|
||||
static int const kAccessesPerVector = 1;
|
||||
|
||||
//
|
||||
// Parameters structure
|
||||
@@ -217,7 +216,8 @@ public:
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
|
||||
if (filter_k_ + s * ThreadMap::Delta::kStrided >= problem_size_.K) {
|
||||
uint32_t kClearMask = ((1u << ThreadMap::Iterations::kContiguous) - 1) << (s * ThreadMap::Iterations::kContiguous);
|
||||
uint32_t kClearMask = ((1u << ThreadMap::Iterations::kContiguous) - 1) << (s * ThreadMap::Iterations::kContiguous);
|
||||
|
||||
predicates_ = (predicates_ & (~kClearMask));
|
||||
}
|
||||
}
|
||||
@@ -281,5 +281,3 @@ public:
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
|
||||
@@ -93,11 +93,10 @@ public:
|
||||
static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided;
|
||||
static int const kConvDim = 3;
|
||||
using ConvProblemSize = typename conv::Conv3dProblemSize;
|
||||
static int const kAccessesPerVector = 1;
|
||||
|
||||
static_assert(sizeof_bits<Element>::value >= 8,
|
||||
"DGRAD requires elements of size 8b or greater.");
|
||||
|
||||
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
|
||||
|
||||
//
|
||||
// Simpligying assertions
|
||||
@@ -328,11 +327,11 @@ public:
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace conv
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
|
||||
@@ -86,7 +86,7 @@ public:
|
||||
static int const kConvDim = 3;
|
||||
using ConvProblemSize = typename conv::Conv3dProblemSize;
|
||||
using Coord3D = Coord<3>;
|
||||
|
||||
static int const kAccessesPerVector = 1;
|
||||
using Mask = uint64_t;
|
||||
|
||||
//
|
||||
@@ -101,8 +101,6 @@ public:
|
||||
|
||||
using Params = Conv3dDgradOutputGradientIteratorOptimizedParams;
|
||||
|
||||
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
|
||||
|
||||
private:
|
||||
|
||||
Params const ¶ms_;
|
||||
@@ -403,7 +401,6 @@ public:
|
||||
}
|
||||
|
||||
clear_mask_(filter_k_ >= problem_size_.K);
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -81,6 +81,7 @@ public:
|
||||
static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided;
|
||||
static int const kConvDim = 3;
|
||||
using ConvProblemSize = typename conv::Conv3dProblemSize;
|
||||
static int const kAccessesPerVector = 1;
|
||||
|
||||
//
|
||||
// Simplifying assertions
|
||||
@@ -94,8 +95,6 @@ public:
|
||||
|
||||
using Params = Conv3dAnalyticParams<Layout>;
|
||||
|
||||
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
|
||||
|
||||
private:
|
||||
|
||||
Params const ¶ms_;
|
||||
|
||||
@@ -82,7 +82,7 @@ public:
|
||||
static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided;
|
||||
static int const kConvDim = 3;
|
||||
using ConvProblemSize = typename conv::Conv3dProblemSize;
|
||||
|
||||
static int const kAccessesPerVector = 1;
|
||||
using Mask = uint64_t;
|
||||
|
||||
//
|
||||
@@ -97,8 +97,6 @@ public:
|
||||
|
||||
using Params = Conv3dFpropActivationIteratorOptimizedParams<Layout>;
|
||||
|
||||
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
|
||||
|
||||
private:
|
||||
|
||||
Conv3dFpropActivationIteratorOptimizedParams<Layout> const ¶ms_;
|
||||
|
||||
@@ -80,6 +80,7 @@ public:
|
||||
static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided;
|
||||
static int const kConvDim = 3;
|
||||
using ConvProblemSize = typename conv::Conv3dProblemSize;
|
||||
static int const kAccessesPerVector = 1;
|
||||
|
||||
//
|
||||
// Simplifying assertions
|
||||
@@ -93,8 +94,6 @@ public:
|
||||
|
||||
using Params = Conv3dAnalyticParams<Layout>;
|
||||
|
||||
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
|
||||
|
||||
private:
|
||||
|
||||
Params const ¶ms_;
|
||||
|
||||
@@ -82,6 +82,7 @@ public:
|
||||
static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided;
|
||||
static int const kConvDim = 3;
|
||||
using ConvProblemSize = typename conv::Conv3dProblemSize;
|
||||
static int const kAccessesPerVector = 1;
|
||||
|
||||
//
|
||||
// Simplifying assertions
|
||||
@@ -89,8 +90,6 @@ public:
|
||||
static_assert(ThreadMap::Iterations::kContiguous == 1,
|
||||
"Require Iterations::kContiguous == 1");
|
||||
|
||||
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
|
||||
|
||||
//
|
||||
// Parameters structure
|
||||
//
|
||||
@@ -156,7 +155,7 @@ public:
|
||||
params_(params),
|
||||
problem_size_(problem_size),
|
||||
pointer_(reinterpret_cast<char const *>(ptr)),
|
||||
predicates_(0),
|
||||
predicates_{0},
|
||||
filter_trs_(0),
|
||||
filter_c_(0) {
|
||||
|
||||
|
||||
@@ -79,11 +79,11 @@ public:
|
||||
static int const kConvDim = 3;
|
||||
using ConvProblemSize = typename conv::Conv3dProblemSize;
|
||||
|
||||
static int const kAccessesPerVector = 1;
|
||||
|
||||
static_assert(sizeof_bits<Element>::value >= 8,
|
||||
"WGRAD requires elements of size 8b or greater.");
|
||||
|
||||
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
|
||||
|
||||
//
|
||||
// Parameters structure
|
||||
//
|
||||
|
||||
@@ -79,12 +79,10 @@ public:
|
||||
static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided;
|
||||
static int const kConvDim = 3;
|
||||
using ConvProblemSize = typename conv::Conv3dProblemSize;
|
||||
|
||||
static int const kAccessesPerVector = 1;
|
||||
static_assert(sizeof_bits<Element>::value >= 8,
|
||||
"WGRAD requires elements of size 8b or greater.");
|
||||
|
||||
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
|
||||
|
||||
//
|
||||
// Parameters structure
|
||||
//
|
||||
|
||||
@@ -78,12 +78,10 @@ public:
|
||||
static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided;
|
||||
static int const kConvDim = 3;
|
||||
using ConvProblemSize = typename conv::Conv3dProblemSize;
|
||||
|
||||
static int const kAccessesPerVector = 1;
|
||||
static_assert(sizeof_bits<Element>::value >= 8,
|
||||
"WGRAD requires elements of size 8b or greater.");
|
||||
|
||||
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
|
||||
|
||||
//
|
||||
// Parameters structure
|
||||
//
|
||||
|
||||
@@ -79,12 +79,10 @@ public:
|
||||
static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided;
|
||||
static int const kConvDim = 3;
|
||||
using ConvProblemSize = typename conv::Conv3dProblemSize;
|
||||
|
||||
static int const kAccessesPerVector = 1;
|
||||
static_assert(sizeof_bits<Element>::value >= 8,
|
||||
"WGRAD requires elements of size 8b or greater.");
|
||||
|
||||
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
|
||||
|
||||
//
|
||||
// Parameters structure
|
||||
//
|
||||
|
||||
@@ -216,6 +216,7 @@ public:
|
||||
for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) {
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA>(
|
||||
dst_ptr + v, iterator_A.get(), iterator_A.valid());
|
||||
|
||||
++iterator_A;
|
||||
}
|
||||
|
||||
@@ -244,6 +245,7 @@ public:
|
||||
for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) {
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB>(
|
||||
dst_ptr + v, iterator_B.get(), iterator_B.valid());
|
||||
|
||||
++iterator_B;
|
||||
}
|
||||
++this->smem_iterator_B_;
|
||||
@@ -289,16 +291,17 @@ public:
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) {
|
||||
int const kSrcBytes =
|
||||
sizeof_bits<typename IteratorA::Element>::value *
|
||||
IteratorA::ThreadMap::kElementsPerAccess /
|
||||
IteratorA::kAccessesPerVector / 8;
|
||||
int const kSrcBytes =
|
||||
sizeof_bits<typename IteratorA::Element>::value *
|
||||
IteratorA::ThreadMap::kElementsPerAccess /
|
||||
IteratorA::kAccessesPerVector / 8;
|
||||
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA>(
|
||||
dst_ptr + v, iterator_A.get(), iterator_A.valid());
|
||||
|
||||
++iterator_A;
|
||||
}
|
||||
|
||||
++this->smem_iterator_A_;
|
||||
}
|
||||
|
||||
@@ -313,17 +316,18 @@ public:
|
||||
this->smem_iterator_B_.get());
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) {
|
||||
for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) {
|
||||
int const kSrcBytes =
|
||||
sizeof_bits<typename IteratorB::Element>::value *
|
||||
IteratorB::ThreadMap::kElementsPerAccess /
|
||||
IteratorB::kAccessesPerVector / 8;
|
||||
|
||||
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB>(
|
||||
dst_ptr + v, iterator_B.get(), iterator_B.valid());
|
||||
|
||||
|
||||
++iterator_B;
|
||||
}
|
||||
|
||||
++this->smem_iterator_B_;
|
||||
}
|
||||
|
||||
|
||||
@@ -39,6 +39,7 @@
|
||||
#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM80_Device_Conv2d_Dgrad_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16,
|
||||
128x128_64x3_64x64x64) {
|
||||
|
||||
@@ -80,6 +81,7 @@ TEST(SM80_Device_Conv2d_Dgrad_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tens
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM80_Device_Conv2d_Dgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16,
|
||||
128x128_64x3_64x64x64) {
|
||||
|
||||
@@ -121,4 +123,115 @@ TEST(SM80_Device_Conv2d_Dgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_ten
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM80_Device_Conv2d_Dgrad_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_align4,
|
||||
128x128_64x3_64x64x64) {
|
||||
|
||||
/// Conv operation element types for the Gemm equivalent (ImplicitGemm)
|
||||
using ElementA = cutlass::half_t;
|
||||
using ElementB = cutlass::half_t;
|
||||
using ElementC = cutlass::half_t;
|
||||
using ElementAccumulator = cutlass::half_t;
|
||||
using ElementCompute = cutlass::half_t;
|
||||
|
||||
/// Device-level Conv2d instance
|
||||
using Conv2dDgradKernel = typename cutlass::conv::kernel::DefaultConv2dDgrad<
|
||||
ElementA, cutlass::layout::TensorNHWC,
|
||||
ElementB, cutlass::layout::TensorNHWC,
|
||||
ElementC, cutlass::layout::TensorNHWC,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<128, 128, 64>,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>,
|
||||
cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementC,
|
||||
4,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
3,
|
||||
cutlass::arch::OpMultiplyAdd,
|
||||
cutlass::conv::IteratorAlgorithm::kAnalytic,
|
||||
cutlass::conv::StrideSupport::kUnity,
|
||||
4,
|
||||
4
|
||||
>::Kernel;
|
||||
|
||||
using Conv2dDgrad = cutlass::conv::device::ImplicitGemmConvolution<Conv2dDgradKernel>;
|
||||
|
||||
test::conv::device::Conv2dProblemVector problem_size_list;
|
||||
|
||||
// run specific problem size in the unit test first
|
||||
problem_size_list.push_back(cutlass::conv::Conv2dProblemSize(
|
||||
{1, 4, 4, 12}, // input size (NHWC)
|
||||
{8, 3, 3, 12}, // filter size (KRSC)
|
||||
{0, 0, 0, 0}, // padding (pad_h, _, pad_w, _)
|
||||
{3, 3}, // stride (stride_h, stride_w)
|
||||
{1, 1} // dilation (dilation_h, dilation_w)
|
||||
));
|
||||
|
||||
/// Run all unit test sizes with device-level Conv2d instance
|
||||
EXPECT_TRUE(test::conv::device::TestAllConv2d<Conv2dDgrad>(problem_size_list));
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM80_Device_Conv2d_Dgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_align4,
|
||||
128x128_64x3_64x64x64) {
|
||||
|
||||
/// Conv operation element types for the Gemm equivalent (ImplicitGemm)
|
||||
using ElementA = cutlass::half_t;
|
||||
using ElementB = cutlass::half_t;
|
||||
using ElementC = cutlass::half_t;
|
||||
using ElementAccumulator = cutlass::half_t;
|
||||
using ElementCompute = cutlass::half_t;
|
||||
|
||||
/// Device-level Conv2d instance
|
||||
using Conv2dDgradKernel = typename cutlass::conv::kernel::DefaultConv2dDgrad<
|
||||
ElementA, cutlass::layout::TensorNHWC,
|
||||
ElementB, cutlass::layout::TensorNHWC,
|
||||
ElementC, cutlass::layout::TensorNHWC,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<128, 128, 64>,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>,
|
||||
cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementC,
|
||||
4,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
3,
|
||||
cutlass::arch::OpMultiplyAdd,
|
||||
cutlass::conv::IteratorAlgorithm::kOptimized,
|
||||
cutlass::conv::StrideSupport::kUnity,
|
||||
4,
|
||||
4
|
||||
>::Kernel;
|
||||
|
||||
using Conv2dDgrad = cutlass::conv::device::ImplicitGemmConvolution<Conv2dDgradKernel>;
|
||||
|
||||
test::conv::device::Conv2dProblemVector problem_size_list;
|
||||
|
||||
// run specific problem size in the unit test first
|
||||
problem_size_list.push_back(cutlass::conv::Conv2dProblemSize(
|
||||
{1, 4, 4, 12}, // input size (NHWC)
|
||||
{8, 3, 3, 12}, // filter size (KRSC)
|
||||
{0, 0, 0, 0}, // padding (pad_h, _, pad_w, _)
|
||||
{3, 3}, // stride (stride_h, stride_w)
|
||||
{1, 1} // dilation (dilation_h, dilation_w)
|
||||
));
|
||||
|
||||
/// Run all unit test sizes with device-level Conv2d instance
|
||||
EXPECT_TRUE(test::conv::device::TestAllConv2d<Conv2dDgrad>(problem_size_list));
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#endif // CUTLASS_ARCH_MMA_SM80_SUPPORTED
|
||||
|
||||
@@ -38,46 +38,7 @@
|
||||
#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED)
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST(SM75_Device_Conv2d_Dgrad_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32,
|
||||
128x128_32x2_64x64x32) {
|
||||
|
||||
/// Conv operation element types for the Gemm equivalent (ImplicitGemm)
|
||||
using ElementA = cutlass::half_t;
|
||||
using ElementB = cutlass::half_t;
|
||||
using ElementC = float;
|
||||
using ElementAccumulator = float;
|
||||
using ElementCompute = float;
|
||||
|
||||
/// Device-level Conv2d instance
|
||||
using Conv2dDgradKernel = typename cutlass::conv::kernel::DefaultConv2dDgrad<
|
||||
ElementA, cutlass::layout::TensorNHWC,
|
||||
ElementB, cutlass::layout::TensorNHWC,
|
||||
ElementC, cutlass::layout::TensorNHWC,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm75,
|
||||
cutlass::gemm::GemmShape<128, 128, 32>,
|
||||
cutlass::gemm::GemmShape<64, 64, 32>,
|
||||
cutlass::gemm::GemmShape<16, 8, 8>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementC,
|
||||
128 / cutlass::sizeof_bits<ElementC>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2,
|
||||
cutlass::arch::OpMultiplyAdd,
|
||||
cutlass::conv::IteratorAlgorithm::kAnalytic,
|
||||
cutlass::conv::StrideSupport::kUnity
|
||||
>::Kernel;
|
||||
|
||||
using Conv2dDgrad = cutlass::conv::device::ImplicitGemmConvolution<Conv2dDgradKernel>;
|
||||
|
||||
EXPECT_TRUE(test::conv::device::TestAllConv2d<Conv2dDgrad>());
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST(SM75_Device_Conv2d_Dgrad_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_unity_stride,
|
||||
128x128_32x2_64x64x32) {
|
||||
|
||||
@@ -118,6 +79,7 @@ TEST(SM75_Device_Conv2d_Dgrad_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tens
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM75_Device_Conv2d_Dgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_unity_stride,
|
||||
128x128_32x2_64x64x32) {
|
||||
|
||||
@@ -158,4 +120,113 @@ TEST(SM75_Device_Conv2d_Dgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_ten
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM75_Device_Conv2d_Dgrad_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_unity_stride_align2,
|
||||
128x128_32x2_64x64x32) {
|
||||
|
||||
/// Conv operation element types for the Gemm equivalent (ImplicitGemm)
|
||||
using ElementA = cutlass::half_t;
|
||||
using ElementB = cutlass::half_t;
|
||||
using ElementC = float;
|
||||
using ElementAccumulator = float;
|
||||
using ElementCompute = float;
|
||||
|
||||
/// Device-level Conv2d instance
|
||||
using Conv2dDgradKernel = typename cutlass::conv::kernel::DefaultConv2dDgrad<
|
||||
ElementA, cutlass::layout::TensorNHWC,
|
||||
ElementB, cutlass::layout::TensorNHWC,
|
||||
ElementC, cutlass::layout::TensorNHWC,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm75,
|
||||
cutlass::gemm::GemmShape<128, 128, 32>,
|
||||
cutlass::gemm::GemmShape<64, 64, 32>,
|
||||
cutlass::gemm::GemmShape<16, 8, 8>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementC,
|
||||
4,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2,
|
||||
cutlass::arch::OpMultiplyAdd,
|
||||
cutlass::conv::IteratorAlgorithm::kAnalytic,
|
||||
cutlass::conv::StrideSupport::kUnity,
|
||||
2,
|
||||
2
|
||||
>::Kernel;
|
||||
|
||||
using Conv2dDgrad = cutlass::conv::device::ImplicitGemmConvolution<Conv2dDgradKernel>;
|
||||
|
||||
test::conv::device::Conv2dProblemVector problem_size_list;
|
||||
|
||||
// run specific problem size in the unit test first
|
||||
problem_size_list.push_back(cutlass::conv::Conv2dProblemSize(
|
||||
{1, 4, 4, 12}, // input size (NHWC)
|
||||
{8, 3, 3, 12}, // filter size (KRSC)
|
||||
{0, 0, 0, 0}, // padding (pad_h, _, pad_w, _)
|
||||
{3, 3}, // stride (stride_h, stride_w)
|
||||
{1, 1} // dilation (dilation_h, dilation_w)
|
||||
));
|
||||
|
||||
EXPECT_TRUE(test::conv::device::TestAllConv2d<Conv2dDgrad>(problem_size_list));
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM75_Device_Conv2d_Dgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_unity_stride_align2,
|
||||
128x128_32x2_64x64x32) {
|
||||
|
||||
/// Conv operation element types for the Gemm equivalent (ImplicitGemm)
|
||||
using ElementA = cutlass::half_t;
|
||||
using ElementB = cutlass::half_t;
|
||||
using ElementC = float;
|
||||
using ElementAccumulator = float;
|
||||
using ElementCompute = float;
|
||||
|
||||
/// Device-level Conv2d instance
|
||||
using Conv2dDgradKernel = typename cutlass::conv::kernel::DefaultConv2dDgrad<
|
||||
ElementA, cutlass::layout::TensorNHWC,
|
||||
ElementB, cutlass::layout::TensorNHWC,
|
||||
ElementC, cutlass::layout::TensorNHWC,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm75,
|
||||
cutlass::gemm::GemmShape<128, 128, 32>,
|
||||
cutlass::gemm::GemmShape<64, 64, 32>,
|
||||
cutlass::gemm::GemmShape<16, 8, 8>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementC,
|
||||
4,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2,
|
||||
cutlass::arch::OpMultiplyAdd,
|
||||
cutlass::conv::IteratorAlgorithm::kOptimized,
|
||||
cutlass::conv::StrideSupport::kUnity,
|
||||
2,
|
||||
2
|
||||
>::Kernel;
|
||||
|
||||
using Conv2dDgrad = cutlass::conv::device::ImplicitGemmConvolution<Conv2dDgradKernel>;
|
||||
|
||||
test::conv::device::Conv2dProblemVector problem_size_list;
|
||||
|
||||
// run specific problem size in the unit test first
|
||||
problem_size_list.push_back(cutlass::conv::Conv2dProblemSize(
|
||||
{1, 4, 4, 12}, // input size (NHWC)
|
||||
{8, 3, 3, 12}, // filter size (KRSC)
|
||||
{0, 0, 0, 0}, // padding (pad_h, _, pad_w, _)
|
||||
{3, 3}, // stride (stride_h, stride_w)
|
||||
{1, 1} // dilation (dilation_h, dilation_w)
|
||||
));
|
||||
|
||||
EXPECT_TRUE(test::conv::device::TestAllConv2d<Conv2dDgrad>(problem_size_list));
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#endif // CUTLASS_ARCH_MMA_SM75_SUPPORTED
|
||||
|
||||
@@ -38,6 +38,7 @@
|
||||
#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16,
|
||||
128x128_64x3_64x64x64) {
|
||||
|
||||
@@ -78,6 +79,7 @@ TEST(SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tens
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM80_Device_Conv2d_Fprop_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16,
|
||||
128x128_64x3_64x64x64) {
|
||||
|
||||
@@ -118,9 +120,10 @@ TEST(SM80_Device_Conv2d_Fprop_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_ten
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST(SM80_Device_Conv2d_Fprop_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_align2,
|
||||
|
||||
TEST(SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_align2,
|
||||
128x128_64x3_64x64x64) {
|
||||
|
||||
|
||||
/// Conv operation element types for the Gemm equivalent (ImplicitGemm)
|
||||
using ElementA = cutlass::half_t;
|
||||
using ElementB = cutlass::half_t;
|
||||
@@ -141,28 +144,41 @@ TEST(SM80_Device_Conv2d_Fprop_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_ten
|
||||
cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementC,
|
||||
128 / cutlass::sizeof_bits<ElementC>::value,
|
||||
8,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
3,
|
||||
cutlass::arch::OpMultiplyAdd,
|
||||
cutlass::conv::IteratorAlgorithm::kOptimized,
|
||||
cutlass::conv::IteratorAlgorithm::kAnalytic,
|
||||
cutlass::conv::StrideSupport::kStrided,
|
||||
2,
|
||||
2
|
||||
>::Kernel;
|
||||
|
||||
using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution<Conv2dFpropKernel>;
|
||||
|
||||
|
||||
test::conv::device::Conv2dProblemVector problem_size_list;
|
||||
|
||||
// run specific problem size in the unit test first
|
||||
problem_size_list.push_back(cutlass::conv::Conv2dProblemSize(
|
||||
{1, 4, 4, 12}, // input size (NHWC)
|
||||
{8, 3, 3, 12}, // filter size (KRSC)
|
||||
{0, 0, 0, 0}, // padding (pad_h, _, pad_w, _)
|
||||
{3, 3}, // stride (stride_h, stride_w)
|
||||
{1, 1} // dilation (dilation_h, dilation_w)
|
||||
));
|
||||
|
||||
/// Run all unit test sizes with device-level Conv2d instance
|
||||
EXPECT_TRUE(test::conv::device::TestAllConv2d<Conv2dFprop>());
|
||||
EXPECT_TRUE(test::conv::device::TestAllConv2d<Conv2dFprop>(problem_size_list));
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST(SM80_Device_Conv2d_Fprop_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_align4,
|
||||
|
||||
TEST(SM80_Device_Conv2d_Fprop_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_align2,
|
||||
128x128_64x3_64x64x64) {
|
||||
|
||||
|
||||
/// Conv operation element types for the Gemm equivalent (ImplicitGemm)
|
||||
using ElementA = cutlass::half_t;
|
||||
using ElementB = cutlass::half_t;
|
||||
@@ -183,7 +199,7 @@ TEST(SM80_Device_Conv2d_Fprop_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_ten
|
||||
cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementC,
|
||||
128 / cutlass::sizeof_bits<ElementC>::value,
|
||||
8,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
>,
|
||||
@@ -191,14 +207,81 @@ TEST(SM80_Device_Conv2d_Fprop_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_ten
|
||||
3,
|
||||
cutlass::arch::OpMultiplyAdd,
|
||||
cutlass::conv::IteratorAlgorithm::kOptimized,
|
||||
cutlass::conv::StrideSupport::kStrided,
|
||||
2,
|
||||
2
|
||||
>::Kernel;
|
||||
|
||||
using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution<Conv2dFpropKernel>;
|
||||
|
||||
test::conv::device::Conv2dProblemVector problem_size_list;
|
||||
|
||||
// run specific problem size in the unit test first
|
||||
problem_size_list.push_back(cutlass::conv::Conv2dProblemSize(
|
||||
{1, 4, 4, 12}, // input size (NHWC)
|
||||
{8, 3, 3, 12}, // filter size (KRSC)
|
||||
{0, 0, 0, 0}, // padding (pad_h, _, pad_w, _)
|
||||
{3, 3}, // stride (stride_h, stride_w)
|
||||
{1, 1} // dilation (dilation_h, dilation_w)
|
||||
));
|
||||
|
||||
/// Run all unit test sizes with device-level Conv2d instance
|
||||
EXPECT_TRUE(test::conv::device::TestAllConv2d<Conv2dFprop>(problem_size_list));
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM80_Device_Conv2d_Fprop_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_align4,
|
||||
128x128_64x3_64x64x64) {
|
||||
|
||||
/// Conv operation element types for the Gemm equivalent (ImplicitGemm)
|
||||
using ElementA = cutlass::half_t;
|
||||
using ElementB = cutlass::half_t;
|
||||
using ElementC = cutlass::half_t;
|
||||
using ElementAccumulator = cutlass::half_t;
|
||||
using ElementCompute = cutlass::half_t;
|
||||
|
||||
/// Device-level Conv2d instance
|
||||
using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop<
|
||||
ElementA, cutlass::layout::TensorNHWC,
|
||||
ElementB, cutlass::layout::TensorNHWC,
|
||||
ElementC, cutlass::layout::TensorNHWC,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<128, 128, 64>,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>,
|
||||
cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementC,
|
||||
8,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
3,
|
||||
cutlass::arch::OpMultiplyAdd,
|
||||
cutlass::conv::IteratorAlgorithm::kOptimized,
|
||||
cutlass::conv::StrideSupport::kStrided,
|
||||
4,
|
||||
4
|
||||
>::Kernel;
|
||||
|
||||
using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution<Conv2dFpropKernel>;
|
||||
|
||||
|
||||
test::conv::device::Conv2dProblemVector problem_size_list;
|
||||
|
||||
// run specific problem size in the unit test first
|
||||
problem_size_list.push_back(cutlass::conv::Conv2dProblemSize(
|
||||
{1, 4, 4, 12}, // input size (NHWC)
|
||||
{8, 3, 3, 12}, // filter size (KRSC)
|
||||
{0, 0, 0, 0}, // padding (pad_h, _, pad_w, _)
|
||||
{3, 3}, // stride (stride_h, stride_w)
|
||||
{1, 1} // dilation (dilation_h, dilation_w)
|
||||
));
|
||||
|
||||
/// Run all unit test sizes with device-level Conv2d instance
|
||||
EXPECT_TRUE(test::conv::device::TestAllConv2d<Conv2dFprop>());
|
||||
EXPECT_TRUE(test::conv::device::TestAllConv2d<Conv2dFprop>(problem_size_list));
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@@ -119,49 +119,6 @@ TEST(SM75_Device_Conv2d_Fprop_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_ten
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM75_Device_Conv2d_Fprop_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_align1,
|
||||
128x128_32x2_64x64x32) {
|
||||
|
||||
/// Conv operation element types for the Gemm equivalent (ImplicitGemm)
|
||||
using ElementA = cutlass::half_t;
|
||||
using ElementB = cutlass::half_t;
|
||||
using ElementC = float;
|
||||
using ElementAccumulator = float;
|
||||
using ElementCompute = float;
|
||||
|
||||
/// Device-level Conv2d instance
|
||||
using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop<
|
||||
ElementA, cutlass::layout::TensorNHWC,
|
||||
ElementB, cutlass::layout::TensorNHWC,
|
||||
ElementC, cutlass::layout::TensorNHWC,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm75,
|
||||
cutlass::gemm::GemmShape<128, 128, 32>,
|
||||
cutlass::gemm::GemmShape<64, 64, 32>,
|
||||
cutlass::gemm::GemmShape<16, 8, 8>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementC,
|
||||
128 / cutlass::sizeof_bits<ElementC>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2,
|
||||
cutlass::arch::OpMultiplyAdd,
|
||||
cutlass::conv::IteratorAlgorithm::kOptimized,
|
||||
1,
|
||||
1
|
||||
>::Kernel;
|
||||
|
||||
using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution<Conv2dFpropKernel>;
|
||||
|
||||
/// Run all unit test sizes with device-level Conv2d instance
|
||||
EXPECT_TRUE(test::conv::device::TestAllConv2d<Conv2dFprop>());
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM75_Device_Conv2d_Fprop_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_align2,
|
||||
128x128_32x2_64x64x32) {
|
||||
|
||||
@@ -185,7 +142,7 @@ TEST(SM75_Device_Conv2d_Fprop_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_ten
|
||||
cutlass::gemm::GemmShape<16, 8, 8>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementC,
|
||||
128 / cutlass::sizeof_bits<ElementC>::value,
|
||||
8,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
>,
|
||||
@@ -193,14 +150,26 @@ TEST(SM75_Device_Conv2d_Fprop_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_ten
|
||||
2,
|
||||
cutlass::arch::OpMultiplyAdd,
|
||||
cutlass::conv::IteratorAlgorithm::kOptimized,
|
||||
cutlass::conv::StrideSupport::kStrided,
|
||||
2,
|
||||
2
|
||||
>::Kernel;
|
||||
|
||||
using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution<Conv2dFpropKernel>;
|
||||
|
||||
test::conv::device::Conv2dProblemVector problem_size_list;
|
||||
|
||||
// run specific problem size in the unit test first
|
||||
problem_size_list.push_back(cutlass::conv::Conv2dProblemSize(
|
||||
{1, 4, 4, 12}, // input size (NHWC)
|
||||
{8, 3, 3, 12}, // filter size (KRSC)
|
||||
{0, 0, 0, 0}, // padding (pad_h, _, pad_w, _)
|
||||
{3, 3}, // stride (stride_h, stride_w)
|
||||
{1, 1} // dilation (dilation_h, dilation_w)
|
||||
));
|
||||
|
||||
/// Run all unit test sizes with device-level Conv2d instance
|
||||
EXPECT_TRUE(test::conv::device::TestAllConv2d<Conv2dFprop>());
|
||||
EXPECT_TRUE(test::conv::device::TestAllConv2d<Conv2dFprop>(problem_size_list));
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -228,7 +197,7 @@ TEST(SM75_Device_Conv2d_Fprop_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_ten
|
||||
cutlass::gemm::GemmShape<16, 8, 8>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementC,
|
||||
128 / cutlass::sizeof_bits<ElementC>::value,
|
||||
8,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
>,
|
||||
@@ -236,15 +205,83 @@ TEST(SM75_Device_Conv2d_Fprop_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_ten
|
||||
2,
|
||||
cutlass::arch::OpMultiplyAdd,
|
||||
cutlass::conv::IteratorAlgorithm::kOptimized,
|
||||
cutlass::conv::StrideSupport::kStrided,
|
||||
4,
|
||||
4
|
||||
>::Kernel;
|
||||
|
||||
using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution<Conv2dFpropKernel>;
|
||||
|
||||
test::conv::device::Conv2dProblemVector problem_size_list;
|
||||
|
||||
// run specific problem size in the unit test first
|
||||
problem_size_list.push_back(cutlass::conv::Conv2dProblemSize(
|
||||
{1, 4, 4, 12}, // input size (NHWC)
|
||||
{8, 3, 3, 12}, // filter size (KRSC)
|
||||
{0, 0, 0, 0}, // padding (pad_h, _, pad_w, _)
|
||||
{3, 3}, // stride (stride_h, stride_w)
|
||||
{1, 1} // dilation (dilation_h, dilation_w)
|
||||
));
|
||||
|
||||
/// Run all unit test sizes with device-level Conv2d instance
|
||||
EXPECT_TRUE(test::conv::device::TestAllConv2d<Conv2dFprop>());
|
||||
EXPECT_TRUE(test::conv::device::TestAllConv2d<Conv2dFprop>(problem_size_list));
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM75_Device_Conv2d_Fprop_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_align4,
|
||||
128x128_32x2_64x64x32) {
|
||||
|
||||
/// Conv operation element types for the Gemm equivalent (ImplicitGemm)
|
||||
using ElementA = cutlass::half_t;
|
||||
using ElementB = cutlass::half_t;
|
||||
using ElementC = float;
|
||||
using ElementAccumulator = float;
|
||||
using ElementCompute = float;
|
||||
|
||||
/// Device-level Conv2d instance
|
||||
using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop<
|
||||
ElementA, cutlass::layout::TensorNHWC,
|
||||
ElementB, cutlass::layout::TensorNHWC,
|
||||
ElementC, cutlass::layout::TensorNHWC,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm75,
|
||||
cutlass::gemm::GemmShape<128, 128, 32>,
|
||||
cutlass::gemm::GemmShape<64, 64, 32>,
|
||||
cutlass::gemm::GemmShape<16, 8, 8>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementC,
|
||||
8,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2,
|
||||
cutlass::arch::OpMultiplyAdd,
|
||||
cutlass::conv::IteratorAlgorithm::kAnalytic,
|
||||
cutlass::conv::StrideSupport::kStrided,
|
||||
4,
|
||||
4
|
||||
>::Kernel;
|
||||
|
||||
using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution<Conv2dFpropKernel>;
|
||||
|
||||
test::conv::device::Conv2dProblemVector problem_size_list;
|
||||
|
||||
// run specific problem size in the unit test first
|
||||
problem_size_list.push_back(cutlass::conv::Conv2dProblemSize(
|
||||
{1, 4, 4, 12}, // input size (NHWC)
|
||||
{8, 3, 3, 12}, // filter size (KRSC)
|
||||
{0, 0, 0, 0}, // padding (pad_h, _, pad_w, _)
|
||||
{3, 3}, // stride (stride_h, stride_w)
|
||||
{1, 1} // dilation (dilation_h, dilation_w)
|
||||
));
|
||||
|
||||
/// Run all unit test sizes with device-level Conv2d instance
|
||||
EXPECT_TRUE(test::conv::device::TestAllConv2d<Conv2dFprop>(problem_size_list));
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#endif // CUTLASS_ARCH_MMA_SM75_SUPPORTED
|
||||
|
||||
@@ -60,7 +60,7 @@ TEST(SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_tf32nhwc_tf32nhwc_f32nhwc_te
|
||||
cutlass::gemm::GemmShape<16, 8, 8>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementC,
|
||||
128 / cutlass::sizeof_bits<ElementC>::value,
|
||||
8,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
>,
|
||||
@@ -79,53 +79,9 @@ TEST(SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_tf32nhwc_tf32nhwc_f32nhwc_te
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM80_Device_Conv2d_Fprop_Optimized_ImplicitGemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_align1,
|
||||
128x128_32x3_64x64x32) {
|
||||
|
||||
/// Conv operation element types for the Gemm equivalent (ImplicitGemm)
|
||||
using ElementA = cutlass::tfloat32_t;
|
||||
using ElementB = cutlass::tfloat32_t;
|
||||
using ElementC = float;
|
||||
using ElementAccumulator = float;
|
||||
using ElementCompute = float;
|
||||
|
||||
/// Device-level Conv2d instance
|
||||
using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop<
|
||||
ElementA, cutlass::layout::TensorNHWC,
|
||||
ElementB, cutlass::layout::TensorNHWC,
|
||||
ElementC, cutlass::layout::TensorNHWC,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<128, 128, 16>,
|
||||
cutlass::gemm::GemmShape<64, 64, 16>,
|
||||
cutlass::gemm::GemmShape<16, 8, 8>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementC,
|
||||
128 / cutlass::sizeof_bits<ElementC>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
3,
|
||||
cutlass::arch::OpMultiplyAdd,
|
||||
cutlass::conv::IteratorAlgorithm::kOptimized,
|
||||
1,
|
||||
1
|
||||
>::Kernel;
|
||||
|
||||
using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution<Conv2dFpropKernel>;
|
||||
|
||||
|
||||
/// Run all unit test sizes with device-level Conv2d instance
|
||||
EXPECT_TRUE(test::conv::device::TestAllConv2d<Conv2dFprop>());
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM80_Device_Conv2d_Fprop_Optimized_ImplicitGemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_align2,
|
||||
128x128_32x3_64x64x32) {
|
||||
|
||||
|
||||
/// Conv operation element types for the Gemm equivalent (ImplicitGemm)
|
||||
using ElementA = cutlass::tfloat32_t;
|
||||
using ElementB = cutlass::tfloat32_t;
|
||||
@@ -146,7 +102,7 @@ TEST(SM80_Device_Conv2d_Fprop_Optimized_ImplicitGemm_tf32nhwc_tf32nhwc_f32nhwc_t
|
||||
cutlass::gemm::GemmShape<16, 8, 8>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementC,
|
||||
128 / cutlass::sizeof_bits<ElementC>::value,
|
||||
8,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
>,
|
||||
@@ -154,15 +110,26 @@ TEST(SM80_Device_Conv2d_Fprop_Optimized_ImplicitGemm_tf32nhwc_tf32nhwc_f32nhwc_t
|
||||
3,
|
||||
cutlass::arch::OpMultiplyAdd,
|
||||
cutlass::conv::IteratorAlgorithm::kOptimized,
|
||||
cutlass::conv::StrideSupport::kStrided,
|
||||
2,
|
||||
2
|
||||
>::Kernel;
|
||||
|
||||
using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution<Conv2dFpropKernel>;
|
||||
|
||||
test::conv::device::Conv2dProblemVector problem_size_list;
|
||||
|
||||
// run specific problem size in the unit test first
|
||||
problem_size_list.push_back(cutlass::conv::Conv2dProblemSize(
|
||||
{1, 4, 4, 12}, // input size (NHWC)
|
||||
{8, 3, 3, 12}, // filter size (KRSC)
|
||||
{0, 0, 0, 0}, // padding (pad_h, _, pad_w, _)
|
||||
{3, 3}, // stride (stride_h, stride_w)
|
||||
{1, 1} // dilation (dilation_h, dilation_w)
|
||||
));
|
||||
|
||||
/// Run all unit test sizes with device-level Conv2d instance
|
||||
EXPECT_TRUE(test::conv::device::TestAllConv2d<Conv2dFprop>());
|
||||
EXPECT_TRUE(test::conv::device::TestAllConv2d<Conv2dFprop>(problem_size_list));
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@@ -174,6 +174,61 @@ TEST(SM80_Device_Conv2d_Strided_Dgrad_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f32n
|
||||
/// Run all unit test sizes with device-level Conv2d instance
|
||||
EXPECT_TRUE(test::conv::device::TestAllConv2d<Conv2dDgrad>());
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM80_Device_Conv2d_Strided_Dgrad_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_align4,
|
||||
128x128_32x3_64x64x32) {
|
||||
|
||||
/// Conv operation element types for the Gemm equivalent (ImplicitGemm)
|
||||
using ElementA = cutlass::half_t;
|
||||
using ElementB = cutlass::half_t;
|
||||
using ElementC = float;
|
||||
using ElementAccumulator = float;
|
||||
using ElementCompute = float;
|
||||
|
||||
/// Device-level Conv2d instance
|
||||
using Conv2dDgradKernel = typename cutlass::conv::kernel::DefaultConv2dDgrad<
|
||||
ElementA, cutlass::layout::TensorNHWC,
|
||||
ElementB, cutlass::layout::TensorNHWC,
|
||||
ElementC, cutlass::layout::TensorNHWC,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<128, 128, 32>,
|
||||
cutlass::gemm::GemmShape<64, 64, 32>,
|
||||
cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementC,
|
||||
4,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
>,
|
||||
cutlass::conv::threadblock::StridedDgradIdentityThreadblockSwizzle<>,
|
||||
3,
|
||||
cutlass::arch::OpMultiplyAdd,
|
||||
cutlass::conv::IteratorAlgorithm::kAnalytic,
|
||||
cutlass::conv::StrideSupport::kStrided,
|
||||
4,
|
||||
4
|
||||
>::Kernel;
|
||||
|
||||
using Conv2dDgrad = cutlass::conv::device::ImplicitGemmConvolution<Conv2dDgradKernel>;
|
||||
|
||||
test::conv::device::Conv2dProblemVector problem_size_list;
|
||||
|
||||
// run specific problem size in the unit test first
|
||||
problem_size_list.push_back(cutlass::conv::Conv2dProblemSize(
|
||||
{1, 4, 4, 12}, // input size (NHWC)
|
||||
{8, 3, 3, 12}, // filter size (KRSC)
|
||||
{0, 0, 0, 0}, // padding (pad_h, _, pad_w, _)
|
||||
{3, 3}, // stride (stride_h, stride_w)
|
||||
{1, 1} // dilation (dilation_h, dilation_w)
|
||||
));
|
||||
|
||||
/// Run all unit test sizes with device-level Conv2d instance
|
||||
EXPECT_TRUE(test::conv::device::TestAllConv2d<Conv2dDgrad>(problem_size_list));
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
#endif // CUTLASS_ARCH_MMA_SM80_SUPPORTED
|
||||
|
||||
@@ -116,6 +116,7 @@ public:
|
||||
cutlass::Distribution::Kind dist_kind,
|
||||
uint64_t seed) {
|
||||
|
||||
//cutlass::reference::host::TensorFill(view, Element(1.0f));
|
||||
if (dist_kind == cutlass::Distribution::Uniform) {
|
||||
|
||||
int scope;
|
||||
|
||||
@@ -36,6 +36,8 @@
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED)
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM75_Device_Conv2d_Wgrad_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32,
|
||||
128x128_32x2_64x64x32) {
|
||||
|
||||
@@ -74,5 +76,114 @@ TEST(SM75_Device_Conv2d_Wgrad_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tens
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM75_Device_Conv2d_Wgrad_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_align4,
|
||||
128x128_32x2_64x64x32) {
|
||||
|
||||
/// Conv operation element types for the Gemm equivalent (ImplicitGemm)
|
||||
using ElementA = cutlass::half_t;
|
||||
using ElementB = cutlass::half_t;
|
||||
using ElementC = float;
|
||||
using ElementAccumulator = float;
|
||||
using ElementCompute = float;
|
||||
|
||||
using Conv2dWgradKernel = typename cutlass::conv::kernel::DefaultConv2dWgrad<
|
||||
ElementA, cutlass::layout::TensorNHWC,
|
||||
ElementB, cutlass::layout::TensorNHWC,
|
||||
ElementC, cutlass::layout::TensorNHWC,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm75,
|
||||
cutlass::gemm::GemmShape<128, 128, 32>,
|
||||
cutlass::gemm::GemmShape<64, 64, 32>,
|
||||
cutlass::gemm::GemmShape<16, 8, 8>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementC,
|
||||
4,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2,
|
||||
cutlass::arch::OpMultiplyAdd,
|
||||
cutlass::conv::IteratorAlgorithm::kAnalytic,
|
||||
cutlass::conv::StrideSupport::kStrided,
|
||||
4,
|
||||
4
|
||||
>::Kernel;
|
||||
|
||||
using Conv2dWgrad = cutlass::conv::device::ImplicitGemmConvolution<Conv2dWgradKernel>;
|
||||
|
||||
test::conv::device::Conv2dProblemVector problem_size_list;
|
||||
|
||||
// run specific problem size in the unit test first
|
||||
problem_size_list.push_back(cutlass::conv::Conv2dProblemSize(
|
||||
{1, 4, 4, 12}, // input size (NHWC)
|
||||
{8, 3, 3, 12}, // filter size (KRSC)
|
||||
{0, 0, 0, 0}, // padding (pad_h, _, pad_w, _)
|
||||
{3, 3}, // stride (stride_h, stride_w)
|
||||
{1, 1} // dilation (dilation_h, dilation_w)
|
||||
));
|
||||
|
||||
/// Run all unit test sizes with device-level Conv2d instance
|
||||
EXPECT_TRUE(test::conv::device::TestAllConv2d<Conv2dWgrad>(problem_size_list));
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM75_Device_Conv2d_Wgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_align4,
|
||||
128x128_32x2_64x64x32) {
|
||||
|
||||
/// Conv operation element types for the Gemm equivalent (ImplicitGemm)
|
||||
using ElementA = cutlass::half_t;
|
||||
using ElementB = cutlass::half_t;
|
||||
using ElementC = float;
|
||||
using ElementAccumulator = float;
|
||||
using ElementCompute = float;
|
||||
|
||||
using Conv2dWgradKernel = typename cutlass::conv::kernel::DefaultConv2dWgrad<
|
||||
ElementA, cutlass::layout::TensorNHWC,
|
||||
ElementB, cutlass::layout::TensorNHWC,
|
||||
ElementC, cutlass::layout::TensorNHWC,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm75,
|
||||
cutlass::gemm::GemmShape<128, 128, 32>,
|
||||
cutlass::gemm::GemmShape<64, 64, 32>,
|
||||
cutlass::gemm::GemmShape<16, 8, 8>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementC,
|
||||
4,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2,
|
||||
cutlass::arch::OpMultiplyAdd,
|
||||
cutlass::conv::IteratorAlgorithm::kOptimized,
|
||||
cutlass::conv::StrideSupport::kStrided,
|
||||
4,
|
||||
4
|
||||
>::Kernel;
|
||||
|
||||
using Conv2dWgrad = cutlass::conv::device::ImplicitGemmConvolution<Conv2dWgradKernel>;
|
||||
|
||||
test::conv::device::Conv2dProblemVector problem_size_list;
|
||||
|
||||
// run specific problem size in the unit test first
|
||||
problem_size_list.push_back(cutlass::conv::Conv2dProblemSize(
|
||||
{1, 4, 4, 12}, // input size (NHWC)
|
||||
{8, 3, 3, 12}, // filter size (KRSC)
|
||||
{0, 0, 0, 0}, // padding (pad_h, _, pad_w, _)
|
||||
{3, 3}, // stride (stride_h, stride_w)
|
||||
{1, 1} // dilation (dilation_h, dilation_w)
|
||||
));
|
||||
|
||||
/// Run all unit test sizes with device-level Conv2d instance
|
||||
EXPECT_TRUE(test::conv::device::TestAllConv2d<Conv2dWgrad>(problem_size_list));
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#endif // CUTLASS_ARCH_MMA_SM75_SUPPORTED
|
||||
|
||||
|
||||
@@ -146,8 +146,7 @@ TEST(SM80_Device_Conv2d_Wgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_ten
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
4,
|
||||
cutlass::arch::OpMultiplyAdd,
|
||||
cutlass::conv::IteratorAlgorithm::kOptimized,
|
||||
cutlass::conv::StrideSupport::kStrided
|
||||
cutlass::conv::IteratorAlgorithm::kOptimized
|
||||
>::Kernel;
|
||||
|
||||
using Conv2dWgrad = cutlass::conv::device::ImplicitGemmConvolution<Conv2dWgradKernel>;
|
||||
@@ -157,5 +156,113 @@ TEST(SM80_Device_Conv2d_Wgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_ten
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
#endif // CUTLASS_ARCH_MMA_SM80_SUPPORTED
|
||||
|
||||
TEST(SM80_Device_Conv2d_Wgrad_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_align4,
|
||||
128x128_32x3_64x64x32) {
|
||||
|
||||
/// Conv operation element types for the Gemm equivalent (ImplicitGemm)
|
||||
using ElementA = cutlass::half_t;
|
||||
using ElementB = cutlass::half_t;
|
||||
using ElementC = float;
|
||||
using ElementAccumulator = float;
|
||||
using ElementCompute = float;
|
||||
|
||||
using Conv2dWgradKernel = typename cutlass::conv::kernel::DefaultConv2dWgrad<
|
||||
ElementA, cutlass::layout::TensorNHWC,
|
||||
ElementB, cutlass::layout::TensorNHWC,
|
||||
ElementC, cutlass::layout::TensorNHWC,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<128, 128, 16>,
|
||||
cutlass::gemm::GemmShape<64, 64, 16>,
|
||||
cutlass::gemm::GemmShape<16, 8, 8>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementC,
|
||||
4,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
3,
|
||||
cutlass::arch::OpMultiplyAdd,
|
||||
cutlass::conv::IteratorAlgorithm::kAnalytic,
|
||||
cutlass::conv::StrideSupport::kStrided,
|
||||
4,
|
||||
4
|
||||
>::Kernel;
|
||||
|
||||
using Conv2dWgrad = cutlass::conv::device::ImplicitGemmConvolution<Conv2dWgradKernel>;
|
||||
|
||||
test::conv::device::Conv2dProblemVector problem_size_list;
|
||||
|
||||
// run specific problem size in the unit test first
|
||||
problem_size_list.push_back(cutlass::conv::Conv2dProblemSize(
|
||||
{1, 4, 4, 12}, // input size (NHWC)
|
||||
{8, 3, 3, 12}, // filter size (KRSC)
|
||||
{0, 0, 0, 0}, // padding (pad_h, _, pad_w, _)
|
||||
{3, 3}, // stride (stride_h, stride_w)
|
||||
{1, 1} // dilation (dilation_h, dilation_w)
|
||||
));
|
||||
|
||||
/// Run all unit test sizes with device-level Conv2d instance
|
||||
EXPECT_TRUE(test::conv::device::TestAllConv2d<Conv2dWgrad>(problem_size_list));
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM80_Device_Conv2d_Wgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_align4,
|
||||
128x128_32x3_64x64x32) {
|
||||
|
||||
/// Conv operation element types for the Gemm equivalent (ImplicitGemm)
|
||||
using ElementA = cutlass::half_t;
|
||||
using ElementB = cutlass::half_t;
|
||||
using ElementC = float;
|
||||
using ElementAccumulator = float;
|
||||
using ElementCompute = float;
|
||||
|
||||
using Conv2dWgradKernel = typename cutlass::conv::kernel::DefaultConv2dWgrad<
|
||||
ElementA, cutlass::layout::TensorNHWC,
|
||||
ElementB, cutlass::layout::TensorNHWC,
|
||||
ElementC, cutlass::layout::TensorNHWC,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<128, 128, 16>,
|
||||
cutlass::gemm::GemmShape<64, 64, 16>,
|
||||
cutlass::gemm::GemmShape<16, 8, 8>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementC,
|
||||
4,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
3,
|
||||
cutlass::arch::OpMultiplyAdd,
|
||||
cutlass::conv::IteratorAlgorithm::kOptimized,
|
||||
cutlass::conv::StrideSupport::kStrided,
|
||||
4,
|
||||
4
|
||||
>::Kernel;
|
||||
|
||||
using Conv2dWgrad = cutlass::conv::device::ImplicitGemmConvolution<Conv2dWgradKernel>;
|
||||
|
||||
test::conv::device::Conv2dProblemVector problem_size_list;
|
||||
|
||||
// run specific problem size in the unit test first
|
||||
problem_size_list.push_back(cutlass::conv::Conv2dProblemSize(
|
||||
{1, 4, 4, 12}, // input size (NHWC)
|
||||
{8, 3, 3, 12}, // filter size (KRSC)
|
||||
{0, 0, 0, 0}, // padding (pad_h, _, pad_w, _)
|
||||
{3, 3}, // stride (stride_h, stride_w)
|
||||
{1, 1} // dilation (dilation_h, dilation_w)
|
||||
));
|
||||
|
||||
/// Run all unit test sizes with device-level Conv2d instance
|
||||
EXPECT_TRUE(test::conv::device::TestAllConv2d<Conv2dWgrad>(problem_size_list));
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#endif // CUTLASS_ARCH_MMA_SM80_SUPPORTED
|
||||
|
||||
@@ -103,9 +103,9 @@ class Conv2dOperation:
|
||||
)
|
||||
|
||||
if self.stride_support == StrideSupport.Unity:
|
||||
configuration_name = "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_unity_stride"
|
||||
configuration_name = "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_align${alignment}_unity_stride"
|
||||
else:
|
||||
configuration_name = "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}"
|
||||
configuration_name = "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_align${alignment}"
|
||||
|
||||
return SubstituteTemplate(
|
||||
configuration_name,
|
||||
@@ -114,6 +114,7 @@ class Conv2dOperation:
|
||||
'extended_name': self.extended_name(),
|
||||
'threadblock': threadblock,
|
||||
'layout': self.layout_name(),
|
||||
'alignment': "%d" % self.A.alignment,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -156,7 +157,9 @@ class EmitConv2dInstance:
|
||||
${stages},
|
||||
${math_operator},
|
||||
${iterator_algorithm},
|
||||
${stride_support}
|
||||
${stride_support},
|
||||
${align_a},
|
||||
${align_b}
|
||||
>::Kernel;
|
||||
"""
|
||||
|
||||
@@ -198,7 +201,9 @@ class EmitConv2dInstance:
|
||||
'iterator_algorithm_name': IteratorAlgorithmNames[operation.iterator_algorithm].capitalize(),
|
||||
'stride_support': StrideSupportTag[operation.stride_support],
|
||||
'math_operator': 'cutlass::arch::OpMultiplyAddComplex' if operation.is_complex() else \
|
||||
MathOperationTag[operation.tile_description.math_instruction.math_operation]
|
||||
MathOperationTag[operation.tile_description.math_instruction.math_operation],
|
||||
'align_a': str(operation.A.alignment),
|
||||
'align_b': str(operation.B.alignment),
|
||||
}
|
||||
|
||||
return SubstituteTemplate(self.template, values)
|
||||
@@ -341,4 +346,3 @@ void initialize_${configuration_name}(Manifest &manifest) {
|
||||
|
||||
###################################################################################################
|
||||
###################################################################################################
|
||||
|
||||
|
||||
@@ -151,14 +151,13 @@ def CreateGemmPlanarComplexOperator(manifest, layouts, tile_descriptions, data_t
|
||||
# Note : Operator marked (*) are supported but not generated to keep the instantiated kernel count low
|
||||
###########################################################################################################
|
||||
# Convolution for 2D operations
|
||||
def CreateConv2dOperator(manifest, layout, tile_descriptions, data_type, alignment, \
|
||||
def CreateConv2dOperator(manifest, layout, tile_descriptions, data_type, alignment_constraints, \
|
||||
conv_kinds = [ConvKind.Fprop, ConvKind.Dgrad, ConvKind.Wgrad], \
|
||||
epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity4):
|
||||
|
||||
element_a, element_b, element_c, element_epilogue = data_type
|
||||
|
||||
# one exceptional case
|
||||
alignment_c = min(8, alignment)
|
||||
|
||||
# iterator algorithm (analytic and optimized)
|
||||
iterator_algorithms = [IteratorAlgorithm.Analytic, IteratorAlgorithm.Optimized]
|
||||
@@ -166,66 +165,71 @@ def CreateConv2dOperator(manifest, layout, tile_descriptions, data_type, alignme
|
||||
# by default, only generate the largest tile size
|
||||
if manifest.args.kernels == '':
|
||||
tile_descriptions = [tile_descriptions[0],]
|
||||
alignment_constraints = [alignment_constraints[0],]
|
||||
|
||||
operations = []
|
||||
|
||||
for tile in tile_descriptions:
|
||||
A = TensorDescription(element_a, layout[0], alignment)
|
||||
B = TensorDescription(element_b, layout[1], alignment)
|
||||
C = TensorDescription(element_c, layout[2], alignment_c)
|
||||
|
||||
swizzling_functor_ = swizzling_functor
|
||||
for alignment in alignment_constraints:
|
||||
|
||||
#
|
||||
# Conv2d Fprop
|
||||
#
|
||||
if ConvKind.Fprop in conv_kinds:
|
||||
alignment_c = min(8, alignment)
|
||||
|
||||
# Strided support for Analytic and Optimized Fprop
|
||||
for iterator_algorithm in iterator_algorithms:
|
||||
new_operation = Conv2dOperation(ConvKind.Fprop, iterator_algorithm, tile.minimum_compute_capability, tile,\
|
||||
A, B, C, element_epilogue, StrideSupport.Strided, epilogue_functor, swizzling_functor_)
|
||||
|
||||
manifest.append(new_operation)
|
||||
operations.append(new_operation)
|
||||
|
||||
#
|
||||
# Conv2d Dgrad
|
||||
#
|
||||
if ConvKind.Dgrad in conv_kinds:
|
||||
|
||||
# Unity stride for Analytic and Optimized Dgrad
|
||||
for iterator_algorithm in iterator_algorithms:
|
||||
new_operation = Conv2dOperation(ConvKind.Dgrad, iterator_algorithm, tile.minimum_compute_capability, tile,\
|
||||
A, B, C, element_epilogue, StrideSupport.Unity, epilogue_functor, swizzling_functor_)
|
||||
|
||||
manifest.append(new_operation)
|
||||
operations.append(new_operation)
|
||||
|
||||
# Strided support for Analytic Dgrad
|
||||
# strided dgrad uses a special threadblock swizzle
|
||||
# note that SwizzlingFunctor.StridedDgradHorizontal might be
|
||||
# better for problem sizes with large activation channel count
|
||||
swizzling_functor_strided_dgrad_ = SwizzlingFunctor.StridedDgradIdentity1
|
||||
|
||||
new_operation = Conv2dOperation(ConvKind.Dgrad, IteratorAlgorithm.Analytic, tile.minimum_compute_capability, tile,\
|
||||
A, B, C, element_epilogue, StrideSupport.Strided, epilogue_functor, swizzling_functor_strided_dgrad_)
|
||||
|
||||
manifest.append(new_operation)
|
||||
operations.append(new_operation)
|
||||
A = TensorDescription(element_a, layout[0], alignment)
|
||||
B = TensorDescription(element_b, layout[1], alignment)
|
||||
C = TensorDescription(element_c, layout[2], alignment_c)
|
||||
|
||||
#
|
||||
# Conv2d Wgrad
|
||||
#
|
||||
if ConvKind.Wgrad in conv_kinds:
|
||||
|
||||
# Strided support for Analytic and Optimized Wgrad
|
||||
for iterator_algorithm in iterator_algorithms:
|
||||
new_operation = Conv2dOperation(ConvKind.Wgrad, iterator_algorithm, tile.minimum_compute_capability, tile,\
|
||||
A, B, C, element_epilogue, StrideSupport.Strided, epilogue_functor, swizzling_functor_)
|
||||
|
||||
swizzling_functor_ = swizzling_functor
|
||||
|
||||
#
|
||||
# Conv2d Fprop
|
||||
#
|
||||
if ConvKind.Fprop in conv_kinds:
|
||||
|
||||
# Strided support for Analytic and Optimized Fprop
|
||||
for iterator_algorithm in iterator_algorithms:
|
||||
new_operation = Conv2dOperation(ConvKind.Fprop, iterator_algorithm, tile.minimum_compute_capability, tile,\
|
||||
A, B, C, element_epilogue, StrideSupport.Strided, epilogue_functor, swizzling_functor_)
|
||||
|
||||
manifest.append(new_operation)
|
||||
operations.append(new_operation)
|
||||
|
||||
#
|
||||
# Conv2d Dgrad
|
||||
#
|
||||
if ConvKind.Dgrad in conv_kinds:
|
||||
|
||||
# Unity stride for Analytic and Optimized Dgrad
|
||||
for iterator_algorithm in iterator_algorithms:
|
||||
new_operation = Conv2dOperation(ConvKind.Dgrad, iterator_algorithm, tile.minimum_compute_capability, tile,\
|
||||
A, B, C, element_epilogue, StrideSupport.Unity, epilogue_functor, swizzling_functor_)
|
||||
|
||||
manifest.append(new_operation)
|
||||
operations.append(new_operation)
|
||||
|
||||
# Strided support for Analytic Dgrad
|
||||
# strided dgrad uses a special threadblock swizzle
|
||||
# note that SwizzlingFunctor.StridedDgradHorizontal might be
|
||||
# better for problem sizes with large activation channel count
|
||||
swizzling_functor_strided_dgrad_ = SwizzlingFunctor.StridedDgradIdentity1
|
||||
|
||||
new_operation = Conv2dOperation(ConvKind.Dgrad, IteratorAlgorithm.Analytic, tile.minimum_compute_capability, tile,\
|
||||
A, B, C, element_epilogue, StrideSupport.Strided, epilogue_functor, swizzling_functor_strided_dgrad_)
|
||||
|
||||
manifest.append(new_operation)
|
||||
operations.append(new_operation)
|
||||
|
||||
#
|
||||
# Conv2d Wgrad
|
||||
#
|
||||
if ConvKind.Wgrad in conv_kinds:
|
||||
|
||||
# Strided support for Analytic and Optimized Wgrad
|
||||
for iterator_algorithm in iterator_algorithms:
|
||||
new_operation = Conv2dOperation(ConvKind.Wgrad, iterator_algorithm, tile.minimum_compute_capability, tile,\
|
||||
A, B, C, element_epilogue, StrideSupport.Strided, epilogue_functor, swizzling_functor_)
|
||||
|
||||
manifest.append(new_operation)
|
||||
operations.append(new_operation)
|
||||
|
||||
return operations
|
||||
|
||||
@@ -322,7 +326,7 @@ def GenerateSM50_Simt(manifest, args):
|
||||
|
||||
if math_inst.element_a == DataType.f32:
|
||||
conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC)
|
||||
CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, 1)
|
||||
CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, alignment_constraints)
|
||||
#
|
||||
|
||||
#
|
||||
@@ -369,7 +373,7 @@ def GenerateSM50_Simt_complex(manifest, args):
|
||||
data_type, alignment_constraints)
|
||||
|
||||
conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC)
|
||||
CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, 1)
|
||||
CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, alignment_constraints)
|
||||
#
|
||||
|
||||
#
|
||||
@@ -543,7 +547,7 @@ def GenerateSM70_TensorOp_884(manifest, args):
|
||||
data_type, alignment_constraints)
|
||||
|
||||
conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC)
|
||||
CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, 8)
|
||||
CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, alignment_constraints)
|
||||
|
||||
# Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation)
|
||||
if math_inst.element_a != math_inst.element_accumulator:
|
||||
@@ -558,7 +562,7 @@ def GenerateSM70_TensorOp_884(manifest, args):
|
||||
CreateGemmOperator(manifest, layouts, tile_descriptions, \
|
||||
data_type_mixed, alignment_constraints)
|
||||
|
||||
CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type_mixed, 8)
|
||||
CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type_mixed, alignment_constraints)
|
||||
|
||||
#
|
||||
def GenerateSM70_PlanarComplexTensorOp_884(manifest, args):
|
||||
@@ -754,7 +758,7 @@ def GenerateSM75_TensorOp_1688(manifest, args):
|
||||
data_type, alignment_constraints)
|
||||
|
||||
conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC)
|
||||
CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, 8)
|
||||
CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, alignment_constraints)
|
||||
|
||||
# Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation)
|
||||
if math_inst.element_a != math_inst.element_accumulator:
|
||||
@@ -769,7 +773,7 @@ def GenerateSM75_TensorOp_1688(manifest, args):
|
||||
CreateGemmOperator(manifest, layouts, tile_descriptions, \
|
||||
data_type_mixed, alignment_constraints)
|
||||
|
||||
CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type_mixed, 8)
|
||||
CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type_mixed, alignment_constraints)
|
||||
|
||||
#
|
||||
|
||||
@@ -891,7 +895,7 @@ def GenerateSM75_TensorOp_8816_TN(manifest, args):
|
||||
|
||||
conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC)
|
||||
CreateConv2dOperator(manifest, conv_layout, tile_descriptions,
|
||||
data_type, 16, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
|
||||
data_type, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
|
||||
|
||||
# Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation)
|
||||
if math_inst.element_a != math_inst.element_accumulator:
|
||||
@@ -909,7 +913,7 @@ def GenerateSM75_TensorOp_8816_TN(manifest, args):
|
||||
data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp)
|
||||
|
||||
operations += CreateConv2dOperator(manifest, conv_layout, tile_descriptions,
|
||||
data_type_mixed, 16, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
|
||||
data_type_mixed, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
|
||||
|
||||
for op in operations:
|
||||
if op.tile_description.threadblock_shape[1] >= 128:
|
||||
@@ -972,7 +976,7 @@ def GenerateSM75_TensorOp_8816_Interleaved(manifest, args):
|
||||
conv_layout = (LayoutType.TensorNC32HW32, LayoutType.TensorC32RSK32, LayoutType.TensorNC32HW32)
|
||||
|
||||
operations += CreateConv2dOperator(manifest, conv_layout, tile_descriptions,
|
||||
data_type_mixed, 16, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
|
||||
data_type_mixed, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
|
||||
|
||||
for op in operations:
|
||||
op.C.alignment = 8
|
||||
@@ -1028,7 +1032,7 @@ def GenerateSM75_TensorOp_8832_TN(manifest, args):
|
||||
|
||||
conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC)
|
||||
CreateConv2dOperator(manifest, conv_layout, tile_descriptions,
|
||||
data_type, 32, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
|
||||
data_type, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
|
||||
|
||||
# Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation)
|
||||
if math_inst.element_a != math_inst.element_accumulator:
|
||||
@@ -1046,7 +1050,7 @@ def GenerateSM75_TensorOp_8832_TN(manifest, args):
|
||||
data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp)
|
||||
|
||||
operations += CreateConv2dOperator(manifest, conv_layout, tile_descriptions,
|
||||
data_type_mixed, 32, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
|
||||
data_type_mixed, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
|
||||
|
||||
for op in operations:
|
||||
if op.tile_description.threadblock_shape[1] >= 128:
|
||||
@@ -1112,7 +1116,7 @@ def GenerateSM75_TensorOp_8832_Interleaved(manifest, args):
|
||||
conv_layout = (LayoutType.TensorNC64HW64, LayoutType.TensorC64RSK64, LayoutType.TensorNC64HW64)
|
||||
|
||||
operations += CreateConv2dOperator(manifest, conv_layout, tile_descriptions,
|
||||
data_type_mixed, 32, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
|
||||
data_type_mixed, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
|
||||
|
||||
for op in operations:
|
||||
op.C.alignment = 16
|
||||
@@ -1250,7 +1254,7 @@ def GenerateSM75_Simt_complex(manifest, args):
|
||||
]
|
||||
|
||||
conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC)
|
||||
CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, 1)
|
||||
CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, alignment_constraints)
|
||||
#
|
||||
|
||||
def GenerateSM75(manifest, args):
|
||||
@@ -1338,7 +1342,7 @@ def GenerateSM80_TensorOp_16816(manifest, args):
|
||||
data_type, alignment_constraints)
|
||||
|
||||
conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC)
|
||||
CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, 8)
|
||||
CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, alignment_constraints)
|
||||
CreateConv3dOperator(manifest, LayoutType.TensorNDHWC, tile_descriptions, data_type, 8)
|
||||
|
||||
# Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation)
|
||||
@@ -1354,7 +1358,7 @@ def GenerateSM80_TensorOp_16816(manifest, args):
|
||||
CreateGemmOperator(manifest, layouts, tile_descriptions, \
|
||||
data_type_mixed, alignment_constraints)
|
||||
|
||||
CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type_mixed, 8)
|
||||
CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type_mixed, alignment_constraints)
|
||||
CreateConv3dOperator(manifest, LayoutType.TensorNDHWC, tile_descriptions, data_type_mixed, 8)
|
||||
#
|
||||
|
||||
@@ -1572,10 +1576,10 @@ def GenerateSM80_TensorOp_16832_TN(manifest, args):
|
||||
|
||||
conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC)
|
||||
CreateConv2dOperator(manifest, conv_layout, tile_descriptions,
|
||||
data_type, 16, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
|
||||
data_type, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
|
||||
|
||||
operations += CreateConv2dOperator(manifest, conv_layout, tile_descriptions,
|
||||
data_type_mixed, 16, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
|
||||
data_type_mixed, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
|
||||
|
||||
for op in operations:
|
||||
if op.tile_description.threadblock_shape[1] >= 128:
|
||||
@@ -1689,7 +1693,7 @@ def GenerateSM80_TensorOp_16832_Interleaved(manifest, args):
|
||||
conv_layout = (LayoutType.TensorNC32HW32, LayoutType.TensorC32RSK32, LayoutType.TensorNC32HW32)
|
||||
|
||||
operations += CreateConv2dOperator(manifest, conv_layout, tile_descriptions,
|
||||
data_type_mixed, 16, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
|
||||
data_type_mixed, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
|
||||
|
||||
for op in operations:
|
||||
op.C.alignment = 8
|
||||
@@ -1758,10 +1762,10 @@ def GenerateSM80_TensorOp_16864_TN(manifest, args):
|
||||
|
||||
conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC)
|
||||
CreateConv2dOperator(manifest, conv_layout, tile_descriptions,
|
||||
data_type, 32, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
|
||||
data_type, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
|
||||
|
||||
operations += CreateConv2dOperator(manifest, conv_layout, tile_descriptions,
|
||||
data_type_mixed, 32, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
|
||||
data_type_mixed, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
|
||||
|
||||
for op in operations:
|
||||
if op.tile_description.threadblock_shape[1] >= 128:
|
||||
@@ -1878,7 +1882,7 @@ def GenerateSM80_TensorOp_16864_Interleaved(manifest, args):
|
||||
conv_layout = (LayoutType.TensorNC64HW64, LayoutType.TensorC64RSK64, LayoutType.TensorNC64HW64)
|
||||
|
||||
operations += CreateConv2dOperator(manifest, conv_layout, tile_descriptions,
|
||||
data_type_mixed, 32, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
|
||||
data_type_mixed, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
|
||||
|
||||
for op in operations:
|
||||
op.C.alignment = 16
|
||||
@@ -2005,9 +2009,9 @@ def GenerateSM80_TensorOp_1688(manifest, args):
|
||||
data_type_mixed, alignment_constraints)
|
||||
|
||||
conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC)
|
||||
CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, 4)
|
||||
CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, alignment_constraints)
|
||||
|
||||
CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type_mixed, 4)
|
||||
CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type_mixed, alignment_constraints)
|
||||
#
|
||||
|
||||
#
|
||||
@@ -2076,7 +2080,7 @@ def GenerateSM80_TensorOp_1688_fast_math(manifest, args):
|
||||
data_type, alignment_constraints)
|
||||
|
||||
conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC)
|
||||
CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, 4)
|
||||
CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, alignment_constraints)
|
||||
#
|
||||
|
||||
#
|
||||
@@ -2366,7 +2370,7 @@ def GenerateSM80_Simt_f32(manifest, args):
|
||||
data_type, alignment_constraints)
|
||||
|
||||
conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC)
|
||||
CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, 1)
|
||||
CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, alignment_constraints)
|
||||
#
|
||||
|
||||
|
||||
@@ -2467,7 +2471,7 @@ def GenerateSM80_Simt_complex(manifest, args):
|
||||
CreateGemmOperator(manifest, layouts, tile_descriptions, data_type, alignment_constraints, complex_transforms)
|
||||
|
||||
conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC)
|
||||
CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, 1)
|
||||
CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, alignment_constraints)
|
||||
#
|
||||
|
||||
###################################################################################################
|
||||
|
||||
Reference in New Issue
Block a user