mirror of
https://github.com/NVIDIA/nvbench.git
synced 2026-03-14 20:27:24 +00:00
I have no idea what I am doing
This commit is contained in:
@@ -101,7 +101,19 @@ def parse_axis_filters(axis_args):
|
||||
value = value.strip()
|
||||
if not name or not value:
|
||||
raise ValueError("Axis filter must be NAME=VALUE: {}".format(axis_arg))
|
||||
display_value = value
|
||||
|
||||
values = []
|
||||
display_values = []
|
||||
if value.startswith("[") and value.endswith("]"):
|
||||
inner = value[1:-1].strip()
|
||||
if inner:
|
||||
values = [item.strip() for item in inner.split(",") if item.strip()]
|
||||
else:
|
||||
values = []
|
||||
else:
|
||||
values = [value]
|
||||
display_values = list(values)
|
||||
|
||||
if name.endswith("[pow2]"):
|
||||
name = name[: -len("[pow2]")].strip()
|
||||
if not name:
|
||||
@@ -109,18 +121,28 @@ def parse_axis_filters(axis_args):
|
||||
"Axis filter missing name before [pow2]: {}".format(axis_arg)
|
||||
)
|
||||
try:
|
||||
exponent = int(value)
|
||||
exponents = [int(v) for v in values]
|
||||
except ValueError as exc:
|
||||
raise ValueError(
|
||||
"Axis filter [pow2] value must be integer: {}".format(axis_arg)
|
||||
) from exc
|
||||
value = str(2**exponent)
|
||||
display_value = "2^{}".format(exponent)
|
||||
values = [str(2**exponent) for exponent in exponents]
|
||||
display_values = ["2^{}".format(exponent) for exponent in exponents]
|
||||
|
||||
if not values:
|
||||
raise ValueError(
|
||||
"Axis filter must specify at least one value: {}".format(axis_arg)
|
||||
)
|
||||
|
||||
if len(display_values) == 1:
|
||||
display = "{}={}".format(name, display_values[0])
|
||||
else:
|
||||
display = "{}=[{}]".format(name, ",".join(display_values))
|
||||
filters.append(
|
||||
{
|
||||
"name": name,
|
||||
"value": value,
|
||||
"display": "{}={}".format(name, display_value),
|
||||
"values": values,
|
||||
"display": display,
|
||||
}
|
||||
)
|
||||
return filters
|
||||
@@ -133,7 +155,7 @@ def matches_axis_filters(state, axis_filters):
|
||||
axis_values = state.get("axis_values") or []
|
||||
for axis_filter in axis_filters:
|
||||
filter_name = axis_filter["name"]
|
||||
filter_value = axis_filter["value"]
|
||||
filter_values = axis_filter["values"]
|
||||
matched = False
|
||||
for axis_value in axis_values:
|
||||
if axis_value.get("name") != filter_name:
|
||||
@@ -141,7 +163,7 @@ def matches_axis_filters(state, axis_filters):
|
||||
value = axis_value.get("value")
|
||||
if value is None:
|
||||
continue
|
||||
if str(value) == filter_value:
|
||||
if str(value) in filter_values:
|
||||
matched = True
|
||||
break
|
||||
if not matched:
|
||||
@@ -154,12 +176,16 @@ def strip_axis_filters_from_state_name(state_name, axis_filters):
|
||||
return state_name
|
||||
|
||||
tokens = state_name.split()
|
||||
tokens_to_remove = set(
|
||||
axis_filter["display"]
|
||||
filter_prefixes = set(
|
||||
"{}=".format(axis_filter["name"])
|
||||
for axis_filter in axis_filters
|
||||
if " " not in axis_filter["display"]
|
||||
if len(axis_filter["values"]) == 1
|
||||
)
|
||||
tokens = [token for token in tokens if token not in tokens_to_remove]
|
||||
tokens = [
|
||||
token
|
||||
for token in tokens
|
||||
if not any(token.startswith(prefix) for prefix in filter_prefixes)
|
||||
]
|
||||
return " ".join(tokens)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user