support inpaint models from fooocus

put inpaint_v26.fooocus.patch in models\ControlNet, control SDXL models only
To get same algorithm as Fooocus, set "Stop at" (Ending Control Step) to 0.5
Fooocus always use 0.5 but in Forge users may use other values.
Results are best when stop at < 0.7. The model is not optimized with ending timesteps > 0.7
Supports inpaint_global_harmonious, inpaint_only, inpaint_only+lama.
In theory the inpaint_only+lama always outperform Fooocus in object removal task (but not all tasks).
This commit is contained in:
lllyasviel
2024-02-09 17:08:48 -08:00
parent d2af6d1b44
commit fb2e271668
6 changed files with 140 additions and 4 deletions

View File

@@ -7,7 +7,7 @@ from lib_controlnet.enums import StableDiffusionVersion
from modules_forge.shared import controlnet_dir, supported_preprocessors
CN_MODEL_EXTS = [".pt", ".pth", ".ckpt", ".safetensors", ".bin"]
CN_MODEL_EXTS = [".pt", ".pth", ".ckpt", ".safetensors", ".bin", ".patch"]
def traverse_all_files(curr_path, model_list):

View File

@@ -417,7 +417,7 @@ class ControlNetForForgeOfficial(scripts.Script):
cond = params.control_cond
mask = params.control_mask
kwargs.update(dict(unit=unit, params=params))
kwargs.update(dict(unit=unit, params=params, cond_original=cond.clone(), mask_original=mask.clone()))
params.model.strength = float(unit.weight)
params.model.start_percent = float(unit.guidance_start)

View File

@@ -0,0 +1,131 @@
import os
import torch
import copy
from modules_forge.shared import add_supported_control_model
from modules_forge.supported_controlnet import ControlModelPatcher
from modules_forge.forge_sampler import sampling_prepare
from ldm_patched.modules.utils import load_torch_file
from ldm_patched.modules import model_patcher
from ldm_patched.modules.model_management import cast_to_device, LoadedModel, current_loaded_models
from ldm_patched.modules.lora import model_lora_keys_unet
def is_model_loaded(model):
return LoadedModel(model) in current_loaded_models
class InpaintHead(torch.nn.Module):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.head = torch.nn.Parameter(torch.empty(size=(320, 5, 3, 3), device="cpu"))
def __call__(self, x):
x = torch.nn.functional.pad(x, (1, 1, 1, 1), "replicate")
return torch.nn.functional.conv2d(input=x, weight=self.head)
def load_fooocus_patch(lora: dict, to_load: dict):
patch_dict = {}
loaded_keys = set()
for key in to_load.values():
if value := lora.get(key, None):
patch_dict[key] = ("fooocus", value)
loaded_keys.add(key)
not_loaded = sum(1 for x in lora if x not in loaded_keys)
print(f"[Fooocus Patch Loader] {len(loaded_keys)} keys loaded, {not_loaded} remaining keys not found in model.")
return patch_dict
def calculate_weight_fooocus(weight, alpha, v):
w1 = cast_to_device(v[0], weight.device, torch.float32)
if w1.shape == weight.shape:
w_min = cast_to_device(v[1], weight.device, torch.float32)
w_max = cast_to_device(v[2], weight.device, torch.float32)
w1 = (w1 / 255.0) * (w_max - w_min) + w_min
weight += alpha * cast_to_device(w1, weight.device, weight.dtype)
else:
print(f"[Fooocus Patch Loader] weight not merged ({w1.shape} != {weight.shape})")
return weight
class FooocusInpaintPatcher(ControlModelPatcher):
@staticmethod
def try_build_from_state_dict(state_dict, ckpt_path):
if 'diffusion_model.time_embed.0.weight' in state_dict:
if len(state_dict['diffusion_model.time_embed.0.weight']) == 3:
return FooocusInpaintPatcher(state_dict)
return None
def __init__(self, state_dict):
super().__init__()
self.state_dict = state_dict
self.inpaint_head = InpaintHead().to(device=torch.device('cpu'), dtype=torch.float32)
self.inpaint_head.load_state_dict(load_torch_file(os.path.join(os.path.dirname(__file__), 'fooocus_inpaint_head')))
return
def process_before_every_sampling(self, process, cond, mask, *args, **kwargs):
cond_original = kwargs['cond_original']
mask_original = kwargs['mask_original']
unet_original = process.sd_model.forge_objects.unet.clone()
unet = process.sd_model.forge_objects.unet.clone()
vae = process.sd_model.forge_objects.vae
latent_image = vae.encode(cond_original.movedim(1, -1))
latent_image = process.sd_model.forge_objects.unet.model.latent_format.process_in(latent_image)
latent_mask = torch.nn.functional.max_pool2d(mask_original, (8, 8)).round().to(cond)
feed = torch.cat([
latent_mask.to(device=torch.device('cpu'), dtype=torch.float32),
latent_image.to(device=torch.device('cpu'), dtype=torch.float32)
], dim=1)
inpaint_head_feature = self.inpaint_head(feed)
def input_block_patch(h, transformer_options):
if transformer_options["block"][1] == 0:
h = h + inpaint_head_feature.to(h)
return h
unet.set_model_input_block_patch(input_block_patch)
lora_keys = model_lora_keys_unet(unet.model, {})
lora_keys.update({x: x for x in unet.model.state_dict().keys()})
loaded_lora = load_fooocus_patch(self.state_dict, lora_keys)
patched = unet.add_patches(loaded_lora, 1.0)
not_patched_count = sum(1 for x in loaded_lora if x not in patched)
if not_patched_count > 0:
print(f"[Fooocus Patch Loader] Failed to load {not_patched_count} keys")
sigma_start = unet.model.model_sampling.percent_to_sigma(self.start_percent)
sigma_end = unet.model.model_sampling.percent_to_sigma(self.end_percent)
def conditioning_modifier(model, x, timestep, uncond, cond, cond_scale, model_options, seed):
if timestep > sigma_start or timestep < sigma_end:
target_model = unet_original
model_options = copy.deepcopy(model_options)
if 'transformer_options' in model_options:
if 'patches' in model_options['transformer_options']:
if 'input_block_patch' in model_options['transformer_options']['patches']:
del model_options['transformer_options']['patches']['input_block_patch']
else:
target_model = unet
if not is_model_loaded(target_model):
sampling_prepare(target_model, x)
return target_model.model, x, timestep, uncond, cond, cond_scale, model_options, seed
unet.add_conditioning_modifier(conditioning_modifier)
process.sd_model.forge_objects.unet = unet
return
model_patcher.extra_weight_calculators['fooocus'] = calculate_weight_fooocus
add_supported_control_model(FooocusInpaintPatcher)

View File

@@ -5,7 +5,6 @@ from ldm_patched.modules.args_parser import args
import ldm_patched.modules.utils
import torch
import sys
import os
class VRAMState(Enum):
DISABLED = 0 #No vram present: no need to move models to vram
@@ -59,7 +58,7 @@ try:
except:
pass
if args.always_cpu or os.environ.get("FORGE_CQ_TEST", ""):
if args.always_cpu:
cpu_state = CPUState.CPU
def is_intel_xpu():

View File

@@ -5,6 +5,10 @@ import inspect
import ldm_patched.modules.utils
import ldm_patched.modules.model_management
extra_weight_calculators = {}
class ModelPatcher:
def __init__(self, model, load_device, offload_device, size=0, current_device=None, weight_inplace_update=False):
self.size = size
@@ -329,6 +333,8 @@ class ModelPatcher:
b2 = ldm_patched.modules.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, torch.float32)
weight += ((torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)) * alpha).reshape(weight.shape).type(weight.dtype)
elif patch_type in extra_weight_calculators:
weight = extra_weight_calculators[patch_type](weight, alpha, v)
else:
print("patch type not recognized", patch_type, key)