[rocm-libraries] ROCm/rocm-libraries#5242 (commit ea9a066)

[CK_TILE] Add the GEMM Memory pipeline to Stream-K tests
 (#5242)

## Motivation

We want to extend our Stream-K coverage to include other GEMM pipeline
since our current tests only test the CompV3 pipeline.

## Technical Details

All Stream-K unit tests currently only tests one pipeline: CompV3. These
changes extend the test support to also test the Memory pipeline. Future
work will add support for additional GEMM pipelines.

The major changes are as follows:
- **Remove of fp8 and bf8 extended tests for gfx90a**: gfx90a does not
have native support for fp8 and bf8 and emulate the behavior with fp32
mfma instruction sizes. We've observed extremely long compile times for
fp8 and bf8 on gfx90a (exceeding 15 minutes), hence we've opted to
disable these tests.
- **Add the memory pipeline to the Stream-K tile engine tests**: Now our
smoke tests covers compv3 and memory pipelines.
- **Add the memory pipeline to the Stream-K extended tests**: These
changes modify the test kernel types to include the appropriate
pipeline. Each pipeline is contained within a separate kernel type to
help avoid large increases in build time.

## Test Plan

- Ran existing and added tests on all architectures.

## Test Result

- All local tests pass.

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
Emily Martins
2026-03-11 21:09:56 +00:00
committed by assistant-librarian[bot]
parent 56e1d5da08
commit c1f2d8166d
22 changed files with 351 additions and 128 deletions

View File

@@ -23,16 +23,31 @@ if(GPU_TARGETS MATCHES "gfx90a|gfx942|gfx950")
#TODO: support all arches
#TODO: current c-shuffle only supports C layout as R
add_gtest_executable(test_ck_tile_streamk_tile_partitioner test_streamk_tile_partitioner.cpp)
add_gtest_executable(test_ck_tile_streamk_extended
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp16_persistent.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf16_persistent.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp8_persistent.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf8_persistent.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp16_nonpersistent.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf16_nonpersistent.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp8_nonpersistent.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf8_nonpersistent.cpp
test_gemm_streamk_util.cpp)
set(STREAMK_EXTENDED_SOURCES
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp16_persistent_compv3.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp16_persistent_mem.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf16_persistent_compv3.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf16_persistent_mem.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp16_nonpersistent_compv3.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp16_nonpersistent_mem.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf16_nonpersistent_compv3.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf16_nonpersistent_mem.cpp
test_gemm_streamk_util.cpp)
# We only test fp8 and bf8 on gfx942 and gfx950 since these types are not natively supported on gfx90a
if(GPU_TARGETS MATCHES "gfx942|gfx950")
list(APPEND STREAMK_EXTENDED_SOURCES
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp8_persistent_compv3.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp8_persistent_mem.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf8_persistent_compv3.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf8_persistent_mem.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp8_nonpersistent_compv3.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp8_nonpersistent_mem.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf8_nonpersistent_compv3.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf8_nonpersistent_mem.cpp)
endif()
add_gtest_executable(test_ck_tile_streamk_extended ${STREAMK_EXTENDED_SOURCES})
target_compile_options(test_ck_tile_streamk_extended PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
# Collect all test targets for umbrella label

View File

@@ -0,0 +1,18 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_gemm_streamk_common_includes.hpp"
template <typename Tuple>
class TestCkTileStreamKBf16NonPersistentCompV3 : public TestCkTileStreamK<Tuple>
{
};
#define TEST_SUITE_NAME TestCkTileStreamKBf16NonPersistentCompV3
TYPED_TEST_SUITE(TestCkTileStreamKBf16NonPersistentCompV3,
KernelTypesStreamKBf16NonPersistentCompV3);
#include "test_gemm_streamk_extended_cases.inc"
#undef TEST_SUITE_NAME

View File

@@ -0,0 +1,17 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_gemm_streamk_common_includes.hpp"
template <typename Tuple>
class TestCkTileStreamKBf16NonPersistentMem : public TestCkTileStreamK<Tuple>
{
};
#define TEST_SUITE_NAME TestCkTileStreamKBf16NonPersistentMem
TYPED_TEST_SUITE(TestCkTileStreamKBf16NonPersistentMem, KernelTypesStreamKBf16NonPersistentMem);
#include "test_gemm_streamk_extended_cases.inc"
#undef TEST_SUITE_NAME

View File

@@ -0,0 +1,17 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_gemm_streamk_common_includes.hpp"
template <typename Tuple>
class TestCkTileStreamKBf16PersistentCompV3 : public TestCkTileStreamK<Tuple>
{
};
#define TEST_SUITE_NAME TestCkTileStreamKBf16PersistentCompV3
TYPED_TEST_SUITE(TestCkTileStreamKBf16PersistentCompV3, KernelTypesStreamKBf16PersistentCompV3);
#include "test_gemm_streamk_extended_cases.inc"
#undef TEST_SUITE_NAME

View File

@@ -4,13 +4,13 @@
#include "test_gemm_streamk_common_includes.hpp"
template <typename Tuple>
class TestCkTileStreamKBf16NonPersistent : public TestCkTileStreamK<Tuple>
class TestCkTileStreamKBf16PersistentMem : public TestCkTileStreamK<Tuple>
{
};
#define TEST_SUITE_NAME TestCkTileStreamKBf16NonPersistent
#define TEST_SUITE_NAME TestCkTileStreamKBf16PersistentMem
TYPED_TEST_SUITE(TestCkTileStreamKBf16NonPersistent, KernelTypesStreamKBf16NonPersistent);
TYPED_TEST_SUITE(TestCkTileStreamKBf16PersistentMem, KernelTypesStreamKBf16PersistentMem);
#include "test_gemm_streamk_extended_cases.inc"

View File

@@ -0,0 +1,17 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_gemm_streamk_common_includes.hpp"
template <typename Tuple>
class TestCkTileStreamKBf8NonPersistentCompV3 : public TestCkTileStreamK<Tuple>
{
};
#define TEST_SUITE_NAME TestCkTileStreamKBf8NonPersistentCompV3
TYPED_TEST_SUITE(TestCkTileStreamKBf8NonPersistentCompV3, KernelTypesStreamKBf8NonPersistentCompV3);
#include "test_gemm_streamk_extended_cases.inc"
#undef TEST_SUITE_NAME

View File

@@ -4,13 +4,13 @@
#include "test_gemm_streamk_common_includes.hpp"
template <typename Tuple>
class TestCkTileStreamKBf16Persistent : public TestCkTileStreamK<Tuple>
class TestCkTileStreamKBf8NonPersistentMem : public TestCkTileStreamK<Tuple>
{
};
#define TEST_SUITE_NAME TestCkTileStreamKBf16Persistent
#define TEST_SUITE_NAME TestCkTileStreamKBf8NonPersistentMem
TYPED_TEST_SUITE(TestCkTileStreamKBf16Persistent, KernelTypesStreamKBf16Persistent);
TYPED_TEST_SUITE(TestCkTileStreamKBf8NonPersistentMem, KernelTypesStreamKBf8NonPersistentMem);
#include "test_gemm_streamk_extended_cases.inc"

View File

@@ -4,13 +4,13 @@
#include "test_gemm_streamk_common_includes.hpp"
template <typename Tuple>
class TestCkTileStreamKBf8Persistent : public TestCkTileStreamK<Tuple>
class TestCkTileStreamKBf8PersistentCompV3 : public TestCkTileStreamK<Tuple>
{
};
#define TEST_SUITE_NAME TestCkTileStreamKBf8Persistent
#define TEST_SUITE_NAME TestCkTileStreamKBf8PersistentCompV3
TYPED_TEST_SUITE(TestCkTileStreamKBf8Persistent, KernelTypesStreamKBf8Persistent);
TYPED_TEST_SUITE(TestCkTileStreamKBf8PersistentCompV3, KernelTypesStreamKBf8PersistentCompV3);
#include "test_gemm_streamk_extended_cases.inc"

View File

@@ -4,13 +4,13 @@
#include "test_gemm_streamk_common_includes.hpp"
template <typename Tuple>
class TestCkTileStreamKFp8NonPersistent : public TestCkTileStreamK<Tuple>
class TestCkTileStreamKBf8PersistentMem : public TestCkTileStreamK<Tuple>
{
};
#define TEST_SUITE_NAME TestCkTileStreamKFp8NonPersistent
#define TEST_SUITE_NAME TestCkTileStreamKBf8PersistentMem
TYPED_TEST_SUITE(TestCkTileStreamKFp8NonPersistent, KernelTypesStreamKFp8NonPersistent);
TYPED_TEST_SUITE(TestCkTileStreamKBf8PersistentMem, KernelTypesStreamKBf8PersistentMem);
#include "test_gemm_streamk_extended_cases.inc"

View File

@@ -0,0 +1,18 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_gemm_streamk_common_includes.hpp"
template <typename Tuple>
class TestCkTileStreamKFp16NonPersistentCompV3 : public TestCkTileStreamK<Tuple>
{
};
#define TEST_SUITE_NAME TestCkTileStreamKFp16NonPersistentCompV3
TYPED_TEST_SUITE(TestCkTileStreamKFp16NonPersistentCompV3,
KernelTypesStreamKFp16NonPersistentCompV3);
#include "test_gemm_streamk_extended_cases.inc"
#undef TEST_SUITE_NAME

View File

@@ -0,0 +1,17 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_gemm_streamk_common_includes.hpp"
template <typename Tuple>
class TestCkTileStreamKFp16NonPersistentMem : public TestCkTileStreamK<Tuple>
{
};
#define TEST_SUITE_NAME TestCkTileStreamKFp16NonPersistentMem
TYPED_TEST_SUITE(TestCkTileStreamKFp16NonPersistentMem, KernelTypesStreamKFp16NonPersistentMem);
#include "test_gemm_streamk_extended_cases.inc"
#undef TEST_SUITE_NAME

View File

@@ -1,17 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_gemm_streamk_common_includes.hpp"
template <typename Tuple>
class TestCkTileStreamKFp16Persistent : public TestCkTileStreamK<Tuple>
{
};
#define TEST_SUITE_NAME TestCkTileStreamKFp16Persistent
TYPED_TEST_SUITE(TestCkTileStreamKFp16Persistent, KernelTypesStreamKFp16Persistent);
#include "test_gemm_streamk_extended_cases.inc"
#undef TEST_SUITE_NAME

View File

@@ -0,0 +1,17 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_gemm_streamk_common_includes.hpp"
template <typename Tuple>
class TestCkTileStreamKFp16PersistentCompV3 : public TestCkTileStreamK<Tuple>
{
};
#define TEST_SUITE_NAME TestCkTileStreamKFp16PersistentCompV3
TYPED_TEST_SUITE(TestCkTileStreamKFp16PersistentCompV3, KernelTypesStreamKFp16PersistentCompV3);
#include "test_gemm_streamk_extended_cases.inc"
#undef TEST_SUITE_NAME

View File

@@ -4,13 +4,13 @@
#include "test_gemm_streamk_common_includes.hpp"
template <typename Tuple>
class TestCkTileStreamKFp16NonPersistent : public TestCkTileStreamK<Tuple>
class TestCkTileStreamKFp16PersistentMem : public TestCkTileStreamK<Tuple>
{
};
#define TEST_SUITE_NAME TestCkTileStreamKFp16NonPersistent
#define TEST_SUITE_NAME TestCkTileStreamKFp16PersistentMem
TYPED_TEST_SUITE(TestCkTileStreamKFp16NonPersistent, KernelTypesStreamKFp16NonPersistent);
TYPED_TEST_SUITE(TestCkTileStreamKFp16PersistentMem, KernelTypesStreamKFp16PersistentMem);
#include "test_gemm_streamk_extended_cases.inc"

View File

@@ -0,0 +1,17 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_gemm_streamk_common_includes.hpp"
template <typename Tuple>
class TestCkTileStreamKFp8NonPersistentCompV3 : public TestCkTileStreamK<Tuple>
{
};
#define TEST_SUITE_NAME TestCkTileStreamKFp8NonPersistentCompV3
TYPED_TEST_SUITE(TestCkTileStreamKFp8NonPersistentCompV3, KernelTypesStreamKFp8NonPersistentCompV3);
#include "test_gemm_streamk_extended_cases.inc"
#undef TEST_SUITE_NAME

View File

@@ -0,0 +1,17 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_gemm_streamk_common_includes.hpp"
template <typename Tuple>
class TestCkTileStreamKFp8NonPersistentMem : public TestCkTileStreamK<Tuple>
{
};
#define TEST_SUITE_NAME TestCkTileStreamKFp8NonPersistentMem
TYPED_TEST_SUITE(TestCkTileStreamKFp8NonPersistentMem, KernelTypesStreamKFp8NonPersistentMem);
#include "test_gemm_streamk_extended_cases.inc"
#undef TEST_SUITE_NAME

View File

@@ -1,17 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_gemm_streamk_common_includes.hpp"
template <typename Tuple>
class TestCkTileStreamKFp8Persistent : public TestCkTileStreamK<Tuple>
{
};
#define TEST_SUITE_NAME TestCkTileStreamKFp8Persistent
TYPED_TEST_SUITE(TestCkTileStreamKFp8Persistent, KernelTypesStreamKFp8Persistent);
#include "test_gemm_streamk_extended_cases.inc"
#undef TEST_SUITE_NAME

View File

@@ -0,0 +1,17 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_gemm_streamk_common_includes.hpp"
template <typename Tuple>
class TestCkTileStreamKFp8PersistentCompV3 : public TestCkTileStreamK<Tuple>
{
};
#define TEST_SUITE_NAME TestCkTileStreamKFp8PersistentCompV3
TYPED_TEST_SUITE(TestCkTileStreamKFp8PersistentCompV3, KernelTypesStreamKFp8PersistentCompV3);
#include "test_gemm_streamk_extended_cases.inc"
#undef TEST_SUITE_NAME

View File

@@ -4,13 +4,13 @@
#include "test_gemm_streamk_common_includes.hpp"
template <typename Tuple>
class TestCkTileStreamKBf8NonPersistent : public TestCkTileStreamK<Tuple>
class TestCkTileStreamKFp8PersistentMem : public TestCkTileStreamK<Tuple>
{
};
#define TEST_SUITE_NAME TestCkTileStreamKBf8NonPersistent
#define TEST_SUITE_NAME TestCkTileStreamKFp8PersistentMem
TYPED_TEST_SUITE(TestCkTileStreamKBf8NonPersistent, KernelTypesStreamKBf8NonPersistent);
TYPED_TEST_SUITE(TestCkTileStreamKFp8PersistentMem, KernelTypesStreamKFp8PersistentMem);
#include "test_gemm_streamk_extended_cases.inc"

View File

@@ -6,6 +6,7 @@
#include <type_traits>
#include "gtest/gtest.h"
#include "ck_tile/host.hpp"
#include "test_gemm_streamk_util.hpp"
using F8 = ck_tile::fp8_t;
using F16 = ck_tile::half_t;
@@ -19,69 +20,131 @@ using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
using Persistent = std::true_type;
using NonPersistent = std::false_type;
using Mem = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::Mem>;
using CompV3 = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::CompV3>;
using I32 = ck_tile::number<32>;
using I128 = ck_tile::number<128>;
using I256 = ck_tile::number<256>;
// clang-format off
using KernelTypesStreamKFp16Persistent = ::testing::Types<
// ALayout BLayout CLayout ADataType BDataType AccDataType CDataType M_MacroTile N_MacroTile K_MacroTile Persistent
std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, Persistent>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, Persistent>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, Persistent>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, Persistent>
// ========================== CompV3 Pipeline ==========================
using KernelTypesStreamKFp16PersistentCompV3 = ::testing::Types<
// ALayout BLayout CLayout ADataType BDataType AccDataType CDataType M_MacroTile N_MacroTile K_MacroTile Persistent Pipeline
std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, Persistent, CompV3>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, Persistent, CompV3>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, Persistent, CompV3>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, Persistent, CompV3>
>;
using KernelTypesStreamKBf16Persistent = ::testing::Types<
std::tuple< Row, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent>,
std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent>,
std::tuple< Col, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent>,
std::tuple< Col, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent>
using KernelTypesStreamKBf16PersistentCompV3 = ::testing::Types<
std::tuple< Row, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent, CompV3>,
std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent, CompV3>,
std::tuple< Col, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent, CompV3>,
std::tuple< Col, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent, CompV3>
>;
using KernelTypesStreamKBf8Persistent = ::testing::Types<
std::tuple< Row, Row, Row, BF8, BF8, F32, BF16, I128, I128, I32, Persistent>,
std::tuple< Row, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, Persistent>,
std::tuple< Col, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, Persistent>,
std::tuple< Col, Row, Row, BF8, BF8, F32, BF16, I128, I128, I32, Persistent>
using KernelTypesStreamKBf8PersistentCompV3 = ::testing::Types<
std::tuple< Row, Row, Row, BF8, BF8, F32, BF16, I128, I128, I32, Persistent, CompV3>,
std::tuple< Row, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, Persistent, CompV3>,
std::tuple< Col, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, Persistent, CompV3>,
std::tuple< Col, Row, Row, BF8, BF8, F32, BF16, I128, I128, I32, Persistent, CompV3>
>;
using KernelTypesStreamKFp8Persistent = ::testing::Types<
std::tuple< Row, Row, Row, F8, F8, F32, F16, I128, I128, I32, Persistent>,
std::tuple< Row, Col, Row, F8, F8, F32, F16, I128, I128, I32, Persistent>,
std::tuple< Col, Col, Row, F8, F8, F32, F16, I128, I128, I32, Persistent>,
std::tuple< Col, Row, Row, F8, F8, F32, F16, I128, I128, I32, Persistent>
using KernelTypesStreamKFp8PersistentCompV3 = ::testing::Types<
std::tuple< Row, Row, Row, F8, F8, F32, F16, I128, I128, I32, Persistent, CompV3>,
std::tuple< Row, Col, Row, F8, F8, F32, F16, I128, I128, I32, Persistent, CompV3>,
std::tuple< Col, Col, Row, F8, F8, F32, F16, I128, I128, I32, Persistent, CompV3>,
std::tuple< Col, Row, Row, F8, F8, F32, F16, I128, I128, I32, Persistent, CompV3>
>;
using KernelTypesStreamKFp16NonPersistent = ::testing::Types<
// ALayout BLayout CLayout ADataType BDataType AccDataType CDataType M_MacroTile N_MacroTile K_MacroTile Persistent
std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent>
using KernelTypesStreamKFp16NonPersistentCompV3 = ::testing::Types<
std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent, CompV3>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent, CompV3>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent, CompV3>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent, CompV3>
>;
using KernelTypesStreamKBf16NonPersistent = ::testing::Types<
std::tuple< Row, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent>,
std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent>,
std::tuple< Col, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent>,
std::tuple< Col, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent>
using KernelTypesStreamKBf16NonPersistentCompV3 = ::testing::Types<
std::tuple< Row, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent, CompV3>,
std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent, CompV3>,
std::tuple< Col, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent, CompV3>,
std::tuple< Col, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent, CompV3>
>;
using KernelTypesStreamKBf8NonPersistent = ::testing::Types<
std::tuple< Row, Row, Row, BF8, BF8, F32, BF16, I128, I128, I32, NonPersistent>,
std::tuple< Row, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, NonPersistent>,
std::tuple< Col, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, NonPersistent>,
std::tuple< Col, Row, Row, BF8, BF8, F32, BF16, I128, I128, I32, NonPersistent>
using KernelTypesStreamKBf8NonPersistentCompV3 = ::testing::Types<
std::tuple< Row, Row, Row, BF8, BF8, F32, BF16, I128, I128, I32, NonPersistent, CompV3>,
std::tuple< Row, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, NonPersistent, CompV3>,
std::tuple< Col, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, NonPersistent, CompV3>,
std::tuple< Col, Row, Row, BF8, BF8, F32, BF16, I128, I128, I32, NonPersistent, CompV3>
>;
using KernelTypesStreamKFp8NonPersistent = ::testing::Types<
std::tuple< Row, Row, Row, F8, F8, F32, F16, I128, I128, I32, NonPersistent>,
std::tuple< Row, Col, Row, F8, F8, F32, F16, I128, I128, I32, NonPersistent>,
std::tuple< Col, Col, Row, F8, F8, F32, F16, I128, I128, I32, NonPersistent>,
std::tuple< Col, Row, Row, F8, F8, F32, F16, I128, I128, I32, NonPersistent>
using KernelTypesStreamKFp8NonPersistentCompV3 = ::testing::Types<
std::tuple< Row, Row, Row, F8, F8, F32, F16, I128, I128, I32, NonPersistent, CompV3>,
std::tuple< Row, Col, Row, F8, F8, F32, F16, I128, I128, I32, NonPersistent, CompV3>,
std::tuple< Col, Col, Row, F8, F8, F32, F16, I128, I128, I32, NonPersistent, CompV3>,
std::tuple< Col, Row, Row, F8, F8, F32, F16, I128, I128, I32, NonPersistent, CompV3>
>;
// ============================= Mem Pipeline =============================
using KernelTypesStreamKFp16PersistentMem = ::testing::Types<
std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, Persistent, Mem>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, Persistent, Mem>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, Persistent, Mem>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, Persistent, Mem>
>;
using KernelTypesStreamKBf16PersistentMem = ::testing::Types<
std::tuple< Row, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent, Mem>,
std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent, Mem>,
std::tuple< Col, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent, Mem>,
std::tuple< Col, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent, Mem>
>;
using KernelTypesStreamKBf8PersistentMem = ::testing::Types<
std::tuple< Row, Row, Row, BF8, BF8, F32, BF16, I128, I128, I32, Persistent, Mem>,
std::tuple< Row, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, Persistent, Mem>,
std::tuple< Col, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, Persistent, Mem>,
std::tuple< Col, Row, Row, BF8, BF8, F32, BF16, I128, I128, I32, Persistent, Mem>
>;
using KernelTypesStreamKFp8PersistentMem = ::testing::Types<
std::tuple< Row, Row, Row, F8, F8, F32, F16, I128, I128, I32, Persistent, Mem>,
std::tuple< Row, Col, Row, F8, F8, F32, F16, I128, I128, I32, Persistent, Mem>,
std::tuple< Col, Col, Row, F8, F8, F32, F16, I128, I128, I32, Persistent, Mem>,
std::tuple< Col, Row, Row, F8, F8, F32, F16, I128, I128, I32, Persistent, Mem>
>;
using KernelTypesStreamKFp16NonPersistentMem = ::testing::Types<
std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent, Mem>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent, Mem>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent, Mem>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent, Mem>
>;
using KernelTypesStreamKBf16NonPersistentMem = ::testing::Types<
std::tuple< Row, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent, Mem>,
std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent, Mem>,
std::tuple< Col, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent, Mem>,
std::tuple< Col, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent, Mem>
>;
using KernelTypesStreamKBf8NonPersistentMem = ::testing::Types<
std::tuple< Row, Row, Row, BF8, BF8, F32, BF16, I128, I128, I32, NonPersistent, Mem>,
std::tuple< Row, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, NonPersistent, Mem>,
std::tuple< Col, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, NonPersistent, Mem>,
std::tuple< Col, Row, Row, BF8, BF8, F32, BF16, I128, I128, I32, NonPersistent, Mem>
>;
using KernelTypesStreamKFp8NonPersistentMem = ::testing::Types<
std::tuple< Row, Row, Row, F8, F8, F32, F16, I128, I128, I32, NonPersistent, Mem>,
std::tuple< Row, Col, Row, F8, F8, F32, F16, I128, I128, I32, NonPersistent, Mem>,
std::tuple< Col, Col, Row, F8, F8, F32, F16, I128, I128, I32, NonPersistent, Mem>,
std::tuple< Col, Row, Row, F8, F8, F32, F16, I128, I128, I32, NonPersistent, Mem>
>;
// clang-format on

View File

@@ -11,6 +11,27 @@
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
enum struct GemmPipelineType
{
Mem,
CompV3
};
template <GemmPipelineType PT, typename Problem>
struct GemmPipelineTypeSelector;
template <typename Problem>
struct GemmPipelineTypeSelector<GemmPipelineType::Mem, Problem>
{
using pipeline = ck_tile::GemmPipelineAgBgCrMem<Problem>;
};
template <typename Problem>
struct GemmPipelineTypeSelector<GemmPipelineType::CompV3, Problem>
{
using pipeline = ck_tile::GemmPipelineAgBgCrCompV3<Problem>;
};
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
auto calculate_rtol_atol(const ck_tile::index_t K,
const ck_tile::index_t kbatch,
@@ -56,6 +77,7 @@ class TestCkTileStreamK : public ::testing::Test
static constexpr ck_tile::index_t N_Tile = std::tuple_element_t<8, Tuple>::value;
static constexpr ck_tile::index_t K_Tile = std::tuple_element_t<9, Tuple>::value;
static constexpr bool Persistent = std::tuple_element_t<10, Tuple>::value;
static constexpr auto PipelineType = std::tuple_element_t<11, Tuple>::value;
template <ck_tile::StreamKReductionStrategy ReductionStrategy,
bool PadM = true,
@@ -117,9 +139,8 @@ class TestCkTileStreamK : public ::testing::Test
GemmShape,
GemmUniversalTraits,
scheduler>;
// For initial testing, we will just test with one pipeline.
// More extensive testing is coming later and will test other pipelines.
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<UniversalGemmProblem>;
using GemmPipeline = GemmPipelineTypeSelector<PipelineType, UniversalGemmProblem>::pipeline;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
@@ -149,7 +170,9 @@ class TestCkTileStreamK : public ::testing::Test
if(!Kernel::IsSupportedArgument(kargs))
{
EXPECT_TRUE(false);
// Since IsSupportedArgument only logs with an enviroment variable set, it's best to
// throw when we hit an unsupported case.
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
}
dim3 grid_dims = Kernel::GridSize(kargs.tile_partitioner);
@@ -165,8 +188,6 @@ class TestCkTileStreamK : public ::testing::Test
void Run(ck_tile::index_t M,
ck_tile::index_t N,
ck_tile::index_t K,
ck_tile::StreamKReductionStrategy reduction_strategy =
ck_tile::StreamKReductionStrategy::Atomic,
ck_tile::index_t stride_A = 0,
ck_tile::index_t stride_B = 0,
ck_tile::index_t stride_C = 0)
@@ -240,23 +261,9 @@ class TestCkTileStreamK : public ::testing::Test
stride_B,
stride_C};
ck_tile::index_t num_accumulations_per_tile;
if(reduction_strategy == ck_tile::StreamKReductionStrategy::Atomic)
{
num_accumulations_per_tile = invoke_streamk<ck_tile::StreamKReductionStrategy::Atomic>(
ck_tile::index_t num_accumulations_per_tile =
invoke_streamk<ck_tile::StreamKReductionStrategy::Atomic>(
args, ck_tile::stream_config{nullptr, false, 0, 0, 1});
}
else if(reduction_strategy == ck_tile::StreamKReductionStrategy::Linear)
{
num_accumulations_per_tile = invoke_streamk<ck_tile::StreamKReductionStrategy::Linear>(
args, ck_tile::stream_config{nullptr, false, 0, 0, 1});
}
else
{
num_accumulations_per_tile = invoke_streamk<ck_tile::StreamKReductionStrategy::Tree>(
args, ck_tile::stream_config{nullptr, false, 0, 0, 1});
}
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());

View File

@@ -33,7 +33,7 @@ class TileConfig:
class TraitConfig:
"""Represents the Trait Config section of a Tile Engine config"""
pipeline: List[str] = field(default_factory=lambda: ["compv3"])
pipeline: List[str] = field(default_factory=lambda: ["compv3", "mem"])
epilogue: List[str] = field(default_factory=lambda: ["cshuffle"])
scheduler: List[str] = field(default_factory=lambda: ["intrawave"])
pad_m: List[bool] = field(default_factory=lambda: [False])