Bug fixes with ip adapter training. Made a clip pre processor that can be trained with ip adapter to help augment the clip input to squeeze in more detail from a larget input. moved clip processing to the dataloader for speed.

This commit is contained in:
Jaret Burkett
2024-01-04 12:59:38 -07:00
parent 65c08b09c3
commit 645b27f97a
8 changed files with 253 additions and 64 deletions

View File

@@ -666,7 +666,7 @@ class SDTrainer(BaseSDTrainProcess):
conditional_embeddings=embeds_to_use.to(self.device_torch, dtype=dtype).detach(),
unconditional_embeddings=unconditional_embeds,
timestep=timesteps,
guidance_scale=1.0,
guidance_scale=self.train_config.cfg_scale,
**pred_kwargs # adapter residuals in here
)
if was_unet_training:
@@ -980,40 +980,45 @@ class SDTrainer(BaseSDTrainProcess):
if self.adapter and isinstance(self.adapter, IPAdapter):
with self.timer('encode_adapter_embeds'):
image_size = self.adapter.input_size
if is_reg:
# we will zero it out in the img embedder
clip_images = torch.zeros(
(noisy_latents.shape[0], 3, 512, 512),
(noisy_latents.shape[0], 3, image_size, image_size),
device=self.device_torch, dtype=dtype
).detach()
# drop will zero it out
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(
clip_images,
drop=True,
is_training=True
is_training=True,
has_been_preprocessed=True
)
if self.train_config.do_cfg:
unconditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(
torch.zeros(
(noisy_latents.shape[0], 3, 512, 512),
(noisy_latents.shape[0], 3, image_size, image_size),
device=self.device_torch, dtype=dtype
).detach(),
is_training=True,
drop=True
drop=True,
has_been_preprocessed=True
)
elif has_clip_image:
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(
clip_images.detach().to(self.device_torch, dtype=dtype),
is_training=True
is_training=True,
has_been_preprocessed=True
)
if self.train_config.do_cfg:
unconditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(
torch.zeros(
(noisy_latents.shape[0], 3, 512, 512),
(noisy_latents.shape[0], 3, image_size, image_size),
device=self.device_torch, dtype=dtype
).detach(),
is_training=True,
drop=True
drop=True,
has_been_preprocessed=True
)
else:
raise ValueError("Adapter images now must be loaded with dataloader or be a reg image")
@@ -1094,7 +1099,7 @@ class SDTrainer(BaseSDTrainProcess):
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype),
unconditional_embeddings=unconditional_embeds,
timestep=timesteps,
guidance_scale=1.0,
guidance_scale=self.train_config.cfg_scale,
**pred_kwargs
)
self.after_unet_predict()

View File

@@ -678,6 +678,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
if batch.tensor is not None:
imgs = batch.tensor
imgs = imgs.to(self.device_torch, dtype=dtype)
if self.train_config.img_multiplier is not None:
imgs = imgs * self.train_config.img_multiplier
if batch.latents is not None:
latents = batch.latents.to(self.device_torch, dtype=dtype)
batch.latents = latents
@@ -1113,6 +1115,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.snr_gos.scale.data = torch.tensor(json_data['scale'], device=self.device_torch)
self.snr_gos.gamma.data = torch.tensor(json_data['gamma'], device=self.device_torch)
# load the adapters before the dataset as they may use the clip encoders
if self.adapter_config is not None:
self.setup_adapter()
flush()
### HOOk ###
@@ -1249,7 +1254,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
flush()
if self.adapter_config is not None:
self.setup_adapter()
# self.setup_adapter()
# set trainable params
params.append({
'params': self.adapter.parameters(),

View File

@@ -272,6 +272,7 @@ class TrainConfig:
self.dynamic_noise_offset = kwargs.get('dynamic_noise_offset', False)
self.do_cfg = kwargs.get('do_cfg', False)
self.cfg_scale = kwargs.get('cfg_scale', 1.0)
class ModelConfig:

View File

@@ -431,6 +431,7 @@ class AiToolkitDataset(LatentCachingMixin, BucketsMixin, CaptionMixin, Dataset):
for file in tqdm(file_list):
try:
file_item = FileItemDTO(
sd=self.sd,
path=file,
dataset_config=dataset_config,
dataloader_transforms=self.transform,

View File

@@ -1,3 +1,5 @@
import weakref
from _weakref import ReferenceType
from typing import TYPE_CHECKING, List, Union
import torch
import random
@@ -10,8 +12,10 @@ from toolkit.dataloader_mixins import CaptionProcessingDTOMixin, ImageProcessing
ControlFileItemDTOMixin, ArgBreakMixin, PoiFileItemDTOMixin, MaskFileItemDTOMixin, AugmentationFileItemDTOMixin, \
UnconditionalFileItemDTOMixin, ClipImageFileItemDTOMixin
if TYPE_CHECKING:
from toolkit.config_modules import DatasetConfig
from toolkit.stable_diffusion_model import StableDiffusion
printed_messages = []

View File

@@ -12,6 +12,7 @@ import numpy as np
import torch
from safetensors.torch import load_file, save_file
from tqdm import tqdm
from transformers import CLIPImageProcessor
from toolkit.basic import flush, value_map
from toolkit.buckets import get_bucket_for_image_size, get_resolution
@@ -27,7 +28,7 @@ from toolkit.train_tools import get_torch_dtype
if TYPE_CHECKING:
from toolkit.data_loader import AiToolkitDataset
from toolkit.data_transfer_object.data_loader import FileItemDTO
from toolkit.stable_diffusion_model import StableDiffusion
# def get_associated_caption_from_img_path(img_path):
# https://demo.albumentations.ai/
@@ -565,8 +566,13 @@ class ClipImageFileItemDTOMixin:
self.clip_image_tensor: Union[torch.Tensor, None] = None
self.has_clip_augmentations = False
self.clip_image_aug_transform: Union[None, A.Compose] = None
self.clip_image_processor: Union[None, CLIPImageProcessor] = None
dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None)
if dataset_config.clip_image_path is not None:
# copy the clip image processor so the dataloader can do it
sd = kwargs.get('sd', None)
if hasattr(sd.adapter, 'clip_image_processor'):
self.clip_image_processor = sd.adapter.clip_image_processor
# find the control image path
clip_image_path = dataset_config.clip_image_path
# we are using control images
@@ -632,13 +638,22 @@ class ClipImageFileItemDTOMixin:
print(f"Error: {e}")
print(f"Error loading image: {self.clip_image_path}")
# we just scale them to 512x512:
img = img.resize((512, 512), Image.BICUBIC)
if self.has_clip_augmentations:
self.clip_image_tensor = self.augment_clip_image(img, transform=None)
else:
self.clip_image_tensor = transforms.ToTensor()(img)
if self.clip_image_processor is not None:
# run it
tensors_0_1 = self.clip_image_tensor.to(dtype=torch.float16)
clip_out = self.clip_image_processor(
images=tensors_0_1,
return_tensors="pt",
do_resize=True,
do_rescale=False,
).pixel_values
self.clip_image_tensor = clip_out.squeeze(0).clone().detach()
def cleanup_clip_image(self: 'FileItemDTO'):
self.clip_image_tensor = None

View File

@@ -5,12 +5,13 @@ from PIL import Image
from torch.nn import Parameter
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from toolkit.models.clip_pre_processor import CLIPImagePreProcessor
from toolkit.paths import REPOS_ROOT
from toolkit.saving import load_ip_adapter_model
from toolkit.train_tools import get_torch_dtype
sys.path.append(REPOS_ROOT)
from typing import TYPE_CHECKING, Union, Iterator, Mapping, Any, Tuple, List
from typing import TYPE_CHECKING, Union, Iterator, Mapping, Any, Tuple, List, Optional
from collections import OrderedDict
from ipadapter.ip_adapter.attention_processor import AttnProcessor, IPAttnProcessor, IPAttnProcessor2_0, \
AttnProcessor2_0
@@ -163,7 +164,9 @@ class IPAdapter(torch.nn.Module):
self.config = adapter_config
self.sd_ref: weakref.ref = weakref.ref(sd)
self.device = self.sd_ref().unet.device
if self.config.image_encoder_arch == 'clip':
self.preprocessor: Optional[CLIPImagePreProcessor] = None
self.input_size = 224
if self.config.image_encoder_arch == 'clip' or self.config.image_encoder_arch == 'clip+':
try:
self.clip_image_processor = CLIPImageProcessor.from_pretrained(adapter_config.image_encoder_path)
except EnvironmentError:
@@ -176,7 +179,8 @@ class IPAdapter(torch.nn.Module):
self.clip_image_processor = ViTFeatureExtractor.from_pretrained(adapter_config.image_encoder_path)
except EnvironmentError:
self.clip_image_processor = ViTFeatureExtractor()
self.image_encoder = ViTForImageClassification.from_pretrained(adapter_config.image_encoder_path).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
self.image_encoder = ViTForImageClassification.from_pretrained(adapter_config.image_encoder_path).to(
self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
elif self.config.image_encoder_arch == 'safe':
try:
self.clip_image_processor = SAFEImageProcessor.from_pretrained(adapter_config.image_encoder_path)
@@ -220,6 +224,26 @@ class IPAdapter(torch.nn.Module):
).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
else:
raise ValueError(f"unknown image encoder arch: {adapter_config.image_encoder_arch}")
self.input_size = self.image_encoder.config.image_size
if self.config.image_encoder_arch == 'clip+':
# self.clip_image_processor.config
# We do a 3x downscale of the image, so we need to adjust the input size
preprocessor_input_size = self.image_encoder.config.image_size * 3
# update the preprocessor so images come in at the right size
self.clip_image_processor.size['shortest_edge'] = preprocessor_input_size
self.clip_image_processor.crop_size['height'] = preprocessor_input_size
self.clip_image_processor.crop_size['width'] = preprocessor_input_size
self.preprocessor = CLIPImagePreProcessor(
input_size=preprocessor_input_size,
clip_input_size=self.image_encoder.config.image_size,
downscale_factor=6
)
self.input_size = self.clip_image_processor.size['shortest_edge']
self.current_scale = 1.0
self.is_active = True
if adapter_config.type == 'ip':
@@ -232,12 +256,9 @@ class IPAdapter(torch.nn.Module):
elif adapter_config.type == 'ip+':
heads = 12 if not sd.is_xl else 20
dim = sd.unet.config['cross_attention_dim'] if not sd.is_xl else 1280
embedding_dim = self.image_encoder.config.hidden_size if not self.config.image_encoder_arch == "convnext" else self.image_encoder.config.hidden_sizes[-1]
embedding_dim = self.image_encoder.config.hidden_size if not self.config.image_encoder_arch == "convnext" else \
self.image_encoder.config.hidden_sizes[-1]
# if self.config.image_encoder_arch == 'safe':
# embedding_dim = self.config.safe_tokens
# size mismatch for latents: copying a param with shape torch.Size([1, 16, 1280]) from checkpoint, the shape in current model is torch.Size([1, 16, 2048]).
# size mismatch for latents: copying a param with shape torch.Size([1, 32, 2048]) from checkpoint, the shape in current model is torch.Size([1, 16, 1280])
# ip-adapter-plus
image_proj_model = Resampler(
dim=dim,
@@ -312,6 +333,8 @@ class IPAdapter(torch.nn.Module):
self.image_encoder.to(*args, **kwargs)
self.image_proj_model.to(*args, **kwargs)
self.adapter_modules.to(*args, **kwargs)
if self.preprocessor is not None:
self.preprocessor.to(*args, **kwargs)
return self
def load_ip_adapter(self, state_dict: Union[OrderedDict, dict]):
@@ -320,6 +343,8 @@ class IPAdapter(torch.nn.Module):
ip_layers.load_state_dict(state_dict["ip_adapter"])
if self.config.train_image_encoder and 'image_encoder' in state_dict:
self.image_encoder.load_state_dict(state_dict["image_encoder"])
if self.preprocessor is not None and 'preprocessor' in state_dict:
self.preprocessor.load_state_dict(state_dict["preprocessor"])
# def load_state_dict(self, state_dict: Union[OrderedDict, dict]):
# self.load_ip_adapter(state_dict)
@@ -330,6 +355,8 @@ class IPAdapter(torch.nn.Module):
state_dict["ip_adapter"] = self.adapter_modules.state_dict()
if self.config.train_image_encoder:
state_dict["image_encoder"] = self.image_encoder.state_dict()
if self.preprocessor is not None:
state_dict["preprocessor"] = self.preprocessor.state_dict()
return state_dict
def get_scale(self):
@@ -341,53 +368,74 @@ class IPAdapter(torch.nn.Module):
if isinstance(attn_processor, CustomIPAttentionProcessor):
attn_processor.scale = scale
@torch.no_grad()
def get_clip_image_embeds_from_pil(self, pil_image: Union[Image.Image, List[Image.Image]],
drop=False) -> torch.Tensor:
# todo: add support for sdxl
if isinstance(pil_image, Image.Image):
pil_image = [pil_image]
clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
clip_image = clip_image.to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
if drop:
clip_image = clip_image * 0
clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
return clip_image_embeds
# @torch.no_grad()
# def get_clip_image_embeds_from_pil(self, pil_image: Union[Image.Image, List[Image.Image]],
# drop=False) -> torch.Tensor:
# # todo: add support for sdxl
# if isinstance(pil_image, Image.Image):
# pil_image = [pil_image]
# clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
# clip_image = clip_image.to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
# if drop:
# clip_image = clip_image * 0
# clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
# return clip_image_embeds
def get_clip_image_embeds_from_tensors(self, tensors_0_1: torch.Tensor, drop=False,
is_training=False) -> torch.Tensor:
def get_clip_image_embeds_from_tensors(
self,
tensors_0_1: torch.Tensor,
drop=False,
is_training=False,
has_been_preprocessed=False
) -> torch.Tensor:
with torch.no_grad():
# tensors should be 0-1
# todo: add support for sdxl
if tensors_0_1.ndim == 3:
tensors_0_1 = tensors_0_1.unsqueeze(0)
# training tensors are 0 - 1
tensors_0_1 = tensors_0_1.to(self.device, dtype=torch.float16)
# if images are out of this range throw error
if tensors_0_1.min() < -0.3 or tensors_0_1.max() > 1.3:
raise ValueError("image tensor values must be between 0 and 1. Got min: {}, max: {}".format(
tensors_0_1.min(), tensors_0_1.max()
))
clip_image = self.clip_image_processor(
images=tensors_0_1,
return_tensors="pt",
do_resize=True,
do_rescale=False,
).pixel_values
# on training the clip image is created in the dataloader
if not has_been_preprocessed:
# tensors should be 0-1
if tensors_0_1.ndim == 3:
tensors_0_1 = tensors_0_1.unsqueeze(0)
# training tensors are 0 - 1
tensors_0_1 = tensors_0_1.to(self.device, dtype=torch.float16)
# if images are out of this range throw error
if tensors_0_1.min() < -0.3 or tensors_0_1.max() > 1.3:
raise ValueError("image tensor values must be between 0 and 1. Got min: {}, max: {}".format(
tensors_0_1.min(), tensors_0_1.max()
))
clip_image = self.clip_image_processor(
images=tensors_0_1,
return_tensors="pt",
do_resize=True,
do_rescale=False,
).pixel_values
else:
clip_image = tensors_0_1
clip_image = clip_image.to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)).detach()
if drop:
clip_image = clip_image * 0
with torch.set_grad_enabled(is_training):
if is_training:
self.image_encoder.train()
clip_output = self.image_encoder(clip_image.requires_grad_(True)
, output_hidden_states=True)
clip_image = clip_image.requires_grad_(True)
if self.preprocessor is not None:
clip_image = self.preprocessor(clip_image)
clip_output = self.image_encoder(
clip_image,
output_hidden_states=True
)
else:
self.image_encoder.eval()
clip_output = self.image_encoder(clip_image, output_hidden_states=True)
if self.preprocessor is not None:
clip_image = self.preprocessor(clip_image)
clip_output = self.image_encoder(
clip_image, output_hidden_states=True
)
clip_image_embeds = clip_output.hidden_states[-2]
if self.config.type.startswith('ip+'):
# they skip last layer for ip+
# https://github.com/tencent-ailab/IP-Adapter/blob/f4b6742db35ea6d81c7b829a55b0a312c7f5a677/tutorial_train_plus.py#L403C26-L403C26
clip_image_embeds = clip_output.hidden_states[-2]
else:
clip_image_embeds = clip_output.image_embeds
return clip_image_embeds
# use drop for prompt dropout, or negatives
@@ -403,13 +451,8 @@ class IPAdapter(torch.nn.Module):
yield from self.image_proj_model.parameters(recurse)
if self.config.train_image_encoder:
yield from self.image_encoder.parameters(recurse)
# if self.config.train_image_encoder:
# yield from self.image_encoder.parameters(recurse)
# self.image_encoder.train()
# else:
# for attn_processor in self.adapter_modules:
# yield from attn_processor.parameters(recurse)
# yield from self.image_proj_model.parameters(recurse)
if self.preprocessor is not None:
yield from self.preprocessor.parameters(recurse)
def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
strict = False
@@ -417,6 +460,8 @@ class IPAdapter(torch.nn.Module):
self.adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=strict)
if self.config.train_image_encoder and 'image_encoder' in state_dict:
self.image_encoder.load_state_dict(state_dict["image_encoder"], strict=strict)
if self.config.image_encoder_arch == 'clip+' and 'preprocessor' in state_dict:
self.preprocessor.load_state_dict(state_dict["preprocessor"], strict=strict)
def enable_gradient_checkpointing(self):
self.image_encoder.gradient_checkpointing = True

View File

@@ -0,0 +1,113 @@
import torch
import torch.nn as nn
class UpsampleBlock(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.conv_in = nn.Sequential(
nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1),
nn.GELU()
)
self.conv_up = nn.Sequential(
nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2),
nn.GELU()
)
self.conv_out = nn.Sequential(
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
)
def forward(self, x):
x = self.conv_in(x)
x = self.conv_up(x)
x = self.conv_out(x)
return x
class CLIPImagePreProcessor(nn.Module):
def __init__(
self,
input_size=672,
clip_input_size=224,
downscale_factor: int = 6,
channels=None, # 108
):
super().__init__()
# make sure they are evenly divisible
assert input_size % clip_input_size == 0
in_channels = 3
self.input_size = input_size
self.clip_input_size = clip_input_size
self.downscale_factor = downscale_factor
subpixel_channels = in_channels * downscale_factor ** 2 # 3 * 6 ** 2 = 108
if channels is None:
channels = subpixel_channels
upscale_factor = downscale_factor / int((input_size / clip_input_size)) # 6 / (672 / 224) = 2
num_upsample_blocks = int(upscale_factor // 2) # 2 // 2 = 1
# do a pooling layer to downscale the input to 1/3 of the size
# (bs, 3, 672, 672) -> (bs, 3, 224, 224)
kernel_size = input_size // clip_input_size
self.res_down = nn.AvgPool2d(
kernel_size=kernel_size,
stride=kernel_size
) # (bs, 3, 672, 672) -> (bs, 3, 224, 224)
# make a blending for output residual with near 0 weight
self.res_blend = nn.Parameter(torch.tensor(0.001)) # (bs, 3, 224, 224) -> (bs, 3, 224, 224)
self.unshuffle = nn.PixelUnshuffle(downscale_factor) # (bs, 3, 672, 672) -> (bs, 108, 112, 112)
self.conv_in = nn.Sequential(
nn.Conv2d(
subpixel_channels,
channels,
kernel_size=3,
padding=1
),
nn.GELU()
) # (bs, 108, 112, 112) -> (bs, 108, 112, 112)
self.upsample_blocks = nn.ModuleList()
current_channels = channels
for _ in range(num_upsample_blocks):
out_channels = current_channels // 2
self.upsample_blocks.append(UpsampleBlock(current_channels, out_channels))
current_channels = out_channels
# (bs, 108, 112, 112) -> (bs, 54, 224, 224)
self.conv_out = nn.Conv2d(
current_channels,
out_channels=3,
kernel_size=3,
padding=1
) # (bs, 54, 224, 224) -> (bs, 3, 224, 224)
def forward(self, x):
# resize to input_size x input_size
x = nn.functional.interpolate(x, size=(self.input_size, self.input_size), mode='bicubic')
res = self.res_down(x)
x = self.unshuffle(x)
x = self.conv_in(x)
for up in self.upsample_blocks:
x = up(x)
x = self.conv_out(x)
# blend residual
x = x * self.res_blend + res
return x