mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-05 06:01:23 +00:00
This reverts commit ffb52783d0.
This commit is contained in:
@@ -23,8 +23,3 @@ if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95")
|
||||
else()
|
||||
message("Skipping ck_tile_gemm tests for current target")
|
||||
endif()
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95" OR GPU_TARGETS MATCHES "gfx90a")
|
||||
add_gtest_executable(test_ck_tile_gemm_pipeline_persistent test_gemm_pipeline_persistent.cpp)
|
||||
target_compile_options(test_ck_tile_gemm_pipeline_persistent PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
endif()
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <tuple>
|
||||
#include <type_traits>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
@@ -22,9 +21,6 @@ using Mem = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType:
|
||||
using CompV3 = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::CompV3>;
|
||||
using CompV4 = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::CompV4>;
|
||||
|
||||
using Persistent = std::true_type;
|
||||
using NonPersistent = std::false_type;
|
||||
|
||||
// clang-format off
|
||||
using KernelTypesMem = ::testing::Types<
|
||||
std::tuple< Row, Row, Row, F16, F16, F32, F16, Intrawave, Mem>,
|
||||
@@ -63,9 +59,4 @@ using KernelTypesCompV4 = ::testing::Types<
|
||||
std::tuple< Col, Col, Row, F16, F16, F32, F16, Intrawave, CompV4>
|
||||
>;
|
||||
|
||||
using KernelTypesPersistent = ::testing::Types<
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave, CompV3, Persistent>,
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave, CompV3, NonPersistent>
|
||||
>;
|
||||
|
||||
// clang-format on
|
||||
|
||||
@@ -1,16 +0,0 @@
|
||||
#include "test_gemm_pipeline_kernel_types.hpp"
|
||||
#include "test_gemm_pipeline_util.hpp"
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
template <typename T>
|
||||
class TestCkTileGemmPipelinePersistent : public TestCkTileGemmPipeline<T>
|
||||
{
|
||||
};
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileGemmPipelinePersistent
|
||||
|
||||
TYPED_TEST_SUITE(TEST_SUITE_NAME, KernelTypesPersistent);
|
||||
|
||||
#include "test_gemm_pipeline_ut_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
@@ -76,8 +76,6 @@ class TestCkTileGemmPipeline : public ::testing::Test
|
||||
using CDataType = std::tuple_element_t<6, Tuple>;
|
||||
static constexpr auto Scheduler = std::tuple_element_t<7, Tuple>::value;
|
||||
static constexpr auto PipelineType = std::tuple_element_t<8, Tuple>::value;
|
||||
static constexpr bool Persistent =
|
||||
ck_tile::tuple_element_or_default_t<Tuple, 9, std::false_type>::value;
|
||||
// TODO: expose tile size through test t-param ?
|
||||
|
||||
template <bool PadM, bool PadN, bool PadK>
|
||||
@@ -119,17 +117,14 @@ class TestCkTileGemmPipeline : public ::testing::Test
|
||||
GemmSpatiallyLocalTilePartitioner<GemmShape, TileParitionerGroupNum, TileParitionerM01>;
|
||||
|
||||
using Traits = ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
|
||||
static constexpr bool StructuredSparsity = false;
|
||||
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<kPadM,
|
||||
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<kPadM,
|
||||
kPadN,
|
||||
kPadK,
|
||||
DoubleSmemBuffer,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
TransposeC,
|
||||
StructuredSparsity,
|
||||
Persistent>;
|
||||
TransposeC>;
|
||||
|
||||
using GemmPipelineProblem =
|
||||
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
|
||||
@@ -182,15 +177,7 @@ class TestCkTileGemmPipeline : public ::testing::Test
|
||||
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
dim3 grids;
|
||||
if constexpr(Persistent)
|
||||
{
|
||||
grids = Kernel::MaxOccupancyGridSize(s);
|
||||
}
|
||||
else
|
||||
{
|
||||
grids = Kernel::GridSize(args.M, args.N, args.k_batch);
|
||||
}
|
||||
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
@@ -359,6 +346,9 @@ class TestCkTileGemmPipeline : public ::testing::Test
|
||||
"Error: Incorrect results!",
|
||||
rtol_atol.at(ck_tile::number<0>{}),
|
||||
rtol_atol.at(ck_tile::number<1>{}));
|
||||
std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
|
||||
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
|
||||
<< std::endl;
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user