Sampling: Add rudimentary DRY support

Adds DRY support based on the current exl2 dev API. Only change for
optimization is dry_max_ngram instead of using a closed range.

Currently, DRY range is aliased to dry_max_ngram.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri
2024-09-07 00:48:42 -04:00
parent d34756dc98
commit 05c3f1194f
2 changed files with 68 additions and 1 deletions

View File

@@ -7,6 +7,7 @@ import pathlib
import traceback
import torch
import uuid
from copy import deepcopy
from exllamav2 import (
ExLlamaV2,
ExLlamaV2Config,
@@ -944,6 +945,14 @@ class ExllamaV2Container:
Meant for dev wheels!
"""
if unwrap(kwargs.get("dry_allowed_length"), 0) > 0 and not hasattr(
ExLlamaV2Sampler.Settings, "dry_multiplier"
):
logger.warning(
"DRY sampling is not supported by the currently "
"installed ExLlamaV2 version."
)
return kwargs
async def generate_gen(
@@ -1035,6 +1044,7 @@ class ExllamaV2Container:
"Please use an ampere (30 series) or higher GPU for CFG support."
)
# Penalties
gen_settings.token_repetition_penalty = unwrap(
kwargs.get("repetition_penalty"), 1.0
)
@@ -1070,6 +1080,23 @@ class ExllamaV2Container:
kwargs.get("repetition_decay"), fallback_decay, 0
)
# DRY options
dry_allowed_length = unwrap(kwargs.get("dry_allowed_length"), 0)
# 0 = disabled
if dry_allowed_length:
gen_settings.dry_allowed_length = dry_allowed_length
gen_settings.dry_base = unwrap(kwargs.get("dry_base"), 2.0)
gen_settings.dry_multiplier = unwrap(kwargs.get("dry_multiplier"), 2.0)
gen_settings.dry_max_ngram = unwrap(kwargs.get("dry_max_ngram"), 20)
# Tokenize sequence breakers
dry_sequence_breakers_json = kwargs.get("dry_sequence_breakers")
if dry_sequence_breakers_json:
gen_settings.dry_sequence_breakers = {
self.encode_tokens(s)[-1] for s in dry_sequence_breakers_json
}
# Initialize grammar handler
grammar_handler = ExLlamaV2Grammar()
@@ -1130,7 +1157,8 @@ class ExllamaV2Container:
)
# Store the gen settings for logging purposes
gen_settings_log_dict = vars(gen_settings)
# Deepcopy to save a snapshot of vars
gen_settings_log_dict = deepcopy(vars(gen_settings))
# Set banned tokens
banned_tokens = unwrap(kwargs.get("banned_tokens"), [])