[CK][Examples] Extending support for rdna3/4 part 2:

-example_batched_gemm_xdl_int8
-example_batched_gemm_xdl_fp8_rowwise_v3
-example_batched_gemm_xdl_fp32
-example_batched_gemm_xdl_bf16
-example_batched_gemm_xdl_bf16_v3
-example_batched_gemm_xdl_fp16
-example_splitk_gemm_bias_e_permute_xdl_fp32
*fixing return value to return 0 as success in above examples.

Fixing cmdline parameters in:
-example_sparse_embedding3_forward_layernorm
-example_elementwise_binary_4D_fp16
-elementwise_scale_permute_amax_2D_fp16_fp8

Signed-off-by: Michal Kulikowski <Michal.Kulikowski@amd.com>


[ROCm/composable_kernel commit: 7259b9c4db]
This commit is contained in:
Michal Kulikowski
2025-10-01 16:04:25 +02:00
committed by Michał Kulikowski
parent 5912982bda
commit f85778eab4
16 changed files with 85 additions and 73 deletions

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <cstdlib>
@@ -51,6 +51,8 @@ int main(int argc, char* argv[])
bool do_verification = true;
bool time_kernel = true;
std::vector<std::size_t> nchw = {16, 128, 32, 64};
if(argc == 1)
{
// use default
@@ -60,30 +62,21 @@ int main(int argc, char* argv[])
do_verification = std::stoi(argv[1]);
time_kernel = std::stoi(argv[2]);
}
else if(argc == 7)
{
do_verification = std::stoi(argv[1]);
time_kernel = std::stoi(argv[2]);
nchw[0] = std::stoi(argv[3]);
nchw[1] = std::stoi(argv[4]);
nchw[2] = std::stoi(argv[5]);
nchw[3] = std::stoi(argv[6]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: time kernel (0=no, 1=yes)\n");
exit(0);
}
std::vector<std::size_t> nchw = {16, 128, 32, 64};
if(argc == 1)
{
// use default case
}
else if(argc == 5)
{
nchw[0] = std::stoi(argv[1]);
nchw[1] = std::stoi(argv[2]);
nchw[2] = std::stoi(argv[3]);
nchw[3] = std::stoi(argv[4]);
}
else
{
std::cerr << "arg1 to 4: N, C, H, W" << std::endl;
return 1;
printf("arg3-6: N, C, H, W (default 16, 128, 32, 64)\n");
exit(1);
}
std::array<ck::index_t, 4> ab_lengths;