begin to use new vae impl

This commit is contained in:
layerdiffusion
2024-07-30 08:43:14 -06:00
parent abd4d4d83d
commit dc62b0d2d7
3 changed files with 31 additions and 40 deletions

View File

@@ -6,11 +6,25 @@ from backend.attention import AttentionProcessorForge
from diffusers.loaders.single_file_model import convert_ldm_vae_checkpoint
class BaseVAE(AutoencoderKL):
def encode(self, x, regulation=None, mode=False):
latent_dist = super().encode(x).latent_dist
if mode:
return latent_dist.mode()
elif regulation is not None:
return regulation(latent_dist)
else:
return latent_dist.sample()
def decode(self, x):
return super().decode(x).sample
def load_vae_from_state_dict(state_dict):
config = guess_vae_config(state_dict)
with using_forge_operations():
model = AutoencoderKL(**config)
model = BaseVAE(**config)
vae_state_dict = split_state_dict_with_prefix(state_dict, "first_stage_model.")
vae_state_dict = convert_ldm_vae_checkpoint(vae_state_dict, config)