Bug fixes with ip adapter training. Made a clip pre processor that can be trained with ip adapter to help augment the clip input to squeeze in more detail from a larget input. moved clip processing to the dataloader for speed.

This commit is contained in:
Jaret Burkett
2024-01-04 12:59:38 -07:00
parent 65c08b09c3
commit 645b27f97a
8 changed files with 253 additions and 64 deletions

View File

@@ -5,12 +5,13 @@ from PIL import Image
from torch.nn import Parameter
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from toolkit.models.clip_pre_processor import CLIPImagePreProcessor
from toolkit.paths import REPOS_ROOT
from toolkit.saving import load_ip_adapter_model
from toolkit.train_tools import get_torch_dtype
sys.path.append(REPOS_ROOT)
from typing import TYPE_CHECKING, Union, Iterator, Mapping, Any, Tuple, List
from typing import TYPE_CHECKING, Union, Iterator, Mapping, Any, Tuple, List, Optional
from collections import OrderedDict
from ipadapter.ip_adapter.attention_processor import AttnProcessor, IPAttnProcessor, IPAttnProcessor2_0, \
AttnProcessor2_0
@@ -163,7 +164,9 @@ class IPAdapter(torch.nn.Module):
self.config = adapter_config
self.sd_ref: weakref.ref = weakref.ref(sd)
self.device = self.sd_ref().unet.device
if self.config.image_encoder_arch == 'clip':
self.preprocessor: Optional[CLIPImagePreProcessor] = None
self.input_size = 224
if self.config.image_encoder_arch == 'clip' or self.config.image_encoder_arch == 'clip+':
try:
self.clip_image_processor = CLIPImageProcessor.from_pretrained(adapter_config.image_encoder_path)
except EnvironmentError:
@@ -176,7 +179,8 @@ class IPAdapter(torch.nn.Module):
self.clip_image_processor = ViTFeatureExtractor.from_pretrained(adapter_config.image_encoder_path)
except EnvironmentError:
self.clip_image_processor = ViTFeatureExtractor()
self.image_encoder = ViTForImageClassification.from_pretrained(adapter_config.image_encoder_path).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
self.image_encoder = ViTForImageClassification.from_pretrained(adapter_config.image_encoder_path).to(
self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
elif self.config.image_encoder_arch == 'safe':
try:
self.clip_image_processor = SAFEImageProcessor.from_pretrained(adapter_config.image_encoder_path)
@@ -220,6 +224,26 @@ class IPAdapter(torch.nn.Module):
).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
else:
raise ValueError(f"unknown image encoder arch: {adapter_config.image_encoder_arch}")
self.input_size = self.image_encoder.config.image_size
if self.config.image_encoder_arch == 'clip+':
# self.clip_image_processor.config
# We do a 3x downscale of the image, so we need to adjust the input size
preprocessor_input_size = self.image_encoder.config.image_size * 3
# update the preprocessor so images come in at the right size
self.clip_image_processor.size['shortest_edge'] = preprocessor_input_size
self.clip_image_processor.crop_size['height'] = preprocessor_input_size
self.clip_image_processor.crop_size['width'] = preprocessor_input_size
self.preprocessor = CLIPImagePreProcessor(
input_size=preprocessor_input_size,
clip_input_size=self.image_encoder.config.image_size,
downscale_factor=6
)
self.input_size = self.clip_image_processor.size['shortest_edge']
self.current_scale = 1.0
self.is_active = True
if adapter_config.type == 'ip':
@@ -232,12 +256,9 @@ class IPAdapter(torch.nn.Module):
elif adapter_config.type == 'ip+':
heads = 12 if not sd.is_xl else 20
dim = sd.unet.config['cross_attention_dim'] if not sd.is_xl else 1280
embedding_dim = self.image_encoder.config.hidden_size if not self.config.image_encoder_arch == "convnext" else self.image_encoder.config.hidden_sizes[-1]
embedding_dim = self.image_encoder.config.hidden_size if not self.config.image_encoder_arch == "convnext" else \
self.image_encoder.config.hidden_sizes[-1]
# if self.config.image_encoder_arch == 'safe':
# embedding_dim = self.config.safe_tokens
# size mismatch for latents: copying a param with shape torch.Size([1, 16, 1280]) from checkpoint, the shape in current model is torch.Size([1, 16, 2048]).
# size mismatch for latents: copying a param with shape torch.Size([1, 32, 2048]) from checkpoint, the shape in current model is torch.Size([1, 16, 1280])
# ip-adapter-plus
image_proj_model = Resampler(
dim=dim,
@@ -312,6 +333,8 @@ class IPAdapter(torch.nn.Module):
self.image_encoder.to(*args, **kwargs)
self.image_proj_model.to(*args, **kwargs)
self.adapter_modules.to(*args, **kwargs)
if self.preprocessor is not None:
self.preprocessor.to(*args, **kwargs)
return self
def load_ip_adapter(self, state_dict: Union[OrderedDict, dict]):
@@ -320,6 +343,8 @@ class IPAdapter(torch.nn.Module):
ip_layers.load_state_dict(state_dict["ip_adapter"])
if self.config.train_image_encoder and 'image_encoder' in state_dict:
self.image_encoder.load_state_dict(state_dict["image_encoder"])
if self.preprocessor is not None and 'preprocessor' in state_dict:
self.preprocessor.load_state_dict(state_dict["preprocessor"])
# def load_state_dict(self, state_dict: Union[OrderedDict, dict]):
# self.load_ip_adapter(state_dict)
@@ -330,6 +355,8 @@ class IPAdapter(torch.nn.Module):
state_dict["ip_adapter"] = self.adapter_modules.state_dict()
if self.config.train_image_encoder:
state_dict["image_encoder"] = self.image_encoder.state_dict()
if self.preprocessor is not None:
state_dict["preprocessor"] = self.preprocessor.state_dict()
return state_dict
def get_scale(self):
@@ -341,53 +368,74 @@ class IPAdapter(torch.nn.Module):
if isinstance(attn_processor, CustomIPAttentionProcessor):
attn_processor.scale = scale
@torch.no_grad()
def get_clip_image_embeds_from_pil(self, pil_image: Union[Image.Image, List[Image.Image]],
drop=False) -> torch.Tensor:
# todo: add support for sdxl
if isinstance(pil_image, Image.Image):
pil_image = [pil_image]
clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
clip_image = clip_image.to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
if drop:
clip_image = clip_image * 0
clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
return clip_image_embeds
# @torch.no_grad()
# def get_clip_image_embeds_from_pil(self, pil_image: Union[Image.Image, List[Image.Image]],
# drop=False) -> torch.Tensor:
# # todo: add support for sdxl
# if isinstance(pil_image, Image.Image):
# pil_image = [pil_image]
# clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
# clip_image = clip_image.to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
# if drop:
# clip_image = clip_image * 0
# clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
# return clip_image_embeds
def get_clip_image_embeds_from_tensors(self, tensors_0_1: torch.Tensor, drop=False,
is_training=False) -> torch.Tensor:
def get_clip_image_embeds_from_tensors(
self,
tensors_0_1: torch.Tensor,
drop=False,
is_training=False,
has_been_preprocessed=False
) -> torch.Tensor:
with torch.no_grad():
# tensors should be 0-1
# todo: add support for sdxl
if tensors_0_1.ndim == 3:
tensors_0_1 = tensors_0_1.unsqueeze(0)
# training tensors are 0 - 1
tensors_0_1 = tensors_0_1.to(self.device, dtype=torch.float16)
# if images are out of this range throw error
if tensors_0_1.min() < -0.3 or tensors_0_1.max() > 1.3:
raise ValueError("image tensor values must be between 0 and 1. Got min: {}, max: {}".format(
tensors_0_1.min(), tensors_0_1.max()
))
clip_image = self.clip_image_processor(
images=tensors_0_1,
return_tensors="pt",
do_resize=True,
do_rescale=False,
).pixel_values
# on training the clip image is created in the dataloader
if not has_been_preprocessed:
# tensors should be 0-1
if tensors_0_1.ndim == 3:
tensors_0_1 = tensors_0_1.unsqueeze(0)
# training tensors are 0 - 1
tensors_0_1 = tensors_0_1.to(self.device, dtype=torch.float16)
# if images are out of this range throw error
if tensors_0_1.min() < -0.3 or tensors_0_1.max() > 1.3:
raise ValueError("image tensor values must be between 0 and 1. Got min: {}, max: {}".format(
tensors_0_1.min(), tensors_0_1.max()
))
clip_image = self.clip_image_processor(
images=tensors_0_1,
return_tensors="pt",
do_resize=True,
do_rescale=False,
).pixel_values
else:
clip_image = tensors_0_1
clip_image = clip_image.to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)).detach()
if drop:
clip_image = clip_image * 0
with torch.set_grad_enabled(is_training):
if is_training:
self.image_encoder.train()
clip_output = self.image_encoder(clip_image.requires_grad_(True)
, output_hidden_states=True)
clip_image = clip_image.requires_grad_(True)
if self.preprocessor is not None:
clip_image = self.preprocessor(clip_image)
clip_output = self.image_encoder(
clip_image,
output_hidden_states=True
)
else:
self.image_encoder.eval()
clip_output = self.image_encoder(clip_image, output_hidden_states=True)
if self.preprocessor is not None:
clip_image = self.preprocessor(clip_image)
clip_output = self.image_encoder(
clip_image, output_hidden_states=True
)
clip_image_embeds = clip_output.hidden_states[-2]
if self.config.type.startswith('ip+'):
# they skip last layer for ip+
# https://github.com/tencent-ailab/IP-Adapter/blob/f4b6742db35ea6d81c7b829a55b0a312c7f5a677/tutorial_train_plus.py#L403C26-L403C26
clip_image_embeds = clip_output.hidden_states[-2]
else:
clip_image_embeds = clip_output.image_embeds
return clip_image_embeds
# use drop for prompt dropout, or negatives
@@ -403,13 +451,8 @@ class IPAdapter(torch.nn.Module):
yield from self.image_proj_model.parameters(recurse)
if self.config.train_image_encoder:
yield from self.image_encoder.parameters(recurse)
# if self.config.train_image_encoder:
# yield from self.image_encoder.parameters(recurse)
# self.image_encoder.train()
# else:
# for attn_processor in self.adapter_modules:
# yield from attn_processor.parameters(recurse)
# yield from self.image_proj_model.parameters(recurse)
if self.preprocessor is not None:
yield from self.preprocessor.parameters(recurse)
def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
strict = False
@@ -417,6 +460,8 @@ class IPAdapter(torch.nn.Module):
self.adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=strict)
if self.config.train_image_encoder and 'image_encoder' in state_dict:
self.image_encoder.load_state_dict(state_dict["image_encoder"], strict=strict)
if self.config.image_encoder_arch == 'clip+' and 'preprocessor' in state_dict:
self.preprocessor.load_state_dict(state_dict["preprocessor"], strict=strict)
def enable_gradient_checkpointing(self):
self.image_encoder.gradient_checkpointing = True