mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-03-05 04:59:49 +00:00
support inpaint models from fooocus
put inpaint_v26.fooocus.patch in models\ControlNet, control SDXL models only To get same algorithm as Fooocus, set "Stop at" (Ending Control Step) to 0.5 Fooocus always use 0.5 but in Forge users may use other values. Results are best when stop at < 0.7. The model is not optimized with ending timesteps > 0.7 Supports inpaint_global_harmonious, inpaint_only, inpaint_only+lama. In theory the inpaint_only+lama always outperform Fooocus in object removal task (but not all tasks).
This commit is contained in:
@@ -7,7 +7,7 @@ from lib_controlnet.enums import StableDiffusionVersion
|
||||
from modules_forge.shared import controlnet_dir, supported_preprocessors
|
||||
|
||||
|
||||
CN_MODEL_EXTS = [".pt", ".pth", ".ckpt", ".safetensors", ".bin"]
|
||||
CN_MODEL_EXTS = [".pt", ".pth", ".ckpt", ".safetensors", ".bin", ".patch"]
|
||||
|
||||
|
||||
def traverse_all_files(curr_path, model_list):
|
||||
|
||||
@@ -417,7 +417,7 @@ class ControlNetForForgeOfficial(scripts.Script):
|
||||
cond = params.control_cond
|
||||
mask = params.control_mask
|
||||
|
||||
kwargs.update(dict(unit=unit, params=params))
|
||||
kwargs.update(dict(unit=unit, params=params, cond_original=cond.clone(), mask_original=mask.clone()))
|
||||
|
||||
params.model.strength = float(unit.weight)
|
||||
params.model.start_percent = float(unit.guidance_start)
|
||||
|
||||
Binary file not shown.
@@ -0,0 +1,131 @@
|
||||
import os
|
||||
import torch
|
||||
import copy
|
||||
|
||||
from modules_forge.shared import add_supported_control_model
|
||||
from modules_forge.supported_controlnet import ControlModelPatcher
|
||||
from modules_forge.forge_sampler import sampling_prepare
|
||||
from ldm_patched.modules.utils import load_torch_file
|
||||
from ldm_patched.modules import model_patcher
|
||||
from ldm_patched.modules.model_management import cast_to_device, LoadedModel, current_loaded_models
|
||||
from ldm_patched.modules.lora import model_lora_keys_unet
|
||||
|
||||
|
||||
def is_model_loaded(model):
|
||||
return LoadedModel(model) in current_loaded_models
|
||||
|
||||
|
||||
class InpaintHead(torch.nn.Module):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.head = torch.nn.Parameter(torch.empty(size=(320, 5, 3, 3), device="cpu"))
|
||||
|
||||
def __call__(self, x):
|
||||
x = torch.nn.functional.pad(x, (1, 1, 1, 1), "replicate")
|
||||
return torch.nn.functional.conv2d(input=x, weight=self.head)
|
||||
|
||||
|
||||
def load_fooocus_patch(lora: dict, to_load: dict):
|
||||
patch_dict = {}
|
||||
loaded_keys = set()
|
||||
for key in to_load.values():
|
||||
if value := lora.get(key, None):
|
||||
patch_dict[key] = ("fooocus", value)
|
||||
loaded_keys.add(key)
|
||||
|
||||
not_loaded = sum(1 for x in lora if x not in loaded_keys)
|
||||
print(f"[Fooocus Patch Loader] {len(loaded_keys)} keys loaded, {not_loaded} remaining keys not found in model.")
|
||||
return patch_dict
|
||||
|
||||
|
||||
def calculate_weight_fooocus(weight, alpha, v):
|
||||
w1 = cast_to_device(v[0], weight.device, torch.float32)
|
||||
if w1.shape == weight.shape:
|
||||
w_min = cast_to_device(v[1], weight.device, torch.float32)
|
||||
w_max = cast_to_device(v[2], weight.device, torch.float32)
|
||||
w1 = (w1 / 255.0) * (w_max - w_min) + w_min
|
||||
weight += alpha * cast_to_device(w1, weight.device, weight.dtype)
|
||||
else:
|
||||
print(f"[Fooocus Patch Loader] weight not merged ({w1.shape} != {weight.shape})")
|
||||
return weight
|
||||
|
||||
|
||||
class FooocusInpaintPatcher(ControlModelPatcher):
|
||||
@staticmethod
|
||||
def try_build_from_state_dict(state_dict, ckpt_path):
|
||||
if 'diffusion_model.time_embed.0.weight' in state_dict:
|
||||
if len(state_dict['diffusion_model.time_embed.0.weight']) == 3:
|
||||
return FooocusInpaintPatcher(state_dict)
|
||||
|
||||
return None
|
||||
|
||||
def __init__(self, state_dict):
|
||||
super().__init__()
|
||||
self.state_dict = state_dict
|
||||
self.inpaint_head = InpaintHead().to(device=torch.device('cpu'), dtype=torch.float32)
|
||||
self.inpaint_head.load_state_dict(load_torch_file(os.path.join(os.path.dirname(__file__), 'fooocus_inpaint_head')))
|
||||
|
||||
return
|
||||
|
||||
def process_before_every_sampling(self, process, cond, mask, *args, **kwargs):
|
||||
cond_original = kwargs['cond_original']
|
||||
mask_original = kwargs['mask_original']
|
||||
|
||||
unet_original = process.sd_model.forge_objects.unet.clone()
|
||||
unet = process.sd_model.forge_objects.unet.clone()
|
||||
vae = process.sd_model.forge_objects.vae
|
||||
|
||||
latent_image = vae.encode(cond_original.movedim(1, -1))
|
||||
latent_image = process.sd_model.forge_objects.unet.model.latent_format.process_in(latent_image)
|
||||
latent_mask = torch.nn.functional.max_pool2d(mask_original, (8, 8)).round().to(cond)
|
||||
feed = torch.cat([
|
||||
latent_mask.to(device=torch.device('cpu'), dtype=torch.float32),
|
||||
latent_image.to(device=torch.device('cpu'), dtype=torch.float32)
|
||||
], dim=1)
|
||||
inpaint_head_feature = self.inpaint_head(feed)
|
||||
|
||||
def input_block_patch(h, transformer_options):
|
||||
if transformer_options["block"][1] == 0:
|
||||
h = h + inpaint_head_feature.to(h)
|
||||
return h
|
||||
|
||||
unet.set_model_input_block_patch(input_block_patch)
|
||||
|
||||
lora_keys = model_lora_keys_unet(unet.model, {})
|
||||
lora_keys.update({x: x for x in unet.model.state_dict().keys()})
|
||||
loaded_lora = load_fooocus_patch(self.state_dict, lora_keys)
|
||||
|
||||
patched = unet.add_patches(loaded_lora, 1.0)
|
||||
|
||||
not_patched_count = sum(1 for x in loaded_lora if x not in patched)
|
||||
|
||||
if not_patched_count > 0:
|
||||
print(f"[Fooocus Patch Loader] Failed to load {not_patched_count} keys")
|
||||
|
||||
sigma_start = unet.model.model_sampling.percent_to_sigma(self.start_percent)
|
||||
sigma_end = unet.model.model_sampling.percent_to_sigma(self.end_percent)
|
||||
|
||||
def conditioning_modifier(model, x, timestep, uncond, cond, cond_scale, model_options, seed):
|
||||
if timestep > sigma_start or timestep < sigma_end:
|
||||
target_model = unet_original
|
||||
model_options = copy.deepcopy(model_options)
|
||||
if 'transformer_options' in model_options:
|
||||
if 'patches' in model_options['transformer_options']:
|
||||
if 'input_block_patch' in model_options['transformer_options']['patches']:
|
||||
del model_options['transformer_options']['patches']['input_block_patch']
|
||||
else:
|
||||
target_model = unet
|
||||
|
||||
if not is_model_loaded(target_model):
|
||||
sampling_prepare(target_model, x)
|
||||
|
||||
return target_model.model, x, timestep, uncond, cond, cond_scale, model_options, seed
|
||||
|
||||
unet.add_conditioning_modifier(conditioning_modifier)
|
||||
|
||||
process.sd_model.forge_objects.unet = unet
|
||||
return
|
||||
|
||||
|
||||
model_patcher.extra_weight_calculators['fooocus'] = calculate_weight_fooocus
|
||||
add_supported_control_model(FooocusInpaintPatcher)
|
||||
@@ -5,7 +5,6 @@ from ldm_patched.modules.args_parser import args
|
||||
import ldm_patched.modules.utils
|
||||
import torch
|
||||
import sys
|
||||
import os
|
||||
|
||||
class VRAMState(Enum):
|
||||
DISABLED = 0 #No vram present: no need to move models to vram
|
||||
@@ -59,7 +58,7 @@ try:
|
||||
except:
|
||||
pass
|
||||
|
||||
if args.always_cpu or os.environ.get("FORGE_CQ_TEST", ""):
|
||||
if args.always_cpu:
|
||||
cpu_state = CPUState.CPU
|
||||
|
||||
def is_intel_xpu():
|
||||
|
||||
@@ -5,6 +5,10 @@ import inspect
|
||||
import ldm_patched.modules.utils
|
||||
import ldm_patched.modules.model_management
|
||||
|
||||
|
||||
extra_weight_calculators = {}
|
||||
|
||||
|
||||
class ModelPatcher:
|
||||
def __init__(self, model, load_device, offload_device, size=0, current_device=None, weight_inplace_update=False):
|
||||
self.size = size
|
||||
@@ -329,6 +333,8 @@ class ModelPatcher:
|
||||
b2 = ldm_patched.modules.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, torch.float32)
|
||||
|
||||
weight += ((torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)) * alpha).reshape(weight.shape).type(weight.dtype)
|
||||
elif patch_type in extra_weight_calculators:
|
||||
weight = extra_weight_calculators[patch_type](weight, alpha, v)
|
||||
else:
|
||||
print("patch type not recognized", patch_type, key)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user