mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 14:29:05 +00:00
Revert "CkProfiler StreamK GemmUniversal Fix and Split Gemm_universal Test (…" (#2054)
This reverts commit 7142d8003c.
This commit is contained in:
2
profiler/include/profiler/profile_gemm_universal_streamk_impl.hpp
Executable file → Normal file
2
profiler/include/profiler/profile_gemm_universal_streamk_impl.hpp
Executable file → Normal file
@@ -166,7 +166,7 @@ bool profile_gemm_universal_streamk_impl(int do_verification,
|
||||
0, 1, 2, 3, 4}; // 0: Data Parallel (DP) mode (Stream-K OFF), 1: 1-tile Stream-K+ DP,
|
||||
// 2:2-tile Stream-K + DP
|
||||
|
||||
if(Grid_size == -1)
|
||||
if(Grid_size != -1)
|
||||
{
|
||||
grid_size_list = {Grid_size};
|
||||
}
|
||||
|
||||
15
test/gemm_universal/CMakeLists.txt
Executable file → Normal file
15
test/gemm_universal/CMakeLists.txt
Executable file → Normal file
@@ -1,15 +1,4 @@
|
||||
add_gtest_executable(test_gemm_universal_fp16 test_gemm_universal_xdl_fp16.cpp)
|
||||
add_gtest_executable(test_gemm_universal test_gemm_universal_xdl.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_gemm_universal_fp16 PRIVATE utility device_gemm_universal_instance)
|
||||
target_link_libraries(test_gemm_universal PRIVATE utility device_gemm_universal_instance)
|
||||
endif()
|
||||
|
||||
add_gtest_executable(test_gemm_universal_fp8 test_gemm_universal_xdl_fp8.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_gemm_universal_fp8 PRIVATE utility device_gemm_universal_instance)
|
||||
endif()
|
||||
|
||||
add_gtest_executable(test_gemm_universal_bf16 test_gemm_universal_xdl_bf16.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_gemm_universal_bf16 PRIVATE utility device_gemm_universal_instance)
|
||||
endif()
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
#pragma once
|
||||
|
||||
TYPED_TEST(TestGemmUniversal_BF16_MK_KN, SmallM)
|
||||
TYPED_TEST(TestGemmUniversal_MK_KN, SmallM)
|
||||
{
|
||||
std::vector<int> Ms{1, 2, 3, 4, 5, 6};
|
||||
constexpr int N = 512;
|
||||
@@ -14,7 +14,7 @@ TYPED_TEST(TestGemmUniversal_BF16_MK_KN, SmallM)
|
||||
this->Run(M, N, K, StrideA, StrideB, StrideC);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmUniversal_BF16_MK_NK, SmallM)
|
||||
TYPED_TEST(TestGemmUniversal_MK_NK, SmallM)
|
||||
{
|
||||
std::vector<int> Ms{1, 2, 3, 4, 5, 6};
|
||||
constexpr int N = 512;
|
||||
@@ -28,7 +28,7 @@ TYPED_TEST(TestGemmUniversal_BF16_MK_NK, SmallM)
|
||||
this->Run(M, N, K, StrideA, StrideB, StrideC);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmUniversal_BF16_KM_KN, SmallM)
|
||||
TYPED_TEST(TestGemmUniversal_KM_KN, SmallM)
|
||||
{
|
||||
std::vector<int> Ms{1, 2, 3, 4, 5, 6};
|
||||
constexpr int N = 512;
|
||||
@@ -44,7 +44,7 @@ TYPED_TEST(TestGemmUniversal_BF16_KM_KN, SmallM)
|
||||
}
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmUniversal_BF16_KM_NK, SmallM)
|
||||
TYPED_TEST(TestGemmUniversal_KM_NK, SmallM)
|
||||
{
|
||||
std::vector<int> Ms{1, 2, 3, 4, 5, 6};
|
||||
constexpr int N = 512;
|
||||
@@ -60,7 +60,7 @@ TYPED_TEST(TestGemmUniversal_BF16_KM_NK, SmallM)
|
||||
}
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmUniversal_BF16_MK_KN, MidLargeM)
|
||||
TYPED_TEST(TestGemmUniversal_MK_KN, MidLargeM)
|
||||
{
|
||||
std::vector<int> Ms{127, 255, 312, 799, 1573};
|
||||
constexpr int N = 512;
|
||||
@@ -74,7 +74,7 @@ TYPED_TEST(TestGemmUniversal_BF16_MK_KN, MidLargeM)
|
||||
this->Run(M, N, K, StrideA, StrideB, StrideC);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmUniversal_BF16_MK_NK, MidLargeM)
|
||||
TYPED_TEST(TestGemmUniversal_MK_NK, MidLargeM)
|
||||
{
|
||||
std::vector<int> Ms{127, 255, 312, 799, 1573};
|
||||
constexpr int N = 512;
|
||||
@@ -88,7 +88,7 @@ TYPED_TEST(TestGemmUniversal_BF16_MK_NK, MidLargeM)
|
||||
this->Run(M, N, K, StrideA, StrideB, StrideC);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmUniversal_BF16_KM_KN, MidLargeM)
|
||||
TYPED_TEST(TestGemmUniversal_KM_KN, MidLargeM)
|
||||
{
|
||||
std::vector<int> Ms{127, 255, 312, 799, 1573};
|
||||
constexpr int N = 512;
|
||||
@@ -104,7 +104,7 @@ TYPED_TEST(TestGemmUniversal_BF16_KM_KN, MidLargeM)
|
||||
}
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmUniversal_BF16_KM_NK, MidLargeM)
|
||||
TYPED_TEST(TestGemmUniversal_KM_NK, MidLargeM)
|
||||
{
|
||||
std::vector<int> Ms{127, 255, 312, 799, 1573};
|
||||
constexpr int N = 512;
|
||||
@@ -120,7 +120,7 @@ TYPED_TEST(TestGemmUniversal_BF16_KM_NK, MidLargeM)
|
||||
}
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmUniversal_BF16_MK_KN, PaddK)
|
||||
TYPED_TEST(TestGemmUniversal_MK_KN, PaddK)
|
||||
{
|
||||
std::vector<int> Ms{127};
|
||||
constexpr int N = 512;
|
||||
@@ -134,7 +134,7 @@ TYPED_TEST(TestGemmUniversal_BF16_MK_KN, PaddK)
|
||||
this->Run(M, N, K, StrideA, StrideB, StrideC);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmUniversal_BF16_MK_NK, PaddK)
|
||||
TYPED_TEST(TestGemmUniversal_MK_NK, PaddK)
|
||||
{
|
||||
std::vector<int> Ms{127};
|
||||
constexpr int N = 512;
|
||||
@@ -148,7 +148,7 @@ TYPED_TEST(TestGemmUniversal_BF16_MK_NK, PaddK)
|
||||
this->Run(M, N, K, StrideA, StrideB, StrideC);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmUniversal_BF16_KM_KN, PaddK)
|
||||
TYPED_TEST(TestGemmUniversal_KM_KN, PaddK)
|
||||
{
|
||||
std::vector<int> Ms{127};
|
||||
constexpr int N = 512;
|
||||
@@ -164,7 +164,7 @@ TYPED_TEST(TestGemmUniversal_BF16_KM_KN, PaddK)
|
||||
}
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmUniversal_BF16_KM_NK, PaddK)
|
||||
TYPED_TEST(TestGemmUniversal_KM_NK, PaddK)
|
||||
{
|
||||
std::vector<int> Ms{127};
|
||||
constexpr int N = 512;
|
||||
@@ -180,7 +180,7 @@ TYPED_TEST(TestGemmUniversal_BF16_KM_NK, PaddK)
|
||||
}
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmUniversal_BF16_MK_KN, Regular)
|
||||
TYPED_TEST(TestGemmUniversal_MK_KN, Regular)
|
||||
{
|
||||
std::vector<int> Ms{512};
|
||||
constexpr int N = 512;
|
||||
@@ -194,7 +194,7 @@ TYPED_TEST(TestGemmUniversal_BF16_MK_KN, Regular)
|
||||
this->Run(M, N, K, StrideA, StrideB, StrideC);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmUniversal_BF16_MK_NK, Regular)
|
||||
TYPED_TEST(TestGemmUniversal_MK_NK, Regular)
|
||||
{
|
||||
std::vector<int> Ms{512};
|
||||
constexpr int N = 512;
|
||||
@@ -207,3 +207,35 @@ TYPED_TEST(TestGemmUniversal_BF16_MK_NK, Regular)
|
||||
for(int M : Ms)
|
||||
this->Run(M, N, K, StrideA, StrideB, StrideC);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmUniversal_KM_KN, Regular)
|
||||
{
|
||||
std::vector<int> Ms{512};
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 512;
|
||||
|
||||
constexpr int StrideB = N;
|
||||
constexpr int StrideC = N;
|
||||
|
||||
for(int M : Ms)
|
||||
{
|
||||
int StrideA = M;
|
||||
this->Run(M, N, K, StrideA, StrideB, StrideC);
|
||||
}
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmUniversal_KM_NK, Regular)
|
||||
{
|
||||
std::vector<int> Ms{512};
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 512;
|
||||
|
||||
constexpr int StrideB = N;
|
||||
constexpr int StrideC = N;
|
||||
|
||||
for(int M : Ms)
|
||||
{
|
||||
int StrideA = M;
|
||||
this->Run(M, N, K, StrideA, StrideB, StrideC);
|
||||
}
|
||||
}
|
||||
@@ -1,99 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
TYPED_TEST(TestGemmUniversal_FP16_MK_KN, SmallM)
|
||||
{
|
||||
std::vector<int> Ms{1, 2, 3, 4, 5, 6};
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 320;
|
||||
|
||||
constexpr int StrideA = K;
|
||||
constexpr int StrideB = N;
|
||||
constexpr int StrideC = N;
|
||||
|
||||
for(int M : Ms)
|
||||
this->Run(M, N, K, StrideA, StrideB, StrideC);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmUniversal_FP16_MK_NK, SmallM)
|
||||
{
|
||||
std::vector<int> Ms{1, 2, 3, 4, 5, 6};
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 320;
|
||||
|
||||
constexpr int StrideA = K;
|
||||
constexpr int StrideB = K;
|
||||
constexpr int StrideC = N;
|
||||
|
||||
for(int M : Ms)
|
||||
this->Run(M, N, K, StrideA, StrideB, StrideC);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmUniversal_FP16_MK_NK, MidLargeM)
|
||||
{
|
||||
std::vector<int> Ms{127, 255, 312, 799, 1573};
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 320;
|
||||
|
||||
constexpr int StrideA = K;
|
||||
constexpr int StrideB = K;
|
||||
constexpr int StrideC = N;
|
||||
|
||||
for(int M : Ms)
|
||||
this->Run(M, N, K, StrideA, StrideB, StrideC);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmUniversal_FP16_MK_KN, PaddK)
|
||||
{
|
||||
std::vector<int> Ms{127};
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 437;
|
||||
|
||||
constexpr int StrideA = K;
|
||||
constexpr int StrideB = N;
|
||||
constexpr int StrideC = N;
|
||||
|
||||
for(int M : Ms)
|
||||
this->Run(M, N, K, StrideA, StrideB, StrideC);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmUniversal_FP16_MK_NK, PaddK)
|
||||
{
|
||||
std::vector<int> Ms{127};
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 437;
|
||||
|
||||
constexpr int StrideA = K;
|
||||
constexpr int StrideB = K;
|
||||
constexpr int StrideC = N;
|
||||
|
||||
for(int M : Ms)
|
||||
this->Run(M, N, K, StrideA, StrideB, StrideC);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmUniversal_FP16_MK_KN, Regular)
|
||||
{
|
||||
std::vector<int> Ms{512};
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 512;
|
||||
|
||||
constexpr int StrideA = K;
|
||||
constexpr int StrideB = N;
|
||||
constexpr int StrideC = N;
|
||||
|
||||
for(int M : Ms)
|
||||
this->Run(M, N, K, StrideA, StrideB, StrideC);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmUniversal_FP16_MK_NK, Regular)
|
||||
{
|
||||
std::vector<int> Ms{512};
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 512;
|
||||
|
||||
constexpr int StrideA = K;
|
||||
constexpr int StrideB = K;
|
||||
constexpr int StrideC = N;
|
||||
|
||||
for(int M : Ms)
|
||||
this->Run(M, N, K, StrideA, StrideB, StrideC);
|
||||
}
|
||||
@@ -1,113 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
TYPED_TEST(TestGemmUniversal_FP8_MK_KN, SmallM)
|
||||
{
|
||||
std::vector<int> Ms{1, 2, 3, 4, 5, 6};
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 320;
|
||||
|
||||
constexpr int StrideA = K;
|
||||
constexpr int StrideB = N;
|
||||
constexpr int StrideC = N;
|
||||
|
||||
for(int M : Ms)
|
||||
this->Run(M, N, K, StrideA, StrideB, StrideC);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmUniversal_FP8_MK_NK, SmallM)
|
||||
{
|
||||
std::vector<int> Ms{1, 2, 3, 4, 5, 6};
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 320;
|
||||
|
||||
constexpr int StrideA = K;
|
||||
constexpr int StrideB = K;
|
||||
constexpr int StrideC = N;
|
||||
|
||||
for(int M : Ms)
|
||||
this->Run(M, N, K, StrideA, StrideB, StrideC);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmUniversal_FP8_MK_KN, MidLargeM)
|
||||
{
|
||||
std::vector<int> Ms{127, 255, 312, 799, 1573};
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 320;
|
||||
|
||||
constexpr int StrideA = K;
|
||||
constexpr int StrideB = N;
|
||||
constexpr int StrideC = N;
|
||||
|
||||
for(int M : Ms)
|
||||
this->Run(M, N, K, StrideA, StrideB, StrideC);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmUniversal_FP8_MK_NK, MidLargeM)
|
||||
{
|
||||
std::vector<int> Ms{127, 255, 312, 799, 1573};
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 320;
|
||||
|
||||
constexpr int StrideA = K;
|
||||
constexpr int StrideB = K;
|
||||
constexpr int StrideC = N;
|
||||
|
||||
for(int M : Ms)
|
||||
this->Run(M, N, K, StrideA, StrideB, StrideC);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmUniversal_FP8_MK_KN, PaddK)
|
||||
{
|
||||
std::vector<int> Ms{127};
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 437;
|
||||
|
||||
constexpr int StrideA = K;
|
||||
constexpr int StrideB = N;
|
||||
constexpr int StrideC = N;
|
||||
|
||||
for(int M : Ms)
|
||||
this->Run(M, N, K, StrideA, StrideB, StrideC);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmUniversal_FP8_MK_NK, PaddK)
|
||||
{
|
||||
std::vector<int> Ms{127};
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 437;
|
||||
|
||||
constexpr int StrideA = K;
|
||||
constexpr int StrideB = K;
|
||||
constexpr int StrideC = N;
|
||||
|
||||
for(int M : Ms)
|
||||
this->Run(M, N, K, StrideA, StrideB, StrideC);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmUniversal_FP8_MK_KN, Regular)
|
||||
{
|
||||
std::vector<int> Ms{512};
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 512;
|
||||
|
||||
constexpr int StrideA = K;
|
||||
constexpr int StrideB = N;
|
||||
constexpr int StrideC = N;
|
||||
|
||||
for(int M : Ms)
|
||||
this->Run(M, N, K, StrideA, StrideB, StrideC);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmUniversal_FP8_MK_NK, Regular)
|
||||
{
|
||||
std::vector<int> Ms{512};
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 512;
|
||||
|
||||
constexpr int StrideA = K;
|
||||
constexpr int StrideB = K;
|
||||
constexpr int StrideC = N;
|
||||
|
||||
for(int M : Ms)
|
||||
this->Run(M, N, K, StrideA, StrideB, StrideC);
|
||||
}
|
||||
@@ -7,6 +7,8 @@
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "test_gemm_universal_util.hpp"
|
||||
|
||||
using F8 = ck::f8_t;
|
||||
using F16 = ck::half_t;
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F32 = float;
|
||||
|
||||
@@ -27,25 +29,25 @@ struct tuple_concat<std::tuple<Xs...>, std::tuple<Ys...>>
|
||||
} // namespace
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGemmUniversal_BF16_MK_KN
|
||||
class TestGemmUniversal_MK_KN
|
||||
: public ck::test::TestGemmUniversal<typename tuple_concat<std::tuple<Row, Row>, Tuple>::type>
|
||||
{
|
||||
};
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGemmUniversal_BF16_MK_NK
|
||||
class TestGemmUniversal_MK_NK
|
||||
: public ck::test::TestGemmUniversal<typename tuple_concat<std::tuple<Row, Col>, Tuple>::type>
|
||||
{
|
||||
};
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGemmUniversal_BF16_KM_KN
|
||||
class TestGemmUniversal_KM_KN
|
||||
: public ck::test::TestGemmUniversal<typename tuple_concat<std::tuple<Col, Row>, Tuple>::type>
|
||||
{
|
||||
};
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGemmUniversal_BF16_KM_NK
|
||||
class TestGemmUniversal_KM_NK
|
||||
: public ck::test::TestGemmUniversal<typename tuple_concat<std::tuple<Col, Col>, Tuple>::type>
|
||||
{
|
||||
};
|
||||
@@ -53,12 +55,22 @@ class TestGemmUniversal_BF16_KM_NK
|
||||
// clang-format off
|
||||
using KernelTypes_MK_KN = ::testing::Types<
|
||||
// ADataType, BDataType, ComputeDataType, CDataType
|
||||
|
||||
std::tuple< F16, F16, F16, F16>,
|
||||
#if defined(CK_ENABLE_FP8) && (defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94))
|
||||
std::tuple< F16, F8, F16, F16>,
|
||||
std::tuple< F8, F16, F16, F16>,
|
||||
std::tuple< F8, F8, F8, BF16>,
|
||||
#endif
|
||||
std::tuple< BF16, BF16, BF16, BF16>
|
||||
>;
|
||||
using KernelTypes_MK_NK = ::testing::Types<
|
||||
// ADataType, BDataType, ComputeDataType, CDataType
|
||||
|
||||
std::tuple< F16, F16, F16, F16>,
|
||||
#if defined(CK_ENABLE_FP8) && (defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94))
|
||||
std::tuple< F16, F8, F16, F16>,
|
||||
std::tuple< F8, F16, F16, F16>,
|
||||
std::tuple< F8, F8, F8, BF16>,
|
||||
#endif
|
||||
std::tuple< BF16, BF16, BF16, BF16>
|
||||
>;
|
||||
|
||||
@@ -74,9 +86,9 @@ using KernelTypes_KM_KN = ::testing::Types<
|
||||
|
||||
// clang-format on
|
||||
|
||||
TYPED_TEST_SUITE(TestGemmUniversal_BF16_MK_KN, KernelTypes_MK_KN);
|
||||
TYPED_TEST_SUITE(TestGemmUniversal_BF16_MK_NK, KernelTypes_MK_NK);
|
||||
TYPED_TEST_SUITE(TestGemmUniversal_BF16_KM_KN, KernelTypes_KM_KN);
|
||||
TYPED_TEST_SUITE(TestGemmUniversal_BF16_KM_NK, KernelTypes_KM_NK);
|
||||
TYPED_TEST_SUITE(TestGemmUniversal_MK_KN, KernelTypes_MK_KN);
|
||||
TYPED_TEST_SUITE(TestGemmUniversal_MK_NK, KernelTypes_MK_NK);
|
||||
TYPED_TEST_SUITE(TestGemmUniversal_KM_KN, KernelTypes_KM_KN);
|
||||
TYPED_TEST_SUITE(TestGemmUniversal_KM_NK, KernelTypes_KM_NK);
|
||||
|
||||
#include "test_gemm_universal_ut_cases_bf16.inc"
|
||||
#include "test_gemm_universal_ut_cases.inc"
|
||||
@@ -1,82 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <tuple>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "test_gemm_universal_util.hpp"
|
||||
|
||||
using F8 = ck::f8_t;
|
||||
using F16 = ck::half_t;
|
||||
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename X, typename Y>
|
||||
struct tuple_concat;
|
||||
|
||||
template <typename... Xs, typename... Ys>
|
||||
struct tuple_concat<std::tuple<Xs...>, std::tuple<Ys...>>
|
||||
{
|
||||
using type = std::tuple<Xs..., Ys...>;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGemmUniversal_FP16_MK_KN
|
||||
: public ck::test::TestGemmUniversal<typename tuple_concat<std::tuple<Row, Row>, Tuple>::type>
|
||||
{
|
||||
};
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGemmUniversal_FP16_MK_NK
|
||||
: public ck::test::TestGemmUniversal<typename tuple_concat<std::tuple<Row, Col>, Tuple>::type>
|
||||
{
|
||||
};
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGemmUniversal_FP16_KM_KN
|
||||
: public ck::test::TestGemmUniversal<typename tuple_concat<std::tuple<Col, Row>, Tuple>::type>
|
||||
{
|
||||
};
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGemmUniversal_FP16_KM_NK
|
||||
: public ck::test::TestGemmUniversal<typename tuple_concat<std::tuple<Col, Col>, Tuple>::type>
|
||||
{
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
using KernelTypes_MK_KN = ::testing::Types<
|
||||
// ADataType, BDataType, ComputeDataType, CDataType
|
||||
|
||||
#if defined(CK_ENABLE_FP8) && (defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94))
|
||||
std::tuple< F16, F8, F16, F16>,
|
||||
std::tuple< F8, F16, F16, F16>,
|
||||
|
||||
#endif
|
||||
std::tuple< F16, F16, F16, F16>
|
||||
>;
|
||||
using KernelTypes_MK_NK = ::testing::Types<
|
||||
// ADataType, BDataType, ComputeDataType, CDataType
|
||||
|
||||
#if defined(CK_ENABLE_FP8) && (defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94))
|
||||
std::tuple< F16, F8, F16, F16>,
|
||||
std::tuple< F8, F16, F16, F16>,
|
||||
|
||||
#endif
|
||||
std::tuple< F16, F16, F16, F16>
|
||||
>;
|
||||
|
||||
// clang-format on
|
||||
|
||||
TYPED_TEST_SUITE(TestGemmUniversal_FP16_MK_KN, KernelTypes_MK_KN);
|
||||
TYPED_TEST_SUITE(TestGemmUniversal_FP16_MK_NK, KernelTypes_MK_NK);
|
||||
|
||||
#include "test_gemm_universal_ut_cases_fp16.inc"
|
||||
@@ -1,71 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <tuple>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "test_gemm_universal_util.hpp"
|
||||
|
||||
using F8 = ck::f8_t;
|
||||
using F16 = ck::half_t;
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename X, typename Y>
|
||||
struct tuple_concat;
|
||||
|
||||
template <typename... Xs, typename... Ys>
|
||||
struct tuple_concat<std::tuple<Xs...>, std::tuple<Ys...>>
|
||||
{
|
||||
using type = std::tuple<Xs..., Ys...>;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGemmUniversal_FP8_MK_KN
|
||||
: public ck::test::TestGemmUniversal<typename tuple_concat<std::tuple<Row, Row>, Tuple>::type>
|
||||
{
|
||||
};
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGemmUniversal_FP8_MK_NK
|
||||
: public ck::test::TestGemmUniversal<typename tuple_concat<std::tuple<Row, Col>, Tuple>::type>
|
||||
{
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
using KernelTypes_MK_KN = ::testing::Types<
|
||||
// ADataType, BDataType, ComputeDataType, CDataType
|
||||
#if defined(CK_ENABLE_FP8) && (defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94))
|
||||
std::tuple< F16, F8, F16, F16>,
|
||||
std::tuple< F8, F16, F16, F16>,
|
||||
std::tuple< F8, F8, F8, BF16>,
|
||||
#endif
|
||||
// Fallback test type when FP8 is not enabled
|
||||
std::tuple< F16, F16, F16, F16>
|
||||
>;
|
||||
using KernelTypes_MK_NK = ::testing::Types<
|
||||
// ADataType, BDataType, ComputeDataType, CDataType
|
||||
|
||||
#if defined(CK_ENABLE_FP8) && (defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94))
|
||||
std::tuple< F16, F8, F16, F16>,
|
||||
std::tuple< F8, F16, F16, F16>,
|
||||
std::tuple< F8, F8, F8, BF16>,
|
||||
#endif
|
||||
// Fallback test type when FP8 is not enabled
|
||||
std::tuple< F16, F16, F16, F16>
|
||||
>;
|
||||
|
||||
|
||||
TYPED_TEST_SUITE(TestGemmUniversal_FP8_MK_KN, KernelTypes_MK_KN);
|
||||
TYPED_TEST_SUITE(TestGemmUniversal_FP8_MK_NK, KernelTypes_MK_NK);
|
||||
|
||||
|
||||
#include "test_gemm_universal_ut_cases_fp8.inc"
|
||||
@@ -44,8 +44,9 @@ class TestGemmUniversal_Streamk : public testing::Test
|
||||
|
||||
void SetUp() override
|
||||
{
|
||||
streamk_sel_list = {0, 1, 2}; // 0: Data Parallel (DP) mode (Stream-K OFF), 1: 1-tile
|
||||
// Stream-K+ DP, // {0, 1, 2, 3, 4}
|
||||
grid_size_list = {38, 114, 228}; // {38, 76, 114, 152, 190, 228, 266, 304, 342, 380};
|
||||
streamk_sel_list = {0, 1, 2}; // 0: Data Parallel (DP) mode (Stream-K OFF), 1: 1-tile
|
||||
// Stream-K+ DP, // {0, 1, 2, 3, 4}
|
||||
// 2:2-tile Stream-K + DP
|
||||
}
|
||||
|
||||
@@ -57,9 +58,10 @@ class TestGemmUniversal_Streamk : public testing::Test
|
||||
const int StrideC)
|
||||
{
|
||||
for(auto streamk_sel : streamk_sel_list)
|
||||
{
|
||||
RunSingle(M, N, K, StrideA, StrideB, StrideC, streamk_sel, -1);
|
||||
}
|
||||
for(auto grid_size : grid_size_list)
|
||||
{
|
||||
RunSingle(M, N, K, StrideA, StrideB, StrideC, streamk_sel, grid_size);
|
||||
}
|
||||
}
|
||||
|
||||
void RunSingle(const int M,
|
||||
|
||||
Reference in New Issue
Block a user