mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 11:11:37 +00:00
DEF for fake vae and adjust scaling
This commit is contained in:
@@ -25,11 +25,12 @@ class Config:
|
|||||||
|
|
||||||
|
|
||||||
class FakeVAE(nn.Module):
|
class FakeVAE(nn.Module):
|
||||||
def __init__(self):
|
def __init__(self, scaling_factor=1.0):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._dtype = torch.float32
|
self._dtype = torch.float32
|
||||||
self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
self.config = Config()
|
self.config = Config()
|
||||||
|
self.config.scaling_factor = scaling_factor
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def dtype(self):
|
def dtype(self):
|
||||||
|
|||||||
@@ -531,9 +531,9 @@ class DiffusionFeatureExtractor4(nn.Module):
|
|||||||
stepped_latents = torch.cat(stepped_chunks, dim=0)
|
stepped_latents = torch.cat(stepped_chunks, dim=0)
|
||||||
|
|
||||||
latents = stepped_latents.to(self.vae.device, dtype=self.vae.dtype)
|
latents = stepped_latents.to(self.vae.device, dtype=self.vae.dtype)
|
||||||
|
|
||||||
scaling_factor = self.vae.config['scaling_factor'] if 'scaling_factor' in self.vae.config else 1.0
|
scaling_factor = self.vae.config.scaling_factor if hasattr(self.vae.config, 'scaling_factor') else 1.0
|
||||||
shift_factor = self.vae.config['shift_factor'] if 'shift_factor' in self.vae.config else 0.0
|
shift_factor = self.vae.config.shift_factor if hasattr(self.vae.config, 'shift_factor') else 0.0
|
||||||
latents = (latents / scaling_factor) + shift_factor
|
latents = (latents / scaling_factor) + shift_factor
|
||||||
if is_video:
|
if is_video:
|
||||||
# if video, we need to unsqueeze the latents to match the vae input shape
|
# if video, we need to unsqueeze the latents to match the vae input shape
|
||||||
|
|||||||
Reference in New Issue
Block a user