mirror of
https://github.com/turboderp-org/exllamav3.git
synced 2026-04-20 14:29:51 +00:00
Generator: Add requeue option
This commit is contained in:
@@ -205,7 +205,6 @@ class Generator:
|
||||
job.prepare_for_queue(self, self.job_serial)
|
||||
self.job_serial += 1
|
||||
self.pending_jobs.append(job)
|
||||
job.time_enqueue = time.time()
|
||||
return job.serial_number
|
||||
|
||||
|
||||
@@ -521,6 +520,7 @@ class Generator:
|
||||
|
||||
# Pass to jobs to sample
|
||||
completed_jobs = []
|
||||
requeuing_jobs = []
|
||||
j = 0
|
||||
for job, a, b in zip(self.active_jobs, logit_mapping[:-1], logit_mapping[1:]):
|
||||
if a == b: continue
|
||||
@@ -531,7 +531,7 @@ class Generator:
|
||||
next_token, next_k_tokens, next_k_probs, next_prob = job.receive_logits(
|
||||
token_logits,
|
||||
)
|
||||
eos, sampled_token = job.receive_sample(
|
||||
eos, sampled_token, rq = job.receive_sample(
|
||||
token_logits,
|
||||
next_token,
|
||||
next_k_tokens,
|
||||
@@ -540,6 +540,11 @@ class Generator:
|
||||
results,
|
||||
)
|
||||
|
||||
# Requeue
|
||||
if len(job.sequences) == 1 and rq:
|
||||
requeuing_jobs.append(job)
|
||||
break
|
||||
|
||||
# EOS
|
||||
if eos:
|
||||
completed_jobs.append(job)
|
||||
@@ -565,9 +570,16 @@ class Generator:
|
||||
|
||||
# Release pages for completed jobs
|
||||
num_jobs = self.num_remaining_jobs()
|
||||
for job in completed_jobs:
|
||||
for job in completed_jobs + requeuing_jobs:
|
||||
job.deallocate_pages()
|
||||
self.active_jobs.remove(job)
|
||||
|
||||
# Requeue jobs
|
||||
for job in requeuing_jobs:
|
||||
rq_job = job.prepare_for_requeue()
|
||||
self.pending_jobs.append(rq_job)
|
||||
|
||||
# Defrag
|
||||
if num_jobs and not self.num_remaining_jobs():
|
||||
self.pagetable.defrag()
|
||||
|
||||
@@ -636,6 +648,7 @@ class Generator:
|
||||
filters: list[list[Filter]] | list[Filter] | None = None,
|
||||
return_last_results: bool = False,
|
||||
embeddings: list[MMEmbedding] | list[list[MMEmbedding]] | None = None,
|
||||
max_rq_tokens: int | None = None,
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
@@ -696,6 +709,11 @@ class Generator:
|
||||
:param embeddings:
|
||||
Optional list of MMEmbeddings to use for, or list of lists for batched generation
|
||||
|
||||
:param max_rq_tokens:
|
||||
Maximum number of tokens before job is requeued. Rounded to nearest page boundary. This limits how
|
||||
many new pages are allocated in the cache for the job in any one round and allows a single job to use
|
||||
the full cache size without limiting concurrency for other jobs.
|
||||
|
||||
:return:
|
||||
Completion(s): (str or list[str] depending on the type of the input prompt argument)
|
||||
Optionally, last results: (dict or list[dict] depending on the type of the input prompt argument)
|
||||
@@ -760,7 +778,8 @@ class Generator:
|
||||
filters = filters[idx] or [],
|
||||
token_healing = token_healing,
|
||||
decode_special_tokens = decode_special_tokens,
|
||||
embeddings = embeddings[idx] or []
|
||||
embeddings = embeddings[idx] or [],
|
||||
max_rq_tokens = max_rq_tokens
|
||||
)
|
||||
|
||||
if seed is not None: seed += 1
|
||||
|
||||
@@ -60,6 +60,8 @@ class Job:
|
||||
identifier: object | None = None,
|
||||
banned_strings: list[str] | None = None,
|
||||
embeddings: list[MMEmbedding] | None = None,
|
||||
max_rq_tokens: int | None = None,
|
||||
rq_state: dict | None = None,
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
@@ -123,12 +125,26 @@ class Job:
|
||||
:param embeddings:
|
||||
Optional list of MMEmbeddings to use, or list of lists for batched generation
|
||||
|
||||
:param max_rq_tokens:
|
||||
Maximum number of tokens before job is requeued. Rounded to nearest page boundary. This limits how
|
||||
many new pages are allocated in the cache for the job in any one round and allows a single job to use
|
||||
the full cache size without limiting concurrency for other jobs.
|
||||
|
||||
:param rq_state:
|
||||
Internal, passed when job requeues itself
|
||||
|
||||
:param kwargs:
|
||||
"""
|
||||
|
||||
assert all(ids.device.type == "cpu" for ids in input_ids), \
|
||||
"input_ids must reside in system memory"
|
||||
|
||||
if rq_state is None:
|
||||
rq_state = {}
|
||||
self.is_requeued = True
|
||||
else:
|
||||
self.is_requeued = False
|
||||
|
||||
self.generator = None
|
||||
self.pagetable = None
|
||||
self.serial_number = None
|
||||
@@ -143,13 +159,14 @@ class Job:
|
||||
sampler = DefaultSampler()
|
||||
|
||||
# Sampling state
|
||||
self.held_text = ""
|
||||
self.held_tokens = None
|
||||
self.held_k_tokens = None
|
||||
self.held_k_probs = None
|
||||
self.held_probs = None
|
||||
self.held_logits = None
|
||||
self.full_completion = ""
|
||||
self.held_text = rq_state.get("held_text", "")
|
||||
self.held_tokens = rq_state.get("held_tokens")
|
||||
self.held_k_tokens = rq_state.get("held_k_tokens")
|
||||
self.held_k_probs = rq_state.get("held_k_probs")
|
||||
self.held_probs = rq_state.get("held_probs")
|
||||
self.held_logits = rq_state.get("held_logits")
|
||||
self.full_completion = rq_state.get("full_completion", "")
|
||||
self.seed = seed
|
||||
|
||||
# Prepare sequences
|
||||
if not isinstance(input_ids, list):
|
||||
@@ -176,7 +193,11 @@ class Job:
|
||||
self.min_new_tokens = min_new_tokens
|
||||
self.new_tokens = 0 if self.prefix_token is None else -1
|
||||
self.sampler = sampler
|
||||
self.rng = random.Random() if seed is None else random.Random(seed)
|
||||
self.rng = rq_state.get("rng")
|
||||
if self.rng is None:
|
||||
self.rng = random.Random() if seed is None else random.Random(seed)
|
||||
self.orig_max_rq_tokens = max_rq_tokens
|
||||
self.max_rq_tokens = max_rq_tokens
|
||||
|
||||
# Output options
|
||||
self.decode_special_tokens = decode_special_tokens
|
||||
@@ -185,8 +206,8 @@ class Job:
|
||||
self.return_probs = return_probs
|
||||
|
||||
# Stop conditions
|
||||
self.stop_strings = set()
|
||||
self.stop_tokens = set()
|
||||
self.stop_strings = rq_state.get("stop_strings", set())
|
||||
self.stop_tokens = rq_state.get("stop_tokens", set())
|
||||
if stop_conditions is not None:
|
||||
for t in stop_conditions:
|
||||
if isinstance(t, int):
|
||||
@@ -198,7 +219,8 @@ class Job:
|
||||
self.stop_strings_utf32_buffer, self.stop_strings_utf32_offsets = \
|
||||
_strings_to_utf32(tuple(list(self.stop_strings)))
|
||||
else:
|
||||
self.stop_strings_utf32_buffer, self.stop_strings_utf32_offsets = None, None
|
||||
self.stop_strings_utf32_buffer = rq_state.get("stop_strings_utf32_buffer")
|
||||
self.stop_strings_utf32_offsets = rq_state.get("stop_strings_utf32_offsets")
|
||||
|
||||
self.stop_tokens_list = list(self.stop_tokens)
|
||||
self.stop_strings_list = list(self.stop_strings)
|
||||
@@ -212,7 +234,8 @@ class Job:
|
||||
_strings_to_utf32(tuple(self.banned_strings))
|
||||
else:
|
||||
self.banned_strings = []
|
||||
self.banned_strings_utf32_buffer, self.banned_strings_utf32_offsets = None, None
|
||||
self.banned_strings_utf32_buffer = None
|
||||
self.banned_strings_utf32_offsets = None
|
||||
|
||||
self.checkpoint = None
|
||||
|
||||
@@ -221,6 +244,9 @@ class Job:
|
||||
self.time_first_prefill = None
|
||||
self.time_first_token = None
|
||||
self.time_last_token = None
|
||||
self.time_enqueued = rq_state.get("time_enqueued", 0.0)
|
||||
self.time_prefill = rq_state.get("time_prefill", 0.0)
|
||||
self.time_generate = rq_state.get("time_generate", 0.0)
|
||||
self.accepted_draft_tokens = 0
|
||||
self.rejected_draft_tokens = 0
|
||||
self.cached_pages = 0
|
||||
@@ -432,6 +458,7 @@ class Job:
|
||||
|
||||
# Accept token
|
||||
self.new_tokens += 1
|
||||
requeue_now = self.new_tokens > self.max_rq_tokens - self.generator.num_draft_tokens
|
||||
|
||||
for seq in self.sequences:
|
||||
|
||||
@@ -470,7 +497,6 @@ class Job:
|
||||
page.add_ref(new_serial)
|
||||
|
||||
else:
|
||||
|
||||
# If an unreferenced page has the same hash, clear that page
|
||||
if new_hash in self.pagetable.unreferenced_pages:
|
||||
up = self.pagetable.unreferenced_pages[new_hash]
|
||||
@@ -479,9 +505,11 @@ class Job:
|
||||
# Update the hash
|
||||
page.update_hash(new_hash)
|
||||
|
||||
page = seq.allocated_pages[page_after]
|
||||
page.prev_hash = new_hash
|
||||
page.can_revert = False
|
||||
# Allow completing the final page without starting a new one (for requeue)
|
||||
if page_after < len(seq.allocated_pages):
|
||||
page = seq.allocated_pages[page_after]
|
||||
page.prev_hash = new_hash
|
||||
page.can_revert = False
|
||||
|
||||
# Stream output
|
||||
|
||||
@@ -496,12 +524,7 @@ class Job:
|
||||
stop_string: str = None,
|
||||
rem_held_text: str = None
|
||||
):
|
||||
r = {
|
||||
"job": self,
|
||||
"stage": "streaming",
|
||||
"eos": emit_eos,
|
||||
"serial": self.serial_number,
|
||||
}
|
||||
nonlocal requeue_now
|
||||
|
||||
r = {
|
||||
"job": self,
|
||||
@@ -522,6 +545,15 @@ class Job:
|
||||
if eos_reason == "stop_string":
|
||||
r.update({ "eos_triggering_string": stop_string })
|
||||
|
||||
# Requeue if we reach max_rq_tokens
|
||||
if requeue_now:
|
||||
r.update({ "requeue": True })
|
||||
requeue_now = True
|
||||
|
||||
# Can't revert to checkpoint after requeuing, so just emit whatever was being held
|
||||
if self.checkpoint is not None:
|
||||
emit_held = True
|
||||
|
||||
if emit_held:
|
||||
if self.held_text != "":
|
||||
self.full_completion += self.held_text
|
||||
@@ -546,16 +578,21 @@ class Job:
|
||||
r.update({ "suppressed_text": suppressed_text })
|
||||
r.update({ "suppressed_tokens": suppressed_tokens.torch() })
|
||||
|
||||
if emit_eos or requeue_now:
|
||||
self.time_last_token = time.time()
|
||||
self.time_enqueued += self.time_first_prefill - self.time_enqueue
|
||||
self.time_prefill += self.time_first_token - self.time_first_prefill
|
||||
self.time_generate += self.time_last_token - self.time_first_token
|
||||
|
||||
if emit_eos:
|
||||
self.is_finished = True
|
||||
self.time_last_token = time.time()
|
||||
r.update({
|
||||
"full_completion": self.full_completion,
|
||||
"new_tokens": self.new_tokens,
|
||||
"prompt_tokens": len(self.sequences[0].input_ids),
|
||||
"time_enqueued": self.time_first_prefill - self.time_enqueue,
|
||||
"time_prefill": self.time_first_token - self.time_first_prefill,
|
||||
"time_generate": self.time_last_token - self.time_first_token,
|
||||
"time_enqueued": self.time_enqueued,
|
||||
"time_prefill": self.time_prefill,
|
||||
"time_generate": self.time_generate,
|
||||
"cached_pages": self.cached_pages // len(self.sequences),
|
||||
"cached_tokens": (self.cached_pages * PAGE_SIZE + self.cached_tokens) // len(self.sequences),
|
||||
})
|
||||
@@ -585,7 +622,7 @@ class Job:
|
||||
r.update({ "identifier": self.identifier })
|
||||
|
||||
results_.append(r)
|
||||
return emit_eos, next_token
|
||||
return emit_eos, next_token, requeue_now
|
||||
|
||||
# Decode and buffer output
|
||||
id_to_piece = self.generator.tokenizer.get_id_to_piece_list(self.decode_special_tokens)
|
||||
@@ -681,7 +718,10 @@ class Job:
|
||||
self.checkpoint["offset"] = 0
|
||||
return off_tokens, off_text
|
||||
|
||||
if self.banned_strings_utf32_offsets is not None and self.new_tokens > 0:
|
||||
if requeue_now:
|
||||
unset_checkpoint()
|
||||
|
||||
elif self.banned_strings_utf32_offsets is not None and self.new_tokens > 0:
|
||||
match = ext.partial_strings_match(
|
||||
np.frombuffer(self.held_text.lower().encode("utf-32-le"), dtype = np.uint8),
|
||||
self.banned_strings_utf32_offsets,
|
||||
@@ -732,7 +772,54 @@ class Job:
|
||||
return emit(results, emit_held = True)
|
||||
|
||||
|
||||
def prepare_for_queue(self, generator, serial_number: int):
|
||||
def prepare_for_requeue(self):
|
||||
assert len(self.sequences) == 1
|
||||
|
||||
seq = self.sequences[0]
|
||||
last_completed_tokens = len(seq.sequence_ids) - len(seq.input_ids)
|
||||
new_input = seq.sequence_ids.torch()
|
||||
|
||||
rq_state = {
|
||||
"rng": self.rng,
|
||||
"stop_strings": self.stop_strings,
|
||||
"stop_tokens": self.stop_tokens,
|
||||
"stop_strings_utf32_buffer": self.stop_strings_utf32_buffer,
|
||||
"stop_strings_utf32_offsets": self.stop_strings_utf32_offsets,
|
||||
"held_text": self.held_text,
|
||||
"held_tokens": self.held_tokens,
|
||||
"held_probs": self.held_probs,
|
||||
"held_k_tokens": self.held_k_tokens,
|
||||
"held_k_probs": self.held_k_probs,
|
||||
"held_logits": self.held_logits,
|
||||
"full_completion": self.full_completion,
|
||||
"time_enqueued": self.time_enqueued,
|
||||
"time_prefill": self.time_prefill,
|
||||
"time_generate": self.time_generate,
|
||||
}
|
||||
|
||||
rq_job = Job(
|
||||
input_ids = new_input,
|
||||
max_new_tokens = self.max_new_tokens - last_completed_tokens,
|
||||
min_new_tokens = max(self.min_new_tokens - last_completed_tokens, 0),
|
||||
sampler = self.sampler,
|
||||
decode_special_tokens = self.decode_special_tokens,
|
||||
return_top_tokens = self.return_top_tokens,
|
||||
return_logits = self.return_logits,
|
||||
return_probs = self.return_probs,
|
||||
filters = self.filters, # Carries over state
|
||||
token_healing = False, # Token healed on first round
|
||||
identifier = self.identifier,
|
||||
banned_strings = self.banned_strings,
|
||||
embeddings = self.embeddings,
|
||||
max_rq_tokens = self.orig_max_rq_tokens,
|
||||
rq_state = rq_state,
|
||||
)
|
||||
|
||||
rq_job.prepare_for_queue(self.generator, self.serial_number, rq = True)
|
||||
return rq_job
|
||||
|
||||
|
||||
def prepare_for_queue(self, generator, serial_number: int, rq: bool = False):
|
||||
|
||||
# Attach to generator
|
||||
self.serial_number = serial_number
|
||||
@@ -740,6 +827,17 @@ class Job:
|
||||
self.pagetable = generator.pagetable
|
||||
self.skips = 0
|
||||
|
||||
# Align max_rq_tokens to page boundary or recurrent checkpoint
|
||||
if self.max_rq_tokens is not None:
|
||||
if len(self.sequences) == 1:
|
||||
boundary = self.generator.recurrent_checkpoint_interval \
|
||||
if self.generator.recurrent_cache is not None else PAGE_SIZE
|
||||
x = len(self.sequences[0].input_ids)
|
||||
y = (x - 1 + self.max_rq_tokens + boundary - 1) // boundary * boundary
|
||||
self.max_rq_tokens = y - x
|
||||
else:
|
||||
self.max_rq_tokens = self.max_new_tokens + 1
|
||||
|
||||
# Compatibility checks
|
||||
assert not self.banned_strings or self.generator.recurrent_cache is not None, \
|
||||
"Cannot use banned strings on recurrent model"
|
||||
@@ -748,7 +846,7 @@ class Job:
|
||||
all_unique_hashes = set()
|
||||
all_unique_pages = 0
|
||||
for seq in self.sequences:
|
||||
unique_hashes, unique_pages = seq.prepare(self.prefix_token is not None, self.max_new_tokens)
|
||||
unique_hashes, unique_pages = seq.prepare(self.prefix_token is not None, self.max_rq_tokens)
|
||||
all_unique_hashes |= unique_hashes
|
||||
all_unique_pages += unique_pages
|
||||
self.all_unique_hashes = list(all_unique_hashes)
|
||||
@@ -765,13 +863,16 @@ class Job:
|
||||
f"generator is {self.generator.max_batch_size}."
|
||||
|
||||
# Initial conditions
|
||||
self.held_text = ""
|
||||
self.held_tokens = SeqTensor((1, 0), dtype = torch.long, seq_dim = -1)
|
||||
self.held_k_tokens = SeqTensor((1, 0, self.return_top_tokens), dtype = torch.long, seq_dim = 1)
|
||||
self.held_k_probs = SeqTensor((1, 0, self.return_top_tokens), dtype = torch.float, seq_dim = 1)
|
||||
self.held_probs = SeqTensor((1, 0), dtype = torch.float, seq_dim = -1)
|
||||
self.held_logits = SeqTensor((1, 0, self.generator.padded_vocab_size), dtype = torch.float, seq_dim = 1)
|
||||
self.full_completion = ""
|
||||
if not rq:
|
||||
self.held_text = ""
|
||||
self.held_tokens = SeqTensor((1, 0), dtype = torch.long, seq_dim = -1)
|
||||
self.held_k_tokens = SeqTensor((1, 0, self.return_top_tokens), dtype = torch.long, seq_dim = 1)
|
||||
self.held_k_probs = SeqTensor((1, 0, self.return_top_tokens), dtype = torch.float, seq_dim = 1)
|
||||
self.held_probs = SeqTensor((1, 0), dtype = torch.float, seq_dim = -1)
|
||||
self.held_logits = SeqTensor((1, 0, self.generator.padded_vocab_size), dtype = torch.float, seq_dim = 1)
|
||||
self.full_completion = ""
|
||||
|
||||
self.time_enqueue = time.time()
|
||||
|
||||
# Prepare MRoPE embeddings
|
||||
# TODO: (embeddings)
|
||||
@@ -980,10 +1081,11 @@ class Job:
|
||||
|
||||
def activate(self):
|
||||
self.logits_device = self.generator.model.output_device
|
||||
for f in self.filters:
|
||||
f.attach(self)
|
||||
f.reset()
|
||||
f.is_active = f.trigger_token is None
|
||||
if not self.is_requeued:
|
||||
for f in self.filters:
|
||||
f.attach(self)
|
||||
f.reset()
|
||||
f.is_active = f.trigger_token is None
|
||||
|
||||
def maybe_stash_recurrent(self, cache, interval):
|
||||
seq = self.sequences[0]
|
||||
|
||||
@@ -70,7 +70,8 @@ def start_new_job():
|
||||
input_ids = input_ids,
|
||||
max_new_tokens = random.randint(completion_len[0], completion_len[1]),
|
||||
sampler = ArgmaxSampler(),
|
||||
identifier = prompt
|
||||
identifier = prompt,
|
||||
max_rq_tokens = 512,
|
||||
)
|
||||
generator.enqueue(job)
|
||||
|
||||
@@ -91,6 +92,9 @@ def iterate():
|
||||
num_pending = generator.num_pending_jobs()
|
||||
results = generator.iterate()
|
||||
for result in results:
|
||||
if result.get("requeue"):
|
||||
print(f"{str(result['job'])} requeued")
|
||||
|
||||
if result["eos"]:
|
||||
cached_tokens = result["cached_tokens"]
|
||||
cached_pages = result["cached_pages"]
|
||||
|
||||
Reference in New Issue
Block a user