mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 19:21:39 +00:00
Handle inpainting training for control_lora adapter
This commit is contained in:
@@ -635,6 +635,8 @@ class ImageProcessingDTOMixin:
|
||||
self.get_latent()
|
||||
if self.has_control_image:
|
||||
self.load_control_image()
|
||||
if self.has_inpaint_image:
|
||||
self.load_inpaint_image()
|
||||
if self.has_clip_image:
|
||||
self.load_clip_image()
|
||||
if self.has_mask_image:
|
||||
@@ -730,6 +732,8 @@ class ImageProcessingDTOMixin:
|
||||
if not only_load_latents:
|
||||
if self.has_control_image:
|
||||
self.load_control_image()
|
||||
if self.has_inpaint_image:
|
||||
self.load_inpaint_image()
|
||||
if self.has_clip_image:
|
||||
self.load_clip_image()
|
||||
if self.has_mask_image:
|
||||
@@ -738,6 +742,89 @@ class ImageProcessingDTOMixin:
|
||||
self.load_unconditional_image()
|
||||
|
||||
|
||||
class InpaintControlFileItemDTOMixin:
|
||||
def __init__(self: 'FileItemDTO', *args, **kwargs):
|
||||
if hasattr(super(), '__init__'):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.has_inpaint_image = False
|
||||
self.inpaint_path: Union[str, None] = None
|
||||
self.inpaint_tensor: Union[torch.Tensor, None] = None
|
||||
dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None)
|
||||
if dataset_config.inpaint_path is not None:
|
||||
# find the control image path
|
||||
inpaint_path = dataset_config.inpaint_path
|
||||
# we are using control images
|
||||
img_path = kwargs.get('path', None)
|
||||
img_ext_list = ['.png', '.webp']
|
||||
file_name_no_ext = os.path.splitext(os.path.basename(img_path))[0]
|
||||
|
||||
for ext in img_ext_list:
|
||||
p = os.path.join(inpaint_path, file_name_no_ext + ext)
|
||||
if os.path.exists(p):
|
||||
self.inpaint_path = p
|
||||
self.has_inpaint_image = True
|
||||
break
|
||||
|
||||
def load_inpaint_image(self: 'FileItemDTO'):
|
||||
try:
|
||||
# image must have alpha channel for inpaint
|
||||
img = Image.open(self.inpaint_path)
|
||||
# make sure has aplha
|
||||
if img.mode != 'RGBA':
|
||||
raise ValueError(f"Image must have alpha channel for inpaint: {self.inpaint_path}")
|
||||
img = exif_transpose(img)
|
||||
|
||||
w, h = img.size
|
||||
if w > h and self.scale_to_width < self.scale_to_height:
|
||||
# throw error, they should match
|
||||
raise ValueError(
|
||||
f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}")
|
||||
elif h > w and self.scale_to_height < self.scale_to_width:
|
||||
# throw error, they should match
|
||||
raise ValueError(
|
||||
f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}")
|
||||
|
||||
if self.flip_x:
|
||||
# do a flip
|
||||
img = img.transpose(Image.FLIP_LEFT_RIGHT)
|
||||
if self.flip_y:
|
||||
# do a flip
|
||||
img = img.transpose(Image.FLIP_TOP_BOTTOM)
|
||||
|
||||
if self.dataset_config.buckets:
|
||||
# scale and crop based on file item
|
||||
img = img.resize((self.scale_to_width, self.scale_to_height), Image.BICUBIC)
|
||||
# img = transforms.CenterCrop((self.crop_height, self.crop_width))(img)
|
||||
# crop
|
||||
img = img.crop((
|
||||
self.crop_x,
|
||||
self.crop_y,
|
||||
self.crop_x + self.crop_width,
|
||||
self.crop_y + self.crop_height
|
||||
))
|
||||
else:
|
||||
raise Exception("Inpaint images not supported for non-bucket datasets")
|
||||
|
||||
transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
])
|
||||
if self.aug_replay_spatial_transforms:
|
||||
tensor = self.augment_spatial_control(img, transform=transform)
|
||||
else:
|
||||
tensor = transform(img)
|
||||
|
||||
# is 0 to 1 with alpha
|
||||
self.inpaint_tensor = tensor
|
||||
|
||||
except Exception as e:
|
||||
print_acc(f"Error: {e}")
|
||||
print_acc(f"Error loading image: {self.inpaint_path}")
|
||||
|
||||
|
||||
def cleanup_inpaint(self: 'FileItemDTO'):
|
||||
self.inpaint_tensor = None
|
||||
|
||||
|
||||
class ControlFileItemDTOMixin:
|
||||
def __init__(self: 'FileItemDTO', *args, **kwargs):
|
||||
if hasattr(super(), '__init__'):
|
||||
@@ -786,7 +873,7 @@ class ControlFileItemDTOMixin:
|
||||
print_acc(f"Error: {e}")
|
||||
print_acc(f"Error loading image: {control_path}")
|
||||
|
||||
if self.full_size_control_images:
|
||||
if not self.full_size_control_images:
|
||||
# we just scale them to 512x512:
|
||||
w, h = img.size
|
||||
img = img.resize((512, 512), Image.BICUBIC)
|
||||
|
||||
Reference in New Issue
Block a user