mirror of
https://github.com/turboderp-org/exllamav2.git
synced 2026-04-20 06:19:00 +00:00
Asynchronous filter evaluation
This commit is contained in:
@@ -232,6 +232,8 @@ class ExLlamaV2DynamicGenerator:
|
||||
max_sampling_threads: int
|
||||
min_sampling_threads: int
|
||||
sampling_pool: ThreadPoolExecutor
|
||||
filter_pool: ThreadPoolExecutor
|
||||
filter_queue: list
|
||||
|
||||
|
||||
def __init__(
|
||||
@@ -443,6 +445,11 @@ class ExLlamaV2DynamicGenerator:
|
||||
if max_sampling_threads > 1:
|
||||
self.sampling_pool = ThreadPoolExecutor(max_workers = max_sampling_threads)
|
||||
|
||||
# Filter threads
|
||||
|
||||
self.filter_pool = ThreadPoolExecutor(max_workers = 16)
|
||||
self.filter_queue = []
|
||||
|
||||
# Temp buffers for defrag
|
||||
|
||||
if self.paged:
|
||||
@@ -1130,6 +1137,14 @@ class ExLlamaV2DynamicGenerator:
|
||||
loras = self.current_loras,
|
||||
)["logits"]
|
||||
|
||||
# GPU workload is scheduled here, so launch any sampling filters that can run while waiting for CUDA
|
||||
|
||||
if self.filter_queue:
|
||||
for f in self.filter_queue:
|
||||
f.background_next(self.filter_pool)
|
||||
time.sleep(0)
|
||||
self.filter_queue.clear()
|
||||
|
||||
# Pass logits to jobs for sampling
|
||||
|
||||
batch_logits = self.logits_pinned[:device_logits.shape[0], :device_logits.shape[1], :]
|
||||
@@ -1729,10 +1744,10 @@ class ExLlamaV2DynamicJob:
|
||||
|
||||
# Start filters
|
||||
|
||||
# TODO: Try to move filter evaluation to the end of the forward pass, before sampling so it can potentially
|
||||
# occur while waiting for the CUDA queue
|
||||
if self.new_tokens == 0:
|
||||
for f in self.filters: f.begin("")
|
||||
for f in self.filters:
|
||||
f.background_drop()
|
||||
f.begin("")
|
||||
|
||||
# Sample
|
||||
|
||||
@@ -1780,7 +1795,11 @@ class ExLlamaV2DynamicJob:
|
||||
# Feed filters
|
||||
|
||||
if self.new_tokens >= 0:
|
||||
for f in self.filters: f.feed(next_token)
|
||||
for f in self.filters:
|
||||
f.feed(next_token)
|
||||
# Evaluate filter in background when possible
|
||||
if f.use_background_worker():
|
||||
self.generator.filter_queue.append(f)
|
||||
|
||||
# Accept token
|
||||
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
from exllamav2 import (
|
||||
ExLlamaV2,
|
||||
ExLlamaV2Tokenizer,
|
||||
)
|
||||
from __future__ import annotations
|
||||
from threading import Lock
|
||||
from concurrent.futures import ThreadPoolExecutor, Future
|
||||
from exllamav2 import ExLlamaV2, ExLlamaV2Tokenizer
|
||||
from exllamav2.ext import exllamav2_ext as ext_c, none_tensor
|
||||
import torch
|
||||
|
||||
class ExLlamaV2Filter:
|
||||
|
||||
@@ -11,6 +13,10 @@ class ExLlamaV2Filter:
|
||||
tokenizer: ExLlamaV2Tokenizer
|
||||
sequence_str: str
|
||||
|
||||
background_result: Future | None = None
|
||||
|
||||
# For compatibility
|
||||
allow_return_type_list: bool = True
|
||||
|
||||
def __init__(self,
|
||||
model: ExLlamaV2,
|
||||
@@ -31,13 +37,51 @@ class ExLlamaV2Filter:
|
||||
|
||||
|
||||
def begin(self, prefix_str):
|
||||
pass
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def feed(self, token):
|
||||
pass
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def next(self):
|
||||
pass
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def use_background_worker(self) -> bool:
|
||||
"""
|
||||
To indicate whether filter can/should run as a background thread. If True, next() will be called
|
||||
asynchronously after the CUDA workload has been scheduled for the following forward pass, instead of right
|
||||
before sampling. Should be True for any CPU-intensive filter such as a grammar constraint.
|
||||
"""
|
||||
return False
|
||||
|
||||
|
||||
def background_next(self, pool: ThreadPoolExecutor):
|
||||
"""
|
||||
Schedule next() via the provided thread pool executor
|
||||
"""
|
||||
assert self.background_result is None
|
||||
self.background_result = pool.submit(self.next)
|
||||
|
||||
|
||||
def background_drop(self):
|
||||
"""
|
||||
Clear the result of an asynchronous filter pass. Used when a complex filter reaches an end state and forces
|
||||
the selection of eos_token_id. next() could still be scheduled after this selection, leaving a pending result
|
||||
that would break subsequent generations with the same filter.
|
||||
"""
|
||||
if self.background_result is not None:
|
||||
self.background_result.result()
|
||||
self.background_result = None
|
||||
|
||||
|
||||
def get_next(self) -> tuple:
|
||||
"""
|
||||
Return either next() or the result of any scheduled call to next()
|
||||
"""
|
||||
if self.background_result is None:
|
||||
return self.next()
|
||||
r = self.background_result.result()
|
||||
self.background_result = None
|
||||
return r
|
||||
@@ -272,7 +272,7 @@ class ExLlamaV2Sampler:
|
||||
end_tokens = None
|
||||
for f in filters:
|
||||
|
||||
pt, et = f.next()
|
||||
pt, et = f.get_next()
|
||||
if len(filters) > 1 and not isinstance(pt, set):
|
||||
pt, et = set(pt), set(et)
|
||||
|
||||
@@ -280,7 +280,7 @@ class ExLlamaV2Sampler:
|
||||
if et is not None: end_tokens = et if end_tokens is None else end_tokens | et
|
||||
|
||||
if pass_tokens is not None:
|
||||
assert pass_tokens, "Filter excluded all tokens"
|
||||
assert len(pass_tokens), "Filter excluded all tokens"
|
||||
|
||||
# Special case if a single token passes
|
||||
if len(pass_tokens) == 1 and return_top_tokens == 0 and prefix_token is None:
|
||||
|
||||
Reference in New Issue
Block a user