From c9ab8e2eb3f7eb637cb00393fcd705420aadeaa7 Mon Sep 17 00:00:00 2001 From: Allison Vacanti Date: Tue, 21 Dec 2021 20:36:52 -0500 Subject: [PATCH] Fix progress display for inactive type axis values. When type axis values were disabled they were still counted towards a benchmark's total number of configs. --- nvbench/benchmark_base.cxx | 21 ++++++++++++++------- nvbench/type_axis.cuh | 1 + nvbench/type_axis.cxx | 5 +++++ testing/benchmark.cu | 18 ++++++++++++++++++ 4 files changed, 38 insertions(+), 7 deletions(-) diff --git a/nvbench/benchmark_base.cxx b/nvbench/benchmark_base.cxx index b3cd649..9c749f5 100644 --- a/nvbench/benchmark_base.cxx +++ b/nvbench/benchmark_base.cxx @@ -62,13 +62,20 @@ benchmark_base &benchmark_base::add_device(int device_id) std::size_t benchmark_base::get_config_count() const { - return nvbench::detail::transform_reduce(m_axes.get_axes().cbegin(), - m_axes.get_axes().cend(), - std::size_t{1}, - std::multiplies<>{}, - [](const auto &axis_ptr) { - return axis_ptr->get_size(); - }); + return nvbench::detail::transform_reduce( + m_axes.get_axes().cbegin(), + m_axes.get_axes().cend(), + std::size_t{1}, + std::multiplies<>{}, + [](const auto &axis_ptr) { + if (const auto *type_axis_ptr = + dynamic_cast(axis_ptr.get()); + type_axis_ptr != nullptr) + { + return type_axis_ptr->get_active_count(); + } + return axis_ptr->get_size(); + }); } } // namespace nvbench diff --git a/nvbench/type_axis.cuh b/nvbench/type_axis.cuh index a71a9ea..2ee9144 100644 --- a/nvbench/type_axis.cuh +++ b/nvbench/type_axis.cuh @@ -47,6 +47,7 @@ struct type_axis final : public axis_base [[nodiscard]] bool get_is_active(const std::string &input) const; [[nodiscard]] bool get_is_active(std::size_t index) const; + [[nodiscard]] std::size_t get_active_count() const; /** * The index of this axis in the `benchmark`'s `type_axes` type list. diff --git a/nvbench/type_axis.cxx b/nvbench/type_axis.cxx index bca3533..2a4e628 100644 --- a/nvbench/type_axis.cxx +++ b/nvbench/type_axis.cxx @@ -52,6 +52,11 @@ bool type_axis::get_is_active(std::size_t idx) const return m_mask.at(idx); } +std::size_t type_axis::get_active_count() const +{ + return std::count(m_mask.cbegin(), m_mask.cend(), true); +} + std::size_t type_axis::get_type_index(const std::string &input_string) const { auto it = diff --git a/testing/benchmark.cu b/testing/benchmark.cu index 5685d16..947664d 100644 --- a/testing/benchmark.cu +++ b/testing/benchmark.cu @@ -286,6 +286,23 @@ void test_clone() ASSERT(clone->get_states().empty()); } +void test_get_config_count() +{ + lots_of_types_bench bench; + bench.set_type_axes_names({"Integer", "Float", "Other"}); + bench.get_axes().get_type_axis(0).set_active_inputs({"I16", "I32"}); // 2, 2 + bench.get_axes().get_type_axis(1).set_active_inputs({"F32", "F64"}); // 2, 4 + bench.get_axes().get_type_axis(2).set_active_inputs({"bool"}); // 1, 4 + bench.add_float64_axis("foo", {0.4, 2.3, 4.3}); // 3, 12 + bench.add_int64_axis("bar", {4, 6, 15}); // 3, 36 + bench.add_string_axis("baz", {"str", "ing"}); // 2, 72 + bench.add_string_axis("baz", {"single"}); // 1, 72 + + ASSERT_MSG(bench.get_config_count() == 72, + "Got {}", + bench.get_config_count()); +} + int main() { test_type_axes(); @@ -296,4 +313,5 @@ int main() test_string_axes(); test_run(); test_clone(); + test_get_config_count(); }