Added some additional experimental things to the vision direct encoder

This commit is contained in:
Jaret Burkett
2024-10-10 19:42:26 +00:00
parent ab22674980
commit 3922981996
4 changed files with 101 additions and 23 deletions

View File

@@ -211,6 +211,11 @@ class AdapterConfig:
self.pixtral_random_image_size: int = kwargs.get('pixtral_random_image_size', False)
self.flux_only_double: bool = kwargs.get('flux_only_double', False)
# train and use a conv layer to pool the embedding
self.conv_pooling: bool = kwargs.get('conv_pooling', False)
self.conv_pooling_stacks: int = kwargs.get('conv_pooling_stacks', 1)
self.sparse_autoencoder_dim: Optional[int] = kwargs.get('sparse_autoencoder_dim', None)
class EmbeddingConfig:

View File

@@ -407,7 +407,7 @@ class CustomAdapter(torch.nn.Module):
if 'vd_adapter' in state_dict:
self.vd_adapter.load_state_dict(state_dict['vd_adapter'], strict=strict)
if 'dvadapter' in state_dict:
self.vd_adapter.load_state_dict(state_dict['dvadapter'], strict=strict)
self.vd_adapter.load_state_dict(state_dict['dvadapter'], strict=False)
if 'sv_adapter' in state_dict:
self.single_value_adapter.load_state_dict(state_dict['sv_adapter'], strict=strict)
@@ -732,8 +732,7 @@ class CustomAdapter(torch.nn.Module):
def train(self, mode: bool = True):
if self.config.train_image_encoder:
self.vision_encoder.train(mode)
else:
super().train(mode)
super().train(mode)
def trigger_pre_te(
self,
@@ -879,7 +878,10 @@ class CustomAdapter(torch.nn.Module):
elif self.config.clip_layer == 'last_hidden_state':
clip_image_embeds = clip_output.hidden_states[-1]
else:
clip_image_embeds = clip_output.image_embeds
if hasattr(clip_output, 'image_embeds'):
clip_image_embeds = clip_output.image_embeds
elif hasattr(clip_output, 'pooler_output'):
clip_image_embeds = clip_output.pooler_output
# TODO should we always norm image embeds?
# get norm embeddings
l2_norm = torch.norm(clip_image_embeds, p=2)
@@ -931,8 +933,12 @@ class CustomAdapter(torch.nn.Module):
yield from attn_processor.parameters(recurse)
if self.config.train_image_encoder:
yield from self.vision_encoder.parameters(recurse)
if self.config.num_tokens:
if self.vd_adapter.resampler is not None:
yield from self.vd_adapter.resampler.parameters(recurse)
if self.vd_adapter.pool is not None:
yield from self.vd_adapter.pool.parameters(recurse)
if self.vd_adapter.sparse_autoencoder is not None:
yield from self.vd_adapter.sparse_autoencoder.parameters(recurse)
elif self.config.type == 'te_augmenter':
yield from self.te_augmenter.parameters(recurse)
if self.config.train_image_encoder:

View File

@@ -13,7 +13,7 @@ import numpy as np
import torch
from safetensors.torch import load_file, save_file
from tqdm import tqdm
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, SiglipImageProcessor
from toolkit.basic import flush, value_map
from toolkit.buckets import get_bucket_for_image_size, get_resolution
@@ -764,7 +764,8 @@ class ClipImageFileItemDTOMixin:
return self.clip_image_path
def load_clip_image(self: 'FileItemDTO'):
is_dynamic_size_and_aspect = isinstance(self.clip_image_processor, PixtralVisionImagePreprocessorCompatible)
is_dynamic_size_and_aspect = isinstance(self.clip_image_processor, PixtralVisionImagePreprocessorCompatible) or \
isinstance(self.clip_image_processor, SiglipImageProcessor)
if self.is_vision_clip_cached:
self.clip_image_embeds = load_file(self.get_clip_vision_embeddings_path())
@@ -794,21 +795,7 @@ class ClipImageFileItemDTOMixin:
img = img.transpose(Image.FLIP_TOP_BOTTOM)
if is_dynamic_size_and_aspect:
# just match the bucket size for now
if self.dataset_config.buckets:
# scale and crop based on file item
img = img.resize((self.scale_to_width, self.scale_to_height), Image.BICUBIC)
# img = transforms.CenterCrop((self.crop_height, self.crop_width))(img)
# crop
img = img.crop((
self.crop_x,
self.crop_y,
self.crop_x + self.crop_width,
self.crop_y + self.crop_height
))
else:
raise Exception("Control images not supported for non-bucket datasets")
pass # let the image processor handle it
elif img.width != img.height:
min_size = min(img.width, img.height)
if self.dataset_config.square_crop:

View File

@@ -10,6 +10,7 @@ from collections import OrderedDict
from diffusers import Transformer2DModel, FluxTransformer2DModel
from transformers import T5EncoderModel, CLIPTextModel, CLIPTokenizer, T5Tokenizer, CLIPVisionModelWithProjection
from toolkit.models.pixtral_vision import PixtralVisionEncoder, PixtralVisionImagePreprocessor, VisionLanguageAdapter
from transformers import SiglipImageProcessor, SiglipVisionModel
from toolkit.config_modules import AdapterConfig
from toolkit.paths import REPOS_ROOT
@@ -19,6 +20,52 @@ sys.path.append(REPOS_ROOT)
if TYPE_CHECKING:
from toolkit.stable_diffusion_model import StableDiffusion
from toolkit.custom_adapter import CustomAdapter
# matches distribution of randn
class Norm(nn.Module):
def __init__(self, target_mean=0.0, target_std=1.0, eps=1e-6):
super(Norm, self).__init__()
self.target_mean = target_mean
self.target_std = target_std
self.eps = eps
def forward(self, x):
dims = tuple(range(1, x.dim()))
mean = x.mean(dim=dims, keepdim=True)
std = x.std(dim=dims, keepdim=True)
# Normalize
return self.target_std * (x - mean) / (std + self.eps) + self.target_mean
class SparseAutoencoder(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim):
super(SparseAutoencoder, self).__init__()
self.encoder = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, output_dim),
)
self.norm = Norm()
self.decoder = nn.Sequential(
nn.Linear(output_dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, input_dim),
)
self.last_run = None
def forward(self, x):
self.last_run = {
"input": x
}
x = self.encoder(x)
x = self.norm(x)
self.last_run["sparse"] = x
x = self.decoder(x)
x = self.norm(x)
self.last_run["output"] = x
return x
class MLPR(nn.Module): # MLP with reshaping
@@ -466,12 +513,18 @@ class VisionDirectAdapter(torch.nn.Module):
is_pixtral = self.config.image_encoder_arch == "pixtral"
if adapter.config.clip_layer == "image_embeds":
self.token_size = vision_model.config.projection_dim
if isinstance(vision_model, SiglipVisionModel):
self.token_size = vision_model.config.hidden_size
else:
self.token_size = vision_model.config.projection_dim
else:
self.token_size = vision_model.config.hidden_size
self.mid_size = self.token_size
if self.config.conv_pooling and self.config.conv_pooling_stacks > 1:
self.mid_size = self.mid_size * self.config.conv_pooling_stacks
# if pixtral, use cross attn dim for more sparse representation if only doing double transformers
if is_pixtral and self.config.flux_only_double:
if is_flux:
@@ -677,6 +730,27 @@ class VisionDirectAdapter(torch.nn.Module):
in_dim=self.token_size,
out_dim=self.mid_size,
)
self.pool = None
self.sparse_autoencoder = None
if self.config.conv_pooling:
vision_config = self.adapter_ref().vision_encoder.config
# sequence_length = int((vision_config.image_size / vision_config.patch_size) ** 2 + 1)
# siglip doesnt add 1
sequence_length = int((vision_config.image_size / vision_config.patch_size) ** 2)
self.pool = nn.Sequential(
nn.Conv1d(sequence_length, self.config.conv_pooling_stacks, 1, bias=False),
Norm(),
)
if self.config.sparse_autoencoder_dim is not None:
hidden_dim = self.token_size * 2
if hidden_dim > self.config.sparse_autoencoder_dim:
hidden_dim = self.config.sparse_autoencoder_dim
self.sparse_autoencoder = SparseAutoencoder(
input_dim=self.token_size,
hidden_dim=hidden_dim,
output_dim=self.config.sparse_autoencoder_dim
)
def state_dict(self, destination=None, prefix='', keep_vars=False):
if self.config.train_scaler:
@@ -699,6 +773,12 @@ class VisionDirectAdapter(torch.nn.Module):
self.block_scaler.data = self.block_scaler.data.to(torch.float32)
if self.resampler is not None:
input = self.resampler(input)
if self.pool is not None:
input = self.pool(input)
if self.config.conv_pooling_stacks > 1:
input = torch.cat(torch.chunk(input, self.config.conv_pooling_stacks, dim=1), dim=2)
if self.sparse_autoencoder is not None:
input = self.sparse_autoencoder(input)
return input
def to(self, *args, **kwargs):