diff --git a/extensions-builtin/sd_forge_ipadapter/lib_ipadapter/IPAdapterPlus.py b/extensions-builtin/sd_forge_ipadapter/lib_ipadapter/IPAdapterPlus.py index 1bd55560..e8f142c1 100644 --- a/extensions-builtin/sd_forge_ipadapter/lib_ipadapter/IPAdapterPlus.py +++ b/extensions-builtin/sd_forge_ipadapter/lib_ipadapter/IPAdapterPlus.py @@ -356,15 +356,9 @@ class IPAdapter(nn.Module): @torch.inference_mode() def get_image_embeds_instantid(self, prompt_image_emb): - image_proj_model_in_features = 512 - if isinstance(prompt_image_emb, torch.Tensor): - prompt_image_emb = prompt_image_emb.clone().detach() - else: - prompt_image_emb = torch.tensor(prompt_image_emb) - - prompt_image_emb = prompt_image_emb.to(device=self.device, dtype=torch.float32) - prompt_image_emb = prompt_image_emb.reshape([1, -1, image_proj_model_in_features]) - return prompt_image_emb, torch.zeros_like(prompt_image_emb) + c = self.image_proj_model(prompt_image_emb) + uc = self.image_proj_model(torch.zeros_like(prompt_image_emb)) + return c, uc class CrossAttentionPatch: # forward for patching @@ -729,6 +723,8 @@ class IPAdapterApply: if self.is_faceid and self.is_plus: image_prompt_embeds = self.ipadapter.get_image_embeds_faceid_plus(face_embed.to(self.device, dtype=self.dtype), clip_embed.to(self.device, dtype=self.dtype), weight_v2, faceid_v2) uncond_image_prompt_embeds = self.ipadapter.get_image_embeds_faceid_plus(face_embed_zeroed.to(self.device, dtype=self.dtype), clip_embed_zeroed.to(self.device, dtype=self.dtype), weight_v2, faceid_v2) + elif self.is_instant_id: + image_prompt_embeds, uncond_image_prompt_embeds = self.ipadapter.get_image_embeds_instantid(face_embed.to(self.device, dtype=self.dtype)) else: image_prompt_embeds, uncond_image_prompt_embeds = self.ipadapter.get_image_embeds(clip_embed.to(self.device, dtype=self.dtype), clip_embed_zeroed.to(self.device, dtype=self.dtype))