diff --git a/python/scripts/nvbench_compare.py b/python/scripts/nvbench_compare.py index fee4ea3..4b81d03 100644 --- a/python/scripts/nvbench_compare.py +++ b/python/scripts/nvbench_compare.py @@ -79,6 +79,12 @@ def format_axis_value(axis_name, axis_value, axes): return format_string_axis_value(axis_name, axis_value, axes) +def make_display(name: str, display_values: [list[str]]) -> str: + open_bracket, close_bracket = ("[", "]") if len(display_values) > 1 else ("", "") + display_values = ",".join(display_values) + return f"{name}={open_bracket}{display_values}{close_bracket}" + + def parse_axis_filters(axis_args): filters = [] for axis_arg in axis_args: @@ -91,7 +97,6 @@ def parse_axis_filters(axis_args): raise ValueError("Axis filter must be NAME=VALUE: {}".format(axis_arg)) values = [] - display_values = [] if value.startswith("[") and value.endswith("]"): inner = value[1:-1].strip() values = [ @@ -121,10 +126,7 @@ def parse_axis_filters(axis_args): "Axis filter must specify at least one value: {}".format(axis_arg) ) - if len(display_values) == 1: - display = "{}={}".format(name, display_values[0]) - else: - display = "{}=[{}]".format(name, ",".join(display_values)) + display = make_display(name, display_values) filters.append( { "name": name,