OAI: Implement completion API endpoint

Add support for /v1/completions with the option to use streaming
if needed. Also rewrite API endpoints to use async when possible
since that improves request performance.

Model container parameter names also needed rewrites as well and
set fallback cases to their disabled values.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri
2023-11-13 18:24:12 -05:00
parent 4fa4386275
commit eee8b642bd
6 changed files with 190 additions and 57 deletions

View File

@@ -1,6 +1,5 @@
import gc, time
import torch
from exllamav2 import(
ExLlamaV2,
ExLlamaV2Config,
@@ -8,25 +7,26 @@ from exllamav2 import(
ExLlamaV2Cache_8bit,
ExLlamaV2Tokenizer,
)
from exllamav2.generator import(
ExLlamaV2StreamingGenerator,
ExLlamaV2Sampler
)
from os import path
from typing import Optional
# Bytes to reserve on first device when loading with auto split
auto_split_reserve_bytes = 96 * 1024**2
class ModelContainer:
config: ExLlamaV2Config or None = None
draft_config: ExLlamaV2Config or None = None
model: ExLlamaV2 or None = None
draft_model: ExLlamaV2 or None = None
cache: ExLlamaV2Cache or None = None
draft_cache: ExLlamaV2Cache or None = None
tokenizer: ExLlamaV2Tokenizer or None = None
generator: ExLlamaV2StreamingGenerator or None = None
config: Optional[ExLlamaV2Config] = None
draft_config: Optional[ExLlamaV2Config] = None
model: Optional[ExLlamaV2] = None
draft_model: Optional[ExLlamaV2] = None
cache: Optional[ExLlamaV2Cache] = None
draft_cache: Optional[ExLlamaV2Cache] = None
tokenizer: Optional[ExLlamaV2Tokenizer] = None
generator: Optional[ExLlamaV2StreamingGenerator] = None
cache_fp8: bool = False
draft_enabled: bool = False
@@ -102,6 +102,11 @@ class ModelContainer:
self.draft_config.max_input_len = kwargs["chunk_size"]
self.draft_config.max_attn_size = kwargs["chunk_size"] ** 2
def get_model_name(self):
if self.draft_enabled:
return path.basename(path.normpath(self.draft_config.model_dir))
else:
return path.basename(path.normpath(self.config.model_dir))
def load(self, progress_callback = None):
"""
@@ -201,20 +206,20 @@ class ModelContainer:
prompt (str): Input prompt
**kwargs:
'token_healing' (bool): Use token healing (default: False)
'temperature' (float): Sampling temperature (default: 0.8)
'top_k' (int): Sampling top-K (default: 100)
'top_p' (float): Sampling top-P (default: 0.8)
'temperature' (float): Sampling temperature (default: 1.0)
'top_k' (int): Sampling top-K (default: 0)
'top_p' (float): Sampling top-P (default: 1.0)
'min_p' (float): Sampling min-P (default: 0.0)
'tfs' (float): Tail-free sampling (default: 0.0)
'typical' (float): Sampling typical (default: 0.0)
'mirostat' (bool): Use Mirostat (default: False)
'mirostat_tau' (float) Mirostat tau parameter (default: 1.5)
'mirostat_eta' (float) Mirostat eta parameter (default: 0.1)
'token_repetition_penalty' (float): Token repetition/presence penalty (default: 1.15)
'token_repetition_range' (int): Repetition penalty range (default: whole context)
'token_repetition_decay' (int): Repetition penalty range (default: same as range)
'stop_conditions' (list): List of stop strings/tokens to end response (default: [EOS])
'max_new_tokens' (int): Max no. tokens in response (default: 150)
'repetition_penalty' (float): Token repetition/presence penalty (default: 1.15)
'repetition_range' (int): Repetition penalty range (default: whole context)
'repetition_decay' (int): Repetition penalty range (default: same as range)
'stop' (list): List of stop strings/tokens to end response (default: [EOS])
'max_tokens' (int): Max no. tokens in response (default: 150)
'stream_interval' (float): Interval in seconds between each output chunk (default: immediate)
'generate_window' (int): Space to reserve at the end of the model's context when generating.
Rolls context window by the same amount if context length is exceeded to allow generating past
@@ -223,25 +228,27 @@ class ModelContainer:
"""
token_healing = kwargs.get("token_healing", False)
max_new_tokens = kwargs.get("max_new_tokens", 150)
max_tokens = kwargs.get("max_tokens", 150)
stream_interval = kwargs.get("stream_interval", 0)
generate_window = min(kwargs.get("generate_window", 512), max_new_tokens)
generate_window = min(kwargs.get("generate_window", 512), max_tokens)
# Sampler settings
gen_settings = ExLlamaV2Sampler.Settings()
gen_settings.temperature = kwargs.get("temperature", 0.8)
gen_settings.top_k = kwargs.get("top_k", 100)
gen_settings.top_p = kwargs.get("top_p", 0.8)
gen_settings.temperature = kwargs.get("temperature", 1.0)
gen_settings.top_k = kwargs.get("top_k", 1)
gen_settings.top_p = kwargs.get("top_p", 1.0)
gen_settings.min_p = kwargs.get("min_p", 0.0)
gen_settings.tfs = kwargs.get("tfs", 0.0)
gen_settings.typical = kwargs.get("typical", 0.0)
gen_settings.mirostat = kwargs.get("mirostat", False)
# Default tau and eta fallbacks don't matter if mirostat is off
gen_settings.mirostat_tau = kwargs.get("mirostat_tau", 1.5)
gen_settings.mirostat_eta = kwargs.get("mirostat_eta", 0.1)
gen_settings.token_repetition_penalty = kwargs.get("token_repetition_penalty", 1.15)
gen_settings.token_repetition_range = kwargs.get("token_repetition_range", self.config.max_seq_len)
gen_settings.token_repetition_decay = kwargs.get("token_repetition_decay", gen_settings.token_repetition_range)
gen_settings.token_repetition_penalty = kwargs.get("repetition_penalty", 1.0)
gen_settings.token_repetition_range = kwargs.get("repetition_range", self.config.max_seq_len)
gen_settings.token_repetition_decay = kwargs.get("repetition_decay", gen_settings.token_repetition_range)
# Override sampler settings for temp = 0
@@ -253,7 +260,7 @@ class ModelContainer:
# Stop conditions
self.generator.set_stop_conditions(kwargs.get("stop_conditions", [self.tokenizer.eos_token_id]))
self.generator.set_stop_conditions(kwargs.get("stop", [self.tokenizer.eos_token_id]))
# Tokenized context
@@ -302,10 +309,10 @@ class ModelContainer:
now = time.time()
elapsed = now - last_chunk_time
if chunk_buffer != "" and (elapsed > stream_interval or eos or generated_tokens == max_new_tokens):
if chunk_buffer != "" and (elapsed > stream_interval or eos or generated_tokens == max_tokens):
yield chunk_buffer
full_response += chunk_buffer
chunk_buffer = ""
last_chunk_time = now
if eos or generated_tokens == max_new_tokens: break
if eos or generated_tokens == max_tokens: break