Asynchronous filter evaluation

This commit is contained in:
turboderp
2024-08-31 16:31:51 +02:00
parent 12f08dbbd8
commit 0d5c0bcc8d
3 changed files with 76 additions and 13 deletions

View File

@@ -232,6 +232,8 @@ class ExLlamaV2DynamicGenerator:
max_sampling_threads: int
min_sampling_threads: int
sampling_pool: ThreadPoolExecutor
filter_pool: ThreadPoolExecutor
filter_queue: list
def __init__(
@@ -443,6 +445,11 @@ class ExLlamaV2DynamicGenerator:
if max_sampling_threads > 1:
self.sampling_pool = ThreadPoolExecutor(max_workers = max_sampling_threads)
# Filter threads
self.filter_pool = ThreadPoolExecutor(max_workers = 16)
self.filter_queue = []
# Temp buffers for defrag
if self.paged:
@@ -1130,6 +1137,14 @@ class ExLlamaV2DynamicGenerator:
loras = self.current_loras,
)["logits"]
# GPU workload is scheduled here, so launch any sampling filters that can run while waiting for CUDA
if self.filter_queue:
for f in self.filter_queue:
f.background_next(self.filter_pool)
time.sleep(0)
self.filter_queue.clear()
# Pass logits to jobs for sampling
batch_logits = self.logits_pinned[:device_logits.shape[0], :device_logits.shape[1], :]
@@ -1729,10 +1744,10 @@ class ExLlamaV2DynamicJob:
# Start filters
# TODO: Try to move filter evaluation to the end of the forward pass, before sampling so it can potentially
# occur while waiting for the CUDA queue
if self.new_tokens == 0:
for f in self.filters: f.begin("")
for f in self.filters:
f.background_drop()
f.begin("")
# Sample
@@ -1780,7 +1795,11 @@ class ExLlamaV2DynamicJob:
# Feed filters
if self.new_tokens >= 0:
for f in self.filters: f.feed(next_token)
for f in self.filters:
f.feed(next_token)
# Evaluate filter in background when possible
if f.use_background_worker():
self.generator.filter_queue.append(f)
# Accept token

View File

@@ -1,7 +1,9 @@
from exllamav2 import (
ExLlamaV2,
ExLlamaV2Tokenizer,
)
from __future__ import annotations
from threading import Lock
from concurrent.futures import ThreadPoolExecutor, Future
from exllamav2 import ExLlamaV2, ExLlamaV2Tokenizer
from exllamav2.ext import exllamav2_ext as ext_c, none_tensor
import torch
class ExLlamaV2Filter:
@@ -11,6 +13,10 @@ class ExLlamaV2Filter:
tokenizer: ExLlamaV2Tokenizer
sequence_str: str
background_result: Future | None = None
# For compatibility
allow_return_type_list: bool = True
def __init__(self,
model: ExLlamaV2,
@@ -31,13 +37,51 @@ class ExLlamaV2Filter:
def begin(self, prefix_str):
pass
raise NotImplementedError
def feed(self, token):
pass
raise NotImplementedError
def next(self):
pass
raise NotImplementedError
def use_background_worker(self) -> bool:
"""
To indicate whether filter can/should run as a background thread. If True, next() will be called
asynchronously after the CUDA workload has been scheduled for the following forward pass, instead of right
before sampling. Should be True for any CPU-intensive filter such as a grammar constraint.
"""
return False
def background_next(self, pool: ThreadPoolExecutor):
"""
Schedule next() via the provided thread pool executor
"""
assert self.background_result is None
self.background_result = pool.submit(self.next)
def background_drop(self):
"""
Clear the result of an asynchronous filter pass. Used when a complex filter reaches an end state and forces
the selection of eos_token_id. next() could still be scheduled after this selection, leaving a pending result
that would break subsequent generations with the same filter.
"""
if self.background_result is not None:
self.background_result.result()
self.background_result = None
def get_next(self) -> tuple:
"""
Return either next() or the result of any scheduled call to next()
"""
if self.background_result is None:
return self.next()
r = self.background_result.result()
self.background_result = None
return r

View File

@@ -272,7 +272,7 @@ class ExLlamaV2Sampler:
end_tokens = None
for f in filters:
pt, et = f.next()
pt, et = f.get_next()
if len(filters) > 1 and not isinstance(pt, set):
pt, et = set(pt), set(et)
@@ -280,7 +280,7 @@ class ExLlamaV2Sampler:
if et is not None: end_tokens = et if end_tokens is None else end_tokens | et
if pass_tokens is not None:
assert pass_tokens, "Filter excluded all tokens"
assert len(pass_tokens), "Filter excluded all tokens"
# Special case if a single token passes
if len(pass_tokens) == 1 and return_top_tokens == 0 and prefix_token is None: