mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-03-14 15:57:27 +00:00
Model: Cleanup and fix fallbacks
Use the standard "dict.get("key") or default" to handle fetching values
from kwargs and get a fallback value without possible errors.
Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
99
model.py
99
model.py
@@ -1,6 +1,5 @@
|
||||
import gc, time, pathlib
|
||||
import torch
|
||||
from datetime import datetime
|
||||
from exllamav2 import(
|
||||
ExLlamaV2,
|
||||
ExLlamaV2Config,
|
||||
@@ -29,7 +28,6 @@ class ModelContainer:
|
||||
generator: Optional[ExLlamaV2StreamingGenerator] = None
|
||||
|
||||
cache_fp8: bool = False
|
||||
draft_enabled: bool = False
|
||||
gpu_split_auto: bool = True
|
||||
gpu_split: list or None = None
|
||||
|
||||
@@ -44,7 +42,7 @@ class ModelContainer:
|
||||
def progress(loaded_modules: int, total_modules: int, loading_draft: bool)
|
||||
**kwargs:
|
||||
`cache_mode` (str): Sets cache mode, "FP16" or "FP8" (defaulf: "FP16")
|
||||
'max_seq_len' (int): Override model's default max sequence length
|
||||
'max_seq_len' (int): Override model's default max sequence length (default: 4096)
|
||||
'rope_scale' (float): Set RoPE scaling factor for model (default: 1.0)
|
||||
'rope_alpha' (float): Set RoPE alpha (NTK) factor for model (default: 1.0)
|
||||
'chunk_size' (int): Sets the maximum chunk size for the model (default: 2048)
|
||||
@@ -52,19 +50,20 @@ class ModelContainer:
|
||||
batches. This limits the size of temporary buffers needed for the hidden state and attention
|
||||
weights.
|
||||
'draft_model_dir' (str): Draft model directory
|
||||
'draft_rope_scale' (float): Set RoPE scaling factor for draft model (default: 1.0)
|
||||
'draft_rope_alpha' (float): RoPE alpha (NTK) factor for draft model.
|
||||
By default, the draft model's alpha value is calculated automatically to scale to the size of the
|
||||
full model.
|
||||
'gpu_split_auto' (bool): Automatically split model across available devices (default: True)
|
||||
'gpu_split' (list[float]): Allocation for weights and (some) tensors, per device
|
||||
'no_flash_attn' (bool): Turns off flash attention (increases vram usage)
|
||||
'no_flash_attn' (bool): Turns off flash attention (increases vram usage) (default: False)
|
||||
"""
|
||||
|
||||
self.quiet = quiet
|
||||
|
||||
self.cache_fp8 = "cache_mode" in kwargs and kwargs["cache_mode"] == "FP8"
|
||||
self.gpu_split = kwargs.get("gpu_split", None)
|
||||
self.gpu_split_auto = kwargs.get("gpu_split_auto", True)
|
||||
self.gpu_split = kwargs.get("gpu_split")
|
||||
self.gpu_split_auto = kwargs.get("gpu_split_auto") or True
|
||||
|
||||
self.config = ExLlamaV2Config()
|
||||
self.config.model_dir = str(model_directory.resolve())
|
||||
@@ -74,13 +73,14 @@ class ModelContainer:
|
||||
base_seq_len = self.config.max_seq_len
|
||||
|
||||
# Then override the max_seq_len if present
|
||||
if "max_seq_len" in kwargs: self.config.max_seq_len = kwargs["max_seq_len"]
|
||||
if "rope_scale" in kwargs: self.config.scale_pos_emb = kwargs["rope_scale"]
|
||||
self.config.max_seq_len = kwargs.get("max_seq_len") or 4096
|
||||
self.config.scale_pos_emb = kwargs.get("rope_scale") or 1.0
|
||||
|
||||
# Automatically calculate rope alpha
|
||||
self.config.scale_alpha_value = kwargs.get("rope_alpha") or self.calculate_rope_alpha(base_seq_len)
|
||||
|
||||
if "no_flash_attn" in kwargs: self.config.no_flash_attn = kwargs["no_flash_attn"]
|
||||
# Turn off flash attention?
|
||||
self.config.no_flash_attn = kwargs.get("no_flash_attn") or False
|
||||
|
||||
# low_mem is currently broken in exllamav2. Don't use it until it's fixed.
|
||||
"""
|
||||
@@ -88,31 +88,30 @@ class ModelContainer:
|
||||
self.config.set_low_mem()
|
||||
"""
|
||||
|
||||
chunk_size = min(kwargs.get("chunk_size", 2048), self.config.max_seq_len)
|
||||
chunk_size = min(kwargs.get("chunk_size") or 2048, self.config.max_seq_len)
|
||||
self.config.max_input_len = chunk_size
|
||||
self.config.max_attn_size = chunk_size ** 2
|
||||
|
||||
draft_config = kwargs.get("draft") or {}
|
||||
draft_model_name = draft_config.get("draft_model_name")
|
||||
enable_draft = bool(draft_config) and draft_model_name is not None
|
||||
draft_args = kwargs.get("draft") or {}
|
||||
draft_model_name = draft_args.get("draft_model_name")
|
||||
enable_draft = draft_args and draft_model_name
|
||||
|
||||
if bool(draft_config) and draft_model_name is None:
|
||||
# Always disable draft if params are incorrectly configured
|
||||
if draft_args and draft_model_name is None:
|
||||
print("A draft config was found but a model name was not given. Please check your config.yml! Skipping draft load.")
|
||||
self.draft_enabled = False
|
||||
else:
|
||||
self.draft_enabled = enable_draft
|
||||
enable_draft = False
|
||||
|
||||
if self.draft_enabled:
|
||||
if enable_draft:
|
||||
|
||||
self.draft_config = ExLlamaV2Config()
|
||||
draft_model_path = pathlib.Path(draft_config.get("draft_model_dir") or "models")
|
||||
draft_model_path = pathlib.Path(draft_args.get("draft_model_dir") or "models")
|
||||
draft_model_path = draft_model_path / draft_model_name
|
||||
|
||||
self.draft_config.model_dir = str(draft_model_path.resolve())
|
||||
self.draft_config.prepare()
|
||||
|
||||
self.draft_config.scale_pos_emb = draft_config.get("draft_rope_scale") or 1.0
|
||||
self.draft_config.scale_alpha_value = draft_config.get("draft_rope_alpha") or self.calculate_rope_alpha(self.draft_config.max_seq_len)
|
||||
self.draft_config.scale_pos_emb = draft_args.get("draft_rope_scale") or 1.0
|
||||
self.draft_config.scale_alpha_value = draft_args.get("draft_rope_alpha") or self.calculate_rope_alpha(self.draft_config.max_seq_len)
|
||||
self.draft_config.max_seq_len = self.config.max_seq_len
|
||||
|
||||
if "chunk_size" in kwargs:
|
||||
@@ -156,9 +155,9 @@ class ModelContainer:
|
||||
|
||||
self.tokenizer = ExLlamaV2Tokenizer(self.config)
|
||||
|
||||
# Load draft model
|
||||
# Load draft model if a config is present
|
||||
|
||||
if self.draft_enabled:
|
||||
if self.draft_config:
|
||||
|
||||
self.draft_model = ExLlamaV2(self.draft_config)
|
||||
if not self.quiet:
|
||||
@@ -228,13 +227,13 @@ class ModelContainer:
|
||||
# Assume token encoding
|
||||
return self.tokenizer.encode(
|
||||
text,
|
||||
add_bos = kwargs.get("add_bos_token", True),
|
||||
encode_special_tokens = kwargs.get("encode_special_tokens", True)
|
||||
add_bos = kwargs.get("add_bos_token") or True,
|
||||
encode_special_tokens = kwargs.get("encode_special_tokens") or True
|
||||
)
|
||||
if ids:
|
||||
# Assume token decoding
|
||||
ids = torch.tensor([ids])
|
||||
return self.tokenizer.decode(ids, decode_special_tokens = kwargs.get("decode_special_tokens", True))[0]
|
||||
return self.tokenizer.decode(ids, decode_special_tokens = kwargs.get("decode_special_tokens") or True)[0]
|
||||
|
||||
|
||||
def generate(self, prompt: str, **kwargs):
|
||||
@@ -274,10 +273,10 @@ class ModelContainer:
|
||||
|
||||
"""
|
||||
|
||||
token_healing = kwargs.get("token_healing", False)
|
||||
max_tokens = kwargs.get("max_tokens", 150)
|
||||
stream_interval = kwargs.get("stream_interval", 0)
|
||||
generate_window = min(kwargs.get("generate_window", 512), max_tokens)
|
||||
token_healing = kwargs.get("token_healing") or False
|
||||
max_tokens = kwargs.get("max_tokens") or 150
|
||||
stream_interval = kwargs.get("stream_interval") or 0
|
||||
generate_window = min(kwargs.get("generate_window") or 512, max_tokens)
|
||||
|
||||
# Sampler settings
|
||||
|
||||
@@ -285,43 +284,43 @@ class ModelContainer:
|
||||
|
||||
# Warn of unsupported settings if the setting is enabled
|
||||
|
||||
if kwargs.get("mirostat", False) and not hasattr(gen_settings, "mirostat"):
|
||||
if kwargs.get("mirostat") or False and not hasattr(gen_settings, "mirostat"):
|
||||
print(" !! Warning: Currently installed ExLlamaV2 does not support Mirostat sampling")
|
||||
|
||||
if kwargs.get("min_p", 0.0) not in [0.0, 1.0] and not hasattr(gen_settings, "min_p"):
|
||||
if kwargs.get("min_p") or 0.0 not in [0.0, 1.0] and not hasattr(gen_settings, "min_p"):
|
||||
print(" !! Warning: Currently installed ExLlamaV2 does not support min-P sampling")
|
||||
|
||||
if kwargs.get("tfs", 0.0) not in [0.0, 1.0] and not hasattr(gen_settings, "tfs"):
|
||||
if kwargs.get("tfs") or 0.0 not in [0.0, 1.0] and not hasattr(gen_settings, "tfs"):
|
||||
print(" !! Warning: Currently installed ExLlamaV2 does not support tail-free sampling (TFS)")
|
||||
|
||||
if kwargs.get("temperature_last", False) and not hasattr(gen_settings, "temperature_last"):
|
||||
if kwargs.get("temperature_last") or False and not hasattr(gen_settings, "temperature_last"):
|
||||
print(" !! Warning: Currently installed ExLlamaV2 does not support temperature_last")
|
||||
|
||||
#Apply settings
|
||||
|
||||
gen_settings.temperature = kwargs.get("temperature", 1.0)
|
||||
gen_settings.temperature_last = kwargs.get("temperature_last", False)
|
||||
gen_settings.top_k = kwargs.get("top_k", 1)
|
||||
gen_settings.top_p = kwargs.get("top_p", 1.0)
|
||||
gen_settings.min_p = kwargs.get("min_p", 0.0)
|
||||
gen_settings.tfs = kwargs.get("tfs", 0.0)
|
||||
gen_settings.typical = kwargs.get("typical", 0.0)
|
||||
gen_settings.mirostat = kwargs.get("mirostat", False)
|
||||
gen_settings.temperature = kwargs.get("temperature") or 1.0
|
||||
gen_settings.temperature_last = kwargs.get("temperature_last") or False
|
||||
gen_settings.top_k = kwargs.get("top_k") or 1
|
||||
gen_settings.top_p = kwargs.get("top_p") or 1.0
|
||||
gen_settings.min_p = kwargs.get("min_p") or 0.0
|
||||
gen_settings.tfs = kwargs.get("tfs") or 0.0
|
||||
gen_settings.typical = kwargs.get("typical") or 0.0
|
||||
gen_settings.mirostat = kwargs.get("mirostat") or False
|
||||
|
||||
# Default tau and eta fallbacks don't matter if mirostat is off
|
||||
gen_settings.mirostat_tau = kwargs.get("mirostat_tau", 1.5)
|
||||
gen_settings.mirostat_eta = kwargs.get("mirostat_eta", 0.1)
|
||||
gen_settings.token_repetition_penalty = kwargs.get("repetition_penalty", 1.0)
|
||||
gen_settings.token_repetition_range = kwargs.get("repetition_range", self.config.max_seq_len)
|
||||
gen_settings.mirostat_tau = kwargs.get("mirostat_tau") or 1.5
|
||||
gen_settings.mirostat_eta = kwargs.get("mirostat_eta") or 0.1
|
||||
gen_settings.token_repetition_penalty = kwargs.get("repetition_penalty") or 1.0
|
||||
gen_settings.token_repetition_range = kwargs.get("repetition_range") or self.config.max_seq_len
|
||||
|
||||
|
||||
# Always make sure the fallback is 0 if range < 0
|
||||
# It's technically fine to use -1, but this just validates the passed fallback
|
||||
fallback_decay = 0 if gen_settings.token_repetition_penalty <= 0 else gen_settings.token_repetition_range
|
||||
gen_settings.token_repetition_decay = kwargs.get("repetition_decay", fallback_decay or 0)
|
||||
gen_settings.token_repetition_decay = kwargs.get("repetition_decay") or fallback_decay or 0
|
||||
|
||||
stop_conditions: List[Union[str, int]] = kwargs.get("stop", [])
|
||||
ban_eos_token = kwargs.get("ban_eos_token", False)
|
||||
stop_conditions: List[Union[str, int]] = kwargs.get("stop") or []
|
||||
ban_eos_token = kwargs.get("ban_eos_token") or False
|
||||
|
||||
|
||||
# Ban the EOS token if specified. If not, append to stop conditions as well.
|
||||
@@ -347,7 +346,7 @@ class ModelContainer:
|
||||
|
||||
ids = self.tokenizer.encode(
|
||||
prompt,
|
||||
add_bos=kwargs.get("add_bos_token", True),
|
||||
add_bos = kwargs.get("add_bos_token") or True,
|
||||
encode_special_tokens = True
|
||||
)
|
||||
context_len = len(ids[0])
|
||||
|
||||
Reference in New Issue
Block a user