diff --git a/extensions-builtin/sd_forge_ipadapter/lib_ipadapter/IPAdapterPlus.py b/extensions-builtin/sd_forge_ipadapter/lib_ipadapter/IPAdapterPlus.py index c3d9a162..9cedb87d 100644 --- a/extensions-builtin/sd_forge_ipadapter/lib_ipadapter/IPAdapterPlus.py +++ b/extensions-builtin/sd_forge_ipadapter/lib_ipadapter/IPAdapterPlus.py @@ -260,7 +260,10 @@ def NPToTensor(image): return out class IPAdapter(nn.Module): - def __init__(self, ipadapter_model, cross_attention_dim=1024, output_cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4, is_sdxl=False, is_plus=False, is_full=False, is_faceid=False): + def __init__(self, ipadapter_model, cross_attention_dim=1024, output_cross_attention_dim=1024, + clip_embeddings_dim=1024, clip_extra_context_tokens=4, + is_sdxl=False, is_plus=False, is_full=False, + is_faceid=False, is_instant_id=False): super().__init__() self.clip_embeddings_dim = clip_embeddings_dim @@ -270,8 +273,11 @@ class IPAdapter(nn.Module): self.is_sdxl = is_sdxl self.is_full = is_full self.is_plus = is_plus + self.is_instant_id = is_instant_id - if is_faceid: + if is_instant_id: + self.image_proj_model = self.init_proj_instantid() + elif is_faceid: self.image_proj_model = self.init_proj_faceid() elif is_plus: self.image_proj_model = self.init_proj_plus() @@ -324,6 +330,19 @@ class IPAdapter(nn.Module): ) return image_proj_model + def init_proj_instantid(self, image_emb_dim=512, num_tokens=16): + image_proj_model = Resampler( + dim=1280, + depth=4, + dim_head=64, + heads=20, + num_queries=num_tokens, + embedding_dim=image_emb_dim, + output_dim=self.cross_attention_dim, + ff_mult=4, + ) + return image_proj_model + @torch.inference_mode() def get_image_embeds(self, clip_embed, clip_embed_zeroed): image_prompt_embeds = self.image_proj_model(clip_embed) @@ -335,6 +354,18 @@ class IPAdapter(nn.Module): embeds = self.image_proj_model(face_embed, clip_embed, scale=s_scale, shortcut=shortcut) return embeds + @torch.inference_mode() + def get_image_embeds_instantid(self, prompt_image_emb): + image_proj_model_in_features = 512 + if isinstance(prompt_image_emb, torch.Tensor): + prompt_image_emb = prompt_image_emb.clone().detach() + else: + prompt_image_emb = torch.tensor(prompt_image_emb) + + prompt_image_emb = prompt_image_emb.to(device=self.device, dtype=torch.float32) + prompt_image_emb = prompt_image_emb.reshape([1, -1, image_proj_model_in_features]) + return prompt_image_emb, torch.zeros_like(prompt_image_emb) + class CrossAttentionPatch: # forward for patching def __init__(self, weight, ipadapter, number, cond, uncond, weight_type, mask=None, sigma_start=0.0, sigma_end=1.0, unfold_batch=False): @@ -591,7 +622,7 @@ class IPAdapterApply: def apply_ipadapter(self, ipadapter, model, weight, clip_vision=None, image=None, weight_type="original", noise=None, embeds=None, attn_mask=None, start_at=0.0, end_at=1.0, unfold_batch=False, - insightface=None, faceid_v2=False, weight_v2=False, using_instant_id=False): + insightface=None, faceid_v2=False, weight_v2=False, instant_id=False): self.dtype = torch.float16 if ldm_patched.modules.model_management.should_use_fp16() else torch.float32 self.device = ldm_patched.modules.model_management.get_torch_device() @@ -600,6 +631,7 @@ class IPAdapterApply: self.is_portrait = "proj.2.weight" in ipadapter["image_proj"] and not "proj.3.weight" in ipadapter["image_proj"] and not "0.to_q_lora.down.weight" in ipadapter["ip_adapter"] self.is_faceid = self.is_portrait or "0.to_q_lora.down.weight" in ipadapter["ip_adapter"] self.is_plus = (self.is_full or "latents" in ipadapter["image_proj"] or "perceiver_resampler.proj_in.weight" in ipadapter["image_proj"]) + self.is_instant_id = instant_id if self.is_faceid and not insightface: raise Exception('InsightFace must be provided for FaceID models.') @@ -614,7 +646,7 @@ class IPAdapterApply: clip_embed = embeds[0].cpu() clip_embed_zeroed = embeds[1].cpu() else: - if self.is_faceid: + if self.is_instant_id or self.is_faceid: insightface.det_model.input_size = (640,640) # reset the detection size face_img = tensorToNP(image) face_embed = [] @@ -638,8 +670,11 @@ class IPAdapterApply: image = torch.stack(face_clipvision, dim=0) neg_image = image_add_noise(image, noise) if noise > 0 else None - - if self.is_plus: + + if self.is_instant_id: + clip_embed = face_embed + clip_embed_zeroed = torch.zeros_like(clip_embed) + elif self.is_plus: clip_embed = clip_vision.encode_image(image).penultimate_hidden_states if noise > 0: clip_embed_zeroed = clip_vision.encode_image(neg_image).penultimate_hidden_states @@ -683,6 +718,7 @@ class IPAdapterApply: is_plus=self.is_plus, is_full=self.is_full, is_faceid=self.is_faceid, + is_instant_id=self.is_instant_id ) self.ipadapter.to(self.device, dtype=self.dtype) diff --git a/extensions-builtin/sd_forge_ipadapter/scripts/forge_ipadapter.py b/extensions-builtin/sd_forge_ipadapter/scripts/forge_ipadapter.py index f328449d..e1d724b3 100644 --- a/extensions-builtin/sd_forge_ipadapter/scripts/forge_ipadapter.py +++ b/extensions-builtin/sd_forge_ipadapter/scripts/forge_ipadapter.py @@ -83,7 +83,7 @@ class PreprocessorInsightFaceForInstantID(Preprocessor): embeds=None, attn_mask=None, unfold_batch=False, - using_instant_id=True + instant_id=True ) return cond