Fixed memory leak when cachine latents to disk

This commit is contained in:
Jaret Burkett
2023-10-10 13:31:47 -06:00
parent 63ceffae24
commit 239addba51

View File

@@ -281,7 +281,8 @@ class CaptionProcessingDTOMixin:
class ImageProcessingDTOMixin:
def load_and_process_image(
self: 'FileItemDTO',
transform: Union[None, transforms.Compose]
transform: Union[None, transforms.Compose],
only_load_latents=False
):
# if we are caching latents, just do that
if self.is_latent_cached:
@@ -363,10 +364,11 @@ class ImageProcessingDTOMixin:
img = transform(img)
self.tensor = img
if self.has_control_image:
self.load_control_image()
if self.has_mask_image:
self.load_mask_image()
if not only_load_latents:
if self.has_control_image:
self.load_control_image()
if self.has_mask_image:
self.load_mask_image()
class ControlFileItemDTOMixin:
@@ -480,9 +482,9 @@ class MaskFileItemDTOMixin:
# do a flip
img.transpose(Image.FLIP_TOP_BOTTOM)
# randomly apply a blur up to 10% of the size of the min (width, height)
# randomly apply a blur up to 2% of the size of the min (width, height)
min_size = min(img.width, img.height)
blur_radius = int(min_size * random.random() * 0.1)
blur_radius = int(min_size * random.random() * 0.02)
img = img.filter(ImageFilter.GaussianBlur(radius=blur_radius))
# make grayscale
@@ -718,7 +720,7 @@ class LatentCachingMixin:
else:
# not saved to disk, calculate
# load the image first
file_item.load_and_process_image(self.transform)
file_item.load_and_process_image(self.transform, only_load_latents=True)
dtype = self.sd.torch_dtype
device = self.sd.device_torch
# add batch dimension
@@ -742,6 +744,7 @@ class LatentCachingMixin:
del latent
del file_item.tensor
flush(garbage_collect=False)
file_item.is_latent_cached = True
# flush every 100