Model: Switch to begin_stream_ex

Allows for dynamically passing logprobs params instead of assuming
on initialization of the generator.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri
2024-03-17 14:41:16 -04:00
parent 08bcc6307a
commit c9a6d9ae1f

View File

@@ -478,10 +478,6 @@ class ExllamaV2Container:
self.draft_cache,
)
# Always return logprobs and logits
self.generator.return_probabilities = True
self.generator.return_logits = True
# Clean up any extra vram usage from torch and cuda
# (Helps reduce VRAM bottlenecking on Windows)
gc.collect()
@@ -789,7 +785,6 @@ class ExllamaV2Container:
# Logprobs
request_logprobs = unwrap(kwargs.get("logprobs"), 0)
self.generator.return_top_tokens = request_logprobs
# Speculative Ngram
self.generator.speculative_ngram = unwrap(
@@ -917,20 +912,26 @@ class ExllamaV2Container:
# Split for exllama versions that have CFG
if self.use_cfg:
self.generator.begin_stream(
self.generator.begin_stream_ex(
active_ids,
gen_settings,
token_healing=token_healing,
loras=self.active_loras,
input_mask=mask,
position_offsets=offsets,
return_probabilities=request_logprobs > 0,
return_top_tokens=request_logprobs,
return_logits=request_logprobs > 0,
)
else:
self.generator.begin_stream(
self.generator.begin_stream_ex(
active_ids,
gen_settings,
token_healing=token_healing,
loras=self.active_loras,
return_probabilities=request_logprobs > 0,
return_top_tokens=request_logprobs,
return_logits=request_logprobs > 0,
)
# Reset offsets for subsequent passes if the context is truncated