mirror of
https://github.com/NVIDIA/nvbench.git
synced 2026-06-30 19:27:34 +00:00
Lazy-load nvbench-compare bulk timing data
Store JSON-bin sample time and frequency metadata in GpuTimingData instead of reading the binary files during summary extraction. Add Float32BinarySource and lazy cached accessors for samples and frequencies. Use np.fromfile by default, but allow tests and alternate callers to inject a float32 reader returning any buffer-compatible object convertable to "<f4" data type. Treat optional bulk-data failures as unavailable evidence instead of aborting comparison: unreadable files, invalid buffers, count mismatches, and mismatched sample/frequency metadata now emit RuntimeWarning and return None. Update nvbench_compare tests to verify lazy loading, cache reuse, injected reader behavior, warning-based degradation, and count mismatch handling.
This commit is contained in:
@@ -7,10 +7,12 @@ import argparse
|
||||
import math
|
||||
import os
|
||||
import sys
|
||||
import warnings
|
||||
from collections import Counter
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Any, Mapping
|
||||
from functools import cached_property
|
||||
from typing import Any, Callable, Mapping
|
||||
|
||||
import jsondiff
|
||||
import numpy as np
|
||||
@@ -41,10 +43,34 @@ GPU_TIME_IR_RELATIVE_TAG = "nv/cold/time/gpu/ir/relative"
|
||||
SAMPLE_TIMES_TAG = "nv/json/bin:nv/cold/sample_times"
|
||||
SAMPLE_FREQUENCIES_TAG = "nv/json/freqs-bin:nv/cold/sample_freqs"
|
||||
|
||||
# The reader returns an object supporting the buffer protocol. Python 3.10 does
|
||||
# not provide a standard Buffer type annotation.
|
||||
Float32Reader = Callable[[str], object]
|
||||
|
||||
|
||||
def read_float32_file(filename: str) -> object:
|
||||
return np.fromfile(filename, dtype="<f4")
|
||||
|
||||
|
||||
# These dataclasses are treated as parsed value objects. frozen=True prevents
|
||||
# accidental field reassignment but does not imply deep immutability.
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Float32BinarySource:
|
||||
count: int
|
||||
filename: str
|
||||
json_dir: str
|
||||
description: str
|
||||
reader: Float32Reader = read_float32_file
|
||||
|
||||
@cached_property
|
||||
def values(self) -> np.ndarray | None:
|
||||
return read_float32_binary(
|
||||
self.count, self.filename, self.json_dir, self.description, self.reader
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GpuTimingData:
|
||||
minimum: float | None
|
||||
@@ -55,8 +81,20 @@ class GpuTimingData:
|
||||
median: float | None
|
||||
interquartile_range: float | None
|
||||
interquartile_range_relative: float | None
|
||||
samples: np.ndarray | None = None
|
||||
frequencies: np.ndarray | None = None
|
||||
sample_source: Float32BinarySource | None = None
|
||||
frequency_source: Float32BinarySource | None = None
|
||||
|
||||
@cached_property
|
||||
def samples(self) -> np.ndarray | None:
|
||||
if self.sample_source is None:
|
||||
return None
|
||||
return self.sample_source.values
|
||||
|
||||
@cached_property
|
||||
def frequencies(self) -> np.ndarray | None:
|
||||
if self.frequency_source is None:
|
||||
return None
|
||||
return self.frequency_source.values
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@@ -342,45 +380,76 @@ def resolve_binary_filename(json_dir, binary_filename):
|
||||
return json_relative_filename
|
||||
|
||||
|
||||
def read_float32_binary(count, filename, json_dir):
|
||||
if count is None or filename is None or json_dir is None:
|
||||
return None
|
||||
def warn_unavailable_bulk_data(description, message):
|
||||
warnings.warn(
|
||||
f"Could not use NVBench {description} data: {message}; treating it as unavailable",
|
||||
RuntimeWarning,
|
||||
stacklevel=3,
|
||||
)
|
||||
|
||||
|
||||
def read_float32_binary(count, filename, json_dir, description, reader):
|
||||
filename = resolve_binary_filename(json_dir, filename)
|
||||
try:
|
||||
values = np.fromfile(filename, dtype="<f4")
|
||||
except FileNotFoundError:
|
||||
values = np.frombuffer(reader(filename), dtype="<f4")
|
||||
except (BufferError, OSError, TypeError, ValueError) as exc:
|
||||
warn_unavailable_bulk_data(description, f"failed to read {filename!r}: {exc}")
|
||||
return None
|
||||
|
||||
if count != len(values):
|
||||
raise ValueError(f"expected {count} values in {filename}, found {len(values)}")
|
||||
warn_unavailable_bulk_data(
|
||||
description,
|
||||
f"expected {count} values in {filename!r}, found {len(values)}",
|
||||
)
|
||||
return None
|
||||
return values
|
||||
|
||||
|
||||
def extract_sample_times(summaries, json_dir):
|
||||
sample_count, samples_filename = extract_binary_meta(summaries, SAMPLE_TIMES_TAG)
|
||||
return read_float32_binary(sample_count, samples_filename, json_dir)
|
||||
|
||||
|
||||
def extract_sample_frequencies(summaries, json_dir):
|
||||
frequency_count, frequencies_filename = extract_binary_meta(
|
||||
summaries, SAMPLE_FREQUENCIES_TAG
|
||||
def extract_float32_binary_source(summaries, tag, json_dir, description, reader):
|
||||
count, filename = extract_binary_meta(summaries, tag)
|
||||
if count is None or filename is None or json_dir is None:
|
||||
return None
|
||||
if count < 0:
|
||||
warn_unavailable_bulk_data(description, f"negative value count {count}")
|
||||
return None
|
||||
return Float32BinarySource(
|
||||
count=count,
|
||||
filename=filename,
|
||||
json_dir=json_dir,
|
||||
description=description,
|
||||
reader=reader,
|
||||
)
|
||||
return read_float32_binary(frequency_count, frequencies_filename, json_dir)
|
||||
|
||||
|
||||
def extract_gpu_timing_data(summaries, json_dir=None):
|
||||
samples = extract_sample_times(summaries, json_dir)
|
||||
frequencies = extract_sample_frequencies(summaries, json_dir)
|
||||
def extract_sample_time_source(summaries, json_dir, reader):
|
||||
return extract_float32_binary_source(
|
||||
summaries, SAMPLE_TIMES_TAG, json_dir, "sample time", reader
|
||||
)
|
||||
|
||||
|
||||
def extract_sample_frequency_source(summaries, json_dir, reader):
|
||||
return extract_float32_binary_source(
|
||||
summaries, SAMPLE_FREQUENCIES_TAG, json_dir, "sample frequency", reader
|
||||
)
|
||||
|
||||
|
||||
def extract_gpu_timing_data(summaries, json_dir=None, float32_reader=read_float32_file):
|
||||
sample_source = extract_sample_time_source(summaries, json_dir, float32_reader)
|
||||
frequency_source = extract_sample_frequency_source(
|
||||
summaries, json_dir, float32_reader
|
||||
)
|
||||
if (
|
||||
samples is not None
|
||||
and frequencies is not None
|
||||
and len(samples) != len(frequencies)
|
||||
sample_source is not None
|
||||
and frequency_source is not None
|
||||
and sample_source.count != frequency_source.count
|
||||
):
|
||||
raise ValueError(
|
||||
f"sample count ({len(samples)}) does not match "
|
||||
f"frequency count ({len(frequencies)})"
|
||||
warn_unavailable_bulk_data(
|
||||
"paired sample time and frequency",
|
||||
f"sample count ({sample_source.count}) does not match "
|
||||
f"frequency count ({frequency_source.count})",
|
||||
)
|
||||
sample_source = None
|
||||
frequency_source = None
|
||||
|
||||
return GpuTimingData(
|
||||
minimum=extract_summary_float(summaries, GPU_TIME_MIN_TAG),
|
||||
@@ -397,8 +466,8 @@ def extract_gpu_timing_data(summaries, json_dir=None):
|
||||
interquartile_range_relative=extract_summary_float(
|
||||
summaries, GPU_TIME_IR_RELATIVE_TAG, null_value=math.inf
|
||||
),
|
||||
samples=samples,
|
||||
frequencies=frequencies,
|
||||
sample_source=sample_source,
|
||||
frequency_source=frequency_source,
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user