[CK][Examples] Extending support for rdna3/4 part 2:

-example_batched_gemm_xdl_int8
-example_batched_gemm_xdl_fp8_rowwise_v3
-example_batched_gemm_xdl_fp32
-example_batched_gemm_xdl_bf16
-example_batched_gemm_xdl_bf16_v3
-example_batched_gemm_xdl_fp16
-example_splitk_gemm_bias_e_permute_xdl_fp32
*fixing return value to return 0 as success in above examples.

Fixing cmdline parameters in:
-example_sparse_embedding3_forward_layernorm
-example_elementwise_binary_4D_fp16
-elementwise_scale_permute_amax_2D_fp16_fp8

Signed-off-by: Michal Kulikowski <Michal.Kulikowski@amd.com>
This commit is contained in:
Michal Kulikowski
2025-10-01 16:04:25 +02:00
committed by Michał Kulikowski
parent 1d4db30af9
commit 7259b9c4db
16 changed files with 85 additions and 73 deletions

View File

@@ -1,3 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
@@ -51,9 +53,9 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmMultiD
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>;
// clang-format on
#include "run_batched_gemm_example.inc"
int main(int argc, char* argv[]) { return !run_batched_gemm_example(argc, argv); }
int main(int argc, char* argv[]) { return run_batched_gemm_example(argc, argv); }

View File

@@ -68,10 +68,10 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmMultiD
32, // KPerBlock
8, // AK1
8, // BK1
32, // MPerXDL
32, // NPerXDL
4, // MXdlPerWave
2, // NXdlPerWave
16, // MPerXDL
16, // NPerXDL
8, // MXdlPerWave
4, // NXdlPerWave
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
@@ -89,11 +89,11 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmMultiD
1, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
S<8>, // CDEShuffleBlockTransferScalarPerVectors
S<4>, // CDEShuffleBlockTransferScalarPerVectors
ck::BlockGemmPipelineScheduler::Intrawave, // BlockGemmPipelineScheduler
ck::BlockGemmPipelineVersion::v3 // BlockGemmPipelineVersion
>;
#include "run_batched_gemm_example.inc"
int main(int argc, char* argv[]) { return !run_batched_gemm_example(argc, argv); }
int main(int argc, char* argv[]) { return run_batched_gemm_example(argc, argv); }

View File

@@ -1,3 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
@@ -51,9 +53,9 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmMultiD
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>;
// clang-format on
#include "run_batched_gemm_example.inc"
int main(int argc, char* argv[]) { return !run_batched_gemm_example(argc, argv); }
int main(int argc, char* argv[]) { return run_batched_gemm_example(argc, argv); }

View File

@@ -1,3 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include <initializer_list>
#include <iostream>

View File

@@ -1,3 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
@@ -50,9 +52,9 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmMultiD
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4>;
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 16, 4, 4, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 2>;
// clang-format on
#include "run_batched_gemm_example.inc"
int main(int argc, char* argv[]) { return !run_batched_gemm_example(argc, argv); }
int main(int argc, char* argv[]) { return run_batched_gemm_example(argc, argv); }

View File

@@ -74,10 +74,10 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmMultiD
64, // KPerBlock
16, // AK1
16, // BK1
32, // MPerXDL
32, // NPerXDL
4, // MXdlPerWave
2, // NXdlPerWave
16, // MPerXDL
16, // NPerXDL
8, // MXdlPerWave
4, // NXdlPerWave
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
@@ -95,7 +95,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmMultiD
1, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
S<8, 8, 1>, // CDEShuffleBlockTransferScalarPerVectors
S<4, 4, 1>, // CDEShuffleBlockTransferScalarPerVectors
ck::BlockGemmPipelineScheduler::Interwave, // BlockGemmPipelineScheduler
ck::BlockGemmPipelineVersion::v1, // BlockGemmPipelineVersion
F8 // ComputeTypeA
@@ -103,4 +103,4 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmMultiD
#include "run_batched_gemm_example_rowwise.inc"
int main(int argc, char* argv[]) { return !run_batched_gemm_rowwise_example(argc, argv); }
int main(int argc, char* argv[]) { return run_batched_gemm_rowwise_example(argc, argv); }

View File

@@ -1,3 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
@@ -96,4 +98,4 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmMultiD
#define BUILD_INT4_EXAMPLE
#include "run_batched_gemm_example.inc"
int main(int argc, char* argv[]) { return !run_batched_gemm_example(argc, argv); }
int main(int argc, char* argv[]) { return run_batched_gemm_example(argc, argv); }

View File

@@ -1,3 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
@@ -48,9 +50,9 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmMultiD
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>;
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 64, 16, 16, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 8>, 4>;
// clang-format on
#include "run_batched_gemm_example.inc"
int main(int argc, char* argv[]) { return !run_batched_gemm_example(argc, argv); }
int main(int argc, char* argv[]) { return run_batched_gemm_example(argc, argv); }

View File

@@ -1,3 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <random>
#pragma once
@@ -243,7 +245,7 @@ bool run_batched_gemm_example(int argc, char* argv[])
problem_size.N,
problem_size.K,
problem_size.batch_count);
exit(0);
exit(1);
}
problem_size.stride_A = problem_size.K;

View File

@@ -346,7 +346,7 @@ bool run_batched_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
{
std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl;
return true;
return false;
}
bool pass = true;
@@ -566,7 +566,7 @@ bool run_batched_gemm_fp16_int4_b_scale_example(int argc, char* argv[])
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=n0, 1=yes)\n");
exit(0);
exit(1);
}
problem_size.stride_A = problem_size.K;

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <random>
#pragma once