mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
[CK Tile] Stream-K gtest Code Gen (#5722)
## Motivation
Stream-K was using the tile engine infrastructure for smoke tests.
However, tile engine creates a different target per kernel instance,
which has resulted in scalability issues when used in the context of
unit tests. To avoid burdens on cmake configuration and build time, we
have opted to remove our Stream-K tile engine tests. Instead, we use
pure gtests with code gen to generate repetitive .cpp files.
**Note: This appears to change a lot of files because many files are
removed since they are now generated at build time.**
## Technical Details
We originally used Tile Engine to facilitate code gen for unit tests
since we found that pure gtests required the addition of many repetitive
.cpp files of the following form:
```cpp
#include "test_gemm_streamk_common_includes.hpp"
template <typename Tuple>
class TestCkTileStreamKBf8 : public TestCkTileStreamK<Tuple>
{
};
#define TEST_SUITE_NAME TestCkTileStreamKBf8
TYPED_TEST_SUITE(TestCkTileStreamKBf8, KernelTypesStreamKBf8);
#include "test_gemm_streamk_atomic_cases.inc"
#undef TEST_SUITE_NAME
```
Due to issues encountered with tile engine, we instead use pure gtests
to generate the repetitive .cpp files. The code generator parses
`KernelTypesStreamK*` type aliases from the types header using a
two-phase approach:
1. At **configure time**, CMake runs the Python script with
`--list_files` to extract the type alias names from the header
(test_gemm_streamk_types.hpp) and compute the list of .cpp file paths
that will be generated. This lets CMake know the exact set of source
files for each target.
2. At **build time**, `add_custom_command` runs the script again with
`--gen_files` to actually emit the .cpp files into the build directory,
triggered only when the types header or generator script changes.
With these changes, we've removed all Stream-K tile engine tests. There
are now 5 targets for Stream-K GEMM tests:
1. test_ck_tile_streamk_atomic_smoke: smoke tests for Atomic reduction
strategy (pipeline: compv3)
2. test_ck_tile_streamk_linear_smoke: smoke tests for Linear reduction
strategy (pipeline: compv3)
3. test_ck_tile_streamk_tree_smoke: smoke tests for Tree reduction
strategy (pipeline: compv3)
4. test_ck_tile_streamk_pipelines_smoke: smoke tests (smaller set) for
pipelines other than compv3
- Since Stream-K can be thought of as a wrapper around universal GEMM,
we don't need to extensively test each pipeline. So, we opt to run a few
tests for different pipelines. Currently, this just consists of the mem
pipeline, but compv4 is coming soon.
5. test_ck_tile_streamk_extended: extended tests
## Test Plan
I have tests the gtests locally on gfx90a, gfx942, and gfx950.
## Test Result
All local tests pass.
## Submission Checklist
- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
---------
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -69,5 +69,4 @@ add_subdirectory(fmha)
|
||||
add_subdirectory(gemm_tile_engine)
|
||||
add_subdirectory(pooling)
|
||||
add_subdirectory(grouped_conv)
|
||||
add_subdirectory(gemm_streamk_tile_engine)
|
||||
add_subdirectory(pooling_tile_engine)
|
||||
|
||||
@@ -19,55 +19,93 @@ set(EXAMPLE_GEMM_COMPILE_COMPUTE_ASYNC_OPTIONS ${EXAMPLE_GEMM_COMPILE_COMPUTE_V4
|
||||
if(GPU_TARGETS MATCHES "gfx90a|gfx942|gfx950")
|
||||
|
||||
include_directories(BEFORE ${CMAKE_CURRENT_SOURCE_DIR})
|
||||
|
||||
|
||||
#TODO: support all arches
|
||||
#TODO: current c-shuffle only supports C layout as R
|
||||
add_gtest_executable(test_ck_tile_streamk_tile_partitioner test_streamk_tile_partitioner.cpp)
|
||||
set(STREAMK_EXTENDED_SOURCES
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp16_persistent_compv3.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp16_persistent_compv4.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp16_persistent_mem.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf16_persistent_compv3.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf16_persistent_compv4.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf16_persistent_mem.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp16_nonpersistent_compv3.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp16_nonpersistent_compv4.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp16_nonpersistent_mem.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf16_nonpersistent_compv3.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf16_nonpersistent_compv4.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf16_nonpersistent_mem.cpp
|
||||
test_gemm_streamk_util.cpp)
|
||||
|
||||
# We only test fp8 and bf8 on gfx942 and gfx950 since these types are not natively supported on gfx90a
|
||||
if(GPU_TARGETS MATCHES "gfx942|gfx950")
|
||||
list(APPEND STREAMK_EXTENDED_SOURCES
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp8_persistent_compv3.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp8_persistent_compv4.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp8_persistent_mem.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf8_persistent_compv3.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf8_persistent_compv4.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf8_persistent_mem.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp8_nonpersistent_compv3.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp8_nonpersistent_compv4.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp8_nonpersistent_mem.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf8_nonpersistent_compv3.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf8_nonpersistent_compv4.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf8_nonpersistent_mem.cpp)
|
||||
endif()
|
||||
# ---- Code-generate test .cpp files from types header ----
|
||||
set(STREAMK_TYPES_HEADER ${CMAKE_CURRENT_SOURCE_DIR}/test_gemm_streamk_types.hpp)
|
||||
set(STREAMK_GEN_SCRIPT ${CMAKE_CURRENT_SOURCE_DIR}/generate_test_files.py)
|
||||
|
||||
add_gtest_executable(test_ck_tile_streamk_extended ${STREAMK_EXTENDED_SOURCES})
|
||||
target_compile_options(test_ck_tile_streamk_extended PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
# Re-run configure automatically if the types header changes (e.g. types added/removed)
|
||||
# or if the generation script changes
|
||||
set_property(DIRECTORY APPEND PROPERTY CMAKE_CONFIGURE_DEPENDS ${STREAMK_TYPES_HEADER} ${STREAMK_GEN_SCRIPT})
|
||||
|
||||
# Define the targets and their corresponding executable names
|
||||
set(STREAMK_GEN_TARGETS extended atomic_smoke linear_smoke tree_smoke pipelines_smoke)
|
||||
set(STREAMK_GEN_EXEC_EXTENDED test_ck_tile_streamk_extended)
|
||||
set(STREAMK_GEN_EXEC_ATOMIC_SMOKE test_ck_tile_streamk_atomic_smoke)
|
||||
set(STREAMK_GEN_EXEC_LINEAR_SMOKE test_ck_tile_streamk_linear_smoke)
|
||||
set(STREAMK_GEN_EXEC_TREE_SMOKE test_ck_tile_streamk_tree_smoke)
|
||||
set(STREAMK_GEN_EXEC_PIPELINES_SMOKE test_ck_tile_streamk_pipelines_smoke)
|
||||
|
||||
# Collect all test targets for umbrella label
|
||||
set(CK_TILE_GEMM_STREAMK_TEST_TARGETS
|
||||
test_ck_tile_streamk_tile_partitioner
|
||||
test_ck_tile_streamk_extended
|
||||
test_ck_tile_streamk_tile_partitioner)
|
||||
|
||||
foreach(target IN LISTS STREAMK_GEN_TARGETS)
|
||||
string(TOUPPER ${target} TARGET_UPPER)
|
||||
set(GEN_DIR ${CMAKE_CURRENT_BINARY_DIR}/${target})
|
||||
set(EXEC_NAME ${STREAMK_GEN_EXEC_${TARGET_UPPER}})
|
||||
set(LIST_FILE ${CMAKE_CURRENT_BINARY_DIR}/${target}_files.txt)
|
||||
|
||||
# Phase 1 (configure time): discover the list of files that will be generated
|
||||
execute_process(
|
||||
COMMAND ${Python3_EXECUTABLE} ${STREAMK_GEN_SCRIPT}
|
||||
--types_header ${STREAMK_TYPES_HEADER}
|
||||
--output_dir ${GEN_DIR}
|
||||
--target ${target}
|
||||
--list_files ${LIST_FILE}
|
||||
RESULT_VARIABLE ret
|
||||
ERROR_VARIABLE list_files_stderr)
|
||||
if(ret AND NOT ret EQUAL 0)
|
||||
message(FATAL_ERROR
|
||||
"Failed to list ${target} test files via Python: ${ret}\n"
|
||||
"stderr: ${list_files_stderr}"
|
||||
)
|
||||
endif()
|
||||
file(STRINGS ${LIST_FILE} ALL_SOURCES_${target})
|
||||
|
||||
# Phase 2 (build time): generate the .cpp files when the types header changes
|
||||
add_custom_command(
|
||||
OUTPUT ${ALL_SOURCES_${target}}
|
||||
COMMAND ${Python3_EXECUTABLE} ${STREAMK_GEN_SCRIPT}
|
||||
--types_header ${STREAMK_TYPES_HEADER}
|
||||
--output_dir ${GEN_DIR}
|
||||
--target ${target}
|
||||
--gen_files
|
||||
DEPENDS ${STREAMK_TYPES_HEADER} ${STREAMK_GEN_SCRIPT}
|
||||
COMMENT "Generating StreamK ${target} test sources from types header")
|
||||
|
||||
# Filter out fp8/bf8 sources on gfx90a since those types are not natively supported
|
||||
set(FILTERED_SOURCES)
|
||||
foreach(src IN LISTS ALL_SOURCES_${target})
|
||||
if(NOT src MATCHES "_(fp8|bf8)_" OR GPU_TARGETS MATCHES "gfx942|gfx950")
|
||||
list(APPEND FILTERED_SOURCES ${src})
|
||||
endif()
|
||||
endforeach()
|
||||
list(APPEND FILTERED_SOURCES test_gemm_streamk_util.cpp)
|
||||
|
||||
add_gtest_executable(${EXEC_NAME} ${FILTERED_SOURCES})
|
||||
target_compile_options(${EXEC_NAME} PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
list(APPEND CK_TILE_GEMM_STREAMK_TEST_TARGETS ${EXEC_NAME})
|
||||
endforeach()
|
||||
|
||||
# Add python unit tests to validate the code gen logic in generate_test_files.py
|
||||
add_test(
|
||||
NAME test_ck_tile_streamk_generate_test_files
|
||||
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/test_generate_test_files.py -v
|
||||
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/..
|
||||
)
|
||||
|
||||
# Label all ck_tile gemm_streamk tests with CK_TILE_GEMM_STREAMK_TESTS for selective execution
|
||||
foreach(test_target ${CK_TILE_GEMM_STREAMK_TEST_TARGETS})
|
||||
set_tests_properties(${test_target} PROPERTIES LABELS "CK_TILE_GEMM_STREAMK_TESTS")
|
||||
endforeach()
|
||||
# Also label the Python test
|
||||
set_tests_properties(test_ck_tile_streamk_generate_test_files PROPERTIES LABELS "CK_TILE_GEMM_STREAMK_TESTS")
|
||||
|
||||
# Umbrella target to build and run all ck_tile gemm_streamk tests
|
||||
# Usage: ninja ck_tile_gemm_streamk_tests
|
||||
|
||||
@@ -1,18 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestCkTileStreamKBf16NonPersistentCompV3 : public TestCkTileStreamK<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileStreamKBf16NonPersistentCompV3
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileStreamKBf16NonPersistentCompV3,
|
||||
KernelTypesStreamKBf16NonPersistentCompV3);
|
||||
|
||||
#include "test_gemm_streamk_extended_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
@@ -1,18 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestCkTileStreamKBf16NonPersistentCompV4 : public TestCkTileStreamK<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileStreamKBf16NonPersistentCompV4
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileStreamKBf16NonPersistentCompV4,
|
||||
KernelTypesStreamKBf16NonPersistentCompV4);
|
||||
|
||||
#include "test_gemm_streamk_extended_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
@@ -1,17 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestCkTileStreamKBf16NonPersistentMem : public TestCkTileStreamK<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileStreamKBf16NonPersistentMem
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileStreamKBf16NonPersistentMem, KernelTypesStreamKBf16NonPersistentMem);
|
||||
|
||||
#include "test_gemm_streamk_extended_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
@@ -1,17 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestCkTileStreamKBf16PersistentCompV3 : public TestCkTileStreamK<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileStreamKBf16PersistentCompV3
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileStreamKBf16PersistentCompV3, KernelTypesStreamKBf16PersistentCompV3);
|
||||
|
||||
#include "test_gemm_streamk_extended_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
@@ -1,17 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestCkTileStreamKBf16PersistentCompV4 : public TestCkTileStreamK<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileStreamKBf16PersistentCompV4
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileStreamKBf16PersistentCompV4, KernelTypesStreamKBf16PersistentCompV4);
|
||||
|
||||
#include "test_gemm_streamk_extended_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
@@ -1,17 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestCkTileStreamKBf16PersistentMem : public TestCkTileStreamK<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileStreamKBf16PersistentMem
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileStreamKBf16PersistentMem, KernelTypesStreamKBf16PersistentMem);
|
||||
|
||||
#include "test_gemm_streamk_extended_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
@@ -1,17 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestCkTileStreamKBf8NonPersistentCompV3 : public TestCkTileStreamK<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileStreamKBf8NonPersistentCompV3
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileStreamKBf8NonPersistentCompV3, KernelTypesStreamKBf8NonPersistentCompV3);
|
||||
|
||||
#include "test_gemm_streamk_extended_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
@@ -1,17 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestCkTileStreamKBf8NonPersistentCompV4 : public TestCkTileStreamK<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileStreamKBf8NonPersistentCompV4
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileStreamKBf8NonPersistentCompV4, KernelTypesStreamKBf8NonPersistentCompV4);
|
||||
|
||||
#include "test_gemm_streamk_extended_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
@@ -1,17 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestCkTileStreamKBf8NonPersistentMem : public TestCkTileStreamK<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileStreamKBf8NonPersistentMem
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileStreamKBf8NonPersistentMem, KernelTypesStreamKBf8NonPersistentMem);
|
||||
|
||||
#include "test_gemm_streamk_extended_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
@@ -1,17 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestCkTileStreamKBf8PersistentCompV3 : public TestCkTileStreamK<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileStreamKBf8PersistentCompV3
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileStreamKBf8PersistentCompV3, KernelTypesStreamKBf8PersistentCompV3);
|
||||
|
||||
#include "test_gemm_streamk_extended_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
@@ -1,17 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestCkTileStreamKBf8PersistentCompV4 : public TestCkTileStreamK<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileStreamKBf8PersistentCompV4
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileStreamKBf8PersistentCompV4, KernelTypesStreamKBf8PersistentCompV4);
|
||||
|
||||
#include "test_gemm_streamk_extended_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
@@ -1,17 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestCkTileStreamKBf8PersistentMem : public TestCkTileStreamK<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileStreamKBf8PersistentMem
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileStreamKBf8PersistentMem, KernelTypesStreamKBf8PersistentMem);
|
||||
|
||||
#include "test_gemm_streamk_extended_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
@@ -1,18 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestCkTileStreamKFp16NonPersistentCompV3 : public TestCkTileStreamK<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileStreamKFp16NonPersistentCompV3
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileStreamKFp16NonPersistentCompV3,
|
||||
KernelTypesStreamKFp16NonPersistentCompV3);
|
||||
|
||||
#include "test_gemm_streamk_extended_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
@@ -1,18 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestCkTileStreamKFp16NonPersistentCompV4 : public TestCkTileStreamK<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileStreamKFp16NonPersistentCompV4
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileStreamKFp16NonPersistentCompV4,
|
||||
KernelTypesStreamKFp16NonPersistentCompV4);
|
||||
|
||||
#include "test_gemm_streamk_extended_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
@@ -1,17 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestCkTileStreamKFp16NonPersistentMem : public TestCkTileStreamK<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileStreamKFp16NonPersistentMem
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileStreamKFp16NonPersistentMem, KernelTypesStreamKFp16NonPersistentMem);
|
||||
|
||||
#include "test_gemm_streamk_extended_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
@@ -1,17 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestCkTileStreamKFp16PersistentCompV3 : public TestCkTileStreamK<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileStreamKFp16PersistentCompV3
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileStreamKFp16PersistentCompV3, KernelTypesStreamKFp16PersistentCompV3);
|
||||
|
||||
#include "test_gemm_streamk_extended_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
@@ -1,17 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestCkTileStreamKFp16PersistentCompV4 : public TestCkTileStreamK<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileStreamKFp16PersistentCompV4
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileStreamKFp16PersistentCompV4, KernelTypesStreamKFp16PersistentCompV4);
|
||||
|
||||
#include "test_gemm_streamk_extended_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
@@ -1,17 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestCkTileStreamKFp16PersistentMem : public TestCkTileStreamK<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileStreamKFp16PersistentMem
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileStreamKFp16PersistentMem, KernelTypesStreamKFp16PersistentMem);
|
||||
|
||||
#include "test_gemm_streamk_extended_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
@@ -1,17 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestCkTileStreamKFp8NonPersistentCompV3 : public TestCkTileStreamK<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileStreamKFp8NonPersistentCompV3
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileStreamKFp8NonPersistentCompV3, KernelTypesStreamKFp8NonPersistentCompV3);
|
||||
|
||||
#include "test_gemm_streamk_extended_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
@@ -1,17 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestCkTileStreamKFp8NonPersistentCompV4 : public TestCkTileStreamK<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileStreamKFp8NonPersistentCompV4
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileStreamKFp8NonPersistentCompV4, KernelTypesStreamKFp8NonPersistentCompV4);
|
||||
|
||||
#include "test_gemm_streamk_extended_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
@@ -1,17 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestCkTileStreamKFp8NonPersistentMem : public TestCkTileStreamK<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileStreamKFp8NonPersistentMem
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileStreamKFp8NonPersistentMem, KernelTypesStreamKFp8NonPersistentMem);
|
||||
|
||||
#include "test_gemm_streamk_extended_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
@@ -1,17 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestCkTileStreamKFp8PersistentCompV3 : public TestCkTileStreamK<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileStreamKFp8PersistentCompV3
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileStreamKFp8PersistentCompV3, KernelTypesStreamKFp8PersistentCompV3);
|
||||
|
||||
#include "test_gemm_streamk_extended_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
@@ -1,17 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestCkTileStreamKFp8PersistentCompV4 : public TestCkTileStreamK<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileStreamKFp8PersistentCompV4
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileStreamKFp8PersistentCompV4, KernelTypesStreamKFp8PersistentCompV4);
|
||||
|
||||
#include "test_gemm_streamk_extended_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
@@ -1,17 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestCkTileStreamKFp8PersistentMem : public TestCkTileStreamK<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileStreamKFp8PersistentMem
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileStreamKFp8PersistentMem, KernelTypesStreamKFp8PersistentMem);
|
||||
|
||||
#include "test_gemm_streamk_extended_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
215
test/ck_tile/gemm_streamk/generate_test_files.py
Normal file
215
test/ck_tile/gemm_streamk/generate_test_files.py
Normal file
@@ -0,0 +1,215 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""Generate test .cpp files from KernelTypes definitions in
|
||||
test_gemm_streamk_types.hpp.
|
||||
|
||||
Two modes:
|
||||
--list_files FILE Write the list of output file paths to FILE (one per line)
|
||||
without generating the files. Used at CMake configure time.
|
||||
--gen_files Actually emit the .cpp files into --output_dir.
|
||||
Used at build time via add_custom_command.
|
||||
|
||||
Target selection (--target):
|
||||
extended Kernel types containing 'Atomic' or 'Pipelines'
|
||||
-> includes test_gemm_streamk_extended_cases.inc
|
||||
atomic_smoke Kernel types containing 'Atomic' (not 'Pipelines')
|
||||
-> includes test_gemm_streamk_atomic_cases.inc
|
||||
linear_smoke Kernel types containing 'Linear' (not 'Pipelines')
|
||||
-> includes test_gemm_streamk_reduction_cases.inc
|
||||
tree_smoke Kernel types containing 'Tree' (not 'Pipelines')
|
||||
-> includes test_gemm_streamk_reduction_cases.inc
|
||||
pipelines_smoke Kernel types matching 'Pipelines'
|
||||
-> includes test_gemm_streamk_reduction_cases.inc
|
||||
and test_gemm_streamk_atomic_cases.inc
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# Template for every generated .cpp file
|
||||
# --------------------------------------------------------------------------- #
|
||||
CPP_TEMPLATE = """\
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class {class_name} : public TestCkTileStreamK<Tuple>
|
||||
{{
|
||||
}};
|
||||
|
||||
#define TEST_SUITE_NAME {class_name}
|
||||
|
||||
TYPED_TEST_SUITE({class_name}, {type_alias});
|
||||
|
||||
{inc_includes}
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
"""
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# Target definitions: filter predicate and .inc files
|
||||
# --------------------------------------------------------------------------- #
|
||||
TARGETS = {
|
||||
"extended": {
|
||||
"filter": lambda suffix: "Atomic" in suffix or suffix == "Pipelines",
|
||||
"inc_files": ["test_gemm_streamk_extended_cases.inc"],
|
||||
},
|
||||
"atomic_smoke": {
|
||||
"filter": lambda suffix: "Atomic" in suffix and suffix != "Pipelines",
|
||||
"inc_files": ["test_gemm_streamk_atomic_cases.inc"],
|
||||
},
|
||||
"linear_smoke": {
|
||||
"filter": lambda suffix: "Linear" in suffix and suffix != "Pipelines",
|
||||
"inc_files": ["test_gemm_streamk_reduction_cases.inc"],
|
||||
},
|
||||
"tree_smoke": {
|
||||
"filter": lambda suffix: "Tree" in suffix and suffix != "Pipelines",
|
||||
"inc_files": ["test_gemm_streamk_reduction_cases.inc"],
|
||||
},
|
||||
"pipelines_smoke": {
|
||||
"filter": lambda suffix: suffix == "Pipelines",
|
||||
"inc_files": [
|
||||
"test_gemm_streamk_reduction_cases.inc",
|
||||
"test_gemm_streamk_atomic_cases.inc",
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# Mapping from CamelCase suffix fragments to file-name fragments
|
||||
# --------------------------------------------------------------------------- #
|
||||
KNOWN_TOKENS = [
|
||||
("Fp16", "fp16"),
|
||||
("Bf16", "bf16"),
|
||||
("Fp8", "fp8"),
|
||||
("Bf8", "bf8"),
|
||||
("NonPersistent", "nonpersistent"),
|
||||
("Persistent", "persistent"),
|
||||
("Atomic", "atomic"),
|
||||
("Linear", "linear"),
|
||||
("Tree", "tree"),
|
||||
("CompV3", "compv3"),
|
||||
("Pipelines", "pipelines"),
|
||||
]
|
||||
|
||||
|
||||
def suffix_to_file_tag(suffix: str) -> str:
|
||||
"""Convert a CamelCase suffix like 'Fp16PersistentAtomicCompV3' to
|
||||
'fp16_persistent_atomic_compv3'."""
|
||||
parts: list[str] = []
|
||||
remaining = suffix
|
||||
while remaining:
|
||||
matched = False
|
||||
for token, replacement in KNOWN_TOKENS:
|
||||
if remaining.startswith(token):
|
||||
parts.append(replacement)
|
||||
remaining = remaining[len(token) :]
|
||||
matched = True
|
||||
break
|
||||
if not matched:
|
||||
raise ValueError(
|
||||
f"Unrecognised token in KernelTypes suffix: '{remaining}' "
|
||||
f"(from '{suffix}')"
|
||||
)
|
||||
return "_".join(parts)
|
||||
|
||||
|
||||
def parse_types_header(header_path: str, target: str) -> list[dict]:
|
||||
"""Return a list of dicts with keys: type_alias, class_name, file_tag, suffix."""
|
||||
target_def = TARGETS[target]
|
||||
# Pattern matches lines like: using KernelTypesStreamKFp16PersistentAtomicCompV3 = ...
|
||||
pattern = re.compile(r"using\s+(KernelTypesStreamK(\w+))\s*=")
|
||||
entries: list[dict] = []
|
||||
with open(header_path) as f:
|
||||
for line in f:
|
||||
match = pattern.search(line)
|
||||
if match:
|
||||
# If the match is: using KernelTypesStreamKFp16PersistentAtomicCompV3 = ...
|
||||
# type_alias is KernelTypesStreamKFp16PersistentAtomicCompV3
|
||||
# suffix is Fp16PersistentAtomicCompV3
|
||||
type_alias = match.group(1)
|
||||
suffix = match.group(2)
|
||||
if not target_def["filter"](suffix):
|
||||
continue
|
||||
entries.append(
|
||||
{
|
||||
"type_alias": type_alias,
|
||||
"class_name": f"TestCkTileStreamK{suffix}",
|
||||
"file_tag": suffix_to_file_tag(suffix),
|
||||
}
|
||||
)
|
||||
return entries
|
||||
|
||||
|
||||
def output_path(output_dir: str, entry: dict) -> str:
|
||||
return os.path.join(output_dir, f"test_gemm_streamk_{entry['file_tag']}.cpp")
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument(
|
||||
"--types_header", required=True, help="Path to test_gemm_streamk_types.hpp"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir", required=True, help="Directory for generated .cpp files"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--target",
|
||||
required=True,
|
||||
choices=list(TARGETS.keys()),
|
||||
help="Which target to generate files for",
|
||||
)
|
||||
group = parser.add_mutually_exclusive_group(required=True)
|
||||
group.add_argument(
|
||||
"--list_files",
|
||||
metavar="FILE",
|
||||
help="Write output file paths to FILE then exit",
|
||||
)
|
||||
group.add_argument(
|
||||
"--gen_files", action="store_true", help="Generate the .cpp files"
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
|
||||
entries = parse_types_header(args.types_header, args.target)
|
||||
if not entries:
|
||||
print(
|
||||
f"ERROR: no KernelTypesStreamK* definitions found for target "
|
||||
f"'{args.target}' in {args.types_header}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
inc_files = TARGETS[args.target]["inc_files"]
|
||||
inc_includes = "\n".join(f'#include "{f}"' for f in inc_files)
|
||||
|
||||
if args.list_files:
|
||||
os.makedirs(os.path.dirname(args.list_files) or ".", exist_ok=True)
|
||||
with open(args.list_files, "w") as f:
|
||||
for entry in entries:
|
||||
f.write(output_path(args.output_dir, entry) + "\n")
|
||||
else:
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
for entry in entries:
|
||||
path = output_path(args.output_dir, entry)
|
||||
content = CPP_TEMPLATE.format(
|
||||
class_name=entry["class_name"],
|
||||
type_alias=entry["type_alias"],
|
||||
inc_includes=inc_includes,
|
||||
)
|
||||
with open(path, "w") as f:
|
||||
f.write(content)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
47
test/ck_tile/gemm_streamk/test_gemm_streamk_atomic_cases.inc
Normal file
47
test/ck_tile/gemm_streamk/test_gemm_streamk_atomic_cases.inc
Normal file
@@ -0,0 +1,47 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
TYPED_TEST(TEST_SUITE_NAME, StreamK_EdgeCase)
|
||||
{
|
||||
ck_tile::index_t M = 256;
|
||||
ck_tile::index_t N = 256;
|
||||
ck_tile::index_t K = 256;
|
||||
|
||||
this->Run(M, N, K);
|
||||
}
|
||||
|
||||
TYPED_TEST(TEST_SUITE_NAME, StreamK_DPOnly)
|
||||
{
|
||||
const ck_tile::index_t num_cu = get_cu_count();
|
||||
constexpr ck_tile::index_t M_Tile = std::tuple_element_t<7, TypeParam>::value;
|
||||
constexpr ck_tile::index_t N_Tile = std::tuple_element_t<8, TypeParam>::value;
|
||||
constexpr ck_tile::index_t K_Tile = std::tuple_element_t<9, TypeParam>::value;
|
||||
|
||||
// For DP only, we ensure that the number of tiles is a multiple of the number of CUs. This
|
||||
// assumes tile sizes are large enough such that occupancy is 1.
|
||||
ck_tile::index_t M = M_Tile * num_cu;
|
||||
ck_tile::index_t N = N_Tile;
|
||||
ck_tile::index_t K = K_Tile;
|
||||
|
||||
this->Run(M, N, K);
|
||||
}
|
||||
|
||||
TYPED_TEST(TEST_SUITE_NAME, StreamK_SKOnly)
|
||||
{
|
||||
const ck_tile::index_t num_cu = get_cu_count();
|
||||
constexpr ck_tile::index_t M_Tile = std::tuple_element_t<7, TypeParam>::value;
|
||||
constexpr ck_tile::index_t N_Tile = std::tuple_element_t<8, TypeParam>::value;
|
||||
constexpr ck_tile::index_t K_Tile = std::tuple_element_t<9, TypeParam>::value;
|
||||
|
||||
// For SK only, we have 4 macro tiles in C. But, we need to make sure there is enough work along
|
||||
// the K dimension to avoid falling into the edge case. Thus, we always have at least num_cu
|
||||
// macro tiles in the K dimension. This assumes tile sizes are large enough such that occupancy
|
||||
// is 1.
|
||||
ck_tile::index_t M = M_Tile * 2;
|
||||
ck_tile::index_t N = N_Tile * 2;
|
||||
ck_tile::index_t K = K_Tile * num_cu;
|
||||
|
||||
this->Run(M, N, K);
|
||||
}
|
||||
@@ -0,0 +1,46 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
TYPED_TEST(TEST_SUITE_NAME, StreamK_SKOnly_OneTile)
|
||||
{
|
||||
const ck_tile::index_t num_cu = get_cu_count();
|
||||
constexpr ck_tile::index_t M_Tile = std::tuple_element_t<7, TypeParam>::value;
|
||||
constexpr ck_tile::index_t N_Tile = std::tuple_element_t<8, TypeParam>::value;
|
||||
constexpr ck_tile::index_t K_Tile = std::tuple_element_t<9, TypeParam>::value;
|
||||
|
||||
ck_tile::index_t M = M_Tile;
|
||||
ck_tile::index_t N = N_Tile;
|
||||
ck_tile::index_t K = K_Tile * num_cu;
|
||||
|
||||
this->Run(M, N, K);
|
||||
}
|
||||
|
||||
TYPED_TEST(TEST_SUITE_NAME, StreamK_SKOnly_4Tiles_Reduction)
|
||||
{
|
||||
const ck_tile::index_t num_cu = get_cu_count();
|
||||
constexpr ck_tile::index_t M_Tile = std::tuple_element_t<7, TypeParam>::value;
|
||||
constexpr ck_tile::index_t N_Tile = std::tuple_element_t<8, TypeParam>::value;
|
||||
constexpr ck_tile::index_t K_Tile = std::tuple_element_t<9, TypeParam>::value;
|
||||
|
||||
ck_tile::index_t M = M_Tile * 4;
|
||||
ck_tile::index_t N = N_Tile;
|
||||
ck_tile::index_t K = K_Tile * num_cu + (25 * K_Tile);
|
||||
|
||||
this->Run(M, N, K);
|
||||
}
|
||||
|
||||
TYPED_TEST(TEST_SUITE_NAME, StreamK_SKOnly_21Tiles)
|
||||
{
|
||||
const ck_tile::index_t num_cu = get_cu_count();
|
||||
constexpr ck_tile::index_t M_Tile = std::tuple_element_t<7, TypeParam>::value;
|
||||
constexpr ck_tile::index_t N_Tile = std::tuple_element_t<8, TypeParam>::value;
|
||||
constexpr ck_tile::index_t K_Tile = std::tuple_element_t<9, TypeParam>::value;
|
||||
|
||||
ck_tile::index_t M = M_Tile * 3;
|
||||
ck_tile::index_t N = N_Tile * 7;
|
||||
ck_tile::index_t K = K_Tile * num_cu + (30 * K_Tile);
|
||||
|
||||
this->Run(M, N, K);
|
||||
}
|
||||
@@ -14,16 +14,28 @@ using BF16 = ck_tile::bf16_t;
|
||||
using BF8 = ck_tile::bf8_t;
|
||||
using F32 = float;
|
||||
|
||||
// Layouts
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
// Persistence
|
||||
using Persistent = std::true_type;
|
||||
using NonPersistent = std::false_type;
|
||||
|
||||
// Pipelines
|
||||
using Mem = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::Mem>;
|
||||
using CompV3 = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::CompV3>;
|
||||
using CompV4 = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::CompV4>;
|
||||
|
||||
using Persistent = std::true_type;
|
||||
using NonPersistent = std::false_type;
|
||||
// Reduction Strategies
|
||||
using Atomic = ck_tile::integral_constant<ck_tile::StreamKReductionStrategy,
|
||||
ck_tile::StreamKReductionStrategy::Atomic>;
|
||||
using Linear = ck_tile::integral_constant<ck_tile::StreamKReductionStrategy,
|
||||
ck_tile::StreamKReductionStrategy::Linear>;
|
||||
using Tree = ck_tile::integral_constant<ck_tile::StreamKReductionStrategy,
|
||||
ck_tile::StreamKReductionStrategy::Tree>;
|
||||
|
||||
using I16 = ck_tile::number<16>;
|
||||
using I32 = ck_tile::number<32>;
|
||||
using I128 = ck_tile::number<128>;
|
||||
using I256 = ck_tile::number<256>;
|
||||
@@ -32,180 +44,157 @@ using I256 = ck_tile::number<256>;
|
||||
|
||||
// ========================== CompV3 Pipeline ==========================
|
||||
|
||||
using KernelTypesStreamKFp16PersistentCompV3 = ::testing::Types<
|
||||
// ALayout BLayout CLayout ADataType BDataType AccDataType CDataType M_MacroTile N_MacroTile K_MacroTile Persistent Pipeline
|
||||
// Atomics
|
||||
using KernelTypesStreamKFp16PersistentAtomicCompV3 = ::testing::Types<
|
||||
// ALayout BLayout CLayout ADataType BDataType AccDataType CDataType M_MacroTile N_MacroTile K_MacroTile M_WaveTile N_WaveTile K_WaveTile Persistent Pipeline ReductionStrategy
|
||||
|
||||
std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, Persistent, CompV3>,
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, Persistent, CompV3>,
|
||||
std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, Persistent, CompV3>,
|
||||
std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, Persistent, CompV3>
|
||||
std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Persistent, CompV3, Atomic>,
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Persistent, CompV3, Atomic>,
|
||||
std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Persistent, CompV3, Atomic>,
|
||||
std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Persistent, CompV3, Atomic>
|
||||
>;
|
||||
|
||||
using KernelTypesStreamKBf16PersistentCompV3 = ::testing::Types<
|
||||
std::tuple< Row, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent, CompV3>,
|
||||
std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent, CompV3>,
|
||||
std::tuple< Col, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent, CompV3>,
|
||||
std::tuple< Col, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent, CompV3>
|
||||
using KernelTypesStreamKBf16PersistentAtomicCompV3 = ::testing::Types<
|
||||
std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, I32, I32, I16, Persistent, CompV3, Atomic>
|
||||
>;
|
||||
|
||||
using KernelTypesStreamKBf8PersistentCompV3 = ::testing::Types<
|
||||
std::tuple< Row, Row, Row, BF8, BF8, F32, BF16, I128, I128, I32, Persistent, CompV3>,
|
||||
std::tuple< Row, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, Persistent, CompV3>,
|
||||
std::tuple< Col, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, Persistent, CompV3>,
|
||||
std::tuple< Col, Row, Row, BF8, BF8, F32, BF16, I128, I128, I32, Persistent, CompV3>
|
||||
using KernelTypesStreamKBf8PersistentAtomicCompV3 = ::testing::Types<
|
||||
std::tuple< Row, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, I32, I32, I16, Persistent, CompV3, Atomic>
|
||||
>;
|
||||
|
||||
using KernelTypesStreamKFp8PersistentCompV3 = ::testing::Types<
|
||||
std::tuple< Row, Row, Row, F8, F8, F32, F16, I128, I128, I32, Persistent, CompV3>,
|
||||
std::tuple< Row, Col, Row, F8, F8, F32, F16, I128, I128, I32, Persistent, CompV3>,
|
||||
std::tuple< Col, Col, Row, F8, F8, F32, F16, I128, I128, I32, Persistent, CompV3>,
|
||||
std::tuple< Col, Row, Row, F8, F8, F32, F16, I128, I128, I32, Persistent, CompV3>
|
||||
using KernelTypesStreamKFp8PersistentAtomicCompV3 = ::testing::Types<
|
||||
std::tuple< Row, Row, Row, F8, F8, F32, F16, I128, I128, I32, I32, I32, I16, Persistent, CompV3, Atomic>,
|
||||
std::tuple< Row, Col, Row, F8, F8, F32, F16, I128, I128, I32, I32, I32, I16, Persistent, CompV3, Atomic>,
|
||||
std::tuple< Col, Col, Row, F8, F8, F32, F16, I128, I128, I32, I32, I32, I16, Persistent, CompV3, Atomic>,
|
||||
std::tuple< Col, Row, Row, F8, F8, F32, F16, I128, I128, I32, I32, I32, I16, Persistent, CompV3, Atomic>
|
||||
>;
|
||||
|
||||
using KernelTypesStreamKFp16NonPersistentCompV3 = ::testing::Types<
|
||||
std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent, CompV3>,
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent, CompV3>,
|
||||
std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent, CompV3>,
|
||||
std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent, CompV3>
|
||||
using KernelTypesStreamKFp16NonPersistentAtomicCompV3 = ::testing::Types<
|
||||
std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, NonPersistent, CompV3, Atomic>,
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, NonPersistent, CompV3, Atomic>,
|
||||
std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, NonPersistent, CompV3, Atomic>,
|
||||
std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, NonPersistent, CompV3, Atomic>
|
||||
>;
|
||||
|
||||
using KernelTypesStreamKBf16NonPersistentCompV3 = ::testing::Types<
|
||||
std::tuple< Row, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent, CompV3>,
|
||||
std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent, CompV3>,
|
||||
std::tuple< Col, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent, CompV3>,
|
||||
std::tuple< Col, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent, CompV3>
|
||||
using KernelTypesStreamKBf16NonPersistentAtomicCompV3 = ::testing::Types<
|
||||
std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, I32, I32, I16, NonPersistent, CompV3, Atomic>
|
||||
>;
|
||||
|
||||
using KernelTypesStreamKBf8NonPersistentCompV3 = ::testing::Types<
|
||||
std::tuple< Row, Row, Row, BF8, BF8, F32, BF16, I128, I128, I32, NonPersistent, CompV3>,
|
||||
std::tuple< Row, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, NonPersistent, CompV3>,
|
||||
std::tuple< Col, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, NonPersistent, CompV3>,
|
||||
std::tuple< Col, Row, Row, BF8, BF8, F32, BF16, I128, I128, I32, NonPersistent, CompV3>
|
||||
using KernelTypesStreamKBf8NonPersistentAtomicCompV3 = ::testing::Types<
|
||||
std::tuple< Row, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, I32, I32, I16, NonPersistent, CompV3, Atomic>
|
||||
>;
|
||||
|
||||
using KernelTypesStreamKFp8NonPersistentCompV3 = ::testing::Types<
|
||||
std::tuple< Row, Row, Row, F8, F8, F32, F16, I128, I128, I32, NonPersistent, CompV3>,
|
||||
std::tuple< Row, Col, Row, F8, F8, F32, F16, I128, I128, I32, NonPersistent, CompV3>,
|
||||
std::tuple< Col, Col, Row, F8, F8, F32, F16, I128, I128, I32, NonPersistent, CompV3>,
|
||||
std::tuple< Col, Row, Row, F8, F8, F32, F16, I128, I128, I32, NonPersistent, CompV3>
|
||||
using KernelTypesStreamKFp8NonPersistentAtomicCompV3 = ::testing::Types<
|
||||
std::tuple< Row, Row, Row, F8, F8, F32, F16, I128, I128, I32, I32, I32, I16, NonPersistent, CompV3, Atomic>,
|
||||
std::tuple< Row, Col, Row, F8, F8, F32, F16, I128, I128, I32, I32, I32, I16, NonPersistent, CompV3, Atomic>,
|
||||
std::tuple< Col, Col, Row, F8, F8, F32, F16, I128, I128, I32, I32, I32, I16, NonPersistent, CompV3, Atomic>,
|
||||
std::tuple< Col, Row, Row, F8, F8, F32, F16, I128, I128, I32, I32, I32, I16, NonPersistent, CompV3, Atomic>
|
||||
>;
|
||||
|
||||
// ========================== CompV4 Pipeline ==========================
|
||||
// Linear
|
||||
using KernelTypesStreamKFp16PersistentLinearCompV3 = ::testing::Types<
|
||||
// ALayout BLayout CLayout ADataType BDataType AccDataType CDataType M_MacroTile N_MacroTile K_MacroTile M_WaveTile N_WaveTile K_WaveTile Persistent Pipeline ReductionStrategy
|
||||
|
||||
using KernelTypesStreamKFp16PersistentCompV4 = ::testing::Types<
|
||||
// ALayout BLayout CLayout ADataType BDataType AccDataType CDataType M_MacroTile N_MacroTile K_MacroTile Persistent Pipeline
|
||||
|
||||
std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, Persistent, CompV4>,
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, Persistent, CompV4>,
|
||||
std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, Persistent, CompV4>,
|
||||
std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, Persistent, CompV4>
|
||||
std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Persistent, CompV3, Linear>,
|
||||
std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, I16, I16, I16, Persistent, CompV3, Linear>,
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Persistent, CompV3, Linear>,
|
||||
std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Persistent, CompV3, Linear>,
|
||||
std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Persistent, CompV3, Linear>
|
||||
>;
|
||||
|
||||
using KernelTypesStreamKBf16PersistentCompV4 = ::testing::Types<
|
||||
std::tuple< Row, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent, CompV4>,
|
||||
std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent, CompV4>,
|
||||
std::tuple< Col, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent, CompV4>,
|
||||
std::tuple< Col, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent, CompV4>
|
||||
using KernelTypesStreamKBf16PersistentLinearCompV3 = ::testing::Types<
|
||||
std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, I32, I32, I16, Persistent, CompV3, Linear>
|
||||
>;
|
||||
|
||||
using KernelTypesStreamKBf8PersistentCompV4 = ::testing::Types<
|
||||
std::tuple< Row, Row, Row, BF8, BF8, F32, BF16, I128, I128, I32, Persistent, CompV4>,
|
||||
std::tuple< Row, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, Persistent, CompV4>,
|
||||
std::tuple< Col, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, Persistent, CompV4>,
|
||||
std::tuple< Col, Row, Row, BF8, BF8, F32, BF16, I128, I128, I32, Persistent, CompV4>
|
||||
using KernelTypesStreamKBf8PersistentLinearCompV3 = ::testing::Types<
|
||||
std::tuple< Row, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, I32, I32, I16, Persistent, CompV3, Linear>
|
||||
>;
|
||||
|
||||
using KernelTypesStreamKFp8PersistentCompV4 = ::testing::Types<
|
||||
std::tuple< Row, Row, Row, F8, F8, F32, F16, I128, I128, I32, Persistent, CompV4>,
|
||||
std::tuple< Row, Col, Row, F8, F8, F32, F16, I128, I128, I32, Persistent, CompV4>,
|
||||
std::tuple< Col, Col, Row, F8, F8, F32, F16, I128, I128, I32, Persistent, CompV4>,
|
||||
std::tuple< Col, Row, Row, F8, F8, F32, F16, I128, I128, I32, Persistent, CompV4>
|
||||
using KernelTypesStreamKFp8PersistentLinearCompV3 = ::testing::Types<
|
||||
std::tuple< Row, Row, Row, F8, F8, F32, F16, I128, I128, I32, I32, I32, I16, Persistent, CompV3, Linear>,
|
||||
std::tuple< Row, Col, Row, F8, F8, F32, F16, I128, I128, I32, I32, I32, I16, Persistent, CompV3, Linear>,
|
||||
std::tuple< Col, Col, Row, F8, F8, F32, F16, I128, I128, I32, I32, I32, I16, Persistent, CompV3, Linear>,
|
||||
std::tuple< Col, Row, Row, F8, F8, F32, F16, I128, I128, I32, I32, I32, I16, Persistent, CompV3, Linear>
|
||||
>;
|
||||
|
||||
using KernelTypesStreamKFp16NonPersistentCompV4 = ::testing::Types<
|
||||
std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent, CompV4>,
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent, CompV4>,
|
||||
std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent, CompV4>,
|
||||
std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent, CompV4>
|
||||
using KernelTypesStreamKFp16NonPersistentLinearCompV3 = ::testing::Types<
|
||||
std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, NonPersistent, CompV3, Linear>,
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, NonPersistent, CompV3, Linear>,
|
||||
std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, NonPersistent, CompV3, Linear>,
|
||||
std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, NonPersistent, CompV3, Linear>
|
||||
>;
|
||||
|
||||
using KernelTypesStreamKBf16NonPersistentCompV4 = ::testing::Types<
|
||||
std::tuple< Row, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent, CompV4>,
|
||||
std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent, CompV4>,
|
||||
std::tuple< Col, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent, CompV4>,
|
||||
std::tuple< Col, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent, CompV4>
|
||||
using KernelTypesStreamKBf16NonPersistentLinearCompV3 = ::testing::Types<
|
||||
std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, I32, I32, I16, NonPersistent, CompV3, Linear>
|
||||
>;
|
||||
|
||||
using KernelTypesStreamKBf8NonPersistentCompV4 = ::testing::Types<
|
||||
std::tuple< Row, Row, Row, BF8, BF8, F32, BF16, I128, I128, I32, NonPersistent, CompV4>,
|
||||
std::tuple< Row, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, NonPersistent, CompV4>,
|
||||
std::tuple< Col, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, NonPersistent, CompV4>,
|
||||
std::tuple< Col, Row, Row, BF8, BF8, F32, BF16, I128, I128, I32, NonPersistent, CompV4>
|
||||
using KernelTypesStreamKBf8NonPersistentLinearCompV3 = ::testing::Types<
|
||||
std::tuple< Row, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, I32, I32, I16, NonPersistent, CompV3, Linear>
|
||||
>;
|
||||
|
||||
using KernelTypesStreamKFp8NonPersistentCompV4 = ::testing::Types<
|
||||
std::tuple< Row, Row, Row, F8, F8, F32, F16, I128, I128, I32, NonPersistent, CompV4>,
|
||||
std::tuple< Row, Col, Row, F8, F8, F32, F16, I128, I128, I32, NonPersistent, CompV4>,
|
||||
std::tuple< Col, Col, Row, F8, F8, F32, F16, I128, I128, I32, NonPersistent, CompV4>,
|
||||
std::tuple< Col, Row, Row, F8, F8, F32, F16, I128, I128, I32, NonPersistent, CompV4>
|
||||
using KernelTypesStreamKFp8NonPersistentLinearCompV3 = ::testing::Types<
|
||||
std::tuple< Row, Row, Row, F8, F8, F32, F16, I128, I128, I32, I32, I32, I16, NonPersistent, CompV3, Linear>,
|
||||
std::tuple< Row, Col, Row, F8, F8, F32, F16, I128, I128, I32, I32, I32, I16, NonPersistent, CompV3, Linear>,
|
||||
std::tuple< Col, Col, Row, F8, F8, F32, F16, I128, I128, I32, I32, I32, I16, NonPersistent, CompV3, Linear>,
|
||||
std::tuple< Col, Row, Row, F8, F8, F32, F16, I128, I128, I32, I32, I32, I16, NonPersistent, CompV3, Linear>
|
||||
>;
|
||||
|
||||
// ============================= Mem Pipeline =============================
|
||||
// Tree
|
||||
using KernelTypesStreamKFp16PersistentTreeCompV3 = ::testing::Types<
|
||||
// ALayout BLayout CLayout ADataType BDataType AccDataType CDataType M_MacroTile N_MacroTile K_MacroTile M_WaveTile N_WaveTile K_WaveTile Persistent Pipeline ReductionStrategy
|
||||
|
||||
using KernelTypesStreamKFp16PersistentMem = ::testing::Types<
|
||||
std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, Persistent, Mem>,
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, Persistent, Mem>,
|
||||
std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, Persistent, Mem>,
|
||||
std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, Persistent, Mem>
|
||||
std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Persistent, CompV3, Tree>,
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Persistent, CompV3, Tree>,
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, I16, I16, I16, Persistent, CompV3, Tree>,
|
||||
std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Persistent, CompV3, Tree>,
|
||||
std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Persistent, CompV3, Tree>
|
||||
>;
|
||||
|
||||
using KernelTypesStreamKBf16PersistentMem = ::testing::Types<
|
||||
std::tuple< Row, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent, Mem>,
|
||||
std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent, Mem>,
|
||||
std::tuple< Col, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent, Mem>,
|
||||
std::tuple< Col, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent, Mem>
|
||||
using KernelTypesStreamKBf16PersistentTreeCompV3 = ::testing::Types<
|
||||
std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, I32, I32, I16, Persistent, CompV3, Tree>
|
||||
>;
|
||||
|
||||
using KernelTypesStreamKBf8PersistentMem = ::testing::Types<
|
||||
std::tuple< Row, Row, Row, BF8, BF8, F32, BF16, I128, I128, I32, Persistent, Mem>,
|
||||
std::tuple< Row, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, Persistent, Mem>,
|
||||
std::tuple< Col, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, Persistent, Mem>,
|
||||
std::tuple< Col, Row, Row, BF8, BF8, F32, BF16, I128, I128, I32, Persistent, Mem>
|
||||
using KernelTypesStreamKBf8PersistentTreeCompV3 = ::testing::Types<
|
||||
std::tuple< Row, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, I32, I32, I16, Persistent, CompV3, Tree>
|
||||
>;
|
||||
|
||||
using KernelTypesStreamKFp8PersistentMem = ::testing::Types<
|
||||
std::tuple< Row, Row, Row, F8, F8, F32, F16, I128, I128, I32, Persistent, Mem>,
|
||||
std::tuple< Row, Col, Row, F8, F8, F32, F16, I128, I128, I32, Persistent, Mem>,
|
||||
std::tuple< Col, Col, Row, F8, F8, F32, F16, I128, I128, I32, Persistent, Mem>,
|
||||
std::tuple< Col, Row, Row, F8, F8, F32, F16, I128, I128, I32, Persistent, Mem>
|
||||
using KernelTypesStreamKFp8PersistentTreeCompV3 = ::testing::Types<
|
||||
std::tuple< Row, Row, Row, F8, F8, F32, F16, I128, I128, I32, I32, I32, I16, Persistent, CompV3, Tree>,
|
||||
std::tuple< Row, Col, Row, F8, F8, F32, F16, I128, I128, I32, I32, I32, I16, Persistent, CompV3, Tree>,
|
||||
std::tuple< Col, Col, Row, F8, F8, F32, F16, I128, I128, I32, I32, I32, I16, Persistent, CompV3, Tree>,
|
||||
std::tuple< Col, Row, Row, F8, F8, F32, F16, I128, I128, I32, I32, I32, I16, Persistent, CompV3, Tree>
|
||||
>;
|
||||
|
||||
using KernelTypesStreamKFp16NonPersistentMem = ::testing::Types<
|
||||
std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent, Mem>,
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent, Mem>,
|
||||
std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent, Mem>,
|
||||
std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent, Mem>
|
||||
using KernelTypesStreamKFp16NonPersistentTreeCompV3 = ::testing::Types<
|
||||
std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, NonPersistent, CompV3, Tree>,
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, NonPersistent, CompV3, Tree>,
|
||||
std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, NonPersistent, CompV3, Tree>,
|
||||
std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, NonPersistent, CompV3, Tree>
|
||||
>;
|
||||
|
||||
using KernelTypesStreamKBf16NonPersistentMem = ::testing::Types<
|
||||
std::tuple< Row, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent, Mem>,
|
||||
std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent, Mem>,
|
||||
std::tuple< Col, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent, Mem>,
|
||||
std::tuple< Col, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent, Mem>
|
||||
using KernelTypesStreamKBf16NonPersistentTreeCompV3 = ::testing::Types<
|
||||
std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, I32, I32, I16, NonPersistent, CompV3, Tree>
|
||||
>;
|
||||
|
||||
using KernelTypesStreamKBf8NonPersistentMem = ::testing::Types<
|
||||
std::tuple< Row, Row, Row, BF8, BF8, F32, BF16, I128, I128, I32, NonPersistent, Mem>,
|
||||
std::tuple< Row, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, NonPersistent, Mem>,
|
||||
std::tuple< Col, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, NonPersistent, Mem>,
|
||||
std::tuple< Col, Row, Row, BF8, BF8, F32, BF16, I128, I128, I32, NonPersistent, Mem>
|
||||
using KernelTypesStreamKBf8NonPersistentTreeCompV3 = ::testing::Types<
|
||||
std::tuple< Row, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, I32, I32, I16, NonPersistent, CompV3, Tree>
|
||||
>;
|
||||
|
||||
using KernelTypesStreamKFp8NonPersistentMem = ::testing::Types<
|
||||
std::tuple< Row, Row, Row, F8, F8, F32, F16, I128, I128, I32, NonPersistent, Mem>,
|
||||
std::tuple< Row, Col, Row, F8, F8, F32, F16, I128, I128, I32, NonPersistent, Mem>,
|
||||
std::tuple< Col, Col, Row, F8, F8, F32, F16, I128, I128, I32, NonPersistent, Mem>,
|
||||
std::tuple< Col, Row, Row, F8, F8, F32, F16, I128, I128, I32, NonPersistent, Mem>
|
||||
using KernelTypesStreamKFp8NonPersistentTreeCompV3 = ::testing::Types<
|
||||
std::tuple< Row, Row, Row, F8, F8, F32, F16, I128, I128, I32, I32, I32, I16, NonPersistent, CompV3, Tree>,
|
||||
std::tuple< Row, Col, Row, F8, F8, F32, F16, I128, I128, I32, I32, I32, I16, NonPersistent, CompV3, Tree>,
|
||||
std::tuple< Col, Col, Row, F8, F8, F32, F16, I128, I128, I32, I32, I32, I16, NonPersistent, CompV3, Tree>,
|
||||
std::tuple< Col, Row, Row, F8, F8, F32, F16, I128, I128, I32, I32, I32, I16, NonPersistent, CompV3, Tree>
|
||||
>;
|
||||
|
||||
// ============================= Other Pipelines =============================
|
||||
|
||||
using KernelTypesStreamKPipelines = ::testing::Types<
|
||||
std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Persistent, Mem, Atomic>,
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, NonPersistent, Mem, Tree>,
|
||||
std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Persistent, Mem, Linear>,
|
||||
std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Persistent, CompV4, Atomic>,
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, NonPersistent, CompV4, Tree>,
|
||||
std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Persistent, CompV4, Linear>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
@@ -71,23 +71,27 @@ template <typename Tuple>
|
||||
class TestCkTileStreamK : public ::testing::Test
|
||||
{
|
||||
protected:
|
||||
using ALayout = std::tuple_element_t<0, Tuple>;
|
||||
using BLayout = std::tuple_element_t<1, Tuple>;
|
||||
using CLayout = std::tuple_element_t<2, Tuple>;
|
||||
using ADataType = std::tuple_element_t<3, Tuple>;
|
||||
using BDataType = std::tuple_element_t<4, Tuple>;
|
||||
using AccDataType = std::tuple_element_t<5, Tuple>;
|
||||
using CDataType = std::tuple_element_t<6, Tuple>;
|
||||
using DsLayout = ck_tile::tuple<>;
|
||||
using DsDataType = ck_tile::tuple<>;
|
||||
static constexpr ck_tile::index_t M_Tile = std::tuple_element_t<7, Tuple>::value;
|
||||
static constexpr ck_tile::index_t N_Tile = std::tuple_element_t<8, Tuple>::value;
|
||||
static constexpr ck_tile::index_t K_Tile = std::tuple_element_t<9, Tuple>::value;
|
||||
static constexpr bool Persistent = std::tuple_element_t<10, Tuple>::value;
|
||||
static constexpr auto PipelineType = std::tuple_element_t<11, Tuple>::value;
|
||||
using ALayout = std::tuple_element_t<0, Tuple>;
|
||||
using BLayout = std::tuple_element_t<1, Tuple>;
|
||||
using CLayout = std::tuple_element_t<2, Tuple>;
|
||||
using ADataType = std::tuple_element_t<3, Tuple>;
|
||||
using BDataType = std::tuple_element_t<4, Tuple>;
|
||||
using AccDataType = std::tuple_element_t<5, Tuple>;
|
||||
using CDataType = std::tuple_element_t<6, Tuple>;
|
||||
using DsLayout = ck_tile::tuple<>;
|
||||
using DsDataType = ck_tile::tuple<>;
|
||||
static constexpr ck_tile::index_t M_Tile = std::tuple_element_t<7, Tuple>::value;
|
||||
static constexpr ck_tile::index_t N_Tile = std::tuple_element_t<8, Tuple>::value;
|
||||
static constexpr ck_tile::index_t K_Tile = std::tuple_element_t<9, Tuple>::value;
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = std::tuple_element_t<10, Tuple>::value;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = std::tuple_element_t<11, Tuple>::value;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = std::tuple_element_t<12, Tuple>::value;
|
||||
|
||||
template <ck_tile::StreamKReductionStrategy ReductionStrategy,
|
||||
bool PadM = true,
|
||||
static constexpr bool Persistent = std::tuple_element_t<13, Tuple>::value;
|
||||
static constexpr auto PipelineType = std::tuple_element_t<14, Tuple>::value;
|
||||
static constexpr auto ReductionStrategy = std::tuple_element_t<15, Tuple>::value;
|
||||
|
||||
template <bool PadM = true,
|
||||
bool PadN = true,
|
||||
bool PadK = true,
|
||||
bool Preshuffle = false,
|
||||
@@ -99,10 +103,6 @@ class TestCkTileStreamK : public ::testing::Test
|
||||
constexpr ck_tile::index_t N_Warp = 2;
|
||||
constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
|
||||
constexpr bool kPadM = PadM;
|
||||
constexpr bool kPadN = PadN;
|
||||
constexpr bool kPadK = PadK;
|
||||
@@ -269,8 +269,7 @@ class TestCkTileStreamK : public ::testing::Test
|
||||
stride_C};
|
||||
|
||||
ck_tile::index_t num_accumulations_per_tile =
|
||||
invoke_streamk<ck_tile::StreamKReductionStrategy::Atomic>(
|
||||
args, ck_tile::stream_config{nullptr, false, 0, 0, 1});
|
||||
invoke_streamk<>(args, ck_tile::stream_config{nullptr, false, 0, 0, 1});
|
||||
|
||||
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
|
||||
|
||||
|
||||
220
test/ck_tile/gemm_streamk/test_generate_test_files.py
Normal file
220
test/ck_tile/gemm_streamk/test_generate_test_files.py
Normal file
@@ -0,0 +1,220 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import unittest
|
||||
from unittest.mock import mock_open, patch
|
||||
from generate_test_files import suffix_to_file_tag, parse_types_header, output_path
|
||||
|
||||
# ------------------------------------------------------------ #
|
||||
# Unit tests for helper functions in generate_test_files.py
|
||||
# ------------------------------------------------------------ #
|
||||
|
||||
|
||||
class TestSuffixToFileTag(unittest.TestCase):
|
||||
def test_fp16_token(self):
|
||||
suffix = "Fp16"
|
||||
expected_tag = "fp16"
|
||||
self.assertEqual(suffix_to_file_tag(suffix), expected_tag)
|
||||
|
||||
def test_bf16_token(self):
|
||||
suffix = "Bf16"
|
||||
expected_tag = "bf16"
|
||||
self.assertEqual(suffix_to_file_tag(suffix), expected_tag)
|
||||
|
||||
def test_fp8_token(self):
|
||||
suffix = "Fp8"
|
||||
expected_tag = "fp8"
|
||||
self.assertEqual(suffix_to_file_tag(suffix), expected_tag)
|
||||
|
||||
def test_bf8_token(self):
|
||||
suffix = "Bf8"
|
||||
expected_tag = "bf8"
|
||||
self.assertEqual(suffix_to_file_tag(suffix), expected_tag)
|
||||
|
||||
def test_nonpersistent_token(self):
|
||||
suffix = "NonPersistent"
|
||||
expected_tag = "nonpersistent"
|
||||
self.assertEqual(suffix_to_file_tag(suffix), expected_tag)
|
||||
|
||||
def test_persistent_token(self):
|
||||
suffix = "Persistent"
|
||||
expected_tag = "persistent"
|
||||
self.assertEqual(suffix_to_file_tag(suffix), expected_tag)
|
||||
|
||||
def test_atomic_token(self):
|
||||
suffix = "Atomic"
|
||||
expected_tag = "atomic"
|
||||
self.assertEqual(suffix_to_file_tag(suffix), expected_tag)
|
||||
|
||||
def test_linear_token(self):
|
||||
suffix = "Linear"
|
||||
expected_tag = "linear"
|
||||
self.assertEqual(suffix_to_file_tag(suffix), expected_tag)
|
||||
|
||||
def test_tree_token(self):
|
||||
suffix = "Tree"
|
||||
expected_tag = "tree"
|
||||
self.assertEqual(suffix_to_file_tag(suffix), expected_tag)
|
||||
|
||||
def test_compv3_token(self):
|
||||
suffix = "CompV3"
|
||||
expected_tag = "compv3"
|
||||
self.assertEqual(suffix_to_file_tag(suffix), expected_tag)
|
||||
|
||||
def test_pipelines_token(self):
|
||||
suffix = "Pipelines"
|
||||
expected_tag = "pipelines"
|
||||
self.assertEqual(suffix_to_file_tag(suffix), expected_tag)
|
||||
|
||||
def test_unknown_token(self):
|
||||
suffix = "unknown"
|
||||
with self.assertRaises(ValueError):
|
||||
suffix_to_file_tag(suffix)
|
||||
|
||||
def test_multiple_valid_tokens(self):
|
||||
suffix = "Fp16PersistentAtomicCompV3"
|
||||
expected_tag = "fp16_persistent_atomic_compv3"
|
||||
self.assertEqual(suffix_to_file_tag(suffix), expected_tag)
|
||||
|
||||
def test_multiple_tokens_with_unknown(self):
|
||||
suffix = "Fp16PersistentUnknownCompV3"
|
||||
with self.assertRaises(ValueError):
|
||||
suffix_to_file_tag(suffix)
|
||||
|
||||
|
||||
class TestParseTypesHeader(unittest.TestCase):
|
||||
def validate_entries(self, entries, expected_entries):
|
||||
self.assertEqual(len(entries), len(expected_entries))
|
||||
for idx in range(len(entries)):
|
||||
self.assertDictEqual(entries[idx], expected_entries[idx])
|
||||
|
||||
def test_empty_entry(self):
|
||||
"""Test that an empty file returns no entries."""
|
||||
mock_content = ""
|
||||
with patch("builtins.open", mock_open(read_data=mock_content)):
|
||||
entries = parse_types_header("fake_path.hpp", "atomic_smoke")
|
||||
self.assertEqual(len(entries), 0)
|
||||
|
||||
def test_pipelines_smoke(self):
|
||||
"""Test pipelines_smoke target: matches suffix == 'Pipelines'.
|
||||
Includes: Pipelines
|
||||
Excludes: Fp8NonPersistentTreeCompV3
|
||||
"""
|
||||
mock_content = (
|
||||
"using KernelTypesStreamKPipelines = ...\n"
|
||||
"using KernelTypesStreamKFp8NonPersistentTreeCompV3 = ...\n"
|
||||
)
|
||||
with patch("builtins.open", mock_open(read_data=mock_content)):
|
||||
entries = parse_types_header("fake_path.hpp", "pipelines_smoke")
|
||||
expected = [
|
||||
{
|
||||
"type_alias": "KernelTypesStreamKPipelines",
|
||||
"class_name": "TestCkTileStreamKPipelines",
|
||||
"file_tag": "pipelines",
|
||||
}
|
||||
]
|
||||
self.validate_entries(entries, expected)
|
||||
|
||||
def test_extended(self):
|
||||
"""Test extended target: matches 'Atomic' in suffix OR suffix == 'Pipelines'.
|
||||
Includes: Fp16PersistentAtomic, Pipelines
|
||||
Excludes: Bf16Linear
|
||||
"""
|
||||
mock_content = (
|
||||
"using KernelTypesStreamKFp16PersistentAtomic = ...\n"
|
||||
"using KernelTypesStreamKPipelines = ...\n"
|
||||
"using KernelTypesStreamKBf16Linear = ...\n"
|
||||
)
|
||||
with patch("builtins.open", mock_open(read_data=mock_content)):
|
||||
entries = parse_types_header("fake_path.hpp", "extended")
|
||||
expected = [
|
||||
{
|
||||
"type_alias": "KernelTypesStreamKFp16PersistentAtomic",
|
||||
"class_name": "TestCkTileStreamKFp16PersistentAtomic",
|
||||
"file_tag": "fp16_persistent_atomic",
|
||||
},
|
||||
{
|
||||
"type_alias": "KernelTypesStreamKPipelines",
|
||||
"class_name": "TestCkTileStreamKPipelines",
|
||||
"file_tag": "pipelines",
|
||||
},
|
||||
]
|
||||
self.validate_entries(entries, expected)
|
||||
|
||||
def test_atomic_smoke(self):
|
||||
"""Test atomic_smoke target: matches 'Atomic' in suffix AND suffix != 'Pipelines'.
|
||||
Includes: Fp16PersistentAtomic
|
||||
Excludes: Bf16Linear, Pipelines
|
||||
"""
|
||||
mock_content = (
|
||||
"using KernelTypesStreamKFp16PersistentAtomic = ...\n"
|
||||
"using KernelTypesStreamKBf16Linear = ...\n"
|
||||
"using KernelTypesStreamKPipelines = ...\n"
|
||||
)
|
||||
with patch("builtins.open", mock_open(read_data=mock_content)):
|
||||
entries = parse_types_header("fake_path.hpp", "atomic_smoke")
|
||||
expected = [
|
||||
{
|
||||
"type_alias": "KernelTypesStreamKFp16PersistentAtomic",
|
||||
"class_name": "TestCkTileStreamKFp16PersistentAtomic",
|
||||
"file_tag": "fp16_persistent_atomic",
|
||||
}
|
||||
]
|
||||
self.validate_entries(entries, expected)
|
||||
|
||||
def test_linear_smoke(self):
|
||||
"""Test linear_smoke target: matches 'Linear' in suffix AND suffix != 'Pipelines'.
|
||||
Includes: Fp8NonPersistentLinear
|
||||
Excludes: Bf16PersistentAtomic, Pipelines
|
||||
"""
|
||||
mock_content = (
|
||||
"using KernelTypesStreamKFp8NonPersistentLinear = ...\n"
|
||||
"using KernelTypesStreamKBf16PersistentAtomic = ...\n"
|
||||
"using KernelTypesStreamKPipelines = ...\n"
|
||||
)
|
||||
with patch("builtins.open", mock_open(read_data=mock_content)):
|
||||
entries = parse_types_header("fake_path.hpp", "linear_smoke")
|
||||
expected = [
|
||||
{
|
||||
"type_alias": "KernelTypesStreamKFp8NonPersistentLinear",
|
||||
"class_name": "TestCkTileStreamKFp8NonPersistentLinear",
|
||||
"file_tag": "fp8_nonpersistent_linear",
|
||||
}
|
||||
]
|
||||
self.validate_entries(entries, expected)
|
||||
|
||||
def test_tree_smoke(self):
|
||||
"""Test tree_smoke target: matches 'Tree' in suffix AND suffix != 'Pipelines'.
|
||||
Includes: Bf8PersistentTreeCompV3
|
||||
Excludes: Fp16Linear, Pipelines
|
||||
"""
|
||||
mock_content = (
|
||||
"using KernelTypesStreamKBf8PersistentTreeCompV3 = ...\n"
|
||||
"using KernelTypesStreamKFp16Linear = ...\n"
|
||||
"using KernelTypesStreamKPipelines = ...\n"
|
||||
)
|
||||
with patch("builtins.open", mock_open(read_data=mock_content)):
|
||||
entries = parse_types_header("fake_path.hpp", "tree_smoke")
|
||||
expected = [
|
||||
{
|
||||
"type_alias": "KernelTypesStreamKBf8PersistentTreeCompV3",
|
||||
"class_name": "TestCkTileStreamKBf8PersistentTreeCompV3",
|
||||
"file_tag": "bf8_persistent_tree_compv3",
|
||||
}
|
||||
]
|
||||
self.validate_entries(entries, expected)
|
||||
|
||||
|
||||
class TestOutputPath(unittest.TestCase):
|
||||
def test_output_path(self):
|
||||
"""Test that output_path generates the correct file path."""
|
||||
entry = {"file_tag": "fp16_persistent_atomic"}
|
||||
output_dir = "/some/output/dir"
|
||||
expected = "/some/output/dir/test_gemm_streamk_fp16_persistent_atomic.cpp"
|
||||
self.assertEqual(output_path(output_dir, entry), expected)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -1,324 +0,0 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
include(generate_configs.cmake)
|
||||
|
||||
# ============================================================================
|
||||
# GEMM Tile Engine Unit Tests
|
||||
#
|
||||
# This CMake file creates unit tests for tile_engine generated GEMM kernels.
|
||||
# It follows the exact same build patterns as tile_engine for consistency
|
||||
# and reliability. Each kernel configuration gets its own test executable.
|
||||
# ============================================================================
|
||||
|
||||
# Locate tile_engine GEMM scripts directory
|
||||
set(TILE_ENGINE_GEMM_DIR "${PROJECT_SOURCE_DIR}/tile_engine/ops/gemm_streamk")
|
||||
|
||||
if(NOT EXISTS ${TILE_ENGINE_GEMM_DIR})
|
||||
message(WARNING "Tile engine directory not found: ${TILE_ENGINE_GEMM_DIR}")
|
||||
return()
|
||||
endif()
|
||||
|
||||
# ============================================================================
|
||||
# create_individual_gemm_test_target
|
||||
#
|
||||
# Creates a single test executable for a specific kernel configuration.
|
||||
# Mirrors tile_engine's create_individual_gemm_target function for consistency.
|
||||
#
|
||||
# Parameters:
|
||||
# datatype - Data type (fp16, bf16, fp32, etc.)
|
||||
# layout - Matrix layout (rcr, rrr, ccr, crr)
|
||||
# config_name - Configuration file name without .json extension
|
||||
# trait - Kernel trait combination string
|
||||
# tile_config - Tile configuration parameters
|
||||
# config_json - Full path to JSON configuration file
|
||||
# ============================================================================
|
||||
function(create_individual_gemm_test_target datatype layout config_name trait tile_config config_json)
|
||||
set(target_name "test_gemm_streamk_tile_engine_${datatype}_${layout}_${config_name}_${trait}_${tile_config}")
|
||||
set(working_path "${CMAKE_CURRENT_BINARY_DIR}/${datatype}/${layout}/${config_name}")
|
||||
|
||||
# Generated header path (already created during cmake configuration)
|
||||
set(test_header "${working_path}/gemm_streamk_single_${datatype}_${layout}_${trait}_${tile_config}.hpp")
|
||||
set(test_params_header "${working_path}/test_params.hpp")
|
||||
|
||||
# Verify header exists (should have been generated during cmake configuration)
|
||||
if(NOT EXISTS ${test_header})
|
||||
message(WARNING "Generated header not found: ${test_header}")
|
||||
return()
|
||||
endif()
|
||||
|
||||
# Verify test parameters header exists
|
||||
if(NOT EXISTS ${test_params_header})
|
||||
message(WARNING "Test parameters header not found: ${test_params_header}")
|
||||
return()
|
||||
endif()
|
||||
|
||||
|
||||
# Create GTest executable for this kernel configuration
|
||||
add_gtest_executable(${target_name}
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/test_gemm_streamk_simple.cpp
|
||||
)
|
||||
|
||||
# Configure GPU architectures for HIP compilation
|
||||
set_property(TARGET ${target_name} PROPERTY HIP_ARCHITECTURES ${GEMM_TEST_GPU_TARGETS})
|
||||
|
||||
# Define preprocessor macros for generated header location and test parameters
|
||||
target_compile_definitions(${target_name} PRIVATE
|
||||
GEMM_SINGLE_INSTANCE_HPP="${test_header}"
|
||||
GEMM_TEST_PARAMS_HPP="${test_params_header}"
|
||||
)
|
||||
|
||||
# Include directories for headers and dependencies
|
||||
target_include_directories(${target_name} PRIVATE
|
||||
${PROJECT_SOURCE_DIR}/include
|
||||
${PROJECT_BINARY_DIR}/include
|
||||
${PROJECT_SOURCE_DIR} # Root directory for tile_engine access
|
||||
${GTEST_INCLUDE_DIRS}
|
||||
)
|
||||
|
||||
# Compiler options matching tile_engine requirements
|
||||
target_compile_options(${target_name} PRIVATE
|
||||
-Wno-undefined-func-template # Suppress template warnings
|
||||
-Wno-float-equal # Allow floating point comparisons
|
||||
--offload-compress # Enable GPU code compression
|
||||
-include ${test_header} # Auto-include generated header
|
||||
)
|
||||
|
||||
# Add FP8 format definitions for proper data type interpretation
|
||||
if(CK_USE_OCP_FP8)
|
||||
target_compile_options(${target_name} PRIVATE -DCK_TILE_USE_OCP_FP8)
|
||||
endif()
|
||||
|
||||
message(DEBUG " Created test target: ${target_name}")
|
||||
endfunction()
|
||||
|
||||
# ============================================================================
|
||||
# build_gemm_test_targets
|
||||
#
|
||||
# Builds all test targets for a specific datatype/layout/config combination.
|
||||
# Uses tile_engine's two-step process: list kernels, then generate tests.
|
||||
#
|
||||
# Parameters:
|
||||
# datatype - Data type (fp16, bf16, fp32, etc.)
|
||||
# layout - Matrix layout (rcr, rrr, ccr, crr)
|
||||
# config_name - Configuration file name without .json extension
|
||||
# ============================================================================
|
||||
function(build_gemm_test_targets datatype layout config_name configs_dir_path)
|
||||
set(working_path "${CMAKE_CURRENT_BINARY_DIR}/${datatype}/${layout}/${config_name}")
|
||||
|
||||
# Locate and validate configuration file
|
||||
set(config_filename "${config_name}.json")
|
||||
set(json_blob "${configs_dir_path}/${config_filename}")
|
||||
|
||||
if(NOT EXISTS ${json_blob})
|
||||
message(WARNING "Test config file not found: ${json_blob}")
|
||||
return()
|
||||
endif()
|
||||
|
||||
# Prepare build directory for this configuration
|
||||
file(MAKE_DIRECTORY ${working_path})
|
||||
|
||||
# STEP 1: Discovery phase - list all valid kernel configurations
|
||||
execute_process(
|
||||
COMMAND ${Python3_EXECUTABLE} -u ${TILE_ENGINE_GEMM_DIR}/gemm_streamk_instance_builder.py
|
||||
--working_path ${working_path}
|
||||
--datatype ${datatype}
|
||||
--layout ${layout}
|
||||
--config_json ${json_blob}
|
||||
--list_kernels
|
||||
--gpu_targets "${SUPPORTED_GPU_TARGETS}"
|
||||
WORKING_DIRECTORY ${TILE_ENGINE_GEMM_DIR}
|
||||
RESULT_VARIABLE ret
|
||||
OUTPUT_VARIABLE list_output
|
||||
ERROR_VARIABLE list_error
|
||||
)
|
||||
|
||||
if(NOT ret EQUAL 0)
|
||||
message(WARNING "Failed to list kernels for ${datatype}_${layout}_${config_name}: ${list_error}")
|
||||
return()
|
||||
endif()
|
||||
|
||||
# Verify kernel list file was generated
|
||||
if(NOT EXISTS ${working_path}/gemm_kernel_list.txt)
|
||||
message(DEBUG "No kernels found for ${datatype}_${layout}_${config_name} (validation filtered out all combinations)")
|
||||
return()
|
||||
endif()
|
||||
|
||||
message(DEBUG "Building tests for ${datatype}_${layout}_${config_name}")
|
||||
|
||||
# STEP 2a: Extract test parameters from config
|
||||
set(test_params_file "${working_path}/test_params.hpp")
|
||||
execute_process(
|
||||
COMMAND ${Python3_EXECUTABLE} -u ${CMAKE_CURRENT_SOURCE_DIR}/extract_test_params.py
|
||||
--config_file ${json_blob}
|
||||
--output_file ${test_params_file}
|
||||
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
|
||||
RESULT_VARIABLE extract_ret
|
||||
OUTPUT_VARIABLE extract_output
|
||||
ERROR_VARIABLE extract_error
|
||||
)
|
||||
|
||||
if(NOT extract_ret EQUAL 0)
|
||||
message(WARNING "Failed to extract test parameters for ${datatype}_${layout}: ${extract_error}")
|
||||
return()
|
||||
endif()
|
||||
|
||||
# STEP 2b: Header generation phase - generate headers using --gen_single
|
||||
message(STATUS " Generating headers using --gen_single...")
|
||||
|
||||
file(STRINGS ${working_path}/gemm_kernel_list.txt kernel_lines)
|
||||
set(gen_count 0)
|
||||
|
||||
foreach(line IN LISTS kernel_lines)
|
||||
# Parse kernel specification format: kernel_name|tile_config|trait_combo
|
||||
string(REPLACE "|" ";" parts "${line}")
|
||||
list(LENGTH parts parts_len)
|
||||
if(parts_len EQUAL 3)
|
||||
list(GET parts 0 kernel_name)
|
||||
list(GET parts 1 tile_config)
|
||||
list(GET parts 2 trait_combo)
|
||||
|
||||
# Generate header using --gen_single
|
||||
execute_process(
|
||||
COMMAND ${Python3_EXECUTABLE} -u ${TILE_ENGINE_GEMM_DIR}/gemm_streamk_instance_builder.py
|
||||
--working_path ${working_path}
|
||||
--datatype ${datatype}
|
||||
--layout ${layout}
|
||||
--config_json ${json_blob}
|
||||
--gen_single
|
||||
--kernel_name "${kernel_name}"
|
||||
--tile_config "${tile_config}"
|
||||
--trait_combo "${trait_combo}"
|
||||
--gpu_targets "${SUPPORTED_GPU_TARGETS}"
|
||||
WORKING_DIRECTORY ${TILE_ENGINE_GEMM_DIR}
|
||||
RESULT_VARIABLE gen_ret
|
||||
OUTPUT_VARIABLE gen_output
|
||||
ERROR_VARIABLE gen_error
|
||||
)
|
||||
|
||||
if(NOT gen_ret EQUAL 0)
|
||||
message(WARNING "Failed to generate header for ${kernel_name}: ${gen_error}")
|
||||
else()
|
||||
math(EXPR gen_count "${gen_count} + 1")
|
||||
endif()
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
message(STATUS " Generated ${gen_count} headers for ${datatype}_${layout}")
|
||||
|
||||
# STEP 3: Target creation phase - create test targets
|
||||
message(STATUS " Creating test targets...")
|
||||
file(STRINGS ${working_path}/gemm_kernel_list.txt kernel_lines)
|
||||
set(test_count 0)
|
||||
foreach(line IN LISTS kernel_lines)
|
||||
# Parse kernel specification format: kernel_name|tile_config|trait_combo
|
||||
string(REPLACE "|" ";" parts "${line}")
|
||||
list(LENGTH parts parts_len)
|
||||
if(parts_len EQUAL 3)
|
||||
list(GET parts 0 kernel_name)
|
||||
list(GET parts 1 tile_config)
|
||||
list(GET parts 2 trait_combo)
|
||||
|
||||
# Generate test target for this kernel configuration
|
||||
create_individual_gemm_test_target("${datatype}" "${layout}" "${config_name}" "${trait_combo}" "${tile_config}" "${json_blob}")
|
||||
math(EXPR test_count "${test_count} + 1")
|
||||
endif()
|
||||
endforeach()
|
||||
message(STATUS " Created ${test_count} test targets for ${datatype}_${layout}")
|
||||
endfunction()# ============================================================================
|
||||
# MAIN EXECUTION - Test Target Generation
|
||||
# ============================================================================
|
||||
|
||||
message(STATUS "=== Starting StreamK GEMM Tile Engine Test Configuration ===")
|
||||
message(STATUS "SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
|
||||
|
||||
# GPU architecture filtering - only build tests for supported architectures
|
||||
set(GEMM_TEST_GPU_TARGETS "")
|
||||
set(DESIRED_TARGETS "gfx90a;gfx942;gfx950;gfx12-generic")
|
||||
|
||||
foreach(target IN LISTS SUPPORTED_GPU_TARGETS)
|
||||
if(target IN_LIST DESIRED_TARGETS)
|
||||
list(APPEND GEMM_TEST_GPU_TARGETS ${target})
|
||||
message(STATUS " Adding GPU target for tests: ${target}")
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
# Early exit if no compatible GPU architectures are available
|
||||
if(NOT GEMM_TEST_GPU_TARGETS)
|
||||
message(WARNING "Skipping StreamK GEMM Tile Engine tests: No supported GPU targets (gfx90a, gfx942, gfx950, gfx12-generic) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
|
||||
return()
|
||||
endif()
|
||||
|
||||
message(STATUS "Building StreamK GEMM tile engine tests for GPU targets: ${GEMM_TEST_GPU_TARGETS}")
|
||||
|
||||
# Enable parallel compilation optimizations
|
||||
# Set up job pools for better parallel compilation control
|
||||
set_property(GLOBAL PROPERTY JOB_POOLS
|
||||
compile_heavy=4 # Limit heavy compilations to prevent OOM
|
||||
compile_normal=16 # Allow more parallel normal compilations
|
||||
)
|
||||
|
||||
# Enable compiler cache if available and explicitly requested
|
||||
# Disabled by default due to permission issues in CI environments
|
||||
option(ENABLE_CCACHE_TESTS "Enable ccache for test compilation" OFF)
|
||||
if(ENABLE_CCACHE_TESTS)
|
||||
find_program(CCACHE_PROGRAM ccache)
|
||||
if(CCACHE_PROGRAM)
|
||||
set(CMAKE_CXX_COMPILER_LAUNCHER ${CCACHE_PROGRAM})
|
||||
message(STATUS "Using ccache for faster test compilation")
|
||||
else()
|
||||
message(WARNING "ccache requested but not found")
|
||||
endif()
|
||||
else()
|
||||
message(STATUS "ccache disabled for tests (use -DENABLE_CCACHE_TESTS=ON to enable)")
|
||||
endif()
|
||||
|
||||
# ============================================================================
|
||||
# Test Configuration Matrix - Clean Focused Design
|
||||
# ============================================================================
|
||||
|
||||
# All supported data types and layouts for comprehensive testing
|
||||
# Note: fp64 not included (no MFMA hardware support)
|
||||
set(TEST_DATATYPES "fp16;bf16")
|
||||
# Temporarily only test rcr and crr
|
||||
# set(TEST_LAYOUTS "rcr;rrr;ccr;crr")
|
||||
set(TEST_LAYOUTS "rcr;crr")
|
||||
|
||||
# ============================================================================
|
||||
# Test Target Generation - Datatype-Specific Categories
|
||||
# ============================================================================
|
||||
|
||||
# 1. SMOKE TESTS: Test for basic functionality with data types (fp8, bf8, fp16, bf16)
|
||||
# Temporarily only consider fp16
|
||||
# set(SMALL_DATATYPES "fp16;bf16;fp8;bf8")
|
||||
set(SMALL_DATATYPES "fp16")
|
||||
set(SIXTEEN_BIT_DATATYPES "fp16;bf16")
|
||||
set(EIGHT_BIT_DATATYPES "fp8;bf8")
|
||||
set(LARGE_TILES "256,256,32")
|
||||
set(SMALL_TILES "128,128,32")
|
||||
set(CONFIG_LIST "")
|
||||
set(GENERATED_CONFIG_PATH ${CMAKE_CURRENT_BINARY_DIR}/configs)
|
||||
get_cu_count(CU_COUNT)
|
||||
|
||||
message(STATUS "Generating and processing configs for Stream-K tests")
|
||||
foreach(datatype IN LISTS SMALL_DATATYPES)
|
||||
|
||||
if(datatype IN_LIST SIXTEEN_BIT_DATATYPES)
|
||||
generate_test_configs(${CU_COUNT} ${LARGE_TILES} ${datatype} CONFIG_LIST ${GENERATED_CONFIG_PATH})
|
||||
else()
|
||||
generate_test_configs(${CU_COUNT} ${SMALL_TILES} ${datatype} CONFIG_LIST ${GENERATED_CONFIG_PATH})
|
||||
endif()
|
||||
|
||||
foreach(config IN LISTS CONFIG_LIST)
|
||||
# testing all layouts (rcr, rrr, ccr, crr)
|
||||
foreach(layout IN LISTS TEST_LAYOUTS)
|
||||
build_gemm_test_targets("${datatype}" "${layout}" "${config}" "${GENERATED_CONFIG_PATH}")
|
||||
endforeach()
|
||||
endforeach()
|
||||
endforeach()
|
||||
|
||||
# ============================================================================
|
||||
|
||||
|
||||
message(STATUS "StreamK GEMM tile engine tests configured with datatype-specific design:")
|
||||
message(STATUS " - Smoke tests: fp16/bf16/fp8/bf8 (all layouts)")
|
||||
@@ -1,64 +0,0 @@
|
||||
# Stream-K GEMM Tile Engine Unit Tests
|
||||
|
||||
## How It Works
|
||||
|
||||
This unit test system integrates **tile_engine's kernel generation** into automated testing:
|
||||
|
||||
1. **Uses tile_engine scripts directly**: Same Python scripts that generate tile_engine kernels
|
||||
2. **JSON-based configuration**: Define test parameters in JSON files (like tile_engine)
|
||||
3. **Build-time generation**: CMake calls tile_engine scripts to generate kernel headers
|
||||
4. **Individual test executables**: Each kernel configuration becomes a separate test
|
||||
5. **Tile_engine verification**: Uses exact same error thresholds and validation as tile_engine
|
||||
|
||||
## Tile Engine Integration
|
||||
|
||||
```
|
||||
JSON Config → tile_engine Python scripts → Generated Headers → Test Executables
|
||||
```
|
||||
|
||||
- **`--list_kernels`**: Get available kernel configurations from JSON
|
||||
- **`--gen_individual`**: Generate all kernel headers in parallel during CMake configuration
|
||||
- **`--gen_single`**: Generate individual kernel header for each configuration
|
||||
- **Same verification**: Uses tile_engine's adaptive error thresholds and reference calculations
|
||||
- **Same patterns**: Follows tile_engine's tensor initialization, stride calculation, and kernel launching
|
||||
|
||||
### Config-Specific Test Parameters
|
||||
|
||||
Each test configuration can specify optimized problem sizes in its JSON file:
|
||||
- **`test_params.problem_sizes`**: Array of `{m, n, k, split_k}` configurations
|
||||
- **CMake extraction**: `extract_test_params.py` generates config-specific test parameter files
|
||||
- **Build integration**: Each test target uses parameters appropriate for its kernel configuration
|
||||
- **Optimized testing**: Different configs test different problem sizes that showcase their strengths
|
||||
|
||||
|
||||
The key idea: **Unit tests that use tile_engine's exact kernel generation and verification methodology** instead of creating separate test infrastructure.
|
||||
|
||||
## Test Configurations
|
||||
Test configs are generated during the Generation Phase. They are stored under the build directory at test/ck_tile/gemm_streamk_tile_engine/configs. The Compute Unit (CU) count of the device is required to generate the configs. If the Generation Phase occurs on a machine without a GPU or does not contain same GPU architecture on which you will run the tests, you can manually set the CU count using the `CU_COUNT` option:
|
||||
```bash
|
||||
# Assuming you are at the root of the repo
|
||||
cd build
|
||||
../script/cmake-ck-dev.sh .. gfx90a -G Ninja -DCU_COUNT=100
|
||||
```
|
||||
You can reference the public whitepaper for your specific GPU to get the appropriate CU count.
|
||||
If no `CU_COUNT` option is given and no HIP device is found, then the default value of 100 CUs will be used to determine the problem sizes tested.
|
||||
|
||||
### 1. **Smoke Tests**
|
||||
- **Purpose**: Basic functionality validation for fp16/bf16/fp8/bf8 data types
|
||||
- **Config**: 256x256x32 (for bf16/fp16) or 128x128x32 (for bf8/fp8), warp 2x2x1, warp_tile 32x32x16
|
||||
- **Traits**: compv3 pipeline only
|
||||
- **Coverage**: All 4 layouts (rcr, rrr, ccr, crr)
|
||||
|
||||
## Data Type Support
|
||||
- ✅ **fp16, bf16, fp8, bf8**: Fully supported - all layouts (rcr, rrr, ccr, crr)
|
||||
- ❌ **fp64**: Not supported (hardware MFMA limitation)
|
||||
- ⏳ **fp32, pk-int4-t**: Not yet supported by gemm_instance_builder (will be added later)
|
||||
|
||||
## Test Result Behavior
|
||||
|
||||
Tests automatically handle unsupported configurations through runtime validation:
|
||||
- **PASSED**: Kernel executed correctly with results within error thresholds ✅
|
||||
- **SKIPPED**: Kernel validation returned "Arguments not supported" (expected for certain problem sizes/configurations) ⚠️
|
||||
- **FAILED**: Actual error or incorrect computation results ❌
|
||||
|
||||
When a kernel's `IsSupportedArgument()` check fails (e.g., due to vector alignment requirements, dimension constraints, or padding limitations), the test is automatically skipped rather than failed. This allows comprehensive testing across various problem sizes while gracefully handling configurations that don't meet specific kernel requirements.
|
||||
@@ -1,50 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <iostream>
|
||||
|
||||
/**
|
||||
* @brief Determines whether a `hipError` is present in the given `error_status`
|
||||
* @return true if the `error_status` has an error, otherwise false.
|
||||
*/
|
||||
bool has_error(const hipError_t& error_status)
|
||||
{
|
||||
if(error_status != hipSuccess)
|
||||
{
|
||||
std::cerr << hipGetErrorString(error_status);
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Returns the number of Compute Units (CUs) on the given device.
|
||||
* @return The number of CUs on the device. If an error occurs while querying the device, zero is
|
||||
* returned.
|
||||
*/
|
||||
int get_cu_count()
|
||||
{
|
||||
hipDevice_t dev;
|
||||
hipDeviceProp_t dev_prop;
|
||||
|
||||
const hipError_t device_status = hipGetDevice(&dev);
|
||||
|
||||
if(has_error(device_status))
|
||||
return 0;
|
||||
|
||||
const hipError_t prop_status = hipGetDeviceProperties(&dev_prop, dev);
|
||||
if(has_error(prop_status))
|
||||
return 0;
|
||||
|
||||
return dev_prop.multiProcessorCount;
|
||||
}
|
||||
|
||||
int main()
|
||||
{
|
||||
|
||||
std::cout << get_cu_count();
|
||||
|
||||
return 0;
|
||||
}
|
||||
@@ -1,74 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
|
||||
import json
|
||||
import argparse
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def extract_test_params(config_file, output_file):
|
||||
"""Extract test parameters from config JSON and write to output file"""
|
||||
|
||||
# Read config file
|
||||
with open(config_file, "r") as f:
|
||||
config = json.load(f)
|
||||
|
||||
# Extract test parameters
|
||||
test_params = []
|
||||
if "test_params" in config and "problem_sizes" in config["test_params"]:
|
||||
test_params = config["test_params"]["problem_sizes"]
|
||||
else:
|
||||
# Default test parameters if none specified
|
||||
test_params = [
|
||||
{"m": 256, "n": 256, "k": 128, "split_k": 1},
|
||||
{"m": 256, "n": 256, "k": 1024, "split_k": 1},
|
||||
{"m": 256, "n": 512, "k": 512, "split_k": 1},
|
||||
{"m": 512, "n": 256, "k": 512, "split_k": 1},
|
||||
]
|
||||
|
||||
# Write to output file in C++ format
|
||||
output_dir = Path(output_file).parent
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(output_file, "w") as f:
|
||||
f.write("// Generated test parameters for this configuration\n")
|
||||
f.write("// This file is auto-generated during CMake configuration\n\n")
|
||||
f.write("static const std::vector<GemmTestParams> CONFIG_TEST_PARAMS = {\n")
|
||||
|
||||
for i, params in enumerate(test_params):
|
||||
comma = "," if i < len(test_params) - 1 else ""
|
||||
f.write(
|
||||
f" {{{params['m']}, {params['n']}, {params['k']}, {params['split_k']}}}{comma}\n"
|
||||
)
|
||||
|
||||
f.write("};\n")
|
||||
|
||||
print(
|
||||
f"Extracted {len(test_params)} test parameters from {config_file} -> {output_file}"
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Extract test parameters from config JSON"
|
||||
)
|
||||
parser.add_argument("--config_file", required=True, help="Input config JSON file")
|
||||
parser.add_argument(
|
||||
"--output_file", required=True, help="Output test parameters file"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if not os.path.exists(args.config_file):
|
||||
print(f"Error: Config file not found: {args.config_file}")
|
||||
return 1
|
||||
|
||||
extract_test_params(args.config_file, args.output_file)
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
exit(main())
|
||||
@@ -1,121 +0,0 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
set(CU_COUNT 0 CACHE STRING "Number of Compute Units on the device")
|
||||
|
||||
# ============================================================================
|
||||
# get_cu_count
|
||||
#
|
||||
# Returns the CU count for the device. If the given cu_count_arg is a positive
|
||||
# integer, then the nothing happens. Otherwise, we attempt to query the CU
|
||||
# count from the device. If the query is unsucessful, the default value of 100
|
||||
# is returned.
|
||||
#
|
||||
# Parameters:
|
||||
# cu_count_arg - The starting CU count
|
||||
# ============================================================================
|
||||
function(get_cu_count cu_count_arg)
|
||||
message(STATUS "Starting query for CU count needed for Stream-K test config generation")
|
||||
|
||||
if(NOT "${${cu_count_arg}}" MATCHES "^[0-9]+$")
|
||||
message(FATAL_ERROR "The CU count must be a non-negative integer. \
|
||||
The given value of ${${cu_count_arg}} is invalid.")
|
||||
endif()
|
||||
|
||||
if("${${cu_count_arg}}" STREQUAL "0")
|
||||
|
||||
set(CPP_FILE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cu_count.cpp)
|
||||
set(CPP_EXE_PATH ${CMAKE_CURRENT_BINARY_DIR}/cu_count)
|
||||
|
||||
execute_process(
|
||||
COMMAND ${CMAKE_HIP_COMPILER} -x hip ${CPP_FILE_PATH} -o ${CPP_EXE_PATH}
|
||||
RESULT_VARIABLE compile_exit_code
|
||||
)
|
||||
|
||||
if (NOT compile_exit_code EQUAL 0)
|
||||
message(FATAL_ERROR "Compilation of ${CPP_FILE_PATH} failed.\n")
|
||||
endif()
|
||||
|
||||
# Get the HIP library directory
|
||||
get_filename_component(HIP_COMPILER_DIR ${CMAKE_HIP_COMPILER} DIRECTORY)
|
||||
get_filename_component(HIP_ROOT_DIR ${HIP_COMPILER_DIR} DIRECTORY)
|
||||
set(HIP_LIB_DIR "${HIP_ROOT_DIR}/lib")
|
||||
|
||||
# Set library path for runtime execution
|
||||
if(WIN32)
|
||||
set(ENV{PATH} "${HIP_LIB_DIR};$ENV{PATH}")
|
||||
else()
|
||||
set(ENV{LD_LIBRARY_PATH} "${HIP_LIB_DIR}:$ENV{LD_LIBRARY_PATH}")
|
||||
endif()
|
||||
|
||||
execute_process(
|
||||
COMMAND ${CPP_EXE_PATH}
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE
|
||||
ERROR_VARIABLE standard_error
|
||||
OUTPUT_VARIABLE queried_cu_count
|
||||
RESULT_VARIABLE queried_cu_count_exit_code
|
||||
)
|
||||
|
||||
if (standard_error)
|
||||
message(STATUS "Error information from attempting to query HIP device and properties:\n"
|
||||
"${standard_error}")
|
||||
endif()
|
||||
|
||||
if (NOT queried_cu_count_exit_code EQUAL 0)
|
||||
message(STATUS "Failed to run ${CPP_EXE_PATH} to query the device's CU count")
|
||||
|
||||
endif()
|
||||
|
||||
|
||||
# Delete the generated cu_count executable
|
||||
file(REMOVE "${CPP_EXE_PATH}")
|
||||
|
||||
if((queried_cu_count STREQUAL "0") OR (NOT queried_cu_count_exit_code EQUAL 0))
|
||||
message(WARNING "Unable to query the number of Compute Units. \
|
||||
Please use the CU_COUNT CLI option to pass in the \
|
||||
number of Compute Units for your target device; otherwise, \
|
||||
the default value of 100 will be used.")
|
||||
set(${cu_count_arg} 100 PARENT_SCOPE)
|
||||
else()
|
||||
set(${cu_count_arg} ${queried_cu_count} PARENT_SCOPE)
|
||||
endif()
|
||||
|
||||
endif()
|
||||
|
||||
endfunction()
|
||||
|
||||
# ============================================================================
|
||||
# generate_test_configs
|
||||
#
|
||||
# Generate config json files for Stream-K tests
|
||||
#
|
||||
# Parameters:
|
||||
# cu_count_arg - The number of CUs on the device
|
||||
# tile_sizes - A list of block tile sizes: tile_m,tile_n,tile_k
|
||||
# datatype - The datatype for which the config is being generated
|
||||
# config_list - The variable to which the list of config file names are written
|
||||
# configs_path - Path to the configs directory to which config files are written
|
||||
# ============================================================================
|
||||
function(generate_test_configs cu_count_arg tile_sizes datatype config_list configs_path)
|
||||
message(STATUS "Generating Stream-K test config files for ${datatype}")
|
||||
|
||||
file(MAKE_DIRECTORY ${configs_path})
|
||||
|
||||
execute_process(
|
||||
COMMAND ${Python3_EXECUTABLE} -u ${CMAKE_CURRENT_SOURCE_DIR}/generate_configs.py
|
||||
--cu_count ${cu_count_arg}
|
||||
--configs_dir_path ${configs_path}
|
||||
--tiles ${tile_sizes}
|
||||
--datatype ${datatype}
|
||||
OUTPUT_VARIABLE CONFIG_LIST
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE
|
||||
RESULT_VARIABLE script_ret_val
|
||||
)
|
||||
|
||||
if (NOT script_ret_val EQUAL 0)
|
||||
message(FATAL_ERROR "Eror occured during execution of ${CMAKE_CURRENT_SOURCE_DIR}/generate_configs.py")
|
||||
endif()
|
||||
|
||||
set(${config_list} ${CONFIG_LIST} PARENT_SCOPE)
|
||||
|
||||
endfunction()
|
||||
@@ -1,287 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
from enum import Enum
|
||||
from typing import Dict, Tuple, List
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from dataclasses import dataclass, field, asdict
|
||||
|
||||
|
||||
@dataclass
|
||||
class TileConfig:
|
||||
"""Represents the Tile Config section of a Tile Engine config"""
|
||||
|
||||
tile_m: List[int] = field(default_factory=list)
|
||||
tile_n: List[int] = field(default_factory=list)
|
||||
tile_k: List[int] = field(default_factory=list)
|
||||
warp_m: List[int] = field(default_factory=lambda: [2])
|
||||
warp_n: List[int] = field(default_factory=lambda: [2])
|
||||
warp_k: List[int] = field(default_factory=lambda: [1])
|
||||
warp_tile_m: List[int] = field(default_factory=lambda: [16, 32])
|
||||
warp_tile_n: List[int] = field(default_factory=lambda: [16, 32])
|
||||
# Temporarily only consider 16 for warp_tile_k
|
||||
# warp_tile_k: List[int] = field(default_factory=lambda: [8, 16, 32])
|
||||
warp_tile_k: List[int] = field(default_factory=lambda: [16])
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
return {k: {"values": v} for k, v in asdict(self).items()}
|
||||
|
||||
|
||||
@dataclass
|
||||
class TraitConfig:
|
||||
"""Represents the Trait Config section of a Tile Engine config"""
|
||||
|
||||
# Temporarily only consider compv3
|
||||
# pipeline: List[str] = field(default_factory=lambda: ["compv3", "mem"])
|
||||
pipeline: List[str] = field(default_factory=lambda: ["compv3"])
|
||||
epilogue: List[str] = field(default_factory=lambda: ["cshuffle"])
|
||||
scheduler: List[str] = field(default_factory=lambda: ["intrawave"])
|
||||
pad_m: List[bool] = field(default_factory=lambda: [False])
|
||||
pad_n: List[bool] = field(default_factory=lambda: [False])
|
||||
pad_k: List[bool] = field(default_factory=lambda: [False])
|
||||
persistent: List[bool] = field(default_factory=lambda: [True, False])
|
||||
reduction_strategy: List[str] = field(default_factory=list)
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
return {k: {"values": v} for k, v in asdict(self).items()}
|
||||
|
||||
|
||||
class TestVariant(Enum):
|
||||
"""Represents a Stream-K test variant"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
val: int,
|
||||
reduction_strategy: List[str],
|
||||
persistent: List[bool],
|
||||
datatypes: List[str],
|
||||
description: str,
|
||||
):
|
||||
self._value_ = val
|
||||
self.reduction_strategy = reduction_strategy
|
||||
self.persistent = persistent
|
||||
self.datatypes = datatypes
|
||||
self.description = description
|
||||
|
||||
ATOMIC_SMOKE = (
|
||||
0,
|
||||
["atomic"],
|
||||
[True, False],
|
||||
# Temporarily only run fp16 tests
|
||||
# ["fp16", "bf16", "fp8", "bf8"],
|
||||
["fp16"],
|
||||
"Stream-K atomic smoke tests",
|
||||
)
|
||||
REDUCTION_SMOKE = (
|
||||
2,
|
||||
["linear", "tree"],
|
||||
[True, False],
|
||||
# Temporarily only run fp16 tests
|
||||
# ["fp16", "bf16", "fp8", "bf8"],
|
||||
["fp16"],
|
||||
"Stream-K reduction smoke tests",
|
||||
)
|
||||
EXTENDED = (
|
||||
3,
|
||||
["atomic"],
|
||||
[True, False],
|
||||
# Temporarily only run fp16 tests
|
||||
# ["fp16", "bf16", "fp8", "bf8"],
|
||||
["fp16"],
|
||||
"Stream-K extended smoke tests",
|
||||
)
|
||||
|
||||
def apply(self, trait_config: TraitConfig) -> None:
|
||||
"""Applies the current test variant's persistent and reduction strategy setting to the given trait_config"""
|
||||
trait_config.persistent = self.persistent
|
||||
trait_config.reduction_strategy = self.reduction_strategy
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProblemSize:
|
||||
"""Represents a problem size in a Tile Engine config"""
|
||||
|
||||
m: int
|
||||
n: int
|
||||
k: int
|
||||
variant: TestVariant
|
||||
split_k: int = 1
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
return {"m": self.m, "n": self.n, "k": self.k, "split_k": self.split_k}
|
||||
|
||||
|
||||
@dataclass
|
||||
class Config:
|
||||
"""Represents a Tile Engine config"""
|
||||
|
||||
description: str
|
||||
problem_sizes: list[ProblemSize] = field(default_factory=list)
|
||||
tile_config: TileConfig = field(default_factory=TileConfig)
|
||||
trait_config: TraitConfig = field(default_factory=TraitConfig)
|
||||
k_block_per_cu: int = 1
|
||||
permute_n: bool = False
|
||||
|
||||
def add_problem_size(self, problem: ProblemSize) -> None:
|
||||
"""Adds the given problem to this config's problem_sizes"""
|
||||
self.problem_sizes.append(problem)
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
config_dict = {
|
||||
"problem": {"description": f"{self.description}"},
|
||||
"test_params": {
|
||||
"problem_sizes": [ps.to_dict() for ps in self.problem_sizes]
|
||||
},
|
||||
"tile_config": self.tile_config.to_dict(),
|
||||
"trait_config": self.trait_config.to_dict(),
|
||||
"k_block_per_cu": self.k_block_per_cu,
|
||||
"permute_n": self.permute_n,
|
||||
}
|
||||
return config_dict
|
||||
|
||||
def write_to_file(self, output_file: str) -> None:
|
||||
"""Writes this configs to the given output_file in a json format"""
|
||||
with open(output_file, "w") as config_file:
|
||||
json.dump(self.to_dict(), config_file, indent=4)
|
||||
config_file.write("\n")
|
||||
|
||||
|
||||
def create_problem_sizes(
|
||||
tile_m: int, tile_n: int, tile_k: int, cu_count: int
|
||||
) -> List[ProblemSize]:
|
||||
"""Creates and returns a list of problem sizes using the given arguments"""
|
||||
problem_sizes = [
|
||||
ProblemSize(256, 256, 256, TestVariant.ATOMIC_SMOKE),
|
||||
ProblemSize(tile_m * cu_count, tile_n, tile_k, TestVariant.ATOMIC_SMOKE),
|
||||
ProblemSize(
|
||||
tile_m * 2, tile_n * 2, cu_count * tile_k, TestVariant.ATOMIC_SMOKE
|
||||
),
|
||||
ProblemSize(tile_m, tile_n, cu_count * tile_k, TestVariant.REDUCTION_SMOKE),
|
||||
ProblemSize(
|
||||
tile_m * 4,
|
||||
tile_n,
|
||||
tile_k * cu_count + (25 * tile_k),
|
||||
TestVariant.REDUCTION_SMOKE,
|
||||
),
|
||||
ProblemSize(
|
||||
tile_m * 3,
|
||||
tile_n * 7,
|
||||
tile_k * cu_count + (30 * tile_k),
|
||||
TestVariant.REDUCTION_SMOKE,
|
||||
),
|
||||
# TODO: Add this test once we determine how to label tests as regresion with tile engine
|
||||
# ProblemSize((tile_m * cu_count * 2) + (tile_m * 2), tile_n, 2048, TestVariant.EXTENDED)
|
||||
]
|
||||
|
||||
return problem_sizes
|
||||
|
||||
|
||||
def write_config_files(
|
||||
problem_sizes: List[ProblemSize],
|
||||
configs_dir_path: str,
|
||||
datatype: str,
|
||||
tile_sizes: Tuple[int, int, int],
|
||||
) -> str:
|
||||
"""Writes the given problem_sizes to a config file and returns the names of the config files written to"""
|
||||
config_names = []
|
||||
tile_m, tile_n, tile_k = tile_sizes
|
||||
tile_config = TileConfig([tile_m], [tile_n], [tile_k])
|
||||
|
||||
# Create a config for each test variant
|
||||
for variant in TestVariant:
|
||||
problem_sizes_filtered = [ps for ps in problem_sizes if ps.variant == variant]
|
||||
|
||||
if (datatype not in variant.datatypes) or len(problem_sizes_filtered) == 0:
|
||||
continue
|
||||
|
||||
trait_config = TraitConfig()
|
||||
variant.apply(trait_config)
|
||||
config_name = f"streamk_{variant.name.lower()}_tests_config_{datatype}"
|
||||
config_names.append(config_name)
|
||||
file_path = os.path.join(configs_dir_path, config_name + ".json")
|
||||
config = Config(
|
||||
variant.description, problem_sizes_filtered, tile_config, trait_config
|
||||
)
|
||||
config.write_to_file(file_path)
|
||||
|
||||
return config_names
|
||||
|
||||
|
||||
def print_config_names(config_file_names: List[str]) -> None:
|
||||
"""Prints given config file names as a single semi-colon separated string"""
|
||||
print(";".join(config_file_names))
|
||||
|
||||
|
||||
def create_config_files(
|
||||
cu_count: int, configs_dir_path: str, tile_sizes: int, datatype: str
|
||||
) -> None:
|
||||
"""Creates Stream-K test config files and prints the file names in a semi-colon-separated list"""
|
||||
tile_m, tile_n, tile_k = tile_sizes
|
||||
|
||||
problem_sizes = create_problem_sizes(tile_m, tile_n, tile_k, cu_count)
|
||||
config_names = write_config_files(
|
||||
problem_sizes, configs_dir_path, datatype, tile_sizes
|
||||
)
|
||||
print_config_names(config_names)
|
||||
|
||||
|
||||
def get_args() -> Tuple[int, str, Tuple[int, int, int], str]:
|
||||
"""Returns user provided arguments"""
|
||||
|
||||
def tile_sizes_type(val: str):
|
||||
sizes = None
|
||||
parts = val.split(",")
|
||||
if len(parts) != 3:
|
||||
raise argparse.ArgumentTypeError(
|
||||
"--tiles must contain exactly three comma-separated values (m,n,k), e.g. --tiles 256,256,32"
|
||||
)
|
||||
try:
|
||||
sizes = tuple(int(size) for size in parts)
|
||||
except ValueError:
|
||||
raise argparse.ArgumentTypeError(
|
||||
"--tiles must contain exactly three comma-separated integers (m,n,k), e.g. --tiles 256,256,32"
|
||||
)
|
||||
|
||||
return sizes
|
||||
|
||||
parser = argparse.ArgumentParser(description="Create Stream-K test configs")
|
||||
parser.add_argument(
|
||||
"--cu_count", required=True, help="Number of Compute Units on the device"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--configs_dir_path",
|
||||
required=True,
|
||||
help="Full path configs directory where config files will be written to",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--tiles",
|
||||
required=True,
|
||||
type=tile_sizes_type,
|
||||
help="Block tile sizes for m, n, and k, respectively. Ex: --tiles 256,256,32",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--datatype",
|
||||
choices=["fp16", "bf16", "fp8", "bf8"],
|
||||
required=True,
|
||||
help="The datatype for which the config is generated.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
return (int(args.cu_count), args.configs_dir_path, args.tiles, args.datatype)
|
||||
|
||||
|
||||
def main():
|
||||
cu_count, configs_dir_path, tile_sizes, datatype = get_args()
|
||||
create_config_files(cu_count, configs_dir_path, tile_sizes, datatype)
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,258 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
/**
|
||||
* @file test_gemm_simple.cpp
|
||||
* @brief Unit tests for GEMM kernels generated by gemm_instance_builder
|
||||
*
|
||||
* This test includes kernels generated during CMake configuration by
|
||||
* gemm_instance_builder.py and tests them with problem sizes extracted
|
||||
* from the corresponding JSON configuration files.
|
||||
*/
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <iostream>
|
||||
#include <tuple>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "tile_engine/ops/gemm_streamk/gemm_streamk_common.hpp"
|
||||
|
||||
// The kernel header is included via compile command line with -include flag
|
||||
// It defines SelectedKernel struct, KERNEL_NAME, and tensor data types
|
||||
|
||||
// Adaptive error threshold calculation matching tile_engine's implementation
|
||||
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
|
||||
auto calculate_rtol_atol(const ck_tile::index_t K,
|
||||
const ck_tile::index_t kbatch,
|
||||
const float max_accumulated_value)
|
||||
{
|
||||
using ComputeType =
|
||||
std::conditional_t<sizeof(ADataType) < sizeof(BDataType), ADataType, BDataType>;
|
||||
// Calculate thresholds
|
||||
const auto rtol = ck_tile::get_relative_threshold<ComputeType, CDataType, AccDataType>(
|
||||
ck_tile::integer_divide_ceil(K, kbatch));
|
||||
const auto atol = ck_tile::get_absolute_threshold<ComputeType, CDataType, AccDataType>(
|
||||
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch));
|
||||
// Calculate error due to split_k accumulation
|
||||
const auto rtol_split_k =
|
||||
ck_tile::get_relative_threshold<CDataType, CDataType, CDataType>(kbatch);
|
||||
const auto atol_split_k = ck_tile::get_absolute_threshold<CDataType, CDataType, CDataType>(
|
||||
max_accumulated_value, kbatch);
|
||||
// Use higher threshold
|
||||
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
|
||||
}
|
||||
|
||||
/// @brief Function to compare the results of the device and host computations (from tile_engine)
|
||||
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
|
||||
bool compare_results(std::string instanceName,
|
||||
ck_tile::index_t K,
|
||||
ck_tile::index_t kbatch,
|
||||
ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
|
||||
ck_tile::HostTensor<CDataType>& c_m_n_host_result)
|
||||
{
|
||||
const float max_accumulated_value =
|
||||
std::abs(static_cast<float>(*std::max_element(c_m_n_host_result.mData.begin(),
|
||||
c_m_n_host_result.mData.end(),
|
||||
[](CDataType a, CDataType b) {
|
||||
return std::abs(static_cast<float>(a)) <
|
||||
std::abs(static_cast<float>(b));
|
||||
})));
|
||||
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
|
||||
K, kbatch, max_accumulated_value);
|
||||
bool pass = ck_tile::check_err(c_m_n_dev_result,
|
||||
c_m_n_host_result,
|
||||
"Error: Incorrect results!",
|
||||
rtol_atol.at(ck_tile::number<0>{}),
|
||||
rtol_atol.at(ck_tile::number<1>{}));
|
||||
|
||||
std::cout << "For " << instanceName << " Relative error threshold is "
|
||||
<< rtol_atol.at(ck_tile::number<0>{}) << " Absolute error threshold is "
|
||||
<< rtol_atol.at(ck_tile::number<1>{}) << std::endl;
|
||||
std::cout << "The verification result is:" << (pass ? "correct" : "fail") << std::endl;
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
// Test parameter structure for matrix dimensions and split_k values
|
||||
struct GemmTestParams
|
||||
{
|
||||
int m, n, k, split_k;
|
||||
};
|
||||
|
||||
// Include config-specific test parameters (after GemmTestParams struct is defined)
|
||||
#ifdef GEMM_TEST_PARAMS_HPP
|
||||
#include GEMM_TEST_PARAMS_HPP
|
||||
#endif
|
||||
|
||||
class StreamKGemmTileEngineTest : public ::testing::TestWithParam<GemmTestParams>
|
||||
{
|
||||
protected:
|
||||
void SetUp() override
|
||||
{
|
||||
auto params = GetParam();
|
||||
m_ = params.m;
|
||||
n_ = params.n;
|
||||
k_ = params.k;
|
||||
split_k_ = params.split_k;
|
||||
|
||||
// Calculate strides (following tile_engine pattern)
|
||||
if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
stride_a_ = k_;
|
||||
}
|
||||
else
|
||||
{
|
||||
stride_a_ = m_;
|
||||
}
|
||||
|
||||
if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
stride_b_ = n_;
|
||||
}
|
||||
else
|
||||
{
|
||||
stride_b_ = k_;
|
||||
}
|
||||
|
||||
if constexpr(std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
stride_c_ = n_;
|
||||
}
|
||||
else
|
||||
{
|
||||
stride_c_ = m_;
|
||||
}
|
||||
}
|
||||
|
||||
// Test dimensions
|
||||
int m_, n_, k_, split_k_;
|
||||
int stride_a_, stride_b_, stride_c_;
|
||||
};
|
||||
|
||||
TEST_P(StreamKGemmTileEngineTest, BasicFunctionality)
|
||||
{
|
||||
// Check that kernel information is available
|
||||
EXPECT_TRUE(strlen(KERNEL_NAME) > 0) << "Kernel name should not be empty";
|
||||
|
||||
std::cout << "Testing kernel: " << KERNEL_NAME << std::endl;
|
||||
std::cout << "Problem size: " << m_ << "x" << n_ << "x" << k_ << std::endl;
|
||||
|
||||
// Get tensor layouts from generated kernel
|
||||
const ALayout layout_a = ALayout{};
|
||||
const BLayout layout_b = BLayout{};
|
||||
const CLayout layout_c = CLayout{};
|
||||
|
||||
// Calculate tensor strides
|
||||
int stride_a_calc = ck_tile::get_default_stride(m_, k_, 0, is_row_major(layout_a));
|
||||
int stride_b_calc = ck_tile::get_default_stride(k_, n_, 0, is_row_major(layout_b));
|
||||
int stride_c_calc = ck_tile::get_default_stride(m_, n_, 0, is_row_major(layout_c));
|
||||
|
||||
// Create host tensors with proper descriptors
|
||||
ck_tile::HostTensor<ADataType> a_m_k(
|
||||
ck_tile::host_tensor_descriptor(m_, k_, stride_a_calc, is_row_major(layout_a)));
|
||||
ck_tile::HostTensor<BDataType> b_k_n(
|
||||
ck_tile::host_tensor_descriptor(k_, n_, stride_b_calc, is_row_major(layout_b)));
|
||||
ck_tile::HostTensor<CDataType> c_m_n_dev_result(
|
||||
ck_tile::host_tensor_descriptor(m_, n_, stride_c_calc, is_row_major(layout_c)));
|
||||
ck_tile::HostTensor<CDataType> c_m_n_dev_ref(
|
||||
ck_tile::host_tensor_descriptor(m_, n_, stride_c_calc, is_row_major(layout_c)));
|
||||
|
||||
// Initialize input tensors with uniform random distribution [-1.0, 1.0] (matches tile_engine)
|
||||
ck_tile::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n);
|
||||
c_m_n_dev_ref.SetZero();
|
||||
|
||||
// Allocate GPU device memory
|
||||
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem ref_c_m_n_dev_buf(c_m_n_dev_ref.get_element_space_size_in_bytes());
|
||||
|
||||
// Copy data to device and zero output buffer
|
||||
a_m_k_dev_buf.ToDevice(a_m_k.data());
|
||||
b_k_n_dev_buf.ToDevice(b_k_n.data());
|
||||
c_m_n_dev_buf.SetZero();
|
||||
ref_c_m_n_dev_buf.SetZero();
|
||||
|
||||
// Calculate reference result on device for verification
|
||||
ADataType* a_m_k_dev_ref_ptr = static_cast<ADataType*>(a_m_k_dev_buf.GetDeviceBuffer());
|
||||
BDataType* b_k_n_dev_ref_ptr = static_cast<BDataType*>(b_k_n_dev_buf.GetDeviceBuffer());
|
||||
CDataType* c_m_n_dev_ref_ptr = static_cast<CDataType*>(ref_c_m_n_dev_buf.GetDeviceBuffer());
|
||||
ck_tile::
|
||||
reference_gemm_gpu<ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout>(
|
||||
a_m_k_dev_ref_ptr,
|
||||
b_k_n_dev_ref_ptr,
|
||||
c_m_n_dev_ref_ptr,
|
||||
m_,
|
||||
n_,
|
||||
k_,
|
||||
stride_a_calc,
|
||||
stride_b_calc,
|
||||
stride_c_calc);
|
||||
ref_c_m_n_dev_buf.FromDevice(c_m_n_dev_ref.data());
|
||||
|
||||
// Create GEMM kernel arguments
|
||||
ck_tile::StreamKHostArgs args{a_m_k_dev_buf.GetDeviceBuffer(),
|
||||
b_k_n_dev_buf.GetDeviceBuffer(),
|
||||
c_m_n_dev_buf.GetDeviceBuffer(),
|
||||
m_,
|
||||
n_,
|
||||
k_,
|
||||
stride_a_calc,
|
||||
stride_b_calc,
|
||||
stride_c_calc};
|
||||
|
||||
// Configure kernel execution for maximum speed (no timing, no debug output)
|
||||
ck_tile::stream_config stream_config{nullptr, // stream
|
||||
false, // time_kernel (disable timing for speed)
|
||||
0, // log_level (disable debug output)
|
||||
0, // n_warmup
|
||||
1, // n_repeat
|
||||
false, // is_gpu_timer (unused when time_kernel=false)
|
||||
false, // flush_cache
|
||||
1}; // rotating_count
|
||||
|
||||
// Launch the generated kernel (no timing overhead for fastest execution)
|
||||
std::tuple<float, ck_tile::index_t> launch_result;
|
||||
try
|
||||
{
|
||||
launch_result = SelectedKernel::launch(args, stream_config);
|
||||
// Kernel launched successfully if no exception thrown
|
||||
}
|
||||
catch(const std::exception& e)
|
||||
{
|
||||
std::string error_msg(e.what());
|
||||
// If arguments not supported, skip the test (configuration validation failure, not a bug)
|
||||
if(error_msg.find("Arguments not supported") != std::string::npos)
|
||||
{
|
||||
GTEST_SKIP() << "Configuration not supported: " << e.what();
|
||||
}
|
||||
else
|
||||
{
|
||||
FAIL() << "Kernel launch failed: " << e.what();
|
||||
}
|
||||
}
|
||||
|
||||
// Copy result back from device
|
||||
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
|
||||
|
||||
// Verify results using tile_engine's adaptive error thresholds
|
||||
const ck_tile::index_t num_wgs_per_tile = get<1>(launch_result);
|
||||
bool verification_passed = compare_results<ADataType, BDataType, AccDataType, CDataType>(
|
||||
KERNEL_NAME, k_, num_wgs_per_tile, c_m_n_dev_result, c_m_n_dev_ref);
|
||||
|
||||
EXPECT_TRUE(verification_passed) << "GEMM result verification failed";
|
||||
}
|
||||
|
||||
// Use config-specific test parameters (included via compile flags)
|
||||
// CONFIG_TEST_PARAMS is defined in the auto-generated test_params.hpp file
|
||||
INSTANTIATE_TEST_SUITE_P(GemmVerification,
|
||||
StreamKGemmTileEngineTest,
|
||||
::testing::ValuesIn(CONFIG_TEST_PARAMS),
|
||||
[](const ::testing::TestParamInfo<GemmTestParams>& param_info) {
|
||||
return std::to_string(param_info.param.m) + "x" +
|
||||
std::to_string(param_info.param.n) + "x" +
|
||||
std::to_string(param_info.param.k) + "_splitk" +
|
||||
std::to_string(param_info.param.split_k);
|
||||
});
|
||||
Reference in New Issue
Block a user