// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT #include #include #include #include #include #include #include "ck_tile/core.hpp" #include "ck_tile/host.hpp" #include "ck_tile/host/kernel_launch.hpp" #include "test_mhc_impl.hpp" // Shape parameters for different test configurations using Shape1_BlockWarps = ck_tile::sequence<4, 1>; using Shape1_BlockTile = ck_tile::sequence<128, 128>; using Shape1_WarpTile = ck_tile::sequence<32, 128>; using Shape1_ThreadTile = ck_tile::sequence<8, 8>; // Test configurations for different data types // Tuple format: using TestConfig_F32 = std::tuple; using TestConfig_BF16 = std::tuple; using TestTypes = ::testing::Types; TYPED_TEST_SUITE(TestCkTileMHC, TestTypes); // Temporarily disable old tests that use runtime parameters // TYPED_TEST(TestCkTileMHC, TestBasic) { this->RunExpansionParallelTest(); } // TYPED_TEST(TestCkTileMHC, TestArbitraryBatchSize) { this->RunArbitraryBatchSizeTest(); } // Explicit test cases for specific batch sizes (using template syntax) TYPED_TEST(TestCkTileMHC, TestBatchSize1) { this->template RunGemmPipeline<1>(); } TYPED_TEST(TestCkTileMHC, TestBatchSize16) { this->template RunGemmPipeline<16>(); } TYPED_TEST(TestCkTileMHC, TestBatchSize17) { this->template RunGemmPipeline<17>(); } TYPED_TEST(TestCkTileMHC, TestBatchSize32) { this->template RunGemmPipeline<32>(); } TYPED_TEST(TestCkTileMHC, TestBatchSize64) { this->template RunGemmPipeline<64>(); } TYPED_TEST(TestCkTileMHC, TestBatchSize100) { this->template RunGemmPipeline<100>(); } // Test with different expansion factors (keeping nC <= 256) TYPED_TEST(TestCkTileMHC, TestBatchSize16N2C128) { this->template RunGemmPipeline<16, 2, 128>(); } TYPED_TEST(TestCkTileMHC, TestBatchSize32N3C85) { this->template RunGemmPipeline<32, 3, 85>(); } TYPED_TEST(TestCkTileMHC, TestBatchSize8N8C32) { this->template RunGemmPipeline<8, 8, 32>(); } // Test with large C values that require K-tiling (nC > 256) TYPED_TEST(TestCkTileMHC, TestBatchSize2N4C512) { this->template RunGemmPipeline<2, 4, 512>(); } TYPED_TEST(TestCkTileMHC, TestBatchSize2N4C1024) { this->template RunGemmPipeline<2, 4, 1024>(); } TYPED_TEST(TestCkTileMHC, TestBatchSize2N4C4096) { this->template RunGemmPipeline<2, 4, 4096>(); } // Test with different activation functions TYPED_TEST(TestCkTileMHC, TestBatchSize16WithTanh) { this->template RunGemmPipeline<16, 4, 64, ck_tile::element_wise::TanH>(); } TYPED_TEST(TestCkTileMHC, TestBatchSize16WithRelu) { this->template RunGemmPipeline<16, 4, 64, ck_tile::element_wise::Relu>(); } TYPED_TEST(TestCkTileMHC, TestBatchSize16WithSilu) { this->template RunGemmPipeline<16, 4, 64, ck_tile::element_wise::Silu>(); } TYPED_TEST(TestCkTileMHC, TestBatchSize16N4C4096) { this->template RunGemmPipeline<16, 4, 4096>(); }