mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-05-01 03:31:35 +00:00
Fix check for making sure vae is on the right device.
This commit is contained in:
@@ -414,7 +414,7 @@ class QwenImageModel(BaseModel):
|
|||||||
dtype = self.vae_torch_dtype
|
dtype = self.vae_torch_dtype
|
||||||
|
|
||||||
# Move to vae to device if on cpu
|
# Move to vae to device if on cpu
|
||||||
if self.vae.device == "cpu":
|
if self.vae.device == torch.device("cpu"):
|
||||||
self.vae.to(device)
|
self.vae.to(device)
|
||||||
self.vae.eval()
|
self.vae.eval()
|
||||||
self.vae.requires_grad_(False)
|
self.vae.requires_grad_(False)
|
||||||
|
|||||||
@@ -1662,6 +1662,8 @@ class LatentCachingFileItemDTOMixin:
|
|||||||
item["flip_x"] = True
|
item["flip_x"] = True
|
||||||
if self.flip_y:
|
if self.flip_y:
|
||||||
item["flip_y"] = True
|
item["flip_y"] = True
|
||||||
|
if self.dataset_config.num_frames > 1:
|
||||||
|
item["num_frames"] = self.dataset_config.num_frames
|
||||||
return item
|
return item
|
||||||
|
|
||||||
def get_latent_path(self: 'FileItemDTO', recalculate=False):
|
def get_latent_path(self: 'FileItemDTO', recalculate=False):
|
||||||
|
|||||||
@@ -1084,7 +1084,7 @@ class BaseModel:
|
|||||||
|
|
||||||
latent_list = []
|
latent_list = []
|
||||||
# Move to vae to device if on cpu
|
# Move to vae to device if on cpu
|
||||||
if self.vae.device == 'cpu':
|
if self.vae.device == torch.device("cpu"):
|
||||||
self.vae.to(device)
|
self.vae.to(device)
|
||||||
self.vae.eval()
|
self.vae.eval()
|
||||||
self.vae.requires_grad_(False)
|
self.vae.requires_grad_(False)
|
||||||
@@ -1127,7 +1127,7 @@ class BaseModel:
|
|||||||
dtype = self.torch_dtype
|
dtype = self.torch_dtype
|
||||||
|
|
||||||
# Move to vae to device if on cpu
|
# Move to vae to device if on cpu
|
||||||
if self.vae.device == 'cpu':
|
if self.vae.device == torch.device('cpu'):
|
||||||
self.vae.to(self.device)
|
self.vae.to(self.device)
|
||||||
latents = latents.to(device, dtype=dtype)
|
latents = latents.to(device, dtype=dtype)
|
||||||
latents = (
|
latents = (
|
||||||
|
|||||||
@@ -608,7 +608,7 @@ class Wan21(BaseModel):
|
|||||||
if dtype is None:
|
if dtype is None:
|
||||||
dtype = self.vae_torch_dtype
|
dtype = self.vae_torch_dtype
|
||||||
|
|
||||||
if self.vae.device == 'cpu':
|
if self.vae.device == torch.device('cpu'):
|
||||||
self.vae.to(device)
|
self.vae.to(device)
|
||||||
self.vae.eval()
|
self.vae.eval()
|
||||||
self.vae.requires_grad_(False)
|
self.vae.requires_grad_(False)
|
||||||
|
|||||||
@@ -2516,7 +2516,7 @@ class StableDiffusion:
|
|||||||
|
|
||||||
latent_list = []
|
latent_list = []
|
||||||
# Move to vae to device if on cpu
|
# Move to vae to device if on cpu
|
||||||
if self.vae.device == 'cpu':
|
if self.vae.device == torch.device("cpu"):
|
||||||
self.vae.to(device)
|
self.vae.to(device)
|
||||||
self.vae.eval()
|
self.vae.eval()
|
||||||
self.vae.requires_grad_(False)
|
self.vae.requires_grad_(False)
|
||||||
@@ -2558,7 +2558,7 @@ class StableDiffusion:
|
|||||||
dtype = self.torch_dtype
|
dtype = self.torch_dtype
|
||||||
|
|
||||||
# Move to vae to device if on cpu
|
# Move to vae to device if on cpu
|
||||||
if self.vae.device == 'cpu':
|
if self.vae.device == torch.device("cpu"):
|
||||||
self.vae.to(self.device_torch)
|
self.vae.to(self.device_torch)
|
||||||
latents = latents.to(self.device_torch, dtype=self.torch_dtype)
|
latents = latents.to(self.device_torch, dtype=self.torch_dtype)
|
||||||
latents = (latents / self.vae.config['scaling_factor']) + self.vae.config['shift_factor']
|
latents = (latents / self.vae.config['scaling_factor']) + self.vae.config['shift_factor']
|
||||||
|
|||||||
Reference in New Issue
Block a user