Extend XDL kernel to Support RDNA3/4 - Part 4 (#2724)

* Fix example

* fix build error

* update pk_i4 & moe test case

* fix all instance build (examples)

* fix batched_gemm_gemm (example)

* disable example_gemm_bias_softmax_gemm_permute on gfx11

* remove unnecessary disable gfx11

* update tests

* update tests2
This commit is contained in:
linqunAMD
2025-09-12 23:17:07 +08:00
committed by GitHub
parent bca99a499d
commit 321627aec5
123 changed files with 848 additions and 574 deletions

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include <iostream>
@@ -70,10 +70,10 @@ using DeviceGroupedConvNDBwdDataInstance =
32, // KPerBlock
8, // AK1
2, // BK1
32, // MPerXdl
32, // NPerXdl
2, // MXdlPerWave
4, // NXdlPerWave
16, // MPerXdl
16, // NPerXdl
4, // MXdlPerWave
8, // NXdlPerWave
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
@@ -91,7 +91,7 @@ using DeviceGroupedConvNDBwdDataInstance =
1,
1,
S<1, 32, 1, 8>,
8>;
4>;
using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDBwdDataInstance<OutElementOp>;

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include <iostream>
@@ -63,10 +63,10 @@ using DeviceGroupedConvNDBwdWeightInstance =
128, // NPerBlock
4, // K0PerBlock
8, // K1
32, // MPerXdl
32, // NPerXdl
2, // MXdlPerWave
2, // NXdlPerWave
16, // MPerXdl
16, // NPerXdl
4, // MXdlPerWave
4, // NXdlPerWave
S<1, 4, 16, 4>, // ABlockTransferThreadClusterLengths_K0_M_K1
S<0, 3, 1, 2>, // ABlockTransferThreadClusterArrangeOrder
S<0, 2, 1, 3>, // ABlockTransferSrcAccessOrder
@@ -84,7 +84,7 @@ using DeviceGroupedConvNDBwdWeightInstance =
1, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 4>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
128 / (sizeof(WeiDataType) * CHAR_BIT)>; // CBlockTransferScalarPerVector_NWaveNPerXdl
64 / (sizeof(WeiDataType) * CHAR_BIT)>; // CBlockTransferScalarPerVector_NWaveNPerXdl
using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDBwdWeightInstance<WeiElementOp>;
namespace {
@@ -257,4 +257,12 @@ bool run_grouped_conv(bool do_verification,
#include "../run_convnd_activ_example.inc"
int main(int argc, char* argv[]) { return !run_convnd_example(argc, argv); }
int main(int argc, char* argv[])
{
// temp disable test on gfx11
if(ck::is_gfx11_supported())
{
return 0;
}
return !run_convnd_example(argc, argv);
}

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include <iostream>
@@ -71,10 +71,10 @@ using DeviceGroupedConvNDFwdInstance =
32, // KPerBlock
8, // AK1
8, // BK1
32, // MPerXdl
32, // NPerXdl
2, // MXdlPerWave
4, // NXdlPerWave
16, // MPerXdl
16, // NPerXdl
4, // MXdlPerWave
8, // NXdlPerWave
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
@@ -92,7 +92,7 @@ using DeviceGroupedConvNDFwdInstance =
1,
1,
S<1, 32, 1, 8>,
8>;
4>;
using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDFwdInstance<OutElementOp>;

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "convnd_fwd_convinvscale_common.hpp"
@@ -58,10 +58,10 @@ using DeviceGroupedConvNDFwdInstance =
32, // KPerBlock
8, // AK1
8, // BK1
32, // MPerXdl
32, // NPerXdl
2, // MXdlPerWave
4, // NXdlPerWave
16, // MPerXdl
16, // NPerXdl
4, // MXdlPerWave
8, // NXdlPerWave
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
@@ -79,7 +79,7 @@ using DeviceGroupedConvNDFwdInstance =
1,
1,
S<1, 32, 1, 8>,
8,
4,
AComputeDataType,
BComputeDataType>;

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <algorithm>
#include <cstdlib>
@@ -74,10 +74,10 @@ using DeviceGroupedConvNDFwdInstance =
32, // KPerBlock
8, // AK1
8, // BK1
32, // MPerXdl
32, // NPerXdl
2, // MXdlPerWave
4, // NXdlPerWave
16, // MPerXdl
16, // NPerXdl
4, // MXdlPerWave
8, // NXdlPerWave
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
@@ -95,7 +95,7 @@ using DeviceGroupedConvNDFwdInstance =
1,
1,
S<1, 32, 1, 8>,
8>;
4>;
using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDFwdInstance<OutElementOp>;

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include <iostream>
@@ -71,10 +71,10 @@ using DeviceGroupedConvNDFwdInstance =
32, // KPerBlock
8, // AK1
8, // BK1
32, // MPerXdl
32, // NPerXdl
2, // MXdlPerWave
4, // NXdlPerWave
16, // MPerXdl
16, // NPerXdl
4, // MXdlPerWave
8, // NXdlPerWave
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
@@ -92,7 +92,7 @@ using DeviceGroupedConvNDFwdInstance =
1,
1,
S<1, 32, 1, 8>,
8>;
4>;
using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDFwdInstance<OutElementOp>;

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "convnd_fwd_convscale_common.hpp"
@@ -58,10 +58,10 @@ using DeviceGroupedConvNDFwdInstance =
32, // KPerBlock
8, // AK1
8, // BK1
32, // MPerXdl
32, // NPerXdl
2, // MXdlPerWave
4, // NXdlPerWave
16, // MPerXdl
16, // NPerXdl
4, // MXdlPerWave
8, // NXdlPerWave
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
@@ -79,7 +79,7 @@ using DeviceGroupedConvNDFwdInstance =
1,
1,
S<1, 32, 1, 8>,
8,
4,
AComputeDataType,
BComputeDataType>;

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "convnd_fwd_convscale_common.hpp"
@@ -58,10 +58,10 @@ using DeviceGroupedConvNDFwdInstance =
32, // KPerBlock
8, // AK1
8, // BK1
32, // MPerXdl
32, // NPerXdl
2, // MXdlPerWave
4, // NXdlPerWave
16, // MPerXdl
16, // NPerXdl
4, // MXdlPerWave
8, // NXdlPerWave
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
@@ -79,7 +79,7 @@ using DeviceGroupedConvNDFwdInstance =
1,
1,
S<1, 32, 1, 8>,
8,
4,
AComputeDataType,
BComputeDataType>;

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "convnd_fwd_convscale_common.hpp"
@@ -58,10 +58,10 @@ using DeviceGroupedConvNDFwdInstance =
32, // KPerBlock
8, // AK1
8, // BK1
32, // MPerXdl
32, // NPerXdl
2, // MXdlPerWave
4, // NXdlPerWave
16, // MPerXdl
16, // NPerXdl
4, // MXdlPerWave
8, // NXdlPerWave
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
@@ -79,7 +79,7 @@ using DeviceGroupedConvNDFwdInstance =
1,
1,
S<1, 32, 1, 8>,
8,
4,
AComputeDataType,
BComputeDataType>;

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "convnd_fwd_convscale_common.hpp"
@@ -58,10 +58,10 @@ using DeviceGroupedConvNDFwdInstance =
32, // KPerBlock
8, // AK1
8, // BK1
32, // MPerXdl
32, // NPerXdl
2, // MXdlPerWave
4, // NXdlPerWave
16, // MPerXdl
16, // NPerXdl
4, // MXdlPerWave
8, // NXdlPerWave
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
@@ -79,7 +79,7 @@ using DeviceGroupedConvNDFwdInstance =
1,
1,
S<1, 32, 1, 8>,
8,
4,
AComputeDataType,
BComputeDataType>;

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/utility/tuple.hpp"
#include "convnd_fwd_convscale_add_common.hpp"
@@ -57,10 +57,10 @@ using DeviceGroupedConvNDFwdInstance =
32, // KPerBlock
8, // AK1
8, // BK1
32, // MPerXdl
32, // NPerXdl
2, // MXdlPerWave
4, // NXdlPerWave
16, // MPerXdl
16, // NPerXdl
4, // MXdlPerWave
8, // NXdlPerWave
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
@@ -78,7 +78,7 @@ using DeviceGroupedConvNDFwdInstance =
1,
1,
S<1, 32, 1, 8>,
8,
4,
AComputeDataType,
BComputeDataType>;

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "convnd_fwd_convscale_reduce_common.hpp"
@@ -52,10 +52,10 @@ using DeviceGroupedConvNDFwdInstance =
32, // KPerBlock
8, // AK1
8, // BK1
32, // MPerXdl
32, // NPerXdl
2, // MXdlPerWave
4, // NXdlPerWave
16, // MPerXdl
16, // NPerXdl
4, // MXdlPerWave
8, // NXdlPerWave
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
@@ -73,7 +73,7 @@ using DeviceGroupedConvNDFwdInstance =
1,
1,
S<1, 32, 1, 8>,
8,
4,
AComputeDataType,
BComputeDataType>;

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "convnd_fwd_convscale_reduce_common.hpp"
@@ -52,10 +52,10 @@ using DeviceGroupedConvNDFwdInstance =
32, // KPerBlock
8, // AK1
8, // BK1
32, // MPerXdl
32, // NPerXdl
2, // MXdlPerWave
4, // NXdlPerWave
16, // MPerXdl
16, // NPerXdl
4, // MXdlPerWave
8, // NXdlPerWave
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
@@ -73,7 +73,7 @@ using DeviceGroupedConvNDFwdInstance =
1,
1,
S<1, 32, 1, 8>,
8,
4,
AComputeDataType,
BComputeDataType>;

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "convnd_fwd_convscale_relu_common.hpp"
@@ -56,10 +56,10 @@ using DeviceGroupedConvNDFwdInstance =
32, // KPerBlock
8, // AK1
8, // BK1
32, // MPerXdl
32, // NPerXdl
2, // MXdlPerWave
4, // NXdlPerWave
16, // MPerXdl
16, // NPerXdl
4, // MXdlPerWave
8, // NXdlPerWave
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
@@ -77,7 +77,7 @@ using DeviceGroupedConvNDFwdInstance =
1,
1,
S<1, 32, 1, 8>,
8,
4,
AComputeDataType,
BComputeDataType>;

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -71,10 +71,10 @@ using DeviceGroupedConvNDActivInstance =
32, // KPerBlock
8, // AK1
8, // BK1
32, // MPerXdl
32, // NPerXdl
2, // MXdlPerWave
4, // NXdlPerWave
16, // MPerXdl
16, // NPerXdl
4, // MXdlPerWave
8, // NXdlPerWave
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
@@ -92,7 +92,7 @@ using DeviceGroupedConvNDActivInstance =
1,
1,
S<1, 32, 1, 8>,
8>;
4>;
template <ck::index_t NDimSpatial,
typename InDataType,

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "convnd_fwd_activ_multi_ab_common.hpp"
@@ -23,4 +23,14 @@ using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDMultiABFwdInstance<D
#include "../run_convnd_activ_example.inc"
int main(int argc, char* argv[]) { return !run_convnd_example(argc, argv); }
int main(int argc, char* argv[])
{
if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
{
std::cout << "FP32 are not supported on gfx11 and gfx12" << std::endl;
return 0;
}
return !run_convnd_example(argc, argv);
}

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include <iostream>
@@ -68,10 +68,10 @@ using DeviceGroupedConvNDMultiABFwdInstance =
32, // KPerBlock
8, // AK1
8, // BK1
32, // MPerXdl
32, // NPerXdl
2, // MXdlPerWave
4, // NXdlPerWave
16, // MPerXdl
16, // NPerXdl
4, // MXdlPerWave
8, // NXdlPerWave
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
@@ -89,7 +89,7 @@ using DeviceGroupedConvNDMultiABFwdInstance =
1,
1,
S<1, 32, 1, 8>,
8>;
4>;
namespace {
template <ck::index_t NDimSpatial,

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -71,10 +71,10 @@ using DeviceGroupedConvNDFwdInstance =
32, // KPerBlock
8, // AK1
8, // BK1
32, // MPerXdl
32, // NPerXdl
2, // MXdlPerWave
4, // NXdlPerWave
16, // MPerXdl
16, // NPerXdl
4, // MXdlPerWave
8, // NXdlPerWave
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
@@ -92,7 +92,7 @@ using DeviceGroupedConvNDFwdInstance =
1,
1,
S<1, 32, 1, 8>,
8>;
4>;
template <ck::index_t NDimSpatial,
typename InDataType,