diff --git a/python/cuda/bench/__init__.py b/python/cuda/bench/__init__.py index c334384..7eb4fb0 100644 --- a/python/cuda/bench/__init__.py +++ b/python/cuda/bench/__init__.py @@ -18,7 +18,7 @@ import importlib import importlib.metadata import warnings -from ._bench_result import BenchResult, SubBenchResult, SubBenchState +from ._bench_result import BenchmarkResult, SubBenchResult, SubBenchState try: __version__ = importlib.metadata.version("cuda-bench") @@ -31,7 +31,7 @@ except Exception as e: ) -BenchResult.__module__ = __name__ +BenchmarkResult.__module__ = __name__ SubBenchResult.__module__ = __name__ SubBenchState.__module__ = __name__ @@ -51,7 +51,7 @@ _NVBENCH_TEST_EXPORTS = ( ) __all__ = [ - "BenchResult", + "BenchmarkResult", "SubBenchResult", "SubBenchState", *_NVBENCH_EXPORTS, diff --git a/python/cuda/bench/__init__.pyi b/python/cuda/bench/__init__.pyi index 55a3d77..8773d1a 100644 --- a/python/cuda/bench/__init__.pyi +++ b/python/cuda/bench/__init__.pyi @@ -34,6 +34,7 @@ from collections.abc import ( Sequence, ValuesView, ) +from os import PathLike from typing import ( Any, Optional, @@ -166,16 +167,21 @@ class SubBenchResult: self, estimator: Callable[[array, array], ResultT] ) -> dict[str, ResultT | None]: ... -class BenchResult: +class BenchmarkResult: metadata: Any subbenches: dict[str, SubBenchResult] def __init__( self, - json_fn: str | None = None, *, + json_path: str | PathLike[str], metadata: Any = None, - parse: bool = True, ) -> None: ... + @classmethod + def empty(cls, *, metadata: Any = None) -> Self: ... + @classmethod + def from_json( + cls, json_path: str | PathLike[str], *, metadata: Any = None + ) -> Self: ... def __len__(self) -> int: ... def __iter__(self) -> Iterator[str]: ... def __contains__(self, subbench_name: object) -> bool: ... diff --git a/python/cuda/bench/_bench_result.py b/python/cuda/bench/_bench_result.py index b83152f..6072041 100644 --- a/python/cuda/bench/_bench_result.py +++ b/python/cuda/bench/_bench_result.py @@ -21,14 +21,15 @@ import sys from collections.abc import ItemsView, Iterator, KeysView, ValuesView from typing import Any, Callable, TypeVar -__all__ = ["BenchResult", "SubBenchResult", "SubBenchState"] +__all__ = ["BenchmarkResult", "SubBenchResult", "SubBenchState"] ResultT = TypeVar("ResultT") +BenchmarkResultT = TypeVar("BenchmarkResultT", bound="BenchmarkResult") _SummaryValue = int | float | str _SummaryData = _SummaryValue | dict[str, _SummaryValue] -def read_json(filename: str) -> dict: +def read_json(filename: str | os.PathLike[str]) -> dict: with open(filename, "r", encoding="utf-8") as f: file_root = json.load(f) return file_root @@ -287,24 +288,41 @@ class SubBenchResult: return result -class BenchResult: +class BenchmarkResult: """Parsed result data from an NVBench JSON output file.""" def __init__( self, - json_fn: str | None = None, *, + json_path: str | os.PathLike[str], metadata: Any = None, - parse: bool = True, ): self.metadata = metadata self.subbenches: dict[str, SubBenchResult] = {} + self._parse_json(json_path) - if json_fn and parse: - json_dir = os.path.dirname(os.path.abspath(json_fn)) - for bench in read_json(json_fn)["benchmarks"]: - bench_name: str = bench["name"] - self.subbenches[bench_name] = SubBenchResult(bench, json_dir) + @classmethod + def empty(cls: type[BenchmarkResultT], *, metadata: Any = None) -> BenchmarkResultT: + result = cls.__new__(cls) + result.metadata = metadata + result.subbenches = {} + return result + + @classmethod + def from_json( + cls: type[BenchmarkResultT], + json_path: str | os.PathLike[str], + *, + metadata: Any = None, + ) -> BenchmarkResultT: + return cls(json_path=json_path, metadata=metadata) + + def _parse_json(self, json_path: str | os.PathLike[str]) -> None: + json_path = os.fspath(json_path) + json_dir = os.path.dirname(os.path.abspath(json_path)) + for bench in read_json(json_path)["benchmarks"]: + bench_name: str = bench["name"] + self.subbenches[bench_name] = SubBenchResult(bench, json_dir) def __repr__(self) -> str: return str(self.__dict__) diff --git a/python/test/test_bench_result.py b/python/test/test_bench_result.py index 91ffc3f..a01b413 100644 --- a/python/test/test_bench_result.py +++ b/python/test/test_bench_result.py @@ -6,7 +6,7 @@ import cuda.bench as bench import pytest -def test_bench_result_reads_jsonbin_relative_to_json_path(tmp_path): +def test_benchmark_result_reads_jsonbin_relative_to_json_path(tmp_path): bin_dir = tmp_path / "result.json-bin" bin_dir.mkdir() (bin_dir / "0.bin").write_bytes(struct.pack("<3f", 1.0, 2.0, 4.0)) @@ -98,10 +98,10 @@ def test_bench_result_reads_jsonbin_relative_to_json_path(tmp_path): ) metadata = {"returncode": 0, "elapsed_seconds": 0.25} - default_result = bench.BenchResult(str(json_fn)) - result = bench.BenchResult(str(json_fn), metadata=metadata) + default_result = bench.BenchmarkResult.from_json(json_fn) + result = bench.BenchmarkResult(json_path=json_fn, metadata=metadata) - assert bench.BenchResult.__module__ == bench.__name__ + assert bench.BenchmarkResult.__module__ == bench.__name__ assert default_result.metadata is None assert result.metadata is metadata subbench = result["copy"] @@ -154,14 +154,16 @@ def test_bench_result_reads_jsonbin_relative_to_json_path(tmp_path): result["missing"] -def test_bench_result_metadata_and_parse_are_keyword_only(): +def test_benchmark_result_json_path_is_required_keyword(): with pytest.raises(TypeError): - bench.BenchResult("", None) + bench.BenchmarkResult("result.json") with pytest.raises(TypeError): - bench.BenchResult("", None, False) + bench.BenchmarkResult(metadata=None) + with pytest.raises(TypeError): + bench.BenchmarkResult(json_path="result.json", parse=False) -def test_bench_result_parse_false_does_not_read_json(tmp_path): +def test_benchmark_result_empty_does_not_read_json(tmp_path): @dataclass class RunMetadata: returncode: int @@ -170,16 +172,18 @@ def test_bench_result_parse_false_does_not_read_json(tmp_path): metadata = RunMetadata(returncode=1, elapsed_seconds=0.25) missing_json = tmp_path / "missing.json" - result = bench.BenchResult(str(missing_json), metadata=metadata, parse=False) + result = bench.BenchmarkResult.empty(metadata=metadata) assert result.metadata is metadata assert result.subbenches == {} with pytest.raises(FileNotFoundError): - bench.BenchResult(str(missing_json), metadata=metadata) + bench.BenchmarkResult(json_path=missing_json, metadata=metadata) + with pytest.raises(FileNotFoundError): + bench.BenchmarkResult.from_json(json_path=missing_json, metadata=metadata) -def test_bench_result_accepts_no_axis_benchmark_with_recorded_binary_path( +def test_benchmark_result_accepts_no_axis_benchmark_with_recorded_binary_path( tmp_path, monkeypatch ): data_dir = tmp_path / "temp_data" @@ -247,7 +251,7 @@ def test_bench_result_accepts_no_axis_benchmark_with_recorded_binary_path( monkeypatch.chdir(tmp_path) - result = bench.BenchResult("temp_data/axes_run1.json") + result = bench.BenchmarkResult(json_path="temp_data/axes_run1.json") state = result.subbenches["simple"].states[0] assert state.name() == "Device=0" @@ -258,7 +262,7 @@ def test_bench_result_accepts_no_axis_benchmark_with_recorded_binary_path( assert list(state.frequencies) == pytest.approx([100.0, 400.0]) -def test_bench_result_accepts_axis_value_input_string(): +def test_benchmark_result_accepts_axis_value_input_string(): result = bench.SubBenchResult( { "name": "single_float64_axis", @@ -299,7 +303,7 @@ def test_bench_result_accepts_axis_value_input_string(): assert state.point == {"Duration": "0"} -def test_bench_result_ignores_skipped_state_with_no_summaries(): +def test_benchmark_result_ignores_skipped_state_with_no_summaries(): result = bench.SubBenchResult( { "name": "copy_sweep_grid_shape", @@ -356,7 +360,7 @@ def test_bench_result_ignores_skipped_state_with_no_summaries(): assert result.states[0].name() == "BlockSize[pow2]=6" -def test_bench_result_uses_none_for_unavailable_samples(tmp_path): +def test_benchmark_result_uses_none_for_unavailable_samples(tmp_path): json_fn = tmp_path / "result.json" json_fn.write_text( json.dumps( @@ -447,7 +451,7 @@ def test_bench_result_uses_none_for_unavailable_samples(tmp_path): encoding="utf-8", ) - result = bench.BenchResult(str(json_fn)) + result = bench.BenchmarkResult(json_path=json_fn) states = result.subbenches["copy"].states assert states[0].samples is None @@ -470,7 +474,7 @@ def test_bench_result_uses_none_for_unavailable_samples(tmp_path): } -def test_bench_result_rejects_mismatched_sample_and_frequency_counts(tmp_path): +def test_benchmark_result_rejects_mismatched_sample_and_frequency_counts(tmp_path): bin_dir = tmp_path / "result.json-bin" bin_dir.mkdir() (bin_dir / "0.bin").write_bytes(struct.pack("<3f", 1.0, 2.0, 4.0)) @@ -552,4 +556,4 @@ def test_bench_result_rejects_mismatched_sample_and_frequency_counts(tmp_path): ) with pytest.raises(ValueError, match="sample count .* frequency count"): - bench.BenchResult(str(json_fn)) + bench.BenchmarkResult(json_path=json_fn)