Added ability to add masks to dataloader and sd trainer to adjust weight of image

This commit is contained in:
Jaret Burkett
2023-10-09 11:21:00 -06:00
parent 1d3de678aa
commit bb1d3793e3
4 changed files with 127 additions and 14 deletions

View File

@@ -16,7 +16,7 @@ from toolkit.buckets import get_bucket_for_image_size
from toolkit.metadata import get_meta_for_safetensors
from toolkit.prompt_utils import inject_trigger_into_prompt
from torchvision import transforms
from PIL import Image
from PIL import Image, ImageFilter
from PIL.ImageOps import exif_transpose
from toolkit.train_tools import get_torch_dtype
@@ -288,6 +288,8 @@ class ImageProcessingDTOMixin:
self.get_latent()
if self.has_control_image:
self.load_control_image()
if self.has_mask_image:
self.load_mask_image()
return
try:
img = Image.open(self.path).convert('RGB')
@@ -363,6 +365,8 @@ class ImageProcessingDTOMixin:
self.tensor = img
if self.has_control_image:
self.load_control_image()
if self.has_mask_image:
self.load_mask_image()
class ControlFileItemDTOMixin:
@@ -430,6 +434,79 @@ class ControlFileItemDTOMixin:
self.control_tensor = None
class MaskFileItemDTOMixin:
def __init__(self: 'FileItemDTO', *args, **kwargs):
if hasattr(super(), '__init__'):
super().__init__(*args, **kwargs)
self.has_mask_image = False
self.mask_path: Union[str, None] = None
self.mask_tensor: Union[torch.Tensor, None] = None
dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None)
if dataset_config.mask_path is not None:
# find the control image path
mask_path = dataset_config.mask_path
# we are using control images
img_path = kwargs.get('path', None)
img_ext_list = ['.jpg', '.jpeg', '.png', '.webp']
file_name_no_ext = os.path.splitext(os.path.basename(img_path))[0]
for ext in img_ext_list:
if os.path.exists(os.path.join(mask_path, file_name_no_ext + ext)):
self.mask_path = os.path.join(mask_path, file_name_no_ext + ext)
self.has_mask_image = True
break
def load_mask_image(self: 'FileItemDTO'):
try:
img = Image.open(self.mask_path).convert('RGB')
img = exif_transpose(img)
except Exception as e:
print(f"Error: {e}")
print(f"Error loading image: {self.mask_path}")
w, h = img.size
if w > h and self.scale_to_width < self.scale_to_height:
# throw error, they should match
raise ValueError(
f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}")
elif h > w and self.scale_to_height < self.scale_to_width:
# throw error, they should match
raise ValueError(
f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}")
if self.flip_x:
# do a flip
img.transpose(Image.FLIP_LEFT_RIGHT)
if self.flip_y:
# do a flip
img.transpose(Image.FLIP_TOP_BOTTOM)
# randomly apply a blur up to 10% of the size of the min (width, height)
min_size = min(img.width, img.height)
blur_radius = int(min_size * random.random() * 0.1)
img = img.filter(ImageFilter.GaussianBlur(radius=blur_radius))
# make grayscale
img = img.convert('L')
if self.dataset_config.buckets:
# scale and crop based on file item
img = img.resize((self.scale_to_width, self.scale_to_height), Image.BICUBIC)
# img = transforms.CenterCrop((self.crop_height, self.crop_width))(img)
# crop
img = img.crop((
self.crop_x,
self.crop_y,
self.crop_x + self.crop_width,
self.crop_y + self.crop_height
))
else:
raise Exception("Mask images not supported for non-bucket datasets")
self.mask_tensor = transforms.ToTensor()(img)
# convert to grayscale
def cleanup_mask(self: 'FileItemDTO'):
self.mask_tensor = None
class PoiFileItemDTOMixin:
# Point of interest bounding box. Allows for dynamic cropping without cropping out the main subject
# items in the poi will always be inside the image when random cropping