diff --git a/toolkit/models/diffusion_feature_extraction.py b/toolkit/models/diffusion_feature_extraction.py index 65c2de3c..30a19d3d 100644 --- a/toolkit/models/diffusion_feature_extraction.py +++ b/toolkit/models/diffusion_feature_extraction.py @@ -285,9 +285,12 @@ class DiffusionFeatureExtractor3(nn.Module): latents = stepped_latents.to(self.vae.device, dtype=self.vae.dtype) - latents = ( - latents / self.vae.config['scaling_factor']) + self.vae.config['shift_factor'] - tensors_n1p1 = self.vae.decode(latents).sample # -1 to 1 + scaling_factor = self.vae.config.scaling_factor if hasattr(self.vae.config, 'scaling_factor') else 1.0 + shift_factor = self.vae.config.shift_factor if hasattr(self.vae.config, 'shift_factor') else 0.0 + latents = (latents / scaling_factor) + shift_factor + tensors_n1p1 = self.vae.decode(latents) # -1 to 1 + if hasattr(tensors_n1p1, 'sample'): + tensors_n1p1 = tensors_n1p1.sample pred_images = (tensors_n1p1 + 1) / 2 # 0 to 1 @@ -540,7 +543,9 @@ class DiffusionFeatureExtractor4(nn.Module): if is_video: # if video, we need to unsqueeze the latents to match the vae input shape latents = latents.unsqueeze(2) - tensors_n1p1 = self.vae.decode(latents).sample # -1 to 1 + tensors_n1p1 = self.vae.decode(latents) # -1 to 1 + if hasattr(tensors_n1p1, 'sample'): + tensors_n1p1 = tensors_n1p1.sample if is_video: # if video, we need to squeeze the tensors to match the output shape