CkProfiler StreamK GemmUniversal Fix and Split Gemm_universal Test (#2044)

* fix and split gemm_universal test

* clang

* Update test_gemm_universal_ut_cases_bf16.inc

* Update test_gemm_universal_xdl_bf16.cpp

* Update test_gemm_universal_ut_cases_fp16.inc
This commit is contained in:
Muhammed Emin Ozturk
2025-04-03 14:22:43 -07:00
committed by GitHub
parent fed0709121
commit 7142d8003c
9 changed files with 409 additions and 79 deletions

View File

@@ -166,7 +166,7 @@ bool profile_gemm_universal_streamk_impl(int do_verification,
0, 1, 2, 3, 4}; // 0: Data Parallel (DP) mode (Stream-K OFF), 1: 1-tile Stream-K+ DP,
// 2:2-tile Stream-K + DP
if(Grid_size != -1)
if(Grid_size == -1)
{
grid_size_list = {Grid_size};
}

15
test/gemm_universal/CMakeLists.txt Normal file → Executable file
View File

@@ -1,4 +1,15 @@
add_gtest_executable(test_gemm_universal test_gemm_universal_xdl.cpp)
add_gtest_executable(test_gemm_universal_fp16 test_gemm_universal_xdl_fp16.cpp)
if(result EQUAL 0)
target_link_libraries(test_gemm_universal PRIVATE utility device_gemm_universal_instance)
target_link_libraries(test_gemm_universal_fp16 PRIVATE utility device_gemm_universal_instance)
endif()
add_gtest_executable(test_gemm_universal_fp8 test_gemm_universal_xdl_fp8.cpp)
if(result EQUAL 0)
target_link_libraries(test_gemm_universal_fp8 PRIVATE utility device_gemm_universal_instance)
endif()
add_gtest_executable(test_gemm_universal_bf16 test_gemm_universal_xdl_bf16.cpp)
if(result EQUAL 0)
target_link_libraries(test_gemm_universal_bf16 PRIVATE utility device_gemm_universal_instance)
endif()

View File

@@ -1,6 +1,6 @@
#pragma once
TYPED_TEST(TestGemmUniversal_MK_KN, SmallM)
TYPED_TEST(TestGemmUniversal_BF16_MK_KN, SmallM)
{
std::vector<int> Ms{1, 2, 3, 4, 5, 6};
constexpr int N = 512;
@@ -14,7 +14,7 @@ TYPED_TEST(TestGemmUniversal_MK_KN, SmallM)
this->Run(M, N, K, StrideA, StrideB, StrideC);
}
TYPED_TEST(TestGemmUniversal_MK_NK, SmallM)
TYPED_TEST(TestGemmUniversal_BF16_MK_NK, SmallM)
{
std::vector<int> Ms{1, 2, 3, 4, 5, 6};
constexpr int N = 512;
@@ -28,7 +28,7 @@ TYPED_TEST(TestGemmUniversal_MK_NK, SmallM)
this->Run(M, N, K, StrideA, StrideB, StrideC);
}
TYPED_TEST(TestGemmUniversal_KM_KN, SmallM)
TYPED_TEST(TestGemmUniversal_BF16_KM_KN, SmallM)
{
std::vector<int> Ms{1, 2, 3, 4, 5, 6};
constexpr int N = 512;
@@ -44,7 +44,7 @@ TYPED_TEST(TestGemmUniversal_KM_KN, SmallM)
}
}
TYPED_TEST(TestGemmUniversal_KM_NK, SmallM)
TYPED_TEST(TestGemmUniversal_BF16_KM_NK, SmallM)
{
std::vector<int> Ms{1, 2, 3, 4, 5, 6};
constexpr int N = 512;
@@ -60,7 +60,7 @@ TYPED_TEST(TestGemmUniversal_KM_NK, SmallM)
}
}
TYPED_TEST(TestGemmUniversal_MK_KN, MidLargeM)
TYPED_TEST(TestGemmUniversal_BF16_MK_KN, MidLargeM)
{
std::vector<int> Ms{127, 255, 312, 799, 1573};
constexpr int N = 512;
@@ -74,7 +74,7 @@ TYPED_TEST(TestGemmUniversal_MK_KN, MidLargeM)
this->Run(M, N, K, StrideA, StrideB, StrideC);
}
TYPED_TEST(TestGemmUniversal_MK_NK, MidLargeM)
TYPED_TEST(TestGemmUniversal_BF16_MK_NK, MidLargeM)
{
std::vector<int> Ms{127, 255, 312, 799, 1573};
constexpr int N = 512;
@@ -88,7 +88,7 @@ TYPED_TEST(TestGemmUniversal_MK_NK, MidLargeM)
this->Run(M, N, K, StrideA, StrideB, StrideC);
}
TYPED_TEST(TestGemmUniversal_KM_KN, MidLargeM)
TYPED_TEST(TestGemmUniversal_BF16_KM_KN, MidLargeM)
{
std::vector<int> Ms{127, 255, 312, 799, 1573};
constexpr int N = 512;
@@ -104,7 +104,7 @@ TYPED_TEST(TestGemmUniversal_KM_KN, MidLargeM)
}
}
TYPED_TEST(TestGemmUniversal_KM_NK, MidLargeM)
TYPED_TEST(TestGemmUniversal_BF16_KM_NK, MidLargeM)
{
std::vector<int> Ms{127, 255, 312, 799, 1573};
constexpr int N = 512;
@@ -120,7 +120,7 @@ TYPED_TEST(TestGemmUniversal_KM_NK, MidLargeM)
}
}
TYPED_TEST(TestGemmUniversal_MK_KN, PaddK)
TYPED_TEST(TestGemmUniversal_BF16_MK_KN, PaddK)
{
std::vector<int> Ms{127};
constexpr int N = 512;
@@ -134,7 +134,7 @@ TYPED_TEST(TestGemmUniversal_MK_KN, PaddK)
this->Run(M, N, K, StrideA, StrideB, StrideC);
}
TYPED_TEST(TestGemmUniversal_MK_NK, PaddK)
TYPED_TEST(TestGemmUniversal_BF16_MK_NK, PaddK)
{
std::vector<int> Ms{127};
constexpr int N = 512;
@@ -148,7 +148,7 @@ TYPED_TEST(TestGemmUniversal_MK_NK, PaddK)
this->Run(M, N, K, StrideA, StrideB, StrideC);
}
TYPED_TEST(TestGemmUniversal_KM_KN, PaddK)
TYPED_TEST(TestGemmUniversal_BF16_KM_KN, PaddK)
{
std::vector<int> Ms{127};
constexpr int N = 512;
@@ -164,7 +164,7 @@ TYPED_TEST(TestGemmUniversal_KM_KN, PaddK)
}
}
TYPED_TEST(TestGemmUniversal_KM_NK, PaddK)
TYPED_TEST(TestGemmUniversal_BF16_KM_NK, PaddK)
{
std::vector<int> Ms{127};
constexpr int N = 512;
@@ -180,7 +180,7 @@ TYPED_TEST(TestGemmUniversal_KM_NK, PaddK)
}
}
TYPED_TEST(TestGemmUniversal_MK_KN, Regular)
TYPED_TEST(TestGemmUniversal_BF16_MK_KN, Regular)
{
std::vector<int> Ms{512};
constexpr int N = 512;
@@ -194,7 +194,7 @@ TYPED_TEST(TestGemmUniversal_MK_KN, Regular)
this->Run(M, N, K, StrideA, StrideB, StrideC);
}
TYPED_TEST(TestGemmUniversal_MK_NK, Regular)
TYPED_TEST(TestGemmUniversal_BF16_MK_NK, Regular)
{
std::vector<int> Ms{512};
constexpr int N = 512;
@@ -207,35 +207,3 @@ TYPED_TEST(TestGemmUniversal_MK_NK, Regular)
for(int M : Ms)
this->Run(M, N, K, StrideA, StrideB, StrideC);
}
TYPED_TEST(TestGemmUniversal_KM_KN, Regular)
{
std::vector<int> Ms{512};
constexpr int N = 512;
constexpr int K = 512;
constexpr int StrideB = N;
constexpr int StrideC = N;
for(int M : Ms)
{
int StrideA = M;
this->Run(M, N, K, StrideA, StrideB, StrideC);
}
}
TYPED_TEST(TestGemmUniversal_KM_NK, Regular)
{
std::vector<int> Ms{512};
constexpr int N = 512;
constexpr int K = 512;
constexpr int StrideB = N;
constexpr int StrideC = N;
for(int M : Ms)
{
int StrideA = M;
this->Run(M, N, K, StrideA, StrideB, StrideC);
}
}

View File

@@ -0,0 +1,99 @@
#pragma once
TYPED_TEST(TestGemmUniversal_FP16_MK_KN, SmallM)
{
std::vector<int> Ms{1, 2, 3, 4, 5, 6};
constexpr int N = 512;
constexpr int K = 320;
constexpr int StrideA = K;
constexpr int StrideB = N;
constexpr int StrideC = N;
for(int M : Ms)
this->Run(M, N, K, StrideA, StrideB, StrideC);
}
TYPED_TEST(TestGemmUniversal_FP16_MK_NK, SmallM)
{
std::vector<int> Ms{1, 2, 3, 4, 5, 6};
constexpr int N = 512;
constexpr int K = 320;
constexpr int StrideA = K;
constexpr int StrideB = K;
constexpr int StrideC = N;
for(int M : Ms)
this->Run(M, N, K, StrideA, StrideB, StrideC);
}
TYPED_TEST(TestGemmUniversal_FP16_MK_NK, MidLargeM)
{
std::vector<int> Ms{127, 255, 312, 799, 1573};
constexpr int N = 512;
constexpr int K = 320;
constexpr int StrideA = K;
constexpr int StrideB = K;
constexpr int StrideC = N;
for(int M : Ms)
this->Run(M, N, K, StrideA, StrideB, StrideC);
}
TYPED_TEST(TestGemmUniversal_FP16_MK_KN, PaddK)
{
std::vector<int> Ms{127};
constexpr int N = 512;
constexpr int K = 437;
constexpr int StrideA = K;
constexpr int StrideB = N;
constexpr int StrideC = N;
for(int M : Ms)
this->Run(M, N, K, StrideA, StrideB, StrideC);
}
TYPED_TEST(TestGemmUniversal_FP16_MK_NK, PaddK)
{
std::vector<int> Ms{127};
constexpr int N = 512;
constexpr int K = 437;
constexpr int StrideA = K;
constexpr int StrideB = K;
constexpr int StrideC = N;
for(int M : Ms)
this->Run(M, N, K, StrideA, StrideB, StrideC);
}
TYPED_TEST(TestGemmUniversal_FP16_MK_KN, Regular)
{
std::vector<int> Ms{512};
constexpr int N = 512;
constexpr int K = 512;
constexpr int StrideA = K;
constexpr int StrideB = N;
constexpr int StrideC = N;
for(int M : Ms)
this->Run(M, N, K, StrideA, StrideB, StrideC);
}
TYPED_TEST(TestGemmUniversal_FP16_MK_NK, Regular)
{
std::vector<int> Ms{512};
constexpr int N = 512;
constexpr int K = 512;
constexpr int StrideA = K;
constexpr int StrideB = K;
constexpr int StrideC = N;
for(int M : Ms)
this->Run(M, N, K, StrideA, StrideB, StrideC);
}

View File

@@ -0,0 +1,113 @@
#pragma once
TYPED_TEST(TestGemmUniversal_FP8_MK_KN, SmallM)
{
std::vector<int> Ms{1, 2, 3, 4, 5, 6};
constexpr int N = 512;
constexpr int K = 320;
constexpr int StrideA = K;
constexpr int StrideB = N;
constexpr int StrideC = N;
for(int M : Ms)
this->Run(M, N, K, StrideA, StrideB, StrideC);
}
TYPED_TEST(TestGemmUniversal_FP8_MK_NK, SmallM)
{
std::vector<int> Ms{1, 2, 3, 4, 5, 6};
constexpr int N = 512;
constexpr int K = 320;
constexpr int StrideA = K;
constexpr int StrideB = K;
constexpr int StrideC = N;
for(int M : Ms)
this->Run(M, N, K, StrideA, StrideB, StrideC);
}
TYPED_TEST(TestGemmUniversal_FP8_MK_KN, MidLargeM)
{
std::vector<int> Ms{127, 255, 312, 799, 1573};
constexpr int N = 512;
constexpr int K = 320;
constexpr int StrideA = K;
constexpr int StrideB = N;
constexpr int StrideC = N;
for(int M : Ms)
this->Run(M, N, K, StrideA, StrideB, StrideC);
}
TYPED_TEST(TestGemmUniversal_FP8_MK_NK, MidLargeM)
{
std::vector<int> Ms{127, 255, 312, 799, 1573};
constexpr int N = 512;
constexpr int K = 320;
constexpr int StrideA = K;
constexpr int StrideB = K;
constexpr int StrideC = N;
for(int M : Ms)
this->Run(M, N, K, StrideA, StrideB, StrideC);
}
TYPED_TEST(TestGemmUniversal_FP8_MK_KN, PaddK)
{
std::vector<int> Ms{127};
constexpr int N = 512;
constexpr int K = 437;
constexpr int StrideA = K;
constexpr int StrideB = N;
constexpr int StrideC = N;
for(int M : Ms)
this->Run(M, N, K, StrideA, StrideB, StrideC);
}
TYPED_TEST(TestGemmUniversal_FP8_MK_NK, PaddK)
{
std::vector<int> Ms{127};
constexpr int N = 512;
constexpr int K = 437;
constexpr int StrideA = K;
constexpr int StrideB = K;
constexpr int StrideC = N;
for(int M : Ms)
this->Run(M, N, K, StrideA, StrideB, StrideC);
}
TYPED_TEST(TestGemmUniversal_FP8_MK_KN, Regular)
{
std::vector<int> Ms{512};
constexpr int N = 512;
constexpr int K = 512;
constexpr int StrideA = K;
constexpr int StrideB = N;
constexpr int StrideC = N;
for(int M : Ms)
this->Run(M, N, K, StrideA, StrideB, StrideC);
}
TYPED_TEST(TestGemmUniversal_FP8_MK_NK, Regular)
{
std::vector<int> Ms{512};
constexpr int N = 512;
constexpr int K = 512;
constexpr int StrideA = K;
constexpr int StrideB = K;
constexpr int StrideC = N;
for(int M : Ms)
this->Run(M, N, K, StrideA, StrideB, StrideC);
}

View File

@@ -7,8 +7,6 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "test_gemm_universal_util.hpp"
using F8 = ck::f8_t;
using F16 = ck::half_t;
using BF16 = ck::bhalf_t;
using F32 = float;
@@ -29,25 +27,25 @@ struct tuple_concat<std::tuple<Xs...>, std::tuple<Ys...>>
} // namespace
template <typename Tuple>
class TestGemmUniversal_MK_KN
class TestGemmUniversal_BF16_MK_KN
: public ck::test::TestGemmUniversal<typename tuple_concat<std::tuple<Row, Row>, Tuple>::type>
{
};
template <typename Tuple>
class TestGemmUniversal_MK_NK
class TestGemmUniversal_BF16_MK_NK
: public ck::test::TestGemmUniversal<typename tuple_concat<std::tuple<Row, Col>, Tuple>::type>
{
};
template <typename Tuple>
class TestGemmUniversal_KM_KN
class TestGemmUniversal_BF16_KM_KN
: public ck::test::TestGemmUniversal<typename tuple_concat<std::tuple<Col, Row>, Tuple>::type>
{
};
template <typename Tuple>
class TestGemmUniversal_KM_NK
class TestGemmUniversal_BF16_KM_NK
: public ck::test::TestGemmUniversal<typename tuple_concat<std::tuple<Col, Col>, Tuple>::type>
{
};
@@ -55,22 +53,12 @@ class TestGemmUniversal_KM_NK
// clang-format off
using KernelTypes_MK_KN = ::testing::Types<
// ADataType, BDataType, ComputeDataType, CDataType
std::tuple< F16, F16, F16, F16>,
#if defined(CK_ENABLE_FP8) && (defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94))
std::tuple< F16, F8, F16, F16>,
std::tuple< F8, F16, F16, F16>,
std::tuple< F8, F8, F8, BF16>,
#endif
std::tuple< BF16, BF16, BF16, BF16>
>;
using KernelTypes_MK_NK = ::testing::Types<
// ADataType, BDataType, ComputeDataType, CDataType
std::tuple< F16, F16, F16, F16>,
#if defined(CK_ENABLE_FP8) && (defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94))
std::tuple< F16, F8, F16, F16>,
std::tuple< F8, F16, F16, F16>,
std::tuple< F8, F8, F8, BF16>,
#endif
std::tuple< BF16, BF16, BF16, BF16>
>;
@@ -86,9 +74,9 @@ using KernelTypes_KM_KN = ::testing::Types<
// clang-format on
TYPED_TEST_SUITE(TestGemmUniversal_MK_KN, KernelTypes_MK_KN);
TYPED_TEST_SUITE(TestGemmUniversal_MK_NK, KernelTypes_MK_NK);
TYPED_TEST_SUITE(TestGemmUniversal_KM_KN, KernelTypes_KM_KN);
TYPED_TEST_SUITE(TestGemmUniversal_KM_NK, KernelTypes_KM_NK);
TYPED_TEST_SUITE(TestGemmUniversal_BF16_MK_KN, KernelTypes_MK_KN);
TYPED_TEST_SUITE(TestGemmUniversal_BF16_MK_NK, KernelTypes_MK_NK);
TYPED_TEST_SUITE(TestGemmUniversal_BF16_KM_KN, KernelTypes_KM_KN);
TYPED_TEST_SUITE(TestGemmUniversal_BF16_KM_NK, KernelTypes_KM_NK);
#include "test_gemm_universal_ut_cases.inc"
#include "test_gemm_universal_ut_cases_bf16.inc"

View File

@@ -0,0 +1,82 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <tuple>
#include "gtest/gtest.h"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "test_gemm_universal_util.hpp"
using F8 = ck::f8_t;
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
namespace {
template <typename X, typename Y>
struct tuple_concat;
template <typename... Xs, typename... Ys>
struct tuple_concat<std::tuple<Xs...>, std::tuple<Ys...>>
{
using type = std::tuple<Xs..., Ys...>;
};
} // namespace
template <typename Tuple>
class TestGemmUniversal_FP16_MK_KN
: public ck::test::TestGemmUniversal<typename tuple_concat<std::tuple<Row, Row>, Tuple>::type>
{
};
template <typename Tuple>
class TestGemmUniversal_FP16_MK_NK
: public ck::test::TestGemmUniversal<typename tuple_concat<std::tuple<Row, Col>, Tuple>::type>
{
};
template <typename Tuple>
class TestGemmUniversal_FP16_KM_KN
: public ck::test::TestGemmUniversal<typename tuple_concat<std::tuple<Col, Row>, Tuple>::type>
{
};
template <typename Tuple>
class TestGemmUniversal_FP16_KM_NK
: public ck::test::TestGemmUniversal<typename tuple_concat<std::tuple<Col, Col>, Tuple>::type>
{
};
// clang-format off
using KernelTypes_MK_KN = ::testing::Types<
// ADataType, BDataType, ComputeDataType, CDataType
#if defined(CK_ENABLE_FP8) && (defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94))
std::tuple< F16, F8, F16, F16>,
std::tuple< F8, F16, F16, F16>,
#endif
std::tuple< F16, F16, F16, F16>
>;
using KernelTypes_MK_NK = ::testing::Types<
// ADataType, BDataType, ComputeDataType, CDataType
#if defined(CK_ENABLE_FP8) && (defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94))
std::tuple< F16, F8, F16, F16>,
std::tuple< F8, F16, F16, F16>,
#endif
std::tuple< F16, F16, F16, F16>
>;
// clang-format on
TYPED_TEST_SUITE(TestGemmUniversal_FP16_MK_KN, KernelTypes_MK_KN);
TYPED_TEST_SUITE(TestGemmUniversal_FP16_MK_NK, KernelTypes_MK_NK);
#include "test_gemm_universal_ut_cases_fp16.inc"

View File

@@ -0,0 +1,71 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <tuple>
#include "gtest/gtest.h"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "test_gemm_universal_util.hpp"
using F8 = ck::f8_t;
using F16 = ck::half_t;
using BF16 = ck::bhalf_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
namespace {
template <typename X, typename Y>
struct tuple_concat;
template <typename... Xs, typename... Ys>
struct tuple_concat<std::tuple<Xs...>, std::tuple<Ys...>>
{
using type = std::tuple<Xs..., Ys...>;
};
} // namespace
template <typename Tuple>
class TestGemmUniversal_FP8_MK_KN
: public ck::test::TestGemmUniversal<typename tuple_concat<std::tuple<Row, Row>, Tuple>::type>
{
};
template <typename Tuple>
class TestGemmUniversal_FP8_MK_NK
: public ck::test::TestGemmUniversal<typename tuple_concat<std::tuple<Row, Col>, Tuple>::type>
{
};
// clang-format off
using KernelTypes_MK_KN = ::testing::Types<
// ADataType, BDataType, ComputeDataType, CDataType
#if defined(CK_ENABLE_FP8) && (defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94))
std::tuple< F16, F8, F16, F16>,
std::tuple< F8, F16, F16, F16>,
std::tuple< F8, F8, F8, BF16>,
#endif
// Fallback test type when FP8 is not enabled
std::tuple< F16, F16, F16, F16>
>;
using KernelTypes_MK_NK = ::testing::Types<
// ADataType, BDataType, ComputeDataType, CDataType
#if defined(CK_ENABLE_FP8) && (defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94))
std::tuple< F16, F8, F16, F16>,
std::tuple< F8, F16, F16, F16>,
std::tuple< F8, F8, F8, BF16>,
#endif
// Fallback test type when FP8 is not enabled
std::tuple< F16, F16, F16, F16>
>;
TYPED_TEST_SUITE(TestGemmUniversal_FP8_MK_KN, KernelTypes_MK_KN);
TYPED_TEST_SUITE(TestGemmUniversal_FP8_MK_NK, KernelTypes_MK_NK);
#include "test_gemm_universal_ut_cases_fp8.inc"

View File

@@ -44,9 +44,8 @@ class TestGemmUniversal_Streamk : public testing::Test
void SetUp() override
{
grid_size_list = {38, 114, 228}; // {38, 76, 114, 152, 190, 228, 266, 304, 342, 380};
streamk_sel_list = {0, 1, 2}; // 0: Data Parallel (DP) mode (Stream-K OFF), 1: 1-tile
// Stream-K+ DP, // {0, 1, 2, 3, 4}
streamk_sel_list = {0, 1, 2}; // 0: Data Parallel (DP) mode (Stream-K OFF), 1: 1-tile
// Stream-K+ DP, // {0, 1, 2, 3, 4}
// 2:2-tile Stream-K + DP
}
@@ -58,10 +57,9 @@ class TestGemmUniversal_Streamk : public testing::Test
const int StrideC)
{
for(auto streamk_sel : streamk_sel_list)
for(auto grid_size : grid_size_list)
{
RunSingle(M, N, K, StrideA, StrideB, StrideC, streamk_sel, grid_size);
}
{
RunSingle(M, N, K, StrideA, StrideB, StrideC, streamk_sel, -1);
}
}
void RunSingle(const int M,