Model: Use true async jobs and add logprobs

The new async dynamic job allows for native async support without the
need of threading. Also add logprobs and metrics back to responses.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri
2024-05-23 21:37:50 -04:00
committed by Brian Dashore
parent 32ae62feac
commit 06ff47e2b4
4 changed files with 102 additions and 217 deletions

View File

@@ -19,8 +19,8 @@ from exllamav2 import (
) )
from exllamav2.generator import ( from exllamav2.generator import (
ExLlamaV2Sampler, ExLlamaV2Sampler,
ExLlamaV2DynamicGenerator, ExLlamaV2DynamicGeneratorAsync,
ExLlamaV2DynamicJob, ExLlamaV2DynamicJobAsync,
) )
from itertools import zip_longest from itertools import zip_longest
from loguru import logger from loguru import logger
@@ -54,7 +54,7 @@ class ExllamaV2Container:
cache: Optional[ExLlamaV2Cache] = None cache: Optional[ExLlamaV2Cache] = None
draft_cache: Optional[ExLlamaV2Cache] = None draft_cache: Optional[ExLlamaV2Cache] = None
tokenizer: Optional[ExLlamaV2Tokenizer] = None tokenizer: Optional[ExLlamaV2Tokenizer] = None
generator: Optional[ExLlamaV2DynamicGenerator] = None generator: Optional[ExLlamaV2DynamicGeneratorAsync] = None
prompt_template: Optional[PromptTemplate] = None prompt_template: Optional[PromptTemplate] = None
active_loras: List[ExLlamaV2Lora] = [] active_loras: List[ExLlamaV2Lora] = []
@@ -410,14 +410,44 @@ class ExllamaV2Container:
return {"success": success, "failure": failure} return {"success": success, "failure": failure}
async def load_gen(self, progress_callback=None): async def load_gen(self, progress_callback=None):
"""Basic async wrapper around the loading generator""" """Loads a model and streams progress via a generator."""
load_generator = self.load_gen_sync(progress_callback) # Indicate that model load has started
async for value in iterate_in_threadpool(load_generator): self.model_is_loading = True
# Streaming gen for model load progress
model_load_generator = self.load_model_sync(progress_callback)
async for value in iterate_in_threadpool(model_load_generator):
yield value yield value
# TODO: Change these!
# Set the max batch size and check if paged support is available
max_batch_size = 1 if self.config.no_flash_attn else 20
paged = not self.config.no_flash_attn
# Create async generator
self.generator = ExLlamaV2DynamicGeneratorAsync(
model=self.model,
cache=self.cache,
draft_model=self.draft_model,
draft_cache=self.draft_cache,
tokenizer=self.tokenizer,
max_batch_size=max_batch_size,
paged=paged,
)
# Clean up any extra vram usage from torch and cuda
# (Helps reduce VRAM bottlenecking on Windows)
gc.collect()
torch.cuda.empty_cache()
# Cleanup and update model load state
self.model_is_loading = False
self.model_loaded = True
logger.info("Model successfully loaded.")
@torch.inference_mode() @torch.inference_mode()
def load_gen_sync(self, progress_callback=None): def load_model_sync(self, progress_callback=None):
""" """
Load model, generator function Load model, generator function
@@ -429,9 +459,6 @@ class ExllamaV2Container:
Runs under a shared inference mode context. Runs under a shared inference mode context.
""" """
# Notify that the model is being loaded
self.model_is_loading = True
# Reset tokenizer namespace vars and create a tokenizer # Reset tokenizer namespace vars and create a tokenizer
ExLlamaV2Tokenizer.unspecial_piece_to_id = {} ExLlamaV2Tokenizer.unspecial_piece_to_id = {}
ExLlamaV2Tokenizer.unspecial_id_to_piece = {} ExLlamaV2Tokenizer.unspecial_id_to_piece = {}
@@ -511,38 +538,8 @@ class ExllamaV2Container:
yield value yield value
# Test VRAM allocation with a full-length forward pass # Test VRAM allocation with a full-length forward pass
"""
input_ids = torch.zeros((1, self.config.max_input_len), dtype=torch.long) input_ids = torch.zeros((1, self.config.max_input_len), dtype=torch.long)
self.model.forward(input_ids, cache=self.cache, preprocess_only=True) self.model.forward(input_ids, cache=self.cache, preprocess_only=True)
"""
# TODO: Change these!
max_batch_size = 1 if self.config.no_flash_attn else 20
paged = not self.config.no_flash_attn
# Create generator
self.generator = ExLlamaV2DynamicGenerator(
model=self.model,
cache=self.cache,
draft_model=self.draft_model,
draft_cache=self.draft_cache,
tokenizer=self.tokenizer,
max_batch_size=max_batch_size,
paged=paged,
)
# Warmup the generator
self.generator.warmup()
# Clean up any extra vram usage from torch and cuda
# (Helps reduce VRAM bottlenecking on Windows)
gc.collect()
torch.cuda.empty_cache()
# Update model load state
self.model_is_loading = False
self.model_loaded = True
logger.info("Model successfully loaded.")
def unload(self, loras_only: bool = False): def unload(self, loras_only: bool = False):
""" """
@@ -682,19 +679,7 @@ class ExllamaV2Container:
return kwargs return kwargs
async def generate_gen( async def generate_gen(self, prompt: str, **kwargs):
self, prompt: str, abort_event: Optional[threading.Event] = None, **kwargs
):
"""Basic async wrapper for completion generator"""
sync_generator = self.generate_gen_sync(prompt, abort_event, **kwargs)
async for value in iterate_in_threadpool(sync_generator):
yield value
@torch.inference_mode()
def generate_gen_sync(
self, prompt: str, abort_event: Optional[threading.Event] = None, **kwargs
):
""" """
Create generator function for prompt completion. Create generator function for prompt completion.
@@ -702,7 +687,6 @@ class ExllamaV2Container:
""" """
token_healing = unwrap(kwargs.get("token_healing"), False) token_healing = unwrap(kwargs.get("token_healing"), False)
stream_interval = unwrap(kwargs.get("stream_interval"), 0)
generate_window = max( generate_window = max(
unwrap(kwargs.get("generate_window"), 512), self.config.max_seq_len // 8 unwrap(kwargs.get("generate_window"), 512), self.config.max_seq_len // 8
) )
@@ -926,25 +910,6 @@ class ExllamaV2Container:
# This is an inverse of skip_special_tokens # This is an inverse of skip_special_tokens
decode_special_tokens = unwrap(not kwargs.get("skip_special_tokens"), False) decode_special_tokens = unwrap(not kwargs.get("skip_special_tokens"), False)
begin_stream_args = {
"token_healing": token_healing,
"loras": self.active_loras,
"return_probabilities": request_logprobs > 0,
"return_top_tokens": request_logprobs,
"return_logits": request_logprobs > 0,
"abort_event": abort_event,
"banned_strings": banned_strings,
"decode_special_tokens": decode_special_tokens,
}
if self.use_cfg:
begin_stream_args.update(
{
"input_mask": mask,
"position_offsets": offsets,
}
)
# Log generation options to console # Log generation options to console
# Some options are too large, so log the args instead # Some options are too large, so log the args instead
log_generation_params( log_generation_params(
@@ -972,19 +937,10 @@ class ExllamaV2Container:
# Log prompt to console # Log prompt to console
log_prompt(prompt, negative_prompt) log_prompt(prompt, negative_prompt)
# Begin
# generated_tokens = 0
# full_response = ""
# start_time = time.time()
# last_chunk_time = start_time
# save_tokens = torch.empty((ids.shape[0], 0), dtype=torch.bool)
# chunk_buffer = ""
# chunk_tokens = 0
# Create and add a new job # Create and add a new job
job_id = uuid.uuid4().hex job_id = uuid.uuid4().hex
job = ExLlamaV2DynamicJob( job = ExLlamaV2DynamicJobAsync(
self.generator,
input_ids=ids, input_ids=ids,
max_new_tokens=max_tokens, max_new_tokens=max_tokens,
gen_settings=gen_settings, gen_settings=gen_settings,
@@ -996,108 +952,30 @@ class ExllamaV2Container:
return_top_tokens=request_logprobs, return_top_tokens=request_logprobs,
return_logits=request_logprobs > 0, return_logits=request_logprobs > 0,
banned_strings=banned_strings, banned_strings=banned_strings,
token_healing=token_healing,
identifier=job_id, identifier=job_id,
) )
self.generator.enqueue(job) # Save generated tokens and full response
# Full response is required for offset calculation
# Save generated tokens
generated_tokens = 0 generated_tokens = 0
full_response = ""
# Grab the next job and iterate through the results # Get the generation status once it's ready
while self.generator.num_remaining_jobs(): async for result in job:
results = self.generator.iterate() stage = result.get("stage")
for raw_generation in results: result_id = result.get("identifier")
if (
raw_generation["stage"] == "streaming"
and raw_generation["identifier"] == job_id
):
chunk = unwrap(raw_generation.get("text"), "")
eos = raw_generation.get("eos")
chunk_tokens = raw_generation.get("token_ids") if stage == "streaming" and result_id == job_id:
if chunk_tokens is not None: chunk = unwrap(result.get("text"), "")
generated_tokens += chunk_tokens.size(dim=0) full_response += chunk
generation = { chunk_tokens = result.get("token_ids")
"text": chunk, if chunk_tokens is not None:
"prompt_tokens": prompt_tokens, generated_tokens += chunk_tokens.size(dim=0)
"generated_tokens": generated_tokens,
# "offset": len(full_response),
}
yield generation
# Second yield if eos is true
if eos:
log_response(raw_generation.get("full_completion"))
eos_reason = raw_generation.get("eos_reason")
finish_reason = (
"length" if eos_reason == "max_new_tokens" else "stop"
)
# Remove the token text
generation["text"] = None
generation["finish_reason"] = finish_reason
yield generation
"""
while True:
# Ingest prompt
if chunk_tokens == 0:
ids = torch.cat((ids, save_tokens), dim=-1)
save_tokens = torch.empty((ids.shape[0], 0), dtype=torch.bool)
overflow = ids.shape[-1] + generate_window - self.config.max_seq_len
active_ids = ids[:, max(0, overflow) :]
chunk_tokens = self.config.max_seq_len - active_ids.shape[-1]
# Kick off the streaming generation
self.generator.begin_stream_ex(
active_ids, gen_settings, **begin_stream_args
)
# Reset offsets for subsequent passes if the context is truncated
offsets = None
if auto_scale_penalty_range:
gen_settings.token_repetition_range = generated_tokens
# Run dict generation
# Guarantees return of chunk, eos, and chunk_token_ids
if generated_tokens < min_tokens:
raw_generation = self.generator.stream_ex(ban_tokens=eos_tokens)
else:
raw_generation = self.generator.stream_ex()
if token_healing:
# Extract healed token
ids[:, -1] = self.generator.sequence_ids[:, -2]
token_healing = False
# Get parameters that will always exist
chunk = raw_generation["chunk"]
eos = raw_generation["eos"]
tokens = raw_generation["chunk_token_ids"]
save_tokens = torch.cat(
(save_tokens, tokens.expand(save_tokens.shape[0], -1)), dim=-1
)
chunk_buffer += chunk
generated_tokens += 1
chunk_tokens -= 1
# Yield output
now = time.time()
elapsed = now - last_chunk_time
if chunk_buffer != "" and (
elapsed > stream_interval or eos or generated_tokens == max_tokens
):
generation = { generation = {
"text": chunk_buffer, "text": chunk,
"prompt_tokens": prompt_tokens, "prompt_tokens": prompt_tokens,
"generated_tokens": generated_tokens, "generated_tokens": generated_tokens,
"offset": len(full_response), "offset": len(full_response),
@@ -1106,12 +984,12 @@ class ExllamaV2Container:
if request_logprobs > 0: if request_logprobs > 0:
# Get top tokens and probs # Get top tokens and probs
top_tokens = unwrap( top_tokens = unwrap(
raw_generation.get("top_tokens"), result.get("top_k_tokens"),
torch.empty((1, 0, 1), dtype=torch.long), torch.empty((1, 0, 1), dtype=torch.long),
) )
top_probs = unwrap( top_probs = unwrap(
raw_generation.get("top_probs"), result.get("top_k_probs"),
torch.empty((1, 0, 1), dtype=torch.float), torch.empty((1, 0, 1), dtype=torch.float),
) )
@@ -1126,25 +1004,32 @@ class ExllamaV2Container:
} }
yield generation yield generation
full_response += chunk_buffer
chunk_buffer = ""
last_chunk_time = now
if eos or generated_tokens == max_tokens: # Second yield if eos is true
# Print response if result.get("eos"):
log_response(full_response) log_response(full_response)
# Print metrics eos_reason = result.get("eos_reason")
elapsed_time = last_chunk_time - start_time finish_reason = (
context_len = None if ids is None else context_len "length" if eos_reason == "max_new_tokens" else "stop"
)
log_metrics( log_metrics(
generated_tokens, elapsed_time, context_len, self.config.max_seq_len result.get("time_enqueued"),
) result.get("prompt_tokens"),
result.get("time_prefill"),
result.get("new_tokens"),
result.get("time_generate"),
context_len,
self.config.max_seq_len,
)
finish_reason = "length" if generated_tokens == max_tokens else "stop" # Remove the token text
generation = {"finish_reason": finish_reason} generation = {
yield generation "prompt_tokens": generation.get("prompt_tokens"),
"generated_tokens": generation.get("generated_tokens"),
"finish_reason": finish_reason,
}
break yield generation
""" break

View File

@@ -70,29 +70,38 @@ def log_response(response: str):
def log_metrics( def log_metrics(
queue_time: float,
prompt_tokens: int,
prompt_time: float,
generated_tokens: int, generated_tokens: int,
elapsed_time: float, generate_time: float,
context_len: Optional[int], context_len: Optional[int],
max_seq_len: int, max_seq_len: int,
): ):
initial_response = ( initial_response = (
f"Metrics: {generated_tokens} tokens generated in " f"Metrics: {generated_tokens} tokens generated in "
f"{round(elapsed_time, 2)} seconds" f"{round(queue_time + prompt_time + generate_time, 2)} seconds"
) )
itemization = [] itemization = []
extra_parts = [] extra_parts = []
# Add tokens per second itemization.append(f"Queue: {round(queue_time, 2)} s")
tokens_per_second = (
"Indeterminate" prompt_ts = (
if elapsed_time == 0 "Indeterminate" if prompt_time == 0 else round(prompt_tokens / prompt_time, 2)
else round(generated_tokens / elapsed_time, 2)
) )
itemization.append(f"{tokens_per_second} T/s") itemization.append(f"Process: {prompt_ts} T/s")
generate_ts = (
"Indeterminate"
if generate_time == 0
else round(generated_tokens / generate_time, 2)
)
itemization.append(f"Generate: {generate_ts} T/s")
# Add context (original token count) # Add context (original token count)
if context_len: if context_len:
itemization.append(f"context {context_len} tokens") itemization.append(f"Context: {context_len} tokens")
if context_len > max_seq_len: if context_len > max_seq_len:
extra_parts.append("<-- Not accurate (truncated)") extra_parts.append("<-- Not accurate (truncated)")

View File

@@ -1,8 +1,7 @@
"""Chat completion utilities for OAI server.""" """Chat completion utilities for OAI server."""
from asyncio import CancelledError
import pathlib import pathlib
import threading from asyncio import CancelledError
from typing import Optional from typing import Optional
from uuid import uuid4 from uuid import uuid4
@@ -198,11 +197,8 @@ async def stream_generate_chat_completion(
"""Generator for the generation process.""" """Generator for the generation process."""
try: try:
const_id = f"chatcmpl-{uuid4().hex}" const_id = f"chatcmpl-{uuid4().hex}"
abort_event = threading.Event()
new_generation = model.container.generate_gen( new_generation = model.container.generate_gen(prompt, **data.to_gen_params())
prompt, abort_event, **data.to_gen_params()
)
async for generation in new_generation: async for generation in new_generation:
response = _create_stream_chunk(const_id, generation, model_path.name) response = _create_stream_chunk(const_id, generation, model_path.name)
@@ -214,7 +210,6 @@ async def stream_generate_chat_completion(
except CancelledError: except CancelledError:
# Get out if the request gets disconnected # Get out if the request gets disconnected
abort_event.set()
handle_request_disconnect("Chat completion generation cancelled by user.") handle_request_disconnect("Chat completion generation cancelled by user.")
except Exception: except Exception:
yield get_generator_error( yield get_generator_error(

View File

@@ -2,7 +2,6 @@
import pathlib import pathlib
from asyncio import CancelledError from asyncio import CancelledError
import threading
from fastapi import HTTPException from fastapi import HTTPException
from typing import Optional from typing import Optional
@@ -65,10 +64,8 @@ async def stream_generate_completion(data: CompletionRequest, model_path: pathli
"""Streaming generation for completions.""" """Streaming generation for completions."""
try: try:
abort_event = threading.Event()
new_generation = model.container.generate_gen( new_generation = model.container.generate_gen(
data.prompt, abort_event, **data.to_gen_params() data.prompt, **data.to_gen_params()
) )
async for generation in new_generation: async for generation in new_generation:
response = _create_response(generation, model_path.name) response = _create_response(generation, model_path.name)
@@ -81,7 +78,6 @@ async def stream_generate_completion(data: CompletionRequest, model_path: pathli
except CancelledError: except CancelledError:
# Get out if the request gets disconnected # Get out if the request gets disconnected
abort_event.set()
handle_request_disconnect("Completion generation cancelled by user.") handle_request_disconnect("Completion generation cancelled by user.")
except Exception: except Exception:
yield get_generator_error( yield get_generator_error(