mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-02-28 08:13:58 +00:00
Added some additional experimental things to the vision direct encoder
This commit is contained in:
@@ -211,6 +211,11 @@ class AdapterConfig:
|
||||
self.pixtral_random_image_size: int = kwargs.get('pixtral_random_image_size', False)
|
||||
|
||||
self.flux_only_double: bool = kwargs.get('flux_only_double', False)
|
||||
|
||||
# train and use a conv layer to pool the embedding
|
||||
self.conv_pooling: bool = kwargs.get('conv_pooling', False)
|
||||
self.conv_pooling_stacks: int = kwargs.get('conv_pooling_stacks', 1)
|
||||
self.sparse_autoencoder_dim: Optional[int] = kwargs.get('sparse_autoencoder_dim', None)
|
||||
|
||||
|
||||
class EmbeddingConfig:
|
||||
|
||||
@@ -407,7 +407,7 @@ class CustomAdapter(torch.nn.Module):
|
||||
if 'vd_adapter' in state_dict:
|
||||
self.vd_adapter.load_state_dict(state_dict['vd_adapter'], strict=strict)
|
||||
if 'dvadapter' in state_dict:
|
||||
self.vd_adapter.load_state_dict(state_dict['dvadapter'], strict=strict)
|
||||
self.vd_adapter.load_state_dict(state_dict['dvadapter'], strict=False)
|
||||
|
||||
if 'sv_adapter' in state_dict:
|
||||
self.single_value_adapter.load_state_dict(state_dict['sv_adapter'], strict=strict)
|
||||
@@ -732,8 +732,7 @@ class CustomAdapter(torch.nn.Module):
|
||||
def train(self, mode: bool = True):
|
||||
if self.config.train_image_encoder:
|
||||
self.vision_encoder.train(mode)
|
||||
else:
|
||||
super().train(mode)
|
||||
super().train(mode)
|
||||
|
||||
def trigger_pre_te(
|
||||
self,
|
||||
@@ -879,7 +878,10 @@ class CustomAdapter(torch.nn.Module):
|
||||
elif self.config.clip_layer == 'last_hidden_state':
|
||||
clip_image_embeds = clip_output.hidden_states[-1]
|
||||
else:
|
||||
clip_image_embeds = clip_output.image_embeds
|
||||
if hasattr(clip_output, 'image_embeds'):
|
||||
clip_image_embeds = clip_output.image_embeds
|
||||
elif hasattr(clip_output, 'pooler_output'):
|
||||
clip_image_embeds = clip_output.pooler_output
|
||||
# TODO should we always norm image embeds?
|
||||
# get norm embeddings
|
||||
l2_norm = torch.norm(clip_image_embeds, p=2)
|
||||
@@ -931,8 +933,12 @@ class CustomAdapter(torch.nn.Module):
|
||||
yield from attn_processor.parameters(recurse)
|
||||
if self.config.train_image_encoder:
|
||||
yield from self.vision_encoder.parameters(recurse)
|
||||
if self.config.num_tokens:
|
||||
if self.vd_adapter.resampler is not None:
|
||||
yield from self.vd_adapter.resampler.parameters(recurse)
|
||||
if self.vd_adapter.pool is not None:
|
||||
yield from self.vd_adapter.pool.parameters(recurse)
|
||||
if self.vd_adapter.sparse_autoencoder is not None:
|
||||
yield from self.vd_adapter.sparse_autoencoder.parameters(recurse)
|
||||
elif self.config.type == 'te_augmenter':
|
||||
yield from self.te_augmenter.parameters(recurse)
|
||||
if self.config.train_image_encoder:
|
||||
|
||||
@@ -13,7 +13,7 @@ import numpy as np
|
||||
import torch
|
||||
from safetensors.torch import load_file, save_file
|
||||
from tqdm import tqdm
|
||||
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
||||
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, SiglipImageProcessor
|
||||
|
||||
from toolkit.basic import flush, value_map
|
||||
from toolkit.buckets import get_bucket_for_image_size, get_resolution
|
||||
@@ -764,7 +764,8 @@ class ClipImageFileItemDTOMixin:
|
||||
return self.clip_image_path
|
||||
|
||||
def load_clip_image(self: 'FileItemDTO'):
|
||||
is_dynamic_size_and_aspect = isinstance(self.clip_image_processor, PixtralVisionImagePreprocessorCompatible)
|
||||
is_dynamic_size_and_aspect = isinstance(self.clip_image_processor, PixtralVisionImagePreprocessorCompatible) or \
|
||||
isinstance(self.clip_image_processor, SiglipImageProcessor)
|
||||
if self.is_vision_clip_cached:
|
||||
self.clip_image_embeds = load_file(self.get_clip_vision_embeddings_path())
|
||||
|
||||
@@ -794,21 +795,7 @@ class ClipImageFileItemDTOMixin:
|
||||
img = img.transpose(Image.FLIP_TOP_BOTTOM)
|
||||
|
||||
if is_dynamic_size_and_aspect:
|
||||
# just match the bucket size for now
|
||||
if self.dataset_config.buckets:
|
||||
# scale and crop based on file item
|
||||
img = img.resize((self.scale_to_width, self.scale_to_height), Image.BICUBIC)
|
||||
# img = transforms.CenterCrop((self.crop_height, self.crop_width))(img)
|
||||
# crop
|
||||
img = img.crop((
|
||||
self.crop_x,
|
||||
self.crop_y,
|
||||
self.crop_x + self.crop_width,
|
||||
self.crop_y + self.crop_height
|
||||
))
|
||||
else:
|
||||
raise Exception("Control images not supported for non-bucket datasets")
|
||||
|
||||
pass # let the image processor handle it
|
||||
elif img.width != img.height:
|
||||
min_size = min(img.width, img.height)
|
||||
if self.dataset_config.square_crop:
|
||||
|
||||
@@ -10,6 +10,7 @@ from collections import OrderedDict
|
||||
from diffusers import Transformer2DModel, FluxTransformer2DModel
|
||||
from transformers import T5EncoderModel, CLIPTextModel, CLIPTokenizer, T5Tokenizer, CLIPVisionModelWithProjection
|
||||
from toolkit.models.pixtral_vision import PixtralVisionEncoder, PixtralVisionImagePreprocessor, VisionLanguageAdapter
|
||||
from transformers import SiglipImageProcessor, SiglipVisionModel
|
||||
|
||||
from toolkit.config_modules import AdapterConfig
|
||||
from toolkit.paths import REPOS_ROOT
|
||||
@@ -19,6 +20,52 @@ sys.path.append(REPOS_ROOT)
|
||||
if TYPE_CHECKING:
|
||||
from toolkit.stable_diffusion_model import StableDiffusion
|
||||
from toolkit.custom_adapter import CustomAdapter
|
||||
|
||||
|
||||
# matches distribution of randn
|
||||
class Norm(nn.Module):
|
||||
def __init__(self, target_mean=0.0, target_std=1.0, eps=1e-6):
|
||||
super(Norm, self).__init__()
|
||||
self.target_mean = target_mean
|
||||
self.target_std = target_std
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, x):
|
||||
dims = tuple(range(1, x.dim()))
|
||||
mean = x.mean(dim=dims, keepdim=True)
|
||||
std = x.std(dim=dims, keepdim=True)
|
||||
|
||||
# Normalize
|
||||
return self.target_std * (x - mean) / (std + self.eps) + self.target_mean
|
||||
|
||||
|
||||
class SparseAutoencoder(nn.Module):
|
||||
def __init__(self, input_dim, hidden_dim, output_dim):
|
||||
super(SparseAutoencoder, self).__init__()
|
||||
self.encoder = nn.Sequential(
|
||||
nn.Linear(input_dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Linear(hidden_dim, output_dim),
|
||||
)
|
||||
self.norm = Norm()
|
||||
self.decoder = nn.Sequential(
|
||||
nn.Linear(output_dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Linear(hidden_dim, input_dim),
|
||||
)
|
||||
self.last_run = None
|
||||
|
||||
def forward(self, x):
|
||||
self.last_run = {
|
||||
"input": x
|
||||
}
|
||||
x = self.encoder(x)
|
||||
x = self.norm(x)
|
||||
self.last_run["sparse"] = x
|
||||
x = self.decoder(x)
|
||||
x = self.norm(x)
|
||||
self.last_run["output"] = x
|
||||
return x
|
||||
|
||||
|
||||
class MLPR(nn.Module): # MLP with reshaping
|
||||
@@ -466,12 +513,18 @@ class VisionDirectAdapter(torch.nn.Module):
|
||||
is_pixtral = self.config.image_encoder_arch == "pixtral"
|
||||
|
||||
if adapter.config.clip_layer == "image_embeds":
|
||||
self.token_size = vision_model.config.projection_dim
|
||||
if isinstance(vision_model, SiglipVisionModel):
|
||||
self.token_size = vision_model.config.hidden_size
|
||||
else:
|
||||
self.token_size = vision_model.config.projection_dim
|
||||
else:
|
||||
self.token_size = vision_model.config.hidden_size
|
||||
|
||||
self.mid_size = self.token_size
|
||||
|
||||
if self.config.conv_pooling and self.config.conv_pooling_stacks > 1:
|
||||
self.mid_size = self.mid_size * self.config.conv_pooling_stacks
|
||||
|
||||
# if pixtral, use cross attn dim for more sparse representation if only doing double transformers
|
||||
if is_pixtral and self.config.flux_only_double:
|
||||
if is_flux:
|
||||
@@ -677,6 +730,27 @@ class VisionDirectAdapter(torch.nn.Module):
|
||||
in_dim=self.token_size,
|
||||
out_dim=self.mid_size,
|
||||
)
|
||||
|
||||
self.pool = None
|
||||
self.sparse_autoencoder = None
|
||||
if self.config.conv_pooling:
|
||||
vision_config = self.adapter_ref().vision_encoder.config
|
||||
# sequence_length = int((vision_config.image_size / vision_config.patch_size) ** 2 + 1)
|
||||
# siglip doesnt add 1
|
||||
sequence_length = int((vision_config.image_size / vision_config.patch_size) ** 2)
|
||||
self.pool = nn.Sequential(
|
||||
nn.Conv1d(sequence_length, self.config.conv_pooling_stacks, 1, bias=False),
|
||||
Norm(),
|
||||
)
|
||||
if self.config.sparse_autoencoder_dim is not None:
|
||||
hidden_dim = self.token_size * 2
|
||||
if hidden_dim > self.config.sparse_autoencoder_dim:
|
||||
hidden_dim = self.config.sparse_autoencoder_dim
|
||||
self.sparse_autoencoder = SparseAutoencoder(
|
||||
input_dim=self.token_size,
|
||||
hidden_dim=hidden_dim,
|
||||
output_dim=self.config.sparse_autoencoder_dim
|
||||
)
|
||||
|
||||
def state_dict(self, destination=None, prefix='', keep_vars=False):
|
||||
if self.config.train_scaler:
|
||||
@@ -699,6 +773,12 @@ class VisionDirectAdapter(torch.nn.Module):
|
||||
self.block_scaler.data = self.block_scaler.data.to(torch.float32)
|
||||
if self.resampler is not None:
|
||||
input = self.resampler(input)
|
||||
if self.pool is not None:
|
||||
input = self.pool(input)
|
||||
if self.config.conv_pooling_stacks > 1:
|
||||
input = torch.cat(torch.chunk(input, self.config.conv_pooling_stacks, dim=1), dim=2)
|
||||
if self.sparse_autoencoder is not None:
|
||||
input = self.sparse_autoencoder(input)
|
||||
return input
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
|
||||
Reference in New Issue
Block a user