Handle multi control inputs for control lora training

This commit is contained in:
Jaret Burkett
2025-03-23 07:37:08 -06:00
parent ccb66c748f
commit f10937e6da
7 changed files with 446 additions and 75 deletions

View File

@@ -743,75 +743,100 @@ class ControlFileItemDTOMixin:
if hasattr(super(), '__init__'):
super().__init__(*args, **kwargs)
self.has_control_image = False
self.control_path: Union[str, None] = None
self.control_path: Union[str, List[str], None] = None
self.control_tensor: Union[torch.Tensor, None] = None
dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None)
self.full_size_control_images = False
if dataset_config.control_path is not None:
# find the control image path
control_path = dataset_config.control_path
control_path_list = dataset_config.control_path
if not isinstance(control_path_list, list):
control_path_list = [control_path_list]
self.full_size_control_images = dataset_config.full_size_control_images
# we are using control images
img_path = kwargs.get('path', None)
img_ext_list = ['.jpg', '.jpeg', '.png', '.webp']
file_name_no_ext = os.path.splitext(os.path.basename(img_path))[0]
for ext in img_ext_list:
if os.path.exists(os.path.join(control_path, file_name_no_ext + ext)):
self.control_path = os.path.join(control_path, file_name_no_ext + ext)
self.has_control_image = True
break
found_control_images = []
for control_path in control_path_list:
for ext in img_ext_list:
if os.path.exists(os.path.join(control_path, file_name_no_ext + ext)):
found_control_images.append(os.path.join(control_path, file_name_no_ext + ext))
self.has_control_image = True
break
self.control_path = found_control_images
if len(self.control_path) == 0:
self.control_path = None
elif len(self.control_path) == 1:
# only do one
self.control_path = self.control_path[0]
def load_control_image(self: 'FileItemDTO'):
try:
img = Image.open(self.control_path).convert('RGB')
img = exif_transpose(img)
except Exception as e:
print_acc(f"Error: {e}")
print_acc(f"Error loading image: {self.control_path}")
control_tensors = []
control_path_list = self.control_path
if not isinstance(self.control_path, list):
control_path_list = [self.control_path]
for control_path in control_path_list:
try:
img = Image.open(control_path).convert('RGB')
img = exif_transpose(img)
except Exception as e:
print_acc(f"Error: {e}")
print_acc(f"Error loading image: {control_path}")
if self.full_size_control_images:
# we just scale them to 512x512:
w, h = img.size
img = img.resize((512, 512), Image.BICUBIC)
if self.full_size_control_images:
# we just scale them to 512x512:
w, h = img.size
img = img.resize((512, 512), Image.BICUBIC)
else:
w, h = img.size
if w > h and self.scale_to_width < self.scale_to_height:
# throw error, they should match
raise ValueError(
f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}")
elif h > w and self.scale_to_height < self.scale_to_width:
# throw error, they should match
raise ValueError(
f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}")
if self.flip_x:
# do a flip
img = img.transpose(Image.FLIP_LEFT_RIGHT)
if self.flip_y:
# do a flip
img = img.transpose(Image.FLIP_TOP_BOTTOM)
if self.dataset_config.buckets:
# scale and crop based on file item
img = img.resize((self.scale_to_width, self.scale_to_height), Image.BICUBIC)
# img = transforms.CenterCrop((self.crop_height, self.crop_width))(img)
# crop
img = img.crop((
self.crop_x,
self.crop_y,
self.crop_x + self.crop_width,
self.crop_y + self.crop_height
))
else:
raise Exception("Control images not supported for non-bucket datasets")
transform = transforms.Compose([
transforms.ToTensor(),
])
if self.aug_replay_spatial_transforms:
self.control_tensor = self.augment_spatial_control(img, transform=transform)
w, h = img.size
if w > h and self.scale_to_width < self.scale_to_height:
# throw error, they should match
raise ValueError(
f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}")
elif h > w and self.scale_to_height < self.scale_to_width:
# throw error, they should match
raise ValueError(
f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}")
if self.flip_x:
# do a flip
img = img.transpose(Image.FLIP_LEFT_RIGHT)
if self.flip_y:
# do a flip
img = img.transpose(Image.FLIP_TOP_BOTTOM)
if self.dataset_config.buckets:
# scale and crop based on file item
img = img.resize((self.scale_to_width, self.scale_to_height), Image.BICUBIC)
# img = transforms.CenterCrop((self.crop_height, self.crop_width))(img)
# crop
img = img.crop((
self.crop_x,
self.crop_y,
self.crop_x + self.crop_width,
self.crop_y + self.crop_height
))
else:
raise Exception("Control images not supported for non-bucket datasets")
transform = transforms.Compose([
transforms.ToTensor(),
])
if self.aug_replay_spatial_transforms:
tensor = self.augment_spatial_control(img, transform=transform)
else:
tensor = transform(img)
control_tensors.append(tensor)
if len(control_tensors) == 0:
self.control_tensor = None
elif len(control_tensors) == 1:
self.control_tensor = control_tensors[0]
else:
self.control_tensor = transform(img)
self.control_tensor = torch.stack(control_tensors, dim=0)
def cleanup_control(self: 'FileItemDTO'):
self.control_tensor = None