make DFE work with more VAEs

This commit is contained in:
Jaret Burkett
2026-02-18 09:46:37 -07:00
parent a055947d56
commit 3632656cda

View File

@@ -285,9 +285,12 @@ class DiffusionFeatureExtractor3(nn.Module):
latents = stepped_latents.to(self.vae.device, dtype=self.vae.dtype)
latents = (
latents / self.vae.config['scaling_factor']) + self.vae.config['shift_factor']
tensors_n1p1 = self.vae.decode(latents).sample # -1 to 1
scaling_factor = self.vae.config.scaling_factor if hasattr(self.vae.config, 'scaling_factor') else 1.0
shift_factor = self.vae.config.shift_factor if hasattr(self.vae.config, 'shift_factor') else 0.0
latents = (latents / scaling_factor) + shift_factor
tensors_n1p1 = self.vae.decode(latents) # -1 to 1
if hasattr(tensors_n1p1, 'sample'):
tensors_n1p1 = tensors_n1p1.sample
pred_images = (tensors_n1p1 + 1) / 2 # 0 to 1
@@ -540,7 +543,9 @@ class DiffusionFeatureExtractor4(nn.Module):
if is_video:
# if video, we need to unsqueeze the latents to match the vae input shape
latents = latents.unsqueeze(2)
tensors_n1p1 = self.vae.decode(latents).sample # -1 to 1
tensors_n1p1 = self.vae.decode(latents) # -1 to 1
if hasattr(tensors_n1p1, 'sample'):
tensors_n1p1 = tensors_n1p1.sample
if is_video:
# if video, we need to squeeze the tensors to match the output shape