mirror of
https://github.com/turboderp-org/exllamav3.git
synced 2026-04-20 06:19:10 +00:00
Generator: Periodically defragment paged cache
This commit is contained in:
11
exllamav3/cache/cache.py
vendored
11
exllamav3/cache/cache.py
vendored
@@ -49,6 +49,10 @@ class CacheLayer(ABC):
|
||||
def copy_page(self, source: CacheLayer, from_page: int, to_page: int, num_tokens: int):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_tensors(self):
|
||||
pass
|
||||
|
||||
|
||||
class Cache:
|
||||
|
||||
@@ -154,3 +158,10 @@ class Cache:
|
||||
for src, dst in zip(target.layers, self.layers):
|
||||
assert type(src) is type(dst)
|
||||
dst.copy_page(src, from_page, to_page, num_tokens)
|
||||
|
||||
|
||||
def get_all_tensors(self):
|
||||
tensors = []
|
||||
for layer in self.layers:
|
||||
tensors += layer.get_tensors()
|
||||
return tensors
|
||||
|
||||
6
exllamav3/cache/fp16.py
vendored
6
exllamav3/cache/fp16.py
vendored
@@ -67,4 +67,8 @@ class CacheLayer_fp16(CacheLayer):
|
||||
def copy_page(self, source: CacheLayer_fp16, from_page: int, to_page: int, num_tokens: int):
|
||||
assert self.shape == source.shape
|
||||
self.k[to_page, :num_tokens, :, :].copy_(source.k[from_page, :num_tokens, :, :], non_blocking = True)
|
||||
self.v[to_page, :num_tokens, :, :].copy_(source.v[from_page, :num_tokens, :, :], non_blocking = True)
|
||||
self.v[to_page, :num_tokens, :, :].copy_(source.v[from_page, :num_tokens, :, :], non_blocking = True)
|
||||
|
||||
@override
|
||||
def get_tensors(self):
|
||||
return [self.k, self.v]
|
||||
|
||||
6
exllamav3/cache/quant.py
vendored
6
exllamav3/cache/quant.py
vendored
@@ -97,4 +97,8 @@ class CacheLayer_quant(CacheLayer):
|
||||
self.qk[to_page, :num_tokens, :].copy_(source.qk[from_page, :num_tokens, :], non_blocking = True)
|
||||
self.qv[to_page, :num_tokens, :].copy_(source.qv[from_page, :num_tokens, :], non_blocking = True)
|
||||
self.sk[to_page, :num_tokens, :].copy_(source.sk[from_page, :num_tokens, :], non_blocking = True)
|
||||
self.sv[to_page, :num_tokens, :].copy_(source.sv[from_page, :num_tokens, :], non_blocking = True)
|
||||
self.sv[to_page, :num_tokens, :].copy_(source.sv[from_page, :num_tokens, :], non_blocking = True)
|
||||
|
||||
@override
|
||||
def get_tensors(self):
|
||||
return [self.qk, self.qv, self.sk, self.sv]
|
||||
|
||||
@@ -12,7 +12,7 @@ __global__ __launch_bounds__(NUM_THREADS)
|
||||
void cache_rotate_kernel
|
||||
(
|
||||
uint8_t* __restrict__ cache,
|
||||
const uint32_t* __restrict__ order,
|
||||
const int32_t* __restrict__ order,
|
||||
uint8_t* __restrict__ temp,
|
||||
size_t page_size,
|
||||
size_t rotate_len
|
||||
@@ -23,34 +23,33 @@ void cache_rotate_kernel
|
||||
size_t block_beg = blockIdx.x * block_size;
|
||||
size_t block_end = MIN(block_beg + block_size, page_size);
|
||||
block_size = block_end - block_beg;
|
||||
if (!block_size) return;
|
||||
if (block_size <= 0) return;
|
||||
|
||||
// Rotate pages
|
||||
auto copy = [&](uint8_t* dst, uint8_t* src)
|
||||
for (int i = 0; i < rotate_len; ++i)
|
||||
{
|
||||
int64_t a = (int64_t) order[2 * i];
|
||||
int64_t b = (int64_t) order[2 * i + 1];
|
||||
uint8_t* dst = (a >= 0 ? cache + page_size * a : temp) + block_beg;
|
||||
uint8_t* src = (b >= 0 ? cache + page_size * b : temp) + block_beg;
|
||||
for (int offset = threadIdx.x * 16; offset < block_size; offset += NUM_THREADS * 16)
|
||||
*((uint4*) (dst + offset)) = *((uint4*) (src + offset));
|
||||
};
|
||||
|
||||
int i;
|
||||
copy(temp + block_beg, cache + page_size * (uint64_t) order[0] + block_beg);
|
||||
for (i = 0; i < rotate_len - 1; ++i)
|
||||
copy(cache + page_size * (uint64_t) order[i] + block_beg, cache + page_size * (uint64_t) order[i + 1] + block_beg);
|
||||
copy(cache + page_size * (uint64_t) order[i] + block_beg, temp + block_beg);
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
Reorder cache pages
|
||||
- cache, paged cache, shape (num_pages, ...), any dtype, contiguous
|
||||
- order, sequence to rotate, shape (n,), dtype long
|
||||
- order, sequence to rotate, shape (2*n,), dtype int
|
||||
- temp, temp storage, sized as one cache page
|
||||
|
||||
Performs:
|
||||
|
||||
temp <- page[order[0]]
|
||||
for a, b in pairwise(order):
|
||||
page[a] <- page[b]
|
||||
page[order[-1]] <- temp
|
||||
for i in range(n):
|
||||
a = order[2*i]
|
||||
b = order[2*i+1]
|
||||
copy: (page[a] if a >= 0 else temp) <- (page[b] if b >= 0 else temp)
|
||||
*/
|
||||
|
||||
void cache_rotate
|
||||
@@ -63,20 +62,20 @@ void cache_rotate
|
||||
const at::cuda::OptionalCUDAGuard device_guard(cache.device());
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
|
||||
TORCH_CHECK(cache.dim() > 1, "cache argument must have dim >= 2")
|
||||
TORCH_CHECK(cache.dim() >= 2, "cache argument must have dim >= 2")
|
||||
TORCH_CHECK(order.dim() == 1, "order argument must have dim == 1")
|
||||
TORCH_CHECK_DTYPE(order, kInt);
|
||||
|
||||
size_t num_pages = cache.size(0);
|
||||
size_t page_size = cache.nbytes() / num_pages;
|
||||
size_t rotate_len = order.size(0);
|
||||
size_t rotate_len = order.size(0) / 2;
|
||||
|
||||
TORCH_CHECK(temp.nbytes() == page_size, "temp tensor incorrect size");
|
||||
|
||||
cache_rotate_kernel<<<NUM_BLOCKS, NUM_THREADS, 0, stream>>>
|
||||
(
|
||||
(uint8_t*) cache.data_ptr(),
|
||||
(const uint32_t*) order.data_ptr(),
|
||||
(const int32_t*) order.data_ptr(),
|
||||
(uint8_t*) temp.data_ptr(),
|
||||
page_size,
|
||||
rotate_len
|
||||
|
||||
@@ -30,6 +30,7 @@ class Generator:
|
||||
draft_cache: Cache | None = None,
|
||||
num_draft_tokens: int = 4,
|
||||
show_visualizer: bool = False,
|
||||
enable_defrag: bool = True,
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
@@ -70,6 +71,9 @@ class Generator:
|
||||
:param show_visualizer:
|
||||
Open window to render visualization of cache (for debug/demonstration purposes)
|
||||
|
||||
:param enable_defrag:
|
||||
Defragment cache periodically
|
||||
|
||||
:param kwargs:
|
||||
"""
|
||||
|
||||
@@ -129,7 +133,8 @@ class Generator:
|
||||
else:
|
||||
self.visualizer = None
|
||||
|
||||
# TODO: (defrag)
|
||||
# Defrag
|
||||
self.enable_defrag = enable_defrag
|
||||
|
||||
|
||||
def num_remaining_jobs(self):
|
||||
@@ -195,10 +200,8 @@ class Generator:
|
||||
elif job in self.active_jobs:
|
||||
job.deallocate_pages()
|
||||
self.active_jobs.remove(job)
|
||||
|
||||
# TODO: (defrag)
|
||||
# if num_jobs and not self.num_remaining_jobs():
|
||||
# self.pagetable.defrag()
|
||||
if num_jobs and not self.num_remaining_jobs():
|
||||
self.pagetable.defrag()
|
||||
|
||||
|
||||
@torch.inference_mode
|
||||
@@ -297,9 +300,9 @@ class Generator:
|
||||
idx = job.serial_number
|
||||
chain = [page.page_index for page in seq.allocated_pages]
|
||||
chains.append((idx, chain))
|
||||
usage = []
|
||||
usage = [0] * self.pagetable.max_pages
|
||||
for page in self.pagetable.all_pages:
|
||||
usage.append(page.kv_position / PAGE_SIZE)
|
||||
usage[page.page_index] = page.kv_position / PAGE_SIZE
|
||||
self.visualizer.update(chains, usage)
|
||||
|
||||
def iterate_draftmodel_gen(self, results: list):
|
||||
@@ -503,15 +506,12 @@ class Generator:
|
||||
# mt_sample = False
|
||||
|
||||
# Release pages for completed jobs
|
||||
# num_jobs = self.num_remaining_jobs()
|
||||
num_jobs = self.num_remaining_jobs()
|
||||
for job in completed_jobs:
|
||||
job.deallocate_pages()
|
||||
self.active_jobs.remove(job)
|
||||
|
||||
# Defrag
|
||||
# TODO: (defrag)
|
||||
# if num_jobs and not self.num_remaining_jobs():
|
||||
# self.pagetable.defrag()
|
||||
if num_jobs and not self.num_remaining_jobs():
|
||||
self.pagetable.defrag()
|
||||
|
||||
|
||||
def iterate_start_jobs(self, results: list):
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from __future__ import annotations
|
||||
from functools import lru_cache
|
||||
import torch
|
||||
import hashlib
|
||||
from dataclasses import dataclass
|
||||
@@ -7,9 +8,12 @@ from ..cache.cache import Cache
|
||||
if TYPE_CHECKING:
|
||||
from .generator import Generator
|
||||
from ..constants import PAGE_SIZE
|
||||
from collections import deque
|
||||
from collections import deque, defaultdict
|
||||
from itertools import pairwise
|
||||
from ..util.tensor import SeqTensor
|
||||
from exllamav3.ext import exllamav3_ext as ext
|
||||
import time
|
||||
from ..util import profile_opt
|
||||
|
||||
|
||||
def _tensor_blake2b_checksum(tensor: torch.Tensor, prev_hash: bytes | None) -> bytes:
|
||||
@@ -60,6 +64,8 @@ class CachePage:
|
||||
|
||||
# Used by defragmenter
|
||||
new_page_index: int
|
||||
children: list[CachePage]
|
||||
longest_chain: int
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
@@ -68,7 +74,7 @@ class CachePage:
|
||||
f"kvp {self.kv_position}"
|
||||
)
|
||||
|
||||
# Copy page state so page can be reverted even
|
||||
# Copy page state so page can be reverted
|
||||
def backup(self):
|
||||
self.phash_revert = self.phash
|
||||
self.prev_hash_revert = self.prev_hash
|
||||
@@ -221,6 +227,7 @@ class PageTable:
|
||||
self.unreferenced_pages = {}
|
||||
self.all_pages = []
|
||||
self.reset_page_table()
|
||||
self.last_defrag_serial = self.max_pages
|
||||
|
||||
|
||||
def reset_page_table(self):
|
||||
@@ -242,17 +249,18 @@ class PageTable:
|
||||
sequence = torch.empty((1, PAGE_SIZE), dtype = torch.long),
|
||||
ref_count = 0,
|
||||
access_serial = idx,
|
||||
access_serial_revert = 0,
|
||||
access_serial_revert = idx,
|
||||
kv_position = 0,
|
||||
kv_position_revert = 0,
|
||||
can_revert = False,
|
||||
new_page_index = 0
|
||||
new_page_index = 0,
|
||||
children = [],
|
||||
longest_chain = 1,
|
||||
)
|
||||
self.all_pages.append(cp)
|
||||
self.unreferenced_pages[h] = cp
|
||||
self.access_serial = self.max_pages
|
||||
# TODO: (defrag)
|
||||
# self.last_defrag_serial = self.access_serial
|
||||
self.last_defrag_serial = self.access_serial
|
||||
|
||||
|
||||
def print_page_list(self, short: bool = True):
|
||||
@@ -270,11 +278,6 @@ class PageTable:
|
||||
print()
|
||||
|
||||
|
||||
def defrag(self):
|
||||
# TODO: (defrag)
|
||||
pass
|
||||
|
||||
|
||||
def allocate_pages(
|
||||
self,
|
||||
page_hashes: list,
|
||||
@@ -358,4 +361,205 @@ class PageTable:
|
||||
|
||||
|
||||
def num_unreferenced_pages(self):
|
||||
return len(self.unreferenced_pages)
|
||||
return len(self.unreferenced_pages)
|
||||
|
||||
|
||||
def defrag(self, debug = False):
|
||||
|
||||
if not self.generator.enable_defrag:
|
||||
return
|
||||
|
||||
# Defragment once job queue is empty and all pages have been touched at least once
|
||||
if self.access_serial < self.last_defrag_serial + self.max_pages:
|
||||
return
|
||||
self.last_defrag_serial = self.access_serial
|
||||
|
||||
assert not self.referenced_pages
|
||||
|
||||
if debug:
|
||||
torch.cuda.synchronize()
|
||||
time_begin = time.time()
|
||||
|
||||
# Build page index
|
||||
page_index = {}
|
||||
def build_page_index():
|
||||
nonlocal page_index
|
||||
page_index = {}
|
||||
for page in self.all_pages:
|
||||
page_index[page.phash] = page
|
||||
page.children = []
|
||||
page.longest_chain = 1
|
||||
build_page_index()
|
||||
|
||||
# Find cached sequences that can be recovered
|
||||
root_pages = []
|
||||
def build_root_pages():
|
||||
nonlocal root_pages
|
||||
root_pages = []
|
||||
for page in self.all_pages:
|
||||
if page.prev_hash is None:
|
||||
root_pages.append(page)
|
||||
else:
|
||||
parent = page_index.get(page.prev_hash)
|
||||
if parent is not None:
|
||||
parent.children.append(page)
|
||||
build_root_pages()
|
||||
|
||||
# Measure recoverable sequence length
|
||||
def measure(p):
|
||||
p.longest_chain = 1
|
||||
if p.children:
|
||||
p.longest_chain += max([measure(pc) for pc in p.children])
|
||||
return p.longest_chain
|
||||
|
||||
for page in root_pages:
|
||||
measure(page)
|
||||
|
||||
# Recursively sort branches by length
|
||||
def sort_seq(p):
|
||||
if len(p.children) > 1:
|
||||
p.children = sorted(p.children, key = lambda x: x.longest_chain, reverse = True)
|
||||
for pc in p.children:
|
||||
sort_seq(pc)
|
||||
|
||||
for page in root_pages:
|
||||
sort_seq(page)
|
||||
|
||||
# Process roots in order of increasing age
|
||||
root_pages = sorted(root_pages, key = lambda x: x.access_serial)
|
||||
|
||||
# Maintain the longest sequence for each tree and create new root nodes from trimmed branches
|
||||
index = 0
|
||||
while index < len(root_pages):
|
||||
page = root_pages[index]
|
||||
while page.children:
|
||||
root_pages += page.children[1:]
|
||||
page.children = page.children[:1]
|
||||
page = page.children[0]
|
||||
index += 1
|
||||
|
||||
# Reorder partial sequences into the longest possible contiguous strings
|
||||
new_page_index = 0
|
||||
shift_counts = defaultdict(int)
|
||||
non_orphaned_pages = []
|
||||
orphans = page_index
|
||||
for page in root_pages:
|
||||
while True:
|
||||
non_orphaned_pages.append(page)
|
||||
del orphans[page.phash]
|
||||
page.new_page_index = new_page_index
|
||||
shift = page.new_page_index - page.page_index
|
||||
shift_counts[shift] += 1
|
||||
new_page_index += 1
|
||||
if not page.children:
|
||||
break
|
||||
page = page.children[0]
|
||||
|
||||
# Move orphans to end of cache, ordered by last access
|
||||
if orphans:
|
||||
orphans = list(orphans.values())
|
||||
orphans = sorted(orphans, key = lambda x: x.page_index)
|
||||
access_serials = [page.access_serial for page in orphans]
|
||||
access_serials = sorted(access_serials)
|
||||
for page, access_serial in zip(orphans, access_serials):
|
||||
page.access_serial = access_serial
|
||||
page.new_page_index = new_page_index
|
||||
shift = page.new_page_index - page.page_index
|
||||
shift_counts[shift] += 1
|
||||
new_page_index += 1
|
||||
|
||||
assert new_page_index == self.max_pages
|
||||
|
||||
# Adjust overall shift to minimize page copies
|
||||
shift_adjust = max(shift_counts, key = shift_counts.get)
|
||||
|
||||
# Order of operations
|
||||
if debug:
|
||||
print("Page shifts")
|
||||
|
||||
defrag_map = {}
|
||||
for page in self.all_pages:
|
||||
page.new_page_index = (page.new_page_index - shift_adjust + self.max_pages) % self.max_pages
|
||||
if page.page_index != page.new_page_index:
|
||||
defrag_map[page.new_page_index] = page.page_index
|
||||
if debug:
|
||||
print(f"{page.new_page_index:2} ← {page.page_index:2}")
|
||||
|
||||
# Don't bother if less than 10% of cache is fragmented
|
||||
if len(defrag_map) <= max(self.max_pages // 10, 2):
|
||||
return
|
||||
|
||||
# Get all tensors to reshuffle
|
||||
cache_tensors = self.cache.get_all_tensors()
|
||||
|
||||
if debug:
|
||||
print("Page rotations")
|
||||
|
||||
# Find page rotations
|
||||
all_rotations = []
|
||||
while defrag_map:
|
||||
|
||||
# Get first dst,src pair in new loop
|
||||
dst = next(iter(defrag_map))
|
||||
src = defrag_map[dst]
|
||||
del defrag_map[dst]
|
||||
rotation = [dst, src]
|
||||
|
||||
# Walk around loop
|
||||
while True:
|
||||
if src == rotation[0]:
|
||||
rotation = [-1, src] + rotation[:-1] + [-1]
|
||||
all_rotations += rotation
|
||||
break
|
||||
dst = src
|
||||
src = defrag_map[dst]
|
||||
del defrag_map[dst]
|
||||
rotation += [dst, src]
|
||||
|
||||
if debug:
|
||||
print(" ← ".join([".."] + [f"{rotation[i + 1]:2}" for i in range(0, len(rotation) - 2, 2)] + [".."]))
|
||||
|
||||
# Rotate pages
|
||||
all_rotations_cpu = torch.tensor(all_rotations, dtype = torch.int)
|
||||
@lru_cache
|
||||
def get_all_rotations(device):
|
||||
nonlocal all_rotations_cpu
|
||||
return all_rotations_cpu.to(device)
|
||||
|
||||
@lru_cache
|
||||
def get_buffer(shape, device, dtype):
|
||||
return torch.empty(shape, device = device, dtype = dtype)
|
||||
|
||||
for cache in cache_tensors:
|
||||
buffer = get_buffer(cache[0].shape, cache.device, cache.dtype)
|
||||
all_rotations = get_all_rotations(cache.device)
|
||||
ext.cache_rotate(cache, all_rotations, buffer)
|
||||
|
||||
# Write new page indices
|
||||
for page in self.all_pages:
|
||||
page.page_index = page.new_page_index
|
||||
|
||||
# Debug stuff
|
||||
if debug:
|
||||
build_page_index()
|
||||
build_root_pages()
|
||||
|
||||
def dbg_walk(l, p):
|
||||
nonlocal walks
|
||||
l = l + [p]
|
||||
if not p.children:
|
||||
walks.append(l)
|
||||
else:
|
||||
for p in p.children:
|
||||
dbg_walk(l, p)
|
||||
|
||||
print("Cache seqs")
|
||||
for page in root_pages:
|
||||
walks = []
|
||||
dbg_walk([], page)
|
||||
for pp in walks:
|
||||
print(" → ".join([f"{p.page_index:2}" for p in pp]))
|
||||
|
||||
torch.cuda.synchronize()
|
||||
elapsed = time.time() - time_begin
|
||||
print(f"Defrag latency: {elapsed:.5f} s")
|
||||
@@ -1,6 +1,6 @@
|
||||
import sys, os
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
from exllamav3 import Config, Model, Cache, Tokenizer, Generator, Job, ArgmaxSampler
|
||||
from exllamav3 import Config, Model, Cache, Tokenizer, Generator, Job, ArgmaxSampler, CacheLayer_quant
|
||||
import random
|
||||
|
||||
"""
|
||||
@@ -9,12 +9,22 @@ the paging and caching logic is sound. Each prompt is an increasing sequence of
|
||||
is verified as a correct continuation of the sequence.
|
||||
"""
|
||||
|
||||
model_dir = "/mnt/str/eval_models/llama3.1-8b-instruct/exl3/4.0bpw/"
|
||||
# ANSI codes
|
||||
ESC = "\u001b"
|
||||
col_default = "\u001b[0m"
|
||||
col_yellow = "\u001b[33;1m"
|
||||
col_blue = "\u001b[34;1m"
|
||||
col_green = "\u001b[32;1m"
|
||||
col_red = "\u001b[31;1m"
|
||||
col_gray = "\u001b[37;1m"
|
||||
|
||||
model_dir = "/mnt/str/models/llama3.2-1b-instruct/exl3/5.0bpw/"
|
||||
cache_size = 16384
|
||||
draft_model_dir = None
|
||||
prompt_len = (50, 8192)
|
||||
completion_len = (50, 2048)
|
||||
target_q_depth = (0, 20)
|
||||
prompt_len = (50, 4096)
|
||||
completion_len = (50, 768)
|
||||
target_q_depth = (0, 25)
|
||||
force_depth_0_interval = 3
|
||||
prefixes = ["All the numbers: ", "It never ends: ", "Counting forever: "]
|
||||
suffix = ", ".join([str(i) for i in range(prompt_len[1])])
|
||||
random.seed(0)
|
||||
@@ -29,8 +39,14 @@ else:
|
||||
|
||||
config = Config.from_directory(model_dir)
|
||||
model = Model.from_config(config)
|
||||
cache = Cache(model, max_num_tokens = cache_size)
|
||||
model.load("cuda:1")
|
||||
cache = Cache(
|
||||
model,
|
||||
max_num_tokens = cache_size,
|
||||
# layer_type = CacheLayer_quant,
|
||||
# k_bits = 5,
|
||||
# v_bits = 3,
|
||||
)
|
||||
model.load("cuda:2")
|
||||
|
||||
tokenizer = Tokenizer.from_config(config)
|
||||
|
||||
@@ -40,6 +56,7 @@ generator = Generator(
|
||||
draft_model = draft_model,
|
||||
draft_cache = draft_cache,
|
||||
tokenizer = tokenizer,
|
||||
show_visualizer = True, # Slows down the test but makes it less boring
|
||||
)
|
||||
|
||||
def start_new_job():
|
||||
@@ -87,11 +104,13 @@ def iterate():
|
||||
else:
|
||||
print("Sus!")
|
||||
print("--------")
|
||||
print(full)
|
||||
pr = result["identifier"]
|
||||
print(col_green + pr + col_red + full[len(pr):] + col_default)
|
||||
print("--------")
|
||||
|
||||
# Main loop
|
||||
next_target_q_depth = 0
|
||||
depth_0_interval = force_depth_0_interval
|
||||
while True:
|
||||
|
||||
# Iterate until target q depth is reached
|
||||
@@ -99,13 +118,19 @@ while True:
|
||||
print(f" - Generating, target depth {next_target_q_depth}")
|
||||
while generator.num_remaining_jobs() > next_target_q_depth:
|
||||
iterate()
|
||||
# print ("iter:", generator.num_remaining_jobs())
|
||||
next_target_q_depth = random.randint(target_q_depth[0], target_q_depth[1])
|
||||
|
||||
next_target_q_depth = random.randint(target_q_depth[0] + 1, target_q_depth[1])
|
||||
|
||||
# Start new jobs until target queue depth is achieved
|
||||
if generator.num_remaining_jobs() < next_target_q_depth:
|
||||
print(f" - Creating jobs, target depth {next_target_q_depth}")
|
||||
while generator.num_remaining_jobs() < next_target_q_depth:
|
||||
start_new_job()
|
||||
# print ("add:", generator.num_remaining_jobs())
|
||||
next_target_q_depth = random.randint(target_q_depth[0], generator.num_remaining_jobs() - 1)
|
||||
|
||||
# Force the queue to reach zero depth to trigger more defragmentation steps
|
||||
depth_0_interval -= 1
|
||||
if depth_0_interval == 0:
|
||||
next_target_q_depth = 0
|
||||
depth_0_interval = force_depth_0_interval
|
||||
else:
|
||||
next_target_q_depth = random.randint(target_q_depth[0], generator.num_remaining_jobs() - 1)
|
||||
|
||||
@@ -37,15 +37,23 @@ def test_rope(cache_dim, cache_dtype, full):
|
||||
if not full:
|
||||
order = order[:num_pages // 4]
|
||||
|
||||
order = order.repeat_interleave(2)
|
||||
m1 = torch.tensor([-1], device = device, dtype = torch.int)
|
||||
order = torch.cat([m1, order, m1], dim = -1)
|
||||
if not full:
|
||||
order = torch.cat([order, order], dim = -1)
|
||||
|
||||
ref_cache = cache.clone()
|
||||
ref_order = order.tolist()
|
||||
|
||||
for _ in range(3):
|
||||
temp = torch.empty_like(ref_cache[0])
|
||||
temp.copy_(ref_cache[ref_order[0], ...])
|
||||
for a, b in pairwise(ref_order):
|
||||
ref_cache[a, ...].copy_(ref_cache[b, ...])
|
||||
ref_cache[ref_order[-1], ...].copy_(temp)
|
||||
for i in range(0, len(ref_order), 2):
|
||||
a = ref_order[i]
|
||||
b = ref_order[i + 1]
|
||||
dst = ref_cache[a, ...] if a >= 0 else temp
|
||||
src = ref_cache[b, ...] if b >= 0 else temp
|
||||
dst.copy_(src)
|
||||
|
||||
temp = torch.empty_like(cache[0])
|
||||
ext.cache_rotate(cache, order, temp)
|
||||
|
||||
Reference in New Issue
Block a user