mirror of
https://github.com/turboderp-org/exllamav3.git
synced 2026-04-20 14:29:51 +00:00
compare_q.py: Add dark mode
This commit is contained in:
@@ -228,16 +228,22 @@ def test_ppl(data_spec: dict, spec: dict, logits_file: str):
|
||||
|
||||
def plot(results, args):
|
||||
|
||||
def col(light, dark):
|
||||
return dark if args.dark else light
|
||||
|
||||
if args.dark:
|
||||
plt.style.use('dark_background')
|
||||
|
||||
def get_color(s):
|
||||
d = {
|
||||
"EXL2": "green",
|
||||
"EXL3": "purple",
|
||||
"AWQ": "olive",
|
||||
"imat": "brown",
|
||||
"GGUF": "red",
|
||||
"VPTQ": "blue",
|
||||
"QTIP": "teal",
|
||||
"****": "black",
|
||||
"EXL2": col("green", "greenyellow"),
|
||||
"EXL3": col("purple", "palevioletred"),
|
||||
"AWQ": col("olive", "tan"),
|
||||
"imat": col("brown", "darkorange"),
|
||||
"GGUF": col("red", "tomato"),
|
||||
"VPTQ": col("blue", "cornflowerblue"),
|
||||
"QTIP": col("teal", "lightseagreen"),
|
||||
"****": col("black", "silver"),
|
||||
}
|
||||
for k, v in d.items():
|
||||
if f"[{v}]" in s:
|
||||
@@ -245,7 +251,7 @@ def plot(results, args):
|
||||
for k, v in d.items():
|
||||
if k in s:
|
||||
return v
|
||||
return "black"
|
||||
return col("black", "silver")
|
||||
|
||||
plt.rcParams["figure.figsize"] = (14, 11)
|
||||
plt.subplots_adjust(left = 0.05, right = 0.95, top = 0.95, bottom = 0.05)
|
||||
@@ -265,7 +271,7 @@ def plot(results, args):
|
||||
labels.append(r["label"].split("[")[0].strip() + f"\n{y_:.3f}")
|
||||
color = get_color(r["label"])
|
||||
colors.append(color)
|
||||
if color != "black":
|
||||
if color != col("black", "silver"):
|
||||
if color not in lpoints:
|
||||
lpoints[color] = []
|
||||
lpoints[color].append((x_, y_))
|
||||
@@ -289,7 +295,7 @@ def plot(results, args):
|
||||
texts,
|
||||
x = x,
|
||||
y = y,
|
||||
arrowprops = {"arrowstyle": "->", "color": "lightgray"},
|
||||
arrowprops = {"arrowstyle": "->", "color": col("lightgray", "dimgray")},
|
||||
expand = (1.35, 2.3),
|
||||
ensure_inside_axes = True,
|
||||
min_arrow_len = 0.10,
|
||||
@@ -306,7 +312,10 @@ def plot(results, args):
|
||||
plt.xlabel("VRAM // GB (decoder + head)" if args.vram else "bits per weight (decoder only)")
|
||||
plt.ylabel("Perplexity" if not args.kld else "KL divergence")
|
||||
plt.title(args.title)
|
||||
plt.grid(True)
|
||||
if args.dark:
|
||||
plt.grid(color = 'dimgray', linestyle = '--', linewidth = 0.5)
|
||||
else:
|
||||
plt.grid(True)
|
||||
plt.show()
|
||||
|
||||
|
||||
@@ -389,6 +398,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument("-kld", "--kld", action = "store_true", help = "Test KL divergence")
|
||||
parser.add_argument("-mask", "--mask", type = str, help = "Semicolon-separated list of strings to match against model labels for inclusion")
|
||||
parser.add_argument("-lf", "--logits_file", type = str, help = "Reference logits file for KLD", required = False)
|
||||
parser.add_argument("-dark", "--dark", action = "store_true", help = "Dark mode")
|
||||
_args = parser.parse_args()
|
||||
main(_args)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user