From 96ba2fd129a31ac0e63826a535a667ad41bd4723 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Wed, 9 Apr 2025 13:35:04 -0600 Subject: [PATCH] Added methods to the dataloader to automatically generate controls for line, mask, inpainting, depth, and pose. --- requirements.txt | 5 +- toolkit/config_modules.py | 8 ++ toolkit/data_loader.py | 11 +- toolkit/dataloader_mixins.py | 190 ++++++++++++++++++++++++++++-- toolkit/stable_diffusion_model.py | 2 + 5 files changed, 205 insertions(+), 11 deletions(-) diff --git a/requirements.txt b/requirements.txt index 33e6bd2a..9a05e197 100644 --- a/requirements.txt +++ b/requirements.txt @@ -22,7 +22,7 @@ k-diffusion open_clip_torch timm prodigyopt -controlnet_aux==0.0.7 +controlnet_aux==0.0.9 python-dotenv bitsandbytes hf_transfer @@ -35,4 +35,5 @@ peft gradio python-slugify opencv-python -pytorch-wavelets==1.3.0 \ No newline at end of file +pytorch-wavelets==1.3.0 +matplotlib==3.10.1 \ No newline at end of file diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index de6c8ba1..56b008d2 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -687,6 +687,7 @@ class SliderConfig: self.targets.append(target) print(f"Built {len(self.targets)} slider targets (with permutations)") +ControlTypes = Literal['depth', 'line', 'pose', 'inpaint', 'mask'] class DatasetConfig: """ @@ -803,6 +804,13 @@ class DatasetConfig: # debug the frame count and frame selection. You dont need this. It is for debugging. self.debug: bool = kwargs.get('debug', False) + + # automatic controls + self.controls: List[ControlTypes] = kwargs.get('controls', []) + if isinstance(self.controls, str): + self.controls = [self.controls] + # remove empty strings + self.controls = [control for control in self.controls if control.strip() != ''] def preprocess_dataset_raw_config(raw_config: List[dict]) -> List[dict]: diff --git a/toolkit/data_loader.py b/toolkit/data_loader.py index aba48b4d..47f2bfef 100644 --- a/toolkit/data_loader.py +++ b/toolkit/data_loader.py @@ -18,7 +18,7 @@ import albumentations as A from toolkit.buckets import get_bucket_for_image_size, BucketResolution from toolkit.config_modules import DatasetConfig, preprocess_dataset_raw_config -from toolkit.dataloader_mixins import CaptionMixin, BucketsMixin, LatentCachingMixin, Augments, CLIPCachingMixin +from toolkit.dataloader_mixins import CaptionMixin, BucketsMixin, LatentCachingMixin, Augments, CLIPCachingMixin, ControlCachingMixin from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO from toolkit.print import print_acc from toolkit.accelerator import get_accelerator @@ -372,7 +372,7 @@ class PairedImageDataset(Dataset): return img, prompt, (self.neg_weight, self.pos_weight) -class AiToolkitDataset(LatentCachingMixin, CLIPCachingMixin, BucketsMixin, CaptionMixin, Dataset): +class AiToolkitDataset(LatentCachingMixin, ControlCachingMixin, CLIPCachingMixin, BucketsMixin, CaptionMixin, Dataset): def __init__( self, @@ -394,6 +394,7 @@ class AiToolkitDataset(LatentCachingMixin, CLIPCachingMixin, BucketsMixin, Capti self.is_caching_latents_to_memory = dataset_config.cache_latents self.is_caching_latents_to_disk = dataset_config.cache_latents_to_disk self.is_caching_clip_vision_to_disk = dataset_config.cache_clip_vision_to_disk + self.is_generating_controls = len(dataset_config.controls) > 0 self.epoch_num = 0 self.sd = sd @@ -425,6 +426,9 @@ class AiToolkitDataset(LatentCachingMixin, CLIPCachingMixin, BucketsMixin, Capti self.caption_dict = json.load(f) # keys are file paths file_list = list(self.caption_dict.keys()) + + # remove items in the _controls_ folder + file_list = [x for x in file_list if not os.path.basename(os.path.dirname(x)) == "_controls"] if self.dataset_config.num_repeats > 1: # repeat the list @@ -548,6 +552,9 @@ class AiToolkitDataset(LatentCachingMixin, CLIPCachingMixin, BucketsMixin, Capti self.cache_latents_all_latents() if self.is_caching_clip_vision_to_disk: self.cache_clip_vision_to_disk() + if self.is_generating_controls: + # always do this last + self.setup_controls() else: if self.dataset_config.poi is not None: # handle cropping to a specific point of interest diff --git a/toolkit/dataloader_mixins.py b/toolkit/dataloader_mixins.py index 8649129a..79131c7b 100644 --- a/toolkit/dataloader_mixins.py +++ b/toolkit/dataloader_mixins.py @@ -18,6 +18,7 @@ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, Sigl from toolkit.basic import flush, value_map from toolkit.buckets import get_bucket_for_image_size, get_resolution +from toolkit.config_modules import ControlTypes from toolkit.metadata import get_meta_for_safetensors from toolkit.models.pixtral_vision import PixtralVisionImagePreprocessorCompatible from toolkit.prompt_utils import inject_trigger_into_prompt @@ -62,6 +63,7 @@ transforms_dict = { } caption_ext_list = ['txt', 'json', 'caption'] +img_ext_list = ['.jpg', '.jpeg', '.png', '.webp'] def standardize_images(images): @@ -755,10 +757,10 @@ class InpaintControlFileItemDTOMixin: inpaint_path = dataset_config.inpaint_path # we are using control images img_path = kwargs.get('path', None) - img_ext_list = ['.png', '.webp'] + img_inpaint_ext_list = ['.png', '.webp'] file_name_no_ext = os.path.splitext(os.path.basename(img_path))[0] - for ext in img_ext_list: + for ext in img_inpaint_ext_list: p = os.path.join(inpaint_path, file_name_no_ext + ext) if os.path.exists(p): self.inpaint_path = p @@ -842,7 +844,6 @@ class ControlFileItemDTOMixin: self.full_size_control_images = dataset_config.full_size_control_images # we are using control images img_path = kwargs.get('path', None) - img_ext_list = ['.jpg', '.jpeg', '.png', '.webp'] file_name_no_ext = os.path.splitext(os.path.basename(img_path))[0] found_control_images = [] @@ -959,7 +960,6 @@ class ClipImageFileItemDTOMixin: clip_image_path = dataset_config.clip_image_path # we are using control images img_path = kwargs.get('path', None) - img_ext_list = ['.jpg', '.jpeg', '.png', '.webp'] file_name_no_ext = os.path.splitext(os.path.basename(img_path))[0] for ext in img_ext_list: if os.path.exists(os.path.join(clip_image_path, file_name_no_ext + ext)): @@ -1062,7 +1062,6 @@ class ClipImageFileItemDTOMixin: # randomly grab an image path from the same folder pool_folder = os.path.dirname(self.path) # find all images in the folder - img_ext_list = ['.jpg', '.jpeg', '.png', '.webp'] img_files = [] for ext in img_ext_list: img_files += glob.glob(os.path.join(pool_folder, f'*{ext}')) @@ -1281,7 +1280,6 @@ class MaskFileItemDTOMixin: mask_path = dataset_config.mask_path if dataset_config.mask_path is not None else dataset_config.alpha_mask # we are using control images img_path = kwargs.get('path', None) - img_ext_list = ['.jpg', '.jpeg', '.png', '.webp'] file_name_no_ext = os.path.splitext(os.path.basename(img_path))[0] for ext in img_ext_list: if os.path.exists(os.path.join(mask_path, file_name_no_ext + ext)): @@ -1385,7 +1383,6 @@ class UnconditionalFileItemDTOMixin: if dataset_config.unconditional_path is not None: # we are using control images img_path = kwargs.get('path', None) - img_ext_list = ['.jpg', '.jpeg', '.png', '.webp'] file_name_no_ext = os.path.splitext(os.path.basename(img_path))[0] for ext in img_ext_list: if os.path.exists(os.path.join(dataset_config.unconditional_path, file_name_no_ext + ext)): @@ -1944,3 +1941,182 @@ class CLIPCachingMixin: # restore device state self.sd.restore_device_state() + + + +class ControlCachingMixin: + def __init__(self: 'AiToolkitDataset', **kwargs): + if hasattr(super(), '__init__'): + super().__init__(**kwargs) + self.control_depth_model = None + self.control_pose_model = None + self.control_line_model = None + self.control_bg_remover = None + + def get_control_path(self: 'AiToolkitDataset', file_item:'FileItemDTO', control_type: ControlTypes): + coltrols_folder = os.path.join(os.path.dirname(file_item.path), '_controls') + file_name_no_ext = os.path.splitext(os.path.basename(file_item.path))[0] + file_name_no_ext_control = f"{file_name_no_ext}.{control_type}" + for ext in img_ext_list: + possible_path = os.path.join(coltrols_folder, file_name_no_ext_control + ext) + if os.path.exists(possible_path): + return possible_path + # if we get here, we need to generate the control + return None + + def add_control_path_to_file_item(self: 'AiToolkitDataset', file_item: 'FileItemDTO', control_path: str, control_type: ControlTypes): + if control_type == 'inpaint': + file_item.inpaint_path = control_path + file_item.has_inpaint_image = True + elif control_type == 'mask': + file_item.mask_path = control_path + file_item.has_mask_image = True + else: + if file_item.control_path is None: + file_item.control_path = [control_path] + elif isinstance(file_item.control_path, str): + file_item.control_path = [file_item.control_path, control_path] + elif isinstance(file_item.control_path, list): + file_item.control_path.append(control_path) + else: + raise Exception(f"Error: control_path is not a string or list: {file_item.control_path}") + file_item.has_control_image = True + + def setup_controls(self: 'AiToolkitDataset'): + if not self.is_generating_controls: + return + with torch.no_grad(): + print_acc(f"Generating controls for {self.dataset_path}") + + has_unloaded = False + device = self.sd.device + + # controls 'depth', 'line', 'pose', 'inpaint', 'mask' + + # use tqdm to show progress + i = 0 + for file_item in tqdm(self.file_list, desc=f'Generating Controls'): + coltrols_folder = os.path.join(os.path.dirname(file_item.path), '_controls') + file_name_no_ext = os.path.splitext(os.path.basename(file_item.path))[0] + + image: Image = None + + for control_type in self.dataset_config.controls: + control_path = self.get_control_path(file_item, control_type) + if control_path is not None: + self.add_control_path_to_file_item(file_item, control_path, control_type) + else: + # we need to generate the control. Unload model if not unloaded + if not has_unloaded: + print("Unloading model to generate controls") + self.sd.set_device_state_preset('unload') + has_unloaded = True + + if image is None: + # make sure image is loaded if we havent loaded it with another control + image = Image.open(file_item.path).convert('RGB') + image = exif_transpose(image) + + # resize to a max of 1mp + max_size = 1024 * 1024 + + w, h = image.size + if w * h > max_size: + scale = math.sqrt(max_size / (w * h)) + w = int(w * scale) + h = int(h * scale) + image = image.resize((w, h), Image.BICUBIC) + + save_path = os.path.join(coltrols_folder, f"{file_name_no_ext}.{control_type}.jpg") + os.makedirs(coltrols_folder, exist_ok=True) + if control_type == 'depth': + if self.control_depth_model is None: + from transformers import pipeline + self.control_depth_model = pipeline( + task="depth-estimation", + model="depth-anything/Depth-Anything-V2-Large-hf", + device=device, + torch_dtype=torch.float16 + ) + img = image.copy() + in_size = img.size + output = self.control_depth_model(img) + out_tensor = output["predicted_depth"] # shape (1, H, W) 0 - 255 + out_tensor = out_tensor.clamp(0, 255) + out_tensor = out_tensor.squeeze(0).cpu().numpy() + img = Image.fromarray(out_tensor.astype('uint8')) + img = img.resize(in_size, Image.LANCZOS) + img.save(save_path) + self.add_control_path_to_file_item(file_item, save_path, control_type) + elif control_type == 'pose': + if self.control_pose_model is None: + from controlnet_aux import OpenposeDetector + self.control_pose_model = OpenposeDetector.from_pretrained("lllyasviel/Annotators").to(device) + img = image.copy() + + detect_res = int(math.sqrt(img.size[0] * img.size[1])) + img = self.control_pose_model(img, hand_and_face=True, detect_resolution=detect_res, image_resolution=detect_res) + img = img.convert('RGB') + img.save(save_path) + self.add_control_path_to_file_item(file_item, save_path, control_type) + + elif control_type == 'line': + if self.control_line_model is None: + from controlnet_aux import TEEDdetector + self.control_line_model = TEEDdetector.from_pretrained("fal-ai/teed", filename="5_model.pth").to(device) + img = image.copy() + img = self.control_line_model(img, detect_resolution=1024) + img = img.convert('RGB') + img.save(save_path) + self.add_control_path_to_file_item(file_item, save_path, control_type) + elif control_type == 'inpaint' or control_type == 'mask': + img = image.copy() + if self.control_bg_remover is None: + from transformers import AutoModelForImageSegmentation + self.control_bg_remover = AutoModelForImageSegmentation.from_pretrained( + 'ZhengPeng7/BiRefNet_HR', + trust_remote_code=True, + revision="595e212b3eaa6a1beaad56cee49749b1e00b1596", + torch_dtype=torch.float16 + ).to(device) + self.control_bg_remover.eval() + + image_size = (1024, 1024) + transform_image = transforms.Compose([ + transforms.Resize(image_size), + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) + ]) + + input_images = transform_image(img).unsqueeze(0).to('cuda').to(torch.float16) + + # Prediction + preds = self.control_bg_remover(input_images)[-1].sigmoid().cpu() + pred = preds[0].squeeze() + pred_pil = transforms.ToPILImage()(pred) + mask = pred_pil.resize(img.size) + if control_type == 'inpaint': + # inpainting feature currently only supports "erased" section desired to inpaint + mask = ImageOps.invert(mask) + img.putalpha(mask) + save_path = os.path.join(coltrols_folder, f"{file_name_no_ext}.{control_type}.webp") + else: + img = mask + img = img.convert('RGB') + img.save(save_path) + self.add_control_path_to_file_item(file_item, save_path, control_type) + else: + raise Exception(f"Error: unknown control type {control_type}") + i += 1 + + # remove models + self.control_depth_model = None + self.control_pose_model = None + self.control_line_model = None + self.control_bg_remover = None + + flush() + + # restore device state + if has_unloaded: + self.sd.restore_device_state() diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index a08b043e..8a05a908 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -3009,6 +3009,8 @@ class StableDiffusion: active_modules = ['vae'] if device_state_preset in ['cache_clip']: active_modules = ['clip'] + if device_state_preset in ['unload']: + active_modules = [] if device_state_preset in ['generate']: active_modules = ['vae', 'unet', 'text_encoder', 'adapter', 'refiner_unet']