mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-04-30 11:11:15 +00:00
VAE patcher and more types of unet patches
This commit is contained in:
@@ -75,6 +75,12 @@ class ModelPatcher:
|
||||
def set_model_unet_function_wrapper(self, unet_wrapper_function):
|
||||
self.model_options["model_function_wrapper"] = unet_wrapper_function
|
||||
|
||||
def set_model_vae_encode_wrapper(self, wrapper_function):
|
||||
self.model_options["model_vae_encode_wrapper"] = wrapper_function
|
||||
|
||||
def set_model_vae_decode_wrapper(self, wrapper_function):
|
||||
self.model_options["model_vae_decode_wrapper"] = wrapper_function
|
||||
|
||||
def set_model_patch(self, patch, name):
|
||||
to = self.model_options["transformer_options"]
|
||||
if "patches" not in to:
|
||||
@@ -242,7 +248,17 @@ class ModelPatcher:
|
||||
w1 = v[0]
|
||||
if alpha != 0.0:
|
||||
if w1.shape != weight.shape:
|
||||
print("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape))
|
||||
if w1.ndim == weight.ndim == 4:
|
||||
new_shape = [max(n, m) for n, m in zip(weight.shape, w1.shape)]
|
||||
print(f'Merged with {key} channel changed to {new_shape}')
|
||||
new_diff = alpha * ldm_patched.modules.model_management.cast_to_device(w1, weight.device, weight.dtype)
|
||||
new_weight = torch.zeros(size=new_shape).to(weight)
|
||||
new_weight[:weight.shape[0], :weight.shape[1], :weight.shape[2], :weight.shape[3]] = weight
|
||||
new_weight[:new_diff.shape[0], :new_diff.shape[1], :new_diff.shape[2], :new_diff.shape[3]] += new_diff
|
||||
new_weight = new_weight.contiguous().clone()
|
||||
weight = new_weight
|
||||
else:
|
||||
print("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape))
|
||||
else:
|
||||
weight += alpha * ldm_patched.modules.model_management.cast_to_device(w1, weight.device, weight.dtype)
|
||||
elif patch_type == "lora": #lora/locon
|
||||
|
||||
Reference in New Issue
Block a user