photomaker

This commit is contained in:
lllyasviel
2024-02-01 22:57:49 -08:00
parent 4b27699e65
commit a9212b640f
3 changed files with 81 additions and 0 deletions

View File

@@ -394,6 +394,7 @@ class ControlNetForForgeOfficial(scripts.Script):
params.model.process_after_every_sampling(p, params, *args, **kwargs)
return
@torch.no_grad()
def process(self, p, *args, **kwargs):
self.current_params = {}
for i, unit in enumerate(self.get_enabled_units(p)):
@@ -403,11 +404,13 @@ class ControlNetForForgeOfficial(scripts.Script):
self.current_params[i] = params
return
@torch.no_grad()
def process_before_every_sampling(self, p, *args, **kwargs):
for i, unit in enumerate(self.get_enabled_units(p)):
self.process_unit_before_every_sampling(p, unit, self.current_params[i], *args, **kwargs)
return
@torch.no_grad()
def postprocess_batch_list(self, p, *args, **kwargs):
for i, unit in enumerate(self.get_enabled_units(p)):
self.process_unit_after_every_sampling(p, unit, self.current_params[i], *args, **kwargs)

View File

@@ -0,0 +1,77 @@
from modules_forge.supported_preprocessor import PreprocessorClipVision, Preprocessor, PreprocessorParameter
from modules_forge.shared import add_supported_preprocessor
from modules_forge.forge_util import numpy_to_pytorch
from modules_forge.shared import add_supported_control_model
from modules_forge.supported_controlnet import ControlModelPatcher
from ldm_patched.contrib.external_photomaker import PhotoMakerEncode, PhotoMakerIDEncoder
from ldm_patched.modules.sample import convert_cond
from ldm_patched.modules.samplers import encode_model_conds
opPhotoMakerEncode = PhotoMakerEncode().apply_photomaker
class PreprocessorClipvisionForPhotomaker(Preprocessor):
def __init__(self, name):
super().__init__()
self.name = name
self.tags = ['PhotoMaker']
self.model_filename_filters = ['PhotoMaker', 'Photo_Maker', 'Photo-Maker']
self.sorting_priority = 20
self.slider_resolution = PreprocessorParameter(visible=False)
self.corp_image_with_a1111_mask_when_in_img2img_inpaint_tab = False
self.show_control_mode = False
add_supported_preprocessor(PreprocessorClipvisionForPhotomaker(
name='ClipVision (Photomaker)',
))
class PhotomakerPatcher(ControlModelPatcher):
@staticmethod
def try_build_from_state_dict(state_dict, ckpt_path):
if "id_encoder" not in state_dict:
return None
state_dict = state_dict["id_encoder"]
photomaker_model = PhotoMakerIDEncoder()
photomaker_model.load_state_dict(state_dict)
return PhotomakerPatcher(photomaker_model)
def __init__(self, model):
super().__init__()
self.model = model
return
def process_before_every_sampling(self, process, cond, mask, *args, **kwargs):
unet = process.sd_model.forge_objects.unet.clone()
clip = process.sd_model.forge_objects.clip
text = process.prompts[0]
cond_modified = opPhotoMakerEncode(photomaker=self.model, image=cond.movedim(1, -1), clip=clip, text=text)[0]
noise = kwargs['x']
cond_modified = encode_model_conds(
model_function=unet.model.extra_conds,
conds=convert_cond(cond_modified),
noise=noise,
device=noise.device,
prompt_type="positive"
)[0]
def conditioning_modifier(model, x, timestep, uncond, cond, cond_scale, model_options, seed):
cond = cond.copy()
for c in cond:
c['pooled_output'] = cond_modified['pooled_output']
c['cross_attn'] = cond_modified['cross_attn']
c['model_conds'].update(cond_modified['model_conds'])
return model, x, timestep, uncond, cond, cond_scale, model_options, seed
unet.add_conditioning_modifier(conditioning_modifier)
process.sd_model.forge_objects.unet = unet
return
add_supported_control_model(PhotomakerPatcher)

View File

@@ -168,6 +168,7 @@ class EmbeddingsWithFixes(torch.nn.Module):
self.wrapped = wrapped
self.embeddings = embeddings
self.textual_inversion_key = textual_inversion_key
self.weight = self.wrapped.weight
def forward(self, input_ids):
batch_fixes = self.embeddings.fixes