// SPDX-License-Identifier: MIT // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. int run_gemm_test(int argc, char* argv[]) { using Row = ck::tensor_layout::gemm::RowMajor; using Col = ck::tensor_layout::gemm::ColumnMajor; using PassThrough = ck::tensor_operation::element_wise::PassThrough; ck::gemm_util::GemmParams params; ck::index_t instance_index = -1; if(argc == 1) { // use default params } else if(argc == 4 || argc == 5) { params.M = atoi(argv[1]); params.N = atoi(argv[2]); params.K = atoi(argv[3]); params.StrideA = params.M; params.StrideB = params.N; params.StrideC = params.K; if(argc == 5) { instance_index = atoi(argv[4]); } } else { std::cout << "Usage of " << argv[0] << std::endl; std::cout << "Arg1-4: M N K instance_index(-1 means all)" << std::endl; } std::cout << "Params (M, N, K, index) " << params.M << " " << params.N << " " << params.K << " " << instance_index << std::endl; auto test = [&](auto a_layout, auto b_layout, auto c_layout) { bool pass = true; using DeviceOp = ck::tensor_operation::device::DeviceGemm; const auto gemmPtrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< DeviceOp>::GetInstances(); ck::index_t num_instance = 0; for(auto& gemmPtr : gemmPtrs) { if(instance_index == -1) { pass &= ck::gemm_util::TestGemm{}(gemmPtr.get(), params); } else { auto test_gemm = ck::gemm_util::TestGemm{}; if(test_gemm.IsSupportedArgument(gemmPtr.get(), params)) { if(num_instance == instance_index) { pass &= test_gemm(gemmPtr.get(), params); } num_instance++; } } } if(instance_index != -1) { std::cout << "TestGemm_instance (" << instance_index << "/" << num_instance << "): " << (pass ? "Passed" : "Failed") << std::endl; } return pass; }; bool pass = test(Row{}, Row{}, Row{}) && test(Row{}, Col{}, Row{}) && test(Col{}, Row{}, Row{}) && test(Col{}, Col{}, Row{}); std::cout << "TestGemm ..... " << (pass ? "SUCCESS" : "FAILURE") << std::endl; return pass ? 0 : 1; }