diff --git a/OAI/types/model.py b/OAI/types/model.py index 08715ad..7072840 100644 --- a/OAI/types/model.py +++ b/OAI/types/model.py @@ -4,7 +4,8 @@ from typing import List, Optional from gen_logging import LogConfig class ModelCardParameters(BaseModel): - max_seq_len: Optional[int] = 4096 + # Safe to do this since it's guaranteed to fetch a max seq len from model_container + max_seq_len: Optional[int] = None rope_scale: Optional[float] = 1.0 rope_alpha: Optional[float] = 1.0 cache_mode: Optional[str] = "FP16" @@ -32,7 +33,9 @@ class DraftModelLoadRequest(BaseModel): # TODO: Unify this with ModelCardParams class ModelLoadRequest(BaseModel): name: str - max_seq_len: Optional[int] = 4096 + + # Max seq len is defaulted when loading the model itself + max_seq_len: Optional[int] = None gpu_split_auto: Optional[bool] = True gpu_split: Optional[List[float]] = Field(default_factory=list) rope_scale: Optional[float] = 1.0 diff --git a/config_sample.yml b/config_sample.yml index 15ce81b..fb17438 100644 --- a/config_sample.yml +++ b/config_sample.yml @@ -37,8 +37,8 @@ model: # The below parameters apply only if model_name is set - # Maximum model context length (default: 4096) - max_seq_len: 4096 + # Override maximum model context length (default: None) + max_seq_len: # Automatically allocate resources to GPUs (default: True) gpu_split_auto: True diff --git a/model.py b/model.py index e39a478..e2076a1 100644 --- a/model.py +++ b/model.py @@ -79,13 +79,21 @@ class ModelContainer: self.config = ExLlamaV2Config() self.config.model_dir = str(model_directory.resolve()) + + # Make the max seq len 4096 before preparing the config + # This is a better default than 2038 + self.config.max_seq_len = 4096 self.config.prepare() + # Then override the max_seq_len if present + override_max_seq_len = kwargs.get("max_seq_len") + if override_max_seq_len: + self.config.max_seq_len = kwargs.get("max_seq_len") + # Grab the base model's sequence length before overrides for rope calculations base_seq_len = self.config.max_seq_len - # Then override the max_seq_len if present - self.config.max_seq_len = unwrap(kwargs.get("max_seq_len"), 4096) + # Set the rope scale self.config.scale_pos_emb = unwrap(kwargs.get("rope_scale"), 1.0) # Automatically calculate rope alpha