DEF for fake vae and adjust scaling

This commit is contained in:
Jaret Burkett
2025-09-12 18:09:08 -06:00
parent b95c17dc17
commit 3666b112a8
2 changed files with 5 additions and 4 deletions

View File

@@ -25,11 +25,12 @@ class Config:
class FakeVAE(nn.Module): class FakeVAE(nn.Module):
def __init__(self): def __init__(self, scaling_factor=1.0):
super().__init__() super().__init__()
self._dtype = torch.float32 self._dtype = torch.float32
self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.config = Config() self.config = Config()
self.config.scaling_factor = scaling_factor
@property @property
def dtype(self): def dtype(self):

View File

@@ -531,9 +531,9 @@ class DiffusionFeatureExtractor4(nn.Module):
stepped_latents = torch.cat(stepped_chunks, dim=0) stepped_latents = torch.cat(stepped_chunks, dim=0)
latents = stepped_latents.to(self.vae.device, dtype=self.vae.dtype) latents = stepped_latents.to(self.vae.device, dtype=self.vae.dtype)
scaling_factor = self.vae.config['scaling_factor'] if 'scaling_factor' in self.vae.config else 1.0 scaling_factor = self.vae.config.scaling_factor if hasattr(self.vae.config, 'scaling_factor') else 1.0
shift_factor = self.vae.config['shift_factor'] if 'shift_factor' in self.vae.config else 0.0 shift_factor = self.vae.config.shift_factor if hasattr(self.vae.config, 'shift_factor') else 0.0
latents = (latents / scaling_factor) + shift_factor latents = (latents / scaling_factor) + shift_factor
if is_video: if is_video:
# if video, we need to unsqueeze the latents to match the vae input shape # if video, we need to unsqueeze the latents to match the vae input shape