mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-03-13 08:59:51 +00:00
photomaker
This commit is contained in:
@@ -394,6 +394,7 @@ class ControlNetForForgeOfficial(scripts.Script):
|
||||
params.model.process_after_every_sampling(p, params, *args, **kwargs)
|
||||
return
|
||||
|
||||
@torch.no_grad()
|
||||
def process(self, p, *args, **kwargs):
|
||||
self.current_params = {}
|
||||
for i, unit in enumerate(self.get_enabled_units(p)):
|
||||
@@ -403,11 +404,13 @@ class ControlNetForForgeOfficial(scripts.Script):
|
||||
self.current_params[i] = params
|
||||
return
|
||||
|
||||
@torch.no_grad()
|
||||
def process_before_every_sampling(self, p, *args, **kwargs):
|
||||
for i, unit in enumerate(self.get_enabled_units(p)):
|
||||
self.process_unit_before_every_sampling(p, unit, self.current_params[i], *args, **kwargs)
|
||||
return
|
||||
|
||||
@torch.no_grad()
|
||||
def postprocess_batch_list(self, p, *args, **kwargs):
|
||||
for i, unit in enumerate(self.get_enabled_units(p)):
|
||||
self.process_unit_after_every_sampling(p, unit, self.current_params[i], *args, **kwargs)
|
||||
|
||||
@@ -0,0 +1,77 @@
|
||||
from modules_forge.supported_preprocessor import PreprocessorClipVision, Preprocessor, PreprocessorParameter
|
||||
from modules_forge.shared import add_supported_preprocessor
|
||||
from modules_forge.forge_util import numpy_to_pytorch
|
||||
from modules_forge.shared import add_supported_control_model
|
||||
from modules_forge.supported_controlnet import ControlModelPatcher
|
||||
from ldm_patched.contrib.external_photomaker import PhotoMakerEncode, PhotoMakerIDEncoder
|
||||
from ldm_patched.modules.sample import convert_cond
|
||||
from ldm_patched.modules.samplers import encode_model_conds
|
||||
|
||||
|
||||
opPhotoMakerEncode = PhotoMakerEncode().apply_photomaker
|
||||
|
||||
|
||||
class PreprocessorClipvisionForPhotomaker(Preprocessor):
|
||||
def __init__(self, name):
|
||||
super().__init__()
|
||||
self.name = name
|
||||
self.tags = ['PhotoMaker']
|
||||
self.model_filename_filters = ['PhotoMaker', 'Photo_Maker', 'Photo-Maker']
|
||||
self.sorting_priority = 20
|
||||
self.slider_resolution = PreprocessorParameter(visible=False)
|
||||
self.corp_image_with_a1111_mask_when_in_img2img_inpaint_tab = False
|
||||
self.show_control_mode = False
|
||||
|
||||
|
||||
add_supported_preprocessor(PreprocessorClipvisionForPhotomaker(
|
||||
name='ClipVision (Photomaker)',
|
||||
))
|
||||
|
||||
|
||||
class PhotomakerPatcher(ControlModelPatcher):
|
||||
@staticmethod
|
||||
def try_build_from_state_dict(state_dict, ckpt_path):
|
||||
if "id_encoder" not in state_dict:
|
||||
return None
|
||||
|
||||
state_dict = state_dict["id_encoder"]
|
||||
|
||||
photomaker_model = PhotoMakerIDEncoder()
|
||||
photomaker_model.load_state_dict(state_dict)
|
||||
|
||||
return PhotomakerPatcher(photomaker_model)
|
||||
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
return
|
||||
|
||||
def process_before_every_sampling(self, process, cond, mask, *args, **kwargs):
|
||||
unet = process.sd_model.forge_objects.unet.clone()
|
||||
clip = process.sd_model.forge_objects.clip
|
||||
text = process.prompts[0]
|
||||
|
||||
cond_modified = opPhotoMakerEncode(photomaker=self.model, image=cond.movedim(1, -1), clip=clip, text=text)[0]
|
||||
noise = kwargs['x']
|
||||
cond_modified = encode_model_conds(
|
||||
model_function=unet.model.extra_conds,
|
||||
conds=convert_cond(cond_modified),
|
||||
noise=noise,
|
||||
device=noise.device,
|
||||
prompt_type="positive"
|
||||
)[0]
|
||||
|
||||
def conditioning_modifier(model, x, timestep, uncond, cond, cond_scale, model_options, seed):
|
||||
cond = cond.copy()
|
||||
for c in cond:
|
||||
c['pooled_output'] = cond_modified['pooled_output']
|
||||
c['cross_attn'] = cond_modified['cross_attn']
|
||||
c['model_conds'].update(cond_modified['model_conds'])
|
||||
return model, x, timestep, uncond, cond, cond_scale, model_options, seed
|
||||
|
||||
unet.add_conditioning_modifier(conditioning_modifier)
|
||||
process.sd_model.forge_objects.unet = unet
|
||||
return
|
||||
|
||||
|
||||
add_supported_control_model(PhotomakerPatcher)
|
||||
@@ -168,6 +168,7 @@ class EmbeddingsWithFixes(torch.nn.Module):
|
||||
self.wrapped = wrapped
|
||||
self.embeddings = embeddings
|
||||
self.textual_inversion_key = textual_inversion_key
|
||||
self.weight = self.wrapped.weight
|
||||
|
||||
def forward(self, input_ids):
|
||||
batch_fixes = self.embeddings.fixes
|
||||
|
||||
Reference in New Issue
Block a user