From ddf49459bae33f19e648831a2c02343672dd22dc Mon Sep 17 00:00:00 2001 From: Emily Martins <65371150+ecamartins@users.noreply.github.com> Date: Wed, 11 Mar 2026 15:09:17 -0600 Subject: [PATCH] [CK_TILE] Add the GEMM Memory pipeline to Stream-K tests (#5242) ## Motivation We want to extend our Stream-K coverage to include other GEMM pipeline since our current tests only test the CompV3 pipeline. ## Technical Details All Stream-K unit tests currently only tests one pipeline: CompV3. These changes extend the test support to also test the Memory pipeline. Future work will add support for additional GEMM pipelines. The major changes are as follows: - **Remove of fp8 and bf8 extended tests for gfx90a**: gfx90a does not have native support for fp8 and bf8 and emulate the behavior with fp32 mfma instruction sizes. We've observed extremely long compile times for fp8 and bf8 on gfx90a (exceeding 15 minutes), hence we've opted to disable these tests. - **Add the memory pipeline to the Stream-K tile engine tests**: Now our smoke tests covers compv3 and memory pipelines. - **Add the memory pipeline to the Stream-K extended tests**: These changes modify the test kernel types to include the appropriate pipeline. Each pipeline is contained within a separate kernel type to help avoid large increases in build time. ## Test Plan - Ran existing and added tests on all architectures. ## Test Result - All local tests pass. ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. --- test/ck_tile/gemm_streamk/CMakeLists.txt | 35 ++-- ...gemm_streamk_bf16_nonpersistent_compv3.cpp | 18 +++ ...st_gemm_streamk_bf16_nonpersistent_mem.cpp | 17 ++ ...st_gemm_streamk_bf16_persistent_compv3.cpp | 17 ++ ...test_gemm_streamk_bf16_persistent_mem.cpp} | 6 +- ..._gemm_streamk_bf8_nonpersistent_compv3.cpp | 17 ++ ...st_gemm_streamk_bf8_nonpersistent_mem.cpp} | 6 +- ...st_gemm_streamk_bf8_persistent_compv3.cpp} | 6 +- ... test_gemm_streamk_bf8_persistent_mem.cpp} | 6 +- ...gemm_streamk_fp16_nonpersistent_compv3.cpp | 18 +++ ...st_gemm_streamk_fp16_nonpersistent_mem.cpp | 17 ++ .../test_gemm_streamk_fp16_persistent.cpp | 17 -- ...st_gemm_streamk_fp16_persistent_compv3.cpp | 17 ++ ...test_gemm_streamk_fp16_persistent_mem.cpp} | 6 +- ..._gemm_streamk_fp8_nonpersistent_compv3.cpp | 17 ++ ...est_gemm_streamk_fp8_nonpersistent_mem.cpp | 17 ++ .../test_gemm_streamk_fp8_persistent.cpp | 17 -- ...est_gemm_streamk_fp8_persistent_compv3.cpp | 17 ++ ... test_gemm_streamk_fp8_persistent_mem.cpp} | 6 +- .../gemm_streamk/test_gemm_streamk_types.hpp | 149 +++++++++++++----- .../gemm_streamk/test_gemm_streamk_util.hpp | 51 +++--- .../generate_configs.py | 2 +- 22 files changed, 351 insertions(+), 128 deletions(-) create mode 100644 test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_bf16_nonpersistent_compv3.cpp create mode 100644 test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_bf16_nonpersistent_mem.cpp create mode 100644 test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_bf16_persistent_compv3.cpp rename test/ck_tile/gemm_streamk/extended_tests/{test_gemm_streamk_bf16_nonpersistent.cpp => test_gemm_streamk_bf16_persistent_mem.cpp} (55%) create mode 100644 test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_bf8_nonpersistent_compv3.cpp rename test/ck_tile/gemm_streamk/extended_tests/{test_gemm_streamk_bf16_persistent.cpp => test_gemm_streamk_bf8_nonpersistent_mem.cpp} (52%) rename test/ck_tile/gemm_streamk/extended_tests/{test_gemm_streamk_bf8_persistent.cpp => test_gemm_streamk_bf8_persistent_compv3.cpp} (52%) rename test/ck_tile/gemm_streamk/extended_tests/{test_gemm_streamk_fp8_nonpersistent.cpp => test_gemm_streamk_bf8_persistent_mem.cpp} (55%) create mode 100644 test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_fp16_nonpersistent_compv3.cpp create mode 100644 test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_fp16_nonpersistent_mem.cpp delete mode 100644 test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_fp16_persistent.cpp create mode 100644 test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_fp16_persistent_compv3.cpp rename test/ck_tile/gemm_streamk/extended_tests/{test_gemm_streamk_fp16_nonpersistent.cpp => test_gemm_streamk_fp16_persistent_mem.cpp} (55%) create mode 100644 test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_fp8_nonpersistent_compv3.cpp create mode 100644 test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_fp8_nonpersistent_mem.cpp delete mode 100644 test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_fp8_persistent.cpp create mode 100644 test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_fp8_persistent_compv3.cpp rename test/ck_tile/gemm_streamk/extended_tests/{test_gemm_streamk_bf8_nonpersistent.cpp => test_gemm_streamk_fp8_persistent_mem.cpp} (55%) diff --git a/test/ck_tile/gemm_streamk/CMakeLists.txt b/test/ck_tile/gemm_streamk/CMakeLists.txt index b38b8a63b9..f6eb33bf76 100644 --- a/test/ck_tile/gemm_streamk/CMakeLists.txt +++ b/test/ck_tile/gemm_streamk/CMakeLists.txt @@ -23,16 +23,31 @@ if(GPU_TARGETS MATCHES "gfx90a|gfx942|gfx950") #TODO: support all arches #TODO: current c-shuffle only supports C layout as R add_gtest_executable(test_ck_tile_streamk_tile_partitioner test_streamk_tile_partitioner.cpp) - add_gtest_executable(test_ck_tile_streamk_extended - ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp16_persistent.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf16_persistent.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp8_persistent.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf8_persistent.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp16_nonpersistent.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf16_nonpersistent.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp8_nonpersistent.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf8_nonpersistent.cpp - test_gemm_streamk_util.cpp) + set(STREAMK_EXTENDED_SOURCES + ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp16_persistent_compv3.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp16_persistent_mem.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf16_persistent_compv3.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf16_persistent_mem.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp16_nonpersistent_compv3.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp16_nonpersistent_mem.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf16_nonpersistent_compv3.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf16_nonpersistent_mem.cpp + test_gemm_streamk_util.cpp) + + # We only test fp8 and bf8 on gfx942 and gfx950 since these types are not natively supported on gfx90a + if(GPU_TARGETS MATCHES "gfx942|gfx950") + list(APPEND STREAMK_EXTENDED_SOURCES + ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp8_persistent_compv3.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp8_persistent_mem.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf8_persistent_compv3.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf8_persistent_mem.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp8_nonpersistent_compv3.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp8_nonpersistent_mem.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf8_nonpersistent_compv3.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf8_nonpersistent_mem.cpp) + endif() + + add_gtest_executable(test_ck_tile_streamk_extended ${STREAMK_EXTENDED_SOURCES}) target_compile_options(test_ck_tile_streamk_extended PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) # Collect all test targets for umbrella label diff --git a/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_bf16_nonpersistent_compv3.cpp b/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_bf16_nonpersistent_compv3.cpp new file mode 100644 index 0000000000..2e35690b3d --- /dev/null +++ b/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_bf16_nonpersistent_compv3.cpp @@ -0,0 +1,18 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "test_gemm_streamk_common_includes.hpp" + +template +class TestCkTileStreamKBf16NonPersistentCompV3 : public TestCkTileStreamK +{ +}; + +#define TEST_SUITE_NAME TestCkTileStreamKBf16NonPersistentCompV3 + +TYPED_TEST_SUITE(TestCkTileStreamKBf16NonPersistentCompV3, + KernelTypesStreamKBf16NonPersistentCompV3); + +#include "test_gemm_streamk_extended_cases.inc" + +#undef TEST_SUITE_NAME diff --git a/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_bf16_nonpersistent_mem.cpp b/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_bf16_nonpersistent_mem.cpp new file mode 100644 index 0000000000..ab1dbffcdb --- /dev/null +++ b/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_bf16_nonpersistent_mem.cpp @@ -0,0 +1,17 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "test_gemm_streamk_common_includes.hpp" + +template +class TestCkTileStreamKBf16NonPersistentMem : public TestCkTileStreamK +{ +}; + +#define TEST_SUITE_NAME TestCkTileStreamKBf16NonPersistentMem + +TYPED_TEST_SUITE(TestCkTileStreamKBf16NonPersistentMem, KernelTypesStreamKBf16NonPersistentMem); + +#include "test_gemm_streamk_extended_cases.inc" + +#undef TEST_SUITE_NAME diff --git a/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_bf16_persistent_compv3.cpp b/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_bf16_persistent_compv3.cpp new file mode 100644 index 0000000000..24385201a1 --- /dev/null +++ b/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_bf16_persistent_compv3.cpp @@ -0,0 +1,17 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "test_gemm_streamk_common_includes.hpp" + +template +class TestCkTileStreamKBf16PersistentCompV3 : public TestCkTileStreamK +{ +}; + +#define TEST_SUITE_NAME TestCkTileStreamKBf16PersistentCompV3 + +TYPED_TEST_SUITE(TestCkTileStreamKBf16PersistentCompV3, KernelTypesStreamKBf16PersistentCompV3); + +#include "test_gemm_streamk_extended_cases.inc" + +#undef TEST_SUITE_NAME diff --git a/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_bf16_nonpersistent.cpp b/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_bf16_persistent_mem.cpp similarity index 55% rename from test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_bf16_nonpersistent.cpp rename to test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_bf16_persistent_mem.cpp index 7c9c2c9657..94f9def529 100644 --- a/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_bf16_nonpersistent.cpp +++ b/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_bf16_persistent_mem.cpp @@ -4,13 +4,13 @@ #include "test_gemm_streamk_common_includes.hpp" template -class TestCkTileStreamKBf16NonPersistent : public TestCkTileStreamK +class TestCkTileStreamKBf16PersistentMem : public TestCkTileStreamK { }; -#define TEST_SUITE_NAME TestCkTileStreamKBf16NonPersistent +#define TEST_SUITE_NAME TestCkTileStreamKBf16PersistentMem -TYPED_TEST_SUITE(TestCkTileStreamKBf16NonPersistent, KernelTypesStreamKBf16NonPersistent); +TYPED_TEST_SUITE(TestCkTileStreamKBf16PersistentMem, KernelTypesStreamKBf16PersistentMem); #include "test_gemm_streamk_extended_cases.inc" diff --git a/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_bf8_nonpersistent_compv3.cpp b/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_bf8_nonpersistent_compv3.cpp new file mode 100644 index 0000000000..a0a04d79e2 --- /dev/null +++ b/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_bf8_nonpersistent_compv3.cpp @@ -0,0 +1,17 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "test_gemm_streamk_common_includes.hpp" + +template +class TestCkTileStreamKBf8NonPersistentCompV3 : public TestCkTileStreamK +{ +}; + +#define TEST_SUITE_NAME TestCkTileStreamKBf8NonPersistentCompV3 + +TYPED_TEST_SUITE(TestCkTileStreamKBf8NonPersistentCompV3, KernelTypesStreamKBf8NonPersistentCompV3); + +#include "test_gemm_streamk_extended_cases.inc" + +#undef TEST_SUITE_NAME diff --git a/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_bf16_persistent.cpp b/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_bf8_nonpersistent_mem.cpp similarity index 52% rename from test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_bf16_persistent.cpp rename to test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_bf8_nonpersistent_mem.cpp index dd4bbad61b..5a6447416d 100644 --- a/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_bf16_persistent.cpp +++ b/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_bf8_nonpersistent_mem.cpp @@ -4,13 +4,13 @@ #include "test_gemm_streamk_common_includes.hpp" template -class TestCkTileStreamKBf16Persistent : public TestCkTileStreamK +class TestCkTileStreamKBf8NonPersistentMem : public TestCkTileStreamK { }; -#define TEST_SUITE_NAME TestCkTileStreamKBf16Persistent +#define TEST_SUITE_NAME TestCkTileStreamKBf8NonPersistentMem -TYPED_TEST_SUITE(TestCkTileStreamKBf16Persistent, KernelTypesStreamKBf16Persistent); +TYPED_TEST_SUITE(TestCkTileStreamKBf8NonPersistentMem, KernelTypesStreamKBf8NonPersistentMem); #include "test_gemm_streamk_extended_cases.inc" diff --git a/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_bf8_persistent.cpp b/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_bf8_persistent_compv3.cpp similarity index 52% rename from test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_bf8_persistent.cpp rename to test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_bf8_persistent_compv3.cpp index 5f1bdaca86..0a6c2346d8 100644 --- a/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_bf8_persistent.cpp +++ b/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_bf8_persistent_compv3.cpp @@ -4,13 +4,13 @@ #include "test_gemm_streamk_common_includes.hpp" template -class TestCkTileStreamKBf8Persistent : public TestCkTileStreamK +class TestCkTileStreamKBf8PersistentCompV3 : public TestCkTileStreamK { }; -#define TEST_SUITE_NAME TestCkTileStreamKBf8Persistent +#define TEST_SUITE_NAME TestCkTileStreamKBf8PersistentCompV3 -TYPED_TEST_SUITE(TestCkTileStreamKBf8Persistent, KernelTypesStreamKBf8Persistent); +TYPED_TEST_SUITE(TestCkTileStreamKBf8PersistentCompV3, KernelTypesStreamKBf8PersistentCompV3); #include "test_gemm_streamk_extended_cases.inc" diff --git a/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_fp8_nonpersistent.cpp b/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_bf8_persistent_mem.cpp similarity index 55% rename from test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_fp8_nonpersistent.cpp rename to test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_bf8_persistent_mem.cpp index 0cdb4091d1..1eef56c971 100644 --- a/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_fp8_nonpersistent.cpp +++ b/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_bf8_persistent_mem.cpp @@ -4,13 +4,13 @@ #include "test_gemm_streamk_common_includes.hpp" template -class TestCkTileStreamKFp8NonPersistent : public TestCkTileStreamK +class TestCkTileStreamKBf8PersistentMem : public TestCkTileStreamK { }; -#define TEST_SUITE_NAME TestCkTileStreamKFp8NonPersistent +#define TEST_SUITE_NAME TestCkTileStreamKBf8PersistentMem -TYPED_TEST_SUITE(TestCkTileStreamKFp8NonPersistent, KernelTypesStreamKFp8NonPersistent); +TYPED_TEST_SUITE(TestCkTileStreamKBf8PersistentMem, KernelTypesStreamKBf8PersistentMem); #include "test_gemm_streamk_extended_cases.inc" diff --git a/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_fp16_nonpersistent_compv3.cpp b/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_fp16_nonpersistent_compv3.cpp new file mode 100644 index 0000000000..3381554d1e --- /dev/null +++ b/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_fp16_nonpersistent_compv3.cpp @@ -0,0 +1,18 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "test_gemm_streamk_common_includes.hpp" + +template +class TestCkTileStreamKFp16NonPersistentCompV3 : public TestCkTileStreamK +{ +}; + +#define TEST_SUITE_NAME TestCkTileStreamKFp16NonPersistentCompV3 + +TYPED_TEST_SUITE(TestCkTileStreamKFp16NonPersistentCompV3, + KernelTypesStreamKFp16NonPersistentCompV3); + +#include "test_gemm_streamk_extended_cases.inc" + +#undef TEST_SUITE_NAME diff --git a/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_fp16_nonpersistent_mem.cpp b/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_fp16_nonpersistent_mem.cpp new file mode 100644 index 0000000000..2f7dd7be33 --- /dev/null +++ b/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_fp16_nonpersistent_mem.cpp @@ -0,0 +1,17 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "test_gemm_streamk_common_includes.hpp" + +template +class TestCkTileStreamKFp16NonPersistentMem : public TestCkTileStreamK +{ +}; + +#define TEST_SUITE_NAME TestCkTileStreamKFp16NonPersistentMem + +TYPED_TEST_SUITE(TestCkTileStreamKFp16NonPersistentMem, KernelTypesStreamKFp16NonPersistentMem); + +#include "test_gemm_streamk_extended_cases.inc" + +#undef TEST_SUITE_NAME diff --git a/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_fp16_persistent.cpp b/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_fp16_persistent.cpp deleted file mode 100644 index 33b474526c..0000000000 --- a/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_fp16_persistent.cpp +++ /dev/null @@ -1,17 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include "test_gemm_streamk_common_includes.hpp" - -template -class TestCkTileStreamKFp16Persistent : public TestCkTileStreamK -{ -}; - -#define TEST_SUITE_NAME TestCkTileStreamKFp16Persistent - -TYPED_TEST_SUITE(TestCkTileStreamKFp16Persistent, KernelTypesStreamKFp16Persistent); - -#include "test_gemm_streamk_extended_cases.inc" - -#undef TEST_SUITE_NAME diff --git a/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_fp16_persistent_compv3.cpp b/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_fp16_persistent_compv3.cpp new file mode 100644 index 0000000000..3c041a3652 --- /dev/null +++ b/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_fp16_persistent_compv3.cpp @@ -0,0 +1,17 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "test_gemm_streamk_common_includes.hpp" + +template +class TestCkTileStreamKFp16PersistentCompV3 : public TestCkTileStreamK +{ +}; + +#define TEST_SUITE_NAME TestCkTileStreamKFp16PersistentCompV3 + +TYPED_TEST_SUITE(TestCkTileStreamKFp16PersistentCompV3, KernelTypesStreamKFp16PersistentCompV3); + +#include "test_gemm_streamk_extended_cases.inc" + +#undef TEST_SUITE_NAME diff --git a/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_fp16_nonpersistent.cpp b/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_fp16_persistent_mem.cpp similarity index 55% rename from test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_fp16_nonpersistent.cpp rename to test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_fp16_persistent_mem.cpp index f1a3bad142..c05135943f 100644 --- a/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_fp16_nonpersistent.cpp +++ b/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_fp16_persistent_mem.cpp @@ -4,13 +4,13 @@ #include "test_gemm_streamk_common_includes.hpp" template -class TestCkTileStreamKFp16NonPersistent : public TestCkTileStreamK +class TestCkTileStreamKFp16PersistentMem : public TestCkTileStreamK { }; -#define TEST_SUITE_NAME TestCkTileStreamKFp16NonPersistent +#define TEST_SUITE_NAME TestCkTileStreamKFp16PersistentMem -TYPED_TEST_SUITE(TestCkTileStreamKFp16NonPersistent, KernelTypesStreamKFp16NonPersistent); +TYPED_TEST_SUITE(TestCkTileStreamKFp16PersistentMem, KernelTypesStreamKFp16PersistentMem); #include "test_gemm_streamk_extended_cases.inc" diff --git a/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_fp8_nonpersistent_compv3.cpp b/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_fp8_nonpersistent_compv3.cpp new file mode 100644 index 0000000000..379702a10a --- /dev/null +++ b/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_fp8_nonpersistent_compv3.cpp @@ -0,0 +1,17 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "test_gemm_streamk_common_includes.hpp" + +template +class TestCkTileStreamKFp8NonPersistentCompV3 : public TestCkTileStreamK +{ +}; + +#define TEST_SUITE_NAME TestCkTileStreamKFp8NonPersistentCompV3 + +TYPED_TEST_SUITE(TestCkTileStreamKFp8NonPersistentCompV3, KernelTypesStreamKFp8NonPersistentCompV3); + +#include "test_gemm_streamk_extended_cases.inc" + +#undef TEST_SUITE_NAME diff --git a/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_fp8_nonpersistent_mem.cpp b/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_fp8_nonpersistent_mem.cpp new file mode 100644 index 0000000000..3d545a61c6 --- /dev/null +++ b/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_fp8_nonpersistent_mem.cpp @@ -0,0 +1,17 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "test_gemm_streamk_common_includes.hpp" + +template +class TestCkTileStreamKFp8NonPersistentMem : public TestCkTileStreamK +{ +}; + +#define TEST_SUITE_NAME TestCkTileStreamKFp8NonPersistentMem + +TYPED_TEST_SUITE(TestCkTileStreamKFp8NonPersistentMem, KernelTypesStreamKFp8NonPersistentMem); + +#include "test_gemm_streamk_extended_cases.inc" + +#undef TEST_SUITE_NAME diff --git a/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_fp8_persistent.cpp b/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_fp8_persistent.cpp deleted file mode 100644 index d418c889cd..0000000000 --- a/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_fp8_persistent.cpp +++ /dev/null @@ -1,17 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include "test_gemm_streamk_common_includes.hpp" - -template -class TestCkTileStreamKFp8Persistent : public TestCkTileStreamK -{ -}; - -#define TEST_SUITE_NAME TestCkTileStreamKFp8Persistent - -TYPED_TEST_SUITE(TestCkTileStreamKFp8Persistent, KernelTypesStreamKFp8Persistent); - -#include "test_gemm_streamk_extended_cases.inc" - -#undef TEST_SUITE_NAME diff --git a/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_fp8_persistent_compv3.cpp b/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_fp8_persistent_compv3.cpp new file mode 100644 index 0000000000..dccdcaf270 --- /dev/null +++ b/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_fp8_persistent_compv3.cpp @@ -0,0 +1,17 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "test_gemm_streamk_common_includes.hpp" + +template +class TestCkTileStreamKFp8PersistentCompV3 : public TestCkTileStreamK +{ +}; + +#define TEST_SUITE_NAME TestCkTileStreamKFp8PersistentCompV3 + +TYPED_TEST_SUITE(TestCkTileStreamKFp8PersistentCompV3, KernelTypesStreamKFp8PersistentCompV3); + +#include "test_gemm_streamk_extended_cases.inc" + +#undef TEST_SUITE_NAME diff --git a/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_bf8_nonpersistent.cpp b/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_fp8_persistent_mem.cpp similarity index 55% rename from test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_bf8_nonpersistent.cpp rename to test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_fp8_persistent_mem.cpp index 9b3b0fccb9..88ebdf1e55 100644 --- a/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_bf8_nonpersistent.cpp +++ b/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_fp8_persistent_mem.cpp @@ -4,13 +4,13 @@ #include "test_gemm_streamk_common_includes.hpp" template -class TestCkTileStreamKBf8NonPersistent : public TestCkTileStreamK +class TestCkTileStreamKFp8PersistentMem : public TestCkTileStreamK { }; -#define TEST_SUITE_NAME TestCkTileStreamKBf8NonPersistent +#define TEST_SUITE_NAME TestCkTileStreamKFp8PersistentMem -TYPED_TEST_SUITE(TestCkTileStreamKBf8NonPersistent, KernelTypesStreamKBf8NonPersistent); +TYPED_TEST_SUITE(TestCkTileStreamKFp8PersistentMem, KernelTypesStreamKFp8PersistentMem); #include "test_gemm_streamk_extended_cases.inc" diff --git a/test/ck_tile/gemm_streamk/test_gemm_streamk_types.hpp b/test/ck_tile/gemm_streamk/test_gemm_streamk_types.hpp index efb7416580..ca8ffee219 100644 --- a/test/ck_tile/gemm_streamk/test_gemm_streamk_types.hpp +++ b/test/ck_tile/gemm_streamk/test_gemm_streamk_types.hpp @@ -6,6 +6,7 @@ #include #include "gtest/gtest.h" #include "ck_tile/host.hpp" +#include "test_gemm_streamk_util.hpp" using F8 = ck_tile::fp8_t; using F16 = ck_tile::half_t; @@ -19,69 +20,131 @@ using Col = ck_tile::tensor_layout::gemm::ColumnMajor; using Persistent = std::true_type; using NonPersistent = std::false_type; +using Mem = ck_tile::integral_constant; +using CompV3 = ck_tile::integral_constant; + using I32 = ck_tile::number<32>; using I128 = ck_tile::number<128>; using I256 = ck_tile::number<256>; // clang-format off -using KernelTypesStreamKFp16Persistent = ::testing::Types< -// ALayout BLayout CLayout ADataType BDataType AccDataType CDataType M_MacroTile N_MacroTile K_MacroTile Persistent - std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, Persistent>, - std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, Persistent>, - std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, Persistent>, - std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, Persistent> +// ========================== CompV3 Pipeline ========================== + +using KernelTypesStreamKFp16PersistentCompV3 = ::testing::Types< +// ALayout BLayout CLayout ADataType BDataType AccDataType CDataType M_MacroTile N_MacroTile K_MacroTile Persistent Pipeline + + std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, Persistent, CompV3>, + std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, Persistent, CompV3>, + std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, Persistent, CompV3>, + std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, Persistent, CompV3> >; -using KernelTypesStreamKBf16Persistent = ::testing::Types< - std::tuple< Row, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent>, - std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent>, - std::tuple< Col, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent>, - std::tuple< Col, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent> +using KernelTypesStreamKBf16PersistentCompV3 = ::testing::Types< + std::tuple< Row, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent, CompV3>, + std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent, CompV3>, + std::tuple< Col, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent, CompV3>, + std::tuple< Col, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent, CompV3> >; -using KernelTypesStreamKBf8Persistent = ::testing::Types< - std::tuple< Row, Row, Row, BF8, BF8, F32, BF16, I128, I128, I32, Persistent>, - std::tuple< Row, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, Persistent>, - std::tuple< Col, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, Persistent>, - std::tuple< Col, Row, Row, BF8, BF8, F32, BF16, I128, I128, I32, Persistent> +using KernelTypesStreamKBf8PersistentCompV3 = ::testing::Types< + std::tuple< Row, Row, Row, BF8, BF8, F32, BF16, I128, I128, I32, Persistent, CompV3>, + std::tuple< Row, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, Persistent, CompV3>, + std::tuple< Col, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, Persistent, CompV3>, + std::tuple< Col, Row, Row, BF8, BF8, F32, BF16, I128, I128, I32, Persistent, CompV3> >; -using KernelTypesStreamKFp8Persistent = ::testing::Types< - std::tuple< Row, Row, Row, F8, F8, F32, F16, I128, I128, I32, Persistent>, - std::tuple< Row, Col, Row, F8, F8, F32, F16, I128, I128, I32, Persistent>, - std::tuple< Col, Col, Row, F8, F8, F32, F16, I128, I128, I32, Persistent>, - std::tuple< Col, Row, Row, F8, F8, F32, F16, I128, I128, I32, Persistent> +using KernelTypesStreamKFp8PersistentCompV3 = ::testing::Types< + std::tuple< Row, Row, Row, F8, F8, F32, F16, I128, I128, I32, Persistent, CompV3>, + std::tuple< Row, Col, Row, F8, F8, F32, F16, I128, I128, I32, Persistent, CompV3>, + std::tuple< Col, Col, Row, F8, F8, F32, F16, I128, I128, I32, Persistent, CompV3>, + std::tuple< Col, Row, Row, F8, F8, F32, F16, I128, I128, I32, Persistent, CompV3> >; -using KernelTypesStreamKFp16NonPersistent = ::testing::Types< -// ALayout BLayout CLayout ADataType BDataType AccDataType CDataType M_MacroTile N_MacroTile K_MacroTile Persistent - - std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent>, - std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent>, - std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent>, - std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent> +using KernelTypesStreamKFp16NonPersistentCompV3 = ::testing::Types< + std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent, CompV3>, + std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent, CompV3>, + std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent, CompV3>, + std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent, CompV3> >; -using KernelTypesStreamKBf16NonPersistent = ::testing::Types< - std::tuple< Row, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent>, - std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent>, - std::tuple< Col, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent>, - std::tuple< Col, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent> +using KernelTypesStreamKBf16NonPersistentCompV3 = ::testing::Types< + std::tuple< Row, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent, CompV3>, + std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent, CompV3>, + std::tuple< Col, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent, CompV3>, + std::tuple< Col, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent, CompV3> >; -using KernelTypesStreamKBf8NonPersistent = ::testing::Types< - std::tuple< Row, Row, Row, BF8, BF8, F32, BF16, I128, I128, I32, NonPersistent>, - std::tuple< Row, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, NonPersistent>, - std::tuple< Col, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, NonPersistent>, - std::tuple< Col, Row, Row, BF8, BF8, F32, BF16, I128, I128, I32, NonPersistent> +using KernelTypesStreamKBf8NonPersistentCompV3 = ::testing::Types< + std::tuple< Row, Row, Row, BF8, BF8, F32, BF16, I128, I128, I32, NonPersistent, CompV3>, + std::tuple< Row, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, NonPersistent, CompV3>, + std::tuple< Col, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, NonPersistent, CompV3>, + std::tuple< Col, Row, Row, BF8, BF8, F32, BF16, I128, I128, I32, NonPersistent, CompV3> >; -using KernelTypesStreamKFp8NonPersistent = ::testing::Types< - std::tuple< Row, Row, Row, F8, F8, F32, F16, I128, I128, I32, NonPersistent>, - std::tuple< Row, Col, Row, F8, F8, F32, F16, I128, I128, I32, NonPersistent>, - std::tuple< Col, Col, Row, F8, F8, F32, F16, I128, I128, I32, NonPersistent>, - std::tuple< Col, Row, Row, F8, F8, F32, F16, I128, I128, I32, NonPersistent> +using KernelTypesStreamKFp8NonPersistentCompV3 = ::testing::Types< + std::tuple< Row, Row, Row, F8, F8, F32, F16, I128, I128, I32, NonPersistent, CompV3>, + std::tuple< Row, Col, Row, F8, F8, F32, F16, I128, I128, I32, NonPersistent, CompV3>, + std::tuple< Col, Col, Row, F8, F8, F32, F16, I128, I128, I32, NonPersistent, CompV3>, + std::tuple< Col, Row, Row, F8, F8, F32, F16, I128, I128, I32, NonPersistent, CompV3> +>; + +// ============================= Mem Pipeline ============================= + +using KernelTypesStreamKFp16PersistentMem = ::testing::Types< + std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, Persistent, Mem>, + std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, Persistent, Mem>, + std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, Persistent, Mem>, + std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, Persistent, Mem> +>; + +using KernelTypesStreamKBf16PersistentMem = ::testing::Types< + std::tuple< Row, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent, Mem>, + std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent, Mem>, + std::tuple< Col, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent, Mem>, + std::tuple< Col, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent, Mem> +>; + +using KernelTypesStreamKBf8PersistentMem = ::testing::Types< + std::tuple< Row, Row, Row, BF8, BF8, F32, BF16, I128, I128, I32, Persistent, Mem>, + std::tuple< Row, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, Persistent, Mem>, + std::tuple< Col, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, Persistent, Mem>, + std::tuple< Col, Row, Row, BF8, BF8, F32, BF16, I128, I128, I32, Persistent, Mem> +>; + +using KernelTypesStreamKFp8PersistentMem = ::testing::Types< + std::tuple< Row, Row, Row, F8, F8, F32, F16, I128, I128, I32, Persistent, Mem>, + std::tuple< Row, Col, Row, F8, F8, F32, F16, I128, I128, I32, Persistent, Mem>, + std::tuple< Col, Col, Row, F8, F8, F32, F16, I128, I128, I32, Persistent, Mem>, + std::tuple< Col, Row, Row, F8, F8, F32, F16, I128, I128, I32, Persistent, Mem> +>; + +using KernelTypesStreamKFp16NonPersistentMem = ::testing::Types< + std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent, Mem>, + std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent, Mem>, + std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent, Mem>, + std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent, Mem> +>; + +using KernelTypesStreamKBf16NonPersistentMem = ::testing::Types< + std::tuple< Row, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent, Mem>, + std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent, Mem>, + std::tuple< Col, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent, Mem>, + std::tuple< Col, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent, Mem> +>; + +using KernelTypesStreamKBf8NonPersistentMem = ::testing::Types< + std::tuple< Row, Row, Row, BF8, BF8, F32, BF16, I128, I128, I32, NonPersistent, Mem>, + std::tuple< Row, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, NonPersistent, Mem>, + std::tuple< Col, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, NonPersistent, Mem>, + std::tuple< Col, Row, Row, BF8, BF8, F32, BF16, I128, I128, I32, NonPersistent, Mem> +>; + +using KernelTypesStreamKFp8NonPersistentMem = ::testing::Types< + std::tuple< Row, Row, Row, F8, F8, F32, F16, I128, I128, I32, NonPersistent, Mem>, + std::tuple< Row, Col, Row, F8, F8, F32, F16, I128, I128, I32, NonPersistent, Mem>, + std::tuple< Col, Col, Row, F8, F8, F32, F16, I128, I128, I32, NonPersistent, Mem>, + std::tuple< Col, Row, Row, F8, F8, F32, F16, I128, I128, I32, NonPersistent, Mem> >; // clang-format on diff --git a/test/ck_tile/gemm_streamk/test_gemm_streamk_util.hpp b/test/ck_tile/gemm_streamk/test_gemm_streamk_util.hpp index 5e3b85c009..af1bab34bf 100644 --- a/test/ck_tile/gemm_streamk/test_gemm_streamk_util.hpp +++ b/test/ck_tile/gemm_streamk/test_gemm_streamk_util.hpp @@ -11,6 +11,27 @@ #include "ck_tile/ops/epilogue.hpp" #include "ck_tile/ops/gemm.hpp" +enum struct GemmPipelineType +{ + Mem, + CompV3 +}; + +template +struct GemmPipelineTypeSelector; + +template +struct GemmPipelineTypeSelector +{ + using pipeline = ck_tile::GemmPipelineAgBgCrMem; +}; + +template +struct GemmPipelineTypeSelector +{ + using pipeline = ck_tile::GemmPipelineAgBgCrCompV3; +}; + template auto calculate_rtol_atol(const ck_tile::index_t K, const ck_tile::index_t kbatch, @@ -56,6 +77,7 @@ class TestCkTileStreamK : public ::testing::Test static constexpr ck_tile::index_t N_Tile = std::tuple_element_t<8, Tuple>::value; static constexpr ck_tile::index_t K_Tile = std::tuple_element_t<9, Tuple>::value; static constexpr bool Persistent = std::tuple_element_t<10, Tuple>::value; + static constexpr auto PipelineType = std::tuple_element_t<11, Tuple>::value; template ; - // For initial testing, we will just test with one pipeline. - // More extensive testing is coming later and will test other pipelines. - using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; + + using GemmPipeline = GemmPipelineTypeSelector::pipeline; using GemmEpilogue = ck_tile::CShuffleEpilogue< ck_tile::CShuffleEpilogueProblem( + ck_tile::index_t num_accumulations_per_tile = + invoke_streamk( args, ck_tile::stream_config{nullptr, false, 0, 0, 1}); - } - else if(reduction_strategy == ck_tile::StreamKReductionStrategy::Linear) - { - num_accumulations_per_tile = invoke_streamk( - args, ck_tile::stream_config{nullptr, false, 0, 0, 1}); - } - else - { - num_accumulations_per_tile = invoke_streamk( - args, ck_tile::stream_config{nullptr, false, 0, 0, 1}); - } c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data()); diff --git a/test/ck_tile/gemm_streamk_tile_engine/generate_configs.py b/test/ck_tile/gemm_streamk_tile_engine/generate_configs.py index 166ea940c3..0f2673c6dd 100644 --- a/test/ck_tile/gemm_streamk_tile_engine/generate_configs.py +++ b/test/ck_tile/gemm_streamk_tile_engine/generate_configs.py @@ -33,7 +33,7 @@ class TileConfig: class TraitConfig: """Represents the Trait Config section of a Tile Engine config""" - pipeline: List[str] = field(default_factory=lambda: ["compv3"]) + pipeline: List[str] = field(default_factory=lambda: ["compv3", "mem"]) epilogue: List[str] = field(default_factory=lambda: ["cshuffle"]) scheduler: List[str] = field(default_factory=lambda: ["intrawave"]) pad_m: List[bool] = field(default_factory=lambda: [False])