mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-04-28 02:01:24 +00:00
Model: Split cache creation into a common function
Unifies the switch statement across both draft and model caches. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
@@ -548,30 +548,11 @@ class ExllamaV2Container:
|
|||||||
if not self.quiet:
|
if not self.quiet:
|
||||||
logger.info("Loading draft model: " + self.draft_config.model_dir)
|
logger.info("Loading draft model: " + self.draft_config.model_dir)
|
||||||
|
|
||||||
if self.draft_cache_mode == "Q4":
|
self.draft_cache = self.create_cache(
|
||||||
self.draft_cache = ExLlamaV2Cache_Q4(
|
cache_mode=self.draft_cache_mode,
|
||||||
self.draft_model,
|
autosplit=True,
|
||||||
max_seq_len=self.cache_size,
|
)
|
||||||
lazy=True,
|
|
||||||
)
|
|
||||||
elif self.draft_cache_mode == "Q6":
|
|
||||||
self.draft_cache = ExLlamaV2Cache_Q6(
|
|
||||||
self.draft_model,
|
|
||||||
max_seq_len=self.cache_size,
|
|
||||||
lazy=True,
|
|
||||||
)
|
|
||||||
elif self.draft_cache_mode == "Q8":
|
|
||||||
self.draft_cache = ExLlamaV2Cache_Q8(
|
|
||||||
self.draft_model,
|
|
||||||
max_seq_len=self.cache_size,
|
|
||||||
lazy=True,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.draft_cache = ExLlamaV2Cache(
|
|
||||||
self.draft_model,
|
|
||||||
max_seq_len=self.cache_size,
|
|
||||||
lazy=True,
|
|
||||||
)
|
|
||||||
for value in self.draft_model.load_autosplit_gen(
|
for value in self.draft_model.load_autosplit_gen(
|
||||||
self.draft_cache,
|
self.draft_cache,
|
||||||
reserve_vram=autosplit_reserve,
|
reserve_vram=autosplit_reserve,
|
||||||
@@ -601,34 +582,10 @@ class ExllamaV2Container:
|
|||||||
if value:
|
if value:
|
||||||
yield value
|
yield value
|
||||||
|
|
||||||
if self.cache_mode == "Q4":
|
self.cache = self.create_cache(
|
||||||
self.cache = ExLlamaV2Cache_Q4(
|
cache_mode=self.cache_mode,
|
||||||
self.model,
|
autosplit=self.gpu_split_auto,
|
||||||
max_seq_len=self.cache_size,
|
)
|
||||||
lazy=self.gpu_split_auto,
|
|
||||||
batch_size=1,
|
|
||||||
)
|
|
||||||
elif self.cache_mode == "Q6":
|
|
||||||
self.cache = ExLlamaV2Cache_Q6(
|
|
||||||
self.model,
|
|
||||||
max_seq_len=self.cache_size,
|
|
||||||
lazy=self.gpu_split_auto,
|
|
||||||
batch_size=1,
|
|
||||||
)
|
|
||||||
elif self.cache_mode == "Q8":
|
|
||||||
self.cache = ExLlamaV2Cache_Q8(
|
|
||||||
self.model,
|
|
||||||
max_seq_len=self.cache_size,
|
|
||||||
lazy=self.gpu_split_auto,
|
|
||||||
batch_size=1,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.cache = ExLlamaV2Cache(
|
|
||||||
self.model,
|
|
||||||
max_seq_len=self.cache_size,
|
|
||||||
lazy=self.gpu_split_auto,
|
|
||||||
batch_size=1,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Load model with autosplit
|
# Load model with autosplit
|
||||||
if self.gpu_split_auto:
|
if self.gpu_split_auto:
|
||||||
@@ -647,6 +604,37 @@ class ExllamaV2Container:
|
|||||||
input_ids = torch.zeros((1, self.config.max_input_len), dtype=torch.long)
|
input_ids = torch.zeros((1, self.config.max_input_len), dtype=torch.long)
|
||||||
self.model.forward(input_ids, cache=self.cache, preprocess_only=True)
|
self.model.forward(input_ids, cache=self.cache, preprocess_only=True)
|
||||||
|
|
||||||
|
def create_cache(self, cache_mode: str, autosplit: bool):
|
||||||
|
match cache_mode:
|
||||||
|
case "Q4":
|
||||||
|
return ExLlamaV2Cache_Q4(
|
||||||
|
self.model,
|
||||||
|
max_seq_len=self.cache_size,
|
||||||
|
lazy=autosplit,
|
||||||
|
batch_size=1,
|
||||||
|
)
|
||||||
|
case "Q6":
|
||||||
|
return ExLlamaV2Cache_Q6(
|
||||||
|
self.model,
|
||||||
|
max_seq_len=self.cache_size,
|
||||||
|
lazy=self.gpu_split_auto,
|
||||||
|
batch_size=1,
|
||||||
|
)
|
||||||
|
case "Q8":
|
||||||
|
return ExLlamaV2Cache_Q8(
|
||||||
|
self.model,
|
||||||
|
max_seq_len=self.cache_size,
|
||||||
|
lazy=autosplit,
|
||||||
|
batch_size=1,
|
||||||
|
)
|
||||||
|
case _:
|
||||||
|
return ExLlamaV2Cache(
|
||||||
|
self.model,
|
||||||
|
max_seq_len=self.cache_size,
|
||||||
|
lazy=self.gpu_split_auto,
|
||||||
|
batch_size=1,
|
||||||
|
)
|
||||||
|
|
||||||
async def create_generator(self):
|
async def create_generator(self):
|
||||||
try:
|
try:
|
||||||
# Don't acquire locks unless a model is loaded
|
# Don't acquire locks unless a model is loaded
|
||||||
|
|||||||
Reference in New Issue
Block a user