mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[CK][Examples] Extending support for rdna3/4 part 3:
-example_gemm_xdl_int8 -example_gemm_xdl_fp8 -example_gemm_xdl_fp8_bf8 -example_gemm_xdl_fp16_fp8 -example_gemm_add_add_fastgelu_xdl_int8 -example_grouped_gemm_xdl_int8 -example_grouped_conv_bwd_weight_xdl_bf16 -example_cgemm_xdl_fp32 -example_cgemm_xdl_int8 fixing cmdlines for: -example_22_cgemm -example_24_batched_gemm -example_batched_gemm_xdl_fp16int4_b_scale_v3 Signed-off-by: Michal Kulikowski <Michal.Kulikowski@amd.com>
This commit is contained in:
committed by
Michał Kulikowski
parent
7259b9c4db
commit
2444c44895
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
@@ -30,7 +30,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle
|
||||
// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Scheduler| Version| |
|
||||
// ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | |
|
||||
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopSched, PipelineVer, ComputeType>;
|
||||
< ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4, LoopSched, PipelineVer, ComputeType>;
|
||||
// clang-format on
|
||||
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
@@ -31,7 +31,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle
|
||||
// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Scheduler| Version| TypeA| TypeB|
|
||||
// ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | | |
|
||||
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, LoopSched, PipelineVer, ComputeTypeA, ComputeTypeB>;
|
||||
< ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 64, 16, 16, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4, LoopSched, PipelineVer, ComputeTypeA, ComputeTypeB>;
|
||||
// this instance has been tested working on gfx950
|
||||
// < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 128, 32, 32, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, LoopSched, PipelineVer, ComputeTypeA, ComputeTypeB>;
|
||||
// clang-format on
|
||||
@@ -55,4 +55,12 @@ using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm<ALa
|
||||
|
||||
#include "run_gemm_example.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); }
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
if(ck::is_gfx11_supported())
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
|
||||
return !run_gemm_example(argc, argv);
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
@@ -31,7 +31,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle
|
||||
// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Scheduler| Version| TypeA| TypeB|
|
||||
// ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | | |
|
||||
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, LoopSched, PipelineVer, ComputeTypeA, ComputeTypeB>;
|
||||
< ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 64, 16, 16, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4, LoopSched, PipelineVer, ComputeTypeA, ComputeTypeB>;
|
||||
// clang-format on
|
||||
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
|
||||
@@ -57,4 +57,12 @@ using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm<ALa
|
||||
|
||||
#include "run_gemm_example.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); }
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
if(ck::is_gfx11_supported())
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
|
||||
return !run_gemm_example(argc, argv);
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
@@ -27,7 +27,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle
|
||||
// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
// ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 16>;
|
||||
< ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 64, 16, 16, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>;
|
||||
// clang-format on
|
||||
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
@@ -32,7 +32,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_C
|
||||
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 16>;
|
||||
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 64, 16, 16, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>;
|
||||
// clang-format on
|
||||
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
@@ -51,7 +51,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Xdl
|
||||
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>;
|
||||
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 64, 16, 16, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 8>, 4>;
|
||||
// clang-format on
|
||||
|
||||
#include "run_grouped_gemm_example.inc"
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
@@ -44,10 +44,10 @@ using DeviceConvBwdWeightInstance =
|
||||
128, // NPerBlock
|
||||
4, // K0PerBlock
|
||||
8, // K1
|
||||
32, // MPerXdl
|
||||
32, // NPerXdl
|
||||
2, // MXdlPerWave
|
||||
2, // NXdlPerWave
|
||||
16, // MPerXdl
|
||||
16, // NPerXdl
|
||||
4, // MXdlPerWave
|
||||
4, // NXdlPerWave
|
||||
S<1, 4, 16, 4>, // ABlockTransferThreadClusterLengths_K0_M_K1
|
||||
S<0, 3, 1, 2>, // ABlockTransferThreadClusterArrangeOrder
|
||||
S<0, 2, 1, 3>, // ABlockTransferSrcAccessOrder
|
||||
@@ -80,6 +80,11 @@ using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWe
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
if(ck::is_gfx11_supported())
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
|
||||
ExecutionConfig config;
|
||||
ck::utils::conv::ConvParam conv_param = DefaultConvParam;
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
|
||||
@@ -48,10 +48,10 @@ using DeviceCGemmInstance = ck::tensor_operation::device::DeviceCGemm_4Gemm_Xdl_
|
||||
16, // index_t KPerBlock
|
||||
4, // index_t AK1
|
||||
4, // index_t BK1
|
||||
32, // index_t MPerXDL
|
||||
32, // index_t NPerXDL
|
||||
4, // index_t MXdlPerWave
|
||||
2, // index_t NXdlPerWave
|
||||
16, // index_t MPerXDL
|
||||
16, // index_t NPerXDL
|
||||
8, // index_t MXdlPerWave
|
||||
4, // index_t NXdlPerWave
|
||||
S<4, 64, 1>, // typename ABlockTransferThreadClusterLengths_AK0_M_AK1
|
||||
S<1, 0, 2>, // typename ABlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // typename ABlockTransferSrcAccessOrder
|
||||
@@ -69,11 +69,16 @@ using DeviceCGemmInstance = ck::tensor_operation::device::DeviceCGemm_4Gemm_Xdl_
|
||||
1, // index_t CShuffleMXdlPerWavePerShuffle
|
||||
1, // index_t CShuffleNXdlPerWavePerShuffle
|
||||
S<1, 16, 1, 16>, // typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
4>; // index_t CShuffleBlockTransferScalarPerVector_NPerBlock
|
||||
2>; // index_t CShuffleBlockTransferScalarPerVector_NPerBlock
|
||||
// clang-format on
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = false;
|
||||
@@ -87,25 +92,25 @@ int main(int argc, char* argv[])
|
||||
ck::index_t StrideB = 4096;
|
||||
ck::index_t StrideC = 4096;
|
||||
|
||||
if(argc == 4)
|
||||
if(argc == 1)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
// use default case
|
||||
}
|
||||
else if(argc == 10)
|
||||
else if(argc == 4 || argc == 10)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
if(argc == 10)
|
||||
{
|
||||
M = std::stoi(argv[4]);
|
||||
N = std::stoi(argv[5]);
|
||||
K = std::stoi(argv[6]);
|
||||
|
||||
M = std::stoi(argv[4]);
|
||||
N = std::stoi(argv[5]);
|
||||
K = std::stoi(argv[6]);
|
||||
|
||||
StrideA = std::stoi(argv[7]);
|
||||
StrideB = std::stoi(argv[8]);
|
||||
StrideC = std::stoi(argv[9]);
|
||||
StrideA = std::stoi(argv[7]);
|
||||
StrideB = std::stoi(argv[8]);
|
||||
StrideC = std::stoi(argv[9]);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -114,7 +119,7 @@ int main(int argc, char* argv[])
|
||||
<< "arg3: run kernel # of times (>1)\n"
|
||||
<< "arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n"
|
||||
<< std::endl;
|
||||
exit(0);
|
||||
exit(1);
|
||||
}
|
||||
|
||||
return !run_cgemm_xdl<ADataType,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
|
||||
@@ -48,10 +48,10 @@ using DeviceCGemmInstance = ck::tensor_operation::device::DeviceCGemm_4Gemm_Xdl_
|
||||
64, // index_t KPerBlock
|
||||
16, // index_t AK1
|
||||
16, // index_t BK1
|
||||
32, // index_t MPerXDL
|
||||
32, // index_t NPerXDL
|
||||
4, // index_t MXdlPerWave
|
||||
2, // index_t NXdlPerWave
|
||||
16, // index_t MPerXDL
|
||||
16, // index_t NPerXDL
|
||||
8, // index_t MXdlPerWave
|
||||
4, // index_t NXdlPerWave
|
||||
S<4, 64, 1>, // typename ABlockTransferThreadClusterLengths_AK0_M_AK1
|
||||
S<1, 0, 2>, // typename ABlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // typename ABlockTransferSrcAccessOrder
|
||||
@@ -68,8 +68,8 @@ using DeviceCGemmInstance = ck::tensor_operation::device::DeviceCGemm_4Gemm_Xdl_
|
||||
1, // index_t BBlockLdsExtraN
|
||||
1, // index_t CShuffleMXdlPerWavePerShuffle
|
||||
1, // index_t CShuffleNXdlPerWavePerShuffle
|
||||
S<1, 64, 1, 4>, // typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
16>; // index_t CShuffleBlockTransferScalarPerVector_NPerBlock
|
||||
S<1, 32, 1, 8>, // typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
4>; // index_t CShuffleBlockTransferScalarPerVector_NPerBlock
|
||||
// clang-format on
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
@@ -87,25 +87,25 @@ int main(int argc, char* argv[])
|
||||
ck::index_t StrideB = 4096;
|
||||
ck::index_t StrideC = 4096;
|
||||
|
||||
if(argc == 4)
|
||||
if(argc == 1)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
// use default case
|
||||
}
|
||||
else if(argc == 10)
|
||||
else if(argc == 4 || argc == 10)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
if(argc == 10)
|
||||
{
|
||||
M = std::stoi(argv[4]);
|
||||
N = std::stoi(argv[5]);
|
||||
K = std::stoi(argv[6]);
|
||||
|
||||
M = std::stoi(argv[4]);
|
||||
N = std::stoi(argv[5]);
|
||||
K = std::stoi(argv[6]);
|
||||
|
||||
StrideA = std::stoi(argv[7]);
|
||||
StrideB = std::stoi(argv[8]);
|
||||
StrideC = std::stoi(argv[9]);
|
||||
StrideA = std::stoi(argv[7]);
|
||||
StrideB = std::stoi(argv[8]);
|
||||
StrideC = std::stoi(argv[9]);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -114,7 +114,7 @@ int main(int argc, char* argv[])
|
||||
<< "arg3: run kernel # of times (>1)\n"
|
||||
<< "arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n"
|
||||
<< std::endl;
|
||||
exit(0);
|
||||
exit(1);
|
||||
}
|
||||
|
||||
return !run_cgemm_xdl<ADataType,
|
||||
|
||||
@@ -57,4 +57,12 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmMultiD
|
||||
|
||||
#include "run_batched_gemm_example.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return run_batched_gemm_example(argc, argv); }
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
|
||||
return run_batched_gemm_example(argc, argv);
|
||||
}
|
||||
|
||||
@@ -218,35 +218,37 @@ bool run_batched_gemm_example(int argc, char* argv[])
|
||||
|
||||
problem_size.batch_count = 2;
|
||||
|
||||
if(argc == 4)
|
||||
if(argc == 1)
|
||||
{
|
||||
// use default case
|
||||
}
|
||||
else if(argc == 4 || argc == 8)
|
||||
{
|
||||
config.do_verification = std::stoi(argv[1]);
|
||||
config.init_method = std::stoi(argv[2]);
|
||||
config.time_kernel = std::stoi(argv[3]);
|
||||
}
|
||||
else if(argc == 8)
|
||||
{
|
||||
config.do_verification = std::stoi(argv[1]);
|
||||
config.init_method = std::stoi(argv[2]);
|
||||
config.time_kernel = std::stoi(argv[3]);
|
||||
problem_size.M = std::stoi(argv[4]);
|
||||
problem_size.N = std::stoi(argv[5]);
|
||||
problem_size.K = std::stoi(argv[6]);
|
||||
problem_size.batch_count = std::stoi(argv[7]);
|
||||
if(argc == 8)
|
||||
{
|
||||
problem_size.M = std::stoi(argv[4]);
|
||||
problem_size.N = std::stoi(argv[5]);
|
||||
problem_size.K = std::stoi(argv[6]);
|
||||
problem_size.batch_count = std::stoi(argv[7]);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("arg1: verification (0=no, 1=yes)\n");
|
||||
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
|
||||
printf("arg3: time kernel (0=n0, 1=yes)\n");
|
||||
printf("optinal\n");
|
||||
printf("arg4-7: M = %d N = %d K = %d Batch = %d\n",
|
||||
problem_size.M,
|
||||
problem_size.N,
|
||||
problem_size.K,
|
||||
problem_size.batch_count);
|
||||
printf("arg3: time kernel (0=no, 1=yes)\n");
|
||||
printf("optional\n");
|
||||
printf("arg4-7: M, N, K, Batch\n");
|
||||
exit(1);
|
||||
}
|
||||
printf("M = %d N = %d K = %d Batch = %d\n",
|
||||
problem_size.M,
|
||||
problem_size.N,
|
||||
problem_size.K,
|
||||
problem_size.batch_count);
|
||||
|
||||
problem_size.stride_A = problem_size.K;
|
||||
problem_size.stride_B = problem_size.K;
|
||||
|
||||
@@ -523,6 +523,11 @@ bool run_batched_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
|
||||
|
||||
bool run_batched_gemm_fp16_int4_b_scale_example(int argc, char* argv[])
|
||||
{
|
||||
if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
|
||||
{
|
||||
return 1;
|
||||
}
|
||||
|
||||
ProblemSize problem_size;
|
||||
ExecutionConfig config;
|
||||
|
||||
@@ -535,30 +540,30 @@ bool run_batched_gemm_fp16_int4_b_scale_example(int argc, char* argv[])
|
||||
|
||||
problem_size.batch_count = 2;
|
||||
|
||||
if(argc == 4)
|
||||
if(argc == 1)
|
||||
{
|
||||
config.do_verification = std::stoi(argv[1]);
|
||||
config.init_method = std::stoi(argv[2]);
|
||||
config.time_kernel = std::stoi(argv[3]);
|
||||
// use default case
|
||||
}
|
||||
else if(argc >= 7)
|
||||
else if(argc == 4 || argc >= 7)
|
||||
{
|
||||
config.do_verification = std::stoi(argv[1]);
|
||||
config.init_method = std::stoi(argv[2]);
|
||||
config.time_kernel = std::stoi(argv[3]);
|
||||
|
||||
problem_size.M = std::stoi(argv[4]);
|
||||
problem_size.N = std::stoi(argv[5]);
|
||||
problem_size.K = std::stoi(argv[6]);
|
||||
|
||||
if(argc >= 8)
|
||||
if(argc >= 7)
|
||||
{
|
||||
problem_size.batch_count = std::stoi(argv[7]);
|
||||
}
|
||||
problem_size.M = std::stoi(argv[4]);
|
||||
problem_size.N = std::stoi(argv[5]);
|
||||
problem_size.K = std::stoi(argv[6]);
|
||||
|
||||
if(argc >= 9)
|
||||
{
|
||||
problem_size.KBatch = std::stoi(argv[8]);
|
||||
if(argc >= 8)
|
||||
{
|
||||
problem_size.batch_count = std::stoi(argv[7]);
|
||||
}
|
||||
|
||||
if(argc >= 9)
|
||||
{
|
||||
problem_size.KBatch = std::stoi(argv[8]);
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
@@ -566,6 +571,9 @@ bool run_batched_gemm_fp16_int4_b_scale_example(int argc, char* argv[])
|
||||
printf("arg1: verification (0=no, 1=yes)\n");
|
||||
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
|
||||
printf("arg3: time kernel (0=n0, 1=yes)\n");
|
||||
printf("arg4-6: problem size (M, N, K)\n");
|
||||
printf("arg7: batch count\n");
|
||||
printf("arg8: KBatch\n");
|
||||
exit(1);
|
||||
}
|
||||
|
||||
|
||||
@@ -56,7 +56,7 @@ using DeviceOpInstanceKKNN = ck::tensor_operation::device::
|
||||
//############################################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Spacialization| Spacialization| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
//############################################| | | | | | | | | | | Operation| Operation| Operation| | | | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
//############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceSplitKContractionMultipleD_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, F16, F16, F32, F16, DsDataType, F16, AElementOp, BElementOp, CDEElementOp, GemmSpec, ABSpec, ABSpec, DESpec, 1, 256, 256, 128, 32, 8, 8, 16, 16, 8, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 4>;
|
||||
DeviceSplitKContractionMultipleD_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, F16, F16, F32, F16, DsDataType, F16, AElementOp, BElementOp, CDEElementOp, GemmSpec, ABSpec, ABSpec, DESpec, 1, 256, 256, 128, 32, 8, 8, 16, 16, 8, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1>;
|
||||
// clang-format on
|
||||
|
||||
using DeviceOpInstance = DeviceOpInstanceKKNN;
|
||||
|
||||
Reference in New Issue
Block a user