Tree: Use unwrap and coalesce for optional handling

Python doesn't have proper handling of optionals. The only way to
handle them is checking via an if statement if the value is None or
by using the "or" keyword to unwrap optionals.

Previously, I used the "or" method to unwrap, but this caused issues
due to falsy values falling back to the default. This is especially
the case with booleans were "False" changed to "True".

Instead, add two new functions: unwrap and coalesce. Both function
to properly implement a functional way of "None" coalescing.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri
2023-12-09 21:52:17 -05:00
parent 7380a3b79a
commit 5ae2a91c04
5 changed files with 83 additions and 68 deletions

View File

@@ -13,6 +13,7 @@ from exllamav2.generator import(
ExLlamaV2Sampler
)
from typing import List, Optional, Union
from utils import coalesce, unwrap
# Bytes to reserve on first device when loading with auto split
auto_split_reserve_bytes = 96 * 1024**2
@@ -30,7 +31,7 @@ class ModelContainer:
cache_fp8: bool = False
gpu_split_auto: bool = True
gpu_split: list or None = None
gpu_split: Optional[list] = None
active_loras: List[ExLlamaV2Lora] = []
@@ -68,7 +69,7 @@ class ModelContainer:
self.cache_fp8 = "cache_mode" in kwargs and kwargs["cache_mode"] == "FP8"
self.gpu_split = kwargs.get("gpu_split")
self.gpu_split_auto = kwargs.get("gpu_split_auto") or True
self.gpu_split_auto = unwrap(kwargs.get("gpu_split_auto"), True)
self.config = ExLlamaV2Config()
self.config.model_dir = str(model_directory.resolve())
@@ -78,14 +79,14 @@ class ModelContainer:
base_seq_len = self.config.max_seq_len
# Then override the max_seq_len if present
self.config.max_seq_len = kwargs.get("max_seq_len") or 4096
self.config.scale_pos_emb = kwargs.get("rope_scale") or 1.0
self.config.max_seq_len = unwrap(kwargs.get("max_seq_len"), 4096)
self.config.scale_pos_emb = unwrap(kwargs.get("rope_scale"), 1.0)
# Automatically calculate rope alpha
self.config.scale_alpha_value = kwargs.get("rope_alpha") or self.calculate_rope_alpha(base_seq_len)
self.config.scale_alpha_value = unwrap(kwargs.get("rope_alpha"), self.calculate_rope_alpha(base_seq_len))
# Turn off flash attention?
self.config.no_flash_attn = kwargs.get("no_flash_attn") or False
self.config.no_flash_attn = unwrap(kwargs.get("no_flash_attn"), False)
# low_mem is currently broken in exllamav2. Don't use it until it's fixed.
"""
@@ -93,11 +94,11 @@ class ModelContainer:
self.config.set_low_mem()
"""
chunk_size = min(kwargs.get("chunk_size") or 2048, self.config.max_seq_len)
chunk_size = min(unwrap(kwargs.get("chunk_size"), 2048), self.config.max_seq_len)
self.config.max_input_len = chunk_size
self.config.max_attn_size = chunk_size ** 2
draft_args = kwargs.get("draft") or {}
draft_args = unwrap(kwargs.get("draft"), {})
draft_model_name = draft_args.get("draft_model_name")
enable_draft = draft_args and draft_model_name
@@ -109,14 +110,14 @@ class ModelContainer:
if enable_draft:
self.draft_config = ExLlamaV2Config()
draft_model_path = pathlib.Path(draft_args.get("draft_model_dir") or "models")
draft_model_path = pathlib.Path(unwrap(draft_args.get("draft_model_dir"), "models"))
draft_model_path = draft_model_path / draft_model_name
self.draft_config.model_dir = str(draft_model_path.resolve())
self.draft_config.prepare()
self.draft_config.scale_pos_emb = draft_args.get("draft_rope_scale") or 1.0
self.draft_config.scale_alpha_value = draft_args.get("draft_rope_alpha") or self.calculate_rope_alpha(self.draft_config.max_seq_len)
self.draft_config.scale_pos_emb = unwrap(draft_args.get("draft_rope_scale"), 1.0)
self.draft_config.scale_alpha_value = unwrap(draft_args.get("draft_rope_alpha"), self.calculate_rope_alpha(self.draft_config.max_seq_len))
self.draft_config.max_seq_len = self.config.max_seq_len
if "chunk_size" in kwargs:
@@ -151,13 +152,13 @@ class ModelContainer:
Load loras
"""
loras = kwargs.get("loras") or []
loras = unwrap(kwargs.get("loras"), [])
success: List[str] = []
failure: List[str] = []
for lora in loras:
lora_name = lora.get("name") or None
lora_scaling = lora.get("scaling") or 1.0
lora_name = lora.get("name")
lora_scaling = unwrap(lora.get("scaling"), 1.0)
if lora_name is None:
print("One of your loras does not have a name. Please check your config.yml! Skipping lora load.")
@@ -265,13 +266,13 @@ class ModelContainer:
# Assume token encoding
return self.tokenizer.encode(
text,
add_bos = kwargs.get("add_bos_token") or True,
encode_special_tokens = kwargs.get("encode_special_tokens") or True
add_bos = unwrap(kwargs.get("add_bos_token"), True),
encode_special_tokens = unwrap(kwargs.get("encode_special_tokens"), True)
)
if ids:
# Assume token decoding
ids = torch.tensor([ids])
return self.tokenizer.decode(ids, decode_special_tokens = kwargs.get("decode_special_tokens") or True)[0]
return self.tokenizer.decode(ids, decode_special_tokens = unwrap(kwargs.get("decode_special_tokens"), True))[0]
def generate(self, prompt: str, **kwargs):
@@ -311,10 +312,10 @@ class ModelContainer:
"""
token_healing = kwargs.get("token_healing") or False
max_tokens = kwargs.get("max_tokens") or 150
stream_interval = kwargs.get("stream_interval") or 0
generate_window = min(kwargs.get("generate_window") or 512, max_tokens)
token_healing = unwrap(kwargs.get("token_healing"), False)
max_tokens = unwrap(kwargs.get("max_tokens"), 150)
stream_interval = unwrap(kwargs.get("stream_interval"), 0)
generate_window = min(unwrap(kwargs.get("generate_window"), 512), max_tokens)
# Sampler settings
@@ -322,42 +323,43 @@ class ModelContainer:
# Warn of unsupported settings if the setting is enabled
if (kwargs.get("mirostat") or False) and not hasattr(gen_settings, "mirostat"):
if (unwrap(kwargs.get("mirostat"), False)) and not hasattr(gen_settings, "mirostat"):
print(" !! Warning: Currently installed ExLlamaV2 does not support Mirostat sampling")
if (kwargs.get("min_p") or 0.0) not in [0.0, 1.0] and not hasattr(gen_settings, "min_p"):
if (unwrap(kwargs.get("min_p"), 0.0)) not in [0.0, 1.0] and not hasattr(gen_settings, "min_p"):
print(" !! Warning: Currently installed ExLlamaV2 does not support min-P sampling")
if (kwargs.get("tfs") or 0.0) not in [0.0, 1.0] and not hasattr(gen_settings, "tfs"):
if (unwrap(kwargs.get("tfs"), 0.0)) not in [0.0, 1.0] and not hasattr(gen_settings, "tfs"):
print(" !! Warning: Currently installed ExLlamaV2 does not support tail-free sampling (TFS)")
if (kwargs.get("temperature_last") or False) and not hasattr(gen_settings, "temperature_last"):
if (unwrap(kwargs.get("temperature_last"), False)) and not hasattr(gen_settings, "temperature_last"):
print(" !! Warning: Currently installed ExLlamaV2 does not support temperature_last")
#Apply settings
gen_settings.temperature = kwargs.get("temperature") or 1.0
gen_settings.temperature_last = kwargs.get("temperature_last") or False
gen_settings.top_k = kwargs.get("top_k") or 0
gen_settings.top_p = kwargs.get("top_p") or 1.0
gen_settings.min_p = kwargs.get("min_p") or 0.0
gen_settings.tfs = kwargs.get("tfs") or 1.0
gen_settings.typical = kwargs.get("typical") or 1.0
gen_settings.mirostat = kwargs.get("mirostat") or False
gen_settings.temperature = unwrap(kwargs.get("temperature"), 1.0)
gen_settings.temperature_last = unwrap(kwargs.get("temperature_last"), False)
gen_settings.top_k = unwrap(kwargs.get("top_k"), 0)
gen_settings.top_p = unwrap(kwargs.get("top_p"), 1.0)
gen_settings.min_p = unwrap(kwargs.get("min_p"), 0.0)
gen_settings.tfs = unwrap(kwargs.get("tfs"), 1.0)
gen_settings.typical = unwrap(kwargs.get("typical"), 1.0)
gen_settings.mirostat = unwrap(kwargs.get("mirostat"), False)
# Default tau and eta fallbacks don't matter if mirostat is off
gen_settings.mirostat_tau = kwargs.get("mirostat_tau") or 1.5
gen_settings.mirostat_eta = kwargs.get("mirostat_eta") or 0.1
gen_settings.token_repetition_penalty = kwargs.get("repetition_penalty") or 1.0
gen_settings.token_repetition_range = kwargs.get("repetition_range") or self.config.max_seq_len
gen_settings.mirostat_tau = unwrap(kwargs.get("mirostat_tau"), 1.5)
gen_settings.mirostat_eta = unwrap(kwargs.get("mirostat_eta"), 0.1)
gen_settings.token_repetition_penalty = unwrap(kwargs.get("repetition_penalty"), 1.0)
gen_settings.token_repetition_range = unwrap(kwargs.get("repetition_range"), self.config.max_seq_len)
# Always make sure the fallback is 0 if range < 0
# It's technically fine to use -1, but this just validates the passed fallback
# Always default to 0 if something goes wrong
fallback_decay = 0 if gen_settings.token_repetition_range <= 0 else gen_settings.token_repetition_range
gen_settings.token_repetition_decay = kwargs.get("repetition_decay") or fallback_decay or 0
gen_settings.token_repetition_decay = coalesce(kwargs.get("repetition_decay"), fallback_decay, 0)
stop_conditions: List[Union[str, int]] = kwargs.get("stop") or []
ban_eos_token = kwargs.get("ban_eos_token") or False
stop_conditions: List[Union[str, int]] = unwrap(kwargs.get("stop"), [])
ban_eos_token = unwrap(kwargs.get("ban_eos_token"), False)
# Ban the EOS token if specified. If not, append to stop conditions as well.
@@ -383,7 +385,7 @@ class ModelContainer:
ids = self.tokenizer.encode(
prompt,
add_bos = kwargs.get("add_bos_token") or True,
add_bos = unwrap(kwargs.get("add_bos_token"), True),
encode_special_tokens = True
)
context_len = len(ids[0])