Files
composable_kernel/test/ck_tile/mhc/test_mhc.cpp
2026-02-06 10:59:21 +00:00

97 lines
3.8 KiB
C++

// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <gtest/gtest.h>
#include <vector>
#include <cmath>
#include <tuple>
#include <iostream>
#include <cstring>
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "test_mhc_impl.hpp"
// Shape parameters for different test configurations
using Shape1_BlockWarps = ck_tile::sequence<4, 1>;
using Shape1_BlockTile = ck_tile::sequence<128, 128>;
using Shape1_WarpTile = ck_tile::sequence<32, 128>;
using Shape1_ThreadTile = ck_tile::sequence<8, 8>;
// Test configurations for different data types
// Tuple format: <XDataType, PhiDataType, YDataType, ComputeDataType, BlockWarps, BlockTile,
// WarpTile, ThreadTile>
using TestConfig_F32 = std::tuple<float, // XDataType
float, // PhiDataType
float, // CDataType
float, // ComputeDataType
Shape1_BlockWarps,
Shape1_BlockTile,
Shape1_WarpTile,
Shape1_ThreadTile>;
using TestConfig_BF16 = std::tuple<ck_tile::bf16_t, // XDataType
ck_tile::bf16_t, // PhiDataType
float, // CDataType
float, // ComputeDataType
Shape1_BlockWarps,
Shape1_BlockTile,
Shape1_WarpTile,
Shape1_ThreadTile>;
using TestTypes = ::testing::Types<TestConfig_F32, TestConfig_BF16>;
TYPED_TEST_SUITE(TestCkTileMHC, TestTypes);
// Temporarily disable old tests that use runtime parameters
// TYPED_TEST(TestCkTileMHC, TestBasic) { this->RunExpansionParallelTest(); }
// TYPED_TEST(TestCkTileMHC, TestArbitraryBatchSize) { this->RunArbitraryBatchSizeTest(); }
// Explicit test cases for specific batch sizes (using template syntax)
TYPED_TEST(TestCkTileMHC, TestBatchSize1) { this->template RunGemmPipeline<1>(); }
TYPED_TEST(TestCkTileMHC, TestBatchSize16) { this->template RunGemmPipeline<16>(); }
TYPED_TEST(TestCkTileMHC, TestBatchSize17) { this->template RunGemmPipeline<17>(); }
TYPED_TEST(TestCkTileMHC, TestBatchSize32) { this->template RunGemmPipeline<32>(); }
TYPED_TEST(TestCkTileMHC, TestBatchSize64) { this->template RunGemmPipeline<64>(); }
TYPED_TEST(TestCkTileMHC, TestBatchSize100) { this->template RunGemmPipeline<100>(); }
// Test with different expansion factors (keeping nC <= 256)
TYPED_TEST(TestCkTileMHC, TestBatchSize16N2C128) { this->template RunGemmPipeline<16, 2, 128>(); }
TYPED_TEST(TestCkTileMHC, TestBatchSize32N3C85) { this->template RunGemmPipeline<32, 3, 85>(); }
TYPED_TEST(TestCkTileMHC, TestBatchSize8N8C32) { this->template RunGemmPipeline<8, 8, 32>(); }
// Test with large C values that require K-tiling (nC > 256)
TYPED_TEST(TestCkTileMHC, TestBatchSize2N4C512) { this->template RunGemmPipeline<2, 4, 512>(); }
TYPED_TEST(TestCkTileMHC, TestBatchSize2N4C1024) { this->template RunGemmPipeline<2, 4, 1024>(); }
TYPED_TEST(TestCkTileMHC, TestBatchSize2N4C4096) { this->template RunGemmPipeline<2, 4, 4096>(); }
// Test with different activation functions
TYPED_TEST(TestCkTileMHC, TestBatchSize16WithTanh)
{
this->template RunGemmPipeline<16, 4, 64, ck_tile::element_wise::TanH>();
}
TYPED_TEST(TestCkTileMHC, TestBatchSize16WithRelu)
{
this->template RunGemmPipeline<16, 4, 64, ck_tile::element_wise::Relu>();
}
TYPED_TEST(TestCkTileMHC, TestBatchSize16WithSilu)
{
this->template RunGemmPipeline<16, 4, 64, ck_tile::element_wise::Silu>();
}
TYPED_TEST(TestCkTileMHC, TestBatchSize16N4C4096) { this->template RunGemmPipeline<16, 4, 4096>(); }