Plot comparison results (#90)

This commit is contained in:
Georgii Evtushenko
2024-11-13 11:28:04 -08:00
committed by GitHub
parent 92286e1d4a
commit 0ce45af043

View File

@@ -99,7 +99,13 @@ def format_percentage(percentage):
return "%0.2f%%" % (percentage * 100.0)
def compare_benches(ref_benches, cmp_benches, threshold):
def compare_benches(ref_benches, cmp_benches, threshold, plot):
if plot:
import matplotlib.pyplot as plt
import seaborn as sns
sns.set()
for cmp_bench in cmp_benches:
ref_bench = find_matching_bench(cmp_bench, ref_benches)
if not ref_bench:
@@ -135,6 +141,8 @@ def compare_benches(ref_benches, cmp_benches, threshold):
for device_id in device_ids:
rows = []
plot_data = {'cmp': {}, 'ref': {}, 'cmp_noise': {}, 'ref_noise': {}}
for cmp_state in cmp_states:
cmp_state_name = cmp_state["name"]
ref_state = next(filter(lambda st: st["name"] == cmp_state_name,
@@ -207,6 +215,27 @@ def compare_benches(ref_benches, cmp_benches, threshold):
else:
min_noise = None # Noise is inf
if plot:
axis_name = []
axis_value = "--"
for aid in range(len(axis_values)):
if axis_values[aid]["name"] != plot:
axis_name.append("{} = {}".format(axis_values[aid]["name"], axis_values[aid]["value"]))
else:
axis_value = float(axis_values[aid]["value"])
axis_name = ', '.join(axis_name)
if axis_name not in plot_data['cmp']:
plot_data['cmp'][axis_name] = {}
plot_data['ref'][axis_name] = {}
plot_data['cmp_noise'][axis_name] = {}
plot_data['ref_noise'][axis_name] = {}
plot_data['cmp'][axis_name][axis_value] = cmp_time
plot_data['ref'][axis_name][axis_value] = ref_time
plot_data['cmp_noise'][axis_name][axis_value] = cmp_noise
plot_data['ref_noise'][axis_name][axis_value] = ref_noise
global config_count
global unknown_count
global pass_count
@@ -252,12 +281,41 @@ def compare_benches(ref_benches, cmp_benches, threshold):
print("")
if plot:
plt.xscale("log")
plt.yscale("log")
plt.xlabel(plot)
plt.ylabel("time [s]")
plt.title(device["name"])
def plot_line(key, shape, label):
x = [float(x) for x in plot_data[key][axis].keys()]
y = list(plot_data[key][axis].values())
noise = list(plot_data[key + '_noise'][axis].values())
top = [y[i] + y[i] * noise[i] for i in range(len(x))]
bottom = [y[i] - y[i] * noise[i] for i in range(len(x))]
p = plt.plot(x, y, shape, marker='o', label=label)
plt.fill_between(x, bottom, top, color=p[0].get_color(), alpha=0.1)
for axis in plot_data['cmp'].keys():
plot_line('cmp', '-', axis)
plot_line('ref', '--', axis + ' ref')
plt.legend()
plt.show()
def main():
help_text = "%(prog)s [reference.json compare.json | reference_dir/ compare_dir/]"
parser = argparse.ArgumentParser(prog='nvbench_compare', usage=help_text)
parser.add_argument('--threshold-diff', type=float, dest='threshold', default=0.0,
help='only show benchmarks where percentage diff is >= THRESHOLD')
parser.add_argument('--plot-along', type=str, dest='plot', default=None,
help='plot results')
args, files_or_dirs = parser.parse_known_args()
print(files_or_dirs)
@@ -294,7 +352,7 @@ def main():
print("Device sections do not match.")
sys.exit(1)
compare_benches(ref_root["benchmarks"], cmp_root["benchmarks"], args.threshold)
compare_benches(ref_root["benchmarks"], cmp_root["benchmarks"], args.threshold, args.plot)
print("# Summary\n")
print("- Total Matches: %d" % config_count)