From 29c4f868ef8203fb7897c5427b43cc483f2076e9 Mon Sep 17 00:00:00 2001 From: Michal Kulikowski Date: Tue, 10 Mar 2026 14:38:58 +0100 Subject: [PATCH] [CK][Examples] Adding parameters for a couple of CK examples: -gemm_add_add_mean_meansquare_xdl_fp16 -gemm_dl_quantization_int8 -gemm_xdl_bias_relu_quantization_int8 -gemm_xdl_quantization_int8 Signed-off-by: Michal Kulikowski --- .../gemm_dl_quantization_int8.cpp | 30 ++++++++++++++- .../gemm_xdl_bias_relu_quantization_int8.cpp | 30 ++++++++++++++- .../gemm_xdl_quantization_int8.cpp | 30 ++++++++++++++- .../gemm_add_add_mean_meansquare_xdl_fp16.cpp | 38 +++++++++++++++++-- 4 files changed, 121 insertions(+), 7 deletions(-) diff --git a/example/14_gemm_quantization/gemm_dl_quantization_int8.cpp b/example/14_gemm_quantization/gemm_dl_quantization_int8.cpp index 6a8153c75f..5444e3dbd6 100644 --- a/example/14_gemm_quantization/gemm_dl_quantization_int8.cpp +++ b/example/14_gemm_quantization/gemm_dl_quantization_int8.cpp @@ -96,7 +96,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Dl< using ReferenceGemmInstance = ck::tensor_operation::host:: ReferenceGemm; -int main() +int main(int argc, char* argv[]) { bool do_verification = true; bool time_kernel = false; @@ -112,6 +112,34 @@ int main() float requant_scale = 0.03; + if(argc == 1) + { + // do nothing + } + else if(argc == 3 || argc == 9) + { + do_verification = std::stoi(argv[1]); + time_kernel = std::stoi(argv[2]); + if(argc == 9) + { + M = std::stoi(argv[3]); + N = std::stoi(argv[4]); + K = std::stoi(argv[5]); + + StrideA = std::stoi(argv[6]); + StrideB = std::stoi(argv[7]); + StrideE = std::stoi(argv[8]); + } + } + else + { + std::cout << "arg1: verification (0=no, 1=yes)\n" + << " arg2: Measure kernel execution time (1=ON, 0=Off)\n" + << " arg3 to 8: M (128x), N(128x), K(16x), StrideA, StrideB, StrideE\n" + << std::endl; + exit(1); + } + auto f_host_tensor_descriptor = [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { using namespace ck::literals; diff --git a/example/14_gemm_quantization/gemm_xdl_bias_relu_quantization_int8.cpp b/example/14_gemm_quantization/gemm_xdl_bias_relu_quantization_int8.cpp index ee2002bf0b..1a936f6af3 100644 --- a/example/14_gemm_quantization/gemm_xdl_bias_relu_quantization_int8.cpp +++ b/example/14_gemm_quantization/gemm_xdl_bias_relu_quantization_int8.cpp @@ -106,7 +106,7 @@ using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; -int main() +int main(int argc, char* argv[]) { bool do_verification = true; bool time_kernel = false; @@ -123,6 +123,34 @@ int main() float requant_scale = 0.03; + if(argc == 1) + { + // do nothing + } + else if(argc == 3 || argc == 9) + { + do_verification = std::stoi(argv[1]); + time_kernel = std::stoi(argv[2]); + if(argc == 9) + { + M = std::stoi(argv[3]); + N = std::stoi(argv[4]); + K = std::stoi(argv[5]); + + StrideA = std::stoi(argv[6]); + StrideB = std::stoi(argv[7]); + StrideE = std::stoi(argv[8]); + } + } + else + { + std::cout << "arg1: verification (0=no, 1=yes)\n" + << " arg2: Measure kernel execution time (1=ON, 0=Off)\n" + << " arg3 to 8: M (256x), N(128x), K(64x), StrideA, StrideB, StrideE\n" + << std::endl; + exit(1); + } + auto f_host_tensor_descriptor2d = [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { using namespace ck::literals; diff --git a/example/14_gemm_quantization/gemm_xdl_quantization_int8.cpp b/example/14_gemm_quantization/gemm_xdl_quantization_int8.cpp index a62bd50c56..baa2fd4353 100644 --- a/example/14_gemm_quantization/gemm_xdl_quantization_int8.cpp +++ b/example/14_gemm_quantization/gemm_xdl_quantization_int8.cpp @@ -99,7 +99,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Xdl using ReferenceGemmInstance = ck::tensor_operation::host:: ReferenceGemm; -int main() +int main(int argc, char* argv[]) { bool do_verification = true; bool time_kernel = false; @@ -115,6 +115,34 @@ int main() float requant_scale = 0.03; + if(argc == 1) + { + // do nothing + } + else if(argc == 3 || argc == 9) + { + do_verification = std::stoi(argv[1]); + time_kernel = std::stoi(argv[2]); + if(argc == 9) + { + M = std::stoi(argv[3]); + N = std::stoi(argv[4]); + K = std::stoi(argv[5]); + + StrideA = std::stoi(argv[6]); + StrideB = std::stoi(argv[7]); + StrideE = std::stoi(argv[8]); + } + } + else + { + std::cout << "arg1: verification (0=no, 1=yes)\n" + << " arg2: Measure kernel execution time (1=ON, 0=Off)\n" + << " arg3 to 8: M (256x), N(128x), K(64x), StrideA, StrideB, StrideE\n" + << std::endl; + exit(1); + } + auto f_host_tensor_descriptor = [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { using namespace ck::literals; diff --git a/example/16_gemm_multi_d_multi_reduces/gemm_add_add_mean_meansquare_xdl_fp16.cpp b/example/16_gemm_multi_d_multi_reduces/gemm_add_add_mean_meansquare_xdl_fp16.cpp index a30bedf282..0f436d4d90 100644 --- a/example/16_gemm_multi_d_multi_reduces/gemm_add_add_mean_meansquare_xdl_fp16.cpp +++ b/example/16_gemm_multi_d_multi_reduces/gemm_add_add_mean_meansquare_xdl_fp16.cpp @@ -131,7 +131,7 @@ auto f_host_tensor_descriptor2d = } }; -int main() +int main(int argc, char* argv[]) { ck::index_t M = 1024; ck::index_t N = 1024; @@ -143,6 +143,38 @@ int main() ck::index_t StrideD1 = 1024; ck::index_t StrideE = 1024; + bool do_verification = true; + bool time_kernel = false; + + if(argc == 1) + { + // do nothing + } + else if(argc == 3 || argc == 10) + { + do_verification = std::stoi(argv[1]); + time_kernel = std::stoi(argv[2]); + if(argc == 10) + { + M = std::stoi(argv[3]); + N = std::stoi(argv[4]); + K = std::stoi(argv[5]); + + StrideA = std::stoi(argv[6]); + StrideB = std::stoi(argv[7]); + StrideD1 = std::stoi(argv[8]); + StrideE = std::stoi(argv[9]); + } + } + else + { + std::cout << "arg1: verification (0=no, 1=yes)\n" + << " arg2: Measure kernel execution time (1=ON, 0=Off)\n" + << " arg3 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD1, StrideE\n" + << std::endl; + exit(1); + } + Tensor a_m_k(f_host_tensor_descriptor2d(M, K, StrideA, ALayout{})); Tensor b_k_n(f_host_tensor_descriptor2d(K, N, StrideB, BLayout{})); Tensor d0_n(f_host_tensor_descriptor1d(N, 1)); @@ -208,8 +240,7 @@ int main() invoker.Run(argument, StreamConfig{nullptr, false}); - bool do_verification = true; - bool pass = true; + bool pass = true; if(do_verification) { @@ -268,7 +299,6 @@ int main() pass &= ck::utils::check_err(r1_m, r1_m_host, "Error: Incorrect results d1", 1e-2, 1e-2); } - bool time_kernel = false; if(time_kernel) { float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});