From 4c0e686e7db401583c8b15a963c815417cd36176 Mon Sep 17 00:00:00 2001 From: kingbri Date: Tue, 5 Dec 2023 23:28:16 -0500 Subject: [PATCH] Model: Cleanup and fix fallbacks Use the standard "dict.get("key") or default" to handle fetching values from kwargs and get a fallback value without possible errors. Signed-off-by: kingbri --- model.py | 99 ++++++++++++++++++++++++++++---------------------------- 1 file changed, 49 insertions(+), 50 deletions(-) diff --git a/model.py b/model.py index af933cb..4f72b4d 100644 --- a/model.py +++ b/model.py @@ -1,6 +1,5 @@ import gc, time, pathlib import torch -from datetime import datetime from exllamav2 import( ExLlamaV2, ExLlamaV2Config, @@ -29,7 +28,6 @@ class ModelContainer: generator: Optional[ExLlamaV2StreamingGenerator] = None cache_fp8: bool = False - draft_enabled: bool = False gpu_split_auto: bool = True gpu_split: list or None = None @@ -44,7 +42,7 @@ class ModelContainer: def progress(loaded_modules: int, total_modules: int, loading_draft: bool) **kwargs: `cache_mode` (str): Sets cache mode, "FP16" or "FP8" (defaulf: "FP16") - 'max_seq_len' (int): Override model's default max sequence length + 'max_seq_len' (int): Override model's default max sequence length (default: 4096) 'rope_scale' (float): Set RoPE scaling factor for model (default: 1.0) 'rope_alpha' (float): Set RoPE alpha (NTK) factor for model (default: 1.0) 'chunk_size' (int): Sets the maximum chunk size for the model (default: 2048) @@ -52,19 +50,20 @@ class ModelContainer: batches. This limits the size of temporary buffers needed for the hidden state and attention weights. 'draft_model_dir' (str): Draft model directory + 'draft_rope_scale' (float): Set RoPE scaling factor for draft model (default: 1.0) 'draft_rope_alpha' (float): RoPE alpha (NTK) factor for draft model. By default, the draft model's alpha value is calculated automatically to scale to the size of the full model. 'gpu_split_auto' (bool): Automatically split model across available devices (default: True) 'gpu_split' (list[float]): Allocation for weights and (some) tensors, per device - 'no_flash_attn' (bool): Turns off flash attention (increases vram usage) + 'no_flash_attn' (bool): Turns off flash attention (increases vram usage) (default: False) """ self.quiet = quiet self.cache_fp8 = "cache_mode" in kwargs and kwargs["cache_mode"] == "FP8" - self.gpu_split = kwargs.get("gpu_split", None) - self.gpu_split_auto = kwargs.get("gpu_split_auto", True) + self.gpu_split = kwargs.get("gpu_split") + self.gpu_split_auto = kwargs.get("gpu_split_auto") or True self.config = ExLlamaV2Config() self.config.model_dir = str(model_directory.resolve()) @@ -74,13 +73,14 @@ class ModelContainer: base_seq_len = self.config.max_seq_len # Then override the max_seq_len if present - if "max_seq_len" in kwargs: self.config.max_seq_len = kwargs["max_seq_len"] - if "rope_scale" in kwargs: self.config.scale_pos_emb = kwargs["rope_scale"] + self.config.max_seq_len = kwargs.get("max_seq_len") or 4096 + self.config.scale_pos_emb = kwargs.get("rope_scale") or 1.0 # Automatically calculate rope alpha self.config.scale_alpha_value = kwargs.get("rope_alpha") or self.calculate_rope_alpha(base_seq_len) - if "no_flash_attn" in kwargs: self.config.no_flash_attn = kwargs["no_flash_attn"] + # Turn off flash attention? + self.config.no_flash_attn = kwargs.get("no_flash_attn") or False # low_mem is currently broken in exllamav2. Don't use it until it's fixed. """ @@ -88,31 +88,30 @@ class ModelContainer: self.config.set_low_mem() """ - chunk_size = min(kwargs.get("chunk_size", 2048), self.config.max_seq_len) + chunk_size = min(kwargs.get("chunk_size") or 2048, self.config.max_seq_len) self.config.max_input_len = chunk_size self.config.max_attn_size = chunk_size ** 2 - draft_config = kwargs.get("draft") or {} - draft_model_name = draft_config.get("draft_model_name") - enable_draft = bool(draft_config) and draft_model_name is not None + draft_args = kwargs.get("draft") or {} + draft_model_name = draft_args.get("draft_model_name") + enable_draft = draft_args and draft_model_name - if bool(draft_config) and draft_model_name is None: + # Always disable draft if params are incorrectly configured + if draft_args and draft_model_name is None: print("A draft config was found but a model name was not given. Please check your config.yml! Skipping draft load.") - self.draft_enabled = False - else: - self.draft_enabled = enable_draft + enable_draft = False - if self.draft_enabled: + if enable_draft: self.draft_config = ExLlamaV2Config() - draft_model_path = pathlib.Path(draft_config.get("draft_model_dir") or "models") + draft_model_path = pathlib.Path(draft_args.get("draft_model_dir") or "models") draft_model_path = draft_model_path / draft_model_name self.draft_config.model_dir = str(draft_model_path.resolve()) self.draft_config.prepare() - self.draft_config.scale_pos_emb = draft_config.get("draft_rope_scale") or 1.0 - self.draft_config.scale_alpha_value = draft_config.get("draft_rope_alpha") or self.calculate_rope_alpha(self.draft_config.max_seq_len) + self.draft_config.scale_pos_emb = draft_args.get("draft_rope_scale") or 1.0 + self.draft_config.scale_alpha_value = draft_args.get("draft_rope_alpha") or self.calculate_rope_alpha(self.draft_config.max_seq_len) self.draft_config.max_seq_len = self.config.max_seq_len if "chunk_size" in kwargs: @@ -156,9 +155,9 @@ class ModelContainer: self.tokenizer = ExLlamaV2Tokenizer(self.config) - # Load draft model + # Load draft model if a config is present - if self.draft_enabled: + if self.draft_config: self.draft_model = ExLlamaV2(self.draft_config) if not self.quiet: @@ -228,13 +227,13 @@ class ModelContainer: # Assume token encoding return self.tokenizer.encode( text, - add_bos = kwargs.get("add_bos_token", True), - encode_special_tokens = kwargs.get("encode_special_tokens", True) + add_bos = kwargs.get("add_bos_token") or True, + encode_special_tokens = kwargs.get("encode_special_tokens") or True ) if ids: # Assume token decoding ids = torch.tensor([ids]) - return self.tokenizer.decode(ids, decode_special_tokens = kwargs.get("decode_special_tokens", True))[0] + return self.tokenizer.decode(ids, decode_special_tokens = kwargs.get("decode_special_tokens") or True)[0] def generate(self, prompt: str, **kwargs): @@ -274,10 +273,10 @@ class ModelContainer: """ - token_healing = kwargs.get("token_healing", False) - max_tokens = kwargs.get("max_tokens", 150) - stream_interval = kwargs.get("stream_interval", 0) - generate_window = min(kwargs.get("generate_window", 512), max_tokens) + token_healing = kwargs.get("token_healing") or False + max_tokens = kwargs.get("max_tokens") or 150 + stream_interval = kwargs.get("stream_interval") or 0 + generate_window = min(kwargs.get("generate_window") or 512, max_tokens) # Sampler settings @@ -285,43 +284,43 @@ class ModelContainer: # Warn of unsupported settings if the setting is enabled - if kwargs.get("mirostat", False) and not hasattr(gen_settings, "mirostat"): + if kwargs.get("mirostat") or False and not hasattr(gen_settings, "mirostat"): print(" !! Warning: Currently installed ExLlamaV2 does not support Mirostat sampling") - if kwargs.get("min_p", 0.0) not in [0.0, 1.0] and not hasattr(gen_settings, "min_p"): + if kwargs.get("min_p") or 0.0 not in [0.0, 1.0] and not hasattr(gen_settings, "min_p"): print(" !! Warning: Currently installed ExLlamaV2 does not support min-P sampling") - if kwargs.get("tfs", 0.0) not in [0.0, 1.0] and not hasattr(gen_settings, "tfs"): + if kwargs.get("tfs") or 0.0 not in [0.0, 1.0] and not hasattr(gen_settings, "tfs"): print(" !! Warning: Currently installed ExLlamaV2 does not support tail-free sampling (TFS)") - if kwargs.get("temperature_last", False) and not hasattr(gen_settings, "temperature_last"): + if kwargs.get("temperature_last") or False and not hasattr(gen_settings, "temperature_last"): print(" !! Warning: Currently installed ExLlamaV2 does not support temperature_last") #Apply settings - gen_settings.temperature = kwargs.get("temperature", 1.0) - gen_settings.temperature_last = kwargs.get("temperature_last", False) - gen_settings.top_k = kwargs.get("top_k", 1) - gen_settings.top_p = kwargs.get("top_p", 1.0) - gen_settings.min_p = kwargs.get("min_p", 0.0) - gen_settings.tfs = kwargs.get("tfs", 0.0) - gen_settings.typical = kwargs.get("typical", 0.0) - gen_settings.mirostat = kwargs.get("mirostat", False) + gen_settings.temperature = kwargs.get("temperature") or 1.0 + gen_settings.temperature_last = kwargs.get("temperature_last") or False + gen_settings.top_k = kwargs.get("top_k") or 1 + gen_settings.top_p = kwargs.get("top_p") or 1.0 + gen_settings.min_p = kwargs.get("min_p") or 0.0 + gen_settings.tfs = kwargs.get("tfs") or 0.0 + gen_settings.typical = kwargs.get("typical") or 0.0 + gen_settings.mirostat = kwargs.get("mirostat") or False # Default tau and eta fallbacks don't matter if mirostat is off - gen_settings.mirostat_tau = kwargs.get("mirostat_tau", 1.5) - gen_settings.mirostat_eta = kwargs.get("mirostat_eta", 0.1) - gen_settings.token_repetition_penalty = kwargs.get("repetition_penalty", 1.0) - gen_settings.token_repetition_range = kwargs.get("repetition_range", self.config.max_seq_len) + gen_settings.mirostat_tau = kwargs.get("mirostat_tau") or 1.5 + gen_settings.mirostat_eta = kwargs.get("mirostat_eta") or 0.1 + gen_settings.token_repetition_penalty = kwargs.get("repetition_penalty") or 1.0 + gen_settings.token_repetition_range = kwargs.get("repetition_range") or self.config.max_seq_len # Always make sure the fallback is 0 if range < 0 # It's technically fine to use -1, but this just validates the passed fallback fallback_decay = 0 if gen_settings.token_repetition_penalty <= 0 else gen_settings.token_repetition_range - gen_settings.token_repetition_decay = kwargs.get("repetition_decay", fallback_decay or 0) + gen_settings.token_repetition_decay = kwargs.get("repetition_decay") or fallback_decay or 0 - stop_conditions: List[Union[str, int]] = kwargs.get("stop", []) - ban_eos_token = kwargs.get("ban_eos_token", False) + stop_conditions: List[Union[str, int]] = kwargs.get("stop") or [] + ban_eos_token = kwargs.get("ban_eos_token") or False # Ban the EOS token if specified. If not, append to stop conditions as well. @@ -347,7 +346,7 @@ class ModelContainer: ids = self.tokenizer.encode( prompt, - add_bos=kwargs.get("add_bos_token", True), + add_bos = kwargs.get("add_bos_token") or True, encode_special_tokens = True ) context_len = len(ids[0])