From 06ff47e2b4c863675711c4f3289e6cc7f0dd8ecc Mon Sep 17 00:00:00 2001 From: kingbri Date: Thu, 23 May 2024 21:37:50 -0400 Subject: [PATCH] Model: Use true async jobs and add logprobs The new async dynamic job allows for native async support without the need of threading. Also add logprobs and metrics back to responses. Signed-off-by: kingbri --- backends/exllamav2/model.py | 277 ++++++++----------------- common/gen_logging.py | 27 ++- endpoints/OAI/utils/chat_completion.py | 9 +- endpoints/OAI/utils/completion.py | 6 +- 4 files changed, 102 insertions(+), 217 deletions(-) diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index 3f302f8..e8e3cff 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -19,8 +19,8 @@ from exllamav2 import ( ) from exllamav2.generator import ( ExLlamaV2Sampler, - ExLlamaV2DynamicGenerator, - ExLlamaV2DynamicJob, + ExLlamaV2DynamicGeneratorAsync, + ExLlamaV2DynamicJobAsync, ) from itertools import zip_longest from loguru import logger @@ -54,7 +54,7 @@ class ExllamaV2Container: cache: Optional[ExLlamaV2Cache] = None draft_cache: Optional[ExLlamaV2Cache] = None tokenizer: Optional[ExLlamaV2Tokenizer] = None - generator: Optional[ExLlamaV2DynamicGenerator] = None + generator: Optional[ExLlamaV2DynamicGeneratorAsync] = None prompt_template: Optional[PromptTemplate] = None active_loras: List[ExLlamaV2Lora] = [] @@ -410,14 +410,44 @@ class ExllamaV2Container: return {"success": success, "failure": failure} async def load_gen(self, progress_callback=None): - """Basic async wrapper around the loading generator""" + """Loads a model and streams progress via a generator.""" - load_generator = self.load_gen_sync(progress_callback) - async for value in iterate_in_threadpool(load_generator): + # Indicate that model load has started + self.model_is_loading = True + + # Streaming gen for model load progress + model_load_generator = self.load_model_sync(progress_callback) + async for value in iterate_in_threadpool(model_load_generator): yield value + # TODO: Change these! + # Set the max batch size and check if paged support is available + max_batch_size = 1 if self.config.no_flash_attn else 20 + paged = not self.config.no_flash_attn + + # Create async generator + self.generator = ExLlamaV2DynamicGeneratorAsync( + model=self.model, + cache=self.cache, + draft_model=self.draft_model, + draft_cache=self.draft_cache, + tokenizer=self.tokenizer, + max_batch_size=max_batch_size, + paged=paged, + ) + + # Clean up any extra vram usage from torch and cuda + # (Helps reduce VRAM bottlenecking on Windows) + gc.collect() + torch.cuda.empty_cache() + + # Cleanup and update model load state + self.model_is_loading = False + self.model_loaded = True + logger.info("Model successfully loaded.") + @torch.inference_mode() - def load_gen_sync(self, progress_callback=None): + def load_model_sync(self, progress_callback=None): """ Load model, generator function @@ -429,9 +459,6 @@ class ExllamaV2Container: Runs under a shared inference mode context. """ - # Notify that the model is being loaded - self.model_is_loading = True - # Reset tokenizer namespace vars and create a tokenizer ExLlamaV2Tokenizer.unspecial_piece_to_id = {} ExLlamaV2Tokenizer.unspecial_id_to_piece = {} @@ -511,38 +538,8 @@ class ExllamaV2Container: yield value # Test VRAM allocation with a full-length forward pass - """ input_ids = torch.zeros((1, self.config.max_input_len), dtype=torch.long) self.model.forward(input_ids, cache=self.cache, preprocess_only=True) - """ - - # TODO: Change these! - max_batch_size = 1 if self.config.no_flash_attn else 20 - paged = not self.config.no_flash_attn - - # Create generator - self.generator = ExLlamaV2DynamicGenerator( - model=self.model, - cache=self.cache, - draft_model=self.draft_model, - draft_cache=self.draft_cache, - tokenizer=self.tokenizer, - max_batch_size=max_batch_size, - paged=paged, - ) - - # Warmup the generator - self.generator.warmup() - - # Clean up any extra vram usage from torch and cuda - # (Helps reduce VRAM bottlenecking on Windows) - gc.collect() - torch.cuda.empty_cache() - - # Update model load state - self.model_is_loading = False - self.model_loaded = True - logger.info("Model successfully loaded.") def unload(self, loras_only: bool = False): """ @@ -682,19 +679,7 @@ class ExllamaV2Container: return kwargs - async def generate_gen( - self, prompt: str, abort_event: Optional[threading.Event] = None, **kwargs - ): - """Basic async wrapper for completion generator""" - - sync_generator = self.generate_gen_sync(prompt, abort_event, **kwargs) - async for value in iterate_in_threadpool(sync_generator): - yield value - - @torch.inference_mode() - def generate_gen_sync( - self, prompt: str, abort_event: Optional[threading.Event] = None, **kwargs - ): + async def generate_gen(self, prompt: str, **kwargs): """ Create generator function for prompt completion. @@ -702,7 +687,6 @@ class ExllamaV2Container: """ token_healing = unwrap(kwargs.get("token_healing"), False) - stream_interval = unwrap(kwargs.get("stream_interval"), 0) generate_window = max( unwrap(kwargs.get("generate_window"), 512), self.config.max_seq_len // 8 ) @@ -926,25 +910,6 @@ class ExllamaV2Container: # This is an inverse of skip_special_tokens decode_special_tokens = unwrap(not kwargs.get("skip_special_tokens"), False) - begin_stream_args = { - "token_healing": token_healing, - "loras": self.active_loras, - "return_probabilities": request_logprobs > 0, - "return_top_tokens": request_logprobs, - "return_logits": request_logprobs > 0, - "abort_event": abort_event, - "banned_strings": banned_strings, - "decode_special_tokens": decode_special_tokens, - } - - if self.use_cfg: - begin_stream_args.update( - { - "input_mask": mask, - "position_offsets": offsets, - } - ) - # Log generation options to console # Some options are too large, so log the args instead log_generation_params( @@ -972,19 +937,10 @@ class ExllamaV2Container: # Log prompt to console log_prompt(prompt, negative_prompt) - # Begin - # generated_tokens = 0 - # full_response = "" - # start_time = time.time() - # last_chunk_time = start_time - - # save_tokens = torch.empty((ids.shape[0], 0), dtype=torch.bool) - # chunk_buffer = "" - # chunk_tokens = 0 - # Create and add a new job job_id = uuid.uuid4().hex - job = ExLlamaV2DynamicJob( + job = ExLlamaV2DynamicJobAsync( + self.generator, input_ids=ids, max_new_tokens=max_tokens, gen_settings=gen_settings, @@ -996,108 +952,30 @@ class ExllamaV2Container: return_top_tokens=request_logprobs, return_logits=request_logprobs > 0, banned_strings=banned_strings, + token_healing=token_healing, identifier=job_id, ) - self.generator.enqueue(job) - - # Save generated tokens + # Save generated tokens and full response + # Full response is required for offset calculation generated_tokens = 0 + full_response = "" - # Grab the next job and iterate through the results - while self.generator.num_remaining_jobs(): - results = self.generator.iterate() - for raw_generation in results: - if ( - raw_generation["stage"] == "streaming" - and raw_generation["identifier"] == job_id - ): - chunk = unwrap(raw_generation.get("text"), "") - eos = raw_generation.get("eos") + # Get the generation status once it's ready + async for result in job: + stage = result.get("stage") + result_id = result.get("identifier") - chunk_tokens = raw_generation.get("token_ids") - if chunk_tokens is not None: - generated_tokens += chunk_tokens.size(dim=0) + if stage == "streaming" and result_id == job_id: + chunk = unwrap(result.get("text"), "") + full_response += chunk - generation = { - "text": chunk, - "prompt_tokens": prompt_tokens, - "generated_tokens": generated_tokens, - # "offset": len(full_response), - } + chunk_tokens = result.get("token_ids") + if chunk_tokens is not None: + generated_tokens += chunk_tokens.size(dim=0) - yield generation - - # Second yield if eos is true - if eos: - log_response(raw_generation.get("full_completion")) - - eos_reason = raw_generation.get("eos_reason") - finish_reason = ( - "length" if eos_reason == "max_new_tokens" else "stop" - ) - - # Remove the token text - generation["text"] = None - generation["finish_reason"] = finish_reason - - yield generation - - """ - while True: - # Ingest prompt - if chunk_tokens == 0: - ids = torch.cat((ids, save_tokens), dim=-1) - save_tokens = torch.empty((ids.shape[0], 0), dtype=torch.bool) - overflow = ids.shape[-1] + generate_window - self.config.max_seq_len - active_ids = ids[:, max(0, overflow) :] - chunk_tokens = self.config.max_seq_len - active_ids.shape[-1] - - # Kick off the streaming generation - self.generator.begin_stream_ex( - active_ids, gen_settings, **begin_stream_args - ) - - # Reset offsets for subsequent passes if the context is truncated - offsets = None - - if auto_scale_penalty_range: - gen_settings.token_repetition_range = generated_tokens - - # Run dict generation - # Guarantees return of chunk, eos, and chunk_token_ids - if generated_tokens < min_tokens: - raw_generation = self.generator.stream_ex(ban_tokens=eos_tokens) - else: - raw_generation = self.generator.stream_ex() - - if token_healing: - # Extract healed token - ids[:, -1] = self.generator.sequence_ids[:, -2] - token_healing = False - - # Get parameters that will always exist - chunk = raw_generation["chunk"] - eos = raw_generation["eos"] - tokens = raw_generation["chunk_token_ids"] - - save_tokens = torch.cat( - (save_tokens, tokens.expand(save_tokens.shape[0], -1)), dim=-1 - ) - chunk_buffer += chunk - - generated_tokens += 1 - chunk_tokens -= 1 - - # Yield output - now = time.time() - elapsed = now - last_chunk_time - - if chunk_buffer != "" and ( - elapsed > stream_interval or eos or generated_tokens == max_tokens - ): generation = { - "text": chunk_buffer, + "text": chunk, "prompt_tokens": prompt_tokens, "generated_tokens": generated_tokens, "offset": len(full_response), @@ -1106,12 +984,12 @@ class ExllamaV2Container: if request_logprobs > 0: # Get top tokens and probs top_tokens = unwrap( - raw_generation.get("top_tokens"), + result.get("top_k_tokens"), torch.empty((1, 0, 1), dtype=torch.long), ) top_probs = unwrap( - raw_generation.get("top_probs"), + result.get("top_k_probs"), torch.empty((1, 0, 1), dtype=torch.float), ) @@ -1126,25 +1004,32 @@ class ExllamaV2Container: } yield generation - full_response += chunk_buffer - chunk_buffer = "" - last_chunk_time = now - if eos or generated_tokens == max_tokens: - # Print response - log_response(full_response) + # Second yield if eos is true + if result.get("eos"): + log_response(full_response) - # Print metrics - elapsed_time = last_chunk_time - start_time - context_len = None if ids is None else context_len + eos_reason = result.get("eos_reason") + finish_reason = ( + "length" if eos_reason == "max_new_tokens" else "stop" + ) - log_metrics( - generated_tokens, elapsed_time, context_len, self.config.max_seq_len - ) + log_metrics( + result.get("time_enqueued"), + result.get("prompt_tokens"), + result.get("time_prefill"), + result.get("new_tokens"), + result.get("time_generate"), + context_len, + self.config.max_seq_len, + ) - finish_reason = "length" if generated_tokens == max_tokens else "stop" - generation = {"finish_reason": finish_reason} - yield generation + # Remove the token text + generation = { + "prompt_tokens": generation.get("prompt_tokens"), + "generated_tokens": generation.get("generated_tokens"), + "finish_reason": finish_reason, + } - break - """ + yield generation + break diff --git a/common/gen_logging.py b/common/gen_logging.py index bcfd042..bfc6c2e 100644 --- a/common/gen_logging.py +++ b/common/gen_logging.py @@ -70,29 +70,38 @@ def log_response(response: str): def log_metrics( + queue_time: float, + prompt_tokens: int, + prompt_time: float, generated_tokens: int, - elapsed_time: float, + generate_time: float, context_len: Optional[int], max_seq_len: int, ): initial_response = ( f"Metrics: {generated_tokens} tokens generated in " - f"{round(elapsed_time, 2)} seconds" + f"{round(queue_time + prompt_time + generate_time, 2)} seconds" ) itemization = [] extra_parts = [] - # Add tokens per second - tokens_per_second = ( - "Indeterminate" - if elapsed_time == 0 - else round(generated_tokens / elapsed_time, 2) + itemization.append(f"Queue: {round(queue_time, 2)} s") + + prompt_ts = ( + "Indeterminate" if prompt_time == 0 else round(prompt_tokens / prompt_time, 2) ) - itemization.append(f"{tokens_per_second} T/s") + itemization.append(f"Process: {prompt_ts} T/s") + + generate_ts = ( + "Indeterminate" + if generate_time == 0 + else round(generated_tokens / generate_time, 2) + ) + itemization.append(f"Generate: {generate_ts} T/s") # Add context (original token count) if context_len: - itemization.append(f"context {context_len} tokens") + itemization.append(f"Context: {context_len} tokens") if context_len > max_seq_len: extra_parts.append("<-- Not accurate (truncated)") diff --git a/endpoints/OAI/utils/chat_completion.py b/endpoints/OAI/utils/chat_completion.py index 11872b7..833d7be 100644 --- a/endpoints/OAI/utils/chat_completion.py +++ b/endpoints/OAI/utils/chat_completion.py @@ -1,8 +1,7 @@ """Chat completion utilities for OAI server.""" -from asyncio import CancelledError import pathlib -import threading +from asyncio import CancelledError from typing import Optional from uuid import uuid4 @@ -198,11 +197,8 @@ async def stream_generate_chat_completion( """Generator for the generation process.""" try: const_id = f"chatcmpl-{uuid4().hex}" - abort_event = threading.Event() - new_generation = model.container.generate_gen( - prompt, abort_event, **data.to_gen_params() - ) + new_generation = model.container.generate_gen(prompt, **data.to_gen_params()) async for generation in new_generation: response = _create_stream_chunk(const_id, generation, model_path.name) @@ -214,7 +210,6 @@ async def stream_generate_chat_completion( except CancelledError: # Get out if the request gets disconnected - abort_event.set() handle_request_disconnect("Chat completion generation cancelled by user.") except Exception: yield get_generator_error( diff --git a/endpoints/OAI/utils/completion.py b/endpoints/OAI/utils/completion.py index 24a3d12..31e1533 100644 --- a/endpoints/OAI/utils/completion.py +++ b/endpoints/OAI/utils/completion.py @@ -2,7 +2,6 @@ import pathlib from asyncio import CancelledError -import threading from fastapi import HTTPException from typing import Optional @@ -65,10 +64,8 @@ async def stream_generate_completion(data: CompletionRequest, model_path: pathli """Streaming generation for completions.""" try: - abort_event = threading.Event() - new_generation = model.container.generate_gen( - data.prompt, abort_event, **data.to_gen_params() + data.prompt, **data.to_gen_params() ) async for generation in new_generation: response = _create_response(generation, model_path.name) @@ -81,7 +78,6 @@ async def stream_generate_completion(data: CompletionRequest, model_path: pathli except CancelledError: # Get out if the request gets disconnected - abort_event.set() handle_request_disconnect("Completion generation cancelled by user.") except Exception: yield get_generator_error(