[CK Tile] Stream-K gtest Code Gen (#5722)

## Motivation

Stream-K was using the tile engine infrastructure for smoke tests.
However, tile engine creates a different target per kernel instance,
which has resulted in scalability issues when used in the context of
unit tests. To avoid burdens on cmake configuration and build time, we
have opted to remove our Stream-K tile engine tests. Instead, we use
pure gtests with code gen to generate repetitive .cpp files.

**Note: This appears to change a lot of files because many files are
removed since they are now generated at build time.**

## Technical Details
We originally used Tile Engine to facilitate code gen for unit tests
since we found that pure gtests required the addition of many repetitive
.cpp files of the following form:
```cpp
#include "test_gemm_streamk_common_includes.hpp"

template <typename Tuple>
class TestCkTileStreamKBf8 : public TestCkTileStreamK<Tuple>
{
};

#define TEST_SUITE_NAME TestCkTileStreamKBf8

TYPED_TEST_SUITE(TestCkTileStreamKBf8, KernelTypesStreamKBf8);

#include "test_gemm_streamk_atomic_cases.inc"

#undef TEST_SUITE_NAME

```
Due to issues encountered with tile engine, we instead use pure gtests
to generate the repetitive .cpp files. The code generator parses
`KernelTypesStreamK*` type aliases from the types header using a
two-phase approach:
1. At **configure time**, CMake runs the Python script with
`--list_files` to extract the type alias names from the header
(test_gemm_streamk_types.hpp) and compute the list of .cpp file paths
that will be generated. This lets CMake know the exact set of source
files for each target.
2. At **build time**, `add_custom_command` runs the script again with
`--gen_files` to actually emit the .cpp files into the build directory,
triggered only when the types header or generator script changes.

With these changes, we've removed all Stream-K tile engine tests. There
are now 5 targets for Stream-K GEMM tests:
1. test_ck_tile_streamk_atomic_smoke: smoke tests for Atomic reduction
strategy (pipeline: compv3)
2. test_ck_tile_streamk_linear_smoke: smoke tests for Linear reduction
strategy (pipeline: compv3)
3. test_ck_tile_streamk_tree_smoke: smoke tests for Tree reduction
strategy (pipeline: compv3)
4. test_ck_tile_streamk_pipelines_smoke: smoke tests (smaller set) for
pipelines other than compv3
- Since Stream-K can be thought of as a wrapper around universal GEMM,
we don't need to extensively test each pipeline. So, we opt to run a few
tests for different pipelines. Currently, this just consists of the mem
pipeline, but compv4 is coming soon.
5. test_ck_tile_streamk_extended: extended tests

## Test Plan

I have tests the gtests locally on gfx90a, gfx942, and gfx950.

## Test Result

All local tests pass.

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Emily Martins
2026-04-02 15:05:44 -06:00
committed by GitHub
parent d55186f63a
commit d80fa8831f
39 changed files with 738 additions and 1775 deletions

View File

@@ -69,5 +69,4 @@ add_subdirectory(fmha)
add_subdirectory(gemm_tile_engine)
add_subdirectory(pooling)
add_subdirectory(grouped_conv)
add_subdirectory(gemm_streamk_tile_engine)
add_subdirectory(pooling_tile_engine)

View File

@@ -19,55 +19,93 @@ set(EXAMPLE_GEMM_COMPILE_COMPUTE_ASYNC_OPTIONS ${EXAMPLE_GEMM_COMPILE_COMPUTE_V4
if(GPU_TARGETS MATCHES "gfx90a|gfx942|gfx950")
include_directories(BEFORE ${CMAKE_CURRENT_SOURCE_DIR})
#TODO: support all arches
#TODO: current c-shuffle only supports C layout as R
add_gtest_executable(test_ck_tile_streamk_tile_partitioner test_streamk_tile_partitioner.cpp)
set(STREAMK_EXTENDED_SOURCES
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp16_persistent_compv3.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp16_persistent_compv4.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp16_persistent_mem.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf16_persistent_compv3.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf16_persistent_compv4.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf16_persistent_mem.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp16_nonpersistent_compv3.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp16_nonpersistent_compv4.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp16_nonpersistent_mem.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf16_nonpersistent_compv3.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf16_nonpersistent_compv4.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf16_nonpersistent_mem.cpp
test_gemm_streamk_util.cpp)
# We only test fp8 and bf8 on gfx942 and gfx950 since these types are not natively supported on gfx90a
if(GPU_TARGETS MATCHES "gfx942|gfx950")
list(APPEND STREAMK_EXTENDED_SOURCES
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp8_persistent_compv3.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp8_persistent_compv4.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp8_persistent_mem.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf8_persistent_compv3.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf8_persistent_compv4.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf8_persistent_mem.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp8_nonpersistent_compv3.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp8_nonpersistent_compv4.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp8_nonpersistent_mem.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf8_nonpersistent_compv3.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf8_nonpersistent_compv4.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf8_nonpersistent_mem.cpp)
endif()
# ---- Code-generate test .cpp files from types header ----
set(STREAMK_TYPES_HEADER ${CMAKE_CURRENT_SOURCE_DIR}/test_gemm_streamk_types.hpp)
set(STREAMK_GEN_SCRIPT ${CMAKE_CURRENT_SOURCE_DIR}/generate_test_files.py)
add_gtest_executable(test_ck_tile_streamk_extended ${STREAMK_EXTENDED_SOURCES})
target_compile_options(test_ck_tile_streamk_extended PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
# Re-run configure automatically if the types header changes (e.g. types added/removed)
# or if the generation script changes
set_property(DIRECTORY APPEND PROPERTY CMAKE_CONFIGURE_DEPENDS ${STREAMK_TYPES_HEADER} ${STREAMK_GEN_SCRIPT})
# Define the targets and their corresponding executable names
set(STREAMK_GEN_TARGETS extended atomic_smoke linear_smoke tree_smoke pipelines_smoke)
set(STREAMK_GEN_EXEC_EXTENDED test_ck_tile_streamk_extended)
set(STREAMK_GEN_EXEC_ATOMIC_SMOKE test_ck_tile_streamk_atomic_smoke)
set(STREAMK_GEN_EXEC_LINEAR_SMOKE test_ck_tile_streamk_linear_smoke)
set(STREAMK_GEN_EXEC_TREE_SMOKE test_ck_tile_streamk_tree_smoke)
set(STREAMK_GEN_EXEC_PIPELINES_SMOKE test_ck_tile_streamk_pipelines_smoke)
# Collect all test targets for umbrella label
set(CK_TILE_GEMM_STREAMK_TEST_TARGETS
test_ck_tile_streamk_tile_partitioner
test_ck_tile_streamk_extended
test_ck_tile_streamk_tile_partitioner)
foreach(target IN LISTS STREAMK_GEN_TARGETS)
string(TOUPPER ${target} TARGET_UPPER)
set(GEN_DIR ${CMAKE_CURRENT_BINARY_DIR}/${target})
set(EXEC_NAME ${STREAMK_GEN_EXEC_${TARGET_UPPER}})
set(LIST_FILE ${CMAKE_CURRENT_BINARY_DIR}/${target}_files.txt)
# Phase 1 (configure time): discover the list of files that will be generated
execute_process(
COMMAND ${Python3_EXECUTABLE} ${STREAMK_GEN_SCRIPT}
--types_header ${STREAMK_TYPES_HEADER}
--output_dir ${GEN_DIR}
--target ${target}
--list_files ${LIST_FILE}
RESULT_VARIABLE ret
ERROR_VARIABLE list_files_stderr)
if(ret AND NOT ret EQUAL 0)
message(FATAL_ERROR
"Failed to list ${target} test files via Python: ${ret}\n"
"stderr: ${list_files_stderr}"
)
endif()
file(STRINGS ${LIST_FILE} ALL_SOURCES_${target})
# Phase 2 (build time): generate the .cpp files when the types header changes
add_custom_command(
OUTPUT ${ALL_SOURCES_${target}}
COMMAND ${Python3_EXECUTABLE} ${STREAMK_GEN_SCRIPT}
--types_header ${STREAMK_TYPES_HEADER}
--output_dir ${GEN_DIR}
--target ${target}
--gen_files
DEPENDS ${STREAMK_TYPES_HEADER} ${STREAMK_GEN_SCRIPT}
COMMENT "Generating StreamK ${target} test sources from types header")
# Filter out fp8/bf8 sources on gfx90a since those types are not natively supported
set(FILTERED_SOURCES)
foreach(src IN LISTS ALL_SOURCES_${target})
if(NOT src MATCHES "_(fp8|bf8)_" OR GPU_TARGETS MATCHES "gfx942|gfx950")
list(APPEND FILTERED_SOURCES ${src})
endif()
endforeach()
list(APPEND FILTERED_SOURCES test_gemm_streamk_util.cpp)
add_gtest_executable(${EXEC_NAME} ${FILTERED_SOURCES})
target_compile_options(${EXEC_NAME} PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
list(APPEND CK_TILE_GEMM_STREAMK_TEST_TARGETS ${EXEC_NAME})
endforeach()
# Add python unit tests to validate the code gen logic in generate_test_files.py
add_test(
NAME test_ck_tile_streamk_generate_test_files
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/test_generate_test_files.py -v
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/..
)
# Label all ck_tile gemm_streamk tests with CK_TILE_GEMM_STREAMK_TESTS for selective execution
foreach(test_target ${CK_TILE_GEMM_STREAMK_TEST_TARGETS})
set_tests_properties(${test_target} PROPERTIES LABELS "CK_TILE_GEMM_STREAMK_TESTS")
endforeach()
# Also label the Python test
set_tests_properties(test_ck_tile_streamk_generate_test_files PROPERTIES LABELS "CK_TILE_GEMM_STREAMK_TESTS")
# Umbrella target to build and run all ck_tile gemm_streamk tests
# Usage: ninja ck_tile_gemm_streamk_tests

View File

@@ -1,18 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_gemm_streamk_common_includes.hpp"
template <typename Tuple>
class TestCkTileStreamKBf16NonPersistentCompV3 : public TestCkTileStreamK<Tuple>
{
};
#define TEST_SUITE_NAME TestCkTileStreamKBf16NonPersistentCompV3
TYPED_TEST_SUITE(TestCkTileStreamKBf16NonPersistentCompV3,
KernelTypesStreamKBf16NonPersistentCompV3);
#include "test_gemm_streamk_extended_cases.inc"
#undef TEST_SUITE_NAME

View File

@@ -1,18 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_gemm_streamk_common_includes.hpp"
template <typename Tuple>
class TestCkTileStreamKBf16NonPersistentCompV4 : public TestCkTileStreamK<Tuple>
{
};
#define TEST_SUITE_NAME TestCkTileStreamKBf16NonPersistentCompV4
TYPED_TEST_SUITE(TestCkTileStreamKBf16NonPersistentCompV4,
KernelTypesStreamKBf16NonPersistentCompV4);
#include "test_gemm_streamk_extended_cases.inc"
#undef TEST_SUITE_NAME

View File

@@ -1,17 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_gemm_streamk_common_includes.hpp"
template <typename Tuple>
class TestCkTileStreamKBf16NonPersistentMem : public TestCkTileStreamK<Tuple>
{
};
#define TEST_SUITE_NAME TestCkTileStreamKBf16NonPersistentMem
TYPED_TEST_SUITE(TestCkTileStreamKBf16NonPersistentMem, KernelTypesStreamKBf16NonPersistentMem);
#include "test_gemm_streamk_extended_cases.inc"
#undef TEST_SUITE_NAME

View File

@@ -1,17 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_gemm_streamk_common_includes.hpp"
template <typename Tuple>
class TestCkTileStreamKBf16PersistentCompV3 : public TestCkTileStreamK<Tuple>
{
};
#define TEST_SUITE_NAME TestCkTileStreamKBf16PersistentCompV3
TYPED_TEST_SUITE(TestCkTileStreamKBf16PersistentCompV3, KernelTypesStreamKBf16PersistentCompV3);
#include "test_gemm_streamk_extended_cases.inc"
#undef TEST_SUITE_NAME

View File

@@ -1,17 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_gemm_streamk_common_includes.hpp"
template <typename Tuple>
class TestCkTileStreamKBf16PersistentCompV4 : public TestCkTileStreamK<Tuple>
{
};
#define TEST_SUITE_NAME TestCkTileStreamKBf16PersistentCompV4
TYPED_TEST_SUITE(TestCkTileStreamKBf16PersistentCompV4, KernelTypesStreamKBf16PersistentCompV4);
#include "test_gemm_streamk_extended_cases.inc"
#undef TEST_SUITE_NAME

View File

@@ -1,17 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_gemm_streamk_common_includes.hpp"
template <typename Tuple>
class TestCkTileStreamKBf16PersistentMem : public TestCkTileStreamK<Tuple>
{
};
#define TEST_SUITE_NAME TestCkTileStreamKBf16PersistentMem
TYPED_TEST_SUITE(TestCkTileStreamKBf16PersistentMem, KernelTypesStreamKBf16PersistentMem);
#include "test_gemm_streamk_extended_cases.inc"
#undef TEST_SUITE_NAME

View File

@@ -1,17 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_gemm_streamk_common_includes.hpp"
template <typename Tuple>
class TestCkTileStreamKBf8NonPersistentCompV3 : public TestCkTileStreamK<Tuple>
{
};
#define TEST_SUITE_NAME TestCkTileStreamKBf8NonPersistentCompV3
TYPED_TEST_SUITE(TestCkTileStreamKBf8NonPersistentCompV3, KernelTypesStreamKBf8NonPersistentCompV3);
#include "test_gemm_streamk_extended_cases.inc"
#undef TEST_SUITE_NAME

View File

@@ -1,17 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_gemm_streamk_common_includes.hpp"
template <typename Tuple>
class TestCkTileStreamKBf8NonPersistentCompV4 : public TestCkTileStreamK<Tuple>
{
};
#define TEST_SUITE_NAME TestCkTileStreamKBf8NonPersistentCompV4
TYPED_TEST_SUITE(TestCkTileStreamKBf8NonPersistentCompV4, KernelTypesStreamKBf8NonPersistentCompV4);
#include "test_gemm_streamk_extended_cases.inc"
#undef TEST_SUITE_NAME

View File

@@ -1,17 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_gemm_streamk_common_includes.hpp"
template <typename Tuple>
class TestCkTileStreamKBf8NonPersistentMem : public TestCkTileStreamK<Tuple>
{
};
#define TEST_SUITE_NAME TestCkTileStreamKBf8NonPersistentMem
TYPED_TEST_SUITE(TestCkTileStreamKBf8NonPersistentMem, KernelTypesStreamKBf8NonPersistentMem);
#include "test_gemm_streamk_extended_cases.inc"
#undef TEST_SUITE_NAME

View File

@@ -1,17 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_gemm_streamk_common_includes.hpp"
template <typename Tuple>
class TestCkTileStreamKBf8PersistentCompV3 : public TestCkTileStreamK<Tuple>
{
};
#define TEST_SUITE_NAME TestCkTileStreamKBf8PersistentCompV3
TYPED_TEST_SUITE(TestCkTileStreamKBf8PersistentCompV3, KernelTypesStreamKBf8PersistentCompV3);
#include "test_gemm_streamk_extended_cases.inc"
#undef TEST_SUITE_NAME

View File

@@ -1,17 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_gemm_streamk_common_includes.hpp"
template <typename Tuple>
class TestCkTileStreamKBf8PersistentCompV4 : public TestCkTileStreamK<Tuple>
{
};
#define TEST_SUITE_NAME TestCkTileStreamKBf8PersistentCompV4
TYPED_TEST_SUITE(TestCkTileStreamKBf8PersistentCompV4, KernelTypesStreamKBf8PersistentCompV4);
#include "test_gemm_streamk_extended_cases.inc"
#undef TEST_SUITE_NAME

View File

@@ -1,17 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_gemm_streamk_common_includes.hpp"
template <typename Tuple>
class TestCkTileStreamKBf8PersistentMem : public TestCkTileStreamK<Tuple>
{
};
#define TEST_SUITE_NAME TestCkTileStreamKBf8PersistentMem
TYPED_TEST_SUITE(TestCkTileStreamKBf8PersistentMem, KernelTypesStreamKBf8PersistentMem);
#include "test_gemm_streamk_extended_cases.inc"
#undef TEST_SUITE_NAME

View File

@@ -1,18 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_gemm_streamk_common_includes.hpp"
template <typename Tuple>
class TestCkTileStreamKFp16NonPersistentCompV3 : public TestCkTileStreamK<Tuple>
{
};
#define TEST_SUITE_NAME TestCkTileStreamKFp16NonPersistentCompV3
TYPED_TEST_SUITE(TestCkTileStreamKFp16NonPersistentCompV3,
KernelTypesStreamKFp16NonPersistentCompV3);
#include "test_gemm_streamk_extended_cases.inc"
#undef TEST_SUITE_NAME

View File

@@ -1,18 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_gemm_streamk_common_includes.hpp"
template <typename Tuple>
class TestCkTileStreamKFp16NonPersistentCompV4 : public TestCkTileStreamK<Tuple>
{
};
#define TEST_SUITE_NAME TestCkTileStreamKFp16NonPersistentCompV4
TYPED_TEST_SUITE(TestCkTileStreamKFp16NonPersistentCompV4,
KernelTypesStreamKFp16NonPersistentCompV4);
#include "test_gemm_streamk_extended_cases.inc"
#undef TEST_SUITE_NAME

View File

@@ -1,17 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_gemm_streamk_common_includes.hpp"
template <typename Tuple>
class TestCkTileStreamKFp16NonPersistentMem : public TestCkTileStreamK<Tuple>
{
};
#define TEST_SUITE_NAME TestCkTileStreamKFp16NonPersistentMem
TYPED_TEST_SUITE(TestCkTileStreamKFp16NonPersistentMem, KernelTypesStreamKFp16NonPersistentMem);
#include "test_gemm_streamk_extended_cases.inc"
#undef TEST_SUITE_NAME

View File

@@ -1,17 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_gemm_streamk_common_includes.hpp"
template <typename Tuple>
class TestCkTileStreamKFp16PersistentCompV3 : public TestCkTileStreamK<Tuple>
{
};
#define TEST_SUITE_NAME TestCkTileStreamKFp16PersistentCompV3
TYPED_TEST_SUITE(TestCkTileStreamKFp16PersistentCompV3, KernelTypesStreamKFp16PersistentCompV3);
#include "test_gemm_streamk_extended_cases.inc"
#undef TEST_SUITE_NAME

View File

@@ -1,17 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_gemm_streamk_common_includes.hpp"
template <typename Tuple>
class TestCkTileStreamKFp16PersistentCompV4 : public TestCkTileStreamK<Tuple>
{
};
#define TEST_SUITE_NAME TestCkTileStreamKFp16PersistentCompV4
TYPED_TEST_SUITE(TestCkTileStreamKFp16PersistentCompV4, KernelTypesStreamKFp16PersistentCompV4);
#include "test_gemm_streamk_extended_cases.inc"
#undef TEST_SUITE_NAME

View File

@@ -1,17 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_gemm_streamk_common_includes.hpp"
template <typename Tuple>
class TestCkTileStreamKFp16PersistentMem : public TestCkTileStreamK<Tuple>
{
};
#define TEST_SUITE_NAME TestCkTileStreamKFp16PersistentMem
TYPED_TEST_SUITE(TestCkTileStreamKFp16PersistentMem, KernelTypesStreamKFp16PersistentMem);
#include "test_gemm_streamk_extended_cases.inc"
#undef TEST_SUITE_NAME

View File

@@ -1,17 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_gemm_streamk_common_includes.hpp"
template <typename Tuple>
class TestCkTileStreamKFp8NonPersistentCompV3 : public TestCkTileStreamK<Tuple>
{
};
#define TEST_SUITE_NAME TestCkTileStreamKFp8NonPersistentCompV3
TYPED_TEST_SUITE(TestCkTileStreamKFp8NonPersistentCompV3, KernelTypesStreamKFp8NonPersistentCompV3);
#include "test_gemm_streamk_extended_cases.inc"
#undef TEST_SUITE_NAME

View File

@@ -1,17 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_gemm_streamk_common_includes.hpp"
template <typename Tuple>
class TestCkTileStreamKFp8NonPersistentCompV4 : public TestCkTileStreamK<Tuple>
{
};
#define TEST_SUITE_NAME TestCkTileStreamKFp8NonPersistentCompV4
TYPED_TEST_SUITE(TestCkTileStreamKFp8NonPersistentCompV4, KernelTypesStreamKFp8NonPersistentCompV4);
#include "test_gemm_streamk_extended_cases.inc"
#undef TEST_SUITE_NAME

View File

@@ -1,17 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_gemm_streamk_common_includes.hpp"
template <typename Tuple>
class TestCkTileStreamKFp8NonPersistentMem : public TestCkTileStreamK<Tuple>
{
};
#define TEST_SUITE_NAME TestCkTileStreamKFp8NonPersistentMem
TYPED_TEST_SUITE(TestCkTileStreamKFp8NonPersistentMem, KernelTypesStreamKFp8NonPersistentMem);
#include "test_gemm_streamk_extended_cases.inc"
#undef TEST_SUITE_NAME

View File

@@ -1,17 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_gemm_streamk_common_includes.hpp"
template <typename Tuple>
class TestCkTileStreamKFp8PersistentCompV3 : public TestCkTileStreamK<Tuple>
{
};
#define TEST_SUITE_NAME TestCkTileStreamKFp8PersistentCompV3
TYPED_TEST_SUITE(TestCkTileStreamKFp8PersistentCompV3, KernelTypesStreamKFp8PersistentCompV3);
#include "test_gemm_streamk_extended_cases.inc"
#undef TEST_SUITE_NAME

View File

@@ -1,17 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_gemm_streamk_common_includes.hpp"
template <typename Tuple>
class TestCkTileStreamKFp8PersistentCompV4 : public TestCkTileStreamK<Tuple>
{
};
#define TEST_SUITE_NAME TestCkTileStreamKFp8PersistentCompV4
TYPED_TEST_SUITE(TestCkTileStreamKFp8PersistentCompV4, KernelTypesStreamKFp8PersistentCompV4);
#include "test_gemm_streamk_extended_cases.inc"
#undef TEST_SUITE_NAME

View File

@@ -1,17 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_gemm_streamk_common_includes.hpp"
template <typename Tuple>
class TestCkTileStreamKFp8PersistentMem : public TestCkTileStreamK<Tuple>
{
};
#define TEST_SUITE_NAME TestCkTileStreamKFp8PersistentMem
TYPED_TEST_SUITE(TestCkTileStreamKFp8PersistentMem, KernelTypesStreamKFp8PersistentMem);
#include "test_gemm_streamk_extended_cases.inc"
#undef TEST_SUITE_NAME

View File

@@ -0,0 +1,215 @@
#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""Generate test .cpp files from KernelTypes definitions in
test_gemm_streamk_types.hpp.
Two modes:
--list_files FILE Write the list of output file paths to FILE (one per line)
without generating the files. Used at CMake configure time.
--gen_files Actually emit the .cpp files into --output_dir.
Used at build time via add_custom_command.
Target selection (--target):
extended Kernel types containing 'Atomic' or 'Pipelines'
-> includes test_gemm_streamk_extended_cases.inc
atomic_smoke Kernel types containing 'Atomic' (not 'Pipelines')
-> includes test_gemm_streamk_atomic_cases.inc
linear_smoke Kernel types containing 'Linear' (not 'Pipelines')
-> includes test_gemm_streamk_reduction_cases.inc
tree_smoke Kernel types containing 'Tree' (not 'Pipelines')
-> includes test_gemm_streamk_reduction_cases.inc
pipelines_smoke Kernel types matching 'Pipelines'
-> includes test_gemm_streamk_reduction_cases.inc
and test_gemm_streamk_atomic_cases.inc
"""
import argparse
import os
import re
import sys
# --------------------------------------------------------------------------- #
# Template for every generated .cpp file
# --------------------------------------------------------------------------- #
CPP_TEMPLATE = """\
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_gemm_streamk_common_includes.hpp"
template <typename Tuple>
class {class_name} : public TestCkTileStreamK<Tuple>
{{
}};
#define TEST_SUITE_NAME {class_name}
TYPED_TEST_SUITE({class_name}, {type_alias});
{inc_includes}
#undef TEST_SUITE_NAME
"""
# --------------------------------------------------------------------------- #
# Target definitions: filter predicate and .inc files
# --------------------------------------------------------------------------- #
TARGETS = {
"extended": {
"filter": lambda suffix: "Atomic" in suffix or suffix == "Pipelines",
"inc_files": ["test_gemm_streamk_extended_cases.inc"],
},
"atomic_smoke": {
"filter": lambda suffix: "Atomic" in suffix and suffix != "Pipelines",
"inc_files": ["test_gemm_streamk_atomic_cases.inc"],
},
"linear_smoke": {
"filter": lambda suffix: "Linear" in suffix and suffix != "Pipelines",
"inc_files": ["test_gemm_streamk_reduction_cases.inc"],
},
"tree_smoke": {
"filter": lambda suffix: "Tree" in suffix and suffix != "Pipelines",
"inc_files": ["test_gemm_streamk_reduction_cases.inc"],
},
"pipelines_smoke": {
"filter": lambda suffix: suffix == "Pipelines",
"inc_files": [
"test_gemm_streamk_reduction_cases.inc",
"test_gemm_streamk_atomic_cases.inc",
],
},
}
# --------------------------------------------------------------------------- #
# Mapping from CamelCase suffix fragments to file-name fragments
# --------------------------------------------------------------------------- #
KNOWN_TOKENS = [
("Fp16", "fp16"),
("Bf16", "bf16"),
("Fp8", "fp8"),
("Bf8", "bf8"),
("NonPersistent", "nonpersistent"),
("Persistent", "persistent"),
("Atomic", "atomic"),
("Linear", "linear"),
("Tree", "tree"),
("CompV3", "compv3"),
("Pipelines", "pipelines"),
]
def suffix_to_file_tag(suffix: str) -> str:
"""Convert a CamelCase suffix like 'Fp16PersistentAtomicCompV3' to
'fp16_persistent_atomic_compv3'."""
parts: list[str] = []
remaining = suffix
while remaining:
matched = False
for token, replacement in KNOWN_TOKENS:
if remaining.startswith(token):
parts.append(replacement)
remaining = remaining[len(token) :]
matched = True
break
if not matched:
raise ValueError(
f"Unrecognised token in KernelTypes suffix: '{remaining}' "
f"(from '{suffix}')"
)
return "_".join(parts)
def parse_types_header(header_path: str, target: str) -> list[dict]:
"""Return a list of dicts with keys: type_alias, class_name, file_tag, suffix."""
target_def = TARGETS[target]
# Pattern matches lines like: using KernelTypesStreamKFp16PersistentAtomicCompV3 = ...
pattern = re.compile(r"using\s+(KernelTypesStreamK(\w+))\s*=")
entries: list[dict] = []
with open(header_path) as f:
for line in f:
match = pattern.search(line)
if match:
# If the match is: using KernelTypesStreamKFp16PersistentAtomicCompV3 = ...
# type_alias is KernelTypesStreamKFp16PersistentAtomicCompV3
# suffix is Fp16PersistentAtomicCompV3
type_alias = match.group(1)
suffix = match.group(2)
if not target_def["filter"](suffix):
continue
entries.append(
{
"type_alias": type_alias,
"class_name": f"TestCkTileStreamK{suffix}",
"file_tag": suffix_to_file_tag(suffix),
}
)
return entries
def output_path(output_dir: str, entry: dict) -> str:
return os.path.join(output_dir, f"test_gemm_streamk_{entry['file_tag']}.cpp")
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--types_header", required=True, help="Path to test_gemm_streamk_types.hpp"
)
parser.add_argument(
"--output_dir", required=True, help="Directory for generated .cpp files"
)
parser.add_argument(
"--target",
required=True,
choices=list(TARGETS.keys()),
help="Which target to generate files for",
)
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument(
"--list_files",
metavar="FILE",
help="Write output file paths to FILE then exit",
)
group.add_argument(
"--gen_files", action="store_true", help="Generate the .cpp files"
)
return parser.parse_args()
def main() -> None:
args = parse_args()
entries = parse_types_header(args.types_header, args.target)
if not entries:
print(
f"ERROR: no KernelTypesStreamK* definitions found for target "
f"'{args.target}' in {args.types_header}",
file=sys.stderr,
)
sys.exit(1)
inc_files = TARGETS[args.target]["inc_files"]
inc_includes = "\n".join(f'#include "{f}"' for f in inc_files)
if args.list_files:
os.makedirs(os.path.dirname(args.list_files) or ".", exist_ok=True)
with open(args.list_files, "w") as f:
for entry in entries:
f.write(output_path(args.output_dir, entry) + "\n")
else:
os.makedirs(args.output_dir, exist_ok=True)
for entry in entries:
path = output_path(args.output_dir, entry)
content = CPP_TEMPLATE.format(
class_name=entry["class_name"],
type_alias=entry["type_alias"],
inc_includes=inc_includes,
)
with open(path, "w") as f:
f.write(content)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,47 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
TYPED_TEST(TEST_SUITE_NAME, StreamK_EdgeCase)
{
ck_tile::index_t M = 256;
ck_tile::index_t N = 256;
ck_tile::index_t K = 256;
this->Run(M, N, K);
}
TYPED_TEST(TEST_SUITE_NAME, StreamK_DPOnly)
{
const ck_tile::index_t num_cu = get_cu_count();
constexpr ck_tile::index_t M_Tile = std::tuple_element_t<7, TypeParam>::value;
constexpr ck_tile::index_t N_Tile = std::tuple_element_t<8, TypeParam>::value;
constexpr ck_tile::index_t K_Tile = std::tuple_element_t<9, TypeParam>::value;
// For DP only, we ensure that the number of tiles is a multiple of the number of CUs. This
// assumes tile sizes are large enough such that occupancy is 1.
ck_tile::index_t M = M_Tile * num_cu;
ck_tile::index_t N = N_Tile;
ck_tile::index_t K = K_Tile;
this->Run(M, N, K);
}
TYPED_TEST(TEST_SUITE_NAME, StreamK_SKOnly)
{
const ck_tile::index_t num_cu = get_cu_count();
constexpr ck_tile::index_t M_Tile = std::tuple_element_t<7, TypeParam>::value;
constexpr ck_tile::index_t N_Tile = std::tuple_element_t<8, TypeParam>::value;
constexpr ck_tile::index_t K_Tile = std::tuple_element_t<9, TypeParam>::value;
// For SK only, we have 4 macro tiles in C. But, we need to make sure there is enough work along
// the K dimension to avoid falling into the edge case. Thus, we always have at least num_cu
// macro tiles in the K dimension. This assumes tile sizes are large enough such that occupancy
// is 1.
ck_tile::index_t M = M_Tile * 2;
ck_tile::index_t N = N_Tile * 2;
ck_tile::index_t K = K_Tile * num_cu;
this->Run(M, N, K);
}

View File

@@ -0,0 +1,46 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
TYPED_TEST(TEST_SUITE_NAME, StreamK_SKOnly_OneTile)
{
const ck_tile::index_t num_cu = get_cu_count();
constexpr ck_tile::index_t M_Tile = std::tuple_element_t<7, TypeParam>::value;
constexpr ck_tile::index_t N_Tile = std::tuple_element_t<8, TypeParam>::value;
constexpr ck_tile::index_t K_Tile = std::tuple_element_t<9, TypeParam>::value;
ck_tile::index_t M = M_Tile;
ck_tile::index_t N = N_Tile;
ck_tile::index_t K = K_Tile * num_cu;
this->Run(M, N, K);
}
TYPED_TEST(TEST_SUITE_NAME, StreamK_SKOnly_4Tiles_Reduction)
{
const ck_tile::index_t num_cu = get_cu_count();
constexpr ck_tile::index_t M_Tile = std::tuple_element_t<7, TypeParam>::value;
constexpr ck_tile::index_t N_Tile = std::tuple_element_t<8, TypeParam>::value;
constexpr ck_tile::index_t K_Tile = std::tuple_element_t<9, TypeParam>::value;
ck_tile::index_t M = M_Tile * 4;
ck_tile::index_t N = N_Tile;
ck_tile::index_t K = K_Tile * num_cu + (25 * K_Tile);
this->Run(M, N, K);
}
TYPED_TEST(TEST_SUITE_NAME, StreamK_SKOnly_21Tiles)
{
const ck_tile::index_t num_cu = get_cu_count();
constexpr ck_tile::index_t M_Tile = std::tuple_element_t<7, TypeParam>::value;
constexpr ck_tile::index_t N_Tile = std::tuple_element_t<8, TypeParam>::value;
constexpr ck_tile::index_t K_Tile = std::tuple_element_t<9, TypeParam>::value;
ck_tile::index_t M = M_Tile * 3;
ck_tile::index_t N = N_Tile * 7;
ck_tile::index_t K = K_Tile * num_cu + (30 * K_Tile);
this->Run(M, N, K);
}

View File

@@ -14,16 +14,28 @@ using BF16 = ck_tile::bf16_t;
using BF8 = ck_tile::bf8_t;
using F32 = float;
// Layouts
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
// Persistence
using Persistent = std::true_type;
using NonPersistent = std::false_type;
// Pipelines
using Mem = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::Mem>;
using CompV3 = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::CompV3>;
using CompV4 = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::CompV4>;
using Persistent = std::true_type;
using NonPersistent = std::false_type;
// Reduction Strategies
using Atomic = ck_tile::integral_constant<ck_tile::StreamKReductionStrategy,
ck_tile::StreamKReductionStrategy::Atomic>;
using Linear = ck_tile::integral_constant<ck_tile::StreamKReductionStrategy,
ck_tile::StreamKReductionStrategy::Linear>;
using Tree = ck_tile::integral_constant<ck_tile::StreamKReductionStrategy,
ck_tile::StreamKReductionStrategy::Tree>;
using I16 = ck_tile::number<16>;
using I32 = ck_tile::number<32>;
using I128 = ck_tile::number<128>;
using I256 = ck_tile::number<256>;
@@ -32,180 +44,157 @@ using I256 = ck_tile::number<256>;
// ========================== CompV3 Pipeline ==========================
using KernelTypesStreamKFp16PersistentCompV3 = ::testing::Types<
// ALayout BLayout CLayout ADataType BDataType AccDataType CDataType M_MacroTile N_MacroTile K_MacroTile Persistent Pipeline
// Atomics
using KernelTypesStreamKFp16PersistentAtomicCompV3 = ::testing::Types<
// ALayout BLayout CLayout ADataType BDataType AccDataType CDataType M_MacroTile N_MacroTile K_MacroTile M_WaveTile N_WaveTile K_WaveTile Persistent Pipeline ReductionStrategy
std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, Persistent, CompV3>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, Persistent, CompV3>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, Persistent, CompV3>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, Persistent, CompV3>
std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Persistent, CompV3, Atomic>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Persistent, CompV3, Atomic>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Persistent, CompV3, Atomic>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Persistent, CompV3, Atomic>
>;
using KernelTypesStreamKBf16PersistentCompV3 = ::testing::Types<
std::tuple< Row, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent, CompV3>,
std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent, CompV3>,
std::tuple< Col, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent, CompV3>,
std::tuple< Col, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent, CompV3>
using KernelTypesStreamKBf16PersistentAtomicCompV3 = ::testing::Types<
std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, I32, I32, I16, Persistent, CompV3, Atomic>
>;
using KernelTypesStreamKBf8PersistentCompV3 = ::testing::Types<
std::tuple< Row, Row, Row, BF8, BF8, F32, BF16, I128, I128, I32, Persistent, CompV3>,
std::tuple< Row, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, Persistent, CompV3>,
std::tuple< Col, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, Persistent, CompV3>,
std::tuple< Col, Row, Row, BF8, BF8, F32, BF16, I128, I128, I32, Persistent, CompV3>
using KernelTypesStreamKBf8PersistentAtomicCompV3 = ::testing::Types<
std::tuple< Row, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, I32, I32, I16, Persistent, CompV3, Atomic>
>;
using KernelTypesStreamKFp8PersistentCompV3 = ::testing::Types<
std::tuple< Row, Row, Row, F8, F8, F32, F16, I128, I128, I32, Persistent, CompV3>,
std::tuple< Row, Col, Row, F8, F8, F32, F16, I128, I128, I32, Persistent, CompV3>,
std::tuple< Col, Col, Row, F8, F8, F32, F16, I128, I128, I32, Persistent, CompV3>,
std::tuple< Col, Row, Row, F8, F8, F32, F16, I128, I128, I32, Persistent, CompV3>
using KernelTypesStreamKFp8PersistentAtomicCompV3 = ::testing::Types<
std::tuple< Row, Row, Row, F8, F8, F32, F16, I128, I128, I32, I32, I32, I16, Persistent, CompV3, Atomic>,
std::tuple< Row, Col, Row, F8, F8, F32, F16, I128, I128, I32, I32, I32, I16, Persistent, CompV3, Atomic>,
std::tuple< Col, Col, Row, F8, F8, F32, F16, I128, I128, I32, I32, I32, I16, Persistent, CompV3, Atomic>,
std::tuple< Col, Row, Row, F8, F8, F32, F16, I128, I128, I32, I32, I32, I16, Persistent, CompV3, Atomic>
>;
using KernelTypesStreamKFp16NonPersistentCompV3 = ::testing::Types<
std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent, CompV3>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent, CompV3>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent, CompV3>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent, CompV3>
using KernelTypesStreamKFp16NonPersistentAtomicCompV3 = ::testing::Types<
std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, NonPersistent, CompV3, Atomic>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, NonPersistent, CompV3, Atomic>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, NonPersistent, CompV3, Atomic>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, NonPersistent, CompV3, Atomic>
>;
using KernelTypesStreamKBf16NonPersistentCompV3 = ::testing::Types<
std::tuple< Row, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent, CompV3>,
std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent, CompV3>,
std::tuple< Col, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent, CompV3>,
std::tuple< Col, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent, CompV3>
using KernelTypesStreamKBf16NonPersistentAtomicCompV3 = ::testing::Types<
std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, I32, I32, I16, NonPersistent, CompV3, Atomic>
>;
using KernelTypesStreamKBf8NonPersistentCompV3 = ::testing::Types<
std::tuple< Row, Row, Row, BF8, BF8, F32, BF16, I128, I128, I32, NonPersistent, CompV3>,
std::tuple< Row, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, NonPersistent, CompV3>,
std::tuple< Col, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, NonPersistent, CompV3>,
std::tuple< Col, Row, Row, BF8, BF8, F32, BF16, I128, I128, I32, NonPersistent, CompV3>
using KernelTypesStreamKBf8NonPersistentAtomicCompV3 = ::testing::Types<
std::tuple< Row, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, I32, I32, I16, NonPersistent, CompV3, Atomic>
>;
using KernelTypesStreamKFp8NonPersistentCompV3 = ::testing::Types<
std::tuple< Row, Row, Row, F8, F8, F32, F16, I128, I128, I32, NonPersistent, CompV3>,
std::tuple< Row, Col, Row, F8, F8, F32, F16, I128, I128, I32, NonPersistent, CompV3>,
std::tuple< Col, Col, Row, F8, F8, F32, F16, I128, I128, I32, NonPersistent, CompV3>,
std::tuple< Col, Row, Row, F8, F8, F32, F16, I128, I128, I32, NonPersistent, CompV3>
using KernelTypesStreamKFp8NonPersistentAtomicCompV3 = ::testing::Types<
std::tuple< Row, Row, Row, F8, F8, F32, F16, I128, I128, I32, I32, I32, I16, NonPersistent, CompV3, Atomic>,
std::tuple< Row, Col, Row, F8, F8, F32, F16, I128, I128, I32, I32, I32, I16, NonPersistent, CompV3, Atomic>,
std::tuple< Col, Col, Row, F8, F8, F32, F16, I128, I128, I32, I32, I32, I16, NonPersistent, CompV3, Atomic>,
std::tuple< Col, Row, Row, F8, F8, F32, F16, I128, I128, I32, I32, I32, I16, NonPersistent, CompV3, Atomic>
>;
// ========================== CompV4 Pipeline ==========================
// Linear
using KernelTypesStreamKFp16PersistentLinearCompV3 = ::testing::Types<
// ALayout BLayout CLayout ADataType BDataType AccDataType CDataType M_MacroTile N_MacroTile K_MacroTile M_WaveTile N_WaveTile K_WaveTile Persistent Pipeline ReductionStrategy
using KernelTypesStreamKFp16PersistentCompV4 = ::testing::Types<
// ALayout BLayout CLayout ADataType BDataType AccDataType CDataType M_MacroTile N_MacroTile K_MacroTile Persistent Pipeline
std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, Persistent, CompV4>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, Persistent, CompV4>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, Persistent, CompV4>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, Persistent, CompV4>
std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Persistent, CompV3, Linear>,
std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, I16, I16, I16, Persistent, CompV3, Linear>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Persistent, CompV3, Linear>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Persistent, CompV3, Linear>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Persistent, CompV3, Linear>
>;
using KernelTypesStreamKBf16PersistentCompV4 = ::testing::Types<
std::tuple< Row, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent, CompV4>,
std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent, CompV4>,
std::tuple< Col, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent, CompV4>,
std::tuple< Col, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent, CompV4>
using KernelTypesStreamKBf16PersistentLinearCompV3 = ::testing::Types<
std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, I32, I32, I16, Persistent, CompV3, Linear>
>;
using KernelTypesStreamKBf8PersistentCompV4 = ::testing::Types<
std::tuple< Row, Row, Row, BF8, BF8, F32, BF16, I128, I128, I32, Persistent, CompV4>,
std::tuple< Row, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, Persistent, CompV4>,
std::tuple< Col, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, Persistent, CompV4>,
std::tuple< Col, Row, Row, BF8, BF8, F32, BF16, I128, I128, I32, Persistent, CompV4>
using KernelTypesStreamKBf8PersistentLinearCompV3 = ::testing::Types<
std::tuple< Row, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, I32, I32, I16, Persistent, CompV3, Linear>
>;
using KernelTypesStreamKFp8PersistentCompV4 = ::testing::Types<
std::tuple< Row, Row, Row, F8, F8, F32, F16, I128, I128, I32, Persistent, CompV4>,
std::tuple< Row, Col, Row, F8, F8, F32, F16, I128, I128, I32, Persistent, CompV4>,
std::tuple< Col, Col, Row, F8, F8, F32, F16, I128, I128, I32, Persistent, CompV4>,
std::tuple< Col, Row, Row, F8, F8, F32, F16, I128, I128, I32, Persistent, CompV4>
using KernelTypesStreamKFp8PersistentLinearCompV3 = ::testing::Types<
std::tuple< Row, Row, Row, F8, F8, F32, F16, I128, I128, I32, I32, I32, I16, Persistent, CompV3, Linear>,
std::tuple< Row, Col, Row, F8, F8, F32, F16, I128, I128, I32, I32, I32, I16, Persistent, CompV3, Linear>,
std::tuple< Col, Col, Row, F8, F8, F32, F16, I128, I128, I32, I32, I32, I16, Persistent, CompV3, Linear>,
std::tuple< Col, Row, Row, F8, F8, F32, F16, I128, I128, I32, I32, I32, I16, Persistent, CompV3, Linear>
>;
using KernelTypesStreamKFp16NonPersistentCompV4 = ::testing::Types<
std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent, CompV4>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent, CompV4>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent, CompV4>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent, CompV4>
using KernelTypesStreamKFp16NonPersistentLinearCompV3 = ::testing::Types<
std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, NonPersistent, CompV3, Linear>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, NonPersistent, CompV3, Linear>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, NonPersistent, CompV3, Linear>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, NonPersistent, CompV3, Linear>
>;
using KernelTypesStreamKBf16NonPersistentCompV4 = ::testing::Types<
std::tuple< Row, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent, CompV4>,
std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent, CompV4>,
std::tuple< Col, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent, CompV4>,
std::tuple< Col, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent, CompV4>
using KernelTypesStreamKBf16NonPersistentLinearCompV3 = ::testing::Types<
std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, I32, I32, I16, NonPersistent, CompV3, Linear>
>;
using KernelTypesStreamKBf8NonPersistentCompV4 = ::testing::Types<
std::tuple< Row, Row, Row, BF8, BF8, F32, BF16, I128, I128, I32, NonPersistent, CompV4>,
std::tuple< Row, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, NonPersistent, CompV4>,
std::tuple< Col, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, NonPersistent, CompV4>,
std::tuple< Col, Row, Row, BF8, BF8, F32, BF16, I128, I128, I32, NonPersistent, CompV4>
using KernelTypesStreamKBf8NonPersistentLinearCompV3 = ::testing::Types<
std::tuple< Row, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, I32, I32, I16, NonPersistent, CompV3, Linear>
>;
using KernelTypesStreamKFp8NonPersistentCompV4 = ::testing::Types<
std::tuple< Row, Row, Row, F8, F8, F32, F16, I128, I128, I32, NonPersistent, CompV4>,
std::tuple< Row, Col, Row, F8, F8, F32, F16, I128, I128, I32, NonPersistent, CompV4>,
std::tuple< Col, Col, Row, F8, F8, F32, F16, I128, I128, I32, NonPersistent, CompV4>,
std::tuple< Col, Row, Row, F8, F8, F32, F16, I128, I128, I32, NonPersistent, CompV4>
using KernelTypesStreamKFp8NonPersistentLinearCompV3 = ::testing::Types<
std::tuple< Row, Row, Row, F8, F8, F32, F16, I128, I128, I32, I32, I32, I16, NonPersistent, CompV3, Linear>,
std::tuple< Row, Col, Row, F8, F8, F32, F16, I128, I128, I32, I32, I32, I16, NonPersistent, CompV3, Linear>,
std::tuple< Col, Col, Row, F8, F8, F32, F16, I128, I128, I32, I32, I32, I16, NonPersistent, CompV3, Linear>,
std::tuple< Col, Row, Row, F8, F8, F32, F16, I128, I128, I32, I32, I32, I16, NonPersistent, CompV3, Linear>
>;
// ============================= Mem Pipeline =============================
// Tree
using KernelTypesStreamKFp16PersistentTreeCompV3 = ::testing::Types<
// ALayout BLayout CLayout ADataType BDataType AccDataType CDataType M_MacroTile N_MacroTile K_MacroTile M_WaveTile N_WaveTile K_WaveTile Persistent Pipeline ReductionStrategy
using KernelTypesStreamKFp16PersistentMem = ::testing::Types<
std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, Persistent, Mem>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, Persistent, Mem>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, Persistent, Mem>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, Persistent, Mem>
std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Persistent, CompV3, Tree>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Persistent, CompV3, Tree>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, I16, I16, I16, Persistent, CompV3, Tree>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Persistent, CompV3, Tree>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Persistent, CompV3, Tree>
>;
using KernelTypesStreamKBf16PersistentMem = ::testing::Types<
std::tuple< Row, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent, Mem>,
std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent, Mem>,
std::tuple< Col, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent, Mem>,
std::tuple< Col, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent, Mem>
using KernelTypesStreamKBf16PersistentTreeCompV3 = ::testing::Types<
std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, I32, I32, I16, Persistent, CompV3, Tree>
>;
using KernelTypesStreamKBf8PersistentMem = ::testing::Types<
std::tuple< Row, Row, Row, BF8, BF8, F32, BF16, I128, I128, I32, Persistent, Mem>,
std::tuple< Row, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, Persistent, Mem>,
std::tuple< Col, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, Persistent, Mem>,
std::tuple< Col, Row, Row, BF8, BF8, F32, BF16, I128, I128, I32, Persistent, Mem>
using KernelTypesStreamKBf8PersistentTreeCompV3 = ::testing::Types<
std::tuple< Row, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, I32, I32, I16, Persistent, CompV3, Tree>
>;
using KernelTypesStreamKFp8PersistentMem = ::testing::Types<
std::tuple< Row, Row, Row, F8, F8, F32, F16, I128, I128, I32, Persistent, Mem>,
std::tuple< Row, Col, Row, F8, F8, F32, F16, I128, I128, I32, Persistent, Mem>,
std::tuple< Col, Col, Row, F8, F8, F32, F16, I128, I128, I32, Persistent, Mem>,
std::tuple< Col, Row, Row, F8, F8, F32, F16, I128, I128, I32, Persistent, Mem>
using KernelTypesStreamKFp8PersistentTreeCompV3 = ::testing::Types<
std::tuple< Row, Row, Row, F8, F8, F32, F16, I128, I128, I32, I32, I32, I16, Persistent, CompV3, Tree>,
std::tuple< Row, Col, Row, F8, F8, F32, F16, I128, I128, I32, I32, I32, I16, Persistent, CompV3, Tree>,
std::tuple< Col, Col, Row, F8, F8, F32, F16, I128, I128, I32, I32, I32, I16, Persistent, CompV3, Tree>,
std::tuple< Col, Row, Row, F8, F8, F32, F16, I128, I128, I32, I32, I32, I16, Persistent, CompV3, Tree>
>;
using KernelTypesStreamKFp16NonPersistentMem = ::testing::Types<
std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent, Mem>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent, Mem>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent, Mem>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent, Mem>
using KernelTypesStreamKFp16NonPersistentTreeCompV3 = ::testing::Types<
std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, NonPersistent, CompV3, Tree>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, NonPersistent, CompV3, Tree>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, NonPersistent, CompV3, Tree>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, NonPersistent, CompV3, Tree>
>;
using KernelTypesStreamKBf16NonPersistentMem = ::testing::Types<
std::tuple< Row, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent, Mem>,
std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent, Mem>,
std::tuple< Col, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent, Mem>,
std::tuple< Col, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent, Mem>
using KernelTypesStreamKBf16NonPersistentTreeCompV3 = ::testing::Types<
std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, I32, I32, I16, NonPersistent, CompV3, Tree>
>;
using KernelTypesStreamKBf8NonPersistentMem = ::testing::Types<
std::tuple< Row, Row, Row, BF8, BF8, F32, BF16, I128, I128, I32, NonPersistent, Mem>,
std::tuple< Row, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, NonPersistent, Mem>,
std::tuple< Col, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, NonPersistent, Mem>,
std::tuple< Col, Row, Row, BF8, BF8, F32, BF16, I128, I128, I32, NonPersistent, Mem>
using KernelTypesStreamKBf8NonPersistentTreeCompV3 = ::testing::Types<
std::tuple< Row, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, I32, I32, I16, NonPersistent, CompV3, Tree>
>;
using KernelTypesStreamKFp8NonPersistentMem = ::testing::Types<
std::tuple< Row, Row, Row, F8, F8, F32, F16, I128, I128, I32, NonPersistent, Mem>,
std::tuple< Row, Col, Row, F8, F8, F32, F16, I128, I128, I32, NonPersistent, Mem>,
std::tuple< Col, Col, Row, F8, F8, F32, F16, I128, I128, I32, NonPersistent, Mem>,
std::tuple< Col, Row, Row, F8, F8, F32, F16, I128, I128, I32, NonPersistent, Mem>
using KernelTypesStreamKFp8NonPersistentTreeCompV3 = ::testing::Types<
std::tuple< Row, Row, Row, F8, F8, F32, F16, I128, I128, I32, I32, I32, I16, NonPersistent, CompV3, Tree>,
std::tuple< Row, Col, Row, F8, F8, F32, F16, I128, I128, I32, I32, I32, I16, NonPersistent, CompV3, Tree>,
std::tuple< Col, Col, Row, F8, F8, F32, F16, I128, I128, I32, I32, I32, I16, NonPersistent, CompV3, Tree>,
std::tuple< Col, Row, Row, F8, F8, F32, F16, I128, I128, I32, I32, I32, I16, NonPersistent, CompV3, Tree>
>;
// ============================= Other Pipelines =============================
using KernelTypesStreamKPipelines = ::testing::Types<
std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Persistent, Mem, Atomic>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, NonPersistent, Mem, Tree>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Persistent, Mem, Linear>,
std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Persistent, CompV4, Atomic>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, NonPersistent, CompV4, Tree>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Persistent, CompV4, Linear>
>;
// clang-format on

View File

@@ -71,23 +71,27 @@ template <typename Tuple>
class TestCkTileStreamK : public ::testing::Test
{
protected:
using ALayout = std::tuple_element_t<0, Tuple>;
using BLayout = std::tuple_element_t<1, Tuple>;
using CLayout = std::tuple_element_t<2, Tuple>;
using ADataType = std::tuple_element_t<3, Tuple>;
using BDataType = std::tuple_element_t<4, Tuple>;
using AccDataType = std::tuple_element_t<5, Tuple>;
using CDataType = std::tuple_element_t<6, Tuple>;
using DsLayout = ck_tile::tuple<>;
using DsDataType = ck_tile::tuple<>;
static constexpr ck_tile::index_t M_Tile = std::tuple_element_t<7, Tuple>::value;
static constexpr ck_tile::index_t N_Tile = std::tuple_element_t<8, Tuple>::value;
static constexpr ck_tile::index_t K_Tile = std::tuple_element_t<9, Tuple>::value;
static constexpr bool Persistent = std::tuple_element_t<10, Tuple>::value;
static constexpr auto PipelineType = std::tuple_element_t<11, Tuple>::value;
using ALayout = std::tuple_element_t<0, Tuple>;
using BLayout = std::tuple_element_t<1, Tuple>;
using CLayout = std::tuple_element_t<2, Tuple>;
using ADataType = std::tuple_element_t<3, Tuple>;
using BDataType = std::tuple_element_t<4, Tuple>;
using AccDataType = std::tuple_element_t<5, Tuple>;
using CDataType = std::tuple_element_t<6, Tuple>;
using DsLayout = ck_tile::tuple<>;
using DsDataType = ck_tile::tuple<>;
static constexpr ck_tile::index_t M_Tile = std::tuple_element_t<7, Tuple>::value;
static constexpr ck_tile::index_t N_Tile = std::tuple_element_t<8, Tuple>::value;
static constexpr ck_tile::index_t K_Tile = std::tuple_element_t<9, Tuple>::value;
static constexpr ck_tile::index_t M_Warp_Tile = std::tuple_element_t<10, Tuple>::value;
static constexpr ck_tile::index_t N_Warp_Tile = std::tuple_element_t<11, Tuple>::value;
static constexpr ck_tile::index_t K_Warp_Tile = std::tuple_element_t<12, Tuple>::value;
template <ck_tile::StreamKReductionStrategy ReductionStrategy,
bool PadM = true,
static constexpr bool Persistent = std::tuple_element_t<13, Tuple>::value;
static constexpr auto PipelineType = std::tuple_element_t<14, Tuple>::value;
static constexpr auto ReductionStrategy = std::tuple_element_t<15, Tuple>::value;
template <bool PadM = true,
bool PadN = true,
bool PadK = true,
bool Preshuffle = false,
@@ -99,10 +103,6 @@ class TestCkTileStreamK : public ::testing::Test
constexpr ck_tile::index_t N_Warp = 2;
constexpr ck_tile::index_t K_Warp = 1;
constexpr ck_tile::index_t M_Warp_Tile = 32;
constexpr ck_tile::index_t N_Warp_Tile = 32;
constexpr ck_tile::index_t K_Warp_Tile = 16;
constexpr bool kPadM = PadM;
constexpr bool kPadN = PadN;
constexpr bool kPadK = PadK;
@@ -269,8 +269,7 @@ class TestCkTileStreamK : public ::testing::Test
stride_C};
ck_tile::index_t num_accumulations_per_tile =
invoke_streamk<ck_tile::StreamKReductionStrategy::Atomic>(
args, ck_tile::stream_config{nullptr, false, 0, 0, 1});
invoke_streamk<>(args, ck_tile::stream_config{nullptr, false, 0, 0, 1});
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());

View File

@@ -0,0 +1,220 @@
#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
import unittest
from unittest.mock import mock_open, patch
from generate_test_files import suffix_to_file_tag, parse_types_header, output_path
# ------------------------------------------------------------ #
# Unit tests for helper functions in generate_test_files.py
# ------------------------------------------------------------ #
class TestSuffixToFileTag(unittest.TestCase):
def test_fp16_token(self):
suffix = "Fp16"
expected_tag = "fp16"
self.assertEqual(suffix_to_file_tag(suffix), expected_tag)
def test_bf16_token(self):
suffix = "Bf16"
expected_tag = "bf16"
self.assertEqual(suffix_to_file_tag(suffix), expected_tag)
def test_fp8_token(self):
suffix = "Fp8"
expected_tag = "fp8"
self.assertEqual(suffix_to_file_tag(suffix), expected_tag)
def test_bf8_token(self):
suffix = "Bf8"
expected_tag = "bf8"
self.assertEqual(suffix_to_file_tag(suffix), expected_tag)
def test_nonpersistent_token(self):
suffix = "NonPersistent"
expected_tag = "nonpersistent"
self.assertEqual(suffix_to_file_tag(suffix), expected_tag)
def test_persistent_token(self):
suffix = "Persistent"
expected_tag = "persistent"
self.assertEqual(suffix_to_file_tag(suffix), expected_tag)
def test_atomic_token(self):
suffix = "Atomic"
expected_tag = "atomic"
self.assertEqual(suffix_to_file_tag(suffix), expected_tag)
def test_linear_token(self):
suffix = "Linear"
expected_tag = "linear"
self.assertEqual(suffix_to_file_tag(suffix), expected_tag)
def test_tree_token(self):
suffix = "Tree"
expected_tag = "tree"
self.assertEqual(suffix_to_file_tag(suffix), expected_tag)
def test_compv3_token(self):
suffix = "CompV3"
expected_tag = "compv3"
self.assertEqual(suffix_to_file_tag(suffix), expected_tag)
def test_pipelines_token(self):
suffix = "Pipelines"
expected_tag = "pipelines"
self.assertEqual(suffix_to_file_tag(suffix), expected_tag)
def test_unknown_token(self):
suffix = "unknown"
with self.assertRaises(ValueError):
suffix_to_file_tag(suffix)
def test_multiple_valid_tokens(self):
suffix = "Fp16PersistentAtomicCompV3"
expected_tag = "fp16_persistent_atomic_compv3"
self.assertEqual(suffix_to_file_tag(suffix), expected_tag)
def test_multiple_tokens_with_unknown(self):
suffix = "Fp16PersistentUnknownCompV3"
with self.assertRaises(ValueError):
suffix_to_file_tag(suffix)
class TestParseTypesHeader(unittest.TestCase):
def validate_entries(self, entries, expected_entries):
self.assertEqual(len(entries), len(expected_entries))
for idx in range(len(entries)):
self.assertDictEqual(entries[idx], expected_entries[idx])
def test_empty_entry(self):
"""Test that an empty file returns no entries."""
mock_content = ""
with patch("builtins.open", mock_open(read_data=mock_content)):
entries = parse_types_header("fake_path.hpp", "atomic_smoke")
self.assertEqual(len(entries), 0)
def test_pipelines_smoke(self):
"""Test pipelines_smoke target: matches suffix == 'Pipelines'.
Includes: Pipelines
Excludes: Fp8NonPersistentTreeCompV3
"""
mock_content = (
"using KernelTypesStreamKPipelines = ...\n"
"using KernelTypesStreamKFp8NonPersistentTreeCompV3 = ...\n"
)
with patch("builtins.open", mock_open(read_data=mock_content)):
entries = parse_types_header("fake_path.hpp", "pipelines_smoke")
expected = [
{
"type_alias": "KernelTypesStreamKPipelines",
"class_name": "TestCkTileStreamKPipelines",
"file_tag": "pipelines",
}
]
self.validate_entries(entries, expected)
def test_extended(self):
"""Test extended target: matches 'Atomic' in suffix OR suffix == 'Pipelines'.
Includes: Fp16PersistentAtomic, Pipelines
Excludes: Bf16Linear
"""
mock_content = (
"using KernelTypesStreamKFp16PersistentAtomic = ...\n"
"using KernelTypesStreamKPipelines = ...\n"
"using KernelTypesStreamKBf16Linear = ...\n"
)
with patch("builtins.open", mock_open(read_data=mock_content)):
entries = parse_types_header("fake_path.hpp", "extended")
expected = [
{
"type_alias": "KernelTypesStreamKFp16PersistentAtomic",
"class_name": "TestCkTileStreamKFp16PersistentAtomic",
"file_tag": "fp16_persistent_atomic",
},
{
"type_alias": "KernelTypesStreamKPipelines",
"class_name": "TestCkTileStreamKPipelines",
"file_tag": "pipelines",
},
]
self.validate_entries(entries, expected)
def test_atomic_smoke(self):
"""Test atomic_smoke target: matches 'Atomic' in suffix AND suffix != 'Pipelines'.
Includes: Fp16PersistentAtomic
Excludes: Bf16Linear, Pipelines
"""
mock_content = (
"using KernelTypesStreamKFp16PersistentAtomic = ...\n"
"using KernelTypesStreamKBf16Linear = ...\n"
"using KernelTypesStreamKPipelines = ...\n"
)
with patch("builtins.open", mock_open(read_data=mock_content)):
entries = parse_types_header("fake_path.hpp", "atomic_smoke")
expected = [
{
"type_alias": "KernelTypesStreamKFp16PersistentAtomic",
"class_name": "TestCkTileStreamKFp16PersistentAtomic",
"file_tag": "fp16_persistent_atomic",
}
]
self.validate_entries(entries, expected)
def test_linear_smoke(self):
"""Test linear_smoke target: matches 'Linear' in suffix AND suffix != 'Pipelines'.
Includes: Fp8NonPersistentLinear
Excludes: Bf16PersistentAtomic, Pipelines
"""
mock_content = (
"using KernelTypesStreamKFp8NonPersistentLinear = ...\n"
"using KernelTypesStreamKBf16PersistentAtomic = ...\n"
"using KernelTypesStreamKPipelines = ...\n"
)
with patch("builtins.open", mock_open(read_data=mock_content)):
entries = parse_types_header("fake_path.hpp", "linear_smoke")
expected = [
{
"type_alias": "KernelTypesStreamKFp8NonPersistentLinear",
"class_name": "TestCkTileStreamKFp8NonPersistentLinear",
"file_tag": "fp8_nonpersistent_linear",
}
]
self.validate_entries(entries, expected)
def test_tree_smoke(self):
"""Test tree_smoke target: matches 'Tree' in suffix AND suffix != 'Pipelines'.
Includes: Bf8PersistentTreeCompV3
Excludes: Fp16Linear, Pipelines
"""
mock_content = (
"using KernelTypesStreamKBf8PersistentTreeCompV3 = ...\n"
"using KernelTypesStreamKFp16Linear = ...\n"
"using KernelTypesStreamKPipelines = ...\n"
)
with patch("builtins.open", mock_open(read_data=mock_content)):
entries = parse_types_header("fake_path.hpp", "tree_smoke")
expected = [
{
"type_alias": "KernelTypesStreamKBf8PersistentTreeCompV3",
"class_name": "TestCkTileStreamKBf8PersistentTreeCompV3",
"file_tag": "bf8_persistent_tree_compv3",
}
]
self.validate_entries(entries, expected)
class TestOutputPath(unittest.TestCase):
def test_output_path(self):
"""Test that output_path generates the correct file path."""
entry = {"file_tag": "fp16_persistent_atomic"}
output_dir = "/some/output/dir"
expected = "/some/output/dir/test_gemm_streamk_fp16_persistent_atomic.cpp"
self.assertEqual(output_path(output_dir, entry), expected)
if __name__ == "__main__":
unittest.main()

View File

@@ -1,324 +0,0 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
include(generate_configs.cmake)
# ============================================================================
# GEMM Tile Engine Unit Tests
#
# This CMake file creates unit tests for tile_engine generated GEMM kernels.
# It follows the exact same build patterns as tile_engine for consistency
# and reliability. Each kernel configuration gets its own test executable.
# ============================================================================
# Locate tile_engine GEMM scripts directory
set(TILE_ENGINE_GEMM_DIR "${PROJECT_SOURCE_DIR}/tile_engine/ops/gemm_streamk")
if(NOT EXISTS ${TILE_ENGINE_GEMM_DIR})
message(WARNING "Tile engine directory not found: ${TILE_ENGINE_GEMM_DIR}")
return()
endif()
# ============================================================================
# create_individual_gemm_test_target
#
# Creates a single test executable for a specific kernel configuration.
# Mirrors tile_engine's create_individual_gemm_target function for consistency.
#
# Parameters:
# datatype - Data type (fp16, bf16, fp32, etc.)
# layout - Matrix layout (rcr, rrr, ccr, crr)
# config_name - Configuration file name without .json extension
# trait - Kernel trait combination string
# tile_config - Tile configuration parameters
# config_json - Full path to JSON configuration file
# ============================================================================
function(create_individual_gemm_test_target datatype layout config_name trait tile_config config_json)
set(target_name "test_gemm_streamk_tile_engine_${datatype}_${layout}_${config_name}_${trait}_${tile_config}")
set(working_path "${CMAKE_CURRENT_BINARY_DIR}/${datatype}/${layout}/${config_name}")
# Generated header path (already created during cmake configuration)
set(test_header "${working_path}/gemm_streamk_single_${datatype}_${layout}_${trait}_${tile_config}.hpp")
set(test_params_header "${working_path}/test_params.hpp")
# Verify header exists (should have been generated during cmake configuration)
if(NOT EXISTS ${test_header})
message(WARNING "Generated header not found: ${test_header}")
return()
endif()
# Verify test parameters header exists
if(NOT EXISTS ${test_params_header})
message(WARNING "Test parameters header not found: ${test_params_header}")
return()
endif()
# Create GTest executable for this kernel configuration
add_gtest_executable(${target_name}
${CMAKE_CURRENT_SOURCE_DIR}/test_gemm_streamk_simple.cpp
)
# Configure GPU architectures for HIP compilation
set_property(TARGET ${target_name} PROPERTY HIP_ARCHITECTURES ${GEMM_TEST_GPU_TARGETS})
# Define preprocessor macros for generated header location and test parameters
target_compile_definitions(${target_name} PRIVATE
GEMM_SINGLE_INSTANCE_HPP="${test_header}"
GEMM_TEST_PARAMS_HPP="${test_params_header}"
)
# Include directories for headers and dependencies
target_include_directories(${target_name} PRIVATE
${PROJECT_SOURCE_DIR}/include
${PROJECT_BINARY_DIR}/include
${PROJECT_SOURCE_DIR} # Root directory for tile_engine access
${GTEST_INCLUDE_DIRS}
)
# Compiler options matching tile_engine requirements
target_compile_options(${target_name} PRIVATE
-Wno-undefined-func-template # Suppress template warnings
-Wno-float-equal # Allow floating point comparisons
--offload-compress # Enable GPU code compression
-include ${test_header} # Auto-include generated header
)
# Add FP8 format definitions for proper data type interpretation
if(CK_USE_OCP_FP8)
target_compile_options(${target_name} PRIVATE -DCK_TILE_USE_OCP_FP8)
endif()
message(DEBUG " Created test target: ${target_name}")
endfunction()
# ============================================================================
# build_gemm_test_targets
#
# Builds all test targets for a specific datatype/layout/config combination.
# Uses tile_engine's two-step process: list kernels, then generate tests.
#
# Parameters:
# datatype - Data type (fp16, bf16, fp32, etc.)
# layout - Matrix layout (rcr, rrr, ccr, crr)
# config_name - Configuration file name without .json extension
# ============================================================================
function(build_gemm_test_targets datatype layout config_name configs_dir_path)
set(working_path "${CMAKE_CURRENT_BINARY_DIR}/${datatype}/${layout}/${config_name}")
# Locate and validate configuration file
set(config_filename "${config_name}.json")
set(json_blob "${configs_dir_path}/${config_filename}")
if(NOT EXISTS ${json_blob})
message(WARNING "Test config file not found: ${json_blob}")
return()
endif()
# Prepare build directory for this configuration
file(MAKE_DIRECTORY ${working_path})
# STEP 1: Discovery phase - list all valid kernel configurations
execute_process(
COMMAND ${Python3_EXECUTABLE} -u ${TILE_ENGINE_GEMM_DIR}/gemm_streamk_instance_builder.py
--working_path ${working_path}
--datatype ${datatype}
--layout ${layout}
--config_json ${json_blob}
--list_kernels
--gpu_targets "${SUPPORTED_GPU_TARGETS}"
WORKING_DIRECTORY ${TILE_ENGINE_GEMM_DIR}
RESULT_VARIABLE ret
OUTPUT_VARIABLE list_output
ERROR_VARIABLE list_error
)
if(NOT ret EQUAL 0)
message(WARNING "Failed to list kernels for ${datatype}_${layout}_${config_name}: ${list_error}")
return()
endif()
# Verify kernel list file was generated
if(NOT EXISTS ${working_path}/gemm_kernel_list.txt)
message(DEBUG "No kernels found for ${datatype}_${layout}_${config_name} (validation filtered out all combinations)")
return()
endif()
message(DEBUG "Building tests for ${datatype}_${layout}_${config_name}")
# STEP 2a: Extract test parameters from config
set(test_params_file "${working_path}/test_params.hpp")
execute_process(
COMMAND ${Python3_EXECUTABLE} -u ${CMAKE_CURRENT_SOURCE_DIR}/extract_test_params.py
--config_file ${json_blob}
--output_file ${test_params_file}
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
RESULT_VARIABLE extract_ret
OUTPUT_VARIABLE extract_output
ERROR_VARIABLE extract_error
)
if(NOT extract_ret EQUAL 0)
message(WARNING "Failed to extract test parameters for ${datatype}_${layout}: ${extract_error}")
return()
endif()
# STEP 2b: Header generation phase - generate headers using --gen_single
message(STATUS " Generating headers using --gen_single...")
file(STRINGS ${working_path}/gemm_kernel_list.txt kernel_lines)
set(gen_count 0)
foreach(line IN LISTS kernel_lines)
# Parse kernel specification format: kernel_name|tile_config|trait_combo
string(REPLACE "|" ";" parts "${line}")
list(LENGTH parts parts_len)
if(parts_len EQUAL 3)
list(GET parts 0 kernel_name)
list(GET parts 1 tile_config)
list(GET parts 2 trait_combo)
# Generate header using --gen_single
execute_process(
COMMAND ${Python3_EXECUTABLE} -u ${TILE_ENGINE_GEMM_DIR}/gemm_streamk_instance_builder.py
--working_path ${working_path}
--datatype ${datatype}
--layout ${layout}
--config_json ${json_blob}
--gen_single
--kernel_name "${kernel_name}"
--tile_config "${tile_config}"
--trait_combo "${trait_combo}"
--gpu_targets "${SUPPORTED_GPU_TARGETS}"
WORKING_DIRECTORY ${TILE_ENGINE_GEMM_DIR}
RESULT_VARIABLE gen_ret
OUTPUT_VARIABLE gen_output
ERROR_VARIABLE gen_error
)
if(NOT gen_ret EQUAL 0)
message(WARNING "Failed to generate header for ${kernel_name}: ${gen_error}")
else()
math(EXPR gen_count "${gen_count} + 1")
endif()
endif()
endforeach()
message(STATUS " Generated ${gen_count} headers for ${datatype}_${layout}")
# STEP 3: Target creation phase - create test targets
message(STATUS " Creating test targets...")
file(STRINGS ${working_path}/gemm_kernel_list.txt kernel_lines)
set(test_count 0)
foreach(line IN LISTS kernel_lines)
# Parse kernel specification format: kernel_name|tile_config|trait_combo
string(REPLACE "|" ";" parts "${line}")
list(LENGTH parts parts_len)
if(parts_len EQUAL 3)
list(GET parts 0 kernel_name)
list(GET parts 1 tile_config)
list(GET parts 2 trait_combo)
# Generate test target for this kernel configuration
create_individual_gemm_test_target("${datatype}" "${layout}" "${config_name}" "${trait_combo}" "${tile_config}" "${json_blob}")
math(EXPR test_count "${test_count} + 1")
endif()
endforeach()
message(STATUS " Created ${test_count} test targets for ${datatype}_${layout}")
endfunction()# ============================================================================
# MAIN EXECUTION - Test Target Generation
# ============================================================================
message(STATUS "=== Starting StreamK GEMM Tile Engine Test Configuration ===")
message(STATUS "SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
# GPU architecture filtering - only build tests for supported architectures
set(GEMM_TEST_GPU_TARGETS "")
set(DESIRED_TARGETS "gfx90a;gfx942;gfx950;gfx12-generic")
foreach(target IN LISTS SUPPORTED_GPU_TARGETS)
if(target IN_LIST DESIRED_TARGETS)
list(APPEND GEMM_TEST_GPU_TARGETS ${target})
message(STATUS " Adding GPU target for tests: ${target}")
endif()
endforeach()
# Early exit if no compatible GPU architectures are available
if(NOT GEMM_TEST_GPU_TARGETS)
message(WARNING "Skipping StreamK GEMM Tile Engine tests: No supported GPU targets (gfx90a, gfx942, gfx950, gfx12-generic) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
return()
endif()
message(STATUS "Building StreamK GEMM tile engine tests for GPU targets: ${GEMM_TEST_GPU_TARGETS}")
# Enable parallel compilation optimizations
# Set up job pools for better parallel compilation control
set_property(GLOBAL PROPERTY JOB_POOLS
compile_heavy=4 # Limit heavy compilations to prevent OOM
compile_normal=16 # Allow more parallel normal compilations
)
# Enable compiler cache if available and explicitly requested
# Disabled by default due to permission issues in CI environments
option(ENABLE_CCACHE_TESTS "Enable ccache for test compilation" OFF)
if(ENABLE_CCACHE_TESTS)
find_program(CCACHE_PROGRAM ccache)
if(CCACHE_PROGRAM)
set(CMAKE_CXX_COMPILER_LAUNCHER ${CCACHE_PROGRAM})
message(STATUS "Using ccache for faster test compilation")
else()
message(WARNING "ccache requested but not found")
endif()
else()
message(STATUS "ccache disabled for tests (use -DENABLE_CCACHE_TESTS=ON to enable)")
endif()
# ============================================================================
# Test Configuration Matrix - Clean Focused Design
# ============================================================================
# All supported data types and layouts for comprehensive testing
# Note: fp64 not included (no MFMA hardware support)
set(TEST_DATATYPES "fp16;bf16")
# Temporarily only test rcr and crr
# set(TEST_LAYOUTS "rcr;rrr;ccr;crr")
set(TEST_LAYOUTS "rcr;crr")
# ============================================================================
# Test Target Generation - Datatype-Specific Categories
# ============================================================================
# 1. SMOKE TESTS: Test for basic functionality with data types (fp8, bf8, fp16, bf16)
# Temporarily only consider fp16
# set(SMALL_DATATYPES "fp16;bf16;fp8;bf8")
set(SMALL_DATATYPES "fp16")
set(SIXTEEN_BIT_DATATYPES "fp16;bf16")
set(EIGHT_BIT_DATATYPES "fp8;bf8")
set(LARGE_TILES "256,256,32")
set(SMALL_TILES "128,128,32")
set(CONFIG_LIST "")
set(GENERATED_CONFIG_PATH ${CMAKE_CURRENT_BINARY_DIR}/configs)
get_cu_count(CU_COUNT)
message(STATUS "Generating and processing configs for Stream-K tests")
foreach(datatype IN LISTS SMALL_DATATYPES)
if(datatype IN_LIST SIXTEEN_BIT_DATATYPES)
generate_test_configs(${CU_COUNT} ${LARGE_TILES} ${datatype} CONFIG_LIST ${GENERATED_CONFIG_PATH})
else()
generate_test_configs(${CU_COUNT} ${SMALL_TILES} ${datatype} CONFIG_LIST ${GENERATED_CONFIG_PATH})
endif()
foreach(config IN LISTS CONFIG_LIST)
# testing all layouts (rcr, rrr, ccr, crr)
foreach(layout IN LISTS TEST_LAYOUTS)
build_gemm_test_targets("${datatype}" "${layout}" "${config}" "${GENERATED_CONFIG_PATH}")
endforeach()
endforeach()
endforeach()
# ============================================================================
message(STATUS "StreamK GEMM tile engine tests configured with datatype-specific design:")
message(STATUS " - Smoke tests: fp16/bf16/fp8/bf8 (all layouts)")

View File

@@ -1,64 +0,0 @@
# Stream-K GEMM Tile Engine Unit Tests
## How It Works
This unit test system integrates **tile_engine's kernel generation** into automated testing:
1. **Uses tile_engine scripts directly**: Same Python scripts that generate tile_engine kernels
2. **JSON-based configuration**: Define test parameters in JSON files (like tile_engine)
3. **Build-time generation**: CMake calls tile_engine scripts to generate kernel headers
4. **Individual test executables**: Each kernel configuration becomes a separate test
5. **Tile_engine verification**: Uses exact same error thresholds and validation as tile_engine
## Tile Engine Integration
```
JSON Config → tile_engine Python scripts → Generated Headers → Test Executables
```
- **`--list_kernels`**: Get available kernel configurations from JSON
- **`--gen_individual`**: Generate all kernel headers in parallel during CMake configuration
- **`--gen_single`**: Generate individual kernel header for each configuration
- **Same verification**: Uses tile_engine's adaptive error thresholds and reference calculations
- **Same patterns**: Follows tile_engine's tensor initialization, stride calculation, and kernel launching
### Config-Specific Test Parameters
Each test configuration can specify optimized problem sizes in its JSON file:
- **`test_params.problem_sizes`**: Array of `{m, n, k, split_k}` configurations
- **CMake extraction**: `extract_test_params.py` generates config-specific test parameter files
- **Build integration**: Each test target uses parameters appropriate for its kernel configuration
- **Optimized testing**: Different configs test different problem sizes that showcase their strengths
The key idea: **Unit tests that use tile_engine's exact kernel generation and verification methodology** instead of creating separate test infrastructure.
## Test Configurations
Test configs are generated during the Generation Phase. They are stored under the build directory at test/ck_tile/gemm_streamk_tile_engine/configs. The Compute Unit (CU) count of the device is required to generate the configs. If the Generation Phase occurs on a machine without a GPU or does not contain same GPU architecture on which you will run the tests, you can manually set the CU count using the `CU_COUNT` option:
```bash
# Assuming you are at the root of the repo
cd build
../script/cmake-ck-dev.sh .. gfx90a -G Ninja -DCU_COUNT=100
```
You can reference the public whitepaper for your specific GPU to get the appropriate CU count.
If no `CU_COUNT` option is given and no HIP device is found, then the default value of 100 CUs will be used to determine the problem sizes tested.
### 1. **Smoke Tests**
- **Purpose**: Basic functionality validation for fp16/bf16/fp8/bf8 data types
- **Config**: 256x256x32 (for bf16/fp16) or 128x128x32 (for bf8/fp8), warp 2x2x1, warp_tile 32x32x16
- **Traits**: compv3 pipeline only
- **Coverage**: All 4 layouts (rcr, rrr, ccr, crr)
## Data Type Support
-**fp16, bf16, fp8, bf8**: Fully supported - all layouts (rcr, rrr, ccr, crr)
-**fp64**: Not supported (hardware MFMA limitation)
-**fp32, pk-int4-t**: Not yet supported by gemm_instance_builder (will be added later)
## Test Result Behavior
Tests automatically handle unsupported configurations through runtime validation:
- **PASSED**: Kernel executed correctly with results within error thresholds ✅
- **SKIPPED**: Kernel validation returned "Arguments not supported" (expected for certain problem sizes/configurations) ⚠️
- **FAILED**: Actual error or incorrect computation results ❌
When a kernel's `IsSupportedArgument()` check fails (e.g., due to vector alignment requirements, dimension constraints, or padding limitations), the test is automatically skipped rather than failed. This allows comprehensive testing across various problem sizes while gracefully handling configurations that don't meet specific kernel requirements.

View File

@@ -1,50 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <hip/hip_runtime.h>
#include <iostream>
/**
* @brief Determines whether a `hipError` is present in the given `error_status`
* @return true if the `error_status` has an error, otherwise false.
*/
bool has_error(const hipError_t& error_status)
{
if(error_status != hipSuccess)
{
std::cerr << hipGetErrorString(error_status);
return true;
}
return false;
}
/**
* @brief Returns the number of Compute Units (CUs) on the given device.
* @return The number of CUs on the device. If an error occurs while querying the device, zero is
* returned.
*/
int get_cu_count()
{
hipDevice_t dev;
hipDeviceProp_t dev_prop;
const hipError_t device_status = hipGetDevice(&dev);
if(has_error(device_status))
return 0;
const hipError_t prop_status = hipGetDeviceProperties(&dev_prop, dev);
if(has_error(prop_status))
return 0;
return dev_prop.multiProcessorCount;
}
int main()
{
std::cout << get_cu_count();
return 0;
}

View File

@@ -1,74 +0,0 @@
#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
import json
import argparse
import os
from pathlib import Path
def extract_test_params(config_file, output_file):
"""Extract test parameters from config JSON and write to output file"""
# Read config file
with open(config_file, "r") as f:
config = json.load(f)
# Extract test parameters
test_params = []
if "test_params" in config and "problem_sizes" in config["test_params"]:
test_params = config["test_params"]["problem_sizes"]
else:
# Default test parameters if none specified
test_params = [
{"m": 256, "n": 256, "k": 128, "split_k": 1},
{"m": 256, "n": 256, "k": 1024, "split_k": 1},
{"m": 256, "n": 512, "k": 512, "split_k": 1},
{"m": 512, "n": 256, "k": 512, "split_k": 1},
]
# Write to output file in C++ format
output_dir = Path(output_file).parent
output_dir.mkdir(parents=True, exist_ok=True)
with open(output_file, "w") as f:
f.write("// Generated test parameters for this configuration\n")
f.write("// This file is auto-generated during CMake configuration\n\n")
f.write("static const std::vector<GemmTestParams> CONFIG_TEST_PARAMS = {\n")
for i, params in enumerate(test_params):
comma = "," if i < len(test_params) - 1 else ""
f.write(
f" {{{params['m']}, {params['n']}, {params['k']}, {params['split_k']}}}{comma}\n"
)
f.write("};\n")
print(
f"Extracted {len(test_params)} test parameters from {config_file} -> {output_file}"
)
def main():
parser = argparse.ArgumentParser(
description="Extract test parameters from config JSON"
)
parser.add_argument("--config_file", required=True, help="Input config JSON file")
parser.add_argument(
"--output_file", required=True, help="Output test parameters file"
)
args = parser.parse_args()
if not os.path.exists(args.config_file):
print(f"Error: Config file not found: {args.config_file}")
return 1
extract_test_params(args.config_file, args.output_file)
return 0
if __name__ == "__main__":
exit(main())

View File

@@ -1,121 +0,0 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
set(CU_COUNT 0 CACHE STRING "Number of Compute Units on the device")
# ============================================================================
# get_cu_count
#
# Returns the CU count for the device. If the given cu_count_arg is a positive
# integer, then the nothing happens. Otherwise, we attempt to query the CU
# count from the device. If the query is unsucessful, the default value of 100
# is returned.
#
# Parameters:
# cu_count_arg - The starting CU count
# ============================================================================
function(get_cu_count cu_count_arg)
message(STATUS "Starting query for CU count needed for Stream-K test config generation")
if(NOT "${${cu_count_arg}}" MATCHES "^[0-9]+$")
message(FATAL_ERROR "The CU count must be a non-negative integer. \
The given value of ${${cu_count_arg}} is invalid.")
endif()
if("${${cu_count_arg}}" STREQUAL "0")
set(CPP_FILE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cu_count.cpp)
set(CPP_EXE_PATH ${CMAKE_CURRENT_BINARY_DIR}/cu_count)
execute_process(
COMMAND ${CMAKE_HIP_COMPILER} -x hip ${CPP_FILE_PATH} -o ${CPP_EXE_PATH}
RESULT_VARIABLE compile_exit_code
)
if (NOT compile_exit_code EQUAL 0)
message(FATAL_ERROR "Compilation of ${CPP_FILE_PATH} failed.\n")
endif()
# Get the HIP library directory
get_filename_component(HIP_COMPILER_DIR ${CMAKE_HIP_COMPILER} DIRECTORY)
get_filename_component(HIP_ROOT_DIR ${HIP_COMPILER_DIR} DIRECTORY)
set(HIP_LIB_DIR "${HIP_ROOT_DIR}/lib")
# Set library path for runtime execution
if(WIN32)
set(ENV{PATH} "${HIP_LIB_DIR};$ENV{PATH}")
else()
set(ENV{LD_LIBRARY_PATH} "${HIP_LIB_DIR}:$ENV{LD_LIBRARY_PATH}")
endif()
execute_process(
COMMAND ${CPP_EXE_PATH}
OUTPUT_STRIP_TRAILING_WHITESPACE
ERROR_VARIABLE standard_error
OUTPUT_VARIABLE queried_cu_count
RESULT_VARIABLE queried_cu_count_exit_code
)
if (standard_error)
message(STATUS "Error information from attempting to query HIP device and properties:\n"
"${standard_error}")
endif()
if (NOT queried_cu_count_exit_code EQUAL 0)
message(STATUS "Failed to run ${CPP_EXE_PATH} to query the device's CU count")
endif()
# Delete the generated cu_count executable
file(REMOVE "${CPP_EXE_PATH}")
if((queried_cu_count STREQUAL "0") OR (NOT queried_cu_count_exit_code EQUAL 0))
message(WARNING "Unable to query the number of Compute Units. \
Please use the CU_COUNT CLI option to pass in the \
number of Compute Units for your target device; otherwise, \
the default value of 100 will be used.")
set(${cu_count_arg} 100 PARENT_SCOPE)
else()
set(${cu_count_arg} ${queried_cu_count} PARENT_SCOPE)
endif()
endif()
endfunction()
# ============================================================================
# generate_test_configs
#
# Generate config json files for Stream-K tests
#
# Parameters:
# cu_count_arg - The number of CUs on the device
# tile_sizes - A list of block tile sizes: tile_m,tile_n,tile_k
# datatype - The datatype for which the config is being generated
# config_list - The variable to which the list of config file names are written
# configs_path - Path to the configs directory to which config files are written
# ============================================================================
function(generate_test_configs cu_count_arg tile_sizes datatype config_list configs_path)
message(STATUS "Generating Stream-K test config files for ${datatype}")
file(MAKE_DIRECTORY ${configs_path})
execute_process(
COMMAND ${Python3_EXECUTABLE} -u ${CMAKE_CURRENT_SOURCE_DIR}/generate_configs.py
--cu_count ${cu_count_arg}
--configs_dir_path ${configs_path}
--tiles ${tile_sizes}
--datatype ${datatype}
OUTPUT_VARIABLE CONFIG_LIST
OUTPUT_STRIP_TRAILING_WHITESPACE
RESULT_VARIABLE script_ret_val
)
if (NOT script_ret_val EQUAL 0)
message(FATAL_ERROR "Eror occured during execution of ${CMAKE_CURRENT_SOURCE_DIR}/generate_configs.py")
endif()
set(${config_list} ${CONFIG_LIST} PARENT_SCOPE)
endfunction()

View File

@@ -1,287 +0,0 @@
#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
from enum import Enum
from typing import Dict, Tuple, List
import argparse
import json
import os
import sys
from dataclasses import dataclass, field, asdict
@dataclass
class TileConfig:
"""Represents the Tile Config section of a Tile Engine config"""
tile_m: List[int] = field(default_factory=list)
tile_n: List[int] = field(default_factory=list)
tile_k: List[int] = field(default_factory=list)
warp_m: List[int] = field(default_factory=lambda: [2])
warp_n: List[int] = field(default_factory=lambda: [2])
warp_k: List[int] = field(default_factory=lambda: [1])
warp_tile_m: List[int] = field(default_factory=lambda: [16, 32])
warp_tile_n: List[int] = field(default_factory=lambda: [16, 32])
# Temporarily only consider 16 for warp_tile_k
# warp_tile_k: List[int] = field(default_factory=lambda: [8, 16, 32])
warp_tile_k: List[int] = field(default_factory=lambda: [16])
def to_dict(self) -> Dict:
return {k: {"values": v} for k, v in asdict(self).items()}
@dataclass
class TraitConfig:
"""Represents the Trait Config section of a Tile Engine config"""
# Temporarily only consider compv3
# pipeline: List[str] = field(default_factory=lambda: ["compv3", "mem"])
pipeline: List[str] = field(default_factory=lambda: ["compv3"])
epilogue: List[str] = field(default_factory=lambda: ["cshuffle"])
scheduler: List[str] = field(default_factory=lambda: ["intrawave"])
pad_m: List[bool] = field(default_factory=lambda: [False])
pad_n: List[bool] = field(default_factory=lambda: [False])
pad_k: List[bool] = field(default_factory=lambda: [False])
persistent: List[bool] = field(default_factory=lambda: [True, False])
reduction_strategy: List[str] = field(default_factory=list)
def to_dict(self) -> Dict:
return {k: {"values": v} for k, v in asdict(self).items()}
class TestVariant(Enum):
"""Represents a Stream-K test variant"""
def __init__(
self,
val: int,
reduction_strategy: List[str],
persistent: List[bool],
datatypes: List[str],
description: str,
):
self._value_ = val
self.reduction_strategy = reduction_strategy
self.persistent = persistent
self.datatypes = datatypes
self.description = description
ATOMIC_SMOKE = (
0,
["atomic"],
[True, False],
# Temporarily only run fp16 tests
# ["fp16", "bf16", "fp8", "bf8"],
["fp16"],
"Stream-K atomic smoke tests",
)
REDUCTION_SMOKE = (
2,
["linear", "tree"],
[True, False],
# Temporarily only run fp16 tests
# ["fp16", "bf16", "fp8", "bf8"],
["fp16"],
"Stream-K reduction smoke tests",
)
EXTENDED = (
3,
["atomic"],
[True, False],
# Temporarily only run fp16 tests
# ["fp16", "bf16", "fp8", "bf8"],
["fp16"],
"Stream-K extended smoke tests",
)
def apply(self, trait_config: TraitConfig) -> None:
"""Applies the current test variant's persistent and reduction strategy setting to the given trait_config"""
trait_config.persistent = self.persistent
trait_config.reduction_strategy = self.reduction_strategy
@dataclass
class ProblemSize:
"""Represents a problem size in a Tile Engine config"""
m: int
n: int
k: int
variant: TestVariant
split_k: int = 1
def to_dict(self) -> Dict:
return {"m": self.m, "n": self.n, "k": self.k, "split_k": self.split_k}
@dataclass
class Config:
"""Represents a Tile Engine config"""
description: str
problem_sizes: list[ProblemSize] = field(default_factory=list)
tile_config: TileConfig = field(default_factory=TileConfig)
trait_config: TraitConfig = field(default_factory=TraitConfig)
k_block_per_cu: int = 1
permute_n: bool = False
def add_problem_size(self, problem: ProblemSize) -> None:
"""Adds the given problem to this config's problem_sizes"""
self.problem_sizes.append(problem)
def to_dict(self) -> Dict:
config_dict = {
"problem": {"description": f"{self.description}"},
"test_params": {
"problem_sizes": [ps.to_dict() for ps in self.problem_sizes]
},
"tile_config": self.tile_config.to_dict(),
"trait_config": self.trait_config.to_dict(),
"k_block_per_cu": self.k_block_per_cu,
"permute_n": self.permute_n,
}
return config_dict
def write_to_file(self, output_file: str) -> None:
"""Writes this configs to the given output_file in a json format"""
with open(output_file, "w") as config_file:
json.dump(self.to_dict(), config_file, indent=4)
config_file.write("\n")
def create_problem_sizes(
tile_m: int, tile_n: int, tile_k: int, cu_count: int
) -> List[ProblemSize]:
"""Creates and returns a list of problem sizes using the given arguments"""
problem_sizes = [
ProblemSize(256, 256, 256, TestVariant.ATOMIC_SMOKE),
ProblemSize(tile_m * cu_count, tile_n, tile_k, TestVariant.ATOMIC_SMOKE),
ProblemSize(
tile_m * 2, tile_n * 2, cu_count * tile_k, TestVariant.ATOMIC_SMOKE
),
ProblemSize(tile_m, tile_n, cu_count * tile_k, TestVariant.REDUCTION_SMOKE),
ProblemSize(
tile_m * 4,
tile_n,
tile_k * cu_count + (25 * tile_k),
TestVariant.REDUCTION_SMOKE,
),
ProblemSize(
tile_m * 3,
tile_n * 7,
tile_k * cu_count + (30 * tile_k),
TestVariant.REDUCTION_SMOKE,
),
# TODO: Add this test once we determine how to label tests as regresion with tile engine
# ProblemSize((tile_m * cu_count * 2) + (tile_m * 2), tile_n, 2048, TestVariant.EXTENDED)
]
return problem_sizes
def write_config_files(
problem_sizes: List[ProblemSize],
configs_dir_path: str,
datatype: str,
tile_sizes: Tuple[int, int, int],
) -> str:
"""Writes the given problem_sizes to a config file and returns the names of the config files written to"""
config_names = []
tile_m, tile_n, tile_k = tile_sizes
tile_config = TileConfig([tile_m], [tile_n], [tile_k])
# Create a config for each test variant
for variant in TestVariant:
problem_sizes_filtered = [ps for ps in problem_sizes if ps.variant == variant]
if (datatype not in variant.datatypes) or len(problem_sizes_filtered) == 0:
continue
trait_config = TraitConfig()
variant.apply(trait_config)
config_name = f"streamk_{variant.name.lower()}_tests_config_{datatype}"
config_names.append(config_name)
file_path = os.path.join(configs_dir_path, config_name + ".json")
config = Config(
variant.description, problem_sizes_filtered, tile_config, trait_config
)
config.write_to_file(file_path)
return config_names
def print_config_names(config_file_names: List[str]) -> None:
"""Prints given config file names as a single semi-colon separated string"""
print(";".join(config_file_names))
def create_config_files(
cu_count: int, configs_dir_path: str, tile_sizes: int, datatype: str
) -> None:
"""Creates Stream-K test config files and prints the file names in a semi-colon-separated list"""
tile_m, tile_n, tile_k = tile_sizes
problem_sizes = create_problem_sizes(tile_m, tile_n, tile_k, cu_count)
config_names = write_config_files(
problem_sizes, configs_dir_path, datatype, tile_sizes
)
print_config_names(config_names)
def get_args() -> Tuple[int, str, Tuple[int, int, int], str]:
"""Returns user provided arguments"""
def tile_sizes_type(val: str):
sizes = None
parts = val.split(",")
if len(parts) != 3:
raise argparse.ArgumentTypeError(
"--tiles must contain exactly three comma-separated values (m,n,k), e.g. --tiles 256,256,32"
)
try:
sizes = tuple(int(size) for size in parts)
except ValueError:
raise argparse.ArgumentTypeError(
"--tiles must contain exactly three comma-separated integers (m,n,k), e.g. --tiles 256,256,32"
)
return sizes
parser = argparse.ArgumentParser(description="Create Stream-K test configs")
parser.add_argument(
"--cu_count", required=True, help="Number of Compute Units on the device"
)
parser.add_argument(
"--configs_dir_path",
required=True,
help="Full path configs directory where config files will be written to",
)
parser.add_argument(
"--tiles",
required=True,
type=tile_sizes_type,
help="Block tile sizes for m, n, and k, respectively. Ex: --tiles 256,256,32",
)
parser.add_argument(
"--datatype",
choices=["fp16", "bf16", "fp8", "bf8"],
required=True,
help="The datatype for which the config is generated.",
)
args = parser.parse_args()
return (int(args.cu_count), args.configs_dir_path, args.tiles, args.datatype)
def main():
cu_count, configs_dir_path, tile_sizes, datatype = get_args()
create_config_files(cu_count, configs_dir_path, tile_sizes, datatype)
sys.exit(0)
if __name__ == "__main__":
main()

View File

@@ -1,258 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
/**
* @file test_gemm_simple.cpp
* @brief Unit tests for GEMM kernels generated by gemm_instance_builder
*
* This test includes kernels generated during CMake configuration by
* gemm_instance_builder.py and tests them with problem sizes extracted
* from the corresponding JSON configuration files.
*/
#include <gtest/gtest.h>
#include <iostream>
#include <tuple>
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
#include "tile_engine/ops/gemm_streamk/gemm_streamk_common.hpp"
// The kernel header is included via compile command line with -include flag
// It defines SelectedKernel struct, KERNEL_NAME, and tensor data types
// Adaptive error threshold calculation matching tile_engine's implementation
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
auto calculate_rtol_atol(const ck_tile::index_t K,
const ck_tile::index_t kbatch,
const float max_accumulated_value)
{
using ComputeType =
std::conditional_t<sizeof(ADataType) < sizeof(BDataType), ADataType, BDataType>;
// Calculate thresholds
const auto rtol = ck_tile::get_relative_threshold<ComputeType, CDataType, AccDataType>(
ck_tile::integer_divide_ceil(K, kbatch));
const auto atol = ck_tile::get_absolute_threshold<ComputeType, CDataType, AccDataType>(
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch));
// Calculate error due to split_k accumulation
const auto rtol_split_k =
ck_tile::get_relative_threshold<CDataType, CDataType, CDataType>(kbatch);
const auto atol_split_k = ck_tile::get_absolute_threshold<CDataType, CDataType, CDataType>(
max_accumulated_value, kbatch);
// Use higher threshold
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
}
/// @brief Function to compare the results of the device and host computations (from tile_engine)
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
bool compare_results(std::string instanceName,
ck_tile::index_t K,
ck_tile::index_t kbatch,
ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
ck_tile::HostTensor<CDataType>& c_m_n_host_result)
{
const float max_accumulated_value =
std::abs(static_cast<float>(*std::max_element(c_m_n_host_result.mData.begin(),
c_m_n_host_result.mData.end(),
[](CDataType a, CDataType b) {
return std::abs(static_cast<float>(a)) <
std::abs(static_cast<float>(b));
})));
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
K, kbatch, max_accumulated_value);
bool pass = ck_tile::check_err(c_m_n_dev_result,
c_m_n_host_result,
"Error: Incorrect results!",
rtol_atol.at(ck_tile::number<0>{}),
rtol_atol.at(ck_tile::number<1>{}));
std::cout << "For " << instanceName << " Relative error threshold is "
<< rtol_atol.at(ck_tile::number<0>{}) << " Absolute error threshold is "
<< rtol_atol.at(ck_tile::number<1>{}) << std::endl;
std::cout << "The verification result is:" << (pass ? "correct" : "fail") << std::endl;
return pass;
}
// Test parameter structure for matrix dimensions and split_k values
struct GemmTestParams
{
int m, n, k, split_k;
};
// Include config-specific test parameters (after GemmTestParams struct is defined)
#ifdef GEMM_TEST_PARAMS_HPP
#include GEMM_TEST_PARAMS_HPP
#endif
class StreamKGemmTileEngineTest : public ::testing::TestWithParam<GemmTestParams>
{
protected:
void SetUp() override
{
auto params = GetParam();
m_ = params.m;
n_ = params.n;
k_ = params.k;
split_k_ = params.split_k;
// Calculate strides (following tile_engine pattern)
if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
stride_a_ = k_;
}
else
{
stride_a_ = m_;
}
if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
stride_b_ = n_;
}
else
{
stride_b_ = k_;
}
if constexpr(std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
stride_c_ = n_;
}
else
{
stride_c_ = m_;
}
}
// Test dimensions
int m_, n_, k_, split_k_;
int stride_a_, stride_b_, stride_c_;
};
TEST_P(StreamKGemmTileEngineTest, BasicFunctionality)
{
// Check that kernel information is available
EXPECT_TRUE(strlen(KERNEL_NAME) > 0) << "Kernel name should not be empty";
std::cout << "Testing kernel: " << KERNEL_NAME << std::endl;
std::cout << "Problem size: " << m_ << "x" << n_ << "x" << k_ << std::endl;
// Get tensor layouts from generated kernel
const ALayout layout_a = ALayout{};
const BLayout layout_b = BLayout{};
const CLayout layout_c = CLayout{};
// Calculate tensor strides
int stride_a_calc = ck_tile::get_default_stride(m_, k_, 0, is_row_major(layout_a));
int stride_b_calc = ck_tile::get_default_stride(k_, n_, 0, is_row_major(layout_b));
int stride_c_calc = ck_tile::get_default_stride(m_, n_, 0, is_row_major(layout_c));
// Create host tensors with proper descriptors
ck_tile::HostTensor<ADataType> a_m_k(
ck_tile::host_tensor_descriptor(m_, k_, stride_a_calc, is_row_major(layout_a)));
ck_tile::HostTensor<BDataType> b_k_n(
ck_tile::host_tensor_descriptor(k_, n_, stride_b_calc, is_row_major(layout_b)));
ck_tile::HostTensor<CDataType> c_m_n_dev_result(
ck_tile::host_tensor_descriptor(m_, n_, stride_c_calc, is_row_major(layout_c)));
ck_tile::HostTensor<CDataType> c_m_n_dev_ref(
ck_tile::host_tensor_descriptor(m_, n_, stride_c_calc, is_row_major(layout_c)));
// Initialize input tensors with uniform random distribution [-1.0, 1.0] (matches tile_engine)
ck_tile::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k);
ck_tile::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n);
c_m_n_dev_ref.SetZero();
// Allocate GPU device memory
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes());
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes());
ck_tile::DeviceMem ref_c_m_n_dev_buf(c_m_n_dev_ref.get_element_space_size_in_bytes());
// Copy data to device and zero output buffer
a_m_k_dev_buf.ToDevice(a_m_k.data());
b_k_n_dev_buf.ToDevice(b_k_n.data());
c_m_n_dev_buf.SetZero();
ref_c_m_n_dev_buf.SetZero();
// Calculate reference result on device for verification
ADataType* a_m_k_dev_ref_ptr = static_cast<ADataType*>(a_m_k_dev_buf.GetDeviceBuffer());
BDataType* b_k_n_dev_ref_ptr = static_cast<BDataType*>(b_k_n_dev_buf.GetDeviceBuffer());
CDataType* c_m_n_dev_ref_ptr = static_cast<CDataType*>(ref_c_m_n_dev_buf.GetDeviceBuffer());
ck_tile::
reference_gemm_gpu<ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout>(
a_m_k_dev_ref_ptr,
b_k_n_dev_ref_ptr,
c_m_n_dev_ref_ptr,
m_,
n_,
k_,
stride_a_calc,
stride_b_calc,
stride_c_calc);
ref_c_m_n_dev_buf.FromDevice(c_m_n_dev_ref.data());
// Create GEMM kernel arguments
ck_tile::StreamKHostArgs args{a_m_k_dev_buf.GetDeviceBuffer(),
b_k_n_dev_buf.GetDeviceBuffer(),
c_m_n_dev_buf.GetDeviceBuffer(),
m_,
n_,
k_,
stride_a_calc,
stride_b_calc,
stride_c_calc};
// Configure kernel execution for maximum speed (no timing, no debug output)
ck_tile::stream_config stream_config{nullptr, // stream
false, // time_kernel (disable timing for speed)
0, // log_level (disable debug output)
0, // n_warmup
1, // n_repeat
false, // is_gpu_timer (unused when time_kernel=false)
false, // flush_cache
1}; // rotating_count
// Launch the generated kernel (no timing overhead for fastest execution)
std::tuple<float, ck_tile::index_t> launch_result;
try
{
launch_result = SelectedKernel::launch(args, stream_config);
// Kernel launched successfully if no exception thrown
}
catch(const std::exception& e)
{
std::string error_msg(e.what());
// If arguments not supported, skip the test (configuration validation failure, not a bug)
if(error_msg.find("Arguments not supported") != std::string::npos)
{
GTEST_SKIP() << "Configuration not supported: " << e.what();
}
else
{
FAIL() << "Kernel launch failed: " << e.what();
}
}
// Copy result back from device
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
// Verify results using tile_engine's adaptive error thresholds
const ck_tile::index_t num_wgs_per_tile = get<1>(launch_result);
bool verification_passed = compare_results<ADataType, BDataType, AccDataType, CDataType>(
KERNEL_NAME, k_, num_wgs_per_tile, c_m_n_dev_result, c_m_n_dev_ref);
EXPECT_TRUE(verification_passed) << "GEMM result verification failed";
}
// Use config-specific test parameters (included via compile flags)
// CONFIG_TEST_PARAMS is defined in the auto-generated test_params.hpp file
INSTANTIATE_TEST_SUITE_P(GemmVerification,
StreamKGemmTileEngineTest,
::testing::ValuesIn(CONFIG_TEST_PARAMS),
[](const ::testing::TestParamInfo<GemmTestParams>& param_info) {
return std::to_string(param_info.param.m) + "x" +
std::to_string(param_info.param.n) + "x" +
std::to_string(param_info.param.k) + "_splitk" +
std::to_string(param_info.param.split_k);
});