diff --git a/extensions_built_in/dataset_tools/SuperTagger.py b/extensions_built_in/dataset_tools/SuperTagger.py index fc6f5a87..b444216d 100644 --- a/extensions_built_in/dataset_tools/SuperTagger.py +++ b/extensions_built_in/dataset_tools/SuperTagger.py @@ -3,15 +3,16 @@ import json import os from collections import OrderedDict import gc -from typing import Type, Literal - +import traceback import torch from PIL import Image, ImageOps from tqdm import tqdm -from .dataset_tools_config_modules import RAW_DIR, TRAIN_DIR, Step, ImgInfo -from .tools.image_tools import load_image -from .tools.llava_utils import LLaVAImageProcessor, long_prompt, short_prompt +from .tools.dataset_tools_config_modules import RAW_DIR, TRAIN_DIR, Step, ImgInfo +from .tools.fuyu_utils import FuyuImageProcessor +from .tools.image_tools import load_image, ImageProcessor, resize_to_max +from .tools.llava_utils import LLaVAImageProcessor +from .tools.caption import default_long_prompt, default_short_prompt from jobs.process import BaseExtensionProcess from .tools.sync_tools import get_img_paths @@ -23,7 +24,7 @@ def flush(): gc.collect() -VERSION = 1 +VERSION = 2 class SuperTagger(BaseExtensionProcess): @@ -34,7 +35,9 @@ class SuperTagger(BaseExtensionProcess): self.dataset_paths: list[str] = config.get('dataset_paths', []) self.device = config.get('device', 'cuda') self.steps: list[Step] = config.get('steps', []) - self.caption_method = config.get('caption_method', 'llava') + self.caption_method = config.get('caption_method', 'llava:default') + self.caption_prompt = config.get('caption_prompt', default_long_prompt) + self.caption_short_prompt = config.get('caption_short_prompt', default_short_prompt) self.force_reprocess_img = config.get('force_reprocess_img', False) self.master_dataset_dict = OrderedDict() self.dataset_master_config_file = config.get('dataset_master_config_file', None) @@ -53,7 +56,15 @@ class SuperTagger(BaseExtensionProcess): print(f"Found {len(self.dataset_paths)} dataset paths") - self.image_processor = LLaVAImageProcessor(device=self.device) + self.image_processor: ImageProcessor = self.get_image_processor() + + def get_image_processor(self): + if self.caption_method.startswith('llava'): + return LLaVAImageProcessor(device=self.device) + elif self.caption_method.startswith('fuyu'): + return FuyuImageProcessor(device=self.device) + else: + raise ValueError(f"Unknown caption method: {self.caption_method}") def process_image(self, img_path: str): root_img_dir = os.path.dirname(os.path.dirname(img_path)) @@ -70,18 +81,19 @@ class SuperTagger(BaseExtensionProcess): else: img_info = ImgInfo() - img_info.set_version(VERSION) - - # send steps to img info so it can store them + # always send steps first in case other processes need them img_info.add_steps(copy.deepcopy(self.steps)) + img_info.set_version(VERSION) + img_info.set_caption_method(self.caption_method) image: Image = None + caption_image: Image = None did_update_image = False # trigger reprocess of steps if self.force_reprocess_img: - img_info.trigger_image_reprocess(steps=self.steps) + img_info.trigger_image_reprocess() # set the image as updated if it does not exist on disk if not os.path.exists(train_img_path): @@ -91,29 +103,39 @@ class SuperTagger(BaseExtensionProcess): did_update_image = True image = load_image(img_path) - # go through the needed steps - for step in img_info.state.steps_to_complete: + for step in copy.deepcopy(img_info.state.steps_to_complete): if step == 'caption': # load image if image is None: image = load_image(img_path) + if caption_image is None: + caption_image = resize_to_max(image, 1024, 1024) if not self.image_processor.is_loaded: print('Loading Model. Takes a while, especially the first time') self.image_processor.load_model() - img_info.caption = self.image_processor.generate_caption(image, prompt=long_prompt) + img_info.caption = self.image_processor.generate_caption( + image=caption_image, + prompt=self.caption_prompt + ) img_info.mark_step_complete(step) elif step == 'caption_short': # load image if image is None: image = load_image(img_path) + if caption_image is None: + caption_image = resize_to_max(image, 1024, 1024) + if not self.image_processor.is_loaded: print('Loading Model. Takes a while, especially the first time') self.image_processor.load_model() - img_info.caption_short = self.image_processor.generate_caption(image, prompt=short_prompt) + img_info.caption_short = self.image_processor.generate_caption( + image=caption_image, + prompt=self.caption_short_prompt + ) img_info.mark_step_complete(step) elif step == 'contrast_stretch': # load image @@ -153,12 +175,13 @@ class SuperTagger(BaseExtensionProcess): print(f"Found {len(imgs_to_process)} to process") for img_path in tqdm(imgs_to_process, desc="Processing images"): - # try: - # self.process_image(img_path) - # except Exception as e: - # print(f"Error processing {img_path}: {e}") - # continue - self.process_image(img_path) + try: + self.process_image(img_path) + except Exception: + # print full stack trace + print(traceback.format_exc()) + continue + # self.process_image(img_path) if self.dataset_master_config_file is not None: # save it as json diff --git a/extensions_built_in/dataset_tools/SyncFromCollection.py b/extensions_built_in/dataset_tools/SyncFromCollection.py index ba094be9..e65a3584 100644 --- a/extensions_built_in/dataset_tools/SyncFromCollection.py +++ b/extensions_built_in/dataset_tools/SyncFromCollection.py @@ -7,7 +7,7 @@ from typing import List import torch from tqdm import tqdm -from .dataset_tools_config_modules import DatasetSyncCollectionConfig, RAW_DIR, NEW_DIR +from .tools.dataset_tools_config_modules import DatasetSyncCollectionConfig, RAW_DIR, NEW_DIR from .tools.sync_tools import get_unsplash_images, get_pexels_images, get_local_image_file_names, download_image, \ get_img_paths from jobs.process import BaseExtensionProcess diff --git a/extensions_built_in/dataset_tools/tools/caption.py b/extensions_built_in/dataset_tools/tools/caption.py new file mode 100644 index 00000000..28daaee0 --- /dev/null +++ b/extensions_built_in/dataset_tools/tools/caption.py @@ -0,0 +1,47 @@ + +caption_manipulation_steps = ['caption', 'caption_short'] + +default_long_prompt = 'caption this image. describe every single thing in the image in detail. Do not include any unnecessary words in your description for the sake of good grammar. I want many short statements that serve the single purpose of giving the most thorough description if items as possible in the smallest, comma separated way possible. be sure to describe people\'s moods, clothing, the environment, lighting, colors, and everything.' +default_short_prompt = 'caption this image in less than ten words' + +default_replacements = [ + ("the image features", ""), + ("the image shows", ""), + ("the image depicts", ""), + ("the image is", ""), + ("in this image", ""), + ("in the image", ""), +] + + +def clean_caption(cap, replacements=None): + if replacements is None: + replacements = default_replacements + + # remove any newlines + cap = cap.replace("\n", ", ") + cap = cap.replace("\r", ", ") + cap = cap.replace(".", ",") + cap = cap.replace("\"", "") + + # remove unicode characters + cap = cap.encode('ascii', 'ignore').decode('ascii') + + # make lowercase + cap = cap.lower() + # remove any extra spaces + cap = " ".join(cap.split()) + + for replacement in replacements: + cap = cap.replace(replacement[0], replacement[1]) + + cap_list = cap.split(",") + # trim whitespace + cap_list = [c.strip() for c in cap_list] + # remove empty strings + cap_list = [c for c in cap_list if c != ""] + # remove duplicates + cap_list = list(dict.fromkeys(cap_list)) + # join back together + cap = ", ".join(cap_list) + return cap \ No newline at end of file diff --git a/extensions_built_in/dataset_tools/dataset_tools_config_modules.py b/extensions_built_in/dataset_tools/tools/dataset_tools_config_modules.py similarity index 67% rename from extensions_built_in/dataset_tools/dataset_tools_config_modules.py rename to extensions_built_in/dataset_tools/tools/dataset_tools_config_modules.py index 105c8166..60c69dbb 100644 --- a/extensions_built_in/dataset_tools/dataset_tools_config_modules.py +++ b/extensions_built_in/dataset_tools/tools/dataset_tools_config_modules.py @@ -8,7 +8,8 @@ NEW_DIR = "_tmp" TRAIN_DIR = "train" DEPTH_DIR = "depth" -from .tools.image_tools import Step, img_manipulation_steps +from .image_tools import Step, img_manipulation_steps +from .caption import caption_manipulation_steps class DatasetSyncCollectionConfig: @@ -64,8 +65,11 @@ class ImgInfo: self.caption_short: str = kwargs.get('caption_short', None) self.poi = [Rect(**poi) for poi in kwargs.get('poi', [])] self.state = ImageState(**kwargs.get('state', {})) + self.caption_method = kwargs.get('caption_method', None) + self.other_captions = kwargs.get('other_captions', {}) self._upgrade_state() self.force_image_process: bool = False + self._requested_steps: list[Step] = [] self.is_dirty: bool = False @@ -77,14 +81,20 @@ class ImgInfo: if self.caption_short is not None and 'caption_short' not in self.state.steps_complete: self.mark_step_complete('caption_short') self.is_dirty = True + if self.caption_method is None and self.caption is not None: + # added caption method in version 2. Was all llava before that + self.caption_method = 'llava:default' + self.is_dirty = True def to_dict(self): return { 'version': self.version, + 'caption_method': self.caption_method, 'caption': self.caption, 'caption_short': self.caption_short, 'poi': [poi.to_dict() for poi in self.poi], - 'state': self.state.to_dict() + 'state': self.state.to_dict(), + 'other_captions': self.other_captions } def mark_step_complete(self, step: Step): @@ -98,7 +108,10 @@ class ImgInfo: if step not in self.state.steps_to_complete and step not in self.state.steps_complete: self.state.steps_to_complete.append(step) - def trigger_image_reprocess(self, steps): + def trigger_image_reprocess(self): + if self._requested_steps is None: + raise Exception("Must call add_steps before trigger_image_reprocess") + steps = self._requested_steps # remove all image manipulationf from steps_to_complete for step in img_manipulation_steps: if step in self.state.steps_to_complete: @@ -112,14 +125,13 @@ class ImgInfo: if step in img_manipulation_steps: self.add_step(step) - def add_steps(self, steps: list[Step]): + self._requested_steps = [step for step in steps] for stage in steps: self.add_step(stage) # update steps if we have any img processes not complete, we have to reprocess them all # if any steps_to_complete are in img_manipulation_steps - # TODO check if they are in a new order now ands trigger a redo is_manipulating_image = any([step in img_manipulation_steps for step in self.state.steps_to_complete]) order_has_changed = False @@ -133,8 +145,38 @@ class ImgInfo: order_has_changed = True if is_manipulating_image or order_has_changed: - self.trigger_image_reprocess(steps) + self.trigger_image_reprocess() + def set_caption_method(self, method: str): + if self._requested_steps is None: + raise Exception("Must call add_steps before set_caption_method") + if self.caption_method != method: + self.is_dirty = True + # move previous caption method to other_captions + if self.caption_method is not None and self.caption is not None or self.caption_short is not None: + self.other_captions[self.caption_method] = { + 'caption': self.caption, + 'caption_short': self.caption_short, + } + self.caption_method = method + self.caption = None + self.caption_short = None + # see if we have a caption from the new method + if method in self.other_captions: + self.caption = self.other_captions[method].get('caption', None) + self.caption_short = self.other_captions[method].get('caption_short', None) + else: + self.trigger_new_caption() + + def trigger_new_caption(self): + self.caption = None + self.caption_short = None + self.is_dirty = True + # check to see if we have any steps in the complete list and move them to the to_complete list + for step in self.state.steps_complete: + if step in caption_manipulation_steps: + self.state.steps_complete.remove(step) + self.state.steps_to_complete.append(step) def to_json(self): return json.dumps(self.to_dict()) diff --git a/extensions_built_in/dataset_tools/tools/fuyu_utils.py b/extensions_built_in/dataset_tools/tools/fuyu_utils.py new file mode 100644 index 00000000..407da10c --- /dev/null +++ b/extensions_built_in/dataset_tools/tools/fuyu_utils.py @@ -0,0 +1,66 @@ +from transformers import CLIPImageProcessor, BitsAndBytesConfig, AutoTokenizer + +from .caption import default_long_prompt, default_short_prompt, default_replacements, clean_caption +import torch +from PIL import Image + + +class FuyuImageProcessor: + def __init__(self, device='cuda'): + from transformers import FuyuProcessor, FuyuForCausalLM + self.device = device + self.model: FuyuForCausalLM = None + self.processor: FuyuProcessor = None + self.dtype = torch.bfloat16 + self.tokenizer: AutoTokenizer + self.is_loaded = False + + def load_model(self): + from transformers import FuyuProcessor, FuyuForCausalLM + model_path = "adept/fuyu-8b" + kwargs = {"device_map": self.device} + kwargs['load_in_4bit'] = True + kwargs['quantization_config'] = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_compute_dtype=self.dtype, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type='nf4' + ) + self.processor = FuyuProcessor.from_pretrained(model_path) + self.model = FuyuForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs) + self.is_loaded = True + + self.tokenizer = AutoTokenizer.from_pretrained(model_path) + self.model = FuyuForCausalLM.from_pretrained(model_path, torch_dtype=self.dtype, **kwargs) + self.processor = FuyuProcessor(image_processor=FuyuImageProcessor(), tokenizer=self.tokenizer) + + def generate_caption( + self, image: Image, + prompt: str = default_long_prompt, + replacements=default_replacements, + max_new_tokens=512 + ): + # prepare inputs for the model + # text_prompt = f"{prompt}\n" + + # image = image.convert('RGB') + model_inputs = self.processor(text=prompt, images=[image]) + model_inputs = {k: v.to(dtype=self.dtype if torch.is_floating_point(v) else v.dtype, device=self.device) for k, v in + model_inputs.items()} + + generation_output = self.model.generate(**model_inputs, max_new_tokens=max_new_tokens) + prompt_len = model_inputs["input_ids"].shape[-1] + output = self.tokenizer.decode(generation_output[0][prompt_len:], skip_special_tokens=True) + output = clean_caption(output, replacements=replacements) + return output + + # inputs = self.processor(text=text_prompt, images=image, return_tensors="pt") + # for k, v in inputs.items(): + # inputs[k] = v.to(self.device) + + # # autoregressively generate text + # generation_output = self.model.generate(**inputs, max_new_tokens=max_new_tokens) + # generation_text = self.processor.batch_decode(generation_output[:, -max_new_tokens:], skip_special_tokens=True) + # output = generation_text[0] + # + # return clean_caption(output, replacements=replacements) diff --git a/extensions_built_in/dataset_tools/tools/image_tools.py b/extensions_built_in/dataset_tools/tools/image_tools.py index d94d8a19..d36073c0 100644 --- a/extensions_built_in/dataset_tools/tools/image_tools.py +++ b/extensions_built_in/dataset_tools/tools/image_tools.py @@ -1,4 +1,4 @@ -from typing import Literal, Type +from typing import Literal, Type, TYPE_CHECKING, Union import cv2 import numpy as np @@ -8,6 +8,14 @@ Step: Type = Literal['caption', 'caption_short', 'create_mask', 'contrast_stretc img_manipulation_steps = ['contrast_stretch'] +img_ext = ['.jpg', '.jpeg', '.png', '.webp'] + +if TYPE_CHECKING: + from .llava_utils import LLaVAImageProcessor + from .fuyu_utils import FuyuImageProcessor + +ImageProcessor = Union['LLaVAImageProcessor', 'FuyuImageProcessor'] + def pil_to_cv2(image): """Convert a PIL image to a cv2 image.""" @@ -18,6 +26,7 @@ def cv2_to_pil(image): """Convert a cv2 image to a PIL image.""" return Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) + def load_image(img_path: str): image = Image.open(img_path).convert('RGB') try: @@ -27,3 +36,14 @@ def load_image(img_path: str): pass return image + +def resize_to_max(image, max_width=1024, max_height=1024): + width, height = image.size + if width <= max_width and height <= max_height: + return image + + scale = min(max_width / width, max_height / height) + width = int(width * scale) + height = int(height * scale) + + return image.resize((width, height), Image.LANCZOS) diff --git a/extensions_built_in/dataset_tools/tools/llava_utils.py b/extensions_built_in/dataset_tools/tools/llava_utils.py index 7701ce52..6619da90 100644 --- a/extensions_built_in/dataset_tools/tools/llava_utils.py +++ b/extensions_built_in/dataset_tools/tools/llava_utils.py @@ -5,20 +5,7 @@ except ImportError: print("You need to manually install llava -> pip install --no-deps git+https://github.com/haotian-liu/LLaVA.git") raise -long_prompt = 'caption this image. describe every single thing in the image in detail. Do not include any unnecessary words in your description for the sake of good grammar. I want many short statements that serve the single purpose of giving the most thorough description if items as possible in the smallest, comma separated way possible. be sure to describe people\'s moods, clothing, the environment, lighting, colors, and everything.' -short_prompt = 'caption this image in less than ten words' - -prompts = [ - long_prompt, - short_prompt, -] - -replacements = [ - ("the image features", ""), - ("the image shows", ""), - ("the image depicts", ""), - ("the image is", ""), -] +from .caption import default_long_prompt, default_short_prompt, default_replacements, clean_caption import torch from PIL import Image, ImageOps @@ -61,36 +48,12 @@ class LLaVAImageProcessor: self.image_processor = vision_tower.image_processor self.is_loaded = True - def clean_caption(self, cap): - # remove any newlines - cap = cap.replace("\n", ", ") - cap = cap.replace("\r", ", ") - cap = cap.replace(".", ",") - cap = cap.replace("\"", "") - - # remove unicode characters - cap = cap.encode('ascii', 'ignore').decode('ascii') - - # make lowercase - cap = cap.lower() - # remove any extra spaces - cap = " ".join(cap.split()) - - for replacement in replacements: - cap = cap.replace(replacement[0], replacement[1]) - - cap_list = cap.split(",") - # trim whitespace - cap_list = [c.strip() for c in cap_list] - # remove empty strings - cap_list = [c for c in cap_list if c != ""] - # remove duplicates - cap_list = list(dict.fromkeys(cap_list)) - # join back together - cap = ", ".join(cap_list) - return cap - - def generate_caption(self, image: Image, prompt: str = long_prompt): + def generate_caption( + self, image: + Image, prompt: str = default_long_prompt, + replacements=default_replacements, + max_new_tokens=512 + ): # question = "how many dogs are in the picture?" disable_torch_init() conv_mode = "llava_v0" @@ -111,19 +74,10 @@ class LLaVAImageProcessor: with torch.inference_mode(): output_ids = self.model.generate( input_ids, images=image_tensor, do_sample=True, temperature=0.1, - max_new_tokens=1024, use_cache=True, stopping_criteria=[stopping_criteria], + max_new_tokens=max_new_tokens, use_cache=True, stopping_criteria=[stopping_criteria], top_p=0.9 ) outputs = self.tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip() conv.messages[-1][-1] = outputs output = outputs.rsplit('', 1)[0] - return self.clean_caption(output) - - def generate_captions(self, image: Image): - - responses = [] - for prompt in prompts: - output = self.generate_caption(image, prompt) - responses.append(output) - # replace all . with , - return responses + return clean_caption(output, replacements=replacements) diff --git a/extensions_built_in/dataset_tools/tools/sync_tools.py b/extensions_built_in/dataset_tools/tools/sync_tools.py index e0b2ad66..143cc6bb 100644 --- a/extensions_built_in/dataset_tools/tools/sync_tools.py +++ b/extensions_built_in/dataset_tools/tools/sync_tools.py @@ -9,7 +9,7 @@ def img_root_path(img_id: str): if TYPE_CHECKING: - from ..dataset_tools_config_modules import DatasetSyncCollectionConfig + from .dataset_tools_config_modules import DatasetSyncCollectionConfig img_exts = ['.jpg', '.jpeg', '.webp', '.png'] diff --git a/requirements.txt b/requirements.txt index d1133b4a..efbccc4f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,7 +2,7 @@ torch torchvision safetensors diffusers==0.21.3 -transformers +git+https://github.com/huggingface/transformers.git@master lycoris-lora==1.8.3 flatten_json pyyaml diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index effa014d..c9a19fad 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -176,6 +176,7 @@ class StableDiffusion: dtype=dtype, device=self.device_torch, variant="fp16", + use_safetensors=True, **load_args ) else: