diff --git a/test/ck_tile/gemm_streamk/CMakeLists.txt b/test/ck_tile/gemm_streamk/CMakeLists.txt index b38b8a63b9..f6eb33bf76 100644 --- a/test/ck_tile/gemm_streamk/CMakeLists.txt +++ b/test/ck_tile/gemm_streamk/CMakeLists.txt @@ -23,16 +23,31 @@ if(GPU_TARGETS MATCHES "gfx90a|gfx942|gfx950") #TODO: support all arches #TODO: current c-shuffle only supports C layout as R add_gtest_executable(test_ck_tile_streamk_tile_partitioner test_streamk_tile_partitioner.cpp) - add_gtest_executable(test_ck_tile_streamk_extended - ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp16_persistent.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf16_persistent.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp8_persistent.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf8_persistent.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp16_nonpersistent.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf16_nonpersistent.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp8_nonpersistent.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf8_nonpersistent.cpp - test_gemm_streamk_util.cpp) + set(STREAMK_EXTENDED_SOURCES + ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp16_persistent_compv3.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp16_persistent_mem.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf16_persistent_compv3.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf16_persistent_mem.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp16_nonpersistent_compv3.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp16_nonpersistent_mem.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf16_nonpersistent_compv3.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf16_nonpersistent_mem.cpp + test_gemm_streamk_util.cpp) + + # We only test fp8 and bf8 on gfx942 and gfx950 since these types are not natively supported on gfx90a + if(GPU_TARGETS MATCHES "gfx942|gfx950") + list(APPEND STREAMK_EXTENDED_SOURCES + ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp8_persistent_compv3.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp8_persistent_mem.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf8_persistent_compv3.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf8_persistent_mem.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp8_nonpersistent_compv3.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp8_nonpersistent_mem.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf8_nonpersistent_compv3.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf8_nonpersistent_mem.cpp) + endif() + + add_gtest_executable(test_ck_tile_streamk_extended ${STREAMK_EXTENDED_SOURCES}) target_compile_options(test_ck_tile_streamk_extended PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) # Collect all test targets for umbrella label diff --git a/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_bf16_nonpersistent_compv3.cpp b/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_bf16_nonpersistent_compv3.cpp new file mode 100644 index 0000000000..2e35690b3d --- /dev/null +++ b/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_bf16_nonpersistent_compv3.cpp @@ -0,0 +1,18 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "test_gemm_streamk_common_includes.hpp" + +template +class TestCkTileStreamKBf16NonPersistentCompV3 : public TestCkTileStreamK +{ +}; + +#define TEST_SUITE_NAME TestCkTileStreamKBf16NonPersistentCompV3 + +TYPED_TEST_SUITE(TestCkTileStreamKBf16NonPersistentCompV3, + KernelTypesStreamKBf16NonPersistentCompV3); + +#include "test_gemm_streamk_extended_cases.inc" + +#undef TEST_SUITE_NAME diff --git a/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_bf16_nonpersistent_mem.cpp b/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_bf16_nonpersistent_mem.cpp new file mode 100644 index 0000000000..ab1dbffcdb --- /dev/null +++ b/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_bf16_nonpersistent_mem.cpp @@ -0,0 +1,17 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "test_gemm_streamk_common_includes.hpp" + +template +class TestCkTileStreamKBf16NonPersistentMem : public TestCkTileStreamK +{ +}; + +#define TEST_SUITE_NAME TestCkTileStreamKBf16NonPersistentMem + +TYPED_TEST_SUITE(TestCkTileStreamKBf16NonPersistentMem, KernelTypesStreamKBf16NonPersistentMem); + +#include "test_gemm_streamk_extended_cases.inc" + +#undef TEST_SUITE_NAME diff --git a/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_bf16_persistent_compv3.cpp b/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_bf16_persistent_compv3.cpp new file mode 100644 index 0000000000..24385201a1 --- /dev/null +++ b/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_bf16_persistent_compv3.cpp @@ -0,0 +1,17 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "test_gemm_streamk_common_includes.hpp" + +template +class TestCkTileStreamKBf16PersistentCompV3 : public TestCkTileStreamK +{ +}; + +#define TEST_SUITE_NAME TestCkTileStreamKBf16PersistentCompV3 + +TYPED_TEST_SUITE(TestCkTileStreamKBf16PersistentCompV3, KernelTypesStreamKBf16PersistentCompV3); + +#include "test_gemm_streamk_extended_cases.inc" + +#undef TEST_SUITE_NAME diff --git a/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_bf16_nonpersistent.cpp b/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_bf16_persistent_mem.cpp similarity index 55% rename from test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_bf16_nonpersistent.cpp rename to test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_bf16_persistent_mem.cpp index 7c9c2c9657..94f9def529 100644 --- a/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_bf16_nonpersistent.cpp +++ b/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_bf16_persistent_mem.cpp @@ -4,13 +4,13 @@ #include "test_gemm_streamk_common_includes.hpp" template -class TestCkTileStreamKBf16NonPersistent : public TestCkTileStreamK +class TestCkTileStreamKBf16PersistentMem : public TestCkTileStreamK { }; -#define TEST_SUITE_NAME TestCkTileStreamKBf16NonPersistent +#define TEST_SUITE_NAME TestCkTileStreamKBf16PersistentMem -TYPED_TEST_SUITE(TestCkTileStreamKBf16NonPersistent, KernelTypesStreamKBf16NonPersistent); +TYPED_TEST_SUITE(TestCkTileStreamKBf16PersistentMem, KernelTypesStreamKBf16PersistentMem); #include "test_gemm_streamk_extended_cases.inc" diff --git a/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_bf8_nonpersistent_compv3.cpp b/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_bf8_nonpersistent_compv3.cpp new file mode 100644 index 0000000000..a0a04d79e2 --- /dev/null +++ b/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_bf8_nonpersistent_compv3.cpp @@ -0,0 +1,17 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "test_gemm_streamk_common_includes.hpp" + +template +class TestCkTileStreamKBf8NonPersistentCompV3 : public TestCkTileStreamK +{ +}; + +#define TEST_SUITE_NAME TestCkTileStreamKBf8NonPersistentCompV3 + +TYPED_TEST_SUITE(TestCkTileStreamKBf8NonPersistentCompV3, KernelTypesStreamKBf8NonPersistentCompV3); + +#include "test_gemm_streamk_extended_cases.inc" + +#undef TEST_SUITE_NAME diff --git a/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_bf16_persistent.cpp b/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_bf8_nonpersistent_mem.cpp similarity index 52% rename from test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_bf16_persistent.cpp rename to test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_bf8_nonpersistent_mem.cpp index dd4bbad61b..5a6447416d 100644 --- a/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_bf16_persistent.cpp +++ b/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_bf8_nonpersistent_mem.cpp @@ -4,13 +4,13 @@ #include "test_gemm_streamk_common_includes.hpp" template -class TestCkTileStreamKBf16Persistent : public TestCkTileStreamK +class TestCkTileStreamKBf8NonPersistentMem : public TestCkTileStreamK { }; -#define TEST_SUITE_NAME TestCkTileStreamKBf16Persistent +#define TEST_SUITE_NAME TestCkTileStreamKBf8NonPersistentMem -TYPED_TEST_SUITE(TestCkTileStreamKBf16Persistent, KernelTypesStreamKBf16Persistent); +TYPED_TEST_SUITE(TestCkTileStreamKBf8NonPersistentMem, KernelTypesStreamKBf8NonPersistentMem); #include "test_gemm_streamk_extended_cases.inc" diff --git a/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_bf8_persistent.cpp b/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_bf8_persistent_compv3.cpp similarity index 52% rename from test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_bf8_persistent.cpp rename to test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_bf8_persistent_compv3.cpp index 5f1bdaca86..0a6c2346d8 100644 --- a/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_bf8_persistent.cpp +++ b/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_bf8_persistent_compv3.cpp @@ -4,13 +4,13 @@ #include "test_gemm_streamk_common_includes.hpp" template -class TestCkTileStreamKBf8Persistent : public TestCkTileStreamK +class TestCkTileStreamKBf8PersistentCompV3 : public TestCkTileStreamK { }; -#define TEST_SUITE_NAME TestCkTileStreamKBf8Persistent +#define TEST_SUITE_NAME TestCkTileStreamKBf8PersistentCompV3 -TYPED_TEST_SUITE(TestCkTileStreamKBf8Persistent, KernelTypesStreamKBf8Persistent); +TYPED_TEST_SUITE(TestCkTileStreamKBf8PersistentCompV3, KernelTypesStreamKBf8PersistentCompV3); #include "test_gemm_streamk_extended_cases.inc" diff --git a/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_fp8_nonpersistent.cpp b/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_bf8_persistent_mem.cpp similarity index 55% rename from test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_fp8_nonpersistent.cpp rename to test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_bf8_persistent_mem.cpp index 0cdb4091d1..1eef56c971 100644 --- a/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_fp8_nonpersistent.cpp +++ b/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_bf8_persistent_mem.cpp @@ -4,13 +4,13 @@ #include "test_gemm_streamk_common_includes.hpp" template -class TestCkTileStreamKFp8NonPersistent : public TestCkTileStreamK +class TestCkTileStreamKBf8PersistentMem : public TestCkTileStreamK { }; -#define TEST_SUITE_NAME TestCkTileStreamKFp8NonPersistent +#define TEST_SUITE_NAME TestCkTileStreamKBf8PersistentMem -TYPED_TEST_SUITE(TestCkTileStreamKFp8NonPersistent, KernelTypesStreamKFp8NonPersistent); +TYPED_TEST_SUITE(TestCkTileStreamKBf8PersistentMem, KernelTypesStreamKBf8PersistentMem); #include "test_gemm_streamk_extended_cases.inc" diff --git a/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_fp16_nonpersistent_compv3.cpp b/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_fp16_nonpersistent_compv3.cpp new file mode 100644 index 0000000000..3381554d1e --- /dev/null +++ b/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_fp16_nonpersistent_compv3.cpp @@ -0,0 +1,18 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "test_gemm_streamk_common_includes.hpp" + +template +class TestCkTileStreamKFp16NonPersistentCompV3 : public TestCkTileStreamK +{ +}; + +#define TEST_SUITE_NAME TestCkTileStreamKFp16NonPersistentCompV3 + +TYPED_TEST_SUITE(TestCkTileStreamKFp16NonPersistentCompV3, + KernelTypesStreamKFp16NonPersistentCompV3); + +#include "test_gemm_streamk_extended_cases.inc" + +#undef TEST_SUITE_NAME diff --git a/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_fp16_nonpersistent_mem.cpp b/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_fp16_nonpersistent_mem.cpp new file mode 100644 index 0000000000..2f7dd7be33 --- /dev/null +++ b/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_fp16_nonpersistent_mem.cpp @@ -0,0 +1,17 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "test_gemm_streamk_common_includes.hpp" + +template +class TestCkTileStreamKFp16NonPersistentMem : public TestCkTileStreamK +{ +}; + +#define TEST_SUITE_NAME TestCkTileStreamKFp16NonPersistentMem + +TYPED_TEST_SUITE(TestCkTileStreamKFp16NonPersistentMem, KernelTypesStreamKFp16NonPersistentMem); + +#include "test_gemm_streamk_extended_cases.inc" + +#undef TEST_SUITE_NAME diff --git a/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_fp16_persistent.cpp b/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_fp16_persistent.cpp deleted file mode 100644 index 33b474526c..0000000000 --- a/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_fp16_persistent.cpp +++ /dev/null @@ -1,17 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include "test_gemm_streamk_common_includes.hpp" - -template -class TestCkTileStreamKFp16Persistent : public TestCkTileStreamK -{ -}; - -#define TEST_SUITE_NAME TestCkTileStreamKFp16Persistent - -TYPED_TEST_SUITE(TestCkTileStreamKFp16Persistent, KernelTypesStreamKFp16Persistent); - -#include "test_gemm_streamk_extended_cases.inc" - -#undef TEST_SUITE_NAME diff --git a/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_fp16_persistent_compv3.cpp b/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_fp16_persistent_compv3.cpp new file mode 100644 index 0000000000..3c041a3652 --- /dev/null +++ b/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_fp16_persistent_compv3.cpp @@ -0,0 +1,17 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "test_gemm_streamk_common_includes.hpp" + +template +class TestCkTileStreamKFp16PersistentCompV3 : public TestCkTileStreamK +{ +}; + +#define TEST_SUITE_NAME TestCkTileStreamKFp16PersistentCompV3 + +TYPED_TEST_SUITE(TestCkTileStreamKFp16PersistentCompV3, KernelTypesStreamKFp16PersistentCompV3); + +#include "test_gemm_streamk_extended_cases.inc" + +#undef TEST_SUITE_NAME diff --git a/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_fp16_nonpersistent.cpp b/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_fp16_persistent_mem.cpp similarity index 55% rename from test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_fp16_nonpersistent.cpp rename to test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_fp16_persistent_mem.cpp index f1a3bad142..c05135943f 100644 --- a/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_fp16_nonpersistent.cpp +++ b/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_fp16_persistent_mem.cpp @@ -4,13 +4,13 @@ #include "test_gemm_streamk_common_includes.hpp" template -class TestCkTileStreamKFp16NonPersistent : public TestCkTileStreamK +class TestCkTileStreamKFp16PersistentMem : public TestCkTileStreamK { }; -#define TEST_SUITE_NAME TestCkTileStreamKFp16NonPersistent +#define TEST_SUITE_NAME TestCkTileStreamKFp16PersistentMem -TYPED_TEST_SUITE(TestCkTileStreamKFp16NonPersistent, KernelTypesStreamKFp16NonPersistent); +TYPED_TEST_SUITE(TestCkTileStreamKFp16PersistentMem, KernelTypesStreamKFp16PersistentMem); #include "test_gemm_streamk_extended_cases.inc" diff --git a/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_fp8_nonpersistent_compv3.cpp b/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_fp8_nonpersistent_compv3.cpp new file mode 100644 index 0000000000..379702a10a --- /dev/null +++ b/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_fp8_nonpersistent_compv3.cpp @@ -0,0 +1,17 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "test_gemm_streamk_common_includes.hpp" + +template +class TestCkTileStreamKFp8NonPersistentCompV3 : public TestCkTileStreamK +{ +}; + +#define TEST_SUITE_NAME TestCkTileStreamKFp8NonPersistentCompV3 + +TYPED_TEST_SUITE(TestCkTileStreamKFp8NonPersistentCompV3, KernelTypesStreamKFp8NonPersistentCompV3); + +#include "test_gemm_streamk_extended_cases.inc" + +#undef TEST_SUITE_NAME diff --git a/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_fp8_nonpersistent_mem.cpp b/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_fp8_nonpersistent_mem.cpp new file mode 100644 index 0000000000..3d545a61c6 --- /dev/null +++ b/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_fp8_nonpersistent_mem.cpp @@ -0,0 +1,17 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "test_gemm_streamk_common_includes.hpp" + +template +class TestCkTileStreamKFp8NonPersistentMem : public TestCkTileStreamK +{ +}; + +#define TEST_SUITE_NAME TestCkTileStreamKFp8NonPersistentMem + +TYPED_TEST_SUITE(TestCkTileStreamKFp8NonPersistentMem, KernelTypesStreamKFp8NonPersistentMem); + +#include "test_gemm_streamk_extended_cases.inc" + +#undef TEST_SUITE_NAME diff --git a/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_fp8_persistent.cpp b/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_fp8_persistent.cpp deleted file mode 100644 index d418c889cd..0000000000 --- a/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_fp8_persistent.cpp +++ /dev/null @@ -1,17 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include "test_gemm_streamk_common_includes.hpp" - -template -class TestCkTileStreamKFp8Persistent : public TestCkTileStreamK -{ -}; - -#define TEST_SUITE_NAME TestCkTileStreamKFp8Persistent - -TYPED_TEST_SUITE(TestCkTileStreamKFp8Persistent, KernelTypesStreamKFp8Persistent); - -#include "test_gemm_streamk_extended_cases.inc" - -#undef TEST_SUITE_NAME diff --git a/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_fp8_persistent_compv3.cpp b/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_fp8_persistent_compv3.cpp new file mode 100644 index 0000000000..dccdcaf270 --- /dev/null +++ b/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_fp8_persistent_compv3.cpp @@ -0,0 +1,17 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "test_gemm_streamk_common_includes.hpp" + +template +class TestCkTileStreamKFp8PersistentCompV3 : public TestCkTileStreamK +{ +}; + +#define TEST_SUITE_NAME TestCkTileStreamKFp8PersistentCompV3 + +TYPED_TEST_SUITE(TestCkTileStreamKFp8PersistentCompV3, KernelTypesStreamKFp8PersistentCompV3); + +#include "test_gemm_streamk_extended_cases.inc" + +#undef TEST_SUITE_NAME diff --git a/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_bf8_nonpersistent.cpp b/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_fp8_persistent_mem.cpp similarity index 55% rename from test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_bf8_nonpersistent.cpp rename to test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_fp8_persistent_mem.cpp index 9b3b0fccb9..88ebdf1e55 100644 --- a/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_bf8_nonpersistent.cpp +++ b/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_fp8_persistent_mem.cpp @@ -4,13 +4,13 @@ #include "test_gemm_streamk_common_includes.hpp" template -class TestCkTileStreamKBf8NonPersistent : public TestCkTileStreamK +class TestCkTileStreamKFp8PersistentMem : public TestCkTileStreamK { }; -#define TEST_SUITE_NAME TestCkTileStreamKBf8NonPersistent +#define TEST_SUITE_NAME TestCkTileStreamKFp8PersistentMem -TYPED_TEST_SUITE(TestCkTileStreamKBf8NonPersistent, KernelTypesStreamKBf8NonPersistent); +TYPED_TEST_SUITE(TestCkTileStreamKFp8PersistentMem, KernelTypesStreamKFp8PersistentMem); #include "test_gemm_streamk_extended_cases.inc" diff --git a/test/ck_tile/gemm_streamk/test_gemm_streamk_types.hpp b/test/ck_tile/gemm_streamk/test_gemm_streamk_types.hpp index efb7416580..ca8ffee219 100644 --- a/test/ck_tile/gemm_streamk/test_gemm_streamk_types.hpp +++ b/test/ck_tile/gemm_streamk/test_gemm_streamk_types.hpp @@ -6,6 +6,7 @@ #include #include "gtest/gtest.h" #include "ck_tile/host.hpp" +#include "test_gemm_streamk_util.hpp" using F8 = ck_tile::fp8_t; using F16 = ck_tile::half_t; @@ -19,69 +20,131 @@ using Col = ck_tile::tensor_layout::gemm::ColumnMajor; using Persistent = std::true_type; using NonPersistent = std::false_type; +using Mem = ck_tile::integral_constant; +using CompV3 = ck_tile::integral_constant; + using I32 = ck_tile::number<32>; using I128 = ck_tile::number<128>; using I256 = ck_tile::number<256>; // clang-format off -using KernelTypesStreamKFp16Persistent = ::testing::Types< -// ALayout BLayout CLayout ADataType BDataType AccDataType CDataType M_MacroTile N_MacroTile K_MacroTile Persistent - std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, Persistent>, - std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, Persistent>, - std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, Persistent>, - std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, Persistent> +// ========================== CompV3 Pipeline ========================== + +using KernelTypesStreamKFp16PersistentCompV3 = ::testing::Types< +// ALayout BLayout CLayout ADataType BDataType AccDataType CDataType M_MacroTile N_MacroTile K_MacroTile Persistent Pipeline + + std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, Persistent, CompV3>, + std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, Persistent, CompV3>, + std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, Persistent, CompV3>, + std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, Persistent, CompV3> >; -using KernelTypesStreamKBf16Persistent = ::testing::Types< - std::tuple< Row, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent>, - std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent>, - std::tuple< Col, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent>, - std::tuple< Col, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent> +using KernelTypesStreamKBf16PersistentCompV3 = ::testing::Types< + std::tuple< Row, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent, CompV3>, + std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent, CompV3>, + std::tuple< Col, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent, CompV3>, + std::tuple< Col, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent, CompV3> >; -using KernelTypesStreamKBf8Persistent = ::testing::Types< - std::tuple< Row, Row, Row, BF8, BF8, F32, BF16, I128, I128, I32, Persistent>, - std::tuple< Row, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, Persistent>, - std::tuple< Col, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, Persistent>, - std::tuple< Col, Row, Row, BF8, BF8, F32, BF16, I128, I128, I32, Persistent> +using KernelTypesStreamKBf8PersistentCompV3 = ::testing::Types< + std::tuple< Row, Row, Row, BF8, BF8, F32, BF16, I128, I128, I32, Persistent, CompV3>, + std::tuple< Row, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, Persistent, CompV3>, + std::tuple< Col, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, Persistent, CompV3>, + std::tuple< Col, Row, Row, BF8, BF8, F32, BF16, I128, I128, I32, Persistent, CompV3> >; -using KernelTypesStreamKFp8Persistent = ::testing::Types< - std::tuple< Row, Row, Row, F8, F8, F32, F16, I128, I128, I32, Persistent>, - std::tuple< Row, Col, Row, F8, F8, F32, F16, I128, I128, I32, Persistent>, - std::tuple< Col, Col, Row, F8, F8, F32, F16, I128, I128, I32, Persistent>, - std::tuple< Col, Row, Row, F8, F8, F32, F16, I128, I128, I32, Persistent> +using KernelTypesStreamKFp8PersistentCompV3 = ::testing::Types< + std::tuple< Row, Row, Row, F8, F8, F32, F16, I128, I128, I32, Persistent, CompV3>, + std::tuple< Row, Col, Row, F8, F8, F32, F16, I128, I128, I32, Persistent, CompV3>, + std::tuple< Col, Col, Row, F8, F8, F32, F16, I128, I128, I32, Persistent, CompV3>, + std::tuple< Col, Row, Row, F8, F8, F32, F16, I128, I128, I32, Persistent, CompV3> >; -using KernelTypesStreamKFp16NonPersistent = ::testing::Types< -// ALayout BLayout CLayout ADataType BDataType AccDataType CDataType M_MacroTile N_MacroTile K_MacroTile Persistent - - std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent>, - std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent>, - std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent>, - std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent> +using KernelTypesStreamKFp16NonPersistentCompV3 = ::testing::Types< + std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent, CompV3>, + std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent, CompV3>, + std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent, CompV3>, + std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent, CompV3> >; -using KernelTypesStreamKBf16NonPersistent = ::testing::Types< - std::tuple< Row, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent>, - std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent>, - std::tuple< Col, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent>, - std::tuple< Col, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent> +using KernelTypesStreamKBf16NonPersistentCompV3 = ::testing::Types< + std::tuple< Row, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent, CompV3>, + std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent, CompV3>, + std::tuple< Col, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent, CompV3>, + std::tuple< Col, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent, CompV3> >; -using KernelTypesStreamKBf8NonPersistent = ::testing::Types< - std::tuple< Row, Row, Row, BF8, BF8, F32, BF16, I128, I128, I32, NonPersistent>, - std::tuple< Row, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, NonPersistent>, - std::tuple< Col, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, NonPersistent>, - std::tuple< Col, Row, Row, BF8, BF8, F32, BF16, I128, I128, I32, NonPersistent> +using KernelTypesStreamKBf8NonPersistentCompV3 = ::testing::Types< + std::tuple< Row, Row, Row, BF8, BF8, F32, BF16, I128, I128, I32, NonPersistent, CompV3>, + std::tuple< Row, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, NonPersistent, CompV3>, + std::tuple< Col, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, NonPersistent, CompV3>, + std::tuple< Col, Row, Row, BF8, BF8, F32, BF16, I128, I128, I32, NonPersistent, CompV3> >; -using KernelTypesStreamKFp8NonPersistent = ::testing::Types< - std::tuple< Row, Row, Row, F8, F8, F32, F16, I128, I128, I32, NonPersistent>, - std::tuple< Row, Col, Row, F8, F8, F32, F16, I128, I128, I32, NonPersistent>, - std::tuple< Col, Col, Row, F8, F8, F32, F16, I128, I128, I32, NonPersistent>, - std::tuple< Col, Row, Row, F8, F8, F32, F16, I128, I128, I32, NonPersistent> +using KernelTypesStreamKFp8NonPersistentCompV3 = ::testing::Types< + std::tuple< Row, Row, Row, F8, F8, F32, F16, I128, I128, I32, NonPersistent, CompV3>, + std::tuple< Row, Col, Row, F8, F8, F32, F16, I128, I128, I32, NonPersistent, CompV3>, + std::tuple< Col, Col, Row, F8, F8, F32, F16, I128, I128, I32, NonPersistent, CompV3>, + std::tuple< Col, Row, Row, F8, F8, F32, F16, I128, I128, I32, NonPersistent, CompV3> +>; + +// ============================= Mem Pipeline ============================= + +using KernelTypesStreamKFp16PersistentMem = ::testing::Types< + std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, Persistent, Mem>, + std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, Persistent, Mem>, + std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, Persistent, Mem>, + std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, Persistent, Mem> +>; + +using KernelTypesStreamKBf16PersistentMem = ::testing::Types< + std::tuple< Row, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent, Mem>, + std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent, Mem>, + std::tuple< Col, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent, Mem>, + std::tuple< Col, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent, Mem> +>; + +using KernelTypesStreamKBf8PersistentMem = ::testing::Types< + std::tuple< Row, Row, Row, BF8, BF8, F32, BF16, I128, I128, I32, Persistent, Mem>, + std::tuple< Row, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, Persistent, Mem>, + std::tuple< Col, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, Persistent, Mem>, + std::tuple< Col, Row, Row, BF8, BF8, F32, BF16, I128, I128, I32, Persistent, Mem> +>; + +using KernelTypesStreamKFp8PersistentMem = ::testing::Types< + std::tuple< Row, Row, Row, F8, F8, F32, F16, I128, I128, I32, Persistent, Mem>, + std::tuple< Row, Col, Row, F8, F8, F32, F16, I128, I128, I32, Persistent, Mem>, + std::tuple< Col, Col, Row, F8, F8, F32, F16, I128, I128, I32, Persistent, Mem>, + std::tuple< Col, Row, Row, F8, F8, F32, F16, I128, I128, I32, Persistent, Mem> +>; + +using KernelTypesStreamKFp16NonPersistentMem = ::testing::Types< + std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent, Mem>, + std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent, Mem>, + std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent, Mem>, + std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent, Mem> +>; + +using KernelTypesStreamKBf16NonPersistentMem = ::testing::Types< + std::tuple< Row, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent, Mem>, + std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent, Mem>, + std::tuple< Col, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent, Mem>, + std::tuple< Col, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent, Mem> +>; + +using KernelTypesStreamKBf8NonPersistentMem = ::testing::Types< + std::tuple< Row, Row, Row, BF8, BF8, F32, BF16, I128, I128, I32, NonPersistent, Mem>, + std::tuple< Row, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, NonPersistent, Mem>, + std::tuple< Col, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, NonPersistent, Mem>, + std::tuple< Col, Row, Row, BF8, BF8, F32, BF16, I128, I128, I32, NonPersistent, Mem> +>; + +using KernelTypesStreamKFp8NonPersistentMem = ::testing::Types< + std::tuple< Row, Row, Row, F8, F8, F32, F16, I128, I128, I32, NonPersistent, Mem>, + std::tuple< Row, Col, Row, F8, F8, F32, F16, I128, I128, I32, NonPersistent, Mem>, + std::tuple< Col, Col, Row, F8, F8, F32, F16, I128, I128, I32, NonPersistent, Mem>, + std::tuple< Col, Row, Row, F8, F8, F32, F16, I128, I128, I32, NonPersistent, Mem> >; // clang-format on diff --git a/test/ck_tile/gemm_streamk/test_gemm_streamk_util.hpp b/test/ck_tile/gemm_streamk/test_gemm_streamk_util.hpp index 5e3b85c009..af1bab34bf 100644 --- a/test/ck_tile/gemm_streamk/test_gemm_streamk_util.hpp +++ b/test/ck_tile/gemm_streamk/test_gemm_streamk_util.hpp @@ -11,6 +11,27 @@ #include "ck_tile/ops/epilogue.hpp" #include "ck_tile/ops/gemm.hpp" +enum struct GemmPipelineType +{ + Mem, + CompV3 +}; + +template +struct GemmPipelineTypeSelector; + +template +struct GemmPipelineTypeSelector +{ + using pipeline = ck_tile::GemmPipelineAgBgCrMem; +}; + +template +struct GemmPipelineTypeSelector +{ + using pipeline = ck_tile::GemmPipelineAgBgCrCompV3; +}; + template auto calculate_rtol_atol(const ck_tile::index_t K, const ck_tile::index_t kbatch, @@ -56,6 +77,7 @@ class TestCkTileStreamK : public ::testing::Test static constexpr ck_tile::index_t N_Tile = std::tuple_element_t<8, Tuple>::value; static constexpr ck_tile::index_t K_Tile = std::tuple_element_t<9, Tuple>::value; static constexpr bool Persistent = std::tuple_element_t<10, Tuple>::value; + static constexpr auto PipelineType = std::tuple_element_t<11, Tuple>::value; template ; - // For initial testing, we will just test with one pipeline. - // More extensive testing is coming later and will test other pipelines. - using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; + + using GemmPipeline = GemmPipelineTypeSelector::pipeline; using GemmEpilogue = ck_tile::CShuffleEpilogue< ck_tile::CShuffleEpilogueProblem( + ck_tile::index_t num_accumulations_per_tile = + invoke_streamk( args, ck_tile::stream_config{nullptr, false, 0, 0, 1}); - } - else if(reduction_strategy == ck_tile::StreamKReductionStrategy::Linear) - { - num_accumulations_per_tile = invoke_streamk( - args, ck_tile::stream_config{nullptr, false, 0, 0, 1}); - } - else - { - num_accumulations_per_tile = invoke_streamk( - args, ck_tile::stream_config{nullptr, false, 0, 0, 1}); - } c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data()); diff --git a/test/ck_tile/gemm_streamk_tile_engine/generate_configs.py b/test/ck_tile/gemm_streamk_tile_engine/generate_configs.py index 166ea940c3..0f2673c6dd 100644 --- a/test/ck_tile/gemm_streamk_tile_engine/generate_configs.py +++ b/test/ck_tile/gemm_streamk_tile_engine/generate_configs.py @@ -33,7 +33,7 @@ class TileConfig: class TraitConfig: """Represents the Trait Config section of a Tile Engine config""" - pipeline: List[str] = field(default_factory=lambda: ["compv3"]) + pipeline: List[str] = field(default_factory=lambda: ["compv3", "mem"]) epilogue: List[str] = field(default_factory=lambda: ["cshuffle"]) scheduler: List[str] = field(default_factory=lambda: ["intrawave"]) pad_m: List[bool] = field(default_factory=lambda: [False])