From 9dae46114288dc1feb6a3f70fce5e5d1c117eeab Mon Sep 17 00:00:00 2001 From: kingbri Date: Mon, 15 Jul 2024 01:09:49 -0400 Subject: [PATCH] Model: Attempt to recreate generator on a fatal error If a job causes the generator to error, tabby stops working until a relaunch. It's better to try establishing a system of redundancy and remake the generator in the event that it fails. May replace this with an exit signal for a fatal error instead, but not sure. Signed-off-by: kingbri --- backends/exllamav2/model.py | 49 ++++++++++++++++++++++++++++++------- 1 file changed, 40 insertions(+), 9 deletions(-) diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index 04c6d08..0c65b38 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -488,15 +488,7 @@ class ExllamaV2Container: yield value # Create async generator - self.generator = ExLlamaV2DynamicGeneratorAsync( - model=self.model, - cache=self.cache, - draft_model=self.draft_model, - draft_cache=self.draft_cache, - tokenizer=self.tokenizer, - max_batch_size=self.max_batch_size, - paged=self.paged, - ) + await self.create_generator() # Clean up any extra vram usage from torch and cuda # (Helps reduce VRAM bottlenecking on Windows) @@ -645,6 +637,34 @@ class ExllamaV2Container: input_ids = torch.zeros((1, self.config.max_input_len), dtype=torch.long) self.model.forward(input_ids, cache=self.cache, preprocess_only=True) + async def create_generator(self): + try: + # Don't acquire locks unless a model is loaded + if self.model_loaded: + await self.load_lock.acquire() + + # Immediately cancel all jobs + await self.wait_for_jobs(skip_wait=True) + + # Create new generator + self.generator = ExLlamaV2DynamicGeneratorAsync( + model=self.model, + cache=self.cache, + draft_model=self.draft_model, + draft_cache=self.draft_cache, + tokenizer=self.tokenizer, + max_batch_size=self.max_batch_size, + paged=self.paged, + ) + finally: + # This means the generator is being recreated + # The load lock is already released in the load function + if self.model_loaded: + self.load_lock.release() + + async with self.load_condition: + self.load_condition.notify_all() + def get_loras(self): """Convenience function to get all loras.""" @@ -1223,3 +1243,14 @@ class ExllamaV2Container: break except asyncio.CancelledError: await job.cancel() + except Exception as ex: + # Create a new generator since the current state is broken + # No need to wait for this to finish + logger.error( + "FATAL ERROR with generation. " + "Attempting to recreate the generator. " + "If this fails, please restart the server.\n" + ) + asyncio.ensure_future(self.create_generator()) + + raise ex