diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index a8bca56..501b3ab 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -108,6 +108,61 @@ class ExllamaV2Container: self.quiet = quiet + # Initialize config + self.config = ExLlamaV2Config() + self.config.model_dir = str(model_directory.resolve()) + + # Make the max seq len 4096 before preparing the config + # This is a better default than 2048 + self.config.max_seq_len = 4096 + + self.config.prepare() + + # Check if the model arch is compatible with various exl2 features + self.config.arch_compat_overrides() + + # Prepare the draft model config if necessary + draft_args = unwrap(kwargs.get("draft"), {}) + draft_model_name = draft_args.get("draft_model_name") + enable_draft = draft_args and draft_model_name + + # Always disable draft if params are incorrectly configured + if draft_args and draft_model_name is None: + logger.warning( + "Draft model is disabled because a model name " + "wasn't provided. Please check your config.yml!" + ) + enable_draft = False + + if enable_draft: + self.draft_config = ExLlamaV2Config() + self.draft_config.no_flash_attn = self.config.no_flash_attn + draft_model_path = pathlib.Path( + unwrap(draft_args.get("draft_model_dir"), "models") + ) + draft_model_path = draft_model_path / draft_model_name + + self.draft_config.model_dir = str(draft_model_path.resolve()) + self.draft_config.prepare() + + # Create the hf_config + self.hf_config = HuggingFaceConfig.from_file(model_directory) + + # Load generation config overrides + generation_config_path = model_directory / "generation_config.json" + if generation_config_path.exists(): + try: + self.generation_config = GenerationConfig.from_file( + generation_config_path.parent + ) + except Exception: + logger.error(traceback.format_exc()) + logger.warning( + "Skipping generation config load because of an unexpected error." + ) + + # MARK: User configuration + # Get cache mode self.cache_mode = unwrap(kwargs.get("cache_mode"), "FP16") @@ -161,24 +216,9 @@ class ExllamaV2Container: for value in autosplit_reserve_megabytes ] - self.config = ExLlamaV2Config() - self.config.model_dir = str(model_directory.resolve()) - - # Make the max seq len 4096 before preparing the config - # This is a better default than 2048 - self.config.max_seq_len = 4096 - # Hardcode max output length to 16 self.config.max_output_len = 16 - self.config.prepare() - - # Check if the model arch is compatible with various exl2 features - self.config.arch_compat_overrides() - - # Create the hf_config - self.hf_config = HuggingFaceConfig.from_file(model_directory) - # Then override the base_seq_len if present override_base_seq_len = kwargs.get("override_base_seq_len") if override_base_seq_len: @@ -264,19 +304,6 @@ class ExllamaV2Container: else: self.cache_size = self.config.max_seq_len - # Load generation config overrides - generation_config_path = model_directory / "generation_config.json" - if generation_config_path.exists(): - try: - self.generation_config = GenerationConfig.from_file( - generation_config_path.parent - ) - except Exception: - logger.error(traceback.format_exc()) - logger.warning( - "Skipping generation config load because of an unexpected error." - ) - # Try to set prompt template self.prompt_template = self.find_prompt_template( kwargs.get("prompt_template"), model_directory @@ -304,29 +331,8 @@ class ExllamaV2Container: self.config.max_input_len = chunk_size self.config.max_attention_size = chunk_size**2 - draft_args = unwrap(kwargs.get("draft"), {}) - draft_model_name = draft_args.get("draft_model_name") - enable_draft = draft_args and draft_model_name - - # Always disable draft if params are incorrectly configured - if draft_args and draft_model_name is None: - logger.warning( - "Draft model is disabled because a model name " - "wasn't provided. Please check your config.yml!" - ) - enable_draft = False - + # Set user-configured draft model values if enable_draft: - self.draft_config = ExLlamaV2Config() - self.draft_config.no_flash_attn = self.config.no_flash_attn - draft_model_path = pathlib.Path( - unwrap(draft_args.get("draft_model_dir"), "models") - ) - draft_model_path = draft_model_path / draft_model_name - - self.draft_config.model_dir = str(draft_model_path.resolve()) - self.draft_config.prepare() - self.draft_config.scale_pos_emb = unwrap( draft_args.get("draft_rope_scale"), 1.0 )