mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-03-04 10:09:49 +00:00
make DFE work with more VAEs
This commit is contained in:
@@ -285,9 +285,12 @@ class DiffusionFeatureExtractor3(nn.Module):
|
||||
|
||||
latents = stepped_latents.to(self.vae.device, dtype=self.vae.dtype)
|
||||
|
||||
latents = (
|
||||
latents / self.vae.config['scaling_factor']) + self.vae.config['shift_factor']
|
||||
tensors_n1p1 = self.vae.decode(latents).sample # -1 to 1
|
||||
scaling_factor = self.vae.config.scaling_factor if hasattr(self.vae.config, 'scaling_factor') else 1.0
|
||||
shift_factor = self.vae.config.shift_factor if hasattr(self.vae.config, 'shift_factor') else 0.0
|
||||
latents = (latents / scaling_factor) + shift_factor
|
||||
tensors_n1p1 = self.vae.decode(latents) # -1 to 1
|
||||
if hasattr(tensors_n1p1, 'sample'):
|
||||
tensors_n1p1 = tensors_n1p1.sample
|
||||
|
||||
pred_images = (tensors_n1p1 + 1) / 2 # 0 to 1
|
||||
|
||||
@@ -540,7 +543,9 @@ class DiffusionFeatureExtractor4(nn.Module):
|
||||
if is_video:
|
||||
# if video, we need to unsqueeze the latents to match the vae input shape
|
||||
latents = latents.unsqueeze(2)
|
||||
tensors_n1p1 = self.vae.decode(latents).sample # -1 to 1
|
||||
tensors_n1p1 = self.vae.decode(latents) # -1 to 1
|
||||
if hasattr(tensors_n1p1, 'sample'):
|
||||
tensors_n1p1 = tensors_n1p1.sample
|
||||
|
||||
if is_video:
|
||||
# if video, we need to squeeze the tensors to match the output shape
|
||||
|
||||
Reference in New Issue
Block a user