mirror of
https://github.com/NVIDIA/nvbench.git
synced 2026-04-20 06:48:53 +00:00
Plot comparison results (#90)
This commit is contained in:
committed by
GitHub
parent
92286e1d4a
commit
0ce45af043
@@ -99,7 +99,13 @@ def format_percentage(percentage):
|
||||
return "%0.2f%%" % (percentage * 100.0)
|
||||
|
||||
|
||||
def compare_benches(ref_benches, cmp_benches, threshold):
|
||||
def compare_benches(ref_benches, cmp_benches, threshold, plot):
|
||||
if plot:
|
||||
import matplotlib.pyplot as plt
|
||||
import seaborn as sns
|
||||
|
||||
sns.set()
|
||||
|
||||
for cmp_bench in cmp_benches:
|
||||
ref_bench = find_matching_bench(cmp_bench, ref_benches)
|
||||
if not ref_bench:
|
||||
@@ -135,6 +141,8 @@ def compare_benches(ref_benches, cmp_benches, threshold):
|
||||
for device_id in device_ids:
|
||||
|
||||
rows = []
|
||||
plot_data = {'cmp': {}, 'ref': {}, 'cmp_noise': {}, 'ref_noise': {}}
|
||||
|
||||
for cmp_state in cmp_states:
|
||||
cmp_state_name = cmp_state["name"]
|
||||
ref_state = next(filter(lambda st: st["name"] == cmp_state_name,
|
||||
@@ -207,6 +215,27 @@ def compare_benches(ref_benches, cmp_benches, threshold):
|
||||
else:
|
||||
min_noise = None # Noise is inf
|
||||
|
||||
if plot:
|
||||
axis_name = []
|
||||
axis_value = "--"
|
||||
for aid in range(len(axis_values)):
|
||||
if axis_values[aid]["name"] != plot:
|
||||
axis_name.append("{} = {}".format(axis_values[aid]["name"], axis_values[aid]["value"]))
|
||||
else:
|
||||
axis_value = float(axis_values[aid]["value"])
|
||||
axis_name = ', '.join(axis_name)
|
||||
|
||||
if axis_name not in plot_data['cmp']:
|
||||
plot_data['cmp'][axis_name] = {}
|
||||
plot_data['ref'][axis_name] = {}
|
||||
plot_data['cmp_noise'][axis_name] = {}
|
||||
plot_data['ref_noise'][axis_name] = {}
|
||||
|
||||
plot_data['cmp'][axis_name][axis_value] = cmp_time
|
||||
plot_data['ref'][axis_name][axis_value] = ref_time
|
||||
plot_data['cmp_noise'][axis_name][axis_value] = cmp_noise
|
||||
plot_data['ref_noise'][axis_name][axis_value] = ref_noise
|
||||
|
||||
global config_count
|
||||
global unknown_count
|
||||
global pass_count
|
||||
@@ -252,12 +281,41 @@ def compare_benches(ref_benches, cmp_benches, threshold):
|
||||
|
||||
print("")
|
||||
|
||||
if plot:
|
||||
plt.xscale("log")
|
||||
plt.yscale("log")
|
||||
plt.xlabel(plot)
|
||||
plt.ylabel("time [s]")
|
||||
plt.title(device["name"])
|
||||
|
||||
def plot_line(key, shape, label):
|
||||
x = [float(x) for x in plot_data[key][axis].keys()]
|
||||
y = list(plot_data[key][axis].values())
|
||||
|
||||
noise = list(plot_data[key + '_noise'][axis].values())
|
||||
|
||||
top = [y[i] + y[i] * noise[i] for i in range(len(x))]
|
||||
bottom = [y[i] - y[i] * noise[i] for i in range(len(x))]
|
||||
|
||||
p = plt.plot(x, y, shape, marker='o', label=label)
|
||||
plt.fill_between(x, bottom, top, color=p[0].get_color(), alpha=0.1)
|
||||
|
||||
|
||||
for axis in plot_data['cmp'].keys():
|
||||
plot_line('cmp', '-', axis)
|
||||
plot_line('ref', '--', axis + ' ref')
|
||||
|
||||
plt.legend()
|
||||
plt.show()
|
||||
|
||||
|
||||
def main():
|
||||
help_text = "%(prog)s [reference.json compare.json | reference_dir/ compare_dir/]"
|
||||
parser = argparse.ArgumentParser(prog='nvbench_compare', usage=help_text)
|
||||
parser.add_argument('--threshold-diff', type=float, dest='threshold', default=0.0,
|
||||
help='only show benchmarks where percentage diff is >= THRESHOLD')
|
||||
parser.add_argument('--plot-along', type=str, dest='plot', default=None,
|
||||
help='plot results')
|
||||
|
||||
args, files_or_dirs = parser.parse_known_args()
|
||||
print(files_or_dirs)
|
||||
@@ -294,7 +352,7 @@ def main():
|
||||
print("Device sections do not match.")
|
||||
sys.exit(1)
|
||||
|
||||
compare_benches(ref_root["benchmarks"], cmp_root["benchmarks"], args.threshold)
|
||||
compare_benches(ref_root["benchmarks"], cmp_root["benchmarks"], args.threshold, args.plot)
|
||||
|
||||
print("# Summary\n")
|
||||
print("- Total Matches: %d" % config_count)
|
||||
|
||||
Reference in New Issue
Block a user