Lazy-load nvbench-compare bulk timing data

Store JSON-bin sample time and frequency metadata in GpuTimingData instead of
reading the binary files during summary extraction.

Add Float32BinarySource and lazy cached accessors for samples and frequencies.
Use np.fromfile by default, but allow tests and alternate callers to inject a
float32 reader returning any buffer-compatible object convertable to "<f4" data
type.

Treat optional bulk-data failures as unavailable evidence instead of aborting
comparison: unreadable files, invalid buffers, count mismatches, and mismatched
sample/frequency metadata now emit RuntimeWarning and return None.

Update nvbench_compare tests to verify lazy loading, cache reuse, injected
reader behavior, warning-based degradation, and count mismatch handling.
This commit is contained in:
Oleksandr Pavlyk
2026-06-02 15:55:02 -05:00
parent 6d34618dc5
commit db4db61596
2 changed files with 144 additions and 33 deletions

View File

@@ -7,10 +7,12 @@ import argparse
import math
import os
import sys
import warnings
from collections import Counter
from dataclasses import dataclass
from enum import Enum
from typing import Any, Mapping
from functools import cached_property
from typing import Any, Callable, Mapping
import jsondiff
import numpy as np
@@ -41,10 +43,34 @@ GPU_TIME_IR_RELATIVE_TAG = "nv/cold/time/gpu/ir/relative"
SAMPLE_TIMES_TAG = "nv/json/bin:nv/cold/sample_times"
SAMPLE_FREQUENCIES_TAG = "nv/json/freqs-bin:nv/cold/sample_freqs"
# The reader returns an object supporting the buffer protocol. Python 3.10 does
# not provide a standard Buffer type annotation.
Float32Reader = Callable[[str], object]
def read_float32_file(filename: str) -> object:
return np.fromfile(filename, dtype="<f4")
# These dataclasses are treated as parsed value objects. frozen=True prevents
# accidental field reassignment but does not imply deep immutability.
@dataclass(frozen=True)
class Float32BinarySource:
count: int
filename: str
json_dir: str
description: str
reader: Float32Reader = read_float32_file
@cached_property
def values(self) -> np.ndarray | None:
return read_float32_binary(
self.count, self.filename, self.json_dir, self.description, self.reader
)
@dataclass(frozen=True)
class GpuTimingData:
minimum: float | None
@@ -55,8 +81,20 @@ class GpuTimingData:
median: float | None
interquartile_range: float | None
interquartile_range_relative: float | None
samples: np.ndarray | None = None
frequencies: np.ndarray | None = None
sample_source: Float32BinarySource | None = None
frequency_source: Float32BinarySource | None = None
@cached_property
def samples(self) -> np.ndarray | None:
if self.sample_source is None:
return None
return self.sample_source.values
@cached_property
def frequencies(self) -> np.ndarray | None:
if self.frequency_source is None:
return None
return self.frequency_source.values
@dataclass(frozen=True)
@@ -342,45 +380,76 @@ def resolve_binary_filename(json_dir, binary_filename):
return json_relative_filename
def read_float32_binary(count, filename, json_dir):
if count is None or filename is None or json_dir is None:
return None
def warn_unavailable_bulk_data(description, message):
warnings.warn(
f"Could not use NVBench {description} data: {message}; treating it as unavailable",
RuntimeWarning,
stacklevel=3,
)
def read_float32_binary(count, filename, json_dir, description, reader):
filename = resolve_binary_filename(json_dir, filename)
try:
values = np.fromfile(filename, dtype="<f4")
except FileNotFoundError:
values = np.frombuffer(reader(filename), dtype="<f4")
except (BufferError, OSError, TypeError, ValueError) as exc:
warn_unavailable_bulk_data(description, f"failed to read {filename!r}: {exc}")
return None
if count != len(values):
raise ValueError(f"expected {count} values in {filename}, found {len(values)}")
warn_unavailable_bulk_data(
description,
f"expected {count} values in {filename!r}, found {len(values)}",
)
return None
return values
def extract_sample_times(summaries, json_dir):
sample_count, samples_filename = extract_binary_meta(summaries, SAMPLE_TIMES_TAG)
return read_float32_binary(sample_count, samples_filename, json_dir)
def extract_sample_frequencies(summaries, json_dir):
frequency_count, frequencies_filename = extract_binary_meta(
summaries, SAMPLE_FREQUENCIES_TAG
def extract_float32_binary_source(summaries, tag, json_dir, description, reader):
count, filename = extract_binary_meta(summaries, tag)
if count is None or filename is None or json_dir is None:
return None
if count < 0:
warn_unavailable_bulk_data(description, f"negative value count {count}")
return None
return Float32BinarySource(
count=count,
filename=filename,
json_dir=json_dir,
description=description,
reader=reader,
)
return read_float32_binary(frequency_count, frequencies_filename, json_dir)
def extract_gpu_timing_data(summaries, json_dir=None):
samples = extract_sample_times(summaries, json_dir)
frequencies = extract_sample_frequencies(summaries, json_dir)
def extract_sample_time_source(summaries, json_dir, reader):
return extract_float32_binary_source(
summaries, SAMPLE_TIMES_TAG, json_dir, "sample time", reader
)
def extract_sample_frequency_source(summaries, json_dir, reader):
return extract_float32_binary_source(
summaries, SAMPLE_FREQUENCIES_TAG, json_dir, "sample frequency", reader
)
def extract_gpu_timing_data(summaries, json_dir=None, float32_reader=read_float32_file):
sample_source = extract_sample_time_source(summaries, json_dir, float32_reader)
frequency_source = extract_sample_frequency_source(
summaries, json_dir, float32_reader
)
if (
samples is not None
and frequencies is not None
and len(samples) != len(frequencies)
sample_source is not None
and frequency_source is not None
and sample_source.count != frequency_source.count
):
raise ValueError(
f"sample count ({len(samples)}) does not match "
f"frequency count ({len(frequencies)})"
warn_unavailable_bulk_data(
"paired sample time and frequency",
f"sample count ({sample_source.count}) does not match "
f"frequency count ({frequency_source.count})",
)
sample_source = None
frequency_source = None
return GpuTimingData(
minimum=extract_summary_float(summaries, GPU_TIME_MIN_TAG),
@@ -397,8 +466,8 @@ def extract_gpu_timing_data(summaries, json_dir=None):
interquartile_range_relative=extract_summary_float(
summaries, GPU_TIME_IR_RELATIVE_TAG, null_value=math.inf
),
samples=samples,
frequencies=frequencies,
sample_source=sample_source,
frequency_source=frequency_source,
)