diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index e0824b17..e82842d1 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -39,8 +39,6 @@ class SDTrainer(BaseSDTrainProcess): self.do_prior_prediction = False self.do_long_prompts = False self.do_guided_loss = False - if self.train_config.inverted_mask_prior: - self.do_prior_prediction = True def before_model_load(self): pass @@ -89,13 +87,15 @@ class SDTrainer(BaseSDTrainProcess): prior_mask_multiplier = None target_mask_multiplier = None + has_mask = batch.mask_tensor is not None + if self.train_config.match_noise_norm: # match the norm of the noise noise_norm = torch.linalg.vector_norm(noise, ord=2, dim=(1, 2, 3), keepdim=True) noise_pred_norm = torch.linalg.vector_norm(noise_pred, ord=2, dim=(1, 2, 3), keepdim=True) noise_pred = noise_pred * (noise_norm / noise_pred_norm) - if self.train_config.inverted_mask_prior and prior_pred is not None: + if self.train_config.inverted_mask_prior and prior_pred is not None and has_mask: # we need to make the noise prediction be a masked blending of noise and prior_pred stretched_mask_multiplier = value_map( mask_multiplier, @@ -867,7 +867,17 @@ class SDTrainer(BaseSDTrainProcess): conditional_embeds = self.adapter(conditional_embeds.detach(), conditional_clip_embeds) prior_pred = None - if ((has_adapter_img and self.assistant_adapter and match_adapter_assist) or self.do_prior_prediction) and not is_reg: + + do_reg_prior = False + if is_reg and (self.network is not None or self.adapter is not None): + # we are doing a reg image and we have a network or adapter + do_reg_prior = True + + do_inverted_masked_prior = False + if self.train_config.inverted_mask_prior and batch.mask_tensor is not None: + do_inverted_masked_prior = True + + if ((has_adapter_img and self.assistant_adapter and match_adapter_assist) or self.do_prior_prediction or do_reg_prior or do_inverted_masked_prior): with self.timer('prior predict'): prior_pred = self.get_prior_prediction( noisy_latents=noisy_latents, diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index a4df25d4..87474515 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -153,7 +153,9 @@ class AdapterConfig: self.num_tokens: int = num_tokens self.train_image_encoder: bool = kwargs.get('train_image_encoder', False) - self.image_encoder_arch: str = kwargs.get('image_encoder_arch', 'clip') # clip vit vit_hybrid + self.image_encoder_arch: str = kwargs.get('image_encoder_arch', 'clip') # clip vit vit_hybrid, safe + self.safe_reducer_channels: int = kwargs.get('safe_reducer_channels', 512) + self.safe_channels: int = kwargs.get('safe_channels', 2048) # clip vision self.trigger = kwargs.get('trigger', 'tri993r') diff --git a/toolkit/ip_adapter.py b/toolkit/ip_adapter.py index 62bd826d..2f49fa21 100644 --- a/toolkit/ip_adapter.py +++ b/toolkit/ip_adapter.py @@ -32,6 +32,8 @@ from transformers import ( ConvNextForImageClassification, ConvNextImageProcessor ) +from toolkit.models.size_agnostic_feature_encoder import SAFEImageProcessor, SAFEVisionModel + from transformers import ViTHybridImageProcessor, ViTHybridForImageClassification from transformers import ViTFeatureExtractor, ViTForImageClassification @@ -175,6 +177,19 @@ class IPAdapter(torch.nn.Module): except EnvironmentError: self.clip_image_processor = ViTFeatureExtractor() self.image_encoder = ViTForImageClassification.from_pretrained(adapter_config.image_encoder_path).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)) + elif self.config.image_encoder_arch == 'safe': + try: + self.clip_image_processor = SAFEImageProcessor.from_pretrained(adapter_config.image_encoder_path) + except EnvironmentError: + self.clip_image_processor = SAFEImageProcessor() + self.image_encoder = SAFEVisionModel( + in_channels=3, + num_tokens=self.config.num_tokens if self.config.adapter_type == 'ip+' else 1, + num_vectors=sd.unet.config['cross_attention_dim'] if self.config.adapter_type == 'ip+' else self.config.safe_channels, + reducer_channels=self.config.safe_reducer_channels, + channels=self.config.safe_channels, + downscale_factor=8 + ).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)) elif self.config.image_encoder_arch == 'convnext': try: self.clip_image_processor = ConvNextImageProcessor.from_pretrained(adapter_config.image_encoder_path) diff --git a/toolkit/models/size_agnostic_feature_encoder.py b/toolkit/models/size_agnostic_feature_encoder.py new file mode 100644 index 00000000..15f7a439 --- /dev/null +++ b/toolkit/models/size_agnostic_feature_encoder.py @@ -0,0 +1,253 @@ +import os +from typing import Union, Optional + +import torch +import torch.nn as nn +from transformers.image_processing_utils import BaseImageProcessor + + +class SAFEReducerBlock(nn.Module): + """ + This is the block that reduces the size of an vactor w and h be half. It is designed to be iterative + So it is run multiple times to reduce an image to a desired dimension while carrying a shrinking residual + along for the ride. This is done to preserve information. + """ + def __init__(self, channels=512): + super(SAFEReducerBlock, self).__init__() + self.channels = channels + + activation = nn.GELU + + self.reducer = nn.Sequential( + nn.Conv2d(channels, channels, kernel_size=3, padding=1), + nn.BatchNorm2d(channels), + activation(), + nn.Conv2d(channels, channels, kernel_size=3, padding=1), + nn.BatchNorm2d(channels), + activation(), + nn.AvgPool2d(kernel_size=2, stride=2), + ) + self.residual_shrink = nn.AvgPool2d(kernel_size=2, stride=2) + + def forward(self, x): + res = self.residual_shrink(x) + reduced = self.reducer(x) + return reduced + res + + +class SizeAgnosticFeatureEncoder(nn.Module): + def __init__( + self, + in_channels=3, + num_tokens=8, + num_vectors=768, + reducer_channels=512, + channels=2048, + downscale_factor: int = 8, + ): + super(SizeAgnosticFeatureEncoder, self).__init__() + self.num_tokens = num_tokens + self.num_vectors = num_vectors + self.channels = channels + self.reducer_channels = reducer_channels + self.gradient_checkpointing = False + + # input is minimum of (bs, 3, 256, 256) + + subpixel_channels = in_channels * downscale_factor ** 2 + + # PixelUnshuffle(8 = # (bs, 3, 32, 32) -> (bs, 192, 32, 32) + # PixelUnshuffle(16 = # (bs, 3, 16, 16) -> (bs, 48, 16, 16) + + self.unshuffle = nn.PixelUnshuffle(downscale_factor) # (bs, 3, 256, 256) -> (bs, 192, 32, 32) + + self.conv_in = nn.Conv2d(subpixel_channels, reducer_channels, kernel_size=3, padding=1) # (bs, 192, 32, 32) -> (bs, 512, 32, 32) + + # run as many times as needed to get to min feature of 8 on the smallest dimension + self.reducer = SAFEReducerBlock(reducer_channels) # (bs, 512, 32, 32) -> (bs, 512, 8, 8) + + self.reduced_out = nn.Conv2d( + reducer_channels, self.channels, kernel_size=3, padding=1 + ) # (bs, 512, 8, 8) -> (bs, 2048, 8, 8) + + # (bs, 2048, 8, 8) + self.block1 = SAFEReducerBlock(self.channels) # (bs, 2048, 8, 8) -> (bs, 2048, 4, 4) + self.block2 = SAFEReducerBlock(self.channels) # (bs, 2048, 8, 8) -> (bs, 2048, 2, 2) + + # reduce mean of dims 2 and 3 + self.adaptive_pool = nn.Sequential( + nn.AdaptiveAvgPool2d((1, 1)), + nn.Flatten(), + ) + + # (bs, 2048) + # linear layer to (bs, self.num_vectors * self.num_tokens) + self.fc1 = nn.Linear(self.channels, self.num_vectors * self.num_tokens) + + # (bs, self.num_vectors * self.num_tokens) = (bs, 8 * 768) = (bs, 6144) + + def forward(self, x): + x = self.unshuffle(x) + x = self.conv_in(x) + + while True: + # reduce until we get as close to 8x8 as possible without going under + x = self.reducer(x) + if x.shape[2] // 2 < 8 or x.shape[3] // 2 < 8: + break + + x = self.reduced_out(x) + x = self.block1(x) + x = self.block2(x) + x = self.adaptive_pool(x) + x = self.fc1(x) + + # reshape + x = x.view(-1, self.num_tokens, self.num_vectors) + + return x + + +class SAFEIPReturn: + def __init__(self, pixel_values): + self.pixel_values = pixel_values + + +class SAFEImageProcessor(BaseImageProcessor): + def __init__( + self, + max_size=1024, + min_size=256, + **kwargs + ): + super().__init__(**kwargs) + self.max_size = max_size + self.min_size = min_size + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: Union[str, os.PathLike], + cache_dir: Optional[Union[str, os.PathLike]] = None, + force_download: bool = False, + local_files_only: bool = False, + token: Optional[Union[str, bool]] = None, + revision: str = "main", + **kwargs, + ): + # not needed + return cls(**kwargs) + + def __call__( + self, + images, + **kwargs + ): + # TODO allow for random resizing + # comes in 0 - 1 range + # if any size is smaller than 256, resize to 256 + # if any size is larger than max_size, resize to max_size + if images.min() < -0.3 or images.max() > 1.3: + raise ValueError( + "images fed into SAFEImageProcessor values must be between 0 and 1. Got min: {}, max: {}".format( + images.min(), images.max() + )) + + # make sure we have (bs, 3, h, w) + while len(images.shape) < 4: + images = images.unsqueeze(0) + + # expand to 3 channels if we only have 1 channel + if images.shape[1] == 1: + images = torch.cat([images, images, images], dim=1) + + width = images.shape[3] + height = images.shape[2] + + if width < self.min_size or height < self.min_size: + # scale up so that the smallest size is 256 + if width < height: + new_width = self.min_size + new_height = int(height * (self.min_size / width)) + else: + new_height = self.min_size + new_width = int(width * (self.min_size / height)) + images = nn.functional.interpolate(images, size=(new_height, new_width), mode='bilinear', + align_corners=False) + + elif width > self.max_size or height > self.max_size: + # scale down so that the largest size is max_size but do not shrink the other size below 256 + if width > height: + new_width = self.max_size + new_height = int(height * (self.max_size / width)) + else: + new_height = self.max_size + new_width = int(width * (self.max_size / height)) + + if new_width < self.min_size: + new_width = self.min_size + new_height = int(height * (self.min_size / width)) + + if new_height < self.min_size: + new_height = self.min_size + new_width = int(width * (self.min_size / height)) + + images = nn.functional.interpolate(images, size=(new_height, new_width), mode='bilinear', + align_corners=False) + + # if wither side is not divisible by 16, mirror pad to make it so + if images.shape[2] % 16 != 0: + pad = 16 - (images.shape[2] % 16) + pad1 = pad // 2 + pad2 = pad - pad1 + images = nn.functional.pad(images, (0, 0, pad1, pad2), mode='reflect') + if images.shape[3] % 16 != 0: + pad = 16 - (images.shape[3] % 16) + pad1 = pad // 2 + pad2 = pad - pad1 + images = nn.functional.pad(images, (pad1, pad2, 0, 0), mode='reflect') + + return SAFEIPReturn(images) + + +class SAFEVMConfig: + def __init__( + self, + in_channels=3, + num_tokens=8, + num_vectors=768, + reducer_channels=512, + channels=2048, + downscale_factor: int = 8, + **kwargs + ): + self.in_channels = in_channels + self.num_tokens = num_tokens + self.num_vectors = num_vectors + self.reducer_channels = reducer_channels + self.channels = channels + self.downscale_factor = downscale_factor + + self.hidden_size = num_vectors + self.projection_dim = num_vectors + + +class SAFEVMReturn: + def __init__(self, output): + self.output = output + # todo actually do hidden states. This is just for code compatability for now + self.hidden_states = [output for _ in range(13)] + + +class SAFEVisionModel(SizeAgnosticFeatureEncoder): + def __init__(self, **kwargs): + self.config = SAFEVMConfig(**kwargs) + super().__init__(**kwargs) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + # not needed + return SAFEVisionModel(**kwargs) + + def forward(self, x, **kwargs): + return SAFEVMReturn(super().forward(x)) diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index 6bac0aee..36fddd43 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -1447,7 +1447,7 @@ class StableDiffusion: } if self.adapter is not None: if isinstance(self.adapter, IPAdapter): - requires_grad = self.adapter.adapter_modules.training + requires_grad = self.adapter.image_proj_model.training adapter_device = self.unet.device elif isinstance(self.adapter, T2IAdapter): requires_grad = self.adapter.adapter.conv_in.weight.requires_grad