mirror of
https://github.com/NVIDIA/nvbench.git
synced 2026-05-13 17:55:39 +00:00
Improve Python BenchResult parsing and container APIs
Add arbitrary BenchResult metadata and explicit parse control, replacing
the previous code/elapsed fields. Make BenchResult subscriptable by
subbenchmark name and make SubBenchResult list-like over its states.
Extend SubBenchState parsing to expose summaries by tag, read paired
sample frequency data, return None for unavailable sample/frequency
files, and validate matching sample/frequency lengths.
Harden parsing for NVBench JSON output with no-axis benchmarks, null
axis_values, skipped states with null summaries, float axis input_string
lookups, and recorded sidecar binary paths.
Expand BenchResult tests to cover metadata, parse=False, sequence-style
access, frequency-aware centers, missing binary data, skipped states,
and mismatched sample/frequency counts.
Example usage:
```
import array, numpy as np, cuda.bench
r = cuda.bench.BenchResult("perf_data/axes_run1.json")
r["copy_sweep_grid_shape"].centers_with_frequencies(
lambda t, f: np.median(np.asarray(t)*np.asarray(f)))
```
This commit is contained in:
@@ -26,8 +26,21 @@
|
||||
# with definitions given here.
|
||||
|
||||
from array import array
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing import Optional, Self, SupportsFloat, SupportsInt, Union
|
||||
from collections.abc import Callable, Iterator, Sequence
|
||||
from typing import (
|
||||
Any,
|
||||
Optional,
|
||||
Self,
|
||||
SupportsFloat,
|
||||
SupportsInt,
|
||||
TypeVar,
|
||||
Union,
|
||||
overload,
|
||||
)
|
||||
|
||||
ResultT = TypeVar("ResultT")
|
||||
_SummaryValue = int | float | str
|
||||
_SummaryData = _SummaryValue | dict[str, _SummaryValue]
|
||||
|
||||
class CudaStream:
|
||||
def __cuda_stream__(self) -> tuple[int, int]: ...
|
||||
@@ -119,25 +132,47 @@ def run_all_benchmarks(argv: Sequence[str]) -> None: ...
|
||||
class NVBenchRuntimeError(RuntimeError): ...
|
||||
|
||||
class SubBenchState:
|
||||
samples: array
|
||||
state_name: str
|
||||
summaries: dict[str, _SummaryData]
|
||||
samples: array | None
|
||||
frequencies: array | None
|
||||
bw: float | None
|
||||
point: dict[str, str]
|
||||
def name(self) -> str: ...
|
||||
def center(self, estimator: Callable[[array], SupportsFloat]) -> SupportsFloat: ...
|
||||
def center(self, estimator: Callable[[array], ResultT]) -> ResultT | None: ...
|
||||
def center_with_frequencies(
|
||||
self, estimator: Callable[[array, array], ResultT]
|
||||
) -> ResultT | None: ...
|
||||
|
||||
class SubBenchResult:
|
||||
states: list[SubBenchState]
|
||||
def __len__(self) -> int: ...
|
||||
@overload
|
||||
def __getitem__(self, state_index: int) -> SubBenchState: ...
|
||||
@overload
|
||||
def __getitem__(self, state_index: slice) -> list[SubBenchState]: ...
|
||||
def __iter__(self) -> Iterator[SubBenchState]: ...
|
||||
def centers(
|
||||
self, estimator: Callable[[array], SupportsFloat]
|
||||
) -> dict[str, SupportsFloat]: ...
|
||||
self, estimator: Callable[[array], ResultT]
|
||||
) -> dict[str, ResultT | None]: ...
|
||||
def centers_with_frequencies(
|
||||
self, estimator: Callable[[array, array], ResultT]
|
||||
) -> dict[str, ResultT | None]: ...
|
||||
|
||||
class BenchResult:
|
||||
code: int
|
||||
elapsed: float
|
||||
metadata: Any
|
||||
subbenches: dict[str, SubBenchResult]
|
||||
def __init__(
|
||||
self, json_fn: str, *, code: int = 0, elapsed: float = 0.0
|
||||
self,
|
||||
json_fn: str | None = None,
|
||||
*,
|
||||
metadata: Any = None,
|
||||
parse: bool = True,
|
||||
) -> None: ...
|
||||
def __getitem__(self, subbench_name: str) -> SubBenchResult: ...
|
||||
def centers(
|
||||
self, estimator: Callable[[array], SupportsFloat]
|
||||
) -> dict[str, dict[str, SupportsFloat]]: ...
|
||||
self, estimator: Callable[[array], ResultT]
|
||||
) -> dict[str, dict[str, ResultT | None]]: ...
|
||||
def centers_with_frequencies(
|
||||
self, estimator: Callable[[array, array], ResultT]
|
||||
) -> dict[str, dict[str, ResultT | None]]: ...
|
||||
|
||||
@@ -18,10 +18,15 @@ import array
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from typing import Callable, SupportsFloat
|
||||
from collections.abc import Iterator
|
||||
from typing import Any, Callable, TypeVar
|
||||
|
||||
__all__ = ["BenchResult", "SubBenchResult", "SubBenchState"]
|
||||
|
||||
ResultT = TypeVar("ResultT")
|
||||
_SummaryValue = int | float | str
|
||||
_SummaryData = _SummaryValue | dict[str, _SummaryValue]
|
||||
|
||||
|
||||
def read_json(filename: str) -> dict:
|
||||
with open(filename, "r", encoding="utf-8") as f:
|
||||
@@ -50,13 +55,41 @@ def extract_bw(summary: dict) -> float:
|
||||
return float(value_data["value"])
|
||||
|
||||
|
||||
def parse_samples_meta(state: dict) -> tuple[int | None, str | None]:
|
||||
def parse_summary_value(value_data: dict) -> _SummaryValue:
|
||||
value_type = value_data["type"]
|
||||
value = value_data["value"]
|
||||
if value_type == "int64":
|
||||
return int(value)
|
||||
if value_type == "float64":
|
||||
return float(value)
|
||||
if value_type == "string":
|
||||
return value
|
||||
raise ValueError(f"unsupported summary value type: {value_type}")
|
||||
|
||||
|
||||
def parse_summary_data(summary: dict) -> _SummaryData:
|
||||
summary_values = {
|
||||
value_data["name"]: parse_summary_value(value_data)
|
||||
for value_data in summary["data"]
|
||||
}
|
||||
if len(summary_values) == 1 and "value" in summary_values:
|
||||
return summary_values["value"]
|
||||
return summary_values
|
||||
|
||||
|
||||
def parse_summaries(state: dict) -> dict[str, _SummaryData]:
|
||||
return {
|
||||
summary["tag"]: parse_summary_data(summary) for summary in state["summaries"]
|
||||
}
|
||||
|
||||
|
||||
def parse_binary_meta(state: dict, tag: str) -> tuple[int | None, str | None]:
|
||||
summaries = state["summaries"]
|
||||
if not summaries:
|
||||
return None, None
|
||||
|
||||
summary = next(
|
||||
filter(lambda s: s["tag"] == "nv/json/bin:nv/cold/sample_times", summaries),
|
||||
filter(lambda s: s["tag"] == tag, summaries),
|
||||
None,
|
||||
)
|
||||
if not summary:
|
||||
@@ -67,40 +100,72 @@ def parse_samples_meta(state: dict) -> tuple[int | None, str | None]:
|
||||
return sample_count, sample_filename
|
||||
|
||||
|
||||
def resolve_sample_filename(json_dir: str, samples_filename: str) -> str:
|
||||
if os.path.isabs(samples_filename):
|
||||
return samples_filename
|
||||
return os.path.join(json_dir, samples_filename)
|
||||
def parse_samples_meta(state: dict) -> tuple[int | None, str | None]:
|
||||
return parse_binary_meta(state, "nv/json/bin:nv/cold/sample_times")
|
||||
|
||||
|
||||
def parse_samples(state: dict, json_dir: str) -> array.array:
|
||||
"""Return the state's sample times as an array of float32 values."""
|
||||
sample_count, samples_filename = parse_samples_meta(state)
|
||||
if sample_count is None or samples_filename is None:
|
||||
return array.array("f", [])
|
||||
def parse_frequencies_meta(state: dict) -> tuple[int | None, str | None]:
|
||||
return parse_binary_meta(state, "nv/json/freqs-bin:nv/cold/sample_freqs")
|
||||
|
||||
samples = array.array("f")
|
||||
if samples.itemsize != 4:
|
||||
|
||||
def resolve_binary_filename(json_dir: str, binary_filename: str) -> str:
|
||||
if os.path.isabs(binary_filename):
|
||||
return binary_filename
|
||||
|
||||
json_relative_filename = os.path.join(json_dir, binary_filename)
|
||||
if os.path.exists(json_relative_filename):
|
||||
return json_relative_filename
|
||||
|
||||
parent_relative_filename = os.path.join(os.path.dirname(json_dir), binary_filename)
|
||||
if os.path.exists(parent_relative_filename):
|
||||
return parent_relative_filename
|
||||
|
||||
if os.path.exists(binary_filename):
|
||||
return binary_filename
|
||||
|
||||
return json_relative_filename
|
||||
|
||||
|
||||
def parse_float32_binary(
|
||||
count: int | None, filename: str | None, json_dir: str
|
||||
) -> array.array | None:
|
||||
if count is None or filename is None:
|
||||
return None
|
||||
|
||||
values = array.array("f")
|
||||
if values.itemsize != 4:
|
||||
raise RuntimeError("array('f') is not a 32-bit float on this platform")
|
||||
|
||||
samples_filename = resolve_sample_filename(json_dir, samples_filename)
|
||||
with open(samples_filename, "rb") as f:
|
||||
size = os.fstat(f.fileno()).st_size
|
||||
if size % samples.itemsize:
|
||||
raise ValueError("file size is not a multiple of float size")
|
||||
filename = resolve_binary_filename(json_dir, filename)
|
||||
try:
|
||||
with open(filename, "rb") as f:
|
||||
size = os.fstat(f.fileno()).st_size
|
||||
if size % values.itemsize:
|
||||
raise ValueError("file size is not a multiple of float size")
|
||||
|
||||
samples.fromfile(f, size // samples.itemsize)
|
||||
values.fromfile(f, size // values.itemsize)
|
||||
except FileNotFoundError:
|
||||
return None
|
||||
|
||||
# Match np.fromfile(fn, "<f4"): little-endian float32.
|
||||
if sys.byteorder != "little":
|
||||
samples.byteswap()
|
||||
values.byteswap()
|
||||
|
||||
if sample_count != len(samples):
|
||||
raise ValueError(
|
||||
f"expected {sample_count} samples in {samples_filename}, "
|
||||
f"found {len(samples)}"
|
||||
)
|
||||
return samples
|
||||
if count != len(values):
|
||||
raise ValueError(f"expected {count} values in {filename}, found {len(values)}")
|
||||
return values
|
||||
|
||||
|
||||
def parse_samples(state: dict, json_dir: str) -> array.array | None:
|
||||
"""Return the state's sample times, or None if sample data is unavailable."""
|
||||
sample_count, samples_filename = parse_samples_meta(state)
|
||||
return parse_float32_binary(sample_count, samples_filename, json_dir)
|
||||
|
||||
|
||||
def parse_frequencies(state: dict, json_dir: str) -> array.array | None:
|
||||
"""Return the state's sample frequencies, or None if data is unavailable."""
|
||||
frequency_count, frequencies_filename = parse_frequencies_meta(state)
|
||||
return parse_float32_binary(frequency_count, frequencies_filename, json_dir)
|
||||
|
||||
|
||||
def parse_bw(state: dict) -> float | None:
|
||||
@@ -125,11 +190,23 @@ def get_axis_name(axis: dict) -> str:
|
||||
|
||||
class SubBenchState:
|
||||
def __init__(self, state: dict, axes_names: dict, axes_values: dict, json_dir: str):
|
||||
self.state_name = state["name"]
|
||||
self.summaries = parse_summaries(state)
|
||||
self.samples = parse_samples(state, json_dir)
|
||||
self.frequencies = parse_frequencies(state, json_dir)
|
||||
if (
|
||||
self.samples is not None
|
||||
and self.frequencies is not None
|
||||
and len(self.samples) != len(self.frequencies)
|
||||
):
|
||||
raise ValueError(
|
||||
f"sample count ({len(self.samples)}) does not match "
|
||||
f"frequency count ({len(self.frequencies)})"
|
||||
)
|
||||
self.bw = parse_bw(state)
|
||||
|
||||
self.point = {}
|
||||
for axis in state["axis_values"]:
|
||||
for axis in state["axis_values"] or []:
|
||||
axis_name = axis["name"]
|
||||
name = axes_names[axis_name]
|
||||
value = axes_values[axis_name][axis["value"]]
|
||||
@@ -139,27 +216,36 @@ class SubBenchState:
|
||||
return str(self.__dict__)
|
||||
|
||||
def name(self) -> str:
|
||||
if not self.point:
|
||||
return self.state_name
|
||||
return " ".join(f"{k}={v}" for k, v in self.point.items())
|
||||
|
||||
def center(
|
||||
self, estimator: Callable[[array.array], SupportsFloat]
|
||||
) -> SupportsFloat:
|
||||
def center(self, estimator: Callable[[array.array], ResultT]) -> ResultT | None:
|
||||
if self.samples is None:
|
||||
return None
|
||||
return estimator(self.samples)
|
||||
|
||||
def center_with_frequencies(
|
||||
self, estimator: Callable[[array.array, array.array], ResultT]
|
||||
) -> ResultT | None:
|
||||
if self.samples is None or self.frequencies is None:
|
||||
return None
|
||||
return estimator(self.samples, self.frequencies)
|
||||
|
||||
|
||||
class SubBenchResult:
|
||||
def __init__(self, bench: dict, json_dir: str):
|
||||
axes_names = {}
|
||||
axes_values = {}
|
||||
for axis in bench["axes"]:
|
||||
for axis in bench["axes"] or []:
|
||||
short_name = axis["name"]
|
||||
full_name = get_axis_name(axis)
|
||||
this_axis_values = {}
|
||||
for value in axis["values"]:
|
||||
input_string = value["input_string"]
|
||||
this_axis_values[input_string] = input_string
|
||||
if "value" in value:
|
||||
this_axis_values[str(value["value"])] = value["input_string"]
|
||||
else:
|
||||
this_axis_values[value["input_string"]] = value["input_string"]
|
||||
this_axis_values[str(value["value"])] = input_string
|
||||
axes_names[short_name] = full_name
|
||||
axes_values[short_name] = this_axis_values
|
||||
|
||||
@@ -173,37 +259,73 @@ class SubBenchResult:
|
||||
def __repr__(self) -> str:
|
||||
return str(self.__dict__)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.states)
|
||||
|
||||
def __getitem__(
|
||||
self, state_index: int | slice
|
||||
) -> SubBenchState | list[SubBenchState]:
|
||||
return self.states[state_index]
|
||||
|
||||
def __iter__(self) -> Iterator[SubBenchState]:
|
||||
return iter(self.states)
|
||||
|
||||
def centers(
|
||||
self, estimator: Callable[[array.array], SupportsFloat]
|
||||
) -> dict[str, SupportsFloat]:
|
||||
self, estimator: Callable[[array.array], ResultT]
|
||||
) -> dict[str, ResultT | None]:
|
||||
result = {}
|
||||
for state in self.states:
|
||||
result[state.name()] = state.center(estimator)
|
||||
return result
|
||||
|
||||
def centers_with_frequencies(
|
||||
self, estimator: Callable[[array.array, array.array], ResultT]
|
||||
) -> dict[str, ResultT | None]:
|
||||
result = {}
|
||||
for state in self.states:
|
||||
result[state.name()] = state.center_with_frequencies(estimator)
|
||||
return result
|
||||
|
||||
|
||||
class BenchResult:
|
||||
"""Parsed result data from an NVBench JSON output file."""
|
||||
|
||||
def __init__(self, json_fn: str, *, code: int = 0, elapsed: float = 0.0):
|
||||
self.code = code
|
||||
self.elapsed = elapsed
|
||||
def __init__(
|
||||
self,
|
||||
json_fn: str | None = None,
|
||||
*,
|
||||
metadata: Any = None,
|
||||
parse: bool = True,
|
||||
):
|
||||
self.metadata = metadata
|
||||
self.subbenches: dict[str, SubBenchResult] = {}
|
||||
|
||||
if json_fn:
|
||||
if code == 0:
|
||||
json_dir = os.path.dirname(os.path.abspath(json_fn))
|
||||
for bench in read_json(json_fn)["benchmarks"]:
|
||||
bench_name: str = bench["name"]
|
||||
self.subbenches[bench_name] = SubBenchResult(bench, json_dir)
|
||||
if json_fn and parse:
|
||||
json_dir = os.path.dirname(os.path.abspath(json_fn))
|
||||
for bench in read_json(json_fn)["benchmarks"]:
|
||||
bench_name: str = bench["name"]
|
||||
self.subbenches[bench_name] = SubBenchResult(bench, json_dir)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return str(self.__dict__)
|
||||
|
||||
def __getitem__(self, subbench_name: str) -> SubBenchResult:
|
||||
return self.subbenches[subbench_name]
|
||||
|
||||
def centers(
|
||||
self, estimator: Callable[[array.array], SupportsFloat]
|
||||
) -> dict[str, dict[str, SupportsFloat]]:
|
||||
self, estimator: Callable[[array.array], ResultT]
|
||||
) -> dict[str, dict[str, ResultT | None]]:
|
||||
result = {}
|
||||
for subbench in self.subbenches:
|
||||
result[subbench] = self.subbenches[subbench].centers(estimator)
|
||||
return result
|
||||
|
||||
def centers_with_frequencies(
|
||||
self, estimator: Callable[[array.array, array.array], ResultT]
|
||||
) -> dict[str, dict[str, ResultT | None]]:
|
||||
result = {}
|
||||
for subbench in self.subbenches:
|
||||
result[subbench] = self.subbenches[subbench].centers_with_frequencies(
|
||||
estimator
|
||||
)
|
||||
return result
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import json
|
||||
import struct
|
||||
from dataclasses import dataclass
|
||||
|
||||
import cuda.bench as bench
|
||||
import pytest
|
||||
@@ -9,6 +10,9 @@ def test_bench_result_reads_jsonbin_relative_to_json_path(tmp_path):
|
||||
bin_dir = tmp_path / "result.json-bin"
|
||||
bin_dir.mkdir()
|
||||
(bin_dir / "0.bin").write_bytes(struct.pack("<3f", 1.0, 2.0, 4.0))
|
||||
freq_bin_dir = tmp_path / "result.json-freqs-bin"
|
||||
freq_bin_dir.mkdir()
|
||||
(freq_bin_dir / "0.bin").write_bytes(struct.pack("<3f", 100.0, 200.0, 400.0))
|
||||
|
||||
json_fn = tmp_path / "result.json"
|
||||
json_fn.write_text(
|
||||
@@ -67,6 +71,21 @@ def test_bench_result_reads_jsonbin_relative_to_json_path(tmp_path):
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"tag": "nv/json/freqs-bin:nv/cold/sample_freqs",
|
||||
"data": [
|
||||
{
|
||||
"name": "filename",
|
||||
"type": "string",
|
||||
"value": "result.json-freqs-bin/0.bin",
|
||||
},
|
||||
{
|
||||
"name": "size",
|
||||
"type": "int64",
|
||||
"value": "3",
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
"is_skipped": False,
|
||||
}
|
||||
@@ -78,24 +97,452 @@ def test_bench_result_reads_jsonbin_relative_to_json_path(tmp_path):
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
metadata = {"returncode": 0, "elapsed_seconds": 0.25}
|
||||
default_result = bench.BenchResult(str(json_fn))
|
||||
result = bench.BenchResult(str(json_fn), elapsed=0.25)
|
||||
result = bench.BenchResult(str(json_fn), metadata=metadata)
|
||||
|
||||
assert bench.BenchResult.__module__ == bench.__name__
|
||||
assert default_result.code == 0
|
||||
assert default_result.elapsed == 0.0
|
||||
assert result.code == 0
|
||||
assert result.elapsed == 0.25
|
||||
state = result.subbenches["copy"].states[0]
|
||||
assert default_result.metadata is None
|
||||
assert result.metadata is metadata
|
||||
subbench = result["copy"]
|
||||
state = subbench[0]
|
||||
assert len(subbench) == 1
|
||||
assert subbench[-1] is state
|
||||
assert subbench[:] == subbench.states
|
||||
assert list(subbench) == subbench.states
|
||||
with pytest.raises(IndexError):
|
||||
subbench[1]
|
||||
assert state.name() == "BlockSize[pow2]=8"
|
||||
assert state.bw == 0.75
|
||||
assert state.summaries["nv/cold/bw/global/utilization"] == pytest.approx(0.75)
|
||||
assert state.summaries["nv/json/bin:nv/cold/sample_times"] == {
|
||||
"filename": "result.json-bin/0.bin",
|
||||
"size": 3,
|
||||
}
|
||||
assert state.summaries["nv/json/freqs-bin:nv/cold/sample_freqs"] == {
|
||||
"filename": "result.json-freqs-bin/0.bin",
|
||||
"size": 3,
|
||||
}
|
||||
assert state.samples is not None
|
||||
assert list(state.samples) == pytest.approx([1.0, 2.0, 4.0])
|
||||
assert state.frequencies is not None
|
||||
assert list(state.frequencies) == pytest.approx([100.0, 200.0, 400.0])
|
||||
centers = result.centers(lambda samples: sum(samples) / len(samples))
|
||||
assert set(centers) == {"copy"}
|
||||
assert set(centers["copy"]) == {"BlockSize[pow2]=8"}
|
||||
assert centers["copy"]["BlockSize[pow2]=8"] == pytest.approx(7.0 / 3.0)
|
||||
|
||||
def weighted_mean(samples, frequencies):
|
||||
return sum(
|
||||
sample * frequency for sample, frequency in zip(samples, frequencies)
|
||||
) / sum(frequencies)
|
||||
|
||||
def test_bench_result_code_and_elapsed_are_keyword_only():
|
||||
weighted_centers = result.centers_with_frequencies(weighted_mean)
|
||||
assert set(weighted_centers) == {"copy"}
|
||||
assert set(weighted_centers["copy"]) == {"BlockSize[pow2]=8"}
|
||||
assert weighted_centers["copy"]["BlockSize[pow2]=8"] == pytest.approx(3.0)
|
||||
assert subbench is result.subbenches["copy"]
|
||||
assert subbench.centers_with_frequencies(weighted_mean) == weighted_centers["copy"]
|
||||
with pytest.raises(KeyError):
|
||||
result["missing"]
|
||||
|
||||
|
||||
def test_bench_result_metadata_and_parse_are_keyword_only():
|
||||
with pytest.raises(TypeError):
|
||||
bench.BenchResult("", 0, 0.0)
|
||||
bench.BenchResult("", None)
|
||||
with pytest.raises(TypeError):
|
||||
bench.BenchResult("", None, False)
|
||||
|
||||
|
||||
def test_bench_result_parse_false_does_not_read_json(tmp_path):
|
||||
@dataclass
|
||||
class RunMetadata:
|
||||
returncode: int
|
||||
elapsed_seconds: float
|
||||
|
||||
metadata = RunMetadata(returncode=1, elapsed_seconds=0.25)
|
||||
missing_json = tmp_path / "missing.json"
|
||||
|
||||
result = bench.BenchResult(str(missing_json), metadata=metadata, parse=False)
|
||||
|
||||
assert result.metadata is metadata
|
||||
assert result.subbenches == {}
|
||||
|
||||
with pytest.raises(FileNotFoundError):
|
||||
bench.BenchResult(str(missing_json), metadata=metadata)
|
||||
|
||||
|
||||
def test_bench_result_accepts_no_axis_benchmark_with_recorded_binary_path(
|
||||
tmp_path, monkeypatch
|
||||
):
|
||||
data_dir = tmp_path / "temp_data"
|
||||
data_dir.mkdir()
|
||||
bin_dir = data_dir / "axes_run1.json-bin"
|
||||
bin_dir.mkdir()
|
||||
(bin_dir / "0.bin").write_bytes(struct.pack("<2f", 1.0, 4.0))
|
||||
freq_bin_dir = data_dir / "axes_run1.json-freqs-bin"
|
||||
freq_bin_dir.mkdir()
|
||||
(freq_bin_dir / "0.bin").write_bytes(struct.pack("<2f", 100.0, 400.0))
|
||||
|
||||
json_fn = data_dir / "axes_run1.json"
|
||||
json_fn.write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"benchmarks": [
|
||||
{
|
||||
"name": "simple",
|
||||
"axes": None,
|
||||
"states": [
|
||||
{
|
||||
"name": "Device=0",
|
||||
"axis_values": None,
|
||||
"summaries": [
|
||||
{
|
||||
"tag": "nv/json/bin:nv/cold/sample_times",
|
||||
"data": [
|
||||
{
|
||||
"name": "filename",
|
||||
"type": "string",
|
||||
"value": "temp_data/axes_run1.json-bin/0.bin",
|
||||
},
|
||||
{
|
||||
"name": "size",
|
||||
"type": "int64",
|
||||
"value": "2",
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
"tag": "nv/json/freqs-bin:nv/cold/sample_freqs",
|
||||
"data": [
|
||||
{
|
||||
"name": "filename",
|
||||
"type": "string",
|
||||
"value": "temp_data/axes_run1.json-freqs-bin/0.bin",
|
||||
},
|
||||
{
|
||||
"name": "size",
|
||||
"type": "int64",
|
||||
"value": "2",
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
"is_skipped": False,
|
||||
}
|
||||
],
|
||||
}
|
||||
]
|
||||
}
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
monkeypatch.chdir(tmp_path)
|
||||
|
||||
result = bench.BenchResult("temp_data/axes_run1.json")
|
||||
|
||||
state = result.subbenches["simple"].states[0]
|
||||
assert state.name() == "Device=0"
|
||||
assert state.point == {}
|
||||
assert state.samples is not None
|
||||
assert list(state.samples) == pytest.approx([1.0, 4.0])
|
||||
assert state.frequencies is not None
|
||||
assert list(state.frequencies) == pytest.approx([100.0, 400.0])
|
||||
|
||||
|
||||
def test_bench_result_accepts_axis_value_input_string():
|
||||
result = bench.SubBenchResult(
|
||||
{
|
||||
"name": "single_float64_axis",
|
||||
"axes": [
|
||||
{
|
||||
"name": "Duration",
|
||||
"type": "float64",
|
||||
"flags": "",
|
||||
"values": [
|
||||
{
|
||||
"input_string": "0",
|
||||
"description": "",
|
||||
"value": 0.0,
|
||||
}
|
||||
],
|
||||
}
|
||||
],
|
||||
"states": [
|
||||
{
|
||||
"name": "Device=0 Duration=0",
|
||||
"axis_values": [
|
||||
{
|
||||
"name": "Duration",
|
||||
"type": "float64",
|
||||
"value": "0",
|
||||
}
|
||||
],
|
||||
"summaries": [],
|
||||
"is_skipped": False,
|
||||
}
|
||||
],
|
||||
},
|
||||
"",
|
||||
)
|
||||
|
||||
state = result.states[0]
|
||||
assert state.name() == "Duration=0"
|
||||
assert state.point == {"Duration": "0"}
|
||||
|
||||
|
||||
def test_bench_result_ignores_skipped_state_with_no_summaries():
|
||||
result = bench.SubBenchResult(
|
||||
{
|
||||
"name": "copy_sweep_grid_shape",
|
||||
"axes": [
|
||||
{
|
||||
"name": "BlockSize",
|
||||
"type": "int64",
|
||||
"flags": "pow2",
|
||||
"values": [
|
||||
{
|
||||
"input_string": "6",
|
||||
"description": "2^6 = 64",
|
||||
"value": 64,
|
||||
},
|
||||
{
|
||||
"input_string": "8",
|
||||
"description": "2^8 = 256",
|
||||
"value": 256,
|
||||
},
|
||||
],
|
||||
}
|
||||
],
|
||||
"states": [
|
||||
{
|
||||
"name": "Device=0 BlockSize=2^8",
|
||||
"axis_values": [
|
||||
{
|
||||
"name": "BlockSize",
|
||||
"type": "int64",
|
||||
"value": "256",
|
||||
}
|
||||
],
|
||||
"summaries": None,
|
||||
"is_skipped": True,
|
||||
},
|
||||
{
|
||||
"name": "Device=0 BlockSize=2^6",
|
||||
"axis_values": [
|
||||
{
|
||||
"name": "BlockSize",
|
||||
"type": "int64",
|
||||
"value": "64",
|
||||
}
|
||||
],
|
||||
"summaries": [],
|
||||
"is_skipped": False,
|
||||
},
|
||||
],
|
||||
},
|
||||
"",
|
||||
)
|
||||
|
||||
assert len(result.states) == 1
|
||||
assert result.states[0].name() == "BlockSize[pow2]=6"
|
||||
|
||||
|
||||
def test_bench_result_uses_none_for_unavailable_samples(tmp_path):
|
||||
json_fn = tmp_path / "result.json"
|
||||
json_fn.write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"benchmarks": [
|
||||
{
|
||||
"name": "copy",
|
||||
"axes": [
|
||||
{
|
||||
"name": "BlockSize",
|
||||
"type": "int64",
|
||||
"flags": "pow2",
|
||||
"values": [
|
||||
{
|
||||
"input_string": "8",
|
||||
"description": "2^8 = 256",
|
||||
"value": 256,
|
||||
},
|
||||
{
|
||||
"input_string": "9",
|
||||
"description": "2^9 = 512",
|
||||
"value": 512,
|
||||
},
|
||||
],
|
||||
}
|
||||
],
|
||||
"states": [
|
||||
{
|
||||
"name": "Device=0 BlockSize=2^8",
|
||||
"axis_values": [
|
||||
{
|
||||
"name": "BlockSize",
|
||||
"type": "int64",
|
||||
"value": "256",
|
||||
}
|
||||
],
|
||||
"summaries": [],
|
||||
"is_skipped": False,
|
||||
},
|
||||
{
|
||||
"name": "Device=0 BlockSize=2^9",
|
||||
"axis_values": [
|
||||
{
|
||||
"name": "BlockSize",
|
||||
"type": "int64",
|
||||
"value": "512",
|
||||
}
|
||||
],
|
||||
"summaries": [
|
||||
{
|
||||
"tag": "nv/json/bin:nv/cold/sample_times",
|
||||
"data": [
|
||||
{
|
||||
"name": "filename",
|
||||
"type": "string",
|
||||
"value": "result.json-bin/missing.bin",
|
||||
},
|
||||
{
|
||||
"name": "size",
|
||||
"type": "int64",
|
||||
"value": "3",
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
"tag": "nv/json/freqs-bin:nv/cold/sample_freqs",
|
||||
"data": [
|
||||
{
|
||||
"name": "filename",
|
||||
"type": "string",
|
||||
"value": "result.json-freqs-bin/missing.bin",
|
||||
},
|
||||
{
|
||||
"name": "size",
|
||||
"type": "int64",
|
||||
"value": "3",
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
"is_skipped": False,
|
||||
},
|
||||
],
|
||||
}
|
||||
]
|
||||
}
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
result = bench.BenchResult(str(json_fn))
|
||||
|
||||
states = result.subbenches["copy"].states
|
||||
assert states[0].samples is None
|
||||
assert states[1].samples is None
|
||||
assert states[0].frequencies is None
|
||||
assert states[1].frequencies is None
|
||||
assert result.centers(lambda samples: pytest.fail("estimator should not run")) == {
|
||||
"copy": {
|
||||
"BlockSize[pow2]=8": None,
|
||||
"BlockSize[pow2]=9": None,
|
||||
}
|
||||
}
|
||||
assert result.centers_with_frequencies(
|
||||
lambda samples, frequencies: pytest.fail("estimator should not run")
|
||||
) == {
|
||||
"copy": {
|
||||
"BlockSize[pow2]=8": None,
|
||||
"BlockSize[pow2]=9": None,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def test_bench_result_rejects_mismatched_sample_and_frequency_counts(tmp_path):
|
||||
bin_dir = tmp_path / "result.json-bin"
|
||||
bin_dir.mkdir()
|
||||
(bin_dir / "0.bin").write_bytes(struct.pack("<3f", 1.0, 2.0, 4.0))
|
||||
freq_bin_dir = tmp_path / "result.json-freqs-bin"
|
||||
freq_bin_dir.mkdir()
|
||||
(freq_bin_dir / "0.bin").write_bytes(struct.pack("<2f", 100.0, 200.0))
|
||||
|
||||
json_fn = tmp_path / "result.json"
|
||||
json_fn.write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"benchmarks": [
|
||||
{
|
||||
"name": "copy",
|
||||
"axes": [
|
||||
{
|
||||
"name": "BlockSize",
|
||||
"type": "int64",
|
||||
"flags": "pow2",
|
||||
"values": [
|
||||
{
|
||||
"input_string": "8",
|
||||
"description": "2^8 = 256",
|
||||
"value": 256,
|
||||
}
|
||||
],
|
||||
}
|
||||
],
|
||||
"states": [
|
||||
{
|
||||
"name": "Device=0 BlockSize=2^8",
|
||||
"axis_values": [
|
||||
{
|
||||
"name": "BlockSize",
|
||||
"type": "int64",
|
||||
"value": "256",
|
||||
}
|
||||
],
|
||||
"summaries": [
|
||||
{
|
||||
"tag": "nv/json/bin:nv/cold/sample_times",
|
||||
"data": [
|
||||
{
|
||||
"name": "filename",
|
||||
"type": "string",
|
||||
"value": "result.json-bin/0.bin",
|
||||
},
|
||||
{
|
||||
"name": "size",
|
||||
"type": "int64",
|
||||
"value": "3",
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
"tag": "nv/json/freqs-bin:nv/cold/sample_freqs",
|
||||
"data": [
|
||||
{
|
||||
"name": "filename",
|
||||
"type": "string",
|
||||
"value": "result.json-freqs-bin/0.bin",
|
||||
},
|
||||
{
|
||||
"name": "size",
|
||||
"type": "int64",
|
||||
"value": "2",
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
"is_skipped": False,
|
||||
}
|
||||
],
|
||||
}
|
||||
]
|
||||
}
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="sample count .* frequency count"):
|
||||
bench.BenchResult(str(json_fn))
|
||||
|
||||
Reference in New Issue
Block a user