From a9212b640f0820b28f7fa733e855fb1aab9f2421 Mon Sep 17 00:00:00 2001 From: lllyasviel Date: Thu, 1 Feb 2024 22:57:49 -0800 Subject: [PATCH] photomaker --- .../sd_forge_controlnet/scripts/controlnet.py | 3 + .../scripts/forge_photomaker.py | 77 +++++++++++++++++++ modules/sd_hijack.py | 1 + 3 files changed, 81 insertions(+) create mode 100644 extensions-builtin/sd_forge_photomaker/scripts/forge_photomaker.py diff --git a/extensions-builtin/sd_forge_controlnet/scripts/controlnet.py b/extensions-builtin/sd_forge_controlnet/scripts/controlnet.py index c37d3d63..74017c27 100644 --- a/extensions-builtin/sd_forge_controlnet/scripts/controlnet.py +++ b/extensions-builtin/sd_forge_controlnet/scripts/controlnet.py @@ -394,6 +394,7 @@ class ControlNetForForgeOfficial(scripts.Script): params.model.process_after_every_sampling(p, params, *args, **kwargs) return + @torch.no_grad() def process(self, p, *args, **kwargs): self.current_params = {} for i, unit in enumerate(self.get_enabled_units(p)): @@ -403,11 +404,13 @@ class ControlNetForForgeOfficial(scripts.Script): self.current_params[i] = params return + @torch.no_grad() def process_before_every_sampling(self, p, *args, **kwargs): for i, unit in enumerate(self.get_enabled_units(p)): self.process_unit_before_every_sampling(p, unit, self.current_params[i], *args, **kwargs) return + @torch.no_grad() def postprocess_batch_list(self, p, *args, **kwargs): for i, unit in enumerate(self.get_enabled_units(p)): self.process_unit_after_every_sampling(p, unit, self.current_params[i], *args, **kwargs) diff --git a/extensions-builtin/sd_forge_photomaker/scripts/forge_photomaker.py b/extensions-builtin/sd_forge_photomaker/scripts/forge_photomaker.py new file mode 100644 index 00000000..37cab46d --- /dev/null +++ b/extensions-builtin/sd_forge_photomaker/scripts/forge_photomaker.py @@ -0,0 +1,77 @@ +from modules_forge.supported_preprocessor import PreprocessorClipVision, Preprocessor, PreprocessorParameter +from modules_forge.shared import add_supported_preprocessor +from modules_forge.forge_util import numpy_to_pytorch +from modules_forge.shared import add_supported_control_model +from modules_forge.supported_controlnet import ControlModelPatcher +from ldm_patched.contrib.external_photomaker import PhotoMakerEncode, PhotoMakerIDEncoder +from ldm_patched.modules.sample import convert_cond +from ldm_patched.modules.samplers import encode_model_conds + + +opPhotoMakerEncode = PhotoMakerEncode().apply_photomaker + + +class PreprocessorClipvisionForPhotomaker(Preprocessor): + def __init__(self, name): + super().__init__() + self.name = name + self.tags = ['PhotoMaker'] + self.model_filename_filters = ['PhotoMaker', 'Photo_Maker', 'Photo-Maker'] + self.sorting_priority = 20 + self.slider_resolution = PreprocessorParameter(visible=False) + self.corp_image_with_a1111_mask_when_in_img2img_inpaint_tab = False + self.show_control_mode = False + + +add_supported_preprocessor(PreprocessorClipvisionForPhotomaker( + name='ClipVision (Photomaker)', +)) + + +class PhotomakerPatcher(ControlModelPatcher): + @staticmethod + def try_build_from_state_dict(state_dict, ckpt_path): + if "id_encoder" not in state_dict: + return None + + state_dict = state_dict["id_encoder"] + + photomaker_model = PhotoMakerIDEncoder() + photomaker_model.load_state_dict(state_dict) + + return PhotomakerPatcher(photomaker_model) + + def __init__(self, model): + super().__init__() + self.model = model + return + + def process_before_every_sampling(self, process, cond, mask, *args, **kwargs): + unet = process.sd_model.forge_objects.unet.clone() + clip = process.sd_model.forge_objects.clip + text = process.prompts[0] + + cond_modified = opPhotoMakerEncode(photomaker=self.model, image=cond.movedim(1, -1), clip=clip, text=text)[0] + noise = kwargs['x'] + cond_modified = encode_model_conds( + model_function=unet.model.extra_conds, + conds=convert_cond(cond_modified), + noise=noise, + device=noise.device, + prompt_type="positive" + )[0] + + def conditioning_modifier(model, x, timestep, uncond, cond, cond_scale, model_options, seed): + cond = cond.copy() + for c in cond: + c['pooled_output'] = cond_modified['pooled_output'] + c['cross_attn'] = cond_modified['cross_attn'] + c['model_conds'].update(cond_modified['model_conds']) + return model, x, timestep, uncond, cond, cond_scale, model_options, seed + + unet.add_conditioning_modifier(conditioning_modifier) + process.sd_model.forge_objects.unet = unet + return + + +add_supported_control_model(PhotomakerPatcher) diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 4a96f164..e55cd217 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -168,6 +168,7 @@ class EmbeddingsWithFixes(torch.nn.Module): self.wrapped = wrapped self.embeddings = embeddings self.textual_inversion_key = textual_inversion_key + self.weight = self.wrapped.weight def forward(self, input_ids): batch_fixes = self.embeddings.fixes