mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-02-23 13:53:57 +00:00
Added dataset tagging and management tools using llava
This commit is contained in:
169
extensions_built_in/dataset_tools/SuperTagger.py
Normal file
169
extensions_built_in/dataset_tools/SuperTagger.py
Normal file
@@ -0,0 +1,169 @@
|
||||
import copy
|
||||
import json
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
import gc
|
||||
from typing import Type, Literal
|
||||
|
||||
import torch
|
||||
from PIL import Image, ImageOps
|
||||
from tqdm import tqdm
|
||||
|
||||
from .dataset_tools_config_modules import RAW_DIR, TRAIN_DIR, Step, ImgInfo
|
||||
from .tools.image_tools import load_image
|
||||
from .tools.llava_utils import LLaVAImageProcessor, long_prompt, short_prompt
|
||||
from jobs.process import BaseExtensionProcess
|
||||
from .tools.sync_tools import get_img_paths
|
||||
|
||||
img_ext = ['.jpg', '.jpeg', '.png', '.webp']
|
||||
|
||||
|
||||
def flush():
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
|
||||
VERSION = 1
|
||||
|
||||
|
||||
class SuperTagger(BaseExtensionProcess):
|
||||
|
||||
def __init__(self, process_id: int, job, config: OrderedDict):
|
||||
super().__init__(process_id, job, config)
|
||||
parent_dir = config.get('parent_dir', None)
|
||||
self.dataset_paths: list[str] = config.get('dataset_paths', [])
|
||||
self.device = config.get('device', 'cuda')
|
||||
self.steps: list[Step] = config.get('steps', [])
|
||||
self.caption_method = config.get('caption_method', 'llava')
|
||||
self.force_reprocess_img = config.get('force_reprocess_img', False)
|
||||
self.master_dataset_dict = OrderedDict()
|
||||
self.dataset_master_config_file = config.get('dataset_master_config_file', None)
|
||||
if parent_dir is not None and len(self.dataset_paths) == 0:
|
||||
# find all folders in the patent_dataset_path
|
||||
self.dataset_paths = [
|
||||
os.path.join(parent_dir, folder)
|
||||
for folder in os.listdir(parent_dir)
|
||||
if os.path.isdir(os.path.join(parent_dir, folder))
|
||||
]
|
||||
else:
|
||||
# make sure they exist
|
||||
for dataset_path in self.dataset_paths:
|
||||
if not os.path.exists(dataset_path):
|
||||
raise ValueError(f"Dataset path does not exist: {dataset_path}")
|
||||
|
||||
print(f"Found {len(self.dataset_paths)} dataset paths")
|
||||
|
||||
self.image_processor = LLaVAImageProcessor(device=self.device)
|
||||
|
||||
def process_image(self, img_path: str):
|
||||
root_img_dir = os.path.dirname(os.path.dirname(img_path))
|
||||
filename = os.path.basename(img_path)
|
||||
filename_no_ext = os.path.splitext(filename)[0]
|
||||
train_dir = os.path.join(root_img_dir, TRAIN_DIR)
|
||||
train_img_path = os.path.join(train_dir, filename)
|
||||
json_path = os.path.join(train_dir, f"{filename_no_ext}.json")
|
||||
|
||||
# check if json exists, if it does load it as image info
|
||||
if os.path.exists(json_path):
|
||||
with open(json_path, 'r') as f:
|
||||
img_info = ImgInfo(**json.load(f))
|
||||
else:
|
||||
img_info = ImgInfo()
|
||||
|
||||
img_info.set_version(VERSION)
|
||||
|
||||
# send steps to img info so it can store them
|
||||
img_info.add_steps(copy.deepcopy(self.steps))
|
||||
|
||||
image: Image = None
|
||||
|
||||
did_update_image = False
|
||||
|
||||
# trigger reprocess of steps
|
||||
if self.force_reprocess_img:
|
||||
img_info.trigger_image_reprocess(steps=self.steps)
|
||||
|
||||
# set the image as updated if it does not exist on disk
|
||||
if not os.path.exists(train_img_path):
|
||||
did_update_image = True
|
||||
image = load_image(img_path)
|
||||
if img_info.force_image_process:
|
||||
did_update_image = True
|
||||
image = load_image(img_path)
|
||||
|
||||
|
||||
# go through the needed steps
|
||||
for step in img_info.state.steps_to_complete:
|
||||
if step == 'caption':
|
||||
# load image
|
||||
if image is None:
|
||||
image = load_image(img_path)
|
||||
|
||||
if not self.image_processor.is_loaded:
|
||||
print('Loading Model. Takes a while, especially the first time')
|
||||
self.image_processor.load_model()
|
||||
|
||||
img_info.caption = self.image_processor.generate_caption(image, prompt=long_prompt)
|
||||
img_info.mark_step_complete(step)
|
||||
elif step == 'caption_short':
|
||||
# load image
|
||||
if image is None:
|
||||
image = load_image(img_path)
|
||||
|
||||
if not self.image_processor.is_loaded:
|
||||
print('Loading Model. Takes a while, especially the first time')
|
||||
self.image_processor.load_model()
|
||||
img_info.caption_short = self.image_processor.generate_caption(image, prompt=short_prompt)
|
||||
img_info.mark_step_complete(step)
|
||||
elif step == 'contrast_stretch':
|
||||
# load image
|
||||
if image is None:
|
||||
image = load_image(img_path)
|
||||
image = ImageOps.autocontrast(image, cutoff=(0.1, 0), preserve_tone=True)
|
||||
did_update_image = True
|
||||
img_info.mark_step_complete(step)
|
||||
else:
|
||||
raise ValueError(f"Unknown step: {step}")
|
||||
|
||||
os.makedirs(os.path.dirname(train_img_path), exist_ok=True)
|
||||
if did_update_image:
|
||||
image.save(train_img_path)
|
||||
|
||||
if img_info.is_dirty:
|
||||
with open(json_path, 'w') as f:
|
||||
json.dump(img_info.to_dict(), f, indent=4)
|
||||
|
||||
if self.dataset_master_config_file:
|
||||
# add to master dict
|
||||
self.master_dataset_dict[train_img_path] = img_info.to_dict()
|
||||
|
||||
def run(self):
|
||||
super().run()
|
||||
imgs_to_process = []
|
||||
# find all images
|
||||
for dataset_path in self.dataset_paths:
|
||||
raw_dir = os.path.join(dataset_path, RAW_DIR)
|
||||
raw_image_paths = get_img_paths(raw_dir)
|
||||
for raw_image_path in raw_image_paths:
|
||||
imgs_to_process.append(raw_image_path)
|
||||
|
||||
if len(imgs_to_process) == 0:
|
||||
print(f"No images to process")
|
||||
else:
|
||||
print(f"Found {len(imgs_to_process)} to process")
|
||||
|
||||
for img_path in tqdm(imgs_to_process, desc="Processing images"):
|
||||
# try:
|
||||
# self.process_image(img_path)
|
||||
# except Exception as e:
|
||||
# print(f"Error processing {img_path}: {e}")
|
||||
# continue
|
||||
self.process_image(img_path)
|
||||
|
||||
if self.dataset_master_config_file is not None:
|
||||
# save it as json
|
||||
with open(self.dataset_master_config_file, 'w') as f:
|
||||
json.dump(self.master_dataset_dict, f, indent=4)
|
||||
|
||||
del self.image_processor
|
||||
flush()
|
||||
131
extensions_built_in/dataset_tools/SyncFromCollection.py
Normal file
131
extensions_built_in/dataset_tools/SyncFromCollection.py
Normal file
@@ -0,0 +1,131 @@
|
||||
import os
|
||||
import shutil
|
||||
from collections import OrderedDict
|
||||
import gc
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from .dataset_tools_config_modules import DatasetSyncCollectionConfig, RAW_DIR, NEW_DIR
|
||||
from .tools.sync_tools import get_unsplash_images, get_pexels_images, get_local_image_file_names, download_image, \
|
||||
get_img_paths
|
||||
from jobs.process import BaseExtensionProcess
|
||||
|
||||
|
||||
def flush():
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
|
||||
class SyncFromCollection(BaseExtensionProcess):
|
||||
|
||||
def __init__(self, process_id: int, job, config: OrderedDict):
|
||||
super().__init__(process_id, job, config)
|
||||
|
||||
self.min_width = config.get('min_width', 1024)
|
||||
self.min_height = config.get('min_height', 1024)
|
||||
|
||||
# add our min_width and min_height to each dataset config if they don't exist
|
||||
for dataset_config in config.get('dataset_sync', []):
|
||||
if 'min_width' not in dataset_config:
|
||||
dataset_config['min_width'] = self.min_width
|
||||
if 'min_height' not in dataset_config:
|
||||
dataset_config['min_height'] = self.min_height
|
||||
|
||||
self.dataset_configs: List[DatasetSyncCollectionConfig] = [
|
||||
DatasetSyncCollectionConfig(**dataset_config)
|
||||
for dataset_config in config.get('dataset_sync', [])
|
||||
]
|
||||
print(f"Found {len(self.dataset_configs)} dataset configs")
|
||||
|
||||
def move_new_images(self, root_dir: str):
|
||||
raw_dir = os.path.join(root_dir, RAW_DIR)
|
||||
new_dir = os.path.join(root_dir, NEW_DIR)
|
||||
new_images = get_img_paths(new_dir)
|
||||
|
||||
for img_path in new_images:
|
||||
# move to raw
|
||||
new_path = os.path.join(raw_dir, os.path.basename(img_path))
|
||||
shutil.move(img_path, new_path)
|
||||
|
||||
# remove new dir
|
||||
shutil.rmtree(new_dir)
|
||||
|
||||
def sync_dataset(self, config: DatasetSyncCollectionConfig):
|
||||
if config.host == 'unsplash':
|
||||
get_images = get_unsplash_images
|
||||
elif config.host == 'pexels':
|
||||
get_images = get_pexels_images
|
||||
else:
|
||||
raise ValueError(f"Unknown host: {config.host}")
|
||||
|
||||
results = {
|
||||
'num_downloaded': 0,
|
||||
'num_skipped': 0,
|
||||
'bad': 0,
|
||||
'total': 0,
|
||||
}
|
||||
|
||||
photos = get_images(config)
|
||||
raw_dir = os.path.join(config.directory, RAW_DIR)
|
||||
new_dir = os.path.join(config.directory, NEW_DIR)
|
||||
raw_images = get_local_image_file_names(raw_dir)
|
||||
new_images = get_local_image_file_names(new_dir)
|
||||
|
||||
for photo in tqdm(photos, desc=f"{config.host}-{config.collection_id}"):
|
||||
try:
|
||||
if photo.filename not in raw_images and photo.filename not in new_images:
|
||||
download_image(photo, new_dir, min_width=self.min_width, min_height=self.min_height)
|
||||
results['num_downloaded'] += 1
|
||||
else:
|
||||
results['num_skipped'] += 1
|
||||
except Exception as e:
|
||||
print(f" - BAD({photo.id}): {e}")
|
||||
results['bad'] += 1
|
||||
continue
|
||||
results['total'] += 1
|
||||
|
||||
return results
|
||||
|
||||
def print_results(self, results):
|
||||
print(
|
||||
f" - new:{results['num_downloaded']}, old:{results['num_skipped']}, bad:{results['bad']} total:{results['total']}")
|
||||
|
||||
def run(self):
|
||||
super().run()
|
||||
print(f"Syncing {len(self.dataset_configs)} datasets")
|
||||
all_results = None
|
||||
failed_datasets = []
|
||||
for dataset_config in tqdm(self.dataset_configs, desc="Syncing datasets", leave=True):
|
||||
try:
|
||||
results = self.sync_dataset(dataset_config)
|
||||
if all_results is None:
|
||||
all_results = {**results}
|
||||
else:
|
||||
for key, value in results.items():
|
||||
all_results[key] += value
|
||||
|
||||
self.print_results(results)
|
||||
except Exception as e:
|
||||
print(f" - FAILED: {e}")
|
||||
if 'response' in e.__dict__:
|
||||
error = f"{e.response.status_code}: {e.response.text}"
|
||||
print(f" - {error}")
|
||||
failed_datasets.append({'dataset': dataset_config, 'error': error})
|
||||
else:
|
||||
failed_datasets.append({'dataset': dataset_config, 'error': str(e)})
|
||||
continue
|
||||
|
||||
print("Moving new images to raw")
|
||||
for dataset_config in self.dataset_configs:
|
||||
self.move_new_images(dataset_config.directory)
|
||||
|
||||
print("Done syncing datasets")
|
||||
self.print_results(all_results)
|
||||
|
||||
if len(failed_datasets) > 0:
|
||||
print(f"Failed to sync {len(failed_datasets)} datasets")
|
||||
for failed in failed_datasets:
|
||||
print(f" - {failed['dataset'].host}-{failed['dataset'].collection_id}")
|
||||
print(f" - ERR: {failed['error']}")
|
||||
@@ -1,10 +1,7 @@
|
||||
# This is an example extension for custom training. It is great for experimenting with new ideas.
|
||||
from toolkit.extension import Extension
|
||||
|
||||
|
||||
# This is for generic training (LoRA, Dreambooth, FineTuning)
|
||||
class DatasetToolsExtension(Extension):
|
||||
# uid must be unique, it is how the extension is identified
|
||||
uid = "dataset_tools"
|
||||
|
||||
# name is the name of the extension for printing
|
||||
@@ -19,7 +16,28 @@ class DatasetToolsExtension(Extension):
|
||||
return DatasetTools
|
||||
|
||||
|
||||
class SyncFromCollectionExtension(Extension):
|
||||
uid = "sync_from_collection"
|
||||
name = "Sync from Collection"
|
||||
|
||||
@classmethod
|
||||
def get_process(cls):
|
||||
# import your process class here so it is only loaded when needed and return it
|
||||
from .SyncFromCollection import SyncFromCollection
|
||||
return SyncFromCollection
|
||||
|
||||
|
||||
class SuperTaggerExtension(Extension):
|
||||
uid = "super_tagger"
|
||||
name = "Super Tagger"
|
||||
|
||||
@classmethod
|
||||
def get_process(cls):
|
||||
# import your process class here so it is only loaded when needed and return it
|
||||
from .SuperTagger import SuperTagger
|
||||
return SuperTagger
|
||||
|
||||
|
||||
AI_TOOLKIT_EXTENSIONS = [
|
||||
# you can put a list of extensions here
|
||||
DatasetToolsExtension,
|
||||
SyncFromCollectionExtension, DatasetToolsExtension, SuperTaggerExtension
|
||||
]
|
||||
|
||||
@@ -0,0 +1,145 @@
|
||||
import json
|
||||
from typing import Literal, Type, TYPE_CHECKING
|
||||
|
||||
Host: Type = Literal['unsplash', 'pexels']
|
||||
|
||||
RAW_DIR = "raw"
|
||||
NEW_DIR = "_tmp"
|
||||
TRAIN_DIR = "train"
|
||||
DEPTH_DIR = "depth"
|
||||
|
||||
from .tools.image_tools import Step, img_manipulation_steps
|
||||
|
||||
|
||||
class DatasetSyncCollectionConfig:
|
||||
def __init__(self, **kwargs):
|
||||
self.host: Host = kwargs.get('host', None)
|
||||
self.collection_id: str = kwargs.get('collection_id', None)
|
||||
self.directory: str = kwargs.get('directory', None)
|
||||
self.api_key: str = kwargs.get('api_key', None)
|
||||
self.min_width: int = kwargs.get('min_width', 1024)
|
||||
self.min_height: int = kwargs.get('min_height', 1024)
|
||||
|
||||
if self.host is None:
|
||||
raise ValueError("host is required")
|
||||
if self.collection_id is None:
|
||||
raise ValueError("collection_id is required")
|
||||
if self.directory is None:
|
||||
raise ValueError("directory is required")
|
||||
if self.api_key is None:
|
||||
raise ValueError(f"api_key is required: {self.host}:{self.collection_id}")
|
||||
|
||||
|
||||
class ImageState:
|
||||
def __init__(self, **kwargs):
|
||||
self.steps_complete: list[Step] = kwargs.get('steps_complete', [])
|
||||
self.steps_to_complete: list[Step] = kwargs.get('steps_to_complete', [])
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
'steps_complete': self.steps_complete
|
||||
}
|
||||
|
||||
|
||||
class Rect:
|
||||
def __init__(self, **kwargs):
|
||||
self.x = kwargs.get('x', 0)
|
||||
self.y = kwargs.get('y', 0)
|
||||
self.width = kwargs.get('width', 0)
|
||||
self.height = kwargs.get('height', 0)
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
'x': self.x,
|
||||
'y': self.y,
|
||||
'width': self.width,
|
||||
'height': self.height
|
||||
}
|
||||
|
||||
|
||||
class ImgInfo:
|
||||
def __init__(self, **kwargs):
|
||||
self.version: int = kwargs.get('version', None)
|
||||
self.caption: str = kwargs.get('caption', None)
|
||||
self.caption_short: str = kwargs.get('caption_short', None)
|
||||
self.poi = [Rect(**poi) for poi in kwargs.get('poi', [])]
|
||||
self.state = ImageState(**kwargs.get('state', {}))
|
||||
self._upgrade_state()
|
||||
self.force_image_process: bool = False
|
||||
|
||||
self.is_dirty: bool = False
|
||||
|
||||
def _upgrade_state(self):
|
||||
# upgrades older states
|
||||
if self.caption is not None and 'caption' not in self.state.steps_complete:
|
||||
self.mark_step_complete('caption')
|
||||
self.is_dirty = True
|
||||
if self.caption_short is not None and 'caption_short' not in self.state.steps_complete:
|
||||
self.mark_step_complete('caption_short')
|
||||
self.is_dirty = True
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
'version': self.version,
|
||||
'caption': self.caption,
|
||||
'caption_short': self.caption_short,
|
||||
'poi': [poi.to_dict() for poi in self.poi],
|
||||
'state': self.state.to_dict()
|
||||
}
|
||||
|
||||
def mark_step_complete(self, step: Step):
|
||||
if step not in self.state.steps_complete:
|
||||
self.state.steps_complete.append(step)
|
||||
if step in self.state.steps_to_complete:
|
||||
self.state.steps_to_complete.remove(step)
|
||||
self.is_dirty = True
|
||||
|
||||
def add_step(self, step: Step):
|
||||
if step not in self.state.steps_to_complete and step not in self.state.steps_complete:
|
||||
self.state.steps_to_complete.append(step)
|
||||
|
||||
def trigger_image_reprocess(self, steps):
|
||||
# remove all image manipulationf from steps_to_complete
|
||||
for step in img_manipulation_steps:
|
||||
if step in self.state.steps_to_complete:
|
||||
self.state.steps_to_complete.remove(step)
|
||||
if step in self.state.steps_complete:
|
||||
self.state.steps_complete.remove(step)
|
||||
self.force_image_process = True
|
||||
self.is_dirty = True
|
||||
# we want to keep the order passed in process file
|
||||
for step in steps:
|
||||
if step in img_manipulation_steps:
|
||||
self.add_step(step)
|
||||
|
||||
|
||||
def add_steps(self, steps: list[Step]):
|
||||
for stage in steps:
|
||||
self.add_step(stage)
|
||||
|
||||
# update steps if we have any img processes not complete, we have to reprocess them all
|
||||
# if any steps_to_complete are in img_manipulation_steps
|
||||
# TODO check if they are in a new order now ands trigger a redo
|
||||
|
||||
is_manipulating_image = any([step in img_manipulation_steps for step in self.state.steps_to_complete])
|
||||
order_has_changed = False
|
||||
|
||||
if not is_manipulating_image:
|
||||
# check to see if order has changed. No need to if already redoing it. Will detect if ones are removed
|
||||
target_img_manipulation_order = [step for step in steps if step in img_manipulation_steps]
|
||||
current_img_manipulation_order = [step for step in self.state.steps_complete if
|
||||
step in img_manipulation_steps]
|
||||
if target_img_manipulation_order != current_img_manipulation_order:
|
||||
order_has_changed = True
|
||||
|
||||
if is_manipulating_image or order_has_changed:
|
||||
self.trigger_image_reprocess(steps)
|
||||
|
||||
|
||||
def to_json(self):
|
||||
return json.dumps(self.to_dict())
|
||||
|
||||
def set_version(self, version: int):
|
||||
if self.version != version:
|
||||
self.is_dirty = True
|
||||
self.version = version
|
||||
29
extensions_built_in/dataset_tools/tools/image_tools.py
Normal file
29
extensions_built_in/dataset_tools/tools/image_tools.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from typing import Literal, Type
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PIL import Image, ImageOps
|
||||
|
||||
Step: Type = Literal['caption', 'caption_short', 'create_mask', 'contrast_stretch']
|
||||
|
||||
img_manipulation_steps = ['contrast_stretch']
|
||||
|
||||
|
||||
def pil_to_cv2(image):
|
||||
"""Convert a PIL image to a cv2 image."""
|
||||
return cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
|
||||
|
||||
|
||||
def cv2_to_pil(image):
|
||||
"""Convert a cv2 image to a PIL image."""
|
||||
return Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
|
||||
|
||||
def load_image(img_path: str):
|
||||
image = Image.open(img_path).convert('RGB')
|
||||
try:
|
||||
# transpose with exif data
|
||||
image = ImageOps.exif_transpose(image)
|
||||
except Exception as e:
|
||||
pass
|
||||
return image
|
||||
|
||||
129
extensions_built_in/dataset_tools/tools/llava_utils.py
Normal file
129
extensions_built_in/dataset_tools/tools/llava_utils.py
Normal file
@@ -0,0 +1,129 @@
|
||||
try:
|
||||
from llava.model import LlavaLlamaForCausalLM
|
||||
except ImportError:
|
||||
# print("You need to manually install llava -> pip install --no-deps git+https://github.com/haotian-liu/LLaVA.git")
|
||||
print("You need to manually install llava -> pip install --no-deps git+https://github.com/haotian-liu/LLaVA.git")
|
||||
raise
|
||||
|
||||
long_prompt = 'caption this image. describe every single thing in the image in detail. Do not include any unnecessary words in your description for the sake of good grammar. I want many short statements that serve the single purpose of giving the most thorough description if items as possible in the smallest, comma separated way possible. be sure to describe people\'s moods, clothing, the environment, lighting, colors, and everything.'
|
||||
short_prompt = 'caption this image in less than ten words'
|
||||
|
||||
prompts = [
|
||||
long_prompt,
|
||||
short_prompt,
|
||||
]
|
||||
|
||||
replacements = [
|
||||
("the image features", ""),
|
||||
("the image shows", ""),
|
||||
("the image depicts", ""),
|
||||
("the image is", ""),
|
||||
]
|
||||
|
||||
import torch
|
||||
from PIL import Image, ImageOps
|
||||
from llava.conversation import conv_templates, SeparatorStyle
|
||||
from transformers import AutoTokenizer, BitsAndBytesConfig, CLIPImageProcessor
|
||||
from llava.utils import disable_torch_init
|
||||
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
|
||||
from llava.mm_utils import tokenizer_image_token, KeywordsStoppingCriteria
|
||||
|
||||
img_ext = ['.jpg', '.jpeg', '.png', '.webp']
|
||||
|
||||
|
||||
class LLaVAImageProcessor:
|
||||
def __init__(self, device='cuda'):
|
||||
self.device = device
|
||||
self.model: LlavaLlamaForCausalLM = None
|
||||
self.tokenizer: AutoTokenizer = None
|
||||
self.image_processor: CLIPImageProcessor = None
|
||||
self.is_loaded = False
|
||||
|
||||
def load_model(self):
|
||||
from llava.model import LlavaLlamaForCausalLM
|
||||
|
||||
model_path = "4bit/llava-v1.5-13b-3GB"
|
||||
# kwargs = {"device_map": "auto"}
|
||||
kwargs = {"device_map": self.device}
|
||||
kwargs['load_in_4bit'] = True
|
||||
kwargs['quantization_config'] = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_compute_dtype=torch.float16,
|
||||
bnb_4bit_use_double_quant=True,
|
||||
bnb_4bit_quant_type='nf4'
|
||||
)
|
||||
self.model = LlavaLlamaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
|
||||
vision_tower = self.model.get_vision_tower()
|
||||
if not vision_tower.is_loaded:
|
||||
vision_tower.load_model()
|
||||
vision_tower.to(device=self.device)
|
||||
self.image_processor = vision_tower.image_processor
|
||||
self.is_loaded = True
|
||||
|
||||
def clean_caption(self, cap):
|
||||
# remove any newlines
|
||||
cap = cap.replace("\n", ", ")
|
||||
cap = cap.replace("\r", ", ")
|
||||
cap = cap.replace(".", ",")
|
||||
cap = cap.replace("\"", "")
|
||||
|
||||
# remove unicode characters
|
||||
cap = cap.encode('ascii', 'ignore').decode('ascii')
|
||||
|
||||
# make lowercase
|
||||
cap = cap.lower()
|
||||
# remove any extra spaces
|
||||
cap = " ".join(cap.split())
|
||||
|
||||
for replacement in replacements:
|
||||
cap = cap.replace(replacement[0], replacement[1])
|
||||
|
||||
cap_list = cap.split(",")
|
||||
# trim whitespace
|
||||
cap_list = [c.strip() for c in cap_list]
|
||||
# remove empty strings
|
||||
cap_list = [c for c in cap_list if c != ""]
|
||||
# remove duplicates
|
||||
cap_list = list(dict.fromkeys(cap_list))
|
||||
# join back together
|
||||
cap = ", ".join(cap_list)
|
||||
return cap
|
||||
|
||||
def generate_caption(self, image: Image, prompt: str = long_prompt):
|
||||
# question = "how many dogs are in the picture?"
|
||||
disable_torch_init()
|
||||
conv_mode = "llava_v0"
|
||||
conv = conv_templates[conv_mode].copy()
|
||||
roles = conv.roles
|
||||
image_tensor = self.image_processor.preprocess([image], return_tensors='pt')['pixel_values'].half().cuda()
|
||||
|
||||
inp = f"{roles[0]}: {prompt}"
|
||||
inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + inp
|
||||
conv.append_message(conv.roles[0], inp)
|
||||
conv.append_message(conv.roles[1], None)
|
||||
raw_prompt = conv.get_prompt()
|
||||
input_ids = tokenizer_image_token(raw_prompt, self.tokenizer, IMAGE_TOKEN_INDEX,
|
||||
return_tensors='pt').unsqueeze(0).cuda()
|
||||
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
|
||||
keywords = [stop_str]
|
||||
stopping_criteria = KeywordsStoppingCriteria(keywords, self.tokenizer, input_ids)
|
||||
with torch.inference_mode():
|
||||
output_ids = self.model.generate(
|
||||
input_ids, images=image_tensor, do_sample=True, temperature=0.1,
|
||||
max_new_tokens=1024, use_cache=True, stopping_criteria=[stopping_criteria],
|
||||
top_p=0.9
|
||||
)
|
||||
outputs = self.tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
|
||||
conv.messages[-1][-1] = outputs
|
||||
output = outputs.rsplit('</s>', 1)[0]
|
||||
return self.clean_caption(output)
|
||||
|
||||
def generate_captions(self, image: Image):
|
||||
|
||||
responses = []
|
||||
for prompt in prompts:
|
||||
output = self.generate_caption(image, prompt)
|
||||
responses.append(output)
|
||||
# replace all . with ,
|
||||
return responses
|
||||
279
extensions_built_in/dataset_tools/tools/sync_tools.py
Normal file
279
extensions_built_in/dataset_tools/tools/sync_tools.py
Normal file
@@ -0,0 +1,279 @@
|
||||
import os
|
||||
import requests
|
||||
import tqdm
|
||||
from typing import List, Optional, TYPE_CHECKING
|
||||
|
||||
|
||||
def img_root_path(img_id: str):
|
||||
return os.path.dirname(os.path.dirname(img_id))
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..dataset_tools_config_modules import DatasetSyncCollectionConfig
|
||||
|
||||
img_exts = ['.jpg', '.jpeg', '.webp', '.png']
|
||||
|
||||
class Photo:
|
||||
def __init__(
|
||||
self,
|
||||
id,
|
||||
host,
|
||||
width,
|
||||
height,
|
||||
url,
|
||||
filename
|
||||
):
|
||||
self.id = str(id)
|
||||
self.host = host
|
||||
self.width = width
|
||||
self.height = height
|
||||
self.url = url
|
||||
self.filename = filename
|
||||
|
||||
|
||||
def get_desired_size(img_width: int, img_height: int, min_width: int, min_height: int):
|
||||
if img_width > img_height:
|
||||
scale = min_height / img_height
|
||||
else:
|
||||
scale = min_width / img_width
|
||||
|
||||
new_width = int(img_width * scale)
|
||||
new_height = int(img_height * scale)
|
||||
|
||||
return new_width, new_height
|
||||
|
||||
|
||||
def get_pexels_images(config: 'DatasetSyncCollectionConfig') -> List[Photo]:
|
||||
all_images = []
|
||||
next_page = f"https://api.pexels.com/v1/collections/{config.collection_id}?page=1&per_page=80&type=photos"
|
||||
|
||||
while True:
|
||||
response = requests.get(next_page, headers={
|
||||
"Authorization": f"{config.api_key}"
|
||||
})
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
all_images.extend(data['media'])
|
||||
if 'next_page' in data and data['next_page']:
|
||||
next_page = data['next_page']
|
||||
else:
|
||||
break
|
||||
|
||||
photos = []
|
||||
for image in all_images:
|
||||
new_width, new_height = get_desired_size(image['width'], image['height'], config.min_width, config.min_height)
|
||||
url = f"{image['src']['original']}?auto=compress&cs=tinysrgb&h={new_height}&w={new_width}"
|
||||
filename = os.path.basename(image['src']['original'])
|
||||
|
||||
photos.append(Photo(
|
||||
id=image['id'],
|
||||
host="pexels",
|
||||
width=image['width'],
|
||||
height=image['height'],
|
||||
url=url,
|
||||
filename=filename
|
||||
))
|
||||
|
||||
return photos
|
||||
|
||||
|
||||
def get_unsplash_images(config: 'DatasetSyncCollectionConfig') -> List[Photo]:
|
||||
headers = {
|
||||
# "Authorization": f"Client-ID {UNSPLASH_ACCESS_KEY}"
|
||||
"Authorization": f"Client-ID {config.api_key}"
|
||||
}
|
||||
# headers['Authorization'] = f"Bearer {token}"
|
||||
|
||||
url = f"https://api.unsplash.com/collections/{config.collection_id}/photos?page=1&per_page=30"
|
||||
response = requests.get(url, headers=headers)
|
||||
response.raise_for_status()
|
||||
res_headers = response.headers
|
||||
# parse the link header to get the next page
|
||||
# 'Link': '<https://api.unsplash.com/collections/mIPWwLdfct8/photos?page=82>; rel="last", <https://api.unsplash.com/collections/mIPWwLdfct8/photos?page=2>; rel="next"'
|
||||
has_next_page = False
|
||||
if 'Link' in res_headers:
|
||||
has_next_page = True
|
||||
link_header = res_headers['Link']
|
||||
link_header = link_header.split(',')
|
||||
link_header = [link.strip() for link in link_header]
|
||||
link_header = [link.split(';') for link in link_header]
|
||||
link_header = [[link[0].strip('<>'), link[1].strip().strip('"')] for link in link_header]
|
||||
link_header = {link[1]: link[0] for link in link_header}
|
||||
|
||||
# get page number from last url
|
||||
last_page = link_header['rel="last']
|
||||
last_page = last_page.split('?')[1]
|
||||
last_page = last_page.split('&')
|
||||
last_page = [param.split('=') for param in last_page]
|
||||
last_page = {param[0]: param[1] for param in last_page}
|
||||
last_page = int(last_page['page'])
|
||||
|
||||
all_images = response.json()
|
||||
|
||||
if has_next_page:
|
||||
# assume we start on page 1, so we don't need to get it again
|
||||
for page in tqdm.tqdm(range(2, last_page + 1)):
|
||||
url = f"https://api.unsplash.com/collections/{config.collection_id}/photos?page={page}&per_page=30"
|
||||
response = requests.get(url, headers=headers)
|
||||
response.raise_for_status()
|
||||
all_images.extend(response.json())
|
||||
|
||||
photos = []
|
||||
for image in all_images:
|
||||
new_width, new_height = get_desired_size(image['width'], image['height'], config.min_width, config.min_height)
|
||||
url = f"{image['urls']['raw']}&w={new_width}"
|
||||
filename = f"{image['id']}.jpg"
|
||||
|
||||
photos.append(Photo(
|
||||
id=image['id'],
|
||||
host="unsplash",
|
||||
width=image['width'],
|
||||
height=image['height'],
|
||||
url=url,
|
||||
filename=filename
|
||||
))
|
||||
|
||||
return photos
|
||||
|
||||
|
||||
def get_img_paths(dir_path: str):
|
||||
os.makedirs(dir_path, exist_ok=True)
|
||||
local_files = os.listdir(dir_path)
|
||||
# remove non image files
|
||||
local_files = [file for file in local_files if os.path.splitext(file)[1].lower() in img_exts]
|
||||
# make full path
|
||||
local_files = [os.path.join(dir_path, file) for file in local_files]
|
||||
return local_files
|
||||
|
||||
|
||||
def get_local_image_ids(dir_path: str):
|
||||
os.makedirs(dir_path, exist_ok=True)
|
||||
local_files = get_img_paths(dir_path)
|
||||
# assuming local files are named after Unsplash IDs, e.g., 'abc123.jpg'
|
||||
return set([os.path.basename(file).split('.')[0] for file in local_files])
|
||||
|
||||
|
||||
def get_local_image_file_names(dir_path: str):
|
||||
os.makedirs(dir_path, exist_ok=True)
|
||||
local_files = get_img_paths(dir_path)
|
||||
# assuming local files are named after Unsplash IDs, e.g., 'abc123.jpg'
|
||||
return set([os.path.basename(file) for file in local_files])
|
||||
|
||||
|
||||
def download_image(photo: Photo, dir_path: str, min_width: int = 1024, min_height: int = 1024):
|
||||
img_width = photo.width
|
||||
img_height = photo.height
|
||||
|
||||
if img_width < min_width or img_height < min_height:
|
||||
raise ValueError(f"Skipping {photo.id} because it is too small: {img_width}x{img_height}")
|
||||
|
||||
img_response = requests.get(photo.url)
|
||||
img_response.raise_for_status()
|
||||
os.makedirs(dir_path, exist_ok=True)
|
||||
|
||||
filename = os.path.join(dir_path, photo.filename)
|
||||
with open(filename, 'wb') as file:
|
||||
file.write(img_response.content)
|
||||
|
||||
|
||||
def update_caption(img_path: str):
|
||||
# if the caption is a txt file, convert it to a json file
|
||||
filename_no_ext = os.path.splitext(os.path.basename(img_path))[0]
|
||||
# see if it exists
|
||||
if os.path.exists(os.path.join(os.path.dirname(img_path), f"{filename_no_ext}.json")):
|
||||
# todo add poi and what not
|
||||
return # we have a json file
|
||||
caption = ""
|
||||
# see if txt file exists
|
||||
if os.path.exists(os.path.join(os.path.dirname(img_path), f"{filename_no_ext}.txt")):
|
||||
# read it
|
||||
with open(os.path.join(os.path.dirname(img_path), f"{filename_no_ext}.txt"), 'r') as file:
|
||||
caption = file.read()
|
||||
# write json file
|
||||
with open(os.path.join(os.path.dirname(img_path), f"{filename_no_ext}.json"), 'w') as file:
|
||||
file.write(f'{{"caption": "{caption}"}}')
|
||||
|
||||
# delete txt file
|
||||
os.remove(os.path.join(os.path.dirname(img_path), f"{filename_no_ext}.txt"))
|
||||
|
||||
|
||||
# def equalize_img(img_path: str):
|
||||
# input_path = img_path
|
||||
# output_path = os.path.join(img_root_path(img_path), COLOR_CORRECTED_DIR, os.path.basename(img_path))
|
||||
# os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||||
# process_img(
|
||||
# img_path=input_path,
|
||||
# output_path=output_path,
|
||||
# equalize=True,
|
||||
# max_size=2056,
|
||||
# white_balance=False,
|
||||
# gamma_correction=False,
|
||||
# strength=0.6,
|
||||
# )
|
||||
|
||||
|
||||
# def annotate_depth(img_path: str):
|
||||
# # make fake args
|
||||
# args = argparse.Namespace()
|
||||
# args.annotator = "midas"
|
||||
# args.res = 1024
|
||||
#
|
||||
# img = cv2.imread(img_path)
|
||||
# img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
#
|
||||
# output = annotate(img, args)
|
||||
#
|
||||
# output = output.astype('uint8')
|
||||
# output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
|
||||
#
|
||||
# os.makedirs(os.path.dirname(img_path), exist_ok=True)
|
||||
# output_path = os.path.join(img_root_path(img_path), DEPTH_DIR, os.path.basename(img_path))
|
||||
#
|
||||
# cv2.imwrite(output_path, output)
|
||||
|
||||
|
||||
# def invert_depth(img_path: str):
|
||||
# img = cv2.imread(img_path)
|
||||
# img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
# # invert the colors
|
||||
# img = cv2.bitwise_not(img)
|
||||
#
|
||||
# os.makedirs(os.path.dirname(img_path), exist_ok=True)
|
||||
# output_path = os.path.join(img_root_path(img_path), INVERTED_DEPTH_DIR, os.path.basename(img_path))
|
||||
# cv2.imwrite(output_path, img)
|
||||
|
||||
|
||||
#
|
||||
# # update our list of raw images
|
||||
# raw_images = get_img_paths(raw_dir)
|
||||
#
|
||||
# # update raw captions
|
||||
# for image_id in tqdm.tqdm(raw_images, desc="Updating raw captions"):
|
||||
# update_caption(image_id)
|
||||
#
|
||||
# # equalize images
|
||||
# for img_path in tqdm.tqdm(raw_images, desc="Equalizing images"):
|
||||
# if img_path not in eq_images:
|
||||
# equalize_img(img_path)
|
||||
#
|
||||
# # update our list of eq images
|
||||
# eq_images = get_img_paths(eq_dir)
|
||||
# # update eq captions
|
||||
# for image_id in tqdm.tqdm(eq_images, desc="Updating eq captions"):
|
||||
# update_caption(image_id)
|
||||
#
|
||||
# # annotate depth
|
||||
# depth_dir = os.path.join(root_dir, DEPTH_DIR)
|
||||
# depth_images = get_img_paths(depth_dir)
|
||||
# for img_path in tqdm.tqdm(eq_images, desc="Annotating depth"):
|
||||
# if img_path not in depth_images:
|
||||
# annotate_depth(img_path)
|
||||
#
|
||||
# depth_images = get_img_paths(depth_dir)
|
||||
#
|
||||
# # invert depth
|
||||
# inv_depth_dir = os.path.join(root_dir, INVERTED_DEPTH_DIR)
|
||||
# inv_depth_images = get_img_paths(inv_depth_dir)
|
||||
# for img_path in tqdm.tqdm(depth_images, desc="Inverting depth"):
|
||||
# if img_path not in inv_depth_images:
|
||||
# invert_depth(img_path)
|
||||
@@ -74,12 +74,18 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
):
|
||||
loss_target = self.train_config.loss_target
|
||||
|
||||
prior_mask_multiplier = None
|
||||
target_mask_multiplier = None
|
||||
|
||||
if self.train_config.inverted_mask_prior:
|
||||
# we need to make the noise prediction be a masked blending of noise and prior_pred
|
||||
prior_multiplier = 1.0 - mask_multiplier
|
||||
target = (noise * mask_multiplier) + (prior_pred * prior_multiplier)
|
||||
prior_mask_multiplier = 1.0 - mask_multiplier
|
||||
# target_mask_multiplier = mask_multiplier
|
||||
# mask_multiplier = 1.0
|
||||
target = noise
|
||||
# target = (noise * mask_multiplier) + (prior_pred * prior_mask_multiplier)
|
||||
# set masked multiplier to 1.0 so we dont double apply it
|
||||
mask_multiplier = 1.0
|
||||
# mask_multiplier = 1.0
|
||||
elif prior_pred is not None:
|
||||
# matching adapter prediction
|
||||
target = prior_pred
|
||||
@@ -128,6 +134,16 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
# multiply by our mask
|
||||
loss = loss * mask_multiplier
|
||||
|
||||
if self.train_config.inverted_mask_prior:
|
||||
# to a loss to unmasked areas of the prior for unmasked regularization
|
||||
prior_loss = torch.nn.functional.mse_loss(
|
||||
prior_pred.float(),
|
||||
pred.float(),
|
||||
reduction="none"
|
||||
)
|
||||
prior_loss = prior_loss * prior_mask_multiplier * self.train_config.inverted_mask_prior_multiplier
|
||||
loss = loss + prior_loss
|
||||
|
||||
loss = loss.mean([1, 2, 3])
|
||||
|
||||
if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001 and not ignore_snr:
|
||||
|
||||
@@ -20,4 +20,5 @@ k-diffusion
|
||||
open_clip_torch
|
||||
timm
|
||||
prodigyopt
|
||||
controlnet_aux==0.0.7
|
||||
controlnet_aux==0.0.7
|
||||
python-dotenv
|
||||
3
run.py
3
run.py
@@ -1,6 +1,9 @@
|
||||
import os
|
||||
import sys
|
||||
from typing import Union, OrderedDict
|
||||
from dotenv import load_dotenv
|
||||
# Load the .env file if it exists
|
||||
load_dotenv()
|
||||
|
||||
sys.path.insert(0, os.getcwd())
|
||||
# must come before ANY torch or fastai imports
|
||||
|
||||
@@ -17,6 +17,24 @@ def get_cwd_abs_path(path):
|
||||
return path
|
||||
|
||||
|
||||
def replace_env_vars_in_string(s: str) -> str:
|
||||
"""
|
||||
Replace placeholders like ${VAR_NAME} with the value of the corresponding environment variable.
|
||||
If the environment variable is not set, raise an error.
|
||||
"""
|
||||
|
||||
def replacer(match):
|
||||
var_name = match.group(1)
|
||||
value = os.environ.get(var_name)
|
||||
|
||||
if value is None:
|
||||
raise ValueError(f"Environment variable {var_name} not set. Please ensure it's defined before proceeding.")
|
||||
|
||||
return value
|
||||
|
||||
return re.sub(r'\$\{([^}]+)\}', replacer, s)
|
||||
|
||||
|
||||
def preprocess_config(config: OrderedDict, name: str = None):
|
||||
if "job" not in config:
|
||||
raise ValueError("config file must have a job key")
|
||||
@@ -81,13 +99,14 @@ def get_config(
|
||||
raise ValueError(f"Could not find config file {config_file_path}")
|
||||
|
||||
# if we found it, check if it is a json or yaml file
|
||||
if real_config_path.endswith('.json') or real_config_path.endswith('.jsonc'):
|
||||
with open(real_config_path, 'r', encoding='utf-8') as f:
|
||||
config = json.load(f, object_pairs_hook=OrderedDict)
|
||||
elif real_config_path.endswith('.yaml') or real_config_path.endswith('.yml'):
|
||||
with open(real_config_path, 'r', encoding='utf-8') as f:
|
||||
config = yaml.load(f, Loader=fixed_loader)
|
||||
else:
|
||||
raise ValueError(f"Config file {config_file_path} must be a json or yaml file")
|
||||
with open(real_config_path, 'r', encoding='utf-8') as f:
|
||||
content = f.read()
|
||||
content_with_env_replaced = replace_env_vars_in_string(content)
|
||||
if real_config_path.endswith('.json') or real_config_path.endswith('.jsonc'):
|
||||
config = json.loads(content_with_env_replaced, object_pairs_hook=OrderedDict)
|
||||
elif real_config_path.endswith('.yaml') or real_config_path.endswith('.yml'):
|
||||
config = yaml.load(content_with_env_replaced, Loader=fixed_loader)
|
||||
else:
|
||||
raise ValueError(f"Config file {config_file_path} must be a json or yaml file")
|
||||
|
||||
return preprocess_config(config, name)
|
||||
|
||||
@@ -133,6 +133,7 @@ class TrainConfig:
|
||||
# we will predict noise without a the LoRa network and use the prediction as a target for
|
||||
# unmasked reign. It is unmasked regularization basically
|
||||
self.inverted_mask_prior = kwargs.get('inverted_mask_prior', False)
|
||||
self.inverted_mask_prior_multiplier = kwargs.get('inverted_mask_prior_multiplier', 0.5)
|
||||
|
||||
# legacy
|
||||
if match_adapter_assist and self.match_adapter_chance == 0.0:
|
||||
|
||||
Reference in New Issue
Block a user