Added methods to the dataloader to automatically generate controls for line, mask, inpainting, depth, and pose.

This commit is contained in:
Jaret Burkett
2025-04-09 13:35:04 -06:00
parent 615b0d0e94
commit 96ba2fd129
5 changed files with 205 additions and 11 deletions

View File

@@ -22,7 +22,7 @@ k-diffusion
open_clip_torch
timm
prodigyopt
controlnet_aux==0.0.7
controlnet_aux==0.0.9
python-dotenv
bitsandbytes
hf_transfer
@@ -35,4 +35,5 @@ peft
gradio
python-slugify
opencv-python
pytorch-wavelets==1.3.0
pytorch-wavelets==1.3.0
matplotlib==3.10.1

View File

@@ -687,6 +687,7 @@ class SliderConfig:
self.targets.append(target)
print(f"Built {len(self.targets)} slider targets (with permutations)")
ControlTypes = Literal['depth', 'line', 'pose', 'inpaint', 'mask']
class DatasetConfig:
"""
@@ -803,6 +804,13 @@ class DatasetConfig:
# debug the frame count and frame selection. You dont need this. It is for debugging.
self.debug: bool = kwargs.get('debug', False)
# automatic controls
self.controls: List[ControlTypes] = kwargs.get('controls', [])
if isinstance(self.controls, str):
self.controls = [self.controls]
# remove empty strings
self.controls = [control for control in self.controls if control.strip() != '']
def preprocess_dataset_raw_config(raw_config: List[dict]) -> List[dict]:

View File

@@ -18,7 +18,7 @@ import albumentations as A
from toolkit.buckets import get_bucket_for_image_size, BucketResolution
from toolkit.config_modules import DatasetConfig, preprocess_dataset_raw_config
from toolkit.dataloader_mixins import CaptionMixin, BucketsMixin, LatentCachingMixin, Augments, CLIPCachingMixin
from toolkit.dataloader_mixins import CaptionMixin, BucketsMixin, LatentCachingMixin, Augments, CLIPCachingMixin, ControlCachingMixin
from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO
from toolkit.print import print_acc
from toolkit.accelerator import get_accelerator
@@ -372,7 +372,7 @@ class PairedImageDataset(Dataset):
return img, prompt, (self.neg_weight, self.pos_weight)
class AiToolkitDataset(LatentCachingMixin, CLIPCachingMixin, BucketsMixin, CaptionMixin, Dataset):
class AiToolkitDataset(LatentCachingMixin, ControlCachingMixin, CLIPCachingMixin, BucketsMixin, CaptionMixin, Dataset):
def __init__(
self,
@@ -394,6 +394,7 @@ class AiToolkitDataset(LatentCachingMixin, CLIPCachingMixin, BucketsMixin, Capti
self.is_caching_latents_to_memory = dataset_config.cache_latents
self.is_caching_latents_to_disk = dataset_config.cache_latents_to_disk
self.is_caching_clip_vision_to_disk = dataset_config.cache_clip_vision_to_disk
self.is_generating_controls = len(dataset_config.controls) > 0
self.epoch_num = 0
self.sd = sd
@@ -425,6 +426,9 @@ class AiToolkitDataset(LatentCachingMixin, CLIPCachingMixin, BucketsMixin, Capti
self.caption_dict = json.load(f)
# keys are file paths
file_list = list(self.caption_dict.keys())
# remove items in the _controls_ folder
file_list = [x for x in file_list if not os.path.basename(os.path.dirname(x)) == "_controls"]
if self.dataset_config.num_repeats > 1:
# repeat the list
@@ -548,6 +552,9 @@ class AiToolkitDataset(LatentCachingMixin, CLIPCachingMixin, BucketsMixin, Capti
self.cache_latents_all_latents()
if self.is_caching_clip_vision_to_disk:
self.cache_clip_vision_to_disk()
if self.is_generating_controls:
# always do this last
self.setup_controls()
else:
if self.dataset_config.poi is not None:
# handle cropping to a specific point of interest

View File

@@ -18,6 +18,7 @@ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, Sigl
from toolkit.basic import flush, value_map
from toolkit.buckets import get_bucket_for_image_size, get_resolution
from toolkit.config_modules import ControlTypes
from toolkit.metadata import get_meta_for_safetensors
from toolkit.models.pixtral_vision import PixtralVisionImagePreprocessorCompatible
from toolkit.prompt_utils import inject_trigger_into_prompt
@@ -62,6 +63,7 @@ transforms_dict = {
}
caption_ext_list = ['txt', 'json', 'caption']
img_ext_list = ['.jpg', '.jpeg', '.png', '.webp']
def standardize_images(images):
@@ -755,10 +757,10 @@ class InpaintControlFileItemDTOMixin:
inpaint_path = dataset_config.inpaint_path
# we are using control images
img_path = kwargs.get('path', None)
img_ext_list = ['.png', '.webp']
img_inpaint_ext_list = ['.png', '.webp']
file_name_no_ext = os.path.splitext(os.path.basename(img_path))[0]
for ext in img_ext_list:
for ext in img_inpaint_ext_list:
p = os.path.join(inpaint_path, file_name_no_ext + ext)
if os.path.exists(p):
self.inpaint_path = p
@@ -842,7 +844,6 @@ class ControlFileItemDTOMixin:
self.full_size_control_images = dataset_config.full_size_control_images
# we are using control images
img_path = kwargs.get('path', None)
img_ext_list = ['.jpg', '.jpeg', '.png', '.webp']
file_name_no_ext = os.path.splitext(os.path.basename(img_path))[0]
found_control_images = []
@@ -959,7 +960,6 @@ class ClipImageFileItemDTOMixin:
clip_image_path = dataset_config.clip_image_path
# we are using control images
img_path = kwargs.get('path', None)
img_ext_list = ['.jpg', '.jpeg', '.png', '.webp']
file_name_no_ext = os.path.splitext(os.path.basename(img_path))[0]
for ext in img_ext_list:
if os.path.exists(os.path.join(clip_image_path, file_name_no_ext + ext)):
@@ -1062,7 +1062,6 @@ class ClipImageFileItemDTOMixin:
# randomly grab an image path from the same folder
pool_folder = os.path.dirname(self.path)
# find all images in the folder
img_ext_list = ['.jpg', '.jpeg', '.png', '.webp']
img_files = []
for ext in img_ext_list:
img_files += glob.glob(os.path.join(pool_folder, f'*{ext}'))
@@ -1281,7 +1280,6 @@ class MaskFileItemDTOMixin:
mask_path = dataset_config.mask_path if dataset_config.mask_path is not None else dataset_config.alpha_mask
# we are using control images
img_path = kwargs.get('path', None)
img_ext_list = ['.jpg', '.jpeg', '.png', '.webp']
file_name_no_ext = os.path.splitext(os.path.basename(img_path))[0]
for ext in img_ext_list:
if os.path.exists(os.path.join(mask_path, file_name_no_ext + ext)):
@@ -1385,7 +1383,6 @@ class UnconditionalFileItemDTOMixin:
if dataset_config.unconditional_path is not None:
# we are using control images
img_path = kwargs.get('path', None)
img_ext_list = ['.jpg', '.jpeg', '.png', '.webp']
file_name_no_ext = os.path.splitext(os.path.basename(img_path))[0]
for ext in img_ext_list:
if os.path.exists(os.path.join(dataset_config.unconditional_path, file_name_no_ext + ext)):
@@ -1944,3 +1941,182 @@ class CLIPCachingMixin:
# restore device state
self.sd.restore_device_state()
class ControlCachingMixin:
def __init__(self: 'AiToolkitDataset', **kwargs):
if hasattr(super(), '__init__'):
super().__init__(**kwargs)
self.control_depth_model = None
self.control_pose_model = None
self.control_line_model = None
self.control_bg_remover = None
def get_control_path(self: 'AiToolkitDataset', file_item:'FileItemDTO', control_type: ControlTypes):
coltrols_folder = os.path.join(os.path.dirname(file_item.path), '_controls')
file_name_no_ext = os.path.splitext(os.path.basename(file_item.path))[0]
file_name_no_ext_control = f"{file_name_no_ext}.{control_type}"
for ext in img_ext_list:
possible_path = os.path.join(coltrols_folder, file_name_no_ext_control + ext)
if os.path.exists(possible_path):
return possible_path
# if we get here, we need to generate the control
return None
def add_control_path_to_file_item(self: 'AiToolkitDataset', file_item: 'FileItemDTO', control_path: str, control_type: ControlTypes):
if control_type == 'inpaint':
file_item.inpaint_path = control_path
file_item.has_inpaint_image = True
elif control_type == 'mask':
file_item.mask_path = control_path
file_item.has_mask_image = True
else:
if file_item.control_path is None:
file_item.control_path = [control_path]
elif isinstance(file_item.control_path, str):
file_item.control_path = [file_item.control_path, control_path]
elif isinstance(file_item.control_path, list):
file_item.control_path.append(control_path)
else:
raise Exception(f"Error: control_path is not a string or list: {file_item.control_path}")
file_item.has_control_image = True
def setup_controls(self: 'AiToolkitDataset'):
if not self.is_generating_controls:
return
with torch.no_grad():
print_acc(f"Generating controls for {self.dataset_path}")
has_unloaded = False
device = self.sd.device
# controls 'depth', 'line', 'pose', 'inpaint', 'mask'
# use tqdm to show progress
i = 0
for file_item in tqdm(self.file_list, desc=f'Generating Controls'):
coltrols_folder = os.path.join(os.path.dirname(file_item.path), '_controls')
file_name_no_ext = os.path.splitext(os.path.basename(file_item.path))[0]
image: Image = None
for control_type in self.dataset_config.controls:
control_path = self.get_control_path(file_item, control_type)
if control_path is not None:
self.add_control_path_to_file_item(file_item, control_path, control_type)
else:
# we need to generate the control. Unload model if not unloaded
if not has_unloaded:
print("Unloading model to generate controls")
self.sd.set_device_state_preset('unload')
has_unloaded = True
if image is None:
# make sure image is loaded if we havent loaded it with another control
image = Image.open(file_item.path).convert('RGB')
image = exif_transpose(image)
# resize to a max of 1mp
max_size = 1024 * 1024
w, h = image.size
if w * h > max_size:
scale = math.sqrt(max_size / (w * h))
w = int(w * scale)
h = int(h * scale)
image = image.resize((w, h), Image.BICUBIC)
save_path = os.path.join(coltrols_folder, f"{file_name_no_ext}.{control_type}.jpg")
os.makedirs(coltrols_folder, exist_ok=True)
if control_type == 'depth':
if self.control_depth_model is None:
from transformers import pipeline
self.control_depth_model = pipeline(
task="depth-estimation",
model="depth-anything/Depth-Anything-V2-Large-hf",
device=device,
torch_dtype=torch.float16
)
img = image.copy()
in_size = img.size
output = self.control_depth_model(img)
out_tensor = output["predicted_depth"] # shape (1, H, W) 0 - 255
out_tensor = out_tensor.clamp(0, 255)
out_tensor = out_tensor.squeeze(0).cpu().numpy()
img = Image.fromarray(out_tensor.astype('uint8'))
img = img.resize(in_size, Image.LANCZOS)
img.save(save_path)
self.add_control_path_to_file_item(file_item, save_path, control_type)
elif control_type == 'pose':
if self.control_pose_model is None:
from controlnet_aux import OpenposeDetector
self.control_pose_model = OpenposeDetector.from_pretrained("lllyasviel/Annotators").to(device)
img = image.copy()
detect_res = int(math.sqrt(img.size[0] * img.size[1]))
img = self.control_pose_model(img, hand_and_face=True, detect_resolution=detect_res, image_resolution=detect_res)
img = img.convert('RGB')
img.save(save_path)
self.add_control_path_to_file_item(file_item, save_path, control_type)
elif control_type == 'line':
if self.control_line_model is None:
from controlnet_aux import TEEDdetector
self.control_line_model = TEEDdetector.from_pretrained("fal-ai/teed", filename="5_model.pth").to(device)
img = image.copy()
img = self.control_line_model(img, detect_resolution=1024)
img = img.convert('RGB')
img.save(save_path)
self.add_control_path_to_file_item(file_item, save_path, control_type)
elif control_type == 'inpaint' or control_type == 'mask':
img = image.copy()
if self.control_bg_remover is None:
from transformers import AutoModelForImageSegmentation
self.control_bg_remover = AutoModelForImageSegmentation.from_pretrained(
'ZhengPeng7/BiRefNet_HR',
trust_remote_code=True,
revision="595e212b3eaa6a1beaad56cee49749b1e00b1596",
torch_dtype=torch.float16
).to(device)
self.control_bg_remover.eval()
image_size = (1024, 1024)
transform_image = transforms.Compose([
transforms.Resize(image_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
input_images = transform_image(img).unsqueeze(0).to('cuda').to(torch.float16)
# Prediction
preds = self.control_bg_remover(input_images)[-1].sigmoid().cpu()
pred = preds[0].squeeze()
pred_pil = transforms.ToPILImage()(pred)
mask = pred_pil.resize(img.size)
if control_type == 'inpaint':
# inpainting feature currently only supports "erased" section desired to inpaint
mask = ImageOps.invert(mask)
img.putalpha(mask)
save_path = os.path.join(coltrols_folder, f"{file_name_no_ext}.{control_type}.webp")
else:
img = mask
img = img.convert('RGB')
img.save(save_path)
self.add_control_path_to_file_item(file_item, save_path, control_type)
else:
raise Exception(f"Error: unknown control type {control_type}")
i += 1
# remove models
self.control_depth_model = None
self.control_pose_model = None
self.control_line_model = None
self.control_bg_remover = None
flush()
# restore device state
if has_unloaded:
self.sd.restore_device_state()

View File

@@ -3009,6 +3009,8 @@ class StableDiffusion:
active_modules = ['vae']
if device_state_preset in ['cache_clip']:
active_modules = ['clip']
if device_state_preset in ['unload']:
active_modules = []
if device_state_preset in ['generate']:
active_modules = ['vae', 'unet', 'text_encoder', 'adapter', 'refiner_unet']