mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-03-14 15:57:27 +00:00
Model: Reorder how configs are set up
Initialize the Exllama classes first then add user-specific params. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
@@ -108,6 +108,61 @@ class ExllamaV2Container:
|
||||
|
||||
self.quiet = quiet
|
||||
|
||||
# Initialize config
|
||||
self.config = ExLlamaV2Config()
|
||||
self.config.model_dir = str(model_directory.resolve())
|
||||
|
||||
# Make the max seq len 4096 before preparing the config
|
||||
# This is a better default than 2048
|
||||
self.config.max_seq_len = 4096
|
||||
|
||||
self.config.prepare()
|
||||
|
||||
# Check if the model arch is compatible with various exl2 features
|
||||
self.config.arch_compat_overrides()
|
||||
|
||||
# Prepare the draft model config if necessary
|
||||
draft_args = unwrap(kwargs.get("draft"), {})
|
||||
draft_model_name = draft_args.get("draft_model_name")
|
||||
enable_draft = draft_args and draft_model_name
|
||||
|
||||
# Always disable draft if params are incorrectly configured
|
||||
if draft_args and draft_model_name is None:
|
||||
logger.warning(
|
||||
"Draft model is disabled because a model name "
|
||||
"wasn't provided. Please check your config.yml!"
|
||||
)
|
||||
enable_draft = False
|
||||
|
||||
if enable_draft:
|
||||
self.draft_config = ExLlamaV2Config()
|
||||
self.draft_config.no_flash_attn = self.config.no_flash_attn
|
||||
draft_model_path = pathlib.Path(
|
||||
unwrap(draft_args.get("draft_model_dir"), "models")
|
||||
)
|
||||
draft_model_path = draft_model_path / draft_model_name
|
||||
|
||||
self.draft_config.model_dir = str(draft_model_path.resolve())
|
||||
self.draft_config.prepare()
|
||||
|
||||
# Create the hf_config
|
||||
self.hf_config = HuggingFaceConfig.from_file(model_directory)
|
||||
|
||||
# Load generation config overrides
|
||||
generation_config_path = model_directory / "generation_config.json"
|
||||
if generation_config_path.exists():
|
||||
try:
|
||||
self.generation_config = GenerationConfig.from_file(
|
||||
generation_config_path.parent
|
||||
)
|
||||
except Exception:
|
||||
logger.error(traceback.format_exc())
|
||||
logger.warning(
|
||||
"Skipping generation config load because of an unexpected error."
|
||||
)
|
||||
|
||||
# MARK: User configuration
|
||||
|
||||
# Get cache mode
|
||||
self.cache_mode = unwrap(kwargs.get("cache_mode"), "FP16")
|
||||
|
||||
@@ -161,24 +216,9 @@ class ExllamaV2Container:
|
||||
for value in autosplit_reserve_megabytes
|
||||
]
|
||||
|
||||
self.config = ExLlamaV2Config()
|
||||
self.config.model_dir = str(model_directory.resolve())
|
||||
|
||||
# Make the max seq len 4096 before preparing the config
|
||||
# This is a better default than 2048
|
||||
self.config.max_seq_len = 4096
|
||||
|
||||
# Hardcode max output length to 16
|
||||
self.config.max_output_len = 16
|
||||
|
||||
self.config.prepare()
|
||||
|
||||
# Check if the model arch is compatible with various exl2 features
|
||||
self.config.arch_compat_overrides()
|
||||
|
||||
# Create the hf_config
|
||||
self.hf_config = HuggingFaceConfig.from_file(model_directory)
|
||||
|
||||
# Then override the base_seq_len if present
|
||||
override_base_seq_len = kwargs.get("override_base_seq_len")
|
||||
if override_base_seq_len:
|
||||
@@ -264,19 +304,6 @@ class ExllamaV2Container:
|
||||
else:
|
||||
self.cache_size = self.config.max_seq_len
|
||||
|
||||
# Load generation config overrides
|
||||
generation_config_path = model_directory / "generation_config.json"
|
||||
if generation_config_path.exists():
|
||||
try:
|
||||
self.generation_config = GenerationConfig.from_file(
|
||||
generation_config_path.parent
|
||||
)
|
||||
except Exception:
|
||||
logger.error(traceback.format_exc())
|
||||
logger.warning(
|
||||
"Skipping generation config load because of an unexpected error."
|
||||
)
|
||||
|
||||
# Try to set prompt template
|
||||
self.prompt_template = self.find_prompt_template(
|
||||
kwargs.get("prompt_template"), model_directory
|
||||
@@ -304,29 +331,8 @@ class ExllamaV2Container:
|
||||
self.config.max_input_len = chunk_size
|
||||
self.config.max_attention_size = chunk_size**2
|
||||
|
||||
draft_args = unwrap(kwargs.get("draft"), {})
|
||||
draft_model_name = draft_args.get("draft_model_name")
|
||||
enable_draft = draft_args and draft_model_name
|
||||
|
||||
# Always disable draft if params are incorrectly configured
|
||||
if draft_args and draft_model_name is None:
|
||||
logger.warning(
|
||||
"Draft model is disabled because a model name "
|
||||
"wasn't provided. Please check your config.yml!"
|
||||
)
|
||||
enable_draft = False
|
||||
|
||||
# Set user-configured draft model values
|
||||
if enable_draft:
|
||||
self.draft_config = ExLlamaV2Config()
|
||||
self.draft_config.no_flash_attn = self.config.no_flash_attn
|
||||
draft_model_path = pathlib.Path(
|
||||
unwrap(draft_args.get("draft_model_dir"), "models")
|
||||
)
|
||||
draft_model_path = draft_model_path / draft_model_name
|
||||
|
||||
self.draft_config.model_dir = str(draft_model_path.resolve())
|
||||
self.draft_config.prepare()
|
||||
|
||||
self.draft_config.scale_pos_emb = unwrap(
|
||||
draft_args.get("draft_rope_scale"), 1.0
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user