mirror of
https://github.com/NVIDIA/nvbench.git
synced 2026-03-14 20:27:24 +00:00
Merge pull request #54 from allisonvacanti/progress_display
Print progress in markdown log.
This commit is contained in:
@@ -64,11 +64,6 @@ struct benchmark final : public benchmark_base
|
||||
: benchmark_base(type_axes{})
|
||||
{}
|
||||
|
||||
// Note that this inline virtual dtor may cause vtable issues if linking
|
||||
// benchmark TUs together. That's not a likely scenario, so we'll deal with
|
||||
// that if it comes up.
|
||||
~benchmark() override = default;
|
||||
|
||||
private:
|
||||
std::unique_ptr<benchmark_base> do_clone() const final
|
||||
{
|
||||
|
||||
@@ -145,6 +145,11 @@ struct benchmark_base
|
||||
return m_axes;
|
||||
}
|
||||
|
||||
// Computes the number of configs in the benchmark.
|
||||
// Unlike get_states().size(), this method may be used prior to calling run().
|
||||
[[nodiscard]] std::size_t get_config_count() const;
|
||||
|
||||
// Is empty until run() is called.
|
||||
[[nodiscard]] const std::vector<nvbench::state> &get_states() const
|
||||
{
|
||||
return m_states;
|
||||
|
||||
@@ -18,6 +18,8 @@
|
||||
|
||||
#include <nvbench/benchmark_base.cuh>
|
||||
|
||||
#include <nvbench/detail/transform_reduce.cuh>
|
||||
|
||||
namespace nvbench
|
||||
{
|
||||
|
||||
@@ -58,4 +60,15 @@ benchmark_base &benchmark_base::add_device(int device_id)
|
||||
return this->add_device(device_info{device_id});
|
||||
}
|
||||
|
||||
std::size_t benchmark_base::get_config_count() const
|
||||
{
|
||||
return nvbench::detail::transform_reduce(m_axes.get_axes().cbegin(),
|
||||
m_axes.get_axes().cend(),
|
||||
std::size_t{1},
|
||||
std::multiplies<>{},
|
||||
[](const auto &axis_ptr) {
|
||||
return axis_ptr->get_size();
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace nvbench
|
||||
|
||||
@@ -30,7 +30,7 @@ struct csv_printer : nvbench::printer_base
|
||||
{
|
||||
using printer_base::printer_base;
|
||||
|
||||
private:
|
||||
protected:
|
||||
// Virtual API from printer_base:
|
||||
void do_print_benchmark_results(const benchmark_vector &benches) override;
|
||||
};
|
||||
|
||||
@@ -30,7 +30,7 @@ struct json_printer : nvbench::printer_base
|
||||
{
|
||||
using printer_base::printer_base;
|
||||
|
||||
private:
|
||||
protected:
|
||||
// Virtual API from printer_base:
|
||||
void do_print_benchmark_results(const benchmark_vector &benches) override;
|
||||
};
|
||||
|
||||
@@ -63,6 +63,15 @@
|
||||
printer.print_device_info(); \
|
||||
printer.print_log_preamble(); \
|
||||
auto &benchmarks = parser.get_benchmarks(); \
|
||||
\
|
||||
std::size_t total_states = 0; \
|
||||
for (auto &bench_ptr : benchmarks) \
|
||||
{ \
|
||||
total_states += bench_ptr->get_config_count(); \
|
||||
} \
|
||||
printer.set_total_state_count(total_states); \
|
||||
\
|
||||
printer.set_completed_state_count(0); \
|
||||
for (auto &bench_ptr : benchmarks) \
|
||||
{ \
|
||||
bench_ptr->set_printer(printer); \
|
||||
|
||||
@@ -23,8 +23,6 @@
|
||||
#include <nvbench/state.cuh>
|
||||
#include <nvbench/summary.cuh>
|
||||
|
||||
#include <nvbench/detail/transform_reduce.cuh>
|
||||
|
||||
#include <nvbench/internal/markdown_table.cuh>
|
||||
|
||||
#include <fmt/color.h>
|
||||
@@ -147,7 +145,19 @@ void markdown_printer::do_log(nvbench::log_level level, const std::string &msg)
|
||||
|
||||
void markdown_printer::do_log_run_state(const nvbench::state &exec_state)
|
||||
{
|
||||
this->log(nvbench::log_level::run, exec_state.get_short_description(m_color));
|
||||
if (m_total_state_count == 0)
|
||||
{ // No progress info
|
||||
this->log(nvbench::log_level::run,
|
||||
exec_state.get_short_description(m_color));
|
||||
}
|
||||
else
|
||||
{ // Add progress
|
||||
this->log(nvbench::log_level::run,
|
||||
fmt::format("[{}/{}] {}",
|
||||
m_completed_state_count + 1,
|
||||
m_total_state_count,
|
||||
exec_state.get_short_description(m_color)));
|
||||
}
|
||||
}
|
||||
|
||||
void markdown_printer::do_print_benchmark_list(
|
||||
@@ -159,12 +169,7 @@ void markdown_printer::do_print_benchmark_list(
|
||||
for (const auto &bench_ptr : benches)
|
||||
{
|
||||
const auto &axes = bench_ptr->get_axes().get_axes();
|
||||
const std::size_t num_configs = nvbench::detail::transform_reduce(
|
||||
axes.cbegin(),
|
||||
axes.cend(),
|
||||
std::size_t{1},
|
||||
std::multiplies<>{},
|
||||
[](const auto &axis_ptr) { return axis_ptr->get_size(); });
|
||||
const std::size_t num_configs = bench_ptr->get_config_count();
|
||||
|
||||
fmt::format_to(buffer,
|
||||
"## [{}] `{}` ({} configurations)\n\n",
|
||||
|
||||
@@ -50,7 +50,7 @@ struct markdown_printer : nvbench::printer_base
|
||||
[[nodiscard]] bool get_color() const { return m_color; }
|
||||
/*!@}*/
|
||||
|
||||
private:
|
||||
protected:
|
||||
// Virtual API from printer_base:
|
||||
void do_print_device_info() override;
|
||||
void do_print_log_preamble() override;
|
||||
|
||||
@@ -56,7 +56,7 @@ struct printer_base
|
||||
* Construct a new printer_base that will write to ostream.
|
||||
*/
|
||||
explicit printer_base(std::ostream &ostream);
|
||||
~printer_base();
|
||||
virtual ~printer_base();
|
||||
|
||||
// move-only
|
||||
printer_base(const printer_base &) = delete;
|
||||
@@ -113,10 +113,34 @@ struct printer_base
|
||||
this->do_print_benchmark_results(benches);
|
||||
}
|
||||
|
||||
protected:
|
||||
std::ostream &m_ostream;
|
||||
/*!
|
||||
* Used to track progress for interactive progress display:
|
||||
*
|
||||
* - `completed_state_count`: Number of states with completed measurements.
|
||||
* - `total_state_count`: Total number of states.
|
||||
* @{
|
||||
*/
|
||||
virtual void set_completed_state_count(std::size_t states)
|
||||
{
|
||||
this->do_set_completed_state_count(states);
|
||||
}
|
||||
virtual void add_completed_state() { this->do_add_completed_state(); }
|
||||
[[nodiscard]] virtual std::size_t get_completed_state_count() const
|
||||
{
|
||||
return this->do_get_completed_state_count();
|
||||
}
|
||||
|
||||
private:
|
||||
virtual void set_total_state_count(std::size_t states)
|
||||
{
|
||||
this->do_set_total_state_count(states);
|
||||
}
|
||||
[[nodiscard]] virtual std::size_t get_total_state_count() const
|
||||
{
|
||||
return this->do_get_total_state_count();
|
||||
}
|
||||
/*!@}*/
|
||||
|
||||
protected:
|
||||
// Implementation hooks for subclasses:
|
||||
virtual void do_print_device_info() {}
|
||||
virtual void do_print_log_preamble() {}
|
||||
@@ -125,6 +149,18 @@ private:
|
||||
virtual void do_log_run_state(const nvbench::state &) {}
|
||||
virtual void do_print_benchmark_list(const benchmark_vector &) {}
|
||||
virtual void do_print_benchmark_results(const benchmark_vector &) {}
|
||||
|
||||
virtual void do_set_completed_state_count(std::size_t states);
|
||||
virtual void do_add_completed_state();
|
||||
[[nodiscard]] virtual std::size_t do_get_completed_state_count() const;
|
||||
|
||||
virtual void do_set_total_state_count(std::size_t states);
|
||||
[[nodiscard]] virtual std::size_t do_get_total_state_count() const;
|
||||
|
||||
std::ostream &m_ostream;
|
||||
|
||||
std::size_t m_completed_state_count{};
|
||||
std::size_t m_total_state_count{};
|
||||
};
|
||||
|
||||
} // namespace nvbench
|
||||
|
||||
@@ -30,4 +30,26 @@ printer_base::printer_base(std::ostream &ostream)
|
||||
// Defined here to keep <ostream> out of the header
|
||||
printer_base::~printer_base() = default;
|
||||
|
||||
void printer_base::do_set_completed_state_count(std::size_t states)
|
||||
{
|
||||
m_completed_state_count = states;
|
||||
}
|
||||
|
||||
void printer_base::do_add_completed_state() { ++m_completed_state_count; }
|
||||
|
||||
std::size_t printer_base::do_get_completed_state_count() const
|
||||
{
|
||||
return m_completed_state_count;
|
||||
}
|
||||
|
||||
void printer_base::do_set_total_state_count(std::size_t states)
|
||||
{
|
||||
m_total_state_count = states;
|
||||
}
|
||||
|
||||
std::size_t printer_base::do_get_total_state_count() const
|
||||
{
|
||||
return m_total_state_count;
|
||||
}
|
||||
|
||||
} // namespace nvbench
|
||||
|
||||
@@ -31,7 +31,6 @@ namespace nvbench
|
||||
*/
|
||||
struct printer_multiplex : nvbench::printer_base
|
||||
{
|
||||
|
||||
printer_multiplex();
|
||||
|
||||
template <typename Format, typename... Ts>
|
||||
@@ -46,7 +45,7 @@ struct printer_multiplex : nvbench::printer_base
|
||||
return m_printers.size();
|
||||
}
|
||||
|
||||
private:
|
||||
protected:
|
||||
void do_print_device_info() override;
|
||||
void do_print_log_preamble() override;
|
||||
void do_print_log_epilogue() override;
|
||||
@@ -54,6 +53,9 @@ private:
|
||||
void do_log_run_state(const nvbench::state &) override;
|
||||
void do_print_benchmark_list(const benchmark_vector &benches) override;
|
||||
void do_print_benchmark_results(const benchmark_vector &benches) override;
|
||||
void do_set_completed_state_count(std::size_t states) override;
|
||||
void do_add_completed_state() override;
|
||||
void do_set_total_state_count(std::size_t states) override;
|
||||
|
||||
std::vector<std::unique_ptr<nvbench::printer_base>> m_printers;
|
||||
};
|
||||
|
||||
@@ -83,5 +83,31 @@ void printer_multiplex::do_print_benchmark_results(
|
||||
format_ptr->print_benchmark_results(benches);
|
||||
}
|
||||
}
|
||||
void printer_multiplex::do_set_completed_state_count(std::size_t states)
|
||||
{
|
||||
printer_base::do_set_completed_state_count(states);
|
||||
for (auto &format_ptr : m_printers)
|
||||
{
|
||||
format_ptr->set_completed_state_count(states);
|
||||
}
|
||||
}
|
||||
|
||||
void printer_multiplex::do_add_completed_state()
|
||||
{
|
||||
printer_base::do_add_completed_state();
|
||||
for (auto &format_ptr : m_printers)
|
||||
{
|
||||
format_ptr->add_completed_state();
|
||||
}
|
||||
}
|
||||
|
||||
void printer_multiplex::do_set_total_state_count(std::size_t states)
|
||||
{
|
||||
printer_base::do_set_total_state_count(states);
|
||||
for (auto &format_ptr : m_printers)
|
||||
{
|
||||
format_ptr->set_total_state_count(states);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace nvbench
|
||||
|
||||
@@ -40,7 +40,8 @@ struct runner_base
|
||||
void handle_sampling_exception(const std::exception &e,
|
||||
nvbench::state &exec_state) const;
|
||||
|
||||
void announce_state(state &exec_state) const;
|
||||
void run_state_prologue(state &exec_state) const;
|
||||
void run_state_epilogue(state &exec_state) const;
|
||||
|
||||
void print_skip_notification(nvbench::state &exec_state) const;
|
||||
|
||||
@@ -98,7 +99,7 @@ private:
|
||||
if (cur_state.get_device() == device &&
|
||||
cur_state.get_type_config_index() == type_config_index)
|
||||
{
|
||||
self.announce_state(cur_state);
|
||||
self.run_state_prologue(cur_state);
|
||||
try
|
||||
{
|
||||
kernel_generator{}(cur_state, type_config{});
|
||||
@@ -111,6 +112,7 @@ private:
|
||||
{
|
||||
self.handle_sampling_exception(e, cur_state);
|
||||
}
|
||||
self.run_state_epilogue(cur_state);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -59,7 +59,7 @@ void runner_base::handle_sampling_exception(const std::exception &e,
|
||||
}
|
||||
}
|
||||
|
||||
void runner_base::announce_state(nvbench::state &exec_state) const
|
||||
void runner_base::run_state_prologue(nvbench::state &exec_state) const
|
||||
{
|
||||
// Log if a printer exists:
|
||||
if (auto printer_opt_ref = exec_state.get_benchmark().get_printer();
|
||||
@@ -70,6 +70,18 @@ void runner_base::announce_state(nvbench::state &exec_state) const
|
||||
}
|
||||
}
|
||||
|
||||
void runner_base::run_state_epilogue(state &exec_state) const
|
||||
{
|
||||
// Notify the printer that the state has completed::
|
||||
if (auto printer_opt_ref = exec_state.get_benchmark().get_printer();
|
||||
printer_opt_ref.has_value())
|
||||
{
|
||||
auto &printer = printer_opt_ref.value().get();
|
||||
printer.add_completed_state();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void runner_base::print_skip_notification(state &exec_state) const
|
||||
{
|
||||
if (auto printer_opt_ref = exec_state.get_benchmark().get_printer();
|
||||
|
||||
Reference in New Issue
Block a user