OAI: Add "n" for non-streaming generations

This adds the ability to add multiple choices to a generation. This
is only available for non-streaming gens for now, it requires some
more work to port over to streaming.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri
2024-05-26 22:22:51 -04:00
committed by Brian Dashore
parent 8d31a5aed1
commit b944f8d756
3 changed files with 89 additions and 64 deletions

View File

@@ -3,7 +3,7 @@
from pydantic import BaseModel, Field
from typing import Optional
from common.sampling import BaseSamplerRequest
from common.sampling import BaseSamplerRequest, get_default_sampler_value
class UsageStats(BaseModel):
@@ -27,10 +27,13 @@ class CommonCompletionRequest(BaseSamplerRequest):
# Generation info (remainder is in BaseSamplerRequest superclass)
stream: Optional[bool] = False
logprobs: Optional[int] = 0
logprobs: Optional[int] = Field(
default_factory=lambda: get_default_sampler_value("logprobs", 0)
)
response_format: Optional[CompletionResponseFormat] = Field(
default_factory=CompletionResponseFormat
)
n: Optional[int] = Field(default_factory=lambda: get_default_sampler_value("n", 1))
# Extra OAI request stuff
best_of: Optional[int] = Field(
@@ -39,9 +42,6 @@ class CommonCompletionRequest(BaseSamplerRequest):
echo: Optional[bool] = Field(
description="Not parsed. Only used for OAI compliance.", default=False
)
n: Optional[int] = Field(
description="Not parsed. Only used for OAI compliance.", default=1
)
suffix: Optional[str] = Field(
description="Not parsed. Only used for OAI compliance.", default=None
)