mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-04-22 15:28:56 +00:00
Sampling: Rewrite mirostat_mode parameter
Apparently the "mirostat" parameter has been updated by frontends to pass a number. ExllamaV2 expects a boolean, but most pass a number anyway, so just alias mirostat_mode and mirostat together. Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com>
This commit is contained in:
@@ -1091,7 +1091,7 @@ class ExllamaV2Container:
|
|||||||
gen_settings.min_p = unwrap(kwargs.get("min_p"), 0.0)
|
gen_settings.min_p = unwrap(kwargs.get("min_p"), 0.0)
|
||||||
gen_settings.tfs = unwrap(kwargs.get("tfs"), 1.0)
|
gen_settings.tfs = unwrap(kwargs.get("tfs"), 1.0)
|
||||||
gen_settings.typical = unwrap(kwargs.get("typical"), 1.0)
|
gen_settings.typical = unwrap(kwargs.get("typical"), 1.0)
|
||||||
gen_settings.mirostat = unwrap(kwargs.get("mirostat"), False)
|
gen_settings.mirostat = unwrap(kwargs.get("mirostat_mode"), 0) == 2
|
||||||
gen_settings.skew = unwrap(kwargs.get("skew"), 0)
|
gen_settings.skew = unwrap(kwargs.get("skew"), 0)
|
||||||
|
|
||||||
# XTC
|
# XTC
|
||||||
|
|||||||
@@ -195,10 +195,9 @@ class BaseSamplerRequest(BaseModel):
|
|||||||
default_factory=lambda: get_default_sampler_value("dry_sequence_breakers", [])
|
default_factory=lambda: get_default_sampler_value("dry_sequence_breakers", [])
|
||||||
)
|
)
|
||||||
|
|
||||||
mirostat: Optional[bool] = False
|
|
||||||
|
|
||||||
mirostat_mode: Optional[int] = Field(
|
mirostat_mode: Optional[int] = Field(
|
||||||
default_factory=lambda: get_default_sampler_value("mirostat_mode", 0)
|
default_factory=lambda: get_default_sampler_value("mirostat_mode", 0),
|
||||||
|
alias=AliasChoices("mirostat_mode", "mirostat"),
|
||||||
)
|
)
|
||||||
|
|
||||||
mirostat_tau: Optional[float] = Field(
|
mirostat_tau: Optional[float] = Field(
|
||||||
@@ -325,15 +324,6 @@ class BaseSamplerRequest(BaseModel):
|
|||||||
)
|
)
|
||||||
return [] # Return empty list if parsing fails
|
return [] # Return empty list if parsing fails
|
||||||
|
|
||||||
@field_validator("mirostat_mode", mode="before")
|
|
||||||
def convert_mirostat(cls, v, field_info):
|
|
||||||
"""Mirostat is enabled if mirostat_mode == 2."""
|
|
||||||
|
|
||||||
if v == 2:
|
|
||||||
field_info.data["mirostat"] = True
|
|
||||||
|
|
||||||
return v
|
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def after_validate(self):
|
def after_validate(self):
|
||||||
# FIXME: find a better way to register this
|
# FIXME: find a better way to register this
|
||||||
|
|||||||
Reference in New Issue
Block a user