mirror of
https://github.com/NVIDIA/cutlass.git
synced 2026-05-11 17:00:05 +00:00
Relax stream K gemm alignment constraints
The current alignment requirements are too strict. Make them identical to the checks for the regular universal gemm.
This commit is contained in:
@@ -30,7 +30,7 @@
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief
|
||||
\brief
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
@@ -575,11 +575,49 @@ public:
|
||||
|
||||
static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess;
|
||||
|
||||
if ((problem_size.m() % kAlignmentA) || (problem_size.k() % kAlignmentA) ||
|
||||
(problem_size.n() % kAlignmentB) || (problem_size.k() % kAlignmentB) ||
|
||||
(problem_size.m() % kAlignmentC) || (problem_size.n() % kAlignmentC))
|
||||
{
|
||||
CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand");
|
||||
bool isAMisaligned = false;
|
||||
bool isBMisaligned = false;
|
||||
bool isCMisaligned = false;
|
||||
|
||||
if (platform::is_same<LayoutA, layout::RowMajor>::value) {
|
||||
isAMisaligned = problem_size.k() % kAlignmentA;
|
||||
} else if (platform::is_same<LayoutA, layout::ColumnMajor>::value) {
|
||||
isAMisaligned = problem_size.m() % kAlignmentA;
|
||||
} else if (platform::is_same<LayoutA, layout::ColumnMajorInterleaved<32>>::value
|
||||
|| platform::is_same<LayoutA, layout::ColumnMajorInterleaved<64>>::value) {
|
||||
isAMisaligned = problem_size.k() % kAlignmentA;
|
||||
}
|
||||
|
||||
if (platform::is_same<LayoutB, layout::RowMajor>::value) {
|
||||
isBMisaligned = problem_size.n() % kAlignmentB;
|
||||
} else if (platform::is_same<LayoutB, layout::ColumnMajor>::value) {
|
||||
isBMisaligned = problem_size.k() % kAlignmentB;
|
||||
} else if (platform::is_same<LayoutB, layout::RowMajorInterleaved<32>>::value
|
||||
|| platform::is_same<LayoutB, layout::RowMajorInterleaved<64>>::value) {
|
||||
isBMisaligned = problem_size.k() % kAlignmentB;
|
||||
}
|
||||
|
||||
if (platform::is_same<LayoutC, layout::RowMajor>::value) {
|
||||
isCMisaligned = problem_size.n() % kAlignmentC;
|
||||
} else if (platform::is_same<LayoutC, layout::ColumnMajor>::value) {
|
||||
isCMisaligned = problem_size.m() % kAlignmentC;
|
||||
} else if (platform::is_same<LayoutC, layout::ColumnMajorInterleaved<32>>::value
|
||||
|| platform::is_same<LayoutC, layout::ColumnMajorInterleaved<64>>::value) {
|
||||
isCMisaligned = problem_size.n() % kAlignmentC;
|
||||
}
|
||||
|
||||
if (isAMisaligned) {
|
||||
CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for A operand");
|
||||
return Status::kErrorMisalignedOperand;
|
||||
}
|
||||
|
||||
if (isBMisaligned) {
|
||||
CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for B operand");
|
||||
return Status::kErrorMisalignedOperand;
|
||||
}
|
||||
|
||||
if (isCMisaligned) {
|
||||
CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for C operand");
|
||||
return Status::kErrorMisalignedOperand;
|
||||
}
|
||||
|
||||
|
||||
@@ -33,10 +33,12 @@
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
|
||||
#include "../../common/cutlass_unit_test.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/device/gemm.h"
|
||||
#include "cutlass/gemm/device/gemm_universal.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/reference/host/gemm.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
@@ -48,7 +50,7 @@
|
||||
#include "testbed_universal.h"
|
||||
|
||||
#if (CUTLASS_ARCH_MMA_SM80_SUPPORTED)
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 128x256x64_64x64x64, {
|
||||
@@ -143,6 +145,37 @@ CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16t_f32n_tensor_op_f32, 128x128x64_64x64x
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal<Gemm>());
|
||||
} )
|
||||
|
||||
CUTLASS_TEST_L1(SM80_Device_Gemm_f16t_f16n_f32t_tensor_op_f32, 128x128x64_64x64x64_sk, {
|
||||
using ElementOutput = float;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversal<
|
||||
cutlass::half_t, cutlass::layout::RowMajor,
|
||||
cutlass::half_t, cutlass::layout::ColumnMajor,
|
||||
ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<128, 128, 64>,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::ThreadblockSwizzleStreamK, 3>;
|
||||
|
||||
// Some custom problem sizes to test relaxed alignment
|
||||
std::vector<int> problem_m{
|
||||
1, 3
|
||||
};
|
||||
std::vector<int> problem_n{
|
||||
4,
|
||||
};
|
||||
std::vector<int> problem_k{
|
||||
512,
|
||||
};
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemmUniversalWithCustomProblemSizes<Gemm>(
|
||||
problem_m, problem_n, problem_k
|
||||
));
|
||||
} )
|
||||
|
||||
CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 256x64x64_64x64x64, {
|
||||
using ElementOutput = float;
|
||||
using ElementAccumulator = float;
|
||||
@@ -381,4 +414,3 @@ CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 64x64x32_32x32x32
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#endif // CUTLASS_ARCH_MMA_SM80_SUPPORTED
|
||||
|
||||
|
||||
@@ -91,7 +91,7 @@ struct TestbedUniversal {
|
||||
/// Helper to initialize a tensor view
|
||||
template <typename Element, typename Layout>
|
||||
bool initialize_tensor(
|
||||
cutlass::TensorView<Element, Layout> view,
|
||||
cutlass::TensorView<Element, Layout> view,
|
||||
cutlass::Distribution::Kind dist_kind,
|
||||
uint64_t seed) {
|
||||
|
||||
@@ -117,11 +117,11 @@ struct TestbedUniversal {
|
||||
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
view, seed, scope_max, scope_min, 0);
|
||||
}
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Identity) {
|
||||
|
||||
cutlass::reference::host::TensorFillIdentity(view);
|
||||
}
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Gaussian) {
|
||||
|
||||
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
|
||||
@@ -130,7 +130,7 @@ struct TestbedUniversal {
|
||||
|
||||
cutlass::reference::host::BlockFillSequential(
|
||||
view.data(), view.capacity());
|
||||
}
|
||||
}
|
||||
else {
|
||||
// TODO: Implement the rest
|
||||
EXPECT_TRUE(false) << "Not implemented";
|
||||
@@ -173,8 +173,8 @@ struct TestbedUniversal {
|
||||
|
||||
/// Compares computed reference with device reference and outputs to a file if incorrect
|
||||
bool compare_reference(
|
||||
cutlass::gemm::GemmCoord problem_size,
|
||||
ElementCompute alpha,
|
||||
cutlass::gemm::GemmCoord problem_size,
|
||||
ElementCompute alpha,
|
||||
ElementCompute beta) {
|
||||
|
||||
tensor_D.sync_host();
|
||||
@@ -182,7 +182,7 @@ struct TestbedUniversal {
|
||||
EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_A.host_view()), 0);
|
||||
EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_B.host_view()), 0);
|
||||
EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_C.host_view()), 0);
|
||||
|
||||
|
||||
EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D.host_view()), 0);
|
||||
EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0);
|
||||
|
||||
@@ -199,11 +199,11 @@ struct TestbedUniversal {
|
||||
<< problem_size.m() << "x"
|
||||
<< problem_size.n() << "x"
|
||||
<< problem_size.k() << "_"
|
||||
<< Gemm::ThreadblockShape::kM << "x"
|
||||
<< Gemm::ThreadblockShape::kN << "x"
|
||||
<< Gemm::ThreadblockShape::kM << "x"
|
||||
<< Gemm::ThreadblockShape::kN << "x"
|
||||
<< Gemm::ThreadblockShape::kK << "_"
|
||||
<< Gemm::WarpShape::kM << "x"
|
||||
<< Gemm::WarpShape::kN << "x"
|
||||
<< Gemm::WarpShape::kM << "x"
|
||||
<< Gemm::WarpShape::kN << "x"
|
||||
<< Gemm::WarpShape::kK << ".txt";
|
||||
|
||||
std::ofstream file(fname.str());
|
||||
@@ -212,10 +212,10 @@ struct TestbedUniversal {
|
||||
std::ofstream file("testbed_universal_errors.txt");
|
||||
|
||||
file
|
||||
<< "problem: " << problem_size
|
||||
<< "problem: " << problem_size
|
||||
<< ", alpha: " << alpha << ", beta: " << beta << "\n\n";
|
||||
|
||||
file
|
||||
file
|
||||
<< "A =\n" << tensor_A.host_view()
|
||||
<< "\nB =\n" << tensor_B.host_view()
|
||||
<< "\nC =\n" << tensor_C.host_view()
|
||||
@@ -228,8 +228,8 @@ struct TestbedUniversal {
|
||||
|
||||
/// Verifies the result is a GEMM
|
||||
bool verify(
|
||||
cutlass::gemm::GemmCoord problem_size,
|
||||
ElementCompute alpha,
|
||||
cutlass::gemm::GemmCoord problem_size,
|
||||
ElementCompute alpha,
|
||||
ElementCompute beta) {
|
||||
|
||||
//
|
||||
@@ -239,18 +239,18 @@ struct TestbedUniversal {
|
||||
cutlass::reference::host::GemmComplex<
|
||||
typename Gemm::ElementA, typename Gemm::LayoutA,
|
||||
typename Gemm::ElementB, typename Gemm::LayoutB,
|
||||
typename Gemm::ElementC, typename Gemm::LayoutC,
|
||||
typename Gemm::ElementC, typename Gemm::LayoutC,
|
||||
ElementCompute, ElementAccumulator
|
||||
>(
|
||||
problem_size,
|
||||
alpha,
|
||||
alpha,
|
||||
tensor_A.host_ref(),
|
||||
Gemm::kTransformA,
|
||||
tensor_B.host_ref(),
|
||||
Gemm::kTransformB,
|
||||
beta,
|
||||
tensor_C.host_ref(),
|
||||
reference_D.host_ref(),
|
||||
beta,
|
||||
tensor_C.host_ref(),
|
||||
reference_D.host_ref(),
|
||||
ElementAccumulator(0)
|
||||
);
|
||||
|
||||
@@ -354,9 +354,11 @@ struct TestbedUniversal {
|
||||
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
cutlass::Status status = gemm_op.initialize(arguments, workspace.get());
|
||||
cutlass::Status status = gemm_op.can_implement(arguments);
|
||||
EXPECT_EQ(status, cutlass::Status::kSuccess) << to_string(status);
|
||||
|
||||
EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status);
|
||||
status = gemm_op.initialize(arguments, workspace.get());
|
||||
EXPECT_EQ(status, cutlass::Status::kSuccess) << to_string(status);
|
||||
|
||||
//
|
||||
// Run the GEMM
|
||||
@@ -386,20 +388,20 @@ bool TestGemmUniversal(
|
||||
cutlass::gemm::GemmCoord const & problem_size,
|
||||
cutlass::gemm::GemmUniversalMode mode,
|
||||
int batch_count,
|
||||
double alpha = 1.0,
|
||||
double alpha = 1.0,
|
||||
double beta = 2.0) {
|
||||
|
||||
bool passed = true;
|
||||
|
||||
TestbedUniversal<Gemm, Relu> testbed;
|
||||
|
||||
|
||||
using ElementCompute = typename Gemm::EpilogueOutputOp::ElementCompute;
|
||||
|
||||
passed = testbed.run(
|
||||
mode,
|
||||
problem_size,
|
||||
problem_size,
|
||||
batch_count,
|
||||
cutlass::from_real<ElementCompute>(alpha),
|
||||
cutlass::from_real<ElementCompute>(alpha),
|
||||
cutlass::from_real<ElementCompute>(beta)
|
||||
);
|
||||
|
||||
@@ -407,54 +409,17 @@ bool TestGemmUniversal(
|
||||
}
|
||||
|
||||
template <typename Gemm, bool Relu = false>
|
||||
bool TestAllGemmUniversal() {
|
||||
bool TestAllGemmUniversalWithCustomProblemSizes(
|
||||
const std::vector<int>& problem_size_m,
|
||||
const std::vector<int>& problem_size_n,
|
||||
const std::vector<int>& problem_size_k
|
||||
) {
|
||||
bool passed = true;
|
||||
|
||||
|
||||
int const kMinimumOperandElementSize =
|
||||
std::min(
|
||||
int(cutlass::sizeof_bits<typename Gemm::ElementA>::value),
|
||||
int(cutlass::sizeof_bits<typename Gemm::ElementB>::value));
|
||||
|
||||
int const kAlignment = cutlass::platform::is_same<
|
||||
typename Gemm::OperatorClass,
|
||||
cutlass::arch::OpClassSimt>::value ? 1 : 128 / kMinimumOperandElementSize;
|
||||
|
||||
// int8_t gemm alignment constraints
|
||||
int const kAlignmentM = cutlass::platform::is_same<typename Gemm::OperatorClass, cutlass::arch::OpClassSimt>::value &&
|
||||
cutlass::platform::is_same<typename Gemm::ElementA, int8_t>::value &&
|
||||
cutlass::platform::is_same<typename Gemm::LayoutA, cutlass::layout::ColumnMajor>::value ? 4 : kAlignment;
|
||||
|
||||
int const kAlignmentN = cutlass::platform::is_same<typename Gemm::OperatorClass, cutlass::arch::OpClassSimt>::value &&
|
||||
cutlass::platform::is_same<typename Gemm::ElementB, int8_t>::value &&
|
||||
cutlass::platform::is_same<typename Gemm::LayoutB, cutlass::layout::RowMajor>::value ? 4 : kAlignment;
|
||||
|
||||
int const kAlignmentK = cutlass::platform::is_same<typename Gemm::OperatorClass, cutlass::arch::OpClassSimt>::value &&
|
||||
cutlass::platform::is_same<typename Gemm::ElementA, int8_t>::value &&
|
||||
cutlass::platform::is_same<typename Gemm::ElementB, int8_t>::value &&
|
||||
(cutlass::platform::is_same<typename Gemm::LayoutA, cutlass::layout::RowMajor>::value ||
|
||||
cutlass::platform::is_same<typename Gemm::LayoutB, cutlass::layout::ColumnMajor>::value) ? 4 : kAlignment;
|
||||
|
||||
|
||||
|
||||
cutlass::gemm::GemmUniversalMode modes[] = {
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
};
|
||||
|
||||
int problem_size_m[] = {
|
||||
kAlignmentM, 512 - 3*kAlignmentM
|
||||
};
|
||||
|
||||
int problem_size_n[] = {
|
||||
kAlignmentN, 512 - 2*kAlignmentN
|
||||
};
|
||||
|
||||
int problem_size_k[] = {
|
||||
kAlignmentK,
|
||||
Gemm::ThreadblockShape::kK * Gemm::kStages - kAlignmentK,
|
||||
Gemm::ThreadblockShape::kK * Gemm::kStages * 3 - kAlignmentK
|
||||
};
|
||||
|
||||
int batch_counts[] = { // may be interpretted as batch count or split-K slices
|
||||
1, 2, 3, 5, 7
|
||||
};
|
||||
@@ -494,9 +459,9 @@ bool TestAllGemmUniversal() {
|
||||
|
||||
passed = testbed.run(
|
||||
mode,
|
||||
problem_size,
|
||||
problem_size,
|
||||
batch_count,
|
||||
cutlass::from_real<ElementCompute>(alpha),
|
||||
cutlass::from_real<ElementCompute>(alpha),
|
||||
cutlass::from_real<ElementCompute>(beta)
|
||||
);
|
||||
|
||||
@@ -520,9 +485,9 @@ bool TestAllGemmUniversal() {
|
||||
|
||||
passed = testbed.run(
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
problem_size,
|
||||
problem_size,
|
||||
split_k_slices,
|
||||
cutlass::from_real<ElementCompute>(1.0),
|
||||
cutlass::from_real<ElementCompute>(1.0),
|
||||
cutlass::from_real<ElementCompute>(2.0)
|
||||
);
|
||||
|
||||
@@ -535,6 +500,53 @@ bool TestAllGemmUniversal() {
|
||||
return passed;
|
||||
}
|
||||
|
||||
template <typename Gemm, bool Relu = false>
|
||||
bool TestAllGemmUniversal() {
|
||||
int const kMinimumOperandElementSize =
|
||||
std::min(
|
||||
int(cutlass::sizeof_bits<typename Gemm::ElementA>::value),
|
||||
int(cutlass::sizeof_bits<typename Gemm::ElementB>::value));
|
||||
|
||||
int const kAlignment = cutlass::platform::is_same<
|
||||
typename Gemm::OperatorClass,
|
||||
cutlass::arch::OpClassSimt>::value ? 1 : 128 / kMinimumOperandElementSize;
|
||||
|
||||
// int8_t gemm alignment constraints
|
||||
int const kAlignmentM = cutlass::platform::is_same<typename Gemm::OperatorClass, cutlass::arch::OpClassSimt>::value &&
|
||||
cutlass::platform::is_same<typename Gemm::ElementA, int8_t>::value &&
|
||||
cutlass::platform::is_same<typename Gemm::LayoutA, cutlass::layout::ColumnMajor>::value ? 4 : kAlignment;
|
||||
|
||||
int const kAlignmentN = cutlass::platform::is_same<typename Gemm::OperatorClass, cutlass::arch::OpClassSimt>::value &&
|
||||
cutlass::platform::is_same<typename Gemm::ElementB, int8_t>::value &&
|
||||
cutlass::platform::is_same<typename Gemm::LayoutB, cutlass::layout::RowMajor>::value ? 4 : kAlignment;
|
||||
|
||||
int const kAlignmentK = cutlass::platform::is_same<typename Gemm::OperatorClass, cutlass::arch::OpClassSimt>::value &&
|
||||
cutlass::platform::is_same<typename Gemm::ElementA, int8_t>::value &&
|
||||
cutlass::platform::is_same<typename Gemm::ElementB, int8_t>::value &&
|
||||
(cutlass::platform::is_same<typename Gemm::LayoutA, cutlass::layout::RowMajor>::value ||
|
||||
cutlass::platform::is_same<typename Gemm::LayoutB, cutlass::layout::ColumnMajor>::value) ? 4 : kAlignment;
|
||||
|
||||
|
||||
|
||||
std::vector<int> problem_size_m{
|
||||
kAlignmentM, 512 - 3*kAlignmentM
|
||||
};
|
||||
|
||||
std::vector<int> problem_size_n{
|
||||
kAlignmentN, 512 - 2*kAlignmentN
|
||||
};
|
||||
|
||||
std::vector<int> problem_size_k{
|
||||
kAlignmentK,
|
||||
Gemm::ThreadblockShape::kK * Gemm::kStages - kAlignmentK,
|
||||
Gemm::ThreadblockShape::kK * Gemm::kStages * 3 - kAlignmentK
|
||||
};
|
||||
|
||||
return TestAllGemmUniversalWithCustomProblemSizes<Gemm, Relu>(
|
||||
problem_size_m, problem_size_n, problem_size_k
|
||||
);
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace device
|
||||
@@ -542,4 +554,3 @@ bool TestAllGemmUniversal() {
|
||||
} // namespace test
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
Reference in New Issue
Block a user