mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 03:01:28 +00:00
Added fuyu captioning
This commit is contained in:
47
extensions_built_in/dataset_tools/tools/caption.py
Normal file
47
extensions_built_in/dataset_tools/tools/caption.py
Normal file
@@ -0,0 +1,47 @@
|
||||
|
||||
caption_manipulation_steps = ['caption', 'caption_short']
|
||||
|
||||
default_long_prompt = 'caption this image. describe every single thing in the image in detail. Do not include any unnecessary words in your description for the sake of good grammar. I want many short statements that serve the single purpose of giving the most thorough description if items as possible in the smallest, comma separated way possible. be sure to describe people\'s moods, clothing, the environment, lighting, colors, and everything.'
|
||||
default_short_prompt = 'caption this image in less than ten words'
|
||||
|
||||
default_replacements = [
|
||||
("the image features", ""),
|
||||
("the image shows", ""),
|
||||
("the image depicts", ""),
|
||||
("the image is", ""),
|
||||
("in this image", ""),
|
||||
("in the image", ""),
|
||||
]
|
||||
|
||||
|
||||
def clean_caption(cap, replacements=None):
|
||||
if replacements is None:
|
||||
replacements = default_replacements
|
||||
|
||||
# remove any newlines
|
||||
cap = cap.replace("\n", ", ")
|
||||
cap = cap.replace("\r", ", ")
|
||||
cap = cap.replace(".", ",")
|
||||
cap = cap.replace("\"", "")
|
||||
|
||||
# remove unicode characters
|
||||
cap = cap.encode('ascii', 'ignore').decode('ascii')
|
||||
|
||||
# make lowercase
|
||||
cap = cap.lower()
|
||||
# remove any extra spaces
|
||||
cap = " ".join(cap.split())
|
||||
|
||||
for replacement in replacements:
|
||||
cap = cap.replace(replacement[0], replacement[1])
|
||||
|
||||
cap_list = cap.split(",")
|
||||
# trim whitespace
|
||||
cap_list = [c.strip() for c in cap_list]
|
||||
# remove empty strings
|
||||
cap_list = [c for c in cap_list if c != ""]
|
||||
# remove duplicates
|
||||
cap_list = list(dict.fromkeys(cap_list))
|
||||
# join back together
|
||||
cap = ", ".join(cap_list)
|
||||
return cap
|
||||
@@ -0,0 +1,187 @@
|
||||
import json
|
||||
from typing import Literal, Type, TYPE_CHECKING
|
||||
|
||||
Host: Type = Literal['unsplash', 'pexels']
|
||||
|
||||
RAW_DIR = "raw"
|
||||
NEW_DIR = "_tmp"
|
||||
TRAIN_DIR = "train"
|
||||
DEPTH_DIR = "depth"
|
||||
|
||||
from .image_tools import Step, img_manipulation_steps
|
||||
from .caption import caption_manipulation_steps
|
||||
|
||||
|
||||
class DatasetSyncCollectionConfig:
|
||||
def __init__(self, **kwargs):
|
||||
self.host: Host = kwargs.get('host', None)
|
||||
self.collection_id: str = kwargs.get('collection_id', None)
|
||||
self.directory: str = kwargs.get('directory', None)
|
||||
self.api_key: str = kwargs.get('api_key', None)
|
||||
self.min_width: int = kwargs.get('min_width', 1024)
|
||||
self.min_height: int = kwargs.get('min_height', 1024)
|
||||
|
||||
if self.host is None:
|
||||
raise ValueError("host is required")
|
||||
if self.collection_id is None:
|
||||
raise ValueError("collection_id is required")
|
||||
if self.directory is None:
|
||||
raise ValueError("directory is required")
|
||||
if self.api_key is None:
|
||||
raise ValueError(f"api_key is required: {self.host}:{self.collection_id}")
|
||||
|
||||
|
||||
class ImageState:
|
||||
def __init__(self, **kwargs):
|
||||
self.steps_complete: list[Step] = kwargs.get('steps_complete', [])
|
||||
self.steps_to_complete: list[Step] = kwargs.get('steps_to_complete', [])
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
'steps_complete': self.steps_complete
|
||||
}
|
||||
|
||||
|
||||
class Rect:
|
||||
def __init__(self, **kwargs):
|
||||
self.x = kwargs.get('x', 0)
|
||||
self.y = kwargs.get('y', 0)
|
||||
self.width = kwargs.get('width', 0)
|
||||
self.height = kwargs.get('height', 0)
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
'x': self.x,
|
||||
'y': self.y,
|
||||
'width': self.width,
|
||||
'height': self.height
|
||||
}
|
||||
|
||||
|
||||
class ImgInfo:
|
||||
def __init__(self, **kwargs):
|
||||
self.version: int = kwargs.get('version', None)
|
||||
self.caption: str = kwargs.get('caption', None)
|
||||
self.caption_short: str = kwargs.get('caption_short', None)
|
||||
self.poi = [Rect(**poi) for poi in kwargs.get('poi', [])]
|
||||
self.state = ImageState(**kwargs.get('state', {}))
|
||||
self.caption_method = kwargs.get('caption_method', None)
|
||||
self.other_captions = kwargs.get('other_captions', {})
|
||||
self._upgrade_state()
|
||||
self.force_image_process: bool = False
|
||||
self._requested_steps: list[Step] = []
|
||||
|
||||
self.is_dirty: bool = False
|
||||
|
||||
def _upgrade_state(self):
|
||||
# upgrades older states
|
||||
if self.caption is not None and 'caption' not in self.state.steps_complete:
|
||||
self.mark_step_complete('caption')
|
||||
self.is_dirty = True
|
||||
if self.caption_short is not None and 'caption_short' not in self.state.steps_complete:
|
||||
self.mark_step_complete('caption_short')
|
||||
self.is_dirty = True
|
||||
if self.caption_method is None and self.caption is not None:
|
||||
# added caption method in version 2. Was all llava before that
|
||||
self.caption_method = 'llava:default'
|
||||
self.is_dirty = True
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
'version': self.version,
|
||||
'caption_method': self.caption_method,
|
||||
'caption': self.caption,
|
||||
'caption_short': self.caption_short,
|
||||
'poi': [poi.to_dict() for poi in self.poi],
|
||||
'state': self.state.to_dict(),
|
||||
'other_captions': self.other_captions
|
||||
}
|
||||
|
||||
def mark_step_complete(self, step: Step):
|
||||
if step not in self.state.steps_complete:
|
||||
self.state.steps_complete.append(step)
|
||||
if step in self.state.steps_to_complete:
|
||||
self.state.steps_to_complete.remove(step)
|
||||
self.is_dirty = True
|
||||
|
||||
def add_step(self, step: Step):
|
||||
if step not in self.state.steps_to_complete and step not in self.state.steps_complete:
|
||||
self.state.steps_to_complete.append(step)
|
||||
|
||||
def trigger_image_reprocess(self):
|
||||
if self._requested_steps is None:
|
||||
raise Exception("Must call add_steps before trigger_image_reprocess")
|
||||
steps = self._requested_steps
|
||||
# remove all image manipulationf from steps_to_complete
|
||||
for step in img_manipulation_steps:
|
||||
if step in self.state.steps_to_complete:
|
||||
self.state.steps_to_complete.remove(step)
|
||||
if step in self.state.steps_complete:
|
||||
self.state.steps_complete.remove(step)
|
||||
self.force_image_process = True
|
||||
self.is_dirty = True
|
||||
# we want to keep the order passed in process file
|
||||
for step in steps:
|
||||
if step in img_manipulation_steps:
|
||||
self.add_step(step)
|
||||
|
||||
def add_steps(self, steps: list[Step]):
|
||||
self._requested_steps = [step for step in steps]
|
||||
for stage in steps:
|
||||
self.add_step(stage)
|
||||
|
||||
# update steps if we have any img processes not complete, we have to reprocess them all
|
||||
# if any steps_to_complete are in img_manipulation_steps
|
||||
|
||||
is_manipulating_image = any([step in img_manipulation_steps for step in self.state.steps_to_complete])
|
||||
order_has_changed = False
|
||||
|
||||
if not is_manipulating_image:
|
||||
# check to see if order has changed. No need to if already redoing it. Will detect if ones are removed
|
||||
target_img_manipulation_order = [step for step in steps if step in img_manipulation_steps]
|
||||
current_img_manipulation_order = [step for step in self.state.steps_complete if
|
||||
step in img_manipulation_steps]
|
||||
if target_img_manipulation_order != current_img_manipulation_order:
|
||||
order_has_changed = True
|
||||
|
||||
if is_manipulating_image or order_has_changed:
|
||||
self.trigger_image_reprocess()
|
||||
|
||||
def set_caption_method(self, method: str):
|
||||
if self._requested_steps is None:
|
||||
raise Exception("Must call add_steps before set_caption_method")
|
||||
if self.caption_method != method:
|
||||
self.is_dirty = True
|
||||
# move previous caption method to other_captions
|
||||
if self.caption_method is not None and self.caption is not None or self.caption_short is not None:
|
||||
self.other_captions[self.caption_method] = {
|
||||
'caption': self.caption,
|
||||
'caption_short': self.caption_short,
|
||||
}
|
||||
self.caption_method = method
|
||||
self.caption = None
|
||||
self.caption_short = None
|
||||
# see if we have a caption from the new method
|
||||
if method in self.other_captions:
|
||||
self.caption = self.other_captions[method].get('caption', None)
|
||||
self.caption_short = self.other_captions[method].get('caption_short', None)
|
||||
else:
|
||||
self.trigger_new_caption()
|
||||
|
||||
def trigger_new_caption(self):
|
||||
self.caption = None
|
||||
self.caption_short = None
|
||||
self.is_dirty = True
|
||||
# check to see if we have any steps in the complete list and move them to the to_complete list
|
||||
for step in self.state.steps_complete:
|
||||
if step in caption_manipulation_steps:
|
||||
self.state.steps_complete.remove(step)
|
||||
self.state.steps_to_complete.append(step)
|
||||
|
||||
def to_json(self):
|
||||
return json.dumps(self.to_dict())
|
||||
|
||||
def set_version(self, version: int):
|
||||
if self.version != version:
|
||||
self.is_dirty = True
|
||||
self.version = version
|
||||
66
extensions_built_in/dataset_tools/tools/fuyu_utils.py
Normal file
66
extensions_built_in/dataset_tools/tools/fuyu_utils.py
Normal file
@@ -0,0 +1,66 @@
|
||||
from transformers import CLIPImageProcessor, BitsAndBytesConfig, AutoTokenizer
|
||||
|
||||
from .caption import default_long_prompt, default_short_prompt, default_replacements, clean_caption
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class FuyuImageProcessor:
|
||||
def __init__(self, device='cuda'):
|
||||
from transformers import FuyuProcessor, FuyuForCausalLM
|
||||
self.device = device
|
||||
self.model: FuyuForCausalLM = None
|
||||
self.processor: FuyuProcessor = None
|
||||
self.dtype = torch.bfloat16
|
||||
self.tokenizer: AutoTokenizer
|
||||
self.is_loaded = False
|
||||
|
||||
def load_model(self):
|
||||
from transformers import FuyuProcessor, FuyuForCausalLM
|
||||
model_path = "adept/fuyu-8b"
|
||||
kwargs = {"device_map": self.device}
|
||||
kwargs['load_in_4bit'] = True
|
||||
kwargs['quantization_config'] = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_compute_dtype=self.dtype,
|
||||
bnb_4bit_use_double_quant=True,
|
||||
bnb_4bit_quant_type='nf4'
|
||||
)
|
||||
self.processor = FuyuProcessor.from_pretrained(model_path)
|
||||
self.model = FuyuForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
|
||||
self.is_loaded = True
|
||||
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
|
||||
self.model = FuyuForCausalLM.from_pretrained(model_path, torch_dtype=self.dtype, **kwargs)
|
||||
self.processor = FuyuProcessor(image_processor=FuyuImageProcessor(), tokenizer=self.tokenizer)
|
||||
|
||||
def generate_caption(
|
||||
self, image: Image,
|
||||
prompt: str = default_long_prompt,
|
||||
replacements=default_replacements,
|
||||
max_new_tokens=512
|
||||
):
|
||||
# prepare inputs for the model
|
||||
# text_prompt = f"{prompt}\n"
|
||||
|
||||
# image = image.convert('RGB')
|
||||
model_inputs = self.processor(text=prompt, images=[image])
|
||||
model_inputs = {k: v.to(dtype=self.dtype if torch.is_floating_point(v) else v.dtype, device=self.device) for k, v in
|
||||
model_inputs.items()}
|
||||
|
||||
generation_output = self.model.generate(**model_inputs, max_new_tokens=max_new_tokens)
|
||||
prompt_len = model_inputs["input_ids"].shape[-1]
|
||||
output = self.tokenizer.decode(generation_output[0][prompt_len:], skip_special_tokens=True)
|
||||
output = clean_caption(output, replacements=replacements)
|
||||
return output
|
||||
|
||||
# inputs = self.processor(text=text_prompt, images=image, return_tensors="pt")
|
||||
# for k, v in inputs.items():
|
||||
# inputs[k] = v.to(self.device)
|
||||
|
||||
# # autoregressively generate text
|
||||
# generation_output = self.model.generate(**inputs, max_new_tokens=max_new_tokens)
|
||||
# generation_text = self.processor.batch_decode(generation_output[:, -max_new_tokens:], skip_special_tokens=True)
|
||||
# output = generation_text[0]
|
||||
#
|
||||
# return clean_caption(output, replacements=replacements)
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Literal, Type
|
||||
from typing import Literal, Type, TYPE_CHECKING, Union
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
@@ -8,6 +8,14 @@ Step: Type = Literal['caption', 'caption_short', 'create_mask', 'contrast_stretc
|
||||
|
||||
img_manipulation_steps = ['contrast_stretch']
|
||||
|
||||
img_ext = ['.jpg', '.jpeg', '.png', '.webp']
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .llava_utils import LLaVAImageProcessor
|
||||
from .fuyu_utils import FuyuImageProcessor
|
||||
|
||||
ImageProcessor = Union['LLaVAImageProcessor', 'FuyuImageProcessor']
|
||||
|
||||
|
||||
def pil_to_cv2(image):
|
||||
"""Convert a PIL image to a cv2 image."""
|
||||
@@ -18,6 +26,7 @@ def cv2_to_pil(image):
|
||||
"""Convert a cv2 image to a PIL image."""
|
||||
return Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
|
||||
|
||||
|
||||
def load_image(img_path: str):
|
||||
image = Image.open(img_path).convert('RGB')
|
||||
try:
|
||||
@@ -27,3 +36,14 @@ def load_image(img_path: str):
|
||||
pass
|
||||
return image
|
||||
|
||||
|
||||
def resize_to_max(image, max_width=1024, max_height=1024):
|
||||
width, height = image.size
|
||||
if width <= max_width and height <= max_height:
|
||||
return image
|
||||
|
||||
scale = min(max_width / width, max_height / height)
|
||||
width = int(width * scale)
|
||||
height = int(height * scale)
|
||||
|
||||
return image.resize((width, height), Image.LANCZOS)
|
||||
|
||||
@@ -5,20 +5,7 @@ except ImportError:
|
||||
print("You need to manually install llava -> pip install --no-deps git+https://github.com/haotian-liu/LLaVA.git")
|
||||
raise
|
||||
|
||||
long_prompt = 'caption this image. describe every single thing in the image in detail. Do not include any unnecessary words in your description for the sake of good grammar. I want many short statements that serve the single purpose of giving the most thorough description if items as possible in the smallest, comma separated way possible. be sure to describe people\'s moods, clothing, the environment, lighting, colors, and everything.'
|
||||
short_prompt = 'caption this image in less than ten words'
|
||||
|
||||
prompts = [
|
||||
long_prompt,
|
||||
short_prompt,
|
||||
]
|
||||
|
||||
replacements = [
|
||||
("the image features", ""),
|
||||
("the image shows", ""),
|
||||
("the image depicts", ""),
|
||||
("the image is", ""),
|
||||
]
|
||||
from .caption import default_long_prompt, default_short_prompt, default_replacements, clean_caption
|
||||
|
||||
import torch
|
||||
from PIL import Image, ImageOps
|
||||
@@ -61,36 +48,12 @@ class LLaVAImageProcessor:
|
||||
self.image_processor = vision_tower.image_processor
|
||||
self.is_loaded = True
|
||||
|
||||
def clean_caption(self, cap):
|
||||
# remove any newlines
|
||||
cap = cap.replace("\n", ", ")
|
||||
cap = cap.replace("\r", ", ")
|
||||
cap = cap.replace(".", ",")
|
||||
cap = cap.replace("\"", "")
|
||||
|
||||
# remove unicode characters
|
||||
cap = cap.encode('ascii', 'ignore').decode('ascii')
|
||||
|
||||
# make lowercase
|
||||
cap = cap.lower()
|
||||
# remove any extra spaces
|
||||
cap = " ".join(cap.split())
|
||||
|
||||
for replacement in replacements:
|
||||
cap = cap.replace(replacement[0], replacement[1])
|
||||
|
||||
cap_list = cap.split(",")
|
||||
# trim whitespace
|
||||
cap_list = [c.strip() for c in cap_list]
|
||||
# remove empty strings
|
||||
cap_list = [c for c in cap_list if c != ""]
|
||||
# remove duplicates
|
||||
cap_list = list(dict.fromkeys(cap_list))
|
||||
# join back together
|
||||
cap = ", ".join(cap_list)
|
||||
return cap
|
||||
|
||||
def generate_caption(self, image: Image, prompt: str = long_prompt):
|
||||
def generate_caption(
|
||||
self, image:
|
||||
Image, prompt: str = default_long_prompt,
|
||||
replacements=default_replacements,
|
||||
max_new_tokens=512
|
||||
):
|
||||
# question = "how many dogs are in the picture?"
|
||||
disable_torch_init()
|
||||
conv_mode = "llava_v0"
|
||||
@@ -111,19 +74,10 @@ class LLaVAImageProcessor:
|
||||
with torch.inference_mode():
|
||||
output_ids = self.model.generate(
|
||||
input_ids, images=image_tensor, do_sample=True, temperature=0.1,
|
||||
max_new_tokens=1024, use_cache=True, stopping_criteria=[stopping_criteria],
|
||||
max_new_tokens=max_new_tokens, use_cache=True, stopping_criteria=[stopping_criteria],
|
||||
top_p=0.9
|
||||
)
|
||||
outputs = self.tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
|
||||
conv.messages[-1][-1] = outputs
|
||||
output = outputs.rsplit('</s>', 1)[0]
|
||||
return self.clean_caption(output)
|
||||
|
||||
def generate_captions(self, image: Image):
|
||||
|
||||
responses = []
|
||||
for prompt in prompts:
|
||||
output = self.generate_caption(image, prompt)
|
||||
responses.append(output)
|
||||
# replace all . with ,
|
||||
return responses
|
||||
return clean_caption(output, replacements=replacements)
|
||||
|
||||
@@ -9,7 +9,7 @@ def img_root_path(img_id: str):
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..dataset_tools_config_modules import DatasetSyncCollectionConfig
|
||||
from .dataset_tools_config_modules import DatasetSyncCollectionConfig
|
||||
|
||||
img_exts = ['.jpg', '.jpeg', '.webp', '.png']
|
||||
|
||||
|
||||
Reference in New Issue
Block a user