Merge pull request #63 from allisonvacanti/fix_progress_display

Fix progress display for inactive type axis values.
This commit is contained in:
Allison Vacanti
2021-12-21 20:42:05 -05:00
committed by GitHub
5 changed files with 40 additions and 8 deletions

View File

@@ -62,13 +62,20 @@ benchmark_base &benchmark_base::add_device(int device_id)
std::size_t benchmark_base::get_config_count() const
{
return nvbench::detail::transform_reduce(m_axes.get_axes().cbegin(),
m_axes.get_axes().cend(),
std::size_t{1},
std::multiplies<>{},
[](const auto &axis_ptr) {
return axis_ptr->get_size();
});
return nvbench::detail::transform_reduce(
m_axes.get_axes().cbegin(),
m_axes.get_axes().cend(),
std::size_t{1},
std::multiplies<>{},
[](const auto &axis_ptr) {
if (const auto *type_axis_ptr =
dynamic_cast<const nvbench::type_axis *>(axis_ptr.get());
type_axis_ptr != nullptr)
{
return type_axis_ptr->get_active_count();
}
return axis_ptr->get_size();
});
}
} // namespace nvbench

View File

@@ -175,7 +175,8 @@ std::vector<std::string> parse_range_values(std::string_view range_spec,
nvbench::wrapped_type<std::string>)
{
NVBENCH_THROW(std::runtime_error,
"Cannot use range syntax for string axis specification: `{}`.",
"Cannot use range syntax for string or type axis "
"specification: `{}`.",
range_spec);
}

View File

@@ -47,6 +47,7 @@ struct type_axis final : public axis_base
[[nodiscard]] bool get_is_active(const std::string &input) const;
[[nodiscard]] bool get_is_active(std::size_t index) const;
[[nodiscard]] std::size_t get_active_count() const;
/**
* The index of this axis in the `benchmark`'s `type_axes` type list.

View File

@@ -52,6 +52,11 @@ bool type_axis::get_is_active(std::size_t idx) const
return m_mask.at(idx);
}
std::size_t type_axis::get_active_count() const
{
return std::count(m_mask.cbegin(), m_mask.cend(), true);
}
std::size_t type_axis::get_type_index(const std::string &input_string) const
{
auto it =

View File

@@ -286,6 +286,23 @@ void test_clone()
ASSERT(clone->get_states().empty());
}
void test_get_config_count()
{
lots_of_types_bench bench;
bench.set_type_axes_names({"Integer", "Float", "Other"});
bench.get_axes().get_type_axis(0).set_active_inputs({"I16", "I32"}); // 2, 2
bench.get_axes().get_type_axis(1).set_active_inputs({"F32", "F64"}); // 2, 4
bench.get_axes().get_type_axis(2).set_active_inputs({"bool"}); // 1, 4
bench.add_float64_axis("foo", {0.4, 2.3, 4.3}); // 3, 12
bench.add_int64_axis("bar", {4, 6, 15}); // 3, 36
bench.add_string_axis("baz", {"str", "ing"}); // 2, 72
bench.add_string_axis("baz", {"single"}); // 1, 72
ASSERT_MSG(bench.get_config_count() == 72,
"Got {}",
bench.get_config_count());
}
int main()
{
test_type_axes();
@@ -296,4 +313,5 @@ int main()
test_string_axes();
test_run();
test_clone();
test_get_config_count();
}