From f85778eab4ea8d3e81cca51e9f669dba8baf2982 Mon Sep 17 00:00:00 2001 From: Michal Kulikowski Date: Wed, 1 Oct 2025 16:04:25 +0200 Subject: [PATCH] [CK][Examples] Extending support for rdna3/4 part 2: -example_batched_gemm_xdl_int8 -example_batched_gemm_xdl_fp8_rowwise_v3 -example_batched_gemm_xdl_fp32 -example_batched_gemm_xdl_bf16 -example_batched_gemm_xdl_bf16_v3 -example_batched_gemm_xdl_fp16 -example_splitk_gemm_bias_e_permute_xdl_fp32 *fixing return value to return 0 as success in above examples. Fixing cmdline parameters in: -example_sparse_embedding3_forward_layernorm -example_elementwise_binary_4D_fp16 -elementwise_scale_permute_amax_2D_fp16_fp8 Signed-off-by: Michal Kulikowski [ROCm/composable_kernel commit: 7259b9c4dba0a955633b768527de214bc2cd03ac] --- example/01_gemm/gemm_xdl_bf16.cpp | 4 +-- .../24_batched_gemm/batched_gemm_xdl_bf16.cpp | 6 ++-- .../batched_gemm_xdl_bf16_v3.cpp | 12 +++---- .../24_batched_gemm/batched_gemm_xdl_fp16.cpp | 6 ++-- .../batched_gemm_xdl_fp16int4_b_scale_v3.cpp | 2 ++ .../24_batched_gemm/batched_gemm_xdl_fp32.cpp | 6 ++-- .../batched_gemm_xdl_fp8_rowwise_v3.cpp | 12 +++---- .../24_batched_gemm/batched_gemm_xdl_int4.cpp | 4 ++- .../24_batched_gemm/batched_gemm_xdl_int8.cpp | 6 ++-- .../run_batched_gemm_example.inc | 4 ++- ..._batched_gemm_example_fp16int4_b_scale.inc | 4 +-- .../run_batched_gemm_example_rowwise.inc | 2 +- .../sparse_embedding3_forward_layernorm.cpp | 23 ++++++------ .../splitk_gemm_bias_e_permute_xdl_fp16.cpp | 4 +-- .../elementwise_binary_4D_fp16.cpp | 35 ++++++++----------- ...entwise_scale_permute_amax_2D_fp16_fp8.cpp | 28 ++++++++------- 16 files changed, 85 insertions(+), 73 deletions(-) diff --git a/example/01_gemm/gemm_xdl_bf16.cpp b/example/01_gemm/gemm_xdl_bf16.cpp index 6cfff30dbd..0440e8246b 100644 --- a/example/01_gemm/gemm_xdl_bf16.cpp +++ b/example/01_gemm/gemm_xdl_bf16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" @@ -27,7 +27,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle // ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| // ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| // ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>; + < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>; // clang-format on using ReferenceGemmInstance = ck::tensor_operation::host:: diff --git a/example/24_batched_gemm/batched_gemm_xdl_bf16.cpp b/example/24_batched_gemm/batched_gemm_xdl_bf16.cpp index c684c13d0d..d155b2c411 100644 --- a/example/24_batched_gemm/batched_gemm_xdl_bf16.cpp +++ b/example/24_batched_gemm/batched_gemm_xdl_bf16.cpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #include #include #include @@ -51,9 +53,9 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmMultiD //######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>; + < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>; // clang-format on #include "run_batched_gemm_example.inc" -int main(int argc, char* argv[]) { return !run_batched_gemm_example(argc, argv); } +int main(int argc, char* argv[]) { return run_batched_gemm_example(argc, argv); } diff --git a/example/24_batched_gemm/batched_gemm_xdl_bf16_v3.cpp b/example/24_batched_gemm/batched_gemm_xdl_bf16_v3.cpp index 548500518f..bd4be91103 100644 --- a/example/24_batched_gemm/batched_gemm_xdl_bf16_v3.cpp +++ b/example/24_batched_gemm/batched_gemm_xdl_bf16_v3.cpp @@ -68,10 +68,10 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmMultiD 32, // KPerBlock 8, // AK1 8, // BK1 - 32, // MPerXDL - 32, // NPerXDL - 4, // MXdlPerWave - 2, // NXdlPerWave + 16, // MPerXDL + 16, // NPerXDL + 8, // MXdlPerWave + 4, // NXdlPerWave S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // ABlockTransferSrcAccessOrder @@ -89,11 +89,11 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmMultiD 1, // CShuffleMXdlPerWavePerShuffle 1, // CShuffleNXdlPerWavePerShuffle S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock - S<8>, // CDEShuffleBlockTransferScalarPerVectors + S<4>, // CDEShuffleBlockTransferScalarPerVectors ck::BlockGemmPipelineScheduler::Intrawave, // BlockGemmPipelineScheduler ck::BlockGemmPipelineVersion::v3 // BlockGemmPipelineVersion >; #include "run_batched_gemm_example.inc" -int main(int argc, char* argv[]) { return !run_batched_gemm_example(argc, argv); } +int main(int argc, char* argv[]) { return run_batched_gemm_example(argc, argv); } diff --git a/example/24_batched_gemm/batched_gemm_xdl_fp16.cpp b/example/24_batched_gemm/batched_gemm_xdl_fp16.cpp index d1985f9af5..80f6b1d663 100644 --- a/example/24_batched_gemm/batched_gemm_xdl_fp16.cpp +++ b/example/24_batched_gemm/batched_gemm_xdl_fp16.cpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #include #include #include @@ -51,9 +53,9 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmMultiD //######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>; + < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>; // clang-format on #include "run_batched_gemm_example.inc" -int main(int argc, char* argv[]) { return !run_batched_gemm_example(argc, argv); } +int main(int argc, char* argv[]) { return run_batched_gemm_example(argc, argv); } diff --git a/example/24_batched_gemm/batched_gemm_xdl_fp16int4_b_scale_v3.cpp b/example/24_batched_gemm/batched_gemm_xdl_fp16int4_b_scale_v3.cpp index 42171bcdb7..46e452005d 100644 --- a/example/24_batched_gemm/batched_gemm_xdl_fp16int4_b_scale_v3.cpp +++ b/example/24_batched_gemm/batched_gemm_xdl_fp16int4_b_scale_v3.cpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #include #include #include diff --git a/example/24_batched_gemm/batched_gemm_xdl_fp32.cpp b/example/24_batched_gemm/batched_gemm_xdl_fp32.cpp index a92a04dbe6..8e3bf3813c 100644 --- a/example/24_batched_gemm/batched_gemm_xdl_fp32.cpp +++ b/example/24_batched_gemm/batched_gemm_xdl_fp32.cpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #include #include #include @@ -50,9 +52,9 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmMultiD //######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4>; + < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 16, 4, 4, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 2>; // clang-format on #include "run_batched_gemm_example.inc" -int main(int argc, char* argv[]) { return !run_batched_gemm_example(argc, argv); } +int main(int argc, char* argv[]) { return run_batched_gemm_example(argc, argv); } diff --git a/example/24_batched_gemm/batched_gemm_xdl_fp8_rowwise_v3.cpp b/example/24_batched_gemm/batched_gemm_xdl_fp8_rowwise_v3.cpp index 84f92eba8e..2a4610bf90 100644 --- a/example/24_batched_gemm/batched_gemm_xdl_fp8_rowwise_v3.cpp +++ b/example/24_batched_gemm/batched_gemm_xdl_fp8_rowwise_v3.cpp @@ -74,10 +74,10 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmMultiD 64, // KPerBlock 16, // AK1 16, // BK1 - 32, // MPerXDL - 32, // NPerXDL - 4, // MXdlPerWave - 2, // NXdlPerWave + 16, // MPerXDL + 16, // NPerXDL + 8, // MXdlPerWave + 4, // NXdlPerWave S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // ABlockTransferSrcAccessOrder @@ -95,7 +95,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmMultiD 1, // CShuffleMXdlPerWavePerShuffle 1, // CShuffleNXdlPerWavePerShuffle S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock - S<8, 8, 1>, // CDEShuffleBlockTransferScalarPerVectors + S<4, 4, 1>, // CDEShuffleBlockTransferScalarPerVectors ck::BlockGemmPipelineScheduler::Interwave, // BlockGemmPipelineScheduler ck::BlockGemmPipelineVersion::v1, // BlockGemmPipelineVersion F8 // ComputeTypeA @@ -103,4 +103,4 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmMultiD #include "run_batched_gemm_example_rowwise.inc" -int main(int argc, char* argv[]) { return !run_batched_gemm_rowwise_example(argc, argv); } +int main(int argc, char* argv[]) { return run_batched_gemm_rowwise_example(argc, argv); } diff --git a/example/24_batched_gemm/batched_gemm_xdl_int4.cpp b/example/24_batched_gemm/batched_gemm_xdl_int4.cpp index 5e82cfe324..4adb79ed29 100644 --- a/example/24_batched_gemm/batched_gemm_xdl_int4.cpp +++ b/example/24_batched_gemm/batched_gemm_xdl_int4.cpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #include #include #include @@ -96,4 +98,4 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmMultiD #define BUILD_INT4_EXAMPLE #include "run_batched_gemm_example.inc" -int main(int argc, char* argv[]) { return !run_batched_gemm_example(argc, argv); } +int main(int argc, char* argv[]) { return run_batched_gemm_example(argc, argv); } diff --git a/example/24_batched_gemm/batched_gemm_xdl_int8.cpp b/example/24_batched_gemm/batched_gemm_xdl_int8.cpp index ad22227af5..06b8c3c0cd 100644 --- a/example/24_batched_gemm/batched_gemm_xdl_int8.cpp +++ b/example/24_batched_gemm/batched_gemm_xdl_int8.cpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #include #include #include @@ -48,9 +50,9 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmMultiD //######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>; + < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 64, 16, 16, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 8>, 4>; // clang-format on #include "run_batched_gemm_example.inc" -int main(int argc, char* argv[]) { return !run_batched_gemm_example(argc, argv); } +int main(int argc, char* argv[]) { return run_batched_gemm_example(argc, argv); } diff --git a/example/24_batched_gemm/run_batched_gemm_example.inc b/example/24_batched_gemm/run_batched_gemm_example.inc index c93a2051d2..1641265b00 100644 --- a/example/24_batched_gemm/run_batched_gemm_example.inc +++ b/example/24_batched_gemm/run_batched_gemm_example.inc @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #include #pragma once @@ -243,7 +245,7 @@ bool run_batched_gemm_example(int argc, char* argv[]) problem_size.N, problem_size.K, problem_size.batch_count); - exit(0); + exit(1); } problem_size.stride_A = problem_size.K; diff --git a/example/24_batched_gemm/run_batched_gemm_example_fp16int4_b_scale.inc b/example/24_batched_gemm/run_batched_gemm_example_fp16int4_b_scale.inc index ac34ed5b8a..4647f14037 100644 --- a/example/24_batched_gemm/run_batched_gemm_example_fp16int4_b_scale.inc +++ b/example/24_batched_gemm/run_batched_gemm_example_fp16int4_b_scale.inc @@ -346,7 +346,7 @@ bool run_batched_gemm(const ProblemSize& problem_size, const ExecutionConfig& co { std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl; - return true; + return false; } bool pass = true; @@ -566,7 +566,7 @@ bool run_batched_gemm_fp16_int4_b_scale_example(int argc, char* argv[]) printf("arg1: verification (0=no, 1=yes)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); printf("arg3: time kernel (0=n0, 1=yes)\n"); - exit(0); + exit(1); } problem_size.stride_A = problem_size.K; diff --git a/example/24_batched_gemm/run_batched_gemm_example_rowwise.inc b/example/24_batched_gemm/run_batched_gemm_example_rowwise.inc index 9939429a08..6ed0b23407 100644 --- a/example/24_batched_gemm/run_batched_gemm_example_rowwise.inc +++ b/example/24_batched_gemm/run_batched_gemm_example_rowwise.inc @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include #pragma once diff --git a/example/36_sparse_embedding/sparse_embedding3_forward_layernorm.cpp b/example/36_sparse_embedding/sparse_embedding3_forward_layernorm.cpp index a1b952259f..ff7acea48d 100644 --- a/example/36_sparse_embedding/sparse_embedding3_forward_layernorm.cpp +++ b/example/36_sparse_embedding/sparse_embedding3_forward_layernorm.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -77,17 +77,21 @@ int main(int argc, char* argv[]) { // Use default value } - else if(argc == 4) + else if(argc == 5) { - num_rows = atoi(argv[1]); - dim_mask = strtol(argv[2], nullptr, 0); - index_length = atoi(argv[3]); + time_kernel = std::stoi(argv[1]); + num_rows = std::stoi(argv[2]); + dim_mask = strtol(argv[3], nullptr, 0); + index_length = std::stoi(argv[4]); } else { std::cout << "Usage of " << argv[0] << std::endl; - std::cout << "Arg1-3: num_rows dim_mask index_length" << std::endl; + std::cout << "arg1: time kernel (0=no, 1=yes)" << std::endl; + std::cout << "arg2-4: num_rows dim_mask index_length" << std::endl; + return 1; } + ck::static_for<0, dims.Size(), 1>{}([&](auto I) { if(dim_mask & (1 << I.value)) { @@ -160,11 +164,10 @@ int main(int argc, char* argv[]) << std::endl << std::flush; - bool is_supported = device_instance.IsSupportedArgument(argument_ptr.get()); - - if(!is_supported) + if(!device_instance.IsSupportedArgument(argument_ptr.get())) { - std::cout << "Runtime parameters are not supported" << std::endl; + std::cerr << device_instance.GetTypeString() << " does not support this problem" + << std::endl; return; } diff --git a/example/43_splitk_gemm_bias_e_permute/splitk_gemm_bias_e_permute_xdl_fp16.cpp b/example/43_splitk_gemm_bias_e_permute/splitk_gemm_bias_e_permute_xdl_fp16.cpp index b5e9686260..9273918bfa 100644 --- a/example/43_splitk_gemm_bias_e_permute/splitk_gemm_bias_e_permute_xdl_fp16.cpp +++ b/example/43_splitk_gemm_bias_e_permute/splitk_gemm_bias_e_permute_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -56,7 +56,7 @@ using DeviceOpInstanceKKNN = ck::tensor_operation::device:: //############################################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Spacialization| Spacialization| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //############################################| | | | | | | | | | | Operation| Operation| Operation| | | | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceSplitKContractionMultipleD_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, F16, F16, F32, F16, DsDataType, F16, AElementOp, BElementOp, CDEElementOp, GemmSpec, ABSpec, ABSpec, DESpec, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>; + DeviceSplitKContractionMultipleD_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, F16, F16, F32, F16, DsDataType, F16, AElementOp, BElementOp, CDEElementOp, GemmSpec, ABSpec, ABSpec, DESpec, 1, 256, 256, 128, 32, 8, 8, 16, 16, 8, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 4>; // clang-format on using DeviceOpInstance = DeviceOpInstanceKKNN; diff --git a/example/44_elementwise_permute/elementwise_binary_4D_fp16.cpp b/example/44_elementwise_permute/elementwise_binary_4D_fp16.cpp index 8ddd432c11..bd9a8c151b 100644 --- a/example/44_elementwise_permute/elementwise_binary_4D_fp16.cpp +++ b/example/44_elementwise_permute/elementwise_binary_4D_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -51,6 +51,8 @@ int main(int argc, char* argv[]) bool do_verification = true; bool time_kernel = true; + std::vector nchw = {16, 128, 32, 64}; + if(argc == 1) { // use default @@ -60,30 +62,21 @@ int main(int argc, char* argv[]) do_verification = std::stoi(argv[1]); time_kernel = std::stoi(argv[2]); } + else if(argc == 7) + { + do_verification = std::stoi(argv[1]); + time_kernel = std::stoi(argv[2]); + nchw[0] = std::stoi(argv[3]); + nchw[1] = std::stoi(argv[4]); + nchw[2] = std::stoi(argv[5]); + nchw[3] = std::stoi(argv[6]); + } else { printf("arg1: verification (0=no, 1=yes)\n"); printf("arg2: time kernel (0=no, 1=yes)\n"); - exit(0); - } - - std::vector nchw = {16, 128, 32, 64}; - if(argc == 1) - { - // use default case - } - else if(argc == 5) - { - nchw[0] = std::stoi(argv[1]); - nchw[1] = std::stoi(argv[2]); - nchw[2] = std::stoi(argv[3]); - nchw[3] = std::stoi(argv[4]); - } - else - { - std::cerr << "arg1 to 4: N, C, H, W" << std::endl; - - return 1; + printf("arg3-6: N, C, H, W (default 16, 128, 32, 64)\n"); + exit(1); } std::array ab_lengths; diff --git a/example/44_elementwise_permute/elementwise_scale_permute_amax_2D_fp16_fp8.cpp b/example/44_elementwise_permute/elementwise_scale_permute_amax_2D_fp16_fp8.cpp index 0619cc7139..1ce797e4dd 100644 --- a/example/44_elementwise_permute/elementwise_scale_permute_amax_2D_fp16_fp8.cpp +++ b/example/44_elementwise_permute/elementwise_scale_permute_amax_2D_fp16_fp8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -119,6 +119,11 @@ int main(int argc, char* argv[]) bool do_verification = true; bool time_kernel = true; + const float scale = 2.f; + + ck::index_t M = 1024; + ck::index_t K = 1024; + if(argc == 1) { // use default @@ -128,22 +133,19 @@ int main(int argc, char* argv[]) do_verification = std::stoi(argv[1]); time_kernel = std::stoi(argv[2]); } + else if(argc == 5) + { + do_verification = std::stoi(argv[1]); + time_kernel = std::stoi(argv[2]); + M = std::stoi(argv[3]); + K = std::stoi(argv[4]); + } else { printf("arg1: verification (0=no, 1=yes)\n"); printf("arg2: time kernel (0=no, 1=yes)\n"); - exit(0); - } - - const float scale = 2.f; - - ck::index_t M = 1024; - ck::index_t K = 1024; - - if(argc == 3) - { - M = std::stoi(argv[1]); - K = std::stoi(argv[2]); + printf("arg3-4: M(default=1024), K(default=1024)\n"); + exit(1); } std::array dims = {M, K};