From b35c48da377810187594eab92891d7646b5926a0 Mon Sep 17 00:00:00 2001 From: randoentity <137087500+randoentity@users.noreply.github.com> Date: Wed, 30 Apr 2025 11:56:24 +0200 Subject: [PATCH] fixup: some metrics --- backends/exllamav3/model.py | 118 +++++++++++++++++++++++++----------- 1 file changed, 84 insertions(+), 34 deletions(-) diff --git a/backends/exllamav3/model.py b/backends/exllamav3/model.py index 54c3547..82f3045 100644 --- a/backends/exllamav3/model.py +++ b/backends/exllamav3/model.py @@ -18,6 +18,7 @@ from common.concurrency import iterate_in_threadpool from common.gen_logging import ( log_metrics, ) +from common.health import HealthManager from common.multimodal import MultimodalEmbeddingWrapper from common.sampling import BaseSamplerRequest from common.templating import PromptTemplate, find_prompt_template @@ -436,6 +437,37 @@ class ExllamaV3Container(BaseModelContainer): return finish_chunk + async def create_generator(self): + """Create and save a Exllama generator class.""" + + try: + # Don't acquire locks unless a model is loaded + if self.loaded: + await self.load_lock.acquire() + + # Immediately cancel all jobs + await self.wait_for_jobs(skip_wait=True) + + # Create new generator + self.generator = AsyncGenerator( + model=self.model, + cache=self.cache, + tokenizer=self.tokenizer, + max_batch_size=self.max_batch_size, + ) + + # Update the state of the container var + if self.max_batch_size is None: + self.max_batch_size = self.generator.generator.max_batch_size + finally: + # This means the generator is being recreated + # The load lock is already released in the load function + if self.loaded: + self.load_lock.release() + + async with self.load_condition: + self.load_condition.notify_all() + async def generate_gen( self, request_id: str, @@ -516,42 +548,60 @@ class ExllamaV3Container(BaseModelContainer): full_response = "" metrics_result = {} - async for result in job: - chunk = unwrap(result.get("text"), "") - if chunk: - chunk_tokens = result.get("token_ids", self.tokenizer.encode(chunk)) - full_response += chunk - if isinstance(chunk_tokens, torch.Tensor): - generated_tokens += chunk_tokens.size(dim=0) - generation = { - "text": chunk, - "prompt_tokens": context_len, - "generated_tokens": generated_tokens, - "offset": len(full_response), - } - yield generation + # Get the generation status once it's ready + try: + async for result in job: + chunk = unwrap(result.get("text"), "") + if chunk: + chunk_tokens = result.get("token_ids", self.tokenizer.encode(chunk)) + full_response += chunk + if isinstance(chunk_tokens, torch.Tensor): + generated_tokens += chunk_tokens.size(dim=0) + generation = { + "text": chunk, + "prompt_tokens": context_len, + "generated_tokens": generated_tokens, + "offset": len(full_response), + } + yield generation - if result.get("eos"): - generation = self.handle_finish_chunk(result, generation) + if result.get("eos"): + generation = self.handle_finish_chunk(result, generation) - # Save the final result for metrics logging - metrics_result = result + # Save the final result for metrics logging + metrics_result = result - yield generation - break - # Assign the active job to the request ID - self.active_job_ids[request_id] = job + yield generation + break + # Assign the active job to the request ID + self.active_job_ids[request_id] = job - # Log the metrics if present - if metrics_result: - log_metrics( - request_id, - metrics_result.get("time_enqueued"), - metrics_result.get("prompt_tokens"), - metrics_result.get("cached_tokens"), - metrics_result.get("time_prefill"), - metrics_result.get("new_tokens"), - metrics_result.get("time_generate"), - context_len, - self.max_seq_len, + except asyncio.CancelledError: + await job.cancel() + except Exception as ex: + # Create a new generator since the current state is broken + # No need to wait for this to finish + logger.error( + "FATAL ERROR with generation. " + "Attempting to recreate the generator. " + "If this fails, please restart the server.\n" ) + asyncio.ensure_future(self.create_generator()) + + await HealthManager.add_unhealthy_event(ex) + + raise ex + finally: + # Log the metrics if present + if metrics_result: + log_metrics( + request_id, + metrics_result.get("time_enqueued"), + metrics_result.get("prompt_tokens"), + metrics_result.get("cached_tokens"), + metrics_result.get("time_prefill"), + metrics_result.get("new_tokens"), + metrics_result.get("time_generate"), + context_len, + self.max_seq_len, + )