Model: Add support for HuggingFace config and bad_words_ids

This is necessary for Kobold's API. Current models use bad_words_ids
in generation_config.json, but for some reason, they're also present
in the model's config.json.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri
2024-07-26 18:23:22 -04:00
parent 545e26608f
commit 7522b1447b
4 changed files with 68 additions and 9 deletions

View File

@@ -1,7 +1,9 @@
from typing import List, Optional
from pydantic import BaseModel, Field
from common import model
from common.sampling import BaseSamplerRequest
from common.utils import flat_map, unwrap
class GenerateRequest(BaseSamplerRequest):
@@ -14,6 +16,16 @@ class GenerateRequest(BaseSamplerRequest):
if self.penalty_range == 0:
self.penalty_range = -1
# Move badwordsids into banned tokens for generation
if self.use_default_badwordsids:
bad_words_ids = unwrap(
model.container.generation_config.bad_words_ids,
model.container.hf_config.get_badwordsids()
)
if bad_words_ids:
self.banned_tokens += flat_map(bad_words_ids)
return super().to_gen_params(**kwargs)