// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT #include #include #include #include #include #include #include "ck_tile/host.hpp" #include "gemm_utils.hpp" #include "run_gemm_example.inc" #include "gemm_weight_preshuffle_invoker.hpp" template int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, ck_tile::ArgParser& arg_parser) { using Row = ck_tile::tensor_layout::gemm::RowMajor; using Col = ck_tile::tensor_layout::gemm::ColumnMajor; bool preshuffle = GemmConfig::Preshuffle; using Invoker = WeightPreshuffleInvoker; if(preshuffle && (a_layout != "R" || b_layout != "C")) { throw std::runtime_error( "Preshuffle is supported only for A(Row major), B(column major) input matrices!"); } if(a_layout == "R" && b_layout == "C") { return run_gemm_example_with_layouts( arg_parser, Row{}, Col{}, Row{}); } else { throw std::runtime_error("Unsupported memory layout for the input matrices!"); } } template