Added comparitive loss when training clip encoder. Allow selecting clip layer. on ip adapter. Improvements to prior prediction

This commit is contained in:
Jaret Burkett
2024-02-05 07:40:03 -07:00
parent 177c7130ec
commit e18e0cb5f8
7 changed files with 227 additions and 65 deletions

View File

@@ -5,6 +5,7 @@ import sys
from PIL import Image
from torch.nn import Parameter
from torch.nn.modules.module import T
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from toolkit.models.clip_pre_processor import CLIPImagePreProcessor
@@ -173,6 +174,7 @@ class IPAdapter(torch.nn.Module):
self.input_size = 224
self.clip_noise_zero = True
self.unconditional: torch.Tensor = None
self.additional_loss = None
if self.config.image_encoder_arch == 'clip' or self.config.image_encoder_arch == 'clip+':
try:
self.clip_image_processor = CLIPImageProcessor.from_pretrained(adapter_config.image_encoder_path)
@@ -451,10 +453,7 @@ class IPAdapter(torch.nn.Module):
):
with torch.no_grad():
device = self.sd_ref().unet.device
if self.config.type.startswith('ip+'):
clip_image_embeds = torch.cat([x['penultimate_hidden_states'] for x in image_embeds_list], dim=0)
else:
clip_image_embeds = torch.cat([x['image_embeds'] for x in image_embeds_list], dim=0)
clip_image_embeds = torch.cat([x[self.config.clip_layer] for x in image_embeds_list], dim=0)
if self.config.quad_image:
# get the outputs of the quat
@@ -548,7 +547,7 @@ class IPAdapter(torch.nn.Module):
# if drop:
# clip_image = clip_image * 0
with torch.set_grad_enabled(is_training):
if is_training:
if is_training and self.config.train_image_encoder:
self.image_encoder.train()
clip_image = clip_image.requires_grad_(True)
if self.preprocessor is not None:
@@ -565,16 +564,39 @@ class IPAdapter(torch.nn.Module):
clip_image, output_hidden_states=True
)
if self.config.type.startswith('ip+'):
if self.config.clip_layer == 'penultimate_hidden_states':
# they skip last layer for ip+
# https://github.com/tencent-ailab/IP-Adapter/blob/f4b6742db35ea6d81c7b829a55b0a312c7f5a677/tutorial_train_plus.py#L403C26-L403C26
clip_image_embeds = clip_output.hidden_states[-2]
elif self.config.clip_layer == 'last_hidden_state':
clip_image_embeds = clip_output.hidden_states[-1]
else:
clip_image_embeds = clip_output.image_embeds
if self.config.quad_image:
# get the outputs of the quat
chunks = clip_image_embeds.chunk(quad_count, dim=0)
if self.config.train_image_encoder and is_training:
# perform a loss across all chunks this will teach the vision encoder to
# identify similarities in our pairs of images and ignore things that do not make them similar
num_losses = 0
total_loss = None
for chunk in chunks:
for chunk2 in chunks:
if chunk is not chunk2:
loss = F.mse_loss(chunk, chunk2)
if total_loss is None:
total_loss = loss
else:
total_loss = total_loss + loss
num_losses += 1
if total_loss is not None:
total_loss = total_loss / num_losses
total_loss = total_loss * 1e-2
if self.additional_loss is not None:
total_loss = total_loss + self.additional_loss
self.additional_loss = total_loss
chunk_sum = torch.zeros_like(chunks[0])
for chunk in chunks:
chunk_sum = chunk_sum + chunk
@@ -582,7 +604,7 @@ class IPAdapter(torch.nn.Module):
clip_image_embeds = chunk_sum / quad_count
if not is_training:
if not is_training or not self.config.train_image_encoder:
clip_image_embeds = clip_image_embeds.detach()
return clip_image_embeds
@@ -594,6 +616,17 @@ class IPAdapter(torch.nn.Module):
embeddings.text_embeds = torch.cat([embeddings.text_embeds, image_prompt_embeds], dim=1)
return embeddings
def train(self: T, mode: bool = True) -> T:
if self.config.train_image_encoder:
self.image_encoder.train(mode)
if not self.config.train_only_image_encoder:
for attn_processor in self.adapter_modules:
attn_processor.train(mode)
if self.image_proj_model is not None:
self.image_proj_model.train(mode)
return super().train(mode)
def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
if self.config.train_only_image_encoder:
yield from self.image_encoder.parameters(recurse)