mirror of
https://github.com/NVIDIA/cutlass.git
synced 2026-05-11 08:50:09 +00:00
Relax stream K gemm alignment constraints (#717)
* Relax stream K gemm alignment constraints
The current alignment requirements are too strict. Make them identical
to the checks for the regular universal gemm.
* Revert "Relax stream K gemm alignment constraints"
This reverts commit 31e80a250e.
* Relax stream K gemm alignment constraints
The current alignment requirements are too strict. Make them identical
to the checks for the regular universal gemm.
Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
@@ -30,7 +30,7 @@
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief
|
||||
\brief
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
@@ -555,31 +555,73 @@ public:
|
||||
static Status can_implement(
|
||||
cutlass::gemm::GemmCoord const & problem_size)
|
||||
{
|
||||
CUTLASS_TRACE_HOST("GemmUniversalStreamk::can_implement()");
|
||||
CUTLASS_TRACE_HOST("GemmUniversal::can_implement()");
|
||||
|
||||
static int const kAlignmentA = (platform::is_same<typename Mma::IteratorA::Layout,
|
||||
static int const kAlignmentA = (platform::is_same<LayoutA,
|
||||
layout::ColumnMajorInterleaved<32>>::value)
|
||||
? 32
|
||||
: (platform::is_same<typename Mma::IteratorA::Layout,
|
||||
: (platform::is_same<LayoutA,
|
||||
layout::ColumnMajorInterleaved<64>>::value)
|
||||
? 64
|
||||
: Mma::IteratorA::AccessType::kElements;
|
||||
|
||||
static int const kAlignmentB = (platform::is_same<typename Mma::IteratorB::Layout,
|
||||
layout::RowMajorInterleaved<32>>::value)
|
||||
static int const kAlignmentB = (platform::is_same<LayoutB,
|
||||
layout::RowMajorInterleaved<32>>::value)
|
||||
? 32
|
||||
: (platform::is_same<typename Mma::IteratorB::Layout,
|
||||
: (platform::is_same<LayoutB,
|
||||
layout::RowMajorInterleaved<64>>::value)
|
||||
? 64
|
||||
: Mma::IteratorB::AccessType::kElements;
|
||||
static int const kAlignmentC = (platform::is_same<LayoutC,
|
||||
layout::ColumnMajorInterleaved<32>>::value)
|
||||
? 32
|
||||
: (platform::is_same<LayoutC,
|
||||
layout::ColumnMajorInterleaved<64>>::value)
|
||||
? 64
|
||||
: Epilogue::OutputTileIterator::kElementsPerAccess;
|
||||
|
||||
static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess;
|
||||
bool isAMisaligned = false;
|
||||
bool isBMisaligned = false;
|
||||
bool isCMisaligned = false;
|
||||
|
||||
if ((problem_size.m() % kAlignmentA) || (problem_size.k() % kAlignmentA) ||
|
||||
(problem_size.n() % kAlignmentB) || (problem_size.k() % kAlignmentB) ||
|
||||
(problem_size.m() % kAlignmentC) || (problem_size.n() % kAlignmentC))
|
||||
{
|
||||
CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand");
|
||||
if (platform::is_same<LayoutA, layout::RowMajor>::value) {
|
||||
isAMisaligned = problem_size.k() % kAlignmentA;
|
||||
} else if (platform::is_same<LayoutA, layout::ColumnMajor>::value) {
|
||||
isAMisaligned = problem_size.m() % kAlignmentA;
|
||||
} else if (platform::is_same<LayoutA, layout::ColumnMajorInterleaved<32>>::value
|
||||
|| platform::is_same<LayoutA, layout::ColumnMajorInterleaved<64>>::value) {
|
||||
isAMisaligned = problem_size.k() % kAlignmentA;
|
||||
}
|
||||
|
||||
if (platform::is_same<LayoutB, layout::RowMajor>::value) {
|
||||
isBMisaligned = problem_size.n() % kAlignmentB;
|
||||
} else if (platform::is_same<LayoutB, layout::ColumnMajor>::value) {
|
||||
isBMisaligned = problem_size.k() % kAlignmentB;
|
||||
} else if (platform::is_same<LayoutB, layout::RowMajorInterleaved<32>>::value
|
||||
|| platform::is_same<LayoutB, layout::RowMajorInterleaved<64>>::value) {
|
||||
isBMisaligned = problem_size.k() % kAlignmentB;
|
||||
}
|
||||
|
||||
if (platform::is_same<LayoutC, layout::RowMajor>::value) {
|
||||
isCMisaligned = problem_size.n() % kAlignmentC;
|
||||
} else if (platform::is_same<LayoutC, layout::ColumnMajor>::value) {
|
||||
isCMisaligned = problem_size.m() % kAlignmentC;
|
||||
} else if (platform::is_same<LayoutC, layout::ColumnMajorInterleaved<32>>::value
|
||||
|| platform::is_same<LayoutC, layout::ColumnMajorInterleaved<64>>::value) {
|
||||
isCMisaligned = problem_size.n() % kAlignmentC;
|
||||
}
|
||||
|
||||
if (isAMisaligned) {
|
||||
CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for A operand");
|
||||
return Status::kErrorMisalignedOperand;
|
||||
}
|
||||
|
||||
if (isBMisaligned) {
|
||||
CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for B operand");
|
||||
return Status::kErrorMisalignedOperand;
|
||||
}
|
||||
|
||||
if (isCMisaligned) {
|
||||
CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for C operand");
|
||||
return Status::kErrorMisalignedOperand;
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user