mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-05 14:11:29 +00:00
Extend XDL kernel to Support RDNA3/4 - Part 4 (#2724)
* Fix example * fix build error * update pk_i4 & moe test case * fix all instance build (examples) * fix batched_gemm_gemm (example) * disable example_gemm_bias_softmax_gemm_permute on gfx11 * remove unnecessary disable gfx11 * update tests * update tests2
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <stdexcept>
|
||||
#include <vector>
|
||||
@@ -39,7 +39,7 @@ class TestGGemmSplitKInterface_MKNKMN : public ::testing::Test
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
CDEBlockTransferScalarPerVector_NPerBlock>;
|
||||
|
||||
using DefaultGGemmInstance = GGemmInstance<GemmDefault, 32, 8, 4, 8, 8>;
|
||||
using DefaultGGemmInstance = GGemmInstance<GemmDefault, 64, 16, 4, 8, 4>;
|
||||
};
|
||||
|
||||
TEST_F(TestGGemmSplitKInterface_MKNKMN, TileSize)
|
||||
@@ -67,7 +67,7 @@ TEST_F(TestGGemmSplitKInterface_MKNKMN, VectorLoadWidth)
|
||||
{
|
||||
static constexpr auto GemmMNKPadding =
|
||||
ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
using PaddedGGemmInstance = GGemmInstance<GemmMNKPadding, 32, 8, 4, 8, 8>;
|
||||
using PaddedGGemmInstance = GGemmInstance<GemmMNKPadding, 64, 16, 4, 8, 4>;
|
||||
|
||||
std::vector<int> Ms{128, 256, 256, 512};
|
||||
constexpr int N = 256;
|
||||
@@ -111,14 +111,17 @@ TEST_F(TestGGemmSplitKInterface_MKNKMN, KLoops)
|
||||
EXPECT_FALSE(
|
||||
DefaultGGemmInstance{}.IsSupported(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, kbatch));
|
||||
|
||||
Ks = std::vector<int>{256, 512, 384, 768};
|
||||
EXPECT_TRUE(
|
||||
DefaultGGemmInstance{}.IsSupported(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, kbatch));
|
||||
if(!ck::is_gfx11_supported())
|
||||
{
|
||||
Ks = std::vector<int>{256, 512, 768, 1536};
|
||||
EXPECT_TRUE(
|
||||
DefaultGGemmInstance{}.IsSupported(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, kbatch));
|
||||
|
||||
// Not all gemms have same value for main_k0_block_loop!
|
||||
Ks = std::vector<int>{256, 512, 512, 512};
|
||||
EXPECT_THROW(DefaultGGemmInstance{}.Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, kbatch),
|
||||
std::runtime_error);
|
||||
// Not all gemms have same value for main_k0_block_loop!
|
||||
Ks = std::vector<int>{256, 512, 512, 512};
|
||||
EXPECT_THROW(DefaultGGemmInstance{}.Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, kbatch),
|
||||
std::runtime_error);
|
||||
}
|
||||
}
|
||||
|
||||
class TestGGemmSplitKInterface_KMKNNM : public ::testing::Test
|
||||
@@ -150,7 +153,7 @@ class TestGGemmSplitKInterface_KMKNNM : public ::testing::Test
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
CDEBlockTransferScalarPerVector_NPerBlock>;
|
||||
|
||||
using DefaultGGemmInstance = GGemmInstance<GemmDefault, 32, 8, 4, 8, 4>;
|
||||
using DefaultGGemmInstance = GGemmInstance<GemmDefault, 64, 16, 4, 8, 4>;
|
||||
};
|
||||
|
||||
TEST_F(TestGGemmSplitKInterface_KMKNNM, TileSize)
|
||||
@@ -178,7 +181,7 @@ TEST_F(TestGGemmSplitKInterface_KMKNNM, VectorLoadWidth)
|
||||
{
|
||||
static constexpr auto GemmMNKPadding =
|
||||
ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
using PaddedGGemmInstance = GGemmInstance<GemmMNKPadding, 32, 8, 2, 8, 4>;
|
||||
using PaddedGGemmInstance = GGemmInstance<GemmMNKPadding, 64, 16, 2, 8, 4>;
|
||||
|
||||
std::vector<int> Ms{128, 256, 256, 512};
|
||||
constexpr int N = 256;
|
||||
|
||||
Reference in New Issue
Block a user