Fix check for making sure vae is on the right device.

This commit is contained in:
Jaret Burkett
2025-10-21 14:49:20 -06:00
parent 5123090f6c
commit ff14cd6343
5 changed files with 8 additions and 6 deletions

View File

@@ -414,7 +414,7 @@ class QwenImageModel(BaseModel):
dtype = self.vae_torch_dtype
# Move to vae to device if on cpu
if self.vae.device == "cpu":
if self.vae.device == torch.device("cpu"):
self.vae.to(device)
self.vae.eval()
self.vae.requires_grad_(False)

View File

@@ -1662,6 +1662,8 @@ class LatentCachingFileItemDTOMixin:
item["flip_x"] = True
if self.flip_y:
item["flip_y"] = True
if self.dataset_config.num_frames > 1:
item["num_frames"] = self.dataset_config.num_frames
return item
def get_latent_path(self: 'FileItemDTO', recalculate=False):

View File

@@ -1084,7 +1084,7 @@ class BaseModel:
latent_list = []
# Move to vae to device if on cpu
if self.vae.device == 'cpu':
if self.vae.device == torch.device("cpu"):
self.vae.to(device)
self.vae.eval()
self.vae.requires_grad_(False)
@@ -1127,7 +1127,7 @@ class BaseModel:
dtype = self.torch_dtype
# Move to vae to device if on cpu
if self.vae.device == 'cpu':
if self.vae.device == torch.device('cpu'):
self.vae.to(self.device)
latents = latents.to(device, dtype=dtype)
latents = (

View File

@@ -608,7 +608,7 @@ class Wan21(BaseModel):
if dtype is None:
dtype = self.vae_torch_dtype
if self.vae.device == 'cpu':
if self.vae.device == torch.device('cpu'):
self.vae.to(device)
self.vae.eval()
self.vae.requires_grad_(False)

View File

@@ -2516,7 +2516,7 @@ class StableDiffusion:
latent_list = []
# Move to vae to device if on cpu
if self.vae.device == 'cpu':
if self.vae.device == torch.device("cpu"):
self.vae.to(device)
self.vae.eval()
self.vae.requires_grad_(False)
@@ -2558,7 +2558,7 @@ class StableDiffusion:
dtype = self.torch_dtype
# Move to vae to device if on cpu
if self.vae.device == 'cpu':
if self.vae.device == torch.device("cpu"):
self.vae.to(self.device_torch)
latents = latents.to(self.device_torch, dtype=self.torch_dtype)
latents = (latents / self.vae.config['scaling_factor']) + self.vae.config['shift_factor']