diff --git a/toolkit/dataloader_mixins.py b/toolkit/dataloader_mixins.py index fe2cab46..6cbab537 100644 --- a/toolkit/dataloader_mixins.py +++ b/toolkit/dataloader_mixins.py @@ -281,7 +281,8 @@ class CaptionProcessingDTOMixin: class ImageProcessingDTOMixin: def load_and_process_image( self: 'FileItemDTO', - transform: Union[None, transforms.Compose] + transform: Union[None, transforms.Compose], + only_load_latents=False ): # if we are caching latents, just do that if self.is_latent_cached: @@ -363,10 +364,11 @@ class ImageProcessingDTOMixin: img = transform(img) self.tensor = img - if self.has_control_image: - self.load_control_image() - if self.has_mask_image: - self.load_mask_image() + if not only_load_latents: + if self.has_control_image: + self.load_control_image() + if self.has_mask_image: + self.load_mask_image() class ControlFileItemDTOMixin: @@ -480,9 +482,9 @@ class MaskFileItemDTOMixin: # do a flip img.transpose(Image.FLIP_TOP_BOTTOM) - # randomly apply a blur up to 10% of the size of the min (width, height) + # randomly apply a blur up to 2% of the size of the min (width, height) min_size = min(img.width, img.height) - blur_radius = int(min_size * random.random() * 0.1) + blur_radius = int(min_size * random.random() * 0.02) img = img.filter(ImageFilter.GaussianBlur(radius=blur_radius)) # make grayscale @@ -718,7 +720,7 @@ class LatentCachingMixin: else: # not saved to disk, calculate # load the image first - file_item.load_and_process_image(self.transform) + file_item.load_and_process_image(self.transform, only_load_latents=True) dtype = self.sd.torch_dtype device = self.sd.device_torch # add batch dimension @@ -742,6 +744,7 @@ class LatentCachingMixin: del latent del file_item.tensor + flush(garbage_collect=False) file_item.is_latent_cached = True # flush every 100