Added fuyu captioning

This commit is contained in:
Jaret Burkett
2023-10-25 14:14:53 -06:00
parent d742792ee4
commit 9636194c09
10 changed files with 240 additions and 87 deletions

View File

@@ -3,15 +3,16 @@ import json
import os
from collections import OrderedDict
import gc
from typing import Type, Literal
import traceback
import torch
from PIL import Image, ImageOps
from tqdm import tqdm
from .dataset_tools_config_modules import RAW_DIR, TRAIN_DIR, Step, ImgInfo
from .tools.image_tools import load_image
from .tools.llava_utils import LLaVAImageProcessor, long_prompt, short_prompt
from .tools.dataset_tools_config_modules import RAW_DIR, TRAIN_DIR, Step, ImgInfo
from .tools.fuyu_utils import FuyuImageProcessor
from .tools.image_tools import load_image, ImageProcessor, resize_to_max
from .tools.llava_utils import LLaVAImageProcessor
from .tools.caption import default_long_prompt, default_short_prompt
from jobs.process import BaseExtensionProcess
from .tools.sync_tools import get_img_paths
@@ -23,7 +24,7 @@ def flush():
gc.collect()
VERSION = 1
VERSION = 2
class SuperTagger(BaseExtensionProcess):
@@ -34,7 +35,9 @@ class SuperTagger(BaseExtensionProcess):
self.dataset_paths: list[str] = config.get('dataset_paths', [])
self.device = config.get('device', 'cuda')
self.steps: list[Step] = config.get('steps', [])
self.caption_method = config.get('caption_method', 'llava')
self.caption_method = config.get('caption_method', 'llava:default')
self.caption_prompt = config.get('caption_prompt', default_long_prompt)
self.caption_short_prompt = config.get('caption_short_prompt', default_short_prompt)
self.force_reprocess_img = config.get('force_reprocess_img', False)
self.master_dataset_dict = OrderedDict()
self.dataset_master_config_file = config.get('dataset_master_config_file', None)
@@ -53,7 +56,15 @@ class SuperTagger(BaseExtensionProcess):
print(f"Found {len(self.dataset_paths)} dataset paths")
self.image_processor = LLaVAImageProcessor(device=self.device)
self.image_processor: ImageProcessor = self.get_image_processor()
def get_image_processor(self):
if self.caption_method.startswith('llava'):
return LLaVAImageProcessor(device=self.device)
elif self.caption_method.startswith('fuyu'):
return FuyuImageProcessor(device=self.device)
else:
raise ValueError(f"Unknown caption method: {self.caption_method}")
def process_image(self, img_path: str):
root_img_dir = os.path.dirname(os.path.dirname(img_path))
@@ -70,18 +81,19 @@ class SuperTagger(BaseExtensionProcess):
else:
img_info = ImgInfo()
img_info.set_version(VERSION)
# send steps to img info so it can store them
# always send steps first in case other processes need them
img_info.add_steps(copy.deepcopy(self.steps))
img_info.set_version(VERSION)
img_info.set_caption_method(self.caption_method)
image: Image = None
caption_image: Image = None
did_update_image = False
# trigger reprocess of steps
if self.force_reprocess_img:
img_info.trigger_image_reprocess(steps=self.steps)
img_info.trigger_image_reprocess()
# set the image as updated if it does not exist on disk
if not os.path.exists(train_img_path):
@@ -91,29 +103,39 @@ class SuperTagger(BaseExtensionProcess):
did_update_image = True
image = load_image(img_path)
# go through the needed steps
for step in img_info.state.steps_to_complete:
for step in copy.deepcopy(img_info.state.steps_to_complete):
if step == 'caption':
# load image
if image is None:
image = load_image(img_path)
if caption_image is None:
caption_image = resize_to_max(image, 1024, 1024)
if not self.image_processor.is_loaded:
print('Loading Model. Takes a while, especially the first time')
self.image_processor.load_model()
img_info.caption = self.image_processor.generate_caption(image, prompt=long_prompt)
img_info.caption = self.image_processor.generate_caption(
image=caption_image,
prompt=self.caption_prompt
)
img_info.mark_step_complete(step)
elif step == 'caption_short':
# load image
if image is None:
image = load_image(img_path)
if caption_image is None:
caption_image = resize_to_max(image, 1024, 1024)
if not self.image_processor.is_loaded:
print('Loading Model. Takes a while, especially the first time')
self.image_processor.load_model()
img_info.caption_short = self.image_processor.generate_caption(image, prompt=short_prompt)
img_info.caption_short = self.image_processor.generate_caption(
image=caption_image,
prompt=self.caption_short_prompt
)
img_info.mark_step_complete(step)
elif step == 'contrast_stretch':
# load image
@@ -153,12 +175,13 @@ class SuperTagger(BaseExtensionProcess):
print(f"Found {len(imgs_to_process)} to process")
for img_path in tqdm(imgs_to_process, desc="Processing images"):
# try:
# self.process_image(img_path)
# except Exception as e:
# print(f"Error processing {img_path}: {e}")
# continue
self.process_image(img_path)
try:
self.process_image(img_path)
except Exception:
# print full stack trace
print(traceback.format_exc())
continue
# self.process_image(img_path)
if self.dataset_master_config_file is not None:
# save it as json

View File

@@ -7,7 +7,7 @@ from typing import List
import torch
from tqdm import tqdm
from .dataset_tools_config_modules import DatasetSyncCollectionConfig, RAW_DIR, NEW_DIR
from .tools.dataset_tools_config_modules import DatasetSyncCollectionConfig, RAW_DIR, NEW_DIR
from .tools.sync_tools import get_unsplash_images, get_pexels_images, get_local_image_file_names, download_image, \
get_img_paths
from jobs.process import BaseExtensionProcess

View File

@@ -0,0 +1,47 @@
caption_manipulation_steps = ['caption', 'caption_short']
default_long_prompt = 'caption this image. describe every single thing in the image in detail. Do not include any unnecessary words in your description for the sake of good grammar. I want many short statements that serve the single purpose of giving the most thorough description if items as possible in the smallest, comma separated way possible. be sure to describe people\'s moods, clothing, the environment, lighting, colors, and everything.'
default_short_prompt = 'caption this image in less than ten words'
default_replacements = [
("the image features", ""),
("the image shows", ""),
("the image depicts", ""),
("the image is", ""),
("in this image", ""),
("in the image", ""),
]
def clean_caption(cap, replacements=None):
if replacements is None:
replacements = default_replacements
# remove any newlines
cap = cap.replace("\n", ", ")
cap = cap.replace("\r", ", ")
cap = cap.replace(".", ",")
cap = cap.replace("\"", "")
# remove unicode characters
cap = cap.encode('ascii', 'ignore').decode('ascii')
# make lowercase
cap = cap.lower()
# remove any extra spaces
cap = " ".join(cap.split())
for replacement in replacements:
cap = cap.replace(replacement[0], replacement[1])
cap_list = cap.split(",")
# trim whitespace
cap_list = [c.strip() for c in cap_list]
# remove empty strings
cap_list = [c for c in cap_list if c != ""]
# remove duplicates
cap_list = list(dict.fromkeys(cap_list))
# join back together
cap = ", ".join(cap_list)
return cap

View File

@@ -8,7 +8,8 @@ NEW_DIR = "_tmp"
TRAIN_DIR = "train"
DEPTH_DIR = "depth"
from .tools.image_tools import Step, img_manipulation_steps
from .image_tools import Step, img_manipulation_steps
from .caption import caption_manipulation_steps
class DatasetSyncCollectionConfig:
@@ -64,8 +65,11 @@ class ImgInfo:
self.caption_short: str = kwargs.get('caption_short', None)
self.poi = [Rect(**poi) for poi in kwargs.get('poi', [])]
self.state = ImageState(**kwargs.get('state', {}))
self.caption_method = kwargs.get('caption_method', None)
self.other_captions = kwargs.get('other_captions', {})
self._upgrade_state()
self.force_image_process: bool = False
self._requested_steps: list[Step] = []
self.is_dirty: bool = False
@@ -77,14 +81,20 @@ class ImgInfo:
if self.caption_short is not None and 'caption_short' not in self.state.steps_complete:
self.mark_step_complete('caption_short')
self.is_dirty = True
if self.caption_method is None and self.caption is not None:
# added caption method in version 2. Was all llava before that
self.caption_method = 'llava:default'
self.is_dirty = True
def to_dict(self):
return {
'version': self.version,
'caption_method': self.caption_method,
'caption': self.caption,
'caption_short': self.caption_short,
'poi': [poi.to_dict() for poi in self.poi],
'state': self.state.to_dict()
'state': self.state.to_dict(),
'other_captions': self.other_captions
}
def mark_step_complete(self, step: Step):
@@ -98,7 +108,10 @@ class ImgInfo:
if step not in self.state.steps_to_complete and step not in self.state.steps_complete:
self.state.steps_to_complete.append(step)
def trigger_image_reprocess(self, steps):
def trigger_image_reprocess(self):
if self._requested_steps is None:
raise Exception("Must call add_steps before trigger_image_reprocess")
steps = self._requested_steps
# remove all image manipulationf from steps_to_complete
for step in img_manipulation_steps:
if step in self.state.steps_to_complete:
@@ -112,14 +125,13 @@ class ImgInfo:
if step in img_manipulation_steps:
self.add_step(step)
def add_steps(self, steps: list[Step]):
self._requested_steps = [step for step in steps]
for stage in steps:
self.add_step(stage)
# update steps if we have any img processes not complete, we have to reprocess them all
# if any steps_to_complete are in img_manipulation_steps
# TODO check if they are in a new order now ands trigger a redo
is_manipulating_image = any([step in img_manipulation_steps for step in self.state.steps_to_complete])
order_has_changed = False
@@ -133,8 +145,38 @@ class ImgInfo:
order_has_changed = True
if is_manipulating_image or order_has_changed:
self.trigger_image_reprocess(steps)
self.trigger_image_reprocess()
def set_caption_method(self, method: str):
if self._requested_steps is None:
raise Exception("Must call add_steps before set_caption_method")
if self.caption_method != method:
self.is_dirty = True
# move previous caption method to other_captions
if self.caption_method is not None and self.caption is not None or self.caption_short is not None:
self.other_captions[self.caption_method] = {
'caption': self.caption,
'caption_short': self.caption_short,
}
self.caption_method = method
self.caption = None
self.caption_short = None
# see if we have a caption from the new method
if method in self.other_captions:
self.caption = self.other_captions[method].get('caption', None)
self.caption_short = self.other_captions[method].get('caption_short', None)
else:
self.trigger_new_caption()
def trigger_new_caption(self):
self.caption = None
self.caption_short = None
self.is_dirty = True
# check to see if we have any steps in the complete list and move them to the to_complete list
for step in self.state.steps_complete:
if step in caption_manipulation_steps:
self.state.steps_complete.remove(step)
self.state.steps_to_complete.append(step)
def to_json(self):
return json.dumps(self.to_dict())

View File

@@ -0,0 +1,66 @@
from transformers import CLIPImageProcessor, BitsAndBytesConfig, AutoTokenizer
from .caption import default_long_prompt, default_short_prompt, default_replacements, clean_caption
import torch
from PIL import Image
class FuyuImageProcessor:
def __init__(self, device='cuda'):
from transformers import FuyuProcessor, FuyuForCausalLM
self.device = device
self.model: FuyuForCausalLM = None
self.processor: FuyuProcessor = None
self.dtype = torch.bfloat16
self.tokenizer: AutoTokenizer
self.is_loaded = False
def load_model(self):
from transformers import FuyuProcessor, FuyuForCausalLM
model_path = "adept/fuyu-8b"
kwargs = {"device_map": self.device}
kwargs['load_in_4bit'] = True
kwargs['quantization_config'] = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=self.dtype,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type='nf4'
)
self.processor = FuyuProcessor.from_pretrained(model_path)
self.model = FuyuForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
self.is_loaded = True
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
self.model = FuyuForCausalLM.from_pretrained(model_path, torch_dtype=self.dtype, **kwargs)
self.processor = FuyuProcessor(image_processor=FuyuImageProcessor(), tokenizer=self.tokenizer)
def generate_caption(
self, image: Image,
prompt: str = default_long_prompt,
replacements=default_replacements,
max_new_tokens=512
):
# prepare inputs for the model
# text_prompt = f"{prompt}\n"
# image = image.convert('RGB')
model_inputs = self.processor(text=prompt, images=[image])
model_inputs = {k: v.to(dtype=self.dtype if torch.is_floating_point(v) else v.dtype, device=self.device) for k, v in
model_inputs.items()}
generation_output = self.model.generate(**model_inputs, max_new_tokens=max_new_tokens)
prompt_len = model_inputs["input_ids"].shape[-1]
output = self.tokenizer.decode(generation_output[0][prompt_len:], skip_special_tokens=True)
output = clean_caption(output, replacements=replacements)
return output
# inputs = self.processor(text=text_prompt, images=image, return_tensors="pt")
# for k, v in inputs.items():
# inputs[k] = v.to(self.device)
# # autoregressively generate text
# generation_output = self.model.generate(**inputs, max_new_tokens=max_new_tokens)
# generation_text = self.processor.batch_decode(generation_output[:, -max_new_tokens:], skip_special_tokens=True)
# output = generation_text[0]
#
# return clean_caption(output, replacements=replacements)

View File

@@ -1,4 +1,4 @@
from typing import Literal, Type
from typing import Literal, Type, TYPE_CHECKING, Union
import cv2
import numpy as np
@@ -8,6 +8,14 @@ Step: Type = Literal['caption', 'caption_short', 'create_mask', 'contrast_stretc
img_manipulation_steps = ['contrast_stretch']
img_ext = ['.jpg', '.jpeg', '.png', '.webp']
if TYPE_CHECKING:
from .llava_utils import LLaVAImageProcessor
from .fuyu_utils import FuyuImageProcessor
ImageProcessor = Union['LLaVAImageProcessor', 'FuyuImageProcessor']
def pil_to_cv2(image):
"""Convert a PIL image to a cv2 image."""
@@ -18,6 +26,7 @@ def cv2_to_pil(image):
"""Convert a cv2 image to a PIL image."""
return Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
def load_image(img_path: str):
image = Image.open(img_path).convert('RGB')
try:
@@ -27,3 +36,14 @@ def load_image(img_path: str):
pass
return image
def resize_to_max(image, max_width=1024, max_height=1024):
width, height = image.size
if width <= max_width and height <= max_height:
return image
scale = min(max_width / width, max_height / height)
width = int(width * scale)
height = int(height * scale)
return image.resize((width, height), Image.LANCZOS)

View File

@@ -5,20 +5,7 @@ except ImportError:
print("You need to manually install llava -> pip install --no-deps git+https://github.com/haotian-liu/LLaVA.git")
raise
long_prompt = 'caption this image. describe every single thing in the image in detail. Do not include any unnecessary words in your description for the sake of good grammar. I want many short statements that serve the single purpose of giving the most thorough description if items as possible in the smallest, comma separated way possible. be sure to describe people\'s moods, clothing, the environment, lighting, colors, and everything.'
short_prompt = 'caption this image in less than ten words'
prompts = [
long_prompt,
short_prompt,
]
replacements = [
("the image features", ""),
("the image shows", ""),
("the image depicts", ""),
("the image is", ""),
]
from .caption import default_long_prompt, default_short_prompt, default_replacements, clean_caption
import torch
from PIL import Image, ImageOps
@@ -61,36 +48,12 @@ class LLaVAImageProcessor:
self.image_processor = vision_tower.image_processor
self.is_loaded = True
def clean_caption(self, cap):
# remove any newlines
cap = cap.replace("\n", ", ")
cap = cap.replace("\r", ", ")
cap = cap.replace(".", ",")
cap = cap.replace("\"", "")
# remove unicode characters
cap = cap.encode('ascii', 'ignore').decode('ascii')
# make lowercase
cap = cap.lower()
# remove any extra spaces
cap = " ".join(cap.split())
for replacement in replacements:
cap = cap.replace(replacement[0], replacement[1])
cap_list = cap.split(",")
# trim whitespace
cap_list = [c.strip() for c in cap_list]
# remove empty strings
cap_list = [c for c in cap_list if c != ""]
# remove duplicates
cap_list = list(dict.fromkeys(cap_list))
# join back together
cap = ", ".join(cap_list)
return cap
def generate_caption(self, image: Image, prompt: str = long_prompt):
def generate_caption(
self, image:
Image, prompt: str = default_long_prompt,
replacements=default_replacements,
max_new_tokens=512
):
# question = "how many dogs are in the picture?"
disable_torch_init()
conv_mode = "llava_v0"
@@ -111,19 +74,10 @@ class LLaVAImageProcessor:
with torch.inference_mode():
output_ids = self.model.generate(
input_ids, images=image_tensor, do_sample=True, temperature=0.1,
max_new_tokens=1024, use_cache=True, stopping_criteria=[stopping_criteria],
max_new_tokens=max_new_tokens, use_cache=True, stopping_criteria=[stopping_criteria],
top_p=0.9
)
outputs = self.tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
conv.messages[-1][-1] = outputs
output = outputs.rsplit('</s>', 1)[0]
return self.clean_caption(output)
def generate_captions(self, image: Image):
responses = []
for prompt in prompts:
output = self.generate_caption(image, prompt)
responses.append(output)
# replace all . with ,
return responses
return clean_caption(output, replacements=replacements)

View File

@@ -9,7 +9,7 @@ def img_root_path(img_id: str):
if TYPE_CHECKING:
from ..dataset_tools_config_modules import DatasetSyncCollectionConfig
from .dataset_tools_config_modules import DatasetSyncCollectionConfig
img_exts = ['.jpg', '.jpeg', '.webp', '.png']

View File

@@ -2,7 +2,7 @@ torch
torchvision
safetensors
diffusers==0.21.3
transformers
git+https://github.com/huggingface/transformers.git@master
lycoris-lora==1.8.3
flatten_json
pyyaml

View File

@@ -176,6 +176,7 @@ class StableDiffusion:
dtype=dtype,
device=self.device_torch,
variant="fp16",
use_safetensors=True,
**load_args
)
else: