diff --git a/test/ck_tile/gemm/test_gemm_pipeline_basic_bf16.cpp b/test/ck_tile/gemm/test_gemm_pipeline_basic_bf16.cpp index af2cb398f5..4e3033782c 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_basic_bf16.cpp +++ b/test/ck_tile/gemm/test_gemm_pipeline_basic_bf16.cpp @@ -2,4 +2,4 @@ // Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include "test_gemm_pipeline_basic_run_test.inc" -int main() { return run_gemm_combinations("bf16"); } +int main() { return run_gemm_combinations(); } diff --git a/test/ck_tile/gemm/test_gemm_pipeline_basic_bf8.cpp b/test/ck_tile/gemm/test_gemm_pipeline_basic_bf8.cpp index fd8c28ef17..61614fc6f5 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_basic_bf8.cpp +++ b/test/ck_tile/gemm/test_gemm_pipeline_basic_bf8.cpp @@ -2,4 +2,4 @@ // Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include "test_gemm_pipeline_basic_run_test.inc" -int main() { return run_gemm_combinations("bf8"); } +int main() { return run_gemm_combinations(); } diff --git a/test/ck_tile/gemm/test_gemm_pipeline_basic_fp16.cpp b/test/ck_tile/gemm/test_gemm_pipeline_basic_fp16.cpp index 4a93d6046a..c667c08053 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_basic_fp16.cpp +++ b/test/ck_tile/gemm/test_gemm_pipeline_basic_fp16.cpp @@ -2,4 +2,4 @@ // Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include "test_gemm_pipeline_basic_run_test.inc" -int main() { return run_gemm_combinations("fp16"); } +int main() { return run_gemm_combinations(); } diff --git a/test/ck_tile/gemm/test_gemm_pipeline_basic_fp8.cpp b/test/ck_tile/gemm/test_gemm_pipeline_basic_fp8.cpp index fd8c28ef17..9a3498b7ea 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_basic_fp8.cpp +++ b/test/ck_tile/gemm/test_gemm_pipeline_basic_fp8.cpp @@ -2,4 +2,4 @@ // Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include "test_gemm_pipeline_basic_run_test.inc" -int main() { return run_gemm_combinations("bf8"); } +int main() { return run_gemm_combinations(); } diff --git a/test/ck_tile/gemm/test_gemm_pipeline_basic_run_test.inc b/test/ck_tile/gemm/test_gemm_pipeline_basic_run_test.inc index 53eff9ecc4..1fdf26f01c 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_basic_run_test.inc +++ b/test/ck_tile/gemm/test_gemm_pipeline_basic_run_test.inc @@ -131,7 +131,9 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) } template -bool run_gemm_test_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[]) +bool run_gemm_test_prec_type(std::string a_layout, + std::string b_layout, + ck_tile::ArgParser& arg_parser) { using Row = ck_tile::tensor_layout::gemm::RowMajor; using Col = ck_tile::tensor_layout::gemm::ColumnMajor; @@ -141,12 +143,12 @@ bool run_gemm_test_prec_type(std::string a_layout, std::string b_layout, int arg if(a_layout == "R" && b_layout == "C") { return run_gemm_test_with_layouts( - argc, argv, Row{}, Col{}, Row{}); + arg_parser, Row{}, Col{}, Row{}); } else if(a_layout == "C" && b_layout == "C") { return run_gemm_test_with_layouts( - argc, argv, Col{}, Col{}, Row{}); + arg_parser, Col{}, Col{}, Row{}); } else { @@ -159,22 +161,22 @@ bool run_gemm_test_prec_type(std::string a_layout, std::string b_layout, int arg if(a_layout == "R" && b_layout == "C") { return run_gemm_test_with_layouts( - argc, argv, Row{}, Col{}, Row{}); + arg_parser, Row{}, Col{}, Row{}); } else if(a_layout == "R" && b_layout == "R") { return run_gemm_test_with_layouts( - argc, argv, Row{}, Row{}, Row{}); + arg_parser, Row{}, Row{}, Row{}); } else if(a_layout == "C" && b_layout == "R") { return run_gemm_test_with_layouts( - argc, argv, Col{}, Row{}, Row{}); + arg_parser, Col{}, Row{}, Row{}); } else if(a_layout == "C" && b_layout == "C") { return run_gemm_test_with_layouts( - argc, argv, Col{}, Col{}, Row{}); + arg_parser, Col{}, Col{}, Row{}); } else { @@ -183,60 +185,26 @@ bool run_gemm_test_prec_type(std::string a_layout, std::string b_layout, int arg } } +template bool run_gemm_test(int argc, char* argv[]) { auto [result, arg_parser] = create_args(argc, argv); if(!result) return false; - std::string data_type = arg_parser.get_str("prec"); - std::string a_layout = arg_parser.get_str("a_layout"); - std::string b_layout = arg_parser.get_str("b_layout"); + std::string a_layout = arg_parser.get_str("a_layout"); + std::string b_layout = arg_parser.get_str("b_layout"); - if(data_type == "fp16") - { - return run_gemm_test_prec_type(a_layout, b_layout, argc, argv); - } - else if(data_type == "bf16") - { - return run_gemm_test_prec_type(a_layout, b_layout, argc, argv); - } - else if(data_type == "fp8") - { - return run_gemm_test_prec_type( - a_layout, b_layout, argc, argv); - } - else if(data_type == "bf8") - { - return run_gemm_test_prec_type( - a_layout, b_layout, argc, argv); - } - else if(data_type == "pk_int4_t") - { - // TODO: Add support for bhalf_t ADataType - if constexpr(GemmConfigBase::Pipeline == CK_TILE_PIPELINE_COMPUTE_V3) - { - return run_gemm_test_prec_type( - a_layout, b_layout, argc, argv); - } - else - { - throw std::runtime_error("Unsupported data type for this operation !!!"); - } - } - else - { - throw std::runtime_error("Unsupported data type for this operation !!!"); - } + return run_gemm_test_prec_type(a_layout, b_layout, arg_parser); } -int run_gemm_combinations(std::string const& data_type) +template +int run_gemm_combinations() { // Define possible values for each parameter - std::vector m_values = {"128", "1024"}; - std::vector n_values = {"128", "2048"}; - std::vector k_values = {"64", "128"}; - std::vector prec_values = {data_type}; + std::vector m_values = {"128", "1024"}; + std::vector n_values = {"128", "2048"}; + std::vector k_values = {"64", "128"}; // We'll store all our arguments as strings first std::vector arg_strings = {"./bin/tile_example_gemm_basic", @@ -246,13 +214,12 @@ int run_gemm_combinations(std::string const& data_type) "-stride_a=0", "-stride_b=0", "-stride_c=0", - "", // prec placeholder "-v=2", "-warmup=0", "-repeat=1"}; // Create an array of const char pointers for argv - constexpr size_t ARG_COUNT = 11; + constexpr size_t ARG_COUNT = 10; constexpr size_t ARG_MAX_LEN = 64; char args[ARG_COUNT][ARG_MAX_LEN]; char* argv[ARG_COUNT]; @@ -271,39 +238,35 @@ int run_gemm_combinations(std::string const& data_type) { arg_strings[3] = "-k=" + k; - for(const auto& prec : prec_values) + // Set up the argv array with pointers to the string data + for(size_t i = 0; i < ARG_COUNT; i++) { - arg_strings[7] = "-prec=" + prec; + strncpy(args[i], arg_strings[i].c_str(), ARG_MAX_LEN); + argv[i] = args[i]; + } - // Set up the argv array with pointers to the string data - for(size_t i = 0; i < ARG_COUNT; i++) - { - strncpy(args[i], arg_strings[i].c_str(), ARG_MAX_LEN); - argv[i] = args[i]; - } + std::cout << "Arguments received: "; + for(size_t i = 1; i < ARG_COUNT; ++i) + { + std::cout << argv[i] << " "; + } + std::cout << std::endl; - std::cout << "Arguments received: "; - for(size_t i = 1; i < ARG_COUNT; ++i) - { - std::cout << argv[i] << " "; - } - std::cout << std::endl; - - // Call the function with the current configuration - try - { - is_success = run_gemm_test(ARG_COUNT, argv) && is_success; - } - catch(const ArgumentsNotSupportedException& e) - { - std::cerr << "Caught ArgumentsNotSupportedException: " << e.what() << '\n'; - // ArgumentsNotSupportedException is not an error. Do not change is_success - } - catch(const std::runtime_error& e) - { - std::cerr << "Caught runtime error: " << e.what() << '\n'; - is_success = false; - } + // Call the function with the current configuration + try + { + is_success = run_gemm_test(ARG_COUNT, argv) && + is_success; + } + catch(const ArgumentsNotSupportedException& e) + { + std::cerr << "Caught ArgumentsNotSupportedException: " << e.what() << '\n'; + // ArgumentsNotSupportedException is not an error. Do not change is_success + } + catch(const std::runtime_error& e) + { + std::cerr << "Caught runtime error: " << e.what() << '\n'; + is_success = false; } } } diff --git a/test/ck_tile/gemm/test_gemm_pipeline_smoke_run_test.inc b/test/ck_tile/gemm/test_gemm_pipeline_smoke_run_test.inc index a967b92e7f..ab74e4e7b1 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_smoke_run_test.inc +++ b/test/ck_tile/gemm/test_gemm_pipeline_smoke_run_test.inc @@ -256,16 +256,11 @@ template -bool run_gemm_test_with_layouts(int argc, - char* argv[], +bool run_gemm_test_with_layouts(ck_tile::ArgParser& arg_parser, const ALayout a_layout = ALayout{}, const BLayout b_layout = BLayout{}, [[maybe_unused]] const CLayout c_layout = CLayout{}) { - auto [result, arg_parser] = create_args(argc, argv); - if(!result) - return false; - using AccDataType = typename GemmTypeConfig::AccDataType; ck_tile::index_t M = arg_parser.get_int("m"); diff --git a/test/ck_tile/gemm/test_gemm_pipeline_universal_bf16.cpp b/test/ck_tile/gemm/test_gemm_pipeline_universal_bf16.cpp index 0673272f5f..1336f6fd70 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_universal_bf16.cpp +++ b/test/ck_tile/gemm/test_gemm_pipeline_universal_bf16.cpp @@ -1,16 +1,9 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. -#include -#include - -#include -#include -#include - #include "ck_tile/host.hpp" #include "test_gemm_pipeline_smoke_util.hpp" #include "test_gemm_pipeline_smoke_run_test.inc" #include "test_gemm_pipeline_universal_run_test.inc" -int main() { return run_gemm_combinations("bf16"); } +int main() { return run_gemm_combinations(); } diff --git a/test/ck_tile/gemm/test_gemm_pipeline_universal_bf8.cpp b/test/ck_tile/gemm/test_gemm_pipeline_universal_bf8.cpp index 70eae12e82..5d55f34b84 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_universal_bf8.cpp +++ b/test/ck_tile/gemm/test_gemm_pipeline_universal_bf8.cpp @@ -1,16 +1,9 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. -#include -#include - -#include -#include -#include - #include "ck_tile/host.hpp" #include "test_gemm_pipeline_smoke_util.hpp" #include "test_gemm_pipeline_smoke_run_test.inc" #include "test_gemm_pipeline_universal_run_test.inc" -int main() { return run_gemm_combinations("bf8"); } +int main() { return run_gemm_combinations(); } diff --git a/test/ck_tile/gemm/test_gemm_pipeline_universal_fp16.cpp b/test/ck_tile/gemm/test_gemm_pipeline_universal_fp16.cpp index 8ea192c7f3..0cebbcc721 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_universal_fp16.cpp +++ b/test/ck_tile/gemm/test_gemm_pipeline_universal_fp16.cpp @@ -1,16 +1,9 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. -#include -#include - -#include -#include -#include - #include "ck_tile/host.hpp" #include "test_gemm_pipeline_smoke_util.hpp" #include "test_gemm_pipeline_smoke_run_test.inc" #include "test_gemm_pipeline_universal_run_test.inc" -int main() { return run_gemm_combinations("fp16"); } +int main() { return run_gemm_combinations(); } diff --git a/test/ck_tile/gemm/test_gemm_pipeline_universal_fp8.cpp b/test/ck_tile/gemm/test_gemm_pipeline_universal_fp8.cpp index 20414b4fec..29fb5f87ce 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_universal_fp8.cpp +++ b/test/ck_tile/gemm/test_gemm_pipeline_universal_fp8.cpp @@ -1,16 +1,9 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. -#include -#include - -#include -#include -#include - #include "ck_tile/host.hpp" #include "test_gemm_pipeline_smoke_util.hpp" #include "test_gemm_pipeline_smoke_run_test.inc" #include "test_gemm_pipeline_universal_run_test.inc" -int main() { return run_gemm_combinations("fp8"); } +int main() { return run_gemm_combinations(); } diff --git a/test/ck_tile/gemm/test_gemm_pipeline_universal_run_test.inc b/test/ck_tile/gemm/test_gemm_pipeline_universal_run_test.inc index adae8dcf92..fd50596f2f 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_universal_run_test.inc +++ b/test/ck_tile/gemm/test_gemm_pipeline_universal_run_test.inc @@ -200,7 +200,9 @@ template -bool run_gemm_test_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[]) +bool run_gemm_test_prec_type(std::string a_layout, + std::string b_layout, + ck_tile::ArgParser& arg_parser) { using Row = ck_tile::tensor_layout::gemm::RowMajor; using Col = ck_tile::tensor_layout::gemm::ColumnMajor; @@ -210,12 +212,12 @@ bool run_gemm_test_prec_type(std::string a_layout, std::string b_layout, int arg if(a_layout == "R" && b_layout == "C") { return run_gemm_test_with_layouts( - argc, argv, Row{}, Col{}, Row{}); + arg_parser, Row{}, Col{}, Row{}); } else if(a_layout == "C" && b_layout == "C") { return run_gemm_test_with_layouts( - argc, argv, Col{}, Col{}, Row{}); + arg_parser, Col{}, Col{}, Row{}); } else { @@ -228,22 +230,22 @@ bool run_gemm_test_prec_type(std::string a_layout, std::string b_layout, int arg if(a_layout == "R" && b_layout == "R") { return run_gemm_test_with_layouts( - argc, argv, Row{}, Row{}, Row{}); + arg_parser, Row{}, Row{}, Row{}); } else if(a_layout == "R" && b_layout == "C") { return run_gemm_test_with_layouts( - argc, argv, Row{}, Col{}, Row{}); + arg_parser, Row{}, Col{}, Row{}); } else if(a_layout == "C" && b_layout == "R") { return run_gemm_test_with_layouts( - argc, argv, Col{}, Row{}, Row{}); + arg_parser, Col{}, Row{}, Row{}); } else if(a_layout == "C" && b_layout == "C") { return run_gemm_test_with_layouts( - argc, argv, Col{}, Col{}, Row{}); + arg_parser, Col{}, Col{}, Row{}); } else { @@ -252,69 +254,27 @@ bool run_gemm_test_prec_type(std::string a_layout, std::string b_layout, int arg } } -template