From 3666b112a89e8a82bff82c06dab59a2ead3072e3 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Fri, 12 Sep 2025 18:09:08 -0600 Subject: [PATCH] DEF for fake vae and adjust scaling --- toolkit/models/FakeVAE.py | 3 ++- toolkit/models/diffusion_feature_extraction.py | 6 +++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/toolkit/models/FakeVAE.py b/toolkit/models/FakeVAE.py index 86ca4730..90e3d507 100644 --- a/toolkit/models/FakeVAE.py +++ b/toolkit/models/FakeVAE.py @@ -25,11 +25,12 @@ class Config: class FakeVAE(nn.Module): - def __init__(self): + def __init__(self, scaling_factor=1.0): super().__init__() self._dtype = torch.float32 self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.config = Config() + self.config.scaling_factor = scaling_factor @property def dtype(self): diff --git a/toolkit/models/diffusion_feature_extraction.py b/toolkit/models/diffusion_feature_extraction.py index 5edff00d..78a85c78 100644 --- a/toolkit/models/diffusion_feature_extraction.py +++ b/toolkit/models/diffusion_feature_extraction.py @@ -531,9 +531,9 @@ class DiffusionFeatureExtractor4(nn.Module): stepped_latents = torch.cat(stepped_chunks, dim=0) latents = stepped_latents.to(self.vae.device, dtype=self.vae.dtype) - - scaling_factor = self.vae.config['scaling_factor'] if 'scaling_factor' in self.vae.config else 1.0 - shift_factor = self.vae.config['shift_factor'] if 'shift_factor' in self.vae.config else 0.0 + + scaling_factor = self.vae.config.scaling_factor if hasattr(self.vae.config, 'scaling_factor') else 1.0 + shift_factor = self.vae.config.shift_factor if hasattr(self.vae.config, 'shift_factor') else 0.0 latents = (latents / scaling_factor) + shift_factor if is_video: # if video, we need to unsqueeze the latents to match the vae input shape