compare_q.py: Add dark mode

This commit is contained in:
turboderp
2025-06-12 05:54:57 +02:00
parent 72bdf5a39c
commit 463ebe1841

View File

@@ -228,16 +228,22 @@ def test_ppl(data_spec: dict, spec: dict, logits_file: str):
def plot(results, args):
def col(light, dark):
return dark if args.dark else light
if args.dark:
plt.style.use('dark_background')
def get_color(s):
d = {
"EXL2": "green",
"EXL3": "purple",
"AWQ": "olive",
"imat": "brown",
"GGUF": "red",
"VPTQ": "blue",
"QTIP": "teal",
"****": "black",
"EXL2": col("green", "greenyellow"),
"EXL3": col("purple", "palevioletred"),
"AWQ": col("olive", "tan"),
"imat": col("brown", "darkorange"),
"GGUF": col("red", "tomato"),
"VPTQ": col("blue", "cornflowerblue"),
"QTIP": col("teal", "lightseagreen"),
"****": col("black", "silver"),
}
for k, v in d.items():
if f"[{v}]" in s:
@@ -245,7 +251,7 @@ def plot(results, args):
for k, v in d.items():
if k in s:
return v
return "black"
return col("black", "silver")
plt.rcParams["figure.figsize"] = (14, 11)
plt.subplots_adjust(left = 0.05, right = 0.95, top = 0.95, bottom = 0.05)
@@ -265,7 +271,7 @@ def plot(results, args):
labels.append(r["label"].split("[")[0].strip() + f"\n{y_:.3f}")
color = get_color(r["label"])
colors.append(color)
if color != "black":
if color != col("black", "silver"):
if color not in lpoints:
lpoints[color] = []
lpoints[color].append((x_, y_))
@@ -289,7 +295,7 @@ def plot(results, args):
texts,
x = x,
y = y,
arrowprops = {"arrowstyle": "->", "color": "lightgray"},
arrowprops = {"arrowstyle": "->", "color": col("lightgray", "dimgray")},
expand = (1.35, 2.3),
ensure_inside_axes = True,
min_arrow_len = 0.10,
@@ -306,7 +312,10 @@ def plot(results, args):
plt.xlabel("VRAM // GB (decoder + head)" if args.vram else "bits per weight (decoder only)")
plt.ylabel("Perplexity" if not args.kld else "KL divergence")
plt.title(args.title)
plt.grid(True)
if args.dark:
plt.grid(color = 'dimgray', linestyle = '--', linewidth = 0.5)
else:
plt.grid(True)
plt.show()
@@ -389,6 +398,7 @@ if __name__ == "__main__":
parser.add_argument("-kld", "--kld", action = "store_true", help = "Test KL divergence")
parser.add_argument("-mask", "--mask", type = str, help = "Semicolon-separated list of strings to match against model labels for inclusion")
parser.add_argument("-lf", "--logits_file", type = str, help = "Reference logits file for KLD", required = False)
parser.add_argument("-dark", "--dark", action = "store_true", help = "Dark mode")
_args = parser.parse_args()
main(_args)