mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-03-14 15:57:27 +00:00
Model: Change cache_size/max_seq_len behavior
- Cache size is now given only by the cache_size config option. Default is 4096 (user should always override to max out VRAM) - max_seq_len, if not overridden in the config, will default to the model's config.json - max_seq_len is reduced to be no larger than the cache
This commit is contained in:
@@ -238,7 +238,7 @@ class ExllamaV2Container(BaseModelContainer):
|
||||
base_seq_len = hf_model.hf_config.max_position_embeddings
|
||||
|
||||
# Set the target seq len if present
|
||||
target_seq_len = unwrap(kwargs.get("max_seq_len"), 4096)
|
||||
target_seq_len = unwrap(kwargs.get("max_seq_len"), base_seq_len)
|
||||
|
||||
# Set the rope scale
|
||||
self.config.scale_pos_emb = unwrap(
|
||||
@@ -289,16 +289,7 @@ class ExllamaV2Container(BaseModelContainer):
|
||||
# Set k/v cache size
|
||||
# cache_size is only relevant when paged mode is enabled
|
||||
if self.paged:
|
||||
cache_size = unwrap(kwargs.get("cache_size"), self.config.max_seq_len)
|
||||
|
||||
if cache_size < self.config.max_seq_len:
|
||||
logger.warning(
|
||||
f"The given cache_size ({cache_size}) is smaller than the "
|
||||
"desired context length.\n"
|
||||
"Overriding cache_size to max_seq_len. "
|
||||
)
|
||||
|
||||
cache_size = self.config.max_seq_len
|
||||
cache_size = unwrap(kwargs.get("cache_size"), 4096)
|
||||
|
||||
# Enforce a multiple of 256 for cache size
|
||||
# Overestimate to ensure that the cache isn't below max_seq_len
|
||||
@@ -317,6 +308,13 @@ class ExllamaV2Container(BaseModelContainer):
|
||||
|
||||
cache_size = rounded_cache_size
|
||||
|
||||
if self.config.max_seq_len > cache_size:
|
||||
logger.warning(
|
||||
f"The given max_seq_len ({self.config.max_seq_len}) is larger than the "
|
||||
f"cache size and will be limited to {cache_size} tokens."
|
||||
)
|
||||
self.config.max_seq_len = cache_size
|
||||
|
||||
# Warn user if cache size may be inadequate for CFG
|
||||
if cache_size < 2 * self.config.max_seq_len:
|
||||
logger.warning(
|
||||
|
||||
@@ -105,12 +105,7 @@ class ExllamaV3Container(BaseModelContainer):
|
||||
self = cls()
|
||||
|
||||
# Make sure ExllamaV3 is up to date
|
||||
check_package_version("exllamav3", "0.0.4")
|
||||
|
||||
logger.warning(
|
||||
"ExllamaV3 is currently in an alpha state. "
|
||||
"Please note that all config options may not work."
|
||||
)
|
||||
check_package_version("exllamav3", "0.0.7")
|
||||
|
||||
self.model_dir = model_directory
|
||||
self.hf_model = hf_model
|
||||
@@ -131,9 +126,6 @@ class ExllamaV3Container(BaseModelContainer):
|
||||
self.vision_model = None
|
||||
self.use_vision = False
|
||||
|
||||
# Fallback to 4096 since exl3 can't fetch from HF's config.json
|
||||
self.max_seq_len = unwrap(kwargs.get("max_seq_len"), 4096)
|
||||
|
||||
# Prepare the draft model config if necessary
|
||||
draft_args = unwrap(kwargs.get("draft_model"), {})
|
||||
draft_model_name = draft_args.get("draft_model_name")
|
||||
@@ -231,11 +223,15 @@ class ExllamaV3Container(BaseModelContainer):
|
||||
raise RuntimeError(gpu_unsupported_message)
|
||||
|
||||
# Cache
|
||||
user_cache_size = unwrap(kwargs.get("cache_size"), self.max_seq_len)
|
||||
user_cache_size = unwrap(kwargs.get("cache_size"), 4096)
|
||||
self.cache_size = self.adjust_cache_size(user_cache_size)
|
||||
self.cache_mode = unwrap(kwargs.get("cache_mode"), "FP16")
|
||||
self.cache = self.create_cache(self.cache_mode, self.model)
|
||||
|
||||
# Limit max_seq_len to prevent sequences larger than the cache
|
||||
max_seq_len = unwrap(kwargs.get("max_seq_len"), hf_model.hf_config.max_position_embeddings)
|
||||
self.max_seq_len = self.adjust_max_seq_len(max_seq_len)
|
||||
|
||||
# Draft cache
|
||||
if self.use_draft_model:
|
||||
# Set draft cache mode
|
||||
@@ -274,21 +270,11 @@ class ExllamaV3Container(BaseModelContainer):
|
||||
return self
|
||||
|
||||
def adjust_cache_size(self, cache_size):
|
||||
if cache_size < self.max_seq_len:
|
||||
logger.warning(
|
||||
f"The given cache_size ({cache_size}) is smaller than the "
|
||||
"desired context length.\n"
|
||||
"Overriding cache_size to max_seq_len. "
|
||||
)
|
||||
|
||||
cache_size = self.max_seq_len
|
||||
|
||||
# Enforce a multiple of 256 for cache size
|
||||
# Overestimate to ensure that the cache isn't below max_seq_len
|
||||
cache_remainder = cache_size % 256
|
||||
if cache_remainder != 0:
|
||||
rounded_cache_size = int(256 * ((cache_size - cache_remainder) / 256 + 1))
|
||||
|
||||
logger.warning(
|
||||
f"The given cache size ({cache_size}) is "
|
||||
"not a multiple of 256.\n"
|
||||
@@ -298,22 +284,22 @@ class ExllamaV3Container(BaseModelContainer):
|
||||
|
||||
cache_size = rounded_cache_size
|
||||
|
||||
# Warn user if cache size may be inadequate for CFG
|
||||
if cache_size < 2 * self.max_seq_len:
|
||||
logger.warning(
|
||||
f"The given cache_size ({cache_size}) is less than 2 * max_seq_len "
|
||||
"and may be too small for requests using CFG. \n"
|
||||
"Ignore this warning if you do not plan on using CFG."
|
||||
)
|
||||
|
||||
return cache_size
|
||||
|
||||
def adjust_chunk_size(self, user_chunk_size: int):
|
||||
chunk_size = sorted((256, user_chunk_size, self.max_seq_len))[1]
|
||||
chunk_remainder = chunk_size % 256
|
||||
if chunk_remainder != 0:
|
||||
rounded_chunk_size = int(256 * ((chunk_size - chunk_remainder) / 256 + 1))
|
||||
def adjust_max_seq_len(self, max_seq_len):
|
||||
if max_seq_len > self.cache_size:
|
||||
logger.warning(
|
||||
f"The given max_seq_len ({max_seq_len}) is larger than the cache size "
|
||||
f"and will be limited to {self.cache_size} tokens."
|
||||
)
|
||||
max_seq_len = self.cache_size
|
||||
|
||||
return max_seq_len
|
||||
|
||||
def adjust_chunk_size(self, user_chunk_size: int):
|
||||
chunk_size = max(256, user_chunk_size)
|
||||
rounded_chunk_size = (chunk_size + 255) // 256 * 256
|
||||
if chunk_size != rounded_chunk_size:
|
||||
logger.warning(
|
||||
f"The given chunk size ({chunk_size}) is "
|
||||
"not a multiple of 256.\n"
|
||||
@@ -950,24 +936,18 @@ class ExllamaV3Container(BaseModelContainer):
|
||||
context_len = input_ids[0].size(dim=-1)
|
||||
|
||||
# Automatically set max_tokens to fill up the context
|
||||
# This should be an OK default, but may be changed in the future
|
||||
max_tokens = unwrap(
|
||||
params.max_tokens,
|
||||
self.max_seq_len - context_len,
|
||||
params.max_tokens if params.max_tokens > 0 else None,
|
||||
self.max_seq_len - context_len - 1,
|
||||
)
|
||||
if max_tokens < 1:
|
||||
logger.warning("max_tokens must be a positive integer, setting to 1.")
|
||||
max_tokens = 1
|
||||
|
||||
# Determine if the negative context or the context length is bigger
|
||||
context_to_check = context_len
|
||||
|
||||
# Check total length of prompt against max context length
|
||||
if context_to_check > self.max_seq_len:
|
||||
preamble = "Prompt"
|
||||
|
||||
if context_len > self.max_seq_len:
|
||||
raise ValueError(
|
||||
f"{preamble} length {context_to_check} is greater than "
|
||||
f"Prompt length {context_len} is greater than "
|
||||
f"max_seq_len {self.max_seq_len}"
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user