Fix llava import

This commit is contained in:
Jaret Burkett
2023-10-29 12:50:54 -06:00
parent 48a9bac22d
commit b84e3260cb

View File

@@ -1,9 +1,3 @@
try:
from llava.model import LlavaLlamaForCausalLM
except ImportError:
# print("You need to manually install llava -> pip install --no-deps git+https://github.com/haotian-liu/LLaVA.git")
print("You need to manually install llava -> pip install --no-deps git+https://github.com/haotian-liu/LLaVA.git")
raise
from .caption import default_long_prompt, default_short_prompt, default_replacements, clean_caption
@@ -20,6 +14,13 @@ img_ext = ['.jpg', '.jpeg', '.png', '.webp']
class LLaVAImageProcessor:
def __init__(self, device='cuda'):
try:
from llava.model import LlavaLlamaForCausalLM
except ImportError:
# print("You need to manually install llava -> pip install --no-deps git+https://github.com/haotian-liu/LLaVA.git")
print(
"You need to manually install llava -> pip install --no-deps git+https://github.com/haotian-liu/LLaVA.git")
raise
self.device = device
self.model: LlavaLlamaForCausalLM = None
self.tokenizer: AutoTokenizer = None