mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
Extend XDL kernel to Support RDNA3/4 - Part 4 (#2724)
* Fix example * fix build error * update pk_i4 & moe test case * fix all instance build (examples) * fix batched_gemm_gemm (example) * disable example_gemm_bias_softmax_gemm_permute on gfx11 * remove unnecessary disable gfx11 * update tests * update tests2
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <cstdlib>
|
||||
#include <iostream>
|
||||
@@ -70,10 +70,10 @@ using DeviceGroupedConvNDBwdDataInstance =
|
||||
32, // KPerBlock
|
||||
8, // AK1
|
||||
2, // BK1
|
||||
32, // MPerXdl
|
||||
32, // NPerXdl
|
||||
2, // MXdlPerWave
|
||||
4, // NXdlPerWave
|
||||
16, // MPerXdl
|
||||
16, // NPerXdl
|
||||
4, // MXdlPerWave
|
||||
8, // NXdlPerWave
|
||||
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
|
||||
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
|
||||
@@ -91,7 +91,7 @@ using DeviceGroupedConvNDBwdDataInstance =
|
||||
1,
|
||||
1,
|
||||
S<1, 32, 1, 8>,
|
||||
8>;
|
||||
4>;
|
||||
|
||||
using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDBwdDataInstance<OutElementOp>;
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <cstdlib>
|
||||
#include <iostream>
|
||||
@@ -63,10 +63,10 @@ using DeviceGroupedConvNDBwdWeightInstance =
|
||||
128, // NPerBlock
|
||||
4, // K0PerBlock
|
||||
8, // K1
|
||||
32, // MPerXdl
|
||||
32, // NPerXdl
|
||||
2, // MXdlPerWave
|
||||
2, // NXdlPerWave
|
||||
16, // MPerXdl
|
||||
16, // NPerXdl
|
||||
4, // MXdlPerWave
|
||||
4, // NXdlPerWave
|
||||
S<1, 4, 16, 4>, // ABlockTransferThreadClusterLengths_K0_M_K1
|
||||
S<0, 3, 1, 2>, // ABlockTransferThreadClusterArrangeOrder
|
||||
S<0, 2, 1, 3>, // ABlockTransferSrcAccessOrder
|
||||
@@ -84,7 +84,7 @@ using DeviceGroupedConvNDBwdWeightInstance =
|
||||
1, // CShuffleMXdlPerWavePerShuffle
|
||||
1, // CShuffleNXdlPerWavePerShuffle
|
||||
S<1, 32, 1, 4>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
128 / (sizeof(WeiDataType) * CHAR_BIT)>; // CBlockTransferScalarPerVector_NWaveNPerXdl
|
||||
64 / (sizeof(WeiDataType) * CHAR_BIT)>; // CBlockTransferScalarPerVector_NWaveNPerXdl
|
||||
using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDBwdWeightInstance<WeiElementOp>;
|
||||
|
||||
namespace {
|
||||
@@ -257,4 +257,12 @@ bool run_grouped_conv(bool do_verification,
|
||||
|
||||
#include "../run_convnd_activ_example.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_convnd_example(argc, argv); }
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
// temp disable test on gfx11
|
||||
if(ck::is_gfx11_supported())
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
return !run_convnd_example(argc, argv);
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <cstdlib>
|
||||
#include <iostream>
|
||||
@@ -71,10 +71,10 @@ using DeviceGroupedConvNDFwdInstance =
|
||||
32, // KPerBlock
|
||||
8, // AK1
|
||||
8, // BK1
|
||||
32, // MPerXdl
|
||||
32, // NPerXdl
|
||||
2, // MXdlPerWave
|
||||
4, // NXdlPerWave
|
||||
16, // MPerXdl
|
||||
16, // NPerXdl
|
||||
4, // MXdlPerWave
|
||||
8, // NXdlPerWave
|
||||
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
|
||||
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
|
||||
@@ -92,7 +92,7 @@ using DeviceGroupedConvNDFwdInstance =
|
||||
1,
|
||||
1,
|
||||
S<1, 32, 1, 8>,
|
||||
8>;
|
||||
4>;
|
||||
|
||||
using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDFwdInstance<OutElementOp>;
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "convnd_fwd_convinvscale_common.hpp"
|
||||
|
||||
@@ -58,10 +58,10 @@ using DeviceGroupedConvNDFwdInstance =
|
||||
32, // KPerBlock
|
||||
8, // AK1
|
||||
8, // BK1
|
||||
32, // MPerXdl
|
||||
32, // NPerXdl
|
||||
2, // MXdlPerWave
|
||||
4, // NXdlPerWave
|
||||
16, // MPerXdl
|
||||
16, // NPerXdl
|
||||
4, // MXdlPerWave
|
||||
8, // NXdlPerWave
|
||||
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
|
||||
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
|
||||
@@ -79,7 +79,7 @@ using DeviceGroupedConvNDFwdInstance =
|
||||
1,
|
||||
1,
|
||||
S<1, 32, 1, 8>,
|
||||
8,
|
||||
4,
|
||||
AComputeDataType,
|
||||
BComputeDataType>;
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdlib>
|
||||
@@ -74,10 +74,10 @@ using DeviceGroupedConvNDFwdInstance =
|
||||
32, // KPerBlock
|
||||
8, // AK1
|
||||
8, // BK1
|
||||
32, // MPerXdl
|
||||
32, // NPerXdl
|
||||
2, // MXdlPerWave
|
||||
4, // NXdlPerWave
|
||||
16, // MPerXdl
|
||||
16, // NPerXdl
|
||||
4, // MXdlPerWave
|
||||
8, // NXdlPerWave
|
||||
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
|
||||
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
|
||||
@@ -95,7 +95,7 @@ using DeviceGroupedConvNDFwdInstance =
|
||||
1,
|
||||
1,
|
||||
S<1, 32, 1, 8>,
|
||||
8>;
|
||||
4>;
|
||||
|
||||
using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDFwdInstance<OutElementOp>;
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <cstdlib>
|
||||
#include <iostream>
|
||||
@@ -71,10 +71,10 @@ using DeviceGroupedConvNDFwdInstance =
|
||||
32, // KPerBlock
|
||||
8, // AK1
|
||||
8, // BK1
|
||||
32, // MPerXdl
|
||||
32, // NPerXdl
|
||||
2, // MXdlPerWave
|
||||
4, // NXdlPerWave
|
||||
16, // MPerXdl
|
||||
16, // NPerXdl
|
||||
4, // MXdlPerWave
|
||||
8, // NXdlPerWave
|
||||
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
|
||||
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
|
||||
@@ -92,7 +92,7 @@ using DeviceGroupedConvNDFwdInstance =
|
||||
1,
|
||||
1,
|
||||
S<1, 32, 1, 8>,
|
||||
8>;
|
||||
4>;
|
||||
|
||||
using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDFwdInstance<OutElementOp>;
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "convnd_fwd_convscale_common.hpp"
|
||||
|
||||
@@ -58,10 +58,10 @@ using DeviceGroupedConvNDFwdInstance =
|
||||
32, // KPerBlock
|
||||
8, // AK1
|
||||
8, // BK1
|
||||
32, // MPerXdl
|
||||
32, // NPerXdl
|
||||
2, // MXdlPerWave
|
||||
4, // NXdlPerWave
|
||||
16, // MPerXdl
|
||||
16, // NPerXdl
|
||||
4, // MXdlPerWave
|
||||
8, // NXdlPerWave
|
||||
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
|
||||
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
|
||||
@@ -79,7 +79,7 @@ using DeviceGroupedConvNDFwdInstance =
|
||||
1,
|
||||
1,
|
||||
S<1, 32, 1, 8>,
|
||||
8,
|
||||
4,
|
||||
AComputeDataType,
|
||||
BComputeDataType>;
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "convnd_fwd_convscale_common.hpp"
|
||||
|
||||
@@ -58,10 +58,10 @@ using DeviceGroupedConvNDFwdInstance =
|
||||
32, // KPerBlock
|
||||
8, // AK1
|
||||
8, // BK1
|
||||
32, // MPerXdl
|
||||
32, // NPerXdl
|
||||
2, // MXdlPerWave
|
||||
4, // NXdlPerWave
|
||||
16, // MPerXdl
|
||||
16, // NPerXdl
|
||||
4, // MXdlPerWave
|
||||
8, // NXdlPerWave
|
||||
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
|
||||
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
|
||||
@@ -79,7 +79,7 @@ using DeviceGroupedConvNDFwdInstance =
|
||||
1,
|
||||
1,
|
||||
S<1, 32, 1, 8>,
|
||||
8,
|
||||
4,
|
||||
AComputeDataType,
|
||||
BComputeDataType>;
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "convnd_fwd_convscale_common.hpp"
|
||||
|
||||
@@ -58,10 +58,10 @@ using DeviceGroupedConvNDFwdInstance =
|
||||
32, // KPerBlock
|
||||
8, // AK1
|
||||
8, // BK1
|
||||
32, // MPerXdl
|
||||
32, // NPerXdl
|
||||
2, // MXdlPerWave
|
||||
4, // NXdlPerWave
|
||||
16, // MPerXdl
|
||||
16, // NPerXdl
|
||||
4, // MXdlPerWave
|
||||
8, // NXdlPerWave
|
||||
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
|
||||
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
|
||||
@@ -79,7 +79,7 @@ using DeviceGroupedConvNDFwdInstance =
|
||||
1,
|
||||
1,
|
||||
S<1, 32, 1, 8>,
|
||||
8,
|
||||
4,
|
||||
AComputeDataType,
|
||||
BComputeDataType>;
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "convnd_fwd_convscale_common.hpp"
|
||||
|
||||
@@ -58,10 +58,10 @@ using DeviceGroupedConvNDFwdInstance =
|
||||
32, // KPerBlock
|
||||
8, // AK1
|
||||
8, // BK1
|
||||
32, // MPerXdl
|
||||
32, // NPerXdl
|
||||
2, // MXdlPerWave
|
||||
4, // NXdlPerWave
|
||||
16, // MPerXdl
|
||||
16, // NPerXdl
|
||||
4, // MXdlPerWave
|
||||
8, // NXdlPerWave
|
||||
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
|
||||
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
|
||||
@@ -79,7 +79,7 @@ using DeviceGroupedConvNDFwdInstance =
|
||||
1,
|
||||
1,
|
||||
S<1, 32, 1, 8>,
|
||||
8,
|
||||
4,
|
||||
AComputeDataType,
|
||||
BComputeDataType>;
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/utility/tuple.hpp"
|
||||
#include "convnd_fwd_convscale_add_common.hpp"
|
||||
@@ -57,10 +57,10 @@ using DeviceGroupedConvNDFwdInstance =
|
||||
32, // KPerBlock
|
||||
8, // AK1
|
||||
8, // BK1
|
||||
32, // MPerXdl
|
||||
32, // NPerXdl
|
||||
2, // MXdlPerWave
|
||||
4, // NXdlPerWave
|
||||
16, // MPerXdl
|
||||
16, // NPerXdl
|
||||
4, // MXdlPerWave
|
||||
8, // NXdlPerWave
|
||||
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
|
||||
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
|
||||
@@ -78,7 +78,7 @@ using DeviceGroupedConvNDFwdInstance =
|
||||
1,
|
||||
1,
|
||||
S<1, 32, 1, 8>,
|
||||
8,
|
||||
4,
|
||||
AComputeDataType,
|
||||
BComputeDataType>;
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "convnd_fwd_convscale_reduce_common.hpp"
|
||||
|
||||
@@ -52,10 +52,10 @@ using DeviceGroupedConvNDFwdInstance =
|
||||
32, // KPerBlock
|
||||
8, // AK1
|
||||
8, // BK1
|
||||
32, // MPerXdl
|
||||
32, // NPerXdl
|
||||
2, // MXdlPerWave
|
||||
4, // NXdlPerWave
|
||||
16, // MPerXdl
|
||||
16, // NPerXdl
|
||||
4, // MXdlPerWave
|
||||
8, // NXdlPerWave
|
||||
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
|
||||
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
|
||||
@@ -73,7 +73,7 @@ using DeviceGroupedConvNDFwdInstance =
|
||||
1,
|
||||
1,
|
||||
S<1, 32, 1, 8>,
|
||||
8,
|
||||
4,
|
||||
AComputeDataType,
|
||||
BComputeDataType>;
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "convnd_fwd_convscale_reduce_common.hpp"
|
||||
|
||||
@@ -52,10 +52,10 @@ using DeviceGroupedConvNDFwdInstance =
|
||||
32, // KPerBlock
|
||||
8, // AK1
|
||||
8, // BK1
|
||||
32, // MPerXdl
|
||||
32, // NPerXdl
|
||||
2, // MXdlPerWave
|
||||
4, // NXdlPerWave
|
||||
16, // MPerXdl
|
||||
16, // NPerXdl
|
||||
4, // MXdlPerWave
|
||||
8, // NXdlPerWave
|
||||
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
|
||||
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
|
||||
@@ -73,7 +73,7 @@ using DeviceGroupedConvNDFwdInstance =
|
||||
1,
|
||||
1,
|
||||
S<1, 32, 1, 8>,
|
||||
8,
|
||||
4,
|
||||
AComputeDataType,
|
||||
BComputeDataType>;
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "convnd_fwd_convscale_relu_common.hpp"
|
||||
|
||||
@@ -56,10 +56,10 @@ using DeviceGroupedConvNDFwdInstance =
|
||||
32, // KPerBlock
|
||||
8, // AK1
|
||||
8, // BK1
|
||||
32, // MPerXdl
|
||||
32, // NPerXdl
|
||||
2, // MXdlPerWave
|
||||
4, // NXdlPerWave
|
||||
16, // MPerXdl
|
||||
16, // NPerXdl
|
||||
4, // MXdlPerWave
|
||||
8, // NXdlPerWave
|
||||
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
|
||||
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
|
||||
@@ -77,7 +77,7 @@ using DeviceGroupedConvNDFwdInstance =
|
||||
1,
|
||||
1,
|
||||
S<1, 32, 1, 8>,
|
||||
8,
|
||||
4,
|
||||
AComputeDataType,
|
||||
BComputeDataType>;
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -71,10 +71,10 @@ using DeviceGroupedConvNDActivInstance =
|
||||
32, // KPerBlock
|
||||
8, // AK1
|
||||
8, // BK1
|
||||
32, // MPerXdl
|
||||
32, // NPerXdl
|
||||
2, // MXdlPerWave
|
||||
4, // NXdlPerWave
|
||||
16, // MPerXdl
|
||||
16, // NPerXdl
|
||||
4, // MXdlPerWave
|
||||
8, // NXdlPerWave
|
||||
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
|
||||
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
|
||||
@@ -92,7 +92,7 @@ using DeviceGroupedConvNDActivInstance =
|
||||
1,
|
||||
1,
|
||||
S<1, 32, 1, 8>,
|
||||
8>;
|
||||
4>;
|
||||
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename InDataType,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "convnd_fwd_activ_multi_ab_common.hpp"
|
||||
|
||||
@@ -23,4 +23,14 @@ using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDMultiABFwdInstance<D
|
||||
|
||||
#include "../run_convnd_activ_example.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_convnd_example(argc, argv); }
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
|
||||
if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
|
||||
{
|
||||
std::cout << "FP32 are not supported on gfx11 and gfx12" << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
return !run_convnd_example(argc, argv);
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <cstdlib>
|
||||
#include <iostream>
|
||||
@@ -68,10 +68,10 @@ using DeviceGroupedConvNDMultiABFwdInstance =
|
||||
32, // KPerBlock
|
||||
8, // AK1
|
||||
8, // BK1
|
||||
32, // MPerXdl
|
||||
32, // NPerXdl
|
||||
2, // MXdlPerWave
|
||||
4, // NXdlPerWave
|
||||
16, // MPerXdl
|
||||
16, // NPerXdl
|
||||
4, // MXdlPerWave
|
||||
8, // NXdlPerWave
|
||||
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
|
||||
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
|
||||
@@ -89,7 +89,7 @@ using DeviceGroupedConvNDMultiABFwdInstance =
|
||||
1,
|
||||
1,
|
||||
S<1, 32, 1, 8>,
|
||||
8>;
|
||||
4>;
|
||||
|
||||
namespace {
|
||||
template <ck::index_t NDimSpatial,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -71,10 +71,10 @@ using DeviceGroupedConvNDFwdInstance =
|
||||
32, // KPerBlock
|
||||
8, // AK1
|
||||
8, // BK1
|
||||
32, // MPerXdl
|
||||
32, // NPerXdl
|
||||
2, // MXdlPerWave
|
||||
4, // NXdlPerWave
|
||||
16, // MPerXdl
|
||||
16, // NPerXdl
|
||||
4, // MXdlPerWave
|
||||
8, // NXdlPerWave
|
||||
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
|
||||
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
|
||||
@@ -92,7 +92,7 @@ using DeviceGroupedConvNDFwdInstance =
|
||||
1,
|
||||
1,
|
||||
S<1, 32, 1, 8>,
|
||||
8>;
|
||||
4>;
|
||||
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename InDataType,
|
||||
|
||||
Reference in New Issue
Block a user