mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
CkProfiler StreamK GemmUniversal Fix and Split Gemm_universal Test (#2044)
* fix and split gemm_universal test * clang * Update test_gemm_universal_ut_cases_bf16.inc * Update test_gemm_universal_xdl_bf16.cpp * Update test_gemm_universal_ut_cases_fp16.inc
This commit is contained in:
committed by
GitHub
parent
fed0709121
commit
7142d8003c
2
profiler/include/profiler/profile_gemm_universal_streamk_impl.hpp
Normal file → Executable file
2
profiler/include/profiler/profile_gemm_universal_streamk_impl.hpp
Normal file → Executable file
@@ -166,7 +166,7 @@ bool profile_gemm_universal_streamk_impl(int do_verification,
|
||||
0, 1, 2, 3, 4}; // 0: Data Parallel (DP) mode (Stream-K OFF), 1: 1-tile Stream-K+ DP,
|
||||
// 2:2-tile Stream-K + DP
|
||||
|
||||
if(Grid_size != -1)
|
||||
if(Grid_size == -1)
|
||||
{
|
||||
grid_size_list = {Grid_size};
|
||||
}
|
||||
|
||||
15
test/gemm_universal/CMakeLists.txt
Normal file → Executable file
15
test/gemm_universal/CMakeLists.txt
Normal file → Executable file
@@ -1,4 +1,15 @@
|
||||
add_gtest_executable(test_gemm_universal test_gemm_universal_xdl.cpp)
|
||||
add_gtest_executable(test_gemm_universal_fp16 test_gemm_universal_xdl_fp16.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_gemm_universal PRIVATE utility device_gemm_universal_instance)
|
||||
target_link_libraries(test_gemm_universal_fp16 PRIVATE utility device_gemm_universal_instance)
|
||||
endif()
|
||||
|
||||
add_gtest_executable(test_gemm_universal_fp8 test_gemm_universal_xdl_fp8.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_gemm_universal_fp8 PRIVATE utility device_gemm_universal_instance)
|
||||
endif()
|
||||
|
||||
add_gtest_executable(test_gemm_universal_bf16 test_gemm_universal_xdl_bf16.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_gemm_universal_bf16 PRIVATE utility device_gemm_universal_instance)
|
||||
endif()
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
#pragma once
|
||||
|
||||
TYPED_TEST(TestGemmUniversal_MK_KN, SmallM)
|
||||
TYPED_TEST(TestGemmUniversal_BF16_MK_KN, SmallM)
|
||||
{
|
||||
std::vector<int> Ms{1, 2, 3, 4, 5, 6};
|
||||
constexpr int N = 512;
|
||||
@@ -14,7 +14,7 @@ TYPED_TEST(TestGemmUniversal_MK_KN, SmallM)
|
||||
this->Run(M, N, K, StrideA, StrideB, StrideC);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmUniversal_MK_NK, SmallM)
|
||||
TYPED_TEST(TestGemmUniversal_BF16_MK_NK, SmallM)
|
||||
{
|
||||
std::vector<int> Ms{1, 2, 3, 4, 5, 6};
|
||||
constexpr int N = 512;
|
||||
@@ -28,7 +28,7 @@ TYPED_TEST(TestGemmUniversal_MK_NK, SmallM)
|
||||
this->Run(M, N, K, StrideA, StrideB, StrideC);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmUniversal_KM_KN, SmallM)
|
||||
TYPED_TEST(TestGemmUniversal_BF16_KM_KN, SmallM)
|
||||
{
|
||||
std::vector<int> Ms{1, 2, 3, 4, 5, 6};
|
||||
constexpr int N = 512;
|
||||
@@ -44,7 +44,7 @@ TYPED_TEST(TestGemmUniversal_KM_KN, SmallM)
|
||||
}
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmUniversal_KM_NK, SmallM)
|
||||
TYPED_TEST(TestGemmUniversal_BF16_KM_NK, SmallM)
|
||||
{
|
||||
std::vector<int> Ms{1, 2, 3, 4, 5, 6};
|
||||
constexpr int N = 512;
|
||||
@@ -60,7 +60,7 @@ TYPED_TEST(TestGemmUniversal_KM_NK, SmallM)
|
||||
}
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmUniversal_MK_KN, MidLargeM)
|
||||
TYPED_TEST(TestGemmUniversal_BF16_MK_KN, MidLargeM)
|
||||
{
|
||||
std::vector<int> Ms{127, 255, 312, 799, 1573};
|
||||
constexpr int N = 512;
|
||||
@@ -74,7 +74,7 @@ TYPED_TEST(TestGemmUniversal_MK_KN, MidLargeM)
|
||||
this->Run(M, N, K, StrideA, StrideB, StrideC);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmUniversal_MK_NK, MidLargeM)
|
||||
TYPED_TEST(TestGemmUniversal_BF16_MK_NK, MidLargeM)
|
||||
{
|
||||
std::vector<int> Ms{127, 255, 312, 799, 1573};
|
||||
constexpr int N = 512;
|
||||
@@ -88,7 +88,7 @@ TYPED_TEST(TestGemmUniversal_MK_NK, MidLargeM)
|
||||
this->Run(M, N, K, StrideA, StrideB, StrideC);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmUniversal_KM_KN, MidLargeM)
|
||||
TYPED_TEST(TestGemmUniversal_BF16_KM_KN, MidLargeM)
|
||||
{
|
||||
std::vector<int> Ms{127, 255, 312, 799, 1573};
|
||||
constexpr int N = 512;
|
||||
@@ -104,7 +104,7 @@ TYPED_TEST(TestGemmUniversal_KM_KN, MidLargeM)
|
||||
}
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmUniversal_KM_NK, MidLargeM)
|
||||
TYPED_TEST(TestGemmUniversal_BF16_KM_NK, MidLargeM)
|
||||
{
|
||||
std::vector<int> Ms{127, 255, 312, 799, 1573};
|
||||
constexpr int N = 512;
|
||||
@@ -120,7 +120,7 @@ TYPED_TEST(TestGemmUniversal_KM_NK, MidLargeM)
|
||||
}
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmUniversal_MK_KN, PaddK)
|
||||
TYPED_TEST(TestGemmUniversal_BF16_MK_KN, PaddK)
|
||||
{
|
||||
std::vector<int> Ms{127};
|
||||
constexpr int N = 512;
|
||||
@@ -134,7 +134,7 @@ TYPED_TEST(TestGemmUniversal_MK_KN, PaddK)
|
||||
this->Run(M, N, K, StrideA, StrideB, StrideC);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmUniversal_MK_NK, PaddK)
|
||||
TYPED_TEST(TestGemmUniversal_BF16_MK_NK, PaddK)
|
||||
{
|
||||
std::vector<int> Ms{127};
|
||||
constexpr int N = 512;
|
||||
@@ -148,7 +148,7 @@ TYPED_TEST(TestGemmUniversal_MK_NK, PaddK)
|
||||
this->Run(M, N, K, StrideA, StrideB, StrideC);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmUniversal_KM_KN, PaddK)
|
||||
TYPED_TEST(TestGemmUniversal_BF16_KM_KN, PaddK)
|
||||
{
|
||||
std::vector<int> Ms{127};
|
||||
constexpr int N = 512;
|
||||
@@ -164,7 +164,7 @@ TYPED_TEST(TestGemmUniversal_KM_KN, PaddK)
|
||||
}
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmUniversal_KM_NK, PaddK)
|
||||
TYPED_TEST(TestGemmUniversal_BF16_KM_NK, PaddK)
|
||||
{
|
||||
std::vector<int> Ms{127};
|
||||
constexpr int N = 512;
|
||||
@@ -180,7 +180,7 @@ TYPED_TEST(TestGemmUniversal_KM_NK, PaddK)
|
||||
}
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmUniversal_MK_KN, Regular)
|
||||
TYPED_TEST(TestGemmUniversal_BF16_MK_KN, Regular)
|
||||
{
|
||||
std::vector<int> Ms{512};
|
||||
constexpr int N = 512;
|
||||
@@ -194,7 +194,7 @@ TYPED_TEST(TestGemmUniversal_MK_KN, Regular)
|
||||
this->Run(M, N, K, StrideA, StrideB, StrideC);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmUniversal_MK_NK, Regular)
|
||||
TYPED_TEST(TestGemmUniversal_BF16_MK_NK, Regular)
|
||||
{
|
||||
std::vector<int> Ms{512};
|
||||
constexpr int N = 512;
|
||||
@@ -207,35 +207,3 @@ TYPED_TEST(TestGemmUniversal_MK_NK, Regular)
|
||||
for(int M : Ms)
|
||||
this->Run(M, N, K, StrideA, StrideB, StrideC);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmUniversal_KM_KN, Regular)
|
||||
{
|
||||
std::vector<int> Ms{512};
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 512;
|
||||
|
||||
constexpr int StrideB = N;
|
||||
constexpr int StrideC = N;
|
||||
|
||||
for(int M : Ms)
|
||||
{
|
||||
int StrideA = M;
|
||||
this->Run(M, N, K, StrideA, StrideB, StrideC);
|
||||
}
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmUniversal_KM_NK, Regular)
|
||||
{
|
||||
std::vector<int> Ms{512};
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 512;
|
||||
|
||||
constexpr int StrideB = N;
|
||||
constexpr int StrideC = N;
|
||||
|
||||
for(int M : Ms)
|
||||
{
|
||||
int StrideA = M;
|
||||
this->Run(M, N, K, StrideA, StrideB, StrideC);
|
||||
}
|
||||
}
|
||||
99
test/gemm_universal/test_gemm_universal_ut_cases_fp16.inc
Normal file
99
test/gemm_universal/test_gemm_universal_ut_cases_fp16.inc
Normal file
@@ -0,0 +1,99 @@
|
||||
#pragma once
|
||||
|
||||
TYPED_TEST(TestGemmUniversal_FP16_MK_KN, SmallM)
|
||||
{
|
||||
std::vector<int> Ms{1, 2, 3, 4, 5, 6};
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 320;
|
||||
|
||||
constexpr int StrideA = K;
|
||||
constexpr int StrideB = N;
|
||||
constexpr int StrideC = N;
|
||||
|
||||
for(int M : Ms)
|
||||
this->Run(M, N, K, StrideA, StrideB, StrideC);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmUniversal_FP16_MK_NK, SmallM)
|
||||
{
|
||||
std::vector<int> Ms{1, 2, 3, 4, 5, 6};
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 320;
|
||||
|
||||
constexpr int StrideA = K;
|
||||
constexpr int StrideB = K;
|
||||
constexpr int StrideC = N;
|
||||
|
||||
for(int M : Ms)
|
||||
this->Run(M, N, K, StrideA, StrideB, StrideC);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmUniversal_FP16_MK_NK, MidLargeM)
|
||||
{
|
||||
std::vector<int> Ms{127, 255, 312, 799, 1573};
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 320;
|
||||
|
||||
constexpr int StrideA = K;
|
||||
constexpr int StrideB = K;
|
||||
constexpr int StrideC = N;
|
||||
|
||||
for(int M : Ms)
|
||||
this->Run(M, N, K, StrideA, StrideB, StrideC);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmUniversal_FP16_MK_KN, PaddK)
|
||||
{
|
||||
std::vector<int> Ms{127};
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 437;
|
||||
|
||||
constexpr int StrideA = K;
|
||||
constexpr int StrideB = N;
|
||||
constexpr int StrideC = N;
|
||||
|
||||
for(int M : Ms)
|
||||
this->Run(M, N, K, StrideA, StrideB, StrideC);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmUniversal_FP16_MK_NK, PaddK)
|
||||
{
|
||||
std::vector<int> Ms{127};
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 437;
|
||||
|
||||
constexpr int StrideA = K;
|
||||
constexpr int StrideB = K;
|
||||
constexpr int StrideC = N;
|
||||
|
||||
for(int M : Ms)
|
||||
this->Run(M, N, K, StrideA, StrideB, StrideC);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmUniversal_FP16_MK_KN, Regular)
|
||||
{
|
||||
std::vector<int> Ms{512};
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 512;
|
||||
|
||||
constexpr int StrideA = K;
|
||||
constexpr int StrideB = N;
|
||||
constexpr int StrideC = N;
|
||||
|
||||
for(int M : Ms)
|
||||
this->Run(M, N, K, StrideA, StrideB, StrideC);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmUniversal_FP16_MK_NK, Regular)
|
||||
{
|
||||
std::vector<int> Ms{512};
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 512;
|
||||
|
||||
constexpr int StrideA = K;
|
||||
constexpr int StrideB = K;
|
||||
constexpr int StrideC = N;
|
||||
|
||||
for(int M : Ms)
|
||||
this->Run(M, N, K, StrideA, StrideB, StrideC);
|
||||
}
|
||||
113
test/gemm_universal/test_gemm_universal_ut_cases_fp8.inc
Normal file
113
test/gemm_universal/test_gemm_universal_ut_cases_fp8.inc
Normal file
@@ -0,0 +1,113 @@
|
||||
#pragma once
|
||||
|
||||
TYPED_TEST(TestGemmUniversal_FP8_MK_KN, SmallM)
|
||||
{
|
||||
std::vector<int> Ms{1, 2, 3, 4, 5, 6};
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 320;
|
||||
|
||||
constexpr int StrideA = K;
|
||||
constexpr int StrideB = N;
|
||||
constexpr int StrideC = N;
|
||||
|
||||
for(int M : Ms)
|
||||
this->Run(M, N, K, StrideA, StrideB, StrideC);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmUniversal_FP8_MK_NK, SmallM)
|
||||
{
|
||||
std::vector<int> Ms{1, 2, 3, 4, 5, 6};
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 320;
|
||||
|
||||
constexpr int StrideA = K;
|
||||
constexpr int StrideB = K;
|
||||
constexpr int StrideC = N;
|
||||
|
||||
for(int M : Ms)
|
||||
this->Run(M, N, K, StrideA, StrideB, StrideC);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmUniversal_FP8_MK_KN, MidLargeM)
|
||||
{
|
||||
std::vector<int> Ms{127, 255, 312, 799, 1573};
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 320;
|
||||
|
||||
constexpr int StrideA = K;
|
||||
constexpr int StrideB = N;
|
||||
constexpr int StrideC = N;
|
||||
|
||||
for(int M : Ms)
|
||||
this->Run(M, N, K, StrideA, StrideB, StrideC);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmUniversal_FP8_MK_NK, MidLargeM)
|
||||
{
|
||||
std::vector<int> Ms{127, 255, 312, 799, 1573};
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 320;
|
||||
|
||||
constexpr int StrideA = K;
|
||||
constexpr int StrideB = K;
|
||||
constexpr int StrideC = N;
|
||||
|
||||
for(int M : Ms)
|
||||
this->Run(M, N, K, StrideA, StrideB, StrideC);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmUniversal_FP8_MK_KN, PaddK)
|
||||
{
|
||||
std::vector<int> Ms{127};
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 437;
|
||||
|
||||
constexpr int StrideA = K;
|
||||
constexpr int StrideB = N;
|
||||
constexpr int StrideC = N;
|
||||
|
||||
for(int M : Ms)
|
||||
this->Run(M, N, K, StrideA, StrideB, StrideC);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmUniversal_FP8_MK_NK, PaddK)
|
||||
{
|
||||
std::vector<int> Ms{127};
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 437;
|
||||
|
||||
constexpr int StrideA = K;
|
||||
constexpr int StrideB = K;
|
||||
constexpr int StrideC = N;
|
||||
|
||||
for(int M : Ms)
|
||||
this->Run(M, N, K, StrideA, StrideB, StrideC);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmUniversal_FP8_MK_KN, Regular)
|
||||
{
|
||||
std::vector<int> Ms{512};
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 512;
|
||||
|
||||
constexpr int StrideA = K;
|
||||
constexpr int StrideB = N;
|
||||
constexpr int StrideC = N;
|
||||
|
||||
for(int M : Ms)
|
||||
this->Run(M, N, K, StrideA, StrideB, StrideC);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmUniversal_FP8_MK_NK, Regular)
|
||||
{
|
||||
std::vector<int> Ms{512};
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 512;
|
||||
|
||||
constexpr int StrideA = K;
|
||||
constexpr int StrideB = K;
|
||||
constexpr int StrideC = N;
|
||||
|
||||
for(int M : Ms)
|
||||
this->Run(M, N, K, StrideA, StrideB, StrideC);
|
||||
}
|
||||
@@ -7,8 +7,6 @@
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "test_gemm_universal_util.hpp"
|
||||
|
||||
using F8 = ck::f8_t;
|
||||
using F16 = ck::half_t;
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F32 = float;
|
||||
|
||||
@@ -29,25 +27,25 @@ struct tuple_concat<std::tuple<Xs...>, std::tuple<Ys...>>
|
||||
} // namespace
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGemmUniversal_MK_KN
|
||||
class TestGemmUniversal_BF16_MK_KN
|
||||
: public ck::test::TestGemmUniversal<typename tuple_concat<std::tuple<Row, Row>, Tuple>::type>
|
||||
{
|
||||
};
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGemmUniversal_MK_NK
|
||||
class TestGemmUniversal_BF16_MK_NK
|
||||
: public ck::test::TestGemmUniversal<typename tuple_concat<std::tuple<Row, Col>, Tuple>::type>
|
||||
{
|
||||
};
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGemmUniversal_KM_KN
|
||||
class TestGemmUniversal_BF16_KM_KN
|
||||
: public ck::test::TestGemmUniversal<typename tuple_concat<std::tuple<Col, Row>, Tuple>::type>
|
||||
{
|
||||
};
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGemmUniversal_KM_NK
|
||||
class TestGemmUniversal_BF16_KM_NK
|
||||
: public ck::test::TestGemmUniversal<typename tuple_concat<std::tuple<Col, Col>, Tuple>::type>
|
||||
{
|
||||
};
|
||||
@@ -55,22 +53,12 @@ class TestGemmUniversal_KM_NK
|
||||
// clang-format off
|
||||
using KernelTypes_MK_KN = ::testing::Types<
|
||||
// ADataType, BDataType, ComputeDataType, CDataType
|
||||
std::tuple< F16, F16, F16, F16>,
|
||||
#if defined(CK_ENABLE_FP8) && (defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94))
|
||||
std::tuple< F16, F8, F16, F16>,
|
||||
std::tuple< F8, F16, F16, F16>,
|
||||
std::tuple< F8, F8, F8, BF16>,
|
||||
#endif
|
||||
|
||||
std::tuple< BF16, BF16, BF16, BF16>
|
||||
>;
|
||||
using KernelTypes_MK_NK = ::testing::Types<
|
||||
// ADataType, BDataType, ComputeDataType, CDataType
|
||||
std::tuple< F16, F16, F16, F16>,
|
||||
#if defined(CK_ENABLE_FP8) && (defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94))
|
||||
std::tuple< F16, F8, F16, F16>,
|
||||
std::tuple< F8, F16, F16, F16>,
|
||||
std::tuple< F8, F8, F8, BF16>,
|
||||
#endif
|
||||
|
||||
std::tuple< BF16, BF16, BF16, BF16>
|
||||
>;
|
||||
|
||||
@@ -86,9 +74,9 @@ using KernelTypes_KM_KN = ::testing::Types<
|
||||
|
||||
// clang-format on
|
||||
|
||||
TYPED_TEST_SUITE(TestGemmUniversal_MK_KN, KernelTypes_MK_KN);
|
||||
TYPED_TEST_SUITE(TestGemmUniversal_MK_NK, KernelTypes_MK_NK);
|
||||
TYPED_TEST_SUITE(TestGemmUniversal_KM_KN, KernelTypes_KM_KN);
|
||||
TYPED_TEST_SUITE(TestGemmUniversal_KM_NK, KernelTypes_KM_NK);
|
||||
TYPED_TEST_SUITE(TestGemmUniversal_BF16_MK_KN, KernelTypes_MK_KN);
|
||||
TYPED_TEST_SUITE(TestGemmUniversal_BF16_MK_NK, KernelTypes_MK_NK);
|
||||
TYPED_TEST_SUITE(TestGemmUniversal_BF16_KM_KN, KernelTypes_KM_KN);
|
||||
TYPED_TEST_SUITE(TestGemmUniversal_BF16_KM_NK, KernelTypes_KM_NK);
|
||||
|
||||
#include "test_gemm_universal_ut_cases.inc"
|
||||
#include "test_gemm_universal_ut_cases_bf16.inc"
|
||||
82
test/gemm_universal/test_gemm_universal_xdl_fp16.cpp
Normal file
82
test/gemm_universal/test_gemm_universal_xdl_fp16.cpp
Normal file
@@ -0,0 +1,82 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <tuple>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "test_gemm_universal_util.hpp"
|
||||
|
||||
using F8 = ck::f8_t;
|
||||
using F16 = ck::half_t;
|
||||
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename X, typename Y>
|
||||
struct tuple_concat;
|
||||
|
||||
template <typename... Xs, typename... Ys>
|
||||
struct tuple_concat<std::tuple<Xs...>, std::tuple<Ys...>>
|
||||
{
|
||||
using type = std::tuple<Xs..., Ys...>;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGemmUniversal_FP16_MK_KN
|
||||
: public ck::test::TestGemmUniversal<typename tuple_concat<std::tuple<Row, Row>, Tuple>::type>
|
||||
{
|
||||
};
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGemmUniversal_FP16_MK_NK
|
||||
: public ck::test::TestGemmUniversal<typename tuple_concat<std::tuple<Row, Col>, Tuple>::type>
|
||||
{
|
||||
};
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGemmUniversal_FP16_KM_KN
|
||||
: public ck::test::TestGemmUniversal<typename tuple_concat<std::tuple<Col, Row>, Tuple>::type>
|
||||
{
|
||||
};
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGemmUniversal_FP16_KM_NK
|
||||
: public ck::test::TestGemmUniversal<typename tuple_concat<std::tuple<Col, Col>, Tuple>::type>
|
||||
{
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
using KernelTypes_MK_KN = ::testing::Types<
|
||||
// ADataType, BDataType, ComputeDataType, CDataType
|
||||
|
||||
#if defined(CK_ENABLE_FP8) && (defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94))
|
||||
std::tuple< F16, F8, F16, F16>,
|
||||
std::tuple< F8, F16, F16, F16>,
|
||||
|
||||
#endif
|
||||
std::tuple< F16, F16, F16, F16>
|
||||
>;
|
||||
using KernelTypes_MK_NK = ::testing::Types<
|
||||
// ADataType, BDataType, ComputeDataType, CDataType
|
||||
|
||||
#if defined(CK_ENABLE_FP8) && (defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94))
|
||||
std::tuple< F16, F8, F16, F16>,
|
||||
std::tuple< F8, F16, F16, F16>,
|
||||
|
||||
#endif
|
||||
std::tuple< F16, F16, F16, F16>
|
||||
>;
|
||||
|
||||
// clang-format on
|
||||
|
||||
TYPED_TEST_SUITE(TestGemmUniversal_FP16_MK_KN, KernelTypes_MK_KN);
|
||||
TYPED_TEST_SUITE(TestGemmUniversal_FP16_MK_NK, KernelTypes_MK_NK);
|
||||
|
||||
#include "test_gemm_universal_ut_cases_fp16.inc"
|
||||
71
test/gemm_universal/test_gemm_universal_xdl_fp8.cpp
Normal file
71
test/gemm_universal/test_gemm_universal_xdl_fp8.cpp
Normal file
@@ -0,0 +1,71 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <tuple>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "test_gemm_universal_util.hpp"
|
||||
|
||||
using F8 = ck::f8_t;
|
||||
using F16 = ck::half_t;
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename X, typename Y>
|
||||
struct tuple_concat;
|
||||
|
||||
template <typename... Xs, typename... Ys>
|
||||
struct tuple_concat<std::tuple<Xs...>, std::tuple<Ys...>>
|
||||
{
|
||||
using type = std::tuple<Xs..., Ys...>;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGemmUniversal_FP8_MK_KN
|
||||
: public ck::test::TestGemmUniversal<typename tuple_concat<std::tuple<Row, Row>, Tuple>::type>
|
||||
{
|
||||
};
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGemmUniversal_FP8_MK_NK
|
||||
: public ck::test::TestGemmUniversal<typename tuple_concat<std::tuple<Row, Col>, Tuple>::type>
|
||||
{
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
using KernelTypes_MK_KN = ::testing::Types<
|
||||
// ADataType, BDataType, ComputeDataType, CDataType
|
||||
#if defined(CK_ENABLE_FP8) && (defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94))
|
||||
std::tuple< F16, F8, F16, F16>,
|
||||
std::tuple< F8, F16, F16, F16>,
|
||||
std::tuple< F8, F8, F8, BF16>,
|
||||
#endif
|
||||
// Fallback test type when FP8 is not enabled
|
||||
std::tuple< F16, F16, F16, F16>
|
||||
>;
|
||||
using KernelTypes_MK_NK = ::testing::Types<
|
||||
// ADataType, BDataType, ComputeDataType, CDataType
|
||||
|
||||
#if defined(CK_ENABLE_FP8) && (defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94))
|
||||
std::tuple< F16, F8, F16, F16>,
|
||||
std::tuple< F8, F16, F16, F16>,
|
||||
std::tuple< F8, F8, F8, BF16>,
|
||||
#endif
|
||||
// Fallback test type when FP8 is not enabled
|
||||
std::tuple< F16, F16, F16, F16>
|
||||
>;
|
||||
|
||||
|
||||
TYPED_TEST_SUITE(TestGemmUniversal_FP8_MK_KN, KernelTypes_MK_KN);
|
||||
TYPED_TEST_SUITE(TestGemmUniversal_FP8_MK_NK, KernelTypes_MK_NK);
|
||||
|
||||
|
||||
#include "test_gemm_universal_ut_cases_fp8.inc"
|
||||
@@ -44,9 +44,8 @@ class TestGemmUniversal_Streamk : public testing::Test
|
||||
|
||||
void SetUp() override
|
||||
{
|
||||
grid_size_list = {38, 114, 228}; // {38, 76, 114, 152, 190, 228, 266, 304, 342, 380};
|
||||
streamk_sel_list = {0, 1, 2}; // 0: Data Parallel (DP) mode (Stream-K OFF), 1: 1-tile
|
||||
// Stream-K+ DP, // {0, 1, 2, 3, 4}
|
||||
streamk_sel_list = {0, 1, 2}; // 0: Data Parallel (DP) mode (Stream-K OFF), 1: 1-tile
|
||||
// Stream-K+ DP, // {0, 1, 2, 3, 4}
|
||||
// 2:2-tile Stream-K + DP
|
||||
}
|
||||
|
||||
@@ -58,10 +57,9 @@ class TestGemmUniversal_Streamk : public testing::Test
|
||||
const int StrideC)
|
||||
{
|
||||
for(auto streamk_sel : streamk_sel_list)
|
||||
for(auto grid_size : grid_size_list)
|
||||
{
|
||||
RunSingle(M, N, K, StrideA, StrideB, StrideC, streamk_sel, grid_size);
|
||||
}
|
||||
{
|
||||
RunSingle(M, N, K, StrideA, StrideB, StrideC, streamk_sel, -1);
|
||||
}
|
||||
}
|
||||
|
||||
void RunSingle(const int M,
|
||||
|
||||
Reference in New Issue
Block a user