Add catch blocks in example GEMM apps to enable better error handling (Issue: 1928) (#2234)

* added catch statements to examples

* clang format
This commit is contained in:
Aviral Goel
2025-05-28 00:32:42 -05:00
committed by GitHub
parent 132bd5b874
commit c52649ad57
6 changed files with 69 additions and 15 deletions

View File

@@ -214,4 +214,15 @@ int run_gemm_example(int argc, char* argv[])
}
}
int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); }
int main(int argc, char* argv[])
{
try
{
return !run_gemm_example(argc, argv);
}
catch(const std::runtime_error& e)
{
std::cerr << "Runtime error: " << e.what() << '\n';
return EXIT_FAILURE;
}
}

View File

@@ -345,7 +345,7 @@ int main(int argc, char* argv[])
{
try
{
run_gemm_example(argc, argv);
return !run_gemm_example(argc, argv);
}
catch(const std::runtime_error& e)
{

View File

@@ -334,16 +334,26 @@ bool test_moe_sorting(ck_tile::ArgParser args)
int main(int argc, char** argv)
{
auto [result, args] = create_args(argc, argv);
if(!result)
return -1;
std::string index_prec = args.get_str("pr_i");
std::string weight_prec = args.get_str("pr_w");
bool r = true;
if(weight_prec.compare("fp32") == 0 && index_prec.compare("int32") == 0)
try
{
r &= test_moe_sorting<float, ck_tile::index_t>(args);
auto [result, args] = create_args(argc, argv);
if(!result)
return -1;
std::string index_prec = args.get_str("pr_i");
std::string weight_prec = args.get_str("pr_w");
bool r = true;
if(weight_prec == "fp32" && index_prec == "int32")
{
r &= test_moe_sorting<float, ck_tile::index_t>(args);
}
return r ? 0 : -1;
}
catch(const std::runtime_error& e)
{
std::cerr << "Runtime error: " << e.what() << '\n';
return EXIT_FAILURE;
}
return r ? 0 : -1;
}

View File

@@ -320,4 +320,15 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre
#include "run_batched_gemm_example.inc"
int main(int argc, char* argv[]) { return !run_batched_gemm_example(argc, argv); }
int main(int argc, char* argv[])
{
try
{
return !run_batched_gemm_example(argc, argv);
}
catch(const std::runtime_error& e)
{
std::cerr << "Runtime error: " << e.what() << '\n';
return EXIT_FAILURE;
}
}

View File

@@ -319,4 +319,15 @@ float grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
#include "run_grouped_gemm_example.inc"
constexpr bool Persistent = false;
int main(int argc, char* argv[]) { return !run_grouped_gemm_example<Persistent>(argc, argv); }
int main(int argc, char* argv[])
{
try
{
return !run_grouped_gemm_example<Persistent>(argc, argv);
}
catch(const std::runtime_error& e)
{
std::cerr << "Runtime error: " << e.what() << '\n';
return EXIT_FAILURE;
}
}

View File

@@ -177,4 +177,15 @@ int run_flatmm_example(int argc, char* argv[])
return -1;
}
int main(int argc, char* argv[]) { return !run_flatmm_example(argc, argv); }
int main(int argc, char* argv[])
{
try
{
return !run_flatmm_example(argc, argv);
}
catch(const std::runtime_error& e)
{
std::cerr << "Runtime error: " << e.what() << '\n';
return EXIT_FAILURE;
}
}