mirror of
https://github.com/NVIDIA/nvbench.git
synced 2026-04-20 14:58:54 +00:00
Apply black formatting.
This commit is contained in:
committed by
Allison Vacanti
parent
dae1f16426
commit
1002082817
@@ -53,14 +53,20 @@ def get_row(cmp_benches, ref_benches):
|
||||
cmp_time_summary = cmp_summaries.get("Average GPU Time (Cold)")
|
||||
ref_time_summary = ref_summaries.get("Average GPU Time (Cold)")
|
||||
cmp_noise_summary = cmp_summaries.get(
|
||||
"GPU Relative Standard Deviation (Cold)")
|
||||
"GPU Relative Standard Deviation (Cold)"
|
||||
)
|
||||
ref_noise_summary = ref_summaries.get(
|
||||
"GPU Relative Standard Deviation (Cold)")
|
||||
"GPU Relative Standard Deviation (Cold)"
|
||||
)
|
||||
|
||||
# TODO: Determine whether empty outputs could be present based on
|
||||
# user requests not to perform certain timings.
|
||||
if cmp_time_summary is None or ref_time_summary is None or \
|
||||
cmp_noise_summary is None or ref_noise_summary is None:
|
||||
if (
|
||||
cmp_time_summary is None
|
||||
or ref_time_summary is None
|
||||
or cmp_noise_summary is None
|
||||
or ref_noise_summary is None
|
||||
):
|
||||
continue
|
||||
|
||||
# TODO Ugly. The JSON needs to be changed to let us look up names
|
||||
@@ -76,31 +82,52 @@ def get_row(cmp_benches, ref_benches):
|
||||
# and sample distributions. Ideally we would use something like
|
||||
# KL divergence to capture the differences, but that's out of scope
|
||||
# at this stage.
|
||||
cmp_abs_std = ((cmp_noise / 100.) * cmp_time)
|
||||
ref_abs_std = ((ref_noise / 100.) * ref_time)
|
||||
cmp_abs_std = (cmp_noise / 100.0) * cmp_time
|
||||
ref_abs_std = (ref_noise / 100.0) * ref_time
|
||||
num_stds_fail = 2
|
||||
failed = (cmp_noise - ref_noise) > (num_stds_fail * (cmp_abs_std + ref_abs_std))
|
||||
failed = (cmp_noise - ref_noise) > (
|
||||
num_stds_fail * (cmp_abs_std + ref_abs_std)
|
||||
)
|
||||
status = (Fore.RED + "FAIL" if failed else Fore.GREEN + "PASS") + Fore.RESET
|
||||
|
||||
# Relative time comparison
|
||||
yield ([cmp_bench['name'], cmp_state_description, cmp_time -
|
||||
ref_time, cmp_time, ref_time, f"{cmp_noise:0.6f}%",
|
||||
f"{ref_noise:0.6f}%", status],
|
||||
failed)
|
||||
yield (
|
||||
[
|
||||
cmp_bench["name"],
|
||||
cmp_state_description,
|
||||
cmp_time - ref_time,
|
||||
cmp_time,
|
||||
ref_time,
|
||||
f"{cmp_noise:0.6f}%",
|
||||
f"{ref_noise:0.6f}%",
|
||||
status,
|
||||
],
|
||||
failed,
|
||||
)
|
||||
|
||||
|
||||
rows, faileds = zip(*get_row(cmp_benches, ref_benches))
|
||||
|
||||
print(tabulate.tabulate(
|
||||
rows,
|
||||
# TODO: Reduce precision once we have really different
|
||||
# numbers for comparison.
|
||||
floatfmt="0.12f",
|
||||
headers=("Name", "Parameters", "Old - New", "New Time", "Old Time",
|
||||
"New Std", "Old Std", "Status"),
|
||||
# TODO: Choose appropriate format (or expose a
|
||||
# command-line argument to let the user choose)
|
||||
tablefmt="github",
|
||||
))
|
||||
print(
|
||||
tabulate.tabulate(
|
||||
rows,
|
||||
# TODO: Reduce precision once we have really different
|
||||
# numbers for comparison.
|
||||
floatfmt="0.12f",
|
||||
headers=(
|
||||
"Name",
|
||||
"Parameters",
|
||||
"Old - New",
|
||||
"New Time",
|
||||
"Old Time",
|
||||
"New Std",
|
||||
"Old Std",
|
||||
"Status",
|
||||
),
|
||||
# TODO: Choose appropriate format (or expose a
|
||||
# command-line argument to let the user choose)
|
||||
tablefmt="github",
|
||||
)
|
||||
)
|
||||
|
||||
sys.exit(any(faileds))
|
||||
|
||||
Reference in New Issue
Block a user