Model: Store directory paths

Storing a pathlib type makes it easier to manipulate the model
directory path in the long run without constantly fetching it
from the config.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri
2024-08-29 22:49:20 -04:00
committed by Brian Dashore
parent 523709741c
commit dd55b99af5
4 changed files with 17 additions and 21 deletions

View File

@@ -63,6 +63,10 @@ except ImportError:
class ExllamaV2Container:
"""The model container class for ExLlamaV2 models."""
# Model directories
model_dir: pathlib.Path = pathlib.Path("models")
draft_model_dir: pathlib.Path = pathlib.Path("models")
# Exl2 vars
config: Optional[ExLlamaV2Config] = None
draft_config: Optional[ExLlamaV2Config] = None
@@ -110,6 +114,7 @@ class ExllamaV2Container:
# Initialize config
self.config = ExLlamaV2Config()
self.model_dir = model_directory
self.config.model_dir = str(model_directory.resolve())
# Make the max seq len 4096 before preparing the config
@@ -142,6 +147,7 @@ class ExllamaV2Container:
)
draft_model_path = draft_model_path / draft_model_name
self.draft_model_dir = draft_model_path
self.draft_config.model_dir = str(draft_model_path.resolve())
self.draft_config.prepare()
@@ -403,20 +409,9 @@ class ExllamaV2Container:
alpha = -0.13436 + 0.80541 * ratio + 0.28833 * ratio**2
return alpha
def get_model_path(self, is_draft: bool = False):
"""Get the path for this model."""
if is_draft and not self.draft_config:
return None
model_path = pathlib.Path(
self.draft_config.model_dir if is_draft else self.config.model_dir
)
return model_path
def get_model_parameters(self):
model_params = {
"name": self.get_model_path().name,
"name": self.model_dir.name,
"rope_scale": self.config.scale_pos_emb,
"rope_alpha": self.config.scale_alpha_value,
"max_seq_len": self.config.max_seq_len,
@@ -431,7 +426,7 @@ class ExllamaV2Container:
if self.draft_config:
draft_model_params = {
"name": self.get_model_path(is_draft=True).name,
"name": self.draft_model_dir.name,
"rope_scale": self.draft_config.scale_pos_emb,
"rope_alpha": self.draft_config.scale_alpha_value,
"max_seq_len": self.draft_config.max_seq_len,