diff --git a/OAI/types/model.py b/OAI/types/model.py index 7072840..3ee0f02 100644 --- a/OAI/types/model.py +++ b/OAI/types/model.py @@ -34,8 +34,9 @@ class DraftModelLoadRequest(BaseModel): class ModelLoadRequest(BaseModel): name: str - # Max seq len is defaulted when loading the model itself - max_seq_len: Optional[int] = None + # Max seq len is fetched from config.json of the model by default + max_seq_len: Optional[int] = Field(description = "Leave this blank to use the model's base sequence length", default = None) + override_base_seq_len: Optional[int] = Field(description = "Overrides the model's base sequence length. Leave blank if unsure", default = None) gpu_split_auto: Optional[bool] = True gpu_split: Optional[List[float]] = Field(default_factory=list) rope_scale: Optional[float] = 1.0 diff --git a/config_sample.yml b/config_sample.yml index fb17438..8b0817c 100644 --- a/config_sample.yml +++ b/config_sample.yml @@ -37,9 +37,15 @@ model: # The below parameters apply only if model_name is set - # Override maximum model context length (default: None) + # Max sequence length (default: None) + # Fetched from the model's base sequence length in config.json by default max_seq_len: + # Overrides base model context length (default: None) + # WARNING: Don't set this unless you know what you're doing! + # Only use this if the model's base sequence length in config.json is incorrect (ex. Mistral/Mixtral models) + override_base_seq_len: + # Automatically allocate resources to GPUs (default: True) gpu_split_auto: True diff --git a/model.py b/model.py index e2076a1..a5cb436 100644 --- a/model.py +++ b/model.py @@ -85,14 +85,19 @@ class ModelContainer: self.config.max_seq_len = 4096 self.config.prepare() - # Then override the max_seq_len if present - override_max_seq_len = kwargs.get("max_seq_len") - if override_max_seq_len: - self.config.max_seq_len = kwargs.get("max_seq_len") + # Then override the base_seq_len if present + override_base_seq_len = kwargs.get("override_base_seq_len") + if override_base_seq_len: + self.config.max_seq_len = override_base_seq_len # Grab the base model's sequence length before overrides for rope calculations base_seq_len = self.config.max_seq_len + # Set the target seq len if present + target_max_seq_len = kwargs.get("max_seq_len") + if target_max_seq_len: + self.config.max_seq_len = target_max_seq_len + # Set the rope scale self.config.scale_pos_emb = unwrap(kwargs.get("rope_scale"), 1.0)