Added dataset tagging and management tools using llava

This commit is contained in:
Jaret Burkett
2023-10-24 10:11:47 -06:00
parent dc36bbb3c8
commit 34eb563d55
12 changed files with 957 additions and 17 deletions

View File

@@ -0,0 +1,169 @@
import copy
import json
import os
from collections import OrderedDict
import gc
from typing import Type, Literal
import torch
from PIL import Image, ImageOps
from tqdm import tqdm
from .dataset_tools_config_modules import RAW_DIR, TRAIN_DIR, Step, ImgInfo
from .tools.image_tools import load_image
from .tools.llava_utils import LLaVAImageProcessor, long_prompt, short_prompt
from jobs.process import BaseExtensionProcess
from .tools.sync_tools import get_img_paths
img_ext = ['.jpg', '.jpeg', '.png', '.webp']
def flush():
torch.cuda.empty_cache()
gc.collect()
VERSION = 1
class SuperTagger(BaseExtensionProcess):
def __init__(self, process_id: int, job, config: OrderedDict):
super().__init__(process_id, job, config)
parent_dir = config.get('parent_dir', None)
self.dataset_paths: list[str] = config.get('dataset_paths', [])
self.device = config.get('device', 'cuda')
self.steps: list[Step] = config.get('steps', [])
self.caption_method = config.get('caption_method', 'llava')
self.force_reprocess_img = config.get('force_reprocess_img', False)
self.master_dataset_dict = OrderedDict()
self.dataset_master_config_file = config.get('dataset_master_config_file', None)
if parent_dir is not None and len(self.dataset_paths) == 0:
# find all folders in the patent_dataset_path
self.dataset_paths = [
os.path.join(parent_dir, folder)
for folder in os.listdir(parent_dir)
if os.path.isdir(os.path.join(parent_dir, folder))
]
else:
# make sure they exist
for dataset_path in self.dataset_paths:
if not os.path.exists(dataset_path):
raise ValueError(f"Dataset path does not exist: {dataset_path}")
print(f"Found {len(self.dataset_paths)} dataset paths")
self.image_processor = LLaVAImageProcessor(device=self.device)
def process_image(self, img_path: str):
root_img_dir = os.path.dirname(os.path.dirname(img_path))
filename = os.path.basename(img_path)
filename_no_ext = os.path.splitext(filename)[0]
train_dir = os.path.join(root_img_dir, TRAIN_DIR)
train_img_path = os.path.join(train_dir, filename)
json_path = os.path.join(train_dir, f"{filename_no_ext}.json")
# check if json exists, if it does load it as image info
if os.path.exists(json_path):
with open(json_path, 'r') as f:
img_info = ImgInfo(**json.load(f))
else:
img_info = ImgInfo()
img_info.set_version(VERSION)
# send steps to img info so it can store them
img_info.add_steps(copy.deepcopy(self.steps))
image: Image = None
did_update_image = False
# trigger reprocess of steps
if self.force_reprocess_img:
img_info.trigger_image_reprocess(steps=self.steps)
# set the image as updated if it does not exist on disk
if not os.path.exists(train_img_path):
did_update_image = True
image = load_image(img_path)
if img_info.force_image_process:
did_update_image = True
image = load_image(img_path)
# go through the needed steps
for step in img_info.state.steps_to_complete:
if step == 'caption':
# load image
if image is None:
image = load_image(img_path)
if not self.image_processor.is_loaded:
print('Loading Model. Takes a while, especially the first time')
self.image_processor.load_model()
img_info.caption = self.image_processor.generate_caption(image, prompt=long_prompt)
img_info.mark_step_complete(step)
elif step == 'caption_short':
# load image
if image is None:
image = load_image(img_path)
if not self.image_processor.is_loaded:
print('Loading Model. Takes a while, especially the first time')
self.image_processor.load_model()
img_info.caption_short = self.image_processor.generate_caption(image, prompt=short_prompt)
img_info.mark_step_complete(step)
elif step == 'contrast_stretch':
# load image
if image is None:
image = load_image(img_path)
image = ImageOps.autocontrast(image, cutoff=(0.1, 0), preserve_tone=True)
did_update_image = True
img_info.mark_step_complete(step)
else:
raise ValueError(f"Unknown step: {step}")
os.makedirs(os.path.dirname(train_img_path), exist_ok=True)
if did_update_image:
image.save(train_img_path)
if img_info.is_dirty:
with open(json_path, 'w') as f:
json.dump(img_info.to_dict(), f, indent=4)
if self.dataset_master_config_file:
# add to master dict
self.master_dataset_dict[train_img_path] = img_info.to_dict()
def run(self):
super().run()
imgs_to_process = []
# find all images
for dataset_path in self.dataset_paths:
raw_dir = os.path.join(dataset_path, RAW_DIR)
raw_image_paths = get_img_paths(raw_dir)
for raw_image_path in raw_image_paths:
imgs_to_process.append(raw_image_path)
if len(imgs_to_process) == 0:
print(f"No images to process")
else:
print(f"Found {len(imgs_to_process)} to process")
for img_path in tqdm(imgs_to_process, desc="Processing images"):
# try:
# self.process_image(img_path)
# except Exception as e:
# print(f"Error processing {img_path}: {e}")
# continue
self.process_image(img_path)
if self.dataset_master_config_file is not None:
# save it as json
with open(self.dataset_master_config_file, 'w') as f:
json.dump(self.master_dataset_dict, f, indent=4)
del self.image_processor
flush()

View File

@@ -0,0 +1,131 @@
import os
import shutil
from collections import OrderedDict
import gc
from typing import List
import torch
from tqdm import tqdm
from .dataset_tools_config_modules import DatasetSyncCollectionConfig, RAW_DIR, NEW_DIR
from .tools.sync_tools import get_unsplash_images, get_pexels_images, get_local_image_file_names, download_image, \
get_img_paths
from jobs.process import BaseExtensionProcess
def flush():
torch.cuda.empty_cache()
gc.collect()
class SyncFromCollection(BaseExtensionProcess):
def __init__(self, process_id: int, job, config: OrderedDict):
super().__init__(process_id, job, config)
self.min_width = config.get('min_width', 1024)
self.min_height = config.get('min_height', 1024)
# add our min_width and min_height to each dataset config if they don't exist
for dataset_config in config.get('dataset_sync', []):
if 'min_width' not in dataset_config:
dataset_config['min_width'] = self.min_width
if 'min_height' not in dataset_config:
dataset_config['min_height'] = self.min_height
self.dataset_configs: List[DatasetSyncCollectionConfig] = [
DatasetSyncCollectionConfig(**dataset_config)
for dataset_config in config.get('dataset_sync', [])
]
print(f"Found {len(self.dataset_configs)} dataset configs")
def move_new_images(self, root_dir: str):
raw_dir = os.path.join(root_dir, RAW_DIR)
new_dir = os.path.join(root_dir, NEW_DIR)
new_images = get_img_paths(new_dir)
for img_path in new_images:
# move to raw
new_path = os.path.join(raw_dir, os.path.basename(img_path))
shutil.move(img_path, new_path)
# remove new dir
shutil.rmtree(new_dir)
def sync_dataset(self, config: DatasetSyncCollectionConfig):
if config.host == 'unsplash':
get_images = get_unsplash_images
elif config.host == 'pexels':
get_images = get_pexels_images
else:
raise ValueError(f"Unknown host: {config.host}")
results = {
'num_downloaded': 0,
'num_skipped': 0,
'bad': 0,
'total': 0,
}
photos = get_images(config)
raw_dir = os.path.join(config.directory, RAW_DIR)
new_dir = os.path.join(config.directory, NEW_DIR)
raw_images = get_local_image_file_names(raw_dir)
new_images = get_local_image_file_names(new_dir)
for photo in tqdm(photos, desc=f"{config.host}-{config.collection_id}"):
try:
if photo.filename not in raw_images and photo.filename not in new_images:
download_image(photo, new_dir, min_width=self.min_width, min_height=self.min_height)
results['num_downloaded'] += 1
else:
results['num_skipped'] += 1
except Exception as e:
print(f" - BAD({photo.id}): {e}")
results['bad'] += 1
continue
results['total'] += 1
return results
def print_results(self, results):
print(
f" - new:{results['num_downloaded']}, old:{results['num_skipped']}, bad:{results['bad']} total:{results['total']}")
def run(self):
super().run()
print(f"Syncing {len(self.dataset_configs)} datasets")
all_results = None
failed_datasets = []
for dataset_config in tqdm(self.dataset_configs, desc="Syncing datasets", leave=True):
try:
results = self.sync_dataset(dataset_config)
if all_results is None:
all_results = {**results}
else:
for key, value in results.items():
all_results[key] += value
self.print_results(results)
except Exception as e:
print(f" - FAILED: {e}")
if 'response' in e.__dict__:
error = f"{e.response.status_code}: {e.response.text}"
print(f" - {error}")
failed_datasets.append({'dataset': dataset_config, 'error': error})
else:
failed_datasets.append({'dataset': dataset_config, 'error': str(e)})
continue
print("Moving new images to raw")
for dataset_config in self.dataset_configs:
self.move_new_images(dataset_config.directory)
print("Done syncing datasets")
self.print_results(all_results)
if len(failed_datasets) > 0:
print(f"Failed to sync {len(failed_datasets)} datasets")
for failed in failed_datasets:
print(f" - {failed['dataset'].host}-{failed['dataset'].collection_id}")
print(f" - ERR: {failed['error']}")

View File

@@ -1,10 +1,7 @@
# This is an example extension for custom training. It is great for experimenting with new ideas.
from toolkit.extension import Extension
# This is for generic training (LoRA, Dreambooth, FineTuning)
class DatasetToolsExtension(Extension):
# uid must be unique, it is how the extension is identified
uid = "dataset_tools"
# name is the name of the extension for printing
@@ -19,7 +16,28 @@ class DatasetToolsExtension(Extension):
return DatasetTools
class SyncFromCollectionExtension(Extension):
uid = "sync_from_collection"
name = "Sync from Collection"
@classmethod
def get_process(cls):
# import your process class here so it is only loaded when needed and return it
from .SyncFromCollection import SyncFromCollection
return SyncFromCollection
class SuperTaggerExtension(Extension):
uid = "super_tagger"
name = "Super Tagger"
@classmethod
def get_process(cls):
# import your process class here so it is only loaded when needed and return it
from .SuperTagger import SuperTagger
return SuperTagger
AI_TOOLKIT_EXTENSIONS = [
# you can put a list of extensions here
DatasetToolsExtension,
SyncFromCollectionExtension, DatasetToolsExtension, SuperTaggerExtension
]

View File

@@ -0,0 +1,145 @@
import json
from typing import Literal, Type, TYPE_CHECKING
Host: Type = Literal['unsplash', 'pexels']
RAW_DIR = "raw"
NEW_DIR = "_tmp"
TRAIN_DIR = "train"
DEPTH_DIR = "depth"
from .tools.image_tools import Step, img_manipulation_steps
class DatasetSyncCollectionConfig:
def __init__(self, **kwargs):
self.host: Host = kwargs.get('host', None)
self.collection_id: str = kwargs.get('collection_id', None)
self.directory: str = kwargs.get('directory', None)
self.api_key: str = kwargs.get('api_key', None)
self.min_width: int = kwargs.get('min_width', 1024)
self.min_height: int = kwargs.get('min_height', 1024)
if self.host is None:
raise ValueError("host is required")
if self.collection_id is None:
raise ValueError("collection_id is required")
if self.directory is None:
raise ValueError("directory is required")
if self.api_key is None:
raise ValueError(f"api_key is required: {self.host}:{self.collection_id}")
class ImageState:
def __init__(self, **kwargs):
self.steps_complete: list[Step] = kwargs.get('steps_complete', [])
self.steps_to_complete: list[Step] = kwargs.get('steps_to_complete', [])
def to_dict(self):
return {
'steps_complete': self.steps_complete
}
class Rect:
def __init__(self, **kwargs):
self.x = kwargs.get('x', 0)
self.y = kwargs.get('y', 0)
self.width = kwargs.get('width', 0)
self.height = kwargs.get('height', 0)
def to_dict(self):
return {
'x': self.x,
'y': self.y,
'width': self.width,
'height': self.height
}
class ImgInfo:
def __init__(self, **kwargs):
self.version: int = kwargs.get('version', None)
self.caption: str = kwargs.get('caption', None)
self.caption_short: str = kwargs.get('caption_short', None)
self.poi = [Rect(**poi) for poi in kwargs.get('poi', [])]
self.state = ImageState(**kwargs.get('state', {}))
self._upgrade_state()
self.force_image_process: bool = False
self.is_dirty: bool = False
def _upgrade_state(self):
# upgrades older states
if self.caption is not None and 'caption' not in self.state.steps_complete:
self.mark_step_complete('caption')
self.is_dirty = True
if self.caption_short is not None and 'caption_short' not in self.state.steps_complete:
self.mark_step_complete('caption_short')
self.is_dirty = True
def to_dict(self):
return {
'version': self.version,
'caption': self.caption,
'caption_short': self.caption_short,
'poi': [poi.to_dict() for poi in self.poi],
'state': self.state.to_dict()
}
def mark_step_complete(self, step: Step):
if step not in self.state.steps_complete:
self.state.steps_complete.append(step)
if step in self.state.steps_to_complete:
self.state.steps_to_complete.remove(step)
self.is_dirty = True
def add_step(self, step: Step):
if step not in self.state.steps_to_complete and step not in self.state.steps_complete:
self.state.steps_to_complete.append(step)
def trigger_image_reprocess(self, steps):
# remove all image manipulationf from steps_to_complete
for step in img_manipulation_steps:
if step in self.state.steps_to_complete:
self.state.steps_to_complete.remove(step)
if step in self.state.steps_complete:
self.state.steps_complete.remove(step)
self.force_image_process = True
self.is_dirty = True
# we want to keep the order passed in process file
for step in steps:
if step in img_manipulation_steps:
self.add_step(step)
def add_steps(self, steps: list[Step]):
for stage in steps:
self.add_step(stage)
# update steps if we have any img processes not complete, we have to reprocess them all
# if any steps_to_complete are in img_manipulation_steps
# TODO check if they are in a new order now ands trigger a redo
is_manipulating_image = any([step in img_manipulation_steps for step in self.state.steps_to_complete])
order_has_changed = False
if not is_manipulating_image:
# check to see if order has changed. No need to if already redoing it. Will detect if ones are removed
target_img_manipulation_order = [step for step in steps if step in img_manipulation_steps]
current_img_manipulation_order = [step for step in self.state.steps_complete if
step in img_manipulation_steps]
if target_img_manipulation_order != current_img_manipulation_order:
order_has_changed = True
if is_manipulating_image or order_has_changed:
self.trigger_image_reprocess(steps)
def to_json(self):
return json.dumps(self.to_dict())
def set_version(self, version: int):
if self.version != version:
self.is_dirty = True
self.version = version

View File

@@ -0,0 +1,29 @@
from typing import Literal, Type
import cv2
import numpy as np
from PIL import Image, ImageOps
Step: Type = Literal['caption', 'caption_short', 'create_mask', 'contrast_stretch']
img_manipulation_steps = ['contrast_stretch']
def pil_to_cv2(image):
"""Convert a PIL image to a cv2 image."""
return cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
def cv2_to_pil(image):
"""Convert a cv2 image to a PIL image."""
return Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
def load_image(img_path: str):
image = Image.open(img_path).convert('RGB')
try:
# transpose with exif data
image = ImageOps.exif_transpose(image)
except Exception as e:
pass
return image

View File

@@ -0,0 +1,129 @@
try:
from llava.model import LlavaLlamaForCausalLM
except ImportError:
# print("You need to manually install llava -> pip install --no-deps git+https://github.com/haotian-liu/LLaVA.git")
print("You need to manually install llava -> pip install --no-deps git+https://github.com/haotian-liu/LLaVA.git")
raise
long_prompt = 'caption this image. describe every single thing in the image in detail. Do not include any unnecessary words in your description for the sake of good grammar. I want many short statements that serve the single purpose of giving the most thorough description if items as possible in the smallest, comma separated way possible. be sure to describe people\'s moods, clothing, the environment, lighting, colors, and everything.'
short_prompt = 'caption this image in less than ten words'
prompts = [
long_prompt,
short_prompt,
]
replacements = [
("the image features", ""),
("the image shows", ""),
("the image depicts", ""),
("the image is", ""),
]
import torch
from PIL import Image, ImageOps
from llava.conversation import conv_templates, SeparatorStyle
from transformers import AutoTokenizer, BitsAndBytesConfig, CLIPImageProcessor
from llava.utils import disable_torch_init
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from llava.mm_utils import tokenizer_image_token, KeywordsStoppingCriteria
img_ext = ['.jpg', '.jpeg', '.png', '.webp']
class LLaVAImageProcessor:
def __init__(self, device='cuda'):
self.device = device
self.model: LlavaLlamaForCausalLM = None
self.tokenizer: AutoTokenizer = None
self.image_processor: CLIPImageProcessor = None
self.is_loaded = False
def load_model(self):
from llava.model import LlavaLlamaForCausalLM
model_path = "4bit/llava-v1.5-13b-3GB"
# kwargs = {"device_map": "auto"}
kwargs = {"device_map": self.device}
kwargs['load_in_4bit'] = True
kwargs['quantization_config'] = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type='nf4'
)
self.model = LlavaLlamaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
self.tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
vision_tower = self.model.get_vision_tower()
if not vision_tower.is_loaded:
vision_tower.load_model()
vision_tower.to(device=self.device)
self.image_processor = vision_tower.image_processor
self.is_loaded = True
def clean_caption(self, cap):
# remove any newlines
cap = cap.replace("\n", ", ")
cap = cap.replace("\r", ", ")
cap = cap.replace(".", ",")
cap = cap.replace("\"", "")
# remove unicode characters
cap = cap.encode('ascii', 'ignore').decode('ascii')
# make lowercase
cap = cap.lower()
# remove any extra spaces
cap = " ".join(cap.split())
for replacement in replacements:
cap = cap.replace(replacement[0], replacement[1])
cap_list = cap.split(",")
# trim whitespace
cap_list = [c.strip() for c in cap_list]
# remove empty strings
cap_list = [c for c in cap_list if c != ""]
# remove duplicates
cap_list = list(dict.fromkeys(cap_list))
# join back together
cap = ", ".join(cap_list)
return cap
def generate_caption(self, image: Image, prompt: str = long_prompt):
# question = "how many dogs are in the picture?"
disable_torch_init()
conv_mode = "llava_v0"
conv = conv_templates[conv_mode].copy()
roles = conv.roles
image_tensor = self.image_processor.preprocess([image], return_tensors='pt')['pixel_values'].half().cuda()
inp = f"{roles[0]}: {prompt}"
inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + inp
conv.append_message(conv.roles[0], inp)
conv.append_message(conv.roles[1], None)
raw_prompt = conv.get_prompt()
input_ids = tokenizer_image_token(raw_prompt, self.tokenizer, IMAGE_TOKEN_INDEX,
return_tensors='pt').unsqueeze(0).cuda()
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
keywords = [stop_str]
stopping_criteria = KeywordsStoppingCriteria(keywords, self.tokenizer, input_ids)
with torch.inference_mode():
output_ids = self.model.generate(
input_ids, images=image_tensor, do_sample=True, temperature=0.1,
max_new_tokens=1024, use_cache=True, stopping_criteria=[stopping_criteria],
top_p=0.9
)
outputs = self.tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
conv.messages[-1][-1] = outputs
output = outputs.rsplit('</s>', 1)[0]
return self.clean_caption(output)
def generate_captions(self, image: Image):
responses = []
for prompt in prompts:
output = self.generate_caption(image, prompt)
responses.append(output)
# replace all . with ,
return responses

View File

@@ -0,0 +1,279 @@
import os
import requests
import tqdm
from typing import List, Optional, TYPE_CHECKING
def img_root_path(img_id: str):
return os.path.dirname(os.path.dirname(img_id))
if TYPE_CHECKING:
from ..dataset_tools_config_modules import DatasetSyncCollectionConfig
img_exts = ['.jpg', '.jpeg', '.webp', '.png']
class Photo:
def __init__(
self,
id,
host,
width,
height,
url,
filename
):
self.id = str(id)
self.host = host
self.width = width
self.height = height
self.url = url
self.filename = filename
def get_desired_size(img_width: int, img_height: int, min_width: int, min_height: int):
if img_width > img_height:
scale = min_height / img_height
else:
scale = min_width / img_width
new_width = int(img_width * scale)
new_height = int(img_height * scale)
return new_width, new_height
def get_pexels_images(config: 'DatasetSyncCollectionConfig') -> List[Photo]:
all_images = []
next_page = f"https://api.pexels.com/v1/collections/{config.collection_id}?page=1&per_page=80&type=photos"
while True:
response = requests.get(next_page, headers={
"Authorization": f"{config.api_key}"
})
response.raise_for_status()
data = response.json()
all_images.extend(data['media'])
if 'next_page' in data and data['next_page']:
next_page = data['next_page']
else:
break
photos = []
for image in all_images:
new_width, new_height = get_desired_size(image['width'], image['height'], config.min_width, config.min_height)
url = f"{image['src']['original']}?auto=compress&cs=tinysrgb&h={new_height}&w={new_width}"
filename = os.path.basename(image['src']['original'])
photos.append(Photo(
id=image['id'],
host="pexels",
width=image['width'],
height=image['height'],
url=url,
filename=filename
))
return photos
def get_unsplash_images(config: 'DatasetSyncCollectionConfig') -> List[Photo]:
headers = {
# "Authorization": f"Client-ID {UNSPLASH_ACCESS_KEY}"
"Authorization": f"Client-ID {config.api_key}"
}
# headers['Authorization'] = f"Bearer {token}"
url = f"https://api.unsplash.com/collections/{config.collection_id}/photos?page=1&per_page=30"
response = requests.get(url, headers=headers)
response.raise_for_status()
res_headers = response.headers
# parse the link header to get the next page
# 'Link': '<https://api.unsplash.com/collections/mIPWwLdfct8/photos?page=82>; rel="last", <https://api.unsplash.com/collections/mIPWwLdfct8/photos?page=2>; rel="next"'
has_next_page = False
if 'Link' in res_headers:
has_next_page = True
link_header = res_headers['Link']
link_header = link_header.split(',')
link_header = [link.strip() for link in link_header]
link_header = [link.split(';') for link in link_header]
link_header = [[link[0].strip('<>'), link[1].strip().strip('"')] for link in link_header]
link_header = {link[1]: link[0] for link in link_header}
# get page number from last url
last_page = link_header['rel="last']
last_page = last_page.split('?')[1]
last_page = last_page.split('&')
last_page = [param.split('=') for param in last_page]
last_page = {param[0]: param[1] for param in last_page}
last_page = int(last_page['page'])
all_images = response.json()
if has_next_page:
# assume we start on page 1, so we don't need to get it again
for page in tqdm.tqdm(range(2, last_page + 1)):
url = f"https://api.unsplash.com/collections/{config.collection_id}/photos?page={page}&per_page=30"
response = requests.get(url, headers=headers)
response.raise_for_status()
all_images.extend(response.json())
photos = []
for image in all_images:
new_width, new_height = get_desired_size(image['width'], image['height'], config.min_width, config.min_height)
url = f"{image['urls']['raw']}&w={new_width}"
filename = f"{image['id']}.jpg"
photos.append(Photo(
id=image['id'],
host="unsplash",
width=image['width'],
height=image['height'],
url=url,
filename=filename
))
return photos
def get_img_paths(dir_path: str):
os.makedirs(dir_path, exist_ok=True)
local_files = os.listdir(dir_path)
# remove non image files
local_files = [file for file in local_files if os.path.splitext(file)[1].lower() in img_exts]
# make full path
local_files = [os.path.join(dir_path, file) for file in local_files]
return local_files
def get_local_image_ids(dir_path: str):
os.makedirs(dir_path, exist_ok=True)
local_files = get_img_paths(dir_path)
# assuming local files are named after Unsplash IDs, e.g., 'abc123.jpg'
return set([os.path.basename(file).split('.')[0] for file in local_files])
def get_local_image_file_names(dir_path: str):
os.makedirs(dir_path, exist_ok=True)
local_files = get_img_paths(dir_path)
# assuming local files are named after Unsplash IDs, e.g., 'abc123.jpg'
return set([os.path.basename(file) for file in local_files])
def download_image(photo: Photo, dir_path: str, min_width: int = 1024, min_height: int = 1024):
img_width = photo.width
img_height = photo.height
if img_width < min_width or img_height < min_height:
raise ValueError(f"Skipping {photo.id} because it is too small: {img_width}x{img_height}")
img_response = requests.get(photo.url)
img_response.raise_for_status()
os.makedirs(dir_path, exist_ok=True)
filename = os.path.join(dir_path, photo.filename)
with open(filename, 'wb') as file:
file.write(img_response.content)
def update_caption(img_path: str):
# if the caption is a txt file, convert it to a json file
filename_no_ext = os.path.splitext(os.path.basename(img_path))[0]
# see if it exists
if os.path.exists(os.path.join(os.path.dirname(img_path), f"{filename_no_ext}.json")):
# todo add poi and what not
return # we have a json file
caption = ""
# see if txt file exists
if os.path.exists(os.path.join(os.path.dirname(img_path), f"{filename_no_ext}.txt")):
# read it
with open(os.path.join(os.path.dirname(img_path), f"{filename_no_ext}.txt"), 'r') as file:
caption = file.read()
# write json file
with open(os.path.join(os.path.dirname(img_path), f"{filename_no_ext}.json"), 'w') as file:
file.write(f'{{"caption": "{caption}"}}')
# delete txt file
os.remove(os.path.join(os.path.dirname(img_path), f"{filename_no_ext}.txt"))
# def equalize_img(img_path: str):
# input_path = img_path
# output_path = os.path.join(img_root_path(img_path), COLOR_CORRECTED_DIR, os.path.basename(img_path))
# os.makedirs(os.path.dirname(output_path), exist_ok=True)
# process_img(
# img_path=input_path,
# output_path=output_path,
# equalize=True,
# max_size=2056,
# white_balance=False,
# gamma_correction=False,
# strength=0.6,
# )
# def annotate_depth(img_path: str):
# # make fake args
# args = argparse.Namespace()
# args.annotator = "midas"
# args.res = 1024
#
# img = cv2.imread(img_path)
# img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
#
# output = annotate(img, args)
#
# output = output.astype('uint8')
# output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
#
# os.makedirs(os.path.dirname(img_path), exist_ok=True)
# output_path = os.path.join(img_root_path(img_path), DEPTH_DIR, os.path.basename(img_path))
#
# cv2.imwrite(output_path, output)
# def invert_depth(img_path: str):
# img = cv2.imread(img_path)
# img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
# # invert the colors
# img = cv2.bitwise_not(img)
#
# os.makedirs(os.path.dirname(img_path), exist_ok=True)
# output_path = os.path.join(img_root_path(img_path), INVERTED_DEPTH_DIR, os.path.basename(img_path))
# cv2.imwrite(output_path, img)
#
# # update our list of raw images
# raw_images = get_img_paths(raw_dir)
#
# # update raw captions
# for image_id in tqdm.tqdm(raw_images, desc="Updating raw captions"):
# update_caption(image_id)
#
# # equalize images
# for img_path in tqdm.tqdm(raw_images, desc="Equalizing images"):
# if img_path not in eq_images:
# equalize_img(img_path)
#
# # update our list of eq images
# eq_images = get_img_paths(eq_dir)
# # update eq captions
# for image_id in tqdm.tqdm(eq_images, desc="Updating eq captions"):
# update_caption(image_id)
#
# # annotate depth
# depth_dir = os.path.join(root_dir, DEPTH_DIR)
# depth_images = get_img_paths(depth_dir)
# for img_path in tqdm.tqdm(eq_images, desc="Annotating depth"):
# if img_path not in depth_images:
# annotate_depth(img_path)
#
# depth_images = get_img_paths(depth_dir)
#
# # invert depth
# inv_depth_dir = os.path.join(root_dir, INVERTED_DEPTH_DIR)
# inv_depth_images = get_img_paths(inv_depth_dir)
# for img_path in tqdm.tqdm(depth_images, desc="Inverting depth"):
# if img_path not in inv_depth_images:
# invert_depth(img_path)

View File

@@ -74,12 +74,18 @@ class SDTrainer(BaseSDTrainProcess):
):
loss_target = self.train_config.loss_target
prior_mask_multiplier = None
target_mask_multiplier = None
if self.train_config.inverted_mask_prior:
# we need to make the noise prediction be a masked blending of noise and prior_pred
prior_multiplier = 1.0 - mask_multiplier
target = (noise * mask_multiplier) + (prior_pred * prior_multiplier)
prior_mask_multiplier = 1.0 - mask_multiplier
# target_mask_multiplier = mask_multiplier
# mask_multiplier = 1.0
target = noise
# target = (noise * mask_multiplier) + (prior_pred * prior_mask_multiplier)
# set masked multiplier to 1.0 so we dont double apply it
mask_multiplier = 1.0
# mask_multiplier = 1.0
elif prior_pred is not None:
# matching adapter prediction
target = prior_pred
@@ -128,6 +134,16 @@ class SDTrainer(BaseSDTrainProcess):
# multiply by our mask
loss = loss * mask_multiplier
if self.train_config.inverted_mask_prior:
# to a loss to unmasked areas of the prior for unmasked regularization
prior_loss = torch.nn.functional.mse_loss(
prior_pred.float(),
pred.float(),
reduction="none"
)
prior_loss = prior_loss * prior_mask_multiplier * self.train_config.inverted_mask_prior_multiplier
loss = loss + prior_loss
loss = loss.mean([1, 2, 3])
if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001 and not ignore_snr:

View File

@@ -20,4 +20,5 @@ k-diffusion
open_clip_torch
timm
prodigyopt
controlnet_aux==0.0.7
controlnet_aux==0.0.7
python-dotenv

3
run.py
View File

@@ -1,6 +1,9 @@
import os
import sys
from typing import Union, OrderedDict
from dotenv import load_dotenv
# Load the .env file if it exists
load_dotenv()
sys.path.insert(0, os.getcwd())
# must come before ANY torch or fastai imports

View File

@@ -17,6 +17,24 @@ def get_cwd_abs_path(path):
return path
def replace_env_vars_in_string(s: str) -> str:
"""
Replace placeholders like ${VAR_NAME} with the value of the corresponding environment variable.
If the environment variable is not set, raise an error.
"""
def replacer(match):
var_name = match.group(1)
value = os.environ.get(var_name)
if value is None:
raise ValueError(f"Environment variable {var_name} not set. Please ensure it's defined before proceeding.")
return value
return re.sub(r'\$\{([^}]+)\}', replacer, s)
def preprocess_config(config: OrderedDict, name: str = None):
if "job" not in config:
raise ValueError("config file must have a job key")
@@ -81,13 +99,14 @@ def get_config(
raise ValueError(f"Could not find config file {config_file_path}")
# if we found it, check if it is a json or yaml file
if real_config_path.endswith('.json') or real_config_path.endswith('.jsonc'):
with open(real_config_path, 'r', encoding='utf-8') as f:
config = json.load(f, object_pairs_hook=OrderedDict)
elif real_config_path.endswith('.yaml') or real_config_path.endswith('.yml'):
with open(real_config_path, 'r', encoding='utf-8') as f:
config = yaml.load(f, Loader=fixed_loader)
else:
raise ValueError(f"Config file {config_file_path} must be a json or yaml file")
with open(real_config_path, 'r', encoding='utf-8') as f:
content = f.read()
content_with_env_replaced = replace_env_vars_in_string(content)
if real_config_path.endswith('.json') or real_config_path.endswith('.jsonc'):
config = json.loads(content_with_env_replaced, object_pairs_hook=OrderedDict)
elif real_config_path.endswith('.yaml') or real_config_path.endswith('.yml'):
config = yaml.load(content_with_env_replaced, Loader=fixed_loader)
else:
raise ValueError(f"Config file {config_file_path} must be a json or yaml file")
return preprocess_config(config, name)

View File

@@ -133,6 +133,7 @@ class TrainConfig:
# we will predict noise without a the LoRa network and use the prediction as a target for
# unmasked reign. It is unmasked regularization basically
self.inverted_mask_prior = kwargs.get('inverted_mask_prior', False)
self.inverted_mask_prior_multiplier = kwargs.get('inverted_mask_prior_multiplier', 0.5)
# legacy
if match_adapter_assist and self.match_adapter_chance == 0.0: