mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-03-13 06:29:48 +00:00
Huge memory optimizations, many big fixes
This commit is contained in:
@@ -610,6 +610,7 @@ class StableDiffusion:
|
||||
)
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def encode_images(
|
||||
self,
|
||||
image_list: List[torch.Tensor],
|
||||
@@ -625,6 +626,8 @@ class StableDiffusion:
|
||||
# Move to vae to device if on cpu
|
||||
if self.vae.device == 'cpu':
|
||||
self.vae.to(self.device)
|
||||
self.vae.eval()
|
||||
self.vae.requires_grad_(False)
|
||||
# move to device and dtype
|
||||
image_list = [image.to(self.device, dtype=self.torch_dtype) for image in image_list]
|
||||
|
||||
@@ -635,8 +638,9 @@ class StableDiffusion:
|
||||
image_list[i] = Resize((image.shape[1] // 8 * 8, image.shape[2] // 8 * 8))(image)
|
||||
|
||||
images = torch.stack(image_list)
|
||||
flush()
|
||||
latents = self.vae.encode(images).latent_dist.sample()
|
||||
latents = latents * 0.18215
|
||||
latents = latents * self.vae.config['scaling_factor']
|
||||
latents = latents.to(device, dtype=dtype)
|
||||
|
||||
return latents
|
||||
|
||||
Reference in New Issue
Block a user