diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index a40a20e..42fd7cc 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -666,7 +666,7 @@ class SDTrainer(BaseSDTrainProcess): conditional_embeddings=embeds_to_use.to(self.device_torch, dtype=dtype).detach(), unconditional_embeddings=unconditional_embeds, timestep=timesteps, - guidance_scale=1.0, + guidance_scale=self.train_config.cfg_scale, **pred_kwargs # adapter residuals in here ) if was_unet_training: @@ -980,40 +980,45 @@ class SDTrainer(BaseSDTrainProcess): if self.adapter and isinstance(self.adapter, IPAdapter): with self.timer('encode_adapter_embeds'): + image_size = self.adapter.input_size if is_reg: # we will zero it out in the img embedder clip_images = torch.zeros( - (noisy_latents.shape[0], 3, 512, 512), + (noisy_latents.shape[0], 3, image_size, image_size), device=self.device_torch, dtype=dtype ).detach() # drop will zero it out conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors( clip_images, drop=True, - is_training=True + is_training=True, + has_been_preprocessed=True ) if self.train_config.do_cfg: unconditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors( torch.zeros( - (noisy_latents.shape[0], 3, 512, 512), + (noisy_latents.shape[0], 3, image_size, image_size), device=self.device_torch, dtype=dtype ).detach(), is_training=True, - drop=True + drop=True, + has_been_preprocessed=True ) elif has_clip_image: conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors( clip_images.detach().to(self.device_torch, dtype=dtype), - is_training=True + is_training=True, + has_been_preprocessed=True ) if self.train_config.do_cfg: unconditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors( torch.zeros( - (noisy_latents.shape[0], 3, 512, 512), + (noisy_latents.shape[0], 3, image_size, image_size), device=self.device_torch, dtype=dtype ).detach(), is_training=True, - drop=True + drop=True, + has_been_preprocessed=True ) else: raise ValueError("Adapter images now must be loaded with dataloader or be a reg image") @@ -1094,7 +1099,7 @@ class SDTrainer(BaseSDTrainProcess): conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype), unconditional_embeddings=unconditional_embeds, timestep=timesteps, - guidance_scale=1.0, + guidance_scale=self.train_config.cfg_scale, **pred_kwargs ) self.after_unet_predict() diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index d7c3d93..f39c187 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -678,6 +678,8 @@ class BaseSDTrainProcess(BaseTrainProcess): if batch.tensor is not None: imgs = batch.tensor imgs = imgs.to(self.device_torch, dtype=dtype) + if self.train_config.img_multiplier is not None: + imgs = imgs * self.train_config.img_multiplier if batch.latents is not None: latents = batch.latents.to(self.device_torch, dtype=dtype) batch.latents = latents @@ -1113,6 +1115,9 @@ class BaseSDTrainProcess(BaseTrainProcess): self.snr_gos.scale.data = torch.tensor(json_data['scale'], device=self.device_torch) self.snr_gos.gamma.data = torch.tensor(json_data['gamma'], device=self.device_torch) + # load the adapters before the dataset as they may use the clip encoders + if self.adapter_config is not None: + self.setup_adapter() flush() ### HOOk ### @@ -1249,7 +1254,7 @@ class BaseSDTrainProcess(BaseTrainProcess): flush() if self.adapter_config is not None: - self.setup_adapter() + # self.setup_adapter() # set trainable params params.append({ 'params': self.adapter.parameters(), diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 1cd2c8f..066ca6b 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -272,6 +272,7 @@ class TrainConfig: self.dynamic_noise_offset = kwargs.get('dynamic_noise_offset', False) self.do_cfg = kwargs.get('do_cfg', False) + self.cfg_scale = kwargs.get('cfg_scale', 1.0) class ModelConfig: diff --git a/toolkit/data_loader.py b/toolkit/data_loader.py index 0eaaee4..aa51878 100644 --- a/toolkit/data_loader.py +++ b/toolkit/data_loader.py @@ -431,6 +431,7 @@ class AiToolkitDataset(LatentCachingMixin, BucketsMixin, CaptionMixin, Dataset): for file in tqdm(file_list): try: file_item = FileItemDTO( + sd=self.sd, path=file, dataset_config=dataset_config, dataloader_transforms=self.transform, diff --git a/toolkit/data_transfer_object/data_loader.py b/toolkit/data_transfer_object/data_loader.py index cc8699d..b06feb3 100644 --- a/toolkit/data_transfer_object/data_loader.py +++ b/toolkit/data_transfer_object/data_loader.py @@ -1,3 +1,5 @@ +import weakref +from _weakref import ReferenceType from typing import TYPE_CHECKING, List, Union import torch import random @@ -10,8 +12,10 @@ from toolkit.dataloader_mixins import CaptionProcessingDTOMixin, ImageProcessing ControlFileItemDTOMixin, ArgBreakMixin, PoiFileItemDTOMixin, MaskFileItemDTOMixin, AugmentationFileItemDTOMixin, \ UnconditionalFileItemDTOMixin, ClipImageFileItemDTOMixin + if TYPE_CHECKING: from toolkit.config_modules import DatasetConfig + from toolkit.stable_diffusion_model import StableDiffusion printed_messages = [] diff --git a/toolkit/dataloader_mixins.py b/toolkit/dataloader_mixins.py index 50fce14..3fcdc85 100644 --- a/toolkit/dataloader_mixins.py +++ b/toolkit/dataloader_mixins.py @@ -12,6 +12,7 @@ import numpy as np import torch from safetensors.torch import load_file, save_file from tqdm import tqdm +from transformers import CLIPImageProcessor from toolkit.basic import flush, value_map from toolkit.buckets import get_bucket_for_image_size, get_resolution @@ -27,7 +28,7 @@ from toolkit.train_tools import get_torch_dtype if TYPE_CHECKING: from toolkit.data_loader import AiToolkitDataset from toolkit.data_transfer_object.data_loader import FileItemDTO - + from toolkit.stable_diffusion_model import StableDiffusion # def get_associated_caption_from_img_path(img_path): # https://demo.albumentations.ai/ @@ -565,8 +566,13 @@ class ClipImageFileItemDTOMixin: self.clip_image_tensor: Union[torch.Tensor, None] = None self.has_clip_augmentations = False self.clip_image_aug_transform: Union[None, A.Compose] = None + self.clip_image_processor: Union[None, CLIPImageProcessor] = None dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None) if dataset_config.clip_image_path is not None: + # copy the clip image processor so the dataloader can do it + sd = kwargs.get('sd', None) + if hasattr(sd.adapter, 'clip_image_processor'): + self.clip_image_processor = sd.adapter.clip_image_processor # find the control image path clip_image_path = dataset_config.clip_image_path # we are using control images @@ -632,13 +638,22 @@ class ClipImageFileItemDTOMixin: print(f"Error: {e}") print(f"Error loading image: {self.clip_image_path}") - # we just scale them to 512x512: - img = img.resize((512, 512), Image.BICUBIC) if self.has_clip_augmentations: self.clip_image_tensor = self.augment_clip_image(img, transform=None) else: self.clip_image_tensor = transforms.ToTensor()(img) + if self.clip_image_processor is not None: + # run it + tensors_0_1 = self.clip_image_tensor.to(dtype=torch.float16) + clip_out = self.clip_image_processor( + images=tensors_0_1, + return_tensors="pt", + do_resize=True, + do_rescale=False, + ).pixel_values + self.clip_image_tensor = clip_out.squeeze(0).clone().detach() + def cleanup_clip_image(self: 'FileItemDTO'): self.clip_image_tensor = None diff --git a/toolkit/ip_adapter.py b/toolkit/ip_adapter.py index 49cbac2..3c0b80d 100644 --- a/toolkit/ip_adapter.py +++ b/toolkit/ip_adapter.py @@ -5,12 +5,13 @@ from PIL import Image from torch.nn import Parameter from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection +from toolkit.models.clip_pre_processor import CLIPImagePreProcessor from toolkit.paths import REPOS_ROOT from toolkit.saving import load_ip_adapter_model from toolkit.train_tools import get_torch_dtype sys.path.append(REPOS_ROOT) -from typing import TYPE_CHECKING, Union, Iterator, Mapping, Any, Tuple, List +from typing import TYPE_CHECKING, Union, Iterator, Mapping, Any, Tuple, List, Optional from collections import OrderedDict from ipadapter.ip_adapter.attention_processor import AttnProcessor, IPAttnProcessor, IPAttnProcessor2_0, \ AttnProcessor2_0 @@ -163,7 +164,9 @@ class IPAdapter(torch.nn.Module): self.config = adapter_config self.sd_ref: weakref.ref = weakref.ref(sd) self.device = self.sd_ref().unet.device - if self.config.image_encoder_arch == 'clip': + self.preprocessor: Optional[CLIPImagePreProcessor] = None + self.input_size = 224 + if self.config.image_encoder_arch == 'clip' or self.config.image_encoder_arch == 'clip+': try: self.clip_image_processor = CLIPImageProcessor.from_pretrained(adapter_config.image_encoder_path) except EnvironmentError: @@ -176,7 +179,8 @@ class IPAdapter(torch.nn.Module): self.clip_image_processor = ViTFeatureExtractor.from_pretrained(adapter_config.image_encoder_path) except EnvironmentError: self.clip_image_processor = ViTFeatureExtractor() - self.image_encoder = ViTForImageClassification.from_pretrained(adapter_config.image_encoder_path).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)) + self.image_encoder = ViTForImageClassification.from_pretrained(adapter_config.image_encoder_path).to( + self.device, dtype=get_torch_dtype(self.sd_ref().dtype)) elif self.config.image_encoder_arch == 'safe': try: self.clip_image_processor = SAFEImageProcessor.from_pretrained(adapter_config.image_encoder_path) @@ -220,6 +224,26 @@ class IPAdapter(torch.nn.Module): ).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)) else: raise ValueError(f"unknown image encoder arch: {adapter_config.image_encoder_arch}") + + self.input_size = self.image_encoder.config.image_size + + if self.config.image_encoder_arch == 'clip+': + # self.clip_image_processor.config + # We do a 3x downscale of the image, so we need to adjust the input size + preprocessor_input_size = self.image_encoder.config.image_size * 3 + + # update the preprocessor so images come in at the right size + self.clip_image_processor.size['shortest_edge'] = preprocessor_input_size + self.clip_image_processor.crop_size['height'] = preprocessor_input_size + self.clip_image_processor.crop_size['width'] = preprocessor_input_size + + self.preprocessor = CLIPImagePreProcessor( + input_size=preprocessor_input_size, + clip_input_size=self.image_encoder.config.image_size, + downscale_factor=6 + ) + + self.input_size = self.clip_image_processor.size['shortest_edge'] self.current_scale = 1.0 self.is_active = True if adapter_config.type == 'ip': @@ -232,12 +256,9 @@ class IPAdapter(torch.nn.Module): elif adapter_config.type == 'ip+': heads = 12 if not sd.is_xl else 20 dim = sd.unet.config['cross_attention_dim'] if not sd.is_xl else 1280 - embedding_dim = self.image_encoder.config.hidden_size if not self.config.image_encoder_arch == "convnext" else self.image_encoder.config.hidden_sizes[-1] + embedding_dim = self.image_encoder.config.hidden_size if not self.config.image_encoder_arch == "convnext" else \ + self.image_encoder.config.hidden_sizes[-1] - # if self.config.image_encoder_arch == 'safe': - # embedding_dim = self.config.safe_tokens - # size mismatch for latents: copying a param with shape torch.Size([1, 16, 1280]) from checkpoint, the shape in current model is torch.Size([1, 16, 2048]). - # size mismatch for latents: copying a param with shape torch.Size([1, 32, 2048]) from checkpoint, the shape in current model is torch.Size([1, 16, 1280]) # ip-adapter-plus image_proj_model = Resampler( dim=dim, @@ -312,6 +333,8 @@ class IPAdapter(torch.nn.Module): self.image_encoder.to(*args, **kwargs) self.image_proj_model.to(*args, **kwargs) self.adapter_modules.to(*args, **kwargs) + if self.preprocessor is not None: + self.preprocessor.to(*args, **kwargs) return self def load_ip_adapter(self, state_dict: Union[OrderedDict, dict]): @@ -320,6 +343,8 @@ class IPAdapter(torch.nn.Module): ip_layers.load_state_dict(state_dict["ip_adapter"]) if self.config.train_image_encoder and 'image_encoder' in state_dict: self.image_encoder.load_state_dict(state_dict["image_encoder"]) + if self.preprocessor is not None and 'preprocessor' in state_dict: + self.preprocessor.load_state_dict(state_dict["preprocessor"]) # def load_state_dict(self, state_dict: Union[OrderedDict, dict]): # self.load_ip_adapter(state_dict) @@ -330,6 +355,8 @@ class IPAdapter(torch.nn.Module): state_dict["ip_adapter"] = self.adapter_modules.state_dict() if self.config.train_image_encoder: state_dict["image_encoder"] = self.image_encoder.state_dict() + if self.preprocessor is not None: + state_dict["preprocessor"] = self.preprocessor.state_dict() return state_dict def get_scale(self): @@ -341,53 +368,74 @@ class IPAdapter(torch.nn.Module): if isinstance(attn_processor, CustomIPAttentionProcessor): attn_processor.scale = scale - @torch.no_grad() - def get_clip_image_embeds_from_pil(self, pil_image: Union[Image.Image, List[Image.Image]], - drop=False) -> torch.Tensor: - # todo: add support for sdxl - if isinstance(pil_image, Image.Image): - pil_image = [pil_image] - clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values - clip_image = clip_image.to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)) - if drop: - clip_image = clip_image * 0 - clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2] - return clip_image_embeds + # @torch.no_grad() + # def get_clip_image_embeds_from_pil(self, pil_image: Union[Image.Image, List[Image.Image]], + # drop=False) -> torch.Tensor: + # # todo: add support for sdxl + # if isinstance(pil_image, Image.Image): + # pil_image = [pil_image] + # clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values + # clip_image = clip_image.to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)) + # if drop: + # clip_image = clip_image * 0 + # clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2] + # return clip_image_embeds - def get_clip_image_embeds_from_tensors(self, tensors_0_1: torch.Tensor, drop=False, - is_training=False) -> torch.Tensor: + def get_clip_image_embeds_from_tensors( + self, + tensors_0_1: torch.Tensor, + drop=False, + is_training=False, + has_been_preprocessed=False + ) -> torch.Tensor: with torch.no_grad(): - # tensors should be 0-1 - # todo: add support for sdxl - if tensors_0_1.ndim == 3: - tensors_0_1 = tensors_0_1.unsqueeze(0) - # training tensors are 0 - 1 - tensors_0_1 = tensors_0_1.to(self.device, dtype=torch.float16) - # if images are out of this range throw error - if tensors_0_1.min() < -0.3 or tensors_0_1.max() > 1.3: - raise ValueError("image tensor values must be between 0 and 1. Got min: {}, max: {}".format( - tensors_0_1.min(), tensors_0_1.max() - )) - - clip_image = self.clip_image_processor( - images=tensors_0_1, - return_tensors="pt", - do_resize=True, - do_rescale=False, - ).pixel_values + # on training the clip image is created in the dataloader + if not has_been_preprocessed: + # tensors should be 0-1 + if tensors_0_1.ndim == 3: + tensors_0_1 = tensors_0_1.unsqueeze(0) + # training tensors are 0 - 1 + tensors_0_1 = tensors_0_1.to(self.device, dtype=torch.float16) + # if images are out of this range throw error + if tensors_0_1.min() < -0.3 or tensors_0_1.max() > 1.3: + raise ValueError("image tensor values must be between 0 and 1. Got min: {}, max: {}".format( + tensors_0_1.min(), tensors_0_1.max() + )) + clip_image = self.clip_image_processor( + images=tensors_0_1, + return_tensors="pt", + do_resize=True, + do_rescale=False, + ).pixel_values + else: + clip_image = tensors_0_1 clip_image = clip_image.to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)).detach() if drop: clip_image = clip_image * 0 with torch.set_grad_enabled(is_training): if is_training: self.image_encoder.train() - clip_output = self.image_encoder(clip_image.requires_grad_(True) - , output_hidden_states=True) + clip_image = clip_image.requires_grad_(True) + if self.preprocessor is not None: + clip_image = self.preprocessor(clip_image) + clip_output = self.image_encoder( + clip_image, + output_hidden_states=True + ) else: self.image_encoder.eval() - clip_output = self.image_encoder(clip_image, output_hidden_states=True) + if self.preprocessor is not None: + clip_image = self.preprocessor(clip_image) + clip_output = self.image_encoder( + clip_image, output_hidden_states=True + ) - clip_image_embeds = clip_output.hidden_states[-2] + if self.config.type.startswith('ip+'): + # they skip last layer for ip+ + # https://github.com/tencent-ailab/IP-Adapter/blob/f4b6742db35ea6d81c7b829a55b0a312c7f5a677/tutorial_train_plus.py#L403C26-L403C26 + clip_image_embeds = clip_output.hidden_states[-2] + else: + clip_image_embeds = clip_output.image_embeds return clip_image_embeds # use drop for prompt dropout, or negatives @@ -403,13 +451,8 @@ class IPAdapter(torch.nn.Module): yield from self.image_proj_model.parameters(recurse) if self.config.train_image_encoder: yield from self.image_encoder.parameters(recurse) - # if self.config.train_image_encoder: - # yield from self.image_encoder.parameters(recurse) - # self.image_encoder.train() - # else: - # for attn_processor in self.adapter_modules: - # yield from attn_processor.parameters(recurse) - # yield from self.image_proj_model.parameters(recurse) + if self.preprocessor is not None: + yield from self.preprocessor.parameters(recurse) def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True): strict = False @@ -417,6 +460,8 @@ class IPAdapter(torch.nn.Module): self.adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=strict) if self.config.train_image_encoder and 'image_encoder' in state_dict: self.image_encoder.load_state_dict(state_dict["image_encoder"], strict=strict) + if self.config.image_encoder_arch == 'clip+' and 'preprocessor' in state_dict: + self.preprocessor.load_state_dict(state_dict["preprocessor"], strict=strict) def enable_gradient_checkpointing(self): self.image_encoder.gradient_checkpointing = True diff --git a/toolkit/models/clip_pre_processor.py b/toolkit/models/clip_pre_processor.py new file mode 100644 index 0000000..851b0f1 --- /dev/null +++ b/toolkit/models/clip_pre_processor.py @@ -0,0 +1,113 @@ +import torch +import torch.nn as nn + + +class UpsampleBlock(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.conv_in = nn.Sequential( + nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1), + nn.GELU() + ) + self.conv_up = nn.Sequential( + nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2), + nn.GELU() + ) + + self.conv_out = nn.Sequential( + nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1) + ) + + def forward(self, x): + x = self.conv_in(x) + x = self.conv_up(x) + x = self.conv_out(x) + return x + + +class CLIPImagePreProcessor(nn.Module): + def __init__( + self, + input_size=672, + clip_input_size=224, + downscale_factor: int = 6, + channels=None, # 108 + ): + super().__init__() + # make sure they are evenly divisible + assert input_size % clip_input_size == 0 + in_channels = 3 + + self.input_size = input_size + self.clip_input_size = clip_input_size + self.downscale_factor = downscale_factor + + subpixel_channels = in_channels * downscale_factor ** 2 # 3 * 6 ** 2 = 108 + + if channels is None: + channels = subpixel_channels + + upscale_factor = downscale_factor / int((input_size / clip_input_size)) # 6 / (672 / 224) = 2 + + num_upsample_blocks = int(upscale_factor // 2) # 2 // 2 = 1 + + # do a pooling layer to downscale the input to 1/3 of the size + # (bs, 3, 672, 672) -> (bs, 3, 224, 224) + kernel_size = input_size // clip_input_size + self.res_down = nn.AvgPool2d( + kernel_size=kernel_size, + stride=kernel_size + ) # (bs, 3, 672, 672) -> (bs, 3, 224, 224) + + # make a blending for output residual with near 0 weight + self.res_blend = nn.Parameter(torch.tensor(0.001)) # (bs, 3, 224, 224) -> (bs, 3, 224, 224) + + self.unshuffle = nn.PixelUnshuffle(downscale_factor) # (bs, 3, 672, 672) -> (bs, 108, 112, 112) + + self.conv_in = nn.Sequential( + nn.Conv2d( + subpixel_channels, + channels, + kernel_size=3, + padding=1 + ), + nn.GELU() + ) # (bs, 108, 112, 112) -> (bs, 108, 112, 112) + + self.upsample_blocks = nn.ModuleList() + current_channels = channels + for _ in range(num_upsample_blocks): + out_channels = current_channels // 2 + self.upsample_blocks.append(UpsampleBlock(current_channels, out_channels)) + current_channels = out_channels + + # (bs, 108, 112, 112) -> (bs, 54, 224, 224) + + self.conv_out = nn.Conv2d( + current_channels, + out_channels=3, + kernel_size=3, + padding=1 + ) # (bs, 54, 224, 224) -> (bs, 3, 224, 224) + + + def forward(self, x): + # resize to input_size x input_size + x = nn.functional.interpolate(x, size=(self.input_size, self.input_size), mode='bicubic') + + res = self.res_down(x) + + x = self.unshuffle(x) + x = self.conv_in(x) + for up in self.upsample_blocks: + x = up(x) + x = self.conv_out(x) + # blend residual + x = x * self.res_blend + res + return x