Addition of Stream-K tests using Tile Engine (#3514)

* Addition of Stream-K tests using Tile Engine

This change adds an implementation for generating Stream-K tests using Tile Engine.
This will generate various test executables for different combinations based on the
config files. This addition has simple tests running for bf16 and fp16, with both
atomic and reduction strategies and compv3 pipeline. The tests rely on the implementation
of Stream-K in Tile Engine.

* integrating addition of tree reduction and editing the README

* temporarily removing parallel and tree reduction from configs while bugs regarding them are being resolved
This commit is contained in:
arai713
2026-01-22 12:53:52 -08:00
committed by GitHub
parent 31a35ecab4
commit b9bb1db5d9
9 changed files with 723 additions and 2 deletions

View File

@@ -41,3 +41,4 @@ add_subdirectory(fmha)
add_subdirectory(gemm_tile_engine)
add_subdirectory(pooling)
add_subdirectory(grouped_conv)
add_subdirectory(gemm_streamk_tile_engine)

View File

@@ -0,0 +1,306 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
# ============================================================================
# GEMM Tile Engine Unit Tests
#
# This CMake file creates unit tests for tile_engine generated GEMM kernels.
# It follows the exact same build patterns as tile_engine for consistency
# and reliability. Each kernel configuration gets its own test executable.
# ============================================================================
# Locate tile_engine GEMM scripts directory
set(TILE_ENGINE_GEMM_DIR "${PROJECT_SOURCE_DIR}/tile_engine/ops/gemm_streamk")
if(NOT EXISTS ${TILE_ENGINE_GEMM_DIR})
message(WARNING "Tile engine directory not found: ${TILE_ENGINE_GEMM_DIR}")
return()
endif()
# ============================================================================
# create_individual_gemm_test_target
#
# Creates a single test executable for a specific kernel configuration.
# Mirrors tile_engine's create_individual_gemm_target function for consistency.
#
# Parameters:
# datatype - Data type (fp16, bf16, fp32, etc.)
# layout - Matrix layout (rcr, rrr, ccr, crr)
# config_name - Configuration file name without .json extension
# trait - Kernel trait combination string
# tile_config - Tile configuration parameters
# config_json - Full path to JSON configuration file
# ============================================================================
function(create_individual_gemm_test_target datatype layout config_name trait tile_config config_json)
set(target_name "test_gemm_streamk_tile_engine_${datatype}_${layout}_${config_name}_${trait}_${tile_config}")
set(working_path "${CMAKE_CURRENT_BINARY_DIR}/${datatype}/${layout}/${config_name}")
# Generated header path (already created during cmake configuration)
set(test_header "${working_path}/gemm_streamk_single_${datatype}_${layout}_${trait}_${tile_config}.hpp")
set(test_params_header "${working_path}/test_params.hpp")
# Verify header exists (should have been generated during cmake configuration)
if(NOT EXISTS ${test_header})
message(WARNING "Generated header not found: ${test_header}")
return()
endif()
# Verify test parameters header exists
if(NOT EXISTS ${test_params_header})
message(WARNING "Test parameters header not found: ${test_params_header}")
return()
endif()
# Create GTest executable for this kernel configuration
add_gtest_executable(${target_name}
${CMAKE_CURRENT_SOURCE_DIR}/test_gemm_streamk_simple.cpp
)
# Configure GPU architectures for HIP compilation
set_property(TARGET ${target_name} PROPERTY HIP_ARCHITECTURES ${GEMM_TEST_GPU_TARGETS})
# Define preprocessor macros for generated header location and test parameters
target_compile_definitions(${target_name} PRIVATE
GEMM_SINGLE_INSTANCE_HPP="${test_header}"
GEMM_TEST_PARAMS_HPP="${test_params_header}"
)
# Include directories for headers and dependencies
target_include_directories(${target_name} PRIVATE
${PROJECT_SOURCE_DIR}/include
${PROJECT_BINARY_DIR}/include
${PROJECT_SOURCE_DIR} # Root directory for tile_engine access
${GTEST_INCLUDE_DIRS}
)
# Compiler options matching tile_engine requirements
target_compile_options(${target_name} PRIVATE
-Wno-undefined-func-template # Suppress template warnings
-Wno-float-equal # Allow floating point comparisons
--offload-compress # Enable GPU code compression
-include ${test_header} # Auto-include generated header
)
# Add FP8 format definitions for proper data type interpretation
if(CK_USE_OCP_FP8)
target_compile_options(${target_name} PRIVATE -DCK_TILE_USE_OCP_FP8)
endif()
message(STATUS " Created test target: ${target_name}")
endfunction()
# ============================================================================
# build_gemm_test_targets
#
# Builds all test targets for a specific datatype/layout/config combination.
# Uses tile_engine's two-step process: list kernels, then generate tests.
#
# Parameters:
# datatype - Data type (fp16, bf16, fp32, etc.)
# layout - Matrix layout (rcr, rrr, ccr, crr)
# config_name - Configuration file name without .json extension
# ============================================================================
function(build_gemm_test_targets datatype layout config_name)
set(working_path "${CMAKE_CURRENT_BINARY_DIR}/${datatype}/${layout}/${config_name}")
# Locate and validate configuration file
set(config_filename "${config_name}.json")
set(json_blob "${CMAKE_CURRENT_SOURCE_DIR}/configs/${config_filename}")
if(NOT EXISTS ${json_blob})
message(WARNING "Test config file not found: ${json_blob}")
return()
endif()
# Prepare build directory for this configuration
file(MAKE_DIRECTORY ${working_path})
# STEP 1: Discovery phase - list all valid kernel configurations
execute_process(
COMMAND ${Python3_EXECUTABLE} -u ${TILE_ENGINE_GEMM_DIR}/gemm_streamk_instance_builder.py
--working_path ${working_path}
--datatype ${datatype}
--layout ${layout}
--config_json ${json_blob}
--list_kernels
WORKING_DIRECTORY ${TILE_ENGINE_GEMM_DIR}
RESULT_VARIABLE ret
OUTPUT_VARIABLE list_output
ERROR_VARIABLE list_error
)
if(NOT ret EQUAL 0)
message(WARNING "Failed to list kernels for ${datatype}_${layout}_${config_name}: ${list_error}")
return()
endif()
# Verify kernel list file was generated
if(NOT EXISTS ${working_path}/gemm_kernel_list.txt)
message(STATUS "No kernels found for ${datatype}_${layout}_${config_name} (validation filtered out all combinations)")
return()
endif()
message(STATUS "Building tests for ${datatype}_${layout}_${config_name}")
# STEP 2a: Extract test parameters from config
set(test_params_file "${working_path}/test_params.hpp")
execute_process(
COMMAND ${Python3_EXECUTABLE} -u ${CMAKE_CURRENT_SOURCE_DIR}/extract_test_params.py
--config_file ${json_blob}
--output_file ${test_params_file}
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
RESULT_VARIABLE extract_ret
OUTPUT_VARIABLE extract_output
ERROR_VARIABLE extract_error
)
if(NOT extract_ret EQUAL 0)
message(WARNING "Failed to extract test parameters for ${datatype}_${layout}: ${extract_error}")
return()
endif()
# STEP 2b: Header generation phase - generate headers using --gen_single
message(STATUS " Generating headers using --gen_single...")
file(STRINGS ${working_path}/gemm_kernel_list.txt kernel_lines)
set(gen_count 0)
foreach(line IN LISTS kernel_lines)
# Parse kernel specification format: kernel_name|tile_config|trait_combo
string(REPLACE "|" ";" parts "${line}")
list(LENGTH parts parts_len)
if(parts_len EQUAL 3)
list(GET parts 0 kernel_name)
list(GET parts 1 tile_config)
list(GET parts 2 trait_combo)
# Generate header using --gen_single
execute_process(
COMMAND ${Python3_EXECUTABLE} -u ${TILE_ENGINE_GEMM_DIR}/gemm_streamk_instance_builder.py
--working_path ${working_path}
--datatype ${datatype}
--layout ${layout}
--config_json ${json_blob}
--gen_single
--kernel_name "${kernel_name}"
--tile_config "${tile_config}"
--trait_combo "${trait_combo}"
WORKING_DIRECTORY ${TILE_ENGINE_GEMM_DIR}
RESULT_VARIABLE gen_ret
OUTPUT_VARIABLE gen_output
ERROR_VARIABLE gen_error
)
if(NOT gen_ret EQUAL 0)
message(WARNING "Failed to generate header for ${kernel_name}: ${gen_error}")
else()
math(EXPR gen_count "${gen_count} + 1")
endif()
endif()
endforeach()
message(STATUS " Generated ${gen_count} headers for ${datatype}_${layout}")
# STEP 3: Target creation phase - create test targets
message(STATUS " Creating test targets...")
file(STRINGS ${working_path}/gemm_kernel_list.txt kernel_lines)
set(test_count 0)
foreach(line IN LISTS kernel_lines)
# Parse kernel specification format: kernel_name|tile_config|trait_combo
string(REPLACE "|" ";" parts "${line}")
list(LENGTH parts parts_len)
if(parts_len EQUAL 3)
list(GET parts 0 kernel_name)
list(GET parts 1 tile_config)
list(GET parts 2 trait_combo)
# Generate test target for this kernel configuration
create_individual_gemm_test_target("${datatype}" "${layout}" "${config_name}" "${trait_combo}" "${tile_config}" "${json_blob}")
math(EXPR test_count "${test_count} + 1")
endif()
endforeach()
message(STATUS " Created ${test_count} test targets for ${datatype}_${layout}")
endfunction()# ============================================================================
# MAIN EXECUTION - Test Target Generation
# ============================================================================
message(STATUS "=== Starting StreamK GEMM Tile Engine Test Configuration ===")
message(STATUS "SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
# GPU architecture filtering - only build tests for supported architectures
set(GEMM_TEST_GPU_TARGETS "")
set(DESIRED_TARGETS "gfx90a;gfx942")
foreach(target IN LISTS SUPPORTED_GPU_TARGETS)
if(target IN_LIST DESIRED_TARGETS)
list(APPEND GEMM_TEST_GPU_TARGETS ${target})
message(STATUS " Adding GPU target for tests: ${target}")
endif()
endforeach()
# Early exit if no compatible GPU architectures are available
if(NOT GEMM_TEST_GPU_TARGETS)
message(WARNING "Skipping StreamK GEMM Tile Engine tests: No supported GPU targets (gfx90a, gfx942) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
return()
endif()
message(STATUS "Building StreamK GEMM tile engine tests for GPU targets: ${GEMM_TEST_GPU_TARGETS}")
# Enable parallel compilation optimizations
# Set up job pools for better parallel compilation control
set_property(GLOBAL PROPERTY JOB_POOLS
compile_heavy=4 # Limit heavy compilations to prevent OOM
compile_normal=16 # Allow more parallel normal compilations
)
# Enable compiler cache if available and explicitly requested
# Disabled by default due to permission issues in CI environments
option(ENABLE_CCACHE_TESTS "Enable ccache for test compilation" OFF)
if(ENABLE_CCACHE_TESTS)
find_program(CCACHE_PROGRAM ccache)
if(CCACHE_PROGRAM)
set(CMAKE_CXX_COMPILER_LAUNCHER ${CCACHE_PROGRAM})
message(STATUS "Using ccache for faster test compilation")
else()
message(WARNING "ccache requested but not found")
endif()
else()
message(STATUS "ccache disabled for tests (use -DENABLE_CCACHE_TESTS=ON to enable)")
endif()
# ============================================================================
# Test Configuration Matrix - Clean Focused Design
# ============================================================================
# All supported data types and layouts for comprehensive testing
# Note: fp64 not included (no MFMA hardware support)
set(TEST_DATATYPES "fp16;bf16")
set(TEST_LAYOUTS "rcr;rrr;ccr;crr")
# ============================================================================
# Test Target Generation - Datatype-Specific Categories
# ============================================================================
# 1. SIMPLE TEST: Test for basic functionality with data types (fp16, bf16)
# These data types can use larger warp tiles due to smaller memory footprint
set(SIMPLE_TEST_CONFIG "simple_test_config")
set(SIMPLE_TEST_CONFIG_FILE "${CMAKE_CURRENT_SOURCE_DIR}/configs/${SIMPLE_TEST_CONFIG}.json")
set(SIMPLE_DATATYPES "fp16;bf16")
if(EXISTS ${SIMPLE_TEST_CONFIG_FILE})
message(STATUS "Processing simple test config: ${SIMPLE_TEST_CONFIG} (fp16, bf16)")
foreach(datatype IN LISTS SIMPLE_DATATYPES)
# fp16, bf16: testing all layouts (rcr, rrr, ccr, crr)
foreach(layout IN LISTS TEST_LAYOUTS)
build_gemm_test_targets("${datatype}" "${layout}" "${SIMPLE_TEST_CONFIG}")
endforeach()
endforeach()
else()
message(WARNING "Simple test config file not found: ${SIMPLE_TEST_CONFIG_FILE}")
endif()
# ============================================================================
message(STATUS "StreamK GEMM tile engine tests configured with datatype-specific design:")
message(STATUS " - Simple test: fp16/bf16 (all layouts)")

View File

@@ -0,0 +1,56 @@
# Stream-K GEMM Tile Engine Unit Tests
## How It Works
This unit test system integrates **tile_engine's kernel generation** into automated testing:
1. **Uses tile_engine scripts directly**: Same Python scripts that generate tile_engine kernels
2. **JSON-based configuration**: Define test parameters in JSON files (like tile_engine)
3. **Build-time generation**: CMake calls tile_engine scripts to generate kernel headers
4. **Individual test executables**: Each kernel configuration becomes a separate test
5. **Tile_engine verification**: Uses exact same error thresholds and validation as tile_engine
## Tile Engine Integration
```
JSON Config → tile_engine Python scripts → Generated Headers → Test Executables
```
- **`--list_kernels`**: Get available kernel configurations from JSON
- **`--gen_individual`**: Generate all kernel headers in parallel during CMake configuration
- **`--gen_single`**: Generate individual kernel header for each configuration
- **Same verification**: Uses tile_engine's adaptive error thresholds and reference calculations
- **Same patterns**: Follows tile_engine's tensor initialization, stride calculation, and kernel launching
### Config-Specific Test Parameters
Each test configuration can specify optimized problem sizes in its JSON file:
- **`test_params.problem_sizes`**: Array of `{m, n, k, split_k}` configurations
- **CMake extraction**: `extract_test_params.py` generates config-specific test parameter files
- **Build integration**: Each test target uses parameters appropriate for its kernel configuration
- **Optimized testing**: Different configs test different problem sizes that showcase their strengths
The key idea: **Unit tests that use tile_engine's exact kernel generation and verification methodology** instead of creating separate test infrastructure.
## Test Configurations
### 1. **Simple Test** (`simple_test_config.json`)
- **Purpose**: Basic functionality validation for fp16/bf16 data types
- **Config**: 128x128x32, warp 2x2x1, warp_tile 32x32x16
- **Traits**: compv3 pipeline only
- **Coverage**: All 4 layouts (rcr, rrr, ccr, crr) for fp16, bf16
## Data Type Support
-**fp16, bf16**: Fully supported - all layouts (rcr, rrr, ccr, crr)
-**fp64**: Not supported (hardware MFMA limitation)
-**fp32, bf8, pk-int4-t**: Not yet supported by gemm_instance_builder (will be added later)
## Test Result Behavior
Tests automatically handle unsupported configurations through runtime validation:
- **PASSED**: Kernel executed correctly with results within error thresholds ✅
- **SKIPPED**: Kernel validation returned "Arguments not supported" (expected for certain problem sizes/configurations) ⚠️
- **FAILED**: Actual error or incorrect computation results ❌
When a kernel's `IsSupportedArgument()` check fails (e.g., due to vector alignment requirements, dimension constraints, or padding limitations), the test is automatically skipped rather than failed. This allows comprehensive testing across various problem sizes while gracefully handling configurations that don't meet specific kernel requirements.

View File

@@ -0,0 +1,35 @@
{
"problem": {
"description": "Basic functionality validation with moderate problem sizes"
},
"test_params": {
"problem_sizes": [
{"m": 256, "n": 256, "k": 128, "split_k": 1},
{"m": 512, "n": 256, "k": 256, "split_k": 1},
{"m": 256, "n": 512, "k": 256, "split_k": 1}
]
},
"tile_config": {
"tile_m": {"values": [128]},
"tile_n": {"values": [128]},
"tile_k": {"values": [64]},
"warp_m": {"values": [2]},
"warp_n": {"values": [2]},
"warp_k": {"values": [1]},
"warp_tile_m": {"values": [16]},
"warp_tile_n": {"values": [16]},
"warp_tile_k": {"values": [16]}
},
"trait_config": {
"pipeline": {"values": ["compv3"]},
"epilogue": {"values": ["default"]},
"scheduler": {"values": ["intrawave"]},
"pad_m": {"values": [false]},
"pad_n": {"values": [false]},
"pad_k": {"values": [false]},
"persistent": {"values": [false, true]},
"reduction_strategy": {"values": ["atomic"]}
},
"k_block_per_cu": 1,
"permute_n": false
}

View File

@@ -0,0 +1,74 @@
#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
import json
import argparse
import os
from pathlib import Path
def extract_test_params(config_file, output_file):
"""Extract test parameters from config JSON and write to output file"""
# Read config file
with open(config_file, "r") as f:
config = json.load(f)
# Extract test parameters
test_params = []
if "test_params" in config and "problem_sizes" in config["test_params"]:
test_params = config["test_params"]["problem_sizes"]
else:
# Default test parameters if none specified
test_params = [
{"m": 256, "n": 256, "k": 128, "split_k": 1},
{"m": 256, "n": 256, "k": 1024, "split_k": 1},
{"m": 256, "n": 512, "k": 512, "split_k": 1},
{"m": 512, "n": 256, "k": 512, "split_k": 1},
]
# Write to output file in C++ format
output_dir = Path(output_file).parent
output_dir.mkdir(parents=True, exist_ok=True)
with open(output_file, "w") as f:
f.write("// Generated test parameters for this configuration\n")
f.write("// This file is auto-generated during CMake configuration\n\n")
f.write("static const std::vector<GemmTestParams> CONFIG_TEST_PARAMS = {\n")
for i, params in enumerate(test_params):
comma = "," if i < len(test_params) - 1 else ""
f.write(
f" {{{params['m']}, {params['n']}, {params['k']}, {params['split_k']}}}{comma}\n"
)
f.write("};\n")
print(
f"Extracted {len(test_params)} test parameters from {config_file} -> {output_file}"
)
def main():
parser = argparse.ArgumentParser(
description="Extract test parameters from config JSON"
)
parser.add_argument("--config_file", required=True, help="Input config JSON file")
parser.add_argument(
"--output_file", required=True, help="Output test parameters file"
)
args = parser.parse_args()
if not os.path.exists(args.config_file):
print(f"Error: Config file not found: {args.config_file}")
return 1
extract_test_params(args.config_file, args.output_file)
return 0
if __name__ == "__main__":
exit(main())

View File

@@ -0,0 +1,240 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
/**
* @file test_gemm_simple.cpp
* @brief Unit tests for GEMM kernels generated by gemm_instance_builder
*
* This test includes kernels generated during CMake configuration by
* gemm_instance_builder.py and tests them with problem sizes extracted
* from the corresponding JSON configuration files.
*/
#include <gtest/gtest.h>
#include <iostream>
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
#include "tile_engine/ops/gemm_streamk/gemm_streamk_common.hpp"
// The kernel header is included via compile command line with -include flag
// It defines SelectedKernel struct, KERNEL_NAME, and tensor data types
// Adaptive error threshold calculation matching tile_engine's implementation
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
auto calculate_rtol_atol(const ck_tile::index_t K,
const ck_tile::index_t kbatch,
const float max_accumulated_value)
{
using ComputeType =
std::conditional_t<sizeof(ADataType) < sizeof(BDataType), ADataType, BDataType>;
// Calculate thresholds
const auto rtol = ck_tile::get_relative_threshold<ComputeType, CDataType, AccDataType>(
ck_tile::integer_divide_ceil(K, kbatch));
const auto atol = ck_tile::get_absolute_threshold<ComputeType, CDataType, AccDataType>(
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch));
// Calculate error due to split_k accumulation
const auto rtol_split_k =
ck_tile::get_relative_threshold<CDataType, CDataType, CDataType>(kbatch);
const auto atol_split_k = ck_tile::get_absolute_threshold<CDataType, CDataType, CDataType>(
max_accumulated_value, kbatch);
// Use higher threshold
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
}
/// @brief Function to compare the results of the device and host computations (from tile_engine)
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
bool compare_results(std::string instanceName,
ck_tile::index_t K,
ck_tile::index_t kbatch,
ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
ck_tile::HostTensor<CDataType>& c_m_n_host_result)
{
const float max_accumulated_value =
*std::max_element(c_m_n_host_result.mData.begin(), c_m_n_host_result.mData.end());
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
K, kbatch, max_accumulated_value);
bool pass = ck_tile::check_err(c_m_n_dev_result,
c_m_n_host_result,
"Error: Incorrect results!",
rtol_atol.at(ck_tile::number<0>{}),
rtol_atol.at(ck_tile::number<1>{}));
std::cout << "For " << instanceName << " Relative error threshold is "
<< rtol_atol.at(ck_tile::number<0>{}) << " Absolute error threshold is "
<< rtol_atol.at(ck_tile::number<1>{}) << std::endl;
std::cout << "The verification result is:" << (pass ? "correct" : "fail") << std::endl;
return pass;
}
// Test parameter structure for matrix dimensions and split_k values
struct GemmTestParams
{
int m, n, k, split_k;
};
// Include config-specific test parameters (after GemmTestParams struct is defined)
#ifdef GEMM_TEST_PARAMS_HPP
#include GEMM_TEST_PARAMS_HPP
#endif
class StreamKGemmTileEngineTest : public ::testing::TestWithParam<GemmTestParams>
{
protected:
void SetUp() override
{
auto params = GetParam();
m_ = params.m;
n_ = params.n;
k_ = params.k;
split_k_ = params.split_k;
// Calculate strides (following tile_engine pattern)
if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
stride_a_ = k_;
}
else
{
stride_a_ = m_;
}
if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
stride_b_ = n_;
}
else
{
stride_b_ = k_;
}
if constexpr(std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
stride_c_ = n_;
}
else
{
stride_c_ = m_;
}
}
// Test dimensions
int m_, n_, k_, split_k_;
int stride_a_, stride_b_, stride_c_;
};
TEST_P(StreamKGemmTileEngineTest, BasicFunctionality)
{
// Get tensor layouts from generated kernel
const ALayout layout_a = ALayout{};
const BLayout layout_b = BLayout{};
const CLayout layout_c = CLayout{};
// Use split_k from test parameters
int split_k = split_k_;
int stride_a_calc = ck_tile::get_default_stride(m_, k_, 0, is_row_major(layout_a));
int stride_b_calc = ck_tile::get_default_stride(k_, n_, 0, is_row_major(layout_b));
int stride_c_calc = ck_tile::get_default_stride(m_, n_, 0, is_row_major(layout_c));
// Create host tensors with proper descriptors
ck_tile::HostTensor<ADataType> a_m_k(
ck_tile::host_tensor_descriptor(m_, k_, stride_a_calc, is_row_major(layout_a)));
ck_tile::HostTensor<BDataType> b_k_n(
ck_tile::host_tensor_descriptor(k_, n_, stride_b_calc, is_row_major(layout_b)));
ck_tile::HostTensor<CDataType> c_m_n_dev_result(
ck_tile::host_tensor_descriptor(m_, n_, stride_c_calc, is_row_major(layout_c)));
ck_tile::HostTensor<CDataType> c_m_n_host_result(
ck_tile::host_tensor_descriptor(m_, n_, stride_c_calc, is_row_major(layout_c)));
// Initialize input tensors with uniform random distribution [-1.0, 1.0] (matches tile_engine)
ck_tile::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k);
ck_tile::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n);
// Allocate GPU device memory
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes());
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes());
// Copy data to device and zero output buffer
a_m_k_dev_buf.ToDevice(a_m_k.data());
b_k_n_dev_buf.ToDevice(b_k_n.data());
c_m_n_dev_buf.SetZero();
c_m_n_dev_result.SetZero();
// Calculate reference result on host for verification
ck_tile::reference_gemm<ADataType, BDataType, AccDataType, CDataType>(
a_m_k, b_k_n, c_m_n_host_result);
// Create GEMM kernel arguments
ck_tile::StreamKHostArgs args{a_m_k_dev_buf.GetDeviceBuffer(),
b_k_n_dev_buf.GetDeviceBuffer(),
c_m_n_dev_buf.GetDeviceBuffer(),
m_,
n_,
k_,
stride_a_calc,
stride_b_calc,
stride_c_calc};
// Configure kernel execution for maximum speed (no timing, no debug output)
ck_tile::stream_config stream_config{nullptr, // stream
false, // time_kernel (disable timing for speed)
0, // log_level (disable debug output)
0, // n_warmup
1, // n_repeat
false, // is_gpu_timer (unused when time_kernel=false)
false, // flush_cache
1}; // rotating_count
// Launch the generated kernel (no timing overhead for fastest execution)
try
{
SelectedKernel::launch(args, stream_config);
// Kernel launched successfully if no exception thrown
}
catch(const std::exception& e)
{
std::string error_msg(e.what());
// If arguments not supported, skip the test (configuration validation failure, not a bug)
if(error_msg.find("Arguments not supported") != std::string::npos)
{
GTEST_SKIP() << "Configuration not supported: " << e.what();
}
else
{
FAIL() << "Kernel launch failed: " << e.what();
}
}
// Copy result back from device
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
// Verify results using tile_engine's adaptive error thresholds
bool verification_passed = compare_results<ADataType, BDataType, AccDataType, CDataType>(
KERNEL_NAME, k_, split_k, c_m_n_dev_result, c_m_n_host_result);
EXPECT_TRUE(verification_passed) << "GEMM result verification failed";
}
TEST_P(StreamKGemmTileEngineTest, KernelInfo)
{
// Simple test to verify kernel information is available
EXPECT_TRUE(strlen(KERNEL_NAME) > 0) << "Kernel name should not be empty";
std::cout << "Testing kernel: " << KERNEL_NAME << std::endl;
std::cout << "Problem size: " << m_ << "x" << n_ << "x" << k_ << " with split_k=" << split_k_
<< std::endl;
}
// Use config-specific test parameters (included via compile flags)
// CONFIG_TEST_PARAMS is defined in the auto-generated test_params.hpp file
INSTANTIATE_TEST_SUITE_P(GemmVerification,
StreamKGemmTileEngineTest,
::testing::ValuesIn(CONFIG_TEST_PARAMS),
[](const ::testing::TestParamInfo<GemmTestParams>& param_info) {
return std::to_string(param_info.param.m) + "x" +
std::to_string(param_info.param.n) + "x" +
std::to_string(param_info.param.k) + "_splitk" +
std::to_string(param_info.param.split_k);
});

View File

@@ -98,7 +98,7 @@
},
"reduction_strategy": {
"values": [
"reduction", "atomic"
"atomic"
]
}
}

View File

@@ -377,6 +377,7 @@ class GemmKernelBuilder:
reduction_strategy_map = {
"atomic": "ck_tile::StreamKReductionStrategy::Atomic",
"reduction": "ck_tile::StreamKReductionStrategy::Reduction",
"tree": "ck_tile::StreamKReductionStrategy::TreeReduction",
}
# Determine accumulator type based on datatype
@@ -555,6 +556,11 @@ struct SelectedKernel {{
// Reset sk flags to zero before each repetition of the kernel
workspace_data.SetZero();
}}
else if(reduction_strategy == ck_tile::StreamKReductionStrategy::TreeReduction)
{{
// Reset sk flags to zero before each repetition of the kernel
workspace_data.SetZero();
}}
}};
// Launch kernel

View File

@@ -165,10 +165,13 @@ class GemmProfiler
auto [name, avg_time] = kernel_run_result;
auto dp_persistent =
SelectedKernel::UsePersistentKernel ? "PersistentKernel" : "NonPersistentKernel";
auto reduction_strategy =
SelectedKernel::reduction_strategy == ck_tile::StreamKReductionStrategy::Atomic
? "Atomic"
: "Reduction";
: SelectedKernel::reduction_strategy == ck_tile::StreamKReductionStrategy::Reduction
? "Reduction"
: "TreeReduction";
KernelInstance kernel_instance{
name, dp_persistent, reduction_strategy, gemm_problem, {-1.0f, -1.0f, -1.0f}};