diff --git a/python/cuda/bench/results/_benchmark_result.py b/python/cuda/bench/results/_benchmark_result.py index d0806c7..0bb8e04 100644 --- a/python/cuda/bench/results/_benchmark_result.py +++ b/python/cuda/bench/results/_benchmark_result.py @@ -117,14 +117,18 @@ def parse_summary(summary: dict) -> BenchmarkResultSummary: ) +def get_state_summaries(state: dict) -> list[dict]: + return state.get("summaries") or [] + + def parse_summaries(state: dict) -> dict[str, BenchmarkResultSummary]: return { - summary["tag"]: parse_summary(summary) for summary in state["summaries"] or [] + summary["tag"]: parse_summary(summary) for summary in get_state_summaries(state) } def parse_binary_meta(state: dict, tag: str) -> tuple[int | None, str | None]: - summaries = state["summaries"] + summaries = get_state_summaries(state) if not summaries: return None, None @@ -251,7 +255,17 @@ class SubBenchmarkState: for axis in self.axis_values: axis_name = axis["name"] name = axes_names[axis_name] - value = axes_values[axis_name][axis["value"]] + axis_value_map = axes_values[axis_name] + if "value" in axis: + key = str(axis["value"]) + value = axis_value_map.get(key, key) + else: + input_string = axis.get("input_string") + value = ( + axis_value_map.get(input_string, input_string) + if input_string is not None + else "" + ) self.point[name] = value def __repr__(self) -> str: diff --git a/python/scripts/nvbench_json_summary.py b/python/scripts/nvbench_json_summary.py index cea0292..d582a55 100644 --- a/python/scripts/nvbench_json_summary.py +++ b/python/scripts/nvbench_json_summary.py @@ -210,7 +210,7 @@ def add_state_row( table.add_cell(row, f"axis:{header}", header, value) for summary in state.summaries.values(): - if summary.hide is not None: + if summary.hide: continue header = summary.name if summary.name is not None else summary.tag table.add_cell(row, summary.tag, header, format_summary(summary)) diff --git a/python/test/test_benchmark_result.py b/python/test/test_benchmark_result.py index 95fa8f4..cbdfefb 100644 --- a/python/test/test_benchmark_result.py +++ b/python/test/test_benchmark_result.py @@ -375,6 +375,76 @@ def test_benchmark_result_accepts_axis_value_input_string(): assert state.point == {"Duration": "0"} +def test_benchmark_result_normalizes_axis_value_lookup_key(): + result = results.SubBenchmarkResult( + { + "name": "num_blocks", + "axes": [ + { + "name": "NumBlocks", + "type": "int64", + "flags": "", + "values": [ + { + "input_string": "64", + "description": "", + "value": 64, + }, + { + "input_string": "default", + "description": "", + "value": None, + }, + ], + } + ], + "states": [ + { + "name": "Device=0 NumBlocks=64", + "axis_values": [ + { + "name": "NumBlocks", + "type": "int64", + "value": 64, + } + ], + "summaries": [], + "is_skipped": False, + }, + { + "name": "Device=0 NumBlocks=default", + "axis_values": [ + { + "name": "NumBlocks", + "type": "int64", + "value": None, + } + ], + "summaries": [], + "is_skipped": False, + }, + { + "name": "Device=0 NumBlocks=64", + "axis_values": [ + { + "name": "NumBlocks", + "type": "int64", + "input_string": "64", + } + ], + "summaries": [], + "is_skipped": False, + }, + ], + }, + "", + ) + + assert result.states[0].point == {"NumBlocks": "64"} + assert result.states[1].point == {"NumBlocks": "default"} + assert result.states[2].point == {"NumBlocks": "64"} + + def test_benchmark_result_ignores_skipped_state_with_no_summaries(): result = results.SubBenchmarkResult( { @@ -414,6 +484,36 @@ def test_benchmark_result_ignores_skipped_state_with_no_summaries(): assert result.states[0].name() == "BlockSize[pow2]=6" +def test_benchmark_result_uses_empty_summaries_when_field_is_missing(): + result = results.SubBenchmarkResult( + { + "name": "copy_sweep_grid_shape", + "axes": [block_size_axis(8)], + "states": [ + { + "name": "Device=0 BlockSize=2^8", + "axis_values": [ + { + "name": "BlockSize", + "type": "int64", + "value": "256", + } + ], + "is_skipped": False, + }, + ], + }, + "", + ) + + state = result.states[0] + assert state.name() == "BlockSize[pow2]=8" + assert state.summaries == {} + assert state.samples is None + assert state.frequencies is None + assert state.bw is None + + def test_benchmark_result_uses_none_for_unavailable_samples(tmp_path): json_fn = tmp_path / "result.json" write_json( diff --git a/python/test/test_nvbench_json_summary.py b/python/test/test_nvbench_json_summary.py index d0e6b70..0507ae0 100644 --- a/python/test/test_nvbench_json_summary.py +++ b/python/test/test_nvbench_json_summary.py @@ -127,6 +127,7 @@ def write_result_json(path): "tag": "nv/cold/bw/global/utilization", "name": "BWUtil", "hint": "percentage", + "hide": False, "data": [ { "name": "value",