// SPDX-License-Identifier: MIT // Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include "gemm_utils.hpp" #include "run_gemm_example.inc" #include "run_gemm_example_common.hpp" #include "gemm_basic_invoker.hpp" int run_gemm_example(ck_tile::ArgParser& arg_parser) { std::string data_type = arg_parser.get_str("prec"); std::string a_layout = arg_parser.get_str("a_layout"); std::string b_layout = arg_parser.get_str("b_layout"); using GemmConfig = GemmConfigBase; using Invoker = BasicInvoker; if(data_type == "fp16") { return run_gemm_example_prec_type( a_layout, b_layout, arg_parser); } else if(data_type == "bf16") { return run_gemm_example_prec_type( a_layout, b_layout, arg_parser); } else if(data_type == "fp8") { return run_gemm_example_prec_type(a_layout, b_layout, arg_parser); } else if(data_type == "bf8") { return run_gemm_example_prec_type(a_layout, b_layout, arg_parser); } else if(data_type == "i8") { return run_gemm_example_prec_type(a_layout, b_layout, arg_parser); } else if(data_type == "pk_int4_t") { // TODO: Add support for bhalf_t ADataType if constexpr(GemmConfig::Pipeline == CK_TILE_PIPELINE_COMPUTE_V3) { return run_gemm_example_prec_type(a_layout, b_layout, arg_parser); } else { throw std::runtime_error("Unsupported data type for this operation !!!"); } } else { throw std::runtime_error("Unsupported data type for this operation !!!"); } } int main(int argc, char* argv[]) { auto arg_parser = create_args(); auto result = arg_parser.parse(argc, argv); if(!result) return -1; try { return !run_gemm_example(arg_parser); } catch(const std::runtime_error& e) { std::cerr << "Runtime error: " << e.what() << '\n'; return EXIT_FAILURE; } }