From bc21f0bbc052c6ec51392c9f7ec48c766b9d364d Mon Sep 17 00:00:00 2001 From: kingbri Date: Sun, 17 Dec 2023 21:57:46 -0500 Subject: [PATCH] OAI: Add field aliasing Repetition penalty range needs field aliases to support multiple parameter calls. Signed-off-by: kingbri --- OAI/types/common.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/OAI/types/common.py b/OAI/types/common.py index bc1e341..065ec5d 100644 --- a/OAI/types/common.py +++ b/OAI/types/common.py @@ -1,6 +1,6 @@ -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, AliasChoices from typing import List, Dict, Optional, Union -from utils import coalesce +from utils import unwrap class LogProbs(BaseModel): text_offset: List[int] = Field(default_factory=list) @@ -36,6 +36,7 @@ class CommonCompletionRequest(BaseModel): max_tokens: Optional[int] = 150 # Aliased to repetition_penalty + # TODO: Maybe make this an alias to rep pen frequency_penalty: Optional[float] = Field(description = "Aliased to Repetition Penalty", default = 0.0) # Sampling params @@ -56,20 +57,21 @@ class CommonCompletionRequest(BaseModel): ban_eos_token: Optional[bool] = False # Aliased variables - # TODO: Add a function to iterate through aliases and return a default value if all are None - repetition_range: Optional[int] = None - repetition_penalty_range: Optional[int] = None + repetition_range: Optional[int] = Field( + default = None, + validation_alias = AliasChoices('repetition_range', 'repetition_penalty_range') + ) # Converts to internal generation parameters def to_gen_params(self): # Convert stop to an array of strings if isinstance(self.stop, str): self.stop = [self.stop] - + # Set repetition_penalty to frequency_penalty if repetition_penalty isn't already defined if (self.repetition_penalty is None or self.repetition_penalty == 1.0) and self.frequency_penalty: self.repetition_penalty = self.frequency_penalty - + return { "stop": self.stop, "max_tokens": self.max_tokens, @@ -84,7 +86,7 @@ class CommonCompletionRequest(BaseModel): "min_p": self.min_p, "tfs": self.tfs, "repetition_penalty": self.repetition_penalty, - "repetition_range": coalesce(self.repetition_range, self.repetition_penalty_range, -1), + "repetition_range": unwrap(self.repetition_range, -1), "repetition_decay": self.repetition_decay, "mirostat": self.mirostat_mode == 2, "mirostat_tau": self.mirostat_tau,