Optimize paged cache defrag

This commit is contained in:
turboderp
2025-05-27 00:52:44 +02:00
parent 1adff7d827
commit a811641c3b
4 changed files with 136 additions and 14 deletions

View File

@@ -1,4 +1,7 @@
#include "cache.cuh"
#include <c10/cuda/CUDAGuard.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda_fp16.h>
#include "quant/qdq_util.cuh"
#include "util.cuh"
@@ -492,3 +495,82 @@ void array_q_to_fp16_kv_cuda
dim, offset, stride
);
}
#define NUM_THREADS 512
#define NUM_BLOCKS 128
#define CEIL_DIVIDE(x, size) (((x) + (size) - 1) / (size))
__global__ __launch_bounds__(NUM_THREADS)
void cache_rotate_kernel
(
uint8_t* __restrict__ cache,
const uint32_t* __restrict__ order,
uint8_t* __restrict__ temp,
size_t page_size,
size_t rotate_len
)
{
// Chunk for current CTA
size_t block_size = CEIL_DIVIDE(page_size, gridDim.x);
size_t block_beg = blockIdx.x * block_size;
size_t block_end = min(block_beg + block_size, page_size);
block_size = block_end - block_beg;
if (!block_size) return;
// Rotate pages
auto copy = [&](uint8_t* dst, uint8_t* src)
{
for (int offset = threadIdx.x * 16; offset < block_size; offset += NUM_THREADS * 16)
*((uint4*) (dst + offset)) = *((uint4*) (src + offset));
};
int i;
copy(temp + block_beg, cache + page_size * (uint64_t) order[0] + block_beg);
for (i = 0; i < rotate_len - 1; ++i)
copy(cache + page_size * (uint64_t) order[i] + block_beg, cache + page_size * (uint64_t) order[i + 1] + block_beg);
copy(cache + page_size * (uint64_t) order[i] + block_beg, temp + block_beg);
}
/*
Reorder cache pages
- cache, paged cache, shape (num_pages, ...), any dtype, contiguous
- order, sequence to rotate, shape (n,), dtype long
- temp, temp storage, sized as one cache page
Performs:
temp <- page[order[0]]
for a, b in pairwise(order):
page[a] <- page[b]
page[order[-1]] <- temp
*/
void cache_rotate
(
const at::Tensor& cache,
const at::Tensor& order,
const at::Tensor& temp
)
{
const at::cuda::OptionalCUDAGuard device_guard(cache.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
TORCH_CHECK(cache.dim() > 1, "cache argument must have dim >= 2")
TORCH_CHECK(order.dim() == 1, "order argument must have dim == 1")
// TORCH_CHECK_DTYPE(order, kInt);
size_t num_pages = cache.size(0);
size_t page_size = cache.nbytes() / num_pages;
size_t rotate_len = order.size(0);
TORCH_CHECK(temp.nbytes() == page_size, "temp tensor incorrect size");
cache_rotate_kernel<<<NUM_BLOCKS, NUM_THREADS, 0, stream>>>
(
(uint8_t*) cache.data_ptr(),
(const uint32_t*) order.data_ptr(),
(uint8_t*) temp.data_ptr(),
page_size,
rotate_len
);
}

View File

@@ -6,6 +6,8 @@
#include <cstdint>
#include <cstdio>
#include <ATen/Tensor.h>
void array_fp16_to_fp8_cuda
(
cudaStream_t stream,
@@ -100,4 +102,11 @@ void array_q_to_fp16_kv_paged_cuda
// void array_fp16_to_fp8_ref_cuda(const half* pIn, unsigned char *pOut, int size);
// void array_fp8_to_fp16_ref_cuda(const unsigned char* pIn, half* pOut, int size);
void cache_rotate
(
const at::Tensor& cache,
const at::Tensor& order,
const at::Tensor& temp
);
#endif

View File

@@ -22,6 +22,8 @@
#include "ext_element.h"
#include "ext_tp.h"
#include "cuda/cache.cuh"
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
// quant
@@ -95,6 +97,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
m.def("count_match", &count_match, "count_match");
// m.def("array_fp16_to_fp8_ref", &array_fp16_to_fp8_ref, "array_fp16_to_fp8_ref");
// m.def("array_fp8_to_fp16_ref", &array_fp8_to_fp16_ref, "array_fp8_to_fp16_ref");
m.def("cache_rotate", &cache_rotate, "cache_rotate");
// hadamard

View File

@@ -1352,27 +1352,34 @@ class ExLlamaV2DynamicGenerator:
if not self.paged:
return
# Defragment once job queue is empty after touching all the cache pages
if self.access_serial < self.last_defrag_serial + self.max_pages:
return
self.last_defrag_serial = self.access_serial
assert not self.referenced_pages
@dataclass
class CacheNode:
page: CachePage | None
parent: CachePage | None = None
children: set[CacheNode] = None
left_page: int = len(self.all_pages)
parent: CacheNode | None
children: set[CacheNode] | None
children_sorted: deque[CacheNode] | None
left_page: int = 0
def __init__(self, page_):
self.page = page_
if self.page:
self.left_page = page_.access_serial
self.parent = None
self.children = set()
self.children_sorted = None
self.left_page = page_.access_serial if page_ else 0
def __hash__(self):
return id(self)
def __eq__(self, other):
return self is other
def presort(self, recursive = True):
self.children_sorted = deque(sorted(self.children, key = lambda x: x.left_page))
if recursive:
for c in self.children:
c.presort()
# Build a tree of the current cache
@@ -1393,28 +1400,50 @@ class ExLlamaV2DynamicGenerator:
# Remove oldest branch until tree is empty
root_node.presort()
shift_counts = {}
new_page_index = 0
while root_node.children:
oldest = min(root_node.children, key = lambda x: x.left_page)
oldest = root_node.children_sorted[0]
node = oldest
skipped_nodes = set()
while True:
node.page.new_page_index = new_page_index
shift = node.page.new_page_index - node.page.page_index
if shift in shift_counts:
shift_counts[shift] += 1
else:
shift_counts[shift] = 1
new_page_index += 1
if not node.children: break
next_node = min(node.children, key = lambda x: x.left_page)
skipped_nodes |= set([n for n in node.children if n != next_node])
next_node = node.children_sorted[0]
if len(node.children_sorted) > 1:
skipped_nodes |= set([n for n in node.children if n != next_node])
node = next_node
root_node.children.remove(oldest)
root_node.children_sorted.popleft()
root_node.children |= skipped_nodes
if len(skipped_nodes):
root_node.presort(False)
# Adjust overall shift to minimize page copies
shift_adjust = max(shift_counts, key = shift_counts.get)
# Order of operations
defrag_map = {}
for page in self.all_pages:
page.new_page_index = (page.new_page_index - shift_adjust + self.max_pages) % self.max_pages
if page.page_index != page.new_page_index:
defrag_map[page.new_page_index] = page.page_index
# Don't bother if less than 10% of cache is fragmented
if len(defrag_map) <= self.max_pages // 10:
return
# Shuffle pages
cache_tensors = self.cache.all_tensors()
@@ -1435,12 +1464,11 @@ class ExLlamaV2DynamicGenerator:
source = defrag_map[target]
del defrag_map[target]
rotation = [r * self.page_size for r in rotation]
rotation = torch.tensor(rotation, dtype = torch.int)
for cache, buffer in zip(cache_tensors, defrag_buffers):
buffer[:, :, :, :].copy_(cache[:, rotation[0] : rotation[0] + self.page_size, :, :])
for a, b in pairwise(rotation):
cache[:, a : a + self.page_size, :, :].copy_(cache[:, b : b + self.page_size, :, :])
cache[:, rotation[-1] : rotation[-1] + self.page_size, :, :].copy_(buffer[:, :, :, :])
rotation = rotation.to(cache.device)
cache = cache.view(cache.shape[1] // self.page_size, -1)
ext_c.cache_rotate(cache, rotation, buffer)
# Update page table