diff --git a/backends/exllamav3/model.py b/backends/exllamav3/model.py index 330c4e1..1026faf 100644 --- a/backends/exllamav3/model.py +++ b/backends/exllamav3/model.py @@ -20,7 +20,7 @@ from exllamav3 import ( Model, Tokenizer, ) -from exllamav3.cache import CacheLayer_fp16, CacheLayer_quant +from exllamav3.cache import CacheLayer_quant from loguru import logger from backends.base_model_container import BaseModelContainer @@ -76,6 +76,7 @@ class ExllamaV3Container(BaseModelContainer): max_seq_len: int = 4096 cache_size: int = 4096 cache_mode: str = "FP16" + draft_cache_mode: str = "FP16" chunk_size: int = 2048 max_batch_size: Optional[int] = None @@ -245,13 +246,35 @@ class ExllamaV3Container(BaseModelContainer): v_bits=v_bits, ) else: - self.cache = Cache( - self.model, max_num_tokens=self.cache_size, layer_type=CacheLayer_fp16 - ) + self.cache = Cache(self.model, max_num_tokens=self.cache_size) # Draft cache if self.use_draft_model: - self.draft_cache = Cache(self.draft_model, max_num_tokens = self.cache_size) + # Set draft cache mode + self.draft_cache_mode = unwrap(draft_args.get("draft_cache_mode"), "FP16") + + # Alias Exl2 q-cache settings + match self.draft_cache_mode: + case "Q4": + self.draft_cache_mode = "4,4" + case "Q6": + self.draft_cache_mode = "6,6" + case "Q8": + self.draft_cache_mode = "8,8" + + split_draft_cache_mode = re.search(r"^([2-8])\s*,\s*([2-8])$", self.draft_cache_mode) + if split_draft_cache_mode: + draft_k_bits = int(split_draft_cache_mode.group(1)) + draft_v_bits = int(split_draft_cache_mode.group(2)) + self.draft_cache = Cache( + self.draft_model, + max_num_tokens=self.cache_size, + layer_type=CacheLayer_quant, + k_bits=draft_k_bits, + v_bits=draft_v_bits, + ) + else: + self.draft_cache = Cache(self.draft_model, max_num_tokens = self.cache_size) # Max batch size self.max_batch_size = unwrap(kwargs.get("max_batch_size"), 256)