Add some intermeidate variables and simplify code to make it self-documenting.

This commit is contained in:
Vyas Ramasubramani
2021-05-26 10:59:07 -07:00
committed by Allison Vacanti
parent 053eb493c7
commit dae1f16426

View File

@@ -52,8 +52,10 @@ def get_row(cmp_benches, ref_benches):
cmp_time_summary = cmp_summaries.get("Average GPU Time (Cold)")
ref_time_summary = ref_summaries.get("Average GPU Time (Cold)")
cmp_noise_summary = cmp_summaries.get("GPU Relative Standard Deviation (Cold)")
ref_noise_summary = ref_summaries.get("GPU Relative Standard Deviation (Cold)")
cmp_noise_summary = cmp_summaries.get(
"GPU Relative Standard Deviation (Cold)")
ref_noise_summary = ref_summaries.get(
"GPU Relative Standard Deviation (Cold)")
# TODO: Determine whether empty outputs could be present based on
# user requests not to perform certain timings.
@@ -74,23 +76,31 @@ def get_row(cmp_benches, ref_benches):
# and sample distributions. Ideally we would use something like
# KL divergence to capture the differences, but that's out of scope
# at this stage.
failed = (cmp_noise - ref_noise) > 2 * (((cmp_noise / 100.) * cmp_time) + ((ref_noise / 100.) * ref_time))
cmp_abs_std = ((cmp_noise / 100.) * cmp_time)
ref_abs_std = ((ref_noise / 100.) * ref_time)
num_stds_fail = 2
failed = (cmp_noise - ref_noise) > (num_stds_fail * (cmp_abs_std + ref_abs_std))
status = (Fore.RED + "FAIL" if failed else Fore.GREEN + "PASS") + Fore.RESET
# Relative time comparison
yield ([cmp_bench['name'], cmp_state_description] + f"{cmp_time - ref_time} {cmp_time} {ref_time} {cmp_noise:0.6f}% {ref_noise:0.6f}% {status}\n".split(), failed)
yield ([cmp_bench['name'], cmp_state_description, cmp_time -
ref_time, cmp_time, ref_time, f"{cmp_noise:0.6f}%",
f"{ref_noise:0.6f}%", status],
failed)
rows, faileds = zip(*get_row(cmp_benches, ref_benches))
print(tabulate.tabulate(rows,
# TODO: Reduce precision once we have really different
# numbers for comparison.
floatfmt="0.12f",
headers=("Name", "Parameters", "Old - New", "New Time", "Old Time", "New Std", "Old Std", "Status"),
# TODO: Choose appropriate format (or expose a
# command-line argument to let the user choose)
tablefmt="github",
print(tabulate.tabulate(
rows,
# TODO: Reduce precision once we have really different
# numbers for comparison.
floatfmt="0.12f",
headers=("Name", "Parameters", "Old - New", "New Time", "Old Time",
"New Std", "Old Std", "Status"),
# TODO: Choose appropriate format (or expose a
# command-line argument to let the user choose)
tablefmt="github",
))
exit(any(faileds))
sys.exit(any(faileds))