mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-02-22 05:13:57 +00:00
DEF for fake vae and adjust scaling
This commit is contained in:
@@ -25,11 +25,12 @@ class Config:
|
||||
|
||||
|
||||
class FakeVAE(nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self, scaling_factor=1.0):
|
||||
super().__init__()
|
||||
self._dtype = torch.float32
|
||||
self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
self.config = Config()
|
||||
self.config.scaling_factor = scaling_factor
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
|
||||
@@ -531,9 +531,9 @@ class DiffusionFeatureExtractor4(nn.Module):
|
||||
stepped_latents = torch.cat(stepped_chunks, dim=0)
|
||||
|
||||
latents = stepped_latents.to(self.vae.device, dtype=self.vae.dtype)
|
||||
|
||||
scaling_factor = self.vae.config['scaling_factor'] if 'scaling_factor' in self.vae.config else 1.0
|
||||
shift_factor = self.vae.config['shift_factor'] if 'shift_factor' in self.vae.config else 0.0
|
||||
|
||||
scaling_factor = self.vae.config.scaling_factor if hasattr(self.vae.config, 'scaling_factor') else 1.0
|
||||
shift_factor = self.vae.config.shift_factor if hasattr(self.vae.config, 'shift_factor') else 0.0
|
||||
latents = (latents / scaling_factor) + shift_factor
|
||||
if is_video:
|
||||
# if video, we need to unsqueeze the latents to match the vae input shape
|
||||
|
||||
Reference in New Issue
Block a user