Model: Initial Exl3 cache quantization support

This commit is contained in:
DocShotgun
2025-05-01 22:55:51 -07:00
parent 036af02bf6
commit 68a660bdb3
4 changed files with 42 additions and 6 deletions

View File

@@ -185,6 +185,7 @@ class ExllamaV2Container(BaseModelContainer):
# MARK: User configuration
# Get cache mode
# TODO: Separate validation for Exl2 and Exl3 q-cache options
self.cache_mode = unwrap(kwargs.get("cache_mode"), "FP16")
# Turn off GPU split if the user is using 1 GPU

View File

@@ -1,6 +1,7 @@
import asyncio
import gc
import pathlib
import re
import traceback
from typing import (
Any,
@@ -19,6 +20,7 @@ from exllamav3 import (
Model,
Tokenizer,
)
from exllamav3.cache import CacheLayer_fp16, CacheLayer_quant
from loguru import logger
from backends.base_model_container import BaseModelContainer
@@ -73,6 +75,7 @@ class ExllamaV3Container(BaseModelContainer):
use_tp: bool = False
max_seq_len: int = 4096
cache_size: int = 4096
cache_mode: str = "FP16"
chunk_size: int = 2048
max_batch_size: Optional[int] = None
@@ -219,7 +222,32 @@ class ExllamaV3Container(BaseModelContainer):
# Cache
user_cache_size = unwrap(kwargs.get("cache_size"), self.max_seq_len)
self.cache_size = self.adjust_cache_size(user_cache_size)
self.cache = Cache(self.model, max_num_tokens=self.cache_size)
self.cache_mode = unwrap(kwargs.get("cache_mode"), "FP16")
# Alias Exl2 q-cache settings
match self.cache_mode:
case "Q4":
self.cache_mode = "4,4"
case "Q6":
self.cache_mode = "6,6"
case "Q8":
self.cache_mode = "8,8"
split_cache_mode = re.search(r"^([2-8]),([2-8])$", self.cache_mode)
if split_cache_mode:
k_bits = int(split_cache_mode.group(1))
v_bits = int(split_cache_mode.group(2))
self.cache = Cache(
self.model,
max_num_tokens=self.cache_size,
layer_type=CacheLayer_quant,
k_bits=k_bits,
v_bits=v_bits,
)
else:
self.cache = Cache(
self.model, max_num_tokens=self.cache_size, layer_type=CacheLayer_fp16
)
# Draft cache
if self.use_draft_model:
@@ -314,7 +342,7 @@ class ExllamaV3Container(BaseModelContainer):
max_seq_len=self.max_seq_len,
cache_size=self.cache_size,
max_batch_size=self.max_batch_size,
# cache_mode=self.cache_mode,
cache_mode=self.cache_mode,
chunk_size=self.chunk_size,
use_vision=self.use_vision,
)