Address code rabbit review feedback

This commit is contained in:
Oleksandr Pavlyk
2026-05-12 15:22:57 -05:00
parent fae9dfca18
commit 151e463fb9
4 changed files with 119 additions and 4 deletions

View File

@@ -117,14 +117,18 @@ def parse_summary(summary: dict) -> BenchmarkResultSummary:
)
def get_state_summaries(state: dict) -> list[dict]:
return state.get("summaries") or []
def parse_summaries(state: dict) -> dict[str, BenchmarkResultSummary]:
return {
summary["tag"]: parse_summary(summary) for summary in state["summaries"] or []
summary["tag"]: parse_summary(summary) for summary in get_state_summaries(state)
}
def parse_binary_meta(state: dict, tag: str) -> tuple[int | None, str | None]:
summaries = state["summaries"]
summaries = get_state_summaries(state)
if not summaries:
return None, None
@@ -251,7 +255,17 @@ class SubBenchmarkState:
for axis in self.axis_values:
axis_name = axis["name"]
name = axes_names[axis_name]
value = axes_values[axis_name][axis["value"]]
axis_value_map = axes_values[axis_name]
if "value" in axis:
key = str(axis["value"])
value = axis_value_map.get(key, key)
else:
input_string = axis.get("input_string")
value = (
axis_value_map.get(input_string, input_string)
if input_string is not None
else ""
)
self.point[name] = value
def __repr__(self) -> str:

View File

@@ -210,7 +210,7 @@ def add_state_row(
table.add_cell(row, f"axis:{header}", header, value)
for summary in state.summaries.values():
if summary.hide is not None:
if summary.hide:
continue
header = summary.name if summary.name is not None else summary.tag
table.add_cell(row, summary.tag, header, format_summary(summary))

View File

@@ -375,6 +375,76 @@ def test_benchmark_result_accepts_axis_value_input_string():
assert state.point == {"Duration": "0"}
def test_benchmark_result_normalizes_axis_value_lookup_key():
result = results.SubBenchmarkResult(
{
"name": "num_blocks",
"axes": [
{
"name": "NumBlocks",
"type": "int64",
"flags": "",
"values": [
{
"input_string": "64",
"description": "",
"value": 64,
},
{
"input_string": "default",
"description": "",
"value": None,
},
],
}
],
"states": [
{
"name": "Device=0 NumBlocks=64",
"axis_values": [
{
"name": "NumBlocks",
"type": "int64",
"value": 64,
}
],
"summaries": [],
"is_skipped": False,
},
{
"name": "Device=0 NumBlocks=default",
"axis_values": [
{
"name": "NumBlocks",
"type": "int64",
"value": None,
}
],
"summaries": [],
"is_skipped": False,
},
{
"name": "Device=0 NumBlocks=64",
"axis_values": [
{
"name": "NumBlocks",
"type": "int64",
"input_string": "64",
}
],
"summaries": [],
"is_skipped": False,
},
],
},
"",
)
assert result.states[0].point == {"NumBlocks": "64"}
assert result.states[1].point == {"NumBlocks": "default"}
assert result.states[2].point == {"NumBlocks": "64"}
def test_benchmark_result_ignores_skipped_state_with_no_summaries():
result = results.SubBenchmarkResult(
{
@@ -414,6 +484,36 @@ def test_benchmark_result_ignores_skipped_state_with_no_summaries():
assert result.states[0].name() == "BlockSize[pow2]=6"
def test_benchmark_result_uses_empty_summaries_when_field_is_missing():
result = results.SubBenchmarkResult(
{
"name": "copy_sweep_grid_shape",
"axes": [block_size_axis(8)],
"states": [
{
"name": "Device=0 BlockSize=2^8",
"axis_values": [
{
"name": "BlockSize",
"type": "int64",
"value": "256",
}
],
"is_skipped": False,
},
],
},
"",
)
state = result.states[0]
assert state.name() == "BlockSize[pow2]=8"
assert state.summaries == {}
assert state.samples is None
assert state.frequencies is None
assert state.bw is None
def test_benchmark_result_uses_none_for_unavailable_samples(tmp_path):
json_fn = tmp_path / "result.json"
write_json(

View File

@@ -127,6 +127,7 @@ def write_result_json(path):
"tag": "nv/cold/bw/global/utilization",
"name": "BWUtil",
"hint": "percentage",
"hide": False,
"data": [
{
"name": "value",