diff --git a/OAI/types/common.py b/OAI/types/common.py index 56ca7a5..eeced09 100644 --- a/OAI/types/common.py +++ b/OAI/types/common.py @@ -48,7 +48,7 @@ class CommonCompletionRequest(BaseModel): tfs: Optional[float] = 1.0 repetition_penalty: Optional[float] = 1.0 repetition_penalty_range: Optional[int] = 0 - repetition_slope: Optional[int] = 0 + repetition_decay: Optional[int] = 0 mirostat_mode: Optional[int] = 0 mirostat_tau: Optional[float] = 1.5 mirostat_eta: Optional[float] = 0.1 @@ -85,7 +85,7 @@ class CommonCompletionRequest(BaseModel): "tfs": self.tfs, "repetition_penalty": self.repetition_penalty, "repetition_range": self.repetition_range or self.repetition_penalty_range or -1, - "repetition_slope": self.repetition_slope, + "repetition_decay": self.repetition_decay, "mirostat": self.mirostat_mode == 2, "mirostat_tau": self.mirostat_tau, "mirostat_eta": self.mirostat_eta, diff --git a/model.py b/model.py index 756908c..da5fc5d 100644 --- a/model.py +++ b/model.py @@ -250,7 +250,7 @@ class ModelContainer: 'mirostat_eta' (float) Mirostat eta parameter (default: 0.1) 'repetition_penalty' (float): Token repetition/presence penalty (default: 1.15) 'repetition_range' (int): Repetition penalty range (default: whole context) - 'repetition_slope' (int): Repetition penalty range (default: same as range) + 'repetition_decay' (int): Repetition penalty range (default: same as range) 'stop' (List[Union[str, int]]): List of stop strings/tokens to end response (default: [EOS]) 'max_tokens' (int): Max no. tokens in response (default: 150) 'add_bos_token' (bool): Adds the BOS token to the start of the prompt (default: True) @@ -302,10 +302,11 @@ class ModelContainer: gen_settings.token_repetition_penalty = kwargs.get("repetition_penalty", 1.0) gen_settings.token_repetition_range = kwargs.get("repetition_range", self.config.max_seq_len) + # Always make sure the fallback is 0 if range < 0 # It's technically fine to use -1, but this just validates the passed fallback - fallback_slope = 0 if gen_settings.token_repetition_penalty <= 0 else gen_settings.token_repetition_range - gen_settings.token_repetition_decay = kwargs.get("repetition_slope", fallback_slope or 0) + fallback_decay = 0 if gen_settings.token_repetition_penalty <= 0 else gen_settings.token_repetition_range + gen_settings.token_repetition_decay = kwargs.get("repetition_decay", fallback_decay or 0) stop_conditions: List[Union[str, int]] = kwargs.get("stop", []) ban_eos_token = kwargs.get("ban_eos_token", False)