mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-29 18:51:37 +00:00
Bug fixes with ip adapter training. Made a clip pre processor that can be trained with ip adapter to help augment the clip input to squeeze in more detail from a larget input. moved clip processing to the dataloader for speed.
This commit is contained in:
@@ -5,12 +5,13 @@ from PIL import Image
|
||||
from torch.nn import Parameter
|
||||
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
||||
|
||||
from toolkit.models.clip_pre_processor import CLIPImagePreProcessor
|
||||
from toolkit.paths import REPOS_ROOT
|
||||
from toolkit.saving import load_ip_adapter_model
|
||||
from toolkit.train_tools import get_torch_dtype
|
||||
|
||||
sys.path.append(REPOS_ROOT)
|
||||
from typing import TYPE_CHECKING, Union, Iterator, Mapping, Any, Tuple, List
|
||||
from typing import TYPE_CHECKING, Union, Iterator, Mapping, Any, Tuple, List, Optional
|
||||
from collections import OrderedDict
|
||||
from ipadapter.ip_adapter.attention_processor import AttnProcessor, IPAttnProcessor, IPAttnProcessor2_0, \
|
||||
AttnProcessor2_0
|
||||
@@ -163,7 +164,9 @@ class IPAdapter(torch.nn.Module):
|
||||
self.config = adapter_config
|
||||
self.sd_ref: weakref.ref = weakref.ref(sd)
|
||||
self.device = self.sd_ref().unet.device
|
||||
if self.config.image_encoder_arch == 'clip':
|
||||
self.preprocessor: Optional[CLIPImagePreProcessor] = None
|
||||
self.input_size = 224
|
||||
if self.config.image_encoder_arch == 'clip' or self.config.image_encoder_arch == 'clip+':
|
||||
try:
|
||||
self.clip_image_processor = CLIPImageProcessor.from_pretrained(adapter_config.image_encoder_path)
|
||||
except EnvironmentError:
|
||||
@@ -176,7 +179,8 @@ class IPAdapter(torch.nn.Module):
|
||||
self.clip_image_processor = ViTFeatureExtractor.from_pretrained(adapter_config.image_encoder_path)
|
||||
except EnvironmentError:
|
||||
self.clip_image_processor = ViTFeatureExtractor()
|
||||
self.image_encoder = ViTForImageClassification.from_pretrained(adapter_config.image_encoder_path).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
|
||||
self.image_encoder = ViTForImageClassification.from_pretrained(adapter_config.image_encoder_path).to(
|
||||
self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
|
||||
elif self.config.image_encoder_arch == 'safe':
|
||||
try:
|
||||
self.clip_image_processor = SAFEImageProcessor.from_pretrained(adapter_config.image_encoder_path)
|
||||
@@ -220,6 +224,26 @@ class IPAdapter(torch.nn.Module):
|
||||
).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
|
||||
else:
|
||||
raise ValueError(f"unknown image encoder arch: {adapter_config.image_encoder_arch}")
|
||||
|
||||
self.input_size = self.image_encoder.config.image_size
|
||||
|
||||
if self.config.image_encoder_arch == 'clip+':
|
||||
# self.clip_image_processor.config
|
||||
# We do a 3x downscale of the image, so we need to adjust the input size
|
||||
preprocessor_input_size = self.image_encoder.config.image_size * 3
|
||||
|
||||
# update the preprocessor so images come in at the right size
|
||||
self.clip_image_processor.size['shortest_edge'] = preprocessor_input_size
|
||||
self.clip_image_processor.crop_size['height'] = preprocessor_input_size
|
||||
self.clip_image_processor.crop_size['width'] = preprocessor_input_size
|
||||
|
||||
self.preprocessor = CLIPImagePreProcessor(
|
||||
input_size=preprocessor_input_size,
|
||||
clip_input_size=self.image_encoder.config.image_size,
|
||||
downscale_factor=6
|
||||
)
|
||||
|
||||
self.input_size = self.clip_image_processor.size['shortest_edge']
|
||||
self.current_scale = 1.0
|
||||
self.is_active = True
|
||||
if adapter_config.type == 'ip':
|
||||
@@ -232,12 +256,9 @@ class IPAdapter(torch.nn.Module):
|
||||
elif adapter_config.type == 'ip+':
|
||||
heads = 12 if not sd.is_xl else 20
|
||||
dim = sd.unet.config['cross_attention_dim'] if not sd.is_xl else 1280
|
||||
embedding_dim = self.image_encoder.config.hidden_size if not self.config.image_encoder_arch == "convnext" else self.image_encoder.config.hidden_sizes[-1]
|
||||
embedding_dim = self.image_encoder.config.hidden_size if not self.config.image_encoder_arch == "convnext" else \
|
||||
self.image_encoder.config.hidden_sizes[-1]
|
||||
|
||||
# if self.config.image_encoder_arch == 'safe':
|
||||
# embedding_dim = self.config.safe_tokens
|
||||
# size mismatch for latents: copying a param with shape torch.Size([1, 16, 1280]) from checkpoint, the shape in current model is torch.Size([1, 16, 2048]).
|
||||
# size mismatch for latents: copying a param with shape torch.Size([1, 32, 2048]) from checkpoint, the shape in current model is torch.Size([1, 16, 1280])
|
||||
# ip-adapter-plus
|
||||
image_proj_model = Resampler(
|
||||
dim=dim,
|
||||
@@ -312,6 +333,8 @@ class IPAdapter(torch.nn.Module):
|
||||
self.image_encoder.to(*args, **kwargs)
|
||||
self.image_proj_model.to(*args, **kwargs)
|
||||
self.adapter_modules.to(*args, **kwargs)
|
||||
if self.preprocessor is not None:
|
||||
self.preprocessor.to(*args, **kwargs)
|
||||
return self
|
||||
|
||||
def load_ip_adapter(self, state_dict: Union[OrderedDict, dict]):
|
||||
@@ -320,6 +343,8 @@ class IPAdapter(torch.nn.Module):
|
||||
ip_layers.load_state_dict(state_dict["ip_adapter"])
|
||||
if self.config.train_image_encoder and 'image_encoder' in state_dict:
|
||||
self.image_encoder.load_state_dict(state_dict["image_encoder"])
|
||||
if self.preprocessor is not None and 'preprocessor' in state_dict:
|
||||
self.preprocessor.load_state_dict(state_dict["preprocessor"])
|
||||
|
||||
# def load_state_dict(self, state_dict: Union[OrderedDict, dict]):
|
||||
# self.load_ip_adapter(state_dict)
|
||||
@@ -330,6 +355,8 @@ class IPAdapter(torch.nn.Module):
|
||||
state_dict["ip_adapter"] = self.adapter_modules.state_dict()
|
||||
if self.config.train_image_encoder:
|
||||
state_dict["image_encoder"] = self.image_encoder.state_dict()
|
||||
if self.preprocessor is not None:
|
||||
state_dict["preprocessor"] = self.preprocessor.state_dict()
|
||||
return state_dict
|
||||
|
||||
def get_scale(self):
|
||||
@@ -341,53 +368,74 @@ class IPAdapter(torch.nn.Module):
|
||||
if isinstance(attn_processor, CustomIPAttentionProcessor):
|
||||
attn_processor.scale = scale
|
||||
|
||||
@torch.no_grad()
|
||||
def get_clip_image_embeds_from_pil(self, pil_image: Union[Image.Image, List[Image.Image]],
|
||||
drop=False) -> torch.Tensor:
|
||||
# todo: add support for sdxl
|
||||
if isinstance(pil_image, Image.Image):
|
||||
pil_image = [pil_image]
|
||||
clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
|
||||
clip_image = clip_image.to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
|
||||
if drop:
|
||||
clip_image = clip_image * 0
|
||||
clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
|
||||
return clip_image_embeds
|
||||
# @torch.no_grad()
|
||||
# def get_clip_image_embeds_from_pil(self, pil_image: Union[Image.Image, List[Image.Image]],
|
||||
# drop=False) -> torch.Tensor:
|
||||
# # todo: add support for sdxl
|
||||
# if isinstance(pil_image, Image.Image):
|
||||
# pil_image = [pil_image]
|
||||
# clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
|
||||
# clip_image = clip_image.to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
|
||||
# if drop:
|
||||
# clip_image = clip_image * 0
|
||||
# clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
|
||||
# return clip_image_embeds
|
||||
|
||||
def get_clip_image_embeds_from_tensors(self, tensors_0_1: torch.Tensor, drop=False,
|
||||
is_training=False) -> torch.Tensor:
|
||||
def get_clip_image_embeds_from_tensors(
|
||||
self,
|
||||
tensors_0_1: torch.Tensor,
|
||||
drop=False,
|
||||
is_training=False,
|
||||
has_been_preprocessed=False
|
||||
) -> torch.Tensor:
|
||||
with torch.no_grad():
|
||||
# tensors should be 0-1
|
||||
# todo: add support for sdxl
|
||||
if tensors_0_1.ndim == 3:
|
||||
tensors_0_1 = tensors_0_1.unsqueeze(0)
|
||||
# training tensors are 0 - 1
|
||||
tensors_0_1 = tensors_0_1.to(self.device, dtype=torch.float16)
|
||||
# if images are out of this range throw error
|
||||
if tensors_0_1.min() < -0.3 or tensors_0_1.max() > 1.3:
|
||||
raise ValueError("image tensor values must be between 0 and 1. Got min: {}, max: {}".format(
|
||||
tensors_0_1.min(), tensors_0_1.max()
|
||||
))
|
||||
|
||||
clip_image = self.clip_image_processor(
|
||||
images=tensors_0_1,
|
||||
return_tensors="pt",
|
||||
do_resize=True,
|
||||
do_rescale=False,
|
||||
).pixel_values
|
||||
# on training the clip image is created in the dataloader
|
||||
if not has_been_preprocessed:
|
||||
# tensors should be 0-1
|
||||
if tensors_0_1.ndim == 3:
|
||||
tensors_0_1 = tensors_0_1.unsqueeze(0)
|
||||
# training tensors are 0 - 1
|
||||
tensors_0_1 = tensors_0_1.to(self.device, dtype=torch.float16)
|
||||
# if images are out of this range throw error
|
||||
if tensors_0_1.min() < -0.3 or tensors_0_1.max() > 1.3:
|
||||
raise ValueError("image tensor values must be between 0 and 1. Got min: {}, max: {}".format(
|
||||
tensors_0_1.min(), tensors_0_1.max()
|
||||
))
|
||||
clip_image = self.clip_image_processor(
|
||||
images=tensors_0_1,
|
||||
return_tensors="pt",
|
||||
do_resize=True,
|
||||
do_rescale=False,
|
||||
).pixel_values
|
||||
else:
|
||||
clip_image = tensors_0_1
|
||||
clip_image = clip_image.to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)).detach()
|
||||
if drop:
|
||||
clip_image = clip_image * 0
|
||||
with torch.set_grad_enabled(is_training):
|
||||
if is_training:
|
||||
self.image_encoder.train()
|
||||
clip_output = self.image_encoder(clip_image.requires_grad_(True)
|
||||
, output_hidden_states=True)
|
||||
clip_image = clip_image.requires_grad_(True)
|
||||
if self.preprocessor is not None:
|
||||
clip_image = self.preprocessor(clip_image)
|
||||
clip_output = self.image_encoder(
|
||||
clip_image,
|
||||
output_hidden_states=True
|
||||
)
|
||||
else:
|
||||
self.image_encoder.eval()
|
||||
clip_output = self.image_encoder(clip_image, output_hidden_states=True)
|
||||
if self.preprocessor is not None:
|
||||
clip_image = self.preprocessor(clip_image)
|
||||
clip_output = self.image_encoder(
|
||||
clip_image, output_hidden_states=True
|
||||
)
|
||||
|
||||
clip_image_embeds = clip_output.hidden_states[-2]
|
||||
if self.config.type.startswith('ip+'):
|
||||
# they skip last layer for ip+
|
||||
# https://github.com/tencent-ailab/IP-Adapter/blob/f4b6742db35ea6d81c7b829a55b0a312c7f5a677/tutorial_train_plus.py#L403C26-L403C26
|
||||
clip_image_embeds = clip_output.hidden_states[-2]
|
||||
else:
|
||||
clip_image_embeds = clip_output.image_embeds
|
||||
return clip_image_embeds
|
||||
|
||||
# use drop for prompt dropout, or negatives
|
||||
@@ -403,13 +451,8 @@ class IPAdapter(torch.nn.Module):
|
||||
yield from self.image_proj_model.parameters(recurse)
|
||||
if self.config.train_image_encoder:
|
||||
yield from self.image_encoder.parameters(recurse)
|
||||
# if self.config.train_image_encoder:
|
||||
# yield from self.image_encoder.parameters(recurse)
|
||||
# self.image_encoder.train()
|
||||
# else:
|
||||
# for attn_processor in self.adapter_modules:
|
||||
# yield from attn_processor.parameters(recurse)
|
||||
# yield from self.image_proj_model.parameters(recurse)
|
||||
if self.preprocessor is not None:
|
||||
yield from self.preprocessor.parameters(recurse)
|
||||
|
||||
def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
|
||||
strict = False
|
||||
@@ -417,6 +460,8 @@ class IPAdapter(torch.nn.Module):
|
||||
self.adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=strict)
|
||||
if self.config.train_image_encoder and 'image_encoder' in state_dict:
|
||||
self.image_encoder.load_state_dict(state_dict["image_encoder"], strict=strict)
|
||||
if self.config.image_encoder_arch == 'clip+' and 'preprocessor' in state_dict:
|
||||
self.preprocessor.load_state_dict(state_dict["preprocessor"], strict=strict)
|
||||
|
||||
def enable_gradient_checkpointing(self):
|
||||
self.image_encoder.gradient_checkpointing = True
|
||||
|
||||
Reference in New Issue
Block a user