Extend XDL kernel to Support RDNA3/4 - Part 4 (#2724)

* Fix example

* fix build error

* update pk_i4 & moe test case

* fix all instance build (examples)

* fix batched_gemm_gemm (example)

* disable example_gemm_bias_softmax_gemm_permute on gfx11

* remove unnecessary disable gfx11

* update tests

* update tests2

[ROCm/composable_kernel commit: 321627aec5]
This commit is contained in:
linqunAMD
2025-09-12 23:17:07 +08:00
committed by GitHub
parent adc66b9b0e
commit 07def6b13d
123 changed files with 848 additions and 574 deletions

View File

@@ -60,11 +60,11 @@ struct DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
8, // AK1
8, // BK1
2, // B1K1
32, // MPerXDL
32, // NPerXDL
1, // MXdlPerWave
4, // NXdlPerWave
4, // Gemm1NXdlPerWave
16, // MPerXDL
16, // NPerXDL
2, // MXdlPerWave
8, // NXdlPerWave
8, // Gemm1NXdlPerWave
S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
@@ -89,7 +89,7 @@ struct DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
1, // CShuffleMXdlPerWavePerShuffle
2, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8>; // CShuffleBlockTransferScalarPerVector_NPerBlock
4>; // CShuffleBlockTransferScalarPerVector_NPerBlock
bool IsSupported(int M, int N, int K, int O)
{

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
@@ -133,11 +133,11 @@ struct DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
8, // AK1
8, // BK1
2, // B1K1
32, // MPerXDL
32, // NPerXDL
1, // MXdlPerWave
4, // NXdlPerWave
4, // Gemm1NXdlPerWave
16, // MPerXDL
16, // NPerXDL
2, // MXdlPerWave
8, // NXdlPerWave
8, // Gemm1NXdlPerWave
S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
@@ -162,7 +162,7 @@ struct DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
1, // CShuffleMXdlPerWavePerShuffle
2, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
4, // CShuffleBlockTransferScalarPerVector_NPerBlock
false>;
bool IsSupported(int M, int N, int K, int O)

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
@@ -293,11 +293,11 @@ struct DeviceInstanceWrapper_G2M1N1K1O1_TNTT_BF16_M128_N128_K32_O128
8, // AK1
8, // BK1
2, // B1K1
32, // MPerXDL
32, // NPerXDL
1, // MXdlPerWave
4, // NXdlPerWave
4, // Gemm1NXdlPerWave
16, // MPerXDL
16, // NPerXDL
2, // MXdlPerWave
8, // NXdlPerWave
8, // Gemm1NXdlPerWave
S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
@@ -322,7 +322,7 @@ struct DeviceInstanceWrapper_G2M1N1K1O1_TNTT_BF16_M128_N128_K32_O128
1, // CShuffleMXdlPerWavePerShuffle
2, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
4, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpecialization::MaskOutUpperTriangle>; // MaskOutUpperTriangle
bool IsSupported(int M, int N, int K, int O)

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
@@ -144,11 +144,11 @@ struct DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128
8, // AK1
8, // BK1
2, // B1K1
32, // MPerXDL
32, // NPerXDL
1, // MXdlPerWave
4, // NXdlPerWave
4, // Gemm1NXdlPerWave
16, // MPerXDL
16, // NPerXDL
2, // MXdlPerWave
8, // NXdlPerWave
8, // Gemm1NXdlPerWave
S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
@@ -173,7 +173,7 @@ struct DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128
1, // CShuffleMXdlPerWavePerShuffle
2, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
4, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpecialization::MaskOutUpperTriangle>; // MaskOutUpperTriangle
bool IsSupported(int M, int N, int K, int O)

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <stdexcept>
#include <vector>
@@ -181,3 +181,14 @@ TEST(TestContractionSupportedArgs, DEMemoryAccess)
wrapper.isSupported(Dims, Dims, Dims, Dims, Strides, Strides, Strides, InvalidStrides));
EXPECT_TRUE(wrapper.isSupported(Dims, Dims, Dims, Dims, Strides, Strides, Strides, Strides));
}
int main(int argc, char** argv)
{
if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
{
std::cout << "FP32/64 are not supported on gfx11 and gfx12." << std::endl;
return 0;
}
testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}

View File

@@ -1,7 +1,9 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include "ck/ck.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "profiler/profile_gemm_add_relu_add_layernorm_impl.hpp"
using Row = ck::tensor_layout::gemm::RowMajor;
@@ -75,3 +77,13 @@ using KernelTypes = ::testing::Types<
TYPED_TEST_SUITE(TestGemmAddReluAddLayernorm, KernelTypes);
TYPED_TEST(TestGemmAddReluAddLayernorm, Test_FP16) { this->Run(); }
int main(int argc, char** argv)
{
if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
{
std::cout << "No available instance for gfx11 & gfx12." << std::endl;
return 0;
}
testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include <iostream>
@@ -47,7 +47,7 @@ class TestGroupedConvndBwdData : public ::testing::Test
// ######| | | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector|
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock|
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< NDimSpatial, OutLayout, WeiLayout, ck::Tuple<>, InLayout, DataType, DataType, AccDataType, DataType, ck::Tuple<>, DataType, Pass, Pass, Pass, ConvSpec, true, true, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>;
< NDimSpatial, OutLayout, WeiLayout, ck::Tuple<>, InLayout, DataType, DataType, AccDataType, DataType, ck::Tuple<>, DataType, Pass, Pass, Pass, ConvSpec, true, true, 1, 256, 128, 256, 32, 8, 2, 16, 16, 4, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 4>;
// clang-format on
ck::utils::conv::ConvParam conv_param;

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include <iostream>
@@ -48,7 +48,7 @@ class TestGroupedConvndBwdWeight : public ::testing::Test
//##########| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector|
//##########| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl|
//##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| |
< NDimSpatial, InLayout, WeiLayout,OutLayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>;
< NDimSpatial, InLayout, WeiLayout,OutLayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 32, 128, 4, 8, 16, 16, 2, 4, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 8>, 4>;
// clang-format on
ck::utils::conv::ConvParam conv_param;

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include <iostream>
@@ -70,10 +70,10 @@ class TestGroupedConvndFwdMultiABInterfaceBase : public ::testing::Test
32, // KPerBlock
8, // AK1
8, // BK1
32, // MPerXdl
32, // NPerXdl
2, // MXdlPerWave
4, // NXdlPerWave
16, // MPerXdl
16, // NPerXdl
4, // MXdlPerWave
8, // NXdlPerWave
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
@@ -91,7 +91,7 @@ class TestGroupedConvndFwdMultiABInterfaceBase : public ::testing::Test
1,
1,
S<1, 32, 1, 8>,
8>;
4>;
const ck::utils::conv::ConvParam conv_param{
3, 1, 16, 16, 8, {3, 3, 3}, {17, 17, 17}, {2, 2, 2}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}};
@@ -172,8 +172,8 @@ class TestGroupedConvndFwdMultiABInterfaceBase : public ::testing::Test
class TestGroupedConvndFwdMultiAInterface
: public TestGroupedConvndFwdMultiABInterfaceBase<float,
ck::Tuple<float, float>,
float,
ck::Tuple<ck::half_t, ck::half_t>,
ck::half_t,
ScaleAdd,
PassThrough>
{
@@ -181,8 +181,8 @@ class TestGroupedConvndFwdMultiAInterface
class TestGroupedConvndFwdMultiBInterface
: public TestGroupedConvndFwdMultiABInterfaceBase<float,
float,
ck::Tuple<float, float>,
ck::half_t,
ck::Tuple<ck::half_t, ck::half_t>,
PassThrough,
ScaleAdd>
{
@@ -190,15 +190,18 @@ class TestGroupedConvndFwdMultiBInterface
class TestGroupedConvndFwdMultiABInterface
: public TestGroupedConvndFwdMultiABInterfaceBase<float,
ck::Tuple<float, float>,
ck::Tuple<float, float>,
ck::Tuple<ck::half_t, ck::half_t>,
ck::Tuple<ck::half_t, ck::half_t>,
ScaleAdd,
ScaleAdd>
{
};
class TestGroupedConvndFwdInterface
: public TestGroupedConvndFwdMultiABInterfaceBase<float, float, float, PassThrough, PassThrough>
class TestGroupedConvndFwdInterface : public TestGroupedConvndFwdMultiABInterfaceBase<float,
ck::half_t,
ck::half_t,
PassThrough,
PassThrough>
{
};

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <stdexcept>
#include <vector>
@@ -39,7 +39,7 @@ class TestGGemmSplitKInterface_MKNKMN : public ::testing::Test
BBlockTransferSrcScalarPerVector,
CDEBlockTransferScalarPerVector_NPerBlock>;
using DefaultGGemmInstance = GGemmInstance<GemmDefault, 32, 8, 4, 8, 8>;
using DefaultGGemmInstance = GGemmInstance<GemmDefault, 64, 16, 4, 8, 4>;
};
TEST_F(TestGGemmSplitKInterface_MKNKMN, TileSize)
@@ -67,7 +67,7 @@ TEST_F(TestGGemmSplitKInterface_MKNKMN, VectorLoadWidth)
{
static constexpr auto GemmMNKPadding =
ck::tensor_operation::device::GemmSpecialization::MNKPadding;
using PaddedGGemmInstance = GGemmInstance<GemmMNKPadding, 32, 8, 4, 8, 8>;
using PaddedGGemmInstance = GGemmInstance<GemmMNKPadding, 64, 16, 4, 8, 4>;
std::vector<int> Ms{128, 256, 256, 512};
constexpr int N = 256;
@@ -111,14 +111,17 @@ TEST_F(TestGGemmSplitKInterface_MKNKMN, KLoops)
EXPECT_FALSE(
DefaultGGemmInstance{}.IsSupported(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, kbatch));
Ks = std::vector<int>{256, 512, 384, 768};
EXPECT_TRUE(
DefaultGGemmInstance{}.IsSupported(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, kbatch));
if(!ck::is_gfx11_supported())
{
Ks = std::vector<int>{256, 512, 768, 1536};
EXPECT_TRUE(
DefaultGGemmInstance{}.IsSupported(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, kbatch));
// Not all gemms have same value for main_k0_block_loop!
Ks = std::vector<int>{256, 512, 512, 512};
EXPECT_THROW(DefaultGGemmInstance{}.Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, kbatch),
std::runtime_error);
// Not all gemms have same value for main_k0_block_loop!
Ks = std::vector<int>{256, 512, 512, 512};
EXPECT_THROW(DefaultGGemmInstance{}.Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, kbatch),
std::runtime_error);
}
}
class TestGGemmSplitKInterface_KMKNNM : public ::testing::Test
@@ -150,7 +153,7 @@ class TestGGemmSplitKInterface_KMKNNM : public ::testing::Test
BBlockTransferSrcScalarPerVector,
CDEBlockTransferScalarPerVector_NPerBlock>;
using DefaultGGemmInstance = GGemmInstance<GemmDefault, 32, 8, 4, 8, 4>;
using DefaultGGemmInstance = GGemmInstance<GemmDefault, 64, 16, 4, 8, 4>;
};
TEST_F(TestGGemmSplitKInterface_KMKNNM, TileSize)
@@ -178,7 +181,7 @@ TEST_F(TestGGemmSplitKInterface_KMKNNM, VectorLoadWidth)
{
static constexpr auto GemmMNKPadding =
ck::tensor_operation::device::GemmSpecialization::MNKPadding;
using PaddedGGemmInstance = GGemmInstance<GemmMNKPadding, 32, 8, 2, 8, 4>;
using PaddedGGemmInstance = GGemmInstance<GemmMNKPadding, 64, 16, 2, 8, 4>;
std::vector<int> Ms{128, 256, 256, 512};
constexpr int N = 256;

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -210,10 +210,10 @@ struct DeviceGroupedGemmSplitkInstanceWrapper
KPerBlock,
K1,
K1,
32,
32,
16,
16,
8,
4,
2,
S<1, 4, 16, 1>,
ABlockTransferThreadClusterArrageOrder,
ABlockTransferSrcAccessOrder,
@@ -303,12 +303,19 @@ struct DeviceGroupedGemmSplitkInstanceWrapper
{
ggemm_instance.SetKBatchSize(&argument, kbatch);
}
EXPECT_TRUE(ggemm_instance.IsSupportedArgument(argument));
auto invoker = ggemm_instance.MakeInvoker();
DeviceMem dev_gemm_kargs(ggemm_instance.GetDeviceKernelArgSize(&argument));
ggemm_instance.SetDeviceKernelArgs(&argument, dev_gemm_kargs.GetDeviceBuffer());
return invoker.Run(argument, StreamConfig{nullptr, false});
if(kbatch > 1 && ck::is_gfx11_supported())
{
EXPECT_FALSE(ggemm_instance.IsSupportedArgument(argument));
return 0;
}
else
{
EXPECT_TRUE(ggemm_instance.IsSupportedArgument(argument));
auto invoker = ggemm_instance.MakeInvoker();
DeviceMem dev_gemm_kargs(ggemm_instance.GetDeviceKernelArgSize(&argument));
ggemm_instance.SetDeviceKernelArgs(&argument, dev_gemm_kargs.GetDeviceBuffer());
return invoker.Run(argument, StreamConfig{nullptr, false});
}
}
};

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <numeric>
#include <cstdlib>
@@ -81,6 +81,7 @@ __global__ void __CK_WRAPPER_LAUNCH_BOUNDS__ DeviceGemm(const void* p_a,
const BlockShape tile_shape,
const ThreadLayout thread_layout)
{
#if defined(__gfx9__)
constexpr auto MPerBlock = ck::wrapper::size<0>(tile_shape);
constexpr auto NPerBlock = ck::wrapper::size<1>(tile_shape);
constexpr auto KPerBlock = ck::wrapper::size<2>(tile_shape);
@@ -256,6 +257,16 @@ __global__ void __CK_WRAPPER_LAUNCH_BOUNDS__ DeviceGemm(const void* p_a,
a_lds_tensor, b_lds_tensor, c_vgpr_reg);
ck::wrapper::copy(c_vgpr_reg, c_global_local_partition);
#else
ck::ignore = p_a;
ck::ignore = p_b;
ck::ignore = p_c;
ck::ignore = M;
ck::ignore = N;
ck::ignore = K;
ck::ignore = tile_shape;
ck::ignore = thread_layout;
#endif
}
template <typename DataType,
@@ -374,3 +385,14 @@ TEST(TestGemm, Float_2x4_4x2_XdlPerWave)
PerformGemm<DataType, ck::wrapper::BlockwisGemmXdlTraits_32x32Xdl_4x2XdlPerWave_4K1, 4, false>(
512, 512, 128, tile_shape, thread_layout);
}
int main(int argc, char** argv)
{
if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
{
std::cout << "This test support gfx9 only" << std::endl;
return 0;
}
testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}