Track actively used devices in device_manager

This commit is contained in:
Allison Vacanti
2021-10-07 17:11:34 -04:00
parent c530fce382
commit fcb9277b1d
3 changed files with 48 additions and 0 deletions

View File

@@ -45,6 +45,15 @@ struct device_manager
return static_cast<int>(m_devices.size());
}
/**
* @return The number of devices actually used by all benchmarks.
* @note This is only valid after nvbench::option_parser::parse executes.
*/
[[nodiscard]] int get_number_of_used_devices() const
{
return static_cast<int>(m_used_devices.size());
}
/**
* @return The device_info object corresponding to `id`.
*/
@@ -62,10 +71,28 @@ struct device_manager
return m_devices;
}
/**
* @return A vector containing device_info objects for devices that are
* actively used by all benchmarks.
* @note This is only valid after nvbench::option_parser::parse executes.
*/
[[nodiscard]] const device_info_vector &get_used_devices() const
{
return m_used_devices;
}
private:
device_manager();
friend struct option_parser;
void set_used_devices(device_info_vector devices)
{
m_used_devices = std::move(devices);
}
device_info_vector m_devices;
device_info_vector m_used_devices;
};
} // namespace nvbench

View File

@@ -114,6 +114,8 @@ private:
void update_float64_prop(const std::string &prop_arg,
const std::string &prop_val);
void update_used_device_state() const;
// less gross argv:
std::vector<std::string> m_args;

View File

@@ -34,6 +34,7 @@
#include <fmt/format.h>
#include <algorithm>
#include <cassert>
#include <cstdlib>
#include <fstream>
@@ -343,6 +344,8 @@ void option_parser::parse_impl()
{
this->add_markdown_printer("stdout");
}
this->update_used_device_state();
}
void option_parser::parse_range(option_parser::arg_iterator_t first,
@@ -850,6 +853,22 @@ catch (std::exception &e)
e.what());
}
void option_parser::update_used_device_state() const
{
device_manager::device_info_vector devices;
for (const auto &bench : m_benchmarks)
{
const auto &bench_devs = bench->get_devices();
devices.insert(devices.end(), bench_devs.cbegin(), bench_devs.cend());
}
std::sort(devices.begin(), devices.end());
auto last = std::unique(devices.begin(), devices.end());
devices.erase(last, devices.end());
device_manager::get().set_used_devices(devices);
}
nvbench::printer_base &option_parser::get_printer() { return m_printer; }
} // namespace nvbench