[rocm-libraries] ROCm/rocm-libraries#5722 (commit 55febd2)

[CK Tile] Stream-K gtest Code Gen

## Motivation

Stream-K was using the tile engine infrastructure for smoke tests.
However, tile engine creates a different target per kernel instance,
which has resulted in scalability issues when used in the context of
unit tests. To avoid burdens on cmake configuration and build time, we
have opted to remove our Stream-K tile engine tests. Instead, we use
pure gtests with code gen to generate repetitive .cpp files.

**Note: This appears to change a lot of files because many files are
removed since they are now generated at build time.**

## Technical Details
We originally used Tile Engine to facilitate code gen for unit tests
since we found that pure gtests required the addition of many repetitive
.cpp files of the following form:
```cpp
#include "test_gemm_streamk_common_includes.hpp"

template <typename Tuple>
class TestCkTileStreamKBf8 : public TestCkTileStreamK<Tuple>
{
};

#define TEST_SUITE_NAME TestCkTileStreamKBf8

TYPED_TEST_SUITE(TestCkTileStreamKBf8, KernelTypesStreamKBf8);

#include "test_gemm_streamk_atomic_cases.inc"

#undef TEST_SUITE_NAME

```
Due to issues encountered with tile engine, we instead use pure gtests
to generate the repetitive .cpp files. The code generator parses
`KernelTypesStreamK*` type aliases from the types header using a
two-phase approach:
1. At **configure time**, CMake runs the Python script with
`--list_files` to extract the type alias names from the header
(test_gemm_streamk_types.hpp) and compute the list of .cpp file paths
that will be generated. This lets CMake know the exact set of source
files for each target.
2. At **build time**, `add_custom_command` runs the script again with
`--gen_files` to actually emit the .cpp files into the build directory,
triggered only when the types header or generator script changes.

With these changes, we've removed all Stream-K tile engine tests. There
are now 5 targets for Stream-K GEMM tests:
1. test_ck_tile_streamk_atomic_smoke: smoke tests for Atomic reduction
strategy (pipeline: compv3)
2. test_ck_tile_streamk_linear_smoke: smoke tests for Linear reduction
strategy (pipeline: compv3)
3. test_ck_tile_streamk_tree_smoke: smoke tests for Tree reduction
strategy (pipeline: compv3)
4. test_ck_tile_streamk_pipelines_smoke: smoke tests (smaller set) for
pipelines other than compv3
- Since Stream-K can be thought of as a wrapper around universal GEMM,
we don't need to extensively test each pipeline. So, we opt to run a few
tests for different pipelines. Currently, this just consists of the mem
pipeline, but compv4 is coming soon.
5. test_ck_tile_streamk_extended: extended tests

## Test Plan

I have tests the gtests locally on gfx90a, gfx942, and gfx950.

## Test Result

All local tests pass.

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
Emily Martins
2026-04-02 21:07:13 +00:00
committed by assistant-librarian[bot]
parent 6d77edc3bd
commit 7cc9bae9d2
39 changed files with 738 additions and 1775 deletions

View File

@@ -19,55 +19,93 @@ set(EXAMPLE_GEMM_COMPILE_COMPUTE_ASYNC_OPTIONS ${EXAMPLE_GEMM_COMPILE_COMPUTE_V4
if(GPU_TARGETS MATCHES "gfx90a|gfx942|gfx950")
include_directories(BEFORE ${CMAKE_CURRENT_SOURCE_DIR})
#TODO: support all arches
#TODO: current c-shuffle only supports C layout as R
add_gtest_executable(test_ck_tile_streamk_tile_partitioner test_streamk_tile_partitioner.cpp)
set(STREAMK_EXTENDED_SOURCES
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp16_persistent_compv3.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp16_persistent_compv4.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp16_persistent_mem.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf16_persistent_compv3.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf16_persistent_compv4.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf16_persistent_mem.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp16_nonpersistent_compv3.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp16_nonpersistent_compv4.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp16_nonpersistent_mem.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf16_nonpersistent_compv3.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf16_nonpersistent_compv4.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf16_nonpersistent_mem.cpp
test_gemm_streamk_util.cpp)
# We only test fp8 and bf8 on gfx942 and gfx950 since these types are not natively supported on gfx90a
if(GPU_TARGETS MATCHES "gfx942|gfx950")
list(APPEND STREAMK_EXTENDED_SOURCES
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp8_persistent_compv3.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp8_persistent_compv4.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp8_persistent_mem.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf8_persistent_compv3.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf8_persistent_compv4.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf8_persistent_mem.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp8_nonpersistent_compv3.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp8_nonpersistent_compv4.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp8_nonpersistent_mem.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf8_nonpersistent_compv3.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf8_nonpersistent_compv4.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf8_nonpersistent_mem.cpp)
endif()
# ---- Code-generate test .cpp files from types header ----
set(STREAMK_TYPES_HEADER ${CMAKE_CURRENT_SOURCE_DIR}/test_gemm_streamk_types.hpp)
set(STREAMK_GEN_SCRIPT ${CMAKE_CURRENT_SOURCE_DIR}/generate_test_files.py)
add_gtest_executable(test_ck_tile_streamk_extended ${STREAMK_EXTENDED_SOURCES})
target_compile_options(test_ck_tile_streamk_extended PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
# Re-run configure automatically if the types header changes (e.g. types added/removed)
# or if the generation script changes
set_property(DIRECTORY APPEND PROPERTY CMAKE_CONFIGURE_DEPENDS ${STREAMK_TYPES_HEADER} ${STREAMK_GEN_SCRIPT})
# Define the targets and their corresponding executable names
set(STREAMK_GEN_TARGETS extended atomic_smoke linear_smoke tree_smoke pipelines_smoke)
set(STREAMK_GEN_EXEC_EXTENDED test_ck_tile_streamk_extended)
set(STREAMK_GEN_EXEC_ATOMIC_SMOKE test_ck_tile_streamk_atomic_smoke)
set(STREAMK_GEN_EXEC_LINEAR_SMOKE test_ck_tile_streamk_linear_smoke)
set(STREAMK_GEN_EXEC_TREE_SMOKE test_ck_tile_streamk_tree_smoke)
set(STREAMK_GEN_EXEC_PIPELINES_SMOKE test_ck_tile_streamk_pipelines_smoke)
# Collect all test targets for umbrella label
set(CK_TILE_GEMM_STREAMK_TEST_TARGETS
test_ck_tile_streamk_tile_partitioner
test_ck_tile_streamk_extended
test_ck_tile_streamk_tile_partitioner)
foreach(target IN LISTS STREAMK_GEN_TARGETS)
string(TOUPPER ${target} TARGET_UPPER)
set(GEN_DIR ${CMAKE_CURRENT_BINARY_DIR}/${target})
set(EXEC_NAME ${STREAMK_GEN_EXEC_${TARGET_UPPER}})
set(LIST_FILE ${CMAKE_CURRENT_BINARY_DIR}/${target}_files.txt)
# Phase 1 (configure time): discover the list of files that will be generated
execute_process(
COMMAND ${Python3_EXECUTABLE} ${STREAMK_GEN_SCRIPT}
--types_header ${STREAMK_TYPES_HEADER}
--output_dir ${GEN_DIR}
--target ${target}
--list_files ${LIST_FILE}
RESULT_VARIABLE ret
ERROR_VARIABLE list_files_stderr)
if(ret AND NOT ret EQUAL 0)
message(FATAL_ERROR
"Failed to list ${target} test files via Python: ${ret}\n"
"stderr: ${list_files_stderr}"
)
endif()
file(STRINGS ${LIST_FILE} ALL_SOURCES_${target})
# Phase 2 (build time): generate the .cpp files when the types header changes
add_custom_command(
OUTPUT ${ALL_SOURCES_${target}}
COMMAND ${Python3_EXECUTABLE} ${STREAMK_GEN_SCRIPT}
--types_header ${STREAMK_TYPES_HEADER}
--output_dir ${GEN_DIR}
--target ${target}
--gen_files
DEPENDS ${STREAMK_TYPES_HEADER} ${STREAMK_GEN_SCRIPT}
COMMENT "Generating StreamK ${target} test sources from types header")
# Filter out fp8/bf8 sources on gfx90a since those types are not natively supported
set(FILTERED_SOURCES)
foreach(src IN LISTS ALL_SOURCES_${target})
if(NOT src MATCHES "_(fp8|bf8)_" OR GPU_TARGETS MATCHES "gfx942|gfx950")
list(APPEND FILTERED_SOURCES ${src})
endif()
endforeach()
list(APPEND FILTERED_SOURCES test_gemm_streamk_util.cpp)
add_gtest_executable(${EXEC_NAME} ${FILTERED_SOURCES})
target_compile_options(${EXEC_NAME} PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
list(APPEND CK_TILE_GEMM_STREAMK_TEST_TARGETS ${EXEC_NAME})
endforeach()
# Add python unit tests to validate the code gen logic in generate_test_files.py
add_test(
NAME test_ck_tile_streamk_generate_test_files
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/test_generate_test_files.py -v
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/..
)
# Label all ck_tile gemm_streamk tests with CK_TILE_GEMM_STREAMK_TESTS for selective execution
foreach(test_target ${CK_TILE_GEMM_STREAMK_TEST_TARGETS})
set_tests_properties(${test_target} PROPERTIES LABELS "CK_TILE_GEMM_STREAMK_TESTS")
endforeach()
# Also label the Python test
set_tests_properties(test_ck_tile_streamk_generate_test_files PROPERTIES LABELS "CK_TILE_GEMM_STREAMK_TESTS")
# Umbrella target to build and run all ck_tile gemm_streamk tests
# Usage: ninja ck_tile_gemm_streamk_tests

View File

@@ -1,18 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_gemm_streamk_common_includes.hpp"
template <typename Tuple>
class TestCkTileStreamKBf16NonPersistentCompV3 : public TestCkTileStreamK<Tuple>
{
};
#define TEST_SUITE_NAME TestCkTileStreamKBf16NonPersistentCompV3
TYPED_TEST_SUITE(TestCkTileStreamKBf16NonPersistentCompV3,
KernelTypesStreamKBf16NonPersistentCompV3);
#include "test_gemm_streamk_extended_cases.inc"
#undef TEST_SUITE_NAME

View File

@@ -1,18 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_gemm_streamk_common_includes.hpp"
template <typename Tuple>
class TestCkTileStreamKBf16NonPersistentCompV4 : public TestCkTileStreamK<Tuple>
{
};
#define TEST_SUITE_NAME TestCkTileStreamKBf16NonPersistentCompV4
TYPED_TEST_SUITE(TestCkTileStreamKBf16NonPersistentCompV4,
KernelTypesStreamKBf16NonPersistentCompV4);
#include "test_gemm_streamk_extended_cases.inc"
#undef TEST_SUITE_NAME

View File

@@ -1,17 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_gemm_streamk_common_includes.hpp"
template <typename Tuple>
class TestCkTileStreamKBf16NonPersistentMem : public TestCkTileStreamK<Tuple>
{
};
#define TEST_SUITE_NAME TestCkTileStreamKBf16NonPersistentMem
TYPED_TEST_SUITE(TestCkTileStreamKBf16NonPersistentMem, KernelTypesStreamKBf16NonPersistentMem);
#include "test_gemm_streamk_extended_cases.inc"
#undef TEST_SUITE_NAME

View File

@@ -1,17 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_gemm_streamk_common_includes.hpp"
template <typename Tuple>
class TestCkTileStreamKBf16PersistentCompV3 : public TestCkTileStreamK<Tuple>
{
};
#define TEST_SUITE_NAME TestCkTileStreamKBf16PersistentCompV3
TYPED_TEST_SUITE(TestCkTileStreamKBf16PersistentCompV3, KernelTypesStreamKBf16PersistentCompV3);
#include "test_gemm_streamk_extended_cases.inc"
#undef TEST_SUITE_NAME

View File

@@ -1,17 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_gemm_streamk_common_includes.hpp"
template <typename Tuple>
class TestCkTileStreamKBf16PersistentCompV4 : public TestCkTileStreamK<Tuple>
{
};
#define TEST_SUITE_NAME TestCkTileStreamKBf16PersistentCompV4
TYPED_TEST_SUITE(TestCkTileStreamKBf16PersistentCompV4, KernelTypesStreamKBf16PersistentCompV4);
#include "test_gemm_streamk_extended_cases.inc"
#undef TEST_SUITE_NAME

View File

@@ -1,17 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_gemm_streamk_common_includes.hpp"
template <typename Tuple>
class TestCkTileStreamKBf16PersistentMem : public TestCkTileStreamK<Tuple>
{
};
#define TEST_SUITE_NAME TestCkTileStreamKBf16PersistentMem
TYPED_TEST_SUITE(TestCkTileStreamKBf16PersistentMem, KernelTypesStreamKBf16PersistentMem);
#include "test_gemm_streamk_extended_cases.inc"
#undef TEST_SUITE_NAME

View File

@@ -1,17 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_gemm_streamk_common_includes.hpp"
template <typename Tuple>
class TestCkTileStreamKBf8NonPersistentCompV3 : public TestCkTileStreamK<Tuple>
{
};
#define TEST_SUITE_NAME TestCkTileStreamKBf8NonPersistentCompV3
TYPED_TEST_SUITE(TestCkTileStreamKBf8NonPersistentCompV3, KernelTypesStreamKBf8NonPersistentCompV3);
#include "test_gemm_streamk_extended_cases.inc"
#undef TEST_SUITE_NAME

View File

@@ -1,17 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_gemm_streamk_common_includes.hpp"
template <typename Tuple>
class TestCkTileStreamKBf8NonPersistentCompV4 : public TestCkTileStreamK<Tuple>
{
};
#define TEST_SUITE_NAME TestCkTileStreamKBf8NonPersistentCompV4
TYPED_TEST_SUITE(TestCkTileStreamKBf8NonPersistentCompV4, KernelTypesStreamKBf8NonPersistentCompV4);
#include "test_gemm_streamk_extended_cases.inc"
#undef TEST_SUITE_NAME

View File

@@ -1,17 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_gemm_streamk_common_includes.hpp"
template <typename Tuple>
class TestCkTileStreamKBf8NonPersistentMem : public TestCkTileStreamK<Tuple>
{
};
#define TEST_SUITE_NAME TestCkTileStreamKBf8NonPersistentMem
TYPED_TEST_SUITE(TestCkTileStreamKBf8NonPersistentMem, KernelTypesStreamKBf8NonPersistentMem);
#include "test_gemm_streamk_extended_cases.inc"
#undef TEST_SUITE_NAME

View File

@@ -1,17 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_gemm_streamk_common_includes.hpp"
template <typename Tuple>
class TestCkTileStreamKBf8PersistentCompV3 : public TestCkTileStreamK<Tuple>
{
};
#define TEST_SUITE_NAME TestCkTileStreamKBf8PersistentCompV3
TYPED_TEST_SUITE(TestCkTileStreamKBf8PersistentCompV3, KernelTypesStreamKBf8PersistentCompV3);
#include "test_gemm_streamk_extended_cases.inc"
#undef TEST_SUITE_NAME

View File

@@ -1,17 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_gemm_streamk_common_includes.hpp"
template <typename Tuple>
class TestCkTileStreamKBf8PersistentCompV4 : public TestCkTileStreamK<Tuple>
{
};
#define TEST_SUITE_NAME TestCkTileStreamKBf8PersistentCompV4
TYPED_TEST_SUITE(TestCkTileStreamKBf8PersistentCompV4, KernelTypesStreamKBf8PersistentCompV4);
#include "test_gemm_streamk_extended_cases.inc"
#undef TEST_SUITE_NAME

View File

@@ -1,17 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_gemm_streamk_common_includes.hpp"
template <typename Tuple>
class TestCkTileStreamKBf8PersistentMem : public TestCkTileStreamK<Tuple>
{
};
#define TEST_SUITE_NAME TestCkTileStreamKBf8PersistentMem
TYPED_TEST_SUITE(TestCkTileStreamKBf8PersistentMem, KernelTypesStreamKBf8PersistentMem);
#include "test_gemm_streamk_extended_cases.inc"
#undef TEST_SUITE_NAME

View File

@@ -1,18 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_gemm_streamk_common_includes.hpp"
template <typename Tuple>
class TestCkTileStreamKFp16NonPersistentCompV3 : public TestCkTileStreamK<Tuple>
{
};
#define TEST_SUITE_NAME TestCkTileStreamKFp16NonPersistentCompV3
TYPED_TEST_SUITE(TestCkTileStreamKFp16NonPersistentCompV3,
KernelTypesStreamKFp16NonPersistentCompV3);
#include "test_gemm_streamk_extended_cases.inc"
#undef TEST_SUITE_NAME

View File

@@ -1,18 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_gemm_streamk_common_includes.hpp"
template <typename Tuple>
class TestCkTileStreamKFp16NonPersistentCompV4 : public TestCkTileStreamK<Tuple>
{
};
#define TEST_SUITE_NAME TestCkTileStreamKFp16NonPersistentCompV4
TYPED_TEST_SUITE(TestCkTileStreamKFp16NonPersistentCompV4,
KernelTypesStreamKFp16NonPersistentCompV4);
#include "test_gemm_streamk_extended_cases.inc"
#undef TEST_SUITE_NAME

View File

@@ -1,17 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_gemm_streamk_common_includes.hpp"
template <typename Tuple>
class TestCkTileStreamKFp16NonPersistentMem : public TestCkTileStreamK<Tuple>
{
};
#define TEST_SUITE_NAME TestCkTileStreamKFp16NonPersistentMem
TYPED_TEST_SUITE(TestCkTileStreamKFp16NonPersistentMem, KernelTypesStreamKFp16NonPersistentMem);
#include "test_gemm_streamk_extended_cases.inc"
#undef TEST_SUITE_NAME

View File

@@ -1,17 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_gemm_streamk_common_includes.hpp"
template <typename Tuple>
class TestCkTileStreamKFp16PersistentCompV3 : public TestCkTileStreamK<Tuple>
{
};
#define TEST_SUITE_NAME TestCkTileStreamKFp16PersistentCompV3
TYPED_TEST_SUITE(TestCkTileStreamKFp16PersistentCompV3, KernelTypesStreamKFp16PersistentCompV3);
#include "test_gemm_streamk_extended_cases.inc"
#undef TEST_SUITE_NAME

View File

@@ -1,17 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_gemm_streamk_common_includes.hpp"
template <typename Tuple>
class TestCkTileStreamKFp16PersistentCompV4 : public TestCkTileStreamK<Tuple>
{
};
#define TEST_SUITE_NAME TestCkTileStreamKFp16PersistentCompV4
TYPED_TEST_SUITE(TestCkTileStreamKFp16PersistentCompV4, KernelTypesStreamKFp16PersistentCompV4);
#include "test_gemm_streamk_extended_cases.inc"
#undef TEST_SUITE_NAME

View File

@@ -1,17 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_gemm_streamk_common_includes.hpp"
template <typename Tuple>
class TestCkTileStreamKFp16PersistentMem : public TestCkTileStreamK<Tuple>
{
};
#define TEST_SUITE_NAME TestCkTileStreamKFp16PersistentMem
TYPED_TEST_SUITE(TestCkTileStreamKFp16PersistentMem, KernelTypesStreamKFp16PersistentMem);
#include "test_gemm_streamk_extended_cases.inc"
#undef TEST_SUITE_NAME

View File

@@ -1,17 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_gemm_streamk_common_includes.hpp"
template <typename Tuple>
class TestCkTileStreamKFp8NonPersistentCompV3 : public TestCkTileStreamK<Tuple>
{
};
#define TEST_SUITE_NAME TestCkTileStreamKFp8NonPersistentCompV3
TYPED_TEST_SUITE(TestCkTileStreamKFp8NonPersistentCompV3, KernelTypesStreamKFp8NonPersistentCompV3);
#include "test_gemm_streamk_extended_cases.inc"
#undef TEST_SUITE_NAME

View File

@@ -1,17 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_gemm_streamk_common_includes.hpp"
template <typename Tuple>
class TestCkTileStreamKFp8NonPersistentCompV4 : public TestCkTileStreamK<Tuple>
{
};
#define TEST_SUITE_NAME TestCkTileStreamKFp8NonPersistentCompV4
TYPED_TEST_SUITE(TestCkTileStreamKFp8NonPersistentCompV4, KernelTypesStreamKFp8NonPersistentCompV4);
#include "test_gemm_streamk_extended_cases.inc"
#undef TEST_SUITE_NAME

View File

@@ -1,17 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_gemm_streamk_common_includes.hpp"
template <typename Tuple>
class TestCkTileStreamKFp8NonPersistentMem : public TestCkTileStreamK<Tuple>
{
};
#define TEST_SUITE_NAME TestCkTileStreamKFp8NonPersistentMem
TYPED_TEST_SUITE(TestCkTileStreamKFp8NonPersistentMem, KernelTypesStreamKFp8NonPersistentMem);
#include "test_gemm_streamk_extended_cases.inc"
#undef TEST_SUITE_NAME

View File

@@ -1,17 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_gemm_streamk_common_includes.hpp"
template <typename Tuple>
class TestCkTileStreamKFp8PersistentCompV3 : public TestCkTileStreamK<Tuple>
{
};
#define TEST_SUITE_NAME TestCkTileStreamKFp8PersistentCompV3
TYPED_TEST_SUITE(TestCkTileStreamKFp8PersistentCompV3, KernelTypesStreamKFp8PersistentCompV3);
#include "test_gemm_streamk_extended_cases.inc"
#undef TEST_SUITE_NAME

View File

@@ -1,17 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_gemm_streamk_common_includes.hpp"
template <typename Tuple>
class TestCkTileStreamKFp8PersistentCompV4 : public TestCkTileStreamK<Tuple>
{
};
#define TEST_SUITE_NAME TestCkTileStreamKFp8PersistentCompV4
TYPED_TEST_SUITE(TestCkTileStreamKFp8PersistentCompV4, KernelTypesStreamKFp8PersistentCompV4);
#include "test_gemm_streamk_extended_cases.inc"
#undef TEST_SUITE_NAME

View File

@@ -1,17 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_gemm_streamk_common_includes.hpp"
template <typename Tuple>
class TestCkTileStreamKFp8PersistentMem : public TestCkTileStreamK<Tuple>
{
};
#define TEST_SUITE_NAME TestCkTileStreamKFp8PersistentMem
TYPED_TEST_SUITE(TestCkTileStreamKFp8PersistentMem, KernelTypesStreamKFp8PersistentMem);
#include "test_gemm_streamk_extended_cases.inc"
#undef TEST_SUITE_NAME

View File

@@ -0,0 +1,215 @@
#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""Generate test .cpp files from KernelTypes definitions in
test_gemm_streamk_types.hpp.
Two modes:
--list_files FILE Write the list of output file paths to FILE (one per line)
without generating the files. Used at CMake configure time.
--gen_files Actually emit the .cpp files into --output_dir.
Used at build time via add_custom_command.
Target selection (--target):
extended Kernel types containing 'Atomic' or 'Pipelines'
-> includes test_gemm_streamk_extended_cases.inc
atomic_smoke Kernel types containing 'Atomic' (not 'Pipelines')
-> includes test_gemm_streamk_atomic_cases.inc
linear_smoke Kernel types containing 'Linear' (not 'Pipelines')
-> includes test_gemm_streamk_reduction_cases.inc
tree_smoke Kernel types containing 'Tree' (not 'Pipelines')
-> includes test_gemm_streamk_reduction_cases.inc
pipelines_smoke Kernel types matching 'Pipelines'
-> includes test_gemm_streamk_reduction_cases.inc
and test_gemm_streamk_atomic_cases.inc
"""
import argparse
import os
import re
import sys
# --------------------------------------------------------------------------- #
# Template for every generated .cpp file
# --------------------------------------------------------------------------- #
CPP_TEMPLATE = """\
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_gemm_streamk_common_includes.hpp"
template <typename Tuple>
class {class_name} : public TestCkTileStreamK<Tuple>
{{
}};
#define TEST_SUITE_NAME {class_name}
TYPED_TEST_SUITE({class_name}, {type_alias});
{inc_includes}
#undef TEST_SUITE_NAME
"""
# --------------------------------------------------------------------------- #
# Target definitions: filter predicate and .inc files
# --------------------------------------------------------------------------- #
TARGETS = {
"extended": {
"filter": lambda suffix: "Atomic" in suffix or suffix == "Pipelines",
"inc_files": ["test_gemm_streamk_extended_cases.inc"],
},
"atomic_smoke": {
"filter": lambda suffix: "Atomic" in suffix and suffix != "Pipelines",
"inc_files": ["test_gemm_streamk_atomic_cases.inc"],
},
"linear_smoke": {
"filter": lambda suffix: "Linear" in suffix and suffix != "Pipelines",
"inc_files": ["test_gemm_streamk_reduction_cases.inc"],
},
"tree_smoke": {
"filter": lambda suffix: "Tree" in suffix and suffix != "Pipelines",
"inc_files": ["test_gemm_streamk_reduction_cases.inc"],
},
"pipelines_smoke": {
"filter": lambda suffix: suffix == "Pipelines",
"inc_files": [
"test_gemm_streamk_reduction_cases.inc",
"test_gemm_streamk_atomic_cases.inc",
],
},
}
# --------------------------------------------------------------------------- #
# Mapping from CamelCase suffix fragments to file-name fragments
# --------------------------------------------------------------------------- #
KNOWN_TOKENS = [
("Fp16", "fp16"),
("Bf16", "bf16"),
("Fp8", "fp8"),
("Bf8", "bf8"),
("NonPersistent", "nonpersistent"),
("Persistent", "persistent"),
("Atomic", "atomic"),
("Linear", "linear"),
("Tree", "tree"),
("CompV3", "compv3"),
("Pipelines", "pipelines"),
]
def suffix_to_file_tag(suffix: str) -> str:
"""Convert a CamelCase suffix like 'Fp16PersistentAtomicCompV3' to
'fp16_persistent_atomic_compv3'."""
parts: list[str] = []
remaining = suffix
while remaining:
matched = False
for token, replacement in KNOWN_TOKENS:
if remaining.startswith(token):
parts.append(replacement)
remaining = remaining[len(token) :]
matched = True
break
if not matched:
raise ValueError(
f"Unrecognised token in KernelTypes suffix: '{remaining}' "
f"(from '{suffix}')"
)
return "_".join(parts)
def parse_types_header(header_path: str, target: str) -> list[dict]:
"""Return a list of dicts with keys: type_alias, class_name, file_tag, suffix."""
target_def = TARGETS[target]
# Pattern matches lines like: using KernelTypesStreamKFp16PersistentAtomicCompV3 = ...
pattern = re.compile(r"using\s+(KernelTypesStreamK(\w+))\s*=")
entries: list[dict] = []
with open(header_path) as f:
for line in f:
match = pattern.search(line)
if match:
# If the match is: using KernelTypesStreamKFp16PersistentAtomicCompV3 = ...
# type_alias is KernelTypesStreamKFp16PersistentAtomicCompV3
# suffix is Fp16PersistentAtomicCompV3
type_alias = match.group(1)
suffix = match.group(2)
if not target_def["filter"](suffix):
continue
entries.append(
{
"type_alias": type_alias,
"class_name": f"TestCkTileStreamK{suffix}",
"file_tag": suffix_to_file_tag(suffix),
}
)
return entries
def output_path(output_dir: str, entry: dict) -> str:
return os.path.join(output_dir, f"test_gemm_streamk_{entry['file_tag']}.cpp")
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--types_header", required=True, help="Path to test_gemm_streamk_types.hpp"
)
parser.add_argument(
"--output_dir", required=True, help="Directory for generated .cpp files"
)
parser.add_argument(
"--target",
required=True,
choices=list(TARGETS.keys()),
help="Which target to generate files for",
)
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument(
"--list_files",
metavar="FILE",
help="Write output file paths to FILE then exit",
)
group.add_argument(
"--gen_files", action="store_true", help="Generate the .cpp files"
)
return parser.parse_args()
def main() -> None:
args = parse_args()
entries = parse_types_header(args.types_header, args.target)
if not entries:
print(
f"ERROR: no KernelTypesStreamK* definitions found for target "
f"'{args.target}' in {args.types_header}",
file=sys.stderr,
)
sys.exit(1)
inc_files = TARGETS[args.target]["inc_files"]
inc_includes = "\n".join(f'#include "{f}"' for f in inc_files)
if args.list_files:
os.makedirs(os.path.dirname(args.list_files) or ".", exist_ok=True)
with open(args.list_files, "w") as f:
for entry in entries:
f.write(output_path(args.output_dir, entry) + "\n")
else:
os.makedirs(args.output_dir, exist_ok=True)
for entry in entries:
path = output_path(args.output_dir, entry)
content = CPP_TEMPLATE.format(
class_name=entry["class_name"],
type_alias=entry["type_alias"],
inc_includes=inc_includes,
)
with open(path, "w") as f:
f.write(content)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,47 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
TYPED_TEST(TEST_SUITE_NAME, StreamK_EdgeCase)
{
ck_tile::index_t M = 256;
ck_tile::index_t N = 256;
ck_tile::index_t K = 256;
this->Run(M, N, K);
}
TYPED_TEST(TEST_SUITE_NAME, StreamK_DPOnly)
{
const ck_tile::index_t num_cu = get_cu_count();
constexpr ck_tile::index_t M_Tile = std::tuple_element_t<7, TypeParam>::value;
constexpr ck_tile::index_t N_Tile = std::tuple_element_t<8, TypeParam>::value;
constexpr ck_tile::index_t K_Tile = std::tuple_element_t<9, TypeParam>::value;
// For DP only, we ensure that the number of tiles is a multiple of the number of CUs. This
// assumes tile sizes are large enough such that occupancy is 1.
ck_tile::index_t M = M_Tile * num_cu;
ck_tile::index_t N = N_Tile;
ck_tile::index_t K = K_Tile;
this->Run(M, N, K);
}
TYPED_TEST(TEST_SUITE_NAME, StreamK_SKOnly)
{
const ck_tile::index_t num_cu = get_cu_count();
constexpr ck_tile::index_t M_Tile = std::tuple_element_t<7, TypeParam>::value;
constexpr ck_tile::index_t N_Tile = std::tuple_element_t<8, TypeParam>::value;
constexpr ck_tile::index_t K_Tile = std::tuple_element_t<9, TypeParam>::value;
// For SK only, we have 4 macro tiles in C. But, we need to make sure there is enough work along
// the K dimension to avoid falling into the edge case. Thus, we always have at least num_cu
// macro tiles in the K dimension. This assumes tile sizes are large enough such that occupancy
// is 1.
ck_tile::index_t M = M_Tile * 2;
ck_tile::index_t N = N_Tile * 2;
ck_tile::index_t K = K_Tile * num_cu;
this->Run(M, N, K);
}

View File

@@ -0,0 +1,46 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
TYPED_TEST(TEST_SUITE_NAME, StreamK_SKOnly_OneTile)
{
const ck_tile::index_t num_cu = get_cu_count();
constexpr ck_tile::index_t M_Tile = std::tuple_element_t<7, TypeParam>::value;
constexpr ck_tile::index_t N_Tile = std::tuple_element_t<8, TypeParam>::value;
constexpr ck_tile::index_t K_Tile = std::tuple_element_t<9, TypeParam>::value;
ck_tile::index_t M = M_Tile;
ck_tile::index_t N = N_Tile;
ck_tile::index_t K = K_Tile * num_cu;
this->Run(M, N, K);
}
TYPED_TEST(TEST_SUITE_NAME, StreamK_SKOnly_4Tiles_Reduction)
{
const ck_tile::index_t num_cu = get_cu_count();
constexpr ck_tile::index_t M_Tile = std::tuple_element_t<7, TypeParam>::value;
constexpr ck_tile::index_t N_Tile = std::tuple_element_t<8, TypeParam>::value;
constexpr ck_tile::index_t K_Tile = std::tuple_element_t<9, TypeParam>::value;
ck_tile::index_t M = M_Tile * 4;
ck_tile::index_t N = N_Tile;
ck_tile::index_t K = K_Tile * num_cu + (25 * K_Tile);
this->Run(M, N, K);
}
TYPED_TEST(TEST_SUITE_NAME, StreamK_SKOnly_21Tiles)
{
const ck_tile::index_t num_cu = get_cu_count();
constexpr ck_tile::index_t M_Tile = std::tuple_element_t<7, TypeParam>::value;
constexpr ck_tile::index_t N_Tile = std::tuple_element_t<8, TypeParam>::value;
constexpr ck_tile::index_t K_Tile = std::tuple_element_t<9, TypeParam>::value;
ck_tile::index_t M = M_Tile * 3;
ck_tile::index_t N = N_Tile * 7;
ck_tile::index_t K = K_Tile * num_cu + (30 * K_Tile);
this->Run(M, N, K);
}

View File

@@ -14,16 +14,28 @@ using BF16 = ck_tile::bf16_t;
using BF8 = ck_tile::bf8_t;
using F32 = float;
// Layouts
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
// Persistence
using Persistent = std::true_type;
using NonPersistent = std::false_type;
// Pipelines
using Mem = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::Mem>;
using CompV3 = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::CompV3>;
using CompV4 = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::CompV4>;
using Persistent = std::true_type;
using NonPersistent = std::false_type;
// Reduction Strategies
using Atomic = ck_tile::integral_constant<ck_tile::StreamKReductionStrategy,
ck_tile::StreamKReductionStrategy::Atomic>;
using Linear = ck_tile::integral_constant<ck_tile::StreamKReductionStrategy,
ck_tile::StreamKReductionStrategy::Linear>;
using Tree = ck_tile::integral_constant<ck_tile::StreamKReductionStrategy,
ck_tile::StreamKReductionStrategy::Tree>;
using I16 = ck_tile::number<16>;
using I32 = ck_tile::number<32>;
using I128 = ck_tile::number<128>;
using I256 = ck_tile::number<256>;
@@ -32,180 +44,157 @@ using I256 = ck_tile::number<256>;
// ========================== CompV3 Pipeline ==========================
using KernelTypesStreamKFp16PersistentCompV3 = ::testing::Types<
// ALayout BLayout CLayout ADataType BDataType AccDataType CDataType M_MacroTile N_MacroTile K_MacroTile Persistent Pipeline
// Atomics
using KernelTypesStreamKFp16PersistentAtomicCompV3 = ::testing::Types<
// ALayout BLayout CLayout ADataType BDataType AccDataType CDataType M_MacroTile N_MacroTile K_MacroTile M_WaveTile N_WaveTile K_WaveTile Persistent Pipeline ReductionStrategy
std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, Persistent, CompV3>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, Persistent, CompV3>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, Persistent, CompV3>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, Persistent, CompV3>
std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Persistent, CompV3, Atomic>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Persistent, CompV3, Atomic>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Persistent, CompV3, Atomic>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Persistent, CompV3, Atomic>
>;
using KernelTypesStreamKBf16PersistentCompV3 = ::testing::Types<
std::tuple< Row, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent, CompV3>,
std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent, CompV3>,
std::tuple< Col, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent, CompV3>,
std::tuple< Col, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent, CompV3>
using KernelTypesStreamKBf16PersistentAtomicCompV3 = ::testing::Types<
std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, I32, I32, I16, Persistent, CompV3, Atomic>
>;
using KernelTypesStreamKBf8PersistentCompV3 = ::testing::Types<
std::tuple< Row, Row, Row, BF8, BF8, F32, BF16, I128, I128, I32, Persistent, CompV3>,
std::tuple< Row, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, Persistent, CompV3>,
std::tuple< Col, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, Persistent, CompV3>,
std::tuple< Col, Row, Row, BF8, BF8, F32, BF16, I128, I128, I32, Persistent, CompV3>
using KernelTypesStreamKBf8PersistentAtomicCompV3 = ::testing::Types<
std::tuple< Row, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, I32, I32, I16, Persistent, CompV3, Atomic>
>;
using KernelTypesStreamKFp8PersistentCompV3 = ::testing::Types<
std::tuple< Row, Row, Row, F8, F8, F32, F16, I128, I128, I32, Persistent, CompV3>,
std::tuple< Row, Col, Row, F8, F8, F32, F16, I128, I128, I32, Persistent, CompV3>,
std::tuple< Col, Col, Row, F8, F8, F32, F16, I128, I128, I32, Persistent, CompV3>,
std::tuple< Col, Row, Row, F8, F8, F32, F16, I128, I128, I32, Persistent, CompV3>
using KernelTypesStreamKFp8PersistentAtomicCompV3 = ::testing::Types<
std::tuple< Row, Row, Row, F8, F8, F32, F16, I128, I128, I32, I32, I32, I16, Persistent, CompV3, Atomic>,
std::tuple< Row, Col, Row, F8, F8, F32, F16, I128, I128, I32, I32, I32, I16, Persistent, CompV3, Atomic>,
std::tuple< Col, Col, Row, F8, F8, F32, F16, I128, I128, I32, I32, I32, I16, Persistent, CompV3, Atomic>,
std::tuple< Col, Row, Row, F8, F8, F32, F16, I128, I128, I32, I32, I32, I16, Persistent, CompV3, Atomic>
>;
using KernelTypesStreamKFp16NonPersistentCompV3 = ::testing::Types<
std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent, CompV3>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent, CompV3>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent, CompV3>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent, CompV3>
using KernelTypesStreamKFp16NonPersistentAtomicCompV3 = ::testing::Types<
std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, NonPersistent, CompV3, Atomic>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, NonPersistent, CompV3, Atomic>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, NonPersistent, CompV3, Atomic>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, NonPersistent, CompV3, Atomic>
>;
using KernelTypesStreamKBf16NonPersistentCompV3 = ::testing::Types<
std::tuple< Row, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent, CompV3>,
std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent, CompV3>,
std::tuple< Col, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent, CompV3>,
std::tuple< Col, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent, CompV3>
using KernelTypesStreamKBf16NonPersistentAtomicCompV3 = ::testing::Types<
std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, I32, I32, I16, NonPersistent, CompV3, Atomic>
>;
using KernelTypesStreamKBf8NonPersistentCompV3 = ::testing::Types<
std::tuple< Row, Row, Row, BF8, BF8, F32, BF16, I128, I128, I32, NonPersistent, CompV3>,
std::tuple< Row, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, NonPersistent, CompV3>,
std::tuple< Col, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, NonPersistent, CompV3>,
std::tuple< Col, Row, Row, BF8, BF8, F32, BF16, I128, I128, I32, NonPersistent, CompV3>
using KernelTypesStreamKBf8NonPersistentAtomicCompV3 = ::testing::Types<
std::tuple< Row, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, I32, I32, I16, NonPersistent, CompV3, Atomic>
>;
using KernelTypesStreamKFp8NonPersistentCompV3 = ::testing::Types<
std::tuple< Row, Row, Row, F8, F8, F32, F16, I128, I128, I32, NonPersistent, CompV3>,
std::tuple< Row, Col, Row, F8, F8, F32, F16, I128, I128, I32, NonPersistent, CompV3>,
std::tuple< Col, Col, Row, F8, F8, F32, F16, I128, I128, I32, NonPersistent, CompV3>,
std::tuple< Col, Row, Row, F8, F8, F32, F16, I128, I128, I32, NonPersistent, CompV3>
using KernelTypesStreamKFp8NonPersistentAtomicCompV3 = ::testing::Types<
std::tuple< Row, Row, Row, F8, F8, F32, F16, I128, I128, I32, I32, I32, I16, NonPersistent, CompV3, Atomic>,
std::tuple< Row, Col, Row, F8, F8, F32, F16, I128, I128, I32, I32, I32, I16, NonPersistent, CompV3, Atomic>,
std::tuple< Col, Col, Row, F8, F8, F32, F16, I128, I128, I32, I32, I32, I16, NonPersistent, CompV3, Atomic>,
std::tuple< Col, Row, Row, F8, F8, F32, F16, I128, I128, I32, I32, I32, I16, NonPersistent, CompV3, Atomic>
>;
// ========================== CompV4 Pipeline ==========================
// Linear
using KernelTypesStreamKFp16PersistentLinearCompV3 = ::testing::Types<
// ALayout BLayout CLayout ADataType BDataType AccDataType CDataType M_MacroTile N_MacroTile K_MacroTile M_WaveTile N_WaveTile K_WaveTile Persistent Pipeline ReductionStrategy
using KernelTypesStreamKFp16PersistentCompV4 = ::testing::Types<
// ALayout BLayout CLayout ADataType BDataType AccDataType CDataType M_MacroTile N_MacroTile K_MacroTile Persistent Pipeline
std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, Persistent, CompV4>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, Persistent, CompV4>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, Persistent, CompV4>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, Persistent, CompV4>
std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Persistent, CompV3, Linear>,
std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, I16, I16, I16, Persistent, CompV3, Linear>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Persistent, CompV3, Linear>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Persistent, CompV3, Linear>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Persistent, CompV3, Linear>
>;
using KernelTypesStreamKBf16PersistentCompV4 = ::testing::Types<
std::tuple< Row, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent, CompV4>,
std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent, CompV4>,
std::tuple< Col, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent, CompV4>,
std::tuple< Col, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent, CompV4>
using KernelTypesStreamKBf16PersistentLinearCompV3 = ::testing::Types<
std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, I32, I32, I16, Persistent, CompV3, Linear>
>;
using KernelTypesStreamKBf8PersistentCompV4 = ::testing::Types<
std::tuple< Row, Row, Row, BF8, BF8, F32, BF16, I128, I128, I32, Persistent, CompV4>,
std::tuple< Row, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, Persistent, CompV4>,
std::tuple< Col, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, Persistent, CompV4>,
std::tuple< Col, Row, Row, BF8, BF8, F32, BF16, I128, I128, I32, Persistent, CompV4>
using KernelTypesStreamKBf8PersistentLinearCompV3 = ::testing::Types<
std::tuple< Row, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, I32, I32, I16, Persistent, CompV3, Linear>
>;
using KernelTypesStreamKFp8PersistentCompV4 = ::testing::Types<
std::tuple< Row, Row, Row, F8, F8, F32, F16, I128, I128, I32, Persistent, CompV4>,
std::tuple< Row, Col, Row, F8, F8, F32, F16, I128, I128, I32, Persistent, CompV4>,
std::tuple< Col, Col, Row, F8, F8, F32, F16, I128, I128, I32, Persistent, CompV4>,
std::tuple< Col, Row, Row, F8, F8, F32, F16, I128, I128, I32, Persistent, CompV4>
using KernelTypesStreamKFp8PersistentLinearCompV3 = ::testing::Types<
std::tuple< Row, Row, Row, F8, F8, F32, F16, I128, I128, I32, I32, I32, I16, Persistent, CompV3, Linear>,
std::tuple< Row, Col, Row, F8, F8, F32, F16, I128, I128, I32, I32, I32, I16, Persistent, CompV3, Linear>,
std::tuple< Col, Col, Row, F8, F8, F32, F16, I128, I128, I32, I32, I32, I16, Persistent, CompV3, Linear>,
std::tuple< Col, Row, Row, F8, F8, F32, F16, I128, I128, I32, I32, I32, I16, Persistent, CompV3, Linear>
>;
using KernelTypesStreamKFp16NonPersistentCompV4 = ::testing::Types<
std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent, CompV4>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent, CompV4>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent, CompV4>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent, CompV4>
using KernelTypesStreamKFp16NonPersistentLinearCompV3 = ::testing::Types<
std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, NonPersistent, CompV3, Linear>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, NonPersistent, CompV3, Linear>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, NonPersistent, CompV3, Linear>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, NonPersistent, CompV3, Linear>
>;
using KernelTypesStreamKBf16NonPersistentCompV4 = ::testing::Types<
std::tuple< Row, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent, CompV4>,
std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent, CompV4>,
std::tuple< Col, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent, CompV4>,
std::tuple< Col, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent, CompV4>
using KernelTypesStreamKBf16NonPersistentLinearCompV3 = ::testing::Types<
std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, I32, I32, I16, NonPersistent, CompV3, Linear>
>;
using KernelTypesStreamKBf8NonPersistentCompV4 = ::testing::Types<
std::tuple< Row, Row, Row, BF8, BF8, F32, BF16, I128, I128, I32, NonPersistent, CompV4>,
std::tuple< Row, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, NonPersistent, CompV4>,
std::tuple< Col, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, NonPersistent, CompV4>,
std::tuple< Col, Row, Row, BF8, BF8, F32, BF16, I128, I128, I32, NonPersistent, CompV4>
using KernelTypesStreamKBf8NonPersistentLinearCompV3 = ::testing::Types<
std::tuple< Row, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, I32, I32, I16, NonPersistent, CompV3, Linear>
>;
using KernelTypesStreamKFp8NonPersistentCompV4 = ::testing::Types<
std::tuple< Row, Row, Row, F8, F8, F32, F16, I128, I128, I32, NonPersistent, CompV4>,
std::tuple< Row, Col, Row, F8, F8, F32, F16, I128, I128, I32, NonPersistent, CompV4>,
std::tuple< Col, Col, Row, F8, F8, F32, F16, I128, I128, I32, NonPersistent, CompV4>,
std::tuple< Col, Row, Row, F8, F8, F32, F16, I128, I128, I32, NonPersistent, CompV4>
using KernelTypesStreamKFp8NonPersistentLinearCompV3 = ::testing::Types<
std::tuple< Row, Row, Row, F8, F8, F32, F16, I128, I128, I32, I32, I32, I16, NonPersistent, CompV3, Linear>,
std::tuple< Row, Col, Row, F8, F8, F32, F16, I128, I128, I32, I32, I32, I16, NonPersistent, CompV3, Linear>,
std::tuple< Col, Col, Row, F8, F8, F32, F16, I128, I128, I32, I32, I32, I16, NonPersistent, CompV3, Linear>,
std::tuple< Col, Row, Row, F8, F8, F32, F16, I128, I128, I32, I32, I32, I16, NonPersistent, CompV3, Linear>
>;
// ============================= Mem Pipeline =============================
// Tree
using KernelTypesStreamKFp16PersistentTreeCompV3 = ::testing::Types<
// ALayout BLayout CLayout ADataType BDataType AccDataType CDataType M_MacroTile N_MacroTile K_MacroTile M_WaveTile N_WaveTile K_WaveTile Persistent Pipeline ReductionStrategy
using KernelTypesStreamKFp16PersistentMem = ::testing::Types<
std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, Persistent, Mem>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, Persistent, Mem>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, Persistent, Mem>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, Persistent, Mem>
std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Persistent, CompV3, Tree>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Persistent, CompV3, Tree>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, I16, I16, I16, Persistent, CompV3, Tree>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Persistent, CompV3, Tree>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Persistent, CompV3, Tree>
>;
using KernelTypesStreamKBf16PersistentMem = ::testing::Types<
std::tuple< Row, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent, Mem>,
std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent, Mem>,
std::tuple< Col, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent, Mem>,
std::tuple< Col, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent, Mem>
using KernelTypesStreamKBf16PersistentTreeCompV3 = ::testing::Types<
std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, I32, I32, I16, Persistent, CompV3, Tree>
>;
using KernelTypesStreamKBf8PersistentMem = ::testing::Types<
std::tuple< Row, Row, Row, BF8, BF8, F32, BF16, I128, I128, I32, Persistent, Mem>,
std::tuple< Row, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, Persistent, Mem>,
std::tuple< Col, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, Persistent, Mem>,
std::tuple< Col, Row, Row, BF8, BF8, F32, BF16, I128, I128, I32, Persistent, Mem>
using KernelTypesStreamKBf8PersistentTreeCompV3 = ::testing::Types<
std::tuple< Row, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, I32, I32, I16, Persistent, CompV3, Tree>
>;
using KernelTypesStreamKFp8PersistentMem = ::testing::Types<
std::tuple< Row, Row, Row, F8, F8, F32, F16, I128, I128, I32, Persistent, Mem>,
std::tuple< Row, Col, Row, F8, F8, F32, F16, I128, I128, I32, Persistent, Mem>,
std::tuple< Col, Col, Row, F8, F8, F32, F16, I128, I128, I32, Persistent, Mem>,
std::tuple< Col, Row, Row, F8, F8, F32, F16, I128, I128, I32, Persistent, Mem>
using KernelTypesStreamKFp8PersistentTreeCompV3 = ::testing::Types<
std::tuple< Row, Row, Row, F8, F8, F32, F16, I128, I128, I32, I32, I32, I16, Persistent, CompV3, Tree>,
std::tuple< Row, Col, Row, F8, F8, F32, F16, I128, I128, I32, I32, I32, I16, Persistent, CompV3, Tree>,
std::tuple< Col, Col, Row, F8, F8, F32, F16, I128, I128, I32, I32, I32, I16, Persistent, CompV3, Tree>,
std::tuple< Col, Row, Row, F8, F8, F32, F16, I128, I128, I32, I32, I32, I16, Persistent, CompV3, Tree>
>;
using KernelTypesStreamKFp16NonPersistentMem = ::testing::Types<
std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent, Mem>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent, Mem>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent, Mem>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent, Mem>
using KernelTypesStreamKFp16NonPersistentTreeCompV3 = ::testing::Types<
std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, NonPersistent, CompV3, Tree>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, NonPersistent, CompV3, Tree>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, NonPersistent, CompV3, Tree>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, NonPersistent, CompV3, Tree>
>;
using KernelTypesStreamKBf16NonPersistentMem = ::testing::Types<
std::tuple< Row, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent, Mem>,
std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent, Mem>,
std::tuple< Col, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent, Mem>,
std::tuple< Col, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent, Mem>
using KernelTypesStreamKBf16NonPersistentTreeCompV3 = ::testing::Types<
std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, I32, I32, I16, NonPersistent, CompV3, Tree>
>;
using KernelTypesStreamKBf8NonPersistentMem = ::testing::Types<
std::tuple< Row, Row, Row, BF8, BF8, F32, BF16, I128, I128, I32, NonPersistent, Mem>,
std::tuple< Row, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, NonPersistent, Mem>,
std::tuple< Col, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, NonPersistent, Mem>,
std::tuple< Col, Row, Row, BF8, BF8, F32, BF16, I128, I128, I32, NonPersistent, Mem>
using KernelTypesStreamKBf8NonPersistentTreeCompV3 = ::testing::Types<
std::tuple< Row, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, I32, I32, I16, NonPersistent, CompV3, Tree>
>;
using KernelTypesStreamKFp8NonPersistentMem = ::testing::Types<
std::tuple< Row, Row, Row, F8, F8, F32, F16, I128, I128, I32, NonPersistent, Mem>,
std::tuple< Row, Col, Row, F8, F8, F32, F16, I128, I128, I32, NonPersistent, Mem>,
std::tuple< Col, Col, Row, F8, F8, F32, F16, I128, I128, I32, NonPersistent, Mem>,
std::tuple< Col, Row, Row, F8, F8, F32, F16, I128, I128, I32, NonPersistent, Mem>
using KernelTypesStreamKFp8NonPersistentTreeCompV3 = ::testing::Types<
std::tuple< Row, Row, Row, F8, F8, F32, F16, I128, I128, I32, I32, I32, I16, NonPersistent, CompV3, Tree>,
std::tuple< Row, Col, Row, F8, F8, F32, F16, I128, I128, I32, I32, I32, I16, NonPersistent, CompV3, Tree>,
std::tuple< Col, Col, Row, F8, F8, F32, F16, I128, I128, I32, I32, I32, I16, NonPersistent, CompV3, Tree>,
std::tuple< Col, Row, Row, F8, F8, F32, F16, I128, I128, I32, I32, I32, I16, NonPersistent, CompV3, Tree>
>;
// ============================= Other Pipelines =============================
using KernelTypesStreamKPipelines = ::testing::Types<
std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Persistent, Mem, Atomic>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, NonPersistent, Mem, Tree>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Persistent, Mem, Linear>,
std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Persistent, CompV4, Atomic>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, NonPersistent, CompV4, Tree>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Persistent, CompV4, Linear>
>;
// clang-format on

View File

@@ -71,23 +71,27 @@ template <typename Tuple>
class TestCkTileStreamK : public ::testing::Test
{
protected:
using ALayout = std::tuple_element_t<0, Tuple>;
using BLayout = std::tuple_element_t<1, Tuple>;
using CLayout = std::tuple_element_t<2, Tuple>;
using ADataType = std::tuple_element_t<3, Tuple>;
using BDataType = std::tuple_element_t<4, Tuple>;
using AccDataType = std::tuple_element_t<5, Tuple>;
using CDataType = std::tuple_element_t<6, Tuple>;
using DsLayout = ck_tile::tuple<>;
using DsDataType = ck_tile::tuple<>;
static constexpr ck_tile::index_t M_Tile = std::tuple_element_t<7, Tuple>::value;
static constexpr ck_tile::index_t N_Tile = std::tuple_element_t<8, Tuple>::value;
static constexpr ck_tile::index_t K_Tile = std::tuple_element_t<9, Tuple>::value;
static constexpr bool Persistent = std::tuple_element_t<10, Tuple>::value;
static constexpr auto PipelineType = std::tuple_element_t<11, Tuple>::value;
using ALayout = std::tuple_element_t<0, Tuple>;
using BLayout = std::tuple_element_t<1, Tuple>;
using CLayout = std::tuple_element_t<2, Tuple>;
using ADataType = std::tuple_element_t<3, Tuple>;
using BDataType = std::tuple_element_t<4, Tuple>;
using AccDataType = std::tuple_element_t<5, Tuple>;
using CDataType = std::tuple_element_t<6, Tuple>;
using DsLayout = ck_tile::tuple<>;
using DsDataType = ck_tile::tuple<>;
static constexpr ck_tile::index_t M_Tile = std::tuple_element_t<7, Tuple>::value;
static constexpr ck_tile::index_t N_Tile = std::tuple_element_t<8, Tuple>::value;
static constexpr ck_tile::index_t K_Tile = std::tuple_element_t<9, Tuple>::value;
static constexpr ck_tile::index_t M_Warp_Tile = std::tuple_element_t<10, Tuple>::value;
static constexpr ck_tile::index_t N_Warp_Tile = std::tuple_element_t<11, Tuple>::value;
static constexpr ck_tile::index_t K_Warp_Tile = std::tuple_element_t<12, Tuple>::value;
template <ck_tile::StreamKReductionStrategy ReductionStrategy,
bool PadM = true,
static constexpr bool Persistent = std::tuple_element_t<13, Tuple>::value;
static constexpr auto PipelineType = std::tuple_element_t<14, Tuple>::value;
static constexpr auto ReductionStrategy = std::tuple_element_t<15, Tuple>::value;
template <bool PadM = true,
bool PadN = true,
bool PadK = true,
bool Preshuffle = false,
@@ -99,10 +103,6 @@ class TestCkTileStreamK : public ::testing::Test
constexpr ck_tile::index_t N_Warp = 2;
constexpr ck_tile::index_t K_Warp = 1;
constexpr ck_tile::index_t M_Warp_Tile = 32;
constexpr ck_tile::index_t N_Warp_Tile = 32;
constexpr ck_tile::index_t K_Warp_Tile = 16;
constexpr bool kPadM = PadM;
constexpr bool kPadN = PadN;
constexpr bool kPadK = PadK;
@@ -269,8 +269,7 @@ class TestCkTileStreamK : public ::testing::Test
stride_C};
ck_tile::index_t num_accumulations_per_tile =
invoke_streamk<ck_tile::StreamKReductionStrategy::Atomic>(
args, ck_tile::stream_config{nullptr, false, 0, 0, 1});
invoke_streamk<>(args, ck_tile::stream_config{nullptr, false, 0, 0, 1});
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());

View File

@@ -0,0 +1,220 @@
#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
import unittest
from unittest.mock import mock_open, patch
from generate_test_files import suffix_to_file_tag, parse_types_header, output_path
# ------------------------------------------------------------ #
# Unit tests for helper functions in generate_test_files.py
# ------------------------------------------------------------ #
class TestSuffixToFileTag(unittest.TestCase):
def test_fp16_token(self):
suffix = "Fp16"
expected_tag = "fp16"
self.assertEqual(suffix_to_file_tag(suffix), expected_tag)
def test_bf16_token(self):
suffix = "Bf16"
expected_tag = "bf16"
self.assertEqual(suffix_to_file_tag(suffix), expected_tag)
def test_fp8_token(self):
suffix = "Fp8"
expected_tag = "fp8"
self.assertEqual(suffix_to_file_tag(suffix), expected_tag)
def test_bf8_token(self):
suffix = "Bf8"
expected_tag = "bf8"
self.assertEqual(suffix_to_file_tag(suffix), expected_tag)
def test_nonpersistent_token(self):
suffix = "NonPersistent"
expected_tag = "nonpersistent"
self.assertEqual(suffix_to_file_tag(suffix), expected_tag)
def test_persistent_token(self):
suffix = "Persistent"
expected_tag = "persistent"
self.assertEqual(suffix_to_file_tag(suffix), expected_tag)
def test_atomic_token(self):
suffix = "Atomic"
expected_tag = "atomic"
self.assertEqual(suffix_to_file_tag(suffix), expected_tag)
def test_linear_token(self):
suffix = "Linear"
expected_tag = "linear"
self.assertEqual(suffix_to_file_tag(suffix), expected_tag)
def test_tree_token(self):
suffix = "Tree"
expected_tag = "tree"
self.assertEqual(suffix_to_file_tag(suffix), expected_tag)
def test_compv3_token(self):
suffix = "CompV3"
expected_tag = "compv3"
self.assertEqual(suffix_to_file_tag(suffix), expected_tag)
def test_pipelines_token(self):
suffix = "Pipelines"
expected_tag = "pipelines"
self.assertEqual(suffix_to_file_tag(suffix), expected_tag)
def test_unknown_token(self):
suffix = "unknown"
with self.assertRaises(ValueError):
suffix_to_file_tag(suffix)
def test_multiple_valid_tokens(self):
suffix = "Fp16PersistentAtomicCompV3"
expected_tag = "fp16_persistent_atomic_compv3"
self.assertEqual(suffix_to_file_tag(suffix), expected_tag)
def test_multiple_tokens_with_unknown(self):
suffix = "Fp16PersistentUnknownCompV3"
with self.assertRaises(ValueError):
suffix_to_file_tag(suffix)
class TestParseTypesHeader(unittest.TestCase):
def validate_entries(self, entries, expected_entries):
self.assertEqual(len(entries), len(expected_entries))
for idx in range(len(entries)):
self.assertDictEqual(entries[idx], expected_entries[idx])
def test_empty_entry(self):
"""Test that an empty file returns no entries."""
mock_content = ""
with patch("builtins.open", mock_open(read_data=mock_content)):
entries = parse_types_header("fake_path.hpp", "atomic_smoke")
self.assertEqual(len(entries), 0)
def test_pipelines_smoke(self):
"""Test pipelines_smoke target: matches suffix == 'Pipelines'.
Includes: Pipelines
Excludes: Fp8NonPersistentTreeCompV3
"""
mock_content = (
"using KernelTypesStreamKPipelines = ...\n"
"using KernelTypesStreamKFp8NonPersistentTreeCompV3 = ...\n"
)
with patch("builtins.open", mock_open(read_data=mock_content)):
entries = parse_types_header("fake_path.hpp", "pipelines_smoke")
expected = [
{
"type_alias": "KernelTypesStreamKPipelines",
"class_name": "TestCkTileStreamKPipelines",
"file_tag": "pipelines",
}
]
self.validate_entries(entries, expected)
def test_extended(self):
"""Test extended target: matches 'Atomic' in suffix OR suffix == 'Pipelines'.
Includes: Fp16PersistentAtomic, Pipelines
Excludes: Bf16Linear
"""
mock_content = (
"using KernelTypesStreamKFp16PersistentAtomic = ...\n"
"using KernelTypesStreamKPipelines = ...\n"
"using KernelTypesStreamKBf16Linear = ...\n"
)
with patch("builtins.open", mock_open(read_data=mock_content)):
entries = parse_types_header("fake_path.hpp", "extended")
expected = [
{
"type_alias": "KernelTypesStreamKFp16PersistentAtomic",
"class_name": "TestCkTileStreamKFp16PersistentAtomic",
"file_tag": "fp16_persistent_atomic",
},
{
"type_alias": "KernelTypesStreamKPipelines",
"class_name": "TestCkTileStreamKPipelines",
"file_tag": "pipelines",
},
]
self.validate_entries(entries, expected)
def test_atomic_smoke(self):
"""Test atomic_smoke target: matches 'Atomic' in suffix AND suffix != 'Pipelines'.
Includes: Fp16PersistentAtomic
Excludes: Bf16Linear, Pipelines
"""
mock_content = (
"using KernelTypesStreamKFp16PersistentAtomic = ...\n"
"using KernelTypesStreamKBf16Linear = ...\n"
"using KernelTypesStreamKPipelines = ...\n"
)
with patch("builtins.open", mock_open(read_data=mock_content)):
entries = parse_types_header("fake_path.hpp", "atomic_smoke")
expected = [
{
"type_alias": "KernelTypesStreamKFp16PersistentAtomic",
"class_name": "TestCkTileStreamKFp16PersistentAtomic",
"file_tag": "fp16_persistent_atomic",
}
]
self.validate_entries(entries, expected)
def test_linear_smoke(self):
"""Test linear_smoke target: matches 'Linear' in suffix AND suffix != 'Pipelines'.
Includes: Fp8NonPersistentLinear
Excludes: Bf16PersistentAtomic, Pipelines
"""
mock_content = (
"using KernelTypesStreamKFp8NonPersistentLinear = ...\n"
"using KernelTypesStreamKBf16PersistentAtomic = ...\n"
"using KernelTypesStreamKPipelines = ...\n"
)
with patch("builtins.open", mock_open(read_data=mock_content)):
entries = parse_types_header("fake_path.hpp", "linear_smoke")
expected = [
{
"type_alias": "KernelTypesStreamKFp8NonPersistentLinear",
"class_name": "TestCkTileStreamKFp8NonPersistentLinear",
"file_tag": "fp8_nonpersistent_linear",
}
]
self.validate_entries(entries, expected)
def test_tree_smoke(self):
"""Test tree_smoke target: matches 'Tree' in suffix AND suffix != 'Pipelines'.
Includes: Bf8PersistentTreeCompV3
Excludes: Fp16Linear, Pipelines
"""
mock_content = (
"using KernelTypesStreamKBf8PersistentTreeCompV3 = ...\n"
"using KernelTypesStreamKFp16Linear = ...\n"
"using KernelTypesStreamKPipelines = ...\n"
)
with patch("builtins.open", mock_open(read_data=mock_content)):
entries = parse_types_header("fake_path.hpp", "tree_smoke")
expected = [
{
"type_alias": "KernelTypesStreamKBf8PersistentTreeCompV3",
"class_name": "TestCkTileStreamKBf8PersistentTreeCompV3",
"file_tag": "bf8_persistent_tree_compv3",
}
]
self.validate_entries(entries, expected)
class TestOutputPath(unittest.TestCase):
def test_output_path(self):
"""Test that output_path generates the correct file path."""
entry = {"file_tag": "fp16_persistent_atomic"}
output_dir = "/some/output/dir"
expected = "/some/output/dir/test_gemm_streamk_fp16_persistent_atomic.cpp"
self.assertEqual(output_path(output_dir, entry), expected)
if __name__ == "__main__":
unittest.main()