mirror of
https://github.com/NVIDIA/nvbench.git
synced 2026-03-14 20:27:24 +00:00
More
This commit is contained in:
@@ -30,6 +30,11 @@ def parse_files():
|
||||
default=None,
|
||||
help="Optional plot title",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dark",
|
||||
action="store_true",
|
||||
help="Use dark theme (black background, white text)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-a",
|
||||
"--axis",
|
||||
@@ -96,6 +101,7 @@ def parse_axis_filters(axis_args):
|
||||
value = value.strip()
|
||||
if not name or not value:
|
||||
raise ValueError("Axis filter must be NAME=VALUE: {}".format(axis_arg))
|
||||
display_value = value
|
||||
if name.endswith("[pow2]"):
|
||||
name = name[: -len("[pow2]")].strip()
|
||||
if not name:
|
||||
@@ -109,7 +115,14 @@ def parse_axis_filters(axis_args):
|
||||
"Axis filter [pow2] value must be integer: {}".format(axis_arg)
|
||||
) from exc
|
||||
value = str(2**exponent)
|
||||
filters.append((name, value))
|
||||
display_value = "2^{}".format(exponent)
|
||||
filters.append(
|
||||
{
|
||||
"name": name,
|
||||
"value": value,
|
||||
"display": "{}={}".format(name, display_value),
|
||||
}
|
||||
)
|
||||
return filters
|
||||
|
||||
|
||||
@@ -118,7 +131,9 @@ def matches_axis_filters(state, axis_filters):
|
||||
return True
|
||||
|
||||
axis_values = state.get("axis_values") or []
|
||||
for filter_name, filter_value in axis_filters:
|
||||
for axis_filter in axis_filters:
|
||||
filter_name = axis_filter["name"]
|
||||
filter_value = axis_filter["value"]
|
||||
matched = False
|
||||
for axis_value in axis_values:
|
||||
if axis_value.get("name") != filter_name:
|
||||
@@ -134,6 +149,20 @@ def matches_axis_filters(state, axis_filters):
|
||||
return True
|
||||
|
||||
|
||||
def strip_axis_filters_from_state_name(state_name, axis_filters):
|
||||
if not axis_filters:
|
||||
return state_name
|
||||
|
||||
tokens = state_name.split()
|
||||
tokens_to_remove = set(
|
||||
axis_filter["display"]
|
||||
for axis_filter in axis_filters
|
||||
if " " not in axis_filter["display"]
|
||||
)
|
||||
tokens = [token for token in tokens if token not in tokens_to_remove]
|
||||
return " ".join(tokens)
|
||||
|
||||
|
||||
def collect_entries(filename, axis_filters, benchmark_filters):
|
||||
json_root = reader.read_file(filename)
|
||||
entries = []
|
||||
@@ -155,6 +184,7 @@ def collect_entries(filename, axis_filters, benchmark_filters):
|
||||
parts = state_name.split(" ", 1)
|
||||
if len(parts) == 2:
|
||||
state_name = parts[1]
|
||||
state_name = strip_axis_filters_from_state_name(state_name, axis_filters)
|
||||
label = "{} | {}".format(bench_name, state_name)
|
||||
device_name = devices.get(state.get("device"))
|
||||
if device_name:
|
||||
@@ -164,7 +194,7 @@ def collect_entries(filename, axis_filters, benchmark_filters):
|
||||
return entries, device_names
|
||||
|
||||
|
||||
def plot_entries(entries, title=None, output=None):
|
||||
def plot_entries(entries, title=None, output=None, dark=False):
|
||||
if not entries:
|
||||
print("No utilization data found.")
|
||||
return 1
|
||||
@@ -184,6 +214,15 @@ def plot_entries(entries, title=None, output=None):
|
||||
|
||||
fig_height = max(4.0, 0.3 * len(entries) + 1.5)
|
||||
fig, ax = plt.subplots(figsize=(10, fig_height))
|
||||
if dark:
|
||||
fig.patch.set_facecolor("black")
|
||||
ax.set_facecolor("black")
|
||||
ax.tick_params(colors="white")
|
||||
ax.xaxis.label.set_color("white")
|
||||
ax.yaxis.label.set_color("white")
|
||||
ax.title.set_color("white")
|
||||
for spine in ax.spines.values():
|
||||
spine.set_color("white")
|
||||
|
||||
y_pos = range(len(labels))
|
||||
ax.barh(y_pos, values, color=colors)
|
||||
@@ -192,7 +231,6 @@ def plot_entries(entries, title=None, output=None):
|
||||
ax.invert_yaxis()
|
||||
ax.set_ylim(len(labels) - 0.5, -0.5)
|
||||
|
||||
ax.set_xlabel("Global BW Utilization")
|
||||
ax.xaxis.set_major_formatter(PercentFormatter(1.0))
|
||||
|
||||
if title:
|
||||
@@ -231,8 +269,11 @@ def main():
|
||||
title = "%SOL Bandwidth"
|
||||
if len(device_names) == 1:
|
||||
title = "{} - {}".format(title, next(iter(device_names)))
|
||||
if axis_filters:
|
||||
axis_label = ", ".join(axis_filter["display"] for axis_filter in axis_filters)
|
||||
title = "{} ({})".format(title, axis_label)
|
||||
|
||||
return plot_entries(entries, title=title, output=args.output)
|
||||
return plot_entries(entries, title=title, output=args.output, dark=args.dark)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user