From cfd274a1844e60cabd5c1f11ee1c0dfd1776abc1 Mon Sep 17 00:00:00 2001 From: arai713 <67439843+arai713@users.noreply.github.com> Date: Fri, 27 Mar 2026 01:12:09 -0700 Subject: [PATCH] [CK_TILE] Support for CompV4 pipeline in Stream-K GEMM (#5445) ## Motivation This PR is extending the pipeline support for Stream-K GEMM by adding the CompV4 pipeline. Additional pipelines will be added in subsequent PRs. ## Technical Details - Enable the CompV4 pipeline by adding an option to set DoubleSMemBuffer to true if the CompV4 pipeline has been selected as it requires double buffered shared memory - Addition of CompV4 pipeline into the extended tests: kernel instances mirror the existing CompV3/Mem configurations (same layout permutations, data types, and tile sizes) with the pipeline type set to CompV4. - Addition of CompV4 pipeline into smoke tests (generated using Tile Engine) ## Test Plan These were tested using the existing smoke and extended tests. ## Test Result All tests passed ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. --------- Co-authored-by: Christopher Millette <63608002+cgmillette@users.noreply.github.com> --- test/ck_tile/gemm_streamk/CMakeLists.txt | 8 +++ ...gemm_streamk_bf16_nonpersistent_compv4.cpp | 18 +++++ ...st_gemm_streamk_bf16_persistent_compv4.cpp | 17 +++++ ..._gemm_streamk_bf8_nonpersistent_compv4.cpp | 17 +++++ ...est_gemm_streamk_bf8_persistent_compv4.cpp | 17 +++++ ...gemm_streamk_fp16_nonpersistent_compv4.cpp | 18 +++++ ...st_gemm_streamk_fp16_persistent_compv4.cpp | 17 +++++ ..._gemm_streamk_fp8_nonpersistent_compv4.cpp | 17 +++++ ...est_gemm_streamk_fp8_persistent_compv4.cpp | 17 +++++ .../gemm_streamk/test_gemm_streamk_types.hpp | 67 ++++++++++++++++++- .../gemm_streamk/test_gemm_streamk_util.hpp | 13 +++- 11 files changed, 220 insertions(+), 6 deletions(-) create mode 100644 test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_bf16_nonpersistent_compv4.cpp create mode 100644 test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_bf16_persistent_compv4.cpp create mode 100644 test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_bf8_nonpersistent_compv4.cpp create mode 100644 test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_bf8_persistent_compv4.cpp create mode 100644 test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_fp16_nonpersistent_compv4.cpp create mode 100644 test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_fp16_persistent_compv4.cpp create mode 100644 test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_fp8_nonpersistent_compv4.cpp create mode 100644 test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_fp8_persistent_compv4.cpp diff --git a/test/ck_tile/gemm_streamk/CMakeLists.txt b/test/ck_tile/gemm_streamk/CMakeLists.txt index f6eb33bf76..636900db8e 100644 --- a/test/ck_tile/gemm_streamk/CMakeLists.txt +++ b/test/ck_tile/gemm_streamk/CMakeLists.txt @@ -25,12 +25,16 @@ if(GPU_TARGETS MATCHES "gfx90a|gfx942|gfx950") add_gtest_executable(test_ck_tile_streamk_tile_partitioner test_streamk_tile_partitioner.cpp) set(STREAMK_EXTENDED_SOURCES ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp16_persistent_compv3.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp16_persistent_compv4.cpp ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp16_persistent_mem.cpp ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf16_persistent_compv3.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf16_persistent_compv4.cpp ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf16_persistent_mem.cpp ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp16_nonpersistent_compv3.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp16_nonpersistent_compv4.cpp ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp16_nonpersistent_mem.cpp ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf16_nonpersistent_compv3.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf16_nonpersistent_compv4.cpp ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf16_nonpersistent_mem.cpp test_gemm_streamk_util.cpp) @@ -38,12 +42,16 @@ if(GPU_TARGETS MATCHES "gfx90a|gfx942|gfx950") if(GPU_TARGETS MATCHES "gfx942|gfx950") list(APPEND STREAMK_EXTENDED_SOURCES ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp8_persistent_compv3.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp8_persistent_compv4.cpp ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp8_persistent_mem.cpp ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf8_persistent_compv3.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf8_persistent_compv4.cpp ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf8_persistent_mem.cpp ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp8_nonpersistent_compv3.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp8_nonpersistent_compv4.cpp ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp8_nonpersistent_mem.cpp ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf8_nonpersistent_compv3.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf8_nonpersistent_compv4.cpp ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf8_nonpersistent_mem.cpp) endif() diff --git a/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_bf16_nonpersistent_compv4.cpp b/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_bf16_nonpersistent_compv4.cpp new file mode 100644 index 0000000000..e0e1b30065 --- /dev/null +++ b/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_bf16_nonpersistent_compv4.cpp @@ -0,0 +1,18 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "test_gemm_streamk_common_includes.hpp" + +template +class TestCkTileStreamKBf16NonPersistentCompV4 : public TestCkTileStreamK +{ +}; + +#define TEST_SUITE_NAME TestCkTileStreamKBf16NonPersistentCompV4 + +TYPED_TEST_SUITE(TestCkTileStreamKBf16NonPersistentCompV4, + KernelTypesStreamKBf16NonPersistentCompV4); + +#include "test_gemm_streamk_extended_cases.inc" + +#undef TEST_SUITE_NAME diff --git a/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_bf16_persistent_compv4.cpp b/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_bf16_persistent_compv4.cpp new file mode 100644 index 0000000000..2c7a40cea9 --- /dev/null +++ b/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_bf16_persistent_compv4.cpp @@ -0,0 +1,17 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "test_gemm_streamk_common_includes.hpp" + +template +class TestCkTileStreamKBf16PersistentCompV4 : public TestCkTileStreamK +{ +}; + +#define TEST_SUITE_NAME TestCkTileStreamKBf16PersistentCompV4 + +TYPED_TEST_SUITE(TestCkTileStreamKBf16PersistentCompV4, KernelTypesStreamKBf16PersistentCompV4); + +#include "test_gemm_streamk_extended_cases.inc" + +#undef TEST_SUITE_NAME diff --git a/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_bf8_nonpersistent_compv4.cpp b/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_bf8_nonpersistent_compv4.cpp new file mode 100644 index 0000000000..5fada00248 --- /dev/null +++ b/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_bf8_nonpersistent_compv4.cpp @@ -0,0 +1,17 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "test_gemm_streamk_common_includes.hpp" + +template +class TestCkTileStreamKBf8NonPersistentCompV4 : public TestCkTileStreamK +{ +}; + +#define TEST_SUITE_NAME TestCkTileStreamKBf8NonPersistentCompV4 + +TYPED_TEST_SUITE(TestCkTileStreamKBf8NonPersistentCompV4, KernelTypesStreamKBf8NonPersistentCompV4); + +#include "test_gemm_streamk_extended_cases.inc" + +#undef TEST_SUITE_NAME diff --git a/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_bf8_persistent_compv4.cpp b/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_bf8_persistent_compv4.cpp new file mode 100644 index 0000000000..cd48886f84 --- /dev/null +++ b/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_bf8_persistent_compv4.cpp @@ -0,0 +1,17 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "test_gemm_streamk_common_includes.hpp" + +template +class TestCkTileStreamKBf8PersistentCompV4 : public TestCkTileStreamK +{ +}; + +#define TEST_SUITE_NAME TestCkTileStreamKBf8PersistentCompV4 + +TYPED_TEST_SUITE(TestCkTileStreamKBf8PersistentCompV4, KernelTypesStreamKBf8PersistentCompV4); + +#include "test_gemm_streamk_extended_cases.inc" + +#undef TEST_SUITE_NAME diff --git a/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_fp16_nonpersistent_compv4.cpp b/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_fp16_nonpersistent_compv4.cpp new file mode 100644 index 0000000000..e6b632b0b1 --- /dev/null +++ b/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_fp16_nonpersistent_compv4.cpp @@ -0,0 +1,18 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "test_gemm_streamk_common_includes.hpp" + +template +class TestCkTileStreamKFp16NonPersistentCompV4 : public TestCkTileStreamK +{ +}; + +#define TEST_SUITE_NAME TestCkTileStreamKFp16NonPersistentCompV4 + +TYPED_TEST_SUITE(TestCkTileStreamKFp16NonPersistentCompV4, + KernelTypesStreamKFp16NonPersistentCompV4); + +#include "test_gemm_streamk_extended_cases.inc" + +#undef TEST_SUITE_NAME diff --git a/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_fp16_persistent_compv4.cpp b/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_fp16_persistent_compv4.cpp new file mode 100644 index 0000000000..8117a7ce96 --- /dev/null +++ b/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_fp16_persistent_compv4.cpp @@ -0,0 +1,17 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "test_gemm_streamk_common_includes.hpp" + +template +class TestCkTileStreamKFp16PersistentCompV4 : public TestCkTileStreamK +{ +}; + +#define TEST_SUITE_NAME TestCkTileStreamKFp16PersistentCompV4 + +TYPED_TEST_SUITE(TestCkTileStreamKFp16PersistentCompV4, KernelTypesStreamKFp16PersistentCompV4); + +#include "test_gemm_streamk_extended_cases.inc" + +#undef TEST_SUITE_NAME diff --git a/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_fp8_nonpersistent_compv4.cpp b/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_fp8_nonpersistent_compv4.cpp new file mode 100644 index 0000000000..bf4dfc30f8 --- /dev/null +++ b/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_fp8_nonpersistent_compv4.cpp @@ -0,0 +1,17 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "test_gemm_streamk_common_includes.hpp" + +template +class TestCkTileStreamKFp8NonPersistentCompV4 : public TestCkTileStreamK +{ +}; + +#define TEST_SUITE_NAME TestCkTileStreamKFp8NonPersistentCompV4 + +TYPED_TEST_SUITE(TestCkTileStreamKFp8NonPersistentCompV4, KernelTypesStreamKFp8NonPersistentCompV4); + +#include "test_gemm_streamk_extended_cases.inc" + +#undef TEST_SUITE_NAME diff --git a/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_fp8_persistent_compv4.cpp b/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_fp8_persistent_compv4.cpp new file mode 100644 index 0000000000..8cbab5c8f8 --- /dev/null +++ b/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_fp8_persistent_compv4.cpp @@ -0,0 +1,17 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "test_gemm_streamk_common_includes.hpp" + +template +class TestCkTileStreamKFp8PersistentCompV4 : public TestCkTileStreamK +{ +}; + +#define TEST_SUITE_NAME TestCkTileStreamKFp8PersistentCompV4 + +TYPED_TEST_SUITE(TestCkTileStreamKFp8PersistentCompV4, KernelTypesStreamKFp8PersistentCompV4); + +#include "test_gemm_streamk_extended_cases.inc" + +#undef TEST_SUITE_NAME diff --git a/test/ck_tile/gemm_streamk/test_gemm_streamk_types.hpp b/test/ck_tile/gemm_streamk/test_gemm_streamk_types.hpp index ca8ffee219..bfe236b37f 100644 --- a/test/ck_tile/gemm_streamk/test_gemm_streamk_types.hpp +++ b/test/ck_tile/gemm_streamk/test_gemm_streamk_types.hpp @@ -17,11 +17,12 @@ using F32 = float; using Row = ck_tile::tensor_layout::gemm::RowMajor; using Col = ck_tile::tensor_layout::gemm::ColumnMajor; -using Persistent = std::true_type; -using NonPersistent = std::false_type; - using Mem = ck_tile::integral_constant; using CompV3 = ck_tile::integral_constant; +using CompV4 = ck_tile::integral_constant; + +using Persistent = std::true_type; +using NonPersistent = std::false_type; using I32 = ck_tile::number<32>; using I128 = ck_tile::number<128>; @@ -89,6 +90,66 @@ using KernelTypesStreamKFp8NonPersistentCompV3 = ::testing::Types< std::tuple< Col, Row, Row, F8, F8, F32, F16, I128, I128, I32, NonPersistent, CompV3> >; +// ========================== CompV4 Pipeline ========================== + +using KernelTypesStreamKFp16PersistentCompV4 = ::testing::Types< +// ALayout BLayout CLayout ADataType BDataType AccDataType CDataType M_MacroTile N_MacroTile K_MacroTile Persistent Pipeline + + std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, Persistent, CompV4>, + std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, Persistent, CompV4>, + std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, Persistent, CompV4>, + std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, Persistent, CompV4> +>; + +using KernelTypesStreamKBf16PersistentCompV4 = ::testing::Types< + std::tuple< Row, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent, CompV4>, + std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent, CompV4>, + std::tuple< Col, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent, CompV4>, + std::tuple< Col, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent, CompV4> +>; + +using KernelTypesStreamKBf8PersistentCompV4 = ::testing::Types< + std::tuple< Row, Row, Row, BF8, BF8, F32, BF16, I128, I128, I32, Persistent, CompV4>, + std::tuple< Row, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, Persistent, CompV4>, + std::tuple< Col, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, Persistent, CompV4>, + std::tuple< Col, Row, Row, BF8, BF8, F32, BF16, I128, I128, I32, Persistent, CompV4> +>; + +using KernelTypesStreamKFp8PersistentCompV4 = ::testing::Types< + std::tuple< Row, Row, Row, F8, F8, F32, F16, I128, I128, I32, Persistent, CompV4>, + std::tuple< Row, Col, Row, F8, F8, F32, F16, I128, I128, I32, Persistent, CompV4>, + std::tuple< Col, Col, Row, F8, F8, F32, F16, I128, I128, I32, Persistent, CompV4>, + std::tuple< Col, Row, Row, F8, F8, F32, F16, I128, I128, I32, Persistent, CompV4> +>; + +using KernelTypesStreamKFp16NonPersistentCompV4 = ::testing::Types< + std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent, CompV4>, + std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent, CompV4>, + std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent, CompV4>, + std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent, CompV4> +>; + +using KernelTypesStreamKBf16NonPersistentCompV4 = ::testing::Types< + std::tuple< Row, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent, CompV4>, + std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent, CompV4>, + std::tuple< Col, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent, CompV4>, + std::tuple< Col, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent, CompV4> +>; + +using KernelTypesStreamKBf8NonPersistentCompV4 = ::testing::Types< + std::tuple< Row, Row, Row, BF8, BF8, F32, BF16, I128, I128, I32, NonPersistent, CompV4>, + std::tuple< Row, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, NonPersistent, CompV4>, + std::tuple< Col, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, NonPersistent, CompV4>, + std::tuple< Col, Row, Row, BF8, BF8, F32, BF16, I128, I128, I32, NonPersistent, CompV4> +>; + +using KernelTypesStreamKFp8NonPersistentCompV4 = ::testing::Types< + std::tuple< Row, Row, Row, F8, F8, F32, F16, I128, I128, I32, NonPersistent, CompV4>, + std::tuple< Row, Col, Row, F8, F8, F32, F16, I128, I128, I32, NonPersistent, CompV4>, + std::tuple< Col, Col, Row, F8, F8, F32, F16, I128, I128, I32, NonPersistent, CompV4>, + std::tuple< Col, Row, Row, F8, F8, F32, F16, I128, I128, I32, NonPersistent, CompV4> +>; + // ============================= Mem Pipeline ============================= using KernelTypesStreamKFp16PersistentMem = ::testing::Types< diff --git a/test/ck_tile/gemm_streamk/test_gemm_streamk_util.hpp b/test/ck_tile/gemm_streamk/test_gemm_streamk_util.hpp index af1bab34bf..8ae1f27e5c 100644 --- a/test/ck_tile/gemm_streamk/test_gemm_streamk_util.hpp +++ b/test/ck_tile/gemm_streamk/test_gemm_streamk_util.hpp @@ -14,7 +14,8 @@ enum struct GemmPipelineType { Mem, - CompV3 + CompV3, + CompV4 }; template @@ -32,6 +33,12 @@ struct GemmPipelineTypeSelector using pipeline = ck_tile::GemmPipelineAgBgCrCompV3; }; +template +struct GemmPipelineTypeSelector +{ + using pipeline = ck_tile::GemmPipelineAgBgCrCompV4; +}; + template auto calculate_rtol_atol(const ck_tile::index_t K, const ck_tile::index_t kbatch, @@ -101,8 +108,8 @@ class TestCkTileStreamK : public ::testing::Test constexpr bool kPadK = PadK; constexpr bool preshuffle = Preshuffle; - constexpr bool DoubleSmemBuffer = false; - constexpr int kBlockPerCu = 1; + constexpr bool DoubleSmemBuffer = (PipelineType == GemmPipelineType::CompV4) ? true : false; + constexpr int kBlockPerCu = 1; constexpr bool StructuredSparsity = false; constexpr bool NumWaveGroup = 1;