From a811641c3ba247f605f095d2776aed8fe25cd1a0 Mon Sep 17 00:00:00 2001 From: turboderp <11859846+turboderp@users.noreply.github.com> Date: Tue, 27 May 2025 00:52:44 +0200 Subject: [PATCH] Optimize paged cache defrag --- exllamav2/exllamav2_ext/cuda/cache.cu | 82 ++++++++++++++++++++++++ exllamav2/exllamav2_ext/cuda/cache.cuh | 9 +++ exllamav2/exllamav2_ext/ext_bindings.cpp | 3 + exllamav2/generator/dynamic.py | 56 ++++++++++++---- 4 files changed, 136 insertions(+), 14 deletions(-) diff --git a/exllamav2/exllamav2_ext/cuda/cache.cu b/exllamav2/exllamav2_ext/cuda/cache.cu index 53ec1cb..26cf2a6 100644 --- a/exllamav2/exllamav2_ext/cuda/cache.cu +++ b/exllamav2/exllamav2_ext/cuda/cache.cu @@ -1,4 +1,7 @@ #include "cache.cuh" +#include +#include +#include #include "quant/qdq_util.cuh" #include "util.cuh" @@ -492,3 +495,82 @@ void array_q_to_fp16_kv_cuda dim, offset, stride ); } + +#define NUM_THREADS 512 +#define NUM_BLOCKS 128 +#define CEIL_DIVIDE(x, size) (((x) + (size) - 1) / (size)) + +__global__ __launch_bounds__(NUM_THREADS) +void cache_rotate_kernel +( + uint8_t* __restrict__ cache, + const uint32_t* __restrict__ order, + uint8_t* __restrict__ temp, + size_t page_size, + size_t rotate_len +) +{ + // Chunk for current CTA + size_t block_size = CEIL_DIVIDE(page_size, gridDim.x); + size_t block_beg = blockIdx.x * block_size; + size_t block_end = min(block_beg + block_size, page_size); + block_size = block_end - block_beg; + if (!block_size) return; + + // Rotate pages + auto copy = [&](uint8_t* dst, uint8_t* src) + { + for (int offset = threadIdx.x * 16; offset < block_size; offset += NUM_THREADS * 16) + *((uint4*) (dst + offset)) = *((uint4*) (src + offset)); + }; + + int i; + copy(temp + block_beg, cache + page_size * (uint64_t) order[0] + block_beg); + for (i = 0; i < rotate_len - 1; ++i) + copy(cache + page_size * (uint64_t) order[i] + block_beg, cache + page_size * (uint64_t) order[i + 1] + block_beg); + copy(cache + page_size * (uint64_t) order[i] + block_beg, temp + block_beg); +} + +/* +Reorder cache pages +- cache, paged cache, shape (num_pages, ...), any dtype, contiguous +- order, sequence to rotate, shape (n,), dtype long +- temp, temp storage, sized as one cache page + +Performs: + +temp <- page[order[0]] +for a, b in pairwise(order): + page[a] <- page[b] +page[order[-1]] <- temp +*/ + +void cache_rotate +( + const at::Tensor& cache, + const at::Tensor& order, + const at::Tensor& temp +) +{ + const at::cuda::OptionalCUDAGuard device_guard(cache.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + + TORCH_CHECK(cache.dim() > 1, "cache argument must have dim >= 2") + TORCH_CHECK(order.dim() == 1, "order argument must have dim == 1") +// TORCH_CHECK_DTYPE(order, kInt); + + size_t num_pages = cache.size(0); + size_t page_size = cache.nbytes() / num_pages; + size_t rotate_len = order.size(0); + + TORCH_CHECK(temp.nbytes() == page_size, "temp tensor incorrect size"); + + cache_rotate_kernel<<>> + ( + (uint8_t*) cache.data_ptr(), + (const uint32_t*) order.data_ptr(), + (uint8_t*) temp.data_ptr(), + page_size, + rotate_len + ); +} diff --git a/exllamav2/exllamav2_ext/cuda/cache.cuh b/exllamav2/exllamav2_ext/cuda/cache.cuh index 4a647c2..aec1e47 100644 --- a/exllamav2/exllamav2_ext/cuda/cache.cuh +++ b/exllamav2/exllamav2_ext/cuda/cache.cuh @@ -6,6 +6,8 @@ #include #include +#include + void array_fp16_to_fp8_cuda ( cudaStream_t stream, @@ -100,4 +102,11 @@ void array_q_to_fp16_kv_paged_cuda // void array_fp16_to_fp8_ref_cuda(const half* pIn, unsigned char *pOut, int size); // void array_fp8_to_fp16_ref_cuda(const unsigned char* pIn, half* pOut, int size); +void cache_rotate +( + const at::Tensor& cache, + const at::Tensor& order, + const at::Tensor& temp +); + #endif diff --git a/exllamav2/exllamav2_ext/ext_bindings.cpp b/exllamav2/exllamav2_ext/ext_bindings.cpp index c93ffc2..ecbb65b 100644 --- a/exllamav2/exllamav2_ext/ext_bindings.cpp +++ b/exllamav2/exllamav2_ext/ext_bindings.cpp @@ -22,6 +22,8 @@ #include "ext_element.h" #include "ext_tp.h" +#include "cuda/cache.cuh" + PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { // quant @@ -95,6 +97,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) m.def("count_match", &count_match, "count_match"); // m.def("array_fp16_to_fp8_ref", &array_fp16_to_fp8_ref, "array_fp16_to_fp8_ref"); // m.def("array_fp8_to_fp16_ref", &array_fp8_to_fp16_ref, "array_fp8_to_fp16_ref"); + m.def("cache_rotate", &cache_rotate, "cache_rotate"); // hadamard diff --git a/exllamav2/generator/dynamic.py b/exllamav2/generator/dynamic.py index 07a121f..c292eb4 100644 --- a/exllamav2/generator/dynamic.py +++ b/exllamav2/generator/dynamic.py @@ -1352,27 +1352,34 @@ class ExLlamaV2DynamicGenerator: if not self.paged: return + # Defragment once job queue is empty after touching all the cache pages if self.access_serial < self.last_defrag_serial + self.max_pages: return self.last_defrag_serial = self.access_serial assert not self.referenced_pages - @dataclass class CacheNode: page: CachePage | None - parent: CachePage | None = None - children: set[CacheNode] = None - left_page: int = len(self.all_pages) + parent: CacheNode | None + children: set[CacheNode] | None + children_sorted: deque[CacheNode] | None + left_page: int = 0 def __init__(self, page_): self.page = page_ - if self.page: - self.left_page = page_.access_serial + self.parent = None self.children = set() + self.children_sorted = None + self.left_page = page_.access_serial if page_ else 0 def __hash__(self): return id(self) def __eq__(self, other): return self is other + def presort(self, recursive = True): + self.children_sorted = deque(sorted(self.children, key = lambda x: x.left_page)) + if recursive: + for c in self.children: + c.presort() # Build a tree of the current cache @@ -1393,28 +1400,50 @@ class ExLlamaV2DynamicGenerator: # Remove oldest branch until tree is empty + root_node.presort() + shift_counts = {} + new_page_index = 0 while root_node.children: - oldest = min(root_node.children, key = lambda x: x.left_page) + oldest = root_node.children_sorted[0] node = oldest skipped_nodes = set() while True: node.page.new_page_index = new_page_index + shift = node.page.new_page_index - node.page.page_index + if shift in shift_counts: + shift_counts[shift] += 1 + else: + shift_counts[shift] = 1 new_page_index += 1 if not node.children: break - next_node = min(node.children, key = lambda x: x.left_page) - skipped_nodes |= set([n for n in node.children if n != next_node]) + next_node = node.children_sorted[0] + if len(node.children_sorted) > 1: + skipped_nodes |= set([n for n in node.children if n != next_node]) node = next_node root_node.children.remove(oldest) + root_node.children_sorted.popleft() root_node.children |= skipped_nodes + if len(skipped_nodes): + root_node.presort(False) + + # Adjust overall shift to minimize page copies + + shift_adjust = max(shift_counts, key = shift_counts.get) # Order of operations defrag_map = {} for page in self.all_pages: + page.new_page_index = (page.new_page_index - shift_adjust + self.max_pages) % self.max_pages if page.page_index != page.new_page_index: defrag_map[page.new_page_index] = page.page_index + # Don't bother if less than 10% of cache is fragmented + + if len(defrag_map) <= self.max_pages // 10: + return + # Shuffle pages cache_tensors = self.cache.all_tensors() @@ -1435,12 +1464,11 @@ class ExLlamaV2DynamicGenerator: source = defrag_map[target] del defrag_map[target] - rotation = [r * self.page_size for r in rotation] + rotation = torch.tensor(rotation, dtype = torch.int) for cache, buffer in zip(cache_tensors, defrag_buffers): - buffer[:, :, :, :].copy_(cache[:, rotation[0] : rotation[0] + self.page_size, :, :]) - for a, b in pairwise(rotation): - cache[:, a : a + self.page_size, :, :].copy_(cache[:, b : b + self.page_size, :, :]) - cache[:, rotation[-1] : rotation[-1] + self.page_size, :, :].copy_(buffer[:, :, :, :]) + rotation = rotation.to(cache.device) + cache = cache.view(cache.shape[1] // self.page_size, -1) + ext_c.cache_rotate(cache, rotation, buffer) # Update page table