Merge branch 'dev'

This commit is contained in:
kingbri
2025-05-27 11:13:29 -04:00
9 changed files with 156 additions and 97 deletions

View File

@@ -1,4 +1,7 @@
#include "cache.cuh"
#include <c10/cuda/CUDAGuard.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda_fp16.h>
#include "quant/qdq_util.cuh"
#include "util.cuh"
@@ -492,3 +495,82 @@ void array_q_to_fp16_kv_cuda
dim, offset, stride
);
}
#define NUM_THREADS 512
#define NUM_BLOCKS 128
#define CEIL_DIVIDE(x, size) (((x) + (size) - 1) / (size))
__global__ __launch_bounds__(NUM_THREADS)
void cache_rotate_kernel
(
uint8_t* __restrict__ cache,
const uint32_t* __restrict__ order,
uint8_t* __restrict__ temp,
size_t page_size,
size_t rotate_len
)
{
// Chunk for current CTA
size_t block_size = CEIL_DIVIDE(page_size, gridDim.x);
size_t block_beg = blockIdx.x * block_size;
size_t block_end = min(block_beg + block_size, page_size);
block_size = block_end - block_beg;
if (!block_size) return;
// Rotate pages
auto copy = [&](uint8_t* dst, uint8_t* src)
{
for (int offset = threadIdx.x * 16; offset < block_size; offset += NUM_THREADS * 16)
*((uint4*) (dst + offset)) = *((uint4*) (src + offset));
};
int i;
copy(temp + block_beg, cache + page_size * (uint64_t) order[0] + block_beg);
for (i = 0; i < rotate_len - 1; ++i)
copy(cache + page_size * (uint64_t) order[i] + block_beg, cache + page_size * (uint64_t) order[i + 1] + block_beg);
copy(cache + page_size * (uint64_t) order[i] + block_beg, temp + block_beg);
}
/*
Reorder cache pages
- cache, paged cache, shape (num_pages, ...), any dtype, contiguous
- order, sequence to rotate, shape (n,), dtype long
- temp, temp storage, sized as one cache page
Performs:
temp <- page[order[0]]
for a, b in pairwise(order):
page[a] <- page[b]
page[order[-1]] <- temp
*/
void cache_rotate
(
const at::Tensor& cache,
const at::Tensor& order,
const at::Tensor& temp
)
{
const at::cuda::OptionalCUDAGuard device_guard(cache.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
TORCH_CHECK(cache.dim() > 1, "cache argument must have dim >= 2")
TORCH_CHECK(order.dim() == 1, "order argument must have dim == 1")
// TORCH_CHECK_DTYPE(order, kInt);
size_t num_pages = cache.size(0);
size_t page_size = cache.nbytes() / num_pages;
size_t rotate_len = order.size(0);
TORCH_CHECK(temp.nbytes() == page_size, "temp tensor incorrect size");
cache_rotate_kernel<<<NUM_BLOCKS, NUM_THREADS, 0, stream>>>
(
(uint8_t*) cache.data_ptr(),
(const uint32_t*) order.data_ptr(),
(uint8_t*) temp.data_ptr(),
page_size,
rotate_len
);
}

View File

@@ -6,6 +6,8 @@
#include <cstdint>
#include <cstdio>
#include <ATen/Tensor.h>
void array_fp16_to_fp8_cuda
(
cudaStream_t stream,
@@ -100,4 +102,11 @@ void array_q_to_fp16_kv_paged_cuda
// void array_fp16_to_fp8_ref_cuda(const half* pIn, unsigned char *pOut, int size);
// void array_fp8_to_fp16_ref_cuda(const unsigned char* pIn, half* pOut, int size);
void cache_rotate
(
const at::Tensor& cache,
const at::Tensor& order,
const at::Tensor& temp
);
#endif

View File

@@ -22,6 +22,8 @@
#include "ext_element.h"
#include "ext_tp.h"
#include "cuda/cache.cuh"
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
// quant
@@ -95,6 +97,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
m.def("count_match", &count_match, "count_match");
// m.def("array_fp16_to_fp8_ref", &array_fp16_to_fp8_ref, "array_fp16_to_fp8_ref");
// m.def("array_fp8_to_fp16_ref", &array_fp8_to_fp16_ref, "array_fp8_to_fp16_ref");
m.def("cache_rotate", &cache_rotate, "cache_rotate");
// hadamard

View File

@@ -1352,27 +1352,34 @@ class ExLlamaV2DynamicGenerator:
if not self.paged:
return
# Defragment once job queue is empty after touching all the cache pages
if self.access_serial < self.last_defrag_serial + self.max_pages:
return
self.last_defrag_serial = self.access_serial
assert not self.referenced_pages
@dataclass
class CacheNode:
page: CachePage | None
parent: CachePage | None = None
children: set[CacheNode] = None
left_page: int = len(self.all_pages)
parent: CacheNode | None
children: set[CacheNode] | None
children_sorted: deque[CacheNode] | None
left_page: int = 0
def __init__(self, page_):
self.page = page_
if self.page:
self.left_page = page_.access_serial
self.parent = None
self.children = set()
self.children_sorted = None
self.left_page = page_.access_serial if page_ else 0
def __hash__(self):
return id(self)
def __eq__(self, other):
return self is other
def presort(self, recursive = True):
self.children_sorted = deque(sorted(self.children, key = lambda x: x.left_page))
if recursive:
for c in self.children:
c.presort()
# Build a tree of the current cache
@@ -1393,28 +1400,50 @@ class ExLlamaV2DynamicGenerator:
# Remove oldest branch until tree is empty
root_node.presort()
shift_counts = {}
new_page_index = 0
while root_node.children:
oldest = min(root_node.children, key = lambda x: x.left_page)
oldest = root_node.children_sorted[0]
node = oldest
skipped_nodes = set()
while True:
node.page.new_page_index = new_page_index
shift = node.page.new_page_index - node.page.page_index
if shift in shift_counts:
shift_counts[shift] += 1
else:
shift_counts[shift] = 1
new_page_index += 1
if not node.children: break
next_node = min(node.children, key = lambda x: x.left_page)
skipped_nodes |= set([n for n in node.children if n != next_node])
next_node = node.children_sorted[0]
if len(node.children_sorted) > 1:
skipped_nodes |= set([n for n in node.children if n != next_node])
node = next_node
root_node.children.remove(oldest)
root_node.children_sorted.popleft()
root_node.children |= skipped_nodes
if len(skipped_nodes):
root_node.presort(False)
# Adjust overall shift to minimize page copies
shift_adjust = max(shift_counts, key = shift_counts.get)
# Order of operations
defrag_map = {}
for page in self.all_pages:
page.new_page_index = (page.new_page_index - shift_adjust + self.max_pages) % self.max_pages
if page.page_index != page.new_page_index:
defrag_map[page.new_page_index] = page.page_index
# Don't bother if less than 10% of cache is fragmented
if len(defrag_map) <= self.max_pages // 10:
return
# Shuffle pages
cache_tensors = self.cache.all_tensors()
@@ -1435,12 +1464,11 @@ class ExLlamaV2DynamicGenerator:
source = defrag_map[target]
del defrag_map[target]
rotation = [r * self.page_size for r in rotation]
rotation = torch.tensor(rotation, dtype = torch.int)
for cache, buffer in zip(cache_tensors, defrag_buffers):
buffer[:, :, :, :].copy_(cache[:, rotation[0] : rotation[0] + self.page_size, :, :])
for a, b in pairwise(rotation):
cache[:, a : a + self.page_size, :, :].copy_(cache[:, b : b + self.page_size, :, :])
cache[:, rotation[-1] : rotation[-1] + self.page_size, :, :].copy_(buffer[:, :, :, :])
rotation = rotation.to(cache.device)
cache = cache.view(cache.shape[1] // self.page_size, -1)
ext_c.cache_rotate(cache, rotation, buffer)
# Update page table

View File

@@ -1,5 +1,4 @@
from exllamav2.version import __version__
from exllamav2.tokenizer.base import ExLlamaV2TokenizerBase
from exllamav2.tokenizer.spm import ExLlamaV2TokenizerSPM
from exllamav2.tokenizer.hf import ExLlamaV2TokenizerHF

View File

@@ -1,57 +0,0 @@
from __future__ import annotations
from typing import List, Union
from sentencepiece import SentencePieceProcessor
from exllamav2.tokenizer.base import ExLlamaV2TokenizerBase
# Wrapper for SentencePiece
class ExLlamaV2TokenizerSPM(ExLlamaV2TokenizerBase):
vocab: list[str] | None
def __init__(self, tokenizer_model: str):
super().__init__()
self.vocab = None
self.spm = SentencePieceProcessor(model_file = tokenizer_model)
def unk_id(self) -> int or None: return self.spm.unk_id()
def pad_id(self) -> int or None: return self.spm.pad_id()
def bos_id(self) -> int or None: return self.spm.bos_id()
def eos_id(self) -> int or None: return self.spm.eos_id()
def unk_token(self) -> str or None: return None
def pad_token(self) -> str or None: return None
def bos_token(self) -> str or None: return None
def eos_token(self) -> str or None: return None
def space_char(self): return ""
def newline_char(self): return "\n"
def enumerate_tokens(self):
if self.vocab is not None: return enumerate(self.vocab)
self.vocab = []
for i in range(self.vocab_size()):
p = self.spm.id_to_piece(i)
if all(c == self.space_char() for c in p):
d = " " * len(p)
else:
d = self.spm.decode(i)
if p.startswith(self.space_char()) and not d.startswith(" "): d = " " + d
self.vocab.append(d)
return enumerate(self.vocab)
def id_to_piece(self, idx: int) -> str:
return self.spm.id_to_piece(idx)
def piece_to_id(self, text: str) -> int:
return self.spm.piece_to_id(text)
def vocab_size(self) -> int:
return self.spm.vocab_size()
def decode(self, ids: List[int]) -> str:
text = self.spm.decode(ids)
return text
def encode(self, text: list or str) -> list:
encoding = self.spm.EncodeAsIds(text)
return encoding

View File

@@ -5,7 +5,6 @@ import torch
import os, json, re
from exllamav2.tokenizer import (
ExLlamaV2TokenizerBase,
ExLlamaV2TokenizerSPM,
ExLlamaV2TokenizerHF
)
import threading
@@ -93,13 +92,12 @@ class ExLlamaV2Tokenizer:
Defer initialization of some data structures to speed up loading
:param force_json:
No effect from v0.2.3. tokenizer.json is now preferred over tokenizer.model by default.
If True and no tokenizer.json is present in the model directory, will emit a warning before
falling back to SPM
No effect from v0.2.3. tokenizer.json is now preferred over tokenizer.model by default. From v0.3.1
tokenizer.model is not used at all
:param force_spm:
Use only tokenizer.model (SentencePiece) even if tokenizer.model (HF Tokenizers)
is available
Deprecated, Sentencepiece is abandoned and no longer supported. All SPM tokenizers should
still load correctly via the Tokenizers library
"""
self.config = config
@@ -123,33 +121,31 @@ class ExLlamaV2Tokenizer:
# Detect tokenizer model type and initialize
path_spm = os.path.join(self.config.model_dir, "tokenizer.model")
assert not force_spm, "tokenizer.py: force_spm is deprecated. Sentencepiece is no longer supported."
path_hf = os.path.join(self.config.model_dir, "tokenizer.json")
if os.path.exists(path_hf) and not force_spm:
self.tokenizer_model = ExLlamaV2TokenizerHF(path_hf)
elif os.path.exists(path_spm):
if force_json:
print(" !! Warning: Tokenizer loading with force_json = True but no tokenizer.json found, falling back to tokenizer.model")
self.tokenizer_model = ExLlamaV2TokenizerSPM(path_spm)
else:
if not os.path.exists(path_hf):
raise FileNotFoundError("No supported tokenizer found.")
self.tokenizer_model = ExLlamaV2TokenizerHF(path_hf)
# Attempt to load added tokens from tokenizer.json
self.extended_piece_to_id = {}
self.unspecial_piece_to_id = {}
tokenizer_json_path = os.path.join(self.config.model_dir, "tokenizer.json")
if os.path.exists(tokenizer_json_path):
with open(tokenizer_json_path, encoding = "utf8") as f:
tokenizer_json = json.load(f)
if "added_tokens" in tokenizer_json:
for v in tokenizer_json["added_tokens"]:
if v["special"]:
self.extended_piece_to_id[v["content"]] = v["id"]
else:
self.unspecial_piece_to_id[v["content"]] = v["id"]
if not os.path.exists(tokenizer_json_path):
raise ValueError(" ## Model does not include a tokenizer.json file. SentencePiece-only tokenizers are no longer supported")
with open(tokenizer_json_path, encoding = "utf8") as f:
tokenizer_json = json.load(f)
if "added_tokens" in tokenizer_json:
for v in tokenizer_json["added_tokens"]:
if v["special"]:
self.extended_piece_to_id[v["content"]] = v["id"]
else:
self.unspecial_piece_to_id[v["content"]] = v["id"]
# Attempt to load tokenizer_config.json

View File

@@ -1 +1 @@
__version__ = "0.3.0"
__version__ = "0.3.1"

View File

@@ -5,7 +5,6 @@ setuptools
fastparquet
torch>=2.2.0
safetensors>=0.4.3
sentencepiece>=0.1.97
pygments
websockets
regex