mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 10:09:41 +00:00
Lwpck 3548 gemm test cleanups (#2717)
* Remove some unnecessary calls to create_args in basic and universal GEMM tests
* Remove unnecessary include statements in universal GEMM tests
* Improve compilation time of basic GEMM tests by only compiling the precision variants that we need
* Universal GEMM PrecType should be the same as CDataType
* Improve compilation time of universal GEMM tests by only compiling the precision variants that we need
* Revert to constexpr when defining some constants
[ROCm/composable_kernel commit: 5e85c38d7d]
This commit is contained in:
@@ -2,4 +2,4 @@
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
#include "test_gemm_pipeline_basic_run_test.inc"
|
||||
|
||||
int main() { return run_gemm_combinations("bf16"); }
|
||||
int main() { return run_gemm_combinations<ck_tile::bf16_t>(); }
|
||||
|
||||
@@ -2,4 +2,4 @@
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
#include "test_gemm_pipeline_basic_run_test.inc"
|
||||
|
||||
int main() { return run_gemm_combinations("bf8"); }
|
||||
int main() { return run_gemm_combinations<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t>(); }
|
||||
|
||||
@@ -2,4 +2,4 @@
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
#include "test_gemm_pipeline_basic_run_test.inc"
|
||||
|
||||
int main() { return run_gemm_combinations("fp16"); }
|
||||
int main() { return run_gemm_combinations<ck_tile::half_t>(); }
|
||||
|
||||
@@ -2,4 +2,4 @@
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
#include "test_gemm_pipeline_basic_run_test.inc"
|
||||
|
||||
int main() { return run_gemm_combinations("bf8"); }
|
||||
int main() { return run_gemm_combinations<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t>(); }
|
||||
|
||||
@@ -131,7 +131,9 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
|
||||
}
|
||||
|
||||
template <typename APrecType, typename BPrecType = APrecType, typename CPrecType = APrecType>
|
||||
bool run_gemm_test_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[])
|
||||
bool run_gemm_test_prec_type(std::string a_layout,
|
||||
std::string b_layout,
|
||||
ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
@@ -141,12 +143,12 @@ bool run_gemm_test_prec_type(std::string a_layout, std::string b_layout, int arg
|
||||
if(a_layout == "R" && b_layout == "C")
|
||||
{
|
||||
return run_gemm_test_with_layouts<GemmConfigBase, APrecType, BPrecType, CPrecType>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
arg_parser, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(a_layout == "C" && b_layout == "C")
|
||||
{
|
||||
return run_gemm_test_with_layouts<GemmConfigBase, APrecType, BPrecType, CPrecType>(
|
||||
argc, argv, Col{}, Col{}, Row{});
|
||||
arg_parser, Col{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -159,22 +161,22 @@ bool run_gemm_test_prec_type(std::string a_layout, std::string b_layout, int arg
|
||||
if(a_layout == "R" && b_layout == "C")
|
||||
{
|
||||
return run_gemm_test_with_layouts<GemmConfigBase, APrecType, BPrecType, CPrecType>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
arg_parser, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(a_layout == "R" && b_layout == "R")
|
||||
{
|
||||
return run_gemm_test_with_layouts<GemmConfigBase, APrecType, BPrecType, CPrecType>(
|
||||
argc, argv, Row{}, Row{}, Row{});
|
||||
arg_parser, Row{}, Row{}, Row{});
|
||||
}
|
||||
else if(a_layout == "C" && b_layout == "R")
|
||||
{
|
||||
return run_gemm_test_with_layouts<GemmConfigBase, APrecType, BPrecType, CPrecType>(
|
||||
argc, argv, Col{}, Row{}, Row{});
|
||||
arg_parser, Col{}, Row{}, Row{});
|
||||
}
|
||||
else if(a_layout == "C" && b_layout == "C")
|
||||
{
|
||||
return run_gemm_test_with_layouts<GemmConfigBase, APrecType, BPrecType, CPrecType>(
|
||||
argc, argv, Col{}, Col{}, Row{});
|
||||
arg_parser, Col{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -183,60 +185,26 @@ bool run_gemm_test_prec_type(std::string a_layout, std::string b_layout, int arg
|
||||
}
|
||||
}
|
||||
|
||||
template <typename APrecType, typename BPrecType, typename CPrecType>
|
||||
bool run_gemm_test(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return false;
|
||||
|
||||
std::string data_type = arg_parser.get_str("prec");
|
||||
std::string a_layout = arg_parser.get_str("a_layout");
|
||||
std::string b_layout = arg_parser.get_str("b_layout");
|
||||
std::string a_layout = arg_parser.get_str("a_layout");
|
||||
std::string b_layout = arg_parser.get_str("b_layout");
|
||||
|
||||
if(data_type == "fp16")
|
||||
{
|
||||
return run_gemm_test_prec_type<ck_tile::half_t>(a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(data_type == "bf16")
|
||||
{
|
||||
return run_gemm_test_prec_type<ck_tile::bf16_t>(a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(data_type == "fp8")
|
||||
{
|
||||
return run_gemm_test_prec_type<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(data_type == "bf8")
|
||||
{
|
||||
return run_gemm_test_prec_type<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(data_type == "pk_int4_t")
|
||||
{
|
||||
// TODO: Add support for bhalf_t ADataType
|
||||
if constexpr(GemmConfigBase::Pipeline == CK_TILE_PIPELINE_COMPUTE_V3)
|
||||
{
|
||||
return run_gemm_test_prec_type<ck_tile::half_t, ck_tile::pk_int4_t, ck_tile::half_t>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data type for this operation !!!");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data type for this operation !!!");
|
||||
}
|
||||
return run_gemm_test_prec_type<APrecType, BPrecType, CPrecType>(a_layout, b_layout, arg_parser);
|
||||
}
|
||||
|
||||
int run_gemm_combinations(std::string const& data_type)
|
||||
template <typename APrecType, typename BPrecType = APrecType, typename CPrecType = APrecType>
|
||||
int run_gemm_combinations()
|
||||
{
|
||||
// Define possible values for each parameter
|
||||
std::vector<std::string> m_values = {"128", "1024"};
|
||||
std::vector<std::string> n_values = {"128", "2048"};
|
||||
std::vector<std::string> k_values = {"64", "128"};
|
||||
std::vector<std::string> prec_values = {data_type};
|
||||
std::vector<std::string> m_values = {"128", "1024"};
|
||||
std::vector<std::string> n_values = {"128", "2048"};
|
||||
std::vector<std::string> k_values = {"64", "128"};
|
||||
|
||||
// We'll store all our arguments as strings first
|
||||
std::vector<std::string> arg_strings = {"./bin/tile_example_gemm_basic",
|
||||
@@ -246,13 +214,12 @@ int run_gemm_combinations(std::string const& data_type)
|
||||
"-stride_a=0",
|
||||
"-stride_b=0",
|
||||
"-stride_c=0",
|
||||
"", // prec placeholder
|
||||
"-v=2",
|
||||
"-warmup=0",
|
||||
"-repeat=1"};
|
||||
|
||||
// Create an array of const char pointers for argv
|
||||
constexpr size_t ARG_COUNT = 11;
|
||||
constexpr size_t ARG_COUNT = 10;
|
||||
constexpr size_t ARG_MAX_LEN = 64;
|
||||
char args[ARG_COUNT][ARG_MAX_LEN];
|
||||
char* argv[ARG_COUNT];
|
||||
@@ -271,39 +238,35 @@ int run_gemm_combinations(std::string const& data_type)
|
||||
{
|
||||
arg_strings[3] = "-k=" + k;
|
||||
|
||||
for(const auto& prec : prec_values)
|
||||
// Set up the argv array with pointers to the string data
|
||||
for(size_t i = 0; i < ARG_COUNT; i++)
|
||||
{
|
||||
arg_strings[7] = "-prec=" + prec;
|
||||
strncpy(args[i], arg_strings[i].c_str(), ARG_MAX_LEN);
|
||||
argv[i] = args[i];
|
||||
}
|
||||
|
||||
// Set up the argv array with pointers to the string data
|
||||
for(size_t i = 0; i < ARG_COUNT; i++)
|
||||
{
|
||||
strncpy(args[i], arg_strings[i].c_str(), ARG_MAX_LEN);
|
||||
argv[i] = args[i];
|
||||
}
|
||||
std::cout << "Arguments received: ";
|
||||
for(size_t i = 1; i < ARG_COUNT; ++i)
|
||||
{
|
||||
std::cout << argv[i] << " ";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
|
||||
std::cout << "Arguments received: ";
|
||||
for(size_t i = 1; i < ARG_COUNT; ++i)
|
||||
{
|
||||
std::cout << argv[i] << " ";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
|
||||
// Call the function with the current configuration
|
||||
try
|
||||
{
|
||||
is_success = run_gemm_test(ARG_COUNT, argv) && is_success;
|
||||
}
|
||||
catch(const ArgumentsNotSupportedException& e)
|
||||
{
|
||||
std::cerr << "Caught ArgumentsNotSupportedException: " << e.what() << '\n';
|
||||
// ArgumentsNotSupportedException is not an error. Do not change is_success
|
||||
}
|
||||
catch(const std::runtime_error& e)
|
||||
{
|
||||
std::cerr << "Caught runtime error: " << e.what() << '\n';
|
||||
is_success = false;
|
||||
}
|
||||
// Call the function with the current configuration
|
||||
try
|
||||
{
|
||||
is_success = run_gemm_test<APrecType, BPrecType, CPrecType>(ARG_COUNT, argv) &&
|
||||
is_success;
|
||||
}
|
||||
catch(const ArgumentsNotSupportedException& e)
|
||||
{
|
||||
std::cerr << "Caught ArgumentsNotSupportedException: " << e.what() << '\n';
|
||||
// ArgumentsNotSupportedException is not an error. Do not change is_success
|
||||
}
|
||||
catch(const std::runtime_error& e)
|
||||
{
|
||||
std::cerr << "Caught runtime error: " << e.what() << '\n';
|
||||
is_success = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -256,16 +256,11 @@ template <typename GemmConfig,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout>
|
||||
bool run_gemm_test_with_layouts(int argc,
|
||||
char* argv[],
|
||||
bool run_gemm_test_with_layouts(ck_tile::ArgParser& arg_parser,
|
||||
const ALayout a_layout = ALayout{},
|
||||
const BLayout b_layout = BLayout{},
|
||||
[[maybe_unused]] const CLayout c_layout = CLayout{})
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return false;
|
||||
|
||||
using AccDataType = typename GemmTypeConfig<ADataType, BDataType, CDataType>::AccDataType;
|
||||
|
||||
ck_tile::index_t M = arg_parser.get_int("m");
|
||||
|
||||
@@ -1,16 +1,9 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <cstddef>
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "test_gemm_pipeline_smoke_util.hpp"
|
||||
#include "test_gemm_pipeline_smoke_run_test.inc"
|
||||
#include "test_gemm_pipeline_universal_run_test.inc"
|
||||
|
||||
int main() { return run_gemm_combinations("bf16"); }
|
||||
int main() { return run_gemm_combinations<ck_tile::bf16_t>(); }
|
||||
|
||||
@@ -1,16 +1,9 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <cstddef>
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "test_gemm_pipeline_smoke_util.hpp"
|
||||
#include "test_gemm_pipeline_smoke_run_test.inc"
|
||||
#include "test_gemm_pipeline_universal_run_test.inc"
|
||||
|
||||
int main() { return run_gemm_combinations("bf8"); }
|
||||
int main() { return run_gemm_combinations<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t>(); }
|
||||
|
||||
@@ -1,16 +1,9 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <cstddef>
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "test_gemm_pipeline_smoke_util.hpp"
|
||||
#include "test_gemm_pipeline_smoke_run_test.inc"
|
||||
#include "test_gemm_pipeline_universal_run_test.inc"
|
||||
|
||||
int main() { return run_gemm_combinations("fp16"); }
|
||||
int main() { return run_gemm_combinations<ck_tile::half_t>(); }
|
||||
|
||||
@@ -1,16 +1,9 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <cstddef>
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "test_gemm_pipeline_smoke_util.hpp"
|
||||
#include "test_gemm_pipeline_smoke_run_test.inc"
|
||||
#include "test_gemm_pipeline_universal_run_test.inc"
|
||||
|
||||
int main() { return run_gemm_combinations("fp8"); }
|
||||
int main() { return run_gemm_combinations<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t>(); }
|
||||
|
||||
@@ -200,7 +200,9 @@ template <typename GemmConfig,
|
||||
typename APrecType,
|
||||
typename BPrecType = APrecType,
|
||||
typename CPrecType = APrecType>
|
||||
bool run_gemm_test_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[])
|
||||
bool run_gemm_test_prec_type(std::string a_layout,
|
||||
std::string b_layout,
|
||||
ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
@@ -210,12 +212,12 @@ bool run_gemm_test_prec_type(std::string a_layout, std::string b_layout, int arg
|
||||
if(a_layout == "R" && b_layout == "C")
|
||||
{
|
||||
return run_gemm_test_with_layouts<GemmConfig, APrecType, BPrecType, CPrecType>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
arg_parser, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(a_layout == "C" && b_layout == "C")
|
||||
{
|
||||
return run_gemm_test_with_layouts<GemmConfig, APrecType, BPrecType, CPrecType>(
|
||||
argc, argv, Col{}, Col{}, Row{});
|
||||
arg_parser, Col{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -228,22 +230,22 @@ bool run_gemm_test_prec_type(std::string a_layout, std::string b_layout, int arg
|
||||
if(a_layout == "R" && b_layout == "R")
|
||||
{
|
||||
return run_gemm_test_with_layouts<GemmConfig, APrecType, BPrecType, CPrecType>(
|
||||
argc, argv, Row{}, Row{}, Row{});
|
||||
arg_parser, Row{}, Row{}, Row{});
|
||||
}
|
||||
else if(a_layout == "R" && b_layout == "C")
|
||||
{
|
||||
return run_gemm_test_with_layouts<GemmConfig, APrecType, BPrecType, CPrecType>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
arg_parser, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(a_layout == "C" && b_layout == "R")
|
||||
{
|
||||
return run_gemm_test_with_layouts<GemmConfig, APrecType, BPrecType, CPrecType>(
|
||||
argc, argv, Col{}, Row{}, Row{});
|
||||
arg_parser, Col{}, Row{}, Row{});
|
||||
}
|
||||
else if(a_layout == "C" && b_layout == "C")
|
||||
{
|
||||
return run_gemm_test_with_layouts<GemmConfig, APrecType, BPrecType, CPrecType>(
|
||||
argc, argv, Col{}, Col{}, Row{});
|
||||
arg_parser, Col{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -252,69 +254,27 @@ bool run_gemm_test_prec_type(std::string a_layout, std::string b_layout, int arg
|
||||
}
|
||||
}
|
||||
|
||||
template <template <typename PreType> typename GemmConfig>
|
||||
template <typename GemmConfig, typename APrecType, typename BPrecType, typename CPrecType>
|
||||
bool run_gemm_test(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return false;
|
||||
|
||||
std::string data_type = arg_parser.get_str("prec");
|
||||
std::string a_layout = arg_parser.get_str("a_layout");
|
||||
std::string b_layout = arg_parser.get_str("b_layout");
|
||||
std::string a_layout = arg_parser.get_str("a_layout");
|
||||
std::string b_layout = arg_parser.get_str("b_layout");
|
||||
|
||||
if(data_type == "fp16")
|
||||
{
|
||||
return run_gemm_test_prec_type<GemmConfig<ck_tile::half_t>, ck_tile::half_t>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(data_type == "bf16")
|
||||
{
|
||||
return run_gemm_test_prec_type<GemmConfig<ck_tile::half_t>, ck_tile::bf16_t>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(data_type == "fp8")
|
||||
{
|
||||
return run_gemm_test_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
ck_tile::fp8_t,
|
||||
ck_tile::fp8_t,
|
||||
ck_tile::half_t>(a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(data_type == "bf8")
|
||||
{
|
||||
return run_gemm_test_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
ck_tile::bf8_t,
|
||||
ck_tile::bf8_t,
|
||||
ck_tile::half_t>(a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(data_type == "pk_int4_t")
|
||||
{
|
||||
// TODO: Add support for bhalf_t ADataType
|
||||
if constexpr(GemmConfig<ck_tile::half_t>::Pipeline == CK_TILE_PIPELINE_COMPUTE_V3)
|
||||
{
|
||||
return run_gemm_test_prec_type<GemmConfig<ck_tile::half_t>,
|
||||
ck_tile::half_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t>(a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported pipeline for this operation !!!");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data type for this operation !!!");
|
||||
}
|
||||
return run_gemm_test_prec_type<GemmConfig, APrecType, BPrecType, CPrecType>(
|
||||
a_layout, b_layout, arg_parser);
|
||||
}
|
||||
|
||||
int run_gemm_combinations(std::string const& data_type)
|
||||
template <typename APrecType, typename BPrecType = APrecType, typename CPrecType = APrecType>
|
||||
int run_gemm_combinations()
|
||||
{
|
||||
// Define possible values for each parameter
|
||||
std::vector<std::string> m_values = {"512", "1024"};
|
||||
std::vector<std::string> n_values = {"512", "2048"};
|
||||
std::vector<std::string> k_values = {"512", "1024"};
|
||||
std::vector<std::string> prec_values = {data_type};
|
||||
std::vector<std::string> m_values = {"512", "1024"};
|
||||
std::vector<std::string> n_values = {"512", "2048"};
|
||||
std::vector<std::string> k_values = {"512", "1024"};
|
||||
|
||||
// We'll store all our arguments as strings first
|
||||
std::vector<std::string> arg_strings = {"./bin/tile_example_gemm_universal",
|
||||
@@ -324,13 +284,12 @@ int run_gemm_combinations(std::string const& data_type)
|
||||
"-stride_a=0",
|
||||
"-stride_b=0",
|
||||
"-stride_c=0",
|
||||
"", // prec placeholder
|
||||
"-v=2",
|
||||
"-warmup=0",
|
||||
"-repeat=1"};
|
||||
|
||||
// Create an array of const char pointers for argv
|
||||
constexpr size_t ARG_COUNT = 11;
|
||||
constexpr size_t ARG_COUNT = 10;
|
||||
constexpr size_t ARG_MAX_LEN = 64;
|
||||
char args[ARG_COUNT][ARG_MAX_LEN];
|
||||
char* argv[ARG_COUNT];
|
||||
@@ -349,42 +308,43 @@ int run_gemm_combinations(std::string const& data_type)
|
||||
{
|
||||
arg_strings[3] = "-k=" + k;
|
||||
|
||||
for(const auto& prec : prec_values)
|
||||
// Set up the argv array with pointers to the string data
|
||||
for(size_t i = 0; i < ARG_COUNT; i++)
|
||||
{
|
||||
arg_strings[7] = "-prec=" + prec;
|
||||
strncpy(args[i], arg_strings[i].c_str(), ARG_MAX_LEN);
|
||||
argv[i] = args[i];
|
||||
}
|
||||
|
||||
// Set up the argv array with pointers to the string data
|
||||
for(size_t i = 0; i < ARG_COUNT; i++)
|
||||
{
|
||||
strncpy(args[i], arg_strings[i].c_str(), ARG_MAX_LEN);
|
||||
argv[i] = args[i];
|
||||
}
|
||||
std::cout << "Arguments received: ";
|
||||
for(size_t i = 1; i < ARG_COUNT; ++i)
|
||||
{
|
||||
std::cout << argv[i] << " ";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
|
||||
std::cout << "Arguments received: ";
|
||||
for(size_t i = 1; i < ARG_COUNT; ++i)
|
||||
{
|
||||
std::cout << argv[i] << " ";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
|
||||
// Call the function with the current configuration
|
||||
try
|
||||
{
|
||||
is_success =
|
||||
run_gemm_test<GemmConfigComputeV3>(ARG_COUNT, argv) && is_success;
|
||||
is_success =
|
||||
run_gemm_test<GemmConfigComputeV3_2>(ARG_COUNT, argv) && is_success;
|
||||
}
|
||||
catch(const ArgumentsNotSupportedException& e)
|
||||
{
|
||||
std::cerr << "Caught ArgumentsNotSupportedException: " << e.what() << '\n';
|
||||
// ArgumentsNotSupportedException is not an error. Do not change is_success
|
||||
}
|
||||
catch(const std::runtime_error& e)
|
||||
{
|
||||
std::cerr << "Caught runtime error: " << e.what() << '\n';
|
||||
is_success = false;
|
||||
}
|
||||
// Call the function with the current configuration
|
||||
try
|
||||
{
|
||||
is_success = run_gemm_test<GemmConfigComputeV3<CPrecType>,
|
||||
APrecType,
|
||||
BPrecType,
|
||||
CPrecType>(ARG_COUNT, argv) &&
|
||||
is_success;
|
||||
is_success = run_gemm_test<GemmConfigComputeV3_2<CPrecType>,
|
||||
APrecType,
|
||||
BPrecType,
|
||||
CPrecType>(ARG_COUNT, argv) &&
|
||||
is_success;
|
||||
}
|
||||
catch(const ArgumentsNotSupportedException& e)
|
||||
{
|
||||
std::cerr << "Caught ArgumentsNotSupportedException: " << e.what() << '\n';
|
||||
// ArgumentsNotSupportedException is not an error. Do not change is_success
|
||||
}
|
||||
catch(const std::runtime_error& e)
|
||||
{
|
||||
std::cerr << "Caught runtime error: " << e.what() << '\n';
|
||||
is_success = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user