Model: Add vision loading support

Adds the ability to load vision parts of text + image models. Requires
an explicit flag in config because there isn't a way to automatically
determine whether the vision tower should be used.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri
2024-11-11 12:04:40 -05:00
parent cc2516790d
commit 69ac0eb8aa
5 changed files with 42 additions and 5 deletions

View File

@@ -20,6 +20,7 @@ from exllamav2 import (
ExLlamaV2Cache_TP,
ExLlamaV2Tokenizer,
ExLlamaV2Lora,
ExLlamaV2VisionTower,
)
from exllamav2.generator import (
ExLlamaV2Sampler,
@@ -28,6 +29,7 @@ from exllamav2.generator import (
)
from itertools import zip_longest
from loguru import logger
from PIL import Image
from typing import List, Optional, Union
from ruamel.yaml import YAML
@@ -91,6 +93,10 @@ class ExllamaV2Container:
autosplit_reserve: List[float] = [96 * 1024**2]
use_tp: bool = False
# Vision vars
use_vision: bool = False
vision_model: Optional[ExLlamaV2VisionTower] = None
# Load state
model_is_loading: bool = False
model_loaded: bool = False
@@ -144,6 +150,9 @@ class ExllamaV2Container:
# Apply a model's config overrides while respecting user settings
kwargs = await self.set_model_overrides(**kwargs)
# Set vision state
self.use_vision = unwrap(kwargs.get("vision"), True)
# Prepare the draft model config if necessary
draft_args = unwrap(kwargs.get("draft_model"), {})
draft_model_name = draft_args.get("draft_model_name")
@@ -608,6 +617,14 @@ class ExllamaV2Container:
input_ids = torch.zeros((1, self.config.max_input_len), dtype=torch.long)
self.draft_model.forward(input_ids, cache=self.cache, preprocess_only=True)
# Load vision tower if it exists
if self.use_vision:
self.vision_model = ExLlamaV2VisionTower(self.config)
for value in self.vision_model.load_gen(callback_gen=progress_callback):
if value:
yield value
self.model = ExLlamaV2(self.config)
if not self.quiet:
logger.info("Loading model: " + self.config.model_dir)