mirror of
https://github.com/NVIDIA/cutlass.git
synced 2026-05-11 17:00:05 +00:00
CUTLASS 2.7 (#318)
CUTLASS 2.7 Mainloop fusion for GEMM: summation over A or B Strided DGRAD (optimized iterators) Half-precision GELU_taylor activation functions Use these when accumulation and epilogue compute types are all cutlass::half_t Tuning and bug fixes to fused GEMM + GEMM example Support for smaller than 128b aligned Convolutions: see examples Caching of results to accelerate Convolution unit tests Can be enabled or disabled by running cmake .. -DCUTLASS_TEST_ENABLE_CACHED_RESULTS=OFF Corrections and bug fixes reported by the CUTLASS community Thank you for filing these issues! authored-by: Haicheng Wu haichengw@nvidia.com, Manish Gupta manigupta@nvidia.com, Dustyn Blasig dblasig@nvidia.com, Andrew Kerr akerr@nvidia.com
This commit is contained in:
@@ -225,6 +225,34 @@ struct global_store;
|
||||
//
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
template <typename AccessType>
|
||||
struct global_store<AccessType, 64> {
|
||||
CUTLASS_DEVICE
|
||||
global_store(AccessType const &D, void *ptr, bool pred_guard) {
|
||||
uint4 const *data = reinterpret_cast<uint4 const *>(&D);
|
||||
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" .reg .pred p;\n"
|
||||
" setp.ne.b32 p, %5, 0;\n"
|
||||
" @p st.global.v4.u32 [%0], {%1, %2, %3, %4};\n"
|
||||
" @p st.global.v4.u32 [%6], {%7, %8, %9, %10};\n"
|
||||
" @p st.global.v4.u32 [%11], {%12, %13, %14, %15};\n"
|
||||
" @p st.global.v4.u32 [%16], {%17, %18, %19, %20};\n"
|
||||
"}\n"
|
||||
:
|
||||
: "l"(ptr), "r"(data[0].x), "r"(data[0].y), "r"(data[0].z),
|
||||
"r"(data[0].w), "r"((int)pred_guard), "l"(((uint8_t *)ptr) + 16),
|
||||
"r"(data[1].x), "r"(data[1].y), "r"(data[1].z), "r"(data[1].w),
|
||||
"l"(((uint8_t *)ptr) + 32),
|
||||
"r"(data[2].x), "r"(data[2].y), "r"(data[2].z), "r"(data[2].w),
|
||||
"l"(((uint8_t *)ptr) + 48),
|
||||
"r"(data[3].x), "r"(data[3].y), "r"(data[3].z), "r"(data[2].w));
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
template <typename AccessType>
|
||||
struct global_store<AccessType, 32> {
|
||||
CUTLASS_DEVICE
|
||||
|
||||
@@ -65,7 +65,7 @@ template <
|
||||
typename ThreadblockSwizzle,
|
||||
int Stages,
|
||||
typename MathOperatorTag,
|
||||
conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kAnalytic,
|
||||
conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized,
|
||||
conv::StrideSupport StrideSupport = StrideSupport::kStrided,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value,
|
||||
@@ -675,6 +675,243 @@ struct DefaultConv2dDgrad <
|
||||
>;
|
||||
};
|
||||
|
||||
/// Defines a kernel for Conv2dDgrad specialzation for Optimized IteratorAlgorithm Dgrad Strided and
|
||||
// multistage pipeline.
|
||||
template <
|
||||
typename ElementA,
|
||||
typename LayoutA,
|
||||
typename ElementB,
|
||||
typename LayoutB,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ElementAccumulator,
|
||||
typename ArchTag,
|
||||
typename ThreadblockShape,
|
||||
typename WarpShape,
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp,
|
||||
typename ThreadblockSwizzle,
|
||||
int Stages,
|
||||
typename MathOperatorTag,
|
||||
int AlignmentA,
|
||||
int AlignmentB
|
||||
>
|
||||
struct DefaultConv2dDgrad <
|
||||
ElementA,
|
||||
LayoutA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ElementAccumulator,
|
||||
arch::OpClassTensorOp,
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp,
|
||||
ThreadblockSwizzle,
|
||||
Stages,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kOptimized,
|
||||
StrideSupport::kStrided,
|
||||
AlignmentA,
|
||||
AlignmentB
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor,
|
||||
ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
|
||||
Stages, MathOperatorTag>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
||||
using AccessTypeA = cutlass::AlignedArray<ElementA, AlignmentA>;
|
||||
using IteratorA =
|
||||
cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
||||
ElementA,
|
||||
ThreadMapA,
|
||||
StrideSupport::kStrided,
|
||||
AccessTypeA
|
||||
>;
|
||||
|
||||
using SmemIteratorA = typename MmaCore::SmemIteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
||||
using AccessTypeB = cutlass::AlignedArray<ElementB, AlignmentB>;
|
||||
using IteratorB =
|
||||
cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
||||
ElementB,
|
||||
ThreadMapB,
|
||||
StrideSupport::kStrided,
|
||||
AccessTypeB
|
||||
>;
|
||||
|
||||
using SmemIteratorB = typename MmaCore::SmemIteratorB;
|
||||
|
||||
// Warp-level GEMM components
|
||||
using WarpMmaTensorOp = typename MmaCore::MmaTensorOp;
|
||||
using MmaPolicy = typename MmaCore::MmaPolicy;
|
||||
|
||||
static cutlass::arch::CacheOperation::Kind const CacheOpB =
|
||||
((sizeof_bits<ElementB>::value * AlignmentB) == 128)
|
||||
? cutlass::arch::CacheOperation::Global
|
||||
: cutlass::arch::CacheOperation::Always;
|
||||
|
||||
// Define the Mma
|
||||
using Mma = threadblock::ImplicitGemmMultistage<
|
||||
ThreadblockShape,
|
||||
IteratorA,
|
||||
SmemIteratorA,
|
||||
arch::CacheOperation::Always,
|
||||
IteratorB,
|
||||
SmemIteratorB,
|
||||
CacheOpB,
|
||||
MmaPolicy,
|
||||
Stages
|
||||
>;
|
||||
|
||||
static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;
|
||||
|
||||
// Define the epilogue
|
||||
using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOpStridedDgrad<
|
||||
ThreadblockShape,
|
||||
WarpMmaTensorOp,
|
||||
kPartitionsK,
|
||||
EpilogueOutputOp,
|
||||
EpilogueOutputOp::kCount
|
||||
>::Epilogue;
|
||||
|
||||
// Define the kernel
|
||||
using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionStridedDgrad<
|
||||
Mma,
|
||||
Epilogue,
|
||||
ThreadblockSwizzle,
|
||||
conv::Operator::kDgrad
|
||||
>;
|
||||
};
|
||||
|
||||
/// Defines a kernel for Conv2dDgrad specialzation for Optimized IteratorAlgorithm Dgrad Strided
|
||||
// and 2 stage pipeline.
|
||||
template <
|
||||
typename ElementA,
|
||||
typename LayoutA,
|
||||
typename ElementB,
|
||||
typename LayoutB,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ElementAccumulator,
|
||||
typename ArchTag,
|
||||
typename ThreadblockShape,
|
||||
typename WarpShape,
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp,
|
||||
typename ThreadblockSwizzle,
|
||||
typename MathOperatorTag,
|
||||
int AlignmentA,
|
||||
int AlignmentB
|
||||
>
|
||||
struct DefaultConv2dDgrad <
|
||||
ElementA,
|
||||
LayoutA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ElementAccumulator,
|
||||
arch::OpClassTensorOp,
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp,
|
||||
ThreadblockSwizzle,
|
||||
2,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kOptimized,
|
||||
StrideSupport::kStrided,
|
||||
AlignmentA,
|
||||
AlignmentB
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor,
|
||||
ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
|
||||
2, MathOperatorTag>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
||||
using AccessTypeA = cutlass::AlignedArray<ElementA, AlignmentA>;
|
||||
using IteratorA =
|
||||
cutlass::conv::threadblock::TileIteratorStridedDgrad<
|
||||
cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
||||
ElementA,
|
||||
ThreadMapA,
|
||||
StrideSupport::kStrided,
|
||||
AccessTypeA
|
||||
>
|
||||
>;
|
||||
|
||||
using SmemIteratorA = typename MmaCore::SmemIteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
||||
using AccessTypeB = cutlass::AlignedArray<ElementB, AlignmentB>;
|
||||
using IteratorB =
|
||||
cutlass::conv::threadblock::TileIteratorStridedDgrad<
|
||||
cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
||||
ElementB,
|
||||
ThreadMapB,
|
||||
StrideSupport::kStrided,
|
||||
AccessTypeB
|
||||
>
|
||||
>;
|
||||
|
||||
using SmemIteratorB = typename MmaCore::SmemIteratorB;
|
||||
|
||||
// Warp-level GEMM components
|
||||
using WarpMmaTensorOp = typename MmaCore::MmaTensorOp;
|
||||
using MmaPolicy = typename MmaCore::MmaPolicy;
|
||||
|
||||
// Define the Mma
|
||||
using Mma = threadblock::ImplicitGemmPipelined<
|
||||
ThreadblockShape,
|
||||
IteratorA,
|
||||
SmemIteratorA,
|
||||
IteratorB,
|
||||
SmemIteratorB,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
MmaPolicy
|
||||
>;
|
||||
|
||||
static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;
|
||||
|
||||
// Define the epilogue
|
||||
using Epilogue = typename detail::DefaultConvEpilogueStridedDgrad<
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpMmaTensorOp,
|
||||
kPartitionsK,
|
||||
EpilogueOutputOp
|
||||
>::Epilogue;
|
||||
|
||||
// Define the kernel
|
||||
using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionStridedDgrad<
|
||||
Mma,
|
||||
Epilogue,
|
||||
ThreadblockSwizzle,
|
||||
conv::Operator::kDgrad
|
||||
>;
|
||||
};
|
||||
|
||||
/// Defines a kernel for Conv2dDgrad specialzation for Optimized IteratorAlgorithm Dgrad Unity
|
||||
// 2 stage pipeline
|
||||
template <
|
||||
@@ -1126,6 +1363,112 @@ struct DefaultConv2dDgrad <
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
template <
|
||||
typename ElementA,
|
||||
typename LayoutA,
|
||||
typename ElementB,
|
||||
typename LayoutB,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ElementAccumulator,
|
||||
typename ArchTag,
|
||||
typename ThreadblockShape,
|
||||
typename WarpShape,
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp,
|
||||
typename ThreadblockSwizzle,
|
||||
int Stages,
|
||||
typename MathOperatorTag,
|
||||
int AlignmentA,
|
||||
int AlignmentB
|
||||
>
|
||||
struct DefaultConv2dDgrad <
|
||||
ElementA,
|
||||
LayoutA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ElementAccumulator,
|
||||
arch::OpClassSimt,
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp,
|
||||
ThreadblockSwizzle,
|
||||
Stages,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kOptimized,
|
||||
conv::StrideSupport::kStrided,
|
||||
AlignmentA,
|
||||
AlignmentB
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor,
|
||||
ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt,
|
||||
Stages, MathOperatorTag>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
||||
using IteratorA =
|
||||
cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
||||
ElementA,
|
||||
ThreadMapA,
|
||||
conv::StrideSupport::kStrided
|
||||
>;
|
||||
|
||||
using SmemIteratorA = typename MmaCore::SmemIteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
||||
using IteratorB =
|
||||
cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
||||
ElementB,
|
||||
ThreadMapB,
|
||||
conv::StrideSupport::kStrided
|
||||
>;
|
||||
|
||||
using SmemIteratorB = typename MmaCore::SmemIteratorB;
|
||||
|
||||
// Warp-level GEMM components
|
||||
using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt;
|
||||
using MmaPolicy = typename MmaCore::MmaPolicy;
|
||||
|
||||
// Define the Mma
|
||||
using Mma = threadblock::ImplicitGemmMultistage<
|
||||
ThreadblockShape,
|
||||
IteratorA,
|
||||
SmemIteratorA,
|
||||
arch::CacheOperation::Always,
|
||||
IteratorB,
|
||||
SmemIteratorB,
|
||||
arch::CacheOperation::Always,
|
||||
MmaPolicy,
|
||||
Stages
|
||||
>;
|
||||
|
||||
// Define the epilogue
|
||||
using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimtStridedDgrad<
|
||||
ThreadblockShape,
|
||||
WarpMmaSimtOp,
|
||||
EpilogueOutputOp,
|
||||
EpilogueOutputOp::kCount
|
||||
>::Epilogue;
|
||||
|
||||
// Define the kernel
|
||||
using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionStridedDgrad<
|
||||
Mma,
|
||||
Epilogue,
|
||||
ThreadblockSwizzle,
|
||||
conv::Operator::kDgrad
|
||||
>;
|
||||
|
||||
};
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines a kernel for Conv2dDgrad specialzation for Analytic IteratorAlgorithm,
|
||||
@@ -1462,6 +1805,115 @@ struct DefaultConv2dDgrad <
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
template <
|
||||
typename ElementA,
|
||||
typename LayoutA,
|
||||
typename ElementB,
|
||||
typename LayoutB,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ElementAccumulator,
|
||||
typename ArchTag,
|
||||
typename ThreadblockShape,
|
||||
typename WarpShape,
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp,
|
||||
typename ThreadblockSwizzle,
|
||||
typename MathOperatorTag,
|
||||
int AlignmentA,
|
||||
int AlignmentB
|
||||
>
|
||||
struct DefaultConv2dDgrad <
|
||||
ElementA,
|
||||
LayoutA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ElementAccumulator,
|
||||
arch::OpClassSimt,
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp,
|
||||
ThreadblockSwizzle,
|
||||
2,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kOptimized,
|
||||
conv::StrideSupport::kStrided,
|
||||
AlignmentA,
|
||||
AlignmentB
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor,
|
||||
ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt,
|
||||
2, MathOperatorTag>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
||||
using IteratorA =
|
||||
cutlass::conv::threadblock::TileIteratorStridedDgrad<
|
||||
cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
||||
ElementA,
|
||||
ThreadMapA,
|
||||
conv::StrideSupport::kStrided
|
||||
>
|
||||
>;
|
||||
|
||||
using SmemIteratorA = typename MmaCore::SmemIteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
||||
using IteratorB =
|
||||
cutlass::conv::threadblock::TileIteratorStridedDgrad<
|
||||
cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
||||
ElementB,
|
||||
ThreadMapB,
|
||||
conv::StrideSupport::kStrided
|
||||
>
|
||||
>;
|
||||
|
||||
using SmemIteratorB = typename MmaCore::SmemIteratorB;
|
||||
|
||||
// Warp-level GEMM components
|
||||
using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt;
|
||||
using MmaPolicy = typename MmaCore::MmaPolicy;
|
||||
|
||||
// Define the Mma
|
||||
using Mma = threadblock::ImplicitGemmPipelined<
|
||||
ThreadblockShape,
|
||||
IteratorA,
|
||||
SmemIteratorA,
|
||||
IteratorB,
|
||||
SmemIteratorB,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
MmaPolicy
|
||||
>;
|
||||
|
||||
// Define the epilogue
|
||||
using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimtStridedDgrad<
|
||||
ThreadblockShape,
|
||||
WarpMmaSimtOp,
|
||||
EpilogueOutputOp,
|
||||
EpilogueOutputOp::kCount
|
||||
>::Epilogue;
|
||||
|
||||
// Define the kernel
|
||||
using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionStridedDgrad<
|
||||
Mma,
|
||||
Epilogue,
|
||||
ThreadblockSwizzle,
|
||||
conv::Operator::kDgrad
|
||||
>;
|
||||
|
||||
};
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace conv
|
||||
} // namespace cutlass
|
||||
|
||||
@@ -65,7 +65,7 @@ template <
|
||||
typename ThreadblockSwizzle,
|
||||
int Stages,
|
||||
typename MathOperatorTag,
|
||||
conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kAnalytic,
|
||||
conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized,
|
||||
conv::StrideSupport StrideSupport = StrideSupport::kStrided,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value,
|
||||
|
||||
@@ -64,7 +64,7 @@ template <
|
||||
typename ThreadblockSwizzle,
|
||||
int Stages,
|
||||
typename MathOperatorTag,
|
||||
conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kAnalytic,
|
||||
conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized,
|
||||
conv::StrideSupport StrideSupport = StrideSupport::kStrided,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value,
|
||||
|
||||
@@ -65,7 +65,7 @@ template <
|
||||
typename ThreadblockSwizzle,
|
||||
int Stages,
|
||||
typename MathOperatorTag,
|
||||
conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kAnalytic,
|
||||
conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized,
|
||||
conv::StrideSupport StrideSupport = StrideSupport::kStrided,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value,
|
||||
|
||||
@@ -66,7 +66,7 @@ template <
|
||||
typename ThreadblockSwizzle,
|
||||
int Stages,
|
||||
typename MathOperatorTag,
|
||||
conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kAnalytic,
|
||||
conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized,
|
||||
conv::StrideSupport StrideSupport = StrideSupport::kStrided,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value,
|
||||
|
||||
@@ -66,7 +66,7 @@ template <
|
||||
typename ThreadblockSwizzle,
|
||||
int Stages,
|
||||
typename MathOperatorTag,
|
||||
conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kAnalytic,
|
||||
conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized,
|
||||
conv::StrideSupport StrideSupport = StrideSupport::kStrided
|
||||
> struct DefaultConv3dDgrad;
|
||||
|
||||
|
||||
@@ -66,7 +66,7 @@ template <
|
||||
typename ThreadblockSwizzle,
|
||||
int Stages,
|
||||
typename MathOperatorTag,
|
||||
conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kAnalytic,
|
||||
conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized,
|
||||
conv::StrideSupport StrideSupport = StrideSupport::kStrided
|
||||
> struct DefaultConv3dFprop;
|
||||
|
||||
|
||||
@@ -65,7 +65,7 @@ template <
|
||||
typename ThreadblockSwizzle,
|
||||
int Stages,
|
||||
typename MathOperatorTag,
|
||||
conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kAnalytic,
|
||||
conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized,
|
||||
conv::StrideSupport StrideSupport = StrideSupport::kStrided
|
||||
> struct DefaultConv3dWgrad;
|
||||
|
||||
|
||||
@@ -210,9 +210,9 @@ public:
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorCoord at() const {
|
||||
|
||||
int c = offset_c_[iteration_contiguous_];
|
||||
int k = offset_k_[iteration_strided_];
|
||||
|
||||
int c = offset_c_[iteration_contiguous_] + iteration_vector_ * AccessType::kElements;
|
||||
|
||||
return TensorCoord(k, filter_r_, filter_s_, c);
|
||||
}
|
||||
|
||||
@@ -222,7 +222,7 @@ public:
|
||||
|
||||
TensorCoord coord = at();
|
||||
|
||||
return coord.n() < problem_size_.K && (coord.c() + iteration_vector_ * AccessType::kElements) < problem_size_.C;
|
||||
return coord.n() < problem_size_.K && coord.c() < problem_size_.C;
|
||||
}
|
||||
|
||||
/// Returns a pointer to the vector starting at the current coordinate
|
||||
@@ -232,7 +232,7 @@ public:
|
||||
TensorCoord coord = at();
|
||||
LongIndex offset = params_.layout(coord);
|
||||
|
||||
return reinterpret_cast<AccessType const *>(pointer_ + offset * sizeof_bits<Element>::value / 8) + iteration_vector_;
|
||||
return reinterpret_cast<AccessType const *>(pointer_ + offset * sizeof_bits<Element>::value / 8);
|
||||
|
||||
}
|
||||
|
||||
@@ -250,6 +250,7 @@ public:
|
||||
return *this;
|
||||
}
|
||||
iteration_contiguous_ = 0;
|
||||
|
||||
++iteration_strided_;
|
||||
if (iteration_strided_ < ThreadMap::Iterations::kStrided) {
|
||||
return *this;
|
||||
@@ -408,8 +409,8 @@ public:
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorCoord at() const {
|
||||
|
||||
int c = offset_c_[iteration_contiguous_];
|
||||
int k = offset_k_[iteration_strided_];
|
||||
int c = offset_c_[iteration_contiguous_] + iteration_vector_ * AccessType::kElements;
|
||||
|
||||
return TensorCoord(k, filter_r_, filter_s_, c);
|
||||
}
|
||||
@@ -420,7 +421,7 @@ public:
|
||||
|
||||
TensorCoord coord = at();
|
||||
|
||||
return coord.n() < problem_size_.K && (coord.c() + iteration_vector_ * AccessType::kElements) < problem_size_.C;
|
||||
return coord.n() < problem_size_.K && coord.c() < problem_size_.C;
|
||||
}
|
||||
|
||||
/// Returns a pointer to the vector starting at the current coordinate
|
||||
@@ -430,7 +431,7 @@ public:
|
||||
TensorCoord coord = at();
|
||||
LongIndex offset = params_.layout(coord);
|
||||
|
||||
return reinterpret_cast<AccessType const *>(pointer_ + offset * sizeof_bits<Element>::value / 8) + iteration_vector_;
|
||||
return reinterpret_cast<AccessType const *>(pointer_ + offset * sizeof_bits<Element>::value / 8);
|
||||
}
|
||||
|
||||
/// Increments to the next memory access
|
||||
|
||||
@@ -67,6 +67,282 @@ class Conv2dDgradFilterTileAccessIteratorOptimized;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Conv2dDgradFilterTileAccessIteratorOptimized unity strided dgrad is more performant for dgrad
|
||||
// on problem sizes with stride = {1x1}
|
||||
template <
|
||||
typename Shape_,
|
||||
typename Element_,
|
||||
typename ThreadMap_,
|
||||
typename AccessType_
|
||||
>
|
||||
class Conv2dDgradFilterTileAccessIteratorOptimized <
|
||||
Shape_,
|
||||
Element_,
|
||||
ThreadMap_,
|
||||
conv::StrideSupport::kStrided,
|
||||
AccessType_
|
||||
> {
|
||||
public:
|
||||
|
||||
//
|
||||
// Types
|
||||
//
|
||||
|
||||
using Shape = Shape_;
|
||||
using Element = Element_;
|
||||
using Layout = layout::TensorNHWC;
|
||||
using ThreadMap = ThreadMap_;
|
||||
using AccessType = AccessType_;
|
||||
using TensorRef = cutlass::TensorRef<Element, Layout>;
|
||||
using TensorCoord = typename Layout::TensorCoord;
|
||||
using Index = typename Layout::Index;
|
||||
using LongIndex = typename Layout::LongIndex;
|
||||
static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kOptimized;
|
||||
static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided;
|
||||
static int const kConvDim = 2;
|
||||
using ConvProblemSize = typename conv::Conv2dProblemSize;
|
||||
|
||||
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
|
||||
|
||||
static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements),
|
||||
"Vectors implied by the thread map must be divisible by the access type.");
|
||||
|
||||
//
|
||||
// Parameters structure
|
||||
//
|
||||
|
||||
struct Params : Conv2dStridedDgradFilterIteratorOptimizedParams {
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params() { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Conv2dStridedDgradFilterIteratorOptimizedParams const &base):
|
||||
Conv2dStridedDgradFilterIteratorOptimizedParams(base) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(
|
||||
Conv2dProblemSize const &problem_size,
|
||||
Layout const &layout
|
||||
):
|
||||
Conv2dStridedDgradFilterIteratorOptimizedParams(
|
||||
problem_size,
|
||||
layout,
|
||||
sizeof_bits<Element>::value,
|
||||
{Shape::kRow, Shape::kColumn},
|
||||
ThreadMap::kThreads,
|
||||
ThreadMap::kElementsPerAccess,
|
||||
{ThreadMap::Iterations::kContiguous, ThreadMap::Iterations::kStrided},
|
||||
{ThreadMap::Delta::kContiguous, ThreadMap::Delta::kStrided}
|
||||
) { }
|
||||
|
||||
};
|
||||
|
||||
private:
|
||||
|
||||
Conv2dStridedDgradFilterIteratorOptimizedParams const ¶ms_;
|
||||
Conv2dProblemSize const &problem_size_;
|
||||
LongIndex iteration_contiguous_;
|
||||
LongIndex iteration_strided_;
|
||||
LongIndex iteration_vector_;
|
||||
char const *pointer_;
|
||||
|
||||
uint32_t predicates_[kAccessesPerVector];
|
||||
int filter_k_;
|
||||
int filter_r_;
|
||||
int filter_s_;
|
||||
|
||||
int start_r_;
|
||||
int start_s_;
|
||||
|
||||
int64_t reset_bytes_s_;
|
||||
int64_t reset_bytes_r_;
|
||||
|
||||
//
|
||||
// Assertions
|
||||
//
|
||||
|
||||
// We map predicates into bits packed in this uint32_t container
|
||||
static_assert(ThreadMap::Iterations::kStrided *
|
||||
ThreadMap::Iterations::kContiguous < sizeof(predicates_) * 8,
|
||||
"Currently, the number of loads per iteration is limited by the size of the predicates container.");
|
||||
|
||||
public:
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv2dDgradFilterTileAccessIteratorOptimized(
|
||||
Conv2dStridedDgradFilterIteratorOptimizedParams const ¶ms,
|
||||
Conv2dProblemSize const &problem_size,
|
||||
Element const *ptr,
|
||||
int thread_idx,
|
||||
int start_r, int start_s,
|
||||
MatrixCoord const &threadblock_offset = MatrixCoord()
|
||||
):
|
||||
params_(params),
|
||||
problem_size_(problem_size),
|
||||
pointer_(reinterpret_cast<char const *>(ptr)),
|
||||
predicates_{0},
|
||||
filter_r_(start_r),
|
||||
filter_s_(start_s),
|
||||
start_r_(start_r),
|
||||
start_s_(start_s) {
|
||||
|
||||
layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx);
|
||||
|
||||
filter_k_ = threadblock_offset.row() + thread_coord.strided();
|
||||
Index column = threadblock_offset.column() + thread_coord.contiguous();
|
||||
|
||||
reset_bytes_s_ = (problem_size_.num_gemm_k_filter_s(start_s_) - 1) * params_.inc_next[0];
|
||||
reset_bytes_r_ = reset_bytes_s_ +
|
||||
(problem_size_.num_gemm_k_filter_r(start_r_) - 1) * params_.inc_next[1];
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) {
|
||||
|
||||
int filter_k = filter_k_ + s * ThreadMap::Delta::kStrided;
|
||||
int filter_c = column + c * ThreadMap::Delta::kContiguous;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < kAccessesPerVector; ++v) {
|
||||
|
||||
uint32_t pred = ((filter_k < problem_size_.K && (filter_c + v * AccessType::kElements) < problem_size_.C) ? 1u : 0);
|
||||
|
||||
int pred_idx = c + s * ThreadMap::Iterations::kContiguous;
|
||||
|
||||
predicates_[v] |= (pred << pred_idx);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TensorCoord coord{filter_k_, filter_r_, filter_s_, column};
|
||||
|
||||
pointer_ += params_.layout(coord) * sizeof_bits<Element>::value / 8;
|
||||
|
||||
set_iteration_index(0);
|
||||
}
|
||||
|
||||
/// Overrides the internal iteration index
|
||||
CUTLASS_HOST_DEVICE
|
||||
void set_iteration_index(Index index) {
|
||||
iteration_vector_ = index % kAccessesPerVector;
|
||||
int residual_access = index / kAccessesPerVector;
|
||||
iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous;
|
||||
iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous;
|
||||
}
|
||||
|
||||
/// Adds a pointer offset in units of Element
|
||||
CUTLASS_HOST_DEVICE
|
||||
void add_pointer_offset(LongIndex pointer_offset) {
|
||||
|
||||
pointer_ += pointer_offset * sizeof_bits<Element>::value / 8;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void advance() {
|
||||
|
||||
int next_idx = 0;
|
||||
LongIndex reset_bytes = params_.reset_bytes;
|
||||
|
||||
// Move filter_s by stride_w
|
||||
filter_s_ += problem_size_.stride_w;
|
||||
if (filter_s_ >= problem_size_.S) {
|
||||
|
||||
// Restore filter_s
|
||||
filter_s_ = start_s_;
|
||||
|
||||
// Move filter_r by stride_h
|
||||
filter_r_ += problem_size_.stride_h;
|
||||
|
||||
bool check = (filter_r_ < problem_size_.R);
|
||||
|
||||
filter_r_ = check ? filter_r_ : start_r_;
|
||||
next_idx = check ? 1 : 2;
|
||||
reset_bytes += (check ? reset_bytes_s_ : reset_bytes_r_);
|
||||
}
|
||||
|
||||
// offset pointers by offset_bytes
|
||||
pointer_ += (params_.inc_next[next_idx] - reset_bytes);
|
||||
|
||||
if (next_idx == 2) {
|
||||
filter_k_ += params_.filter_k_delta;
|
||||
}
|
||||
|
||||
// Clear predicates if needed
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
|
||||
if (filter_k_ + s * ThreadMap::Delta::kStrided >= problem_size_.K) {
|
||||
uint32_t kClearMask = ((1u << ThreadMap::Iterations::kContiguous) - 1) << (s * ThreadMap::Iterations::kContiguous);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < kAccessesPerVector; ++v) {
|
||||
predicates_[v] = (predicates_[v] & (~kClearMask));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns true if the current coordinate is within the filter tensor W
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool valid() {
|
||||
LongIndex pred_idx = iteration_contiguous_ + iteration_strided_ * ThreadMap::Iterations::kContiguous;
|
||||
return (predicates_[iteration_vector_] & (1u << pred_idx));
|
||||
}
|
||||
|
||||
/// Returns a pointer to the vector starting at the current coordinate
|
||||
CUTLASS_HOST_DEVICE
|
||||
AccessType const *get() const {
|
||||
return reinterpret_cast<AccessType const *>(pointer_ +
|
||||
iteration_contiguous_ * ThreadMap::Delta::kContiguous * sizeof_bits<Element>::value / 8) + iteration_vector_;
|
||||
}
|
||||
|
||||
/// Increments to the next memory access
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv2dDgradFilterTileAccessIteratorOptimized &operator++() {
|
||||
++iteration_vector_;
|
||||
if (iteration_vector_ < kAccessesPerVector) {
|
||||
return *this;
|
||||
}
|
||||
iteration_vector_ = 0;
|
||||
|
||||
++iteration_contiguous_;
|
||||
if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) {
|
||||
return *this;
|
||||
}
|
||||
iteration_contiguous_ = 0;
|
||||
|
||||
++iteration_strided_;
|
||||
if (iteration_strided_ < ThreadMap::Iterations::kStrided) {
|
||||
|
||||
// Move to the next K coordinate within the tile
|
||||
pointer_ += params_.inc_next_strided;
|
||||
|
||||
return *this;
|
||||
}
|
||||
iteration_strided_ = 0;
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Determines whether the Implicit GEMM can execute the given problem.
|
||||
CUTLASS_HOST_DEVICE
|
||||
static Status can_implement(Conv2dProblemSize const &problem_size) {
|
||||
|
||||
// check alignment constraint on iterator's contiguous dimension
|
||||
if (problem_size.C % AccessType::kElements) {
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Conv2dDgradFilterTileAccessIteratorOptimized unity strided dgrad is more performant for dgrad
|
||||
// on problem sizes with stride = {1x1}
|
||||
template <
|
||||
|
||||
@@ -268,11 +268,13 @@ public:
|
||||
p += (conv_sign * (filter_r_ / problem_size_.stride_h));
|
||||
q += (conv_sign * (filter_s_ / problem_size_.stride_w));
|
||||
|
||||
int k = filter_k_ + iteration_vector_ * AccessType::kElements;
|
||||
|
||||
return TensorCoord(
|
||||
n,
|
||||
p,
|
||||
q,
|
||||
filter_k_);
|
||||
k);
|
||||
}
|
||||
|
||||
|
||||
@@ -286,7 +288,7 @@ public:
|
||||
coord.n() < problem_size_.N &&
|
||||
coord.h() >= 0 && coord.h() < problem_size_.P &&
|
||||
coord.w() >= 0 && coord.w() < problem_size_.Q &&
|
||||
(coord.c() + iteration_vector_ * AccessType::kElements) < problem_size_.K;
|
||||
coord.c() < problem_size_.K;
|
||||
}
|
||||
|
||||
/// Returns a pointer to the vector starting at the current coordinate
|
||||
@@ -296,7 +298,7 @@ public:
|
||||
TensorCoord coord = at();
|
||||
LongIndex offset = params_.layout(coord);
|
||||
|
||||
return reinterpret_cast<AccessType const *>(pointer_ + offset * sizeof_bits<Element>::value / 8) + iteration_vector_;
|
||||
return reinterpret_cast<AccessType const *>(pointer_ + offset * sizeof_bits<Element>::value / 8);
|
||||
}
|
||||
|
||||
/// Increments to the next memory access
|
||||
@@ -313,6 +315,7 @@ public:
|
||||
return *this;
|
||||
}
|
||||
iteration_contiguous_ = 0;
|
||||
|
||||
++iteration_strided_;
|
||||
if (iteration_strided_ < ThreadMap::Iterations::kStrided) {
|
||||
return *this;
|
||||
@@ -516,7 +519,9 @@ public:
|
||||
int p = (h + problem_size_.pad_h - r * problem_size_.dilation_h) / problem_size_.stride_h;
|
||||
int q = (w + problem_size_.pad_w - s * problem_size_.dilation_w) / problem_size_.stride_w;
|
||||
|
||||
return TensorCoord(n, p, q, filter_k_);
|
||||
int k = filter_k_ + iteration_vector_ * AccessType::kElements;
|
||||
|
||||
return TensorCoord(n, p, q, k);
|
||||
|
||||
}
|
||||
|
||||
@@ -529,7 +534,7 @@ public:
|
||||
return coord.n() < problem_size_.N &&
|
||||
coord.h() >= 0 && coord.h() < problem_size_.P &&
|
||||
coord.w() >= 0 && coord.w() < problem_size_.Q &&
|
||||
(coord.c() + iteration_vector_ * AccessType::kElements) < problem_size_.K;
|
||||
coord.c() < problem_size_.K;
|
||||
}
|
||||
|
||||
/// Returns a pointer to the vector starting at the current coordinate
|
||||
@@ -539,7 +544,7 @@ public:
|
||||
TensorCoord coord = at();
|
||||
LongIndex offset = params_.layout(coord);
|
||||
|
||||
return reinterpret_cast<AccessType const *>(pointer_ + offset * sizeof_bits<Element>::value / 8) + iteration_vector_;
|
||||
return reinterpret_cast<AccessType const *>(pointer_ + offset * sizeof_bits<Element>::value / 8);
|
||||
}
|
||||
|
||||
/// Increments to the next memory access
|
||||
|
||||
@@ -67,6 +67,380 @@ template <
|
||||
class Conv2dDgradOutputGradientTileAccessIteratorOptimized;
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// Conv2dDgradOutputGradientTileAccessIteratorOptimized strided dgrad needs special handling
|
||||
// to skip MMAs (Dx = Dy * w) on invalid filter positions
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
template <
|
||||
typename Shape_,
|
||||
typename Element_,
|
||||
typename ThreadMap_,
|
||||
typename AccessType_
|
||||
>
|
||||
class Conv2dDgradOutputGradientTileAccessIteratorOptimized <
|
||||
Shape_,
|
||||
Element_,
|
||||
ThreadMap_,
|
||||
conv::StrideSupport::kStrided,
|
||||
AccessType_
|
||||
> {
|
||||
public:
|
||||
|
||||
//
|
||||
// Types
|
||||
//
|
||||
using Shape = Shape_;
|
||||
using Element = Element_;
|
||||
using Layout = layout::TensorNHWC;
|
||||
using ThreadMap = ThreadMap_;
|
||||
using AccessType = AccessType_;
|
||||
using TensorRef = cutlass::TensorRef<Element, Layout>;
|
||||
using TensorCoord = typename Layout::TensorCoord;
|
||||
using Index = typename Layout::Index;
|
||||
using LongIndex = typename Layout::LongIndex;
|
||||
static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kOptimized;
|
||||
static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided;
|
||||
static int const kConvDim = 2;
|
||||
using ConvProblemSize = typename conv::Conv2dProblemSize;
|
||||
|
||||
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
|
||||
|
||||
static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements),
|
||||
"Vectors implied by the thread map must be divisible by the access type.");
|
||||
|
||||
using Mask = uint64_t;
|
||||
|
||||
static_assert(sizeof_bits<Element>::value >= 8,
|
||||
"DGRAD requires elements of size 8b or greater.");
|
||||
|
||||
//
|
||||
// Simpligying assertions
|
||||
//
|
||||
|
||||
static_assert(ThreadMap::Iterations::kContiguous == 1,
|
||||
"Require Iterations::kContiguous == 1");
|
||||
|
||||
//
|
||||
// Parameters structure
|
||||
//
|
||||
|
||||
using Params = Conv2dStridedDgradOutputGradientIteratorOptimizedParams;
|
||||
|
||||
private:
|
||||
|
||||
Params const ¶ms_;
|
||||
Conv2dProblemSize const &problem_size_;
|
||||
LongIndex iteration_contiguous_;
|
||||
LongIndex iteration_strided_;
|
||||
LongIndex iteration_vector_;
|
||||
|
||||
// One pointer per access
|
||||
char const *pointer_[ThreadMap::Iterations::kStrided];
|
||||
|
||||
int filter_k_;
|
||||
int filter_r_;
|
||||
int filter_s_;
|
||||
int start_r_;
|
||||
int start_s_;
|
||||
int64_t reset_bytes_s_;
|
||||
int64_t reset_bytes_r_;
|
||||
|
||||
Index masks_[ThreadMap::Iterations::kStrided][kAccessesPerVector][2];
|
||||
|
||||
public:
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv2dDgradOutputGradientTileAccessIteratorOptimized(
|
||||
Params const ¶ms,
|
||||
Conv2dProblemSize const &problem_size,
|
||||
Element const *ptr,
|
||||
int thread_idx,
|
||||
FastDivmod const &stride_h_divmod, FastDivmod const &stride_w_divmod,
|
||||
int start_r, int start_s,
|
||||
MatrixCoord const &threadblock_offset = MatrixCoord() // threadblock offset - units are whole CTA tiles
|
||||
):
|
||||
params_(params),
|
||||
problem_size_(problem_size),
|
||||
filter_k_(0),
|
||||
filter_r_(start_r),
|
||||
filter_s_(start_s),
|
||||
start_r_(start_r),
|
||||
start_s_(start_s) {
|
||||
|
||||
layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx);
|
||||
|
||||
filter_k_ = threadblock_offset.column() + thread_coord.contiguous();
|
||||
|
||||
reset_bytes_s_ = (problem_size_.num_gemm_k_filter_s(start_s_) - 1) * params_.inc_next[0];
|
||||
|
||||
reset_bytes_r_ = (problem_size_.num_gemm_k_filter_s(start_s_) - 1) * params_.inc_next[0] +
|
||||
(problem_size_.num_gemm_k_filter_r(start_r_) - 1) * params_.inc_next[1];
|
||||
|
||||
int offset_n[ThreadMap::Iterations::kStrided];
|
||||
int offset_p[ThreadMap::Iterations::kStrided];
|
||||
int offset_q[ThreadMap::Iterations::kStrided];
|
||||
|
||||
int filter_r = filter_r_;
|
||||
int filter_s = filter_s_;
|
||||
|
||||
if (problem_size_.mode == Mode::kConvolution) {
|
||||
filter_r = (problem_size_.R - 1 - filter_r);
|
||||
filter_s = (problem_size_.S - 1 - filter_s);
|
||||
}
|
||||
|
||||
// Starting h, w positions for filter position in gemm_k=0
|
||||
int start_h, start_w;
|
||||
strided_dgrad_starting_coords(
|
||||
problem_size_,
|
||||
stride_h_divmod, stride_w_divmod,
|
||||
filter_r, filter_s,
|
||||
start_h, start_w);
|
||||
|
||||
|
||||
// Effective starting P and Q for filter position required for remapping NHW rows
|
||||
int P = (problem_size_.H - start_h + problem_size_.stride_h - 1) / problem_size_.stride_h;
|
||||
int Q = (problem_size_.W - start_w + problem_size_.stride_w - 1) / problem_size_.stride_w;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
|
||||
|
||||
pointer_[s] = reinterpret_cast<char const *>(ptr);
|
||||
|
||||
int offset_npq = (threadblock_offset.row() + thread_coord.strided() + s * ThreadMap::Delta::kStrided) % params_.tiled_rows_per_filter;
|
||||
|
||||
// (STEP 1) [reorder NHW rows to start with same filter positions]
|
||||
offset_n[s] = offset_npq / (P * Q);
|
||||
int residual = offset_npq % (P * Q);
|
||||
|
||||
int p = (residual / Q);
|
||||
int q = (residual % Q);
|
||||
|
||||
int mapped_h = (start_h + p * problem_size_.stride_h);
|
||||
int mapped_w = (start_w + q * problem_size_.stride_w);
|
||||
|
||||
// Access (p, q) coordinates for Dy tensor for filter position in gemm_k=0
|
||||
// note that (h + pad_h - filter_r) and (w + pad_w - filter_s) are ensured to be
|
||||
// divisible by stride_h and stride_w
|
||||
offset_p[s] = (mapped_h + problem_size_.pad_h - filter_r) / problem_size_.stride_h;
|
||||
offset_q[s] = (mapped_w + problem_size_.pad_w - filter_s) / problem_size_.stride_w;
|
||||
|
||||
// Intialize pointers for gemm_k=0
|
||||
TensorCoord coord{offset_n[s], offset_p[s], offset_q[s], filter_k_};
|
||||
|
||||
pointer_[s] += params_.layout(coord) * sizeof_bits<Element>::value / 8;
|
||||
}
|
||||
|
||||
//
|
||||
// Precompute mask predicates
|
||||
//
|
||||
clear_mask();
|
||||
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
for (int r = start_r; r < problem_size_.R; r += problem_size_.stride_h) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int s_idx = 0; s_idx < ThreadMap::Iterations::kStrided; ++s_idx) {
|
||||
|
||||
int p = offset_p[s_idx] ;
|
||||
|
||||
p += (params_.conv_sign * (r / problem_size_.stride_h));
|
||||
|
||||
bool pred = (offset_n[s_idx] < problem_size_.N && p >= 0 && p < problem_size_.P);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) {
|
||||
masks_[s_idx][v_idx][0] |= (pred << r);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
for(int s = start_s; s < problem_size_.S; s += problem_size_.stride_w) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int s_idx = 0; s_idx < ThreadMap::Iterations::kStrided; ++s_idx) {
|
||||
|
||||
int q = offset_q[s_idx];
|
||||
q += (params_.conv_sign * (s / problem_size_.stride_w));
|
||||
|
||||
bool pred = (q >=0 && q < problem_size_.Q);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) {
|
||||
masks_[s_idx][v_idx][1] |= (pred << s);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) {
|
||||
clear_mask(v_idx, (filter_k_ + v_idx * AccessType::kElements) >= problem_size.K);
|
||||
}
|
||||
|
||||
set_iteration_index(0);
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
static Params getParams(Conv2dProblemSize const &problem_size, Layout const &layout) {
|
||||
return Params(problem_size,
|
||||
layout,
|
||||
sizeof_bits<Element>::value,
|
||||
{Shape::kRow, Shape::kColumn});
|
||||
}
|
||||
|
||||
private:
|
||||
|
||||
/// Adds a pointer offset in units of element
|
||||
CUTLASS_HOST_DEVICE
|
||||
void add_byte_offset_(LongIndex byte_offset, LongIndex byte_reset = 0) {
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
|
||||
pointer_[s] += byte_offset - byte_reset;
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
|
||||
/// Overrides the internal iteration index
|
||||
CUTLASS_HOST_DEVICE
|
||||
void set_iteration_index(Index index) {
|
||||
iteration_vector_ = index % kAccessesPerVector;
|
||||
int residual_access = index / kAccessesPerVector;
|
||||
iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous;
|
||||
iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous;
|
||||
}
|
||||
|
||||
/// Adds a pointer offset in units of Element
|
||||
CUTLASS_HOST_DEVICE
|
||||
void add_pointer_offset(LongIndex pointer_offset) {
|
||||
add_byte_offset_(pointer_offset * sizeof_bits<Element>::value / 8);
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void advance() {
|
||||
|
||||
int next_idx = 0;
|
||||
int64_t reset_bytes = 0;
|
||||
|
||||
// Move filter_s by stride_w
|
||||
filter_s_ += problem_size_.stride_w;
|
||||
if (filter_s_ >= problem_size_.S) {
|
||||
|
||||
// Restore filter_s
|
||||
filter_s_ = start_s_;
|
||||
|
||||
// Move filter_r by stride_h
|
||||
filter_r_ += problem_size_.stride_h;
|
||||
if (filter_r_ < problem_size_.R) {
|
||||
|
||||
next_idx = 1;
|
||||
|
||||
// Restore bytes in q coordinate (Mma in filter s dimenstion)
|
||||
reset_bytes = reset_bytes_s_;
|
||||
|
||||
} else {
|
||||
|
||||
// Restore filter_r
|
||||
filter_r_ = start_r_;
|
||||
|
||||
next_idx = 2;
|
||||
|
||||
// Restore bytes in p and q coordinate (Mma in filter s and r dimenstion)
|
||||
reset_bytes = reset_bytes_r_;
|
||||
}
|
||||
}
|
||||
|
||||
// offset pointers by offset_bytes
|
||||
add_byte_offset_(params_.inc_next[next_idx] - reset_bytes);
|
||||
|
||||
if (next_idx == 2) {
|
||||
filter_k_ += params_.filter_k_delta;
|
||||
}
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) {
|
||||
clear_mask(v_idx, (filter_k_ + v_idx * AccessType::kElements) >= problem_size_.K);
|
||||
}
|
||||
}
|
||||
|
||||
/// Clears the predicates
|
||||
CUTLASS_HOST_DEVICE
|
||||
void clear_mask(bool clear = true) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < kAccessesPerVector; ++v) {
|
||||
masks_[s][v][0] = clear ? Mask(0) : masks_[s][v][0];
|
||||
masks_[s][v][1] = clear ? Mask(0) : masks_[s][v][1];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Clears the predicates
|
||||
CUTLASS_HOST_DEVICE
|
||||
void clear_mask(int v, bool clear = true) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
|
||||
masks_[s][v][0] = clear ? Mask(0) : masks_[s][v][0];
|
||||
masks_[s][v][1] = clear ? Mask(0) : masks_[s][v][1];
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns true if the current coordinate is within the output tensor Dy
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool valid() const {
|
||||
return
|
||||
(masks_[iteration_strided_][iteration_vector_][0] & (Index(1) << filter_r_)) &&
|
||||
(masks_[iteration_strided_][iteration_vector_][1] & (Index(1) << filter_s_));
|
||||
}
|
||||
|
||||
/// Returns a pointer to the vector starting at the current coordinate
|
||||
CUTLASS_HOST_DEVICE
|
||||
AccessType const *get() const {
|
||||
|
||||
return reinterpret_cast<AccessType const *>(pointer_[iteration_strided_]) + iteration_vector_;
|
||||
}
|
||||
|
||||
/// Increments to the next memory access
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv2dDgradOutputGradientTileAccessIteratorOptimized &operator++() {
|
||||
++iteration_vector_;
|
||||
if (iteration_vector_ < kAccessesPerVector) {
|
||||
return *this;
|
||||
}
|
||||
iteration_vector_ = 0;
|
||||
|
||||
++iteration_contiguous_;
|
||||
if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) {
|
||||
return *this;
|
||||
}
|
||||
iteration_contiguous_ = 0;
|
||||
++iteration_strided_;
|
||||
if (iteration_strided_ < ThreadMap::Iterations::kStrided) {
|
||||
return *this;
|
||||
}
|
||||
iteration_strided_ = 0;
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Determines whether the Implicit GEMM can execute the given problem.
|
||||
CUTLASS_HOST_DEVICE
|
||||
static Status can_implement(Conv2dProblemSize const &problem_size) {
|
||||
|
||||
// check alignment constraint on iterator's contiguous dimension
|
||||
if (problem_size.K % AccessType::kElements) {
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
|
||||
// Limit on filter size
|
||||
if (problem_size.R > 32 || problem_size.S > 32) {
|
||||
return Status::kErrorNotSupported;
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// Conv2dDgradOutputGradientTileAccessIteratorOptimized unity stride dgrad is optimized for dgrad
|
||||
// with problem stride = {1x1}
|
||||
|
||||
@@ -209,7 +209,9 @@ public:
|
||||
int h = p * problem_size_.stride_h - problem_size_.pad_h + r * problem_size_.dilation_h;
|
||||
int w = q * problem_size_.stride_w - problem_size_.pad_w + s * problem_size_.dilation_w;
|
||||
|
||||
return TensorCoord(n, h, w, filter_c_);
|
||||
int c = filter_c_ + iteration_vector_ * AccessType::kElements;
|
||||
|
||||
return TensorCoord(n, h, w, c);
|
||||
}
|
||||
|
||||
/// Returns true if the current coordinate is within the activations tensor X
|
||||
@@ -221,7 +223,7 @@ public:
|
||||
return coord.n() < problem_size_.N &&
|
||||
coord.h() >= 0 && coord.h() < problem_size_.H &&
|
||||
coord.w() >= 0 && coord.w() < problem_size_.W &&
|
||||
(coord.c() + iteration_vector_ * AccessType::kElements) < problem_size_.C;
|
||||
coord.c() < problem_size_.C;
|
||||
}
|
||||
|
||||
/// Returns a pointer to the vector starting at the current coordinate
|
||||
@@ -231,7 +233,7 @@ public:
|
||||
TensorCoord coord = at();
|
||||
LongIndex offset = params_.layout(coord);
|
||||
|
||||
AccessType const *ptr = reinterpret_cast<AccessType const *>(pointer_ + offset * sizeof_bits<Element>::value / 8) + iteration_vector_;
|
||||
AccessType const *ptr = reinterpret_cast<AccessType const *>(pointer_ + offset * sizeof_bits<Element>::value / 8);
|
||||
|
||||
return ptr;
|
||||
}
|
||||
|
||||
@@ -183,8 +183,9 @@ public:
|
||||
TensorCoord at() const {
|
||||
|
||||
int k = offset_k_[iteration_strided_];
|
||||
int c = filter_c_ + iteration_vector_ * AccessType::kElements;
|
||||
|
||||
return TensorCoord(k, filter_r_, filter_s_, filter_c_);
|
||||
return TensorCoord(k, filter_r_, filter_s_, c);
|
||||
}
|
||||
|
||||
/// Returns true if the current coordinate is within the activations tensor W
|
||||
@@ -194,7 +195,7 @@ public:
|
||||
TensorCoord coord = at();
|
||||
|
||||
return coord.n() < problem_size_.K &&
|
||||
(coord.c() + iteration_vector_ * AccessType::kElements) < problem_size_.C;
|
||||
coord.c() < problem_size_.C;
|
||||
}
|
||||
|
||||
/// Returns a pointer to the vector starting at the current coordinate
|
||||
@@ -204,7 +205,7 @@ public:
|
||||
TensorCoord coord = at();
|
||||
LongIndex offset = params_.layout(coord);
|
||||
|
||||
return reinterpret_cast<AccessType const *>(pointer_ + offset * sizeof_bits<Element>::value / 8) + iteration_vector_;
|
||||
return reinterpret_cast<AccessType const *>(pointer_ + offset * sizeof_bits<Element>::value / 8);
|
||||
}
|
||||
|
||||
/// Increments to the next memory access
|
||||
|
||||
@@ -527,6 +527,64 @@ struct Conv2dDgradOutputGradientIteratorOptimizedParams {
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// Strided Dgrad Optimized Dy params (layout::TensorNHWC)
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
struct Conv2dStridedDgradOutputGradientIteratorOptimizedParams {
|
||||
|
||||
using Layout = layout::TensorNHWC;
|
||||
|
||||
Layout layout;
|
||||
|
||||
int64_t inc_next[3]; // {next S, next R, next K}
|
||||
|
||||
int filter_k_delta; // number of logical elements to add to filter_k_
|
||||
|
||||
int tiled_rows_per_filter;
|
||||
|
||||
int conv_sign;
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv2dStridedDgradOutputGradientIteratorOptimizedParams() { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv2dStridedDgradOutputGradientIteratorOptimizedParams(
|
||||
Conv2dProblemSize const &problem_size,
|
||||
Layout const &layout, ///< layout object
|
||||
int element_size_bits, ///< size of each element in bits
|
||||
MatrixCoord threadblock_shape
|
||||
): layout(layout) {
|
||||
|
||||
int tile_m_per_filter = strided_dgrad_tile_m_per_filter(problem_size, threadblock_shape.row());
|
||||
|
||||
tiled_rows_per_filter = tile_m_per_filter * threadblock_shape.row();
|
||||
|
||||
conv_sign = (problem_size.mode == Mode::kConvolution ? 1 : -1);
|
||||
|
||||
// next S
|
||||
inc_next[0] = conv_sign * (
|
||||
layout.stride()[0] * problem_size.dilation_w
|
||||
) * element_size_bits / 8;
|
||||
|
||||
// next R
|
||||
inc_next[1] = conv_sign * (
|
||||
layout.stride()[1] * problem_size.dilation_h
|
||||
) * element_size_bits / 8;
|
||||
|
||||
// next K
|
||||
inc_next[2] = (
|
||||
threadblock_shape.column() * problem_size.split_k_slices
|
||||
) * element_size_bits / 8;
|
||||
|
||||
// logical offset added to internal channel counter - units are elements, not bytes
|
||||
filter_k_delta = threadblock_shape.column() * problem_size.split_k_slices;
|
||||
}
|
||||
};
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// Dgrad Optimized w params (layout::TensorNHWC)
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -584,6 +642,73 @@ struct Conv2dDgradFilterIteratorOptimizedParams {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// StridedDgrad Optimized w params (layout::TensorNHWC)
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
struct Conv2dStridedDgradFilterIteratorOptimizedParams {
|
||||
|
||||
using Layout = layout::TensorNHWC;
|
||||
|
||||
Layout layout;
|
||||
int RS;
|
||||
int filter_k_delta;
|
||||
|
||||
int64_t inc_next_strided; // offset in units of bytes to next K coordinate within tile
|
||||
int64_t inc_next[3]; // {next S, next R, next K}
|
||||
int64_t reset_bytes; // offset in units of bytes to move back the pointer
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv2dStridedDgradFilterIteratorOptimizedParams() { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv2dStridedDgradFilterIteratorOptimizedParams(
|
||||
Conv2dProblemSize const &problem_size,
|
||||
Layout const &layout,
|
||||
int element_size_bits, ///< size of each element in bits
|
||||
MatrixCoord threadblock_shape,
|
||||
int thread_count,
|
||||
int access_size,
|
||||
layout::PitchLinearCoord threadmap_iterations,
|
||||
layout::PitchLinearCoord threadmap_delta
|
||||
):
|
||||
layout(layout), RS(problem_size.R * problem_size.S) {
|
||||
|
||||
TRACE_CONV_INITIALIZERS("conv2d_dgrad", "filter",
|
||||
element_size_bits, threadblock_shape, thread_count, access_size, threadmap_iterations, threadmap_delta);
|
||||
|
||||
inc_next_strided = (layout.stride()[2] * threadmap_delta.strided() * element_size_bits) / 8;
|
||||
|
||||
// next S
|
||||
inc_next[0] =
|
||||
( layout.stride()[0] * problem_size.stride_w
|
||||
//- (threadmap_iterations.strided() - 1) * threadmap_delta.strided() * layout.stride()[2]
|
||||
) * element_size_bits / 8;
|
||||
|
||||
// next R
|
||||
inc_next[1] =
|
||||
( layout.stride()[1] * problem_size.stride_h
|
||||
//- (threadmap_iterations.strided() - 1) * threadmap_delta.strided() * layout.stride()[2]
|
||||
) * element_size_bits / 8;
|
||||
|
||||
// next K
|
||||
inc_next[2] =
|
||||
(
|
||||
threadblock_shape.row() * problem_size.split_k_slices * layout.stride()[2]
|
||||
//- (problem_size.R * problem_size.S - 1) * layout.stride()[0]
|
||||
//- (threadmap_iterations.strided() - 1) * threadmap_delta.strided() * layout.stride()[2]
|
||||
) * element_size_bits / 8;
|
||||
|
||||
// offset in units of bytes to move the pointer in backward direction
|
||||
reset_bytes = (threadmap_iterations.strided() - 1) * threadmap_delta.strided() * layout.stride()[2]
|
||||
* element_size_bits / 8;
|
||||
|
||||
filter_k_delta = threadblock_shape.row() * problem_size.split_k_slices;
|
||||
}
|
||||
};
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Parameters object for Conv2d WGRAD Output Gradient (dy) iterator
|
||||
struct Conv2dWgradOutputGradientIteratorOptimizedParams {
|
||||
|
||||
|
||||
@@ -183,10 +183,13 @@ public:
|
||||
int r, s, c;
|
||||
|
||||
if (kAccessesPerVector == 1) {
|
||||
/// One 128b aligned access fetching more than one element
|
||||
c = filter_c_[iteration_contiguous_];
|
||||
r = filter_r_[iteration_contiguous_];
|
||||
s = filter_s_[iteration_contiguous_];
|
||||
c = filter_c_[iteration_contiguous_];
|
||||
} else {
|
||||
}
|
||||
else {
|
||||
/// Multiple access to support non-128b alignment in contiguous dimenstion
|
||||
c = (filter_c_[iteration_contiguous_] + iteration_vector_ * AccessType::kElements) % problem_size_.C;
|
||||
int wrap_c = (filter_c_[iteration_contiguous_] + iteration_vector_ * AccessType::kElements) / problem_size_.C;
|
||||
s = (filter_s_[iteration_contiguous_] + wrap_c) % problem_size_.S;
|
||||
|
||||
@@ -205,6 +205,8 @@ public:
|
||||
int c = filter_c_[iteration_contiguous_];
|
||||
|
||||
if (kAccessesPerVector > 1) {
|
||||
// This code section is only to support non-128b alignment
|
||||
// Multiple access to support non-128b alignment in contiguous dimenstion
|
||||
int wrap_c;
|
||||
params_.c_divmod(wrap_c, c, c + iteration_vector_ * AccessType::kElements);
|
||||
|
||||
|
||||
@@ -182,7 +182,9 @@ public:
|
||||
int p = residual / problem_size_.Q;
|
||||
int q = residual % problem_size_.Q;
|
||||
|
||||
return TensorCoord(n, p, q, filter_k_[iteration_contiguous_]);
|
||||
int k = filter_k_[iteration_contiguous_] + iteration_vector_ * AccessType::kElements;
|
||||
|
||||
return TensorCoord(n, p, q, k);
|
||||
}
|
||||
|
||||
|
||||
@@ -194,7 +196,7 @@ public:
|
||||
return coord.n() < problem_size_.N &&
|
||||
coord.h() < problem_size_.P &&
|
||||
coord.w() < problem_size_.Q &&
|
||||
(coord.c() + iteration_vector_ * AccessType::kElements) < problem_size_.K;
|
||||
coord.c() < problem_size_.K;
|
||||
}
|
||||
|
||||
/// Returns a pointer to the vector starting at the current coordinate
|
||||
@@ -204,7 +206,7 @@ public:
|
||||
TensorCoord coord = at();
|
||||
LongIndex offset = params_.layout(coord);
|
||||
|
||||
return reinterpret_cast<AccessType const *>(pointer_ + offset * sizeof_bits<Element>::value / 8) + iteration_vector_;
|
||||
return reinterpret_cast<AccessType const *>(pointer_ + offset * sizeof_bits<Element>::value / 8);
|
||||
}
|
||||
|
||||
/// Increments to the next memory access
|
||||
|
||||
@@ -192,6 +192,32 @@ struct GELU_taylor {
|
||||
}
|
||||
};
|
||||
|
||||
template <int N>
|
||||
struct GELU_taylor<Array<half_t, N> > {
|
||||
static const bool kIsHeavy=true;
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<half_t, N> operator()(Array<half_t, N> const &z) const {
|
||||
|
||||
using T = half_t;
|
||||
Array<half_t, N> y;
|
||||
|
||||
half_t k0 = half_t(0.7978845608028654);
|
||||
half_t k1 = half_t(0.044715);
|
||||
|
||||
multiply_add<Array<half_t, N>> fma;
|
||||
multiplies<Array<half_t, N>> mul;
|
||||
plus<Array<half_t, N>> add;
|
||||
|
||||
fast_tanh_op<Array<half_t, N>> tanh;
|
||||
|
||||
Array<half_t, N> u = mul(mul(k0, z), fma(mul(k1, z), z, cutlass::constants::one<T>()));
|
||||
|
||||
y = mul(mul(z, cutlass::constants::half<T>()), add(cutlass::constants::one<T>(), tanh(u)));
|
||||
|
||||
return y;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, int N>
|
||||
struct GELU_taylor<Array<T, N> > {
|
||||
static const bool kIsHeavy=true;
|
||||
|
||||
@@ -234,8 +234,9 @@ public:
|
||||
if (WarpShape::kN == 64) {
|
||||
ptr = pointers_[n / 4];
|
||||
}
|
||||
|
||||
#else
|
||||
else
|
||||
#endif
|
||||
{
|
||||
// This is the reference implementation
|
||||
int column_idx = warp_column_ + n * Detail::kLanesInQuad * Policy::kElementsPerAccess;
|
||||
int ptr_idx = ((column_idx * sizeof_bits<Element>::value) / 1024) % Detail::kPointerCount;
|
||||
@@ -252,7 +253,8 @@ public:
|
||||
else if (ptr_idx == 3) {
|
||||
ptr = pointers_[3 % Detail::kPointerCount];
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
|
||||
int offset = n * Detail::kLanesInQuad + pointer_offset / Policy::kElementsPerAccess;
|
||||
|
||||
@@ -34,6 +34,7 @@
|
||||
#endif
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/uint128.h"
|
||||
#include "cutlass/coord.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
@@ -724,7 +725,13 @@ double fast_log(double x) {
|
||||
CUTLASS_HOST_DEVICE
|
||||
float fast_tanh(float x) {
|
||||
#if defined(__CUDA_ARCH__)
|
||||
return ::tanhf(x);
|
||||
#if (__CUDACC_VER_MAJOR__ >= 11) && (__CUDA_ARCH__ >= 750)
|
||||
float y;
|
||||
asm volatile ( "tanh.approx.f32 %0, %1; " : "=f"(y) : "f"(x));
|
||||
return y;
|
||||
#else
|
||||
return ::tanhf(x);
|
||||
#endif
|
||||
#else
|
||||
return std::tanh(x);
|
||||
#endif
|
||||
@@ -739,6 +746,74 @@ double fast_tanh(double x) {
|
||||
#endif
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
half_t fast_tanh(half_t x) {
|
||||
#if defined(__CUDA_ARCH__) && (__CUDACC_VER_MAJOR__ >= 11) && (__CUDA_ARCH__ >= 750)
|
||||
|
||||
asm volatile ( "tanh.approx.f16 %0, %1;" : "=h"(x.raw()) : "h"(x.raw()));
|
||||
return x;
|
||||
|
||||
#else
|
||||
return half_t(fast_tanh(float(x)));
|
||||
#endif
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T>
|
||||
struct fast_tanh_op {
|
||||
CUTLASS_HOST_DEVICE
|
||||
T operator()(T const &rhs) const {
|
||||
return fast_tanh(rhs);
|
||||
}
|
||||
};
|
||||
|
||||
#if defined(__CUDA_ARCH__) && (__CUDACC_VER_MAJOR__ >= 11) && (__CUDA_ARCH__ >= 750)
|
||||
template <int N>
|
||||
struct fast_tanh_op<Array<half_t, N>> {
|
||||
CUTLASS_DEVICE
|
||||
Array<half_t, N> operator()(Array<half_t, N> const &rhs) const {
|
||||
|
||||
Array<half_t, N> result;
|
||||
|
||||
// use x2 specialization
|
||||
uint32_t const *in = reinterpret_cast<uint32_t const *>(&rhs);
|
||||
uint32_t *out = reinterpret_cast<uint32_t *>(&result);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < N / 2; ++i) {
|
||||
asm volatile ("tanh.approx.f16x2 %0, %1;" : "=r"(out[i]) : "r"(in[i]));
|
||||
}
|
||||
|
||||
// residual
|
||||
if (N % 2) {
|
||||
uint16_t const *in = reinterpret_cast<uint16_t const *>(&rhs);
|
||||
uint16_t *out = reinterpret_cast<uint16_t *>(&result);
|
||||
asm volatile ("tanh.approx.f16 %0, %1;" : "=h"(out[N - 1]) : "h"(in[N - 1]));
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
};
|
||||
#endif // #if defined(__CUDA_ARCH__)
|
||||
|
||||
template <typename T, int N>
|
||||
struct fast_tanh_op<Array<T, N>> {
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<T, N> operator()(Array<T, N> const &rhs) const {
|
||||
|
||||
fast_tanh_op<T> fast_op;
|
||||
Array<T, N> y;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < N; ++i) {
|
||||
y[i] = fast_op(rhs[i]);
|
||||
}
|
||||
|
||||
return y;
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass
|
||||
|
||||
@@ -126,7 +126,7 @@ struct DefaultGemmWithKReduction {
|
||||
ThreadblockShape, typename Mma::Operator, kPartitionsK, EpilogueOutputOp,
|
||||
EpilogueOutputOp::kCount>::Epilogue;
|
||||
|
||||
/// Define the epilogue
|
||||
/// Define the epilogue of the reduction vector
|
||||
using EpilogueGemmKReduction =
|
||||
typename cutlass::epilogue::threadblock::EpilogueGemmKReduction<
|
||||
ElementAccumulator, ElementC, ThreadblockShape, typename Mma::Operator, kReduceKForA>;
|
||||
|
||||
@@ -582,6 +582,13 @@ public:
|
||||
__threadfence();
|
||||
}
|
||||
|
||||
// Execute the epilogue operator to update the destination tensor.
|
||||
epilogue(
|
||||
output_op,
|
||||
iterator_D,
|
||||
accumulators,
|
||||
iterator_C);
|
||||
|
||||
if ((kReduceKForA && threadblock_tile_offset.n() == 0)
|
||||
|| (!kReduceKForA && threadblock_tile_offset.m() == 0)) {
|
||||
|
||||
@@ -610,14 +617,7 @@ public:
|
||||
&& (threadblock_tile_offset.k() > 0));
|
||||
}
|
||||
}
|
||||
|
||||
// Execute the epilogue operator to update the destination tensor.
|
||||
epilogue(
|
||||
output_op,
|
||||
iterator_D,
|
||||
accumulators,
|
||||
iterator_C);
|
||||
|
||||
|
||||
//
|
||||
// Release the semaphore
|
||||
//
|
||||
|
||||
@@ -378,11 +378,21 @@ struct DefaultMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB,
|
||||
|| platform::is_same<LayoutC, layout::AffineRankN<2>>::value,
|
||||
"simt epilogue must be row major");
|
||||
|
||||
static cutlass::arch::CacheOperation::Kind const CacheOpA =
|
||||
((sizeof_bits<ElementA>::value * kAlignmentA) == 128)
|
||||
? cutlass::arch::CacheOperation::Global
|
||||
: cutlass::arch::CacheOperation::Always;
|
||||
|
||||
static cutlass::arch::CacheOperation::Kind const CacheOpB =
|
||||
((sizeof_bits<ElementB>::value * kAlignmentB) == 128)
|
||||
? cutlass::arch::CacheOperation::Global
|
||||
: cutlass::arch::CacheOperation::Always;
|
||||
|
||||
// Define the MmaCore components
|
||||
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA,
|
||||
ElementB, LayoutB, ElementAccumulator, LayoutC, arch::OpClassSimt,
|
||||
Stages, Operator>;
|
||||
Stages, Operator, false, CacheOpA, CacheOpB>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
||||
|
||||
@@ -1111,8 +1111,8 @@ struct DefaultMmaCore<
|
||||
using ElementC = complex<double>;
|
||||
using LayoutC = LayoutC_;
|
||||
static int const kStages = Stages;
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always;
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always;
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Global;
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Global;
|
||||
static const ComplexTransform TransformA = TransformA_;
|
||||
static const ComplexTransform TransformB = TransformB_;
|
||||
|
||||
|
||||
@@ -116,11 +116,22 @@ struct DefaultMultistageMmaComplex<ElementA, LayoutA, ElementB, LayoutB,
|
||||
ElementAccumulator, layout::RowMajor, OperatorClass,
|
||||
ArchTag, ThreadblockShape, WarpShape,
|
||||
InstructionShape, Stages, TransformA, TransformB, Operator> {
|
||||
|
||||
static cutlass::arch::CacheOperation::Kind const CacheOpA =
|
||||
(sizeof_bits<ElementA>::value == 128)
|
||||
? cutlass::arch::CacheOperation::Global
|
||||
: cutlass::arch::CacheOperation::Always;
|
||||
|
||||
static cutlass::arch::CacheOperation::Kind const CacheOpB =
|
||||
(sizeof_bits<ElementB>::value == 128)
|
||||
? cutlass::arch::CacheOperation::Global
|
||||
: cutlass::arch::CacheOperation::Always;
|
||||
|
||||
// Define the MmaCore components
|
||||
using MmaCore = typename cutlass::gemm::threadblock::DefaultMultistageMmaComplexCore<
|
||||
ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA,
|
||||
ElementB, LayoutB, ElementAccumulator, layout::RowMajor, OperatorClass,
|
||||
Stages, TransformA, TransformB, Operator>;
|
||||
Stages, TransformA, TransformB, Operator, CacheOpA, CacheOpB>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
||||
|
||||
@@ -113,8 +113,8 @@ struct DefaultMultistageMmaComplexCore<
|
||||
static ComplexTransform const kTransformA = TransformA;
|
||||
static ComplexTransform const kTransformB = TransformB;
|
||||
using Operator = Operator_;
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always;
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always;
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Global;
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Global;
|
||||
|
||||
/// Number of warps present
|
||||
using WarpCount = GemmShape<Shape::kM / WarpShape::kM,
|
||||
@@ -242,8 +242,8 @@ struct DefaultMultistageMmaComplexCore<
|
||||
using Operator = Operator_;
|
||||
static ComplexTransform const kTransformA = TransformA;
|
||||
static ComplexTransform const kTransformB = TransformB;
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always;
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always;
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Global;
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Global;
|
||||
|
||||
/// Number of warps present
|
||||
using WarpCount = GemmShape<Shape::kM / WarpShape::kM,
|
||||
@@ -371,8 +371,8 @@ struct DefaultMultistageMmaComplexCore<
|
||||
static ComplexTransform const kTransformA = TransformA;
|
||||
static ComplexTransform const kTransformB = TransformB;
|
||||
using Operator = Operator_;
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always;
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always;
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Global;
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Global;
|
||||
|
||||
/// Number of warps present
|
||||
using WarpCount = GemmShape<Shape::kM / WarpShape::kM,
|
||||
@@ -501,8 +501,8 @@ struct DefaultMultistageMmaComplexCore<
|
||||
static ComplexTransform const kTransformA = TransformA;
|
||||
static ComplexTransform const kTransformB = TransformB;
|
||||
using Operator = Operator_;
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always;
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always;
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Global;
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Global;
|
||||
|
||||
/// Number of warps present
|
||||
using WarpCount = GemmShape<Shape::kM / WarpShape::kM,
|
||||
@@ -1159,8 +1159,8 @@ struct DefaultMultistageMmaComplexCore<
|
||||
static ComplexTransform const kTransformA = TransformA;
|
||||
static ComplexTransform const kTransformB = TransformB;
|
||||
using Operator = Operator_;
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always;
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always;
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA;
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB;
|
||||
|
||||
/// Number of warps present
|
||||
using WarpCount = GemmShape<Shape::kM / WarpShape::kM,
|
||||
@@ -1326,8 +1326,8 @@ struct DefaultMultistageMmaComplexCore<
|
||||
static ComplexTransform const kTransformA = TransformA;
|
||||
static ComplexTransform const kTransformB = TransformB;
|
||||
using Operator = Operator_;
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always;
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always;
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA;
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB;
|
||||
|
||||
/// Number of warps present
|
||||
using WarpCount = GemmShape<Shape::kM / WarpShape::kM,
|
||||
@@ -1490,8 +1490,8 @@ struct DefaultMultistageMmaComplexCore<
|
||||
static ComplexTransform const kTransformA = TransformA;
|
||||
static ComplexTransform const kTransformB = TransformB;
|
||||
using Operator = Operator_;
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always;
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always;
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA;
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB;
|
||||
|
||||
/// Number of warps present
|
||||
using WarpCount = GemmShape<Shape::kM / WarpShape::kM,
|
||||
@@ -1660,8 +1660,8 @@ struct DefaultMultistageMmaComplexCore<
|
||||
static ComplexTransform const kTransformA = TransformA;
|
||||
static ComplexTransform const kTransformB = TransformB;
|
||||
using Operator = Operator_;
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always;
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always;
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA;
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB;
|
||||
|
||||
/// Number of warps present
|
||||
using WarpCount = GemmShape<Shape::kM / WarpShape::kM,
|
||||
@@ -1775,7 +1775,6 @@ struct DefaultMultistageMmaComplexCore<
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
@@ -298,7 +298,6 @@ class PredicatedTileAccessIteratorPredicates {
|
||||
return pred;
|
||||
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
Reference in New Issue
Block a user