mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
Revert "CkProfiler StreamK GemmUniversal Fix and Split Gemm_universal Test (…" (#2054)
This reverts commit 7142d8003c.
This commit is contained in:
15
test/gemm_universal/CMakeLists.txt
Executable file → Normal file
15
test/gemm_universal/CMakeLists.txt
Executable file → Normal file
@@ -1,15 +1,4 @@
|
||||
add_gtest_executable(test_gemm_universal_fp16 test_gemm_universal_xdl_fp16.cpp)
|
||||
add_gtest_executable(test_gemm_universal test_gemm_universal_xdl.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_gemm_universal_fp16 PRIVATE utility device_gemm_universal_instance)
|
||||
target_link_libraries(test_gemm_universal PRIVATE utility device_gemm_universal_instance)
|
||||
endif()
|
||||
|
||||
add_gtest_executable(test_gemm_universal_fp8 test_gemm_universal_xdl_fp8.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_gemm_universal_fp8 PRIVATE utility device_gemm_universal_instance)
|
||||
endif()
|
||||
|
||||
add_gtest_executable(test_gemm_universal_bf16 test_gemm_universal_xdl_bf16.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_gemm_universal_bf16 PRIVATE utility device_gemm_universal_instance)
|
||||
endif()
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
#pragma once
|
||||
|
||||
TYPED_TEST(TestGemmUniversal_BF16_MK_KN, SmallM)
|
||||
TYPED_TEST(TestGemmUniversal_MK_KN, SmallM)
|
||||
{
|
||||
std::vector<int> Ms{1, 2, 3, 4, 5, 6};
|
||||
constexpr int N = 512;
|
||||
@@ -14,7 +14,7 @@ TYPED_TEST(TestGemmUniversal_BF16_MK_KN, SmallM)
|
||||
this->Run(M, N, K, StrideA, StrideB, StrideC);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmUniversal_BF16_MK_NK, SmallM)
|
||||
TYPED_TEST(TestGemmUniversal_MK_NK, SmallM)
|
||||
{
|
||||
std::vector<int> Ms{1, 2, 3, 4, 5, 6};
|
||||
constexpr int N = 512;
|
||||
@@ -28,7 +28,7 @@ TYPED_TEST(TestGemmUniversal_BF16_MK_NK, SmallM)
|
||||
this->Run(M, N, K, StrideA, StrideB, StrideC);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmUniversal_BF16_KM_KN, SmallM)
|
||||
TYPED_TEST(TestGemmUniversal_KM_KN, SmallM)
|
||||
{
|
||||
std::vector<int> Ms{1, 2, 3, 4, 5, 6};
|
||||
constexpr int N = 512;
|
||||
@@ -44,7 +44,7 @@ TYPED_TEST(TestGemmUniversal_BF16_KM_KN, SmallM)
|
||||
}
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmUniversal_BF16_KM_NK, SmallM)
|
||||
TYPED_TEST(TestGemmUniversal_KM_NK, SmallM)
|
||||
{
|
||||
std::vector<int> Ms{1, 2, 3, 4, 5, 6};
|
||||
constexpr int N = 512;
|
||||
@@ -60,7 +60,7 @@ TYPED_TEST(TestGemmUniversal_BF16_KM_NK, SmallM)
|
||||
}
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmUniversal_BF16_MK_KN, MidLargeM)
|
||||
TYPED_TEST(TestGemmUniversal_MK_KN, MidLargeM)
|
||||
{
|
||||
std::vector<int> Ms{127, 255, 312, 799, 1573};
|
||||
constexpr int N = 512;
|
||||
@@ -74,7 +74,7 @@ TYPED_TEST(TestGemmUniversal_BF16_MK_KN, MidLargeM)
|
||||
this->Run(M, N, K, StrideA, StrideB, StrideC);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmUniversal_BF16_MK_NK, MidLargeM)
|
||||
TYPED_TEST(TestGemmUniversal_MK_NK, MidLargeM)
|
||||
{
|
||||
std::vector<int> Ms{127, 255, 312, 799, 1573};
|
||||
constexpr int N = 512;
|
||||
@@ -88,7 +88,7 @@ TYPED_TEST(TestGemmUniversal_BF16_MK_NK, MidLargeM)
|
||||
this->Run(M, N, K, StrideA, StrideB, StrideC);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmUniversal_BF16_KM_KN, MidLargeM)
|
||||
TYPED_TEST(TestGemmUniversal_KM_KN, MidLargeM)
|
||||
{
|
||||
std::vector<int> Ms{127, 255, 312, 799, 1573};
|
||||
constexpr int N = 512;
|
||||
@@ -104,7 +104,7 @@ TYPED_TEST(TestGemmUniversal_BF16_KM_KN, MidLargeM)
|
||||
}
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmUniversal_BF16_KM_NK, MidLargeM)
|
||||
TYPED_TEST(TestGemmUniversal_KM_NK, MidLargeM)
|
||||
{
|
||||
std::vector<int> Ms{127, 255, 312, 799, 1573};
|
||||
constexpr int N = 512;
|
||||
@@ -120,7 +120,7 @@ TYPED_TEST(TestGemmUniversal_BF16_KM_NK, MidLargeM)
|
||||
}
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmUniversal_BF16_MK_KN, PaddK)
|
||||
TYPED_TEST(TestGemmUniversal_MK_KN, PaddK)
|
||||
{
|
||||
std::vector<int> Ms{127};
|
||||
constexpr int N = 512;
|
||||
@@ -134,7 +134,7 @@ TYPED_TEST(TestGemmUniversal_BF16_MK_KN, PaddK)
|
||||
this->Run(M, N, K, StrideA, StrideB, StrideC);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmUniversal_BF16_MK_NK, PaddK)
|
||||
TYPED_TEST(TestGemmUniversal_MK_NK, PaddK)
|
||||
{
|
||||
std::vector<int> Ms{127};
|
||||
constexpr int N = 512;
|
||||
@@ -148,7 +148,7 @@ TYPED_TEST(TestGemmUniversal_BF16_MK_NK, PaddK)
|
||||
this->Run(M, N, K, StrideA, StrideB, StrideC);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmUniversal_BF16_KM_KN, PaddK)
|
||||
TYPED_TEST(TestGemmUniversal_KM_KN, PaddK)
|
||||
{
|
||||
std::vector<int> Ms{127};
|
||||
constexpr int N = 512;
|
||||
@@ -164,7 +164,7 @@ TYPED_TEST(TestGemmUniversal_BF16_KM_KN, PaddK)
|
||||
}
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmUniversal_BF16_KM_NK, PaddK)
|
||||
TYPED_TEST(TestGemmUniversal_KM_NK, PaddK)
|
||||
{
|
||||
std::vector<int> Ms{127};
|
||||
constexpr int N = 512;
|
||||
@@ -180,7 +180,7 @@ TYPED_TEST(TestGemmUniversal_BF16_KM_NK, PaddK)
|
||||
}
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmUniversal_BF16_MK_KN, Regular)
|
||||
TYPED_TEST(TestGemmUniversal_MK_KN, Regular)
|
||||
{
|
||||
std::vector<int> Ms{512};
|
||||
constexpr int N = 512;
|
||||
@@ -194,7 +194,7 @@ TYPED_TEST(TestGemmUniversal_BF16_MK_KN, Regular)
|
||||
this->Run(M, N, K, StrideA, StrideB, StrideC);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmUniversal_BF16_MK_NK, Regular)
|
||||
TYPED_TEST(TestGemmUniversal_MK_NK, Regular)
|
||||
{
|
||||
std::vector<int> Ms{512};
|
||||
constexpr int N = 512;
|
||||
@@ -207,3 +207,35 @@ TYPED_TEST(TestGemmUniversal_BF16_MK_NK, Regular)
|
||||
for(int M : Ms)
|
||||
this->Run(M, N, K, StrideA, StrideB, StrideC);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmUniversal_KM_KN, Regular)
|
||||
{
|
||||
std::vector<int> Ms{512};
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 512;
|
||||
|
||||
constexpr int StrideB = N;
|
||||
constexpr int StrideC = N;
|
||||
|
||||
for(int M : Ms)
|
||||
{
|
||||
int StrideA = M;
|
||||
this->Run(M, N, K, StrideA, StrideB, StrideC);
|
||||
}
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmUniversal_KM_NK, Regular)
|
||||
{
|
||||
std::vector<int> Ms{512};
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 512;
|
||||
|
||||
constexpr int StrideB = N;
|
||||
constexpr int StrideC = N;
|
||||
|
||||
for(int M : Ms)
|
||||
{
|
||||
int StrideA = M;
|
||||
this->Run(M, N, K, StrideA, StrideB, StrideC);
|
||||
}
|
||||
}
|
||||
@@ -1,99 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
TYPED_TEST(TestGemmUniversal_FP16_MK_KN, SmallM)
|
||||
{
|
||||
std::vector<int> Ms{1, 2, 3, 4, 5, 6};
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 320;
|
||||
|
||||
constexpr int StrideA = K;
|
||||
constexpr int StrideB = N;
|
||||
constexpr int StrideC = N;
|
||||
|
||||
for(int M : Ms)
|
||||
this->Run(M, N, K, StrideA, StrideB, StrideC);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmUniversal_FP16_MK_NK, SmallM)
|
||||
{
|
||||
std::vector<int> Ms{1, 2, 3, 4, 5, 6};
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 320;
|
||||
|
||||
constexpr int StrideA = K;
|
||||
constexpr int StrideB = K;
|
||||
constexpr int StrideC = N;
|
||||
|
||||
for(int M : Ms)
|
||||
this->Run(M, N, K, StrideA, StrideB, StrideC);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmUniversal_FP16_MK_NK, MidLargeM)
|
||||
{
|
||||
std::vector<int> Ms{127, 255, 312, 799, 1573};
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 320;
|
||||
|
||||
constexpr int StrideA = K;
|
||||
constexpr int StrideB = K;
|
||||
constexpr int StrideC = N;
|
||||
|
||||
for(int M : Ms)
|
||||
this->Run(M, N, K, StrideA, StrideB, StrideC);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmUniversal_FP16_MK_KN, PaddK)
|
||||
{
|
||||
std::vector<int> Ms{127};
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 437;
|
||||
|
||||
constexpr int StrideA = K;
|
||||
constexpr int StrideB = N;
|
||||
constexpr int StrideC = N;
|
||||
|
||||
for(int M : Ms)
|
||||
this->Run(M, N, K, StrideA, StrideB, StrideC);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmUniversal_FP16_MK_NK, PaddK)
|
||||
{
|
||||
std::vector<int> Ms{127};
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 437;
|
||||
|
||||
constexpr int StrideA = K;
|
||||
constexpr int StrideB = K;
|
||||
constexpr int StrideC = N;
|
||||
|
||||
for(int M : Ms)
|
||||
this->Run(M, N, K, StrideA, StrideB, StrideC);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmUniversal_FP16_MK_KN, Regular)
|
||||
{
|
||||
std::vector<int> Ms{512};
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 512;
|
||||
|
||||
constexpr int StrideA = K;
|
||||
constexpr int StrideB = N;
|
||||
constexpr int StrideC = N;
|
||||
|
||||
for(int M : Ms)
|
||||
this->Run(M, N, K, StrideA, StrideB, StrideC);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmUniversal_FP16_MK_NK, Regular)
|
||||
{
|
||||
std::vector<int> Ms{512};
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 512;
|
||||
|
||||
constexpr int StrideA = K;
|
||||
constexpr int StrideB = K;
|
||||
constexpr int StrideC = N;
|
||||
|
||||
for(int M : Ms)
|
||||
this->Run(M, N, K, StrideA, StrideB, StrideC);
|
||||
}
|
||||
@@ -1,113 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
TYPED_TEST(TestGemmUniversal_FP8_MK_KN, SmallM)
|
||||
{
|
||||
std::vector<int> Ms{1, 2, 3, 4, 5, 6};
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 320;
|
||||
|
||||
constexpr int StrideA = K;
|
||||
constexpr int StrideB = N;
|
||||
constexpr int StrideC = N;
|
||||
|
||||
for(int M : Ms)
|
||||
this->Run(M, N, K, StrideA, StrideB, StrideC);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmUniversal_FP8_MK_NK, SmallM)
|
||||
{
|
||||
std::vector<int> Ms{1, 2, 3, 4, 5, 6};
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 320;
|
||||
|
||||
constexpr int StrideA = K;
|
||||
constexpr int StrideB = K;
|
||||
constexpr int StrideC = N;
|
||||
|
||||
for(int M : Ms)
|
||||
this->Run(M, N, K, StrideA, StrideB, StrideC);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmUniversal_FP8_MK_KN, MidLargeM)
|
||||
{
|
||||
std::vector<int> Ms{127, 255, 312, 799, 1573};
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 320;
|
||||
|
||||
constexpr int StrideA = K;
|
||||
constexpr int StrideB = N;
|
||||
constexpr int StrideC = N;
|
||||
|
||||
for(int M : Ms)
|
||||
this->Run(M, N, K, StrideA, StrideB, StrideC);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmUniversal_FP8_MK_NK, MidLargeM)
|
||||
{
|
||||
std::vector<int> Ms{127, 255, 312, 799, 1573};
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 320;
|
||||
|
||||
constexpr int StrideA = K;
|
||||
constexpr int StrideB = K;
|
||||
constexpr int StrideC = N;
|
||||
|
||||
for(int M : Ms)
|
||||
this->Run(M, N, K, StrideA, StrideB, StrideC);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmUniversal_FP8_MK_KN, PaddK)
|
||||
{
|
||||
std::vector<int> Ms{127};
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 437;
|
||||
|
||||
constexpr int StrideA = K;
|
||||
constexpr int StrideB = N;
|
||||
constexpr int StrideC = N;
|
||||
|
||||
for(int M : Ms)
|
||||
this->Run(M, N, K, StrideA, StrideB, StrideC);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmUniversal_FP8_MK_NK, PaddK)
|
||||
{
|
||||
std::vector<int> Ms{127};
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 437;
|
||||
|
||||
constexpr int StrideA = K;
|
||||
constexpr int StrideB = K;
|
||||
constexpr int StrideC = N;
|
||||
|
||||
for(int M : Ms)
|
||||
this->Run(M, N, K, StrideA, StrideB, StrideC);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmUniversal_FP8_MK_KN, Regular)
|
||||
{
|
||||
std::vector<int> Ms{512};
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 512;
|
||||
|
||||
constexpr int StrideA = K;
|
||||
constexpr int StrideB = N;
|
||||
constexpr int StrideC = N;
|
||||
|
||||
for(int M : Ms)
|
||||
this->Run(M, N, K, StrideA, StrideB, StrideC);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmUniversal_FP8_MK_NK, Regular)
|
||||
{
|
||||
std::vector<int> Ms{512};
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 512;
|
||||
|
||||
constexpr int StrideA = K;
|
||||
constexpr int StrideB = K;
|
||||
constexpr int StrideC = N;
|
||||
|
||||
for(int M : Ms)
|
||||
this->Run(M, N, K, StrideA, StrideB, StrideC);
|
||||
}
|
||||
@@ -7,6 +7,8 @@
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "test_gemm_universal_util.hpp"
|
||||
|
||||
using F8 = ck::f8_t;
|
||||
using F16 = ck::half_t;
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F32 = float;
|
||||
|
||||
@@ -27,25 +29,25 @@ struct tuple_concat<std::tuple<Xs...>, std::tuple<Ys...>>
|
||||
} // namespace
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGemmUniversal_BF16_MK_KN
|
||||
class TestGemmUniversal_MK_KN
|
||||
: public ck::test::TestGemmUniversal<typename tuple_concat<std::tuple<Row, Row>, Tuple>::type>
|
||||
{
|
||||
};
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGemmUniversal_BF16_MK_NK
|
||||
class TestGemmUniversal_MK_NK
|
||||
: public ck::test::TestGemmUniversal<typename tuple_concat<std::tuple<Row, Col>, Tuple>::type>
|
||||
{
|
||||
};
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGemmUniversal_BF16_KM_KN
|
||||
class TestGemmUniversal_KM_KN
|
||||
: public ck::test::TestGemmUniversal<typename tuple_concat<std::tuple<Col, Row>, Tuple>::type>
|
||||
{
|
||||
};
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGemmUniversal_BF16_KM_NK
|
||||
class TestGemmUniversal_KM_NK
|
||||
: public ck::test::TestGemmUniversal<typename tuple_concat<std::tuple<Col, Col>, Tuple>::type>
|
||||
{
|
||||
};
|
||||
@@ -53,12 +55,22 @@ class TestGemmUniversal_BF16_KM_NK
|
||||
// clang-format off
|
||||
using KernelTypes_MK_KN = ::testing::Types<
|
||||
// ADataType, BDataType, ComputeDataType, CDataType
|
||||
|
||||
std::tuple< F16, F16, F16, F16>,
|
||||
#if defined(CK_ENABLE_FP8) && (defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94))
|
||||
std::tuple< F16, F8, F16, F16>,
|
||||
std::tuple< F8, F16, F16, F16>,
|
||||
std::tuple< F8, F8, F8, BF16>,
|
||||
#endif
|
||||
std::tuple< BF16, BF16, BF16, BF16>
|
||||
>;
|
||||
using KernelTypes_MK_NK = ::testing::Types<
|
||||
// ADataType, BDataType, ComputeDataType, CDataType
|
||||
|
||||
std::tuple< F16, F16, F16, F16>,
|
||||
#if defined(CK_ENABLE_FP8) && (defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94))
|
||||
std::tuple< F16, F8, F16, F16>,
|
||||
std::tuple< F8, F16, F16, F16>,
|
||||
std::tuple< F8, F8, F8, BF16>,
|
||||
#endif
|
||||
std::tuple< BF16, BF16, BF16, BF16>
|
||||
>;
|
||||
|
||||
@@ -74,9 +86,9 @@ using KernelTypes_KM_KN = ::testing::Types<
|
||||
|
||||
// clang-format on
|
||||
|
||||
TYPED_TEST_SUITE(TestGemmUniversal_BF16_MK_KN, KernelTypes_MK_KN);
|
||||
TYPED_TEST_SUITE(TestGemmUniversal_BF16_MK_NK, KernelTypes_MK_NK);
|
||||
TYPED_TEST_SUITE(TestGemmUniversal_BF16_KM_KN, KernelTypes_KM_KN);
|
||||
TYPED_TEST_SUITE(TestGemmUniversal_BF16_KM_NK, KernelTypes_KM_NK);
|
||||
TYPED_TEST_SUITE(TestGemmUniversal_MK_KN, KernelTypes_MK_KN);
|
||||
TYPED_TEST_SUITE(TestGemmUniversal_MK_NK, KernelTypes_MK_NK);
|
||||
TYPED_TEST_SUITE(TestGemmUniversal_KM_KN, KernelTypes_KM_KN);
|
||||
TYPED_TEST_SUITE(TestGemmUniversal_KM_NK, KernelTypes_KM_NK);
|
||||
|
||||
#include "test_gemm_universal_ut_cases_bf16.inc"
|
||||
#include "test_gemm_universal_ut_cases.inc"
|
||||
@@ -1,82 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <tuple>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "test_gemm_universal_util.hpp"
|
||||
|
||||
using F8 = ck::f8_t;
|
||||
using F16 = ck::half_t;
|
||||
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename X, typename Y>
|
||||
struct tuple_concat;
|
||||
|
||||
template <typename... Xs, typename... Ys>
|
||||
struct tuple_concat<std::tuple<Xs...>, std::tuple<Ys...>>
|
||||
{
|
||||
using type = std::tuple<Xs..., Ys...>;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGemmUniversal_FP16_MK_KN
|
||||
: public ck::test::TestGemmUniversal<typename tuple_concat<std::tuple<Row, Row>, Tuple>::type>
|
||||
{
|
||||
};
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGemmUniversal_FP16_MK_NK
|
||||
: public ck::test::TestGemmUniversal<typename tuple_concat<std::tuple<Row, Col>, Tuple>::type>
|
||||
{
|
||||
};
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGemmUniversal_FP16_KM_KN
|
||||
: public ck::test::TestGemmUniversal<typename tuple_concat<std::tuple<Col, Row>, Tuple>::type>
|
||||
{
|
||||
};
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGemmUniversal_FP16_KM_NK
|
||||
: public ck::test::TestGemmUniversal<typename tuple_concat<std::tuple<Col, Col>, Tuple>::type>
|
||||
{
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
using KernelTypes_MK_KN = ::testing::Types<
|
||||
// ADataType, BDataType, ComputeDataType, CDataType
|
||||
|
||||
#if defined(CK_ENABLE_FP8) && (defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94))
|
||||
std::tuple< F16, F8, F16, F16>,
|
||||
std::tuple< F8, F16, F16, F16>,
|
||||
|
||||
#endif
|
||||
std::tuple< F16, F16, F16, F16>
|
||||
>;
|
||||
using KernelTypes_MK_NK = ::testing::Types<
|
||||
// ADataType, BDataType, ComputeDataType, CDataType
|
||||
|
||||
#if defined(CK_ENABLE_FP8) && (defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94))
|
||||
std::tuple< F16, F8, F16, F16>,
|
||||
std::tuple< F8, F16, F16, F16>,
|
||||
|
||||
#endif
|
||||
std::tuple< F16, F16, F16, F16>
|
||||
>;
|
||||
|
||||
// clang-format on
|
||||
|
||||
TYPED_TEST_SUITE(TestGemmUniversal_FP16_MK_KN, KernelTypes_MK_KN);
|
||||
TYPED_TEST_SUITE(TestGemmUniversal_FP16_MK_NK, KernelTypes_MK_NK);
|
||||
|
||||
#include "test_gemm_universal_ut_cases_fp16.inc"
|
||||
@@ -1,71 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <tuple>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "test_gemm_universal_util.hpp"
|
||||
|
||||
using F8 = ck::f8_t;
|
||||
using F16 = ck::half_t;
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename X, typename Y>
|
||||
struct tuple_concat;
|
||||
|
||||
template <typename... Xs, typename... Ys>
|
||||
struct tuple_concat<std::tuple<Xs...>, std::tuple<Ys...>>
|
||||
{
|
||||
using type = std::tuple<Xs..., Ys...>;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGemmUniversal_FP8_MK_KN
|
||||
: public ck::test::TestGemmUniversal<typename tuple_concat<std::tuple<Row, Row>, Tuple>::type>
|
||||
{
|
||||
};
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGemmUniversal_FP8_MK_NK
|
||||
: public ck::test::TestGemmUniversal<typename tuple_concat<std::tuple<Row, Col>, Tuple>::type>
|
||||
{
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
using KernelTypes_MK_KN = ::testing::Types<
|
||||
// ADataType, BDataType, ComputeDataType, CDataType
|
||||
#if defined(CK_ENABLE_FP8) && (defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94))
|
||||
std::tuple< F16, F8, F16, F16>,
|
||||
std::tuple< F8, F16, F16, F16>,
|
||||
std::tuple< F8, F8, F8, BF16>,
|
||||
#endif
|
||||
// Fallback test type when FP8 is not enabled
|
||||
std::tuple< F16, F16, F16, F16>
|
||||
>;
|
||||
using KernelTypes_MK_NK = ::testing::Types<
|
||||
// ADataType, BDataType, ComputeDataType, CDataType
|
||||
|
||||
#if defined(CK_ENABLE_FP8) && (defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94))
|
||||
std::tuple< F16, F8, F16, F16>,
|
||||
std::tuple< F8, F16, F16, F16>,
|
||||
std::tuple< F8, F8, F8, BF16>,
|
||||
#endif
|
||||
// Fallback test type when FP8 is not enabled
|
||||
std::tuple< F16, F16, F16, F16>
|
||||
>;
|
||||
|
||||
|
||||
TYPED_TEST_SUITE(TestGemmUniversal_FP8_MK_KN, KernelTypes_MK_KN);
|
||||
TYPED_TEST_SUITE(TestGemmUniversal_FP8_MK_NK, KernelTypes_MK_NK);
|
||||
|
||||
|
||||
#include "test_gemm_universal_ut_cases_fp8.inc"
|
||||
Reference in New Issue
Block a user