diff --git a/python/src/py_nvbench.cpp b/python/src/py_nvbench.cpp index e1b86ea..159936f 100644 --- a/python/src/py_nvbench.cpp +++ b/python/src/py_nvbench.cpp @@ -119,6 +119,29 @@ struct nvbench_run_error : std::runtime_error }; py::handle benchmark_exc{}; +void run_interruptible(nvbench::option_parser &parser) +{ + auto &printer = parser.get_printer(); + auto &benchmarks = parser.get_benchmarks(); + + std::size_t total_states = 0; + for (auto &bench_ptr : benchmarks) + { + total_states += bench_ptr->get_config_count(); + } + + printer.set_completed_state_count(0); + printer.set_total_state_count(total_states); + + bool skip_remaining_flag = false; + for (auto &bench_ptr : benchmarks) + { + bench_ptr->set_printer(printer); + bench_ptr->run_or_skip(skip_remaining_flag); + bench_ptr->clear_printer(); + } +} + class GlobalBenchmarkRegistry { bool m_finalized; @@ -185,7 +208,7 @@ public: parser.parse(argv); NVBENCH_MAIN_PRINT_PREAMBLE(parser); - NVBENCH_MAIN_RUN_BENCHMARKS(parser); + run_interruptible(parser); NVBENCH_MAIN_PRINT_EPILOGUE(parser); NVBENCH_MAIN_PRINT_RESULTS(parser);