mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-29 02:31:17 +00:00
Created a size agnostic feature encoder (SAFE) model to be trained in replace of CLIP for ip adapters. It is mostly conv layers so will hopefully be able to handle facial features better than clip can. Also bug fixes
This commit is contained in:
253
toolkit/models/size_agnostic_feature_encoder.py
Normal file
253
toolkit/models/size_agnostic_feature_encoder.py
Normal file
@@ -0,0 +1,253 @@
|
||||
import os
|
||||
from typing import Union, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers.image_processing_utils import BaseImageProcessor
|
||||
|
||||
|
||||
class SAFEReducerBlock(nn.Module):
|
||||
"""
|
||||
This is the block that reduces the size of an vactor w and h be half. It is designed to be iterative
|
||||
So it is run multiple times to reduce an image to a desired dimension while carrying a shrinking residual
|
||||
along for the ride. This is done to preserve information.
|
||||
"""
|
||||
def __init__(self, channels=512):
|
||||
super(SAFEReducerBlock, self).__init__()
|
||||
self.channels = channels
|
||||
|
||||
activation = nn.GELU
|
||||
|
||||
self.reducer = nn.Sequential(
|
||||
nn.Conv2d(channels, channels, kernel_size=3, padding=1),
|
||||
nn.BatchNorm2d(channels),
|
||||
activation(),
|
||||
nn.Conv2d(channels, channels, kernel_size=3, padding=1),
|
||||
nn.BatchNorm2d(channels),
|
||||
activation(),
|
||||
nn.AvgPool2d(kernel_size=2, stride=2),
|
||||
)
|
||||
self.residual_shrink = nn.AvgPool2d(kernel_size=2, stride=2)
|
||||
|
||||
def forward(self, x):
|
||||
res = self.residual_shrink(x)
|
||||
reduced = self.reducer(x)
|
||||
return reduced + res
|
||||
|
||||
|
||||
class SizeAgnosticFeatureEncoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=3,
|
||||
num_tokens=8,
|
||||
num_vectors=768,
|
||||
reducer_channels=512,
|
||||
channels=2048,
|
||||
downscale_factor: int = 8,
|
||||
):
|
||||
super(SizeAgnosticFeatureEncoder, self).__init__()
|
||||
self.num_tokens = num_tokens
|
||||
self.num_vectors = num_vectors
|
||||
self.channels = channels
|
||||
self.reducer_channels = reducer_channels
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
# input is minimum of (bs, 3, 256, 256)
|
||||
|
||||
subpixel_channels = in_channels * downscale_factor ** 2
|
||||
|
||||
# PixelUnshuffle(8 = # (bs, 3, 32, 32) -> (bs, 192, 32, 32)
|
||||
# PixelUnshuffle(16 = # (bs, 3, 16, 16) -> (bs, 48, 16, 16)
|
||||
|
||||
self.unshuffle = nn.PixelUnshuffle(downscale_factor) # (bs, 3, 256, 256) -> (bs, 192, 32, 32)
|
||||
|
||||
self.conv_in = nn.Conv2d(subpixel_channels, reducer_channels, kernel_size=3, padding=1) # (bs, 192, 32, 32) -> (bs, 512, 32, 32)
|
||||
|
||||
# run as many times as needed to get to min feature of 8 on the smallest dimension
|
||||
self.reducer = SAFEReducerBlock(reducer_channels) # (bs, 512, 32, 32) -> (bs, 512, 8, 8)
|
||||
|
||||
self.reduced_out = nn.Conv2d(
|
||||
reducer_channels, self.channels, kernel_size=3, padding=1
|
||||
) # (bs, 512, 8, 8) -> (bs, 2048, 8, 8)
|
||||
|
||||
# (bs, 2048, 8, 8)
|
||||
self.block1 = SAFEReducerBlock(self.channels) # (bs, 2048, 8, 8) -> (bs, 2048, 4, 4)
|
||||
self.block2 = SAFEReducerBlock(self.channels) # (bs, 2048, 8, 8) -> (bs, 2048, 2, 2)
|
||||
|
||||
# reduce mean of dims 2 and 3
|
||||
self.adaptive_pool = nn.Sequential(
|
||||
nn.AdaptiveAvgPool2d((1, 1)),
|
||||
nn.Flatten(),
|
||||
)
|
||||
|
||||
# (bs, 2048)
|
||||
# linear layer to (bs, self.num_vectors * self.num_tokens)
|
||||
self.fc1 = nn.Linear(self.channels, self.num_vectors * self.num_tokens)
|
||||
|
||||
# (bs, self.num_vectors * self.num_tokens) = (bs, 8 * 768) = (bs, 6144)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.unshuffle(x)
|
||||
x = self.conv_in(x)
|
||||
|
||||
while True:
|
||||
# reduce until we get as close to 8x8 as possible without going under
|
||||
x = self.reducer(x)
|
||||
if x.shape[2] // 2 < 8 or x.shape[3] // 2 < 8:
|
||||
break
|
||||
|
||||
x = self.reduced_out(x)
|
||||
x = self.block1(x)
|
||||
x = self.block2(x)
|
||||
x = self.adaptive_pool(x)
|
||||
x = self.fc1(x)
|
||||
|
||||
# reshape
|
||||
x = x.view(-1, self.num_tokens, self.num_vectors)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class SAFEIPReturn:
|
||||
def __init__(self, pixel_values):
|
||||
self.pixel_values = pixel_values
|
||||
|
||||
|
||||
class SAFEImageProcessor(BaseImageProcessor):
|
||||
def __init__(
|
||||
self,
|
||||
max_size=1024,
|
||||
min_size=256,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self.max_size = max_size
|
||||
self.min_size = min_size
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls,
|
||||
pretrained_model_name_or_path: Union[str, os.PathLike],
|
||||
cache_dir: Optional[Union[str, os.PathLike]] = None,
|
||||
force_download: bool = False,
|
||||
local_files_only: bool = False,
|
||||
token: Optional[Union[str, bool]] = None,
|
||||
revision: str = "main",
|
||||
**kwargs,
|
||||
):
|
||||
# not needed
|
||||
return cls(**kwargs)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
images,
|
||||
**kwargs
|
||||
):
|
||||
# TODO allow for random resizing
|
||||
# comes in 0 - 1 range
|
||||
# if any size is smaller than 256, resize to 256
|
||||
# if any size is larger than max_size, resize to max_size
|
||||
if images.min() < -0.3 or images.max() > 1.3:
|
||||
raise ValueError(
|
||||
"images fed into SAFEImageProcessor values must be between 0 and 1. Got min: {}, max: {}".format(
|
||||
images.min(), images.max()
|
||||
))
|
||||
|
||||
# make sure we have (bs, 3, h, w)
|
||||
while len(images.shape) < 4:
|
||||
images = images.unsqueeze(0)
|
||||
|
||||
# expand to 3 channels if we only have 1 channel
|
||||
if images.shape[1] == 1:
|
||||
images = torch.cat([images, images, images], dim=1)
|
||||
|
||||
width = images.shape[3]
|
||||
height = images.shape[2]
|
||||
|
||||
if width < self.min_size or height < self.min_size:
|
||||
# scale up so that the smallest size is 256
|
||||
if width < height:
|
||||
new_width = self.min_size
|
||||
new_height = int(height * (self.min_size / width))
|
||||
else:
|
||||
new_height = self.min_size
|
||||
new_width = int(width * (self.min_size / height))
|
||||
images = nn.functional.interpolate(images, size=(new_height, new_width), mode='bilinear',
|
||||
align_corners=False)
|
||||
|
||||
elif width > self.max_size or height > self.max_size:
|
||||
# scale down so that the largest size is max_size but do not shrink the other size below 256
|
||||
if width > height:
|
||||
new_width = self.max_size
|
||||
new_height = int(height * (self.max_size / width))
|
||||
else:
|
||||
new_height = self.max_size
|
||||
new_width = int(width * (self.max_size / height))
|
||||
|
||||
if new_width < self.min_size:
|
||||
new_width = self.min_size
|
||||
new_height = int(height * (self.min_size / width))
|
||||
|
||||
if new_height < self.min_size:
|
||||
new_height = self.min_size
|
||||
new_width = int(width * (self.min_size / height))
|
||||
|
||||
images = nn.functional.interpolate(images, size=(new_height, new_width), mode='bilinear',
|
||||
align_corners=False)
|
||||
|
||||
# if wither side is not divisible by 16, mirror pad to make it so
|
||||
if images.shape[2] % 16 != 0:
|
||||
pad = 16 - (images.shape[2] % 16)
|
||||
pad1 = pad // 2
|
||||
pad2 = pad - pad1
|
||||
images = nn.functional.pad(images, (0, 0, pad1, pad2), mode='reflect')
|
||||
if images.shape[3] % 16 != 0:
|
||||
pad = 16 - (images.shape[3] % 16)
|
||||
pad1 = pad // 2
|
||||
pad2 = pad - pad1
|
||||
images = nn.functional.pad(images, (pad1, pad2, 0, 0), mode='reflect')
|
||||
|
||||
return SAFEIPReturn(images)
|
||||
|
||||
|
||||
class SAFEVMConfig:
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=3,
|
||||
num_tokens=8,
|
||||
num_vectors=768,
|
||||
reducer_channels=512,
|
||||
channels=2048,
|
||||
downscale_factor: int = 8,
|
||||
**kwargs
|
||||
):
|
||||
self.in_channels = in_channels
|
||||
self.num_tokens = num_tokens
|
||||
self.num_vectors = num_vectors
|
||||
self.reducer_channels = reducer_channels
|
||||
self.channels = channels
|
||||
self.downscale_factor = downscale_factor
|
||||
|
||||
self.hidden_size = num_vectors
|
||||
self.projection_dim = num_vectors
|
||||
|
||||
|
||||
class SAFEVMReturn:
|
||||
def __init__(self, output):
|
||||
self.output = output
|
||||
# todo actually do hidden states. This is just for code compatability for now
|
||||
self.hidden_states = [output for _ in range(13)]
|
||||
|
||||
|
||||
class SAFEVisionModel(SizeAgnosticFeatureEncoder):
|
||||
def __init__(self, **kwargs):
|
||||
self.config = SAFEVMConfig(**kwargs)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
# not needed
|
||||
return SAFEVisionModel(**kwargs)
|
||||
|
||||
def forward(self, x, **kwargs):
|
||||
return SAFEVMReturn(super().forward(x))
|
||||
Reference in New Issue
Block a user