From ff14cd6343281ef9706e14d82570930215234c6a Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Tue, 21 Oct 2025 14:49:20 -0600 Subject: [PATCH] Fix check for making sure vae is on the right device. --- extensions_built_in/diffusion_models/qwen_image/qwen_image.py | 2 +- toolkit/dataloader_mixins.py | 2 ++ toolkit/models/base_model.py | 4 ++-- toolkit/models/wan21/wan21.py | 2 +- toolkit/stable_diffusion_model.py | 4 ++-- 5 files changed, 8 insertions(+), 6 deletions(-) diff --git a/extensions_built_in/diffusion_models/qwen_image/qwen_image.py b/extensions_built_in/diffusion_models/qwen_image/qwen_image.py index b9c78374..6e6813ff 100644 --- a/extensions_built_in/diffusion_models/qwen_image/qwen_image.py +++ b/extensions_built_in/diffusion_models/qwen_image/qwen_image.py @@ -414,7 +414,7 @@ class QwenImageModel(BaseModel): dtype = self.vae_torch_dtype # Move to vae to device if on cpu - if self.vae.device == "cpu": + if self.vae.device == torch.device("cpu"): self.vae.to(device) self.vae.eval() self.vae.requires_grad_(False) diff --git a/toolkit/dataloader_mixins.py b/toolkit/dataloader_mixins.py index 3fe10592..3490806b 100644 --- a/toolkit/dataloader_mixins.py +++ b/toolkit/dataloader_mixins.py @@ -1662,6 +1662,8 @@ class LatentCachingFileItemDTOMixin: item["flip_x"] = True if self.flip_y: item["flip_y"] = True + if self.dataset_config.num_frames > 1: + item["num_frames"] = self.dataset_config.num_frames return item def get_latent_path(self: 'FileItemDTO', recalculate=False): diff --git a/toolkit/models/base_model.py b/toolkit/models/base_model.py index c81141da..31f18c1b 100644 --- a/toolkit/models/base_model.py +++ b/toolkit/models/base_model.py @@ -1084,7 +1084,7 @@ class BaseModel: latent_list = [] # Move to vae to device if on cpu - if self.vae.device == 'cpu': + if self.vae.device == torch.device("cpu"): self.vae.to(device) self.vae.eval() self.vae.requires_grad_(False) @@ -1127,7 +1127,7 @@ class BaseModel: dtype = self.torch_dtype # Move to vae to device if on cpu - if self.vae.device == 'cpu': + if self.vae.device == torch.device('cpu'): self.vae.to(self.device) latents = latents.to(device, dtype=dtype) latents = ( diff --git a/toolkit/models/wan21/wan21.py b/toolkit/models/wan21/wan21.py index 71d3deca..998b2312 100644 --- a/toolkit/models/wan21/wan21.py +++ b/toolkit/models/wan21/wan21.py @@ -608,7 +608,7 @@ class Wan21(BaseModel): if dtype is None: dtype = self.vae_torch_dtype - if self.vae.device == 'cpu': + if self.vae.device == torch.device('cpu'): self.vae.to(device) self.vae.eval() self.vae.requires_grad_(False) diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index 78960ed1..11bae8b1 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -2516,7 +2516,7 @@ class StableDiffusion: latent_list = [] # Move to vae to device if on cpu - if self.vae.device == 'cpu': + if self.vae.device == torch.device("cpu"): self.vae.to(device) self.vae.eval() self.vae.requires_grad_(False) @@ -2558,7 +2558,7 @@ class StableDiffusion: dtype = self.torch_dtype # Move to vae to device if on cpu - if self.vae.device == 'cpu': + if self.vae.device == torch.device("cpu"): self.vae.to(self.device_torch) latents = latents.to(self.device_torch, dtype=self.torch_dtype) latents = (latents / self.vae.config['scaling_factor']) + self.vae.config['shift_factor']