diff --git a/nvbench/benchmark.cuh b/nvbench/benchmark.cuh index cc0d3ed..5e050d1 100644 --- a/nvbench/benchmark.cuh +++ b/nvbench/benchmark.cuh @@ -64,11 +64,6 @@ struct benchmark final : public benchmark_base : benchmark_base(type_axes{}) {} - // Note that this inline virtual dtor may cause vtable issues if linking - // benchmark TUs together. That's not a likely scenario, so we'll deal with - // that if it comes up. - ~benchmark() override = default; - private: std::unique_ptr do_clone() const final { diff --git a/nvbench/benchmark_base.cuh b/nvbench/benchmark_base.cuh index c77c496..588445d 100644 --- a/nvbench/benchmark_base.cuh +++ b/nvbench/benchmark_base.cuh @@ -145,6 +145,11 @@ struct benchmark_base return m_axes; } + // Computes the number of configs in the benchmark. + // Unlike get_states().size(), this method may be used prior to calling run(). + [[nodiscard]] std::size_t get_config_count() const; + + // Is empty until run() is called. [[nodiscard]] const std::vector &get_states() const { return m_states; diff --git a/nvbench/benchmark_base.cxx b/nvbench/benchmark_base.cxx index c981df8..b3cd649 100644 --- a/nvbench/benchmark_base.cxx +++ b/nvbench/benchmark_base.cxx @@ -18,6 +18,8 @@ #include +#include + namespace nvbench { @@ -58,4 +60,15 @@ benchmark_base &benchmark_base::add_device(int device_id) return this->add_device(device_info{device_id}); } +std::size_t benchmark_base::get_config_count() const +{ + return nvbench::detail::transform_reduce(m_axes.get_axes().cbegin(), + m_axes.get_axes().cend(), + std::size_t{1}, + std::multiplies<>{}, + [](const auto &axis_ptr) { + return axis_ptr->get_size(); + }); +} + } // namespace nvbench diff --git a/nvbench/csv_printer.cuh b/nvbench/csv_printer.cuh index a004912..ad90efd 100644 --- a/nvbench/csv_printer.cuh +++ b/nvbench/csv_printer.cuh @@ -30,7 +30,7 @@ struct csv_printer : nvbench::printer_base { using printer_base::printer_base; -private: +protected: // Virtual API from printer_base: void do_print_benchmark_results(const benchmark_vector &benches) override; }; diff --git a/nvbench/json_printer.cuh b/nvbench/json_printer.cuh index 6e5010a..0f011a7 100644 --- a/nvbench/json_printer.cuh +++ b/nvbench/json_printer.cuh @@ -30,7 +30,7 @@ struct json_printer : nvbench::printer_base { using printer_base::printer_base; -private: +protected: // Virtual API from printer_base: void do_print_benchmark_results(const benchmark_vector &benches) override; }; diff --git a/nvbench/main.cuh b/nvbench/main.cuh index a16128c..b81264e 100644 --- a/nvbench/main.cuh +++ b/nvbench/main.cuh @@ -63,6 +63,15 @@ printer.print_device_info(); \ printer.print_log_preamble(); \ auto &benchmarks = parser.get_benchmarks(); \ + \ + std::size_t total_states = 0; \ + for (auto &bench_ptr : benchmarks) \ + { \ + total_states += bench_ptr->get_config_count(); \ + } \ + printer.set_total_state_count(total_states); \ + \ + printer.set_completed_state_count(0); \ for (auto &bench_ptr : benchmarks) \ { \ bench_ptr->set_printer(printer); \ diff --git a/nvbench/markdown_printer.cu b/nvbench/markdown_printer.cu index 5e079ba..26966ca 100644 --- a/nvbench/markdown_printer.cu +++ b/nvbench/markdown_printer.cu @@ -23,8 +23,6 @@ #include #include -#include - #include #include @@ -147,7 +145,19 @@ void markdown_printer::do_log(nvbench::log_level level, const std::string &msg) void markdown_printer::do_log_run_state(const nvbench::state &exec_state) { - this->log(nvbench::log_level::run, exec_state.get_short_description(m_color)); + if (m_total_state_count == 0) + { // No progress info + this->log(nvbench::log_level::run, + exec_state.get_short_description(m_color)); + } + else + { // Add progress + this->log(nvbench::log_level::run, + fmt::format("[{}/{}] {}", + m_completed_state_count + 1, + m_total_state_count, + exec_state.get_short_description(m_color))); + } } void markdown_printer::do_print_benchmark_list( @@ -159,12 +169,7 @@ void markdown_printer::do_print_benchmark_list( for (const auto &bench_ptr : benches) { const auto &axes = bench_ptr->get_axes().get_axes(); - const std::size_t num_configs = nvbench::detail::transform_reduce( - axes.cbegin(), - axes.cend(), - std::size_t{1}, - std::multiplies<>{}, - [](const auto &axis_ptr) { return axis_ptr->get_size(); }); + const std::size_t num_configs = bench_ptr->get_config_count(); fmt::format_to(buffer, "## [{}] `{}` ({} configurations)\n\n", diff --git a/nvbench/markdown_printer.cuh b/nvbench/markdown_printer.cuh index 6f922ac..fa4fedc 100644 --- a/nvbench/markdown_printer.cuh +++ b/nvbench/markdown_printer.cuh @@ -50,7 +50,7 @@ struct markdown_printer : nvbench::printer_base [[nodiscard]] bool get_color() const { return m_color; } /*!@}*/ -private: +protected: // Virtual API from printer_base: void do_print_device_info() override; void do_print_log_preamble() override; diff --git a/nvbench/printer_base.cuh b/nvbench/printer_base.cuh index 19978dd..846b627 100644 --- a/nvbench/printer_base.cuh +++ b/nvbench/printer_base.cuh @@ -56,7 +56,7 @@ struct printer_base * Construct a new printer_base that will write to ostream. */ explicit printer_base(std::ostream &ostream); - ~printer_base(); + virtual ~printer_base(); // move-only printer_base(const printer_base &) = delete; @@ -113,10 +113,34 @@ struct printer_base this->do_print_benchmark_results(benches); } -protected: - std::ostream &m_ostream; + /*! + * Used to track progress for interactive progress display: + * + * - `completed_state_count`: Number of states with completed measurements. + * - `total_state_count`: Total number of states. + * @{ + */ + virtual void set_completed_state_count(std::size_t states) + { + this->do_set_completed_state_count(states); + } + virtual void add_completed_state() { this->do_add_completed_state(); } + [[nodiscard]] virtual std::size_t get_completed_state_count() const + { + return this->do_get_completed_state_count(); + } -private: + virtual void set_total_state_count(std::size_t states) + { + this->do_set_total_state_count(states); + } + [[nodiscard]] virtual std::size_t get_total_state_count() const + { + return this->do_get_total_state_count(); + } + /*!@}*/ + +protected: // Implementation hooks for subclasses: virtual void do_print_device_info() {} virtual void do_print_log_preamble() {} @@ -125,6 +149,18 @@ private: virtual void do_log_run_state(const nvbench::state &) {} virtual void do_print_benchmark_list(const benchmark_vector &) {} virtual void do_print_benchmark_results(const benchmark_vector &) {} + + virtual void do_set_completed_state_count(std::size_t states); + virtual void do_add_completed_state(); + [[nodiscard]] virtual std::size_t do_get_completed_state_count() const; + + virtual void do_set_total_state_count(std::size_t states); + [[nodiscard]] virtual std::size_t do_get_total_state_count() const; + + std::ostream &m_ostream; + + std::size_t m_completed_state_count{}; + std::size_t m_total_state_count{}; }; } // namespace nvbench diff --git a/nvbench/printer_base.cxx b/nvbench/printer_base.cxx index 8a6aa2b..5943bab 100644 --- a/nvbench/printer_base.cxx +++ b/nvbench/printer_base.cxx @@ -30,4 +30,26 @@ printer_base::printer_base(std::ostream &ostream) // Defined here to keep out of the header printer_base::~printer_base() = default; +void printer_base::do_set_completed_state_count(std::size_t states) +{ + m_completed_state_count = states; +} + +void printer_base::do_add_completed_state() { ++m_completed_state_count; } + +std::size_t printer_base::do_get_completed_state_count() const +{ + return m_completed_state_count; +} + +void printer_base::do_set_total_state_count(std::size_t states) +{ + m_total_state_count = states; +} + +std::size_t printer_base::do_get_total_state_count() const +{ + return m_total_state_count; +} + } // namespace nvbench diff --git a/nvbench/printer_multiplex.cuh b/nvbench/printer_multiplex.cuh index d34ceb2..c20151d 100644 --- a/nvbench/printer_multiplex.cuh +++ b/nvbench/printer_multiplex.cuh @@ -31,7 +31,6 @@ namespace nvbench */ struct printer_multiplex : nvbench::printer_base { - printer_multiplex(); template @@ -46,7 +45,7 @@ struct printer_multiplex : nvbench::printer_base return m_printers.size(); } -private: +protected: void do_print_device_info() override; void do_print_log_preamble() override; void do_print_log_epilogue() override; @@ -54,6 +53,9 @@ private: void do_log_run_state(const nvbench::state &) override; void do_print_benchmark_list(const benchmark_vector &benches) override; void do_print_benchmark_results(const benchmark_vector &benches) override; + void do_set_completed_state_count(std::size_t states) override; + void do_add_completed_state() override; + void do_set_total_state_count(std::size_t states) override; std::vector> m_printers; }; diff --git a/nvbench/printer_multiplex.cxx b/nvbench/printer_multiplex.cxx index 3e65a11..7ec2b0e 100644 --- a/nvbench/printer_multiplex.cxx +++ b/nvbench/printer_multiplex.cxx @@ -83,5 +83,31 @@ void printer_multiplex::do_print_benchmark_results( format_ptr->print_benchmark_results(benches); } } +void printer_multiplex::do_set_completed_state_count(std::size_t states) +{ + printer_base::do_set_completed_state_count(states); + for (auto &format_ptr : m_printers) + { + format_ptr->set_completed_state_count(states); + } +} + +void printer_multiplex::do_add_completed_state() +{ + printer_base::do_add_completed_state(); + for (auto &format_ptr : m_printers) + { + format_ptr->add_completed_state(); + } +} + +void printer_multiplex::do_set_total_state_count(std::size_t states) +{ + printer_base::do_set_total_state_count(states); + for (auto &format_ptr : m_printers) + { + format_ptr->set_total_state_count(states); + } +} } // namespace nvbench diff --git a/nvbench/runner.cuh b/nvbench/runner.cuh index 6edbf01..9435906 100644 --- a/nvbench/runner.cuh +++ b/nvbench/runner.cuh @@ -40,7 +40,8 @@ struct runner_base void handle_sampling_exception(const std::exception &e, nvbench::state &exec_state) const; - void announce_state(state &exec_state) const; + void run_state_prologue(state &exec_state) const; + void run_state_epilogue(state &exec_state) const; void print_skip_notification(nvbench::state &exec_state) const; @@ -98,7 +99,7 @@ private: if (cur_state.get_device() == device && cur_state.get_type_config_index() == type_config_index) { - self.announce_state(cur_state); + self.run_state_prologue(cur_state); try { kernel_generator{}(cur_state, type_config{}); @@ -111,6 +112,7 @@ private: { self.handle_sampling_exception(e, cur_state); } + self.run_state_epilogue(cur_state); } } diff --git a/nvbench/runner.cxx b/nvbench/runner.cxx index 5203f77..3aba964 100644 --- a/nvbench/runner.cxx +++ b/nvbench/runner.cxx @@ -59,7 +59,7 @@ void runner_base::handle_sampling_exception(const std::exception &e, } } -void runner_base::announce_state(nvbench::state &exec_state) const +void runner_base::run_state_prologue(nvbench::state &exec_state) const { // Log if a printer exists: if (auto printer_opt_ref = exec_state.get_benchmark().get_printer(); @@ -70,6 +70,18 @@ void runner_base::announce_state(nvbench::state &exec_state) const } } +void runner_base::run_state_epilogue(state &exec_state) const +{ + // Notify the printer that the state has completed:: + if (auto printer_opt_ref = exec_state.get_benchmark().get_printer(); + printer_opt_ref.has_value()) + { + auto &printer = printer_opt_ref.value().get(); + printer.add_completed_state(); + } +} + + void runner_base::print_skip_notification(state &exec_state) const { if (auto printer_opt_ref = exec_state.get_benchmark().get_printer();