diff --git a/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_xdl_fixed_nk_bias_bf16_i8.cpp b/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_xdl_fixed_nk_bias_bf16_i8.cpp index 63343df3a8..6f30bdaa73 100644 --- a/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_xdl_fixed_nk_bias_bf16_i8.cpp +++ b/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_xdl_fixed_nk_bias_bf16_i8.cpp @@ -221,8 +221,8 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co b0_tensors_device.emplace_back(std::make_unique( sizeof(B0DataType) * problem_size.Ns[i] * problem_size.Ks[i])); - b1_tensors_device.emplace_back( - std::make_unique(sizeof(B1DataType) * problem_size.Ns[i])); + b1_tensors_device.emplace_back(std::make_unique( + sizeof(B1DataType) * problem_size.Ns[i] * problem_size.Ks[i])); d0_tensors_device.emplace_back( std::make_unique(sizeof(D0DataType) * problem_size.Ns[i])); diff --git a/example/ck_tile/03_gemm/run_gemm_example.inc b/example/ck_tile/03_gemm/run_gemm_example.inc index d5f164c40f..703ab810d8 100644 --- a/example/ck_tile/03_gemm/run_gemm_example.inc +++ b/example/ck_tile/03_gemm/run_gemm_example.inc @@ -309,8 +309,8 @@ int run_gemm_example_with_layouts(ck_tile::ArgParser& arg_parser, if(init_method == 0) { - ck_tile::FillUniformDistribution{-5.f, 5.f}(a_m_k); - ck_tile::FillUniformDistribution{-5.f, 5.f}(b_k_n); + ck_tile::FillUniformDistribution{-2.f, 2.f}(a_m_k); + ck_tile::FillUniformDistribution{-2.f, 2.f}(b_k_n); } else if(init_method == 1) { diff --git a/profiler/src/CMakeLists.txt b/profiler/src/CMakeLists.txt index 9f86f6d88f..c22867fbed 100644 --- a/profiler/src/CMakeLists.txt +++ b/profiler/src/CMakeLists.txt @@ -40,6 +40,9 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9") list(APPEND PROFILER_OPS profile_contraction_bilinear.cpp) list(APPEND PROFILER_OPS profile_contraction_scale.cpp) endif() +endif() + +if(SUPPORTED_GPU_TARGETS MATCHES "gfx9" OR SUPPORTED_GPU_TARGETS MATCHES "gfx1[12]") if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) list(APPEND PROFILER_OPS profile_gemm_reduce.cpp) list(APPEND PROFILER_OPS profile_batched_gemm_add_relu_gemm_add.cpp) @@ -53,7 +56,7 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9") list(APPEND PROFILER_OPS profile_grouped_gemm_tile_loop.cpp) list(APPEND PROFILER_OPS profile_grouped_gemm_multiply_tile_loop.cpp) endif() - if(SUPPORTED_GPU_TARGETS MATCHES "gfx9[45]") + if(SUPPORTED_GPU_TARGETS MATCHES "gfx9[45]" OR SUPPORTED_GPU_TARGETS MATCHES "gfx12") list(APPEND PROFILER_OPS profile_gemm_multiply_multiply_wp.cpp) list(APPEND PROFILER_OPS profile_gemm_ab_scale.cpp) list(APPEND PROFILER_OPS profile_gemm_blockscale_wp.cpp) @@ -74,7 +77,6 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9") list(APPEND PROFILER_OPS profile_conv_bwd_data.cpp) list(APPEND PROFILER_OPS profile_conv_fwd.cpp) list(APPEND PROFILER_OPS profile_grouped_conv_fwd_outelementop.cpp) - endif() if((SUPPORTED_GPU_TARGETS MATCHES "gfx9" AND (DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)) OR diff --git a/test/ck_tile/gemm/CMakeLists.txt b/test/ck_tile/gemm/CMakeLists.txt index 96c071cbc4..c08ab33b91 100644 --- a/test/ck_tile/gemm/CMakeLists.txt +++ b/test/ck_tile/gemm/CMakeLists.txt @@ -22,6 +22,12 @@ else() message(DEBUG "Skipping ck_tile_gemm tests for current target") endif() + +if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12") + add_gtest_executable(test_gemm_pipeline_compiler test_gemm_pipeline_compiler.cpp) + target_compile_options(test_gemm_pipeline_compiler PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) +endif() + if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12") add_gtest_executable(test_ck_tile_gemm_pipeline_universal_fp8 test_gemm_pipeline_universal_fp8.cpp) add_gtest_executable(test_ck_tile_gemm_pipeline_universal_bf8 test_gemm_pipeline_universal_bf8.cpp) diff --git a/test/ck_tile/gemm/test_gemm_pipeline_compiler.cpp b/test/ck_tile/gemm/test_gemm_pipeline_compiler.cpp new file mode 100644 index 0000000000..bf39e0b552 --- /dev/null +++ b/test/ck_tile/gemm/test_gemm_pipeline_compiler.cpp @@ -0,0 +1,900 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "test_gemm_pipeline_kernel_types.hpp" +#include "test_gemm_pipeline_util.hpp" +#include "gtest/gtest.h" + +// ============================================================================ +// Comprehensive GEMM Compiler Validation Test Suite +// This file consolidates all GEMM pipeline tests for compiler validation +// Covers essential combinations of data types, layouts, and pipeline types +// ============================================================================ + +// ---------------------------------------------------------------------------- +// Test Class Definitions for Different Pipeline Types +// ---------------------------------------------------------------------------- + +template +class TestGemmMem : public TestCkTileGemmPipeline> +{ +}; + +#if defined(CK_TILE_USE_WMMA) +template +class TestGemmMemWmma : public TestCkTileGemmPipeline> +{ +}; +#endif + +template +class TestGemmCompV3 : public TestCkTileGemmPipeline> +{ +}; + +#if defined(CK_TILE_USE_WMMA) +template +class TestGemmCompV3Wmma : public TestCkTileGemmPipeline> +{ +}; +#endif + +template +class TestGemmCompV4 : public TestCkTileGemmPipeline> +{ +}; + +#if defined(CK_TILE_USE_WMMA) +template +class TestGemmCompV4Wmma : public TestCkTileGemmPipeline> +{ +}; +#endif + +template +class TestGemmCompV6 : public TestCkTileGemmPipeline> +{ +}; + +template +class TestGemmPersistent : public TestCkTileGemmPipeline> +{ +}; + +#if defined(CK_TILE_USE_WMMA) +template +class TestGemmPersistentWmma : public TestCkTileGemmPipeline> +{ +}; +#endif + +// ---------------------------------------------------------------------------- +// Type Definitions for Each Pipeline Configuration +// ---------------------------------------------------------------------------- + +// Memory Pipeline Types +using MemTestTypes = ::testing::Types< + // Parameters: ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType, + // M_BlockSize, N_BlockSize, K_BlockSize, M_TileSize, N_TileSize, K_TileSize, Scheduler, + // PipelineType + + std::tuple, + std::tuple>; + +#if defined(CK_TILE_USE_WMMA) +// Memory Pipeline WMMA Types +using MemWmmaTestTypes = ::testing::Types< + std::tuple, + std::tuple>; +#endif + +// CompV3 Pipeline Types +using CompV3TestTypes = ::testing::Types< + std::tuple, + std::tuple>; + +#if defined(CK_TILE_USE_WMMA) +// CompV3 Pipeline WMMA Types +using CompV3WmmaTestTypes = ::testing::Types< + std::tuple, + std::tuple>; +#endif + +// CompV4 Pipeline Types +using CompV4TestTypes = ::testing::Types< + std::tuple, + std::tuple>; + +#if defined(CK_TILE_USE_WMMA) +// CompV4 Pipeline WMMA Types +using CompV4WmmaTestTypes = ::testing::Types< + std::tuple, + std::tuple>; +#endif + +// CompV6 Pipeline Types +using CompV6TestTypes = ::testing::Types< + std::tuple, + std::tuple>; + +// Persistent CompV3 Pipeline Types +using PersistentTestTypes = ::testing::Types, + std::tuple>; + +#if defined(CK_TILE_USE_WMMA) +// Persistent CompV3 Pipeline WMMA Types +using PersistentWmmaTestTypes = ::testing::Types, + std::tuple>; +#endif + +// ---------------------------------------------------------------------------- +// Test Suite Registrations +// ---------------------------------------------------------------------------- + +TYPED_TEST_SUITE(TestGemmMem, MemTestTypes); +#if defined(CK_TILE_USE_WMMA) +TYPED_TEST_SUITE(TestGemmMemWmma, MemWmmaTestTypes); +#endif +TYPED_TEST_SUITE(TestGemmCompV3, CompV3TestTypes); +#if defined(CK_TILE_USE_WMMA) +TYPED_TEST_SUITE(TestGemmCompV3Wmma, CompV3WmmaTestTypes); +#endif +TYPED_TEST_SUITE(TestGemmCompV4, CompV4TestTypes); +#if defined(CK_TILE_USE_WMMA) +TYPED_TEST_SUITE(TestGemmCompV4Wmma, CompV4WmmaTestTypes); +#endif +TYPED_TEST_SUITE(TestGemmCompV6, CompV6TestTypes); +TYPED_TEST_SUITE(TestGemmPersistent, PersistentTestTypes); +#if defined(CK_TILE_USE_WMMA) +TYPED_TEST_SUITE(TestGemmPersistentWmma, PersistentWmmaTestTypes); +#endif + +// ============================================================================ +// Memory Pipeline Tests (Mem) +// ============================================================================ + +#define TEST_SUITE_NAME TestGemmMem + +TYPED_TEST(TEST_SUITE_NAME, SmallM_SingleRow) +{ + std::vector Ms{1}; + constexpr int N = 1024; + constexpr int K = TestFixture::K_Tile * 2; + + for(int M : Ms) + { + if constexpr(std::is_same_v) + { + EXPECT_THROW((this->Run(M, N, K)), std::runtime_error); + } + else + { + this->Run(M, N, K); + } + } +} + +TYPED_TEST(TEST_SUITE_NAME, SingleTile) +{ + this->Run(TestFixture::M_Tile, TestFixture::N_Tile, TestFixture::K_Tile); +} + +TYPED_TEST(TEST_SUITE_NAME, ExactlyTwoTiles_M) +{ + this->Run(TestFixture::M_Tile * 2, TestFixture::N_Tile, TestFixture::K_Tile * 2); +} + +TYPED_TEST(TEST_SUITE_NAME, ExactlyTwoTiles_N) +{ + this->Run(TestFixture::M_Tile, TestFixture::N_Tile * 2, TestFixture::K_Tile * 2); +} + +TYPED_TEST(TEST_SUITE_NAME, ExactlyTwoTiles_K) +{ + this->Run(TestFixture::M_Tile, TestFixture::N_Tile, TestFixture::K_Tile * 2); +} + +TYPED_TEST(TEST_SUITE_NAME, Regular_512x1024x512) +{ + constexpr int M = 512; + constexpr int N = 1024; + constexpr int K = 512; + this->Run(M, N, K); +} + +TYPED_TEST(TEST_SUITE_NAME, Square_1024x1024x1024) +{ + constexpr int M = 1024; + constexpr int N = 1024; + constexpr int K = 1024; + this->Run(M, N, K); +} + +TYPED_TEST(TEST_SUITE_NAME, LargeMatrix_2048x2048x2048) +{ + constexpr int M = 2048; + constexpr int N = 2048; + constexpr int K = 2048; + this->Run(M, N, K); +} + +TYPED_TEST(TEST_SUITE_NAME, VeryLargeMatrix_4096x4096x4096) +{ + constexpr int M = 4096; + constexpr int N = 4096; + constexpr int K = 4096; + this->Run(M, N, K); +} + +TYPED_TEST(TEST_SUITE_NAME, TallSkinny_4096x128x1024) +{ + constexpr int M = 4096; + constexpr int N = 128; + constexpr int K = 1024; + this->Run(M, N, K); +} + +TYPED_TEST(TEST_SUITE_NAME, ShortWide_128x4096x1024) +{ + constexpr int M = 128; + constexpr int N = 4096; + constexpr int K = 1024; + this->Run(M, N, K); +} + +TYPED_TEST(TEST_SUITE_NAME, DeepNarrow_2048x2048x8192) +{ + constexpr int M = 2048; + constexpr int N = 2048; + constexpr int K = 8192; + this->Run(M, N, K); +} + +TYPED_TEST(TEST_SUITE_NAME, StressTest_ExtremelyTallMatrix) +{ + constexpr int M = 16384; + constexpr int N = 64; + constexpr int K = 512; + this->Run(M, N, K); +} + +TYPED_TEST(TEST_SUITE_NAME, StressTest_ExtremelyWideMatrix) +{ + constexpr int M = 64; + constexpr int N = 16384; + constexpr int K = 512; + this->Run(M, N, K); +} + +TYPED_TEST(TEST_SUITE_NAME, StressTest_VeryDeepK) +{ + constexpr int M = 1024; + constexpr int N = 1024; + constexpr int K = 16384; + this->Run(M, N, K); +} + +#undef TEST_SUITE_NAME + +#if defined(CK_TILE_USE_WMMA) +// ============================================================================ +// Memory Pipeline Tests with WMMA +// ============================================================================ + +#define TEST_SUITE_NAME TestGemmMemWmma + +TYPED_TEST(TEST_SUITE_NAME, SingleTile_WMMA) +{ + this->Run(TestFixture::M_Tile, TestFixture::N_Tile, TestFixture::K_Tile); +} + +TYPED_TEST(TEST_SUITE_NAME, Regular_WMMA) +{ + constexpr int M = 512; + constexpr int N = 1024; + constexpr int K = 512; + this->Run(M, N, K); +} + +TYPED_TEST(TEST_SUITE_NAME, LargeMatrix_WMMA) +{ + constexpr int M = 2048; + constexpr int N = 2048; + constexpr int K = 2048; + this->Run(M, N, K); +} + +#undef TEST_SUITE_NAME +#endif // CK_TILE_USE_WMMA + +// ============================================================================ +// Compute V3 Pipeline Tests +// ============================================================================ + +#define TEST_SUITE_NAME TestGemmCompV3 + +TYPED_TEST(TEST_SUITE_NAME, SmallM_CompV3) +{ + std::vector Ms{1, 2}; + constexpr int N = 1024; + std::vector Ks; + for(auto K_count : {2, 4}) + { + Ks.push_back(K_count * TestFixture::K_Tile); + } + + for(int M : Ms) + { + for(int K : Ks) + { + if constexpr(std::is_same_v) + { + EXPECT_THROW((this->Run(M, N, K)), std::runtime_error); + } + else + { + this->Run(M, N, K); + } + } + } +} + +TYPED_TEST(TEST_SUITE_NAME, SingleTile_CompV3) +{ + this->Run(TestFixture::M_Tile, TestFixture::N_Tile, TestFixture::K_Tile); +} + +TYPED_TEST(TEST_SUITE_NAME, MidLargeM_CompV3) +{ + std::vector Ms{127, 255}; + constexpr int N = 1024; + + std::vector Ks; + for(auto K_count : {2, 4}) + { + Ks.push_back(K_count * TestFixture::K_Tile); + } + + constexpr int VecLoadSize = (std::is_same_v || + std::is_same_v || + std::is_same_v) + ? 16 + : 8; + + for(int M : Ms) + { + for(int K : Ks) + { + if constexpr(std::is_same_v) + { + if(M % VecLoadSize == 0) + { + this->Run(M, N, K); + } + else + { + EXPECT_THROW((this->Run(M, N, K)), std::runtime_error); + } + } + else + { + this->Run(M, N, K); + } + } + } +} + +TYPED_TEST(TEST_SUITE_NAME, Regular_CompV3) +{ + constexpr int M = 512; + constexpr int N = 1024; + constexpr int K = 512; + this->Run(M, N, K); +} + +TYPED_TEST(TEST_SUITE_NAME, LargeMatrix_CompV3) +{ + constexpr int M = 2048; + constexpr int N = 2048; + constexpr int K = 2048; + this->Run(M, N, K); +} + +TYPED_TEST(TEST_SUITE_NAME, BatchedSmall_CompV3) +{ + constexpr int M = 256; + constexpr int N = 256; + constexpr int K = 256; + this->Run(M, N, K); +} + +#undef TEST_SUITE_NAME + +#if defined(CK_TILE_USE_WMMA) +// ============================================================================ +// Compute V3 Pipeline Tests with WMMA +// ============================================================================ + +#define TEST_SUITE_NAME TestGemmCompV3Wmma + +TYPED_TEST(TEST_SUITE_NAME, SmallM_CompV3Wmma) +{ + std::vector Ms{1, 2}; + constexpr int N = 1024; + std::vector Ks; + for(auto K_count : {2, 4}) + { + Ks.push_back(K_count * TestFixture::K_Tile); + } + + for(int M : Ms) + { + for(int K : Ks) + { + if constexpr(std::is_same_v) + { + EXPECT_THROW((this->Run(M, N, K)), std::runtime_error); + } + else + { + this->Run(M, N, K); + } + } + } +} + +TYPED_TEST(TEST_SUITE_NAME, SingleTile_CompV3Wmma) +{ + this->Run(TestFixture::M_Tile, TestFixture::N_Tile, TestFixture::K_Tile); +} + +TYPED_TEST(TEST_SUITE_NAME, Regular_CompV3Wmma) +{ + constexpr int M = 512; + constexpr int N = 1024; + constexpr int K = 512; + this->Run(M, N, K); +} + +TYPED_TEST(TEST_SUITE_NAME, LargeMatrix_CompV3Wmma) +{ + constexpr int M = 2048; + constexpr int N = 2048; + constexpr int K = 2048; + this->Run(M, N, K); +} + +#undef TEST_SUITE_NAME +#endif // CK_TILE_USE_WMMA + +// ============================================================================ +// Compute V4 Pipeline Tests +// ============================================================================ + +#define TEST_SUITE_NAME TestGemmCompV4 + +TYPED_TEST(TEST_SUITE_NAME, SmallM_CompV4) +{ + std::vector Ms{1, 2}; + constexpr int N = 1024; + std::vector Ks; + for(auto K_count : {2, 4}) + { + Ks.push_back(K_count * TestFixture::K_Tile); + } + + for(int M : Ms) + { + for(int K : Ks) + { + if constexpr(std::is_same_v) + { + EXPECT_THROW((this->Run(M, N, K)), std::runtime_error); + } + else + { + this->Run(M, N, K); + } + } + } +} + +TYPED_TEST(TEST_SUITE_NAME, SingleTile_CompV4) +{ + this->Run(TestFixture::M_Tile, TestFixture::N_Tile, TestFixture::K_Tile); +} + +TYPED_TEST(TEST_SUITE_NAME, Regular_CompV4) +{ + constexpr int M = 512; + constexpr int N = 1024; + constexpr int K = 512; + this->Run(M, N, K); +} + +TYPED_TEST(TEST_SUITE_NAME, LargeMatrix_CompV4) +{ + constexpr int M = 2048; + constexpr int N = 2048; + constexpr int K = 2048; + this->Run(M, N, K); +} + +#undef TEST_SUITE_NAME + +#if defined(CK_TILE_USE_WMMA) +// ============================================================================ +// Compute V4 Pipeline Tests with WMMA +// ============================================================================ + +#define TEST_SUITE_NAME TestGemmCompV4Wmma + +TYPED_TEST(TEST_SUITE_NAME, SingleTile_CompV4Wmma) +{ + this->Run(TestFixture::M_Tile, TestFixture::N_Tile, TestFixture::K_Tile); +} + +TYPED_TEST(TEST_SUITE_NAME, Regular_CompV4Wmma) +{ + constexpr int M = 512; + constexpr int N = 1024; + constexpr int K = 512; + this->Run(M, N, K); +} + +TYPED_TEST(TEST_SUITE_NAME, LargeMatrix_CompV4Wmma) +{ + constexpr int M = 2048; + constexpr int N = 2048; + constexpr int K = 2048; + this->Run(M, N, K); +} + +#undef TEST_SUITE_NAME +#endif // CK_TILE_USE_WMMA + +// ============================================================================ +// Compute V6 Pipeline Tests +// ============================================================================ + +#define TEST_SUITE_NAME TestGemmCompV6 + +TYPED_TEST(TEST_SUITE_NAME, SmallM_CompV6) +{ + std::vector Ms{1, 2}; + constexpr int N = 1024; + std::vector Ks; + for(auto K_count : {2, 4}) + { + Ks.push_back(K_count * TestFixture::K_Tile); + } + + for(int M : Ms) + { + for(int K : Ks) + { + if constexpr(std::is_same_v) + { + EXPECT_THROW((this->Run(M, N, K)), std::runtime_error); + } + else + { + this->Run(M, N, K); + } + } + } +} + +TYPED_TEST(TEST_SUITE_NAME, SingleTile_CompV6) +{ + this->Run(TestFixture::M_Tile, TestFixture::N_Tile, TestFixture::K_Tile); +} + +TYPED_TEST(TEST_SUITE_NAME, MidLargeM_CompV6) +{ + std::vector Ms{127, 255}; + constexpr int N = 1024; + + std::vector Ks; + for(auto K_count : {2, 4}) + { + Ks.push_back(K_count * TestFixture::K_Tile); + } + + constexpr int VecLoadSize = (std::is_same_v || + std::is_same_v || + std::is_same_v) + ? 16 + : 8; + + for(int M : Ms) + { + for(int K : Ks) + { + if constexpr(std::is_same_v) + { + if(M % VecLoadSize == 0) + { + this->Run(M, N, K); + } + else + { + EXPECT_THROW((this->Run(M, N, K)), std::runtime_error); + } + } + else + { + this->Run(M, N, K); + } + } + } +} + +TYPED_TEST(TEST_SUITE_NAME, Regular_CompV6) +{ + constexpr int M = 512; + constexpr int N = 1024; + constexpr int K = 512; + this->Run(M, N, K); +} + +TYPED_TEST(TEST_SUITE_NAME, LargeMatrix_CompV6) +{ + constexpr int M = 2048; + constexpr int N = 2048; + constexpr int K = 2048; + this->Run(M, N, K); +} + +#undef TEST_SUITE_NAME + +// ============================================================================ +// Persistent Kernel Tests +// ============================================================================ + +#define TEST_SUITE_NAME TestGemmPersistent + +TYPED_TEST(TEST_SUITE_NAME, SmallM_Persistent) +{ + std::vector Ms{1, 2}; + constexpr int N = 1024; + std::vector Ks; + for(auto K_count : {2, 4}) + { + Ks.push_back(K_count * TestFixture::K_Tile); + } + + for(int M : Ms) + { + for(int K : Ks) + { + if constexpr(std::is_same_v) + { + EXPECT_THROW((this->Run(M, N, K)), std::runtime_error); + } + else + { + this->Run(M, N, K); + } + } + } +} + +TYPED_TEST(TEST_SUITE_NAME, SingleTile_Persistent) +{ + this->Run(TestFixture::M_Tile, TestFixture::N_Tile, TestFixture::K_Tile); +} + +TYPED_TEST(TEST_SUITE_NAME, Regular_Persistent) +{ + constexpr int M = 512; + constexpr int N = 1024; + constexpr int K = 512; + this->Run(M, N, K); +} + +TYPED_TEST(TEST_SUITE_NAME, LargeMatrix_Persistent) +{ + constexpr int M = 2048; + constexpr int N = 2048; + constexpr int K = 2048; + this->Run(M, N, K); +} + +#undef TEST_SUITE_NAME + +#if defined(CK_TILE_USE_WMMA) +// ============================================================================ +// Persistent Kernel Tests with WMMA +// ============================================================================ + +#define TEST_SUITE_NAME TestGemmPersistentWmma + +TYPED_TEST(TEST_SUITE_NAME, SmallM_PersistentWmma) +{ + std::vector Ms{1, 2}; + constexpr int N = 1024; + std::vector Ks; + for(auto K_count : {2, 4}) + { + Ks.push_back(K_count * TestFixture::K_Tile); + } + + for(int M : Ms) + { + for(int K : Ks) + { + if constexpr(std::is_same_v) + { + EXPECT_THROW((this->Run(M, N, K)), std::runtime_error); + } + else + { + this->Run(M, N, K); + } + } + } +} + +TYPED_TEST(TEST_SUITE_NAME, SingleTile_PersistentWmma) +{ + this->Run(TestFixture::M_Tile, TestFixture::N_Tile, TestFixture::K_Tile); +} + +TYPED_TEST(TEST_SUITE_NAME, Regular_PersistentWmma) +{ + constexpr int M = 512; + constexpr int N = 1024; + constexpr int K = 512; + this->Run(M, N, K); +} + +TYPED_TEST(TEST_SUITE_NAME, LargeMatrix_PersistentWmma) +{ + constexpr int M = 2048; + constexpr int N = 2048; + constexpr int K = 2048; + this->Run(M, N, K); +} + +#undef TEST_SUITE_NAME +#endif // CK_TILE_USE_WMMA