Model: Cleanup generation length and page checks

Reduce the amount of if statements and combine parts of code.

Signed-off-by: kingbri <8082010+bdashore3@users.noreply.github.com>
This commit is contained in:
kingbri
2024-12-26 23:13:08 -05:00
parent ba2579ff74
commit b994aae995

View File

@@ -1309,10 +1309,7 @@ class ExllamaV2Container:
context_len = input_ids[0].size(dim=-1)
# The second index will be the negative prompt if CFG is enabled
if negative_prompt is not None:
negative_context_len = input_ids[1].size(dim=-1)
else:
negative_context_len = 0
negative_context_len = input_ids[1].size(dim=-1) if negative_prompt else 0
# Automatically set max_tokens to fill up the context
# This should be an OK default, but may be changed in the future
@@ -1324,34 +1321,35 @@ class ExllamaV2Container:
logger.warning("max_tokens must be a positive integer, setting to 1.")
max_tokens = 1
# Check total length of request
if context_len + max_tokens > self.config.max_seq_len:
# Determine if the negative context or the context length is bigger
context_to_check = max(negative_context_len, context_len)
# Check highest possible total length of request
if context_to_check + max_tokens > self.config.max_seq_len:
preamble = (
"Negative prompt request"
if negative_context_len > context_len
else "Request"
)
raise ValueError(
f"Request length {context_len} + {max_tokens} is greater than "
f"{preamble} length {context_to_check} + {max_tokens} is greater than "
f"max_seq_len {self.config.max_seq_len}"
)
# Check total length of negative prompt request if CFG is enabled
if negative_prompt is not None:
if context_len + max_tokens > self.config.max_seq_len:
raise ValueError(
f"Request length for negative prompt "
f"{negative_context_len} + {max_tokens} is greater than "
f"max_seq_len {self.config.max_seq_len}"
)
# Check total required pages for CFG request
if (
sum(
256 * math.ceil((context + max_tokens) / 256)
for context in (context_len, negative_context_len)
)
> self.cache_size
):
raise ValueError(
f"Total required page size for request "
f"{context_len} + {negative_context_len} + {max_tokens} * 2 "
f"is greater than cache_size {self.cache_size}"
)
# Check total required pages for CFG request to avoid overallocation
if negative_prompt and (
sum(
256 * math.ceil((context + max_tokens) / 256)
for context in (context_len, negative_context_len)
)
> self.cache_size
):
raise ValueError(
f"Total required page size for request "
f"{context_len} + {negative_context_len} + {max_tokens} * 2 "
f"is greater than cache_size {self.cache_size}"
)
# Set min_tokens to generate while keeping EOS banned
min_tokens = unwrap(kwargs.get("min_tokens"), 0)