Model: Reorder how configs are set up

Initialize the Exllama classes first then add user-specific params.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri
2024-08-29 22:41:19 -04:00
committed by Brian Dashore
parent 21712578cf
commit 523709741c

View File

@@ -108,6 +108,61 @@ class ExllamaV2Container:
self.quiet = quiet
# Initialize config
self.config = ExLlamaV2Config()
self.config.model_dir = str(model_directory.resolve())
# Make the max seq len 4096 before preparing the config
# This is a better default than 2048
self.config.max_seq_len = 4096
self.config.prepare()
# Check if the model arch is compatible with various exl2 features
self.config.arch_compat_overrides()
# Prepare the draft model config if necessary
draft_args = unwrap(kwargs.get("draft"), {})
draft_model_name = draft_args.get("draft_model_name")
enable_draft = draft_args and draft_model_name
# Always disable draft if params are incorrectly configured
if draft_args and draft_model_name is None:
logger.warning(
"Draft model is disabled because a model name "
"wasn't provided. Please check your config.yml!"
)
enable_draft = False
if enable_draft:
self.draft_config = ExLlamaV2Config()
self.draft_config.no_flash_attn = self.config.no_flash_attn
draft_model_path = pathlib.Path(
unwrap(draft_args.get("draft_model_dir"), "models")
)
draft_model_path = draft_model_path / draft_model_name
self.draft_config.model_dir = str(draft_model_path.resolve())
self.draft_config.prepare()
# Create the hf_config
self.hf_config = HuggingFaceConfig.from_file(model_directory)
# Load generation config overrides
generation_config_path = model_directory / "generation_config.json"
if generation_config_path.exists():
try:
self.generation_config = GenerationConfig.from_file(
generation_config_path.parent
)
except Exception:
logger.error(traceback.format_exc())
logger.warning(
"Skipping generation config load because of an unexpected error."
)
# MARK: User configuration
# Get cache mode
self.cache_mode = unwrap(kwargs.get("cache_mode"), "FP16")
@@ -161,24 +216,9 @@ class ExllamaV2Container:
for value in autosplit_reserve_megabytes
]
self.config = ExLlamaV2Config()
self.config.model_dir = str(model_directory.resolve())
# Make the max seq len 4096 before preparing the config
# This is a better default than 2048
self.config.max_seq_len = 4096
# Hardcode max output length to 16
self.config.max_output_len = 16
self.config.prepare()
# Check if the model arch is compatible with various exl2 features
self.config.arch_compat_overrides()
# Create the hf_config
self.hf_config = HuggingFaceConfig.from_file(model_directory)
# Then override the base_seq_len if present
override_base_seq_len = kwargs.get("override_base_seq_len")
if override_base_seq_len:
@@ -264,19 +304,6 @@ class ExllamaV2Container:
else:
self.cache_size = self.config.max_seq_len
# Load generation config overrides
generation_config_path = model_directory / "generation_config.json"
if generation_config_path.exists():
try:
self.generation_config = GenerationConfig.from_file(
generation_config_path.parent
)
except Exception:
logger.error(traceback.format_exc())
logger.warning(
"Skipping generation config load because of an unexpected error."
)
# Try to set prompt template
self.prompt_template = self.find_prompt_template(
kwargs.get("prompt_template"), model_directory
@@ -304,29 +331,8 @@ class ExllamaV2Container:
self.config.max_input_len = chunk_size
self.config.max_attention_size = chunk_size**2
draft_args = unwrap(kwargs.get("draft"), {})
draft_model_name = draft_args.get("draft_model_name")
enable_draft = draft_args and draft_model_name
# Always disable draft if params are incorrectly configured
if draft_args and draft_model_name is None:
logger.warning(
"Draft model is disabled because a model name "
"wasn't provided. Please check your config.yml!"
)
enable_draft = False
# Set user-configured draft model values
if enable_draft:
self.draft_config = ExLlamaV2Config()
self.draft_config.no_flash_attn = self.config.no_flash_attn
draft_model_path = pathlib.Path(
unwrap(draft_args.get("draft_model_dir"), "models")
)
draft_model_path = draft_model_path / draft_model_name
self.draft_config.model_dir = str(draft_model_path.resolve())
self.draft_config.prepare()
self.draft_config.scale_pos_emb = unwrap(
draft_args.get("draft_rope_scale"), 1.0
)