mirror of
https://github.com/NVIDIA/nvbench.git
synced 2026-05-13 17:55:39 +00:00
Address code rabbit review feedback
This commit is contained in:
@@ -117,14 +117,18 @@ def parse_summary(summary: dict) -> BenchmarkResultSummary:
|
||||
)
|
||||
|
||||
|
||||
def get_state_summaries(state: dict) -> list[dict]:
|
||||
return state.get("summaries") or []
|
||||
|
||||
|
||||
def parse_summaries(state: dict) -> dict[str, BenchmarkResultSummary]:
|
||||
return {
|
||||
summary["tag"]: parse_summary(summary) for summary in state["summaries"] or []
|
||||
summary["tag"]: parse_summary(summary) for summary in get_state_summaries(state)
|
||||
}
|
||||
|
||||
|
||||
def parse_binary_meta(state: dict, tag: str) -> tuple[int | None, str | None]:
|
||||
summaries = state["summaries"]
|
||||
summaries = get_state_summaries(state)
|
||||
if not summaries:
|
||||
return None, None
|
||||
|
||||
@@ -251,7 +255,17 @@ class SubBenchmarkState:
|
||||
for axis in self.axis_values:
|
||||
axis_name = axis["name"]
|
||||
name = axes_names[axis_name]
|
||||
value = axes_values[axis_name][axis["value"]]
|
||||
axis_value_map = axes_values[axis_name]
|
||||
if "value" in axis:
|
||||
key = str(axis["value"])
|
||||
value = axis_value_map.get(key, key)
|
||||
else:
|
||||
input_string = axis.get("input_string")
|
||||
value = (
|
||||
axis_value_map.get(input_string, input_string)
|
||||
if input_string is not None
|
||||
else ""
|
||||
)
|
||||
self.point[name] = value
|
||||
|
||||
def __repr__(self) -> str:
|
||||
|
||||
@@ -210,7 +210,7 @@ def add_state_row(
|
||||
table.add_cell(row, f"axis:{header}", header, value)
|
||||
|
||||
for summary in state.summaries.values():
|
||||
if summary.hide is not None:
|
||||
if summary.hide:
|
||||
continue
|
||||
header = summary.name if summary.name is not None else summary.tag
|
||||
table.add_cell(row, summary.tag, header, format_summary(summary))
|
||||
|
||||
@@ -375,6 +375,76 @@ def test_benchmark_result_accepts_axis_value_input_string():
|
||||
assert state.point == {"Duration": "0"}
|
||||
|
||||
|
||||
def test_benchmark_result_normalizes_axis_value_lookup_key():
|
||||
result = results.SubBenchmarkResult(
|
||||
{
|
||||
"name": "num_blocks",
|
||||
"axes": [
|
||||
{
|
||||
"name": "NumBlocks",
|
||||
"type": "int64",
|
||||
"flags": "",
|
||||
"values": [
|
||||
{
|
||||
"input_string": "64",
|
||||
"description": "",
|
||||
"value": 64,
|
||||
},
|
||||
{
|
||||
"input_string": "default",
|
||||
"description": "",
|
||||
"value": None,
|
||||
},
|
||||
],
|
||||
}
|
||||
],
|
||||
"states": [
|
||||
{
|
||||
"name": "Device=0 NumBlocks=64",
|
||||
"axis_values": [
|
||||
{
|
||||
"name": "NumBlocks",
|
||||
"type": "int64",
|
||||
"value": 64,
|
||||
}
|
||||
],
|
||||
"summaries": [],
|
||||
"is_skipped": False,
|
||||
},
|
||||
{
|
||||
"name": "Device=0 NumBlocks=default",
|
||||
"axis_values": [
|
||||
{
|
||||
"name": "NumBlocks",
|
||||
"type": "int64",
|
||||
"value": None,
|
||||
}
|
||||
],
|
||||
"summaries": [],
|
||||
"is_skipped": False,
|
||||
},
|
||||
{
|
||||
"name": "Device=0 NumBlocks=64",
|
||||
"axis_values": [
|
||||
{
|
||||
"name": "NumBlocks",
|
||||
"type": "int64",
|
||||
"input_string": "64",
|
||||
}
|
||||
],
|
||||
"summaries": [],
|
||||
"is_skipped": False,
|
||||
},
|
||||
],
|
||||
},
|
||||
"",
|
||||
)
|
||||
|
||||
assert result.states[0].point == {"NumBlocks": "64"}
|
||||
assert result.states[1].point == {"NumBlocks": "default"}
|
||||
assert result.states[2].point == {"NumBlocks": "64"}
|
||||
|
||||
|
||||
def test_benchmark_result_ignores_skipped_state_with_no_summaries():
|
||||
result = results.SubBenchmarkResult(
|
||||
{
|
||||
@@ -414,6 +484,36 @@ def test_benchmark_result_ignores_skipped_state_with_no_summaries():
|
||||
assert result.states[0].name() == "BlockSize[pow2]=6"
|
||||
|
||||
|
||||
def test_benchmark_result_uses_empty_summaries_when_field_is_missing():
|
||||
result = results.SubBenchmarkResult(
|
||||
{
|
||||
"name": "copy_sweep_grid_shape",
|
||||
"axes": [block_size_axis(8)],
|
||||
"states": [
|
||||
{
|
||||
"name": "Device=0 BlockSize=2^8",
|
||||
"axis_values": [
|
||||
{
|
||||
"name": "BlockSize",
|
||||
"type": "int64",
|
||||
"value": "256",
|
||||
}
|
||||
],
|
||||
"is_skipped": False,
|
||||
},
|
||||
],
|
||||
},
|
||||
"",
|
||||
)
|
||||
|
||||
state = result.states[0]
|
||||
assert state.name() == "BlockSize[pow2]=8"
|
||||
assert state.summaries == {}
|
||||
assert state.samples is None
|
||||
assert state.frequencies is None
|
||||
assert state.bw is None
|
||||
|
||||
|
||||
def test_benchmark_result_uses_none_for_unavailable_samples(tmp_path):
|
||||
json_fn = tmp_path / "result.json"
|
||||
write_json(
|
||||
|
||||
@@ -127,6 +127,7 @@ def write_result_json(path):
|
||||
"tag": "nv/cold/bw/global/utilization",
|
||||
"name": "BWUtil",
|
||||
"hint": "percentage",
|
||||
"hide": False,
|
||||
"data": [
|
||||
{
|
||||
"name": "value",
|
||||
|
||||
Reference in New Issue
Block a user