diff --git a/example/ck_tile/03_gemm/gemm_basic.cpp b/example/ck_tile/03_gemm/gemm_basic.cpp index b7d8693442..f5260c306e 100644 --- a/example/ck_tile/03_gemm/gemm_basic.cpp +++ b/example/ck_tile/03_gemm/gemm_basic.cpp @@ -92,6 +92,11 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) const dim3 grids = Kernel::GridSize(args.M, args.N, args.kbatch); constexpr dim3 blocks = Kernel::BlockSize(); + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); + } + if(s.log_level_ > 0) { std::cout << "Launching kernel with args:" diff --git a/example/ck_tile/03_gemm/universal_gemm.cpp b/example/ck_tile/03_gemm/universal_gemm.cpp index eaafc13b98..6c87ca0087 100644 --- a/example/ck_tile/03_gemm/universal_gemm.cpp +++ b/example/ck_tile/03_gemm/universal_gemm.cpp @@ -119,6 +119,11 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) const dim3 grids = Kernel::GridSize(args.M, args.N, args.kbatch); constexpr dim3 blocks = Kernel::BlockSize(); + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); + } + if(s.log_level_ > 0) { std::cout << "Launching kernel with args:" diff --git a/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp index 96af6e8260..763d8cad9c 100644 --- a/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp @@ -66,6 +66,79 @@ struct GemmKernel return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize()); } + CK_TILE_HOST static bool IsSupportedArgument(const GemmCommonKargs& kargs) + { + if constexpr(std::is_same_v) + { + if(kargs.K % TilePartitioner::kK != 0 && GemmPipeline::kPadK == false) + { + return false; + } + if(kargs.K % GemmPipeline::VectorSizeA != 0) + { + return false; + } + } + else + { + if(kargs.M % TilePartitioner::kM != 0 && GemmPipeline::kPadM == false) + { + return false; + } + if(kargs.M % GemmPipeline::VectorSizeA != 0) + { + return false; + } + } + + if constexpr(std::is_same_v) + { + if(kargs.N % TilePartitioner::kN != 0 && GemmPipeline::kPadN == false) + { + return false; + } + if(kargs.N % GemmPipeline::VectorSizeB != 0) + { + return false; + } + } + else + { + if(kargs.K % TilePartitioner::kK != 0 && GemmPipeline::kPadK == false) + { + return false; + } + if(kargs.K % GemmPipeline::VectorSizeB != 0) + { + return false; + } + } + + if constexpr(std::is_same_v) + { + if(kargs.N % TilePartitioner::kN != 0 && GemmPipeline::kPadN == false) + { + return false; + } + if(kargs.N % GemmPipeline::VectorSizeC != 0) + { + return false; + } + } + else + { + if(kargs.M % TilePartitioner::kM != 0 && GemmPipeline::kPadM == false) + { + return false; + } + if(kargs.M % GemmPipeline::VectorSizeC != 0) + { + return false; + } + } + return true; + } + CK_TILE_DEVICE void operator()(GemmCommonKargs kargs) const { const auto [i_m, i_n] = TilePartitioner{}(); diff --git a/test/ck_tile/gemm/test_gemm_mem_pipeline.cpp b/test/ck_tile/gemm/test_gemm_mem_pipeline.cpp index a1c80fee4b..aeb383c87d 100644 --- a/test/ck_tile/gemm/test_gemm_mem_pipeline.cpp +++ b/test/ck_tile/gemm/test_gemm_mem_pipeline.cpp @@ -8,35 +8,29 @@ #include "ck_tile/host.hpp" #include "test_gemm_mem_pipeline_util.hpp" -using F16 = ck_tile::half_t; -using F32 = float; - -using Row = ck_tile::tensor_layout::gemm::RowMajor; -using Col = ck_tile::tensor_layout::gemm::ColumnMajor; -static constexpr auto Intrawave = ck_tile::GemmPipelineScheduler::Intrawave; -static constexpr auto Interwave = ck_tile::GemmPipelineScheduler::Interwave; - -template -class TestCkTileGemmMemPipelineIntrawave : public TestCkTileGemmMemPipeline -{ -}; - -template -class TestCkTileGemmMemPipelineInterwave : public TestCkTileGemmMemPipeline -{ -}; +using F16 = ck_tile::half_t; +using F32 = float; +using Row = ck_tile::tensor_layout::gemm::RowMajor; +using Col = ck_tile::tensor_layout::gemm::ColumnMajor; +using Intrawave = ck_tile::integral_constant; +using Interwave = ck_tile::integral_constant; // clang-format off using KernelTypes = ::testing::Types< - // ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType - std::tuple< Row, Col, Row, F16, F16, F32, F16>, - std::tuple< Col, Row, Row, F16, F16, F32, F16>, - std::tuple< Row, Row, Row, F16, F16, F32, F16>, - std::tuple< Col, Col, Row, F16, F16, F32, F16> + // ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType, GemmPipelineScheduler + std::tuple< Row, Row, Row, F16, F16, F32, F16, Intrawave>, + std::tuple< Row, Row, Row, F16, F16, F32, F16, Interwave>, + std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave>, + std::tuple< Row, Col, Row, F16, F16, F32, F16, Interwave>, + std::tuple< Col, Row, Row, F16, F16, F32, F16, Intrawave>, + std::tuple< Col, Row, Row, F16, F16, F32, F16, Interwave>, + std::tuple< Col, Col, Row, F16, F16, F32, F16, Intrawave>, + std::tuple< Col, Col, Row, F16, F16, F32, F16, Interwave> >; // clang-format on -TYPED_TEST_SUITE(TestCkTileGemmMemPipelineIntrawave, KernelTypes); -TYPED_TEST_SUITE(TestCkTileGemmMemPipelineInterwave, KernelTypes); +TYPED_TEST_SUITE(TestCkTileGemmMemPipeline, KernelTypes); #include "test_gemm_mem_pipeline_ut_cases.inc" diff --git a/test/ck_tile/gemm/test_gemm_mem_pipeline_ut_cases.inc b/test/ck_tile/gemm/test_gemm_mem_pipeline_ut_cases.inc index 6b914e7975..af94d68f2c 100644 --- a/test/ck_tile/gemm/test_gemm_mem_pipeline_ut_cases.inc +++ b/test/ck_tile/gemm/test_gemm_mem_pipeline_ut_cases.inc @@ -3,11 +3,7 @@ #pragma once -//------------------------------------------------------------------------------------------------ -// INTERWAVE SCHEDULER -//------------------------------------------------------------------------------------------------ - -TYPED_TEST(TestCkTileGemmMemPipelineInterwave, SmallM) +TYPED_TEST(TestCkTileGemmMemPipeline, SmallM) { std::vector Ms{1, 2, 3, 4, 5, 6}; constexpr int N = 1024; @@ -17,7 +13,7 @@ TYPED_TEST(TestCkTileGemmMemPipelineInterwave, SmallM) this->Run(M, N, K); } -TYPED_TEST(TestCkTileGemmMemPipelineInterwave, MidLargeM) +TYPED_TEST(TestCkTileGemmMemPipeline, MidLargeM) { std::vector Ms{127, 255, 312, 799, 1573}; constexpr int N = 1024; @@ -27,7 +23,7 @@ TYPED_TEST(TestCkTileGemmMemPipelineInterwave, MidLargeM) this->Run(M, N, K); } -TYPED_TEST(TestCkTileGemmMemPipelineInterwave, PaddK) +TYPED_TEST(TestCkTileGemmMemPipeline, PaddK) { std::vector Ms{127}; constexpr int N = 1024; @@ -37,7 +33,7 @@ TYPED_TEST(TestCkTileGemmMemPipelineInterwave, PaddK) this->Run(M, N, K); } -TYPED_TEST(TestCkTileGemmMemPipelineInterwave, Regular) +TYPED_TEST(TestCkTileGemmMemPipeline, Regular) { std::vector Ms{512}; constexpr int N = 1024; @@ -47,46 +43,15 @@ TYPED_TEST(TestCkTileGemmMemPipelineInterwave, Regular) this->Run(M, N, K); } -//------------------------------------------------------------------------------------------------ -// INTRAWAVE SCHEDULER -//------------------------------------------------------------------------------------------------ - -TYPED_TEST(TestCkTileGemmMemPipelineIntrawave, SmallM) +TYPED_TEST(TestCkTileGemmMemPipeline, NotSupportedArgument) { - std::vector Ms{1, 2, 3, 4, 5, 6}; - constexpr int N = 1024; - constexpr int K = 320; + constexpr int M = 512; + constexpr int N = 1025; + constexpr int K = 513; - for(int M : Ms) - this->Run(M, N, K); -} - -TYPED_TEST(TestCkTileGemmMemPipelineIntrawave, MidLargeM) -{ - std::vector Ms{127, 255, 312, 799, 1573}; - constexpr int N = 1024; - constexpr int K = 320; - - for(int M : Ms) - this->Run(M, N, K); -} - -TYPED_TEST(TestCkTileGemmMemPipelineIntrawave, PaddK) -{ - std::vector Ms{127}; - constexpr int N = 1024; - constexpr int K = 432; - - for(int M : Ms) - this->Run(M, N, K); -} - -TYPED_TEST(TestCkTileGemmMemPipelineIntrawave, Regular) -{ - std::vector Ms{512}; - constexpr int N = 1024; - constexpr int K = 512; - - for(int M : Ms) - this->Run(M, N, K); + constexpr bool PadM = false; + constexpr bool PadN = false; + constexpr bool PadK = false; + + EXPECT_THROW((this->template Run(M, N, K)), std::runtime_error); } diff --git a/test/ck_tile/gemm/test_gemm_mem_pipeline_util.hpp b/test/ck_tile/gemm/test_gemm_mem_pipeline_util.hpp index 15f9f516ee..6941a7596a 100644 --- a/test/ck_tile/gemm/test_gemm_mem_pipeline_util.hpp +++ b/test/ck_tile/gemm/test_gemm_mem_pipeline_util.hpp @@ -11,7 +11,7 @@ #include "ck_tile/ops/epilogue.hpp" #include "ck_tile/ops/gemm.hpp" -template +template class TestCkTileGemmMemPipeline : public ::testing::Test { protected: @@ -22,7 +22,7 @@ class TestCkTileGemmMemPipeline : public ::testing::Test using BDataType = std::tuple_element_t<4, Tuple>; using AccDataType = std::tuple_element_t<5, Tuple>; using CDataType = std::tuple_element_t<6, Tuple>; - static constexpr auto Scheduler = Scheduler_; + static constexpr auto Scheduler = std::tuple_element_t<7, Tuple>::value; // TODO: expose tile size through test t-param ? struct gemm_args @@ -39,6 +39,7 @@ class TestCkTileGemmMemPipeline : public ::testing::Test ck_tile::index_t stride_C; }; + template void invoke_gemm(const gemm_args& args, const ck_tile::stream_config& s) { // TODO: This should be parameterized in tests @@ -54,9 +55,9 @@ class TestCkTileGemmMemPipeline : public ::testing::Test constexpr ck_tile::index_t N_Warp_Tile = 32; constexpr ck_tile::index_t K_Warp_Tile = 8; - constexpr bool kPadM = true; - constexpr bool kPadN = true; - constexpr bool kPadK = true; + constexpr bool kPadM = PadM; + constexpr bool kPadN = PadN; + constexpr bool kPadK = PadK; constexpr int kBlockPerCu = 1; @@ -107,6 +108,11 @@ class TestCkTileGemmMemPipeline : public ::testing::Test const dim3 grids = Kernel::GridSize(args.M, args.N, args.kbatch); constexpr dim3 blocks = Kernel::BlockSize(); + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); + } + if(s.log_level_ > 0) { std::cout << "Launching kernel with args:" @@ -212,6 +218,7 @@ class TestCkTileGemmMemPipeline : public ::testing::Test void SetUp() override { k_batches_ = {1}; } + template void Run(const int M, const int N, const int K, @@ -221,10 +228,11 @@ class TestCkTileGemmMemPipeline : public ::testing::Test { for(auto kb : k_batches_) { - RunSingle(M, N, K, StrideA, StrideB, StrideC, kb); + RunSingle(M, N, K, StrideA, StrideB, StrideC, kb); } } + template void RunSingle(const int M, const int N, const int K, @@ -301,7 +309,7 @@ class TestCkTileGemmMemPipeline : public ::testing::Test args.stride_B = stride_B; args.stride_C = stride_C; - invoke_gemm(args, ck_tile::stream_config{nullptr, false}); + invoke_gemm(args, ck_tile::stream_config{nullptr, false}); c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data()); bool pass = true;