From 31e80a250e2b0ac4bda2e4b437b39dc5bcd5e845 Mon Sep 17 00:00:00 2001 From: mikeiovine Date: Fri, 2 Dec 2022 07:49:27 -0800 Subject: [PATCH] Relax stream K gemm alignment constraints The current alignment requirements are too strict. Make them identical to the checks for the regular universal gemm. --- .../gemm/kernel/gemm_universal_streamk.h | 50 +++++- .../gemm_f16n_f16t_f32t_tensor_op_f32_sm80.cu | 36 +++- test/unit/gemm/device/testbed_universal.h | 157 ++++++++++-------- 3 files changed, 162 insertions(+), 81 deletions(-) diff --git a/include/cutlass/gemm/kernel/gemm_universal_streamk.h b/include/cutlass/gemm/kernel/gemm_universal_streamk.h index b8bf3f800..3b71b3eff 100644 --- a/include/cutlass/gemm/kernel/gemm_universal_streamk.h +++ b/include/cutlass/gemm/kernel/gemm_universal_streamk.h @@ -30,7 +30,7 @@ **************************************************************************************************/ /*! \file - \brief + \brief */ #pragma once @@ -575,11 +575,49 @@ public: static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; - if ((problem_size.m() % kAlignmentA) || (problem_size.k() % kAlignmentA) || - (problem_size.n() % kAlignmentB) || (problem_size.k() % kAlignmentB) || - (problem_size.m() % kAlignmentC) || (problem_size.n() % kAlignmentC)) - { - CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand"); + bool isAMisaligned = false; + bool isBMisaligned = false; + bool isCMisaligned = false; + + if (platform::is_same::value) { + isAMisaligned = problem_size.k() % kAlignmentA; + } else if (platform::is_same::value) { + isAMisaligned = problem_size.m() % kAlignmentA; + } else if (platform::is_same>::value + || platform::is_same>::value) { + isAMisaligned = problem_size.k() % kAlignmentA; + } + + if (platform::is_same::value) { + isBMisaligned = problem_size.n() % kAlignmentB; + } else if (platform::is_same::value) { + isBMisaligned = problem_size.k() % kAlignmentB; + } else if (platform::is_same>::value + || platform::is_same>::value) { + isBMisaligned = problem_size.k() % kAlignmentB; + } + + if (platform::is_same::value) { + isCMisaligned = problem_size.n() % kAlignmentC; + } else if (platform::is_same::value) { + isCMisaligned = problem_size.m() % kAlignmentC; + } else if (platform::is_same>::value + || platform::is_same>::value) { + isCMisaligned = problem_size.n() % kAlignmentC; + } + + if (isAMisaligned) { + CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for A operand"); + return Status::kErrorMisalignedOperand; + } + + if (isBMisaligned) { + CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for B operand"); + return Status::kErrorMisalignedOperand; + } + + if (isCMisaligned) { + CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for C operand"); return Status::kErrorMisalignedOperand; } diff --git a/test/unit/gemm/device/gemm_f16n_f16t_f32t_tensor_op_f32_sm80.cu b/test/unit/gemm/device/gemm_f16n_f16t_f32t_tensor_op_f32_sm80.cu index 785840b7c..9b476acde 100644 --- a/test/unit/gemm/device/gemm_f16n_f16t_f32t_tensor_op_f32_sm80.cu +++ b/test/unit/gemm/device/gemm_f16n_f16t_f32t_tensor_op_f32_sm80.cu @@ -33,10 +33,12 @@ */ #include +#include #include "../../common/cutlass_unit_test.h" #include "cutlass/cutlass.h" #include "cutlass/gemm/device/gemm.h" +#include "cutlass/gemm/device/gemm_universal.h" #include "cutlass/util/host_tensor.h" #include "cutlass/util/reference/host/gemm.h" #include "cutlass/util/reference/host/tensor_compare.h" @@ -48,7 +50,7 @@ #include "testbed_universal.h" #if (CUTLASS_ARCH_MMA_SM80_SUPPORTED) - + //////////////////////////////////////////////////////////////////////////////// CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 128x256x64_64x64x64, { @@ -143,6 +145,37 @@ CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16t_f32n_tensor_op_f32, 128x128x64_64x64x EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal()); } ) +CUTLASS_TEST_L1(SM80_Device_Gemm_f16t_f16n_f32t_tensor_op_f32, 128x128x64_64x64x64_sk, { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::GemmUniversal< + cutlass::half_t, cutlass::layout::RowMajor, + cutlass::half_t, cutlass::layout::ColumnMajor, + ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::ThreadblockSwizzleStreamK, 3>; + + // Some custom problem sizes to test relaxed alignment + std::vector problem_m{ + 1, 3 + }; + std::vector problem_n{ + 4, + }; + std::vector problem_k{ + 512, + }; + EXPECT_TRUE(test::gemm::device::TestAllGemmUniversalWithCustomProblemSizes( + problem_m, problem_n, problem_k + )); +} ) + CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 256x64x64_64x64x64, { using ElementOutput = float; using ElementAccumulator = float; @@ -381,4 +414,3 @@ CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 64x64x32_32x32x32 //////////////////////////////////////////////////////////////////////////////// #endif // CUTLASS_ARCH_MMA_SM80_SUPPORTED - diff --git a/test/unit/gemm/device/testbed_universal.h b/test/unit/gemm/device/testbed_universal.h index 8de39e019..3b1e40d09 100644 --- a/test/unit/gemm/device/testbed_universal.h +++ b/test/unit/gemm/device/testbed_universal.h @@ -91,7 +91,7 @@ struct TestbedUniversal { /// Helper to initialize a tensor view template bool initialize_tensor( - cutlass::TensorView view, + cutlass::TensorView view, cutlass::Distribution::Kind dist_kind, uint64_t seed) { @@ -117,11 +117,11 @@ struct TestbedUniversal { cutlass::reference::host::TensorFillRandomUniform( view, seed, scope_max, scope_min, 0); - } + } else if (dist_kind == cutlass::Distribution::Identity) { cutlass::reference::host::TensorFillIdentity(view); - } + } else if (dist_kind == cutlass::Distribution::Gaussian) { cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); @@ -130,7 +130,7 @@ struct TestbedUniversal { cutlass::reference::host::BlockFillSequential( view.data(), view.capacity()); - } + } else { // TODO: Implement the rest EXPECT_TRUE(false) << "Not implemented"; @@ -173,8 +173,8 @@ struct TestbedUniversal { /// Compares computed reference with device reference and outputs to a file if incorrect bool compare_reference( - cutlass::gemm::GemmCoord problem_size, - ElementCompute alpha, + cutlass::gemm::GemmCoord problem_size, + ElementCompute alpha, ElementCompute beta) { tensor_D.sync_host(); @@ -182,7 +182,7 @@ struct TestbedUniversal { EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_A.host_view()), 0); EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_B.host_view()), 0); EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_C.host_view()), 0); - + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D.host_view()), 0); EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0); @@ -199,11 +199,11 @@ struct TestbedUniversal { << problem_size.m() << "x" << problem_size.n() << "x" << problem_size.k() << "_" - << Gemm::ThreadblockShape::kM << "x" - << Gemm::ThreadblockShape::kN << "x" + << Gemm::ThreadblockShape::kM << "x" + << Gemm::ThreadblockShape::kN << "x" << Gemm::ThreadblockShape::kK << "_" - << Gemm::WarpShape::kM << "x" - << Gemm::WarpShape::kN << "x" + << Gemm::WarpShape::kM << "x" + << Gemm::WarpShape::kN << "x" << Gemm::WarpShape::kK << ".txt"; std::ofstream file(fname.str()); @@ -212,10 +212,10 @@ struct TestbedUniversal { std::ofstream file("testbed_universal_errors.txt"); file - << "problem: " << problem_size + << "problem: " << problem_size << ", alpha: " << alpha << ", beta: " << beta << "\n\n"; - file + file << "A =\n" << tensor_A.host_view() << "\nB =\n" << tensor_B.host_view() << "\nC =\n" << tensor_C.host_view() @@ -228,8 +228,8 @@ struct TestbedUniversal { /// Verifies the result is a GEMM bool verify( - cutlass::gemm::GemmCoord problem_size, - ElementCompute alpha, + cutlass::gemm::GemmCoord problem_size, + ElementCompute alpha, ElementCompute beta) { // @@ -239,18 +239,18 @@ struct TestbedUniversal { cutlass::reference::host::GemmComplex< typename Gemm::ElementA, typename Gemm::LayoutA, typename Gemm::ElementB, typename Gemm::LayoutB, - typename Gemm::ElementC, typename Gemm::LayoutC, + typename Gemm::ElementC, typename Gemm::LayoutC, ElementCompute, ElementAccumulator >( problem_size, - alpha, + alpha, tensor_A.host_ref(), Gemm::kTransformA, tensor_B.host_ref(), Gemm::kTransformB, - beta, - tensor_C.host_ref(), - reference_D.host_ref(), + beta, + tensor_C.host_ref(), + reference_D.host_ref(), ElementAccumulator(0) ); @@ -354,9 +354,11 @@ struct TestbedUniversal { cutlass::device_memory::allocation workspace(workspace_size); - cutlass::Status status = gemm_op.initialize(arguments, workspace.get()); + cutlass::Status status = gemm_op.can_implement(arguments); + EXPECT_EQ(status, cutlass::Status::kSuccess) << to_string(status); - EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + status = gemm_op.initialize(arguments, workspace.get()); + EXPECT_EQ(status, cutlass::Status::kSuccess) << to_string(status); // // Run the GEMM @@ -386,20 +388,20 @@ bool TestGemmUniversal( cutlass::gemm::GemmCoord const & problem_size, cutlass::gemm::GemmUniversalMode mode, int batch_count, - double alpha = 1.0, + double alpha = 1.0, double beta = 2.0) { bool passed = true; TestbedUniversal testbed; - + using ElementCompute = typename Gemm::EpilogueOutputOp::ElementCompute; passed = testbed.run( mode, - problem_size, + problem_size, batch_count, - cutlass::from_real(alpha), + cutlass::from_real(alpha), cutlass::from_real(beta) ); @@ -407,54 +409,17 @@ bool TestGemmUniversal( } template -bool TestAllGemmUniversal() { +bool TestAllGemmUniversalWithCustomProblemSizes( + const std::vector& problem_size_m, + const std::vector& problem_size_n, + const std::vector& problem_size_k +) { bool passed = true; - - int const kMinimumOperandElementSize = - std::min( - int(cutlass::sizeof_bits::value), - int(cutlass::sizeof_bits::value)); - - int const kAlignment = cutlass::platform::is_same< - typename Gemm::OperatorClass, - cutlass::arch::OpClassSimt>::value ? 1 : 128 / kMinimumOperandElementSize; - - // int8_t gemm alignment constraints - int const kAlignmentM = cutlass::platform::is_same::value && - cutlass::platform::is_same::value && - cutlass::platform::is_same::value ? 4 : kAlignment; - - int const kAlignmentN = cutlass::platform::is_same::value && - cutlass::platform::is_same::value && - cutlass::platform::is_same::value ? 4 : kAlignment; - - int const kAlignmentK = cutlass::platform::is_same::value && - cutlass::platform::is_same::value && - cutlass::platform::is_same::value && - (cutlass::platform::is_same::value || - cutlass::platform::is_same::value) ? 4 : kAlignment; - - - cutlass::gemm::GemmUniversalMode modes[] = { cutlass::gemm::GemmUniversalMode::kGemm, }; - int problem_size_m[] = { - kAlignmentM, 512 - 3*kAlignmentM - }; - - int problem_size_n[] = { - kAlignmentN, 512 - 2*kAlignmentN - }; - - int problem_size_k[] = { - kAlignmentK, - Gemm::ThreadblockShape::kK * Gemm::kStages - kAlignmentK, - Gemm::ThreadblockShape::kK * Gemm::kStages * 3 - kAlignmentK - }; - int batch_counts[] = { // may be interpretted as batch count or split-K slices 1, 2, 3, 5, 7 }; @@ -494,9 +459,9 @@ bool TestAllGemmUniversal() { passed = testbed.run( mode, - problem_size, + problem_size, batch_count, - cutlass::from_real(alpha), + cutlass::from_real(alpha), cutlass::from_real(beta) ); @@ -520,9 +485,9 @@ bool TestAllGemmUniversal() { passed = testbed.run( cutlass::gemm::GemmUniversalMode::kGemm, - problem_size, + problem_size, split_k_slices, - cutlass::from_real(1.0), + cutlass::from_real(1.0), cutlass::from_real(2.0) ); @@ -535,6 +500,53 @@ bool TestAllGemmUniversal() { return passed; } +template +bool TestAllGemmUniversal() { + int const kMinimumOperandElementSize = + std::min( + int(cutlass::sizeof_bits::value), + int(cutlass::sizeof_bits::value)); + + int const kAlignment = cutlass::platform::is_same< + typename Gemm::OperatorClass, + cutlass::arch::OpClassSimt>::value ? 1 : 128 / kMinimumOperandElementSize; + + // int8_t gemm alignment constraints + int const kAlignmentM = cutlass::platform::is_same::value && + cutlass::platform::is_same::value && + cutlass::platform::is_same::value ? 4 : kAlignment; + + int const kAlignmentN = cutlass::platform::is_same::value && + cutlass::platform::is_same::value && + cutlass::platform::is_same::value ? 4 : kAlignment; + + int const kAlignmentK = cutlass::platform::is_same::value && + cutlass::platform::is_same::value && + cutlass::platform::is_same::value && + (cutlass::platform::is_same::value || + cutlass::platform::is_same::value) ? 4 : kAlignment; + + + + std::vector problem_size_m{ + kAlignmentM, 512 - 3*kAlignmentM + }; + + std::vector problem_size_n{ + kAlignmentN, 512 - 2*kAlignmentN + }; + + std::vector problem_size_k{ + kAlignmentK, + Gemm::ThreadblockShape::kK * Gemm::kStages - kAlignmentK, + Gemm::ThreadblockShape::kK * Gemm::kStages * 3 - kAlignmentK + }; + + return TestAllGemmUniversalWithCustomProblemSizes( + problem_size_m, problem_size_n, problem_size_k + ); +} + ///////////////////////////////////////////////////////////////////////////////////////////////// } // namespace device @@ -542,4 +554,3 @@ bool TestAllGemmUniversal() { } // namespace test ///////////////////////////////////////////////////////////////////////////////////////////////// -