mirror of
https://github.com/NVIDIA/nvbench.git
synced 2026-04-30 03:31:13 +00:00
Merge pull request #54 from allisonvacanti/progress_display
Print progress in markdown log.
This commit is contained in:
@@ -64,11 +64,6 @@ struct benchmark final : public benchmark_base
|
|||||||
: benchmark_base(type_axes{})
|
: benchmark_base(type_axes{})
|
||||||
{}
|
{}
|
||||||
|
|
||||||
// Note that this inline virtual dtor may cause vtable issues if linking
|
|
||||||
// benchmark TUs together. That's not a likely scenario, so we'll deal with
|
|
||||||
// that if it comes up.
|
|
||||||
~benchmark() override = default;
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::unique_ptr<benchmark_base> do_clone() const final
|
std::unique_ptr<benchmark_base> do_clone() const final
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -145,6 +145,11 @@ struct benchmark_base
|
|||||||
return m_axes;
|
return m_axes;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Computes the number of configs in the benchmark.
|
||||||
|
// Unlike get_states().size(), this method may be used prior to calling run().
|
||||||
|
[[nodiscard]] std::size_t get_config_count() const;
|
||||||
|
|
||||||
|
// Is empty until run() is called.
|
||||||
[[nodiscard]] const std::vector<nvbench::state> &get_states() const
|
[[nodiscard]] const std::vector<nvbench::state> &get_states() const
|
||||||
{
|
{
|
||||||
return m_states;
|
return m_states;
|
||||||
|
|||||||
@@ -18,6 +18,8 @@
|
|||||||
|
|
||||||
#include <nvbench/benchmark_base.cuh>
|
#include <nvbench/benchmark_base.cuh>
|
||||||
|
|
||||||
|
#include <nvbench/detail/transform_reduce.cuh>
|
||||||
|
|
||||||
namespace nvbench
|
namespace nvbench
|
||||||
{
|
{
|
||||||
|
|
||||||
@@ -58,4 +60,15 @@ benchmark_base &benchmark_base::add_device(int device_id)
|
|||||||
return this->add_device(device_info{device_id});
|
return this->add_device(device_info{device_id});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::size_t benchmark_base::get_config_count() const
|
||||||
|
{
|
||||||
|
return nvbench::detail::transform_reduce(m_axes.get_axes().cbegin(),
|
||||||
|
m_axes.get_axes().cend(),
|
||||||
|
std::size_t{1},
|
||||||
|
std::multiplies<>{},
|
||||||
|
[](const auto &axis_ptr) {
|
||||||
|
return axis_ptr->get_size();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace nvbench
|
} // namespace nvbench
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ struct csv_printer : nvbench::printer_base
|
|||||||
{
|
{
|
||||||
using printer_base::printer_base;
|
using printer_base::printer_base;
|
||||||
|
|
||||||
private:
|
protected:
|
||||||
// Virtual API from printer_base:
|
// Virtual API from printer_base:
|
||||||
void do_print_benchmark_results(const benchmark_vector &benches) override;
|
void do_print_benchmark_results(const benchmark_vector &benches) override;
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ struct json_printer : nvbench::printer_base
|
|||||||
{
|
{
|
||||||
using printer_base::printer_base;
|
using printer_base::printer_base;
|
||||||
|
|
||||||
private:
|
protected:
|
||||||
// Virtual API from printer_base:
|
// Virtual API from printer_base:
|
||||||
void do_print_benchmark_results(const benchmark_vector &benches) override;
|
void do_print_benchmark_results(const benchmark_vector &benches) override;
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -63,6 +63,15 @@
|
|||||||
printer.print_device_info(); \
|
printer.print_device_info(); \
|
||||||
printer.print_log_preamble(); \
|
printer.print_log_preamble(); \
|
||||||
auto &benchmarks = parser.get_benchmarks(); \
|
auto &benchmarks = parser.get_benchmarks(); \
|
||||||
|
\
|
||||||
|
std::size_t total_states = 0; \
|
||||||
|
for (auto &bench_ptr : benchmarks) \
|
||||||
|
{ \
|
||||||
|
total_states += bench_ptr->get_config_count(); \
|
||||||
|
} \
|
||||||
|
printer.set_total_state_count(total_states); \
|
||||||
|
\
|
||||||
|
printer.set_completed_state_count(0); \
|
||||||
for (auto &bench_ptr : benchmarks) \
|
for (auto &bench_ptr : benchmarks) \
|
||||||
{ \
|
{ \
|
||||||
bench_ptr->set_printer(printer); \
|
bench_ptr->set_printer(printer); \
|
||||||
|
|||||||
@@ -23,8 +23,6 @@
|
|||||||
#include <nvbench/state.cuh>
|
#include <nvbench/state.cuh>
|
||||||
#include <nvbench/summary.cuh>
|
#include <nvbench/summary.cuh>
|
||||||
|
|
||||||
#include <nvbench/detail/transform_reduce.cuh>
|
|
||||||
|
|
||||||
#include <nvbench/internal/markdown_table.cuh>
|
#include <nvbench/internal/markdown_table.cuh>
|
||||||
|
|
||||||
#include <fmt/color.h>
|
#include <fmt/color.h>
|
||||||
@@ -147,7 +145,19 @@ void markdown_printer::do_log(nvbench::log_level level, const std::string &msg)
|
|||||||
|
|
||||||
void markdown_printer::do_log_run_state(const nvbench::state &exec_state)
|
void markdown_printer::do_log_run_state(const nvbench::state &exec_state)
|
||||||
{
|
{
|
||||||
this->log(nvbench::log_level::run, exec_state.get_short_description(m_color));
|
if (m_total_state_count == 0)
|
||||||
|
{ // No progress info
|
||||||
|
this->log(nvbench::log_level::run,
|
||||||
|
exec_state.get_short_description(m_color));
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{ // Add progress
|
||||||
|
this->log(nvbench::log_level::run,
|
||||||
|
fmt::format("[{}/{}] {}",
|
||||||
|
m_completed_state_count + 1,
|
||||||
|
m_total_state_count,
|
||||||
|
exec_state.get_short_description(m_color)));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void markdown_printer::do_print_benchmark_list(
|
void markdown_printer::do_print_benchmark_list(
|
||||||
@@ -159,12 +169,7 @@ void markdown_printer::do_print_benchmark_list(
|
|||||||
for (const auto &bench_ptr : benches)
|
for (const auto &bench_ptr : benches)
|
||||||
{
|
{
|
||||||
const auto &axes = bench_ptr->get_axes().get_axes();
|
const auto &axes = bench_ptr->get_axes().get_axes();
|
||||||
const std::size_t num_configs = nvbench::detail::transform_reduce(
|
const std::size_t num_configs = bench_ptr->get_config_count();
|
||||||
axes.cbegin(),
|
|
||||||
axes.cend(),
|
|
||||||
std::size_t{1},
|
|
||||||
std::multiplies<>{},
|
|
||||||
[](const auto &axis_ptr) { return axis_ptr->get_size(); });
|
|
||||||
|
|
||||||
fmt::format_to(buffer,
|
fmt::format_to(buffer,
|
||||||
"## [{}] `{}` ({} configurations)\n\n",
|
"## [{}] `{}` ({} configurations)\n\n",
|
||||||
|
|||||||
@@ -50,7 +50,7 @@ struct markdown_printer : nvbench::printer_base
|
|||||||
[[nodiscard]] bool get_color() const { return m_color; }
|
[[nodiscard]] bool get_color() const { return m_color; }
|
||||||
/*!@}*/
|
/*!@}*/
|
||||||
|
|
||||||
private:
|
protected:
|
||||||
// Virtual API from printer_base:
|
// Virtual API from printer_base:
|
||||||
void do_print_device_info() override;
|
void do_print_device_info() override;
|
||||||
void do_print_log_preamble() override;
|
void do_print_log_preamble() override;
|
||||||
|
|||||||
@@ -56,7 +56,7 @@ struct printer_base
|
|||||||
* Construct a new printer_base that will write to ostream.
|
* Construct a new printer_base that will write to ostream.
|
||||||
*/
|
*/
|
||||||
explicit printer_base(std::ostream &ostream);
|
explicit printer_base(std::ostream &ostream);
|
||||||
~printer_base();
|
virtual ~printer_base();
|
||||||
|
|
||||||
// move-only
|
// move-only
|
||||||
printer_base(const printer_base &) = delete;
|
printer_base(const printer_base &) = delete;
|
||||||
@@ -113,10 +113,34 @@ struct printer_base
|
|||||||
this->do_print_benchmark_results(benches);
|
this->do_print_benchmark_results(benches);
|
||||||
}
|
}
|
||||||
|
|
||||||
protected:
|
/*!
|
||||||
std::ostream &m_ostream;
|
* Used to track progress for interactive progress display:
|
||||||
|
*
|
||||||
|
* - `completed_state_count`: Number of states with completed measurements.
|
||||||
|
* - `total_state_count`: Total number of states.
|
||||||
|
* @{
|
||||||
|
*/
|
||||||
|
virtual void set_completed_state_count(std::size_t states)
|
||||||
|
{
|
||||||
|
this->do_set_completed_state_count(states);
|
||||||
|
}
|
||||||
|
virtual void add_completed_state() { this->do_add_completed_state(); }
|
||||||
|
[[nodiscard]] virtual std::size_t get_completed_state_count() const
|
||||||
|
{
|
||||||
|
return this->do_get_completed_state_count();
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
virtual void set_total_state_count(std::size_t states)
|
||||||
|
{
|
||||||
|
this->do_set_total_state_count(states);
|
||||||
|
}
|
||||||
|
[[nodiscard]] virtual std::size_t get_total_state_count() const
|
||||||
|
{
|
||||||
|
return this->do_get_total_state_count();
|
||||||
|
}
|
||||||
|
/*!@}*/
|
||||||
|
|
||||||
|
protected:
|
||||||
// Implementation hooks for subclasses:
|
// Implementation hooks for subclasses:
|
||||||
virtual void do_print_device_info() {}
|
virtual void do_print_device_info() {}
|
||||||
virtual void do_print_log_preamble() {}
|
virtual void do_print_log_preamble() {}
|
||||||
@@ -125,6 +149,18 @@ private:
|
|||||||
virtual void do_log_run_state(const nvbench::state &) {}
|
virtual void do_log_run_state(const nvbench::state &) {}
|
||||||
virtual void do_print_benchmark_list(const benchmark_vector &) {}
|
virtual void do_print_benchmark_list(const benchmark_vector &) {}
|
||||||
virtual void do_print_benchmark_results(const benchmark_vector &) {}
|
virtual void do_print_benchmark_results(const benchmark_vector &) {}
|
||||||
|
|
||||||
|
virtual void do_set_completed_state_count(std::size_t states);
|
||||||
|
virtual void do_add_completed_state();
|
||||||
|
[[nodiscard]] virtual std::size_t do_get_completed_state_count() const;
|
||||||
|
|
||||||
|
virtual void do_set_total_state_count(std::size_t states);
|
||||||
|
[[nodiscard]] virtual std::size_t do_get_total_state_count() const;
|
||||||
|
|
||||||
|
std::ostream &m_ostream;
|
||||||
|
|
||||||
|
std::size_t m_completed_state_count{};
|
||||||
|
std::size_t m_total_state_count{};
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace nvbench
|
} // namespace nvbench
|
||||||
|
|||||||
@@ -30,4 +30,26 @@ printer_base::printer_base(std::ostream &ostream)
|
|||||||
// Defined here to keep <ostream> out of the header
|
// Defined here to keep <ostream> out of the header
|
||||||
printer_base::~printer_base() = default;
|
printer_base::~printer_base() = default;
|
||||||
|
|
||||||
|
void printer_base::do_set_completed_state_count(std::size_t states)
|
||||||
|
{
|
||||||
|
m_completed_state_count = states;
|
||||||
|
}
|
||||||
|
|
||||||
|
void printer_base::do_add_completed_state() { ++m_completed_state_count; }
|
||||||
|
|
||||||
|
std::size_t printer_base::do_get_completed_state_count() const
|
||||||
|
{
|
||||||
|
return m_completed_state_count;
|
||||||
|
}
|
||||||
|
|
||||||
|
void printer_base::do_set_total_state_count(std::size_t states)
|
||||||
|
{
|
||||||
|
m_total_state_count = states;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::size_t printer_base::do_get_total_state_count() const
|
||||||
|
{
|
||||||
|
return m_total_state_count;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace nvbench
|
} // namespace nvbench
|
||||||
|
|||||||
@@ -31,7 +31,6 @@ namespace nvbench
|
|||||||
*/
|
*/
|
||||||
struct printer_multiplex : nvbench::printer_base
|
struct printer_multiplex : nvbench::printer_base
|
||||||
{
|
{
|
||||||
|
|
||||||
printer_multiplex();
|
printer_multiplex();
|
||||||
|
|
||||||
template <typename Format, typename... Ts>
|
template <typename Format, typename... Ts>
|
||||||
@@ -46,7 +45,7 @@ struct printer_multiplex : nvbench::printer_base
|
|||||||
return m_printers.size();
|
return m_printers.size();
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
protected:
|
||||||
void do_print_device_info() override;
|
void do_print_device_info() override;
|
||||||
void do_print_log_preamble() override;
|
void do_print_log_preamble() override;
|
||||||
void do_print_log_epilogue() override;
|
void do_print_log_epilogue() override;
|
||||||
@@ -54,6 +53,9 @@ private:
|
|||||||
void do_log_run_state(const nvbench::state &) override;
|
void do_log_run_state(const nvbench::state &) override;
|
||||||
void do_print_benchmark_list(const benchmark_vector &benches) override;
|
void do_print_benchmark_list(const benchmark_vector &benches) override;
|
||||||
void do_print_benchmark_results(const benchmark_vector &benches) override;
|
void do_print_benchmark_results(const benchmark_vector &benches) override;
|
||||||
|
void do_set_completed_state_count(std::size_t states) override;
|
||||||
|
void do_add_completed_state() override;
|
||||||
|
void do_set_total_state_count(std::size_t states) override;
|
||||||
|
|
||||||
std::vector<std::unique_ptr<nvbench::printer_base>> m_printers;
|
std::vector<std::unique_ptr<nvbench::printer_base>> m_printers;
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -83,5 +83,31 @@ void printer_multiplex::do_print_benchmark_results(
|
|||||||
format_ptr->print_benchmark_results(benches);
|
format_ptr->print_benchmark_results(benches);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
void printer_multiplex::do_set_completed_state_count(std::size_t states)
|
||||||
|
{
|
||||||
|
printer_base::do_set_completed_state_count(states);
|
||||||
|
for (auto &format_ptr : m_printers)
|
||||||
|
{
|
||||||
|
format_ptr->set_completed_state_count(states);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void printer_multiplex::do_add_completed_state()
|
||||||
|
{
|
||||||
|
printer_base::do_add_completed_state();
|
||||||
|
for (auto &format_ptr : m_printers)
|
||||||
|
{
|
||||||
|
format_ptr->add_completed_state();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void printer_multiplex::do_set_total_state_count(std::size_t states)
|
||||||
|
{
|
||||||
|
printer_base::do_set_total_state_count(states);
|
||||||
|
for (auto &format_ptr : m_printers)
|
||||||
|
{
|
||||||
|
format_ptr->set_total_state_count(states);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace nvbench
|
} // namespace nvbench
|
||||||
|
|||||||
@@ -40,7 +40,8 @@ struct runner_base
|
|||||||
void handle_sampling_exception(const std::exception &e,
|
void handle_sampling_exception(const std::exception &e,
|
||||||
nvbench::state &exec_state) const;
|
nvbench::state &exec_state) const;
|
||||||
|
|
||||||
void announce_state(state &exec_state) const;
|
void run_state_prologue(state &exec_state) const;
|
||||||
|
void run_state_epilogue(state &exec_state) const;
|
||||||
|
|
||||||
void print_skip_notification(nvbench::state &exec_state) const;
|
void print_skip_notification(nvbench::state &exec_state) const;
|
||||||
|
|
||||||
@@ -98,7 +99,7 @@ private:
|
|||||||
if (cur_state.get_device() == device &&
|
if (cur_state.get_device() == device &&
|
||||||
cur_state.get_type_config_index() == type_config_index)
|
cur_state.get_type_config_index() == type_config_index)
|
||||||
{
|
{
|
||||||
self.announce_state(cur_state);
|
self.run_state_prologue(cur_state);
|
||||||
try
|
try
|
||||||
{
|
{
|
||||||
kernel_generator{}(cur_state, type_config{});
|
kernel_generator{}(cur_state, type_config{});
|
||||||
@@ -111,6 +112,7 @@ private:
|
|||||||
{
|
{
|
||||||
self.handle_sampling_exception(e, cur_state);
|
self.handle_sampling_exception(e, cur_state);
|
||||||
}
|
}
|
||||||
|
self.run_state_epilogue(cur_state);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -59,7 +59,7 @@ void runner_base::handle_sampling_exception(const std::exception &e,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void runner_base::announce_state(nvbench::state &exec_state) const
|
void runner_base::run_state_prologue(nvbench::state &exec_state) const
|
||||||
{
|
{
|
||||||
// Log if a printer exists:
|
// Log if a printer exists:
|
||||||
if (auto printer_opt_ref = exec_state.get_benchmark().get_printer();
|
if (auto printer_opt_ref = exec_state.get_benchmark().get_printer();
|
||||||
@@ -70,6 +70,18 @@ void runner_base::announce_state(nvbench::state &exec_state) const
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void runner_base::run_state_epilogue(state &exec_state) const
|
||||||
|
{
|
||||||
|
// Notify the printer that the state has completed::
|
||||||
|
if (auto printer_opt_ref = exec_state.get_benchmark().get_printer();
|
||||||
|
printer_opt_ref.has_value())
|
||||||
|
{
|
||||||
|
auto &printer = printer_opt_ref.value().get();
|
||||||
|
printer.add_completed_state();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
void runner_base::print_skip_notification(state &exec_state) const
|
void runner_base::print_skip_notification(state &exec_state) const
|
||||||
{
|
{
|
||||||
if (auto printer_opt_ref = exec_state.get_benchmark().get_printer();
|
if (auto printer_opt_ref = exec_state.get_benchmark().get_printer();
|
||||||
|
|||||||
Reference in New Issue
Block a user