mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 05:01:25 +00:00
ck tile engine integrate with gemm unit tests (#2601)
* first try to understand how tile engine works * 1st implemented unit tests * manage different types for unit tests * manage using different config files to have different unit tests * manage different layouts * making instances and running them by unit test * Add reference calculation * manage different input dimension combination * add splitk to unit tests. clean code. * remove unused files * clean and test with a simple json file
This commit is contained in:
237
test/ck_tile/gemm_tile_engine/CMakeLists.txt
Normal file
237
test/ck_tile/gemm_tile_engine/CMakeLists.txt
Normal file
@@ -0,0 +1,237 @@
|
||||
# ============================================================================
|
||||
# GEMM Tile Engine Unit Tests
|
||||
#
|
||||
# This CMake file creates unit tests for tile_engine generated GEMM kernels.
|
||||
# It follows the exact same build patterns as tile_engine for consistency
|
||||
# and reliability. Each kernel configuration gets its own test executable.
|
||||
# ============================================================================
|
||||
|
||||
# Locate tile_engine GEMM scripts directory
|
||||
set(TILE_ENGINE_GEMM_DIR "${PROJECT_SOURCE_DIR}/tile_engine/ops/gemm")
|
||||
|
||||
if(NOT EXISTS ${TILE_ENGINE_GEMM_DIR})
|
||||
message(WARNING "Tile engine directory not found: ${TILE_ENGINE_GEMM_DIR}")
|
||||
return()
|
||||
endif()
|
||||
|
||||
# ============================================================================
|
||||
# create_individual_gemm_test_target
|
||||
#
|
||||
# Creates a single test executable for a specific kernel configuration.
|
||||
# Mirrors tile_engine's create_individual_gemm_target function for consistency.
|
||||
#
|
||||
# Parameters:
|
||||
# datatype - Data type (fp16, bf16, fp32, etc.)
|
||||
# layout - Matrix layout (rcr, rrr, ccr, crr)
|
||||
# config_name - Configuration file name without .json extension
|
||||
# trait - Kernel trait combination string
|
||||
# tile_config - Tile configuration parameters
|
||||
# config_json - Full path to JSON configuration file
|
||||
# ============================================================================
|
||||
function(create_individual_gemm_test_target datatype layout config_name trait tile_config config_json)
|
||||
set(target_name "test_gemm_tile_engine_${datatype}_${layout}_${config_name}_${trait}_${tile_config}")
|
||||
set(working_path "${CMAKE_CURRENT_BINARY_DIR}/${datatype}/${layout}/${config_name}")
|
||||
|
||||
# Generated header path for this specific kernel configuration
|
||||
set(test_header "${working_path}/gemm_single_${datatype}_${layout}_${trait}_${tile_config}.hpp")
|
||||
|
||||
# Generate kernel header using tile_engine's Python script
|
||||
add_custom_command(
|
||||
OUTPUT ${test_header}
|
||||
COMMAND ${Python3_EXECUTABLE} ${TILE_ENGINE_GEMM_DIR}/gemm_instance_builder.py
|
||||
--working_path ${working_path}
|
||||
--datatype ${datatype}
|
||||
--layout ${layout}
|
||||
--config_json ${config_json}
|
||||
--gen_single
|
||||
--kernel_name "test_gemm_${datatype}_${layout}_${trait}_${tile_config}"
|
||||
--tile_config "${tile_config}"
|
||||
--trait_combo "${trait}"
|
||||
DEPENDS ${TILE_ENGINE_GEMM_DIR}/gemm_instance_builder.py ${config_json}
|
||||
COMMENT "Generating test header ${test_header}"
|
||||
VERBATIM
|
||||
)
|
||||
|
||||
# Create GTest executable for this kernel configuration
|
||||
add_gtest_executable(${target_name}
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/test_gemm_simple.cpp
|
||||
)
|
||||
|
||||
# Ensure header is generated before compilation
|
||||
set(header_target "${target_name}_header")
|
||||
add_custom_target(${header_target} DEPENDS ${test_header})
|
||||
add_dependencies(${target_name} ${header_target})
|
||||
|
||||
# Configure GPU architectures for HIP compilation
|
||||
set_property(TARGET ${target_name} PROPERTY HIP_ARCHITECTURES ${GEMM_TEST_GPU_TARGETS})
|
||||
|
||||
# Define preprocessor macros for generated header location
|
||||
target_compile_definitions(${target_name} PRIVATE
|
||||
GEMM_SINGLE_INSTANCE_HPP="${test_header}"
|
||||
)
|
||||
|
||||
# Include directories for headers and dependencies
|
||||
target_include_directories(${target_name} PRIVATE
|
||||
${PROJECT_SOURCE_DIR}/include
|
||||
${PROJECT_BINARY_DIR}/include
|
||||
${PROJECT_SOURCE_DIR} # Root directory for tile_engine access
|
||||
${GTEST_INCLUDE_DIRS}
|
||||
)
|
||||
|
||||
# Compiler options matching tile_engine requirements
|
||||
target_compile_options(${target_name} PRIVATE
|
||||
-Wno-undefined-func-template # Suppress template warnings
|
||||
-Wno-float-equal # Allow floating point comparisons
|
||||
--offload-compress # Enable GPU code compression
|
||||
-include ${test_header} # Auto-include generated header
|
||||
)
|
||||
|
||||
message(STATUS " Created test target: ${target_name}")
|
||||
endfunction()
|
||||
|
||||
# ============================================================================
|
||||
# build_gemm_test_targets
|
||||
#
|
||||
# Builds all test targets for a specific datatype/layout/config combination.
|
||||
# Uses tile_engine's two-step process: list kernels, then generate tests.
|
||||
#
|
||||
# Parameters:
|
||||
# datatype - Data type (fp16, bf16, fp32, etc.)
|
||||
# layout - Matrix layout (rcr, rrr, ccr, crr)
|
||||
# config_name - Configuration file name without .json extension
|
||||
# ============================================================================
|
||||
function(build_gemm_test_targets datatype layout config_name)
|
||||
set(working_path "${CMAKE_CURRENT_BINARY_DIR}/${datatype}/${layout}/${config_name}")
|
||||
|
||||
# Locate and validate configuration file
|
||||
set(config_filename "${config_name}.json")
|
||||
set(json_blob "${CMAKE_CURRENT_SOURCE_DIR}/configs/${config_filename}")
|
||||
message(STATUS " Using test config: ${config_filename}")
|
||||
|
||||
if(NOT EXISTS ${json_blob})
|
||||
message(WARNING "Test config file not found: ${json_blob}")
|
||||
return()
|
||||
endif()
|
||||
|
||||
# Prepare build directory for this configuration
|
||||
file(MAKE_DIRECTORY ${working_path})
|
||||
|
||||
# STEP 1: Discovery phase - list all valid kernel configurations
|
||||
message(STATUS " Listing kernel configurations for ${datatype}_${layout}...")
|
||||
execute_process(
|
||||
COMMAND ${Python3_EXECUTABLE} -u ${TILE_ENGINE_GEMM_DIR}/gemm_instance_builder.py
|
||||
--working_path ${working_path}
|
||||
--datatype ${datatype}
|
||||
--layout ${layout}
|
||||
--config_json ${json_blob}
|
||||
--list_kernels
|
||||
WORKING_DIRECTORY ${TILE_ENGINE_GEMM_DIR}
|
||||
RESULT_VARIABLE ret
|
||||
OUTPUT_VARIABLE list_output
|
||||
ERROR_VARIABLE list_error
|
||||
)
|
||||
|
||||
if(NOT ret EQUAL 0)
|
||||
message(WARNING "Failed to list kernels for ${datatype}_${layout}: ${list_error}")
|
||||
return()
|
||||
endif()
|
||||
|
||||
# Validate kernel discovery results
|
||||
if(EXISTS ${working_path}/gemm_kernel_count.txt)
|
||||
file(READ ${working_path}/gemm_kernel_count.txt kernel_count)
|
||||
string(STRIP "${kernel_count}" kernel_count)
|
||||
message(STATUS " Found ${kernel_count} test configurations for ${datatype}_${layout}")
|
||||
else()
|
||||
message(WARNING "Kernel count file not found for ${datatype}_${layout}")
|
||||
return()
|
||||
endif()
|
||||
|
||||
# STEP 2: Generation phase - create test targets for each discovered kernel
|
||||
if(EXISTS ${working_path}/gemm_kernel_list.txt)
|
||||
file(STRINGS ${working_path}/gemm_kernel_list.txt kernel_lines)
|
||||
set(test_count 0)
|
||||
foreach(line IN LISTS kernel_lines)
|
||||
# Parse kernel specification format: kernel_name|tile_config|trait_combo
|
||||
string(REPLACE "|" ";" parts "${line}")
|
||||
list(LENGTH parts parts_len)
|
||||
if(parts_len EQUAL 3)
|
||||
list(GET parts 0 kernel_name)
|
||||
list(GET parts 1 tile_config)
|
||||
list(GET parts 2 trait_combo)
|
||||
|
||||
# Generate test target for this kernel configuration
|
||||
create_individual_gemm_test_target("${datatype}" "${layout}" "${config_name}" "${trait_combo}" "${tile_config}" "${json_blob}")
|
||||
math(EXPR test_count "${test_count} + 1")
|
||||
endif()
|
||||
endforeach()
|
||||
message(STATUS " Created ${test_count} test targets for ${datatype}_${layout}")
|
||||
else()
|
||||
message(WARNING "Kernel list file not found for ${datatype}_${layout}")
|
||||
endif()
|
||||
endfunction()
|
||||
|
||||
# ============================================================================
|
||||
# MAIN EXECUTION - Test Target Generation
|
||||
# ============================================================================
|
||||
|
||||
message(STATUS "=== Starting GEMM Tile Engine Test Configuration ===")
|
||||
message(STATUS "SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
|
||||
|
||||
# GPU architecture filtering - only build tests for supported architectures
|
||||
set(GEMM_TEST_GPU_TARGETS "")
|
||||
set(DESIRED_TARGETS "gfx90a;gfx942;gfx950;gfx1201")
|
||||
|
||||
foreach(target IN LISTS SUPPORTED_GPU_TARGETS)
|
||||
if(target IN_LIST DESIRED_TARGETS)
|
||||
list(APPEND GEMM_TEST_GPU_TARGETS ${target})
|
||||
message(STATUS " Adding GPU target for tests: ${target}")
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
# Early exit if no compatible GPU architectures are available
|
||||
if(NOT GEMM_TEST_GPU_TARGETS)
|
||||
message(WARNING "Skipping GEMM Tile Engine tests: No supported GPU targets (gfx90a, gfx942, gfx950, gfx1201) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
|
||||
return()
|
||||
endif()
|
||||
|
||||
message(STATUS "Building GEMM tile engine tests for GPU targets: ${GEMM_TEST_GPU_TARGETS}")
|
||||
|
||||
# ============================================================================
|
||||
# Test Configuration Matrix
|
||||
# ============================================================================
|
||||
|
||||
# Available test configurations (minimal set for fast CI/testing)
|
||||
set(TEST_CONFIGS
|
||||
"simple_test_config"
|
||||
# "medium_tiles_config" # Uncomment for broader testing
|
||||
)
|
||||
|
||||
# Data types for testing (core precision types)
|
||||
set(TEST_DATATYPES "fp16" "bf16")
|
||||
# Extended data type options:
|
||||
# set(TEST_DATATYPES "fp16" "bf16" "fp32" "fp64" "int8")
|
||||
|
||||
# Matrix layouts for testing (row-column-row is most common)
|
||||
set(TEST_LAYOUTS "rcr")
|
||||
# Extended layout options:
|
||||
# set(TEST_LAYOUTS "rcr" "rrr" "ccr" "crr")
|
||||
|
||||
# ============================================================================
|
||||
# Test Target Generation Loop
|
||||
# ============================================================================
|
||||
|
||||
foreach(datatype IN LISTS TEST_DATATYPES)
|
||||
foreach(layout IN LISTS TEST_LAYOUTS)
|
||||
foreach(config IN LISTS TEST_CONFIGS)
|
||||
set(CONFIG_FILE "${CMAKE_CURRENT_SOURCE_DIR}/configs/${config}.json")
|
||||
if(EXISTS ${CONFIG_FILE})
|
||||
message(STATUS "Building tests for ${datatype}_${layout}_${config}")
|
||||
build_gemm_test_targets("${datatype}" "${layout}" "${config}")
|
||||
else()
|
||||
message(WARNING "Config file not found: ${CONFIG_FILE}")
|
||||
endif()
|
||||
endforeach()
|
||||
endforeach()
|
||||
endforeach()
|
||||
|
||||
message(STATUS "GEMM tile engine tests configured for ${TEST_DATATYPES} with ${TEST_LAYOUTS} layouts using ${TEST_CONFIGS} configurations")
|
||||
27
test/ck_tile/gemm_tile_engine/README.md
Normal file
27
test/ck_tile/gemm_tile_engine/README.md
Normal file
@@ -0,0 +1,27 @@
|
||||
# GEMM Tile Engine Unit Tests
|
||||
|
||||
## How It Works
|
||||
|
||||
This unit test system integrates **tile_engine's kernel generation** into automated testing:
|
||||
|
||||
1. **Uses tile_engine scripts directly**: Same Python scripts that generate tile_engine kernels
|
||||
2. **JSON-based configuration**: Define test parameters in JSON files (like tile_engine)
|
||||
3. **Build-time generation**: CMake calls tile_engine scripts to generate kernel headers
|
||||
4. **Individual test executables**: Each kernel configuration becomes a separate test
|
||||
5. **Tile_engine verification**: Uses exact same error thresholds and validation as tile_engine
|
||||
|
||||
## Tile Engine Integration
|
||||
|
||||
```
|
||||
JSON Config → tile_engine Python scripts → Generated Headers → Test Executables
|
||||
```
|
||||
|
||||
- **`--list_kernels`**: Get available kernel configurations from JSON
|
||||
- **`--gen_single`**: Generate individual kernel header for each configuration
|
||||
- **Same verification**: Uses tile_engine's adaptive error thresholds and reference calculations
|
||||
- **Same patterns**: Follows tile_engine's tensor initialization, stride calculation, and kernel launching
|
||||
|
||||
|
||||
|
||||
|
||||
The key idea: **Unit tests that use tile_engine's exact kernel generation and verification methodology** instead of creating separate test infrastructure.
|
||||
@@ -0,0 +1,89 @@
|
||||
{
|
||||
"problem": {
|
||||
},
|
||||
"tile_config": {
|
||||
"tile_m": {
|
||||
"values": [
|
||||
128
|
||||
]
|
||||
},
|
||||
"tile_n": {
|
||||
"values": [
|
||||
128
|
||||
]
|
||||
},
|
||||
"tile_k": {
|
||||
"values": [
|
||||
64
|
||||
]
|
||||
},
|
||||
"warp_m": {
|
||||
"values": [
|
||||
2
|
||||
]
|
||||
},
|
||||
"warp_n": {
|
||||
"values": [
|
||||
2
|
||||
]
|
||||
},
|
||||
"warp_k": {
|
||||
"values": [
|
||||
1
|
||||
]
|
||||
},
|
||||
"warp_tile_m": {
|
||||
"values": [
|
||||
16
|
||||
]
|
||||
},
|
||||
"warp_tile_n": {
|
||||
"values": [
|
||||
16
|
||||
]
|
||||
},
|
||||
"warp_tile_k": {
|
||||
"values": [
|
||||
16
|
||||
]
|
||||
}
|
||||
},
|
||||
"trait_config": {
|
||||
"pipeline": {
|
||||
"values": [
|
||||
"compv3",
|
||||
"compv4"
|
||||
]
|
||||
},
|
||||
"scheduler": {
|
||||
"values": [
|
||||
"intrawave"
|
||||
]
|
||||
},
|
||||
"epilogue": {
|
||||
"values": [
|
||||
"default"
|
||||
]
|
||||
},
|
||||
"pad_m": {
|
||||
"values": [
|
||||
false
|
||||
]
|
||||
},
|
||||
"pad_n": {
|
||||
"values": [
|
||||
false
|
||||
]
|
||||
},
|
||||
"pad_k": {
|
||||
"values": [
|
||||
false
|
||||
]
|
||||
},
|
||||
"persistent": {
|
||||
"values": [
|
||||
false
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
223
test/ck_tile/gemm_tile_engine/test_gemm_simple.cpp
Normal file
223
test/ck_tile/gemm_tile_engine/test_gemm_simple.cpp
Normal file
@@ -0,0 +1,223 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// Unit tests for tile_engine generated GEMM kernels
|
||||
// Tests kernel correctness using tile_engine's verification methodology
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <iostream>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "tile_engine/ops/gemm/gemm_common.hpp"
|
||||
|
||||
// The kernel header is included via compile command line with -include flag
|
||||
// It defines SelectedKernel struct, KERNEL_NAME, and tensor data types
|
||||
|
||||
// Adaptive error threshold calculation matching tile_engine's implementation
|
||||
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
|
||||
auto calculate_rtol_atol(const ck_tile::index_t K,
|
||||
const ck_tile::index_t kbatch,
|
||||
const float max_accumulated_value)
|
||||
{
|
||||
using ComputeType =
|
||||
std::conditional_t<sizeof(ADataType) < sizeof(BDataType), ADataType, BDataType>;
|
||||
// Calculate thresholds
|
||||
const auto rtol = ck_tile::get_relative_threshold<ComputeType, CDataType, AccDataType>(
|
||||
ck_tile::integer_divide_ceil(K, kbatch));
|
||||
const auto atol = ck_tile::get_absolute_threshold<ComputeType, CDataType, AccDataType>(
|
||||
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch));
|
||||
// Calculate error due to split_k accumulation
|
||||
const auto rtol_split_k =
|
||||
ck_tile::get_relative_threshold<CDataType, CDataType, CDataType>(kbatch);
|
||||
const auto atol_split_k = ck_tile::get_absolute_threshold<CDataType, CDataType, CDataType>(
|
||||
max_accumulated_value, kbatch);
|
||||
// Use higher threshold
|
||||
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
|
||||
}
|
||||
|
||||
/// @brief Function to compare the results of the device and host computations (from tile_engine)
|
||||
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
|
||||
bool compare_results(std::string instanceName,
|
||||
ck_tile::index_t K,
|
||||
ck_tile::index_t kbatch,
|
||||
ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
|
||||
ck_tile::HostTensor<CDataType>& c_m_n_host_result)
|
||||
{
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(c_m_n_host_result.mData.begin(), c_m_n_host_result.mData.end());
|
||||
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
|
||||
K, kbatch, max_accumulated_value);
|
||||
bool pass = ck_tile::check_err(c_m_n_dev_result,
|
||||
c_m_n_host_result,
|
||||
"Error: Incorrect results!",
|
||||
rtol_atol.at(ck_tile::number<0>{}),
|
||||
rtol_atol.at(ck_tile::number<1>{}));
|
||||
|
||||
std::cout << "For " << instanceName << " Relative error threshold is "
|
||||
<< rtol_atol.at(ck_tile::number<0>{}) << " Absolute error threshold is "
|
||||
<< rtol_atol.at(ck_tile::number<1>{}) << std::endl;
|
||||
std::cout << "The verification result is:" << (pass ? "correct" : "fail") << std::endl;
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
// Test parameter structure for matrix dimensions and split_k values
|
||||
struct GemmTestParams
|
||||
{
|
||||
int m, n, k, split_k;
|
||||
};
|
||||
|
||||
class GemmTileEngineTest : public ::testing::TestWithParam<GemmTestParams>
|
||||
{
|
||||
protected:
|
||||
void SetUp() override
|
||||
{
|
||||
auto params = GetParam();
|
||||
m_ = params.m;
|
||||
n_ = params.n;
|
||||
k_ = params.k;
|
||||
split_k_ = params.split_k;
|
||||
|
||||
// Calculate strides (following tile_engine pattern)
|
||||
if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
stride_a_ = k_;
|
||||
}
|
||||
else
|
||||
{
|
||||
stride_a_ = m_;
|
||||
}
|
||||
|
||||
if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
stride_b_ = n_;
|
||||
}
|
||||
else
|
||||
{
|
||||
stride_b_ = k_;
|
||||
}
|
||||
|
||||
if constexpr(std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
stride_c_ = n_;
|
||||
}
|
||||
else
|
||||
{
|
||||
stride_c_ = m_;
|
||||
}
|
||||
}
|
||||
|
||||
// Test dimensions
|
||||
int m_, n_, k_, split_k_;
|
||||
int stride_a_, stride_b_, stride_c_;
|
||||
};
|
||||
|
||||
TEST_P(GemmTileEngineTest, BasicFunctionality)
|
||||
{
|
||||
// Get tensor layouts from generated kernel
|
||||
const ALayout layout_a = ALayout{};
|
||||
const BLayout layout_b = BLayout{};
|
||||
const CLayout layout_c = CLayout{};
|
||||
|
||||
// Use split_k from test parameters
|
||||
int split_k = split_k_;
|
||||
int stride_a_calc = ck_tile::get_default_stride(m_, k_, 0, is_row_major(layout_a));
|
||||
int stride_b_calc = ck_tile::get_default_stride(k_, n_, 0, is_row_major(layout_b));
|
||||
int stride_c_calc = ck_tile::get_default_stride(m_, n_, 0, is_row_major(layout_c));
|
||||
|
||||
// Create host tensors with proper descriptors
|
||||
ck_tile::HostTensor<ADataType> a_m_k(
|
||||
ck_tile::host_tensor_descriptor(m_, k_, stride_a_calc, is_row_major(layout_a)));
|
||||
ck_tile::HostTensor<BDataType> b_k_n(
|
||||
ck_tile::host_tensor_descriptor(k_, n_, stride_b_calc, is_row_major(layout_b)));
|
||||
ck_tile::HostTensor<CDataType> c_m_n_dev_result(
|
||||
ck_tile::host_tensor_descriptor(m_, n_, stride_c_calc, is_row_major(layout_c)));
|
||||
ck_tile::HostTensor<CDataType> c_m_n_host_result(
|
||||
ck_tile::host_tensor_descriptor(m_, n_, stride_c_calc, is_row_major(layout_c)));
|
||||
|
||||
// Initialize input tensors with uniform random distribution [-1.0, 1.0] (matches tile_engine)
|
||||
ck_tile::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n);
|
||||
|
||||
// Allocate GPU device memory
|
||||
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes());
|
||||
|
||||
// Copy data to device and zero output buffer
|
||||
a_m_k_dev_buf.ToDevice(a_m_k.data());
|
||||
b_k_n_dev_buf.ToDevice(b_k_n.data());
|
||||
c_m_n_dev_buf.SetZero();
|
||||
c_m_n_dev_result.SetZero();
|
||||
|
||||
// Calculate reference result on host for verification
|
||||
ck_tile::reference_gemm<ADataType, BDataType, AccDataType, CDataType>(
|
||||
a_m_k, b_k_n, c_m_n_host_result);
|
||||
|
||||
// Create GEMM kernel arguments
|
||||
ck_tile::GemmHostArgs gemm_args(a_m_k_dev_buf.GetDeviceBuffer(),
|
||||
b_k_n_dev_buf.GetDeviceBuffer(),
|
||||
c_m_n_dev_buf.GetDeviceBuffer(),
|
||||
split_k,
|
||||
m_,
|
||||
n_,
|
||||
k_,
|
||||
stride_a_calc,
|
||||
stride_b_calc,
|
||||
stride_c_calc);
|
||||
|
||||
// Configure kernel execution for maximum speed (no timing, no debug output)
|
||||
ck_tile::stream_config stream_config{nullptr, // stream
|
||||
false, // time_kernel (disable timing for speed)
|
||||
0, // log_level (disable debug output)
|
||||
0, // n_warmup
|
||||
1, // n_repeat
|
||||
false, // is_gpu_timer (unused when time_kernel=false)
|
||||
false, // flush_cache
|
||||
1}; // rotating_count
|
||||
|
||||
// Launch the generated kernel (no timing overhead for fastest execution)
|
||||
try
|
||||
{
|
||||
SelectedKernel::launch(gemm_args, stream_config);
|
||||
// Kernel launched successfully if no exception thrown
|
||||
}
|
||||
catch(const std::exception& e)
|
||||
{
|
||||
FAIL() << "Kernel launch failed: " << e.what();
|
||||
}
|
||||
|
||||
// Copy result back from device
|
||||
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
|
||||
|
||||
// Verify results using tile_engine's adaptive error thresholds
|
||||
bool verification_passed = compare_results<ADataType, BDataType, AccDataType, CDataType>(
|
||||
KERNEL_NAME, k_, split_k, c_m_n_dev_result, c_m_n_host_result);
|
||||
|
||||
EXPECT_TRUE(verification_passed) << "GEMM result verification failed";
|
||||
}
|
||||
|
||||
TEST_P(GemmTileEngineTest, KernelInfo)
|
||||
{
|
||||
// Simple test to verify kernel information is available
|
||||
EXPECT_TRUE(strlen(KERNEL_NAME) > 0) << "Kernel name should not be empty";
|
||||
|
||||
std::cout << "Testing kernel: " << KERNEL_NAME << std::endl;
|
||||
std::cout << "Problem size: " << m_ << "x" << n_ << "x" << k_ << " with split_k=" << split_k_
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
// Define test parameters for GEMM verification
|
||||
INSTANTIATE_TEST_SUITE_P(GemmVerification,
|
||||
GemmTileEngineTest,
|
||||
::testing::Values(GemmTestParams{256, 256, 128, 1},
|
||||
GemmTestParams{256, 256, 1024, 1},
|
||||
GemmTestParams{256, 512, 512, 1},
|
||||
GemmTestParams{512, 256, 512, 1}),
|
||||
[](const ::testing::TestParamInfo<GemmTestParams>& param_info) {
|
||||
return std::to_string(param_info.param.m) + "x" +
|
||||
std::to_string(param_info.param.n) + "x" +
|
||||
std::to_string(param_info.param.k) + "_splitk" +
|
||||
std::to_string(param_info.param.split_k);
|
||||
});
|
||||
Reference in New Issue
Block a user