Update IPAdapterPlus.py

This commit is contained in:
lllyasviel
2024-01-31 18:50:28 -08:00
parent 9b9e800c6c
commit 83a33aef32

View File

@@ -356,15 +356,9 @@ class IPAdapter(nn.Module):
@torch.inference_mode()
def get_image_embeds_instantid(self, prompt_image_emb):
image_proj_model_in_features = 512
if isinstance(prompt_image_emb, torch.Tensor):
prompt_image_emb = prompt_image_emb.clone().detach()
else:
prompt_image_emb = torch.tensor(prompt_image_emb)
prompt_image_emb = prompt_image_emb.to(device=self.device, dtype=torch.float32)
prompt_image_emb = prompt_image_emb.reshape([1, -1, image_proj_model_in_features])
return prompt_image_emb, torch.zeros_like(prompt_image_emb)
c = self.image_proj_model(prompt_image_emb)
uc = self.image_proj_model(torch.zeros_like(prompt_image_emb))
return c, uc
class CrossAttentionPatch:
# forward for patching
@@ -729,6 +723,8 @@ class IPAdapterApply:
if self.is_faceid and self.is_plus:
image_prompt_embeds = self.ipadapter.get_image_embeds_faceid_plus(face_embed.to(self.device, dtype=self.dtype), clip_embed.to(self.device, dtype=self.dtype), weight_v2, faceid_v2)
uncond_image_prompt_embeds = self.ipadapter.get_image_embeds_faceid_plus(face_embed_zeroed.to(self.device, dtype=self.dtype), clip_embed_zeroed.to(self.device, dtype=self.dtype), weight_v2, faceid_v2)
elif self.is_instant_id:
image_prompt_embeds, uncond_image_prompt_embeds = self.ipadapter.get_image_embeds_instantid(face_embed.to(self.device, dtype=self.dtype))
else:
image_prompt_embeds, uncond_image_prompt_embeds = self.ipadapter.get_image_embeds(clip_embed.to(self.device, dtype=self.dtype), clip_embed_zeroed.to(self.device, dtype=self.dtype))