diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index bdb36fa0..c897e9ed 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -211,6 +211,11 @@ class AdapterConfig: self.pixtral_random_image_size: int = kwargs.get('pixtral_random_image_size', False) self.flux_only_double: bool = kwargs.get('flux_only_double', False) + + # train and use a conv layer to pool the embedding + self.conv_pooling: bool = kwargs.get('conv_pooling', False) + self.conv_pooling_stacks: int = kwargs.get('conv_pooling_stacks', 1) + self.sparse_autoencoder_dim: Optional[int] = kwargs.get('sparse_autoencoder_dim', None) class EmbeddingConfig: diff --git a/toolkit/custom_adapter.py b/toolkit/custom_adapter.py index 1d0e28f9..9fae4090 100644 --- a/toolkit/custom_adapter.py +++ b/toolkit/custom_adapter.py @@ -407,7 +407,7 @@ class CustomAdapter(torch.nn.Module): if 'vd_adapter' in state_dict: self.vd_adapter.load_state_dict(state_dict['vd_adapter'], strict=strict) if 'dvadapter' in state_dict: - self.vd_adapter.load_state_dict(state_dict['dvadapter'], strict=strict) + self.vd_adapter.load_state_dict(state_dict['dvadapter'], strict=False) if 'sv_adapter' in state_dict: self.single_value_adapter.load_state_dict(state_dict['sv_adapter'], strict=strict) @@ -732,8 +732,7 @@ class CustomAdapter(torch.nn.Module): def train(self, mode: bool = True): if self.config.train_image_encoder: self.vision_encoder.train(mode) - else: - super().train(mode) + super().train(mode) def trigger_pre_te( self, @@ -879,7 +878,10 @@ class CustomAdapter(torch.nn.Module): elif self.config.clip_layer == 'last_hidden_state': clip_image_embeds = clip_output.hidden_states[-1] else: - clip_image_embeds = clip_output.image_embeds + if hasattr(clip_output, 'image_embeds'): + clip_image_embeds = clip_output.image_embeds + elif hasattr(clip_output, 'pooler_output'): + clip_image_embeds = clip_output.pooler_output # TODO should we always norm image embeds? # get norm embeddings l2_norm = torch.norm(clip_image_embeds, p=2) @@ -931,8 +933,12 @@ class CustomAdapter(torch.nn.Module): yield from attn_processor.parameters(recurse) if self.config.train_image_encoder: yield from self.vision_encoder.parameters(recurse) - if self.config.num_tokens: + if self.vd_adapter.resampler is not None: yield from self.vd_adapter.resampler.parameters(recurse) + if self.vd_adapter.pool is not None: + yield from self.vd_adapter.pool.parameters(recurse) + if self.vd_adapter.sparse_autoencoder is not None: + yield from self.vd_adapter.sparse_autoencoder.parameters(recurse) elif self.config.type == 'te_augmenter': yield from self.te_augmenter.parameters(recurse) if self.config.train_image_encoder: diff --git a/toolkit/dataloader_mixins.py b/toolkit/dataloader_mixins.py index 539ba479..1bba4431 100644 --- a/toolkit/dataloader_mixins.py +++ b/toolkit/dataloader_mixins.py @@ -13,7 +13,7 @@ import numpy as np import torch from safetensors.torch import load_file, save_file from tqdm import tqdm -from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection +from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, SiglipImageProcessor from toolkit.basic import flush, value_map from toolkit.buckets import get_bucket_for_image_size, get_resolution @@ -764,7 +764,8 @@ class ClipImageFileItemDTOMixin: return self.clip_image_path def load_clip_image(self: 'FileItemDTO'): - is_dynamic_size_and_aspect = isinstance(self.clip_image_processor, PixtralVisionImagePreprocessorCompatible) + is_dynamic_size_and_aspect = isinstance(self.clip_image_processor, PixtralVisionImagePreprocessorCompatible) or \ + isinstance(self.clip_image_processor, SiglipImageProcessor) if self.is_vision_clip_cached: self.clip_image_embeds = load_file(self.get_clip_vision_embeddings_path()) @@ -794,21 +795,7 @@ class ClipImageFileItemDTOMixin: img = img.transpose(Image.FLIP_TOP_BOTTOM) if is_dynamic_size_and_aspect: - # just match the bucket size for now - if self.dataset_config.buckets: - # scale and crop based on file item - img = img.resize((self.scale_to_width, self.scale_to_height), Image.BICUBIC) - # img = transforms.CenterCrop((self.crop_height, self.crop_width))(img) - # crop - img = img.crop(( - self.crop_x, - self.crop_y, - self.crop_x + self.crop_width, - self.crop_y + self.crop_height - )) - else: - raise Exception("Control images not supported for non-bucket datasets") - + pass # let the image processor handle it elif img.width != img.height: min_size = min(img.width, img.height) if self.dataset_config.square_crop: diff --git a/toolkit/models/vd_adapter.py b/toolkit/models/vd_adapter.py index d3f396b0..4232bf86 100644 --- a/toolkit/models/vd_adapter.py +++ b/toolkit/models/vd_adapter.py @@ -10,6 +10,7 @@ from collections import OrderedDict from diffusers import Transformer2DModel, FluxTransformer2DModel from transformers import T5EncoderModel, CLIPTextModel, CLIPTokenizer, T5Tokenizer, CLIPVisionModelWithProjection from toolkit.models.pixtral_vision import PixtralVisionEncoder, PixtralVisionImagePreprocessor, VisionLanguageAdapter +from transformers import SiglipImageProcessor, SiglipVisionModel from toolkit.config_modules import AdapterConfig from toolkit.paths import REPOS_ROOT @@ -19,6 +20,52 @@ sys.path.append(REPOS_ROOT) if TYPE_CHECKING: from toolkit.stable_diffusion_model import StableDiffusion from toolkit.custom_adapter import CustomAdapter + + +# matches distribution of randn +class Norm(nn.Module): + def __init__(self, target_mean=0.0, target_std=1.0, eps=1e-6): + super(Norm, self).__init__() + self.target_mean = target_mean + self.target_std = target_std + self.eps = eps + + def forward(self, x): + dims = tuple(range(1, x.dim())) + mean = x.mean(dim=dims, keepdim=True) + std = x.std(dim=dims, keepdim=True) + + # Normalize + return self.target_std * (x - mean) / (std + self.eps) + self.target_mean + + +class SparseAutoencoder(nn.Module): + def __init__(self, input_dim, hidden_dim, output_dim): + super(SparseAutoencoder, self).__init__() + self.encoder = nn.Sequential( + nn.Linear(input_dim, hidden_dim), + nn.GELU(), + nn.Linear(hidden_dim, output_dim), + ) + self.norm = Norm() + self.decoder = nn.Sequential( + nn.Linear(output_dim, hidden_dim), + nn.GELU(), + nn.Linear(hidden_dim, input_dim), + ) + self.last_run = None + + def forward(self, x): + self.last_run = { + "input": x + } + x = self.encoder(x) + x = self.norm(x) + self.last_run["sparse"] = x + x = self.decoder(x) + x = self.norm(x) + self.last_run["output"] = x + return x class MLPR(nn.Module): # MLP with reshaping @@ -466,12 +513,18 @@ class VisionDirectAdapter(torch.nn.Module): is_pixtral = self.config.image_encoder_arch == "pixtral" if adapter.config.clip_layer == "image_embeds": - self.token_size = vision_model.config.projection_dim + if isinstance(vision_model, SiglipVisionModel): + self.token_size = vision_model.config.hidden_size + else: + self.token_size = vision_model.config.projection_dim else: self.token_size = vision_model.config.hidden_size self.mid_size = self.token_size + if self.config.conv_pooling and self.config.conv_pooling_stacks > 1: + self.mid_size = self.mid_size * self.config.conv_pooling_stacks + # if pixtral, use cross attn dim for more sparse representation if only doing double transformers if is_pixtral and self.config.flux_only_double: if is_flux: @@ -677,6 +730,27 @@ class VisionDirectAdapter(torch.nn.Module): in_dim=self.token_size, out_dim=self.mid_size, ) + + self.pool = None + self.sparse_autoencoder = None + if self.config.conv_pooling: + vision_config = self.adapter_ref().vision_encoder.config + # sequence_length = int((vision_config.image_size / vision_config.patch_size) ** 2 + 1) + # siglip doesnt add 1 + sequence_length = int((vision_config.image_size / vision_config.patch_size) ** 2) + self.pool = nn.Sequential( + nn.Conv1d(sequence_length, self.config.conv_pooling_stacks, 1, bias=False), + Norm(), + ) + if self.config.sparse_autoencoder_dim is not None: + hidden_dim = self.token_size * 2 + if hidden_dim > self.config.sparse_autoencoder_dim: + hidden_dim = self.config.sparse_autoencoder_dim + self.sparse_autoencoder = SparseAutoencoder( + input_dim=self.token_size, + hidden_dim=hidden_dim, + output_dim=self.config.sparse_autoencoder_dim + ) def state_dict(self, destination=None, prefix='', keep_vars=False): if self.config.train_scaler: @@ -699,6 +773,12 @@ class VisionDirectAdapter(torch.nn.Module): self.block_scaler.data = self.block_scaler.data.to(torch.float32) if self.resampler is not None: input = self.resampler(input) + if self.pool is not None: + input = self.pool(input) + if self.config.conv_pooling_stacks > 1: + input = torch.cat(torch.chunk(input, self.config.conv_pooling_stacks, dim=1), dim=2) + if self.sparse_autoencoder is not None: + input = self.sparse_autoencoder(input) return input def to(self, *args, **kwargs):