mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-04-19 22:08:59 +00:00
Model: Add vision loading support
Adds the ability to load vision parts of text + image models. Requires an explicit flag in config because there isn't a way to automatically determine whether the vision tower should be used. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
@@ -20,6 +20,7 @@ from exllamav2 import (
|
||||
ExLlamaV2Cache_TP,
|
||||
ExLlamaV2Tokenizer,
|
||||
ExLlamaV2Lora,
|
||||
ExLlamaV2VisionTower,
|
||||
)
|
||||
from exllamav2.generator import (
|
||||
ExLlamaV2Sampler,
|
||||
@@ -28,6 +29,7 @@ from exllamav2.generator import (
|
||||
)
|
||||
from itertools import zip_longest
|
||||
from loguru import logger
|
||||
from PIL import Image
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from ruamel.yaml import YAML
|
||||
@@ -91,6 +93,10 @@ class ExllamaV2Container:
|
||||
autosplit_reserve: List[float] = [96 * 1024**2]
|
||||
use_tp: bool = False
|
||||
|
||||
# Vision vars
|
||||
use_vision: bool = False
|
||||
vision_model: Optional[ExLlamaV2VisionTower] = None
|
||||
|
||||
# Load state
|
||||
model_is_loading: bool = False
|
||||
model_loaded: bool = False
|
||||
@@ -144,6 +150,9 @@ class ExllamaV2Container:
|
||||
# Apply a model's config overrides while respecting user settings
|
||||
kwargs = await self.set_model_overrides(**kwargs)
|
||||
|
||||
# Set vision state
|
||||
self.use_vision = unwrap(kwargs.get("vision"), True)
|
||||
|
||||
# Prepare the draft model config if necessary
|
||||
draft_args = unwrap(kwargs.get("draft_model"), {})
|
||||
draft_model_name = draft_args.get("draft_model_name")
|
||||
@@ -608,6 +617,14 @@ class ExllamaV2Container:
|
||||
input_ids = torch.zeros((1, self.config.max_input_len), dtype=torch.long)
|
||||
self.draft_model.forward(input_ids, cache=self.cache, preprocess_only=True)
|
||||
|
||||
# Load vision tower if it exists
|
||||
if self.use_vision:
|
||||
self.vision_model = ExLlamaV2VisionTower(self.config)
|
||||
|
||||
for value in self.vision_model.load_gen(callback_gen=progress_callback):
|
||||
if value:
|
||||
yield value
|
||||
|
||||
self.model = ExLlamaV2(self.config)
|
||||
if not self.quiet:
|
||||
logger.info("Loading model: " + self.config.model_dir)
|
||||
|
||||
Reference in New Issue
Block a user