From 116cf56c8754763b12a89ce9be7a6fac7f4d0bd1 Mon Sep 17 00:00:00 2001 From: kingbri Date: Sun, 26 May 2024 21:24:54 -0400 Subject: [PATCH] Model: Auto-round cache size on init Cache size must be a multiple of 256 to work properly in ExllamaV2. Take the config value and set the cache size to one multiple above the remainder of the cache size divided by 256. This is because cache size can never be lower than max_seq_len. If max_seq_len isn't a multiple of 256, this method will never yield a number that's lower than max_seq_len since it's no longer a source of truth. Signed-off-by: kingbri --- backends/exllamav2/model.py | 31 +++++++++++++++++++++++++------ 1 file changed, 25 insertions(+), 6 deletions(-) diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index 7ffc9ce..64f571f 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -193,14 +193,33 @@ class ExllamaV2Container: ) # Set k/v cache size - self.cache_size = unwrap(kwargs.get("cache_size"), self.config.max_seq_len) - if self.cache_size < self.config.max_seq_len: + cache_size = unwrap(kwargs.get("cache_size"), self.config.max_seq_len) + + if cache_size < self.config.max_seq_len: logger.warning( - "Your specified cache_size is smaller than your " - "desired context length. \n" - "Defaulting cache_size to max_seq_len." + f"The given cache_size ({cache_size}) is smaller than the " + "desired context length.\n" + "Overriding cache_size to max_seq_len. " ) - self.cache_size = self.config.max_seq_len + + cache_size = self.config.max_seq_len + + # Enforce a multiple of 256 for cache size + # Overestimate to ensure that the cache isn't below max_seq_len + cache_remainder = cache_size % 256 + if cache_remainder != 0: + rounded_cache_size = int(256 * ((cache_size - cache_remainder) / 256 + 1)) + + logger.warning( + f"The given cache size ({cache_size}) is " + "not a multiple of 256.\n" + "Overriding cache_size with an overestimated value of " + f"{rounded_cache_size} tokens." + ) + + cache_size = rounded_cache_size + + self.cache_size = cache_size # Enable fasttensors loading if present self.config.fasttensors = unwrap(kwargs.get("fasttensors"), False)