DEF for fake vae and adjust scaling

This commit is contained in:
Jaret Burkett
2025-09-12 18:09:08 -06:00
parent b95c17dc17
commit 3666b112a8
2 changed files with 5 additions and 4 deletions

View File

@@ -25,11 +25,12 @@ class Config:
class FakeVAE(nn.Module):
def __init__(self):
def __init__(self, scaling_factor=1.0):
super().__init__()
self._dtype = torch.float32
self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.config = Config()
self.config.scaling_factor = scaling_factor
@property
def dtype(self):

View File

@@ -531,9 +531,9 @@ class DiffusionFeatureExtractor4(nn.Module):
stepped_latents = torch.cat(stepped_chunks, dim=0)
latents = stepped_latents.to(self.vae.device, dtype=self.vae.dtype)
scaling_factor = self.vae.config['scaling_factor'] if 'scaling_factor' in self.vae.config else 1.0
shift_factor = self.vae.config['shift_factor'] if 'shift_factor' in self.vae.config else 0.0
scaling_factor = self.vae.config.scaling_factor if hasattr(self.vae.config, 'scaling_factor') else 1.0
shift_factor = self.vae.config.shift_factor if hasattr(self.vae.config, 'shift_factor') else 0.0
latents = (latents / scaling_factor) + shift_factor
if is_video:
# if video, we need to unsqueeze the latents to match the vae input shape