diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index 3f302f8..e8e3cff 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -19,8 +19,8 @@ from exllamav2 import ( ) from exllamav2.generator import ( ExLlamaV2Sampler, - ExLlamaV2DynamicGenerator, - ExLlamaV2DynamicJob, + ExLlamaV2DynamicGeneratorAsync, + ExLlamaV2DynamicJobAsync, ) from itertools import zip_longest from loguru import logger @@ -54,7 +54,7 @@ class ExllamaV2Container: cache: Optional[ExLlamaV2Cache] = None draft_cache: Optional[ExLlamaV2Cache] = None tokenizer: Optional[ExLlamaV2Tokenizer] = None - generator: Optional[ExLlamaV2DynamicGenerator] = None + generator: Optional[ExLlamaV2DynamicGeneratorAsync] = None prompt_template: Optional[PromptTemplate] = None active_loras: List[ExLlamaV2Lora] = [] @@ -410,14 +410,44 @@ class ExllamaV2Container: return {"success": success, "failure": failure} async def load_gen(self, progress_callback=None): - """Basic async wrapper around the loading generator""" + """Loads a model and streams progress via a generator.""" - load_generator = self.load_gen_sync(progress_callback) - async for value in iterate_in_threadpool(load_generator): + # Indicate that model load has started + self.model_is_loading = True + + # Streaming gen for model load progress + model_load_generator = self.load_model_sync(progress_callback) + async for value in iterate_in_threadpool(model_load_generator): yield value + # TODO: Change these! + # Set the max batch size and check if paged support is available + max_batch_size = 1 if self.config.no_flash_attn else 20 + paged = not self.config.no_flash_attn + + # Create async generator + self.generator = ExLlamaV2DynamicGeneratorAsync( + model=self.model, + cache=self.cache, + draft_model=self.draft_model, + draft_cache=self.draft_cache, + tokenizer=self.tokenizer, + max_batch_size=max_batch_size, + paged=paged, + ) + + # Clean up any extra vram usage from torch and cuda + # (Helps reduce VRAM bottlenecking on Windows) + gc.collect() + torch.cuda.empty_cache() + + # Cleanup and update model load state + self.model_is_loading = False + self.model_loaded = True + logger.info("Model successfully loaded.") + @torch.inference_mode() - def load_gen_sync(self, progress_callback=None): + def load_model_sync(self, progress_callback=None): """ Load model, generator function @@ -429,9 +459,6 @@ class ExllamaV2Container: Runs under a shared inference mode context. """ - # Notify that the model is being loaded - self.model_is_loading = True - # Reset tokenizer namespace vars and create a tokenizer ExLlamaV2Tokenizer.unspecial_piece_to_id = {} ExLlamaV2Tokenizer.unspecial_id_to_piece = {} @@ -511,38 +538,8 @@ class ExllamaV2Container: yield value # Test VRAM allocation with a full-length forward pass - """ input_ids = torch.zeros((1, self.config.max_input_len), dtype=torch.long) self.model.forward(input_ids, cache=self.cache, preprocess_only=True) - """ - - # TODO: Change these! - max_batch_size = 1 if self.config.no_flash_attn else 20 - paged = not self.config.no_flash_attn - - # Create generator - self.generator = ExLlamaV2DynamicGenerator( - model=self.model, - cache=self.cache, - draft_model=self.draft_model, - draft_cache=self.draft_cache, - tokenizer=self.tokenizer, - max_batch_size=max_batch_size, - paged=paged, - ) - - # Warmup the generator - self.generator.warmup() - - # Clean up any extra vram usage from torch and cuda - # (Helps reduce VRAM bottlenecking on Windows) - gc.collect() - torch.cuda.empty_cache() - - # Update model load state - self.model_is_loading = False - self.model_loaded = True - logger.info("Model successfully loaded.") def unload(self, loras_only: bool = False): """ @@ -682,19 +679,7 @@ class ExllamaV2Container: return kwargs - async def generate_gen( - self, prompt: str, abort_event: Optional[threading.Event] = None, **kwargs - ): - """Basic async wrapper for completion generator""" - - sync_generator = self.generate_gen_sync(prompt, abort_event, **kwargs) - async for value in iterate_in_threadpool(sync_generator): - yield value - - @torch.inference_mode() - def generate_gen_sync( - self, prompt: str, abort_event: Optional[threading.Event] = None, **kwargs - ): + async def generate_gen(self, prompt: str, **kwargs): """ Create generator function for prompt completion. @@ -702,7 +687,6 @@ class ExllamaV2Container: """ token_healing = unwrap(kwargs.get("token_healing"), False) - stream_interval = unwrap(kwargs.get("stream_interval"), 0) generate_window = max( unwrap(kwargs.get("generate_window"), 512), self.config.max_seq_len // 8 ) @@ -926,25 +910,6 @@ class ExllamaV2Container: # This is an inverse of skip_special_tokens decode_special_tokens = unwrap(not kwargs.get("skip_special_tokens"), False) - begin_stream_args = { - "token_healing": token_healing, - "loras": self.active_loras, - "return_probabilities": request_logprobs > 0, - "return_top_tokens": request_logprobs, - "return_logits": request_logprobs > 0, - "abort_event": abort_event, - "banned_strings": banned_strings, - "decode_special_tokens": decode_special_tokens, - } - - if self.use_cfg: - begin_stream_args.update( - { - "input_mask": mask, - "position_offsets": offsets, - } - ) - # Log generation options to console # Some options are too large, so log the args instead log_generation_params( @@ -972,19 +937,10 @@ class ExllamaV2Container: # Log prompt to console log_prompt(prompt, negative_prompt) - # Begin - # generated_tokens = 0 - # full_response = "" - # start_time = time.time() - # last_chunk_time = start_time - - # save_tokens = torch.empty((ids.shape[0], 0), dtype=torch.bool) - # chunk_buffer = "" - # chunk_tokens = 0 - # Create and add a new job job_id = uuid.uuid4().hex - job = ExLlamaV2DynamicJob( + job = ExLlamaV2DynamicJobAsync( + self.generator, input_ids=ids, max_new_tokens=max_tokens, gen_settings=gen_settings, @@ -996,108 +952,30 @@ class ExllamaV2Container: return_top_tokens=request_logprobs, return_logits=request_logprobs > 0, banned_strings=banned_strings, + token_healing=token_healing, identifier=job_id, ) - self.generator.enqueue(job) - - # Save generated tokens + # Save generated tokens and full response + # Full response is required for offset calculation generated_tokens = 0 + full_response = "" - # Grab the next job and iterate through the results - while self.generator.num_remaining_jobs(): - results = self.generator.iterate() - for raw_generation in results: - if ( - raw_generation["stage"] == "streaming" - and raw_generation["identifier"] == job_id - ): - chunk = unwrap(raw_generation.get("text"), "") - eos = raw_generation.get("eos") + # Get the generation status once it's ready + async for result in job: + stage = result.get("stage") + result_id = result.get("identifier") - chunk_tokens = raw_generation.get("token_ids") - if chunk_tokens is not None: - generated_tokens += chunk_tokens.size(dim=0) + if stage == "streaming" and result_id == job_id: + chunk = unwrap(result.get("text"), "") + full_response += chunk - generation = { - "text": chunk, - "prompt_tokens": prompt_tokens, - "generated_tokens": generated_tokens, - # "offset": len(full_response), - } + chunk_tokens = result.get("token_ids") + if chunk_tokens is not None: + generated_tokens += chunk_tokens.size(dim=0) - yield generation - - # Second yield if eos is true - if eos: - log_response(raw_generation.get("full_completion")) - - eos_reason = raw_generation.get("eos_reason") - finish_reason = ( - "length" if eos_reason == "max_new_tokens" else "stop" - ) - - # Remove the token text - generation["text"] = None - generation["finish_reason"] = finish_reason - - yield generation - - """ - while True: - # Ingest prompt - if chunk_tokens == 0: - ids = torch.cat((ids, save_tokens), dim=-1) - save_tokens = torch.empty((ids.shape[0], 0), dtype=torch.bool) - overflow = ids.shape[-1] + generate_window - self.config.max_seq_len - active_ids = ids[:, max(0, overflow) :] - chunk_tokens = self.config.max_seq_len - active_ids.shape[-1] - - # Kick off the streaming generation - self.generator.begin_stream_ex( - active_ids, gen_settings, **begin_stream_args - ) - - # Reset offsets for subsequent passes if the context is truncated - offsets = None - - if auto_scale_penalty_range: - gen_settings.token_repetition_range = generated_tokens - - # Run dict generation - # Guarantees return of chunk, eos, and chunk_token_ids - if generated_tokens < min_tokens: - raw_generation = self.generator.stream_ex(ban_tokens=eos_tokens) - else: - raw_generation = self.generator.stream_ex() - - if token_healing: - # Extract healed token - ids[:, -1] = self.generator.sequence_ids[:, -2] - token_healing = False - - # Get parameters that will always exist - chunk = raw_generation["chunk"] - eos = raw_generation["eos"] - tokens = raw_generation["chunk_token_ids"] - - save_tokens = torch.cat( - (save_tokens, tokens.expand(save_tokens.shape[0], -1)), dim=-1 - ) - chunk_buffer += chunk - - generated_tokens += 1 - chunk_tokens -= 1 - - # Yield output - now = time.time() - elapsed = now - last_chunk_time - - if chunk_buffer != "" and ( - elapsed > stream_interval or eos or generated_tokens == max_tokens - ): generation = { - "text": chunk_buffer, + "text": chunk, "prompt_tokens": prompt_tokens, "generated_tokens": generated_tokens, "offset": len(full_response), @@ -1106,12 +984,12 @@ class ExllamaV2Container: if request_logprobs > 0: # Get top tokens and probs top_tokens = unwrap( - raw_generation.get("top_tokens"), + result.get("top_k_tokens"), torch.empty((1, 0, 1), dtype=torch.long), ) top_probs = unwrap( - raw_generation.get("top_probs"), + result.get("top_k_probs"), torch.empty((1, 0, 1), dtype=torch.float), ) @@ -1126,25 +1004,32 @@ class ExllamaV2Container: } yield generation - full_response += chunk_buffer - chunk_buffer = "" - last_chunk_time = now - if eos or generated_tokens == max_tokens: - # Print response - log_response(full_response) + # Second yield if eos is true + if result.get("eos"): + log_response(full_response) - # Print metrics - elapsed_time = last_chunk_time - start_time - context_len = None if ids is None else context_len + eos_reason = result.get("eos_reason") + finish_reason = ( + "length" if eos_reason == "max_new_tokens" else "stop" + ) - log_metrics( - generated_tokens, elapsed_time, context_len, self.config.max_seq_len - ) + log_metrics( + result.get("time_enqueued"), + result.get("prompt_tokens"), + result.get("time_prefill"), + result.get("new_tokens"), + result.get("time_generate"), + context_len, + self.config.max_seq_len, + ) - finish_reason = "length" if generated_tokens == max_tokens else "stop" - generation = {"finish_reason": finish_reason} - yield generation + # Remove the token text + generation = { + "prompt_tokens": generation.get("prompt_tokens"), + "generated_tokens": generation.get("generated_tokens"), + "finish_reason": finish_reason, + } - break - """ + yield generation + break diff --git a/common/gen_logging.py b/common/gen_logging.py index bcfd042..bfc6c2e 100644 --- a/common/gen_logging.py +++ b/common/gen_logging.py @@ -70,29 +70,38 @@ def log_response(response: str): def log_metrics( + queue_time: float, + prompt_tokens: int, + prompt_time: float, generated_tokens: int, - elapsed_time: float, + generate_time: float, context_len: Optional[int], max_seq_len: int, ): initial_response = ( f"Metrics: {generated_tokens} tokens generated in " - f"{round(elapsed_time, 2)} seconds" + f"{round(queue_time + prompt_time + generate_time, 2)} seconds" ) itemization = [] extra_parts = [] - # Add tokens per second - tokens_per_second = ( - "Indeterminate" - if elapsed_time == 0 - else round(generated_tokens / elapsed_time, 2) + itemization.append(f"Queue: {round(queue_time, 2)} s") + + prompt_ts = ( + "Indeterminate" if prompt_time == 0 else round(prompt_tokens / prompt_time, 2) ) - itemization.append(f"{tokens_per_second} T/s") + itemization.append(f"Process: {prompt_ts} T/s") + + generate_ts = ( + "Indeterminate" + if generate_time == 0 + else round(generated_tokens / generate_time, 2) + ) + itemization.append(f"Generate: {generate_ts} T/s") # Add context (original token count) if context_len: - itemization.append(f"context {context_len} tokens") + itemization.append(f"Context: {context_len} tokens") if context_len > max_seq_len: extra_parts.append("<-- Not accurate (truncated)") diff --git a/endpoints/OAI/utils/chat_completion.py b/endpoints/OAI/utils/chat_completion.py index 11872b7..833d7be 100644 --- a/endpoints/OAI/utils/chat_completion.py +++ b/endpoints/OAI/utils/chat_completion.py @@ -1,8 +1,7 @@ """Chat completion utilities for OAI server.""" -from asyncio import CancelledError import pathlib -import threading +from asyncio import CancelledError from typing import Optional from uuid import uuid4 @@ -198,11 +197,8 @@ async def stream_generate_chat_completion( """Generator for the generation process.""" try: const_id = f"chatcmpl-{uuid4().hex}" - abort_event = threading.Event() - new_generation = model.container.generate_gen( - prompt, abort_event, **data.to_gen_params() - ) + new_generation = model.container.generate_gen(prompt, **data.to_gen_params()) async for generation in new_generation: response = _create_stream_chunk(const_id, generation, model_path.name) @@ -214,7 +210,6 @@ async def stream_generate_chat_completion( except CancelledError: # Get out if the request gets disconnected - abort_event.set() handle_request_disconnect("Chat completion generation cancelled by user.") except Exception: yield get_generator_error( diff --git a/endpoints/OAI/utils/completion.py b/endpoints/OAI/utils/completion.py index 24a3d12..31e1533 100644 --- a/endpoints/OAI/utils/completion.py +++ b/endpoints/OAI/utils/completion.py @@ -2,7 +2,6 @@ import pathlib from asyncio import CancelledError -import threading from fastapi import HTTPException from typing import Optional @@ -65,10 +64,8 @@ async def stream_generate_completion(data: CompletionRequest, model_path: pathli """Streaming generation for completions.""" try: - abort_event = threading.Event() - new_generation = model.container.generate_gen( - data.prompt, abort_event, **data.to_gen_params() + data.prompt, **data.to_gen_params() ) async for generation in new_generation: response = _create_response(generation, model_path.name) @@ -81,7 +78,6 @@ async def stream_generate_completion(data: CompletionRequest, model_path: pathli except CancelledError: # Get out if the request gets disconnected - abort_event.set() handle_request_disconnect("Completion generation cancelled by user.") except Exception: yield get_generator_error(