mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-03-15 00:07:28 +00:00
Tree: Use unwrap and coalesce for optional handling
Python doesn't have proper handling of optionals. The only way to handle them is checking via an if statement if the value is None or by using the "or" keyword to unwrap optionals. Previously, I used the "or" method to unwrap, but this caused issues due to falsy values falling back to the default. This is especially the case with booleans were "False" changed to "True". Instead, add two new functions: unwrap and coalesce. Both function to properly implement a functional way of "None" coalescing. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
84
model.py
84
model.py
@@ -13,6 +13,7 @@ from exllamav2.generator import(
|
||||
ExLlamaV2Sampler
|
||||
)
|
||||
from typing import List, Optional, Union
|
||||
from utils import coalesce, unwrap
|
||||
|
||||
# Bytes to reserve on first device when loading with auto split
|
||||
auto_split_reserve_bytes = 96 * 1024**2
|
||||
@@ -30,7 +31,7 @@ class ModelContainer:
|
||||
|
||||
cache_fp8: bool = False
|
||||
gpu_split_auto: bool = True
|
||||
gpu_split: list or None = None
|
||||
gpu_split: Optional[list] = None
|
||||
|
||||
active_loras: List[ExLlamaV2Lora] = []
|
||||
|
||||
@@ -68,7 +69,7 @@ class ModelContainer:
|
||||
|
||||
self.cache_fp8 = "cache_mode" in kwargs and kwargs["cache_mode"] == "FP8"
|
||||
self.gpu_split = kwargs.get("gpu_split")
|
||||
self.gpu_split_auto = kwargs.get("gpu_split_auto") or True
|
||||
self.gpu_split_auto = unwrap(kwargs.get("gpu_split_auto"), True)
|
||||
|
||||
self.config = ExLlamaV2Config()
|
||||
self.config.model_dir = str(model_directory.resolve())
|
||||
@@ -78,14 +79,14 @@ class ModelContainer:
|
||||
base_seq_len = self.config.max_seq_len
|
||||
|
||||
# Then override the max_seq_len if present
|
||||
self.config.max_seq_len = kwargs.get("max_seq_len") or 4096
|
||||
self.config.scale_pos_emb = kwargs.get("rope_scale") or 1.0
|
||||
self.config.max_seq_len = unwrap(kwargs.get("max_seq_len"), 4096)
|
||||
self.config.scale_pos_emb = unwrap(kwargs.get("rope_scale"), 1.0)
|
||||
|
||||
# Automatically calculate rope alpha
|
||||
self.config.scale_alpha_value = kwargs.get("rope_alpha") or self.calculate_rope_alpha(base_seq_len)
|
||||
self.config.scale_alpha_value = unwrap(kwargs.get("rope_alpha"), self.calculate_rope_alpha(base_seq_len))
|
||||
|
||||
# Turn off flash attention?
|
||||
self.config.no_flash_attn = kwargs.get("no_flash_attn") or False
|
||||
self.config.no_flash_attn = unwrap(kwargs.get("no_flash_attn"), False)
|
||||
|
||||
# low_mem is currently broken in exllamav2. Don't use it until it's fixed.
|
||||
"""
|
||||
@@ -93,11 +94,11 @@ class ModelContainer:
|
||||
self.config.set_low_mem()
|
||||
"""
|
||||
|
||||
chunk_size = min(kwargs.get("chunk_size") or 2048, self.config.max_seq_len)
|
||||
chunk_size = min(unwrap(kwargs.get("chunk_size"), 2048), self.config.max_seq_len)
|
||||
self.config.max_input_len = chunk_size
|
||||
self.config.max_attn_size = chunk_size ** 2
|
||||
|
||||
draft_args = kwargs.get("draft") or {}
|
||||
draft_args = unwrap(kwargs.get("draft"), {})
|
||||
draft_model_name = draft_args.get("draft_model_name")
|
||||
enable_draft = draft_args and draft_model_name
|
||||
|
||||
@@ -109,14 +110,14 @@ class ModelContainer:
|
||||
if enable_draft:
|
||||
|
||||
self.draft_config = ExLlamaV2Config()
|
||||
draft_model_path = pathlib.Path(draft_args.get("draft_model_dir") or "models")
|
||||
draft_model_path = pathlib.Path(unwrap(draft_args.get("draft_model_dir"), "models"))
|
||||
draft_model_path = draft_model_path / draft_model_name
|
||||
|
||||
self.draft_config.model_dir = str(draft_model_path.resolve())
|
||||
self.draft_config.prepare()
|
||||
|
||||
self.draft_config.scale_pos_emb = draft_args.get("draft_rope_scale") or 1.0
|
||||
self.draft_config.scale_alpha_value = draft_args.get("draft_rope_alpha") or self.calculate_rope_alpha(self.draft_config.max_seq_len)
|
||||
self.draft_config.scale_pos_emb = unwrap(draft_args.get("draft_rope_scale"), 1.0)
|
||||
self.draft_config.scale_alpha_value = unwrap(draft_args.get("draft_rope_alpha"), self.calculate_rope_alpha(self.draft_config.max_seq_len))
|
||||
self.draft_config.max_seq_len = self.config.max_seq_len
|
||||
|
||||
if "chunk_size" in kwargs:
|
||||
@@ -151,13 +152,13 @@ class ModelContainer:
|
||||
Load loras
|
||||
"""
|
||||
|
||||
loras = kwargs.get("loras") or []
|
||||
loras = unwrap(kwargs.get("loras"), [])
|
||||
success: List[str] = []
|
||||
failure: List[str] = []
|
||||
|
||||
for lora in loras:
|
||||
lora_name = lora.get("name") or None
|
||||
lora_scaling = lora.get("scaling") or 1.0
|
||||
lora_name = lora.get("name")
|
||||
lora_scaling = unwrap(lora.get("scaling"), 1.0)
|
||||
|
||||
if lora_name is None:
|
||||
print("One of your loras does not have a name. Please check your config.yml! Skipping lora load.")
|
||||
@@ -265,13 +266,13 @@ class ModelContainer:
|
||||
# Assume token encoding
|
||||
return self.tokenizer.encode(
|
||||
text,
|
||||
add_bos = kwargs.get("add_bos_token") or True,
|
||||
encode_special_tokens = kwargs.get("encode_special_tokens") or True
|
||||
add_bos = unwrap(kwargs.get("add_bos_token"), True),
|
||||
encode_special_tokens = unwrap(kwargs.get("encode_special_tokens"), True)
|
||||
)
|
||||
if ids:
|
||||
# Assume token decoding
|
||||
ids = torch.tensor([ids])
|
||||
return self.tokenizer.decode(ids, decode_special_tokens = kwargs.get("decode_special_tokens") or True)[0]
|
||||
return self.tokenizer.decode(ids, decode_special_tokens = unwrap(kwargs.get("decode_special_tokens"), True))[0]
|
||||
|
||||
|
||||
def generate(self, prompt: str, **kwargs):
|
||||
@@ -311,10 +312,10 @@ class ModelContainer:
|
||||
|
||||
"""
|
||||
|
||||
token_healing = kwargs.get("token_healing") or False
|
||||
max_tokens = kwargs.get("max_tokens") or 150
|
||||
stream_interval = kwargs.get("stream_interval") or 0
|
||||
generate_window = min(kwargs.get("generate_window") or 512, max_tokens)
|
||||
token_healing = unwrap(kwargs.get("token_healing"), False)
|
||||
max_tokens = unwrap(kwargs.get("max_tokens"), 150)
|
||||
stream_interval = unwrap(kwargs.get("stream_interval"), 0)
|
||||
generate_window = min(unwrap(kwargs.get("generate_window"), 512), max_tokens)
|
||||
|
||||
# Sampler settings
|
||||
|
||||
@@ -322,42 +323,43 @@ class ModelContainer:
|
||||
|
||||
# Warn of unsupported settings if the setting is enabled
|
||||
|
||||
if (kwargs.get("mirostat") or False) and not hasattr(gen_settings, "mirostat"):
|
||||
if (unwrap(kwargs.get("mirostat"), False)) and not hasattr(gen_settings, "mirostat"):
|
||||
print(" !! Warning: Currently installed ExLlamaV2 does not support Mirostat sampling")
|
||||
|
||||
if (kwargs.get("min_p") or 0.0) not in [0.0, 1.0] and not hasattr(gen_settings, "min_p"):
|
||||
if (unwrap(kwargs.get("min_p"), 0.0)) not in [0.0, 1.0] and not hasattr(gen_settings, "min_p"):
|
||||
print(" !! Warning: Currently installed ExLlamaV2 does not support min-P sampling")
|
||||
|
||||
if (kwargs.get("tfs") or 0.0) not in [0.0, 1.0] and not hasattr(gen_settings, "tfs"):
|
||||
if (unwrap(kwargs.get("tfs"), 0.0)) not in [0.0, 1.0] and not hasattr(gen_settings, "tfs"):
|
||||
print(" !! Warning: Currently installed ExLlamaV2 does not support tail-free sampling (TFS)")
|
||||
|
||||
if (kwargs.get("temperature_last") or False) and not hasattr(gen_settings, "temperature_last"):
|
||||
if (unwrap(kwargs.get("temperature_last"), False)) and not hasattr(gen_settings, "temperature_last"):
|
||||
print(" !! Warning: Currently installed ExLlamaV2 does not support temperature_last")
|
||||
|
||||
#Apply settings
|
||||
|
||||
gen_settings.temperature = kwargs.get("temperature") or 1.0
|
||||
gen_settings.temperature_last = kwargs.get("temperature_last") or False
|
||||
gen_settings.top_k = kwargs.get("top_k") or 0
|
||||
gen_settings.top_p = kwargs.get("top_p") or 1.0
|
||||
gen_settings.min_p = kwargs.get("min_p") or 0.0
|
||||
gen_settings.tfs = kwargs.get("tfs") or 1.0
|
||||
gen_settings.typical = kwargs.get("typical") or 1.0
|
||||
gen_settings.mirostat = kwargs.get("mirostat") or False
|
||||
gen_settings.temperature = unwrap(kwargs.get("temperature"), 1.0)
|
||||
gen_settings.temperature_last = unwrap(kwargs.get("temperature_last"), False)
|
||||
gen_settings.top_k = unwrap(kwargs.get("top_k"), 0)
|
||||
gen_settings.top_p = unwrap(kwargs.get("top_p"), 1.0)
|
||||
gen_settings.min_p = unwrap(kwargs.get("min_p"), 0.0)
|
||||
gen_settings.tfs = unwrap(kwargs.get("tfs"), 1.0)
|
||||
gen_settings.typical = unwrap(kwargs.get("typical"), 1.0)
|
||||
gen_settings.mirostat = unwrap(kwargs.get("mirostat"), False)
|
||||
|
||||
# Default tau and eta fallbacks don't matter if mirostat is off
|
||||
gen_settings.mirostat_tau = kwargs.get("mirostat_tau") or 1.5
|
||||
gen_settings.mirostat_eta = kwargs.get("mirostat_eta") or 0.1
|
||||
gen_settings.token_repetition_penalty = kwargs.get("repetition_penalty") or 1.0
|
||||
gen_settings.token_repetition_range = kwargs.get("repetition_range") or self.config.max_seq_len
|
||||
gen_settings.mirostat_tau = unwrap(kwargs.get("mirostat_tau"), 1.5)
|
||||
gen_settings.mirostat_eta = unwrap(kwargs.get("mirostat_eta"), 0.1)
|
||||
gen_settings.token_repetition_penalty = unwrap(kwargs.get("repetition_penalty"), 1.0)
|
||||
gen_settings.token_repetition_range = unwrap(kwargs.get("repetition_range"), self.config.max_seq_len)
|
||||
|
||||
# Always make sure the fallback is 0 if range < 0
|
||||
# It's technically fine to use -1, but this just validates the passed fallback
|
||||
# Always default to 0 if something goes wrong
|
||||
fallback_decay = 0 if gen_settings.token_repetition_range <= 0 else gen_settings.token_repetition_range
|
||||
gen_settings.token_repetition_decay = kwargs.get("repetition_decay") or fallback_decay or 0
|
||||
gen_settings.token_repetition_decay = coalesce(kwargs.get("repetition_decay"), fallback_decay, 0)
|
||||
|
||||
stop_conditions: List[Union[str, int]] = kwargs.get("stop") or []
|
||||
ban_eos_token = kwargs.get("ban_eos_token") or False
|
||||
stop_conditions: List[Union[str, int]] = unwrap(kwargs.get("stop"), [])
|
||||
ban_eos_token = unwrap(kwargs.get("ban_eos_token"), False)
|
||||
|
||||
|
||||
# Ban the EOS token if specified. If not, append to stop conditions as well.
|
||||
@@ -383,7 +385,7 @@ class ModelContainer:
|
||||
|
||||
ids = self.tokenizer.encode(
|
||||
prompt,
|
||||
add_bos = kwargs.get("add_bos_token") or True,
|
||||
add_bos = unwrap(kwargs.get("add_bos_token"), True),
|
||||
encode_special_tokens = True
|
||||
)
|
||||
context_len = len(ids[0])
|
||||
|
||||
Reference in New Issue
Block a user