Huge memory optimizations, many big fixes

This commit is contained in:
Jaret Burkett
2023-08-27 17:48:02 -06:00
parent cc49786ee9
commit c446f768ea
15 changed files with 86 additions and 78 deletions

View File

@@ -610,6 +610,7 @@ class StableDiffusion:
)
)
@torch.no_grad()
def encode_images(
self,
image_list: List[torch.Tensor],
@@ -625,6 +626,8 @@ class StableDiffusion:
# Move to vae to device if on cpu
if self.vae.device == 'cpu':
self.vae.to(self.device)
self.vae.eval()
self.vae.requires_grad_(False)
# move to device and dtype
image_list = [image.to(self.device, dtype=self.torch_dtype) for image in image_list]
@@ -635,8 +638,9 @@ class StableDiffusion:
image_list[i] = Resize((image.shape[1] // 8 * 8, image.shape[2] // 8 * 8))(image)
images = torch.stack(image_list)
flush()
latents = self.vae.encode(images).latent_dist.sample()
latents = latents * 0.18215
latents = latents * self.vae.config['scaling_factor']
latents = latents.to(device, dtype=dtype)
return latents