Relax stream K gemm alignment constraints

The current alignment requirements are too strict. Make them identical
to the checks for the regular universal gemm.
This commit is contained in:
mikeiovine
2022-12-02 07:49:27 -08:00
parent c975e2ccbb
commit 31e80a250e
3 changed files with 162 additions and 81 deletions

View File

@@ -30,7 +30,7 @@
**************************************************************************************************/
/*! \file
\brief
\brief
*/
#pragma once
@@ -575,11 +575,49 @@ public:
static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess;
if ((problem_size.m() % kAlignmentA) || (problem_size.k() % kAlignmentA) ||
(problem_size.n() % kAlignmentB) || (problem_size.k() % kAlignmentB) ||
(problem_size.m() % kAlignmentC) || (problem_size.n() % kAlignmentC))
{
CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand");
bool isAMisaligned = false;
bool isBMisaligned = false;
bool isCMisaligned = false;
if (platform::is_same<LayoutA, layout::RowMajor>::value) {
isAMisaligned = problem_size.k() % kAlignmentA;
} else if (platform::is_same<LayoutA, layout::ColumnMajor>::value) {
isAMisaligned = problem_size.m() % kAlignmentA;
} else if (platform::is_same<LayoutA, layout::ColumnMajorInterleaved<32>>::value
|| platform::is_same<LayoutA, layout::ColumnMajorInterleaved<64>>::value) {
isAMisaligned = problem_size.k() % kAlignmentA;
}
if (platform::is_same<LayoutB, layout::RowMajor>::value) {
isBMisaligned = problem_size.n() % kAlignmentB;
} else if (platform::is_same<LayoutB, layout::ColumnMajor>::value) {
isBMisaligned = problem_size.k() % kAlignmentB;
} else if (platform::is_same<LayoutB, layout::RowMajorInterleaved<32>>::value
|| platform::is_same<LayoutB, layout::RowMajorInterleaved<64>>::value) {
isBMisaligned = problem_size.k() % kAlignmentB;
}
if (platform::is_same<LayoutC, layout::RowMajor>::value) {
isCMisaligned = problem_size.n() % kAlignmentC;
} else if (platform::is_same<LayoutC, layout::ColumnMajor>::value) {
isCMisaligned = problem_size.m() % kAlignmentC;
} else if (platform::is_same<LayoutC, layout::ColumnMajorInterleaved<32>>::value
|| platform::is_same<LayoutC, layout::ColumnMajorInterleaved<64>>::value) {
isCMisaligned = problem_size.n() % kAlignmentC;
}
if (isAMisaligned) {
CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for A operand");
return Status::kErrorMisalignedOperand;
}
if (isBMisaligned) {
CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for B operand");
return Status::kErrorMisalignedOperand;
}
if (isCMisaligned) {
CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for C operand");
return Status::kErrorMisalignedOperand;
}

View File

@@ -33,10 +33,12 @@
*/
#include <iostream>
#include <vector>
#include "../../common/cutlass_unit_test.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/device/gemm.h"
#include "cutlass/gemm/device/gemm_universal.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/reference/host/gemm.h"
#include "cutlass/util/reference/host/tensor_compare.h"
@@ -48,7 +50,7 @@
#include "testbed_universal.h"
#if (CUTLASS_ARCH_MMA_SM80_SUPPORTED)
////////////////////////////////////////////////////////////////////////////////
CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 128x256x64_64x64x64, {
@@ -143,6 +145,37 @@ CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16t_f32n_tensor_op_f32, 128x128x64_64x64x
EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal<Gemm>());
} )
CUTLASS_TEST_L1(SM80_Device_Gemm_f16t_f16n_f32t_tensor_op_f32, 128x128x64_64x64x64_sk, {
using ElementOutput = float;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::GemmUniversal<
cutlass::half_t, cutlass::layout::RowMajor,
cutlass::half_t, cutlass::layout::ColumnMajor,
ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 128, 64>,
cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::ThreadblockSwizzleStreamK, 3>;
// Some custom problem sizes to test relaxed alignment
std::vector<int> problem_m{
1, 3
};
std::vector<int> problem_n{
4,
};
std::vector<int> problem_k{
512,
};
EXPECT_TRUE(test::gemm::device::TestAllGemmUniversalWithCustomProblemSizes<Gemm>(
problem_m, problem_n, problem_k
));
} )
CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 256x64x64_64x64x64, {
using ElementOutput = float;
using ElementAccumulator = float;
@@ -381,4 +414,3 @@ CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 64x64x32_32x32x32
////////////////////////////////////////////////////////////////////////////////
#endif // CUTLASS_ARCH_MMA_SM80_SUPPORTED

View File

@@ -91,7 +91,7 @@ struct TestbedUniversal {
/// Helper to initialize a tensor view
template <typename Element, typename Layout>
bool initialize_tensor(
cutlass::TensorView<Element, Layout> view,
cutlass::TensorView<Element, Layout> view,
cutlass::Distribution::Kind dist_kind,
uint64_t seed) {
@@ -117,11 +117,11 @@ struct TestbedUniversal {
cutlass::reference::host::TensorFillRandomUniform(
view, seed, scope_max, scope_min, 0);
}
}
else if (dist_kind == cutlass::Distribution::Identity) {
cutlass::reference::host::TensorFillIdentity(view);
}
}
else if (dist_kind == cutlass::Distribution::Gaussian) {
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
@@ -130,7 +130,7 @@ struct TestbedUniversal {
cutlass::reference::host::BlockFillSequential(
view.data(), view.capacity());
}
}
else {
// TODO: Implement the rest
EXPECT_TRUE(false) << "Not implemented";
@@ -173,8 +173,8 @@ struct TestbedUniversal {
/// Compares computed reference with device reference and outputs to a file if incorrect
bool compare_reference(
cutlass::gemm::GemmCoord problem_size,
ElementCompute alpha,
cutlass::gemm::GemmCoord problem_size,
ElementCompute alpha,
ElementCompute beta) {
tensor_D.sync_host();
@@ -182,7 +182,7 @@ struct TestbedUniversal {
EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_A.host_view()), 0);
EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_B.host_view()), 0);
EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_C.host_view()), 0);
EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D.host_view()), 0);
EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0);
@@ -199,11 +199,11 @@ struct TestbedUniversal {
<< problem_size.m() << "x"
<< problem_size.n() << "x"
<< problem_size.k() << "_"
<< Gemm::ThreadblockShape::kM << "x"
<< Gemm::ThreadblockShape::kN << "x"
<< Gemm::ThreadblockShape::kM << "x"
<< Gemm::ThreadblockShape::kN << "x"
<< Gemm::ThreadblockShape::kK << "_"
<< Gemm::WarpShape::kM << "x"
<< Gemm::WarpShape::kN << "x"
<< Gemm::WarpShape::kM << "x"
<< Gemm::WarpShape::kN << "x"
<< Gemm::WarpShape::kK << ".txt";
std::ofstream file(fname.str());
@@ -212,10 +212,10 @@ struct TestbedUniversal {
std::ofstream file("testbed_universal_errors.txt");
file
<< "problem: " << problem_size
<< "problem: " << problem_size
<< ", alpha: " << alpha << ", beta: " << beta << "\n\n";
file
file
<< "A =\n" << tensor_A.host_view()
<< "\nB =\n" << tensor_B.host_view()
<< "\nC =\n" << tensor_C.host_view()
@@ -228,8 +228,8 @@ struct TestbedUniversal {
/// Verifies the result is a GEMM
bool verify(
cutlass::gemm::GemmCoord problem_size,
ElementCompute alpha,
cutlass::gemm::GemmCoord problem_size,
ElementCompute alpha,
ElementCompute beta) {
//
@@ -239,18 +239,18 @@ struct TestbedUniversal {
cutlass::reference::host::GemmComplex<
typename Gemm::ElementA, typename Gemm::LayoutA,
typename Gemm::ElementB, typename Gemm::LayoutB,
typename Gemm::ElementC, typename Gemm::LayoutC,
typename Gemm::ElementC, typename Gemm::LayoutC,
ElementCompute, ElementAccumulator
>(
problem_size,
alpha,
alpha,
tensor_A.host_ref(),
Gemm::kTransformA,
tensor_B.host_ref(),
Gemm::kTransformB,
beta,
tensor_C.host_ref(),
reference_D.host_ref(),
beta,
tensor_C.host_ref(),
reference_D.host_ref(),
ElementAccumulator(0)
);
@@ -354,9 +354,11 @@ struct TestbedUniversal {
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
cutlass::Status status = gemm_op.initialize(arguments, workspace.get());
cutlass::Status status = gemm_op.can_implement(arguments);
EXPECT_EQ(status, cutlass::Status::kSuccess) << to_string(status);
EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status);
status = gemm_op.initialize(arguments, workspace.get());
EXPECT_EQ(status, cutlass::Status::kSuccess) << to_string(status);
//
// Run the GEMM
@@ -386,20 +388,20 @@ bool TestGemmUniversal(
cutlass::gemm::GemmCoord const & problem_size,
cutlass::gemm::GemmUniversalMode mode,
int batch_count,
double alpha = 1.0,
double alpha = 1.0,
double beta = 2.0) {
bool passed = true;
TestbedUniversal<Gemm, Relu> testbed;
using ElementCompute = typename Gemm::EpilogueOutputOp::ElementCompute;
passed = testbed.run(
mode,
problem_size,
problem_size,
batch_count,
cutlass::from_real<ElementCompute>(alpha),
cutlass::from_real<ElementCompute>(alpha),
cutlass::from_real<ElementCompute>(beta)
);
@@ -407,54 +409,17 @@ bool TestGemmUniversal(
}
template <typename Gemm, bool Relu = false>
bool TestAllGemmUniversal() {
bool TestAllGemmUniversalWithCustomProblemSizes(
const std::vector<int>& problem_size_m,
const std::vector<int>& problem_size_n,
const std::vector<int>& problem_size_k
) {
bool passed = true;
int const kMinimumOperandElementSize =
std::min(
int(cutlass::sizeof_bits<typename Gemm::ElementA>::value),
int(cutlass::sizeof_bits<typename Gemm::ElementB>::value));
int const kAlignment = cutlass::platform::is_same<
typename Gemm::OperatorClass,
cutlass::arch::OpClassSimt>::value ? 1 : 128 / kMinimumOperandElementSize;
// int8_t gemm alignment constraints
int const kAlignmentM = cutlass::platform::is_same<typename Gemm::OperatorClass, cutlass::arch::OpClassSimt>::value &&
cutlass::platform::is_same<typename Gemm::ElementA, int8_t>::value &&
cutlass::platform::is_same<typename Gemm::LayoutA, cutlass::layout::ColumnMajor>::value ? 4 : kAlignment;
int const kAlignmentN = cutlass::platform::is_same<typename Gemm::OperatorClass, cutlass::arch::OpClassSimt>::value &&
cutlass::platform::is_same<typename Gemm::ElementB, int8_t>::value &&
cutlass::platform::is_same<typename Gemm::LayoutB, cutlass::layout::RowMajor>::value ? 4 : kAlignment;
int const kAlignmentK = cutlass::platform::is_same<typename Gemm::OperatorClass, cutlass::arch::OpClassSimt>::value &&
cutlass::platform::is_same<typename Gemm::ElementA, int8_t>::value &&
cutlass::platform::is_same<typename Gemm::ElementB, int8_t>::value &&
(cutlass::platform::is_same<typename Gemm::LayoutA, cutlass::layout::RowMajor>::value ||
cutlass::platform::is_same<typename Gemm::LayoutB, cutlass::layout::ColumnMajor>::value) ? 4 : kAlignment;
cutlass::gemm::GemmUniversalMode modes[] = {
cutlass::gemm::GemmUniversalMode::kGemm,
};
int problem_size_m[] = {
kAlignmentM, 512 - 3*kAlignmentM
};
int problem_size_n[] = {
kAlignmentN, 512 - 2*kAlignmentN
};
int problem_size_k[] = {
kAlignmentK,
Gemm::ThreadblockShape::kK * Gemm::kStages - kAlignmentK,
Gemm::ThreadblockShape::kK * Gemm::kStages * 3 - kAlignmentK
};
int batch_counts[] = { // may be interpretted as batch count or split-K slices
1, 2, 3, 5, 7
};
@@ -494,9 +459,9 @@ bool TestAllGemmUniversal() {
passed = testbed.run(
mode,
problem_size,
problem_size,
batch_count,
cutlass::from_real<ElementCompute>(alpha),
cutlass::from_real<ElementCompute>(alpha),
cutlass::from_real<ElementCompute>(beta)
);
@@ -520,9 +485,9 @@ bool TestAllGemmUniversal() {
passed = testbed.run(
cutlass::gemm::GemmUniversalMode::kGemm,
problem_size,
problem_size,
split_k_slices,
cutlass::from_real<ElementCompute>(1.0),
cutlass::from_real<ElementCompute>(1.0),
cutlass::from_real<ElementCompute>(2.0)
);
@@ -535,6 +500,53 @@ bool TestAllGemmUniversal() {
return passed;
}
template <typename Gemm, bool Relu = false>
bool TestAllGemmUniversal() {
int const kMinimumOperandElementSize =
std::min(
int(cutlass::sizeof_bits<typename Gemm::ElementA>::value),
int(cutlass::sizeof_bits<typename Gemm::ElementB>::value));
int const kAlignment = cutlass::platform::is_same<
typename Gemm::OperatorClass,
cutlass::arch::OpClassSimt>::value ? 1 : 128 / kMinimumOperandElementSize;
// int8_t gemm alignment constraints
int const kAlignmentM = cutlass::platform::is_same<typename Gemm::OperatorClass, cutlass::arch::OpClassSimt>::value &&
cutlass::platform::is_same<typename Gemm::ElementA, int8_t>::value &&
cutlass::platform::is_same<typename Gemm::LayoutA, cutlass::layout::ColumnMajor>::value ? 4 : kAlignment;
int const kAlignmentN = cutlass::platform::is_same<typename Gemm::OperatorClass, cutlass::arch::OpClassSimt>::value &&
cutlass::platform::is_same<typename Gemm::ElementB, int8_t>::value &&
cutlass::platform::is_same<typename Gemm::LayoutB, cutlass::layout::RowMajor>::value ? 4 : kAlignment;
int const kAlignmentK = cutlass::platform::is_same<typename Gemm::OperatorClass, cutlass::arch::OpClassSimt>::value &&
cutlass::platform::is_same<typename Gemm::ElementA, int8_t>::value &&
cutlass::platform::is_same<typename Gemm::ElementB, int8_t>::value &&
(cutlass::platform::is_same<typename Gemm::LayoutA, cutlass::layout::RowMajor>::value ||
cutlass::platform::is_same<typename Gemm::LayoutB, cutlass::layout::ColumnMajor>::value) ? 4 : kAlignment;
std::vector<int> problem_size_m{
kAlignmentM, 512 - 3*kAlignmentM
};
std::vector<int> problem_size_n{
kAlignmentN, 512 - 2*kAlignmentN
};
std::vector<int> problem_size_k{
kAlignmentK,
Gemm::ThreadblockShape::kK * Gemm::kStages - kAlignmentK,
Gemm::ThreadblockShape::kK * Gemm::kStages * 3 - kAlignmentK
};
return TestAllGemmUniversalWithCustomProblemSizes<Gemm, Relu>(
problem_size_m, problem_size_n, problem_size_k
);
}
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace device
@@ -542,4 +554,3 @@ bool TestAllGemmUniversal() {
} // namespace test
/////////////////////////////////////////////////////////////////////////////////////////////////