API: Add min_tokens

Bans the EOS token until the generation reaches a minimum length. This will not prevent the model from otherwise ending the generation early by outputting other stop conditions.
This commit is contained in:
DocShotgun
2024-05-10 12:30:17 -07:00
parent 643b53e347
commit a1df22668b
3 changed files with 33 additions and 1 deletions

View File

@@ -18,6 +18,11 @@ class BaseSamplerRequest(BaseModel):
examples=[150],
)
min_tokens: Optional[int] = Field(
default_factory=lambda: get_default_sampler_value("min_tokens", 0),
examples=[0],
)
generate_window: Optional[int] = Field(
default_factory=lambda: get_default_sampler_value("generate_window"),
examples=[512],
@@ -260,6 +265,7 @@ class BaseSamplerRequest(BaseModel):
gen_params = {
"max_tokens": self.max_tokens,
"min_tokens": self.min_tokens,
"generate_window": self.generate_window,
"stop": self.stop,
"add_bos_token": self.add_bos_token,