From f194d9d7b0788e576f257c4ad4176c2230556edc Mon Sep 17 00:00:00 2001 From: turboderp <11859846+turboderp@users.noreply.github.com> Date: Mon, 19 Feb 2024 00:00:11 +0100 Subject: [PATCH] Add filter_prefer_eos option --- exllamav2/generator/sampler.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/exllamav2/generator/sampler.py b/exllamav2/generator/sampler.py index e3f9e19..7a27f40 100644 --- a/exllamav2/generator/sampler.py +++ b/exllamav2/generator/sampler.py @@ -37,7 +37,7 @@ class ExLlamaV2Sampler: cfg_scale = None filters = [] - + filter_prefer_eos = False def clone(self): @@ -168,11 +168,14 @@ class ExLlamaV2Sampler: for f in settings.filters: pt, et = f.next() - pass_tokens = pt if pass_tokens is None else pass_tokens & pt - end_tokens = et if end_tokens is None else end_tokens | et + if pt is not None: pass_tokens = pt if pass_tokens is None else pass_tokens & pt + if et is not None: end_tokens = et if end_tokens is None else end_tokens | et - assert pass_tokens, "Filter excluded all tokens" - ext_c.logit_filter_exclusive(logit_filter, [sorted(list(pass_tokens))]) + if pass_tokens is not None: + assert pass_tokens, "Filter excluded all tokens" + if settings.filter_prefer_eos and tokenizer.eos_token_id in pass_tokens: + pass_tokens = { tokenizer.eos_token_id } + ext_c.logit_filter_exclusive(logit_filter, [sorted(list(pass_tokens))]) # Healing @@ -206,8 +209,8 @@ class ExLlamaV2Sampler: batch_size = logits.shape[0] - output_tokens = torch.empty((batch_size, 1), device="cpu", dtype=torch.long) - output_probs = torch.empty((batch_size, 1), device="cpu", dtype=torch.float) + output_tokens = torch.empty((batch_size, 1), device = "cpu", dtype = torch.long) + output_probs = torch.empty((batch_size, 1), device = "cpu", dtype = torch.float) if return_top_tokens == 0: output_ktokens = none_tensor output_kprobs = none_tensor