mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Reworked control generator. It is now significantly faster. Also uses better pose model with better license.
This commit is contained in:
@@ -2,6 +2,7 @@ torch==2.6.0
|
||||
torchvision==0.21.0
|
||||
torchao==0.9.0
|
||||
safetensors
|
||||
git+https://github.com/jaretburkett/easy_dwpose.git
|
||||
git+https://github.com/huggingface/diffusers@363d1ab7e24c5ed6c190abb00df66d9edb74383b
|
||||
transformers==4.49.0
|
||||
lycoris-lora==1.8.3
|
||||
@@ -22,7 +23,7 @@ k-diffusion
|
||||
open_clip_torch
|
||||
timm
|
||||
prodigyopt
|
||||
controlnet_aux==0.0.9
|
||||
controlnet_aux==0.0.10
|
||||
python-dotenv
|
||||
bitsandbytes
|
||||
hf_transfer
|
||||
|
||||
279
toolkit/control_generator.py
Normal file
279
toolkit/control_generator.py
Normal file
@@ -0,0 +1,279 @@
|
||||
import gc
|
||||
import math
|
||||
import os
|
||||
import torch
|
||||
from typing import Literal
|
||||
from PIL import Image, ImageFilter, ImageOps
|
||||
from PIL.ImageOps import exif_transpose
|
||||
from tqdm import tqdm
|
||||
|
||||
from torchvision import transforms
|
||||
|
||||
# supress all warnings
|
||||
import warnings
|
||||
|
||||
warnings.filterwarnings("ignore", category=UserWarning)
|
||||
warnings.filterwarnings("ignore", category=FutureWarning)
|
||||
|
||||
|
||||
def flush(garbage_collect=True):
|
||||
torch.cuda.empty_cache()
|
||||
if garbage_collect:
|
||||
gc.collect()
|
||||
|
||||
|
||||
ControlTypes = Literal['depth', 'pose', 'line', 'inpaint', 'mask']
|
||||
|
||||
img_ext_list = ['.jpg', '.jpeg', '.png', '.webp']
|
||||
|
||||
|
||||
class ControlGenerator:
|
||||
def __init__(self, device, sd=None):
|
||||
self.device = device
|
||||
self.sd = sd # optional. It will unload the model if not None
|
||||
self.has_unloaded = False
|
||||
self.control_depth_model = None
|
||||
self.control_pose_model = None
|
||||
self.control_line_model = None
|
||||
self.control_bg_remover = None
|
||||
self.debug = False
|
||||
self.regen = False
|
||||
|
||||
def get_control_path(self, img_path, control_type: ControlTypes):
|
||||
if self.regen:
|
||||
return self._generate_control(img_path, control_type)
|
||||
coltrols_folder = os.path.join(os.path.dirname(img_path), '_controls')
|
||||
file_name_no_ext = os.path.splitext(os.path.basename(img_path))[0]
|
||||
file_name_no_ext_control = f"{file_name_no_ext}.{control_type}"
|
||||
for ext in img_ext_list:
|
||||
possible_path = os.path.join(
|
||||
coltrols_folder, file_name_no_ext_control + ext)
|
||||
if os.path.exists(possible_path):
|
||||
return possible_path
|
||||
# if we get here, we need to generate the control
|
||||
return self._generate_control(img_path, control_type)
|
||||
|
||||
def debug_print(self, *args, **kwargs):
|
||||
if self.debug:
|
||||
print(*args, **kwargs)
|
||||
|
||||
def _generate_control(self, img_path, control_type):
|
||||
device = self.device
|
||||
image: Image = None
|
||||
|
||||
coltrols_folder = os.path.join(os.path.dirname(img_path), '_controls')
|
||||
file_name_no_ext = os.path.splitext(os.path.basename(img_path))[0]
|
||||
|
||||
# we need to generate the control. Unload model if not unloaded
|
||||
if not self.has_unloaded:
|
||||
if self.sd is not None:
|
||||
print("Unloading model to generate controls")
|
||||
self.sd.set_device_state_preset('unload')
|
||||
self.has_unloaded = True
|
||||
|
||||
if image is None:
|
||||
# make sure image is loaded if we havent loaded it with another control
|
||||
image = Image.open(img_path).convert('RGB')
|
||||
image = exif_transpose(image)
|
||||
|
||||
# resize to a max of 1mp
|
||||
max_size = 1024 * 1024
|
||||
|
||||
w, h = image.size
|
||||
if w * h > max_size:
|
||||
scale = math.sqrt(max_size / (w * h))
|
||||
w = int(w * scale)
|
||||
h = int(h * scale)
|
||||
image = image.resize((w, h), Image.BICUBIC)
|
||||
|
||||
save_path = os.path.join(
|
||||
coltrols_folder, f"{file_name_no_ext}.{control_type}.jpg")
|
||||
os.makedirs(coltrols_folder, exist_ok=True)
|
||||
if control_type == 'depth':
|
||||
self.debug_print("Generating depth control")
|
||||
if self.control_depth_model is None:
|
||||
from transformers import pipeline
|
||||
self.control_depth_model = pipeline(
|
||||
task="depth-estimation",
|
||||
model="depth-anything/Depth-Anything-V2-Large-hf",
|
||||
device=device,
|
||||
torch_dtype=torch.float16
|
||||
)
|
||||
img = image.copy()
|
||||
in_size = img.size
|
||||
output = self.control_depth_model(img)
|
||||
out_tensor = output["predicted_depth"] # shape (1, H, W) 0 - 255
|
||||
out_tensor = out_tensor.clamp(0, 255)
|
||||
out_tensor = out_tensor.squeeze(0).cpu().numpy()
|
||||
img = Image.fromarray(out_tensor.astype('uint8'))
|
||||
img = img.resize(in_size, Image.LANCZOS)
|
||||
img.save(save_path)
|
||||
return save_path
|
||||
elif control_type == 'pose':
|
||||
self.debug_print("Generating pose control")
|
||||
if self.control_pose_model is None:
|
||||
try:
|
||||
import onnxruntime
|
||||
onnxruntime.set_default_logger_severity(3)
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"onnxruntime is not installed. Please install it with pip install onnxruntime or onnxruntime-gpu")
|
||||
try:
|
||||
from easy_dwpose import DWposeDetector
|
||||
self.control_pose_model = DWposeDetector(
|
||||
device=str(device))
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"easy-dwpose is not installed. Please install it with pip install easy-dwpose")
|
||||
img = image.copy()
|
||||
|
||||
detect_res = int(math.sqrt(img.size[0] * img.size[1]))
|
||||
img = self.control_pose_model(
|
||||
img, output_type="pil", include_hands=True, include_face=True, detect_resolution=detect_res)
|
||||
img = img.convert('RGB')
|
||||
img.save(save_path)
|
||||
return save_path
|
||||
|
||||
elif control_type == 'line':
|
||||
self.debug_print("Generating line control")
|
||||
if self.control_line_model is None:
|
||||
from controlnet_aux import TEEDdetector
|
||||
self.control_line_model = TEEDdetector.from_pretrained(
|
||||
"fal-ai/teed", filename="5_model.pth").to(device)
|
||||
img = image.copy()
|
||||
img = self.control_line_model(img, detect_resolution=1024)
|
||||
# apply threshold
|
||||
# img = img.filter(ImageFilter.GaussianBlur(radius=1))
|
||||
img = img.point(lambda p: p > 128 and 255)
|
||||
img = img.convert('RGB')
|
||||
img.save(save_path)
|
||||
return save_path
|
||||
elif control_type == 'inpaint' or control_type == 'mask':
|
||||
self.debug_print("Generating inpaint/mask control")
|
||||
img = image.copy()
|
||||
if self.control_bg_remover is None:
|
||||
from transformers import AutoModelForImageSegmentation
|
||||
self.control_bg_remover = AutoModelForImageSegmentation.from_pretrained(
|
||||
'ZhengPeng7/BiRefNet_HR',
|
||||
trust_remote_code=True,
|
||||
revision="595e212b3eaa6a1beaad56cee49749b1e00b1596",
|
||||
torch_dtype=torch.float16
|
||||
).to(device)
|
||||
self.control_bg_remover.eval()
|
||||
|
||||
image_size = (1024, 1024)
|
||||
transform_image = transforms.Compose([
|
||||
transforms.Resize(image_size),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.485, 0.456, 0.406], [
|
||||
0.229, 0.224, 0.225])
|
||||
])
|
||||
|
||||
input_images = transform_image(img).unsqueeze(
|
||||
0).to('cuda').to(torch.float16)
|
||||
|
||||
# Prediction
|
||||
preds = self.control_bg_remover(input_images)[-1].sigmoid().cpu()
|
||||
pred = preds[0].squeeze()
|
||||
pred_pil = transforms.ToPILImage()(pred)
|
||||
mask = pred_pil.resize(img.size)
|
||||
if control_type == 'inpaint':
|
||||
# inpainting feature currently only supports "erased" section desired to inpaint
|
||||
mask = ImageOps.invert(mask)
|
||||
img.putalpha(mask)
|
||||
save_path = os.path.join(
|
||||
coltrols_folder, f"{file_name_no_ext}.{control_type}.webp")
|
||||
else:
|
||||
img = mask
|
||||
img = img.convert('RGB')
|
||||
img.save(save_path)
|
||||
return save_path
|
||||
else:
|
||||
raise Exception(f"Error: unknown control type {control_type}")
|
||||
|
||||
def cleanup(self):
|
||||
if self.control_depth_model is not None:
|
||||
self.control_depth_model = None
|
||||
if self.control_pose_model is not None:
|
||||
self.control_pose_model = None
|
||||
if self.control_line_model is not None:
|
||||
self.control_line_model = None
|
||||
if self.control_bg_remover is not None:
|
||||
self.control_bg_remover = None
|
||||
if self.sd is not None and self.has_unloaded:
|
||||
self.sd.restore_device_state()
|
||||
self.has_unloaded = False
|
||||
|
||||
flush()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
import argparse
|
||||
import time
|
||||
import transformers
|
||||
transformers.logging.set_verbosity_error()
|
||||
|
||||
control_times = {
|
||||
'depth': 0,
|
||||
'pose': 0,
|
||||
'line': 0,
|
||||
'inpaint': 0,
|
||||
'mask': 0
|
||||
}
|
||||
|
||||
controls = control_times.keys()
|
||||
|
||||
parser = argparse.ArgumentParser(description="Generate control images")
|
||||
parser.add_argument("img_dir", type=str, help="Path to image directory")
|
||||
parser.add_argument('--debug', action='store_true',
|
||||
help="Enable debug mode")
|
||||
parser.add_argument('--regen', action='store_true',
|
||||
help="Regenerate all controls")
|
||||
|
||||
args = parser.parse_args()
|
||||
img_dir = args.img_dir
|
||||
if not os.path.exists(img_dir):
|
||||
print(f"Error: {img_dir} does not exist")
|
||||
exit()
|
||||
if not os.path.isdir(img_dir):
|
||||
print(f"Error: {img_dir} is not a directory")
|
||||
exit()
|
||||
|
||||
# find images
|
||||
img_list = []
|
||||
for root, dirs, files in os.walk(img_dir):
|
||||
for file in files:
|
||||
if "_controls" in root:
|
||||
continue
|
||||
if file.startswith('.'):
|
||||
continue
|
||||
if file.lower().endswith(tuple(img_ext_list)):
|
||||
img_list.append(os.path.join(root, file))
|
||||
if len(img_list) == 0:
|
||||
print(f"Error: no images found in {img_dir}")
|
||||
exit()
|
||||
|
||||
# load model
|
||||
idx = 0
|
||||
for img_path in tqdm(img_list):
|
||||
for control in controls:
|
||||
start = time.time()
|
||||
control_gen = ControlGenerator(torch.device('cuda'))
|
||||
control_gen.debug = args.debug
|
||||
control_gen.regen = args.regen
|
||||
control_path = control_gen.get_control_path(img_path, control)
|
||||
end = time.time()
|
||||
# dont track for first 2 images
|
||||
if idx < 2:
|
||||
continue
|
||||
control_times[control] += end - start
|
||||
idx += 1
|
||||
|
||||
# determine avgt time
|
||||
for control in controls:
|
||||
control_times[control] /= (idx - 2)
|
||||
print(
|
||||
f"Avg time for {control} control: {control_times[control]:.2f} seconds")
|
||||
|
||||
print("Done")
|
||||
@@ -19,6 +19,7 @@ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, Sigl
|
||||
from toolkit.basic import flush, value_map
|
||||
from toolkit.buckets import get_bucket_for_image_size, get_resolution
|
||||
from toolkit.config_modules import ControlTypes
|
||||
from toolkit.control_generator import ControlGenerator
|
||||
from toolkit.metadata import get_meta_for_safetensors
|
||||
from toolkit.models.pixtral_vision import PixtralVisionImagePreprocessorCompatible
|
||||
from toolkit.prompt_utils import inject_trigger_into_prompt
|
||||
@@ -1950,21 +1951,7 @@ class ControlCachingMixin:
|
||||
def __init__(self: 'AiToolkitDataset', **kwargs):
|
||||
if hasattr(super(), '__init__'):
|
||||
super().__init__(**kwargs)
|
||||
self.control_depth_model = None
|
||||
self.control_pose_model = None
|
||||
self.control_line_model = None
|
||||
self.control_bg_remover = None
|
||||
|
||||
def get_control_path(self: 'AiToolkitDataset', file_item:'FileItemDTO', control_type: ControlTypes):
|
||||
coltrols_folder = os.path.join(os.path.dirname(file_item.path), '_controls')
|
||||
file_name_no_ext = os.path.splitext(os.path.basename(file_item.path))[0]
|
||||
file_name_no_ext_control = f"{file_name_no_ext}.{control_type}"
|
||||
for ext in img_ext_list:
|
||||
possible_path = os.path.join(coltrols_folder, file_name_no_ext_control + ext)
|
||||
if os.path.exists(possible_path):
|
||||
return possible_path
|
||||
# if we get here, we need to generate the control
|
||||
return None
|
||||
self.control_generator: ControlGenerator = None
|
||||
|
||||
def add_control_path_to_file_item(self: 'AiToolkitDataset', file_item: 'FileItemDTO', control_path: str, control_type: ControlTypes):
|
||||
if control_type == 'inpaint':
|
||||
@@ -1989,136 +1976,23 @@ class ControlCachingMixin:
|
||||
return
|
||||
with torch.no_grad():
|
||||
print_acc(f"Generating controls for {self.dataset_path}")
|
||||
|
||||
has_unloaded = False
|
||||
device = self.sd.device
|
||||
|
||||
# controls 'depth', 'line', 'pose', 'inpaint', 'mask'
|
||||
self.control_generator = ControlGenerator(
|
||||
device=device,
|
||||
sd=self.sd,
|
||||
)
|
||||
|
||||
# use tqdm to show progress
|
||||
i = 0
|
||||
for file_item in tqdm(self.file_list, desc=f'Generating Controls'):
|
||||
coltrols_folder = os.path.join(os.path.dirname(file_item.path), '_controls')
|
||||
file_name_no_ext = os.path.splitext(os.path.basename(file_item.path))[0]
|
||||
|
||||
image: Image = None
|
||||
|
||||
for control_type in self.dataset_config.controls:
|
||||
control_path = self.get_control_path(file_item, control_type)
|
||||
# generates the control if it is not already there
|
||||
control_path = self.control_generator.get_control_path(file_item.path, control_type)
|
||||
if control_path is not None:
|
||||
self.add_control_path_to_file_item(file_item, control_path, control_type)
|
||||
else:
|
||||
# we need to generate the control. Unload model if not unloaded
|
||||
if not has_unloaded:
|
||||
print("Unloading model to generate controls")
|
||||
self.sd.set_device_state_preset('unload')
|
||||
has_unloaded = True
|
||||
|
||||
if image is None:
|
||||
# make sure image is loaded if we havent loaded it with another control
|
||||
image = Image.open(file_item.path).convert('RGB')
|
||||
image = exif_transpose(image)
|
||||
|
||||
# resize to a max of 1mp
|
||||
max_size = 1024 * 1024
|
||||
|
||||
w, h = image.size
|
||||
if w * h > max_size:
|
||||
scale = math.sqrt(max_size / (w * h))
|
||||
w = int(w * scale)
|
||||
h = int(h * scale)
|
||||
image = image.resize((w, h), Image.BICUBIC)
|
||||
|
||||
save_path = os.path.join(coltrols_folder, f"{file_name_no_ext}.{control_type}.jpg")
|
||||
os.makedirs(coltrols_folder, exist_ok=True)
|
||||
if control_type == 'depth':
|
||||
if self.control_depth_model is None:
|
||||
from transformers import pipeline
|
||||
self.control_depth_model = pipeline(
|
||||
task="depth-estimation",
|
||||
model="depth-anything/Depth-Anything-V2-Large-hf",
|
||||
device=device,
|
||||
torch_dtype=torch.float16
|
||||
)
|
||||
img = image.copy()
|
||||
in_size = img.size
|
||||
output = self.control_depth_model(img)
|
||||
out_tensor = output["predicted_depth"] # shape (1, H, W) 0 - 255
|
||||
out_tensor = out_tensor.clamp(0, 255)
|
||||
out_tensor = out_tensor.squeeze(0).cpu().numpy()
|
||||
img = Image.fromarray(out_tensor.astype('uint8'))
|
||||
img = img.resize(in_size, Image.LANCZOS)
|
||||
img.save(save_path)
|
||||
self.add_control_path_to_file_item(file_item, save_path, control_type)
|
||||
elif control_type == 'pose':
|
||||
if self.control_pose_model is None:
|
||||
from controlnet_aux import OpenposeDetector
|
||||
self.control_pose_model = OpenposeDetector.from_pretrained("lllyasviel/Annotators").to(device)
|
||||
img = image.copy()
|
||||
|
||||
detect_res = int(math.sqrt(img.size[0] * img.size[1]))
|
||||
img = self.control_pose_model(img, hand_and_face=True, detect_resolution=detect_res, image_resolution=detect_res)
|
||||
img = img.convert('RGB')
|
||||
img.save(save_path)
|
||||
self.add_control_path_to_file_item(file_item, save_path, control_type)
|
||||
|
||||
elif control_type == 'line':
|
||||
if self.control_line_model is None:
|
||||
from controlnet_aux import TEEDdetector
|
||||
self.control_line_model = TEEDdetector.from_pretrained("fal-ai/teed", filename="5_model.pth").to(device)
|
||||
img = image.copy()
|
||||
img = self.control_line_model(img, detect_resolution=1024)
|
||||
img = img.convert('RGB')
|
||||
img.save(save_path)
|
||||
self.add_control_path_to_file_item(file_item, save_path, control_type)
|
||||
elif control_type == 'inpaint' or control_type == 'mask':
|
||||
img = image.copy()
|
||||
if self.control_bg_remover is None:
|
||||
from transformers import AutoModelForImageSegmentation
|
||||
self.control_bg_remover = AutoModelForImageSegmentation.from_pretrained(
|
||||
'ZhengPeng7/BiRefNet_HR',
|
||||
trust_remote_code=True,
|
||||
revision="595e212b3eaa6a1beaad56cee49749b1e00b1596",
|
||||
torch_dtype=torch.float16
|
||||
).to(device)
|
||||
self.control_bg_remover.eval()
|
||||
|
||||
image_size = (1024, 1024)
|
||||
transform_image = transforms.Compose([
|
||||
transforms.Resize(image_size),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
||||
])
|
||||
|
||||
input_images = transform_image(img).unsqueeze(0).to('cuda').to(torch.float16)
|
||||
|
||||
# Prediction
|
||||
preds = self.control_bg_remover(input_images)[-1].sigmoid().cpu()
|
||||
pred = preds[0].squeeze()
|
||||
pred_pil = transforms.ToPILImage()(pred)
|
||||
mask = pred_pil.resize(img.size)
|
||||
if control_type == 'inpaint':
|
||||
# inpainting feature currently only supports "erased" section desired to inpaint
|
||||
mask = ImageOps.invert(mask)
|
||||
img.putalpha(mask)
|
||||
save_path = os.path.join(coltrols_folder, f"{file_name_no_ext}.{control_type}.webp")
|
||||
else:
|
||||
img = mask
|
||||
img = img.convert('RGB')
|
||||
img.save(save_path)
|
||||
self.add_control_path_to_file_item(file_item, save_path, control_type)
|
||||
else:
|
||||
raise Exception(f"Error: unknown control type {control_type}")
|
||||
i += 1
|
||||
|
||||
# remove models
|
||||
self.control_depth_model = None
|
||||
self.control_pose_model = None
|
||||
self.control_line_model = None
|
||||
self.control_bg_remover = None
|
||||
self.control_generator.cleanup()
|
||||
self.control_generator = None
|
||||
|
||||
flush()
|
||||
|
||||
# restore device state
|
||||
if has_unloaded:
|
||||
self.sd.restore_device_state()
|
||||
|
||||
Reference in New Issue
Block a user