diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index 9cc223c..1628444 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -98,18 +98,31 @@ class SDTrainer(BaseSDTrainProcess): noisy_latents, noise, timesteps, conditioned_prompts, imgs = self.process_general_training_batch(batch) network_weight_list = batch.get_network_weight_list() - adapter_images = None - sigmas = None - if self.adapter: - # todo move this to data loader - if batch.control_tensor is not None: - adapter_images = batch.control_tensor.to(self.device_torch, dtype=dtype).detach() - else: - adapter_images = self.get_adapter_images(batch) - # not 100% sure what this does. But they do it here - # https://github.com/huggingface/diffusers/blob/38a664a3d61e27ab18cd698231422b3c38d6eebf/examples/t2i_adapter/train_t2i_adapter_sdxl.py#L1170 - # sigmas = self.get_sigmas(timesteps, len(noisy_latents.shape), noisy_latents.dtype) - # noisy_latents = noisy_latents / ((sigmas ** 2 + 1) ** 0.5) + with torch.no_grad(): + adapter_images = None + sigmas = None + if self.adapter: + # todo move this to data loader + if batch.control_tensor is not None: + adapter_images = batch.control_tensor.to(self.device_torch, dtype=dtype).detach() + else: + adapter_images = self.get_adapter_images(batch) + # not 100% sure what this does. But they do it here + # https://github.com/huggingface/diffusers/blob/38a664a3d61e27ab18cd698231422b3c38d6eebf/examples/t2i_adapter/train_t2i_adapter_sdxl.py#L1170 + # sigmas = self.get_sigmas(timesteps, len(noisy_latents.shape), noisy_latents.dtype) + # noisy_latents = noisy_latents / ((sigmas ** 2 + 1) ** 0.5) + + mask_multiplier = 1.0 + if batch.mask_tensor is not None: + # upsampling no supported for bfloat16 + mask_multiplier = batch.mask_tensor.to(self.device_torch, dtype=torch.float16).detach() + # scale down to the size of the latents, mask multiplier shape(bs, 1, width, height), noisy_latents shape(bs, channels, width, height) + mask_multiplier = torch.nn.functional.interpolate( + mask_multiplier, size=(noisy_latents.shape[2], noisy_latents.shape[3]) + ) + # expand to match latents + mask_multiplier = mask_multiplier.expand(-1, noisy_latents.shape[1], -1, -1) + mask_multiplier = mask_multiplier.to(self.device_torch, dtype=dtype).detach() # flush() self.optimizer.zero_grad() @@ -188,6 +201,9 @@ class SDTrainer(BaseSDTrainProcess): else: target = noise loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") + # multiply by our mask + loss = loss * mask_multiplier + loss = loss.mean([1, 2, 3]) if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001: diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 3f87521..ed8ea4c 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -235,6 +235,7 @@ class DatasetConfig: self.flip_y: bool = kwargs.get('flip_y', False) self.augments: List[str] = kwargs.get('augments', []) self.control_path: str = kwargs.get('control_path', None) # depth maps, etc + self.mask_path: str = kwargs.get('mask_path', None) # focus mask (black and white. White has higher loss than black) self.poi: Union[str, None] = kwargs.get('poi', None) # if one is set and in json data, will be used as auto crop scale point of interes # cache latents will store them in memory diff --git a/toolkit/data_transfer_object/data_loader.py b/toolkit/data_transfer_object/data_loader.py index 684383f..a36df5d 100644 --- a/toolkit/data_transfer_object/data_loader.py +++ b/toolkit/data_transfer_object/data_loader.py @@ -7,7 +7,7 @@ from PIL.ImageOps import exif_transpose from toolkit import image_utils from toolkit.dataloader_mixins import CaptionProcessingDTOMixin, ImageProcessingDTOMixin, LatentCachingFileItemDTOMixin, \ - ControlFileItemDTOMixin, ArgBreakMixin, PoiFileItemDTOMixin + ControlFileItemDTOMixin, ArgBreakMixin, PoiFileItemDTOMixin, MaskFileItemDTOMixin if TYPE_CHECKING: from toolkit.config_modules import DatasetConfig @@ -27,6 +27,7 @@ class FileItemDTO( CaptionProcessingDTOMixin, ImageProcessingDTOMixin, ControlFileItemDTOMixin, + MaskFileItemDTOMixin, PoiFileItemDTOMixin, ArgBreakMixin, ): @@ -67,6 +68,7 @@ class FileItemDTO( self.tensor = None self.cleanup_latent() self.cleanup_control() + self.cleanup_mask() class DataLoaderBatchDTO: @@ -76,6 +78,8 @@ class DataLoaderBatchDTO: is_latents_cached = self.file_items[0].is_latent_cached self.tensor: Union[torch.Tensor, None] = None self.latents: Union[torch.Tensor, None] = None + self.control_tensor: Union[torch.Tensor, None] = None + self.mask_tensor: Union[torch.Tensor, None] = None if not is_latents_cached: # only return a tensor if latents are not cached self.tensor: torch.Tensor = torch.cat([x.tensor.unsqueeze(0) for x in self.file_items]) @@ -100,6 +104,21 @@ class DataLoaderBatchDTO: else: control_tensors.append(x.control_tensor) self.control_tensor = torch.cat([x.unsqueeze(0) for x in control_tensors]) + + if any([x.mask_tensor is not None for x in self.file_items]): + # find one to use as a base + base_mask_tensor = None + for x in self.file_items: + if x.mask_tensor is not None: + base_mask_tensor = x.mask_tensor + break + mask_tensors = [] + for x in self.file_items: + if x.mask_tensor is None: + mask_tensors.append(torch.zeros_like(base_mask_tensor)) + else: + mask_tensors.append(x.mask_tensor) + self.mask_tensor = torch.cat([x.unsqueeze(0) for x in mask_tensors]) except Exception as e: print(e) raise e diff --git a/toolkit/dataloader_mixins.py b/toolkit/dataloader_mixins.py index 285c96c..bebfd4b 100644 --- a/toolkit/dataloader_mixins.py +++ b/toolkit/dataloader_mixins.py @@ -16,7 +16,7 @@ from toolkit.buckets import get_bucket_for_image_size from toolkit.metadata import get_meta_for_safetensors from toolkit.prompt_utils import inject_trigger_into_prompt from torchvision import transforms -from PIL import Image +from PIL import Image, ImageFilter from PIL.ImageOps import exif_transpose from toolkit.train_tools import get_torch_dtype @@ -288,6 +288,8 @@ class ImageProcessingDTOMixin: self.get_latent() if self.has_control_image: self.load_control_image() + if self.has_mask_image: + self.load_mask_image() return try: img = Image.open(self.path).convert('RGB') @@ -363,6 +365,8 @@ class ImageProcessingDTOMixin: self.tensor = img if self.has_control_image: self.load_control_image() + if self.has_mask_image: + self.load_mask_image() class ControlFileItemDTOMixin: @@ -430,6 +434,79 @@ class ControlFileItemDTOMixin: self.control_tensor = None +class MaskFileItemDTOMixin: + def __init__(self: 'FileItemDTO', *args, **kwargs): + if hasattr(super(), '__init__'): + super().__init__(*args, **kwargs) + self.has_mask_image = False + self.mask_path: Union[str, None] = None + self.mask_tensor: Union[torch.Tensor, None] = None + dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None) + if dataset_config.mask_path is not None: + # find the control image path + mask_path = dataset_config.mask_path + # we are using control images + img_path = kwargs.get('path', None) + img_ext_list = ['.jpg', '.jpeg', '.png', '.webp'] + file_name_no_ext = os.path.splitext(os.path.basename(img_path))[0] + for ext in img_ext_list: + if os.path.exists(os.path.join(mask_path, file_name_no_ext + ext)): + self.mask_path = os.path.join(mask_path, file_name_no_ext + ext) + self.has_mask_image = True + break + + def load_mask_image(self: 'FileItemDTO'): + try: + img = Image.open(self.mask_path).convert('RGB') + img = exif_transpose(img) + except Exception as e: + print(f"Error: {e}") + print(f"Error loading image: {self.mask_path}") + w, h = img.size + if w > h and self.scale_to_width < self.scale_to_height: + # throw error, they should match + raise ValueError( + f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}") + elif h > w and self.scale_to_height < self.scale_to_width: + # throw error, they should match + raise ValueError( + f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}") + + if self.flip_x: + # do a flip + img.transpose(Image.FLIP_LEFT_RIGHT) + if self.flip_y: + # do a flip + img.transpose(Image.FLIP_TOP_BOTTOM) + + # randomly apply a blur up to 10% of the size of the min (width, height) + min_size = min(img.width, img.height) + blur_radius = int(min_size * random.random() * 0.1) + img = img.filter(ImageFilter.GaussianBlur(radius=blur_radius)) + + # make grayscale + img = img.convert('L') + + if self.dataset_config.buckets: + # scale and crop based on file item + img = img.resize((self.scale_to_width, self.scale_to_height), Image.BICUBIC) + # img = transforms.CenterCrop((self.crop_height, self.crop_width))(img) + # crop + img = img.crop(( + self.crop_x, + self.crop_y, + self.crop_x + self.crop_width, + self.crop_y + self.crop_height + )) + else: + raise Exception("Mask images not supported for non-bucket datasets") + + self.mask_tensor = transforms.ToTensor()(img) + # convert to grayscale + + def cleanup_mask(self: 'FileItemDTO'): + self.mask_tensor = None + class PoiFileItemDTOMixin: # Point of interest bounding box. Allows for dynamic cropping without cropping out the main subject # items in the poi will always be inside the image when random cropping