Merge pull request #54 from allisonvacanti/progress_display

Print progress in markdown log.
This commit is contained in:
Allison Vacanti
2021-12-20 11:44:50 -05:00
committed by GitHub
14 changed files with 153 additions and 26 deletions

View File

@@ -64,11 +64,6 @@ struct benchmark final : public benchmark_base
: benchmark_base(type_axes{})
{}
// Note that this inline virtual dtor may cause vtable issues if linking
// benchmark TUs together. That's not a likely scenario, so we'll deal with
// that if it comes up.
~benchmark() override = default;
private:
std::unique_ptr<benchmark_base> do_clone() const final
{

View File

@@ -145,6 +145,11 @@ struct benchmark_base
return m_axes;
}
// Computes the number of configs in the benchmark.
// Unlike get_states().size(), this method may be used prior to calling run().
[[nodiscard]] std::size_t get_config_count() const;
// Is empty until run() is called.
[[nodiscard]] const std::vector<nvbench::state> &get_states() const
{
return m_states;

View File

@@ -18,6 +18,8 @@
#include <nvbench/benchmark_base.cuh>
#include <nvbench/detail/transform_reduce.cuh>
namespace nvbench
{
@@ -58,4 +60,15 @@ benchmark_base &benchmark_base::add_device(int device_id)
return this->add_device(device_info{device_id});
}
std::size_t benchmark_base::get_config_count() const
{
return nvbench::detail::transform_reduce(m_axes.get_axes().cbegin(),
m_axes.get_axes().cend(),
std::size_t{1},
std::multiplies<>{},
[](const auto &axis_ptr) {
return axis_ptr->get_size();
});
}
} // namespace nvbench

View File

@@ -30,7 +30,7 @@ struct csv_printer : nvbench::printer_base
{
using printer_base::printer_base;
private:
protected:
// Virtual API from printer_base:
void do_print_benchmark_results(const benchmark_vector &benches) override;
};

View File

@@ -30,7 +30,7 @@ struct json_printer : nvbench::printer_base
{
using printer_base::printer_base;
private:
protected:
// Virtual API from printer_base:
void do_print_benchmark_results(const benchmark_vector &benches) override;
};

View File

@@ -63,6 +63,15 @@
printer.print_device_info(); \
printer.print_log_preamble(); \
auto &benchmarks = parser.get_benchmarks(); \
\
std::size_t total_states = 0; \
for (auto &bench_ptr : benchmarks) \
{ \
total_states += bench_ptr->get_config_count(); \
} \
printer.set_total_state_count(total_states); \
\
printer.set_completed_state_count(0); \
for (auto &bench_ptr : benchmarks) \
{ \
bench_ptr->set_printer(printer); \

View File

@@ -23,8 +23,6 @@
#include <nvbench/state.cuh>
#include <nvbench/summary.cuh>
#include <nvbench/detail/transform_reduce.cuh>
#include <nvbench/internal/markdown_table.cuh>
#include <fmt/color.h>
@@ -147,7 +145,19 @@ void markdown_printer::do_log(nvbench::log_level level, const std::string &msg)
void markdown_printer::do_log_run_state(const nvbench::state &exec_state)
{
this->log(nvbench::log_level::run, exec_state.get_short_description(m_color));
if (m_total_state_count == 0)
{ // No progress info
this->log(nvbench::log_level::run,
exec_state.get_short_description(m_color));
}
else
{ // Add progress
this->log(nvbench::log_level::run,
fmt::format("[{}/{}] {}",
m_completed_state_count + 1,
m_total_state_count,
exec_state.get_short_description(m_color)));
}
}
void markdown_printer::do_print_benchmark_list(
@@ -159,12 +169,7 @@ void markdown_printer::do_print_benchmark_list(
for (const auto &bench_ptr : benches)
{
const auto &axes = bench_ptr->get_axes().get_axes();
const std::size_t num_configs = nvbench::detail::transform_reduce(
axes.cbegin(),
axes.cend(),
std::size_t{1},
std::multiplies<>{},
[](const auto &axis_ptr) { return axis_ptr->get_size(); });
const std::size_t num_configs = bench_ptr->get_config_count();
fmt::format_to(buffer,
"## [{}] `{}` ({} configurations)\n\n",

View File

@@ -50,7 +50,7 @@ struct markdown_printer : nvbench::printer_base
[[nodiscard]] bool get_color() const { return m_color; }
/*!@}*/
private:
protected:
// Virtual API from printer_base:
void do_print_device_info() override;
void do_print_log_preamble() override;

View File

@@ -56,7 +56,7 @@ struct printer_base
* Construct a new printer_base that will write to ostream.
*/
explicit printer_base(std::ostream &ostream);
~printer_base();
virtual ~printer_base();
// move-only
printer_base(const printer_base &) = delete;
@@ -113,10 +113,34 @@ struct printer_base
this->do_print_benchmark_results(benches);
}
protected:
std::ostream &m_ostream;
/*!
* Used to track progress for interactive progress display:
*
* - `completed_state_count`: Number of states with completed measurements.
* - `total_state_count`: Total number of states.
* @{
*/
virtual void set_completed_state_count(std::size_t states)
{
this->do_set_completed_state_count(states);
}
virtual void add_completed_state() { this->do_add_completed_state(); }
[[nodiscard]] virtual std::size_t get_completed_state_count() const
{
return this->do_get_completed_state_count();
}
private:
virtual void set_total_state_count(std::size_t states)
{
this->do_set_total_state_count(states);
}
[[nodiscard]] virtual std::size_t get_total_state_count() const
{
return this->do_get_total_state_count();
}
/*!@}*/
protected:
// Implementation hooks for subclasses:
virtual void do_print_device_info() {}
virtual void do_print_log_preamble() {}
@@ -125,6 +149,18 @@ private:
virtual void do_log_run_state(const nvbench::state &) {}
virtual void do_print_benchmark_list(const benchmark_vector &) {}
virtual void do_print_benchmark_results(const benchmark_vector &) {}
virtual void do_set_completed_state_count(std::size_t states);
virtual void do_add_completed_state();
[[nodiscard]] virtual std::size_t do_get_completed_state_count() const;
virtual void do_set_total_state_count(std::size_t states);
[[nodiscard]] virtual std::size_t do_get_total_state_count() const;
std::ostream &m_ostream;
std::size_t m_completed_state_count{};
std::size_t m_total_state_count{};
};
} // namespace nvbench

View File

@@ -30,4 +30,26 @@ printer_base::printer_base(std::ostream &ostream)
// Defined here to keep <ostream> out of the header
printer_base::~printer_base() = default;
void printer_base::do_set_completed_state_count(std::size_t states)
{
m_completed_state_count = states;
}
void printer_base::do_add_completed_state() { ++m_completed_state_count; }
std::size_t printer_base::do_get_completed_state_count() const
{
return m_completed_state_count;
}
void printer_base::do_set_total_state_count(std::size_t states)
{
m_total_state_count = states;
}
std::size_t printer_base::do_get_total_state_count() const
{
return m_total_state_count;
}
} // namespace nvbench

View File

@@ -31,7 +31,6 @@ namespace nvbench
*/
struct printer_multiplex : nvbench::printer_base
{
printer_multiplex();
template <typename Format, typename... Ts>
@@ -46,7 +45,7 @@ struct printer_multiplex : nvbench::printer_base
return m_printers.size();
}
private:
protected:
void do_print_device_info() override;
void do_print_log_preamble() override;
void do_print_log_epilogue() override;
@@ -54,6 +53,9 @@ private:
void do_log_run_state(const nvbench::state &) override;
void do_print_benchmark_list(const benchmark_vector &benches) override;
void do_print_benchmark_results(const benchmark_vector &benches) override;
void do_set_completed_state_count(std::size_t states) override;
void do_add_completed_state() override;
void do_set_total_state_count(std::size_t states) override;
std::vector<std::unique_ptr<nvbench::printer_base>> m_printers;
};

View File

@@ -83,5 +83,31 @@ void printer_multiplex::do_print_benchmark_results(
format_ptr->print_benchmark_results(benches);
}
}
void printer_multiplex::do_set_completed_state_count(std::size_t states)
{
printer_base::do_set_completed_state_count(states);
for (auto &format_ptr : m_printers)
{
format_ptr->set_completed_state_count(states);
}
}
void printer_multiplex::do_add_completed_state()
{
printer_base::do_add_completed_state();
for (auto &format_ptr : m_printers)
{
format_ptr->add_completed_state();
}
}
void printer_multiplex::do_set_total_state_count(std::size_t states)
{
printer_base::do_set_total_state_count(states);
for (auto &format_ptr : m_printers)
{
format_ptr->set_total_state_count(states);
}
}
} // namespace nvbench

View File

@@ -40,7 +40,8 @@ struct runner_base
void handle_sampling_exception(const std::exception &e,
nvbench::state &exec_state) const;
void announce_state(state &exec_state) const;
void run_state_prologue(state &exec_state) const;
void run_state_epilogue(state &exec_state) const;
void print_skip_notification(nvbench::state &exec_state) const;
@@ -98,7 +99,7 @@ private:
if (cur_state.get_device() == device &&
cur_state.get_type_config_index() == type_config_index)
{
self.announce_state(cur_state);
self.run_state_prologue(cur_state);
try
{
kernel_generator{}(cur_state, type_config{});
@@ -111,6 +112,7 @@ private:
{
self.handle_sampling_exception(e, cur_state);
}
self.run_state_epilogue(cur_state);
}
}

View File

@@ -59,7 +59,7 @@ void runner_base::handle_sampling_exception(const std::exception &e,
}
}
void runner_base::announce_state(nvbench::state &exec_state) const
void runner_base::run_state_prologue(nvbench::state &exec_state) const
{
// Log if a printer exists:
if (auto printer_opt_ref = exec_state.get_benchmark().get_printer();
@@ -70,6 +70,18 @@ void runner_base::announce_state(nvbench::state &exec_state) const
}
}
void runner_base::run_state_epilogue(state &exec_state) const
{
// Notify the printer that the state has completed::
if (auto printer_opt_ref = exec_state.get_benchmark().get_printer();
printer_opt_ref.has_value())
{
auto &printer = printer_opt_ref.value().get();
printer.add_completed_state();
}
}
void runner_base::print_skip_notification(state &exec_state) const
{
if (auto printer_opt_ref = exec_state.get_benchmark().get_printer();