Generator: Add requeue option

This commit is contained in:
turboderp
2025-09-22 03:34:31 +02:00
parent 1ff09ee3c4
commit 12eadfe114
3 changed files with 171 additions and 46 deletions

View File

@@ -205,7 +205,6 @@ class Generator:
job.prepare_for_queue(self, self.job_serial)
self.job_serial += 1
self.pending_jobs.append(job)
job.time_enqueue = time.time()
return job.serial_number
@@ -521,6 +520,7 @@ class Generator:
# Pass to jobs to sample
completed_jobs = []
requeuing_jobs = []
j = 0
for job, a, b in zip(self.active_jobs, logit_mapping[:-1], logit_mapping[1:]):
if a == b: continue
@@ -531,7 +531,7 @@ class Generator:
next_token, next_k_tokens, next_k_probs, next_prob = job.receive_logits(
token_logits,
)
eos, sampled_token = job.receive_sample(
eos, sampled_token, rq = job.receive_sample(
token_logits,
next_token,
next_k_tokens,
@@ -540,6 +540,11 @@ class Generator:
results,
)
# Requeue
if len(job.sequences) == 1 and rq:
requeuing_jobs.append(job)
break
# EOS
if eos:
completed_jobs.append(job)
@@ -565,9 +570,16 @@ class Generator:
# Release pages for completed jobs
num_jobs = self.num_remaining_jobs()
for job in completed_jobs:
for job in completed_jobs + requeuing_jobs:
job.deallocate_pages()
self.active_jobs.remove(job)
# Requeue jobs
for job in requeuing_jobs:
rq_job = job.prepare_for_requeue()
self.pending_jobs.append(rq_job)
# Defrag
if num_jobs and not self.num_remaining_jobs():
self.pagetable.defrag()
@@ -636,6 +648,7 @@ class Generator:
filters: list[list[Filter]] | list[Filter] | None = None,
return_last_results: bool = False,
embeddings: list[MMEmbedding] | list[list[MMEmbedding]] | None = None,
max_rq_tokens: int | None = None,
**kwargs
):
"""
@@ -696,6 +709,11 @@ class Generator:
:param embeddings:
Optional list of MMEmbeddings to use for, or list of lists for batched generation
:param max_rq_tokens:
Maximum number of tokens before job is requeued. Rounded to nearest page boundary. This limits how
many new pages are allocated in the cache for the job in any one round and allows a single job to use
the full cache size without limiting concurrency for other jobs.
:return:
Completion(s): (str or list[str] depending on the type of the input prompt argument)
Optionally, last results: (dict or list[dict] depending on the type of the input prompt argument)
@@ -760,7 +778,8 @@ class Generator:
filters = filters[idx] or [],
token_healing = token_healing,
decode_special_tokens = decode_special_tokens,
embeddings = embeddings[idx] or []
embeddings = embeddings[idx] or [],
max_rq_tokens = max_rq_tokens
)
if seed is not None: seed += 1

View File

@@ -60,6 +60,8 @@ class Job:
identifier: object | None = None,
banned_strings: list[str] | None = None,
embeddings: list[MMEmbedding] | None = None,
max_rq_tokens: int | None = None,
rq_state: dict | None = None,
**kwargs
):
"""
@@ -123,12 +125,26 @@ class Job:
:param embeddings:
Optional list of MMEmbeddings to use, or list of lists for batched generation
:param max_rq_tokens:
Maximum number of tokens before job is requeued. Rounded to nearest page boundary. This limits how
many new pages are allocated in the cache for the job in any one round and allows a single job to use
the full cache size without limiting concurrency for other jobs.
:param rq_state:
Internal, passed when job requeues itself
:param kwargs:
"""
assert all(ids.device.type == "cpu" for ids in input_ids), \
"input_ids must reside in system memory"
if rq_state is None:
rq_state = {}
self.is_requeued = True
else:
self.is_requeued = False
self.generator = None
self.pagetable = None
self.serial_number = None
@@ -143,13 +159,14 @@ class Job:
sampler = DefaultSampler()
# Sampling state
self.held_text = ""
self.held_tokens = None
self.held_k_tokens = None
self.held_k_probs = None
self.held_probs = None
self.held_logits = None
self.full_completion = ""
self.held_text = rq_state.get("held_text", "")
self.held_tokens = rq_state.get("held_tokens")
self.held_k_tokens = rq_state.get("held_k_tokens")
self.held_k_probs = rq_state.get("held_k_probs")
self.held_probs = rq_state.get("held_probs")
self.held_logits = rq_state.get("held_logits")
self.full_completion = rq_state.get("full_completion", "")
self.seed = seed
# Prepare sequences
if not isinstance(input_ids, list):
@@ -176,7 +193,11 @@ class Job:
self.min_new_tokens = min_new_tokens
self.new_tokens = 0 if self.prefix_token is None else -1
self.sampler = sampler
self.rng = random.Random() if seed is None else random.Random(seed)
self.rng = rq_state.get("rng")
if self.rng is None:
self.rng = random.Random() if seed is None else random.Random(seed)
self.orig_max_rq_tokens = max_rq_tokens
self.max_rq_tokens = max_rq_tokens
# Output options
self.decode_special_tokens = decode_special_tokens
@@ -185,8 +206,8 @@ class Job:
self.return_probs = return_probs
# Stop conditions
self.stop_strings = set()
self.stop_tokens = set()
self.stop_strings = rq_state.get("stop_strings", set())
self.stop_tokens = rq_state.get("stop_tokens", set())
if stop_conditions is not None:
for t in stop_conditions:
if isinstance(t, int):
@@ -198,7 +219,8 @@ class Job:
self.stop_strings_utf32_buffer, self.stop_strings_utf32_offsets = \
_strings_to_utf32(tuple(list(self.stop_strings)))
else:
self.stop_strings_utf32_buffer, self.stop_strings_utf32_offsets = None, None
self.stop_strings_utf32_buffer = rq_state.get("stop_strings_utf32_buffer")
self.stop_strings_utf32_offsets = rq_state.get("stop_strings_utf32_offsets")
self.stop_tokens_list = list(self.stop_tokens)
self.stop_strings_list = list(self.stop_strings)
@@ -212,7 +234,8 @@ class Job:
_strings_to_utf32(tuple(self.banned_strings))
else:
self.banned_strings = []
self.banned_strings_utf32_buffer, self.banned_strings_utf32_offsets = None, None
self.banned_strings_utf32_buffer = None
self.banned_strings_utf32_offsets = None
self.checkpoint = None
@@ -221,6 +244,9 @@ class Job:
self.time_first_prefill = None
self.time_first_token = None
self.time_last_token = None
self.time_enqueued = rq_state.get("time_enqueued", 0.0)
self.time_prefill = rq_state.get("time_prefill", 0.0)
self.time_generate = rq_state.get("time_generate", 0.0)
self.accepted_draft_tokens = 0
self.rejected_draft_tokens = 0
self.cached_pages = 0
@@ -432,6 +458,7 @@ class Job:
# Accept token
self.new_tokens += 1
requeue_now = self.new_tokens > self.max_rq_tokens - self.generator.num_draft_tokens
for seq in self.sequences:
@@ -470,7 +497,6 @@ class Job:
page.add_ref(new_serial)
else:
# If an unreferenced page has the same hash, clear that page
if new_hash in self.pagetable.unreferenced_pages:
up = self.pagetable.unreferenced_pages[new_hash]
@@ -479,9 +505,11 @@ class Job:
# Update the hash
page.update_hash(new_hash)
page = seq.allocated_pages[page_after]
page.prev_hash = new_hash
page.can_revert = False
# Allow completing the final page without starting a new one (for requeue)
if page_after < len(seq.allocated_pages):
page = seq.allocated_pages[page_after]
page.prev_hash = new_hash
page.can_revert = False
# Stream output
@@ -496,12 +524,7 @@ class Job:
stop_string: str = None,
rem_held_text: str = None
):
r = {
"job": self,
"stage": "streaming",
"eos": emit_eos,
"serial": self.serial_number,
}
nonlocal requeue_now
r = {
"job": self,
@@ -522,6 +545,15 @@ class Job:
if eos_reason == "stop_string":
r.update({ "eos_triggering_string": stop_string })
# Requeue if we reach max_rq_tokens
if requeue_now:
r.update({ "requeue": True })
requeue_now = True
# Can't revert to checkpoint after requeuing, so just emit whatever was being held
if self.checkpoint is not None:
emit_held = True
if emit_held:
if self.held_text != "":
self.full_completion += self.held_text
@@ -546,16 +578,21 @@ class Job:
r.update({ "suppressed_text": suppressed_text })
r.update({ "suppressed_tokens": suppressed_tokens.torch() })
if emit_eos or requeue_now:
self.time_last_token = time.time()
self.time_enqueued += self.time_first_prefill - self.time_enqueue
self.time_prefill += self.time_first_token - self.time_first_prefill
self.time_generate += self.time_last_token - self.time_first_token
if emit_eos:
self.is_finished = True
self.time_last_token = time.time()
r.update({
"full_completion": self.full_completion,
"new_tokens": self.new_tokens,
"prompt_tokens": len(self.sequences[0].input_ids),
"time_enqueued": self.time_first_prefill - self.time_enqueue,
"time_prefill": self.time_first_token - self.time_first_prefill,
"time_generate": self.time_last_token - self.time_first_token,
"time_enqueued": self.time_enqueued,
"time_prefill": self.time_prefill,
"time_generate": self.time_generate,
"cached_pages": self.cached_pages // len(self.sequences),
"cached_tokens": (self.cached_pages * PAGE_SIZE + self.cached_tokens) // len(self.sequences),
})
@@ -585,7 +622,7 @@ class Job:
r.update({ "identifier": self.identifier })
results_.append(r)
return emit_eos, next_token
return emit_eos, next_token, requeue_now
# Decode and buffer output
id_to_piece = self.generator.tokenizer.get_id_to_piece_list(self.decode_special_tokens)
@@ -681,7 +718,10 @@ class Job:
self.checkpoint["offset"] = 0
return off_tokens, off_text
if self.banned_strings_utf32_offsets is not None and self.new_tokens > 0:
if requeue_now:
unset_checkpoint()
elif self.banned_strings_utf32_offsets is not None and self.new_tokens > 0:
match = ext.partial_strings_match(
np.frombuffer(self.held_text.lower().encode("utf-32-le"), dtype = np.uint8),
self.banned_strings_utf32_offsets,
@@ -732,7 +772,54 @@ class Job:
return emit(results, emit_held = True)
def prepare_for_queue(self, generator, serial_number: int):
def prepare_for_requeue(self):
assert len(self.sequences) == 1
seq = self.sequences[0]
last_completed_tokens = len(seq.sequence_ids) - len(seq.input_ids)
new_input = seq.sequence_ids.torch()
rq_state = {
"rng": self.rng,
"stop_strings": self.stop_strings,
"stop_tokens": self.stop_tokens,
"stop_strings_utf32_buffer": self.stop_strings_utf32_buffer,
"stop_strings_utf32_offsets": self.stop_strings_utf32_offsets,
"held_text": self.held_text,
"held_tokens": self.held_tokens,
"held_probs": self.held_probs,
"held_k_tokens": self.held_k_tokens,
"held_k_probs": self.held_k_probs,
"held_logits": self.held_logits,
"full_completion": self.full_completion,
"time_enqueued": self.time_enqueued,
"time_prefill": self.time_prefill,
"time_generate": self.time_generate,
}
rq_job = Job(
input_ids = new_input,
max_new_tokens = self.max_new_tokens - last_completed_tokens,
min_new_tokens = max(self.min_new_tokens - last_completed_tokens, 0),
sampler = self.sampler,
decode_special_tokens = self.decode_special_tokens,
return_top_tokens = self.return_top_tokens,
return_logits = self.return_logits,
return_probs = self.return_probs,
filters = self.filters, # Carries over state
token_healing = False, # Token healed on first round
identifier = self.identifier,
banned_strings = self.banned_strings,
embeddings = self.embeddings,
max_rq_tokens = self.orig_max_rq_tokens,
rq_state = rq_state,
)
rq_job.prepare_for_queue(self.generator, self.serial_number, rq = True)
return rq_job
def prepare_for_queue(self, generator, serial_number: int, rq: bool = False):
# Attach to generator
self.serial_number = serial_number
@@ -740,6 +827,17 @@ class Job:
self.pagetable = generator.pagetable
self.skips = 0
# Align max_rq_tokens to page boundary or recurrent checkpoint
if self.max_rq_tokens is not None:
if len(self.sequences) == 1:
boundary = self.generator.recurrent_checkpoint_interval \
if self.generator.recurrent_cache is not None else PAGE_SIZE
x = len(self.sequences[0].input_ids)
y = (x - 1 + self.max_rq_tokens + boundary - 1) // boundary * boundary
self.max_rq_tokens = y - x
else:
self.max_rq_tokens = self.max_new_tokens + 1
# Compatibility checks
assert not self.banned_strings or self.generator.recurrent_cache is not None, \
"Cannot use banned strings on recurrent model"
@@ -748,7 +846,7 @@ class Job:
all_unique_hashes = set()
all_unique_pages = 0
for seq in self.sequences:
unique_hashes, unique_pages = seq.prepare(self.prefix_token is not None, self.max_new_tokens)
unique_hashes, unique_pages = seq.prepare(self.prefix_token is not None, self.max_rq_tokens)
all_unique_hashes |= unique_hashes
all_unique_pages += unique_pages
self.all_unique_hashes = list(all_unique_hashes)
@@ -765,13 +863,16 @@ class Job:
f"generator is {self.generator.max_batch_size}."
# Initial conditions
self.held_text = ""
self.held_tokens = SeqTensor((1, 0), dtype = torch.long, seq_dim = -1)
self.held_k_tokens = SeqTensor((1, 0, self.return_top_tokens), dtype = torch.long, seq_dim = 1)
self.held_k_probs = SeqTensor((1, 0, self.return_top_tokens), dtype = torch.float, seq_dim = 1)
self.held_probs = SeqTensor((1, 0), dtype = torch.float, seq_dim = -1)
self.held_logits = SeqTensor((1, 0, self.generator.padded_vocab_size), dtype = torch.float, seq_dim = 1)
self.full_completion = ""
if not rq:
self.held_text = ""
self.held_tokens = SeqTensor((1, 0), dtype = torch.long, seq_dim = -1)
self.held_k_tokens = SeqTensor((1, 0, self.return_top_tokens), dtype = torch.long, seq_dim = 1)
self.held_k_probs = SeqTensor((1, 0, self.return_top_tokens), dtype = torch.float, seq_dim = 1)
self.held_probs = SeqTensor((1, 0), dtype = torch.float, seq_dim = -1)
self.held_logits = SeqTensor((1, 0, self.generator.padded_vocab_size), dtype = torch.float, seq_dim = 1)
self.full_completion = ""
self.time_enqueue = time.time()
# Prepare MRoPE embeddings
# TODO: (embeddings)
@@ -980,10 +1081,11 @@ class Job:
def activate(self):
self.logits_device = self.generator.model.output_device
for f in self.filters:
f.attach(self)
f.reset()
f.is_active = f.trigger_token is None
if not self.is_requeued:
for f in self.filters:
f.attach(self)
f.reset()
f.is_active = f.trigger_token is None
def maybe_stash_recurrent(self, cache, interval):
seq = self.sequences[0]

View File

@@ -70,7 +70,8 @@ def start_new_job():
input_ids = input_ids,
max_new_tokens = random.randint(completion_len[0], completion_len[1]),
sampler = ArgmaxSampler(),
identifier = prompt
identifier = prompt,
max_rq_tokens = 512,
)
generator.enqueue(job)
@@ -91,6 +92,9 @@ def iterate():
num_pending = generator.num_pending_jobs()
results = generator.iterate()
for result in results:
if result.get("requeue"):
print(f"{str(result['job'])} requeued")
if result["eos"]:
cached_tokens = result["cached_tokens"]
cached_pages = result["cached_pages"]