mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-04-28 10:11:39 +00:00
Sampling: Add XTC support
Matches with upstream. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
@@ -958,6 +958,13 @@ class ExllamaV2Container:
|
|||||||
Meant for dev wheels!
|
Meant for dev wheels!
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if unwrap(kwargs.get("xtc_probability"), 0.0) > 0.0 and not hasattr(
|
||||||
|
ExLlamaV2Sampler.Settings, "xtc_probability"
|
||||||
|
):
|
||||||
|
logger.warning(
|
||||||
|
"XTC is not supported by the currently " "installed ExLlamaV2 version."
|
||||||
|
)
|
||||||
|
|
||||||
return kwargs
|
return kwargs
|
||||||
|
|
||||||
async def generate_gen(
|
async def generate_gen(
|
||||||
@@ -1003,6 +1010,14 @@ class ExllamaV2Container:
|
|||||||
gen_settings.mirostat = unwrap(kwargs.get("mirostat"), False)
|
gen_settings.mirostat = unwrap(kwargs.get("mirostat"), False)
|
||||||
gen_settings.skew = unwrap(kwargs.get("skew"), 0)
|
gen_settings.skew = unwrap(kwargs.get("skew"), 0)
|
||||||
|
|
||||||
|
# XTC
|
||||||
|
xtc_probability = unwrap(kwargs.get("xtc_probability"), 0.0)
|
||||||
|
if xtc_probability > 0.0:
|
||||||
|
gen_settings.xtc_probability = xtc_probability
|
||||||
|
|
||||||
|
# 0.1 is the default for this value
|
||||||
|
gen_settings.xtc_threshold = unwrap(kwargs.get("xtc_threshold", 0.1))
|
||||||
|
|
||||||
# DynaTemp settings
|
# DynaTemp settings
|
||||||
max_temp = unwrap(kwargs.get("max_temp"), 1.0)
|
max_temp = unwrap(kwargs.get("max_temp"), 1.0)
|
||||||
min_temp = unwrap(kwargs.get("min_temp"), 1.0)
|
min_temp = unwrap(kwargs.get("min_temp"), 1.0)
|
||||||
|
|||||||
@@ -110,6 +110,14 @@ class BaseSamplerRequest(BaseModel):
|
|||||||
examples=[0.0],
|
examples=[0.0],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
xtc_probability: Optional[float] = Field(
|
||||||
|
default_factory=lambda: get_default_sampler_value("xtc_probability", 0.0),
|
||||||
|
)
|
||||||
|
|
||||||
|
xtc_threshold: Optional[float] = Field(
|
||||||
|
default_factory=lambda: get_default_sampler_value("xtc_threshold", 0.1)
|
||||||
|
)
|
||||||
|
|
||||||
frequency_penalty: Optional[float] = Field(
|
frequency_penalty: Optional[float] = Field(
|
||||||
default_factory=lambda: get_default_sampler_value("frequency_penalty", 0.0)
|
default_factory=lambda: get_default_sampler_value("frequency_penalty", 0.0)
|
||||||
)
|
)
|
||||||
@@ -366,6 +374,8 @@ class BaseSamplerRequest(BaseModel):
|
|||||||
"min_p": self.min_p,
|
"min_p": self.min_p,
|
||||||
"tfs": self.tfs,
|
"tfs": self.tfs,
|
||||||
"skew": self.skew,
|
"skew": self.skew,
|
||||||
|
"xtc_probability": self.xtc_probability,
|
||||||
|
"xtc_threshold": self.xtc_threshold,
|
||||||
"frequency_penalty": self.frequency_penalty,
|
"frequency_penalty": self.frequency_penalty,
|
||||||
"presence_penalty": self.presence_penalty,
|
"presence_penalty": self.presence_penalty,
|
||||||
"repetition_penalty": self.repetition_penalty,
|
"repetition_penalty": self.repetition_penalty,
|
||||||
|
|||||||
@@ -79,6 +79,12 @@ typical:
|
|||||||
skew:
|
skew:
|
||||||
override: 0.0
|
override: 0.0
|
||||||
force: false
|
force: false
|
||||||
|
xtc_probability:
|
||||||
|
override: 0.0
|
||||||
|
force: false
|
||||||
|
xtc_threshold:
|
||||||
|
override: 0.1
|
||||||
|
force: false
|
||||||
|
|
||||||
# MARK: Penalty settings
|
# MARK: Penalty settings
|
||||||
frequency_penalty:
|
frequency_penalty:
|
||||||
|
|||||||
Reference in New Issue
Block a user