This commit is contained in:
lllyasviel
2024-01-31 18:33:44 -08:00
parent 13a649010a
commit ab6f349a00
2 changed files with 43 additions and 7 deletions

View File

@@ -260,7 +260,10 @@ def NPToTensor(image):
return out
class IPAdapter(nn.Module):
def __init__(self, ipadapter_model, cross_attention_dim=1024, output_cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4, is_sdxl=False, is_plus=False, is_full=False, is_faceid=False):
def __init__(self, ipadapter_model, cross_attention_dim=1024, output_cross_attention_dim=1024,
clip_embeddings_dim=1024, clip_extra_context_tokens=4,
is_sdxl=False, is_plus=False, is_full=False,
is_faceid=False, is_instant_id=False):
super().__init__()
self.clip_embeddings_dim = clip_embeddings_dim
@@ -270,8 +273,11 @@ class IPAdapter(nn.Module):
self.is_sdxl = is_sdxl
self.is_full = is_full
self.is_plus = is_plus
self.is_instant_id = is_instant_id
if is_faceid:
if is_instant_id:
self.image_proj_model = self.init_proj_instantid()
elif is_faceid:
self.image_proj_model = self.init_proj_faceid()
elif is_plus:
self.image_proj_model = self.init_proj_plus()
@@ -324,6 +330,19 @@ class IPAdapter(nn.Module):
)
return image_proj_model
def init_proj_instantid(self, image_emb_dim=512, num_tokens=16):
image_proj_model = Resampler(
dim=1280,
depth=4,
dim_head=64,
heads=20,
num_queries=num_tokens,
embedding_dim=image_emb_dim,
output_dim=self.cross_attention_dim,
ff_mult=4,
)
return image_proj_model
@torch.inference_mode()
def get_image_embeds(self, clip_embed, clip_embed_zeroed):
image_prompt_embeds = self.image_proj_model(clip_embed)
@@ -335,6 +354,18 @@ class IPAdapter(nn.Module):
embeds = self.image_proj_model(face_embed, clip_embed, scale=s_scale, shortcut=shortcut)
return embeds
@torch.inference_mode()
def get_image_embeds_instantid(self, prompt_image_emb):
image_proj_model_in_features = 512
if isinstance(prompt_image_emb, torch.Tensor):
prompt_image_emb = prompt_image_emb.clone().detach()
else:
prompt_image_emb = torch.tensor(prompt_image_emb)
prompt_image_emb = prompt_image_emb.to(device=self.device, dtype=torch.float32)
prompt_image_emb = prompt_image_emb.reshape([1, -1, image_proj_model_in_features])
return prompt_image_emb, torch.zeros_like(prompt_image_emb)
class CrossAttentionPatch:
# forward for patching
def __init__(self, weight, ipadapter, number, cond, uncond, weight_type, mask=None, sigma_start=0.0, sigma_end=1.0, unfold_batch=False):
@@ -591,7 +622,7 @@ class IPAdapterApply:
def apply_ipadapter(self, ipadapter, model, weight, clip_vision=None, image=None, weight_type="original",
noise=None, embeds=None, attn_mask=None, start_at=0.0, end_at=1.0, unfold_batch=False,
insightface=None, faceid_v2=False, weight_v2=False, using_instant_id=False):
insightface=None, faceid_v2=False, weight_v2=False, instant_id=False):
self.dtype = torch.float16 if ldm_patched.modules.model_management.should_use_fp16() else torch.float32
self.device = ldm_patched.modules.model_management.get_torch_device()
@@ -600,6 +631,7 @@ class IPAdapterApply:
self.is_portrait = "proj.2.weight" in ipadapter["image_proj"] and not "proj.3.weight" in ipadapter["image_proj"] and not "0.to_q_lora.down.weight" in ipadapter["ip_adapter"]
self.is_faceid = self.is_portrait or "0.to_q_lora.down.weight" in ipadapter["ip_adapter"]
self.is_plus = (self.is_full or "latents" in ipadapter["image_proj"] or "perceiver_resampler.proj_in.weight" in ipadapter["image_proj"])
self.is_instant_id = instant_id
if self.is_faceid and not insightface:
raise Exception('InsightFace must be provided for FaceID models.')
@@ -614,7 +646,7 @@ class IPAdapterApply:
clip_embed = embeds[0].cpu()
clip_embed_zeroed = embeds[1].cpu()
else:
if self.is_faceid:
if self.is_instant_id or self.is_faceid:
insightface.det_model.input_size = (640,640) # reset the detection size
face_img = tensorToNP(image)
face_embed = []
@@ -638,8 +670,11 @@ class IPAdapterApply:
image = torch.stack(face_clipvision, dim=0)
neg_image = image_add_noise(image, noise) if noise > 0 else None
if self.is_plus:
if self.is_instant_id:
clip_embed = face_embed
clip_embed_zeroed = torch.zeros_like(clip_embed)
elif self.is_plus:
clip_embed = clip_vision.encode_image(image).penultimate_hidden_states
if noise > 0:
clip_embed_zeroed = clip_vision.encode_image(neg_image).penultimate_hidden_states
@@ -683,6 +718,7 @@ class IPAdapterApply:
is_plus=self.is_plus,
is_full=self.is_full,
is_faceid=self.is_faceid,
is_instant_id=self.is_instant_id
)
self.ipadapter.to(self.device, dtype=self.dtype)

View File

@@ -83,7 +83,7 @@ class PreprocessorInsightFaceForInstantID(Preprocessor):
embeds=None,
attn_mask=None,
unfold_batch=False,
using_instant_id=True
instant_id=True
)
return cond