WMMA GEMM universal pipeline v1, mixed precision and paddings, examples (#2230)

* Fixed cmake errors related to  gemm_bilinear. Previously, if the above flags are set, cmake build fails: GPU_TARGETS="gfx1100;gfx1201" -D DTYPES="fp16;bf16;fp8"

* Fixed cmake build errors related to test_fp8

* Updates to support mixed precision

* Adding support for RRR, F8xF16xF16 gemm_universal_wmma - wip

* Added support for F8xF16xF16 to gemm_wmma_universal

* Added support for F16xF8xF16 to gemm_wmma_universal

* Added support for BF16xI4xBF16 to gemm_wmma_universal

* Added support for F16xI4xF16 to gemm_wmma_universal

* Fixed IsSupportedArgument to check ComputeTypeA, ComputeTypeB instead of ADataType, BDataType

* Added missing test class for FP16_KM_NK

* Pre-commit hooks fixes

* Added padding instances for f16xf16xf16

* Fixed cmake errors related to  gemm_bilinear. Previously, if the above flags are set, cmake build fails: GPU_TARGETS="gfx1100;gfx1201" -D DTYPES="fp16;bf16;fp8"

* Fixed cmake build errors related to test_fp8

* Ammending changes for adding support for padding instances for f16xf16xf16

* Fixes for padding instances for f16xf16xf16

* Added padding instances for bf16xbf16, f8xf8

* Added packed instances for bf16xi4xbf16

* Added padding instances for f8xf16xf16

* Added padding instances for f16xf8xf16, f16xi4xf16

* Fixed typos for bf16xbf16xbf16 padding instances

* Fixed typos for padded instances

* Added tests for fp16, KM_KN and KM_NK

* Padding not supported for when BDataType is pk_i4_t. Added fix for correct check and removed padding instances.

* Fixed typos

* Updated the set of tests for FP16

* Updated the set of tests for FP16

* Fix typo

* Moved f16xi4 test under the correct data layout group

* example for gemm_universal_bf16

* Adding examples for gemm_wmma instances

* Added the  missing parameters

* Fixed review comments and added executable to cmakeLists

* Fixing clang format

* Fixing build erros

* Fixed compilation failure.

* Modified some code as per gemm_universal_examples

* Fixed the gemm specialization error

* Fixed the build errors.

* Fix strides of a/b_thread_desc

The descriptors are larger than needed (even though the compiler don't alloc registers for unused values).

* Load in M/NRepeat dims with thread copy's slice instead of a loop

* Clone BlockwiseGemmXdlops_pipeline_v1 for WMMA implementation

* Implement Intrawave and Interwave variants of pipeline v1

* Add instances for Interwave and Intrawave v1

* Add instances with ABlockLdsExtraM and BBlockLdsExtraN = 0

* Remove instances that are too slow (mostly because of register spilling)

* Add a workaround for fp8/bf8->f32 packed conversion issue

* Add instances for Interwave and Intrawave v1

* Enable profiling of mixed precision with f8 and int4 on WMMA

* Fix segfault in profiler when B is pk_i4_t

b_device_buf's size in bytes is larger than b_k_n_permute so b_device_buf.ToDevice reads out-of-bounds.

* Remove instances that are too slow (mostly because of register spilling)

* Add missing add_device_gemm_wmma_universal_f8_f8_bf16 declarations

* Add test case for bf16_i4

* Add missing Regular tests

* Add test_gemm_universal_xdl/wmma_fp16 to REGRESSION_TESTS

They take more than 30 seconds

* Fix a bug that fp16_i4 validation passes only with PermuteB

A permutation required by conversion from pk_i4_t to half_t does not
depend on PermuteB, they can be used independently.

* Use PermuteB with f16_i4 in most instances (as xdl)

Some instances use PermuteB = false for checking correctness.
See also the previous commit.

* Fix cache flushing for pk_i4

* Add mixed precision examples

* Disable all tests and instances with f8 on gfx11

Even though f8_f16 and f16_f8 don't require f8 WMMA instructions,
gfx11 still lacks hardware instructions for fast f8->f32 conversion.

* Add FP16 KM_NK and KM_KN test suites for XDL

These tests were added to common .inc for better testing of WMMA instances

* Fix int8 DTYPES check for gemm_bilinear

---------

Co-authored-by: Anca Hamuraru <anca@streamhpc.com>
Co-authored-by: Apoorva Kalyani <apoorva@streamhpc.com>

[ROCm/composable_kernel commit: 52b4860a30]
This commit is contained in:
Anton Gorenko
2025-06-04 12:22:33 +06:00
committed by GitHub
parent c395db8926
commit 780cb29a42
117 changed files with 4953 additions and 271 deletions

View File

@@ -14,7 +14,8 @@ set(REGRESSION_TESTS
test_gemm_fp16
test_gemm_splitk
test_batched_gemm
test_gemm_universal
test_gemm_universal_wmma_fp16
test_gemm_universal_xdl_fp16
test_gemm_universal_streamk_fp16
test_gemm_universal_streamk_bf16
test_gemm_universal_streamk_fp8

View File

@@ -16,15 +16,15 @@ if (CK_USE_OCP_FP8)
add_gtest_executable(test_fp8_ocp test_fp8_ocp.cpp)
if(result EQUAL 0)
target_link_libraries(test_fp8_ocp PRIVATE utility)
add_dependencies(test_fp8 test_fp8_ocp)
endif()
add_gtest_executable(test_bf8_ocp test_bf8_ocp.cpp)
if(result EQUAL 0)
target_link_libraries(test_bf8_ocp PRIVATE utility)
add_dependencies(test_fp8 test_bf8_ocp)
endif()
add_dependencies(test_fp8 test_fp8_ocp)
add_dependencies(test_fp8 test_bf8_ocp)
endif()
if (CK_USE_FNUZ_FP8)

View File

@@ -207,3 +207,35 @@ TYPED_TEST(TestGemmUniversal_BF16_MK_NK, Regular)
for(int M : Ms)
this->Run(M, N, K, StrideA, StrideB, StrideC);
}
TYPED_TEST(TestGemmUniversal_BF16_KM_KN, Regular)
{
std::vector<int> Ms{512};
constexpr int N = 512;
constexpr int K = 512;
constexpr int StrideB = N;
constexpr int StrideC = N;
for(int M : Ms)
{
int StrideA = M;
this->Run(M, N, K, StrideA, StrideB, StrideC);
}
}
TYPED_TEST(TestGemmUniversal_BF16_KM_NK, Regular)
{
std::vector<int> Ms{512};
constexpr int N = 512;
constexpr int K = 512;
constexpr int StrideB = N;
constexpr int StrideC = N;
for(int M : Ms)
{
int StrideA = M;
this->Run(M, N, K, StrideA, StrideB, StrideC);
}
}

View File

@@ -28,6 +28,38 @@ TYPED_TEST(TestGemmUniversal_FP16_MK_NK, SmallM)
this->Run(M, N, K, StrideA, StrideB, StrideC);
}
TYPED_TEST(TestGemmUniversal_FP16_KM_KN, SmallM)
{
std::vector<int> Ms{1, 2, 3, 4, 5, 6};
constexpr int N = 512;
constexpr int K = 320;
constexpr int StrideB = N;
constexpr int StrideC = N;
for(int M : Ms)
{
int StrideA = M;
this->Run(M, N, K, StrideA, StrideB, StrideC);
}
}
TYPED_TEST(TestGemmUniversal_FP16_KM_NK, SmallM)
{
std::vector<int> Ms{1, 2, 3, 4, 5, 6};
constexpr int N = 512;
constexpr int K = 320;
constexpr int StrideB = N;
constexpr int StrideC = N;
for(int M : Ms)
{
int StrideA = M;
this->Run(M, N, K, StrideA, StrideB, StrideC);
}
}
TYPED_TEST(TestGemmUniversal_FP16_MK_KN, MidLargeM)
{
std::vector<int> Ms{127, 255, 312, 799, 1573};
@@ -56,6 +88,38 @@ TYPED_TEST(TestGemmUniversal_FP16_MK_NK, MidLargeM)
this->Run(M, N, K, StrideA, StrideB, StrideC);
}
TYPED_TEST(TestGemmUniversal_FP16_KM_KN, MidLargeM)
{
std::vector<int> Ms{127, 255, 312, 799, 1573};
constexpr int N = 512;
constexpr int K = 320;
constexpr int StrideB = N;
constexpr int StrideC = N;
for(int M : Ms)
{
int StrideA = M;
this->Run(M, N, K, StrideA, StrideB, StrideC);
}
}
TYPED_TEST(TestGemmUniversal_FP16_KM_NK, MidLargeM)
{
std::vector<int> Ms{127, 255, 312, 799, 1573};
constexpr int N = 512;
constexpr int K = 320;
constexpr int StrideB = N;
constexpr int StrideC = N;
for(int M : Ms)
{
int StrideA = M;
this->Run(M, N, K, StrideA, StrideB, StrideC);
}
}
TYPED_TEST(TestGemmUniversal_FP16_MK_KN, PaddK)
{
std::vector<int> Ms{127};
@@ -84,6 +148,38 @@ TYPED_TEST(TestGemmUniversal_FP16_MK_NK, PaddK)
this->Run(M, N, K, StrideA, StrideB, StrideC);
}
TYPED_TEST(TestGemmUniversal_FP16_KM_KN, PaddK)
{
std::vector<int> Ms{127};
constexpr int N = 512;
constexpr int K = 437;
constexpr int StrideB = N;
constexpr int StrideC = N;
for(int M : Ms)
{
int StrideA = M;
this->Run(M, N, K, StrideA, StrideB, StrideC);
}
}
TYPED_TEST(TestGemmUniversal_FP16_KM_NK, PaddK)
{
std::vector<int> Ms{127};
constexpr int N = 512;
constexpr int K = 437;
constexpr int StrideB = N;
constexpr int StrideC = N;
for(int M : Ms)
{
int StrideA = M;
this->Run(M, N, K, StrideA, StrideB, StrideC);
}
}
TYPED_TEST(TestGemmUniversal_FP16_MK_KN, Regular)
{
std::vector<int> Ms{512};
@@ -111,3 +207,35 @@ TYPED_TEST(TestGemmUniversal_FP16_MK_NK, Regular)
for(int M : Ms)
this->Run(M, N, K, StrideA, StrideB, StrideC);
}
TYPED_TEST(TestGemmUniversal_FP16_KM_KN, Regular)
{
std::vector<int> Ms{512};
constexpr int N = 512;
constexpr int K = 512;
constexpr int StrideB = N;
constexpr int StrideC = N;
for(int M : Ms)
{
int StrideA = M;
this->Run(M, N, K, StrideA, StrideB, StrideC);
}
}
TYPED_TEST(TestGemmUniversal_FP16_KM_NK, Regular)
{
std::vector<int> Ms{512};
constexpr int N = 512;
constexpr int K = 512;
constexpr int StrideB = N;
constexpr int StrideC = N;
for(int M : Ms)
{
int StrideA = M;
this->Run(M, N, K, StrideA, StrideB, StrideC);
}
}

View File

@@ -7,6 +7,7 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "test_gemm_universal_util.hpp"
using I4 = ck::pk_i4_t;
using BF16 = ck::bhalf_t;
using F32 = float;
@@ -58,6 +59,9 @@ using KernelTypes_MK_KN = ::testing::Types<
using KernelTypes_MK_NK = ::testing::Types<
// ADataType, BDataType, ComputeDataType, CDataType
#if defined(CK_ENABLE_FP8)
std::tuple< BF16, I4, BF16, BF16>,
#endif
std::tuple< BF16, BF16, BF16, BF16>
>;
@@ -68,6 +72,9 @@ using KernelTypes_KM_KN = ::testing::Types<
using KernelTypes_KM_NK = ::testing::Types<
// ADataType, BDataType, ComputeDataType, CDataType
#if defined(CK_ENABLE_FP8)
std::tuple< BF16, I4, BF16, BF16>,
#endif
std::tuple< BF16, BF16, BF16, BF16>
>;
// clang-format on

View File

@@ -7,6 +7,8 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "test_gemm_universal_util.hpp"
using I4 = ck::pk_i4_t;
using F8 = ck::f8_t;
using F16 = ck::half_t;
using F32 = float;
@@ -39,19 +41,61 @@ class TestGemmUniversal_FP16_MK_NK
{
};
template <typename Tuple>
class TestGemmUniversal_FP16_KM_KN
: public ck::test::TestGemmUniversal<typename tuple_concat<std::tuple<Col, Row>, Tuple>::type>
{
};
template <typename Tuple>
class TestGemmUniversal_FP16_KM_NK
: public ck::test::TestGemmUniversal<typename tuple_concat<std::tuple<Col, Col>, Tuple>::type>
{
};
// clang-format off
using KernelTypes_MK_KN = ::testing::Types<
// ADataType, BDataType, ComputeDataType, CDataType
#if defined(CK_ENABLE_FP8) && defined(CK_USE_WMMA_FP8)
std::tuple< F8, F16, F16, F16>,
std::tuple< F16, F8, F16, F16>,
#endif
std::tuple< F16, F16, F16, F16>
>;
using KernelTypes_MK_NK = ::testing::Types<
// ADataType, BDataType, ComputeDataType, CDataType
#if defined(CK_ENABLE_FP8) && defined(CK_USE_WMMA_FP8)
std::tuple< F8, F16, F16, F16>,
std::tuple< F16, F8, F16, F16>,
std::tuple< F16, I4, F16, F16>,
#endif
std::tuple< F16, F16, F16, F16>
>;
using KernelTypes_KM_NK = ::testing::Types<
// ADataType, BDataType, ComputeDataType, CDataType
#if defined(CK_ENABLE_FP8) && defined(CK_USE_WMMA_FP8)
std::tuple< F8, F16, F16, F16>,
std::tuple< F16, F8, F16, F16>,
std::tuple< F16, I4, F16, F16>,
#endif
std::tuple< F16, F16, F16, F16>
>;
using KernelTypes_KM_KN = ::testing::Types<
// ADataType, BDataType, ComputeDataType, CDataType
#if defined(CK_ENABLE_FP8) && defined(CK_USE_WMMA_FP8)
std::tuple< F8, F16, F16, F16>,
std::tuple< F16, F8, F16, F16>,
#endif
std::tuple< F16, F16, F16, F16>
>;
// clang-format on
TYPED_TEST_SUITE(TestGemmUniversal_FP16_MK_KN, KernelTypes_MK_KN);
TYPED_TEST_SUITE(TestGemmUniversal_FP16_MK_NK, KernelTypes_MK_NK);
TYPED_TEST_SUITE(TestGemmUniversal_FP16_KM_NK, KernelTypes_KM_NK);
TYPED_TEST_SUITE(TestGemmUniversal_FP16_KM_KN, KernelTypes_KM_KN);
#include "test_gemm_universal_ut_cases_fp16.inc"

View File

@@ -7,7 +7,7 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "test_gemm_universal_util.hpp"
#if CK_USE_WMMA_FP8
#if defined(CK_USE_WMMA_FP8)
using F8 = ck::f8_t;
using BF16 = ck::bhalf_t;

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <tuple>
@@ -55,7 +55,7 @@ class TestGemmUniversal_FP16_KM_NK
// clang-format off
using KernelTypes_MK_KN = ::testing::Types<
// ADataType, BDataType, ComputeDataType, CDataType
#if defined(CK_ENABLE_FP8) && (defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94))
std::tuple< F16, F8, F16, F16>,
std::tuple< F8, F16, F16, F16>,
@@ -63,9 +63,10 @@ using KernelTypes_MK_KN = ::testing::Types<
#endif
std::tuple< F16, F16, F16, F16>
>;
using KernelTypes_MK_NK = ::testing::Types<
// ADataType, BDataType, ComputeDataType, CDataType
#if defined(CK_ENABLE_FP8) && (defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94))
std::tuple< F16, F8, F16, F16>,
std::tuple< F8, F16, F16, F16>,
@@ -74,9 +75,20 @@ using KernelTypes_MK_NK = ::testing::Types<
std::tuple< F16, F16, F16, F16>
>;
using KernelTypes_KM_NK = ::testing::Types<
// ADataType, BDataType, ComputeDataType, CDataType
std::tuple< F16, F16, F16, F16>
>;
using KernelTypes_KM_KN = ::testing::Types<
// ADataType, BDataType, ComputeDataType, CDataType
std::tuple< F16, F16, F16, F16>
>;
// clang-format on
TYPED_TEST_SUITE(TestGemmUniversal_FP16_MK_KN, KernelTypes_MK_KN);
TYPED_TEST_SUITE(TestGemmUniversal_FP16_MK_NK, KernelTypes_MK_NK);
TYPED_TEST_SUITE(TestGemmUniversal_FP16_KM_NK, KernelTypes_KM_NK);
TYPED_TEST_SUITE(TestGemmUniversal_FP16_KM_KN, KernelTypes_KM_KN);
#include "test_gemm_universal_ut_cases_fp16.inc"