mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-04-27 01:38:56 +00:00
Model: Robust request length checking in generator
* Ensure that length of positive/negative prompt + max_tokens does not exceed max_seq_len * Ensure that total required pages for CFG request does not exceed allocated cache_size
This commit is contained in:
@@ -1301,17 +1301,51 @@ class ExllamaV2Container:
|
|||||||
|
|
||||||
# The first index will always be the positive prompt
|
# The first index will always be the positive prompt
|
||||||
context_len = input_ids[0].size(dim=-1)
|
context_len = input_ids[0].size(dim=-1)
|
||||||
if context_len > self.config.max_seq_len:
|
|
||||||
raise ValueError(
|
# The second index will be the negative prompt if CFG is enabled
|
||||||
f"Context length {context_len} is greater than max_seq_len "
|
if negative_prompt is not None:
|
||||||
f"{self.config.max_seq_len}"
|
negative_context_len = input_ids[1].size(dim=-1)
|
||||||
)
|
else:
|
||||||
|
negative_context_len = 0
|
||||||
|
|
||||||
# Automatically set max_tokens to fill up the context
|
# Automatically set max_tokens to fill up the context
|
||||||
# This should be an OK default, but may be changed in the future
|
# This should be an OK default, but may be changed in the future
|
||||||
max_tokens = unwrap(
|
max_tokens = unwrap(
|
||||||
kwargs.get("max_tokens"), self.config.max_seq_len - context_len
|
kwargs.get("max_tokens"),
|
||||||
|
self.config.max_seq_len - max(context_len, negative_context_len),
|
||||||
)
|
)
|
||||||
|
if max_tokens < 1:
|
||||||
|
logger.warning("max_tokens must be a positive integer, " "setting to 1.")
|
||||||
|
max_tokens = 1
|
||||||
|
|
||||||
|
# Check total length of request
|
||||||
|
if context_len + max_tokens > self.config.max_seq_len:
|
||||||
|
raise ValueError(
|
||||||
|
f"Request length {context_len} + {max_tokens} is greater than "
|
||||||
|
f"max_seq_len {self.config.max_seq_len}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check total length of negative prompt request if CFG is enabled
|
||||||
|
if negative_prompt is not None:
|
||||||
|
if context_len + max_tokens > self.config.max_seq_len:
|
||||||
|
raise ValueError(
|
||||||
|
f"Request length for negative prompt "
|
||||||
|
f"{negative_context_len} + {max_tokens} is greater than "
|
||||||
|
f"max_seq_len {self.config.max_seq_len}"
|
||||||
|
)
|
||||||
|
# Check total required pages for CFG request
|
||||||
|
if (
|
||||||
|
sum(
|
||||||
|
256 * math.ceil((context + max_tokens) / 256)
|
||||||
|
for context in (context_len, negative_context_len)
|
||||||
|
)
|
||||||
|
> self.cache_size
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
f"Total required page size for request "
|
||||||
|
f"{context_len} + {negative_context_len} + {max_tokens} * 2 "
|
||||||
|
f"is greater than cache_size {self.cache_size}"
|
||||||
|
)
|
||||||
|
|
||||||
# Set min_tokens to generate while keeping EOS banned
|
# Set min_tokens to generate while keeping EOS banned
|
||||||
min_tokens = unwrap(kwargs.get("min_tokens"), 0)
|
min_tokens = unwrap(kwargs.get("min_tokens"), 0)
|
||||||
|
|||||||
Reference in New Issue
Block a user