Model: Cleanup and fix fallbacks

Use the standard "dict.get("key") or default" to handle fetching values
from kwargs and get a fallback value without possible errors.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri
2023-12-05 23:28:16 -05:00
parent 0ef2fe9b95
commit 4c0e686e7d

View File

@@ -1,6 +1,5 @@
import gc, time, pathlib
import torch
from datetime import datetime
from exllamav2 import(
ExLlamaV2,
ExLlamaV2Config,
@@ -29,7 +28,6 @@ class ModelContainer:
generator: Optional[ExLlamaV2StreamingGenerator] = None
cache_fp8: bool = False
draft_enabled: bool = False
gpu_split_auto: bool = True
gpu_split: list or None = None
@@ -44,7 +42,7 @@ class ModelContainer:
def progress(loaded_modules: int, total_modules: int, loading_draft: bool)
**kwargs:
`cache_mode` (str): Sets cache mode, "FP16" or "FP8" (defaulf: "FP16")
'max_seq_len' (int): Override model's default max sequence length
'max_seq_len' (int): Override model's default max sequence length (default: 4096)
'rope_scale' (float): Set RoPE scaling factor for model (default: 1.0)
'rope_alpha' (float): Set RoPE alpha (NTK) factor for model (default: 1.0)
'chunk_size' (int): Sets the maximum chunk size for the model (default: 2048)
@@ -52,19 +50,20 @@ class ModelContainer:
batches. This limits the size of temporary buffers needed for the hidden state and attention
weights.
'draft_model_dir' (str): Draft model directory
'draft_rope_scale' (float): Set RoPE scaling factor for draft model (default: 1.0)
'draft_rope_alpha' (float): RoPE alpha (NTK) factor for draft model.
By default, the draft model's alpha value is calculated automatically to scale to the size of the
full model.
'gpu_split_auto' (bool): Automatically split model across available devices (default: True)
'gpu_split' (list[float]): Allocation for weights and (some) tensors, per device
'no_flash_attn' (bool): Turns off flash attention (increases vram usage)
'no_flash_attn' (bool): Turns off flash attention (increases vram usage) (default: False)
"""
self.quiet = quiet
self.cache_fp8 = "cache_mode" in kwargs and kwargs["cache_mode"] == "FP8"
self.gpu_split = kwargs.get("gpu_split", None)
self.gpu_split_auto = kwargs.get("gpu_split_auto", True)
self.gpu_split = kwargs.get("gpu_split")
self.gpu_split_auto = kwargs.get("gpu_split_auto") or True
self.config = ExLlamaV2Config()
self.config.model_dir = str(model_directory.resolve())
@@ -74,13 +73,14 @@ class ModelContainer:
base_seq_len = self.config.max_seq_len
# Then override the max_seq_len if present
if "max_seq_len" in kwargs: self.config.max_seq_len = kwargs["max_seq_len"]
if "rope_scale" in kwargs: self.config.scale_pos_emb = kwargs["rope_scale"]
self.config.max_seq_len = kwargs.get("max_seq_len") or 4096
self.config.scale_pos_emb = kwargs.get("rope_scale") or 1.0
# Automatically calculate rope alpha
self.config.scale_alpha_value = kwargs.get("rope_alpha") or self.calculate_rope_alpha(base_seq_len)
if "no_flash_attn" in kwargs: self.config.no_flash_attn = kwargs["no_flash_attn"]
# Turn off flash attention?
self.config.no_flash_attn = kwargs.get("no_flash_attn") or False
# low_mem is currently broken in exllamav2. Don't use it until it's fixed.
"""
@@ -88,31 +88,30 @@ class ModelContainer:
self.config.set_low_mem()
"""
chunk_size = min(kwargs.get("chunk_size", 2048), self.config.max_seq_len)
chunk_size = min(kwargs.get("chunk_size") or 2048, self.config.max_seq_len)
self.config.max_input_len = chunk_size
self.config.max_attn_size = chunk_size ** 2
draft_config = kwargs.get("draft") or {}
draft_model_name = draft_config.get("draft_model_name")
enable_draft = bool(draft_config) and draft_model_name is not None
draft_args = kwargs.get("draft") or {}
draft_model_name = draft_args.get("draft_model_name")
enable_draft = draft_args and draft_model_name
if bool(draft_config) and draft_model_name is None:
# Always disable draft if params are incorrectly configured
if draft_args and draft_model_name is None:
print("A draft config was found but a model name was not given. Please check your config.yml! Skipping draft load.")
self.draft_enabled = False
else:
self.draft_enabled = enable_draft
enable_draft = False
if self.draft_enabled:
if enable_draft:
self.draft_config = ExLlamaV2Config()
draft_model_path = pathlib.Path(draft_config.get("draft_model_dir") or "models")
draft_model_path = pathlib.Path(draft_args.get("draft_model_dir") or "models")
draft_model_path = draft_model_path / draft_model_name
self.draft_config.model_dir = str(draft_model_path.resolve())
self.draft_config.prepare()
self.draft_config.scale_pos_emb = draft_config.get("draft_rope_scale") or 1.0
self.draft_config.scale_alpha_value = draft_config.get("draft_rope_alpha") or self.calculate_rope_alpha(self.draft_config.max_seq_len)
self.draft_config.scale_pos_emb = draft_args.get("draft_rope_scale") or 1.0
self.draft_config.scale_alpha_value = draft_args.get("draft_rope_alpha") or self.calculate_rope_alpha(self.draft_config.max_seq_len)
self.draft_config.max_seq_len = self.config.max_seq_len
if "chunk_size" in kwargs:
@@ -156,9 +155,9 @@ class ModelContainer:
self.tokenizer = ExLlamaV2Tokenizer(self.config)
# Load draft model
# Load draft model if a config is present
if self.draft_enabled:
if self.draft_config:
self.draft_model = ExLlamaV2(self.draft_config)
if not self.quiet:
@@ -228,13 +227,13 @@ class ModelContainer:
# Assume token encoding
return self.tokenizer.encode(
text,
add_bos = kwargs.get("add_bos_token", True),
encode_special_tokens = kwargs.get("encode_special_tokens", True)
add_bos = kwargs.get("add_bos_token") or True,
encode_special_tokens = kwargs.get("encode_special_tokens") or True
)
if ids:
# Assume token decoding
ids = torch.tensor([ids])
return self.tokenizer.decode(ids, decode_special_tokens = kwargs.get("decode_special_tokens", True))[0]
return self.tokenizer.decode(ids, decode_special_tokens = kwargs.get("decode_special_tokens") or True)[0]
def generate(self, prompt: str, **kwargs):
@@ -274,10 +273,10 @@ class ModelContainer:
"""
token_healing = kwargs.get("token_healing", False)
max_tokens = kwargs.get("max_tokens", 150)
stream_interval = kwargs.get("stream_interval", 0)
generate_window = min(kwargs.get("generate_window", 512), max_tokens)
token_healing = kwargs.get("token_healing") or False
max_tokens = kwargs.get("max_tokens") or 150
stream_interval = kwargs.get("stream_interval") or 0
generate_window = min(kwargs.get("generate_window") or 512, max_tokens)
# Sampler settings
@@ -285,43 +284,43 @@ class ModelContainer:
# Warn of unsupported settings if the setting is enabled
if kwargs.get("mirostat", False) and not hasattr(gen_settings, "mirostat"):
if kwargs.get("mirostat") or False and not hasattr(gen_settings, "mirostat"):
print(" !! Warning: Currently installed ExLlamaV2 does not support Mirostat sampling")
if kwargs.get("min_p", 0.0) not in [0.0, 1.0] and not hasattr(gen_settings, "min_p"):
if kwargs.get("min_p") or 0.0 not in [0.0, 1.0] and not hasattr(gen_settings, "min_p"):
print(" !! Warning: Currently installed ExLlamaV2 does not support min-P sampling")
if kwargs.get("tfs", 0.0) not in [0.0, 1.0] and not hasattr(gen_settings, "tfs"):
if kwargs.get("tfs") or 0.0 not in [0.0, 1.0] and not hasattr(gen_settings, "tfs"):
print(" !! Warning: Currently installed ExLlamaV2 does not support tail-free sampling (TFS)")
if kwargs.get("temperature_last", False) and not hasattr(gen_settings, "temperature_last"):
if kwargs.get("temperature_last") or False and not hasattr(gen_settings, "temperature_last"):
print(" !! Warning: Currently installed ExLlamaV2 does not support temperature_last")
#Apply settings
gen_settings.temperature = kwargs.get("temperature", 1.0)
gen_settings.temperature_last = kwargs.get("temperature_last", False)
gen_settings.top_k = kwargs.get("top_k", 1)
gen_settings.top_p = kwargs.get("top_p", 1.0)
gen_settings.min_p = kwargs.get("min_p", 0.0)
gen_settings.tfs = kwargs.get("tfs", 0.0)
gen_settings.typical = kwargs.get("typical", 0.0)
gen_settings.mirostat = kwargs.get("mirostat", False)
gen_settings.temperature = kwargs.get("temperature") or 1.0
gen_settings.temperature_last = kwargs.get("temperature_last") or False
gen_settings.top_k = kwargs.get("top_k") or 1
gen_settings.top_p = kwargs.get("top_p") or 1.0
gen_settings.min_p = kwargs.get("min_p") or 0.0
gen_settings.tfs = kwargs.get("tfs") or 0.0
gen_settings.typical = kwargs.get("typical") or 0.0
gen_settings.mirostat = kwargs.get("mirostat") or False
# Default tau and eta fallbacks don't matter if mirostat is off
gen_settings.mirostat_tau = kwargs.get("mirostat_tau", 1.5)
gen_settings.mirostat_eta = kwargs.get("mirostat_eta", 0.1)
gen_settings.token_repetition_penalty = kwargs.get("repetition_penalty", 1.0)
gen_settings.token_repetition_range = kwargs.get("repetition_range", self.config.max_seq_len)
gen_settings.mirostat_tau = kwargs.get("mirostat_tau") or 1.5
gen_settings.mirostat_eta = kwargs.get("mirostat_eta") or 0.1
gen_settings.token_repetition_penalty = kwargs.get("repetition_penalty") or 1.0
gen_settings.token_repetition_range = kwargs.get("repetition_range") or self.config.max_seq_len
# Always make sure the fallback is 0 if range < 0
# It's technically fine to use -1, but this just validates the passed fallback
fallback_decay = 0 if gen_settings.token_repetition_penalty <= 0 else gen_settings.token_repetition_range
gen_settings.token_repetition_decay = kwargs.get("repetition_decay", fallback_decay or 0)
gen_settings.token_repetition_decay = kwargs.get("repetition_decay") or fallback_decay or 0
stop_conditions: List[Union[str, int]] = kwargs.get("stop", [])
ban_eos_token = kwargs.get("ban_eos_token", False)
stop_conditions: List[Union[str, int]] = kwargs.get("stop") or []
ban_eos_token = kwargs.get("ban_eos_token") or False
# Ban the EOS token if specified. If not, append to stop conditions as well.
@@ -347,7 +346,7 @@ class ModelContainer:
ids = self.tokenizer.encode(
prompt,
add_bos=kwargs.get("add_bos_token", True),
add_bos = kwargs.get("add_bos_token") or True,
encode_special_tokens = True
)
context_len = len(ids[0])