Generator: Periodically defragment paged cache

This commit is contained in:
turboderp
2025-05-28 22:23:08 +02:00
parent a39fdfc4d5
commit 9be91d644d
8 changed files with 316 additions and 61 deletions

View File

@@ -49,6 +49,10 @@ class CacheLayer(ABC):
def copy_page(self, source: CacheLayer, from_page: int, to_page: int, num_tokens: int):
pass
@abstractmethod
def get_tensors(self):
pass
class Cache:
@@ -154,3 +158,10 @@ class Cache:
for src, dst in zip(target.layers, self.layers):
assert type(src) is type(dst)
dst.copy_page(src, from_page, to_page, num_tokens)
def get_all_tensors(self):
tensors = []
for layer in self.layers:
tensors += layer.get_tensors()
return tensors

View File

@@ -67,4 +67,8 @@ class CacheLayer_fp16(CacheLayer):
def copy_page(self, source: CacheLayer_fp16, from_page: int, to_page: int, num_tokens: int):
assert self.shape == source.shape
self.k[to_page, :num_tokens, :, :].copy_(source.k[from_page, :num_tokens, :, :], non_blocking = True)
self.v[to_page, :num_tokens, :, :].copy_(source.v[from_page, :num_tokens, :, :], non_blocking = True)
self.v[to_page, :num_tokens, :, :].copy_(source.v[from_page, :num_tokens, :, :], non_blocking = True)
@override
def get_tensors(self):
return [self.k, self.v]

View File

@@ -97,4 +97,8 @@ class CacheLayer_quant(CacheLayer):
self.qk[to_page, :num_tokens, :].copy_(source.qk[from_page, :num_tokens, :], non_blocking = True)
self.qv[to_page, :num_tokens, :].copy_(source.qv[from_page, :num_tokens, :], non_blocking = True)
self.sk[to_page, :num_tokens, :].copy_(source.sk[from_page, :num_tokens, :], non_blocking = True)
self.sv[to_page, :num_tokens, :].copy_(source.sv[from_page, :num_tokens, :], non_blocking = True)
self.sv[to_page, :num_tokens, :].copy_(source.sv[from_page, :num_tokens, :], non_blocking = True)
@override
def get_tensors(self):
return [self.qk, self.qv, self.sk, self.sv]

View File

@@ -12,7 +12,7 @@ __global__ __launch_bounds__(NUM_THREADS)
void cache_rotate_kernel
(
uint8_t* __restrict__ cache,
const uint32_t* __restrict__ order,
const int32_t* __restrict__ order,
uint8_t* __restrict__ temp,
size_t page_size,
size_t rotate_len
@@ -23,34 +23,33 @@ void cache_rotate_kernel
size_t block_beg = blockIdx.x * block_size;
size_t block_end = MIN(block_beg + block_size, page_size);
block_size = block_end - block_beg;
if (!block_size) return;
if (block_size <= 0) return;
// Rotate pages
auto copy = [&](uint8_t* dst, uint8_t* src)
for (int i = 0; i < rotate_len; ++i)
{
int64_t a = (int64_t) order[2 * i];
int64_t b = (int64_t) order[2 * i + 1];
uint8_t* dst = (a >= 0 ? cache + page_size * a : temp) + block_beg;
uint8_t* src = (b >= 0 ? cache + page_size * b : temp) + block_beg;
for (int offset = threadIdx.x * 16; offset < block_size; offset += NUM_THREADS * 16)
*((uint4*) (dst + offset)) = *((uint4*) (src + offset));
};
int i;
copy(temp + block_beg, cache + page_size * (uint64_t) order[0] + block_beg);
for (i = 0; i < rotate_len - 1; ++i)
copy(cache + page_size * (uint64_t) order[i] + block_beg, cache + page_size * (uint64_t) order[i + 1] + block_beg);
copy(cache + page_size * (uint64_t) order[i] + block_beg, temp + block_beg);
__syncthreads();
}
}
/*
Reorder cache pages
- cache, paged cache, shape (num_pages, ...), any dtype, contiguous
- order, sequence to rotate, shape (n,), dtype long
- order, sequence to rotate, shape (2*n,), dtype int
- temp, temp storage, sized as one cache page
Performs:
temp <- page[order[0]]
for a, b in pairwise(order):
page[a] <- page[b]
page[order[-1]] <- temp
for i in range(n):
a = order[2*i]
b = order[2*i+1]
copy: (page[a] if a >= 0 else temp) <- (page[b] if b >= 0 else temp)
*/
void cache_rotate
@@ -63,20 +62,20 @@ void cache_rotate
const at::cuda::OptionalCUDAGuard device_guard(cache.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
TORCH_CHECK(cache.dim() > 1, "cache argument must have dim >= 2")
TORCH_CHECK(cache.dim() >= 2, "cache argument must have dim >= 2")
TORCH_CHECK(order.dim() == 1, "order argument must have dim == 1")
TORCH_CHECK_DTYPE(order, kInt);
size_t num_pages = cache.size(0);
size_t page_size = cache.nbytes() / num_pages;
size_t rotate_len = order.size(0);
size_t rotate_len = order.size(0) / 2;
TORCH_CHECK(temp.nbytes() == page_size, "temp tensor incorrect size");
cache_rotate_kernel<<<NUM_BLOCKS, NUM_THREADS, 0, stream>>>
(
(uint8_t*) cache.data_ptr(),
(const uint32_t*) order.data_ptr(),
(const int32_t*) order.data_ptr(),
(uint8_t*) temp.data_ptr(),
page_size,
rotate_len

View File

@@ -30,6 +30,7 @@ class Generator:
draft_cache: Cache | None = None,
num_draft_tokens: int = 4,
show_visualizer: bool = False,
enable_defrag: bool = True,
**kwargs
):
"""
@@ -70,6 +71,9 @@ class Generator:
:param show_visualizer:
Open window to render visualization of cache (for debug/demonstration purposes)
:param enable_defrag:
Defragment cache periodically
:param kwargs:
"""
@@ -129,7 +133,8 @@ class Generator:
else:
self.visualizer = None
# TODO: (defrag)
# Defrag
self.enable_defrag = enable_defrag
def num_remaining_jobs(self):
@@ -195,10 +200,8 @@ class Generator:
elif job in self.active_jobs:
job.deallocate_pages()
self.active_jobs.remove(job)
# TODO: (defrag)
# if num_jobs and not self.num_remaining_jobs():
# self.pagetable.defrag()
if num_jobs and not self.num_remaining_jobs():
self.pagetable.defrag()
@torch.inference_mode
@@ -297,9 +300,9 @@ class Generator:
idx = job.serial_number
chain = [page.page_index for page in seq.allocated_pages]
chains.append((idx, chain))
usage = []
usage = [0] * self.pagetable.max_pages
for page in self.pagetable.all_pages:
usage.append(page.kv_position / PAGE_SIZE)
usage[page.page_index] = page.kv_position / PAGE_SIZE
self.visualizer.update(chains, usage)
def iterate_draftmodel_gen(self, results: list):
@@ -503,15 +506,12 @@ class Generator:
# mt_sample = False
# Release pages for completed jobs
# num_jobs = self.num_remaining_jobs()
num_jobs = self.num_remaining_jobs()
for job in completed_jobs:
job.deallocate_pages()
self.active_jobs.remove(job)
# Defrag
# TODO: (defrag)
# if num_jobs and not self.num_remaining_jobs():
# self.pagetable.defrag()
if num_jobs and not self.num_remaining_jobs():
self.pagetable.defrag()
def iterate_start_jobs(self, results: list):

View File

@@ -1,4 +1,5 @@
from __future__ import annotations
from functools import lru_cache
import torch
import hashlib
from dataclasses import dataclass
@@ -7,9 +8,12 @@ from ..cache.cache import Cache
if TYPE_CHECKING:
from .generator import Generator
from ..constants import PAGE_SIZE
from collections import deque
from collections import deque, defaultdict
from itertools import pairwise
from ..util.tensor import SeqTensor
from exllamav3.ext import exllamav3_ext as ext
import time
from ..util import profile_opt
def _tensor_blake2b_checksum(tensor: torch.Tensor, prev_hash: bytes | None) -> bytes:
@@ -60,6 +64,8 @@ class CachePage:
# Used by defragmenter
new_page_index: int
children: list[CachePage]
longest_chain: int
def __repr__(self):
return (
@@ -68,7 +74,7 @@ class CachePage:
f"kvp {self.kv_position}"
)
# Copy page state so page can be reverted even
# Copy page state so page can be reverted
def backup(self):
self.phash_revert = self.phash
self.prev_hash_revert = self.prev_hash
@@ -221,6 +227,7 @@ class PageTable:
self.unreferenced_pages = {}
self.all_pages = []
self.reset_page_table()
self.last_defrag_serial = self.max_pages
def reset_page_table(self):
@@ -242,17 +249,18 @@ class PageTable:
sequence = torch.empty((1, PAGE_SIZE), dtype = torch.long),
ref_count = 0,
access_serial = idx,
access_serial_revert = 0,
access_serial_revert = idx,
kv_position = 0,
kv_position_revert = 0,
can_revert = False,
new_page_index = 0
new_page_index = 0,
children = [],
longest_chain = 1,
)
self.all_pages.append(cp)
self.unreferenced_pages[h] = cp
self.access_serial = self.max_pages
# TODO: (defrag)
# self.last_defrag_serial = self.access_serial
self.last_defrag_serial = self.access_serial
def print_page_list(self, short: bool = True):
@@ -270,11 +278,6 @@ class PageTable:
print()
def defrag(self):
# TODO: (defrag)
pass
def allocate_pages(
self,
page_hashes: list,
@@ -358,4 +361,205 @@ class PageTable:
def num_unreferenced_pages(self):
return len(self.unreferenced_pages)
return len(self.unreferenced_pages)
def defrag(self, debug = False):
if not self.generator.enable_defrag:
return
# Defragment once job queue is empty and all pages have been touched at least once
if self.access_serial < self.last_defrag_serial + self.max_pages:
return
self.last_defrag_serial = self.access_serial
assert not self.referenced_pages
if debug:
torch.cuda.synchronize()
time_begin = time.time()
# Build page index
page_index = {}
def build_page_index():
nonlocal page_index
page_index = {}
for page in self.all_pages:
page_index[page.phash] = page
page.children = []
page.longest_chain = 1
build_page_index()
# Find cached sequences that can be recovered
root_pages = []
def build_root_pages():
nonlocal root_pages
root_pages = []
for page in self.all_pages:
if page.prev_hash is None:
root_pages.append(page)
else:
parent = page_index.get(page.prev_hash)
if parent is not None:
parent.children.append(page)
build_root_pages()
# Measure recoverable sequence length
def measure(p):
p.longest_chain = 1
if p.children:
p.longest_chain += max([measure(pc) for pc in p.children])
return p.longest_chain
for page in root_pages:
measure(page)
# Recursively sort branches by length
def sort_seq(p):
if len(p.children) > 1:
p.children = sorted(p.children, key = lambda x: x.longest_chain, reverse = True)
for pc in p.children:
sort_seq(pc)
for page in root_pages:
sort_seq(page)
# Process roots in order of increasing age
root_pages = sorted(root_pages, key = lambda x: x.access_serial)
# Maintain the longest sequence for each tree and create new root nodes from trimmed branches
index = 0
while index < len(root_pages):
page = root_pages[index]
while page.children:
root_pages += page.children[1:]
page.children = page.children[:1]
page = page.children[0]
index += 1
# Reorder partial sequences into the longest possible contiguous strings
new_page_index = 0
shift_counts = defaultdict(int)
non_orphaned_pages = []
orphans = page_index
for page in root_pages:
while True:
non_orphaned_pages.append(page)
del orphans[page.phash]
page.new_page_index = new_page_index
shift = page.new_page_index - page.page_index
shift_counts[shift] += 1
new_page_index += 1
if not page.children:
break
page = page.children[0]
# Move orphans to end of cache, ordered by last access
if orphans:
orphans = list(orphans.values())
orphans = sorted(orphans, key = lambda x: x.page_index)
access_serials = [page.access_serial for page in orphans]
access_serials = sorted(access_serials)
for page, access_serial in zip(orphans, access_serials):
page.access_serial = access_serial
page.new_page_index = new_page_index
shift = page.new_page_index - page.page_index
shift_counts[shift] += 1
new_page_index += 1
assert new_page_index == self.max_pages
# Adjust overall shift to minimize page copies
shift_adjust = max(shift_counts, key = shift_counts.get)
# Order of operations
if debug:
print("Page shifts")
defrag_map = {}
for page in self.all_pages:
page.new_page_index = (page.new_page_index - shift_adjust + self.max_pages) % self.max_pages
if page.page_index != page.new_page_index:
defrag_map[page.new_page_index] = page.page_index
if debug:
print(f"{page.new_page_index:2}{page.page_index:2}")
# Don't bother if less than 10% of cache is fragmented
if len(defrag_map) <= max(self.max_pages // 10, 2):
return
# Get all tensors to reshuffle
cache_tensors = self.cache.get_all_tensors()
if debug:
print("Page rotations")
# Find page rotations
all_rotations = []
while defrag_map:
# Get first dst,src pair in new loop
dst = next(iter(defrag_map))
src = defrag_map[dst]
del defrag_map[dst]
rotation = [dst, src]
# Walk around loop
while True:
if src == rotation[0]:
rotation = [-1, src] + rotation[:-1] + [-1]
all_rotations += rotation
break
dst = src
src = defrag_map[dst]
del defrag_map[dst]
rotation += [dst, src]
if debug:
print("".join([".."] + [f"{rotation[i + 1]:2}" for i in range(0, len(rotation) - 2, 2)] + [".."]))
# Rotate pages
all_rotations_cpu = torch.tensor(all_rotations, dtype = torch.int)
@lru_cache
def get_all_rotations(device):
nonlocal all_rotations_cpu
return all_rotations_cpu.to(device)
@lru_cache
def get_buffer(shape, device, dtype):
return torch.empty(shape, device = device, dtype = dtype)
for cache in cache_tensors:
buffer = get_buffer(cache[0].shape, cache.device, cache.dtype)
all_rotations = get_all_rotations(cache.device)
ext.cache_rotate(cache, all_rotations, buffer)
# Write new page indices
for page in self.all_pages:
page.page_index = page.new_page_index
# Debug stuff
if debug:
build_page_index()
build_root_pages()
def dbg_walk(l, p):
nonlocal walks
l = l + [p]
if not p.children:
walks.append(l)
else:
for p in p.children:
dbg_walk(l, p)
print("Cache seqs")
for page in root_pages:
walks = []
dbg_walk([], page)
for pp in walks:
print("".join([f"{p.page_index:2}" for p in pp]))
torch.cuda.synchronize()
elapsed = time.time() - time_begin
print(f"Defrag latency: {elapsed:.5f} s")

View File

@@ -1,6 +1,6 @@
import sys, os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from exllamav3 import Config, Model, Cache, Tokenizer, Generator, Job, ArgmaxSampler
from exllamav3 import Config, Model, Cache, Tokenizer, Generator, Job, ArgmaxSampler, CacheLayer_quant
import random
"""
@@ -9,12 +9,22 @@ the paging and caching logic is sound. Each prompt is an increasing sequence of
is verified as a correct continuation of the sequence.
"""
model_dir = "/mnt/str/eval_models/llama3.1-8b-instruct/exl3/4.0bpw/"
# ANSI codes
ESC = "\u001b"
col_default = "\u001b[0m"
col_yellow = "\u001b[33;1m"
col_blue = "\u001b[34;1m"
col_green = "\u001b[32;1m"
col_red = "\u001b[31;1m"
col_gray = "\u001b[37;1m"
model_dir = "/mnt/str/models/llama3.2-1b-instruct/exl3/5.0bpw/"
cache_size = 16384
draft_model_dir = None
prompt_len = (50, 8192)
completion_len = (50, 2048)
target_q_depth = (0, 20)
prompt_len = (50, 4096)
completion_len = (50, 768)
target_q_depth = (0, 25)
force_depth_0_interval = 3
prefixes = ["All the numbers: ", "It never ends: ", "Counting forever: "]
suffix = ", ".join([str(i) for i in range(prompt_len[1])])
random.seed(0)
@@ -29,8 +39,14 @@ else:
config = Config.from_directory(model_dir)
model = Model.from_config(config)
cache = Cache(model, max_num_tokens = cache_size)
model.load("cuda:1")
cache = Cache(
model,
max_num_tokens = cache_size,
# layer_type = CacheLayer_quant,
# k_bits = 5,
# v_bits = 3,
)
model.load("cuda:2")
tokenizer = Tokenizer.from_config(config)
@@ -40,6 +56,7 @@ generator = Generator(
draft_model = draft_model,
draft_cache = draft_cache,
tokenizer = tokenizer,
show_visualizer = True, # Slows down the test but makes it less boring
)
def start_new_job():
@@ -87,11 +104,13 @@ def iterate():
else:
print("Sus!")
print("--------")
print(full)
pr = result["identifier"]
print(col_green + pr + col_red + full[len(pr):] + col_default)
print("--------")
# Main loop
next_target_q_depth = 0
depth_0_interval = force_depth_0_interval
while True:
# Iterate until target q depth is reached
@@ -99,13 +118,19 @@ while True:
print(f" - Generating, target depth {next_target_q_depth}")
while generator.num_remaining_jobs() > next_target_q_depth:
iterate()
# print ("iter:", generator.num_remaining_jobs())
next_target_q_depth = random.randint(target_q_depth[0], target_q_depth[1])
next_target_q_depth = random.randint(target_q_depth[0] + 1, target_q_depth[1])
# Start new jobs until target queue depth is achieved
if generator.num_remaining_jobs() < next_target_q_depth:
print(f" - Creating jobs, target depth {next_target_q_depth}")
while generator.num_remaining_jobs() < next_target_q_depth:
start_new_job()
# print ("add:", generator.num_remaining_jobs())
next_target_q_depth = random.randint(target_q_depth[0], generator.num_remaining_jobs() - 1)
# Force the queue to reach zero depth to trigger more defragmentation steps
depth_0_interval -= 1
if depth_0_interval == 0:
next_target_q_depth = 0
depth_0_interval = force_depth_0_interval
else:
next_target_q_depth = random.randint(target_q_depth[0], generator.num_remaining_jobs() - 1)

View File

@@ -37,15 +37,23 @@ def test_rope(cache_dim, cache_dtype, full):
if not full:
order = order[:num_pages // 4]
order = order.repeat_interleave(2)
m1 = torch.tensor([-1], device = device, dtype = torch.int)
order = torch.cat([m1, order, m1], dim = -1)
if not full:
order = torch.cat([order, order], dim = -1)
ref_cache = cache.clone()
ref_order = order.tolist()
for _ in range(3):
temp = torch.empty_like(ref_cache[0])
temp.copy_(ref_cache[ref_order[0], ...])
for a, b in pairwise(ref_order):
ref_cache[a, ...].copy_(ref_cache[b, ...])
ref_cache[ref_order[-1], ...].copy_(temp)
for i in range(0, len(ref_order), 2):
a = ref_order[i]
b = ref_order[i + 1]
dst = ref_cache[a, ...] if a >= 0 else temp
src = ref_cache[b, ...] if b >= 0 else temp
dst.copy_(src)
temp = torch.empty_like(cache[0])
ext.cache_rotate(cache, order, temp)