mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-03-13 08:59:51 +00:00
VAE patcher and more types of unet patches
This commit is contained in:
@@ -235,14 +235,14 @@ def load_model_for_a1111(timer, checkpoint_info=None, state_dict=None):
|
||||
|
||||
@torch.inference_mode()
|
||||
def patched_decode_first_stage(x):
|
||||
sample = forge_objects.unet.model.model_config.latent_format.process_out(x)
|
||||
sample = forge_objects.vae.decode(sample).movedim(-1, 1) * 2.0 - 1.0
|
||||
sample = sd_model.forge_objects.unet.model.model_config.latent_format.process_out(x)
|
||||
sample = sd_model.forge_objects.vae.decode(sample).movedim(-1, 1) * 2.0 - 1.0
|
||||
return sample.to(x)
|
||||
|
||||
@torch.inference_mode()
|
||||
def patched_encode_first_stage(x):
|
||||
sample = forge_objects.vae.encode(x.movedim(1, -1) * 0.5 + 0.5)
|
||||
sample = forge_objects.unet.model.model_config.latent_format.process_in(sample)
|
||||
sample = sd_model.forge_objects.vae.encode(x.movedim(1, -1) * 0.5 + 0.5)
|
||||
sample = sd_model.forge_objects.unet.model.model_config.latent_format.process_in(sample)
|
||||
return sample.to(x)
|
||||
|
||||
sd_model.ema_scope = lambda *args, **kwargs: contextlib.nullcontext()
|
||||
|
||||
@@ -56,6 +56,7 @@ def cond_from_a1111_to_patched_ldm_weighted(cond, weights):
|
||||
def forge_sample(self, denoiser_params, cond_scale, cond_composition):
|
||||
model = self.inner_model.inner_model.forge_objects.unet.model
|
||||
control = self.inner_model.inner_model.forge_objects.unet.controlnet_linked_list
|
||||
extra_concat_condition = self.inner_model.inner_model.forge_objects.unet.extra_concat_condition
|
||||
x = denoiser_params.x
|
||||
timestep = denoiser_params.sigma
|
||||
uncond = cond_from_a1111_to_patched_ldm(denoiser_params.text_uncond)
|
||||
@@ -63,7 +64,11 @@ def forge_sample(self, denoiser_params, cond_scale, cond_composition):
|
||||
model_options = self.inner_model.inner_model.forge_objects.unet.model_options
|
||||
seed = self.p.seeds[0]
|
||||
|
||||
image_cond_in = denoiser_params.image_cond
|
||||
if extra_concat_condition is not None:
|
||||
image_cond_in = extra_concat_condition
|
||||
else:
|
||||
image_cond_in = denoiser_params.image_cond
|
||||
|
||||
if isinstance(image_cond_in, torch.Tensor):
|
||||
if image_cond_in.shape[0] == x.shape[0] \
|
||||
and image_cond_in.shape[2] == x.shape[2] \
|
||||
|
||||
@@ -1 +1 @@
|
||||
version = '0.0.16v1.8.0rc'
|
||||
version = '0.0.17v1.8.0rc'
|
||||
|
||||
@@ -12,6 +12,7 @@ class UnetPatcher(ModelPatcher):
|
||||
self.controlnet_linked_list = None
|
||||
self.extra_preserved_memory_during_sampling = 0
|
||||
self.extra_model_patchers_during_sampling = []
|
||||
self.extra_concat_condition = None
|
||||
|
||||
def clone(self):
|
||||
n = UnetPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device,
|
||||
@@ -27,6 +28,7 @@ class UnetPatcher(ModelPatcher):
|
||||
n.controlnet_linked_list = self.controlnet_linked_list
|
||||
n.extra_preserved_memory_during_sampling = self.extra_preserved_memory_during_sampling
|
||||
n.extra_model_patchers_during_sampling = self.extra_model_patchers_during_sampling.copy()
|
||||
n.extra_concat_condition = self.extra_concat_condition
|
||||
return n
|
||||
|
||||
def add_extra_preserved_memory_during_sampling(self, memory_in_bytes: int):
|
||||
@@ -176,3 +178,21 @@ class UnetPatcher(ModelPatcher):
|
||||
device=noise.device,
|
||||
prompt_type=prompt_type
|
||||
)
|
||||
|
||||
def load_frozen_patcher(self, state_dict, strength):
|
||||
patch_dict = {}
|
||||
for k, w in state_dict.items():
|
||||
model_key, patch_type, weight_index = k.split('::')
|
||||
if model_key not in patch_dict:
|
||||
patch_dict[model_key] = {}
|
||||
if patch_type not in patch_dict[model_key]:
|
||||
patch_dict[model_key][patch_type] = [None] * 16
|
||||
patch_dict[model_key][patch_type][int(weight_index)] = w
|
||||
|
||||
patch_flat = {}
|
||||
for model_key, v in patch_dict.items():
|
||||
for patch_type, weight_list in v.items():
|
||||
patch_flat[model_key] = (patch_type, weight_list)
|
||||
|
||||
self.add_patches(patches=patch_flat, strength_patch=float(strength), strength_model=1.0)
|
||||
return
|
||||
|
||||
Reference in New Issue
Block a user