diff --git a/test/ck_tile/gemm/CMakeLists.txt b/test/ck_tile/gemm/CMakeLists.txt index f96ad9c6e0..ecfbd4e55b 100644 --- a/test/ck_tile/gemm/CMakeLists.txt +++ b/test/ck_tile/gemm/CMakeLists.txt @@ -1,4 +1,4 @@ # Currently ck_tile is only built on gfx9 if(GPU_TARGETS MATCHES "gfx9") - add_gtest_executable(test_ck_tile_gemm_mem_pipeline test_gemm_mem_pipeline.cpp) + add_gtest_executable(test_ck_tile_gemm_pipeline test_gemm_pipeline.cpp) endif() diff --git a/test/ck_tile/gemm/test_gemm_mem_pipeline.cpp b/test/ck_tile/gemm/test_gemm_mem_pipeline.cpp deleted file mode 100644 index aeb383c87d..0000000000 --- a/test/ck_tile/gemm/test_gemm_mem_pipeline.cpp +++ /dev/null @@ -1,36 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. - -#include - -#include "gtest/gtest.h" - -#include "ck_tile/host.hpp" -#include "test_gemm_mem_pipeline_util.hpp" - -using F16 = ck_tile::half_t; -using F32 = float; -using Row = ck_tile::tensor_layout::gemm::RowMajor; -using Col = ck_tile::tensor_layout::gemm::ColumnMajor; -using Intrawave = ck_tile::integral_constant; -using Interwave = ck_tile::integral_constant; - -// clang-format off -using KernelTypes = ::testing::Types< - // ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType, GemmPipelineScheduler - std::tuple< Row, Row, Row, F16, F16, F32, F16, Intrawave>, - std::tuple< Row, Row, Row, F16, F16, F32, F16, Interwave>, - std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave>, - std::tuple< Row, Col, Row, F16, F16, F32, F16, Interwave>, - std::tuple< Col, Row, Row, F16, F16, F32, F16, Intrawave>, - std::tuple< Col, Row, Row, F16, F16, F32, F16, Interwave>, - std::tuple< Col, Col, Row, F16, F16, F32, F16, Intrawave>, - std::tuple< Col, Col, Row, F16, F16, F32, F16, Interwave> - >; -// clang-format on - -TYPED_TEST_SUITE(TestCkTileGemmMemPipeline, KernelTypes); - -#include "test_gemm_mem_pipeline_ut_cases.inc" diff --git a/test/ck_tile/gemm/test_gemm_pipeline.cpp b/test/ck_tile/gemm/test_gemm_pipeline.cpp new file mode 100644 index 0000000000..48a2b86a63 --- /dev/null +++ b/test/ck_tile/gemm/test_gemm_pipeline.cpp @@ -0,0 +1,42 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "gtest/gtest.h" + +#include "ck_tile/host.hpp" +#include "test_gemm_pipeline_util.hpp" + +using F16 = ck_tile::half_t; +using F32 = float; +using Row = ck_tile::tensor_layout::gemm::RowMajor; +using Col = ck_tile::tensor_layout::gemm::ColumnMajor; +using Intrawave = ck_tile::integral_constant; +using Interwave = ck_tile::integral_constant; +using Mem = ck_tile::integral_constant; +using Comp = ck_tile::integral_constant; + +// clang-format off +using KernelTypes = ::testing::Types< + // ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType, GemmPipelineScheduler, PipelineType + std::tuple< Row, Row, Row, F16, F16, F32, F16, Intrawave, Mem>, + std::tuple< Row, Row, Row, F16, F16, F32, F16, Intrawave, Comp>, + std::tuple< Row, Row, Row, F16, F16, F32, F16, Interwave, Mem>, + std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave, Mem>, + std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave, Comp>, + std::tuple< Row, Col, Row, F16, F16, F32, F16, Interwave, Mem>, + std::tuple< Col, Row, Row, F16, F16, F32, F16, Intrawave, Mem>, + std::tuple< Col, Row, Row, F16, F16, F32, F16, Intrawave, Comp>, + std::tuple< Col, Row, Row, F16, F16, F32, F16, Interwave, Mem>, + std::tuple< Col, Col, Row, F16, F16, F32, F16, Intrawave, Mem>, + std::tuple< Col, Col, Row, F16, F16, F32, F16, Intrawave, Comp>, + std::tuple< Col, Col, Row, F16, F16, F32, F16, Interwave, Mem> + >; +// clang-format on + +TYPED_TEST_SUITE(TestCkTileGemmPipeline, KernelTypes); + +#include "test_gemm_pipeline_ut_cases.inc" diff --git a/test/ck_tile/gemm/test_gemm_mem_pipeline_ut_cases.inc b/test/ck_tile/gemm/test_gemm_pipeline_ut_cases.inc similarity index 79% rename from test/ck_tile/gemm/test_gemm_mem_pipeline_ut_cases.inc rename to test/ck_tile/gemm/test_gemm_pipeline_ut_cases.inc index af94d68f2c..c78d69601c 100644 --- a/test/ck_tile/gemm/test_gemm_mem_pipeline_ut_cases.inc +++ b/test/ck_tile/gemm/test_gemm_pipeline_ut_cases.inc @@ -3,7 +3,7 @@ #pragma once -TYPED_TEST(TestCkTileGemmMemPipeline, SmallM) +TYPED_TEST(TestCkTileGemmPipeline, SmallM) { std::vector Ms{1, 2, 3, 4, 5, 6}; constexpr int N = 1024; @@ -13,7 +13,7 @@ TYPED_TEST(TestCkTileGemmMemPipeline, SmallM) this->Run(M, N, K); } -TYPED_TEST(TestCkTileGemmMemPipeline, MidLargeM) +TYPED_TEST(TestCkTileGemmPipeline, MidLargeM) { std::vector Ms{127, 255, 312, 799, 1573}; constexpr int N = 1024; @@ -23,7 +23,7 @@ TYPED_TEST(TestCkTileGemmMemPipeline, MidLargeM) this->Run(M, N, K); } -TYPED_TEST(TestCkTileGemmMemPipeline, PaddK) +TYPED_TEST(TestCkTileGemmPipeline, PaddK) { std::vector Ms{127}; constexpr int N = 1024; @@ -33,7 +33,7 @@ TYPED_TEST(TestCkTileGemmMemPipeline, PaddK) this->Run(M, N, K); } -TYPED_TEST(TestCkTileGemmMemPipeline, Regular) +TYPED_TEST(TestCkTileGemmPipeline, Regular) { std::vector Ms{512}; constexpr int N = 1024; @@ -43,7 +43,7 @@ TYPED_TEST(TestCkTileGemmMemPipeline, Regular) this->Run(M, N, K); } -TYPED_TEST(TestCkTileGemmMemPipeline, NotSupportedArgument) +TYPED_TEST(TestCkTileGemmPipeline, NotSupportedArgument) { constexpr int M = 512; constexpr int N = 1025; diff --git a/test/ck_tile/gemm/test_gemm_mem_pipeline_util.hpp b/test/ck_tile/gemm/test_gemm_pipeline_util.hpp similarity index 80% rename from test/ck_tile/gemm/test_gemm_mem_pipeline_util.hpp rename to test/ck_tile/gemm/test_gemm_pipeline_util.hpp index 6941a7596a..a514986024 100644 --- a/test/ck_tile/gemm/test_gemm_mem_pipeline_util.hpp +++ b/test/ck_tile/gemm/test_gemm_pipeline_util.hpp @@ -11,18 +11,24 @@ #include "ck_tile/ops/epilogue.hpp" #include "ck_tile/ops/gemm.hpp" +enum struct GemmPipelineType +{ + Mem, + Comp +}; template -class TestCkTileGemmMemPipeline : public ::testing::Test +class TestCkTileGemmPipeline : public ::testing::Test { protected: - using ALayout = std::tuple_element_t<0, Tuple>; - using BLayout = std::tuple_element_t<1, Tuple>; - using CLayout = std::tuple_element_t<2, Tuple>; - using ADataType = std::tuple_element_t<3, Tuple>; - using BDataType = std::tuple_element_t<4, Tuple>; - using AccDataType = std::tuple_element_t<5, Tuple>; - using CDataType = std::tuple_element_t<6, Tuple>; - static constexpr auto Scheduler = std::tuple_element_t<7, Tuple>::value; + using ALayout = std::tuple_element_t<0, Tuple>; + using BLayout = std::tuple_element_t<1, Tuple>; + using CLayout = std::tuple_element_t<2, Tuple>; + using ADataType = std::tuple_element_t<3, Tuple>; + using BDataType = std::tuple_element_t<4, Tuple>; + using AccDataType = std::tuple_element_t<5, Tuple>; + using CDataType = std::tuple_element_t<6, Tuple>; + static constexpr auto Scheduler = std::tuple_element_t<7, Tuple>::value; + static constexpr auto PipelineType = std::tuple_element_t<8, Tuple>::value; // TODO: expose tile size through test t-param ? struct gemm_args @@ -74,8 +80,13 @@ class TestCkTileGemmMemPipeline : public ::testing::Test using Traits = ck_tile::TileGemmTraits; - using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem< - ck_tile::GemmPipelineProblem>; + using BaseGemmPipeline = std::conditional_t< + PipelineType == GemmPipelineType::Mem, + ck_tile::BaseGemmPipelineAgBgCrMem< + ck_tile::GemmPipelineProblem>, + ck_tile::BaseGemmPipelineAgBgCrCompV3< + ck_tile:: + GemmPipelineProblem>>; const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(args.K); const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); @@ -85,15 +96,26 @@ class TestCkTileGemmMemPipeline : public ::testing::Test constexpr bool has_hot_loop_v = has_hot_loop_.value; constexpr auto tail_number_v = tail_number_.value; - using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem< - ck_tile::UniversalGemmPipelineProblem>; + using GemmPipeline = + std::conditional_t>, + ck_tile::GemmPipelineAgBgCrCompV3< + ck_tile::UniversalGemmPipelineProblem>>; using Kernel = ck_tile::GemmKernel; auto kargs = Kernel::MakeKargs(args.p_a, args.p_b,