mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-03-14 15:57:27 +00:00
OAI: Implement completion API endpoint
Add support for /v1/completions with the option to use streaming if needed. Also rewrite API endpoints to use async when possible since that improves request performance. Model container parameter names also needed rewrites as well and set fallback cases to their disabled values. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
65
model.py
65
model.py
@@ -1,6 +1,5 @@
|
||||
import gc, time
|
||||
import torch
|
||||
|
||||
from exllamav2 import(
|
||||
ExLlamaV2,
|
||||
ExLlamaV2Config,
|
||||
@@ -8,25 +7,26 @@ from exllamav2 import(
|
||||
ExLlamaV2Cache_8bit,
|
||||
ExLlamaV2Tokenizer,
|
||||
)
|
||||
|
||||
from exllamav2.generator import(
|
||||
ExLlamaV2StreamingGenerator,
|
||||
ExLlamaV2Sampler
|
||||
)
|
||||
from os import path
|
||||
from typing import Optional
|
||||
|
||||
# Bytes to reserve on first device when loading with auto split
|
||||
auto_split_reserve_bytes = 96 * 1024**2
|
||||
|
||||
class ModelContainer:
|
||||
|
||||
config: ExLlamaV2Config or None = None
|
||||
draft_config: ExLlamaV2Config or None = None
|
||||
model: ExLlamaV2 or None = None
|
||||
draft_model: ExLlamaV2 or None = None
|
||||
cache: ExLlamaV2Cache or None = None
|
||||
draft_cache: ExLlamaV2Cache or None = None
|
||||
tokenizer: ExLlamaV2Tokenizer or None = None
|
||||
generator: ExLlamaV2StreamingGenerator or None = None
|
||||
config: Optional[ExLlamaV2Config] = None
|
||||
draft_config: Optional[ExLlamaV2Config] = None
|
||||
model: Optional[ExLlamaV2] = None
|
||||
draft_model: Optional[ExLlamaV2] = None
|
||||
cache: Optional[ExLlamaV2Cache] = None
|
||||
draft_cache: Optional[ExLlamaV2Cache] = None
|
||||
tokenizer: Optional[ExLlamaV2Tokenizer] = None
|
||||
generator: Optional[ExLlamaV2StreamingGenerator] = None
|
||||
|
||||
cache_fp8: bool = False
|
||||
draft_enabled: bool = False
|
||||
@@ -102,6 +102,11 @@ class ModelContainer:
|
||||
self.draft_config.max_input_len = kwargs["chunk_size"]
|
||||
self.draft_config.max_attn_size = kwargs["chunk_size"] ** 2
|
||||
|
||||
def get_model_name(self):
|
||||
if self.draft_enabled:
|
||||
return path.basename(path.normpath(self.draft_config.model_dir))
|
||||
else:
|
||||
return path.basename(path.normpath(self.config.model_dir))
|
||||
|
||||
def load(self, progress_callback = None):
|
||||
"""
|
||||
@@ -201,20 +206,20 @@ class ModelContainer:
|
||||
prompt (str): Input prompt
|
||||
**kwargs:
|
||||
'token_healing' (bool): Use token healing (default: False)
|
||||
'temperature' (float): Sampling temperature (default: 0.8)
|
||||
'top_k' (int): Sampling top-K (default: 100)
|
||||
'top_p' (float): Sampling top-P (default: 0.8)
|
||||
'temperature' (float): Sampling temperature (default: 1.0)
|
||||
'top_k' (int): Sampling top-K (default: 0)
|
||||
'top_p' (float): Sampling top-P (default: 1.0)
|
||||
'min_p' (float): Sampling min-P (default: 0.0)
|
||||
'tfs' (float): Tail-free sampling (default: 0.0)
|
||||
'typical' (float): Sampling typical (default: 0.0)
|
||||
'mirostat' (bool): Use Mirostat (default: False)
|
||||
'mirostat_tau' (float) Mirostat tau parameter (default: 1.5)
|
||||
'mirostat_eta' (float) Mirostat eta parameter (default: 0.1)
|
||||
'token_repetition_penalty' (float): Token repetition/presence penalty (default: 1.15)
|
||||
'token_repetition_range' (int): Repetition penalty range (default: whole context)
|
||||
'token_repetition_decay' (int): Repetition penalty range (default: same as range)
|
||||
'stop_conditions' (list): List of stop strings/tokens to end response (default: [EOS])
|
||||
'max_new_tokens' (int): Max no. tokens in response (default: 150)
|
||||
'repetition_penalty' (float): Token repetition/presence penalty (default: 1.15)
|
||||
'repetition_range' (int): Repetition penalty range (default: whole context)
|
||||
'repetition_decay' (int): Repetition penalty range (default: same as range)
|
||||
'stop' (list): List of stop strings/tokens to end response (default: [EOS])
|
||||
'max_tokens' (int): Max no. tokens in response (default: 150)
|
||||
'stream_interval' (float): Interval in seconds between each output chunk (default: immediate)
|
||||
'generate_window' (int): Space to reserve at the end of the model's context when generating.
|
||||
Rolls context window by the same amount if context length is exceeded to allow generating past
|
||||
@@ -223,25 +228,27 @@ class ModelContainer:
|
||||
"""
|
||||
|
||||
token_healing = kwargs.get("token_healing", False)
|
||||
max_new_tokens = kwargs.get("max_new_tokens", 150)
|
||||
max_tokens = kwargs.get("max_tokens", 150)
|
||||
stream_interval = kwargs.get("stream_interval", 0)
|
||||
generate_window = min(kwargs.get("generate_window", 512), max_new_tokens)
|
||||
generate_window = min(kwargs.get("generate_window", 512), max_tokens)
|
||||
|
||||
# Sampler settings
|
||||
|
||||
gen_settings = ExLlamaV2Sampler.Settings()
|
||||
gen_settings.temperature = kwargs.get("temperature", 0.8)
|
||||
gen_settings.top_k = kwargs.get("top_k", 100)
|
||||
gen_settings.top_p = kwargs.get("top_p", 0.8)
|
||||
gen_settings.temperature = kwargs.get("temperature", 1.0)
|
||||
gen_settings.top_k = kwargs.get("top_k", 1)
|
||||
gen_settings.top_p = kwargs.get("top_p", 1.0)
|
||||
gen_settings.min_p = kwargs.get("min_p", 0.0)
|
||||
gen_settings.tfs = kwargs.get("tfs", 0.0)
|
||||
gen_settings.typical = kwargs.get("typical", 0.0)
|
||||
gen_settings.mirostat = kwargs.get("mirostat", False)
|
||||
|
||||
# Default tau and eta fallbacks don't matter if mirostat is off
|
||||
gen_settings.mirostat_tau = kwargs.get("mirostat_tau", 1.5)
|
||||
gen_settings.mirostat_eta = kwargs.get("mirostat_eta", 0.1)
|
||||
gen_settings.token_repetition_penalty = kwargs.get("token_repetition_penalty", 1.15)
|
||||
gen_settings.token_repetition_range = kwargs.get("token_repetition_range", self.config.max_seq_len)
|
||||
gen_settings.token_repetition_decay = kwargs.get("token_repetition_decay", gen_settings.token_repetition_range)
|
||||
gen_settings.token_repetition_penalty = kwargs.get("repetition_penalty", 1.0)
|
||||
gen_settings.token_repetition_range = kwargs.get("repetition_range", self.config.max_seq_len)
|
||||
gen_settings.token_repetition_decay = kwargs.get("repetition_decay", gen_settings.token_repetition_range)
|
||||
|
||||
# Override sampler settings for temp = 0
|
||||
|
||||
@@ -253,7 +260,7 @@ class ModelContainer:
|
||||
|
||||
# Stop conditions
|
||||
|
||||
self.generator.set_stop_conditions(kwargs.get("stop_conditions", [self.tokenizer.eos_token_id]))
|
||||
self.generator.set_stop_conditions(kwargs.get("stop", [self.tokenizer.eos_token_id]))
|
||||
|
||||
# Tokenized context
|
||||
|
||||
@@ -302,10 +309,10 @@ class ModelContainer:
|
||||
now = time.time()
|
||||
elapsed = now - last_chunk_time
|
||||
|
||||
if chunk_buffer != "" and (elapsed > stream_interval or eos or generated_tokens == max_new_tokens):
|
||||
if chunk_buffer != "" and (elapsed > stream_interval or eos or generated_tokens == max_tokens):
|
||||
yield chunk_buffer
|
||||
full_response += chunk_buffer
|
||||
chunk_buffer = ""
|
||||
last_chunk_time = now
|
||||
|
||||
if eos or generated_tokens == max_new_tokens: break
|
||||
if eos or generated_tokens == max_tokens: break
|
||||
Reference in New Issue
Block a user