mirror of
https://github.com/turboderp-org/exllamav2.git
synced 2026-03-15 00:07:26 +00:00
Merge branch 'dev'
This commit is contained in:
@@ -1,4 +1,7 @@
|
||||
#include "cache.cuh"
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <cuda_fp16.h>
|
||||
|
||||
#include "quant/qdq_util.cuh"
|
||||
#include "util.cuh"
|
||||
@@ -492,3 +495,82 @@ void array_q_to_fp16_kv_cuda
|
||||
dim, offset, stride
|
||||
);
|
||||
}
|
||||
|
||||
#define NUM_THREADS 512
|
||||
#define NUM_BLOCKS 128
|
||||
#define CEIL_DIVIDE(x, size) (((x) + (size) - 1) / (size))
|
||||
|
||||
__global__ __launch_bounds__(NUM_THREADS)
|
||||
void cache_rotate_kernel
|
||||
(
|
||||
uint8_t* __restrict__ cache,
|
||||
const uint32_t* __restrict__ order,
|
||||
uint8_t* __restrict__ temp,
|
||||
size_t page_size,
|
||||
size_t rotate_len
|
||||
)
|
||||
{
|
||||
// Chunk for current CTA
|
||||
size_t block_size = CEIL_DIVIDE(page_size, gridDim.x);
|
||||
size_t block_beg = blockIdx.x * block_size;
|
||||
size_t block_end = min(block_beg + block_size, page_size);
|
||||
block_size = block_end - block_beg;
|
||||
if (!block_size) return;
|
||||
|
||||
// Rotate pages
|
||||
auto copy = [&](uint8_t* dst, uint8_t* src)
|
||||
{
|
||||
for (int offset = threadIdx.x * 16; offset < block_size; offset += NUM_THREADS * 16)
|
||||
*((uint4*) (dst + offset)) = *((uint4*) (src + offset));
|
||||
};
|
||||
|
||||
int i;
|
||||
copy(temp + block_beg, cache + page_size * (uint64_t) order[0] + block_beg);
|
||||
for (i = 0; i < rotate_len - 1; ++i)
|
||||
copy(cache + page_size * (uint64_t) order[i] + block_beg, cache + page_size * (uint64_t) order[i + 1] + block_beg);
|
||||
copy(cache + page_size * (uint64_t) order[i] + block_beg, temp + block_beg);
|
||||
}
|
||||
|
||||
/*
|
||||
Reorder cache pages
|
||||
- cache, paged cache, shape (num_pages, ...), any dtype, contiguous
|
||||
- order, sequence to rotate, shape (n,), dtype long
|
||||
- temp, temp storage, sized as one cache page
|
||||
|
||||
Performs:
|
||||
|
||||
temp <- page[order[0]]
|
||||
for a, b in pairwise(order):
|
||||
page[a] <- page[b]
|
||||
page[order[-1]] <- temp
|
||||
*/
|
||||
|
||||
void cache_rotate
|
||||
(
|
||||
const at::Tensor& cache,
|
||||
const at::Tensor& order,
|
||||
const at::Tensor& temp
|
||||
)
|
||||
{
|
||||
const at::cuda::OptionalCUDAGuard device_guard(cache.device());
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
|
||||
TORCH_CHECK(cache.dim() > 1, "cache argument must have dim >= 2")
|
||||
TORCH_CHECK(order.dim() == 1, "order argument must have dim == 1")
|
||||
// TORCH_CHECK_DTYPE(order, kInt);
|
||||
|
||||
size_t num_pages = cache.size(0);
|
||||
size_t page_size = cache.nbytes() / num_pages;
|
||||
size_t rotate_len = order.size(0);
|
||||
|
||||
TORCH_CHECK(temp.nbytes() == page_size, "temp tensor incorrect size");
|
||||
|
||||
cache_rotate_kernel<<<NUM_BLOCKS, NUM_THREADS, 0, stream>>>
|
||||
(
|
||||
(uint8_t*) cache.data_ptr(),
|
||||
(const uint32_t*) order.data_ptr(),
|
||||
(uint8_t*) temp.data_ptr(),
|
||||
page_size,
|
||||
rotate_len
|
||||
);
|
||||
}
|
||||
|
||||
@@ -6,6 +6,8 @@
|
||||
#include <cstdint>
|
||||
#include <cstdio>
|
||||
|
||||
#include <ATen/Tensor.h>
|
||||
|
||||
void array_fp16_to_fp8_cuda
|
||||
(
|
||||
cudaStream_t stream,
|
||||
@@ -100,4 +102,11 @@ void array_q_to_fp16_kv_paged_cuda
|
||||
// void array_fp16_to_fp8_ref_cuda(const half* pIn, unsigned char *pOut, int size);
|
||||
// void array_fp8_to_fp16_ref_cuda(const unsigned char* pIn, half* pOut, int size);
|
||||
|
||||
void cache_rotate
|
||||
(
|
||||
const at::Tensor& cache,
|
||||
const at::Tensor& order,
|
||||
const at::Tensor& temp
|
||||
);
|
||||
|
||||
#endif
|
||||
|
||||
@@ -22,6 +22,8 @@
|
||||
#include "ext_element.h"
|
||||
#include "ext_tp.h"
|
||||
|
||||
#include "cuda/cache.cuh"
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
||||
{
|
||||
// quant
|
||||
@@ -95,6 +97,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
||||
m.def("count_match", &count_match, "count_match");
|
||||
// m.def("array_fp16_to_fp8_ref", &array_fp16_to_fp8_ref, "array_fp16_to_fp8_ref");
|
||||
// m.def("array_fp8_to_fp16_ref", &array_fp8_to_fp16_ref, "array_fp8_to_fp16_ref");
|
||||
m.def("cache_rotate", &cache_rotate, "cache_rotate");
|
||||
|
||||
// hadamard
|
||||
|
||||
|
||||
@@ -1352,27 +1352,34 @@ class ExLlamaV2DynamicGenerator:
|
||||
if not self.paged:
|
||||
return
|
||||
|
||||
# Defragment once job queue is empty after touching all the cache pages
|
||||
if self.access_serial < self.last_defrag_serial + self.max_pages:
|
||||
return
|
||||
self.last_defrag_serial = self.access_serial
|
||||
|
||||
assert not self.referenced_pages
|
||||
|
||||
@dataclass
|
||||
class CacheNode:
|
||||
page: CachePage | None
|
||||
parent: CachePage | None = None
|
||||
children: set[CacheNode] = None
|
||||
left_page: int = len(self.all_pages)
|
||||
parent: CacheNode | None
|
||||
children: set[CacheNode] | None
|
||||
children_sorted: deque[CacheNode] | None
|
||||
left_page: int = 0
|
||||
def __init__(self, page_):
|
||||
self.page = page_
|
||||
if self.page:
|
||||
self.left_page = page_.access_serial
|
||||
self.parent = None
|
||||
self.children = set()
|
||||
self.children_sorted = None
|
||||
self.left_page = page_.access_serial if page_ else 0
|
||||
def __hash__(self):
|
||||
return id(self)
|
||||
def __eq__(self, other):
|
||||
return self is other
|
||||
def presort(self, recursive = True):
|
||||
self.children_sorted = deque(sorted(self.children, key = lambda x: x.left_page))
|
||||
if recursive:
|
||||
for c in self.children:
|
||||
c.presort()
|
||||
|
||||
# Build a tree of the current cache
|
||||
|
||||
@@ -1393,28 +1400,50 @@ class ExLlamaV2DynamicGenerator:
|
||||
|
||||
# Remove oldest branch until tree is empty
|
||||
|
||||
root_node.presort()
|
||||
shift_counts = {}
|
||||
|
||||
new_page_index = 0
|
||||
while root_node.children:
|
||||
oldest = min(root_node.children, key = lambda x: x.left_page)
|
||||
oldest = root_node.children_sorted[0]
|
||||
node = oldest
|
||||
skipped_nodes = set()
|
||||
while True:
|
||||
node.page.new_page_index = new_page_index
|
||||
shift = node.page.new_page_index - node.page.page_index
|
||||
if shift in shift_counts:
|
||||
shift_counts[shift] += 1
|
||||
else:
|
||||
shift_counts[shift] = 1
|
||||
new_page_index += 1
|
||||
if not node.children: break
|
||||
next_node = min(node.children, key = lambda x: x.left_page)
|
||||
skipped_nodes |= set([n for n in node.children if n != next_node])
|
||||
next_node = node.children_sorted[0]
|
||||
if len(node.children_sorted) > 1:
|
||||
skipped_nodes |= set([n for n in node.children if n != next_node])
|
||||
node = next_node
|
||||
root_node.children.remove(oldest)
|
||||
root_node.children_sorted.popleft()
|
||||
root_node.children |= skipped_nodes
|
||||
if len(skipped_nodes):
|
||||
root_node.presort(False)
|
||||
|
||||
# Adjust overall shift to minimize page copies
|
||||
|
||||
shift_adjust = max(shift_counts, key = shift_counts.get)
|
||||
|
||||
# Order of operations
|
||||
|
||||
defrag_map = {}
|
||||
for page in self.all_pages:
|
||||
page.new_page_index = (page.new_page_index - shift_adjust + self.max_pages) % self.max_pages
|
||||
if page.page_index != page.new_page_index:
|
||||
defrag_map[page.new_page_index] = page.page_index
|
||||
|
||||
# Don't bother if less than 10% of cache is fragmented
|
||||
|
||||
if len(defrag_map) <= self.max_pages // 10:
|
||||
return
|
||||
|
||||
# Shuffle pages
|
||||
|
||||
cache_tensors = self.cache.all_tensors()
|
||||
@@ -1435,12 +1464,11 @@ class ExLlamaV2DynamicGenerator:
|
||||
source = defrag_map[target]
|
||||
del defrag_map[target]
|
||||
|
||||
rotation = [r * self.page_size for r in rotation]
|
||||
rotation = torch.tensor(rotation, dtype = torch.int)
|
||||
for cache, buffer in zip(cache_tensors, defrag_buffers):
|
||||
buffer[:, :, :, :].copy_(cache[:, rotation[0] : rotation[0] + self.page_size, :, :])
|
||||
for a, b in pairwise(rotation):
|
||||
cache[:, a : a + self.page_size, :, :].copy_(cache[:, b : b + self.page_size, :, :])
|
||||
cache[:, rotation[-1] : rotation[-1] + self.page_size, :, :].copy_(buffer[:, :, :, :])
|
||||
rotation = rotation.to(cache.device)
|
||||
cache = cache.view(cache.shape[1] // self.page_size, -1)
|
||||
ext_c.cache_rotate(cache, rotation, buffer)
|
||||
|
||||
# Update page table
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from exllamav2.version import __version__
|
||||
|
||||
from exllamav2.tokenizer.base import ExLlamaV2TokenizerBase
|
||||
from exllamav2.tokenizer.spm import ExLlamaV2TokenizerSPM
|
||||
from exllamav2.tokenizer.hf import ExLlamaV2TokenizerHF
|
||||
|
||||
@@ -1,57 +0,0 @@
|
||||
from __future__ import annotations
|
||||
from typing import List, Union
|
||||
from sentencepiece import SentencePieceProcessor
|
||||
from exllamav2.tokenizer.base import ExLlamaV2TokenizerBase
|
||||
|
||||
# Wrapper for SentencePiece
|
||||
|
||||
class ExLlamaV2TokenizerSPM(ExLlamaV2TokenizerBase):
|
||||
|
||||
vocab: list[str] | None
|
||||
|
||||
def __init__(self, tokenizer_model: str):
|
||||
super().__init__()
|
||||
self.vocab = None
|
||||
self.spm = SentencePieceProcessor(model_file = tokenizer_model)
|
||||
|
||||
def unk_id(self) -> int or None: return self.spm.unk_id()
|
||||
def pad_id(self) -> int or None: return self.spm.pad_id()
|
||||
def bos_id(self) -> int or None: return self.spm.bos_id()
|
||||
def eos_id(self) -> int or None: return self.spm.eos_id()
|
||||
def unk_token(self) -> str or None: return None
|
||||
def pad_token(self) -> str or None: return None
|
||||
def bos_token(self) -> str or None: return None
|
||||
def eos_token(self) -> str or None: return None
|
||||
|
||||
def space_char(self): return "▁"
|
||||
def newline_char(self): return "\n"
|
||||
|
||||
def enumerate_tokens(self):
|
||||
if self.vocab is not None: return enumerate(self.vocab)
|
||||
self.vocab = []
|
||||
for i in range(self.vocab_size()):
|
||||
p = self.spm.id_to_piece(i)
|
||||
if all(c == self.space_char() for c in p):
|
||||
d = " " * len(p)
|
||||
else:
|
||||
d = self.spm.decode(i)
|
||||
if p.startswith(self.space_char()) and not d.startswith(" "): d = " " + d
|
||||
self.vocab.append(d)
|
||||
return enumerate(self.vocab)
|
||||
|
||||
def id_to_piece(self, idx: int) -> str:
|
||||
return self.spm.id_to_piece(idx)
|
||||
|
||||
def piece_to_id(self, text: str) -> int:
|
||||
return self.spm.piece_to_id(text)
|
||||
|
||||
def vocab_size(self) -> int:
|
||||
return self.spm.vocab_size()
|
||||
|
||||
def decode(self, ids: List[int]) -> str:
|
||||
text = self.spm.decode(ids)
|
||||
return text
|
||||
|
||||
def encode(self, text: list or str) -> list:
|
||||
encoding = self.spm.EncodeAsIds(text)
|
||||
return encoding
|
||||
@@ -5,7 +5,6 @@ import torch
|
||||
import os, json, re
|
||||
from exllamav2.tokenizer import (
|
||||
ExLlamaV2TokenizerBase,
|
||||
ExLlamaV2TokenizerSPM,
|
||||
ExLlamaV2TokenizerHF
|
||||
)
|
||||
import threading
|
||||
@@ -93,13 +92,12 @@ class ExLlamaV2Tokenizer:
|
||||
Defer initialization of some data structures to speed up loading
|
||||
|
||||
:param force_json:
|
||||
No effect from v0.2.3. tokenizer.json is now preferred over tokenizer.model by default.
|
||||
If True and no tokenizer.json is present in the model directory, will emit a warning before
|
||||
falling back to SPM
|
||||
No effect from v0.2.3. tokenizer.json is now preferred over tokenizer.model by default. From v0.3.1
|
||||
tokenizer.model is not used at all
|
||||
|
||||
:param force_spm:
|
||||
Use only tokenizer.model (SentencePiece) even if tokenizer.model (HF Tokenizers)
|
||||
is available
|
||||
Deprecated, Sentencepiece is abandoned and no longer supported. All SPM tokenizers should
|
||||
still load correctly via the Tokenizers library
|
||||
"""
|
||||
|
||||
self.config = config
|
||||
@@ -123,33 +121,31 @@ class ExLlamaV2Tokenizer:
|
||||
|
||||
# Detect tokenizer model type and initialize
|
||||
|
||||
path_spm = os.path.join(self.config.model_dir, "tokenizer.model")
|
||||
assert not force_spm, "tokenizer.py: force_spm is deprecated. Sentencepiece is no longer supported."
|
||||
path_hf = os.path.join(self.config.model_dir, "tokenizer.json")
|
||||
|
||||
if os.path.exists(path_hf) and not force_spm:
|
||||
self.tokenizer_model = ExLlamaV2TokenizerHF(path_hf)
|
||||
elif os.path.exists(path_spm):
|
||||
if force_json:
|
||||
print(" !! Warning: Tokenizer loading with force_json = True but no tokenizer.json found, falling back to tokenizer.model")
|
||||
self.tokenizer_model = ExLlamaV2TokenizerSPM(path_spm)
|
||||
else:
|
||||
if not os.path.exists(path_hf):
|
||||
raise FileNotFoundError("No supported tokenizer found.")
|
||||
|
||||
self.tokenizer_model = ExLlamaV2TokenizerHF(path_hf)
|
||||
|
||||
# Attempt to load added tokens from tokenizer.json
|
||||
|
||||
self.extended_piece_to_id = {}
|
||||
self.unspecial_piece_to_id = {}
|
||||
|
||||
tokenizer_json_path = os.path.join(self.config.model_dir, "tokenizer.json")
|
||||
if os.path.exists(tokenizer_json_path):
|
||||
with open(tokenizer_json_path, encoding = "utf8") as f:
|
||||
tokenizer_json = json.load(f)
|
||||
if "added_tokens" in tokenizer_json:
|
||||
for v in tokenizer_json["added_tokens"]:
|
||||
if v["special"]:
|
||||
self.extended_piece_to_id[v["content"]] = v["id"]
|
||||
else:
|
||||
self.unspecial_piece_to_id[v["content"]] = v["id"]
|
||||
if not os.path.exists(tokenizer_json_path):
|
||||
raise ValueError(" ## Model does not include a tokenizer.json file. SentencePiece-only tokenizers are no longer supported")
|
||||
|
||||
with open(tokenizer_json_path, encoding = "utf8") as f:
|
||||
tokenizer_json = json.load(f)
|
||||
if "added_tokens" in tokenizer_json:
|
||||
for v in tokenizer_json["added_tokens"]:
|
||||
if v["special"]:
|
||||
self.extended_piece_to_id[v["content"]] = v["id"]
|
||||
else:
|
||||
self.unspecial_piece_to_id[v["content"]] = v["id"]
|
||||
|
||||
# Attempt to load tokenizer_config.json
|
||||
|
||||
|
||||
@@ -1 +1 @@
|
||||
__version__ = "0.3.0"
|
||||
__version__ = "0.3.1"
|
||||
@@ -5,7 +5,6 @@ setuptools
|
||||
fastparquet
|
||||
torch>=2.2.0
|
||||
safetensors>=0.4.3
|
||||
sentencepiece>=0.1.97
|
||||
pygments
|
||||
websockets
|
||||
regex
|
||||
|
||||
Reference in New Issue
Block a user