mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-02-28 18:53:58 +00:00
put inpaint_v26.fooocus.patch in models\ControlNet, control SDXL models only To get same algorithm as Fooocus, set "Stop at" (Ending Control Step) to 0.5 Fooocus always use 0.5 but in Forge users may use other values. Results are best when stop at < 0.7. The model is not optimized with ending timesteps > 0.7 Supports inpaint_global_harmonious, inpaint_only, inpaint_only+lama. In theory the inpaint_only+lama always outperform Fooocus in object removal task (but not all tasks).
364 lines
16 KiB
Python
364 lines
16 KiB
Python
import torch
|
|
import copy
|
|
import inspect
|
|
|
|
import ldm_patched.modules.utils
|
|
import ldm_patched.modules.model_management
|
|
|
|
|
|
extra_weight_calculators = {}
|
|
|
|
|
|
class ModelPatcher:
|
|
def __init__(self, model, load_device, offload_device, size=0, current_device=None, weight_inplace_update=False):
|
|
self.size = size
|
|
self.model = model
|
|
self.patches = {}
|
|
self.backup = {}
|
|
self.object_patches = {}
|
|
self.object_patches_backup = {}
|
|
self.model_options = {"transformer_options":{}}
|
|
self.model_size()
|
|
self.load_device = load_device
|
|
self.offload_device = offload_device
|
|
if current_device is None:
|
|
self.current_device = self.offload_device
|
|
else:
|
|
self.current_device = current_device
|
|
|
|
self.weight_inplace_update = weight_inplace_update
|
|
|
|
def model_size(self):
|
|
if self.size > 0:
|
|
return self.size
|
|
model_sd = self.model.state_dict()
|
|
self.size = ldm_patched.modules.model_management.module_size(self.model)
|
|
self.model_keys = set(model_sd.keys())
|
|
return self.size
|
|
|
|
def clone(self):
|
|
n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device, weight_inplace_update=self.weight_inplace_update)
|
|
n.patches = {}
|
|
for k in self.patches:
|
|
n.patches[k] = self.patches[k][:]
|
|
|
|
n.object_patches = self.object_patches.copy()
|
|
n.model_options = copy.deepcopy(self.model_options)
|
|
n.model_keys = self.model_keys
|
|
return n
|
|
|
|
def is_clone(self, other):
|
|
if hasattr(other, 'model') and self.model is other.model:
|
|
return True
|
|
return False
|
|
|
|
def memory_required(self, input_shape):
|
|
return self.model.memory_required(input_shape=input_shape)
|
|
|
|
def set_model_sampler_cfg_function(self, sampler_cfg_function, disable_cfg1_optimization=False):
|
|
if len(inspect.signature(sampler_cfg_function).parameters) == 3:
|
|
self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) #Old way
|
|
else:
|
|
self.model_options["sampler_cfg_function"] = sampler_cfg_function
|
|
if disable_cfg1_optimization:
|
|
self.model_options["disable_cfg1_optimization"] = True
|
|
|
|
def set_model_sampler_post_cfg_function(self, post_cfg_function, disable_cfg1_optimization=False):
|
|
self.model_options["sampler_post_cfg_function"] = self.model_options.get("sampler_post_cfg_function", []) + [post_cfg_function]
|
|
if disable_cfg1_optimization:
|
|
self.model_options["disable_cfg1_optimization"] = True
|
|
|
|
def set_model_unet_function_wrapper(self, unet_wrapper_function):
|
|
self.model_options["model_function_wrapper"] = unet_wrapper_function
|
|
|
|
def set_model_patch(self, patch, name):
|
|
to = self.model_options["transformer_options"]
|
|
if "patches" not in to:
|
|
to["patches"] = {}
|
|
to["patches"][name] = to["patches"].get(name, []) + [patch]
|
|
|
|
def set_model_patch_replace(self, patch, name, block_name, number, transformer_index=None):
|
|
to = self.model_options["transformer_options"]
|
|
if "patches_replace" not in to:
|
|
to["patches_replace"] = {}
|
|
if name not in to["patches_replace"]:
|
|
to["patches_replace"][name] = {}
|
|
if transformer_index is not None:
|
|
block = (block_name, number, transformer_index)
|
|
else:
|
|
block = (block_name, number)
|
|
to["patches_replace"][name][block] = patch
|
|
|
|
def set_model_attn1_patch(self, patch):
|
|
self.set_model_patch(patch, "attn1_patch")
|
|
|
|
def set_model_attn2_patch(self, patch):
|
|
self.set_model_patch(patch, "attn2_patch")
|
|
|
|
def set_model_attn1_replace(self, patch, block_name, number, transformer_index=None):
|
|
self.set_model_patch_replace(patch, "attn1", block_name, number, transformer_index)
|
|
|
|
def set_model_attn2_replace(self, patch, block_name, number, transformer_index=None):
|
|
self.set_model_patch_replace(patch, "attn2", block_name, number, transformer_index)
|
|
|
|
def set_model_attn1_output_patch(self, patch):
|
|
self.set_model_patch(patch, "attn1_output_patch")
|
|
|
|
def set_model_attn2_output_patch(self, patch):
|
|
self.set_model_patch(patch, "attn2_output_patch")
|
|
|
|
def set_model_input_block_patch(self, patch):
|
|
self.set_model_patch(patch, "input_block_patch")
|
|
|
|
def set_model_input_block_patch_after_skip(self, patch):
|
|
self.set_model_patch(patch, "input_block_patch_after_skip")
|
|
|
|
def set_model_output_block_patch(self, patch):
|
|
self.set_model_patch(patch, "output_block_patch")
|
|
|
|
def add_object_patch(self, name, obj):
|
|
self.object_patches[name] = obj
|
|
|
|
def model_patches_to(self, device):
|
|
to = self.model_options["transformer_options"]
|
|
if "patches" in to:
|
|
patches = to["patches"]
|
|
for name in patches:
|
|
patch_list = patches[name]
|
|
for i in range(len(patch_list)):
|
|
if hasattr(patch_list[i], "to"):
|
|
patch_list[i] = patch_list[i].to(device)
|
|
if "patches_replace" in to:
|
|
patches = to["patches_replace"]
|
|
for name in patches:
|
|
patch_list = patches[name]
|
|
for k in patch_list:
|
|
if hasattr(patch_list[k], "to"):
|
|
patch_list[k] = patch_list[k].to(device)
|
|
if "model_function_wrapper" in self.model_options:
|
|
wrap_func = self.model_options["model_function_wrapper"]
|
|
if hasattr(wrap_func, "to"):
|
|
self.model_options["model_function_wrapper"] = wrap_func.to(device)
|
|
|
|
def model_dtype(self):
|
|
if hasattr(self.model, "get_dtype"):
|
|
return self.model.get_dtype()
|
|
|
|
def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
|
|
p = set()
|
|
for k in patches:
|
|
if k in self.model_keys:
|
|
p.add(k)
|
|
current_patches = self.patches.get(k, [])
|
|
current_patches.append((strength_patch, patches[k], strength_model))
|
|
self.patches[k] = current_patches
|
|
|
|
return list(p)
|
|
|
|
def get_key_patches(self, filter_prefix=None):
|
|
ldm_patched.modules.model_management.unload_model_clones(self)
|
|
model_sd = self.model_state_dict()
|
|
p = {}
|
|
for k in model_sd:
|
|
if filter_prefix is not None:
|
|
if not k.startswith(filter_prefix):
|
|
continue
|
|
if k in self.patches:
|
|
p[k] = [model_sd[k]] + self.patches[k]
|
|
else:
|
|
p[k] = (model_sd[k],)
|
|
return p
|
|
|
|
def model_state_dict(self, filter_prefix=None):
|
|
sd = self.model.state_dict()
|
|
keys = list(sd.keys())
|
|
if filter_prefix is not None:
|
|
for k in keys:
|
|
if not k.startswith(filter_prefix):
|
|
sd.pop(k)
|
|
return sd
|
|
|
|
def patch_model(self, device_to=None, patch_weights=True):
|
|
for k in self.object_patches:
|
|
old = getattr(self.model, k)
|
|
if k not in self.object_patches_backup:
|
|
self.object_patches_backup[k] = old
|
|
setattr(self.model, k, self.object_patches[k])
|
|
|
|
if patch_weights:
|
|
model_sd = self.model_state_dict()
|
|
for key in self.patches:
|
|
if key not in model_sd:
|
|
print("could not patch. key doesn't exist in model:", key)
|
|
continue
|
|
|
|
weight = model_sd[key]
|
|
|
|
inplace_update = self.weight_inplace_update
|
|
|
|
if key not in self.backup:
|
|
self.backup[key] = weight.to(device=self.offload_device, copy=inplace_update)
|
|
|
|
if device_to is not None:
|
|
temp_weight = ldm_patched.modules.model_management.cast_to_device(weight, device_to, torch.float32, copy=True)
|
|
else:
|
|
temp_weight = weight.to(torch.float32, copy=True)
|
|
out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype)
|
|
if inplace_update:
|
|
ldm_patched.modules.utils.copy_to_param(self.model, key, out_weight)
|
|
else:
|
|
ldm_patched.modules.utils.set_attr(self.model, key, out_weight)
|
|
del temp_weight
|
|
|
|
if device_to is not None:
|
|
self.model.to(device_to)
|
|
self.current_device = device_to
|
|
|
|
return self.model
|
|
|
|
def calculate_weight(self, patches, weight, key):
|
|
for p in patches:
|
|
alpha = p[0]
|
|
v = p[1]
|
|
strength_model = p[2]
|
|
|
|
if strength_model != 1.0:
|
|
weight *= strength_model
|
|
|
|
if isinstance(v, list):
|
|
v = (self.calculate_weight(v[1:], v[0].clone(), key), )
|
|
|
|
if len(v) == 1:
|
|
patch_type = "diff"
|
|
elif len(v) == 2:
|
|
patch_type = v[0]
|
|
v = v[1]
|
|
|
|
if patch_type == "diff":
|
|
w1 = v[0]
|
|
if alpha != 0.0:
|
|
if w1.shape != weight.shape:
|
|
print("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape))
|
|
else:
|
|
weight += alpha * ldm_patched.modules.model_management.cast_to_device(w1, weight.device, weight.dtype)
|
|
elif patch_type == "lora": #lora/locon
|
|
mat1 = ldm_patched.modules.model_management.cast_to_device(v[0], weight.device, torch.float32)
|
|
mat2 = ldm_patched.modules.model_management.cast_to_device(v[1], weight.device, torch.float32)
|
|
if v[2] is not None:
|
|
alpha *= v[2] / mat2.shape[0]
|
|
if v[3] is not None:
|
|
#locon mid weights, hopefully the math is fine because I didn't properly test it
|
|
mat3 = ldm_patched.modules.model_management.cast_to_device(v[3], weight.device, torch.float32)
|
|
final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]]
|
|
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1)
|
|
try:
|
|
weight += (alpha * torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1))).reshape(weight.shape).type(weight.dtype)
|
|
except Exception as e:
|
|
print("ERROR", key, e)
|
|
elif patch_type == "lokr":
|
|
w1 = v[0]
|
|
w2 = v[1]
|
|
w1_a = v[3]
|
|
w1_b = v[4]
|
|
w2_a = v[5]
|
|
w2_b = v[6]
|
|
t2 = v[7]
|
|
dim = None
|
|
|
|
if w1 is None:
|
|
dim = w1_b.shape[0]
|
|
w1 = torch.mm(ldm_patched.modules.model_management.cast_to_device(w1_a, weight.device, torch.float32),
|
|
ldm_patched.modules.model_management.cast_to_device(w1_b, weight.device, torch.float32))
|
|
else:
|
|
w1 = ldm_patched.modules.model_management.cast_to_device(w1, weight.device, torch.float32)
|
|
|
|
if w2 is None:
|
|
dim = w2_b.shape[0]
|
|
if t2 is None:
|
|
w2 = torch.mm(ldm_patched.modules.model_management.cast_to_device(w2_a, weight.device, torch.float32),
|
|
ldm_patched.modules.model_management.cast_to_device(w2_b, weight.device, torch.float32))
|
|
else:
|
|
w2 = torch.einsum('i j k l, j r, i p -> p r k l',
|
|
ldm_patched.modules.model_management.cast_to_device(t2, weight.device, torch.float32),
|
|
ldm_patched.modules.model_management.cast_to_device(w2_b, weight.device, torch.float32),
|
|
ldm_patched.modules.model_management.cast_to_device(w2_a, weight.device, torch.float32))
|
|
else:
|
|
w2 = ldm_patched.modules.model_management.cast_to_device(w2, weight.device, torch.float32)
|
|
|
|
if len(w2.shape) == 4:
|
|
w1 = w1.unsqueeze(2).unsqueeze(2)
|
|
if v[2] is not None and dim is not None:
|
|
alpha *= v[2] / dim
|
|
|
|
try:
|
|
weight += alpha * torch.kron(w1, w2).reshape(weight.shape).type(weight.dtype)
|
|
except Exception as e:
|
|
print("ERROR", key, e)
|
|
elif patch_type == "loha":
|
|
w1a = v[0]
|
|
w1b = v[1]
|
|
if v[2] is not None:
|
|
alpha *= v[2] / w1b.shape[0]
|
|
w2a = v[3]
|
|
w2b = v[4]
|
|
if v[5] is not None: #cp decomposition
|
|
t1 = v[5]
|
|
t2 = v[6]
|
|
m1 = torch.einsum('i j k l, j r, i p -> p r k l',
|
|
ldm_patched.modules.model_management.cast_to_device(t1, weight.device, torch.float32),
|
|
ldm_patched.modules.model_management.cast_to_device(w1b, weight.device, torch.float32),
|
|
ldm_patched.modules.model_management.cast_to_device(w1a, weight.device, torch.float32))
|
|
|
|
m2 = torch.einsum('i j k l, j r, i p -> p r k l',
|
|
ldm_patched.modules.model_management.cast_to_device(t2, weight.device, torch.float32),
|
|
ldm_patched.modules.model_management.cast_to_device(w2b, weight.device, torch.float32),
|
|
ldm_patched.modules.model_management.cast_to_device(w2a, weight.device, torch.float32))
|
|
else:
|
|
m1 = torch.mm(ldm_patched.modules.model_management.cast_to_device(w1a, weight.device, torch.float32),
|
|
ldm_patched.modules.model_management.cast_to_device(w1b, weight.device, torch.float32))
|
|
m2 = torch.mm(ldm_patched.modules.model_management.cast_to_device(w2a, weight.device, torch.float32),
|
|
ldm_patched.modules.model_management.cast_to_device(w2b, weight.device, torch.float32))
|
|
|
|
try:
|
|
weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype)
|
|
except Exception as e:
|
|
print("ERROR", key, e)
|
|
elif patch_type == "glora":
|
|
if v[4] is not None:
|
|
alpha *= v[4] / v[0].shape[0]
|
|
|
|
a1 = ldm_patched.modules.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, torch.float32)
|
|
a2 = ldm_patched.modules.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, torch.float32)
|
|
b1 = ldm_patched.modules.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, torch.float32)
|
|
b2 = ldm_patched.modules.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, torch.float32)
|
|
|
|
weight += ((torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)) * alpha).reshape(weight.shape).type(weight.dtype)
|
|
elif patch_type in extra_weight_calculators:
|
|
weight = extra_weight_calculators[patch_type](weight, alpha, v)
|
|
else:
|
|
print("patch type not recognized", patch_type, key)
|
|
|
|
return weight
|
|
|
|
def unpatch_model(self, device_to=None):
|
|
keys = list(self.backup.keys())
|
|
|
|
if self.weight_inplace_update:
|
|
for k in keys:
|
|
ldm_patched.modules.utils.copy_to_param(self.model, k, self.backup[k])
|
|
else:
|
|
for k in keys:
|
|
ldm_patched.modules.utils.set_attr(self.model, k, self.backup[k])
|
|
|
|
self.backup = {}
|
|
|
|
if device_to is not None:
|
|
self.model.to(device_to)
|
|
self.current_device = device_to
|
|
|
|
keys = list(self.object_patches_backup.keys())
|
|
for k in keys:
|
|
setattr(self.model, k, self.object_patches_backup[k])
|
|
|
|
self.object_patches_backup = {}
|