diff --git a/python/scripts/nvbench_plot.py b/python/scripts/nvbench_plot.py index 63330a0..4cdf4bd 100644 --- a/python/scripts/nvbench_plot.py +++ b/python/scripts/nvbench_plot.py @@ -101,7 +101,19 @@ def parse_axis_filters(axis_args): value = value.strip() if not name or not value: raise ValueError("Axis filter must be NAME=VALUE: {}".format(axis_arg)) - display_value = value + + values = [] + display_values = [] + if value.startswith("[") and value.endswith("]"): + inner = value[1:-1].strip() + if inner: + values = [item.strip() for item in inner.split(",") if item.strip()] + else: + values = [] + else: + values = [value] + display_values = list(values) + if name.endswith("[pow2]"): name = name[: -len("[pow2]")].strip() if not name: @@ -109,18 +121,28 @@ def parse_axis_filters(axis_args): "Axis filter missing name before [pow2]: {}".format(axis_arg) ) try: - exponent = int(value) + exponents = [int(v) for v in values] except ValueError as exc: raise ValueError( "Axis filter [pow2] value must be integer: {}".format(axis_arg) ) from exc - value = str(2**exponent) - display_value = "2^{}".format(exponent) + values = [str(2**exponent) for exponent in exponents] + display_values = ["2^{}".format(exponent) for exponent in exponents] + + if not values: + raise ValueError( + "Axis filter must specify at least one value: {}".format(axis_arg) + ) + + if len(display_values) == 1: + display = "{}={}".format(name, display_values[0]) + else: + display = "{}=[{}]".format(name, ",".join(display_values)) filters.append( { "name": name, - "value": value, - "display": "{}={}".format(name, display_value), + "values": values, + "display": display, } ) return filters @@ -133,7 +155,7 @@ def matches_axis_filters(state, axis_filters): axis_values = state.get("axis_values") or [] for axis_filter in axis_filters: filter_name = axis_filter["name"] - filter_value = axis_filter["value"] + filter_values = axis_filter["values"] matched = False for axis_value in axis_values: if axis_value.get("name") != filter_name: @@ -141,7 +163,7 @@ def matches_axis_filters(state, axis_filters): value = axis_value.get("value") if value is None: continue - if str(value) == filter_value: + if str(value) in filter_values: matched = True break if not matched: @@ -154,12 +176,16 @@ def strip_axis_filters_from_state_name(state_name, axis_filters): return state_name tokens = state_name.split() - tokens_to_remove = set( - axis_filter["display"] + filter_prefixes = set( + "{}=".format(axis_filter["name"]) for axis_filter in axis_filters - if " " not in axis_filter["display"] + if len(axis_filter["values"]) == 1 ) - tokens = [token for token in tokens if token not in tokens_to_remove] + tokens = [ + token + for token in tokens + if not any(token.startswith(prefix) for prefix in filter_prefixes) + ] return " ".join(tokens)