OAI: Add validation to "n"

n must be greater than 1 to generate.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri
2024-05-28 00:34:32 -04:00
committed by Brian Dashore
parent e2a8b6e8ae
commit e95e67a000

View File

@@ -49,6 +49,13 @@ class CommonCompletionRequest(BaseSamplerRequest):
description="Not parsed. Only used for OAI compliance.", default=None
)
def validate_params(self):
# Temperature
if self.n < 1:
raise ValueError(f"n must be greater than or equal to 1. Got {self.n}")
return super().validate_params()
def to_gen_params(self):
extra_gen_params = {
"stream": self.stream,