mirror of
https://github.com/turboderp-org/exllamav3.git
synced 2026-04-20 14:29:51 +00:00
Add cache map visualization
This commit is contained in:
@@ -17,11 +17,14 @@ Display modes for this demo:
|
||||
"""
|
||||
display_mode = 1
|
||||
|
||||
# Show graphical visualization of the paged cache (adds some overhead)
|
||||
show_visualization = False
|
||||
|
||||
# Where to find our model
|
||||
model_dir = "/mnt/str/eval_models/llama3.1-8b-instruct/exl3/4.0bpw/"
|
||||
|
||||
# Total number of tokens to allocate space for in the cache.
|
||||
total_context = 32768
|
||||
total_context = 16384
|
||||
|
||||
# Max number of batches to run at once, assuming the sequences will fit within total_context.
|
||||
max_batch_size = 16
|
||||
@@ -43,7 +46,6 @@ prompts = [
|
||||
"Can you guess the next number in this sequence: " + ", ".join(str(n) for n in range(200)),
|
||||
"Can you write a C++ quicksort implementation pretty please?",
|
||||
"Hello!",
|
||||
"Hi there!",
|
||||
"What's the difference smoke and vapor?",
|
||||
"What seems out of place in this sequence: " + ", ".join(str(n if n != 123 else 69) for n in range(200)),
|
||||
"What seems out of place in this sequence: " + ", ".join(str(n if n != 42 else 111) for n in range(200)),
|
||||
@@ -52,7 +54,6 @@ prompts = [
|
||||
"What seems out of place in this sequence: " + ", ".join(str(n if n != 42 else 111) for n in range(200)),
|
||||
"Please guess the next 20 numbers in this sequence: " + ", ".join(str(n) for n in range(700)),
|
||||
"Write a short essay about cell membranes.",
|
||||
"What's up?",
|
||||
"How do I open a can of beans?",
|
||||
"How do I open a can of soup?",
|
||||
"How do I open a can of strawberry jam?",
|
||||
@@ -72,19 +73,21 @@ prompts = [
|
||||
"How do I build a time machine?",
|
||||
"What seems out of place in this sequence: " + ", ".join(str(n if n != 123 else 69) for n in range(200)),
|
||||
"Is it legal to grow your own catnip?",
|
||||
"What seems out of place in this sequence: " + ", ".join(str(n if n != 160 else 420) for n in range(400)),
|
||||
"What seems out of place in this sequence: " + ", ".join(str(n if n != 161 else 421) for n in range(400)),
|
||||
"What seems out of place in this sequence: " + ", ".join(str(n if n != 360 else 420) for n in range(400)),
|
||||
"What seems out of place in this sequence: " + ", ".join(str(n if n != 361 else 421) for n in range(400)),
|
||||
"What's inside a black hole?",
|
||||
"What seems out of place in this sequence: " + ", ".join(str(n if n != 360 else 420) for n in range(400)),
|
||||
"What seems out of place in this sequence: " + ", ".join(str(n if n != 363 else 421) for n in range(400)),
|
||||
"What do the numbers 2, 4, 8, 16, 32 and 64 have in common?",
|
||||
"What do the numbers 2, 3, 5, 7, 11 and 13 have in common?",
|
||||
"Is there life on Mars?",
|
||||
"Hello!",
|
||||
"Hi!",
|
||||
"Boop!",
|
||||
"Why are cats better than dogs?",
|
||||
"Why are cats better than dogs?",
|
||||
"Why are cats better than dogs?",
|
||||
"Write a parable about why cats are better than dogs.",
|
||||
"Can you guess the next number in this sequence: " + ", ".join(str(n) for n in range(999)),
|
||||
"Can you guess the next number in this sequence: " + ", ".join(str(n) for n in range(999)),
|
||||
"Can you guess the next number in this sequence: " + ", ".join(str(n) for n in range(999)),
|
||||
"Can you guess the next number in this sequence: " + ", ".join(str(n) for n in range(999)),
|
||||
"Can you guess the next number in this sequence: " + ", ".join(str(n) for n in range(999)),
|
||||
]
|
||||
|
||||
term = Terminal()
|
||||
@@ -92,7 +95,7 @@ term = Terminal()
|
||||
def main():
|
||||
|
||||
# Load the model config
|
||||
config = Config.from_directory("/mnt/str/eval_models/llama3.1-8b-instruct/exl3/2.0bpw/")
|
||||
config = Config.from_directory("/mnt/str/models/llama3.1-8b-instruct/exl3/4.0bpw/")
|
||||
|
||||
# Create the model from the config
|
||||
model = Model.from_config(config)
|
||||
@@ -114,6 +117,7 @@ def main():
|
||||
tokenizer = tokenizer,
|
||||
max_batch_size = max_batch_size,
|
||||
max_chunk_size = max_chunk_size,
|
||||
show_visualizer = show_visualization
|
||||
)
|
||||
|
||||
# Create jobs
|
||||
|
||||
@@ -10,6 +10,7 @@ from .pagetable import PageTable
|
||||
from .job import Job
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from .sampler import Sampler, GumbelSampler
|
||||
from .visualizer import CacheVisualizer
|
||||
import time
|
||||
import threading
|
||||
import numpy as np
|
||||
@@ -27,6 +28,7 @@ class Generator:
|
||||
draft_model: Model | None = None,
|
||||
draft_cache: Cache | None = None,
|
||||
num_draft_tokens: int = 4,
|
||||
show_visualizer: bool = False,
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
@@ -64,6 +66,9 @@ class Generator:
|
||||
:param num_draft_tokens:
|
||||
Number of future tokens to draft.
|
||||
|
||||
:param show_visualizer:
|
||||
Open window to render visualization of cache (for debug/demonstration purposes)
|
||||
|
||||
:param kwargs:
|
||||
"""
|
||||
|
||||
@@ -116,6 +121,12 @@ class Generator:
|
||||
pin_memory = False
|
||||
)
|
||||
|
||||
# Visualizer
|
||||
if show_visualizer:
|
||||
self.visualizer = CacheVisualizer(self.pagetable.max_pages)
|
||||
else:
|
||||
self.visualizer = None
|
||||
|
||||
# TODO: (defrag)
|
||||
|
||||
|
||||
@@ -269,10 +280,26 @@ class Generator:
|
||||
else:
|
||||
self.iterate_gen(results)
|
||||
|
||||
# Visualization
|
||||
if self.visualizer:
|
||||
self.update_visualizer()
|
||||
|
||||
# Finished iteration
|
||||
return results
|
||||
|
||||
|
||||
def update_visualizer(self):
|
||||
chains = []
|
||||
for job in self.active_jobs:
|
||||
for seq in job.sequences:
|
||||
idx = job.serial_number
|
||||
chain = [page.page_index for page in seq.allocated_pages]
|
||||
chains.append((idx, chain))
|
||||
usage = []
|
||||
for page in self.pagetable.all_pages:
|
||||
usage.append(page.kv_position / PAGE_SIZE)
|
||||
self.visualizer.update(chains, usage)
|
||||
|
||||
def iterate_draftmodel_gen(self, results: list):
|
||||
|
||||
# Get shape of active batch
|
||||
|
||||
221
exllamav3/generator/visualizer.py
Normal file
221
exllamav3/generator/visualizer.py
Normal file
@@ -0,0 +1,221 @@
|
||||
from __future__ import annotations
|
||||
import tkinter as tk
|
||||
from collections import deque
|
||||
import math
|
||||
|
||||
"""
|
||||
Quick and dirty visualizer for the paged cache, for debug purposes. Horribly slow and should probably be
|
||||
rewritten to just draw on a bitmap.
|
||||
"""
|
||||
|
||||
job_colors = [
|
||||
"#00DDFF",
|
||||
"#9800FF",
|
||||
"#D8FF00",
|
||||
"#00FFA5",
|
||||
"#FF00E4",
|
||||
"#FF8800",
|
||||
"#057DFF",
|
||||
"#FF008C",
|
||||
"#00FFE1",
|
||||
"#FFFA00",
|
||||
"#B6FF00",
|
||||
"#D400FF",
|
||||
"#FF1900",
|
||||
"#FFCC00",
|
||||
]
|
||||
|
||||
empty_color = "#505050"
|
||||
empty_color_outline = "#707070"
|
||||
|
||||
class CacheVisualizer:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_pages: int,
|
||||
window_size: int = (800, 600),
|
||||
gap: float = 0.75,
|
||||
margin: int = 25
|
||||
):
|
||||
self.num_pages = num_pages
|
||||
self.window_size = window_size
|
||||
self.gap = gap
|
||||
self.margin = margin
|
||||
self.chains = []
|
||||
self.usage = []
|
||||
|
||||
w, h = self.window_size
|
||||
self.root = tk.Tk()
|
||||
self.root.title("Cache Map")
|
||||
self.canvas = tk.Canvas(
|
||||
self.root,
|
||||
width = w,
|
||||
height = h,
|
||||
bg = "#242424",
|
||||
highlightthickness = 0,
|
||||
borderwidth = 0
|
||||
)
|
||||
self.canvas.pack(fill = "both", expand = True)
|
||||
|
||||
self.page_rects = []
|
||||
for i in range(num_pages):
|
||||
rid = self.canvas.create_rectangle(0, 0, 10, 10, fill = empty_color, outline = empty_color_outline)
|
||||
self.page_rects.append(rid)
|
||||
|
||||
self._page_grid_layout()
|
||||
self.canvas.after_idle(lambda: self.root.bind("<Configure>", self._on_resize))
|
||||
|
||||
self.elements = []
|
||||
self.root.update()
|
||||
|
||||
|
||||
def _page_grid_layout(self):
|
||||
w, h = self.window_size
|
||||
w -= 2 * self.margin
|
||||
h -= 2 * self.margin
|
||||
ratio = w / h
|
||||
self.w_pages = min(int(math.ceil(math.sqrt(self.num_pages * ratio))), self.num_pages)
|
||||
self.h_pages = int(math.ceil(self.num_pages / self.w_pages))
|
||||
self.page_bboxes = []
|
||||
cell_size_a = w / (self.w_pages + (self.w_pages - 1) * self.gap)
|
||||
cell_size_b = h / (self.h_pages + (self.h_pages - 1) * self.gap)
|
||||
cell_size = min(cell_size_a, cell_size_b)
|
||||
cell_step = cell_size * (1 + self.gap)
|
||||
self.cell_step = cell_step
|
||||
self.cell_size = cell_size
|
||||
self.gap_size = self.gap * cell_size
|
||||
for y in range(self.h_pages):
|
||||
for x in range(self.w_pages):
|
||||
if len(self.page_bboxes) >= self.num_pages:
|
||||
break
|
||||
x0 = self.margin + x * cell_step
|
||||
y0 = self.margin + y * cell_step
|
||||
x1 = x0 + cell_size
|
||||
y1 = y0 + cell_size
|
||||
rid = self.page_rects[len(self.page_bboxes)]
|
||||
self.page_bboxes.append((x0, y0, x1, y1))
|
||||
self.canvas.coords(rid, x0, y0, x1, y1)
|
||||
|
||||
|
||||
def _on_resize(self, event):
|
||||
if event.width <= 1 or event.height <= 1:
|
||||
return
|
||||
w, h = self.window_size
|
||||
if (w, h) == (event.width, event.height):
|
||||
return
|
||||
self.window_size = (event.width, event.height)
|
||||
self._page_grid_layout()
|
||||
self._update_chains()
|
||||
|
||||
|
||||
def _update_chains(self):
|
||||
for e in self.elements:
|
||||
self.canvas.delete(e)
|
||||
self.elements.clear()
|
||||
|
||||
cols = [list() for _ in range(self.num_pages)]
|
||||
for index, chain in self.chains:
|
||||
col = job_colors[index % len(job_colors)]
|
||||
for page in chain:
|
||||
cols[page] += [col]
|
||||
|
||||
in_handles = []
|
||||
out_handles = []
|
||||
for page, col in enumerate(cols):
|
||||
if not col:
|
||||
in_handles.append([])
|
||||
out_handles.append([])
|
||||
continue
|
||||
bbox = self.page_bboxes[page]
|
||||
inh = deque()
|
||||
outh = deque()
|
||||
for i, c in enumerate(col):
|
||||
c_br = self.root.tk.call("tk::Darken", c, 135)
|
||||
c_dk = self.root.tk.call("tk::Darken", c, 60)
|
||||
a = i / len(col)
|
||||
b = (i + 1) / len(col)
|
||||
x0, y0, x1, y1 = bbox
|
||||
h = y1 - y0
|
||||
y0, y1 = y0 + a * h, y0 + b * h
|
||||
rid = self.canvas.create_rectangle(x0, y0, x1, y1, fill = c, outline = c_br)
|
||||
self.elements.append(rid)
|
||||
u = self.usage[page] or 0.0
|
||||
if u < 1.0:
|
||||
mx = x0 + (x1 - x0) * u
|
||||
rid = self.canvas.create_rectangle(mx, y0, x1, y1, fill = c_dk, outline = c_dk)
|
||||
self.elements.append(rid)
|
||||
inh.append((x0, (y0 + y1) * 0.5))
|
||||
outh.append((x1, (y0 + y1) * 0.5))
|
||||
in_handles.append(inh)
|
||||
out_handles.append(outh)
|
||||
|
||||
x = 0
|
||||
y = 0
|
||||
def start(_x, _y):
|
||||
nonlocal x, y
|
||||
x, y = _x, _y
|
||||
|
||||
def line(_x, _y):
|
||||
nonlocal x, y, col
|
||||
aid = self.canvas.create_line(x, y, _x, _y, fill = col, width = 2.0)
|
||||
self.elements.append(aid)
|
||||
x, y = _x, _y
|
||||
|
||||
def arrow(_x, _y):
|
||||
nonlocal x, y, col
|
||||
aid = self.canvas.create_line(x, y, _x, _y, arrow = 'last', tags = ("arrow",), fill = col, width = 2.0)
|
||||
self.elements.append(aid)
|
||||
x, y = _x, _y
|
||||
|
||||
for l_index, (index, chain) in enumerate(self.chains):
|
||||
bcol = job_colors[index % len(job_colors)]
|
||||
bias = -self.gap_size / 6 + \
|
||||
((self.gap_size / 3) * l_index + (self.gap_size / 3) * (l_index + 1)) * 0.5 / len(self.chains)
|
||||
for page_a, page_b in zip(chain[:-1], chain[1:]):
|
||||
if self.usage[page_b]:
|
||||
col = bcol
|
||||
else:
|
||||
col = self.root.tk.call("tk::Darken", bcol, 60)
|
||||
|
||||
|
||||
x0, y0 = out_handles[page_a].popleft()
|
||||
x1, y1 = in_handles[page_b].popleft()
|
||||
ax0, ay0, ax1, ay1 = self.page_bboxes[page_a]
|
||||
bx0, by0, bx1, by1 = self.page_bboxes[page_b]
|
||||
dy = y1 - y0
|
||||
dx = x1 - x0
|
||||
cs = self.cell_size
|
||||
gs = self.gap_size
|
||||
if 0 < dx < cs and abs(dy) < cs:
|
||||
start(x0, y0)
|
||||
line(x0 + gs / 2 + bias, y0)
|
||||
line(x0 + gs / 2 + bias, y1)
|
||||
arrow(x1, y1)
|
||||
elif 0 < dx and abs(dy) < cs:
|
||||
start(x0, y0)
|
||||
line(x0 + gs / 2 + bias, y0)
|
||||
line(x0 + gs / 2 + bias, ay1 + gs / 2 + bias)
|
||||
line(x1 - gs / 2 + bias, ay1 + gs / 2 + bias)
|
||||
line(x1 - gs / 2 + bias, y1)
|
||||
arrow(x1, y1)
|
||||
elif dy > 0:
|
||||
start(x0, y0)
|
||||
line(x0 + gs / 2 + bias, y0)
|
||||
line(x0 + gs / 2 + bias, by0 - gs / 2 + bias)
|
||||
line(x1 - gs / 2 + bias, by0 - gs / 2 + bias)
|
||||
line(x1 - gs / 2 + bias, y1)
|
||||
arrow(x1, y1)
|
||||
else:
|
||||
start(x0, y0)
|
||||
line(x0 + gs / 2 + bias, y0)
|
||||
line(x0 + gs / 2 + bias, by1 + gs / 2 + bias)
|
||||
line(x1 - gs / 2 + bias, by1 + gs / 2 + bias)
|
||||
line(x1 - gs / 2 + bias, y1)
|
||||
arrow(x1, y1)
|
||||
|
||||
|
||||
def update(self, chains: list[tuple[int, list]], usage: list[float]):
|
||||
self.chains = chains
|
||||
self.usage = usage
|
||||
self._update_chains()
|
||||
self.root.update()
|
||||
Reference in New Issue
Block a user