mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-05 14:11:29 +00:00
[rocm-libraries] ROCm/rocm-libraries#5722 (commit 55febd2)
[CK Tile] Stream-K gtest Code Gen
## Motivation
Stream-K was using the tile engine infrastructure for smoke tests.
However, tile engine creates a different target per kernel instance,
which has resulted in scalability issues when used in the context of
unit tests. To avoid burdens on cmake configuration and build time, we
have opted to remove our Stream-K tile engine tests. Instead, we use
pure gtests with code gen to generate repetitive .cpp files.
**Note: This appears to change a lot of files because many files are
removed since they are now generated at build time.**
## Technical Details
We originally used Tile Engine to facilitate code gen for unit tests
since we found that pure gtests required the addition of many repetitive
.cpp files of the following form:
```cpp
#include "test_gemm_streamk_common_includes.hpp"
template <typename Tuple>
class TestCkTileStreamKBf8 : public TestCkTileStreamK<Tuple>
{
};
#define TEST_SUITE_NAME TestCkTileStreamKBf8
TYPED_TEST_SUITE(TestCkTileStreamKBf8, KernelTypesStreamKBf8);
#include "test_gemm_streamk_atomic_cases.inc"
#undef TEST_SUITE_NAME
```
Due to issues encountered with tile engine, we instead use pure gtests
to generate the repetitive .cpp files. The code generator parses
`KernelTypesStreamK*` type aliases from the types header using a
two-phase approach:
1. At **configure time**, CMake runs the Python script with
`--list_files` to extract the type alias names from the header
(test_gemm_streamk_types.hpp) and compute the list of .cpp file paths
that will be generated. This lets CMake know the exact set of source
files for each target.
2. At **build time**, `add_custom_command` runs the script again with
`--gen_files` to actually emit the .cpp files into the build directory,
triggered only when the types header or generator script changes.
With these changes, we've removed all Stream-K tile engine tests. There
are now 5 targets for Stream-K GEMM tests:
1. test_ck_tile_streamk_atomic_smoke: smoke tests for Atomic reduction
strategy (pipeline: compv3)
2. test_ck_tile_streamk_linear_smoke: smoke tests for Linear reduction
strategy (pipeline: compv3)
3. test_ck_tile_streamk_tree_smoke: smoke tests for Tree reduction
strategy (pipeline: compv3)
4. test_ck_tile_streamk_pipelines_smoke: smoke tests (smaller set) for
pipelines other than compv3
- Since Stream-K can be thought of as a wrapper around universal GEMM,
we don't need to extensively test each pipeline. So, we opt to run a few
tests for different pipelines. Currently, this just consists of the mem
pipeline, but compv4 is coming soon.
5. test_ck_tile_streamk_extended: extended tests
## Test Plan
I have tests the gtests locally on gfx90a, gfx942, and gfx950.
## Test Result
All local tests pass.
## Submission Checklist
- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
6d77edc3bd
commit
7cc9bae9d2
@@ -19,55 +19,93 @@ set(EXAMPLE_GEMM_COMPILE_COMPUTE_ASYNC_OPTIONS ${EXAMPLE_GEMM_COMPILE_COMPUTE_V4
|
||||
if(GPU_TARGETS MATCHES "gfx90a|gfx942|gfx950")
|
||||
|
||||
include_directories(BEFORE ${CMAKE_CURRENT_SOURCE_DIR})
|
||||
|
||||
|
||||
#TODO: support all arches
|
||||
#TODO: current c-shuffle only supports C layout as R
|
||||
add_gtest_executable(test_ck_tile_streamk_tile_partitioner test_streamk_tile_partitioner.cpp)
|
||||
set(STREAMK_EXTENDED_SOURCES
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp16_persistent_compv3.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp16_persistent_compv4.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp16_persistent_mem.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf16_persistent_compv3.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf16_persistent_compv4.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf16_persistent_mem.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp16_nonpersistent_compv3.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp16_nonpersistent_compv4.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp16_nonpersistent_mem.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf16_nonpersistent_compv3.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf16_nonpersistent_compv4.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf16_nonpersistent_mem.cpp
|
||||
test_gemm_streamk_util.cpp)
|
||||
|
||||
# We only test fp8 and bf8 on gfx942 and gfx950 since these types are not natively supported on gfx90a
|
||||
if(GPU_TARGETS MATCHES "gfx942|gfx950")
|
||||
list(APPEND STREAMK_EXTENDED_SOURCES
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp8_persistent_compv3.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp8_persistent_compv4.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp8_persistent_mem.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf8_persistent_compv3.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf8_persistent_compv4.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf8_persistent_mem.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp8_nonpersistent_compv3.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp8_nonpersistent_compv4.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp8_nonpersistent_mem.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf8_nonpersistent_compv3.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf8_nonpersistent_compv4.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf8_nonpersistent_mem.cpp)
|
||||
endif()
|
||||
# ---- Code-generate test .cpp files from types header ----
|
||||
set(STREAMK_TYPES_HEADER ${CMAKE_CURRENT_SOURCE_DIR}/test_gemm_streamk_types.hpp)
|
||||
set(STREAMK_GEN_SCRIPT ${CMAKE_CURRENT_SOURCE_DIR}/generate_test_files.py)
|
||||
|
||||
add_gtest_executable(test_ck_tile_streamk_extended ${STREAMK_EXTENDED_SOURCES})
|
||||
target_compile_options(test_ck_tile_streamk_extended PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
# Re-run configure automatically if the types header changes (e.g. types added/removed)
|
||||
# or if the generation script changes
|
||||
set_property(DIRECTORY APPEND PROPERTY CMAKE_CONFIGURE_DEPENDS ${STREAMK_TYPES_HEADER} ${STREAMK_GEN_SCRIPT})
|
||||
|
||||
# Define the targets and their corresponding executable names
|
||||
set(STREAMK_GEN_TARGETS extended atomic_smoke linear_smoke tree_smoke pipelines_smoke)
|
||||
set(STREAMK_GEN_EXEC_EXTENDED test_ck_tile_streamk_extended)
|
||||
set(STREAMK_GEN_EXEC_ATOMIC_SMOKE test_ck_tile_streamk_atomic_smoke)
|
||||
set(STREAMK_GEN_EXEC_LINEAR_SMOKE test_ck_tile_streamk_linear_smoke)
|
||||
set(STREAMK_GEN_EXEC_TREE_SMOKE test_ck_tile_streamk_tree_smoke)
|
||||
set(STREAMK_GEN_EXEC_PIPELINES_SMOKE test_ck_tile_streamk_pipelines_smoke)
|
||||
|
||||
# Collect all test targets for umbrella label
|
||||
set(CK_TILE_GEMM_STREAMK_TEST_TARGETS
|
||||
test_ck_tile_streamk_tile_partitioner
|
||||
test_ck_tile_streamk_extended
|
||||
test_ck_tile_streamk_tile_partitioner)
|
||||
|
||||
foreach(target IN LISTS STREAMK_GEN_TARGETS)
|
||||
string(TOUPPER ${target} TARGET_UPPER)
|
||||
set(GEN_DIR ${CMAKE_CURRENT_BINARY_DIR}/${target})
|
||||
set(EXEC_NAME ${STREAMK_GEN_EXEC_${TARGET_UPPER}})
|
||||
set(LIST_FILE ${CMAKE_CURRENT_BINARY_DIR}/${target}_files.txt)
|
||||
|
||||
# Phase 1 (configure time): discover the list of files that will be generated
|
||||
execute_process(
|
||||
COMMAND ${Python3_EXECUTABLE} ${STREAMK_GEN_SCRIPT}
|
||||
--types_header ${STREAMK_TYPES_HEADER}
|
||||
--output_dir ${GEN_DIR}
|
||||
--target ${target}
|
||||
--list_files ${LIST_FILE}
|
||||
RESULT_VARIABLE ret
|
||||
ERROR_VARIABLE list_files_stderr)
|
||||
if(ret AND NOT ret EQUAL 0)
|
||||
message(FATAL_ERROR
|
||||
"Failed to list ${target} test files via Python: ${ret}\n"
|
||||
"stderr: ${list_files_stderr}"
|
||||
)
|
||||
endif()
|
||||
file(STRINGS ${LIST_FILE} ALL_SOURCES_${target})
|
||||
|
||||
# Phase 2 (build time): generate the .cpp files when the types header changes
|
||||
add_custom_command(
|
||||
OUTPUT ${ALL_SOURCES_${target}}
|
||||
COMMAND ${Python3_EXECUTABLE} ${STREAMK_GEN_SCRIPT}
|
||||
--types_header ${STREAMK_TYPES_HEADER}
|
||||
--output_dir ${GEN_DIR}
|
||||
--target ${target}
|
||||
--gen_files
|
||||
DEPENDS ${STREAMK_TYPES_HEADER} ${STREAMK_GEN_SCRIPT}
|
||||
COMMENT "Generating StreamK ${target} test sources from types header")
|
||||
|
||||
# Filter out fp8/bf8 sources on gfx90a since those types are not natively supported
|
||||
set(FILTERED_SOURCES)
|
||||
foreach(src IN LISTS ALL_SOURCES_${target})
|
||||
if(NOT src MATCHES "_(fp8|bf8)_" OR GPU_TARGETS MATCHES "gfx942|gfx950")
|
||||
list(APPEND FILTERED_SOURCES ${src})
|
||||
endif()
|
||||
endforeach()
|
||||
list(APPEND FILTERED_SOURCES test_gemm_streamk_util.cpp)
|
||||
|
||||
add_gtest_executable(${EXEC_NAME} ${FILTERED_SOURCES})
|
||||
target_compile_options(${EXEC_NAME} PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
list(APPEND CK_TILE_GEMM_STREAMK_TEST_TARGETS ${EXEC_NAME})
|
||||
endforeach()
|
||||
|
||||
# Add python unit tests to validate the code gen logic in generate_test_files.py
|
||||
add_test(
|
||||
NAME test_ck_tile_streamk_generate_test_files
|
||||
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/test_generate_test_files.py -v
|
||||
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/..
|
||||
)
|
||||
|
||||
# Label all ck_tile gemm_streamk tests with CK_TILE_GEMM_STREAMK_TESTS for selective execution
|
||||
foreach(test_target ${CK_TILE_GEMM_STREAMK_TEST_TARGETS})
|
||||
set_tests_properties(${test_target} PROPERTIES LABELS "CK_TILE_GEMM_STREAMK_TESTS")
|
||||
endforeach()
|
||||
# Also label the Python test
|
||||
set_tests_properties(test_ck_tile_streamk_generate_test_files PROPERTIES LABELS "CK_TILE_GEMM_STREAMK_TESTS")
|
||||
|
||||
# Umbrella target to build and run all ck_tile gemm_streamk tests
|
||||
# Usage: ninja ck_tile_gemm_streamk_tests
|
||||
|
||||
@@ -1,18 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestCkTileStreamKBf16NonPersistentCompV3 : public TestCkTileStreamK<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileStreamKBf16NonPersistentCompV3
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileStreamKBf16NonPersistentCompV3,
|
||||
KernelTypesStreamKBf16NonPersistentCompV3);
|
||||
|
||||
#include "test_gemm_streamk_extended_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
@@ -1,18 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestCkTileStreamKBf16NonPersistentCompV4 : public TestCkTileStreamK<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileStreamKBf16NonPersistentCompV4
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileStreamKBf16NonPersistentCompV4,
|
||||
KernelTypesStreamKBf16NonPersistentCompV4);
|
||||
|
||||
#include "test_gemm_streamk_extended_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
@@ -1,17 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestCkTileStreamKBf16NonPersistentMem : public TestCkTileStreamK<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileStreamKBf16NonPersistentMem
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileStreamKBf16NonPersistentMem, KernelTypesStreamKBf16NonPersistentMem);
|
||||
|
||||
#include "test_gemm_streamk_extended_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
@@ -1,17 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestCkTileStreamKBf16PersistentCompV3 : public TestCkTileStreamK<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileStreamKBf16PersistentCompV3
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileStreamKBf16PersistentCompV3, KernelTypesStreamKBf16PersistentCompV3);
|
||||
|
||||
#include "test_gemm_streamk_extended_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
@@ -1,17 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestCkTileStreamKBf16PersistentCompV4 : public TestCkTileStreamK<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileStreamKBf16PersistentCompV4
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileStreamKBf16PersistentCompV4, KernelTypesStreamKBf16PersistentCompV4);
|
||||
|
||||
#include "test_gemm_streamk_extended_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
@@ -1,17 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestCkTileStreamKBf16PersistentMem : public TestCkTileStreamK<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileStreamKBf16PersistentMem
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileStreamKBf16PersistentMem, KernelTypesStreamKBf16PersistentMem);
|
||||
|
||||
#include "test_gemm_streamk_extended_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
@@ -1,17 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestCkTileStreamKBf8NonPersistentCompV3 : public TestCkTileStreamK<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileStreamKBf8NonPersistentCompV3
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileStreamKBf8NonPersistentCompV3, KernelTypesStreamKBf8NonPersistentCompV3);
|
||||
|
||||
#include "test_gemm_streamk_extended_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
@@ -1,17 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestCkTileStreamKBf8NonPersistentCompV4 : public TestCkTileStreamK<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileStreamKBf8NonPersistentCompV4
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileStreamKBf8NonPersistentCompV4, KernelTypesStreamKBf8NonPersistentCompV4);
|
||||
|
||||
#include "test_gemm_streamk_extended_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
@@ -1,17 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestCkTileStreamKBf8NonPersistentMem : public TestCkTileStreamK<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileStreamKBf8NonPersistentMem
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileStreamKBf8NonPersistentMem, KernelTypesStreamKBf8NonPersistentMem);
|
||||
|
||||
#include "test_gemm_streamk_extended_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
@@ -1,17 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestCkTileStreamKBf8PersistentCompV3 : public TestCkTileStreamK<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileStreamKBf8PersistentCompV3
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileStreamKBf8PersistentCompV3, KernelTypesStreamKBf8PersistentCompV3);
|
||||
|
||||
#include "test_gemm_streamk_extended_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
@@ -1,17 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestCkTileStreamKBf8PersistentCompV4 : public TestCkTileStreamK<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileStreamKBf8PersistentCompV4
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileStreamKBf8PersistentCompV4, KernelTypesStreamKBf8PersistentCompV4);
|
||||
|
||||
#include "test_gemm_streamk_extended_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
@@ -1,17 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestCkTileStreamKBf8PersistentMem : public TestCkTileStreamK<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileStreamKBf8PersistentMem
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileStreamKBf8PersistentMem, KernelTypesStreamKBf8PersistentMem);
|
||||
|
||||
#include "test_gemm_streamk_extended_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
@@ -1,18 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestCkTileStreamKFp16NonPersistentCompV3 : public TestCkTileStreamK<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileStreamKFp16NonPersistentCompV3
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileStreamKFp16NonPersistentCompV3,
|
||||
KernelTypesStreamKFp16NonPersistentCompV3);
|
||||
|
||||
#include "test_gemm_streamk_extended_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
@@ -1,18 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestCkTileStreamKFp16NonPersistentCompV4 : public TestCkTileStreamK<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileStreamKFp16NonPersistentCompV4
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileStreamKFp16NonPersistentCompV4,
|
||||
KernelTypesStreamKFp16NonPersistentCompV4);
|
||||
|
||||
#include "test_gemm_streamk_extended_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
@@ -1,17 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestCkTileStreamKFp16NonPersistentMem : public TestCkTileStreamK<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileStreamKFp16NonPersistentMem
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileStreamKFp16NonPersistentMem, KernelTypesStreamKFp16NonPersistentMem);
|
||||
|
||||
#include "test_gemm_streamk_extended_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
@@ -1,17 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestCkTileStreamKFp16PersistentCompV3 : public TestCkTileStreamK<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileStreamKFp16PersistentCompV3
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileStreamKFp16PersistentCompV3, KernelTypesStreamKFp16PersistentCompV3);
|
||||
|
||||
#include "test_gemm_streamk_extended_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
@@ -1,17 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestCkTileStreamKFp16PersistentCompV4 : public TestCkTileStreamK<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileStreamKFp16PersistentCompV4
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileStreamKFp16PersistentCompV4, KernelTypesStreamKFp16PersistentCompV4);
|
||||
|
||||
#include "test_gemm_streamk_extended_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
@@ -1,17 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestCkTileStreamKFp16PersistentMem : public TestCkTileStreamK<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileStreamKFp16PersistentMem
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileStreamKFp16PersistentMem, KernelTypesStreamKFp16PersistentMem);
|
||||
|
||||
#include "test_gemm_streamk_extended_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
@@ -1,17 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestCkTileStreamKFp8NonPersistentCompV3 : public TestCkTileStreamK<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileStreamKFp8NonPersistentCompV3
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileStreamKFp8NonPersistentCompV3, KernelTypesStreamKFp8NonPersistentCompV3);
|
||||
|
||||
#include "test_gemm_streamk_extended_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
@@ -1,17 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestCkTileStreamKFp8NonPersistentCompV4 : public TestCkTileStreamK<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileStreamKFp8NonPersistentCompV4
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileStreamKFp8NonPersistentCompV4, KernelTypesStreamKFp8NonPersistentCompV4);
|
||||
|
||||
#include "test_gemm_streamk_extended_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
@@ -1,17 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestCkTileStreamKFp8NonPersistentMem : public TestCkTileStreamK<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileStreamKFp8NonPersistentMem
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileStreamKFp8NonPersistentMem, KernelTypesStreamKFp8NonPersistentMem);
|
||||
|
||||
#include "test_gemm_streamk_extended_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
@@ -1,17 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestCkTileStreamKFp8PersistentCompV3 : public TestCkTileStreamK<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileStreamKFp8PersistentCompV3
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileStreamKFp8PersistentCompV3, KernelTypesStreamKFp8PersistentCompV3);
|
||||
|
||||
#include "test_gemm_streamk_extended_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
@@ -1,17 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestCkTileStreamKFp8PersistentCompV4 : public TestCkTileStreamK<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileStreamKFp8PersistentCompV4
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileStreamKFp8PersistentCompV4, KernelTypesStreamKFp8PersistentCompV4);
|
||||
|
||||
#include "test_gemm_streamk_extended_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
@@ -1,17 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestCkTileStreamKFp8PersistentMem : public TestCkTileStreamK<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileStreamKFp8PersistentMem
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileStreamKFp8PersistentMem, KernelTypesStreamKFp8PersistentMem);
|
||||
|
||||
#include "test_gemm_streamk_extended_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
215
test/ck_tile/gemm_streamk/generate_test_files.py
Normal file
215
test/ck_tile/gemm_streamk/generate_test_files.py
Normal file
@@ -0,0 +1,215 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""Generate test .cpp files from KernelTypes definitions in
|
||||
test_gemm_streamk_types.hpp.
|
||||
|
||||
Two modes:
|
||||
--list_files FILE Write the list of output file paths to FILE (one per line)
|
||||
without generating the files. Used at CMake configure time.
|
||||
--gen_files Actually emit the .cpp files into --output_dir.
|
||||
Used at build time via add_custom_command.
|
||||
|
||||
Target selection (--target):
|
||||
extended Kernel types containing 'Atomic' or 'Pipelines'
|
||||
-> includes test_gemm_streamk_extended_cases.inc
|
||||
atomic_smoke Kernel types containing 'Atomic' (not 'Pipelines')
|
||||
-> includes test_gemm_streamk_atomic_cases.inc
|
||||
linear_smoke Kernel types containing 'Linear' (not 'Pipelines')
|
||||
-> includes test_gemm_streamk_reduction_cases.inc
|
||||
tree_smoke Kernel types containing 'Tree' (not 'Pipelines')
|
||||
-> includes test_gemm_streamk_reduction_cases.inc
|
||||
pipelines_smoke Kernel types matching 'Pipelines'
|
||||
-> includes test_gemm_streamk_reduction_cases.inc
|
||||
and test_gemm_streamk_atomic_cases.inc
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# Template for every generated .cpp file
|
||||
# --------------------------------------------------------------------------- #
|
||||
CPP_TEMPLATE = """\
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class {class_name} : public TestCkTileStreamK<Tuple>
|
||||
{{
|
||||
}};
|
||||
|
||||
#define TEST_SUITE_NAME {class_name}
|
||||
|
||||
TYPED_TEST_SUITE({class_name}, {type_alias});
|
||||
|
||||
{inc_includes}
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
"""
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# Target definitions: filter predicate and .inc files
|
||||
# --------------------------------------------------------------------------- #
|
||||
TARGETS = {
|
||||
"extended": {
|
||||
"filter": lambda suffix: "Atomic" in suffix or suffix == "Pipelines",
|
||||
"inc_files": ["test_gemm_streamk_extended_cases.inc"],
|
||||
},
|
||||
"atomic_smoke": {
|
||||
"filter": lambda suffix: "Atomic" in suffix and suffix != "Pipelines",
|
||||
"inc_files": ["test_gemm_streamk_atomic_cases.inc"],
|
||||
},
|
||||
"linear_smoke": {
|
||||
"filter": lambda suffix: "Linear" in suffix and suffix != "Pipelines",
|
||||
"inc_files": ["test_gemm_streamk_reduction_cases.inc"],
|
||||
},
|
||||
"tree_smoke": {
|
||||
"filter": lambda suffix: "Tree" in suffix and suffix != "Pipelines",
|
||||
"inc_files": ["test_gemm_streamk_reduction_cases.inc"],
|
||||
},
|
||||
"pipelines_smoke": {
|
||||
"filter": lambda suffix: suffix == "Pipelines",
|
||||
"inc_files": [
|
||||
"test_gemm_streamk_reduction_cases.inc",
|
||||
"test_gemm_streamk_atomic_cases.inc",
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# Mapping from CamelCase suffix fragments to file-name fragments
|
||||
# --------------------------------------------------------------------------- #
|
||||
KNOWN_TOKENS = [
|
||||
("Fp16", "fp16"),
|
||||
("Bf16", "bf16"),
|
||||
("Fp8", "fp8"),
|
||||
("Bf8", "bf8"),
|
||||
("NonPersistent", "nonpersistent"),
|
||||
("Persistent", "persistent"),
|
||||
("Atomic", "atomic"),
|
||||
("Linear", "linear"),
|
||||
("Tree", "tree"),
|
||||
("CompV3", "compv3"),
|
||||
("Pipelines", "pipelines"),
|
||||
]
|
||||
|
||||
|
||||
def suffix_to_file_tag(suffix: str) -> str:
|
||||
"""Convert a CamelCase suffix like 'Fp16PersistentAtomicCompV3' to
|
||||
'fp16_persistent_atomic_compv3'."""
|
||||
parts: list[str] = []
|
||||
remaining = suffix
|
||||
while remaining:
|
||||
matched = False
|
||||
for token, replacement in KNOWN_TOKENS:
|
||||
if remaining.startswith(token):
|
||||
parts.append(replacement)
|
||||
remaining = remaining[len(token) :]
|
||||
matched = True
|
||||
break
|
||||
if not matched:
|
||||
raise ValueError(
|
||||
f"Unrecognised token in KernelTypes suffix: '{remaining}' "
|
||||
f"(from '{suffix}')"
|
||||
)
|
||||
return "_".join(parts)
|
||||
|
||||
|
||||
def parse_types_header(header_path: str, target: str) -> list[dict]:
|
||||
"""Return a list of dicts with keys: type_alias, class_name, file_tag, suffix."""
|
||||
target_def = TARGETS[target]
|
||||
# Pattern matches lines like: using KernelTypesStreamKFp16PersistentAtomicCompV3 = ...
|
||||
pattern = re.compile(r"using\s+(KernelTypesStreamK(\w+))\s*=")
|
||||
entries: list[dict] = []
|
||||
with open(header_path) as f:
|
||||
for line in f:
|
||||
match = pattern.search(line)
|
||||
if match:
|
||||
# If the match is: using KernelTypesStreamKFp16PersistentAtomicCompV3 = ...
|
||||
# type_alias is KernelTypesStreamKFp16PersistentAtomicCompV3
|
||||
# suffix is Fp16PersistentAtomicCompV3
|
||||
type_alias = match.group(1)
|
||||
suffix = match.group(2)
|
||||
if not target_def["filter"](suffix):
|
||||
continue
|
||||
entries.append(
|
||||
{
|
||||
"type_alias": type_alias,
|
||||
"class_name": f"TestCkTileStreamK{suffix}",
|
||||
"file_tag": suffix_to_file_tag(suffix),
|
||||
}
|
||||
)
|
||||
return entries
|
||||
|
||||
|
||||
def output_path(output_dir: str, entry: dict) -> str:
|
||||
return os.path.join(output_dir, f"test_gemm_streamk_{entry['file_tag']}.cpp")
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument(
|
||||
"--types_header", required=True, help="Path to test_gemm_streamk_types.hpp"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir", required=True, help="Directory for generated .cpp files"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--target",
|
||||
required=True,
|
||||
choices=list(TARGETS.keys()),
|
||||
help="Which target to generate files for",
|
||||
)
|
||||
group = parser.add_mutually_exclusive_group(required=True)
|
||||
group.add_argument(
|
||||
"--list_files",
|
||||
metavar="FILE",
|
||||
help="Write output file paths to FILE then exit",
|
||||
)
|
||||
group.add_argument(
|
||||
"--gen_files", action="store_true", help="Generate the .cpp files"
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
|
||||
entries = parse_types_header(args.types_header, args.target)
|
||||
if not entries:
|
||||
print(
|
||||
f"ERROR: no KernelTypesStreamK* definitions found for target "
|
||||
f"'{args.target}' in {args.types_header}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
inc_files = TARGETS[args.target]["inc_files"]
|
||||
inc_includes = "\n".join(f'#include "{f}"' for f in inc_files)
|
||||
|
||||
if args.list_files:
|
||||
os.makedirs(os.path.dirname(args.list_files) or ".", exist_ok=True)
|
||||
with open(args.list_files, "w") as f:
|
||||
for entry in entries:
|
||||
f.write(output_path(args.output_dir, entry) + "\n")
|
||||
else:
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
for entry in entries:
|
||||
path = output_path(args.output_dir, entry)
|
||||
content = CPP_TEMPLATE.format(
|
||||
class_name=entry["class_name"],
|
||||
type_alias=entry["type_alias"],
|
||||
inc_includes=inc_includes,
|
||||
)
|
||||
with open(path, "w") as f:
|
||||
f.write(content)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
47
test/ck_tile/gemm_streamk/test_gemm_streamk_atomic_cases.inc
Normal file
47
test/ck_tile/gemm_streamk/test_gemm_streamk_atomic_cases.inc
Normal file
@@ -0,0 +1,47 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
TYPED_TEST(TEST_SUITE_NAME, StreamK_EdgeCase)
|
||||
{
|
||||
ck_tile::index_t M = 256;
|
||||
ck_tile::index_t N = 256;
|
||||
ck_tile::index_t K = 256;
|
||||
|
||||
this->Run(M, N, K);
|
||||
}
|
||||
|
||||
TYPED_TEST(TEST_SUITE_NAME, StreamK_DPOnly)
|
||||
{
|
||||
const ck_tile::index_t num_cu = get_cu_count();
|
||||
constexpr ck_tile::index_t M_Tile = std::tuple_element_t<7, TypeParam>::value;
|
||||
constexpr ck_tile::index_t N_Tile = std::tuple_element_t<8, TypeParam>::value;
|
||||
constexpr ck_tile::index_t K_Tile = std::tuple_element_t<9, TypeParam>::value;
|
||||
|
||||
// For DP only, we ensure that the number of tiles is a multiple of the number of CUs. This
|
||||
// assumes tile sizes are large enough such that occupancy is 1.
|
||||
ck_tile::index_t M = M_Tile * num_cu;
|
||||
ck_tile::index_t N = N_Tile;
|
||||
ck_tile::index_t K = K_Tile;
|
||||
|
||||
this->Run(M, N, K);
|
||||
}
|
||||
|
||||
TYPED_TEST(TEST_SUITE_NAME, StreamK_SKOnly)
|
||||
{
|
||||
const ck_tile::index_t num_cu = get_cu_count();
|
||||
constexpr ck_tile::index_t M_Tile = std::tuple_element_t<7, TypeParam>::value;
|
||||
constexpr ck_tile::index_t N_Tile = std::tuple_element_t<8, TypeParam>::value;
|
||||
constexpr ck_tile::index_t K_Tile = std::tuple_element_t<9, TypeParam>::value;
|
||||
|
||||
// For SK only, we have 4 macro tiles in C. But, we need to make sure there is enough work along
|
||||
// the K dimension to avoid falling into the edge case. Thus, we always have at least num_cu
|
||||
// macro tiles in the K dimension. This assumes tile sizes are large enough such that occupancy
|
||||
// is 1.
|
||||
ck_tile::index_t M = M_Tile * 2;
|
||||
ck_tile::index_t N = N_Tile * 2;
|
||||
ck_tile::index_t K = K_Tile * num_cu;
|
||||
|
||||
this->Run(M, N, K);
|
||||
}
|
||||
@@ -0,0 +1,46 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
TYPED_TEST(TEST_SUITE_NAME, StreamK_SKOnly_OneTile)
|
||||
{
|
||||
const ck_tile::index_t num_cu = get_cu_count();
|
||||
constexpr ck_tile::index_t M_Tile = std::tuple_element_t<7, TypeParam>::value;
|
||||
constexpr ck_tile::index_t N_Tile = std::tuple_element_t<8, TypeParam>::value;
|
||||
constexpr ck_tile::index_t K_Tile = std::tuple_element_t<9, TypeParam>::value;
|
||||
|
||||
ck_tile::index_t M = M_Tile;
|
||||
ck_tile::index_t N = N_Tile;
|
||||
ck_tile::index_t K = K_Tile * num_cu;
|
||||
|
||||
this->Run(M, N, K);
|
||||
}
|
||||
|
||||
TYPED_TEST(TEST_SUITE_NAME, StreamK_SKOnly_4Tiles_Reduction)
|
||||
{
|
||||
const ck_tile::index_t num_cu = get_cu_count();
|
||||
constexpr ck_tile::index_t M_Tile = std::tuple_element_t<7, TypeParam>::value;
|
||||
constexpr ck_tile::index_t N_Tile = std::tuple_element_t<8, TypeParam>::value;
|
||||
constexpr ck_tile::index_t K_Tile = std::tuple_element_t<9, TypeParam>::value;
|
||||
|
||||
ck_tile::index_t M = M_Tile * 4;
|
||||
ck_tile::index_t N = N_Tile;
|
||||
ck_tile::index_t K = K_Tile * num_cu + (25 * K_Tile);
|
||||
|
||||
this->Run(M, N, K);
|
||||
}
|
||||
|
||||
TYPED_TEST(TEST_SUITE_NAME, StreamK_SKOnly_21Tiles)
|
||||
{
|
||||
const ck_tile::index_t num_cu = get_cu_count();
|
||||
constexpr ck_tile::index_t M_Tile = std::tuple_element_t<7, TypeParam>::value;
|
||||
constexpr ck_tile::index_t N_Tile = std::tuple_element_t<8, TypeParam>::value;
|
||||
constexpr ck_tile::index_t K_Tile = std::tuple_element_t<9, TypeParam>::value;
|
||||
|
||||
ck_tile::index_t M = M_Tile * 3;
|
||||
ck_tile::index_t N = N_Tile * 7;
|
||||
ck_tile::index_t K = K_Tile * num_cu + (30 * K_Tile);
|
||||
|
||||
this->Run(M, N, K);
|
||||
}
|
||||
@@ -14,16 +14,28 @@ using BF16 = ck_tile::bf16_t;
|
||||
using BF8 = ck_tile::bf8_t;
|
||||
using F32 = float;
|
||||
|
||||
// Layouts
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
// Persistence
|
||||
using Persistent = std::true_type;
|
||||
using NonPersistent = std::false_type;
|
||||
|
||||
// Pipelines
|
||||
using Mem = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::Mem>;
|
||||
using CompV3 = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::CompV3>;
|
||||
using CompV4 = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::CompV4>;
|
||||
|
||||
using Persistent = std::true_type;
|
||||
using NonPersistent = std::false_type;
|
||||
// Reduction Strategies
|
||||
using Atomic = ck_tile::integral_constant<ck_tile::StreamKReductionStrategy,
|
||||
ck_tile::StreamKReductionStrategy::Atomic>;
|
||||
using Linear = ck_tile::integral_constant<ck_tile::StreamKReductionStrategy,
|
||||
ck_tile::StreamKReductionStrategy::Linear>;
|
||||
using Tree = ck_tile::integral_constant<ck_tile::StreamKReductionStrategy,
|
||||
ck_tile::StreamKReductionStrategy::Tree>;
|
||||
|
||||
using I16 = ck_tile::number<16>;
|
||||
using I32 = ck_tile::number<32>;
|
||||
using I128 = ck_tile::number<128>;
|
||||
using I256 = ck_tile::number<256>;
|
||||
@@ -32,180 +44,157 @@ using I256 = ck_tile::number<256>;
|
||||
|
||||
// ========================== CompV3 Pipeline ==========================
|
||||
|
||||
using KernelTypesStreamKFp16PersistentCompV3 = ::testing::Types<
|
||||
// ALayout BLayout CLayout ADataType BDataType AccDataType CDataType M_MacroTile N_MacroTile K_MacroTile Persistent Pipeline
|
||||
// Atomics
|
||||
using KernelTypesStreamKFp16PersistentAtomicCompV3 = ::testing::Types<
|
||||
// ALayout BLayout CLayout ADataType BDataType AccDataType CDataType M_MacroTile N_MacroTile K_MacroTile M_WaveTile N_WaveTile K_WaveTile Persistent Pipeline ReductionStrategy
|
||||
|
||||
std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, Persistent, CompV3>,
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, Persistent, CompV3>,
|
||||
std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, Persistent, CompV3>,
|
||||
std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, Persistent, CompV3>
|
||||
std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Persistent, CompV3, Atomic>,
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Persistent, CompV3, Atomic>,
|
||||
std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Persistent, CompV3, Atomic>,
|
||||
std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Persistent, CompV3, Atomic>
|
||||
>;
|
||||
|
||||
using KernelTypesStreamKBf16PersistentCompV3 = ::testing::Types<
|
||||
std::tuple< Row, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent, CompV3>,
|
||||
std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent, CompV3>,
|
||||
std::tuple< Col, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent, CompV3>,
|
||||
std::tuple< Col, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent, CompV3>
|
||||
using KernelTypesStreamKBf16PersistentAtomicCompV3 = ::testing::Types<
|
||||
std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, I32, I32, I16, Persistent, CompV3, Atomic>
|
||||
>;
|
||||
|
||||
using KernelTypesStreamKBf8PersistentCompV3 = ::testing::Types<
|
||||
std::tuple< Row, Row, Row, BF8, BF8, F32, BF16, I128, I128, I32, Persistent, CompV3>,
|
||||
std::tuple< Row, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, Persistent, CompV3>,
|
||||
std::tuple< Col, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, Persistent, CompV3>,
|
||||
std::tuple< Col, Row, Row, BF8, BF8, F32, BF16, I128, I128, I32, Persistent, CompV3>
|
||||
using KernelTypesStreamKBf8PersistentAtomicCompV3 = ::testing::Types<
|
||||
std::tuple< Row, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, I32, I32, I16, Persistent, CompV3, Atomic>
|
||||
>;
|
||||
|
||||
using KernelTypesStreamKFp8PersistentCompV3 = ::testing::Types<
|
||||
std::tuple< Row, Row, Row, F8, F8, F32, F16, I128, I128, I32, Persistent, CompV3>,
|
||||
std::tuple< Row, Col, Row, F8, F8, F32, F16, I128, I128, I32, Persistent, CompV3>,
|
||||
std::tuple< Col, Col, Row, F8, F8, F32, F16, I128, I128, I32, Persistent, CompV3>,
|
||||
std::tuple< Col, Row, Row, F8, F8, F32, F16, I128, I128, I32, Persistent, CompV3>
|
||||
using KernelTypesStreamKFp8PersistentAtomicCompV3 = ::testing::Types<
|
||||
std::tuple< Row, Row, Row, F8, F8, F32, F16, I128, I128, I32, I32, I32, I16, Persistent, CompV3, Atomic>,
|
||||
std::tuple< Row, Col, Row, F8, F8, F32, F16, I128, I128, I32, I32, I32, I16, Persistent, CompV3, Atomic>,
|
||||
std::tuple< Col, Col, Row, F8, F8, F32, F16, I128, I128, I32, I32, I32, I16, Persistent, CompV3, Atomic>,
|
||||
std::tuple< Col, Row, Row, F8, F8, F32, F16, I128, I128, I32, I32, I32, I16, Persistent, CompV3, Atomic>
|
||||
>;
|
||||
|
||||
using KernelTypesStreamKFp16NonPersistentCompV3 = ::testing::Types<
|
||||
std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent, CompV3>,
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent, CompV3>,
|
||||
std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent, CompV3>,
|
||||
std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent, CompV3>
|
||||
using KernelTypesStreamKFp16NonPersistentAtomicCompV3 = ::testing::Types<
|
||||
std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, NonPersistent, CompV3, Atomic>,
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, NonPersistent, CompV3, Atomic>,
|
||||
std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, NonPersistent, CompV3, Atomic>,
|
||||
std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, NonPersistent, CompV3, Atomic>
|
||||
>;
|
||||
|
||||
using KernelTypesStreamKBf16NonPersistentCompV3 = ::testing::Types<
|
||||
std::tuple< Row, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent, CompV3>,
|
||||
std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent, CompV3>,
|
||||
std::tuple< Col, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent, CompV3>,
|
||||
std::tuple< Col, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent, CompV3>
|
||||
using KernelTypesStreamKBf16NonPersistentAtomicCompV3 = ::testing::Types<
|
||||
std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, I32, I32, I16, NonPersistent, CompV3, Atomic>
|
||||
>;
|
||||
|
||||
using KernelTypesStreamKBf8NonPersistentCompV3 = ::testing::Types<
|
||||
std::tuple< Row, Row, Row, BF8, BF8, F32, BF16, I128, I128, I32, NonPersistent, CompV3>,
|
||||
std::tuple< Row, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, NonPersistent, CompV3>,
|
||||
std::tuple< Col, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, NonPersistent, CompV3>,
|
||||
std::tuple< Col, Row, Row, BF8, BF8, F32, BF16, I128, I128, I32, NonPersistent, CompV3>
|
||||
using KernelTypesStreamKBf8NonPersistentAtomicCompV3 = ::testing::Types<
|
||||
std::tuple< Row, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, I32, I32, I16, NonPersistent, CompV3, Atomic>
|
||||
>;
|
||||
|
||||
using KernelTypesStreamKFp8NonPersistentCompV3 = ::testing::Types<
|
||||
std::tuple< Row, Row, Row, F8, F8, F32, F16, I128, I128, I32, NonPersistent, CompV3>,
|
||||
std::tuple< Row, Col, Row, F8, F8, F32, F16, I128, I128, I32, NonPersistent, CompV3>,
|
||||
std::tuple< Col, Col, Row, F8, F8, F32, F16, I128, I128, I32, NonPersistent, CompV3>,
|
||||
std::tuple< Col, Row, Row, F8, F8, F32, F16, I128, I128, I32, NonPersistent, CompV3>
|
||||
using KernelTypesStreamKFp8NonPersistentAtomicCompV3 = ::testing::Types<
|
||||
std::tuple< Row, Row, Row, F8, F8, F32, F16, I128, I128, I32, I32, I32, I16, NonPersistent, CompV3, Atomic>,
|
||||
std::tuple< Row, Col, Row, F8, F8, F32, F16, I128, I128, I32, I32, I32, I16, NonPersistent, CompV3, Atomic>,
|
||||
std::tuple< Col, Col, Row, F8, F8, F32, F16, I128, I128, I32, I32, I32, I16, NonPersistent, CompV3, Atomic>,
|
||||
std::tuple< Col, Row, Row, F8, F8, F32, F16, I128, I128, I32, I32, I32, I16, NonPersistent, CompV3, Atomic>
|
||||
>;
|
||||
|
||||
// ========================== CompV4 Pipeline ==========================
|
||||
// Linear
|
||||
using KernelTypesStreamKFp16PersistentLinearCompV3 = ::testing::Types<
|
||||
// ALayout BLayout CLayout ADataType BDataType AccDataType CDataType M_MacroTile N_MacroTile K_MacroTile M_WaveTile N_WaveTile K_WaveTile Persistent Pipeline ReductionStrategy
|
||||
|
||||
using KernelTypesStreamKFp16PersistentCompV4 = ::testing::Types<
|
||||
// ALayout BLayout CLayout ADataType BDataType AccDataType CDataType M_MacroTile N_MacroTile K_MacroTile Persistent Pipeline
|
||||
|
||||
std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, Persistent, CompV4>,
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, Persistent, CompV4>,
|
||||
std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, Persistent, CompV4>,
|
||||
std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, Persistent, CompV4>
|
||||
std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Persistent, CompV3, Linear>,
|
||||
std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, I16, I16, I16, Persistent, CompV3, Linear>,
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Persistent, CompV3, Linear>,
|
||||
std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Persistent, CompV3, Linear>,
|
||||
std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Persistent, CompV3, Linear>
|
||||
>;
|
||||
|
||||
using KernelTypesStreamKBf16PersistentCompV4 = ::testing::Types<
|
||||
std::tuple< Row, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent, CompV4>,
|
||||
std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent, CompV4>,
|
||||
std::tuple< Col, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent, CompV4>,
|
||||
std::tuple< Col, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent, CompV4>
|
||||
using KernelTypesStreamKBf16PersistentLinearCompV3 = ::testing::Types<
|
||||
std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, I32, I32, I16, Persistent, CompV3, Linear>
|
||||
>;
|
||||
|
||||
using KernelTypesStreamKBf8PersistentCompV4 = ::testing::Types<
|
||||
std::tuple< Row, Row, Row, BF8, BF8, F32, BF16, I128, I128, I32, Persistent, CompV4>,
|
||||
std::tuple< Row, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, Persistent, CompV4>,
|
||||
std::tuple< Col, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, Persistent, CompV4>,
|
||||
std::tuple< Col, Row, Row, BF8, BF8, F32, BF16, I128, I128, I32, Persistent, CompV4>
|
||||
using KernelTypesStreamKBf8PersistentLinearCompV3 = ::testing::Types<
|
||||
std::tuple< Row, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, I32, I32, I16, Persistent, CompV3, Linear>
|
||||
>;
|
||||
|
||||
using KernelTypesStreamKFp8PersistentCompV4 = ::testing::Types<
|
||||
std::tuple< Row, Row, Row, F8, F8, F32, F16, I128, I128, I32, Persistent, CompV4>,
|
||||
std::tuple< Row, Col, Row, F8, F8, F32, F16, I128, I128, I32, Persistent, CompV4>,
|
||||
std::tuple< Col, Col, Row, F8, F8, F32, F16, I128, I128, I32, Persistent, CompV4>,
|
||||
std::tuple< Col, Row, Row, F8, F8, F32, F16, I128, I128, I32, Persistent, CompV4>
|
||||
using KernelTypesStreamKFp8PersistentLinearCompV3 = ::testing::Types<
|
||||
std::tuple< Row, Row, Row, F8, F8, F32, F16, I128, I128, I32, I32, I32, I16, Persistent, CompV3, Linear>,
|
||||
std::tuple< Row, Col, Row, F8, F8, F32, F16, I128, I128, I32, I32, I32, I16, Persistent, CompV3, Linear>,
|
||||
std::tuple< Col, Col, Row, F8, F8, F32, F16, I128, I128, I32, I32, I32, I16, Persistent, CompV3, Linear>,
|
||||
std::tuple< Col, Row, Row, F8, F8, F32, F16, I128, I128, I32, I32, I32, I16, Persistent, CompV3, Linear>
|
||||
>;
|
||||
|
||||
using KernelTypesStreamKFp16NonPersistentCompV4 = ::testing::Types<
|
||||
std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent, CompV4>,
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent, CompV4>,
|
||||
std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent, CompV4>,
|
||||
std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent, CompV4>
|
||||
using KernelTypesStreamKFp16NonPersistentLinearCompV3 = ::testing::Types<
|
||||
std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, NonPersistent, CompV3, Linear>,
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, NonPersistent, CompV3, Linear>,
|
||||
std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, NonPersistent, CompV3, Linear>,
|
||||
std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, NonPersistent, CompV3, Linear>
|
||||
>;
|
||||
|
||||
using KernelTypesStreamKBf16NonPersistentCompV4 = ::testing::Types<
|
||||
std::tuple< Row, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent, CompV4>,
|
||||
std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent, CompV4>,
|
||||
std::tuple< Col, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent, CompV4>,
|
||||
std::tuple< Col, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent, CompV4>
|
||||
using KernelTypesStreamKBf16NonPersistentLinearCompV3 = ::testing::Types<
|
||||
std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, I32, I32, I16, NonPersistent, CompV3, Linear>
|
||||
>;
|
||||
|
||||
using KernelTypesStreamKBf8NonPersistentCompV4 = ::testing::Types<
|
||||
std::tuple< Row, Row, Row, BF8, BF8, F32, BF16, I128, I128, I32, NonPersistent, CompV4>,
|
||||
std::tuple< Row, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, NonPersistent, CompV4>,
|
||||
std::tuple< Col, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, NonPersistent, CompV4>,
|
||||
std::tuple< Col, Row, Row, BF8, BF8, F32, BF16, I128, I128, I32, NonPersistent, CompV4>
|
||||
using KernelTypesStreamKBf8NonPersistentLinearCompV3 = ::testing::Types<
|
||||
std::tuple< Row, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, I32, I32, I16, NonPersistent, CompV3, Linear>
|
||||
>;
|
||||
|
||||
using KernelTypesStreamKFp8NonPersistentCompV4 = ::testing::Types<
|
||||
std::tuple< Row, Row, Row, F8, F8, F32, F16, I128, I128, I32, NonPersistent, CompV4>,
|
||||
std::tuple< Row, Col, Row, F8, F8, F32, F16, I128, I128, I32, NonPersistent, CompV4>,
|
||||
std::tuple< Col, Col, Row, F8, F8, F32, F16, I128, I128, I32, NonPersistent, CompV4>,
|
||||
std::tuple< Col, Row, Row, F8, F8, F32, F16, I128, I128, I32, NonPersistent, CompV4>
|
||||
using KernelTypesStreamKFp8NonPersistentLinearCompV3 = ::testing::Types<
|
||||
std::tuple< Row, Row, Row, F8, F8, F32, F16, I128, I128, I32, I32, I32, I16, NonPersistent, CompV3, Linear>,
|
||||
std::tuple< Row, Col, Row, F8, F8, F32, F16, I128, I128, I32, I32, I32, I16, NonPersistent, CompV3, Linear>,
|
||||
std::tuple< Col, Col, Row, F8, F8, F32, F16, I128, I128, I32, I32, I32, I16, NonPersistent, CompV3, Linear>,
|
||||
std::tuple< Col, Row, Row, F8, F8, F32, F16, I128, I128, I32, I32, I32, I16, NonPersistent, CompV3, Linear>
|
||||
>;
|
||||
|
||||
// ============================= Mem Pipeline =============================
|
||||
// Tree
|
||||
using KernelTypesStreamKFp16PersistentTreeCompV3 = ::testing::Types<
|
||||
// ALayout BLayout CLayout ADataType BDataType AccDataType CDataType M_MacroTile N_MacroTile K_MacroTile M_WaveTile N_WaveTile K_WaveTile Persistent Pipeline ReductionStrategy
|
||||
|
||||
using KernelTypesStreamKFp16PersistentMem = ::testing::Types<
|
||||
std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, Persistent, Mem>,
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, Persistent, Mem>,
|
||||
std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, Persistent, Mem>,
|
||||
std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, Persistent, Mem>
|
||||
std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Persistent, CompV3, Tree>,
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Persistent, CompV3, Tree>,
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, I16, I16, I16, Persistent, CompV3, Tree>,
|
||||
std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Persistent, CompV3, Tree>,
|
||||
std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Persistent, CompV3, Tree>
|
||||
>;
|
||||
|
||||
using KernelTypesStreamKBf16PersistentMem = ::testing::Types<
|
||||
std::tuple< Row, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent, Mem>,
|
||||
std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent, Mem>,
|
||||
std::tuple< Col, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent, Mem>,
|
||||
std::tuple< Col, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent, Mem>
|
||||
using KernelTypesStreamKBf16PersistentTreeCompV3 = ::testing::Types<
|
||||
std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, I32, I32, I16, Persistent, CompV3, Tree>
|
||||
>;
|
||||
|
||||
using KernelTypesStreamKBf8PersistentMem = ::testing::Types<
|
||||
std::tuple< Row, Row, Row, BF8, BF8, F32, BF16, I128, I128, I32, Persistent, Mem>,
|
||||
std::tuple< Row, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, Persistent, Mem>,
|
||||
std::tuple< Col, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, Persistent, Mem>,
|
||||
std::tuple< Col, Row, Row, BF8, BF8, F32, BF16, I128, I128, I32, Persistent, Mem>
|
||||
using KernelTypesStreamKBf8PersistentTreeCompV3 = ::testing::Types<
|
||||
std::tuple< Row, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, I32, I32, I16, Persistent, CompV3, Tree>
|
||||
>;
|
||||
|
||||
using KernelTypesStreamKFp8PersistentMem = ::testing::Types<
|
||||
std::tuple< Row, Row, Row, F8, F8, F32, F16, I128, I128, I32, Persistent, Mem>,
|
||||
std::tuple< Row, Col, Row, F8, F8, F32, F16, I128, I128, I32, Persistent, Mem>,
|
||||
std::tuple< Col, Col, Row, F8, F8, F32, F16, I128, I128, I32, Persistent, Mem>,
|
||||
std::tuple< Col, Row, Row, F8, F8, F32, F16, I128, I128, I32, Persistent, Mem>
|
||||
using KernelTypesStreamKFp8PersistentTreeCompV3 = ::testing::Types<
|
||||
std::tuple< Row, Row, Row, F8, F8, F32, F16, I128, I128, I32, I32, I32, I16, Persistent, CompV3, Tree>,
|
||||
std::tuple< Row, Col, Row, F8, F8, F32, F16, I128, I128, I32, I32, I32, I16, Persistent, CompV3, Tree>,
|
||||
std::tuple< Col, Col, Row, F8, F8, F32, F16, I128, I128, I32, I32, I32, I16, Persistent, CompV3, Tree>,
|
||||
std::tuple< Col, Row, Row, F8, F8, F32, F16, I128, I128, I32, I32, I32, I16, Persistent, CompV3, Tree>
|
||||
>;
|
||||
|
||||
using KernelTypesStreamKFp16NonPersistentMem = ::testing::Types<
|
||||
std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent, Mem>,
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent, Mem>,
|
||||
std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent, Mem>,
|
||||
std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent, Mem>
|
||||
using KernelTypesStreamKFp16NonPersistentTreeCompV3 = ::testing::Types<
|
||||
std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, NonPersistent, CompV3, Tree>,
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, NonPersistent, CompV3, Tree>,
|
||||
std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, NonPersistent, CompV3, Tree>,
|
||||
std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, NonPersistent, CompV3, Tree>
|
||||
>;
|
||||
|
||||
using KernelTypesStreamKBf16NonPersistentMem = ::testing::Types<
|
||||
std::tuple< Row, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent, Mem>,
|
||||
std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent, Mem>,
|
||||
std::tuple< Col, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent, Mem>,
|
||||
std::tuple< Col, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent, Mem>
|
||||
using KernelTypesStreamKBf16NonPersistentTreeCompV3 = ::testing::Types<
|
||||
std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, I32, I32, I16, NonPersistent, CompV3, Tree>
|
||||
>;
|
||||
|
||||
using KernelTypesStreamKBf8NonPersistentMem = ::testing::Types<
|
||||
std::tuple< Row, Row, Row, BF8, BF8, F32, BF16, I128, I128, I32, NonPersistent, Mem>,
|
||||
std::tuple< Row, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, NonPersistent, Mem>,
|
||||
std::tuple< Col, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, NonPersistent, Mem>,
|
||||
std::tuple< Col, Row, Row, BF8, BF8, F32, BF16, I128, I128, I32, NonPersistent, Mem>
|
||||
using KernelTypesStreamKBf8NonPersistentTreeCompV3 = ::testing::Types<
|
||||
std::tuple< Row, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, I32, I32, I16, NonPersistent, CompV3, Tree>
|
||||
>;
|
||||
|
||||
using KernelTypesStreamKFp8NonPersistentMem = ::testing::Types<
|
||||
std::tuple< Row, Row, Row, F8, F8, F32, F16, I128, I128, I32, NonPersistent, Mem>,
|
||||
std::tuple< Row, Col, Row, F8, F8, F32, F16, I128, I128, I32, NonPersistent, Mem>,
|
||||
std::tuple< Col, Col, Row, F8, F8, F32, F16, I128, I128, I32, NonPersistent, Mem>,
|
||||
std::tuple< Col, Row, Row, F8, F8, F32, F16, I128, I128, I32, NonPersistent, Mem>
|
||||
using KernelTypesStreamKFp8NonPersistentTreeCompV3 = ::testing::Types<
|
||||
std::tuple< Row, Row, Row, F8, F8, F32, F16, I128, I128, I32, I32, I32, I16, NonPersistent, CompV3, Tree>,
|
||||
std::tuple< Row, Col, Row, F8, F8, F32, F16, I128, I128, I32, I32, I32, I16, NonPersistent, CompV3, Tree>,
|
||||
std::tuple< Col, Col, Row, F8, F8, F32, F16, I128, I128, I32, I32, I32, I16, NonPersistent, CompV3, Tree>,
|
||||
std::tuple< Col, Row, Row, F8, F8, F32, F16, I128, I128, I32, I32, I32, I16, NonPersistent, CompV3, Tree>
|
||||
>;
|
||||
|
||||
// ============================= Other Pipelines =============================
|
||||
|
||||
using KernelTypesStreamKPipelines = ::testing::Types<
|
||||
std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Persistent, Mem, Atomic>,
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, NonPersistent, Mem, Tree>,
|
||||
std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Persistent, Mem, Linear>,
|
||||
std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Persistent, CompV4, Atomic>,
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, NonPersistent, CompV4, Tree>,
|
||||
std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Persistent, CompV4, Linear>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
@@ -71,23 +71,27 @@ template <typename Tuple>
|
||||
class TestCkTileStreamK : public ::testing::Test
|
||||
{
|
||||
protected:
|
||||
using ALayout = std::tuple_element_t<0, Tuple>;
|
||||
using BLayout = std::tuple_element_t<1, Tuple>;
|
||||
using CLayout = std::tuple_element_t<2, Tuple>;
|
||||
using ADataType = std::tuple_element_t<3, Tuple>;
|
||||
using BDataType = std::tuple_element_t<4, Tuple>;
|
||||
using AccDataType = std::tuple_element_t<5, Tuple>;
|
||||
using CDataType = std::tuple_element_t<6, Tuple>;
|
||||
using DsLayout = ck_tile::tuple<>;
|
||||
using DsDataType = ck_tile::tuple<>;
|
||||
static constexpr ck_tile::index_t M_Tile = std::tuple_element_t<7, Tuple>::value;
|
||||
static constexpr ck_tile::index_t N_Tile = std::tuple_element_t<8, Tuple>::value;
|
||||
static constexpr ck_tile::index_t K_Tile = std::tuple_element_t<9, Tuple>::value;
|
||||
static constexpr bool Persistent = std::tuple_element_t<10, Tuple>::value;
|
||||
static constexpr auto PipelineType = std::tuple_element_t<11, Tuple>::value;
|
||||
using ALayout = std::tuple_element_t<0, Tuple>;
|
||||
using BLayout = std::tuple_element_t<1, Tuple>;
|
||||
using CLayout = std::tuple_element_t<2, Tuple>;
|
||||
using ADataType = std::tuple_element_t<3, Tuple>;
|
||||
using BDataType = std::tuple_element_t<4, Tuple>;
|
||||
using AccDataType = std::tuple_element_t<5, Tuple>;
|
||||
using CDataType = std::tuple_element_t<6, Tuple>;
|
||||
using DsLayout = ck_tile::tuple<>;
|
||||
using DsDataType = ck_tile::tuple<>;
|
||||
static constexpr ck_tile::index_t M_Tile = std::tuple_element_t<7, Tuple>::value;
|
||||
static constexpr ck_tile::index_t N_Tile = std::tuple_element_t<8, Tuple>::value;
|
||||
static constexpr ck_tile::index_t K_Tile = std::tuple_element_t<9, Tuple>::value;
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = std::tuple_element_t<10, Tuple>::value;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = std::tuple_element_t<11, Tuple>::value;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = std::tuple_element_t<12, Tuple>::value;
|
||||
|
||||
template <ck_tile::StreamKReductionStrategy ReductionStrategy,
|
||||
bool PadM = true,
|
||||
static constexpr bool Persistent = std::tuple_element_t<13, Tuple>::value;
|
||||
static constexpr auto PipelineType = std::tuple_element_t<14, Tuple>::value;
|
||||
static constexpr auto ReductionStrategy = std::tuple_element_t<15, Tuple>::value;
|
||||
|
||||
template <bool PadM = true,
|
||||
bool PadN = true,
|
||||
bool PadK = true,
|
||||
bool Preshuffle = false,
|
||||
@@ -99,10 +103,6 @@ class TestCkTileStreamK : public ::testing::Test
|
||||
constexpr ck_tile::index_t N_Warp = 2;
|
||||
constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
|
||||
constexpr bool kPadM = PadM;
|
||||
constexpr bool kPadN = PadN;
|
||||
constexpr bool kPadK = PadK;
|
||||
@@ -269,8 +269,7 @@ class TestCkTileStreamK : public ::testing::Test
|
||||
stride_C};
|
||||
|
||||
ck_tile::index_t num_accumulations_per_tile =
|
||||
invoke_streamk<ck_tile::StreamKReductionStrategy::Atomic>(
|
||||
args, ck_tile::stream_config{nullptr, false, 0, 0, 1});
|
||||
invoke_streamk<>(args, ck_tile::stream_config{nullptr, false, 0, 0, 1});
|
||||
|
||||
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
|
||||
|
||||
|
||||
220
test/ck_tile/gemm_streamk/test_generate_test_files.py
Normal file
220
test/ck_tile/gemm_streamk/test_generate_test_files.py
Normal file
@@ -0,0 +1,220 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import unittest
|
||||
from unittest.mock import mock_open, patch
|
||||
from generate_test_files import suffix_to_file_tag, parse_types_header, output_path
|
||||
|
||||
# ------------------------------------------------------------ #
|
||||
# Unit tests for helper functions in generate_test_files.py
|
||||
# ------------------------------------------------------------ #
|
||||
|
||||
|
||||
class TestSuffixToFileTag(unittest.TestCase):
|
||||
def test_fp16_token(self):
|
||||
suffix = "Fp16"
|
||||
expected_tag = "fp16"
|
||||
self.assertEqual(suffix_to_file_tag(suffix), expected_tag)
|
||||
|
||||
def test_bf16_token(self):
|
||||
suffix = "Bf16"
|
||||
expected_tag = "bf16"
|
||||
self.assertEqual(suffix_to_file_tag(suffix), expected_tag)
|
||||
|
||||
def test_fp8_token(self):
|
||||
suffix = "Fp8"
|
||||
expected_tag = "fp8"
|
||||
self.assertEqual(suffix_to_file_tag(suffix), expected_tag)
|
||||
|
||||
def test_bf8_token(self):
|
||||
suffix = "Bf8"
|
||||
expected_tag = "bf8"
|
||||
self.assertEqual(suffix_to_file_tag(suffix), expected_tag)
|
||||
|
||||
def test_nonpersistent_token(self):
|
||||
suffix = "NonPersistent"
|
||||
expected_tag = "nonpersistent"
|
||||
self.assertEqual(suffix_to_file_tag(suffix), expected_tag)
|
||||
|
||||
def test_persistent_token(self):
|
||||
suffix = "Persistent"
|
||||
expected_tag = "persistent"
|
||||
self.assertEqual(suffix_to_file_tag(suffix), expected_tag)
|
||||
|
||||
def test_atomic_token(self):
|
||||
suffix = "Atomic"
|
||||
expected_tag = "atomic"
|
||||
self.assertEqual(suffix_to_file_tag(suffix), expected_tag)
|
||||
|
||||
def test_linear_token(self):
|
||||
suffix = "Linear"
|
||||
expected_tag = "linear"
|
||||
self.assertEqual(suffix_to_file_tag(suffix), expected_tag)
|
||||
|
||||
def test_tree_token(self):
|
||||
suffix = "Tree"
|
||||
expected_tag = "tree"
|
||||
self.assertEqual(suffix_to_file_tag(suffix), expected_tag)
|
||||
|
||||
def test_compv3_token(self):
|
||||
suffix = "CompV3"
|
||||
expected_tag = "compv3"
|
||||
self.assertEqual(suffix_to_file_tag(suffix), expected_tag)
|
||||
|
||||
def test_pipelines_token(self):
|
||||
suffix = "Pipelines"
|
||||
expected_tag = "pipelines"
|
||||
self.assertEqual(suffix_to_file_tag(suffix), expected_tag)
|
||||
|
||||
def test_unknown_token(self):
|
||||
suffix = "unknown"
|
||||
with self.assertRaises(ValueError):
|
||||
suffix_to_file_tag(suffix)
|
||||
|
||||
def test_multiple_valid_tokens(self):
|
||||
suffix = "Fp16PersistentAtomicCompV3"
|
||||
expected_tag = "fp16_persistent_atomic_compv3"
|
||||
self.assertEqual(suffix_to_file_tag(suffix), expected_tag)
|
||||
|
||||
def test_multiple_tokens_with_unknown(self):
|
||||
suffix = "Fp16PersistentUnknownCompV3"
|
||||
with self.assertRaises(ValueError):
|
||||
suffix_to_file_tag(suffix)
|
||||
|
||||
|
||||
class TestParseTypesHeader(unittest.TestCase):
|
||||
def validate_entries(self, entries, expected_entries):
|
||||
self.assertEqual(len(entries), len(expected_entries))
|
||||
for idx in range(len(entries)):
|
||||
self.assertDictEqual(entries[idx], expected_entries[idx])
|
||||
|
||||
def test_empty_entry(self):
|
||||
"""Test that an empty file returns no entries."""
|
||||
mock_content = ""
|
||||
with patch("builtins.open", mock_open(read_data=mock_content)):
|
||||
entries = parse_types_header("fake_path.hpp", "atomic_smoke")
|
||||
self.assertEqual(len(entries), 0)
|
||||
|
||||
def test_pipelines_smoke(self):
|
||||
"""Test pipelines_smoke target: matches suffix == 'Pipelines'.
|
||||
Includes: Pipelines
|
||||
Excludes: Fp8NonPersistentTreeCompV3
|
||||
"""
|
||||
mock_content = (
|
||||
"using KernelTypesStreamKPipelines = ...\n"
|
||||
"using KernelTypesStreamKFp8NonPersistentTreeCompV3 = ...\n"
|
||||
)
|
||||
with patch("builtins.open", mock_open(read_data=mock_content)):
|
||||
entries = parse_types_header("fake_path.hpp", "pipelines_smoke")
|
||||
expected = [
|
||||
{
|
||||
"type_alias": "KernelTypesStreamKPipelines",
|
||||
"class_name": "TestCkTileStreamKPipelines",
|
||||
"file_tag": "pipelines",
|
||||
}
|
||||
]
|
||||
self.validate_entries(entries, expected)
|
||||
|
||||
def test_extended(self):
|
||||
"""Test extended target: matches 'Atomic' in suffix OR suffix == 'Pipelines'.
|
||||
Includes: Fp16PersistentAtomic, Pipelines
|
||||
Excludes: Bf16Linear
|
||||
"""
|
||||
mock_content = (
|
||||
"using KernelTypesStreamKFp16PersistentAtomic = ...\n"
|
||||
"using KernelTypesStreamKPipelines = ...\n"
|
||||
"using KernelTypesStreamKBf16Linear = ...\n"
|
||||
)
|
||||
with patch("builtins.open", mock_open(read_data=mock_content)):
|
||||
entries = parse_types_header("fake_path.hpp", "extended")
|
||||
expected = [
|
||||
{
|
||||
"type_alias": "KernelTypesStreamKFp16PersistentAtomic",
|
||||
"class_name": "TestCkTileStreamKFp16PersistentAtomic",
|
||||
"file_tag": "fp16_persistent_atomic",
|
||||
},
|
||||
{
|
||||
"type_alias": "KernelTypesStreamKPipelines",
|
||||
"class_name": "TestCkTileStreamKPipelines",
|
||||
"file_tag": "pipelines",
|
||||
},
|
||||
]
|
||||
self.validate_entries(entries, expected)
|
||||
|
||||
def test_atomic_smoke(self):
|
||||
"""Test atomic_smoke target: matches 'Atomic' in suffix AND suffix != 'Pipelines'.
|
||||
Includes: Fp16PersistentAtomic
|
||||
Excludes: Bf16Linear, Pipelines
|
||||
"""
|
||||
mock_content = (
|
||||
"using KernelTypesStreamKFp16PersistentAtomic = ...\n"
|
||||
"using KernelTypesStreamKBf16Linear = ...\n"
|
||||
"using KernelTypesStreamKPipelines = ...\n"
|
||||
)
|
||||
with patch("builtins.open", mock_open(read_data=mock_content)):
|
||||
entries = parse_types_header("fake_path.hpp", "atomic_smoke")
|
||||
expected = [
|
||||
{
|
||||
"type_alias": "KernelTypesStreamKFp16PersistentAtomic",
|
||||
"class_name": "TestCkTileStreamKFp16PersistentAtomic",
|
||||
"file_tag": "fp16_persistent_atomic",
|
||||
}
|
||||
]
|
||||
self.validate_entries(entries, expected)
|
||||
|
||||
def test_linear_smoke(self):
|
||||
"""Test linear_smoke target: matches 'Linear' in suffix AND suffix != 'Pipelines'.
|
||||
Includes: Fp8NonPersistentLinear
|
||||
Excludes: Bf16PersistentAtomic, Pipelines
|
||||
"""
|
||||
mock_content = (
|
||||
"using KernelTypesStreamKFp8NonPersistentLinear = ...\n"
|
||||
"using KernelTypesStreamKBf16PersistentAtomic = ...\n"
|
||||
"using KernelTypesStreamKPipelines = ...\n"
|
||||
)
|
||||
with patch("builtins.open", mock_open(read_data=mock_content)):
|
||||
entries = parse_types_header("fake_path.hpp", "linear_smoke")
|
||||
expected = [
|
||||
{
|
||||
"type_alias": "KernelTypesStreamKFp8NonPersistentLinear",
|
||||
"class_name": "TestCkTileStreamKFp8NonPersistentLinear",
|
||||
"file_tag": "fp8_nonpersistent_linear",
|
||||
}
|
||||
]
|
||||
self.validate_entries(entries, expected)
|
||||
|
||||
def test_tree_smoke(self):
|
||||
"""Test tree_smoke target: matches 'Tree' in suffix AND suffix != 'Pipelines'.
|
||||
Includes: Bf8PersistentTreeCompV3
|
||||
Excludes: Fp16Linear, Pipelines
|
||||
"""
|
||||
mock_content = (
|
||||
"using KernelTypesStreamKBf8PersistentTreeCompV3 = ...\n"
|
||||
"using KernelTypesStreamKFp16Linear = ...\n"
|
||||
"using KernelTypesStreamKPipelines = ...\n"
|
||||
)
|
||||
with patch("builtins.open", mock_open(read_data=mock_content)):
|
||||
entries = parse_types_header("fake_path.hpp", "tree_smoke")
|
||||
expected = [
|
||||
{
|
||||
"type_alias": "KernelTypesStreamKBf8PersistentTreeCompV3",
|
||||
"class_name": "TestCkTileStreamKBf8PersistentTreeCompV3",
|
||||
"file_tag": "bf8_persistent_tree_compv3",
|
||||
}
|
||||
]
|
||||
self.validate_entries(entries, expected)
|
||||
|
||||
|
||||
class TestOutputPath(unittest.TestCase):
|
||||
def test_output_path(self):
|
||||
"""Test that output_path generates the correct file path."""
|
||||
entry = {"file_tag": "fp16_persistent_atomic"}
|
||||
output_dir = "/some/output/dir"
|
||||
expected = "/some/output/dir/test_gemm_streamk_fp16_persistent_atomic.cpp"
|
||||
self.assertEqual(output_path(output_dir, entry), expected)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user