From 638eef401acf02e1460e1bb2c01b277089b22378 Mon Sep 17 00:00:00 2001 From: kingbri <8082010+kingbri1@users.noreply.github.com> Date: Thu, 8 May 2025 23:10:03 -0400 Subject: [PATCH] Model: Move cache creation to a common function Prevents repetitiveness while also creating a Cache class. Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com> --- backends/exllamav3/model.py | 81 ++++++++++++++----------------------- 1 file changed, 31 insertions(+), 50 deletions(-) diff --git a/backends/exllamav3/model.py b/backends/exllamav3/model.py index 61986c1..c386a6e 100644 --- a/backends/exllamav3/model.py +++ b/backends/exllamav3/model.py @@ -168,7 +168,7 @@ class ExllamaV3Container(BaseModelContainer): logger.info(f"Using draft model: {str(draft_model_path.resolve())}") else: self.draft_model = None - self.craft_cache = None + self.draft_cache = None # Turn off GPU split if the user is using 1 GPU gpu_count = torch.cuda.device_count() @@ -222,61 +222,15 @@ class ExllamaV3Container(BaseModelContainer): user_cache_size = unwrap(kwargs.get("cache_size"), self.max_seq_len) self.cache_size = self.adjust_cache_size(user_cache_size) self.cache_mode = unwrap(kwargs.get("cache_mode"), "FP16") - - # Alias Exl2 q-cache settings - match self.cache_mode: - case "Q4": - self.cache_mode = "4,4" - case "Q6": - self.cache_mode = "6,6" - case "Q8": - self.cache_mode = "8,8" - - split_cache_mode = re.search(r"^([2-8])\s*,\s*([2-8])$", self.cache_mode) - if split_cache_mode: - k_bits = int(split_cache_mode.group(1)) - v_bits = int(split_cache_mode.group(2)) - self.cache = Cache( - self.model, - max_num_tokens=self.cache_size, - layer_type=CacheLayer_quant, - k_bits=k_bits, - v_bits=v_bits, - ) - else: - self.cache = Cache(self.model, max_num_tokens=self.cache_size) + self.cache = self.create_cache(self.cache_mode, self.model) # Draft cache if self.use_draft_model: # Set draft cache mode self.draft_cache_mode = unwrap(draft_args.get("draft_cache_mode"), "FP16") - - # Alias Exl2 q-cache settings - match self.draft_cache_mode: - case "Q4": - self.draft_cache_mode = "4,4" - case "Q6": - self.draft_cache_mode = "6,6" - case "Q8": - self.draft_cache_mode = "8,8" - - split_draft_cache_mode = re.search( - r"^([2-8])\s*,\s*([2-8])$", self.draft_cache_mode + self.draft_cache = self.create_cache( + self.draft_cache_mode, self.draft_model ) - if split_draft_cache_mode: - draft_k_bits = int(split_draft_cache_mode.group(1)) - draft_v_bits = int(split_draft_cache_mode.group(2)) - self.draft_cache = Cache( - self.draft_model, - max_num_tokens=self.cache_size, - layer_type=CacheLayer_quant, - k_bits=draft_k_bits, - v_bits=draft_v_bits, - ) - else: - self.draft_cache = Cache( - self.draft_model, max_num_tokens=self.cache_size - ) # Max batch size self.max_batch_size = unwrap(kwargs.get("max_batch_size"), 256) @@ -355,6 +309,33 @@ class ExllamaV3Container(BaseModelContainer): return chunk_size + def create_cache(self, raw_cache_mode: str, model: Model): + # Cast exl2 types to exl3 + match raw_cache_mode: + case "Q4": + raw_cache_mode = "4,4" + case "Q6": + raw_cache_mode = "6,6" + case "Q8": + raw_cache_mode = "8,8" + + split_cache_mode = re.search(r"^([2-8])\s*,\s*([2-8])$", raw_cache_mode) + + if split_cache_mode: + draft_k_bits = int(split_cache_mode.group(1)) + draft_v_bits = int(split_cache_mode.group(2)) + cache = Cache( + model, + max_num_tokens=self.cache_size, + layer_type=CacheLayer_quant, + k_bits=draft_k_bits, + v_bits=draft_v_bits, + ) + else: + cache = Cache(model, max_num_tokens=self.cache_size) + + return cache + def model_info(self) -> ModelCard: """ Returns a dictionary of the current model's configuration parameters.