From 239addba515eb3d8fc07206ce965f8fa8e7777c0 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Tue, 10 Oct 2023 13:31:47 -0600 Subject: [PATCH] Fixed memory leak when cachine latents to disk --- toolkit/dataloader_mixins.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/toolkit/dataloader_mixins.py b/toolkit/dataloader_mixins.py index fe2cab46..6cbab537 100644 --- a/toolkit/dataloader_mixins.py +++ b/toolkit/dataloader_mixins.py @@ -281,7 +281,8 @@ class CaptionProcessingDTOMixin: class ImageProcessingDTOMixin: def load_and_process_image( self: 'FileItemDTO', - transform: Union[None, transforms.Compose] + transform: Union[None, transforms.Compose], + only_load_latents=False ): # if we are caching latents, just do that if self.is_latent_cached: @@ -363,10 +364,11 @@ class ImageProcessingDTOMixin: img = transform(img) self.tensor = img - if self.has_control_image: - self.load_control_image() - if self.has_mask_image: - self.load_mask_image() + if not only_load_latents: + if self.has_control_image: + self.load_control_image() + if self.has_mask_image: + self.load_mask_image() class ControlFileItemDTOMixin: @@ -480,9 +482,9 @@ class MaskFileItemDTOMixin: # do a flip img.transpose(Image.FLIP_TOP_BOTTOM) - # randomly apply a blur up to 10% of the size of the min (width, height) + # randomly apply a blur up to 2% of the size of the min (width, height) min_size = min(img.width, img.height) - blur_radius = int(min_size * random.random() * 0.1) + blur_radius = int(min_size * random.random() * 0.02) img = img.filter(ImageFilter.GaussianBlur(radius=blur_radius)) # make grayscale @@ -718,7 +720,7 @@ class LatentCachingMixin: else: # not saved to disk, calculate # load the image first - file_item.load_and_process_image(self.transform) + file_item.load_and_process_image(self.transform, only_load_latents=True) dtype = self.sd.torch_dtype device = self.sd.device_torch # add batch dimension @@ -742,6 +744,7 @@ class LatentCachingMixin: del latent del file_item.tensor + flush(garbage_collect=False) file_item.is_latent_cached = True # flush every 100