mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-04-21 06:48:56 +00:00
Model: Add support for HuggingFace config and bad_words_ids
This is necessary for Kobold's API. Current models use bad_words_ids in generation_config.json, but for some reason, they're also present in the model's config.json. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
@@ -1,7 +1,9 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from common import model
|
||||
from common.sampling import BaseSamplerRequest
|
||||
from common.utils import flat_map, unwrap
|
||||
|
||||
|
||||
class GenerateRequest(BaseSamplerRequest):
|
||||
@@ -14,6 +16,16 @@ class GenerateRequest(BaseSamplerRequest):
|
||||
if self.penalty_range == 0:
|
||||
self.penalty_range = -1
|
||||
|
||||
# Move badwordsids into banned tokens for generation
|
||||
if self.use_default_badwordsids:
|
||||
bad_words_ids = unwrap(
|
||||
model.container.generation_config.bad_words_ids,
|
||||
model.container.hf_config.get_badwordsids()
|
||||
)
|
||||
|
||||
if bad_words_ids:
|
||||
self.banned_tokens += flat_map(bad_words_ids)
|
||||
|
||||
return super().to_gen_params(**kwargs)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user