mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 19:21:39 +00:00
Added comparitive loss when training clip encoder. Allow selecting clip layer. on ip adapter. Improvements to prior prediction
This commit is contained in:
@@ -5,6 +5,7 @@ import sys
|
||||
|
||||
from PIL import Image
|
||||
from torch.nn import Parameter
|
||||
from torch.nn.modules.module import T
|
||||
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
||||
|
||||
from toolkit.models.clip_pre_processor import CLIPImagePreProcessor
|
||||
@@ -173,6 +174,7 @@ class IPAdapter(torch.nn.Module):
|
||||
self.input_size = 224
|
||||
self.clip_noise_zero = True
|
||||
self.unconditional: torch.Tensor = None
|
||||
self.additional_loss = None
|
||||
if self.config.image_encoder_arch == 'clip' or self.config.image_encoder_arch == 'clip+':
|
||||
try:
|
||||
self.clip_image_processor = CLIPImageProcessor.from_pretrained(adapter_config.image_encoder_path)
|
||||
@@ -451,10 +453,7 @@ class IPAdapter(torch.nn.Module):
|
||||
):
|
||||
with torch.no_grad():
|
||||
device = self.sd_ref().unet.device
|
||||
if self.config.type.startswith('ip+'):
|
||||
clip_image_embeds = torch.cat([x['penultimate_hidden_states'] for x in image_embeds_list], dim=0)
|
||||
else:
|
||||
clip_image_embeds = torch.cat([x['image_embeds'] for x in image_embeds_list], dim=0)
|
||||
clip_image_embeds = torch.cat([x[self.config.clip_layer] for x in image_embeds_list], dim=0)
|
||||
|
||||
if self.config.quad_image:
|
||||
# get the outputs of the quat
|
||||
@@ -548,7 +547,7 @@ class IPAdapter(torch.nn.Module):
|
||||
# if drop:
|
||||
# clip_image = clip_image * 0
|
||||
with torch.set_grad_enabled(is_training):
|
||||
if is_training:
|
||||
if is_training and self.config.train_image_encoder:
|
||||
self.image_encoder.train()
|
||||
clip_image = clip_image.requires_grad_(True)
|
||||
if self.preprocessor is not None:
|
||||
@@ -565,16 +564,39 @@ class IPAdapter(torch.nn.Module):
|
||||
clip_image, output_hidden_states=True
|
||||
)
|
||||
|
||||
if self.config.type.startswith('ip+'):
|
||||
if self.config.clip_layer == 'penultimate_hidden_states':
|
||||
# they skip last layer for ip+
|
||||
# https://github.com/tencent-ailab/IP-Adapter/blob/f4b6742db35ea6d81c7b829a55b0a312c7f5a677/tutorial_train_plus.py#L403C26-L403C26
|
||||
clip_image_embeds = clip_output.hidden_states[-2]
|
||||
elif self.config.clip_layer == 'last_hidden_state':
|
||||
clip_image_embeds = clip_output.hidden_states[-1]
|
||||
else:
|
||||
clip_image_embeds = clip_output.image_embeds
|
||||
|
||||
if self.config.quad_image:
|
||||
# get the outputs of the quat
|
||||
chunks = clip_image_embeds.chunk(quad_count, dim=0)
|
||||
if self.config.train_image_encoder and is_training:
|
||||
# perform a loss across all chunks this will teach the vision encoder to
|
||||
# identify similarities in our pairs of images and ignore things that do not make them similar
|
||||
num_losses = 0
|
||||
total_loss = None
|
||||
for chunk in chunks:
|
||||
for chunk2 in chunks:
|
||||
if chunk is not chunk2:
|
||||
loss = F.mse_loss(chunk, chunk2)
|
||||
if total_loss is None:
|
||||
total_loss = loss
|
||||
else:
|
||||
total_loss = total_loss + loss
|
||||
num_losses += 1
|
||||
if total_loss is not None:
|
||||
total_loss = total_loss / num_losses
|
||||
total_loss = total_loss * 1e-2
|
||||
if self.additional_loss is not None:
|
||||
total_loss = total_loss + self.additional_loss
|
||||
self.additional_loss = total_loss
|
||||
|
||||
chunk_sum = torch.zeros_like(chunks[0])
|
||||
for chunk in chunks:
|
||||
chunk_sum = chunk_sum + chunk
|
||||
@@ -582,7 +604,7 @@ class IPAdapter(torch.nn.Module):
|
||||
|
||||
clip_image_embeds = chunk_sum / quad_count
|
||||
|
||||
if not is_training:
|
||||
if not is_training or not self.config.train_image_encoder:
|
||||
clip_image_embeds = clip_image_embeds.detach()
|
||||
|
||||
return clip_image_embeds
|
||||
@@ -594,6 +616,17 @@ class IPAdapter(torch.nn.Module):
|
||||
embeddings.text_embeds = torch.cat([embeddings.text_embeds, image_prompt_embeds], dim=1)
|
||||
return embeddings
|
||||
|
||||
|
||||
def train(self: T, mode: bool = True) -> T:
|
||||
if self.config.train_image_encoder:
|
||||
self.image_encoder.train(mode)
|
||||
if not self.config.train_only_image_encoder:
|
||||
for attn_processor in self.adapter_modules:
|
||||
attn_processor.train(mode)
|
||||
if self.image_proj_model is not None:
|
||||
self.image_proj_model.train(mode)
|
||||
return super().train(mode)
|
||||
|
||||
def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
|
||||
if self.config.train_only_image_encoder:
|
||||
yield from self.image_encoder.parameters(recurse)
|
||||
|
||||
Reference in New Issue
Block a user