mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-04-30 11:11:15 +00:00
VAE patcher and more types of unet patches
This commit is contained in:
@@ -75,6 +75,12 @@ class ModelPatcher:
|
|||||||
def set_model_unet_function_wrapper(self, unet_wrapper_function):
|
def set_model_unet_function_wrapper(self, unet_wrapper_function):
|
||||||
self.model_options["model_function_wrapper"] = unet_wrapper_function
|
self.model_options["model_function_wrapper"] = unet_wrapper_function
|
||||||
|
|
||||||
|
def set_model_vae_encode_wrapper(self, wrapper_function):
|
||||||
|
self.model_options["model_vae_encode_wrapper"] = wrapper_function
|
||||||
|
|
||||||
|
def set_model_vae_decode_wrapper(self, wrapper_function):
|
||||||
|
self.model_options["model_vae_decode_wrapper"] = wrapper_function
|
||||||
|
|
||||||
def set_model_patch(self, patch, name):
|
def set_model_patch(self, patch, name):
|
||||||
to = self.model_options["transformer_options"]
|
to = self.model_options["transformer_options"]
|
||||||
if "patches" not in to:
|
if "patches" not in to:
|
||||||
@@ -242,7 +248,17 @@ class ModelPatcher:
|
|||||||
w1 = v[0]
|
w1 = v[0]
|
||||||
if alpha != 0.0:
|
if alpha != 0.0:
|
||||||
if w1.shape != weight.shape:
|
if w1.shape != weight.shape:
|
||||||
print("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape))
|
if w1.ndim == weight.ndim == 4:
|
||||||
|
new_shape = [max(n, m) for n, m in zip(weight.shape, w1.shape)]
|
||||||
|
print(f'Merged with {key} channel changed to {new_shape}')
|
||||||
|
new_diff = alpha * ldm_patched.modules.model_management.cast_to_device(w1, weight.device, weight.dtype)
|
||||||
|
new_weight = torch.zeros(size=new_shape).to(weight)
|
||||||
|
new_weight[:weight.shape[0], :weight.shape[1], :weight.shape[2], :weight.shape[3]] = weight
|
||||||
|
new_weight[:new_diff.shape[0], :new_diff.shape[1], :new_diff.shape[2], :new_diff.shape[3]] += new_diff
|
||||||
|
new_weight = new_weight.contiguous().clone()
|
||||||
|
weight = new_weight
|
||||||
|
else:
|
||||||
|
print("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape))
|
||||||
else:
|
else:
|
||||||
weight += alpha * ldm_patched.modules.model_management.cast_to_device(w1, weight.device, weight.dtype)
|
weight += alpha * ldm_patched.modules.model_management.cast_to_device(w1, weight.device, weight.dtype)
|
||||||
elif patch_type == "lora": #lora/locon
|
elif patch_type == "lora": #lora/locon
|
||||||
|
|||||||
@@ -163,7 +163,10 @@ class CLIP:
|
|||||||
return self.patcher.get_key_patches()
|
return self.patcher.get_key_patches()
|
||||||
|
|
||||||
class VAE:
|
class VAE:
|
||||||
def __init__(self, sd=None, device=None, config=None, dtype=None):
|
def __init__(self, sd=None, device=None, config=None, dtype=None, no_init=False):
|
||||||
|
if no_init:
|
||||||
|
return
|
||||||
|
|
||||||
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
|
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
|
||||||
sd = diffusers_convert.convert_vae_state_dict(sd)
|
sd = diffusers_convert.convert_vae_state_dict(sd)
|
||||||
|
|
||||||
@@ -215,6 +218,19 @@ class VAE:
|
|||||||
|
|
||||||
self.patcher = ldm_patched.modules.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
|
self.patcher = ldm_patched.modules.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
|
||||||
|
|
||||||
|
def clone(self):
|
||||||
|
n = VAE(no_init=True)
|
||||||
|
n.patcher = self.patcher.clone()
|
||||||
|
n.memory_used_encode = self.memory_used_encode
|
||||||
|
n.memory_used_decode = self.memory_used_decode
|
||||||
|
n.downscale_ratio = self.downscale_ratio
|
||||||
|
n.latent_channels = self.latent_channels
|
||||||
|
n.first_stage_model = self.first_stage_model
|
||||||
|
n.device = self.device
|
||||||
|
n.vae_dtype = self.vae_dtype
|
||||||
|
n.output_device = self.output_device
|
||||||
|
return n
|
||||||
|
|
||||||
def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16):
|
def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16):
|
||||||
steps = samples.shape[0] * ldm_patched.modules.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x, tile_y, overlap)
|
steps = samples.shape[0] * ldm_patched.modules.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x, tile_y, overlap)
|
||||||
steps += samples.shape[0] * ldm_patched.modules.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x // 2, tile_y * 2, overlap)
|
steps += samples.shape[0] * ldm_patched.modules.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x // 2, tile_y * 2, overlap)
|
||||||
@@ -242,7 +258,7 @@ class VAE:
|
|||||||
samples /= 3.0
|
samples /= 3.0
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
def decode(self, samples_in):
|
def decode_inner(self, samples_in):
|
||||||
if model_management.VAE_ALWAYS_TILED:
|
if model_management.VAE_ALWAYS_TILED:
|
||||||
return self.decode_tiled(samples_in).to(self.output_device)
|
return self.decode_tiled(samples_in).to(self.output_device)
|
||||||
|
|
||||||
@@ -264,12 +280,19 @@ class VAE:
|
|||||||
pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1)
|
pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1)
|
||||||
return pixel_samples
|
return pixel_samples
|
||||||
|
|
||||||
|
def decode(self, samples_in):
|
||||||
|
wrapper = self.patcher.model_options.get('model_vae_decode_wrapper', None)
|
||||||
|
if wrapper is None:
|
||||||
|
return self.decode_inner(samples_in)
|
||||||
|
else:
|
||||||
|
return wrapper(self.decode_inner, samples_in)
|
||||||
|
|
||||||
def decode_tiled(self, samples, tile_x=64, tile_y=64, overlap = 16):
|
def decode_tiled(self, samples, tile_x=64, tile_y=64, overlap = 16):
|
||||||
model_management.load_model_gpu(self.patcher)
|
model_management.load_model_gpu(self.patcher)
|
||||||
output = self.decode_tiled_(samples, tile_x, tile_y, overlap)
|
output = self.decode_tiled_(samples, tile_x, tile_y, overlap)
|
||||||
return output.movedim(1,-1)
|
return output.movedim(1,-1)
|
||||||
|
|
||||||
def encode(self, pixel_samples):
|
def encode_inner(self, pixel_samples):
|
||||||
if model_management.VAE_ALWAYS_TILED:
|
if model_management.VAE_ALWAYS_TILED:
|
||||||
return self.encode_tiled(pixel_samples)
|
return self.encode_tiled(pixel_samples)
|
||||||
|
|
||||||
@@ -291,6 +314,13 @@ class VAE:
|
|||||||
|
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
|
def encode(self, pixel_samples):
|
||||||
|
wrapper = self.patcher.model_options.get('model_vae_encode_wrapper', None)
|
||||||
|
if wrapper is None:
|
||||||
|
return self.encode_inner(pixel_samples)
|
||||||
|
else:
|
||||||
|
return wrapper(self.encode_inner, pixel_samples)
|
||||||
|
|
||||||
def encode_tiled(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
|
def encode_tiled(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
|
||||||
model_management.load_model_gpu(self.patcher)
|
model_management.load_model_gpu(self.patcher)
|
||||||
pixel_samples = pixel_samples.movedim(-1,1)
|
pixel_samples = pixel_samples.movedim(-1,1)
|
||||||
|
|||||||
@@ -235,14 +235,14 @@ def load_model_for_a1111(timer, checkpoint_info=None, state_dict=None):
|
|||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def patched_decode_first_stage(x):
|
def patched_decode_first_stage(x):
|
||||||
sample = forge_objects.unet.model.model_config.latent_format.process_out(x)
|
sample = sd_model.forge_objects.unet.model.model_config.latent_format.process_out(x)
|
||||||
sample = forge_objects.vae.decode(sample).movedim(-1, 1) * 2.0 - 1.0
|
sample = sd_model.forge_objects.vae.decode(sample).movedim(-1, 1) * 2.0 - 1.0
|
||||||
return sample.to(x)
|
return sample.to(x)
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def patched_encode_first_stage(x):
|
def patched_encode_first_stage(x):
|
||||||
sample = forge_objects.vae.encode(x.movedim(1, -1) * 0.5 + 0.5)
|
sample = sd_model.forge_objects.vae.encode(x.movedim(1, -1) * 0.5 + 0.5)
|
||||||
sample = forge_objects.unet.model.model_config.latent_format.process_in(sample)
|
sample = sd_model.forge_objects.unet.model.model_config.latent_format.process_in(sample)
|
||||||
return sample.to(x)
|
return sample.to(x)
|
||||||
|
|
||||||
sd_model.ema_scope = lambda *args, **kwargs: contextlib.nullcontext()
|
sd_model.ema_scope = lambda *args, **kwargs: contextlib.nullcontext()
|
||||||
|
|||||||
@@ -56,6 +56,7 @@ def cond_from_a1111_to_patched_ldm_weighted(cond, weights):
|
|||||||
def forge_sample(self, denoiser_params, cond_scale, cond_composition):
|
def forge_sample(self, denoiser_params, cond_scale, cond_composition):
|
||||||
model = self.inner_model.inner_model.forge_objects.unet.model
|
model = self.inner_model.inner_model.forge_objects.unet.model
|
||||||
control = self.inner_model.inner_model.forge_objects.unet.controlnet_linked_list
|
control = self.inner_model.inner_model.forge_objects.unet.controlnet_linked_list
|
||||||
|
extra_concat_condition = self.inner_model.inner_model.forge_objects.unet.extra_concat_condition
|
||||||
x = denoiser_params.x
|
x = denoiser_params.x
|
||||||
timestep = denoiser_params.sigma
|
timestep = denoiser_params.sigma
|
||||||
uncond = cond_from_a1111_to_patched_ldm(denoiser_params.text_uncond)
|
uncond = cond_from_a1111_to_patched_ldm(denoiser_params.text_uncond)
|
||||||
@@ -63,7 +64,11 @@ def forge_sample(self, denoiser_params, cond_scale, cond_composition):
|
|||||||
model_options = self.inner_model.inner_model.forge_objects.unet.model_options
|
model_options = self.inner_model.inner_model.forge_objects.unet.model_options
|
||||||
seed = self.p.seeds[0]
|
seed = self.p.seeds[0]
|
||||||
|
|
||||||
image_cond_in = denoiser_params.image_cond
|
if extra_concat_condition is not None:
|
||||||
|
image_cond_in = extra_concat_condition
|
||||||
|
else:
|
||||||
|
image_cond_in = denoiser_params.image_cond
|
||||||
|
|
||||||
if isinstance(image_cond_in, torch.Tensor):
|
if isinstance(image_cond_in, torch.Tensor):
|
||||||
if image_cond_in.shape[0] == x.shape[0] \
|
if image_cond_in.shape[0] == x.shape[0] \
|
||||||
and image_cond_in.shape[2] == x.shape[2] \
|
and image_cond_in.shape[2] == x.shape[2] \
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
version = '0.0.16v1.8.0rc'
|
version = '0.0.17v1.8.0rc'
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ class UnetPatcher(ModelPatcher):
|
|||||||
self.controlnet_linked_list = None
|
self.controlnet_linked_list = None
|
||||||
self.extra_preserved_memory_during_sampling = 0
|
self.extra_preserved_memory_during_sampling = 0
|
||||||
self.extra_model_patchers_during_sampling = []
|
self.extra_model_patchers_during_sampling = []
|
||||||
|
self.extra_concat_condition = None
|
||||||
|
|
||||||
def clone(self):
|
def clone(self):
|
||||||
n = UnetPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device,
|
n = UnetPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device,
|
||||||
@@ -27,6 +28,7 @@ class UnetPatcher(ModelPatcher):
|
|||||||
n.controlnet_linked_list = self.controlnet_linked_list
|
n.controlnet_linked_list = self.controlnet_linked_list
|
||||||
n.extra_preserved_memory_during_sampling = self.extra_preserved_memory_during_sampling
|
n.extra_preserved_memory_during_sampling = self.extra_preserved_memory_during_sampling
|
||||||
n.extra_model_patchers_during_sampling = self.extra_model_patchers_during_sampling.copy()
|
n.extra_model_patchers_during_sampling = self.extra_model_patchers_during_sampling.copy()
|
||||||
|
n.extra_concat_condition = self.extra_concat_condition
|
||||||
return n
|
return n
|
||||||
|
|
||||||
def add_extra_preserved_memory_during_sampling(self, memory_in_bytes: int):
|
def add_extra_preserved_memory_during_sampling(self, memory_in_bytes: int):
|
||||||
@@ -176,3 +178,21 @@ class UnetPatcher(ModelPatcher):
|
|||||||
device=noise.device,
|
device=noise.device,
|
||||||
prompt_type=prompt_type
|
prompt_type=prompt_type
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def load_frozen_patcher(self, state_dict, strength):
|
||||||
|
patch_dict = {}
|
||||||
|
for k, w in state_dict.items():
|
||||||
|
model_key, patch_type, weight_index = k.split('::')
|
||||||
|
if model_key not in patch_dict:
|
||||||
|
patch_dict[model_key] = {}
|
||||||
|
if patch_type not in patch_dict[model_key]:
|
||||||
|
patch_dict[model_key][patch_type] = [None] * 16
|
||||||
|
patch_dict[model_key][patch_type][int(weight_index)] = w
|
||||||
|
|
||||||
|
patch_flat = {}
|
||||||
|
for model_key, v in patch_dict.items():
|
||||||
|
for patch_type, weight_list in v.items():
|
||||||
|
patch_flat[model_key] = (patch_type, weight_list)
|
||||||
|
|
||||||
|
self.add_patches(patches=patch_flat, strength_patch=float(strength), strength_model=1.0)
|
||||||
|
return
|
||||||
|
|||||||
Reference in New Issue
Block a user