mirror of
https://github.com/turboderp-org/exllamav2.git
synced 2026-03-15 00:07:26 +00:00
Add filter_prefer_eos option
This commit is contained in:
@@ -37,7 +37,7 @@ class ExLlamaV2Sampler:
|
||||
cfg_scale = None
|
||||
|
||||
filters = []
|
||||
|
||||
filter_prefer_eos = False
|
||||
|
||||
def clone(self):
|
||||
|
||||
@@ -168,11 +168,14 @@ class ExLlamaV2Sampler:
|
||||
for f in settings.filters:
|
||||
|
||||
pt, et = f.next()
|
||||
pass_tokens = pt if pass_tokens is None else pass_tokens & pt
|
||||
end_tokens = et if end_tokens is None else end_tokens | et
|
||||
if pt is not None: pass_tokens = pt if pass_tokens is None else pass_tokens & pt
|
||||
if et is not None: end_tokens = et if end_tokens is None else end_tokens | et
|
||||
|
||||
assert pass_tokens, "Filter excluded all tokens"
|
||||
ext_c.logit_filter_exclusive(logit_filter, [sorted(list(pass_tokens))])
|
||||
if pass_tokens is not None:
|
||||
assert pass_tokens, "Filter excluded all tokens"
|
||||
if settings.filter_prefer_eos and tokenizer.eos_token_id in pass_tokens:
|
||||
pass_tokens = { tokenizer.eos_token_id }
|
||||
ext_c.logit_filter_exclusive(logit_filter, [sorted(list(pass_tokens))])
|
||||
|
||||
# Healing
|
||||
|
||||
@@ -206,8 +209,8 @@ class ExLlamaV2Sampler:
|
||||
|
||||
batch_size = logits.shape[0]
|
||||
|
||||
output_tokens = torch.empty((batch_size, 1), device="cpu", dtype=torch.long)
|
||||
output_probs = torch.empty((batch_size, 1), device="cpu", dtype=torch.float)
|
||||
output_tokens = torch.empty((batch_size, 1), device = "cpu", dtype = torch.long)
|
||||
output_probs = torch.empty((batch_size, 1), device = "cpu", dtype = torch.float)
|
||||
if return_top_tokens == 0:
|
||||
output_ktokens = none_tensor
|
||||
output_kprobs = none_tensor
|
||||
|
||||
Reference in New Issue
Block a user