diff --git a/requirements.txt b/requirements.txt index 9a05e197..5bed325e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,6 +2,7 @@ torch==2.6.0 torchvision==0.21.0 torchao==0.9.0 safetensors +git+https://github.com/jaretburkett/easy_dwpose.git git+https://github.com/huggingface/diffusers@363d1ab7e24c5ed6c190abb00df66d9edb74383b transformers==4.49.0 lycoris-lora==1.8.3 @@ -22,7 +23,7 @@ k-diffusion open_clip_torch timm prodigyopt -controlnet_aux==0.0.9 +controlnet_aux==0.0.10 python-dotenv bitsandbytes hf_transfer diff --git a/toolkit/control_generator.py b/toolkit/control_generator.py new file mode 100644 index 00000000..a7a95c03 --- /dev/null +++ b/toolkit/control_generator.py @@ -0,0 +1,279 @@ +import gc +import math +import os +import torch +from typing import Literal +from PIL import Image, ImageFilter, ImageOps +from PIL.ImageOps import exif_transpose +from tqdm import tqdm + +from torchvision import transforms + +# supress all warnings +import warnings + +warnings.filterwarnings("ignore", category=UserWarning) +warnings.filterwarnings("ignore", category=FutureWarning) + + +def flush(garbage_collect=True): + torch.cuda.empty_cache() + if garbage_collect: + gc.collect() + + +ControlTypes = Literal['depth', 'pose', 'line', 'inpaint', 'mask'] + +img_ext_list = ['.jpg', '.jpeg', '.png', '.webp'] + + +class ControlGenerator: + def __init__(self, device, sd=None): + self.device = device + self.sd = sd # optional. It will unload the model if not None + self.has_unloaded = False + self.control_depth_model = None + self.control_pose_model = None + self.control_line_model = None + self.control_bg_remover = None + self.debug = False + self.regen = False + + def get_control_path(self, img_path, control_type: ControlTypes): + if self.regen: + return self._generate_control(img_path, control_type) + coltrols_folder = os.path.join(os.path.dirname(img_path), '_controls') + file_name_no_ext = os.path.splitext(os.path.basename(img_path))[0] + file_name_no_ext_control = f"{file_name_no_ext}.{control_type}" + for ext in img_ext_list: + possible_path = os.path.join( + coltrols_folder, file_name_no_ext_control + ext) + if os.path.exists(possible_path): + return possible_path + # if we get here, we need to generate the control + return self._generate_control(img_path, control_type) + + def debug_print(self, *args, **kwargs): + if self.debug: + print(*args, **kwargs) + + def _generate_control(self, img_path, control_type): + device = self.device + image: Image = None + + coltrols_folder = os.path.join(os.path.dirname(img_path), '_controls') + file_name_no_ext = os.path.splitext(os.path.basename(img_path))[0] + + # we need to generate the control. Unload model if not unloaded + if not self.has_unloaded: + if self.sd is not None: + print("Unloading model to generate controls") + self.sd.set_device_state_preset('unload') + self.has_unloaded = True + + if image is None: + # make sure image is loaded if we havent loaded it with another control + image = Image.open(img_path).convert('RGB') + image = exif_transpose(image) + + # resize to a max of 1mp + max_size = 1024 * 1024 + + w, h = image.size + if w * h > max_size: + scale = math.sqrt(max_size / (w * h)) + w = int(w * scale) + h = int(h * scale) + image = image.resize((w, h), Image.BICUBIC) + + save_path = os.path.join( + coltrols_folder, f"{file_name_no_ext}.{control_type}.jpg") + os.makedirs(coltrols_folder, exist_ok=True) + if control_type == 'depth': + self.debug_print("Generating depth control") + if self.control_depth_model is None: + from transformers import pipeline + self.control_depth_model = pipeline( + task="depth-estimation", + model="depth-anything/Depth-Anything-V2-Large-hf", + device=device, + torch_dtype=torch.float16 + ) + img = image.copy() + in_size = img.size + output = self.control_depth_model(img) + out_tensor = output["predicted_depth"] # shape (1, H, W) 0 - 255 + out_tensor = out_tensor.clamp(0, 255) + out_tensor = out_tensor.squeeze(0).cpu().numpy() + img = Image.fromarray(out_tensor.astype('uint8')) + img = img.resize(in_size, Image.LANCZOS) + img.save(save_path) + return save_path + elif control_type == 'pose': + self.debug_print("Generating pose control") + if self.control_pose_model is None: + try: + import onnxruntime + onnxruntime.set_default_logger_severity(3) + except ImportError: + raise ImportError( + "onnxruntime is not installed. Please install it with pip install onnxruntime or onnxruntime-gpu") + try: + from easy_dwpose import DWposeDetector + self.control_pose_model = DWposeDetector( + device=str(device)) + except ImportError: + raise ImportError( + "easy-dwpose is not installed. Please install it with pip install easy-dwpose") + img = image.copy() + + detect_res = int(math.sqrt(img.size[0] * img.size[1])) + img = self.control_pose_model( + img, output_type="pil", include_hands=True, include_face=True, detect_resolution=detect_res) + img = img.convert('RGB') + img.save(save_path) + return save_path + + elif control_type == 'line': + self.debug_print("Generating line control") + if self.control_line_model is None: + from controlnet_aux import TEEDdetector + self.control_line_model = TEEDdetector.from_pretrained( + "fal-ai/teed", filename="5_model.pth").to(device) + img = image.copy() + img = self.control_line_model(img, detect_resolution=1024) + # apply threshold + # img = img.filter(ImageFilter.GaussianBlur(radius=1)) + img = img.point(lambda p: p > 128 and 255) + img = img.convert('RGB') + img.save(save_path) + return save_path + elif control_type == 'inpaint' or control_type == 'mask': + self.debug_print("Generating inpaint/mask control") + img = image.copy() + if self.control_bg_remover is None: + from transformers import AutoModelForImageSegmentation + self.control_bg_remover = AutoModelForImageSegmentation.from_pretrained( + 'ZhengPeng7/BiRefNet_HR', + trust_remote_code=True, + revision="595e212b3eaa6a1beaad56cee49749b1e00b1596", + torch_dtype=torch.float16 + ).to(device) + self.control_bg_remover.eval() + + image_size = (1024, 1024) + transform_image = transforms.Compose([ + transforms.Resize(image_size), + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], [ + 0.229, 0.224, 0.225]) + ]) + + input_images = transform_image(img).unsqueeze( + 0).to('cuda').to(torch.float16) + + # Prediction + preds = self.control_bg_remover(input_images)[-1].sigmoid().cpu() + pred = preds[0].squeeze() + pred_pil = transforms.ToPILImage()(pred) + mask = pred_pil.resize(img.size) + if control_type == 'inpaint': + # inpainting feature currently only supports "erased" section desired to inpaint + mask = ImageOps.invert(mask) + img.putalpha(mask) + save_path = os.path.join( + coltrols_folder, f"{file_name_no_ext}.{control_type}.webp") + else: + img = mask + img = img.convert('RGB') + img.save(save_path) + return save_path + else: + raise Exception(f"Error: unknown control type {control_type}") + + def cleanup(self): + if self.control_depth_model is not None: + self.control_depth_model = None + if self.control_pose_model is not None: + self.control_pose_model = None + if self.control_line_model is not None: + self.control_line_model = None + if self.control_bg_remover is not None: + self.control_bg_remover = None + if self.sd is not None and self.has_unloaded: + self.sd.restore_device_state() + self.has_unloaded = False + + flush() + + +if __name__ == "__main__": + import sys + import argparse + import time + import transformers + transformers.logging.set_verbosity_error() + + control_times = { + 'depth': 0, + 'pose': 0, + 'line': 0, + 'inpaint': 0, + 'mask': 0 + } + + controls = control_times.keys() + + parser = argparse.ArgumentParser(description="Generate control images") + parser.add_argument("img_dir", type=str, help="Path to image directory") + parser.add_argument('--debug', action='store_true', + help="Enable debug mode") + parser.add_argument('--regen', action='store_true', + help="Regenerate all controls") + + args = parser.parse_args() + img_dir = args.img_dir + if not os.path.exists(img_dir): + print(f"Error: {img_dir} does not exist") + exit() + if not os.path.isdir(img_dir): + print(f"Error: {img_dir} is not a directory") + exit() + + # find images + img_list = [] + for root, dirs, files in os.walk(img_dir): + for file in files: + if "_controls" in root: + continue + if file.startswith('.'): + continue + if file.lower().endswith(tuple(img_ext_list)): + img_list.append(os.path.join(root, file)) + if len(img_list) == 0: + print(f"Error: no images found in {img_dir}") + exit() + + # load model + idx = 0 + for img_path in tqdm(img_list): + for control in controls: + start = time.time() + control_gen = ControlGenerator(torch.device('cuda')) + control_gen.debug = args.debug + control_gen.regen = args.regen + control_path = control_gen.get_control_path(img_path, control) + end = time.time() + # dont track for first 2 images + if idx < 2: + continue + control_times[control] += end - start + idx += 1 + + # determine avgt time + for control in controls: + control_times[control] /= (idx - 2) + print( + f"Avg time for {control} control: {control_times[control]:.2f} seconds") + + print("Done") diff --git a/toolkit/dataloader_mixins.py b/toolkit/dataloader_mixins.py index 255e70f2..7985bfe7 100644 --- a/toolkit/dataloader_mixins.py +++ b/toolkit/dataloader_mixins.py @@ -19,6 +19,7 @@ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, Sigl from toolkit.basic import flush, value_map from toolkit.buckets import get_bucket_for_image_size, get_resolution from toolkit.config_modules import ControlTypes +from toolkit.control_generator import ControlGenerator from toolkit.metadata import get_meta_for_safetensors from toolkit.models.pixtral_vision import PixtralVisionImagePreprocessorCompatible from toolkit.prompt_utils import inject_trigger_into_prompt @@ -1950,21 +1951,7 @@ class ControlCachingMixin: def __init__(self: 'AiToolkitDataset', **kwargs): if hasattr(super(), '__init__'): super().__init__(**kwargs) - self.control_depth_model = None - self.control_pose_model = None - self.control_line_model = None - self.control_bg_remover = None - - def get_control_path(self: 'AiToolkitDataset', file_item:'FileItemDTO', control_type: ControlTypes): - coltrols_folder = os.path.join(os.path.dirname(file_item.path), '_controls') - file_name_no_ext = os.path.splitext(os.path.basename(file_item.path))[0] - file_name_no_ext_control = f"{file_name_no_ext}.{control_type}" - for ext in img_ext_list: - possible_path = os.path.join(coltrols_folder, file_name_no_ext_control + ext) - if os.path.exists(possible_path): - return possible_path - # if we get here, we need to generate the control - return None + self.control_generator: ControlGenerator = None def add_control_path_to_file_item(self: 'AiToolkitDataset', file_item: 'FileItemDTO', control_path: str, control_type: ControlTypes): if control_type == 'inpaint': @@ -1989,136 +1976,23 @@ class ControlCachingMixin: return with torch.no_grad(): print_acc(f"Generating controls for {self.dataset_path}") - - has_unloaded = False device = self.sd.device - # controls 'depth', 'line', 'pose', 'inpaint', 'mask' + self.control_generator = ControlGenerator( + device=device, + sd=self.sd, + ) # use tqdm to show progress - i = 0 for file_item in tqdm(self.file_list, desc=f'Generating Controls'): - coltrols_folder = os.path.join(os.path.dirname(file_item.path), '_controls') - file_name_no_ext = os.path.splitext(os.path.basename(file_item.path))[0] - - image: Image = None - for control_type in self.dataset_config.controls: - control_path = self.get_control_path(file_item, control_type) + # generates the control if it is not already there + control_path = self.control_generator.get_control_path(file_item.path, control_type) if control_path is not None: self.add_control_path_to_file_item(file_item, control_path, control_type) - else: - # we need to generate the control. Unload model if not unloaded - if not has_unloaded: - print("Unloading model to generate controls") - self.sd.set_device_state_preset('unload') - has_unloaded = True - - if image is None: - # make sure image is loaded if we havent loaded it with another control - image = Image.open(file_item.path).convert('RGB') - image = exif_transpose(image) - - # resize to a max of 1mp - max_size = 1024 * 1024 - - w, h = image.size - if w * h > max_size: - scale = math.sqrt(max_size / (w * h)) - w = int(w * scale) - h = int(h * scale) - image = image.resize((w, h), Image.BICUBIC) - - save_path = os.path.join(coltrols_folder, f"{file_name_no_ext}.{control_type}.jpg") - os.makedirs(coltrols_folder, exist_ok=True) - if control_type == 'depth': - if self.control_depth_model is None: - from transformers import pipeline - self.control_depth_model = pipeline( - task="depth-estimation", - model="depth-anything/Depth-Anything-V2-Large-hf", - device=device, - torch_dtype=torch.float16 - ) - img = image.copy() - in_size = img.size - output = self.control_depth_model(img) - out_tensor = output["predicted_depth"] # shape (1, H, W) 0 - 255 - out_tensor = out_tensor.clamp(0, 255) - out_tensor = out_tensor.squeeze(0).cpu().numpy() - img = Image.fromarray(out_tensor.astype('uint8')) - img = img.resize(in_size, Image.LANCZOS) - img.save(save_path) - self.add_control_path_to_file_item(file_item, save_path, control_type) - elif control_type == 'pose': - if self.control_pose_model is None: - from controlnet_aux import OpenposeDetector - self.control_pose_model = OpenposeDetector.from_pretrained("lllyasviel/Annotators").to(device) - img = image.copy() - - detect_res = int(math.sqrt(img.size[0] * img.size[1])) - img = self.control_pose_model(img, hand_and_face=True, detect_resolution=detect_res, image_resolution=detect_res) - img = img.convert('RGB') - img.save(save_path) - self.add_control_path_to_file_item(file_item, save_path, control_type) - - elif control_type == 'line': - if self.control_line_model is None: - from controlnet_aux import TEEDdetector - self.control_line_model = TEEDdetector.from_pretrained("fal-ai/teed", filename="5_model.pth").to(device) - img = image.copy() - img = self.control_line_model(img, detect_resolution=1024) - img = img.convert('RGB') - img.save(save_path) - self.add_control_path_to_file_item(file_item, save_path, control_type) - elif control_type == 'inpaint' or control_type == 'mask': - img = image.copy() - if self.control_bg_remover is None: - from transformers import AutoModelForImageSegmentation - self.control_bg_remover = AutoModelForImageSegmentation.from_pretrained( - 'ZhengPeng7/BiRefNet_HR', - trust_remote_code=True, - revision="595e212b3eaa6a1beaad56cee49749b1e00b1596", - torch_dtype=torch.float16 - ).to(device) - self.control_bg_remover.eval() - - image_size = (1024, 1024) - transform_image = transforms.Compose([ - transforms.Resize(image_size), - transforms.ToTensor(), - transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) - ]) - - input_images = transform_image(img).unsqueeze(0).to('cuda').to(torch.float16) - - # Prediction - preds = self.control_bg_remover(input_images)[-1].sigmoid().cpu() - pred = preds[0].squeeze() - pred_pil = transforms.ToPILImage()(pred) - mask = pred_pil.resize(img.size) - if control_type == 'inpaint': - # inpainting feature currently only supports "erased" section desired to inpaint - mask = ImageOps.invert(mask) - img.putalpha(mask) - save_path = os.path.join(coltrols_folder, f"{file_name_no_ext}.{control_type}.webp") - else: - img = mask - img = img.convert('RGB') - img.save(save_path) - self.add_control_path_to_file_item(file_item, save_path, control_type) - else: - raise Exception(f"Error: unknown control type {control_type}") - i += 1 # remove models - self.control_depth_model = None - self.control_pose_model = None - self.control_line_model = None - self.control_bg_remover = None + self.control_generator.cleanup() + self.control_generator = None flush() - - # restore device state - if has_unloaded: - self.sd.restore_device_state()