diff --git a/nvbench/benchmark_base.cuh b/nvbench/benchmark_base.cuh index 1a143a3..939bbeb 100644 --- a/nvbench/benchmark_base.cuh +++ b/nvbench/benchmark_base.cuh @@ -33,9 +33,7 @@ #include #include -#include // reference_wrapper, ref #include -#include #include namespace nvbench @@ -56,9 +54,6 @@ struct runner; */ struct benchmark_base { - template - using optional_ref = std::optional>; - template explicit benchmark_base(TypeAxes type_axes) : m_axes(type_axes) @@ -157,11 +152,10 @@ struct benchmark_base void run() { this->do_run(); } void run_or_skip(bool &skip_remaining) { this->do_run_or_skip(skip_remaining); } - void set_printer(nvbench::printer_base &printer) { m_printer = std::ref(printer); } + void set_printer(nvbench::printer_base &printer) { m_printer_ptr = &printer; } + void clear_printer() { m_printer_ptr = nullptr; } - void clear_printer() { m_printer = std::nullopt; } - - [[nodiscard]] optional_ref get_printer() const { return m_printer; } + [[nodiscard]] nvbench::printer_base *get_printer() const { return m_printer_ptr; } /// Execute at least this many trials per measurement. @{ [[nodiscard]] nvbench::int64_t get_min_samples() const { return m_min_samples; } @@ -321,8 +315,6 @@ protected: std::vector m_devices; std::vector m_states; - optional_ref m_printer; - bool m_is_cpu_only{false}; bool m_run_once{false}; bool m_disable_blocking_kernel{false}; @@ -340,6 +332,8 @@ protected: std::string m_stopping_criterion{}; private: + nvbench::printer_base *m_printer_ptr{}; + // route these through virtuals so the templated subclass can inject type info virtual std::unique_ptr do_clone() const = 0; virtual void do_set_type_axes_names(std::vector names) = 0; diff --git a/nvbench/benchmark_base.cxx b/nvbench/benchmark_base.cxx index 7d6f06f..a237e5c 100644 --- a/nvbench/benchmark_base.cxx +++ b/nvbench/benchmark_base.cxx @@ -37,7 +37,7 @@ std::unique_ptr benchmark_base::clone() const result->m_axes = m_axes; result->m_devices = m_devices; - result->m_printer = m_printer; + result->m_printer_ptr = m_printer_ptr; result->m_is_cpu_only = m_is_cpu_only; result->m_run_once = m_run_once; diff --git a/nvbench/detail/measure_cold.cu b/nvbench/detail/measure_cold.cu index b6a4c34..6c23548 100644 --- a/nvbench/detail/measure_cold.cu +++ b/nvbench/detail/measure_cold.cu @@ -115,9 +115,9 @@ void measure_cold_base::record_measurements() 0.5f); } - if (auto printer_opt_ref = m_state.get_benchmark().get_printer(); printer_opt_ref.has_value()) + if (auto printer_ptr = m_state.get_benchmark().get_printer()) { - auto &printer = printer_opt_ref.value().get(); + auto &printer = *printer_ptr; printer.log(nvbench::log_level::warn, fmt::format("GPU throttled below threshold ({:0.2f} MHz / {:0.2f} MHz) " "({:0.0f}% < {:0.0f}%) on sample {}. Discarding previous trial " @@ -386,9 +386,9 @@ void measure_cold_base::generate_summaries() } // Log if a printer exists: - if (auto printer_opt_ref = m_state.get_benchmark().get_printer(); printer_opt_ref.has_value()) + if (auto printer_ptr = m_state.get_benchmark().get_printer()) { - auto &printer = printer_opt_ref.value().get(); + auto &printer = *printer_ptr; if (m_max_time_exceeded) { diff --git a/nvbench/detail/measure_cpu_only.cxx b/nvbench/detail/measure_cpu_only.cxx index 4d92a05..5545f20 100644 --- a/nvbench/detail/measure_cpu_only.cxx +++ b/nvbench/detail/measure_cpu_only.cxx @@ -204,9 +204,9 @@ void measure_cpu_only_base::generate_summaries() } // Log if a printer exists: - if (auto printer_opt_ref = m_state.get_benchmark().get_printer(); printer_opt_ref.has_value()) + if (auto printer_ptr = m_state.get_benchmark().get_printer()) { - auto &printer = printer_opt_ref.value().get(); + auto &printer = *printer_ptr; if (m_max_time_exceeded) { diff --git a/nvbench/detail/measure_cupti.cu b/nvbench/detail/measure_cupti.cu index 24028f2..84f2105 100644 --- a/nvbench/detail/measure_cupti.cu +++ b/nvbench/detail/measure_cupti.cu @@ -171,9 +171,9 @@ try // clang-format on catch (const std::exception &ex) { - if (auto printer_opt_ref = exec_state.get_benchmark().get_printer(); printer_opt_ref) + if (auto printer_ptr = exec_state.get_benchmark().get_printer()) { - auto &printer = printer_opt_ref.value().get(); + auto &printer = *printer_ptr; printer.log(nvbench::log_level::warn, fmt::format("CUPTI failed to construct profiler: {}", ex.what())); } @@ -247,9 +247,9 @@ try } // Log if a printer exists: - if (auto printer_opt_ref = m_state.get_benchmark().get_printer(); printer_opt_ref.has_value()) + if (auto printer_ptr = m_state.get_benchmark().get_printer()) { - auto &printer = printer_opt_ref.value().get(); + auto &printer = *printer_ptr; printer.log(nvbench::log_level::pass, fmt::format("CUPTI: {:0.2f}s total wall, {}x", m_walltime_timer.get_duration(), @@ -258,9 +258,9 @@ try } catch (const std::exception &ex) { - if (auto printer_opt_ref = m_state.get_benchmark().get_printer(); printer_opt_ref) + if (auto printer_ptr = m_state.get_benchmark().get_printer()) { - auto &printer = printer_opt_ref.value().get(); + auto &printer = *printer_ptr; printer.log(nvbench::log_level::warn, fmt::format("CUPTI failed to generate the summary: {}", ex.what())); } diff --git a/nvbench/detail/measure_hot.cu b/nvbench/detail/measure_hot.cu index f44bb35..41156f5 100644 --- a/nvbench/detail/measure_hot.cu +++ b/nvbench/detail/measure_hot.cu @@ -112,9 +112,9 @@ void measure_hot_base::generate_summaries() } // Log if a printer exists: - if (auto printer_opt_ref = m_state.get_benchmark().get_printer(); printer_opt_ref.has_value()) + if (auto printer_ptr = m_state.get_benchmark().get_printer()) { - auto &printer = printer_opt_ref.value().get(); + auto &printer = *printer_ptr; // Warn if timed out: if (m_max_time_exceeded) diff --git a/nvbench/json_printer.cu b/nvbench/json_printer.cu index 4e17359..e363a45 100644 --- a/nvbench/json_printer.cu +++ b/nvbench/json_printer.cu @@ -247,9 +247,9 @@ void json_printer::do_process_bulk_data_float64(state &state, } catch (std::exception &e) { - if (auto printer_opt_ref = state.get_benchmark().get_printer(); printer_opt_ref.has_value()) + if (auto printer_ptr = state.get_benchmark().get_printer()) { - auto &printer = printer_opt_ref.value().get(); + auto &printer = *printer_ptr; printer.log( nvbench::log_level::warn, fmt::format("Error writing {} ({}) to {}: {}", tag, hint, result_path.string(), e.what())); @@ -267,9 +267,9 @@ void json_printer::do_process_bulk_data_float64(state &state, summ.set_string("hide", "Not needed in table."); timer.stop(); - if (auto printer_opt_ref = state.get_benchmark().get_printer(); printer_opt_ref.has_value()) + if (auto printer_ptr = state.get_benchmark().get_printer()) { - auto &printer = printer_opt_ref.value().get(); + auto &printer = *printer_ptr; printer.log( nvbench::log_level::info, fmt::format("Wrote '{}' in {:>6.3f}ms", result_path.string(), timer.get_duration() * 1000)); @@ -307,9 +307,9 @@ void json_printer::do_process_bulk_data_float64(state &state, } catch (std::exception &e) { - if (auto printer_opt_ref = state.get_benchmark().get_printer(); printer_opt_ref.has_value()) + if (auto printer_ptr = state.get_benchmark().get_printer()) { - auto &printer = printer_opt_ref.value().get(); + auto &printer = *printer_ptr; printer.log( nvbench::log_level::warn, fmt::format("Error writing {} ({}) to {}: {}", tag, hint, result_path.string(), e.what())); @@ -327,9 +327,9 @@ void json_printer::do_process_bulk_data_float64(state &state, summ.set_string("hide", "Not needed in table."); timer.stop(); - if (auto printer_opt_ref = state.get_benchmark().get_printer(); printer_opt_ref.has_value()) + if (auto printer_ptr = state.get_benchmark().get_printer()) { - auto &printer = printer_opt_ref.value().get(); + auto &printer = *printer_ptr; printer.log( nvbench::log_level::info, fmt::format("Wrote '{}' in {:>6.3f}ms", result_path.string(), timer.get_duration() * 1000)); diff --git a/nvbench/runner.cxx b/nvbench/runner.cxx index 09ddb46..aacd424 100644 --- a/nvbench/runner.cxx +++ b/nvbench/runner.cxx @@ -46,10 +46,9 @@ void runner_base::handle_sampling_exception(const std::exception &e, state &exec { const auto reason = fmt::format("Unexpected error: {}", e.what()); - if (auto printer_opt_ref = exec_state.get_benchmark().get_printer(); - printer_opt_ref.has_value()) + if (auto printer_ptr = exec_state.get_benchmark().get_printer()) { - auto &printer = printer_opt_ref.value().get(); + auto &printer = *printer_ptr; printer.log(nvbench::log_level::fail, reason); } @@ -60,9 +59,9 @@ void runner_base::handle_sampling_exception(const std::exception &e, state &exec void runner_base::run_state_prologue(nvbench::state &exec_state) const { // Log if a printer exists: - if (auto printer_opt_ref = exec_state.get_benchmark().get_printer(); printer_opt_ref.has_value()) + if (auto printer_ptr = exec_state.get_benchmark().get_printer()) { - auto &printer = printer_opt_ref.value().get(); + auto &printer = *printer_ptr; printer.log_run_state(exec_state); } } @@ -70,18 +69,18 @@ void runner_base::run_state_prologue(nvbench::state &exec_state) const void runner_base::run_state_epilogue(state &exec_state) const { // Notify the printer that the state has completed:: - if (auto printer_opt_ref = exec_state.get_benchmark().get_printer(); printer_opt_ref.has_value()) + if (auto printer_ptr = exec_state.get_benchmark().get_printer()) { - auto &printer = printer_opt_ref.value().get(); + auto &printer = *printer_ptr; printer.add_completed_state(); } } void runner_base::print_skip_notification(state &exec_state) const { - if (auto printer_opt_ref = exec_state.get_benchmark().get_printer(); printer_opt_ref.has_value()) + if (auto printer_ptr = exec_state.get_benchmark().get_printer()) { - auto &printer = printer_opt_ref.value().get(); + auto &printer = *printer_ptr; printer.log(nvbench::log_level::skip, exec_state.get_skip_reason()); } } diff --git a/nvbench/state.cuh b/nvbench/state.cuh index f639e27..aad5a2c 100644 --- a/nvbench/state.cuh +++ b/nvbench/state.cuh @@ -36,7 +36,6 @@ #include #include -#include #include #include #include @@ -246,7 +245,7 @@ struct state */ [[nodiscard]] std::string get_axis_values_as_string(bool color = false) const; - [[nodiscard]] const benchmark_base &get_benchmark() const { return m_benchmark; } + [[nodiscard]] const benchmark_base &get_benchmark() const { return *m_benchmark_ptr; }; void collect_l1_hit_rates() { m_collect_l1_hit_rates = true; } void collect_l2_hit_rates() { m_collect_l2_hit_rates = true; } @@ -318,7 +317,8 @@ private: [[nodiscard]] bool skip_hot_measurement() const { return get_run_once() || get_skip_batched(); } - std::reference_wrapper m_benchmark; + const nvbench::benchmark_base *m_benchmark_ptr; + nvbench::named_values m_axis_values; std::optional m_device; std::size_t m_type_config_index{}; diff --git a/nvbench/state.cxx b/nvbench/state.cxx index 2ae3c11..af53502 100644 --- a/nvbench/state.cxx +++ b/nvbench/state.cxx @@ -32,7 +32,7 @@ namespace nvbench { state::state(const benchmark_base &bench) - : m_benchmark{bench} + : m_benchmark_ptr{&bench} , m_is_cpu_only(bench.get_is_cpu_only()) , m_run_once{bench.get_run_once()} , m_disable_blocking_kernel{bench.get_disable_blocking_kernel()} @@ -50,7 +50,7 @@ state::state(const benchmark_base &bench, nvbench::named_values values, std::optional device, std::size_t type_config_index) - : m_benchmark{bench} + : m_benchmark_ptr{&bench} , m_axis_values{std::move(values)} , m_device{std::move(device)} , m_type_config_index{type_config_index} @@ -205,7 +205,7 @@ std::string state::get_axis_values_as_string(bool color) const append_key_value("Device", m_device->get_id()); } - const axes_metadata &axes = m_benchmark.get().get_axes(); + const axes_metadata &axes = m_benchmark_ptr->get_axes(); for (const auto &name : m_axis_values.get_names()) { const auto axis_type = m_axis_values.get_type(name); @@ -242,7 +242,7 @@ std::string state::get_short_description(bool color) const }; return fmt::format("{} [{}]", - fmt::format(style(fmt::emphasis::bold), "{}", m_benchmark.get().get_name()), + fmt::format(style(fmt::emphasis::bold), "{}", m_benchmark_ptr->get_name()), this->get_axis_values_as_string(color)); } diff --git a/python/src/py_nvbench.cpp b/python/src/py_nvbench.cpp index aeadcc2..8ecac4a 100644 --- a/python/src/py_nvbench.cpp +++ b/python/src/py_nvbench.cpp @@ -710,7 +710,7 @@ Returns True if configuration has a device // method State.has_printers auto method_has_printers_impl = [](const nvbench::state &state) -> bool { - return state.get_benchmark().get_printer().has_value(); + return state.get_benchmark().get_printer() != nullptr; }; static constexpr const char *method_has_printers_doc = R"XXXX( Returns True if configuration has a printer"