diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index 7fe08db..1d80062 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -7,6 +7,7 @@ import pathlib import traceback import torch import uuid +from copy import deepcopy from exllamav2 import ( ExLlamaV2, ExLlamaV2Config, @@ -944,6 +945,14 @@ class ExllamaV2Container: Meant for dev wheels! """ + if unwrap(kwargs.get("dry_allowed_length"), 0) > 0 and not hasattr( + ExLlamaV2Sampler.Settings, "dry_multiplier" + ): + logger.warning( + "DRY sampling is not supported by the currently " + "installed ExLlamaV2 version." + ) + return kwargs async def generate_gen( @@ -1035,6 +1044,7 @@ class ExllamaV2Container: "Please use an ampere (30 series) or higher GPU for CFG support." ) + # Penalties gen_settings.token_repetition_penalty = unwrap( kwargs.get("repetition_penalty"), 1.0 ) @@ -1070,6 +1080,23 @@ class ExllamaV2Container: kwargs.get("repetition_decay"), fallback_decay, 0 ) + # DRY options + dry_allowed_length = unwrap(kwargs.get("dry_allowed_length"), 0) + + # 0 = disabled + if dry_allowed_length: + gen_settings.dry_allowed_length = dry_allowed_length + gen_settings.dry_base = unwrap(kwargs.get("dry_base"), 2.0) + gen_settings.dry_multiplier = unwrap(kwargs.get("dry_multiplier"), 2.0) + gen_settings.dry_max_ngram = unwrap(kwargs.get("dry_max_ngram"), 20) + + # Tokenize sequence breakers + dry_sequence_breakers_json = kwargs.get("dry_sequence_breakers") + if dry_sequence_breakers_json: + gen_settings.dry_sequence_breakers = { + self.encode_tokens(s)[-1] for s in dry_sequence_breakers_json + } + # Initialize grammar handler grammar_handler = ExLlamaV2Grammar() @@ -1130,7 +1157,8 @@ class ExllamaV2Container: ) # Store the gen settings for logging purposes - gen_settings_log_dict = vars(gen_settings) + # Deepcopy to save a snapshot of vars + gen_settings_log_dict = deepcopy(vars(gen_settings)) # Set banned tokens banned_tokens = unwrap(kwargs.get("banned_tokens"), []) diff --git a/common/sampling.py b/common/sampling.py index 56c5b34..a3bccb3 100644 --- a/common/sampling.py +++ b/common/sampling.py @@ -1,5 +1,6 @@ """Common functions for sampling parameters""" +import json import pathlib import yaml from copy import deepcopy @@ -140,6 +141,28 @@ class BaseSamplerRequest(BaseModel): default_factory=lambda: get_default_sampler_value("repetition_decay", 0) ) + dry_allowed_length: Optional[int] = Field( + default_factory=lambda: get_default_sampler_value("dry_allowed_length", 0) + ) + + dry_base: Optional[float] = Field( + default_factory=lambda: get_default_sampler_value("dry_base", 2.0) + ) + + dry_multiplier: Optional[float] = Field( + default_factory=lambda: get_default_sampler_value("dry_multiplier", 2.0) + ) + + # TODO: Remove these aliases + dry_max_ngram: Optional[int] = Field( + default_factory=lambda: get_default_sampler_value("dry_max_ngram", 20), + alias=AliasChoices("dry_max_ngram", "dry_penalty_last_n"), + ) + + dry_sequence_breakers: Optional[str] = Field( + default_factory=lambda: get_default_sampler_value("dry_sequence_breakers", []) + ) + mirostat_mode: Optional[int] = Field( default_factory=lambda: get_default_sampler_value("mirostat_mode", 0) ) @@ -305,6 +328,17 @@ class BaseSamplerRequest(BaseModel): int(x) for x in self.allowed_tokens.split(",") if x.isdigit() ] + # Convert sequence breakers into an array of strings + # NOTE: This sampler sucks to parse. + if self.dry_sequence_breakers: + if not self.dry_sequence_breakers.startswith("["): + self.dry_sequence_breakers = f"[{self.dry_sequence_breakers}]" + + try: + self.dry_sequence_breakers = json.loads(self.dry_sequence_breakers) + except Exception: + self.dry_sequence_breakers = [] + gen_params = { "max_tokens": self.max_tokens, "min_tokens": self.min_tokens, @@ -335,6 +369,11 @@ class BaseSamplerRequest(BaseModel): "presence_penalty": self.presence_penalty, "repetition_penalty": self.repetition_penalty, "penalty_range": self.penalty_range, + "dry_allowed_length": self.dry_allowed_length, + "dry_base": self.dry_base, + "dry_max_ngram": self.dry_max_ngram, + "dry_multiplier": self.dry_multiplier, + "dry_sequence_breakers": self.dry_sequence_breakers, "repetition_decay": self.repetition_decay, "mirostat": self.mirostat_mode == 2, "mirostat_tau": self.mirostat_tau,