mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 21:51:28 +00:00
[rocm-libraries] ROCm/rocm-libraries#5445 (commit 2cdbf8b)
[CK_TILE] Support for CompV4 pipeline in Stream-K GEMM (#5445) ## Motivation This PR is extending the pipeline support for Stream-K GEMM by adding the CompV4 pipeline. Additional pipelines will be added in subsequent PRs. ## Technical Details - Enable the CompV4 pipeline by adding an option to set DoubleSMemBuffer to true if the CompV4 pipeline has been selected as it requires double buffered shared memory - Addition of CompV4 pipeline into the extended tests: kernel instances mirror the existing CompV3/Mem configurations (same layout permutations, data types, and tile sizes) with the pipeline type set to CompV4. - Addition of CompV4 pipeline into smoke tests (generated using Tile Engine) ## Test Plan These were tested using the existing smoke and extended tests. ## Test Result All tests passed ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
47a04fda08
commit
36f2ec23f5
@@ -25,12 +25,16 @@ if(GPU_TARGETS MATCHES "gfx90a|gfx942|gfx950")
|
||||
add_gtest_executable(test_ck_tile_streamk_tile_partitioner test_streamk_tile_partitioner.cpp)
|
||||
set(STREAMK_EXTENDED_SOURCES
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp16_persistent_compv3.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp16_persistent_compv4.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp16_persistent_mem.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf16_persistent_compv3.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf16_persistent_compv4.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf16_persistent_mem.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp16_nonpersistent_compv3.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp16_nonpersistent_compv4.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp16_nonpersistent_mem.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf16_nonpersistent_compv3.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf16_nonpersistent_compv4.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf16_nonpersistent_mem.cpp
|
||||
test_gemm_streamk_util.cpp)
|
||||
|
||||
@@ -38,12 +42,16 @@ if(GPU_TARGETS MATCHES "gfx90a|gfx942|gfx950")
|
||||
if(GPU_TARGETS MATCHES "gfx942|gfx950")
|
||||
list(APPEND STREAMK_EXTENDED_SOURCES
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp8_persistent_compv3.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp8_persistent_compv4.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp8_persistent_mem.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf8_persistent_compv3.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf8_persistent_compv4.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf8_persistent_mem.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp8_nonpersistent_compv3.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp8_nonpersistent_compv4.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp8_nonpersistent_mem.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf8_nonpersistent_compv3.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf8_nonpersistent_compv4.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf8_nonpersistent_mem.cpp)
|
||||
endif()
|
||||
|
||||
|
||||
@@ -0,0 +1,18 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestCkTileStreamKBf16NonPersistentCompV4 : public TestCkTileStreamK<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileStreamKBf16NonPersistentCompV4
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileStreamKBf16NonPersistentCompV4,
|
||||
KernelTypesStreamKBf16NonPersistentCompV4);
|
||||
|
||||
#include "test_gemm_streamk_extended_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
@@ -0,0 +1,17 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestCkTileStreamKBf16PersistentCompV4 : public TestCkTileStreamK<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileStreamKBf16PersistentCompV4
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileStreamKBf16PersistentCompV4, KernelTypesStreamKBf16PersistentCompV4);
|
||||
|
||||
#include "test_gemm_streamk_extended_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
@@ -0,0 +1,17 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestCkTileStreamKBf8NonPersistentCompV4 : public TestCkTileStreamK<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileStreamKBf8NonPersistentCompV4
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileStreamKBf8NonPersistentCompV4, KernelTypesStreamKBf8NonPersistentCompV4);
|
||||
|
||||
#include "test_gemm_streamk_extended_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
@@ -0,0 +1,17 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestCkTileStreamKBf8PersistentCompV4 : public TestCkTileStreamK<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileStreamKBf8PersistentCompV4
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileStreamKBf8PersistentCompV4, KernelTypesStreamKBf8PersistentCompV4);
|
||||
|
||||
#include "test_gemm_streamk_extended_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
@@ -0,0 +1,18 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestCkTileStreamKFp16NonPersistentCompV4 : public TestCkTileStreamK<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileStreamKFp16NonPersistentCompV4
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileStreamKFp16NonPersistentCompV4,
|
||||
KernelTypesStreamKFp16NonPersistentCompV4);
|
||||
|
||||
#include "test_gemm_streamk_extended_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
@@ -0,0 +1,17 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestCkTileStreamKFp16PersistentCompV4 : public TestCkTileStreamK<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileStreamKFp16PersistentCompV4
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileStreamKFp16PersistentCompV4, KernelTypesStreamKFp16PersistentCompV4);
|
||||
|
||||
#include "test_gemm_streamk_extended_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
@@ -0,0 +1,17 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestCkTileStreamKFp8NonPersistentCompV4 : public TestCkTileStreamK<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileStreamKFp8NonPersistentCompV4
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileStreamKFp8NonPersistentCompV4, KernelTypesStreamKFp8NonPersistentCompV4);
|
||||
|
||||
#include "test_gemm_streamk_extended_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
@@ -0,0 +1,17 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestCkTileStreamKFp8PersistentCompV4 : public TestCkTileStreamK<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileStreamKFp8PersistentCompV4
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileStreamKFp8PersistentCompV4, KernelTypesStreamKFp8PersistentCompV4);
|
||||
|
||||
#include "test_gemm_streamk_extended_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
@@ -17,11 +17,12 @@ using F32 = float;
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using Persistent = std::true_type;
|
||||
using NonPersistent = std::false_type;
|
||||
|
||||
using Mem = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::Mem>;
|
||||
using CompV3 = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::CompV3>;
|
||||
using CompV4 = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::CompV4>;
|
||||
|
||||
using Persistent = std::true_type;
|
||||
using NonPersistent = std::false_type;
|
||||
|
||||
using I32 = ck_tile::number<32>;
|
||||
using I128 = ck_tile::number<128>;
|
||||
@@ -89,6 +90,66 @@ using KernelTypesStreamKFp8NonPersistentCompV3 = ::testing::Types<
|
||||
std::tuple< Col, Row, Row, F8, F8, F32, F16, I128, I128, I32, NonPersistent, CompV3>
|
||||
>;
|
||||
|
||||
// ========================== CompV4 Pipeline ==========================
|
||||
|
||||
using KernelTypesStreamKFp16PersistentCompV4 = ::testing::Types<
|
||||
// ALayout BLayout CLayout ADataType BDataType AccDataType CDataType M_MacroTile N_MacroTile K_MacroTile Persistent Pipeline
|
||||
|
||||
std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, Persistent, CompV4>,
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, Persistent, CompV4>,
|
||||
std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, Persistent, CompV4>,
|
||||
std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, Persistent, CompV4>
|
||||
>;
|
||||
|
||||
using KernelTypesStreamKBf16PersistentCompV4 = ::testing::Types<
|
||||
std::tuple< Row, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent, CompV4>,
|
||||
std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent, CompV4>,
|
||||
std::tuple< Col, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent, CompV4>,
|
||||
std::tuple< Col, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent, CompV4>
|
||||
>;
|
||||
|
||||
using KernelTypesStreamKBf8PersistentCompV4 = ::testing::Types<
|
||||
std::tuple< Row, Row, Row, BF8, BF8, F32, BF16, I128, I128, I32, Persistent, CompV4>,
|
||||
std::tuple< Row, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, Persistent, CompV4>,
|
||||
std::tuple< Col, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, Persistent, CompV4>,
|
||||
std::tuple< Col, Row, Row, BF8, BF8, F32, BF16, I128, I128, I32, Persistent, CompV4>
|
||||
>;
|
||||
|
||||
using KernelTypesStreamKFp8PersistentCompV4 = ::testing::Types<
|
||||
std::tuple< Row, Row, Row, F8, F8, F32, F16, I128, I128, I32, Persistent, CompV4>,
|
||||
std::tuple< Row, Col, Row, F8, F8, F32, F16, I128, I128, I32, Persistent, CompV4>,
|
||||
std::tuple< Col, Col, Row, F8, F8, F32, F16, I128, I128, I32, Persistent, CompV4>,
|
||||
std::tuple< Col, Row, Row, F8, F8, F32, F16, I128, I128, I32, Persistent, CompV4>
|
||||
>;
|
||||
|
||||
using KernelTypesStreamKFp16NonPersistentCompV4 = ::testing::Types<
|
||||
std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent, CompV4>,
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent, CompV4>,
|
||||
std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent, CompV4>,
|
||||
std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent, CompV4>
|
||||
>;
|
||||
|
||||
using KernelTypesStreamKBf16NonPersistentCompV4 = ::testing::Types<
|
||||
std::tuple< Row, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent, CompV4>,
|
||||
std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent, CompV4>,
|
||||
std::tuple< Col, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent, CompV4>,
|
||||
std::tuple< Col, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent, CompV4>
|
||||
>;
|
||||
|
||||
using KernelTypesStreamKBf8NonPersistentCompV4 = ::testing::Types<
|
||||
std::tuple< Row, Row, Row, BF8, BF8, F32, BF16, I128, I128, I32, NonPersistent, CompV4>,
|
||||
std::tuple< Row, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, NonPersistent, CompV4>,
|
||||
std::tuple< Col, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, NonPersistent, CompV4>,
|
||||
std::tuple< Col, Row, Row, BF8, BF8, F32, BF16, I128, I128, I32, NonPersistent, CompV4>
|
||||
>;
|
||||
|
||||
using KernelTypesStreamKFp8NonPersistentCompV4 = ::testing::Types<
|
||||
std::tuple< Row, Row, Row, F8, F8, F32, F16, I128, I128, I32, NonPersistent, CompV4>,
|
||||
std::tuple< Row, Col, Row, F8, F8, F32, F16, I128, I128, I32, NonPersistent, CompV4>,
|
||||
std::tuple< Col, Col, Row, F8, F8, F32, F16, I128, I128, I32, NonPersistent, CompV4>,
|
||||
std::tuple< Col, Row, Row, F8, F8, F32, F16, I128, I128, I32, NonPersistent, CompV4>
|
||||
>;
|
||||
|
||||
// ============================= Mem Pipeline =============================
|
||||
|
||||
using KernelTypesStreamKFp16PersistentMem = ::testing::Types<
|
||||
|
||||
@@ -14,7 +14,8 @@
|
||||
enum struct GemmPipelineType
|
||||
{
|
||||
Mem,
|
||||
CompV3
|
||||
CompV3,
|
||||
CompV4
|
||||
};
|
||||
|
||||
template <GemmPipelineType PT, typename Problem>
|
||||
@@ -32,6 +33,12 @@ struct GemmPipelineTypeSelector<GemmPipelineType::CompV3, Problem>
|
||||
using pipeline = ck_tile::GemmPipelineAgBgCrCompV3<Problem>;
|
||||
};
|
||||
|
||||
template <typename Problem>
|
||||
struct GemmPipelineTypeSelector<GemmPipelineType::CompV4, Problem>
|
||||
{
|
||||
using pipeline = ck_tile::GemmPipelineAgBgCrCompV4<Problem>;
|
||||
};
|
||||
|
||||
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
|
||||
auto calculate_rtol_atol(const ck_tile::index_t K,
|
||||
const ck_tile::index_t kbatch,
|
||||
@@ -101,8 +108,8 @@ class TestCkTileStreamK : public ::testing::Test
|
||||
constexpr bool kPadK = PadK;
|
||||
constexpr bool preshuffle = Preshuffle;
|
||||
|
||||
constexpr bool DoubleSmemBuffer = false;
|
||||
constexpr int kBlockPerCu = 1;
|
||||
constexpr bool DoubleSmemBuffer = (PipelineType == GemmPipelineType::CompV4) ? true : false;
|
||||
constexpr int kBlockPerCu = 1;
|
||||
constexpr bool StructuredSparsity = false;
|
||||
constexpr bool NumWaveGroup = 1;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user