Extend XDL kernel to Support RDNA3/4 - Part 4 (#2724)

* Fix example

* fix build error

* update pk_i4 & moe test case

* fix all instance build (examples)

* fix batched_gemm_gemm (example)

* disable example_gemm_bias_softmax_gemm_permute on gfx11

* remove unnecessary disable gfx11

* update tests

* update tests2
This commit is contained in:
linqunAMD
2025-09-12 23:17:07 +08:00
committed by GitHub
parent bca99a499d
commit 321627aec5
123 changed files with 848 additions and 574 deletions

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <stdexcept>
#include <vector>
@@ -39,7 +39,7 @@ class TestGGemmSplitKInterface_MKNKMN : public ::testing::Test
BBlockTransferSrcScalarPerVector,
CDEBlockTransferScalarPerVector_NPerBlock>;
using DefaultGGemmInstance = GGemmInstance<GemmDefault, 32, 8, 4, 8, 8>;
using DefaultGGemmInstance = GGemmInstance<GemmDefault, 64, 16, 4, 8, 4>;
};
TEST_F(TestGGemmSplitKInterface_MKNKMN, TileSize)
@@ -67,7 +67,7 @@ TEST_F(TestGGemmSplitKInterface_MKNKMN, VectorLoadWidth)
{
static constexpr auto GemmMNKPadding =
ck::tensor_operation::device::GemmSpecialization::MNKPadding;
using PaddedGGemmInstance = GGemmInstance<GemmMNKPadding, 32, 8, 4, 8, 8>;
using PaddedGGemmInstance = GGemmInstance<GemmMNKPadding, 64, 16, 4, 8, 4>;
std::vector<int> Ms{128, 256, 256, 512};
constexpr int N = 256;
@@ -111,14 +111,17 @@ TEST_F(TestGGemmSplitKInterface_MKNKMN, KLoops)
EXPECT_FALSE(
DefaultGGemmInstance{}.IsSupported(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, kbatch));
Ks = std::vector<int>{256, 512, 384, 768};
EXPECT_TRUE(
DefaultGGemmInstance{}.IsSupported(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, kbatch));
if(!ck::is_gfx11_supported())
{
Ks = std::vector<int>{256, 512, 768, 1536};
EXPECT_TRUE(
DefaultGGemmInstance{}.IsSupported(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, kbatch));
// Not all gemms have same value for main_k0_block_loop!
Ks = std::vector<int>{256, 512, 512, 512};
EXPECT_THROW(DefaultGGemmInstance{}.Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, kbatch),
std::runtime_error);
// Not all gemms have same value for main_k0_block_loop!
Ks = std::vector<int>{256, 512, 512, 512};
EXPECT_THROW(DefaultGGemmInstance{}.Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, kbatch),
std::runtime_error);
}
}
class TestGGemmSplitKInterface_KMKNNM : public ::testing::Test
@@ -150,7 +153,7 @@ class TestGGemmSplitKInterface_KMKNNM : public ::testing::Test
BBlockTransferSrcScalarPerVector,
CDEBlockTransferScalarPerVector_NPerBlock>;
using DefaultGGemmInstance = GGemmInstance<GemmDefault, 32, 8, 4, 8, 4>;
using DefaultGGemmInstance = GGemmInstance<GemmDefault, 64, 16, 4, 8, 4>;
};
TEST_F(TestGGemmSplitKInterface_KMKNNM, TileSize)
@@ -178,7 +181,7 @@ TEST_F(TestGGemmSplitKInterface_KMKNNM, VectorLoadWidth)
{
static constexpr auto GemmMNKPadding =
ck::tensor_operation::device::GemmSpecialization::MNKPadding;
using PaddedGGemmInstance = GGemmInstance<GemmMNKPadding, 32, 8, 2, 8, 4>;
using PaddedGGemmInstance = GGemmInstance<GemmMNKPadding, 64, 16, 2, 8, 4>;
std::vector<int> Ms{128, 256, 256, 512};
constexpr int N = 256;