Created a size agnostic feature encoder (SAFE) model to be trained in replace of CLIP for ip adapters. It is mostly conv layers so will hopefully be able to handle facial features better than clip can. Also bug fixes

This commit is contained in:
Jaret Burkett
2023-12-28 12:20:27 -07:00
parent d11ed7f66c
commit eeee4a1620
5 changed files with 286 additions and 6 deletions

View File

@@ -39,8 +39,6 @@ class SDTrainer(BaseSDTrainProcess):
self.do_prior_prediction = False
self.do_long_prompts = False
self.do_guided_loss = False
if self.train_config.inverted_mask_prior:
self.do_prior_prediction = True
def before_model_load(self):
pass
@@ -89,13 +87,15 @@ class SDTrainer(BaseSDTrainProcess):
prior_mask_multiplier = None
target_mask_multiplier = None
has_mask = batch.mask_tensor is not None
if self.train_config.match_noise_norm:
# match the norm of the noise
noise_norm = torch.linalg.vector_norm(noise, ord=2, dim=(1, 2, 3), keepdim=True)
noise_pred_norm = torch.linalg.vector_norm(noise_pred, ord=2, dim=(1, 2, 3), keepdim=True)
noise_pred = noise_pred * (noise_norm / noise_pred_norm)
if self.train_config.inverted_mask_prior and prior_pred is not None:
if self.train_config.inverted_mask_prior and prior_pred is not None and has_mask:
# we need to make the noise prediction be a masked blending of noise and prior_pred
stretched_mask_multiplier = value_map(
mask_multiplier,
@@ -867,7 +867,17 @@ class SDTrainer(BaseSDTrainProcess):
conditional_embeds = self.adapter(conditional_embeds.detach(), conditional_clip_embeds)
prior_pred = None
if ((has_adapter_img and self.assistant_adapter and match_adapter_assist) or self.do_prior_prediction) and not is_reg:
do_reg_prior = False
if is_reg and (self.network is not None or self.adapter is not None):
# we are doing a reg image and we have a network or adapter
do_reg_prior = True
do_inverted_masked_prior = False
if self.train_config.inverted_mask_prior and batch.mask_tensor is not None:
do_inverted_masked_prior = True
if ((has_adapter_img and self.assistant_adapter and match_adapter_assist) or self.do_prior_prediction or do_reg_prior or do_inverted_masked_prior):
with self.timer('prior predict'):
prior_pred = self.get_prior_prediction(
noisy_latents=noisy_latents,

View File

@@ -153,7 +153,9 @@ class AdapterConfig:
self.num_tokens: int = num_tokens
self.train_image_encoder: bool = kwargs.get('train_image_encoder', False)
self.image_encoder_arch: str = kwargs.get('image_encoder_arch', 'clip') # clip vit vit_hybrid
self.image_encoder_arch: str = kwargs.get('image_encoder_arch', 'clip') # clip vit vit_hybrid, safe
self.safe_reducer_channels: int = kwargs.get('safe_reducer_channels', 512)
self.safe_channels: int = kwargs.get('safe_channels', 2048)
# clip vision
self.trigger = kwargs.get('trigger', 'tri993r')

View File

@@ -32,6 +32,8 @@ from transformers import (
ConvNextForImageClassification,
ConvNextImageProcessor
)
from toolkit.models.size_agnostic_feature_encoder import SAFEImageProcessor, SAFEVisionModel
from transformers import ViTHybridImageProcessor, ViTHybridForImageClassification
from transformers import ViTFeatureExtractor, ViTForImageClassification
@@ -175,6 +177,19 @@ class IPAdapter(torch.nn.Module):
except EnvironmentError:
self.clip_image_processor = ViTFeatureExtractor()
self.image_encoder = ViTForImageClassification.from_pretrained(adapter_config.image_encoder_path).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
elif self.config.image_encoder_arch == 'safe':
try:
self.clip_image_processor = SAFEImageProcessor.from_pretrained(adapter_config.image_encoder_path)
except EnvironmentError:
self.clip_image_processor = SAFEImageProcessor()
self.image_encoder = SAFEVisionModel(
in_channels=3,
num_tokens=self.config.num_tokens if self.config.adapter_type == 'ip+' else 1,
num_vectors=sd.unet.config['cross_attention_dim'] if self.config.adapter_type == 'ip+' else self.config.safe_channels,
reducer_channels=self.config.safe_reducer_channels,
channels=self.config.safe_channels,
downscale_factor=8
).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
elif self.config.image_encoder_arch == 'convnext':
try:
self.clip_image_processor = ConvNextImageProcessor.from_pretrained(adapter_config.image_encoder_path)

View File

@@ -0,0 +1,253 @@
import os
from typing import Union, Optional
import torch
import torch.nn as nn
from transformers.image_processing_utils import BaseImageProcessor
class SAFEReducerBlock(nn.Module):
"""
This is the block that reduces the size of an vactor w and h be half. It is designed to be iterative
So it is run multiple times to reduce an image to a desired dimension while carrying a shrinking residual
along for the ride. This is done to preserve information.
"""
def __init__(self, channels=512):
super(SAFEReducerBlock, self).__init__()
self.channels = channels
activation = nn.GELU
self.reducer = nn.Sequential(
nn.Conv2d(channels, channels, kernel_size=3, padding=1),
nn.BatchNorm2d(channels),
activation(),
nn.Conv2d(channels, channels, kernel_size=3, padding=1),
nn.BatchNorm2d(channels),
activation(),
nn.AvgPool2d(kernel_size=2, stride=2),
)
self.residual_shrink = nn.AvgPool2d(kernel_size=2, stride=2)
def forward(self, x):
res = self.residual_shrink(x)
reduced = self.reducer(x)
return reduced + res
class SizeAgnosticFeatureEncoder(nn.Module):
def __init__(
self,
in_channels=3,
num_tokens=8,
num_vectors=768,
reducer_channels=512,
channels=2048,
downscale_factor: int = 8,
):
super(SizeAgnosticFeatureEncoder, self).__init__()
self.num_tokens = num_tokens
self.num_vectors = num_vectors
self.channels = channels
self.reducer_channels = reducer_channels
self.gradient_checkpointing = False
# input is minimum of (bs, 3, 256, 256)
subpixel_channels = in_channels * downscale_factor ** 2
# PixelUnshuffle(8 = # (bs, 3, 32, 32) -> (bs, 192, 32, 32)
# PixelUnshuffle(16 = # (bs, 3, 16, 16) -> (bs, 48, 16, 16)
self.unshuffle = nn.PixelUnshuffle(downscale_factor) # (bs, 3, 256, 256) -> (bs, 192, 32, 32)
self.conv_in = nn.Conv2d(subpixel_channels, reducer_channels, kernel_size=3, padding=1) # (bs, 192, 32, 32) -> (bs, 512, 32, 32)
# run as many times as needed to get to min feature of 8 on the smallest dimension
self.reducer = SAFEReducerBlock(reducer_channels) # (bs, 512, 32, 32) -> (bs, 512, 8, 8)
self.reduced_out = nn.Conv2d(
reducer_channels, self.channels, kernel_size=3, padding=1
) # (bs, 512, 8, 8) -> (bs, 2048, 8, 8)
# (bs, 2048, 8, 8)
self.block1 = SAFEReducerBlock(self.channels) # (bs, 2048, 8, 8) -> (bs, 2048, 4, 4)
self.block2 = SAFEReducerBlock(self.channels) # (bs, 2048, 8, 8) -> (bs, 2048, 2, 2)
# reduce mean of dims 2 and 3
self.adaptive_pool = nn.Sequential(
nn.AdaptiveAvgPool2d((1, 1)),
nn.Flatten(),
)
# (bs, 2048)
# linear layer to (bs, self.num_vectors * self.num_tokens)
self.fc1 = nn.Linear(self.channels, self.num_vectors * self.num_tokens)
# (bs, self.num_vectors * self.num_tokens) = (bs, 8 * 768) = (bs, 6144)
def forward(self, x):
x = self.unshuffle(x)
x = self.conv_in(x)
while True:
# reduce until we get as close to 8x8 as possible without going under
x = self.reducer(x)
if x.shape[2] // 2 < 8 or x.shape[3] // 2 < 8:
break
x = self.reduced_out(x)
x = self.block1(x)
x = self.block2(x)
x = self.adaptive_pool(x)
x = self.fc1(x)
# reshape
x = x.view(-1, self.num_tokens, self.num_vectors)
return x
class SAFEIPReturn:
def __init__(self, pixel_values):
self.pixel_values = pixel_values
class SAFEImageProcessor(BaseImageProcessor):
def __init__(
self,
max_size=1024,
min_size=256,
**kwargs
):
super().__init__(**kwargs)
self.max_size = max_size
self.min_size = min_size
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: Union[str, os.PathLike],
cache_dir: Optional[Union[str, os.PathLike]] = None,
force_download: bool = False,
local_files_only: bool = False,
token: Optional[Union[str, bool]] = None,
revision: str = "main",
**kwargs,
):
# not needed
return cls(**kwargs)
def __call__(
self,
images,
**kwargs
):
# TODO allow for random resizing
# comes in 0 - 1 range
# if any size is smaller than 256, resize to 256
# if any size is larger than max_size, resize to max_size
if images.min() < -0.3 or images.max() > 1.3:
raise ValueError(
"images fed into SAFEImageProcessor values must be between 0 and 1. Got min: {}, max: {}".format(
images.min(), images.max()
))
# make sure we have (bs, 3, h, w)
while len(images.shape) < 4:
images = images.unsqueeze(0)
# expand to 3 channels if we only have 1 channel
if images.shape[1] == 1:
images = torch.cat([images, images, images], dim=1)
width = images.shape[3]
height = images.shape[2]
if width < self.min_size or height < self.min_size:
# scale up so that the smallest size is 256
if width < height:
new_width = self.min_size
new_height = int(height * (self.min_size / width))
else:
new_height = self.min_size
new_width = int(width * (self.min_size / height))
images = nn.functional.interpolate(images, size=(new_height, new_width), mode='bilinear',
align_corners=False)
elif width > self.max_size or height > self.max_size:
# scale down so that the largest size is max_size but do not shrink the other size below 256
if width > height:
new_width = self.max_size
new_height = int(height * (self.max_size / width))
else:
new_height = self.max_size
new_width = int(width * (self.max_size / height))
if new_width < self.min_size:
new_width = self.min_size
new_height = int(height * (self.min_size / width))
if new_height < self.min_size:
new_height = self.min_size
new_width = int(width * (self.min_size / height))
images = nn.functional.interpolate(images, size=(new_height, new_width), mode='bilinear',
align_corners=False)
# if wither side is not divisible by 16, mirror pad to make it so
if images.shape[2] % 16 != 0:
pad = 16 - (images.shape[2] % 16)
pad1 = pad // 2
pad2 = pad - pad1
images = nn.functional.pad(images, (0, 0, pad1, pad2), mode='reflect')
if images.shape[3] % 16 != 0:
pad = 16 - (images.shape[3] % 16)
pad1 = pad // 2
pad2 = pad - pad1
images = nn.functional.pad(images, (pad1, pad2, 0, 0), mode='reflect')
return SAFEIPReturn(images)
class SAFEVMConfig:
def __init__(
self,
in_channels=3,
num_tokens=8,
num_vectors=768,
reducer_channels=512,
channels=2048,
downscale_factor: int = 8,
**kwargs
):
self.in_channels = in_channels
self.num_tokens = num_tokens
self.num_vectors = num_vectors
self.reducer_channels = reducer_channels
self.channels = channels
self.downscale_factor = downscale_factor
self.hidden_size = num_vectors
self.projection_dim = num_vectors
class SAFEVMReturn:
def __init__(self, output):
self.output = output
# todo actually do hidden states. This is just for code compatability for now
self.hidden_states = [output for _ in range(13)]
class SAFEVisionModel(SizeAgnosticFeatureEncoder):
def __init__(self, **kwargs):
self.config = SAFEVMConfig(**kwargs)
super().__init__(**kwargs)
@classmethod
def from_pretrained(cls, *args, **kwargs):
# not needed
return SAFEVisionModel(**kwargs)
def forward(self, x, **kwargs):
return SAFEVMReturn(super().forward(x))

View File

@@ -1447,7 +1447,7 @@ class StableDiffusion:
}
if self.adapter is not None:
if isinstance(self.adapter, IPAdapter):
requires_grad = self.adapter.adapter_modules.training
requires_grad = self.adapter.image_proj_model.training
adapter_device = self.unet.device
elif isinstance(self.adapter, T2IAdapter):
requires_grad = self.adapter.adapter.conv_in.weight.requires_grad