From eeee4a1620f884da3739b7c9a47f20ee19428efd Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Thu, 28 Dec 2023 12:20:27 -0700 Subject: [PATCH] Created a size agnostic feature encoder (SAFE) model to be trained in replace of CLIP for ip adapters. It is mostly conv layers so will hopefully be able to handle facial features better than clip can. Also bug fixes --- extensions_built_in/sd_trainer/SDTrainer.py | 18 +- toolkit/config_modules.py | 4 +- toolkit/ip_adapter.py | 15 ++ .../models/size_agnostic_feature_encoder.py | 253 ++++++++++++++++++ toolkit/stable_diffusion_model.py | 2 +- 5 files changed, 286 insertions(+), 6 deletions(-) create mode 100644 toolkit/models/size_agnostic_feature_encoder.py diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index e0824b17..e82842d1 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -39,8 +39,6 @@ class SDTrainer(BaseSDTrainProcess): self.do_prior_prediction = False self.do_long_prompts = False self.do_guided_loss = False - if self.train_config.inverted_mask_prior: - self.do_prior_prediction = True def before_model_load(self): pass @@ -89,13 +87,15 @@ class SDTrainer(BaseSDTrainProcess): prior_mask_multiplier = None target_mask_multiplier = None + has_mask = batch.mask_tensor is not None + if self.train_config.match_noise_norm: # match the norm of the noise noise_norm = torch.linalg.vector_norm(noise, ord=2, dim=(1, 2, 3), keepdim=True) noise_pred_norm = torch.linalg.vector_norm(noise_pred, ord=2, dim=(1, 2, 3), keepdim=True) noise_pred = noise_pred * (noise_norm / noise_pred_norm) - if self.train_config.inverted_mask_prior and prior_pred is not None: + if self.train_config.inverted_mask_prior and prior_pred is not None and has_mask: # we need to make the noise prediction be a masked blending of noise and prior_pred stretched_mask_multiplier = value_map( mask_multiplier, @@ -867,7 +867,17 @@ class SDTrainer(BaseSDTrainProcess): conditional_embeds = self.adapter(conditional_embeds.detach(), conditional_clip_embeds) prior_pred = None - if ((has_adapter_img and self.assistant_adapter and match_adapter_assist) or self.do_prior_prediction) and not is_reg: + + do_reg_prior = False + if is_reg and (self.network is not None or self.adapter is not None): + # we are doing a reg image and we have a network or adapter + do_reg_prior = True + + do_inverted_masked_prior = False + if self.train_config.inverted_mask_prior and batch.mask_tensor is not None: + do_inverted_masked_prior = True + + if ((has_adapter_img and self.assistant_adapter and match_adapter_assist) or self.do_prior_prediction or do_reg_prior or do_inverted_masked_prior): with self.timer('prior predict'): prior_pred = self.get_prior_prediction( noisy_latents=noisy_latents, diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index a4df25d4..87474515 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -153,7 +153,9 @@ class AdapterConfig: self.num_tokens: int = num_tokens self.train_image_encoder: bool = kwargs.get('train_image_encoder', False) - self.image_encoder_arch: str = kwargs.get('image_encoder_arch', 'clip') # clip vit vit_hybrid + self.image_encoder_arch: str = kwargs.get('image_encoder_arch', 'clip') # clip vit vit_hybrid, safe + self.safe_reducer_channels: int = kwargs.get('safe_reducer_channels', 512) + self.safe_channels: int = kwargs.get('safe_channels', 2048) # clip vision self.trigger = kwargs.get('trigger', 'tri993r') diff --git a/toolkit/ip_adapter.py b/toolkit/ip_adapter.py index 62bd826d..2f49fa21 100644 --- a/toolkit/ip_adapter.py +++ b/toolkit/ip_adapter.py @@ -32,6 +32,8 @@ from transformers import ( ConvNextForImageClassification, ConvNextImageProcessor ) +from toolkit.models.size_agnostic_feature_encoder import SAFEImageProcessor, SAFEVisionModel + from transformers import ViTHybridImageProcessor, ViTHybridForImageClassification from transformers import ViTFeatureExtractor, ViTForImageClassification @@ -175,6 +177,19 @@ class IPAdapter(torch.nn.Module): except EnvironmentError: self.clip_image_processor = ViTFeatureExtractor() self.image_encoder = ViTForImageClassification.from_pretrained(adapter_config.image_encoder_path).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)) + elif self.config.image_encoder_arch == 'safe': + try: + self.clip_image_processor = SAFEImageProcessor.from_pretrained(adapter_config.image_encoder_path) + except EnvironmentError: + self.clip_image_processor = SAFEImageProcessor() + self.image_encoder = SAFEVisionModel( + in_channels=3, + num_tokens=self.config.num_tokens if self.config.adapter_type == 'ip+' else 1, + num_vectors=sd.unet.config['cross_attention_dim'] if self.config.adapter_type == 'ip+' else self.config.safe_channels, + reducer_channels=self.config.safe_reducer_channels, + channels=self.config.safe_channels, + downscale_factor=8 + ).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)) elif self.config.image_encoder_arch == 'convnext': try: self.clip_image_processor = ConvNextImageProcessor.from_pretrained(adapter_config.image_encoder_path) diff --git a/toolkit/models/size_agnostic_feature_encoder.py b/toolkit/models/size_agnostic_feature_encoder.py new file mode 100644 index 00000000..15f7a439 --- /dev/null +++ b/toolkit/models/size_agnostic_feature_encoder.py @@ -0,0 +1,253 @@ +import os +from typing import Union, Optional + +import torch +import torch.nn as nn +from transformers.image_processing_utils import BaseImageProcessor + + +class SAFEReducerBlock(nn.Module): + """ + This is the block that reduces the size of an vactor w and h be half. It is designed to be iterative + So it is run multiple times to reduce an image to a desired dimension while carrying a shrinking residual + along for the ride. This is done to preserve information. + """ + def __init__(self, channels=512): + super(SAFEReducerBlock, self).__init__() + self.channels = channels + + activation = nn.GELU + + self.reducer = nn.Sequential( + nn.Conv2d(channels, channels, kernel_size=3, padding=1), + nn.BatchNorm2d(channels), + activation(), + nn.Conv2d(channels, channels, kernel_size=3, padding=1), + nn.BatchNorm2d(channels), + activation(), + nn.AvgPool2d(kernel_size=2, stride=2), + ) + self.residual_shrink = nn.AvgPool2d(kernel_size=2, stride=2) + + def forward(self, x): + res = self.residual_shrink(x) + reduced = self.reducer(x) + return reduced + res + + +class SizeAgnosticFeatureEncoder(nn.Module): + def __init__( + self, + in_channels=3, + num_tokens=8, + num_vectors=768, + reducer_channels=512, + channels=2048, + downscale_factor: int = 8, + ): + super(SizeAgnosticFeatureEncoder, self).__init__() + self.num_tokens = num_tokens + self.num_vectors = num_vectors + self.channels = channels + self.reducer_channels = reducer_channels + self.gradient_checkpointing = False + + # input is minimum of (bs, 3, 256, 256) + + subpixel_channels = in_channels * downscale_factor ** 2 + + # PixelUnshuffle(8 = # (bs, 3, 32, 32) -> (bs, 192, 32, 32) + # PixelUnshuffle(16 = # (bs, 3, 16, 16) -> (bs, 48, 16, 16) + + self.unshuffle = nn.PixelUnshuffle(downscale_factor) # (bs, 3, 256, 256) -> (bs, 192, 32, 32) + + self.conv_in = nn.Conv2d(subpixel_channels, reducer_channels, kernel_size=3, padding=1) # (bs, 192, 32, 32) -> (bs, 512, 32, 32) + + # run as many times as needed to get to min feature of 8 on the smallest dimension + self.reducer = SAFEReducerBlock(reducer_channels) # (bs, 512, 32, 32) -> (bs, 512, 8, 8) + + self.reduced_out = nn.Conv2d( + reducer_channels, self.channels, kernel_size=3, padding=1 + ) # (bs, 512, 8, 8) -> (bs, 2048, 8, 8) + + # (bs, 2048, 8, 8) + self.block1 = SAFEReducerBlock(self.channels) # (bs, 2048, 8, 8) -> (bs, 2048, 4, 4) + self.block2 = SAFEReducerBlock(self.channels) # (bs, 2048, 8, 8) -> (bs, 2048, 2, 2) + + # reduce mean of dims 2 and 3 + self.adaptive_pool = nn.Sequential( + nn.AdaptiveAvgPool2d((1, 1)), + nn.Flatten(), + ) + + # (bs, 2048) + # linear layer to (bs, self.num_vectors * self.num_tokens) + self.fc1 = nn.Linear(self.channels, self.num_vectors * self.num_tokens) + + # (bs, self.num_vectors * self.num_tokens) = (bs, 8 * 768) = (bs, 6144) + + def forward(self, x): + x = self.unshuffle(x) + x = self.conv_in(x) + + while True: + # reduce until we get as close to 8x8 as possible without going under + x = self.reducer(x) + if x.shape[2] // 2 < 8 or x.shape[3] // 2 < 8: + break + + x = self.reduced_out(x) + x = self.block1(x) + x = self.block2(x) + x = self.adaptive_pool(x) + x = self.fc1(x) + + # reshape + x = x.view(-1, self.num_tokens, self.num_vectors) + + return x + + +class SAFEIPReturn: + def __init__(self, pixel_values): + self.pixel_values = pixel_values + + +class SAFEImageProcessor(BaseImageProcessor): + def __init__( + self, + max_size=1024, + min_size=256, + **kwargs + ): + super().__init__(**kwargs) + self.max_size = max_size + self.min_size = min_size + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: Union[str, os.PathLike], + cache_dir: Optional[Union[str, os.PathLike]] = None, + force_download: bool = False, + local_files_only: bool = False, + token: Optional[Union[str, bool]] = None, + revision: str = "main", + **kwargs, + ): + # not needed + return cls(**kwargs) + + def __call__( + self, + images, + **kwargs + ): + # TODO allow for random resizing + # comes in 0 - 1 range + # if any size is smaller than 256, resize to 256 + # if any size is larger than max_size, resize to max_size + if images.min() < -0.3 or images.max() > 1.3: + raise ValueError( + "images fed into SAFEImageProcessor values must be between 0 and 1. Got min: {}, max: {}".format( + images.min(), images.max() + )) + + # make sure we have (bs, 3, h, w) + while len(images.shape) < 4: + images = images.unsqueeze(0) + + # expand to 3 channels if we only have 1 channel + if images.shape[1] == 1: + images = torch.cat([images, images, images], dim=1) + + width = images.shape[3] + height = images.shape[2] + + if width < self.min_size or height < self.min_size: + # scale up so that the smallest size is 256 + if width < height: + new_width = self.min_size + new_height = int(height * (self.min_size / width)) + else: + new_height = self.min_size + new_width = int(width * (self.min_size / height)) + images = nn.functional.interpolate(images, size=(new_height, new_width), mode='bilinear', + align_corners=False) + + elif width > self.max_size or height > self.max_size: + # scale down so that the largest size is max_size but do not shrink the other size below 256 + if width > height: + new_width = self.max_size + new_height = int(height * (self.max_size / width)) + else: + new_height = self.max_size + new_width = int(width * (self.max_size / height)) + + if new_width < self.min_size: + new_width = self.min_size + new_height = int(height * (self.min_size / width)) + + if new_height < self.min_size: + new_height = self.min_size + new_width = int(width * (self.min_size / height)) + + images = nn.functional.interpolate(images, size=(new_height, new_width), mode='bilinear', + align_corners=False) + + # if wither side is not divisible by 16, mirror pad to make it so + if images.shape[2] % 16 != 0: + pad = 16 - (images.shape[2] % 16) + pad1 = pad // 2 + pad2 = pad - pad1 + images = nn.functional.pad(images, (0, 0, pad1, pad2), mode='reflect') + if images.shape[3] % 16 != 0: + pad = 16 - (images.shape[3] % 16) + pad1 = pad // 2 + pad2 = pad - pad1 + images = nn.functional.pad(images, (pad1, pad2, 0, 0), mode='reflect') + + return SAFEIPReturn(images) + + +class SAFEVMConfig: + def __init__( + self, + in_channels=3, + num_tokens=8, + num_vectors=768, + reducer_channels=512, + channels=2048, + downscale_factor: int = 8, + **kwargs + ): + self.in_channels = in_channels + self.num_tokens = num_tokens + self.num_vectors = num_vectors + self.reducer_channels = reducer_channels + self.channels = channels + self.downscale_factor = downscale_factor + + self.hidden_size = num_vectors + self.projection_dim = num_vectors + + +class SAFEVMReturn: + def __init__(self, output): + self.output = output + # todo actually do hidden states. This is just for code compatability for now + self.hidden_states = [output for _ in range(13)] + + +class SAFEVisionModel(SizeAgnosticFeatureEncoder): + def __init__(self, **kwargs): + self.config = SAFEVMConfig(**kwargs) + super().__init__(**kwargs) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + # not needed + return SAFEVisionModel(**kwargs) + + def forward(self, x, **kwargs): + return SAFEVMReturn(super().forward(x)) diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index 6bac0aee..36fddd43 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -1447,7 +1447,7 @@ class StableDiffusion: } if self.adapter is not None: if isinstance(self.adapter, IPAdapter): - requires_grad = self.adapter.adapter_modules.training + requires_grad = self.adapter.image_proj_model.training adapter_device = self.unet.device elif isinstance(self.adapter, T2IAdapter): requires_grad = self.adapter.adapter.conv_in.weight.requires_grad