mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 19:21:39 +00:00
Added an experimental clip fusion model that is showing promise for embedding concepts
This commit is contained in:
@@ -119,11 +119,12 @@ class PhotoMakerCLIPEncoder(CLIPVisionModelWithProjection):
|
||||
super().__init__(config, *model_args, **model_kwargs)
|
||||
self.visual_projection_2 = nn.Linear(1024, 1280, bias=False)
|
||||
|
||||
def forward(self, id_pixel_values, do_projection2=True):
|
||||
def forward(self, id_pixel_values, do_projection2=True, output_full=False):
|
||||
b, num_inputs, c, h, w = id_pixel_values.shape
|
||||
id_pixel_values = id_pixel_values.view(b * num_inputs, c, h, w)
|
||||
|
||||
shared_id_embeds = self.vision_model(id_pixel_values)[1]
|
||||
# last_hidden_state, 1, 257, 1024
|
||||
vision_output = self.vision_model(id_pixel_values, output_hidden_states=True)
|
||||
shared_id_embeds = vision_output[1]
|
||||
id_embeds = self.visual_projection(shared_id_embeds)
|
||||
|
||||
id_embeds = id_embeds.view(b, num_inputs, 1, -1)
|
||||
@@ -133,6 +134,8 @@ class PhotoMakerCLIPEncoder(CLIPVisionModelWithProjection):
|
||||
id_embeds_2 = id_embeds_2.view(b, num_inputs, 1, -1)
|
||||
id_embeds = torch.cat((id_embeds, id_embeds_2), dim=-1)
|
||||
|
||||
if output_full:
|
||||
return id_embeds, vision_output
|
||||
return id_embeds
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user