mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-03-14 15:57:27 +00:00
ExllamaV3: Handle max_seq_len defined and cache_size undefined case
The previous changes broke existing configs and max_seq_len was force-overriden to 4096. This helps single-user setups since they do not really benefit from the split cache_size max_seq_len mechanism (except if batching). cache_size is still the prime mover in exl3 due to its paging mechanism. Ideally, for multi-user setups, cache_size should take as much VRAM as possible and max_seq_len should be limited. Breakdown: cache_size and max_seq_len specified -> values only cache_size/max_seq_len specified -> max_seq_len = cache_size and vice versa neither specified -> cache_size = 4096, max_seq_len = min(max_position_embeddings, cache_size) Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com>
This commit is contained in:
@@ -222,17 +222,19 @@ class ExllamaV3Container(BaseModelContainer):
|
||||
|
||||
raise RuntimeError(gpu_unsupported_message)
|
||||
|
||||
# Cache
|
||||
user_cache_size = unwrap(kwargs.get("cache_size"), 4096)
|
||||
# Store the max_seq_len arg
|
||||
user_max_seq_len = kwargs.get("max_seq_len")
|
||||
|
||||
# Cache creation
|
||||
|
||||
# If undefined, cache_size should be max_seq_len, otherwise use 4096 default
|
||||
user_cache_size = coalesce(kwargs.get("cache_size"), user_max_seq_len, 4096)
|
||||
self.cache_size = self.adjust_cache_size(user_cache_size)
|
||||
self.cache_mode = unwrap(kwargs.get("cache_mode"), "FP16")
|
||||
self.cache = self.create_cache(self.cache_mode, self.model)
|
||||
|
||||
# Limit max_seq_len to prevent sequences larger than the cache
|
||||
max_seq_len = unwrap(
|
||||
kwargs.get("max_seq_len"), hf_model.hf_config.max_position_embeddings
|
||||
)
|
||||
self.max_seq_len = self.adjust_max_seq_len(max_seq_len)
|
||||
self.max_seq_len = self.adjust_max_seq_len(user_max_seq_len)
|
||||
|
||||
# Draft cache
|
||||
if self.use_draft_model:
|
||||
@@ -271,9 +273,9 @@ class ExllamaV3Container(BaseModelContainer):
|
||||
|
||||
return self
|
||||
|
||||
# Enforce a multiple of 256 for cache size
|
||||
# Overestimate to ensure that the cache isn't below max_seq_len
|
||||
def adjust_cache_size(self, cache_size):
|
||||
# Enforce a multiple of 256 for cache size
|
||||
# Overestimate to ensure that the cache isn't below max_seq_len
|
||||
cache_remainder = cache_size % 256
|
||||
if cache_remainder != 0:
|
||||
rounded_cache_size = int(256 * ((cache_size - cache_remainder) / 256 + 1))
|
||||
@@ -288,8 +290,20 @@ class ExllamaV3Container(BaseModelContainer):
|
||||
|
||||
return cache_size
|
||||
|
||||
# Make sure max_seq_len's upper limit is cache_size
|
||||
# If max_seq_len isn't specified, override it to
|
||||
# cache_size/max_pos_embeddings, whichever is smaller
|
||||
def adjust_max_seq_len(self, max_seq_len):
|
||||
if max_seq_len > self.cache_size:
|
||||
if not max_seq_len:
|
||||
default_max_seq_len = min(
|
||||
self.hf_model.hf_config.max_position_embeddings, self.cache_size
|
||||
)
|
||||
|
||||
logger.warning(
|
||||
f"max_seq_len is undefined. Overriding to {default_max_seq_len} tokens."
|
||||
)
|
||||
max_seq_len = default_max_seq_len
|
||||
elif max_seq_len > self.cache_size:
|
||||
logger.warning(
|
||||
f"The given max_seq_len ({max_seq_len}) is larger than the cache size "
|
||||
f"and will be limited to {self.cache_size} tokens."
|
||||
|
||||
@@ -157,7 +157,7 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs):
|
||||
|
||||
# Override the max sequence length based on user
|
||||
max_seq_len = kwargs.get("max_seq_len")
|
||||
if max_seq_len == -1 or max_seq_len is None:
|
||||
if max_seq_len == -1:
|
||||
kwargs["max_seq_len"] = hf_model.hf_config.max_position_embeddings
|
||||
|
||||
# Create a new container and check if the right dependencies are installed
|
||||
|
||||
Reference in New Issue
Block a user