mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Created a size agnostic feature encoder (SAFE) model to be trained in replace of CLIP for ip adapters. It is mostly conv layers so will hopefully be able to handle facial features better than clip can. Also bug fixes
This commit is contained in:
@@ -39,8 +39,6 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
self.do_prior_prediction = False
|
||||
self.do_long_prompts = False
|
||||
self.do_guided_loss = False
|
||||
if self.train_config.inverted_mask_prior:
|
||||
self.do_prior_prediction = True
|
||||
|
||||
def before_model_load(self):
|
||||
pass
|
||||
@@ -89,13 +87,15 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
prior_mask_multiplier = None
|
||||
target_mask_multiplier = None
|
||||
|
||||
has_mask = batch.mask_tensor is not None
|
||||
|
||||
if self.train_config.match_noise_norm:
|
||||
# match the norm of the noise
|
||||
noise_norm = torch.linalg.vector_norm(noise, ord=2, dim=(1, 2, 3), keepdim=True)
|
||||
noise_pred_norm = torch.linalg.vector_norm(noise_pred, ord=2, dim=(1, 2, 3), keepdim=True)
|
||||
noise_pred = noise_pred * (noise_norm / noise_pred_norm)
|
||||
|
||||
if self.train_config.inverted_mask_prior and prior_pred is not None:
|
||||
if self.train_config.inverted_mask_prior and prior_pred is not None and has_mask:
|
||||
# we need to make the noise prediction be a masked blending of noise and prior_pred
|
||||
stretched_mask_multiplier = value_map(
|
||||
mask_multiplier,
|
||||
@@ -867,7 +867,17 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
conditional_embeds = self.adapter(conditional_embeds.detach(), conditional_clip_embeds)
|
||||
|
||||
prior_pred = None
|
||||
if ((has_adapter_img and self.assistant_adapter and match_adapter_assist) or self.do_prior_prediction) and not is_reg:
|
||||
|
||||
do_reg_prior = False
|
||||
if is_reg and (self.network is not None or self.adapter is not None):
|
||||
# we are doing a reg image and we have a network or adapter
|
||||
do_reg_prior = True
|
||||
|
||||
do_inverted_masked_prior = False
|
||||
if self.train_config.inverted_mask_prior and batch.mask_tensor is not None:
|
||||
do_inverted_masked_prior = True
|
||||
|
||||
if ((has_adapter_img and self.assistant_adapter and match_adapter_assist) or self.do_prior_prediction or do_reg_prior or do_inverted_masked_prior):
|
||||
with self.timer('prior predict'):
|
||||
prior_pred = self.get_prior_prediction(
|
||||
noisy_latents=noisy_latents,
|
||||
|
||||
@@ -153,7 +153,9 @@ class AdapterConfig:
|
||||
|
||||
self.num_tokens: int = num_tokens
|
||||
self.train_image_encoder: bool = kwargs.get('train_image_encoder', False)
|
||||
self.image_encoder_arch: str = kwargs.get('image_encoder_arch', 'clip') # clip vit vit_hybrid
|
||||
self.image_encoder_arch: str = kwargs.get('image_encoder_arch', 'clip') # clip vit vit_hybrid, safe
|
||||
self.safe_reducer_channels: int = kwargs.get('safe_reducer_channels', 512)
|
||||
self.safe_channels: int = kwargs.get('safe_channels', 2048)
|
||||
|
||||
# clip vision
|
||||
self.trigger = kwargs.get('trigger', 'tri993r')
|
||||
|
||||
@@ -32,6 +32,8 @@ from transformers import (
|
||||
ConvNextForImageClassification,
|
||||
ConvNextImageProcessor
|
||||
)
|
||||
from toolkit.models.size_agnostic_feature_encoder import SAFEImageProcessor, SAFEVisionModel
|
||||
|
||||
from transformers import ViTHybridImageProcessor, ViTHybridForImageClassification
|
||||
|
||||
from transformers import ViTFeatureExtractor, ViTForImageClassification
|
||||
@@ -175,6 +177,19 @@ class IPAdapter(torch.nn.Module):
|
||||
except EnvironmentError:
|
||||
self.clip_image_processor = ViTFeatureExtractor()
|
||||
self.image_encoder = ViTForImageClassification.from_pretrained(adapter_config.image_encoder_path).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
|
||||
elif self.config.image_encoder_arch == 'safe':
|
||||
try:
|
||||
self.clip_image_processor = SAFEImageProcessor.from_pretrained(adapter_config.image_encoder_path)
|
||||
except EnvironmentError:
|
||||
self.clip_image_processor = SAFEImageProcessor()
|
||||
self.image_encoder = SAFEVisionModel(
|
||||
in_channels=3,
|
||||
num_tokens=self.config.num_tokens if self.config.adapter_type == 'ip+' else 1,
|
||||
num_vectors=sd.unet.config['cross_attention_dim'] if self.config.adapter_type == 'ip+' else self.config.safe_channels,
|
||||
reducer_channels=self.config.safe_reducer_channels,
|
||||
channels=self.config.safe_channels,
|
||||
downscale_factor=8
|
||||
).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
|
||||
elif self.config.image_encoder_arch == 'convnext':
|
||||
try:
|
||||
self.clip_image_processor = ConvNextImageProcessor.from_pretrained(adapter_config.image_encoder_path)
|
||||
|
||||
253
toolkit/models/size_agnostic_feature_encoder.py
Normal file
253
toolkit/models/size_agnostic_feature_encoder.py
Normal file
@@ -0,0 +1,253 @@
|
||||
import os
|
||||
from typing import Union, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers.image_processing_utils import BaseImageProcessor
|
||||
|
||||
|
||||
class SAFEReducerBlock(nn.Module):
|
||||
"""
|
||||
This is the block that reduces the size of an vactor w and h be half. It is designed to be iterative
|
||||
So it is run multiple times to reduce an image to a desired dimension while carrying a shrinking residual
|
||||
along for the ride. This is done to preserve information.
|
||||
"""
|
||||
def __init__(self, channels=512):
|
||||
super(SAFEReducerBlock, self).__init__()
|
||||
self.channels = channels
|
||||
|
||||
activation = nn.GELU
|
||||
|
||||
self.reducer = nn.Sequential(
|
||||
nn.Conv2d(channels, channels, kernel_size=3, padding=1),
|
||||
nn.BatchNorm2d(channels),
|
||||
activation(),
|
||||
nn.Conv2d(channels, channels, kernel_size=3, padding=1),
|
||||
nn.BatchNorm2d(channels),
|
||||
activation(),
|
||||
nn.AvgPool2d(kernel_size=2, stride=2),
|
||||
)
|
||||
self.residual_shrink = nn.AvgPool2d(kernel_size=2, stride=2)
|
||||
|
||||
def forward(self, x):
|
||||
res = self.residual_shrink(x)
|
||||
reduced = self.reducer(x)
|
||||
return reduced + res
|
||||
|
||||
|
||||
class SizeAgnosticFeatureEncoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=3,
|
||||
num_tokens=8,
|
||||
num_vectors=768,
|
||||
reducer_channels=512,
|
||||
channels=2048,
|
||||
downscale_factor: int = 8,
|
||||
):
|
||||
super(SizeAgnosticFeatureEncoder, self).__init__()
|
||||
self.num_tokens = num_tokens
|
||||
self.num_vectors = num_vectors
|
||||
self.channels = channels
|
||||
self.reducer_channels = reducer_channels
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
# input is minimum of (bs, 3, 256, 256)
|
||||
|
||||
subpixel_channels = in_channels * downscale_factor ** 2
|
||||
|
||||
# PixelUnshuffle(8 = # (bs, 3, 32, 32) -> (bs, 192, 32, 32)
|
||||
# PixelUnshuffle(16 = # (bs, 3, 16, 16) -> (bs, 48, 16, 16)
|
||||
|
||||
self.unshuffle = nn.PixelUnshuffle(downscale_factor) # (bs, 3, 256, 256) -> (bs, 192, 32, 32)
|
||||
|
||||
self.conv_in = nn.Conv2d(subpixel_channels, reducer_channels, kernel_size=3, padding=1) # (bs, 192, 32, 32) -> (bs, 512, 32, 32)
|
||||
|
||||
# run as many times as needed to get to min feature of 8 on the smallest dimension
|
||||
self.reducer = SAFEReducerBlock(reducer_channels) # (bs, 512, 32, 32) -> (bs, 512, 8, 8)
|
||||
|
||||
self.reduced_out = nn.Conv2d(
|
||||
reducer_channels, self.channels, kernel_size=3, padding=1
|
||||
) # (bs, 512, 8, 8) -> (bs, 2048, 8, 8)
|
||||
|
||||
# (bs, 2048, 8, 8)
|
||||
self.block1 = SAFEReducerBlock(self.channels) # (bs, 2048, 8, 8) -> (bs, 2048, 4, 4)
|
||||
self.block2 = SAFEReducerBlock(self.channels) # (bs, 2048, 8, 8) -> (bs, 2048, 2, 2)
|
||||
|
||||
# reduce mean of dims 2 and 3
|
||||
self.adaptive_pool = nn.Sequential(
|
||||
nn.AdaptiveAvgPool2d((1, 1)),
|
||||
nn.Flatten(),
|
||||
)
|
||||
|
||||
# (bs, 2048)
|
||||
# linear layer to (bs, self.num_vectors * self.num_tokens)
|
||||
self.fc1 = nn.Linear(self.channels, self.num_vectors * self.num_tokens)
|
||||
|
||||
# (bs, self.num_vectors * self.num_tokens) = (bs, 8 * 768) = (bs, 6144)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.unshuffle(x)
|
||||
x = self.conv_in(x)
|
||||
|
||||
while True:
|
||||
# reduce until we get as close to 8x8 as possible without going under
|
||||
x = self.reducer(x)
|
||||
if x.shape[2] // 2 < 8 or x.shape[3] // 2 < 8:
|
||||
break
|
||||
|
||||
x = self.reduced_out(x)
|
||||
x = self.block1(x)
|
||||
x = self.block2(x)
|
||||
x = self.adaptive_pool(x)
|
||||
x = self.fc1(x)
|
||||
|
||||
# reshape
|
||||
x = x.view(-1, self.num_tokens, self.num_vectors)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class SAFEIPReturn:
|
||||
def __init__(self, pixel_values):
|
||||
self.pixel_values = pixel_values
|
||||
|
||||
|
||||
class SAFEImageProcessor(BaseImageProcessor):
|
||||
def __init__(
|
||||
self,
|
||||
max_size=1024,
|
||||
min_size=256,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self.max_size = max_size
|
||||
self.min_size = min_size
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls,
|
||||
pretrained_model_name_or_path: Union[str, os.PathLike],
|
||||
cache_dir: Optional[Union[str, os.PathLike]] = None,
|
||||
force_download: bool = False,
|
||||
local_files_only: bool = False,
|
||||
token: Optional[Union[str, bool]] = None,
|
||||
revision: str = "main",
|
||||
**kwargs,
|
||||
):
|
||||
# not needed
|
||||
return cls(**kwargs)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
images,
|
||||
**kwargs
|
||||
):
|
||||
# TODO allow for random resizing
|
||||
# comes in 0 - 1 range
|
||||
# if any size is smaller than 256, resize to 256
|
||||
# if any size is larger than max_size, resize to max_size
|
||||
if images.min() < -0.3 or images.max() > 1.3:
|
||||
raise ValueError(
|
||||
"images fed into SAFEImageProcessor values must be between 0 and 1. Got min: {}, max: {}".format(
|
||||
images.min(), images.max()
|
||||
))
|
||||
|
||||
# make sure we have (bs, 3, h, w)
|
||||
while len(images.shape) < 4:
|
||||
images = images.unsqueeze(0)
|
||||
|
||||
# expand to 3 channels if we only have 1 channel
|
||||
if images.shape[1] == 1:
|
||||
images = torch.cat([images, images, images], dim=1)
|
||||
|
||||
width = images.shape[3]
|
||||
height = images.shape[2]
|
||||
|
||||
if width < self.min_size or height < self.min_size:
|
||||
# scale up so that the smallest size is 256
|
||||
if width < height:
|
||||
new_width = self.min_size
|
||||
new_height = int(height * (self.min_size / width))
|
||||
else:
|
||||
new_height = self.min_size
|
||||
new_width = int(width * (self.min_size / height))
|
||||
images = nn.functional.interpolate(images, size=(new_height, new_width), mode='bilinear',
|
||||
align_corners=False)
|
||||
|
||||
elif width > self.max_size or height > self.max_size:
|
||||
# scale down so that the largest size is max_size but do not shrink the other size below 256
|
||||
if width > height:
|
||||
new_width = self.max_size
|
||||
new_height = int(height * (self.max_size / width))
|
||||
else:
|
||||
new_height = self.max_size
|
||||
new_width = int(width * (self.max_size / height))
|
||||
|
||||
if new_width < self.min_size:
|
||||
new_width = self.min_size
|
||||
new_height = int(height * (self.min_size / width))
|
||||
|
||||
if new_height < self.min_size:
|
||||
new_height = self.min_size
|
||||
new_width = int(width * (self.min_size / height))
|
||||
|
||||
images = nn.functional.interpolate(images, size=(new_height, new_width), mode='bilinear',
|
||||
align_corners=False)
|
||||
|
||||
# if wither side is not divisible by 16, mirror pad to make it so
|
||||
if images.shape[2] % 16 != 0:
|
||||
pad = 16 - (images.shape[2] % 16)
|
||||
pad1 = pad // 2
|
||||
pad2 = pad - pad1
|
||||
images = nn.functional.pad(images, (0, 0, pad1, pad2), mode='reflect')
|
||||
if images.shape[3] % 16 != 0:
|
||||
pad = 16 - (images.shape[3] % 16)
|
||||
pad1 = pad // 2
|
||||
pad2 = pad - pad1
|
||||
images = nn.functional.pad(images, (pad1, pad2, 0, 0), mode='reflect')
|
||||
|
||||
return SAFEIPReturn(images)
|
||||
|
||||
|
||||
class SAFEVMConfig:
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=3,
|
||||
num_tokens=8,
|
||||
num_vectors=768,
|
||||
reducer_channels=512,
|
||||
channels=2048,
|
||||
downscale_factor: int = 8,
|
||||
**kwargs
|
||||
):
|
||||
self.in_channels = in_channels
|
||||
self.num_tokens = num_tokens
|
||||
self.num_vectors = num_vectors
|
||||
self.reducer_channels = reducer_channels
|
||||
self.channels = channels
|
||||
self.downscale_factor = downscale_factor
|
||||
|
||||
self.hidden_size = num_vectors
|
||||
self.projection_dim = num_vectors
|
||||
|
||||
|
||||
class SAFEVMReturn:
|
||||
def __init__(self, output):
|
||||
self.output = output
|
||||
# todo actually do hidden states. This is just for code compatability for now
|
||||
self.hidden_states = [output for _ in range(13)]
|
||||
|
||||
|
||||
class SAFEVisionModel(SizeAgnosticFeatureEncoder):
|
||||
def __init__(self, **kwargs):
|
||||
self.config = SAFEVMConfig(**kwargs)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
# not needed
|
||||
return SAFEVisionModel(**kwargs)
|
||||
|
||||
def forward(self, x, **kwargs):
|
||||
return SAFEVMReturn(super().forward(x))
|
||||
@@ -1447,7 +1447,7 @@ class StableDiffusion:
|
||||
}
|
||||
if self.adapter is not None:
|
||||
if isinstance(self.adapter, IPAdapter):
|
||||
requires_grad = self.adapter.adapter_modules.training
|
||||
requires_grad = self.adapter.image_proj_model.training
|
||||
adapter_device = self.unet.device
|
||||
elif isinstance(self.adapter, T2IAdapter):
|
||||
requires_grad = self.adapter.adapter.conv_in.weight.requires_grad
|
||||
|
||||
Reference in New Issue
Block a user